"""Generate per-operator TC vs Profiling comparison data.
Reads TC chrome traces (from --performance-model profiling --chrome-trace)
for exact per-op invocation counts and latencies. No layer_multiplier needed.
Uses pre-extracted forward-pass trace CSVs from the repo for profiling
ground truth. All inputs are checked into the repository.
Usage:
python3.10 tools/perf_data_analysis/generate_op_comparison.py \
[--trace-dir docs/perf_database/forward_pass_traces] \
[--data-dir <db_path>]
Outputs results/op_comparison.json with per-kernel-type rows:
TC (HIT empirical + MISS analytic) vs Profiling per-forward.
"""
import argparse
import csv
import json
import os
from collections import defaultdict
from typing import Any
import yaml
DEFAULT_DATA_DIR = (
"tensor_cast/performance_model/profiling_database/data"
"/ATLAS_800_A3_752T_128G_DIE/vllm_ascend/vllm0.18.0_torch2.9.0_cann8.5"
)
TRACE_DIR = "docs/perf_database/forward_pass_traces"
TC_UNMODELED_STATUS = "TC\u672a\u5efa\u6a21"
SCENARIO_SEPARATOR = "\u2014"
TC_LATENCY_HEADER = "TC(\u03bcs)"
TC_COUNT_HEADER = "TC\u6b21\u6570"
PROF_LATENCY_HEADER = "Prof(\u03bcs)"
PROF_COUNT_HEADER = "Prof\u6b21\u6570"
LATENCY_ERROR_HEADER = "\u5ef6\u8fdf\u8bef\u5dee"
COUNT_ERROR_HEADER = "\u6b21\u6570\u8bef\u5dee"
PROF_SHARE_HEADER = "Prof\u5360\u6bd4"
STATUS_HEADER = "\u72b6\u6001"
TC_TRACE_FILES = {
"Qwen3 PF": "results/qwen3_pf_trace.json",
"Qwen3 DC": "results/qwen3_dc_trace.json",
"DSv3 PF": "results/dsv3_pf_trace.json",
"DSv3 DC": "results/dsv3_dc_trace.json",
}
def get_scenarios(trace_dir: str = TRACE_DIR):
return [
{
"name": "Qwen3 PF",
"m6": "results/qwen3_prefill_m6.json",
"trace_csv": f"{trace_dir}/qwen3-32b_pf_4112tok.csv",
"tc_trace": TC_TRACE_FILES["Qwen3 PF"],
},
{
"name": "Qwen3 DC",
"m6": "results/qwen3_decode_m6.json",
"trace_csv": f"{trace_dir}/qwen3-32b_dc_16tok.csv",
"tc_trace": TC_TRACE_FILES["Qwen3 DC"],
},
{
"name": "DSv3 PF",
"m6": "results/dsv3_prefill_m6.json",
"trace_csv": f"{trace_dir}/dsv3_pf_4099tok.csv",
"tc_trace": TC_TRACE_FILES["DSv3 PF"],
},
{
"name": "DSv3 DC",
"m6": "results/dsv3_decode_m6.json",
"trace_csv": f"{trace_dir}/dsv3_dc_1tok.csv",
"tc_trace": TC_TRACE_FILES["DSv3 DC"],
},
]
def parse_profiling_by_type(trace_csv: str) -> dict[str, dict[str, Any]]:
"""Parse a forward-pass trace CSV with hcom deduplication.
Accepts a CSV file path directly (pre-extracted forward pass trace
from docs/perf_database/forward_pass_traces/).
Dedup strategy: group hcom_* kernels by (int(start_time), Type) and
take max(duration) per group (parallel HCCL + AICPU streams).
Non-hcom kernels are summed directly.
"""
stats: dict[str, dict[str, Any]] = defaultdict(
lambda: {"total_us": 0.0, "count": 0}
)
hcom_groups: dict[tuple[int, str], float] = {}
with open(trace_csv) as f:
for row in csv.DictReader(f):
kt = row.get("Type", "").strip()
if not kt:
continue
dur = float(row.get("Duration(us)", 0))
if kt.startswith("hcom_"):
start_str = row.get("Start Time(us)", "0").strip()
try:
start_key = int(float(start_str))
except ValueError:
start_key = 0
key = (start_key, kt)
hcom_groups[key] = max(hcom_groups.get(key, 0.0), dur)
elif kt.endswith("AicpuKernel") or kt == "AicpuKernel":
continue
else:
stats[kt]["total_us"] += dur
stats[kt]["count"] += 1
for (_, kt), dur in hcom_groups.items():
stats[kt]["total_us"] += dur
stats[kt]["count"] += 1
return dict(stats)
def load_op_mapping(data_dir: str) -> dict:
"""Load op_mapping.yaml for func_name -> kernel_type resolution."""
mapping_path = os.path.join(data_dir, "op_mapping.yaml")
if not os.path.exists(mapping_path):
return {}
with open(mapping_path) as f:
return yaml.safe_load(f) or {}
def _get_op_entry(func_name: str, op_mapping: dict) -> dict | None:
"""Look up op_mapping entry for a func_name."""
short = func_name.replace("torch.ops.", "")
entry = op_mapping.get(short) or op_mapping.get(func_name)
if entry and isinstance(entry, dict):
return entry
return None
def resolve_kernel_type(func_name: str, op_mapping: dict) -> str | None:
"""Resolve func_name to kernel_type via op_mapping."""
entry = _get_op_entry(func_name, op_mapping)
return entry.get("kernel_type") if entry else None
def is_zero_cost(func_name: str, op_mapping: dict) -> bool:
"""Check if an op is zero_cost via op_mapping."""
entry = _get_op_entry(func_name, op_mapping)
return bool(entry.get("zero_cost")) if entry else False
def extract_tc_from_chrome_trace(
trace_path: str, op_mapping: dict
) -> dict[str, dict[str, Any]]:
"""Extract per-kernel-type TC stats from chrome trace.
The chrome trace contains one event per process_op call across all
layers - exact invocation counts, no layer_multiplier needed.
For profiling model events:
- HIT ops: args contains 'kernel_type' from profiling data source
- MISS ops: no 'kernel_type' in args, resolve via op_mapping
Returns: {kernel_type: {"total_us": float, "count": int,
"status": "HIT"|"MISS", "func_names": [...]}}
"""
with open(trace_path) as f:
trace = json.load(f)
profiling_pid = None
for ev in trace.get("traceEvents", []):
if ev.get("ph") == "M" and ev.get("name") == "process_name":
name_str = ev.get("args", {}).get("name", "").lower()
if "profiling" in name_str or "empirical" in name_str:
profiling_pid = ev["pid"]
break
if profiling_pid is None:
pids = {ev["pid"] for ev in trace.get("traceEvents", []) if ev.get("ph") == "X"}
if len(pids) == 1:
profiling_pid = pids.pop()
import sys
print(
f"Warning: profiling PID not found by name in {trace_path}, "
f"using sole PID {profiling_pid}",
file=sys.stderr,
)
else:
raise ValueError(
f"Cannot identify profiling PID in {trace_path}. Found PIDs: {pids}"
)
stats: dict[str, dict[str, Any]] = defaultdict(
lambda: {
"total_us": 0.0,
"count": 0,
"hit_count": 0,
"miss_count": 0,
"func_names": [],
}
)
for ev in trace.get("traceEvents", []):
if ev.get("ph") != "X" or ev.get("pid") != profiling_pid:
continue
op_name = ev["name"]
dur_us = ev.get("dur", 0)
args = ev.get("args", {})
if is_zero_cost(op_name, op_mapping):
continue
kt_from_args = args.get("kernel_type")
if kt_from_args and kt_from_args != "?":
kt_raw = str(kt_from_args)
sub_kts = [s.strip() for s in kt_raw.split(",") if s.strip()]
skd_raw = args.get("sub_kernel_durations")
skd = None
if skd_raw:
try:
import ast
skd = ast.literal_eval(skd_raw)
except (ValueError, SyntaxError):
skd = None
if not (
isinstance(skd, list)
and all(isinstance(x, (list, tuple)) and len(x) == 2 for x in skd)
):
skd = None
if skd:
skd_names = {name for name, _ in skd}
sub_kt_names = set(sub_kts)
if skd_names != sub_kt_names:
import logging
logging.getLogger(__name__).debug(
"sub_kernel_durations names %s != kernel_type names %s for %s",
skd_names,
sub_kt_names,
op_name,
)
for sk_name, sk_dur in skd:
stats[sk_name]["hit_count"] += 1
stats[sk_name]["total_us"] += sk_dur
stats[sk_name]["count"] += 1
fn_short = op_name.replace("torch.ops.", "")
if fn_short not in stats[sk_name]["func_names"]:
stats[sk_name]["func_names"].append(fn_short)
else:
per_sub_dur = dur_us / len(sub_kts) if sub_kts else dur_us
for sub_kt in sub_kts:
stats[sub_kt]["hit_count"] += 1
stats[sub_kt]["total_us"] += per_sub_dur
stats[sub_kt]["count"] += 1
fn_short = op_name.replace("torch.ops.", "")
if fn_short not in stats[sub_kt]["func_names"]:
stats[sub_kt]["func_names"].append(fn_short)
fn_lower = op_name.lower()
if "all_reduce" in fn_lower and "hcom_allReduce_" not in sub_kts:
stats["hcom_allReduce_"]["hit_count"] += 1
stats["hcom_allReduce_"]["count"] += 1
fn_short = op_name.replace("torch.ops.", "")
if fn_short not in stats["hcom_allReduce_"]["func_names"]:
stats["hcom_allReduce_"]["func_names"].append(fn_short)
else:
kt = resolve_kernel_type(op_name, op_mapping) or op_name
stats[kt]["miss_count"] += 1
stats[kt]["total_us"] += dur_us
stats[kt]["count"] += 1
fn_short = op_name.replace("torch.ops.", "")
if fn_short not in stats[kt]["func_names"]:
stats[kt]["func_names"].append(fn_short)
for kt_stats in stats.values():
if kt_stats["miss_count"] > 0 and kt_stats["hit_count"] > 0:
kt_stats["status"] = "PARTIAL"
elif kt_stats["hit_count"] > 0:
kt_stats["status"] = "HIT"
else:
kt_stats["status"] = "MISS"
return dict(stats)
def build_comparison(scenario: dict, op_mapping: dict) -> list[dict]:
"""Build per-operator comparison: TC vs Profiling.
TC data comes from chrome trace (exact per-invocation counts).
Profiling data comes from forward-pass trace CSV (single forward).
"""
with open(scenario["m6"]) as f:
m6_data = json.load(f)
n_fwd = m6_data.get("n_forward_passes", 1)
tc_stats = extract_tc_from_chrome_trace(scenario["tc_trace"], op_mapping)
prof_stats = parse_profiling_by_type(scenario["trace_csv"])
rows = []
all_kts = set(tc_stats.keys()) | set(prof_stats.keys())
for kt in sorted(all_kts):
tc_info = tc_stats.get(
kt,
{
"total_us": 0,
"count": 0,
"status": TC_UNMODELED_STATUS,
"func_names": [],
},
)
prof_info = prof_stats.get(kt, {"total_us": 0, "count": 0})
tc_total_us = tc_info["total_us"]
tc_count = tc_info["count"]
prof_per_fwd_us = prof_info["total_us"] / n_fwd if n_fwd > 0 else 0
prof_count = prof_info["count"] / n_fwd if n_fwd > 0 else 0
error_pct = (
(tc_total_us - prof_per_fwd_us) / prof_per_fwd_us * 100
if prof_per_fwd_us > 0
else 0
)
count_error_pct = (
(tc_count - prof_count) / prof_count * 100 if prof_count > 0 else 0
)
status = tc_info["status"] if kt in tc_stats else TC_UNMODELED_STATUS
rows.append(
{
"kernel_type": kt,
"tc_funcs": ", ".join(tc_info["func_names"][:3]),
"tc_total_us": round(tc_total_us, 1),
"tc_count": tc_count,
"prof_per_fwd_us": round(prof_per_fwd_us, 1),
"prof_count": int(round(prof_count)),
"error_pct": round(error_pct, 1),
"count_error_pct": round(count_error_pct, 1),
"status": status,
}
)
rows.sort(
key=lambda x: max(x["tc_total_us"], x["prof_per_fwd_us"]),
reverse=True,
)
return rows
def main():
parser = argparse.ArgumentParser(
description="Generate per-operator TC vs Profiling comparison data"
)
parser.add_argument(
"--trace-dir",
default=TRACE_DIR,
help=(
"Directory with forward-pass trace CSVs "
"(default: docs/perf_database/forward_pass_traces)"
),
)
parser.add_argument(
"--data-dir",
default=DEFAULT_DATA_DIR,
help="Path to profiling database directory (with op_mapping.yaml)",
)
parser.add_argument(
"--output",
default="results/op_comparison.json",
help="Output JSON path",
)
args = parser.parse_args()
op_mapping = load_op_mapping(args.data_dir)
scenarios = get_scenarios(args.trace_dir)
all_results = {}
for sc in scenarios:
name = sc["name"]
try:
rows = build_comparison(sc, op_mapping)
except FileNotFoundError as e:
print(f"SKIP {name}: {e}")
continue
all_results[name] = rows
total_tc = sum(r["tc_total_us"] for r in rows)
total_prof = sum(r["prof_per_fwd_us"] for r in rows)
print(f"\n{'=' * 115}")
print(f"{name} {SCENARIO_SEPARATOR} TC vs Profiling per-forward")
print(f"{'=' * 115}")
print(
f"{'kernel_type':<35} {TC_LATENCY_HEADER:>12} {TC_COUNT_HEADER:>7} "
f"{PROF_LATENCY_HEADER:>12} {PROF_COUNT_HEADER:>8} "
f"{LATENCY_ERROR_HEADER:>9} {COUNT_ERROR_HEADER:>9} "
f"{PROF_SHARE_HEADER:>8} {STATUS_HEADER:<6}"
)
print("-" * 115)
for r in rows[:15]:
pf_pct = r["prof_per_fwd_us"] / total_prof * 100 if total_prof > 0 else 0
print(
f"{r['kernel_type']:<35} {r['tc_total_us']:>12.1f} "
f"{r['tc_count']:>7d} "
f"{r['prof_per_fwd_us']:>12.1f} {r['prof_count']:>8.0f} "
f"{r['error_pct']:>8.1f}% {r['count_error_pct']:>8.1f}% "
f"{pf_pct:>7.1f}% {r['status']:<6}"
)
print(f"{'TOTAL':<35} {total_tc:>12.1f} {'':>7} {total_prof:>12.1f}")
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
with open(args.output, "w") as f:
json.dump(all_results, f, indent=2, ensure_ascii=False)
print(f"\nSaved to {args.output}")
if __name__ == "__main__":
main()