from enum import IntEnum, auto
import pandas as pd
from ms_service_profiler.plugins.base import PluginBase
from ms_service_profiler.utils.timer import timer
from ms_service_profiler.utils.log import logger
from ms_service_profiler.utils.error import KeyExcept
class HostRole(IntEnum):
PREFILL = auto()
DECODE = auto()
OTHER = auto()
class BatchType(IntEnum):
PREFILL = auto()
DECODE = auto()
MIX = auto()
OTHER = auto()
class PluginTrace(PluginBase):
name = "plugin_trace"
depends = ["plugin_common", "plugin_req_status"]
@staticmethod
def map_batch_type(batch_type, batch_type_mapping):
if pd.isna(batch_type):
return None
return batch_type_mapping.get(batch_type, "Other")
@staticmethod
def fix_batch_type(tx_data_df):
with KeyExcept('name', 'hostname', 'pid', "batch_type", "rid_list", ignore=True, msg=""):
batch_type_mapping = {
0: "Prefill",
1: "Decode",
2: "Extend",
3: "Mixed",
5: "Dummy",
}
if 'batchType' in tx_data_df.columns:
tx_data_df['batch_type'] = tx_data_df['batchType'].apply(
lambda x: PluginTrace.map_batch_type(x, batch_type_mapping)
)
return tx_data_df
role_df = tx_data_df[tx_data_df["name"].isin(["prefillRes", "decodeRes"])]
role_map = dict(zip(zip(role_df['hostname'], role_df['pid']),
role_df['name'].map(dict(prefillRes=HostRole.PREFILL, decodeRes=HostRole.DECODE))))
tx_data_df['role'] = tx_data_df[tx_data_df['batch_type'].notna()].apply(
lambda row: role_map.get((row['hostname'], row['pid']), None), axis=1)
prefill_conflict = tx_data_df[
(tx_data_df['role'] == HostRole.PREFILL) & (tx_data_df['batch_type'] != "Prefill")]
tx_data_df.loc[prefill_conflict.index, "batch_type"] = "Other"
decode_conflict = tx_data_df[
(tx_data_df['role'] == HostRole.DECODE) & (tx_data_df['batch_type'] != "Decode")]
decode_conflict = decode_conflict[["hostname", "pid", "name", "rid_list"]].reset_index().explode("rid_list")
decode_conflict["batchtype"] = BatchType.OTHER
last_rows = decode_conflict.groupby(["hostname", "pid", "name", "rid_list"]).tail(1).index
decode_conflict.loc[last_rows, "batchtype"] = BatchType.DECODE
str_batch_type_map = {
BatchType.PREFILL: "Prefill",
BatchType.DECODE: "Decode",
BatchType.MIX: "Prefill, Decode",
BatchType.OTHER: "Other",
}
decode_batch_type = decode_conflict.groupby("index")["batchtype"].min().map(str_batch_type_map)
tx_data_df.loc[decode_batch_type.index, "batch_type"] = decode_batch_type
return tx_data_df
@classmethod
@timer(logger.debug)
def parse(cls, data):
with KeyExcept('token_id_list', 'batch_type', 'rid_list', ignore=True,
msg="ignoring current process by default."):
tx_data_df = data.get('tx_data_df')
if tx_data_df is None:
return data
if 'batch_type' not in tx_data_df.columns:
tx_data_df['batch_type'] = None
tx_data_df['batch_type'] = [
extract_batch_type(token_list, batch_type)
for token_list, batch_type in zip(tx_data_df['token_id_list'], tx_data_df['batch_type'])
]
tx_data_df = PluginTrace.fix_batch_type(tx_data_df)
tx_data_df['batch_size'] = [extract_batch_size(x) for x in tx_data_df['rid_list']]
tx_data_df['prefill_batch_size'], tx_data_df['decode_batch_size'] = zip(
*[
extract_batch_size_when_pd_mixed(token_list) if batch_type == 'Prefill, Decode' else (None, None)
for token_list, batch_type in zip(tx_data_df['token_id_list'], tx_data_df['batch_type'])
]
)
data['tx_data_df'] = tx_data_df
return data
def extract_batch_type(token_list, batch_type):
if batch_type is not None:
return batch_type
if token_list is None:
return None
has_prefill = 0 in token_list
has_decode = any(x > 0 for x in token_list if x is not None)
if has_prefill and has_decode:
return 'Prefill, Decode'
elif has_prefill and not has_decode:
return 'Prefill'
elif not has_prefill and has_decode:
return 'Decode'
else:
return None
def extract_batch_size(rid_list):
if rid_list is None:
return None
return str(int(len(rid_list)))
def extract_batch_size_when_pd_mixed(token_list):
prefill_batch_size = token_list.count(0)
decode_batch_size = len(token_list) - prefill_batch_size
return prefill_batch_size, decode_batch_size