import json
import logging
import re
import unittest
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
from parameterized import parameterized
from tensor_cast.core.input_generator import generate_inputs
from tensor_cast.core.model_runner import ModelRunner, ModelRunnerMetrics
from tensor_cast.core.quantization.datatypes import (
QuantizeAttentionAction,
QuantizeLinearAction,
)
from tensor_cast.core.user_config import UserInputConfig
logger = logging.getLogger(__name__)
CASE_DIR = Path(__file__).resolve().parent / "cases"
@dataclass
class BasePerfRegressionCase:
name: str
description: str
initial_time_s: float = 0.0
baseline_time_s: float = 0.0
initial_tolerance: float = 0.10
baseline_tolerance: float = 0.20
operator_top_n: int = 10
operator_tolerance: float = 0.10
operators: list[dict[str, float]] = None
@dataclass
class TextPerfRegressionCase(BasePerfRegressionCase):
user_input: Optional[UserInputConfig] = None
@dataclass
class VideoPerfRegressionCase(BasePerfRegressionCase):
device: str = ""
model_id: str = ""
seq_len: int = 0
batch_size: int = 0
height: int = 0
width: int = 0
frame_num: int = 0
sample_step: int = 0
dtype: str = "float16"
use_cfg: bool = False
world_size: int = 1
ulysses_size: int = 1
cfg_parallel: bool = False
quantize_linear_action: QuantizeLinearAction = QuantizeLinearAction.DISABLED
def _parse_total_time_s(table_result: str, model_name: str = "analytic") -> float:
pattern = rf"Total time for {model_name}:\s*([\d.]+)\s*(ns|us|ms|s)"
m = re.search(pattern, table_result)
if not m:
raise ValueError(f"Could not find 'Total time for {model_name}' in output:\n{table_result}")
value = float(m.group(1))
unit = m.group(2)
return value * {"ns": 1e-9, "us": 1e-6, "ms": 1e-3, "s": 1.0}[unit]
def _parse_top_operators(
table_result: str,
top_n: int = 10,
model_name: str = "analytic",
) -> list[tuple[str, float, int]]:
lines = table_result.split("\n")
data_started = False
operators: list[tuple[str, float, int]] = []
for line in lines:
if f"{model_name} total" in line and f"{model_name} avg" in line:
data_started = True
continue
if not data_started:
continue
if line.startswith("-") and operators:
break
if not line.strip():
continue
parts = line.split()
if len(parts) < 4:
continue
op_name = parts[0]
time_str = parts[1] if len(parts) > 1 else ""
calls_str = parts[3] if len(parts) > 3 else "0"
m = re.match(r"([\d.]+)\s*(ns|us|ms|s)", time_str)
if not m:
continue
value = float(m.group(1))
unit = m.group(2)
time_s = value * {"ns": 1e-9, "us": 1e-6, "ms": 1e-3, "s": 1.0}[unit]
num_calls = int(calls_str)
operators.append((op_name, time_s, num_calls))
return operators[:top_n]
def _load_baseline_operators(case_name: str) -> Optional[dict[str, dict[str, float]]]:
filepath = CASE_DIR / f"{case_name}.json"
if not filepath.exists():
return None
with open(filepath, encoding="utf-8") as f:
data = json.load(f)
operators = data.get("operators", [])
if not operators:
return None
return {
op["name"]: {
"total_time_s": op["total_time_s"],
"num_calls": op.get("num_calls", 0),
}
for op in operators
}
def _save_baseline_operators(case_name: str, operators: list[tuple[str, float, int]]):
filepath = CASE_DIR / f"{case_name}.json"
if not filepath.exists():
raise FileNotFoundError(f"Case file not found: {filepath}")
with open(filepath, encoding="utf-8") as f:
data = json.load(f)
data["operators"] = [
{"name": name, "total_time_s": time_s, "num_calls": num_calls} for name, time_s, num_calls in operators
]
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
def _load_perf_regression_cases() -> list[BasePerfRegressionCase]:
cases: list[BasePerfRegressionCase] = []
for path in sorted(CASE_DIR.glob("*.json")):
with open(path, encoding="utf-8") as f:
data = json.load(f)
case_type = data.pop("type", "text")
if case_type == "video":
data["quantize_linear_action"] = QuantizeLinearAction[data["quantize_linear_action"]]
if data.get("model_id") and not Path(data["model_id"]).is_absolute():
candidate = (Path(__file__).resolve().parent / data["model_id"]).resolve()
if not candidate.exists():
candidate = (Path(__file__).resolve().parents[2] / data["model_id"]).resolve()
data["model_id"] = str(candidate)
cases.append(VideoPerfRegressionCase(**data))
else:
ui_data = data.pop("user_input", {})
for key in ("quantize_linear_action", "quantize_attention_action"):
if key in ui_data and isinstance(ui_data[key], str):
if key == "quantize_linear_action":
ui_data[key] = QuantizeLinearAction[ui_data[key]]
else:
ui_data[key] = QuantizeAttentionAction[ui_data[key]]
user_input = UserInputConfig(**ui_data)
cases.append(TextPerfRegressionCase(user_input=user_input, **data))
return cases
PERF_REGRESSION_CASES = _load_perf_regression_cases()
class TestPerformanceRegression(unittest.TestCase):
@classmethod
def setUpClass(cls):
logging.basicConfig(
level=logging.INFO,
format="[%(levelname)s] [%(name)s] %(message)s",
)
cls._time_results: list[tuple] = []
cls._op_results: list[dict] = []
cls._op_detail_rows: list[tuple] = []
@classmethod
def tearDownClass(cls):
_print_time_summary(cls._time_results)
_print_operator_summary(cls._op_results, cls._op_detail_rows)
@parameterized.expand(
[(case.name, case) for case in PERF_REGRESSION_CASES],
skip_on_empty=True,
)
def test_performance_regression(self, name: str, case: BasePerfRegressionCase):
torch.compiler.reset()
if isinstance(case, VideoPerfRegressionCase):
import io
from contextlib import redirect_stderr, redirect_stdout
from cli.inference.video_generate import (
run_inference as video_run_inference,
)
buf = io.StringIO()
with redirect_stdout(buf), redirect_stderr(buf):
video_run_inference(
device=case.device,
model_id=case.model_id,
batch_size=case.batch_size,
seq_len=case.seq_len,
height=case.height,
width=case.width,
frame_num=case.frame_num,
sample_step=case.sample_step,
dtype=case.dtype,
use_cfg=case.use_cfg,
world_size=case.world_size,
ulysses_size=case.ulysses_size,
cfg_parallel=case.cfg_parallel,
quantize_linear_action=case.quantize_linear_action,
)
table_result = buf.getvalue()
actual_time_s = _parse_total_time_s(table_result)
else:
model_runner = ModelRunner(case.user_input)
result = model_runner.run_inference(generate_inputs_func=generate_inputs)
if isinstance(result, ModelRunnerMetrics):
table_result = result.table_result
else:
self.fail(f"Unexpected result type: {type(result)}")
actual_time_s = _parse_total_time_s(table_result)
logger.info(
"[%s] Actual total time: %.6fs (%.3fms)",
name,
actual_time_s,
actual_time_s * 1000,
)
initial_passed = True
baseline_passed = True
initial_diff_pct = 0.0
baseline_diff_pct = 0.0
if case.initial_time_s > 0.0:
initial_diff_pct = (actual_time_s - case.initial_time_s) / case.initial_time_s
initial_passed = abs(initial_diff_pct) <= case.initial_tolerance
if case.baseline_time_s > 0.0:
baseline_diff_pct = (actual_time_s - case.baseline_time_s) / case.baseline_time_s
baseline_passed = abs(baseline_diff_pct) <= case.baseline_tolerance
time_overall = initial_passed and baseline_passed
if case.initial_time_s == 0.0 and case.baseline_time_s == 0.0:
time_status = "NO_BASELINE"
elif not initial_passed and not baseline_passed:
time_status = "FAIL(BOTH)"
elif not initial_passed:
time_status = "FAIL(INIT)"
elif not baseline_passed:
time_status = "FAIL(BASE)"
else:
time_status = "PASS"
self.__class__._time_results.append(
(
name,
time_overall,
actual_time_s,
case.initial_time_s,
initial_diff_pct * 100,
case.baseline_time_s,
baseline_diff_pct * 100,
time_status,
)
)
actual_operators = _parse_top_operators(table_result, top_n=case.operator_top_n)
baseline_operators = _load_baseline_operators(name)
op_passed = None
if baseline_operators is None:
case_path = CASE_DIR / f"{name}.json"
self.__class__._op_results.append(
{
"case_name": name,
"op_passed": None,
"op_status": "NO_BASELINE",
"violations": [],
}
)
logger.warning(
"[%s] No operator baseline found in case file: %s. "
"Please generate baseline explicitly before running regression tests.",
name,
case_path,
)
else:
baseline_top_n = dict(
sorted(
baseline_operators.items(),
key=lambda item: item[1]["total_time_s"],
reverse=True,
)[: case.operator_top_n]
)
violations: list[str] = []
actual_op_names = {op_name for op_name, _, _ in actual_operators}
baseline_op_names = set(baseline_top_n.keys())
missing_from_actual = baseline_op_names - actual_op_names
for op_name in sorted(missing_from_actual):
bl = baseline_top_n[op_name]
violations.append(
f" MISSING OPERATOR: {op_name} (baseline={bl['total_time_s'] * 1000:.3f}ms, #calls={bl['num_calls']}) - Not found in current results"
)
self.__class__._op_detail_rows.append(
(
name,
op_name,
f"{bl['total_time_s'] * 1000:.3f}ms",
"MISSING",
"N/A",
"WARN",
str(bl["num_calls"]),
"-",
)
)
for op_name, actual_op_time, actual_num_calls in actual_operators:
baseline = baseline_top_n.get(op_name)
if baseline is None:
violations.append(
f" NEW OPERATOR: {op_name} (actual={actual_op_time * 1000:.3f}ms, #calls={actual_num_calls}) - Not in baseline Top {case.operator_top_n}"
)
self.__class__._op_detail_rows.append(
(
name,
op_name,
"N/A",
f"{actual_op_time * 1000:.3f}ms",
"NEW",
"WARN",
"-",
str(actual_num_calls),
)
)
continue
baseline_op_time = baseline["total_time_s"]
baseline_num_calls = baseline["num_calls"]
if baseline_op_time == 0.0:
self.__class__._op_detail_rows.append(
(
name,
op_name,
"0.000ms",
f"{actual_op_time * 1000:.3f}ms",
"N/A",
"PASS",
"0",
str(actual_num_calls),
)
)
continue
op_diff_pct = (actual_op_time - baseline_op_time) / baseline_op_time
time_fail = abs(op_diff_pct) > case.operator_tolerance
calls_mismatch = baseline_num_calls > 0 and actual_num_calls != baseline_num_calls
op_row_status = "PASS" if (not time_fail and not calls_mismatch) else "WARN"
if op_row_status == "WARN":
violations.append(
f" {op_name}: {op_diff_pct * 100:+.2f}% "
f"(baseline={baseline_op_time * 1000:.3f}ms, actual={actual_op_time * 1000:.3f}ms)"
)
if baseline_num_calls > 0 and actual_num_calls != baseline_num_calls:
violations.append(
f" {op_name}: #CALLS MISMATCH baseline={baseline_num_calls}, actual={actual_num_calls}"
)
calls_str = f"{baseline_num_calls}/{actual_num_calls}{'!' if baseline_num_calls > 0 and actual_num_calls != baseline_num_calls else ''}"
self.__class__._op_detail_rows.append(
(
name,
op_name,
f"{baseline_op_time * 1000:.3f}ms",
f"{actual_op_time * 1000:.3f}ms",
f"{op_diff_pct * 100:+.2f}%",
op_row_status,
str(baseline_num_calls),
calls_str,
)
)
op_passed = len(violations) == 0
op_status = "PASS" if op_passed else "WARN"
self.__class__._op_results.append(
{
"case_name": name,
"op_passed": op_passed,
"op_status": op_status,
"violations": violations,
}
)
logger.info(
"[%s] Top-%d Operator Comparison (tolerance: ±%.0f%%):",
name,
case.operator_top_n,
case.operator_tolerance * 100,
)
logger.info(
" %-50s %10s %10s %10s %8s",
"Operator",
"Baseline",
"Actual",
"Diff%",
"#Calls",
)
logger.info(" %s %s %s %s %s", "-" * 50, "-" * 10, "-" * 10, "-" * 10, "-" * 8)
for op_name, actual_op_time, actual_num_calls in actual_operators:
baseline = baseline_top_n.get(op_name)
if baseline is None:
logger.info(
" %-50s %10s %9.3fms %10s %8d",
op_name,
"N/A",
actual_op_time * 1000,
"NEW",
actual_num_calls,
)
elif baseline["total_time_s"] == 0.0:
logger.info(
" %-50s %9.3fms %9.3fms %10s %8d",
op_name,
0.0,
actual_op_time * 1000,
"N/A",
actual_num_calls,
)
else:
diff = (actual_op_time - baseline["total_time_s"]) / baseline["total_time_s"] * 100
calls_flag = "!" if baseline["num_calls"] > 0 and actual_num_calls != baseline["num_calls"] else ""
logger.info(
" %-50s %9.3fms %9.3fms %+9.2f%% %7d%s",
op_name,
baseline["total_time_s"] * 1000,
actual_op_time * 1000,
diff,
actual_num_calls,
calls_flag,
)
if not op_passed:
pass
if not time_overall:
msg_parts = [f"\n[{case.name}] Total time regression anomaly detected!"]
msg_parts.append(f" Description: {case.description}")
msg_parts.append(f" Actual: {actual_time_s * 1000:.3f}ms")
if case.initial_time_s > 0.0:
msg_parts.append(
f" vs Initial: {case.initial_time_s * 1000:.3f}ms "
f"({initial_diff_pct * 100:+.2f}%, tolerance: ±{case.initial_tolerance * 100:.0f}%) "
f"{'PASS' if initial_passed else 'FAIL'}"
)
if case.baseline_time_s > 0.0:
msg_parts.append(
f" vs Baseline: {case.baseline_time_s * 1000:.3f}ms "
f"({baseline_diff_pct * 100:+.2f}%, tolerance: ±{case.baseline_tolerance * 100:.0f}%) "
f"{'PASS' if baseline_passed else 'FAIL'}"
)
self.fail("\n".join(msg_parts))
if op_passed is False:
violations = self.__class__._op_results[-1].get("violations", [])
msg_parts = [
f"\n[{case.name}] Operator-level differences detected (WARNING - not treated as test failure):"
]
msg_parts.append(f" Description: {case.description}")
msg_parts.append(f" Top-{case.operator_top_n} Operator Comparison (...)")
for v in violations:
msg_parts.append(v)
logger.warning("\n".join(msg_parts))
def _emit(text: str):
print(text)
def _print_time_summary(results: list[tuple]):
if not results:
return
total = len(results)
passed = sum(1 for r in results if r[1] is True)
failed = sum(1 for r in results if r[1] is False)
no_baseline = sum(1 for r in results if r[1] is None)
_ = no_baseline
_emit("")
_emit("=" * 120)
_emit(" [Test 1] Total Time Regression Summary")
_emit("=" * 120)
header = (
f"{'Case':<36} {'Actual':>10} "
f"{'Init':>10} {'InitDiff':>10} "
f"{'Baseline':>10} {'BaseDiff':>10} "
f"{'Status':>10}"
)
_emit(header)
_emit("-" * 120)
for name, ok, actual, init_time, init_diff, base_time, base_diff, status in results:
actual_str = f"{actual * 1000:.3f}ms"
init_str = f"{init_time * 1000:.3f}ms" if init_time > 0 else "N/A"
init_diff_str = f"{init_diff:+.2f}%" if init_time > 0 else "N/A"
base_str = f"{base_time * 1000:.3f}ms" if base_time > 0 else "N/A"
base_diff_str = f"{base_diff:+.2f}%" if base_time > 0 else "N/A"
_emit(
f"{name:<36} {actual_str:>10} {init_str:>10} {init_diff_str:>10} {base_str:>10} {base_diff_str:>10} {status:>10}"
)
_emit("-" * 120)
_emit(f"Total: {total} | Passed: {passed} | Failed: {failed} | No Baseline: {no_baseline}")
_emit("=" * 120)
_emit("")
def _print_operator_summary(op_results: list[dict], op_detail_rows: list[tuple]):
if not op_detail_rows:
return
total = len(op_results)
passed = sum(1 for r in op_results if r["op_passed"] is True)
warned = sum(1 for r in op_results if r["op_passed"] is False)
no_baseline = sum(1 for r in op_results if r["op_passed"] is None)
if warned == 0:
_emit("*** All Operator Top-N Checks Passed (no warnings) ***")
_emit("")
return
_emit("=" * 145)
_emit(" [Test 2] Operator-Level Differences Summary (WARNING only, not treated as failure)")
_emit("=" * 145)
header = (
f"{'Case':<32} {'Operator':<44} {'Baseline':>10} {'Actual':>10} {'Diff':>10} {'#Calls':>12} {'Status':>10}"
)
_emit(header)
_emit("-" * 145)
for (
case_name,
op_name,
baseline_str,
actual_str,
diff_str,
status,
_,
calls_str,
) in op_detail_rows:
if status == "WARN":
_emit(
f"{case_name:<32} {op_name:<44} {baseline_str:>10} {actual_str:>10} {diff_str:>10} {calls_str:>12} {status:>10}"
)
_emit("-" * 145)
_emit(f"Total Cases: {total} | Passed: {passed} | Warnings: {warned} | No Baseline: {no_baseline}")
_emit("=" * 145)
_emit("")
_emit("*** Operator Differences Detected (WARNING only, test NOT failed) ***")