import bisect
import json
import os
from collections.abc import Callable
from copy import deepcopy
from enum import Enum
from inspect import isfunction
from math import isinf, isclose, isnan
from pathlib import Path
from typing import Any, List, Tuple, Type, Optional, Union, Dict
import numpy as np
from loguru import logger
from pydantic import BaseModel, field_validator, Field, model_validator
from pydantic_settings import (
BaseSettings,
SettingsConfigDict,
PydanticBaseSettingsSource,
TomlConfigSettingsSource,
)
from ..common import is_vllm, is_mindie, ais_bench_exists
from ..io_utils import open_file
from ..config.custom_command import (
VllmBenchmarkCommandConfig,
MindieCommandConfig,
VllmCommandConfig,
AisBenchCommandConfig,
)
from . import base_config
from .base_config import (
INSTALL_PATH,
RUN_PATH,
ServiceType,
ms_serviceparam_optimizer_config_path,
)
CUSTOM_OUTPUT = base_config.CUSTOM_OUTPUT
MODEL_EVAL_STATE_CONFIG_PATH = base_config.MODEL_EVAL_STATE_CONFIG_PATH
class MetricAlgorithm(BaseModel):
metric: str = "TTFT"
algorithm: str = "average"
class PerformanceConfig(BaseModel):
time_to_first_token: MetricAlgorithm = MetricAlgorithm(metric="TTFT", algorithm="average")
time_per_output_token: MetricAlgorithm = MetricAlgorithm(metric="TPOT", algorithm="average")
dtype_func = {"int": int, "float": float, "str": str}
class ErrorSeverity(Enum):
"""Error severity level"""
FATAL = "fatal"
RETRYABLE = "retryable"
class ErrorType(Enum):
"""Error type classification"""
OUT_OF_MEMORY = "out_of_memory"
DEVICE_ERROR = "device_error"
NETWORK_ERROR = "network_error"
IO_ERROR = "io_error"
UNKNOWN = "unknown"
class OptimizerConfigField(BaseModel):
name: str = "max_batch_size"
config_position: str = "BackendConfig.ScheduleConfig.maxBatchSize"
min: float = 0.0
max: float = 100.0
dtype: str = "float"
value: Union[int, float, bool, str] = 0.0
dtype_param: Any = None
constant: Optional[float] = None
@model_validator(mode="after")
def update_constant(self):
if self.min > self.max:
raise ValueError(f"min({self.min}) > max({self.max}). please check")
if self.constant and not isclose(self.min, self.max):
self.min = self.max = self.constant
elif self.constant is None and isclose(self.min, self.max, rel_tol=1e-5) and self.dtype in dtype_func:
self.constant = dtype_func.get(self.dtype, float)(self.max)
return self
def convert_dtype(self, value):
if self.dtype == "str":
return str(value)
return dtype_func.get(self.dtype, float)(value)
def find_available_value(self, value):
if self.dtype == "str":
return str(value)
_new_value = dtype_func.get(self.dtype, float)(value)
if self.dtype == "enum":
enum_values = list(self.dtype_param) if isinstance(self.dtype_param, (list, tuple)) else []
if not enum_values:
return _new_value
if isinstance(enum_values[0], str):
if value in enum_values:
return value
return enum_values[0]
if value in enum_values:
return value
_index = bisect.bisect_left(enum_values, value)
if _index == len(enum_values):
_new_value = enum_values[-1]
else:
_new_value = enum_values[_index]
return _new_value
if self.min <= _new_value <= self.max:
return _new_value
if _new_value < self.min:
return dtype_func.get(self.dtype, float)(self.min)
return dtype_func.get(self.dtype, float)(self.max)
default_support_field = [
OptimizerConfigField(
name="max_batch_size",
config_position="BackendConfig.ScheduleConfig.maxBatchSize",
min=10,
max=1000,
dtype="int",
),
OptimizerConfigField(
name="max_prefill_batch_size",
config_position="BackendConfig.ScheduleConfig.maxPrefillBatchSize",
min=0.1,
max=0.7,
dtype="ratio",
dtype_param="max_batch_size",
),
OptimizerConfigField(
name="prefill_time_ms_per_req",
config_position="BackendConfig.ScheduleConfig.prefillTimeMsPerReq",
max=1000,
dtype="int",
),
OptimizerConfigField(
name="decode_time_ms_per_req",
config_position="BackendConfig.ScheduleConfig.decodeTimeMsPerReq",
max=1000,
dtype="int",
),
OptimizerConfigField(
name="support_select_batch",
config_position="BackendConfig.ScheduleConfig.supportSelectBatch",
max=1,
dtype="bool",
),
OptimizerConfigField(
name="max_prefill_token",
config_position="BackendConfig.ScheduleConfig.maxPrefillTokens",
min=4096,
max=409600,
dtype="int",
),
OptimizerConfigField(
name="max_queue_deloy_mircroseconds",
config_position="BackendConfig.ScheduleConfig.maxQueueDelayMicroseconds",
min=500,
max=1000000,
dtype="int",
),
OptimizerConfigField(
name="prefill_policy_type",
config_position="BackendConfig.ScheduleConfig.prefillPolicyType",
min=0,
max=1,
dtype="enum",
dtype_param=[0, 1, 3],
),
OptimizerConfigField(
name="decode_policy_type",
config_position="BackendConfig.ScheduleConfig.decodePolicyType",
min=0,
max=1,
dtype="enum",
dtype_param=[0, 1, 3],
),
OptimizerConfigField(
name="max_preempt_count",
config_position="BackendConfig.ScheduleConfig.maxPreemptCount",
min=0,
max=1,
dtype="ratio",
dtype_param="max_batch_size",
),
OptimizerConfigField(
name="tp",
config_position="BackendConfig.ModelDeployConfig.ModelConfig.0.tp",
min=0,
max=1,
dtype="enum",
dtype_param=[1, 2, 4, 8, 16],
),
OptimizerConfigField(
name="dp",
config_position="BackendConfig.ModelDeployConfig.ModelConfig.0.dp",
min=0,
max=0,
dtype="factories",
dtype_param={"target_name": "tp", "product": 16, "dtype": "int"},
),
OptimizerConfigField(
name="moe_ep",
config_position="BackendConfig.ModelDeployConfig.ModelConfig.0.moe_ep",
min=0,
max=1,
dtype="enum",
dtype_param=[1, 2, 4, 8, 16],
),
OptimizerConfigField(
name="moe_tp",
config_position="BackendConfig.ModelDeployConfig.ModelConfig.0.moe_tp",
min=0,
max=0,
dtype="factories",
dtype_param={"target_name": "moe_ep", "product": 16, "dtype": "int"},
),
]
def range_to_enum(params_field: Tuple[OptimizerConfigField, ...]):
for v in params_field:
if v.dtype != "range":
continue
if not v.dtype_param:
continue
try:
_start = int(v.min)
_end = int(v.max)
_step = int(v.dtype_param)
except (ValueError, TypeError):
logger.error(f"Failed convert to int data, data: {v.min, v.max, v.dtype_param}")
continue
_enums = list(range(_start, _end + _step, _step))
v.min = 0
v.max = 1
v.dtype_param = _enums
v.dtype = "enum"
class DecodeContext(BaseModel):
"""
Particle decode context, used by the balanced strategy to evenly distribute repair priority
across different particles and iteration rounds.
Attributes:
particle_index: Current particle index (0-based)
n_particles: Total number of particles in the population
iteration: Current iteration round (0-based), used by balanced strategy to alternate
direction between rounds, preventing the same particle from being locked
to the same repair order throughout the optimization process.
When None, falls back to pure particle index splitting.
"""
particle_index: Optional[int] = None
n_particles: Optional[int] = None
iteration: Optional[int] = None
def resolve_priority(dtype_param: dict, context=None) -> list:
"""
Resolve field priority order for repair based on priority_policy and particle context.
Strategies:
- fixed: Uses the explicit order specified by dtype_param["priority"];
falls back to target_names order if not specified
- balanced: Splits two directions evenly by particle index, reducing structural bias
introduced by a single decode order;
falls back to target_names order when no context is available
Args:
dtype_param: The dtype_param dict from ternary_factories
context: DecodeContext instance, or None (non-PSO path)
Returns:
[High priority field name, Low priority field name]
"""
target_names = dtype_param.get("target_names", [])
if len(target_names) < 2:
return list(target_names)
policy = dtype_param.get("priority_policy", "balanced")
if policy == "fixed":
priority = list(dtype_param.get("priority", target_names))
if len(priority) != len(target_names) or set(priority) != set(target_names):
logger.warning(f"Invalid fixed priority {priority}; fallback to target_names {target_names}.")
return list(target_names)
return priority
if policy == "balanced":
if context is None or context.particle_index is None or context.n_particles is None:
return list(target_names)
reverse = context.particle_index >= context.n_particles / 2
if context.iteration is not None and context.iteration % 2 == 1:
reverse = not reverse
return list(reversed(target_names)) if reverse else list(target_names)
return list(target_names)
def _repair_ternary_factories_with_priority(
v, simulate_run_info, params_field, product, min_val, max_val, conv, context=None
):
"""
Priority-aware constraint repair (new version): replaces the global nearest-distance strategy
of _repair_ternary_factories.
Repair is performed in two stages:
- Stage 1 (keep high priority field): fix the current value of the keep field, search candidate values for the adjust field
- Stage 2 (both fields adjustable): joint search using candidate values sorted by their respective distances
Fallback behavior is compatible with the old version: if repair fails, the caller falls back to clamping.
Args:
v: The OptimizerConfigField definition for the current derived field
simulate_run_info:Mutable list of field copies (will be modified in-place)
params_field: Original field definition tuple (used to obtain candidate value ranges)
product: The product value from dtype_param
min_val: Result lower bound (None means unbounded)
max_val: Result upper bound (None means unbounded)
conv: Type conversion function (int / float)
context: DecodeContext instance, determines balanced strategy direction (degenerates when None)
Returns:
True Repair succeeded, simulate_run_info has been updated in-place
False Unable to repair, caller should fall back to clamping
"""
target_names = v.dtype_param.get("target_names", [])
if len(target_names) < 2:
return False
priority = resolve_priority(v.dtype_param, context)
keep_name = priority[0]
adjust_name = priority[1]
def_by_name = {f.name: f for f in params_field}
sim_by_name = {f.name: f for f in simulate_run_info}
def_keep = def_by_name.get(keep_name)
def_adjust = def_by_name.get(adjust_name)
if def_keep is None or def_adjust is None:
return False
cands_keep = _get_field_candidates(def_keep)
cands_adjust = _get_field_candidates(def_adjust)
if not cands_keep or not cands_adjust:
return False
cur_keep = sim_by_name[keep_name].value if keep_name in sim_by_name else 0
cur_adjust = sim_by_name[adjust_name].value if adjust_name in sim_by_name else 0
is_int_dtype = v.dtype_param.get("dtype", "int") == "int"
cands_keep_sorted = sorted(cands_keep, key=lambda c: abs(c - (cur_keep or 0)))
cands_adjust_sorted = sorted(cands_adjust, key=lambda c: abs(c - (cur_adjust or 0)))
def is_valid_combination(keep_val, adjust_val):
"""Check if (keep_val, adjust_val) combination is valid, returns (ok, result)"""
if not keep_val or not adjust_val:
return False, None
divisor = keep_val * adjust_val
if divisor == 0:
return False, None
if is_int_dtype and product % divisor != 0:
return False, None
result = conv(product / divisor)
if min_val is not None and result < min_val:
return False, None
if max_val is not None and result > max_val:
return False, None
return True, result
def apply_result(keep_val, adjust_val, result, stage):
old_derived = sim_by_name[v.name].value if v.name in sim_by_name else None
sim_by_name[keep_name].value = keep_val
sim_by_name[adjust_name].value = adjust_val
sim_by_name[v.name].value = result
keep_part = f"{keep_name}={keep_val}(kept)" if keep_val == cur_keep else f"{keep_name}: {cur_keep}→{keep_val}"
adjust_part = (
f"{adjust_name}={adjust_val}(kept)"
if adjust_val == cur_adjust
else f"{adjust_name}: {cur_adjust}→{adjust_val}"
)
derived_part = f"{v.name}: {old_derived}→{result}"
logger.info(
f"ternary_factories repair [{stage}] '{v.name}' "
f"(policy={v.dtype_param.get('priority_policy', 'balanced')}): "
f"{keep_part}, {adjust_part}, {derived_part} "
f"(product={product})"
)
for adjust_val in cands_adjust_sorted:
ok, result = is_valid_combination(cur_keep, adjust_val)
if ok:
apply_result(cur_keep, adjust_val, result, "stage1-fix-keep")
return True
for keep_val in cands_keep_sorted:
for adjust_val in cands_adjust_sorted:
ok, result = is_valid_combination(keep_val, adjust_val)
if ok:
apply_result(keep_val, adjust_val, result, "stage2-both-adjust")
return True
return False
def _get_field_candidates(field_def):
"""
Get the list of candidate discrete values for a field, used for ternary_factories constraint repair search.
- enum type: returns the numeric candidate list from dtype_param
- int type (range <= 256): returns the integer interval [min, max]
- Other types or range too large: returns None, indicating enumeration is not possible (fall back to clamping)
Args:
field_def: OptimizerConfigField definition object
Returns:
List of candidate values, or None (when enumeration is not possible)
"""
if field_def.dtype == "enum":
params = field_def.dtype_param
return [p for p in params if isinstance(p, (int, float))] if params else []
if field_def.dtype == "int":
lo, hi = int(field_def.min), int(field_def.max)
if 0 <= hi - lo <= 256:
return list(range(lo, hi + 1))
return None
def _update_ratio_field(field, i, params_field, simulate_run_info, decode_context=None):
"""ratio type handler: value = int(self_ratio × target.value)"""
_field = simulate_run_info[i]
_t_op = [_op for _op in simulate_run_info if _op.name == field.dtype_param][0]
_field.value = int(_field.value * _t_op.value)
def _update_factories_field(field, i, params_field, simulate_run_info, decode_context=None):
"""factories type handler: value = product / target.value"""
_field = simulate_run_info[i]
_t_op = [_op for _op in simulate_run_info if _op.name == field.dtype_param["target_name"]][0]
if _t_op.value != 0:
_field.value = dtype_func.get(field.dtype_param["dtype"], int)(field.dtype_param["product"] / _t_op.value)
def _update_times_field(field, i, params_field, simulate_run_info, decode_context=None):
"""times type handler: value = product × target.value"""
_field = simulate_run_info[i]
_t_op = [_op for _op in simulate_run_info if _op.name == field.dtype_param["target_name"]][0]
if _t_op.value is not None and not (isnan(_t_op.value) if isinstance(_t_op.value, float) else False):
_field.value = dtype_func.get(field.dtype_param["dtype"], int)(field.dtype_param["product"] * _t_op.value)
else:
logger.warning(f"Target value for {field.name} is invalid, skipping times calculation")
def _update_ternary_factories_field(field, i, params_field, simulate_run_info, decode_context=None):
"""
ternary_factories type handler: value = product / (field_a × field_b)
dtype_param structure: {"target_names": ["field_a", "field_b"], "product": 16, "dtype": "int",
"min_value": 1, # optional, result lower bound, int type defaults to 1
"max_value": 16} # optional, result upper bound
When result is out of bounds: first try constraint repair (adjust source fields to find nearest valid combo),
fall back to clamping if repair fails.
"""
_field = simulate_run_info[i]
target_names = field.dtype_param.get("target_names", [])
target_ops = [_op for _op in simulate_run_info if _op.name in target_names]
found_names = {op.name for op in target_ops}
missing = [n for n in target_names if n not in found_names]
if missing:
logger.warning(
f"ternary_factories '{field.name}': target_names {missing} not found in fields. "
f"Check for typos or case mismatch. Available fields: "
f"{[op.name for op in simulate_run_info]}"
)
return
divisor = 1
for _t_op in target_ops:
if _t_op.value != 0:
divisor *= _t_op.value
else:
logger.warning(f"Target value {_t_op.name} is 0, skipping ternary_factories calculation for {field.name}")
return
product = field.dtype_param.get("product", 1)
conv = dtype_func.get(field.dtype_param.get("dtype", "int"), int)
result_value = conv(product / divisor)
min_value = field.dtype_param.get("min_value", 1 if field.dtype_param.get("dtype", "int") == "int" else None)
max_value = field.dtype_param.get("max_value", None)
is_int_dtype = field.dtype_param.get("dtype", "int") == "int"
needs_repair = (
(min_value is not None and result_value < min_value)
or (max_value is not None and result_value > max_value)
or (is_int_dtype and product % divisor != 0)
)
if needs_repair:
if not _repair_ternary_factories_with_priority(
field,
simulate_run_info,
params_field,
product,
min_value,
max_value,
conv,
context=decode_context,
):
repaired = False
if min_value is not None and result_value < min_value:
logger.warning(
f"ternary_factories priority repair failed for '{field.name}'; "
f"fallback to clamp: {result_value} → min_value {min_value}."
)
_field.value = conv(min_value)
repaired = True
if max_value is not None and result_value > max_value:
logger.warning(
f"ternary_factories priority repair failed for '{field.name}'; "
f"fallback to clamp: {result_value} → max_value {max_value}."
)
_field.value = conv(max_value)
repaired = True
if not repaired:
if is_int_dtype and product % divisor != 0:
raise ValueError(
f"ternary_factories constraint violated for '{field.name}': "
f"product={product} not divisible by divisor={divisor} "
f"(targets={target_names}), and repair could not find valid source values."
)
_field.value = result_value
else:
_field.value = result_value
def _update_ternary_times_field(field, i, params_field, simulate_run_info, decode_context=None):
"""
ternary_times type handler: value = product × field_a × field_b
dtype_param structure: {"target_names": ["field_a", "field_b"], "product": 1, "dtype": "int"}
"""
_field = simulate_run_info[i]
target_names = field.dtype_param.get("target_names", [])
target_ops = [_op for _op in simulate_run_info if _op.name in target_names]
found_names = {op.name for op in target_ops}
missing = [n for n in target_names if n not in found_names]
if missing:
logger.warning(
f"ternary_times '{field.name}': target_names {missing} not found in fields. "
f"Check for typos or case mismatch. Available fields: "
f"{[op.name for op in simulate_run_info]}"
)
return
result = field.dtype_param.get("product", 1)
for _t_op in target_ops:
if _t_op.value is not None and not (isnan(_t_op.value) if isinstance(_t_op.value, float) else False):
result *= _t_op.value
else:
logger.warning(f"Target value {_t_op.name} for {field.name} is invalid, skipping ternary_times calculation")
return
_field.value = dtype_func.get(field.dtype_param.get("dtype", "int"), int)(result)
def _update_share_field(field, i, params_field, simulate_run_info, decode_context=None):
"""share type handler: value = int(target.min + target.max - target.value)"""
_field = simulate_run_info[i]
for _op in simulate_run_info:
if _op.name == field.dtype_param:
_field.value = int(_op.min + _op.max - _op.value)
break
DERIVED_FIELD_HANDLERS = {
"ratio": _update_ratio_field,
"share": _update_share_field,
"factories": _update_factories_field,
"times": _update_times_field,
"ternary_factories": _update_ternary_factories_field,
"ternary_times": _update_ternary_times_field,
}
def update_optimizer_value(
params_field: Tuple[OptimizerConfigField, ...],
simulate_run_info: Tuple[OptimizerConfigField, ...],
support_select_is_false,
decode_context: Optional["DecodeContext"] = None,
):
"""
Post-process and assign derived field values in simulate_run_info based on inter-field dependencies.
This function handles fields with the following derived dtypes (these fields typically have min=max and are marked
as constants, with their values derived by this function):
Binary relations (depend on a single field)
-------------------------------------------
- ``ratio`` : value = int(self_ratio × target.value)
- ``factories`` : value = product / target.value (skipped when target.value is 0)
- ``times`` : value = product × target.value (skipped when target.value is None/NaN)
Ternary relations (depend on two fields)
-----------------------------------------
- ``ternary_factories``: value = product / (field_a.value × field_b.value)
Skipped with warning when any dependent field value is 0.
dtype_param format: {"target_names": [str, str], "product": number, "dtype": str}
- ``ternary_times`` : value = product × field_a.value × field_b.value
Skipped with warning when any dependent field value is None or NaN.
dtype_param format: {"target_names": [str, str], "product": number, "dtype": str}
Also handles the following business constraints:
- maxPrefillBatchSize field value is forced to 1 when it is 0.
- When support_select_is_false is True, prefillTimeMsPerReq / decodeTimeMsPerReq are forced to 0.
Args:
params_field: Original field definition tuple, used to determine each field's dtype and dtype_param.
simulate_run_info: Deep-copied list of the same length as params_field; values will be modified in-place.
support_select_is_false: Pass True when supportSelectBatch field value is False,
triggering prefill/decode time field zeroing logic.
"""
for i, v in enumerate(params_field):
handler = DERIVED_FIELD_HANDLERS.get(v.dtype)
if handler:
handler(v, i, params_field, simulate_run_info, decode_context)
if "maxPrefillBatchSize" in v.config_position:
_field = simulate_run_info[i]
if _field.value == 0:
_field.value = 1
if support_select_is_false:
_field = simulate_run_info[i]
if "prefillTimeMsPerReq" in _field.config_position:
_field.value = 0
if "decodeTimeMsPerReq" in _field.config_position:
_field.value = 0
def map_param_with_value(
params: np.ndarray,
params_field: Tuple[OptimizerConfigField, ...],
decode_context: Optional["DecodeContext"] = None,
):
_simulate_run_info = []
_support_select_is_false = False
i = 0
for v in params_field:
_field = deepcopy(v)
if _field.constant is not None or isclose(_field.min, _field.max, rel_tol=1e-5):
if _field.value and not isinf(_field.value):
try:
_field.value = dtype_func.get(v.dtype, int)(_field.value)
except (ValueError, TypeError) as e:
logger.warning(f"Failed in func {params[i]} for {v}, error: {e}")
_simulate_run_info.append(_field)
continue
if v.dtype == "int":
try:
_field.value = int(params[i])
except (ValueError, TypeError):
logger.warning(f"Failed convert to int data, data: {params[i]}")
_field.value = params[i]
elif v.dtype == "bool":
if params[i] > 0.5:
_field.value = True
if "supportSelectBatch" in _field.name:
_support_select_is_false = True
else:
_field.value = False
elif v.dtype == "enum":
if v.dtype_param and len(v.dtype_param) > 0 and isinstance(v.dtype_param[0], str):
num_options = len(v.dtype_param)
if num_options == 1:
_field.value = v.dtype_param[0]
else:
normalized = (params[i] - v.min) / (v.max - v.min) if v.max > v.min else 0
_enum_index = int(normalized * (num_options - 1) + 0.5)
_enum_index = max(0, min(_enum_index, num_options - 1))
_field.value = v.dtype_param[_enum_index]
else:
segment = np.linspace(v.min, v.max, len(v.dtype_param) + 1)
if params[i] <= v.min:
_field.value = v.dtype_param[0]
elif params[i] >= v.max:
_field.value = v.dtype_param[-1]
else:
_enum_index = np.searchsorted(segment, params[i]) - 1
_field.value = v.dtype_param[_enum_index]
else:
try:
_field.value = float(params[i])
except (ValueError, TypeError):
logger.warning(f"Failed convert to float data, data: {params[i]}")
_field.value = params[i]
i += 1
_simulate_run_info.append(_field)
update_optimizer_value(
params_field,
tuple(_simulate_run_info),
_support_select_is_false,
decode_context,
)
return _simulate_run_info
def reverse_special_field(params_field: Tuple[OptimizerConfigField, ...], params: np.ndarray, concurrency: int):
_params = params
i = 0
for v in params_field:
if v.constant is not None or isclose(v.min, v.max, rel_tol=1e-5):
continue
if v.dtype == "ratio":
for _op in params_field:
if _op.name == v.dtype_param and _op.value != 0:
_t_op = _op
_params[i] = float(v.value / _t_op.value)
if v.name in ["CONCURRENCY", "MAXCONCURRENCY"]:
if v.value == 0 and v.dtype == "ratio":
_params[i] = 1
elif v.value is not None and v.dtype == "ratio" and concurrency > 0:
_params[i] = v.value / concurrency
elif v.value is not None:
_params[i] = v.value
else:
_params[i] = concurrency
i += 1
return _params
def field_to_param(params_field: Tuple[OptimizerConfigField, ...]):
concurrency = None
_params = []
for _, v in enumerate(params_field):
if v.constant is not None or isclose(v.min, v.max, rel_tol=1e-5):
continue
if v.dtype == "int":
try:
_params.append(int(v.value))
except Exception as e:
logger.warning(f"Failed in field to param, error: {e}")
_params.append(v.value)
elif v.dtype == "bool":
if v.value:
_params.append(1)
else:
_params.append(0)
elif v.dtype == "enum":
if v.value not in v.dtype_param and isinstance(v.value, str):
v.dtype_param.append(v.value)
if v.value not in v.dtype_param and isinstance(v.value, (int, float)):
v.dtype_param.sort()
bisect.insort_left(v.dtype_param, v.value)
_index = v.dtype_param.index(v.value)
segment = np.linspace(v.min, v.max, len(v.dtype_param) + 1)
_params.append((segment[_index] + segment[_index + 1]) / 2)
else:
_params.append(v.value)
if v.config_position == "BackendConfig.ScheduleConfig.maxBatchSize" or v.name in [
"MAX_NUM_SEQS",
"max_batch_size",
]:
concurrency = v.value
_params = np.array(_params, dtype=float)
return reverse_special_field(params_field, _params, concurrency)
class PerformanceIndex(BaseModel):
generate_speed: Optional[float] = None
time_to_first_token: Optional[float] = None
time_per_output_token: Optional[float] = None
success_rate: Optional[float] = None
throughput: Optional[float] = None
class DataStorageConfig(BaseModel):
store_dir: Path = Path("store")
pso_top_k: int = 3
@field_validator("store_dir")
@classmethod
def create_path(cls, path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True, mode=0o750)
return path
class LatencyModel(BaseModel):
base_path: Path = Path("latency_model")
model_path: Optional[Path] = Field(
default_factory=lambda data: (data["base_path"].joinpath("bak/base/xgb_model.ubj").resolve())
)
static_file_dir: Optional[Path] = Field(
default_factory=lambda data: (data["base_path"].joinpath("model_static_file").resolve()),
validate_default=True,
)
req_and_decode_file: Optional[Path] = Field(
default_factory=lambda data: (data["base_path"].joinpath("req_id_and_decode_num.json").resolve())
)
cache_data: Optional[Path] = Field(default_factory=lambda data: data["base_path"].joinpath("cache").resolve())
@field_validator("base_path", "cache_data", "static_file_dir")
@classmethod
def create_path(cls, path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path
def _get_mindie_config_paths():
"""Get mindie configuration file paths"""
default_config_path = Path("/usr/local/Ascend/mindie/latest/mindie-service/conf/config.json")
default_config_bak_path = Path("/usr/local/Ascend/mindie/latest/mindie-service/conf/config_bak.json")
if not default_config_path.is_file():
mies_install_path = os.getenv("MIES_INSTALL_PATH")
if mies_install_path:
new_config_path_parent = Path(mies_install_path).parent
return (
new_config_path_parent / "mindie_llm/conf/config.json",
new_config_path_parent / "mindie_llm/conf/config_bak.json",
)
return default_config_path, default_config_bak_path
class MindieConfig(BaseModel):
process_name: str = "mindie, mindie-llm, mindieservice_daemon, mindie_llm"
output: Path = Path("mindie")
work_path: Path = Field(default_factory=lambda: Path(os.getcwd()).resolve())
config_path: Path = Field(default_factory=lambda: _get_mindie_config_paths()[0])
config_bak_path: Path = Field(default_factory=lambda: _get_mindie_config_paths()[1])
command: MindieCommandConfig = MindieCommandConfig()
target_field: List[OptimizerConfigField] = default_support_field
class AisBenchConfig(BaseModel):
process_name: str = "ais_bench"
output_path: Path = Path("ais_bench")
work_path: Path = Field(default_factory=lambda: Path(os.getcwd()).resolve())
command: AisBenchCommandConfig = AisBenchCommandConfig()
performance_config: PerformanceConfig = PerformanceConfig()
target_field: List[OptimizerConfigField] = Field(default_factory=list)
model: str = ""
path: str = ""
host_ip: str = ""
host_port: int = 0
max_out_len: int = 0
best_concurrency_coefficient: int = 3
best_concurrency_threshold: int = 200
class VllmBenchmarkConfig(BaseModel):
output_path: Path = Path("vllm")
process_name: str = ""
command: VllmBenchmarkCommandConfig = VllmBenchmarkCommandConfig()
performance_config: PerformanceConfig = PerformanceConfig()
target_field: List[OptimizerConfigField] = Field(default_factory=list)
class VllmConfig(BaseModel):
output: Path = Path("vllm")
process_name: str = "vllm"
work_path: Path = Field(default_factory=lambda: Path(os.getcwd()).resolve())
command: VllmCommandConfig = VllmCommandConfig()
target_field: List[OptimizerConfigField] = Field(default_factory=list)
class PsoOptions(BaseModel):
c1: float = 2.0
c2: float = 2.0
w: float = 1.8
class PsoStrategy(BaseModel):
w: str = "exp_decay"
c1: str = "exp_decay"
c2: str = "exp_decay"
class ErrorPatternConfig(BaseModel):
"""Error pattern configuration - 3-tier design: ErrorType -> patterns -> severity"""
fatal_patterns: Dict[ErrorType, List[str]] = Field(
default_factory=lambda: {
ErrorType.OUT_OF_MEMORY: [],
ErrorType.DEVICE_ERROR: [],
}
)
retryable_patterns: Dict[ErrorType, List[str]] = Field(
default_factory=lambda: {ErrorType.NETWORK_ERROR: [], ErrorType.IO_ERROR: []}
)
class HealthCheckConfig(BaseModel):
"""Health check configuration"""
service_errors: ErrorPatternConfig = Field(default_factory=ErrorPatternConfig)
benchmark_errors: ErrorPatternConfig = Field(
default_factory=lambda: ErrorPatternConfig(
fatal_patterns={},
retryable_patterns={ErrorType.NETWORK_ERROR: [], ErrorType.IO_ERROR: []},
)
)
log_snippet_length: int = 50
class Settings(BaseSettings):
"""
Settings class definition, initialized by reading configuration files
"""
model_config = SettingsConfigDict(
toml_file=[
INSTALL_PATH.joinpath("model_eval_state.toml"),
Path("~/model_eval_state.toml").expanduser(),
RUN_PATH.joinpath("model_eval_state.toml"),
INSTALL_PATH.joinpath("config.toml"),
INSTALL_PATH.joinpath("optix/config.toml"),
Path("~/config.toml").expanduser(),
RUN_PATH.joinpath("config.toml"),
ms_serviceparam_optimizer_config_path,
],
env_prefix="model_eval_state_",
)
output: Path = Field(
default_factory=lambda: Path(os.getcwd()).joinpath("result").resolve(),
validate_default=True,
)
simulator_output: Path = Field(default_factory=lambda data: data["output"].joinpath("simulator").resolve())
pso_options: PsoOptions = PsoOptions()
pso_strategy: PsoStrategy = PsoStrategy()
particles_time_out: int = 1 * 60 * 60
wait_start_time: int = 1800
n_particles: int = Field(default=5, gt=0, lt=1000)
iters: int = Field(default=10, gt=0, lt=1000)
ftol: float = -np.inf
ftol_iter: int = 1
ttft_penalty: float = 3.0
tpot_penalty: float = 3.0
success_rate_penalty: float = 5.0
ttft_slo: float = Field(default=0.5, gt=0)
tpot_slo: float = Field(default=0.05, gt=0)
success_rate_slo: float = Field(default=1.0, gt=0)
slo_coefficient: float = 0.1
generate_speed_target: float = 5000.0
sample_size: Optional[int] = None
mem_coefficient: float = 0.8
max_fine_tune: int = 10
scaling_coefficient: float = 1.3
step_size: float = 0.6
theory_guided_enable: bool = True
service: str = ServiceType.master.value
latency_model: LatencyModel = Field(
default_factory=lambda data: LatencyModel(base_path=data["output"].joinpath("latency_model")),
validate_default=True,
)
vllm: VllmConfig = Field(
default_factory=lambda data: VllmConfig(output=data["output"].joinpath("vllm")),
validate_default=True,
)
mindie: MindieConfig = Field(
default_factory=lambda data: MindieConfig(output=data["output"].joinpath("mindie")),
validate_default=True,
)
ais_bench: AisBenchConfig = AisBenchConfig()
vllm_benchmark: VllmBenchmarkConfig = VllmBenchmarkConfig()
data_storage: DataStorageConfig = Field(
default_factory=lambda data: DataStorageConfig(store_dir=data["output"].joinpath("store")),
validate_default=True,
)
health_check: HealthCheckConfig = Field(default_factory=HealthCheckConfig)
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
TomlConfigSettingsSource(settings_cls),
file_secret_settings,
)
@field_validator("output", "simulator_output")
@classmethod
def create_path(cls, path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True, mode=0o750)
return path
@model_validator(mode="after")
def partial_update_vllm(self):
if not is_vllm():
return self
output = VllmConfig.model_fields["output"].default
if self.vllm.output == output:
self.vllm.output = self.output.joinpath(output)
output = VllmBenchmarkConfig.model_fields["output_path"].default
if self.vllm_benchmark.output_path == output:
self.vllm_benchmark.output_path = self.output.joinpath(output)
if self.vllm_benchmark.command.result_dir == VllmBenchmarkCommandConfig.model_fields["result_dir"].default:
self.vllm_benchmark.command.result_dir = str(self.vllm_benchmark.output_path.joinpath("result"))
Path(self.vllm_benchmark.command.result_dir).mkdir(parents=True, exist_ok=True, mode=0o750)
self.vllm_benchmark.command.host = self.vllm.command.host
self.vllm_benchmark.command.port = self.vllm.command.port
self.vllm_benchmark.command.model = self.vllm.command.model
self.vllm_benchmark.command.served_model_name = self.vllm.command.served_model_name
if self.vllm.target_field:
range_to_enum(self.vllm.target_field)
if self.vllm_benchmark.target_field:
range_to_enum(self.vllm_benchmark.target_field)
return self
@model_validator(mode="after")
def partial_update_aisbench(self):
if not ais_bench_exists():
return self
output = AisBenchConfig.model_fields["output_path"].default
if self.ais_bench.output_path == output:
self.ais_bench.output_path = self.output.joinpath(output)
if not self.ais_bench.command.work_dir:
self.ais_bench.command.work_dir = str(self.ais_bench.output_path)
if self.ais_bench.target_field:
range_to_enum(self.ais_bench.target_field)
return self
@model_validator(mode="after")
def partial_update_mindie(self):
if self.data_storage.store_dir == DataStorageConfig.model_fields["store_dir"].default:
self.data_storage.store_dir = self.output.joinpath("store")
range_to_enum(self.mindie.target_field)
if not is_mindie():
return self
if not self.mindie.config_path.exists():
logger.error(f"File Not Found. file: {self.mindie.config_path!r}")
return self
with open_file(self.mindie.config_path, "r") as f:
try:
json.load(f)
except json.decoder.JSONDecodeError as e:
logger.error(f"Failed in load {self.mindie.config_path!r}. error: {e}")
raise e
output = MindieConfig.model_fields["output"].default
if self.mindie.output == output:
self.mindie.output = self.output.joinpath(output)
return self
custom_settings_func: Optional[Callable] = None
settings = None
def get_settings() -> Settings:
"""
Get the settings object
Return: Settings() instance
"""
global settings
if not settings:
if custom_settings_func and isfunction(custom_settings_func):
settings = custom_settings_func()
else:
settings = Settings()
return settings
def register_settings(func: Optional[Callable] = None) -> None:
"""
Register custom settings - can provide a function to generate or provide new settings
"""
global custom_settings_func
custom_settings_func = func