import importlib
import json
import sys
from pathlib import Path
from types import SimpleNamespace
import pytest
ST_PYTHON_DIR = Path(__file__).resolve().parents[2] / "st" / "python"
sys.path.insert(0, str(ST_PYTHON_DIR))
metric_assertions = importlib.import_module("metric.metric_assertions")
metric_log_health = importlib.import_module("metric.metric_log_health")
metric_runner = importlib.import_module("metric.metric_runner")
metric_scenarios = importlib.import_module("metric.metric_scenarios")
metric_smoke_utils = importlib.import_module("metric.metric_smoke_utils")
MetricExpectation = metric_assertions.MetricExpectation
assert_metric_expectations = metric_assertions.assert_metric_expectations
parse_prometheus_text = metric_assertions.parse_prometheus_text
wait_for_metric_expectations = metric_assertions.wait_for_metric_expectations
assert_metric_log_health = metric_log_health.assert_metric_log_health
scan_metric_log_health = metric_log_health.scan_metric_log_health
ScenarioContext = metric_scenarios.ScenarioContext
get_scenario = metric_scenarios.get_scenario
MetricSmokeError = metric_smoke_utils.MetricSmokeError
_parse_visible_devices = metric_smoke_utils._parse_visible_devices
PROMETHEUS_TEXT = """
# TYPE vllm_profiling_engine:generate:duration histogram
vllm_profiling_engine:generate:duration_sum{dp="0",phase="mixed",role="mixed"} 0.25
vllm_profiling_engine:generate:duration_count{dp="0",phase="mixed",role="mixed"} 1
# TYPE vllm_profiling_scheduler:duration histogram
vllm_profiling_scheduler:duration_sum{dp="0",phase="prefill",role="mixed"} 0.5
vllm_profiling_scheduler:duration_count{dp="0",phase="prefill",role="mixed"} 2
optional_nan NaN
optional_inf +Inf
"""
def test_parse_prometheus_text_preserves_names_labels_and_special_values():
samples = parse_prometheus_text(PROMETHEUS_TEXT)
generate = next(sample for sample in samples if sample.name.endswith("generate:duration_count"))
assert generate.value == 1
assert generate.labels == {"dp": "0", "phase": "mixed", "role": "mixed"}
assert any(sample.value != sample.value for sample in samples if sample.name == "optional_nan")
def test_expectation_supports_fallback_names_labels_and_rules(tmp_path):
report = assert_metric_expectations(
PROMETHEUS_TEXT,
[
MetricExpectation(
names=["old_name", "vllm_profiling_engine:generate:duration_count"],
rules=["exists", "finite", "gt:0"],
labels={"dp": "present", "phase": ["mixed", "decode"]},
)
],
report_path=tmp_path / "assertion_report.json",
)
assert len(report.passed) == 1
assert (tmp_path / "assertion_report.json").exists()
def test_histogram_requires_count_and_sum():
assert_metric_expectations(
PROMETHEUS_TEXT,
[
MetricExpectation(
names=["vllm_profiling_engine:generate:duration"],
histogram=True,
rules=["exists", "finite", "gte:0"],
labels={"dp": "0"},
)
],
)
with pytest.raises(MetricSmokeError):
assert_metric_expectations(
"metric_count 1\n",
[MetricExpectation(names=["metric"], histogram=True)],
)
def test_optional_and_forbidden_expectations():
report = assert_metric_expectations(
PROMETHEUS_TEXT,
[
MetricExpectation(names=["missing_optional"], required=False),
MetricExpectation(names=["must_not_exist"], forbidden=True),
],
)
assert len(report.skipped_optional) == 1
assert len(report.passed) == 1
def test_expectation_checks_distinct_label_values():
text = """
metric_total{dp="0"} 1
metric_total{dp="1"} 2
"""
assert_metric_expectations(
text,
[
MetricExpectation(
names=["metric_total"],
labels={"dp": "present"},
min_distinct_labels={"dp": 2},
)
],
)
def test_expectation_requires_metric_increase_from_baseline():
expectation = MetricExpectation(names=["request_count"], increase=True)
report = assert_metric_expectations(
"request_count 3\n",
[expectation],
baseline_text="request_count 2\n",
)
assert len(report.passed) == 1
with pytest.raises(MetricSmokeError):
assert_metric_expectations(
"request_count 2\n",
[expectation],
baseline_text="request_count 2\n",
)
def test_triggered_optional_metric_skips_when_not_increased():
report = assert_metric_expectations(
"eplb_update_count 2\n",
[
MetricExpectation(
names=["eplb_update_count"],
required=False,
increase=True,
triggered=True,
)
],
baseline_text="eplb_update_count 2\n",
)
assert len(report.skipped_optional) == 1
assert "should increase" in report.skipped_optional[0]["reason"]
def test_metric_polling_retries_until_expectation_passes():
samples = iter(["request_count 1\n", "request_count 2\n"])
calls = []
def fetcher():
calls.append(1)
return next(samples)
metrics_text, report = wait_for_metric_expectations(
fetcher=fetcher,
expectations=[MetricExpectation(names=["request_count"], increase=True)],
baseline_text="request_count 1\n",
timeout=1,
interval=0,
)
assert metrics_text == "request_count 2\n"
assert len(calls) == 2
assert len(report.passed) == 1
def test_metric_log_health_flags_only_metric_failures(tmp_path):
service_log = """
WARNING 06-11 NIXL is not available
[ms_service_metric.symbol] ERROR Failed to apply hook for module: boom
[ms_service_metric.metrics_manager] WARNING Failed to record metric scheduler:duration: bad labels
[ms_service_metric.symbol] WARNING No target available for symbol optional.module:func
"""
report = scan_metric_log_health(service_log)
assert len(report.fatal) == 2
assert len(report.warnings) == 1
with pytest.raises(MetricSmokeError):
assert_metric_log_health(service_log, report_path=tmp_path / "log_health_report.json")
assert (tmp_path / "log_health_report.json").exists()
def test_scenario_registry_and_vllm_args(tmp_path):
context = ScenarioContext(
model_path="/model",
devices=[0, 1],
port=8000,
artifact_dir=tmp_path,
)
basic = get_scenario("basic")
assert basic.build_vllm_args(context) == ["--enforce-eager", "--max-model-len", "2048"]
eplb = get_scenario("eplb")
eplb_args = eplb.build_vllm_args(context)
assert eplb_args[:5] == [
"--tensor-parallel-size",
"2",
"--enable-expert-parallel",
"--max-model-len",
"4096",
]
assert eplb.build_service_env(context) == {"DYNAMIC_EPLB": "true"}
eplb_config = json.loads(eplb_args[eplb_args.index("--additional-config") + 1])
assert eplb_config["eplb_config"]["expert_heat_collection_interval"] == 4
custom_eplb_context = ScenarioContext(
model_path="/model",
devices=[0, 1],
port=8000,
artifact_dir=tmp_path,
options={
"eplb_heat_collection_interval": 8,
"eplb_algorithm_execution_interval": 2,
"eplb_policy_type": 1,
"eplb_num_redundant_experts": 0,
"eplb_max_model_len": 8192,
},
)
custom_eplb_args = eplb.build_vllm_args(custom_eplb_context)
custom_eplb_config = json.loads(custom_eplb_args[custom_eplb_args.index("--additional-config") + 1])
assert custom_eplb_args[custom_eplb_args.index("--max-model-len") + 1] == "8192"
assert custom_eplb_config["eplb_config"] == {
"dynamic_eplb": True,
"expert_heat_collection_interval": 8,
"algorithm_execution_interval": 2,
"eplb_policy_type": 1,
"num_redundant_experts": 0,
}
dplb = get_scenario("dplb")
assert dplb.build_vllm_args(context, ["--trust-remote-code"]) == [
"--enforce-eager",
"--max-model-len",
"2048",
"--data-parallel-size",
"2",
"--data-parallel-hybrid-lb",
"--trust-remote-code",
]
with pytest.raises(MetricSmokeError):
get_scenario("unknown")
def test_scenario_expectations_cover_stable_and_triggered_metrics():
basic_expectations = get_scenario("basic").expectations()
eplb_expectations = get_scenario("eplb").expectations()
assert len([item for item in basic_expectations if item.required]) == 4
assert all(item.increase for item in basic_expectations if item.required)
assert len([item for item in basic_expectations if item.triggered]) == 1
assert len([item for item in eplb_expectations if item.required]) == 2
assert len([item for item in eplb_expectations if item.triggered]) == 3
def test_metric_runner_restarts_collection_before_service_start(monkeypatch, tmp_path):
events = []
class FakeScenario:
name = "basic"
startup_timeout = 1
metric_poll_timeout = 1
def preflight(self, _context, report):
report.pass_("scenario:basic", "ok")
def build_service_env(self, _context):
return {}
def build_vllm_args(self, _context, _extra_args):
return []
def expectations(self):
return []
def validate_startup(self, _service_output):
events.append("validate_startup")
def run_requests(self, _context):
events.append("run_requests")
return SimpleNamespace(get_output_tail=lambda _limit: "request log")
class FakeServer:
output_lines = ["service log"]
def __init__(self, **_kwargs):
events.append("server_init")
def ready_go(self):
events.append("ready_go")
return True
def get_output_tail(self, _limit=None):
return "service log"
def kill(self):
events.append("kill")
monkeypatch.setattr(metric_runner, "get_scenario", lambda _name: FakeScenario())
monkeypatch.setattr(metric_runner, "require_command", lambda name, _report: events.append(f"require:{name}"))
monkeypatch.setattr(metric_runner, "check_model_path", lambda *_args, **_kwargs: None)
monkeypatch.setattr(metric_runner, "check_npu", lambda devices, **_kwargs: devices)
monkeypatch.setattr(metric_runner, "choose_service_port", lambda *_args, **_kwargs: 8000)
monkeypatch.setattr(metric_runner, "prepare_prometheus_dir", lambda *_args, **_kwargs: str(tmp_path / "prom"))
monkeypatch.setattr(metric_runner, "install_metric_source", lambda **_kwargs: {})
monkeypatch.setattr(metric_runner, "collect_runtime_info", lambda: {})
monkeypatch.setattr(metric_runner, "restart_metric_collection", lambda: events.append("restart"))
monkeypatch.setattr(metric_runner, "ExecVLLMServer", FakeServer)
monkeypatch.setattr(metric_runner, "wait_http_ok", lambda _url: events.append("wait_http_ok"))
monkeypatch.setattr(metric_runner, "fetch_text", lambda _url: "metrics")
monkeypatch.setattr(
metric_runner,
"wait_for_metric_expectations",
lambda **_kwargs: ("metrics", SimpleNamespace(passed=[], failed=[], skipped_optional=[])),
)
monkeypatch.setattr(
metric_runner,
"assert_metric_log_health",
lambda *_args, **_kwargs: SimpleNamespace(fatal=[], warnings=[], scanned_metric_lines=0),
)
metric_runner.run_metric_scenario(
metric_runner.MetricRunnerConfig(
scenario="basic",
model_path="/model",
devices=[0],
workspace=str(tmp_path),
debug=True,
)
)
assert events.index("restart") < events.index("server_init")
assert events.index("restart") < events.index("ready_go")
assert events.index("wait_http_ok") < events.index("run_requests")
def test_explicit_devices_override_visible_devices(monkeypatch):
monkeypatch.setenv("ASCEND_RT_VISIBLE_DEVICES", "0")
assert _parse_visible_devices([0, 1]) == [0, 1]
assert _parse_visible_devices([]) == [0]