from enum import Enum
from ms_service_profiler.plugins.base import PluginBase
from ms_service_profiler.plugins.plugin_metric import is_metric
from ms_service_profiler.utils.log import logger
from ms_service_profiler.utils.timer import timer
class ReqStatus(Enum):
WAITING = 0
PENDING = 1
RUNNING = 2
RUNNING2 = 3
SWAPPED = 4
RECOMPUTE = 5
SUSPENDED = 6
END = 7
STOP = 8
PREFILL_HOLD = 9
END_PRE = 10
STOP_PRE = 11
WAITING_PULL = 12
PULLING = 13
PULLED = 14
D2D_PULLING = 15
class PluginReqStatus(PluginBase):
name = "plugin_req_status"
depends = ["plugin_common"]
_warned_no_request_status = False
@classmethod
@timer(logger.debug)
def parse(cls, data):
tx_data_df = data.get('tx_data_df')
if tx_data_df is None or tx_data_df.empty:
return data
if 'status' in tx_data_df.columns:
return data
tx_data_df['message'] = tx_data_df['message'].apply(parse_message_state_name)
rename_mapping = {
col: status_index_to_status_name(col)
for col in tx_data_df.columns
if is_req_status_metric(col)
}
tx_data_df = tx_data_df.rename(columns=rename_mapping)
req_status = list(rename_mapping.values())
if req_status:
tx_data_df = rename_req_status(tx_data_df, req_status)
else:
vllm_req_status = ['WAITING+', 'RUNNING+', 'FINISHED+']
valid_cols = [col for col in vllm_req_status if col in tx_data_df.columns]
if not valid_cols and 'QueueSize=' in tx_data_df.columns and 'scope#QueueName' in tx_data_df.columns:
logger.debug(
"Skip request status normalization because queue-based status schema is used"
)
return data
if not valid_cols and not cls._warned_no_request_status:
logger.info(
"Skip request status normalization because no request-status fields were found in current process data"
)
cls._warned_no_request_status = True
return data
tx_data_df = rename_req_status(tx_data_df, valid_cols)
if 'domain' in tx_data_df.columns:
tx_data_df['name'] = tx_data_df['name'].fillna(tx_data_df['domain'])
tx_data_df['domain'] = tx_data_df['domain'].fillna(tx_data_df['name'])
data['tx_data_df'] = tx_data_df
return data
def is_req_status_metric(metric):
flag = is_metric(metric) and metric[:-1].isdigit()
return flag
def status_index_to_status_name(metric):
if not is_req_status_metric(metric):
return metric
try:
index = int(metric[:-1])
except ValueError as ex:
raise ValueError(f"Invalid status index: {metric[:-1]}") from ex
if index not in [status.value for status in ReqStatus]:
return metric
return f"{ReqStatus(index).name}{metric[-1]}"
def parse_message_state_name(message):
if not isinstance(message, dict):
raise ValueError(f"Message must be a dict, but got {type(message)}")
new_message = {}
for key, value in message.items():
new_message[status_index_to_status_name(key)] = value
return new_message
def rename_req_status(tx_data_df, req_status):
real_status = tx_data_df[req_status].gt(0)
real_status.columns = real_status.columns.str.replace('+', '', regex=False)
indexer = tx_data_df['name'] == 'ReqState'
tx_data_df.loc[indexer, 'name'] = real_status.idxmax(axis=1).where(real_status.any(axis=1), \
tx_data_df.loc[indexer, 'name'])
tx_data_df.loc[indexer, 'domain'] = "RequestState"
return tx_data_df