import dataclasses
import shlex
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Union
_BOOL_OPTIONS = {
"allow_graph_break",
"compile",
"compile_allow_graph_break",
"decode",
"disable_repetition",
"dump_input_shapes",
"enable_dispatch_ffn_combine",
"enable_external_shared_experts",
"enable_redundant_experts",
"enable_sequence_parallel",
"enable_shared_expert_tp",
"host_external_shared_experts",
"quantize_lmhead",
}
_INT_OPTIONS = {
"block_size",
"context_length",
"dp_size",
"ep_size",
"image_batch_size",
"image_height",
"image_width",
"lmhead_dp_size",
"lmhead_tp_size",
"mlp_dp_size",
"mlp_tp_size",
"moe_dp_size",
"moe_tp_size",
"mxfp4_group_size",
"num_devices",
"num_hidden_layers_override",
"num_mtp_tokens",
"num_queries",
"o_proj_dp_size",
"o_proj_tp_size",
"pp_size",
"query_length",
"tp_size",
}
_FLOAT_OPTIONS = {
"prefix_cache_hit_rate",
"reserved_memory_gb",
}
_REPEAT_OPTIONS = {
"performance_model",
}
@dataclasses.dataclass(frozen=True)
class AdaptationContext:
model_id: str
raw_command: str
normalized_args: Dict[str, Any]
artifacts: Dict[str, str] = dataclasses.field(default_factory=dict)
version: int = 1
def to_dict(self) -> Dict[str, Any]:
return {
"version": self.version,
"model_id": self.model_id,
"raw_command": self.raw_command,
"normalized_args": dict(self.normalized_args),
"artifacts": dict(self.artifacts),
}
def _normalize_key(option: str) -> str:
return option.lstrip("-").replace("-", "_")
def _coerce_value(key: str, value: str) -> Any:
if key in _INT_OPTIONS:
return int(value)
if key in _FLOAT_OPTIONS:
return float(value)
return value
def _find_model_id(tokens: List[str]) -> str:
for index, token in enumerate(tokens):
if token in {"cli.inference.text_generate", "cli.inference.video_generate"}:
if index + 1 < len(tokens):
return tokens[index + 1]
if token.endswith("text_generate.py") or token.endswith("video_generate.py"):
if index + 1 < len(tokens):
return tokens[index + 1]
raise ValueError("Could not find model id in simulation command.")
def _iter_option_tokens(tokens: List[str], model_id: str) -> Iterable[str]:
try:
start = tokens.index(model_id) + 1
except ValueError:
start = 0
return tokens[start:]
def parse_simulation_command(command: str) -> AdaptationContext:
raw_command = " ".join(line.strip().rstrip("\\") for line in command.strip().splitlines() if line.strip())
tokens = shlex.split(raw_command)
if not tokens:
raise ValueError("Simulation command is empty.")
model_id = _find_model_id(tokens)
normalized_args: Dict[str, Any] = {}
option_tokens = list(_iter_option_tokens(tokens, model_id))
index = 0
while index < len(option_tokens):
token = option_tokens[index]
if not token.startswith("--"):
index += 1
continue
if "=" in token:
option, raw_value = token.split("=", maxsplit=1)
key = _normalize_key(option)
value = _coerce_value(key, raw_value)
index += 1
else:
key = _normalize_key(token)
if key in _BOOL_OPTIONS:
value = True
index += 1
elif index + 1 < len(option_tokens) and not option_tokens[index + 1].startswith("--"):
value = _coerce_value(key, option_tokens[index + 1])
index += 2
else:
value = True
index += 1
if key in _REPEAT_OPTIONS:
normalized_args.setdefault(key, []).append(value)
else:
normalized_args[key] = value
return AdaptationContext(
model_id=model_id,
raw_command=raw_command,
normalized_args=normalized_args,
)
def load_command_text(path: Union[str, Path]) -> str:
return Path(path).read_text(encoding="utf-8")
def load_context_from_command_file(
command_file: Union[str, Path],
raw_insight_file: Optional[Union[str, Path]] = None,
hints_file: Optional[Union[str, Path]] = None,
) -> AdaptationContext:
context = parse_simulation_command(load_command_text(command_file))
artifacts = dict(context.artifacts)
if raw_insight_file is not None:
artifacts["raw_insight_file"] = str(raw_insight_file)
if hints_file is not None:
artifacts["hints_file"] = str(hints_file)
return dataclasses.replace(context, artifacts=artifacts)
def apply_context_to_namespace(args: Any, context: AdaptationContext) -> None:
args.model_id = context.model_id
for key, value in context.normalized_args.items():
setattr(args, key, value)