# -------------------------------------------------------------------------
# This file is part of the MindStudio project.
# Copyright (c) 2025 Huawei Technologies Co.,Ltd.
#
# MindStudio is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
#          http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
# -------------------------------------------------------------------------

import numpy as np
import pandas as pd

from ms_service_profiler.constant import US_PER_MS
from ms_service_profiler.utils.log import logger
from ms_service_profiler.utils.timer import timer, Timer
from ms_service_profiler.processor.processor_base import ProcessorBase


class ProcessorReq(ProcessorBase):
    _batch_token_iter_warning_issued = False

    @property
    def name(self):
        return "ProcessorReq"

    @staticmethod
    def ensure_list_size(lst, ensure_size, fill_value=0):
        if len(lst) >= ensure_size:
            return lst
        return lst + [fill_value] * (ensure_size - len(lst))

    @staticmethod
    def parse_node_role(data_df: pd.DataFrame):
        role_df = data_df[data_df["name"].isin(["prefillRes", "decodeRes"])]
        role_dict = dict(zip(role_df['pid'], role_df['name'].map(dict(prefillRes=1, decodeRes=2))))
        return role_dict

    @classmethod
    def batch_token_iter_to_batch_type(cls, token_iter_list):
        # 统一处理空值和非列表/元组类型
        token_iter_list_con = np.isscalar(token_iter_list) and pd.isna(token_iter_list)
        if (token_iter_list is None or token_iter_list_con or not isinstance(token_iter_list, (list, tuple))):
            if not cls._batch_token_iter_warning_issued:
                logger.warning(f"Warning: Skipping invalid row type {type(token_iter_list)}: {token_iter_list}")
                cls._batch_token_iter_warning_issued = True
            return 1

        # 处理空列表
        if not token_iter_list:
            return 1

        # 处理含NaN的列表
        if pd.isna(token_iter_list).any():
            return 1

        # 有效列表的判断(无重复)
        if all(token_iter_list):  # 全部大于0
            return 2
        elif any(token_iter_list):  # 存在0和非0
            return 0
        else:  # 全为0
            return 1

    @timer(logger.debug)
    def parse_batch(self, data_df: pd.DataFrame):
        batch_event_df = pd.DataFrame(columns=["batch_id", "event", "start_time", "end_time", "pid", "blocks"])
        batch_attr_df = pd.DataFrame(columns=["batch_id", "req_list", "req_id_list", "batch_size", "batch_type"])

        if data_df is None or data_df.empty:
            return batch_event_df, batch_attr_df

        name_or_res_list_not = "name" not in data_df or "res_list" not in data_df

        if (name_or_res_list_not or "token_id_list" not in data_df or "rid_list" not in data_df):
            return batch_event_df, batch_attr_df
        role_dict = self.parse_node_role(data_df)

        # forward 之后补充
        batch_data_df = data_df[data_df["name"].isin(["BatchSchedule", "modelExec", "batchFrameworkProcessing",
                                                      "Execute", "preprocess", "forward", "modelRunnerExec"])]


        # 先不考虑 batch_id 重复的情况
        batch_id_df = batch_data_df["res_list"].map(str)

        # 过滤掉PD分离场景,batch type 判断错误的数据
        role_batch_type = batch_data_df['pid'].map(role_dict)
        iter_batch_type = batch_data_df["token_id_list"].map(lambda row: self.batch_token_iter_to_batch_type(row))

        right_role_batch_type = role_batch_type.isna()  # 没有PD分离的数据
        right_iter_batch_type = iter_batch_type == role_batch_type  # 判断正确的数据
        right_decode_batch_type = batch_data_df[(role_batch_type == 2) & (iter_batch_type == 1)] \
            .groupby(['name', batch_id_df]).cumcount(ascending=False) == 0  # D节点最后一个判断为P 的数据

        right_batch_type_mask = right_role_batch_type | right_iter_batch_type | right_decode_batch_type
        batch_data_df = batch_data_df[right_batch_type_mask]

        # 过滤后数据填充 data frame 返回
        batch_event_df["batch_id"] = batch_id_df[right_batch_type_mask]
        batch_event_df["event"] = batch_data_df["name"]
        batch_event_df["pid"] = batch_data_df["pid"]

        # 只有当 blocks 列存在时才添加到 batch_event_df
        if "blocks" in batch_data_df.columns:
            batch_event_df["blocks"] = batch_data_df["blocks"]

        batch_event_df["start_time"] = batch_data_df["start_time"]
        batch_event_df["end_time"] = batch_data_df["end_time"]

        schedule_mask = batch_data_df["name"].isin(["BatchSchedule", "batchFrameworkProcessing"])
        schedule_data_df = batch_data_df[schedule_mask]
        # 创建过滤条件:所有三个列都非空列表
        schedule_data_df = schedule_data_df[
            (schedule_data_df['res_list'].apply(len) > 0) &
            (schedule_data_df['rid_list'].apply(len) > 0) &
            (schedule_data_df['token_id_list'].apply(len) > 0)
            ]
        batch_attr_df["batch_id"] = batch_event_df[schedule_mask]["batch_id"]
        batch_attr_df["req_list"] = schedule_data_df["res_list"]
        batch_attr_df["req_id_list"] = schedule_data_df["rid_list"]
        batch_attr_df["batch_size"] = schedule_data_df["rid_list"].map(len)
        batch_attr_df["batch_type"] = pd.concat([role_batch_type, iter_batch_type]).groupby(level=0).first()

        return batch_event_df, batch_attr_df

    @timer(logger.debug)
    def parse_req(self, data_df: pd.DataFrame, batch_event_df: pd.DataFrame, batch_attr_df: pd.DataFrame):
        req_event_df = pd.DataFrame(columns=["rid", "event", "iter", "start_time", "end_time", "batch_id"])
        req_attr_df = pd.DataFrame(columns=["rid", "recv_token", "reply_token", "ttft"])
        req_queue_df = pd.DataFrame(columns=["rid", "start_time", "end_time", "event", "status"])

        if data_df is None or data_df.empty:
            return req_event_df, req_attr_df, req_queue_df

        if not self._validate_data_columns(data_df):
            return req_event_df, req_attr_df, req_queue_df

        # 处理HTTP事件
        req_event_df = self._process_http_events(data_df, req_event_df)

        # 处理请求属性
        req_attr_df = self._process_request_attributes(data_df)

        # 处理请求队列
        req_queue_df = self._process_request_queue(data_df)

        # 处理批次事件
        req_event_df = self._process_batch_events(req_event_df, batch_event_df, batch_attr_df)

        # 处理批次调度事件
        req_event_df = self._process_batch_schedule_events(req_event_df, batch_event_df, batch_attr_df)

        return req_event_df, req_attr_df, req_queue_df

    def _validate_data_columns(self, data_df: pd.DataFrame) -> bool:
        """验证数据框是否包含必要的列"""
        required_columns = ["name", "res_list", "token_id_list", "rid_list"]
        return all(col in data_df for col in required_columns)

    def _process_http_events(self, data_df: pd.DataFrame, req_event_df: pd.DataFrame) -> pd.DataFrame:
        """处理HTTP事件"""
        http_event_df = data_df[data_df["name"].isin(["httpReq", "httpRes", "decode",
                                                      "detokenize", "DecodeEnd", "sendResponse", "FINISHED"])]
        req_event_df["rid"] = http_event_df["rid"]
        req_event_df["event"] = http_event_df["name"]
        req_event_df["start_time"] = http_event_df["start_time"]
        req_event_df["end_time"] = http_event_df["end_time"]
        req_event_df["end_flag"] = http_event_df.get("endFlag", None)
        return req_event_df

    def _process_request_attributes(self, data_df: pd.DataFrame) -> pd.DataFrame:
        """处理请求属性"""
        rid_recv_token_map = {}
        rid_reply_token_map = {}

        if "recvTokenSize=" in data_df:
            recv_token_df = data_df[data_df["recvTokenSize="].notna()]
            rid_recv_token_map = recv_token_df.set_index('rid')['recvTokenSize='].to_dict()

        if "replyTokenSize=" in data_df:
            reply_token_df = data_df[data_df["replyTokenSize="].notna()]
            rid_reply_token_map = reply_token_df.set_index('rid')['replyTokenSize='].to_dict()

        req_attr_df = pd.DataFrame({'recv_token': rid_recv_token_map, 'reply_token': rid_reply_token_map})
        req_attr_df['rid'] = req_attr_df.index
        return req_attr_df

    def _process_request_queue(self, data_df: pd.DataFrame) -> pd.DataFrame:
        """处理请求队列"""
        status_col = data_df.get('status')
        if status_col is not None:
            mask = data_df['name'].isin(['Dequeue', 'Enqueue']) & (data_df['status'] == 'waiting')
        else:
            mask = data_df['name'].isin(['Dequeue', 'Enqueue'])

        tmp = data_df.loc[mask, :].copy(deep=False)
        tmp['rid'] = tmp['rid'].astype(str).str.strip().str.split(r'\s*,\s*')

        status_col = tmp.get('status')
        selected_columns = ['rid', 'start_time', 'end_time', 'event']
        if status_col is not None:
            selected_columns.append('status')

        req_queue_df = (
            tmp
            .explode('rid')
            .query('rid.str.strip() != ""')
            .rename(columns={'name': 'event'})
            [selected_columns]
        )
        return req_queue_df

    def _process_batch_events(self, req_event_df: pd.DataFrame, batch_event_df: pd.DataFrame,
                              batch_attr_df: pd.DataFrame) -> pd.DataFrame:
        """处理批次事件"""
        model_exec_df = batch_event_df[batch_event_df["event"].isin(["modelExec", "Execute", "modelRunnerExec"])]

        batch_attr_explode_by_req_df = batch_attr_df.explode('req_list')
        batch_attr_explode_by_req_df['rid'] = batch_attr_explode_by_req_df['req_list'].map(
            lambda x: x.get("rid") if isinstance(x, dict) else None
        )
        batch_attr_explode_by_req_df['iter'] = batch_attr_explode_by_req_df['req_list'].map(
            lambda x: x.get("iter") if isinstance(x, dict) else None
        )

        has_num_scheduled_tokens = batch_attr_explode_by_req_df['req_list'].apply(
            lambda x: isinstance(x, dict) and "num_scheduled_tokens=" in x
        ).any()

        if has_num_scheduled_tokens:
            batch_attr_explode_by_req_df['num_scheduled_tokens='] = batch_attr_explode_by_req_df['req_list'].map(
                lambda x: x.get("num_scheduled_tokens=") if isinstance(x, dict) else None
            )

        merged = batch_attr_explode_by_req_df.join(model_exec_df.set_index('batch_id'), on='batch_id')

        selected_columns = ["rid", "event", "iter", "start_time", "end_time", "batch_id", "batch_size"]
        if has_num_scheduled_tokens:
            selected_columns.append("num_scheduled_tokens=")

        new_events = merged[selected_columns]

        req_event_df = pd.concat([req_event_df, new_events], ignore_index=True)
        return req_event_df

    def _process_batch_schedule_events(self, req_event_df: pd.DataFrame, batch_event_df: pd.DataFrame,
                                       batch_attr_df: pd.DataFrame) -> pd.DataFrame:
        """处理批次调度事件"""
        batch_schedule_events = batch_event_df[batch_event_df["event"].isin(["BatchSchedule", "batchFrameworkProcessing"])]
        if batch_schedule_events.empty:
            return req_event_df

        original_schedule_data_df = batch_event_df.join(
            batch_attr_df.set_index('batch_id'), on='batch_id', rsuffix='_attr'
        )
        schedule_data_joined = original_schedule_data_df[
            original_schedule_data_df["event"].isin(["BatchSchedule", "batchFrameworkProcessing"])
            ]
        if schedule_data_joined.empty:
            return req_event_df

        exploded_schedule = schedule_data_joined.explode('req_list')
        exploded_schedule['rid'] = exploded_schedule['req_list'].map(
            lambda x: x.get("rid") if isinstance(x, dict) else None
        )
        exploded_schedule['iter'] = exploded_schedule['req_list'].map(
            lambda x: x.get("iter") if isinstance(x, dict) else None
        )

        has_num_scheduled_tokens = exploded_schedule['req_list'].apply(
            lambda x: isinstance(x, dict) and "num_scheduled_tokens=" in x
        ).any()

        if has_num_scheduled_tokens:
            exploded_schedule['num_scheduled_tokens='] = exploded_schedule['req_list'].map(
                lambda x: x.get("num_scheduled_tokens=") if isinstance(x, dict) else None
            )

        prefill_schedule = exploded_schedule[exploded_schedule['iter'] == 0].copy()
        if prefill_schedule.empty:
            return req_event_df

        prefill_start_dict = {
            'rid': prefill_schedule['rid'],
            'event': prefill_schedule['event'],
            'iter': prefill_schedule['iter'],
            'start_time': prefill_schedule['start_time'],
            'end_time': prefill_schedule['end_time'],
            'batch_id': prefill_schedule['batch_id'],
            'batch_size': prefill_schedule['batch_size']
        }

        if has_num_scheduled_tokens:
            prefill_start_dict['num_scheduled_tokens='] = prefill_schedule['num_scheduled_tokens=']

        prefill_start_df = pd.DataFrame(prefill_start_dict)

        prefill_start_df = prefill_start_df.dropna(subset=['rid'])

        if 'num_scheduled_tokens=' in prefill_start_df.columns:
            prefill_start_df = prefill_start_df[prefill_start_df['num_scheduled_tokens='] > 0]

        req_event_df = pd.concat([req_event_df, prefill_start_df], ignore_index=True)
        return req_event_df

    @timer(logger.debug)
    def calc_ttft(self, req_event_df: pd.DataFrame):
        req_ttft_df = pd.DataFrame(columns=["rid", "ttft", "start", "end"])

        if req_event_df is None or req_event_df.empty:
            return req_ttft_df

        # 取请求到达时间和第一个迭代时间
        calc_df = req_event_df[(req_event_df["event"] == "httpReq") | (req_event_df["iter"] == 0)]

        # 取第一个 detokenize/decode
        first_decode = (
            req_event_df[req_event_df["event"].isin(["detokenize", "decode"])]
            .groupby("rid")
            .first()
            .reset_index()
        )

        # 合并非空 DataFrame
        non_empty_dfs = [df for df in [calc_df, first_decode] if not df.empty]
        calc_df = pd.concat(non_empty_dfs, ignore_index=True) if non_empty_dfs else calc_df

        # 之前ttft算的有问题,应该是取第一个 sendResponse
        last_send_response = (
            req_event_df[req_event_df["event"] == "sendResponse"]
            .groupby("rid")
            .first()
            .reset_index()
        )
        if not last_send_response.empty:
            calc_df = pd.concat([calc_df, last_send_response], ignore_index=True)

        # 如果 calc_df 为空,直接返回
        if calc_df.empty:
            return req_ttft_df

        # 按 rid 聚合
        group_by_df = calc_df.groupby("rid").agg({
            "start_time": "min",
            "end_time": "max",
            "event": ["first", "count"]
        }).reset_index()

        group_by_df.columns = ['rid', 'start_time', 'end_time', 'event_first', 'event_count']

        req_ttft_df = group_by_df[
            (group_by_df['event_count'] > 1) &
            (group_by_df['event_first'] == 'httpReq')
            ].copy()

        req_ttft_df.loc[:, "ttft"] = req_ttft_df['end_time'] - req_ttft_df['start_time']
        req_ttft_df = req_ttft_df.drop(columns=['event_first', 'event_count'])
        return req_ttft_df

    @timer(logger.debug)
    def calc_que_wait(self, req_queue_df: pd.DataFrame):
        """
        队列等待时长逻辑为按rid分组后,使用Dequeue的结束时间减去Enqueue的开始时间
        由于都是瞬时的点,故开始时间和结束时间相同
        """
        req_que_wait_df = pd.DataFrame(columns=["rid", "que_wait_time"])

        if req_queue_df is None or req_queue_df.empty:
            return req_que_wait_df

        # 1. 把事件拆成两类
        enq = req_queue_df[req_queue_df["event"] == "Enqueue"]
        deq = req_queue_df[req_queue_df["event"] == "Dequeue"]

        # 2. 聚合:取 Enqueue 的最早 start_time 和 Dequeue 的最晚 end_time
        enq_agg = enq.groupby("rid")["start_time"].min().rename("enq_start")
        deq_agg = deq.groupby("rid")["end_time"].max().rename("deq_end")

        # 3. 合并、计算等待时长(秒)
        req_que_wait_df = (
            pd.concat([enq_agg, deq_agg], axis=1)
            .assign(que_wait_time=lambda x: (x["deq_end"] - x["enq_start"]))
            .reset_index()
            .loc[:, ["rid", "que_wait_time"]]
        )

        return req_que_wait_df

    def parse(self, data_df: pd.DataFrame):
        batch_event_df, batch_attr_df = self.parse_batch(data_df)
        req_event_df, req_attr_df, req_queue_df = (
            self.parse_req(data_df, batch_event_df, batch_attr_df)
        )

        req_ttft_df = self.calc_ttft(req_event_df)
        req_queue_wait_time_df = self.calc_que_wait(req_queue_df)
        req_attr_df["ttft"] = req_ttft_df["ttft"]

        # ttft 和 que_wait_time为原始数据,单位为微秒,需要exporter中调用时进行单位转换

        return {
            "req_event_df": req_event_df,
            "req_attr_df": req_attr_df,
            "batch_event_df": batch_event_df,
            "batch_attr_df": batch_attr_df,
            "req_ttft_df": req_ttft_df,
            "req_que_wait_df": req_queue_wait_time_df
        }