"""C9 comm data alignment: validate HCCL microbenchmark CSVs against CommAnalyticModel.
Loads per-op CSVs from an hccl/ directory and compares each measured Duration(us)
against the CommAnalyticModel analytic prediction for the same
(message_bytes, num_devices, topology_tier) combination.
Alignment criteria (configurable via --tolerance):
- PASS : 1/tolerance ≤ ratio ≤ tolerance (default tolerance=2.0)
- WARN : ratio outside tolerance but within tolerance*2
- FAIL : ratio outside tolerance*2
CSV filename → op_type mapping (matches op_mapping.yaml kernel_type):
hcom_allReduce_.csv → all_reduce
hcom_allGather_.csv → all_gather
hcom_reduceScatter_.csv → reduce_scatter
hcom_alltoallv_.csv → all_to_all
Usage:
python validate_comm_alignment.py --csv-dir ./hccl/v8.5/
python validate_comm_alignment.py --csv-dir ./hccl/v8.5/ --tolerance 1.5 --verbose
"""
import argparse
import csv
import math
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
_REPO_ROOT = Path(__file__).resolve().parents[3]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from cli.logo import print_logo
@dataclass
class _Topology:
bandwidth_bytes_ps: float
latency_s: float
comm_efficiency: float = 0.7
_A3_TOPOLOGIES: Dict[int, _Topology] = {
0: _Topology(bandwidth_bytes_ps=196e9, latency_s=5.5e-6),
1: _Topology(bandwidth_bytes_ps=196e9, latency_s=0.5e-6),
2: _Topology(bandwidth_bytes_ps=224e9, latency_s=0.2e-6),
}
_CSV_TO_OP: Dict[str, str] = {
"hcom_allReduce_.csv": "all_reduce",
"hcom_allGather_.csv": "all_gather",
"hcom_reduceScatter_.csv": "reduce_scatter",
"hcom_alltoallv_.csv": "all_to_all",
}
_DEFAULT_TOLERANCE = 2.0
def analytic_predict_us(
op_type: str,
message_bytes: int,
num_devices: int,
topology_tier: int,
topologies: Optional[Dict[int, _Topology]] = None,
) -> float:
"""Return analytic latency estimate in microseconds.
Mirrors CommAnalyticModel ring/tree/recursive algorithm selection.
message_bytes semantics vary per op to match benchmark CSV conventions:
- all_reduce: total tensor size (formula uses 2*(n-1)*m/n)
- all_gather: per-rank chunk size (formula uses (n-1)*m)
- reduce_scatter: total tensor size (formula uses (n-1)*m/n)
Both sides (CSV measured values and analytic predictions) use the same
convention per op, so the ratio comparison remains valid.
"""
if topologies is None:
topologies = _A3_TOPOLOGIES
topo = topologies[topology_tier]
bw = topo.bandwidth_bytes_ps * topo.comm_efficiency
lat = topo.latency_s
n = num_devices
m = message_bytes
if n <= 1:
return 0.0
if op_type == "all_reduce":
time_ring = 2 * (n - 1) * lat + 2 * (n - 1) * m / n / bw
time_tree = 2 * math.log2(n) * lat + 2 * m / bw
return min(time_ring, time_tree) * 1e6
elif op_type == "all_gather":
time_ring = (n - 1) * lat + (n - 1) * m / bw
time_rec = math.log2(n) * lat + (n - 1) * m / bw
return min(time_ring, time_rec) * 1e6
elif op_type == "reduce_scatter":
time_ring = (n - 1) * lat + (n - 1) * m / n / bw
time_rec = math.log2(n) * lat + (n - 1) * m / n / bw
return min(time_ring, time_rec) * 1e6
elif op_type == "all_to_all":
time_pairwise = (n - 1) * lat + m / bw
time_bruck = math.log2(n) * lat + m / bw
return min(time_pairwise, time_bruck) * 1e6
else:
raise ValueError(f"Unknown op_type: {op_type!r}")
@dataclass
class AlignmentRow:
op_type: str
message_bytes: int
num_devices: int
topology_tier: int
measured_us: float
predicted_us: float
@property
def ratio(self) -> float:
if self.predicted_us == 0:
return float("inf")
return self.measured_us / self.predicted_us
def status(self, tolerance: float) -> str:
r = self.ratio
if 1.0 / tolerance <= r <= tolerance:
return "PASS"
if 1.0 / (tolerance * 2) <= r <= tolerance * 2:
return "WARN"
return "FAIL"
def load_comm_csv(csv_path: Path) -> List[dict]:
"""Load a microbenchmark comm CSV. Returns list of row dicts."""
rows = []
with csv_path.open(newline="") as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
return rows
def _parse_row(row: dict, op_type: str) -> Optional[AlignmentRow]:
"""Parse a CSV row into an AlignmentRow. Returns None on parse error."""
try:
message_bytes = int(row["message_bytes"])
num_devices = int(row["num_devices"])
topology_tier = int(row.get("topology_tier", 1))
duration_col = "Average Duration(us)" if "Average Duration(us)" in row else "Duration(us)"
measured_us = float(row[duration_col])
except (KeyError, ValueError):
return None
predicted_us = analytic_predict_us(op_type, message_bytes, num_devices, topology_tier)
return AlignmentRow(
op_type=op_type,
message_bytes=message_bytes,
num_devices=num_devices,
topology_tier=topology_tier,
measured_us=measured_us,
predicted_us=predicted_us,
)
@dataclass
class AlignmentReport:
rows: List[AlignmentRow]
tolerance: float
@property
def pass_count(self) -> int:
return sum(1 for r in self.rows if r.status(self.tolerance) == "PASS")
@property
def warn_count(self) -> int:
return sum(1 for r in self.rows if r.status(self.tolerance) == "WARN")
@property
def fail_count(self) -> int:
return sum(1 for r in self.rows if r.status(self.tolerance) == "FAIL")
@property
def mean_ratio(self) -> float:
if not self.rows:
return float("nan")
return sum(r.ratio for r in self.rows) / len(self.rows)
@property
def max_ratio(self) -> float:
if not self.rows:
return float("nan")
return max(r.ratio for r in self.rows)
@property
def min_ratio(self) -> float:
if not self.rows:
return float("nan")
return min(r.ratio for r in self.rows)
def ok(self) -> bool:
return self.fail_count == 0
def validate_csv(csv_path: Path, op_type: str, tolerance: float) -> AlignmentReport:
"""Validate a single comm CSV against analytic predictions."""
raw_rows = load_comm_csv(csv_path)
alignment_rows = []
for raw in raw_rows:
row = _parse_row(raw, op_type)
if row is not None:
alignment_rows.append(row)
return AlignmentReport(rows=alignment_rows, tolerance=tolerance)
def validate_directory(
csv_dir: Path,
tolerance: float = _DEFAULT_TOLERANCE,
) -> Tuple[Dict[str, AlignmentReport], bool]:
"""Validate all known comm CSVs in csv_dir.
Returns (reports_by_op, all_ok).
"""
reports: Dict[str, AlignmentReport] = {}
all_ok = True
for filename, op_type in _CSV_TO_OP.items():
csv_path = csv_dir / filename
if not csv_path.exists():
print(f" [SKIP] {filename} not found in {csv_dir}")
continue
report = validate_csv(csv_path, op_type, tolerance)
reports[op_type] = report
if not report.ok():
all_ok = False
return reports, all_ok
def _print_report(op_type: str, report: AlignmentReport, verbose: bool) -> None:
print(f"\n{'='*70}")
print(f" {op_type} ({len(report.rows)} rows, tolerance={report.tolerance:.1f}x)")
print(f" PASS={report.pass_count} WARN={report.warn_count} FAIL={report.fail_count}")
if report.rows:
print(f" ratio: mean={report.mean_ratio:.2f}x min={report.min_ratio:.2f}x max={report.max_ratio:.2f}x")
print(f"{'='*70}")
if verbose or report.fail_count > 0 or report.warn_count > 0:
header = f" {'status':<6} {'msg_bytes':>12} {'n_dev':>5} {'tier':>4} {'measured':>10} {'predicted':>10} {'ratio':>6}"
print(header)
print(" " + "-" * (len(header) - 2))
for row in report.rows:
status = row.status(report.tolerance)
if not verbose and status == "PASS":
continue
print(
f" {status:<6} {row.message_bytes:>12} {row.num_devices:>5} "
f"{row.topology_tier:>4} {row.measured_us:>10.2f} "
f"{row.predicted_us:>10.2f} {row.ratio:>6.2f}x"
)
def main() -> None:
parser = argparse.ArgumentParser(
description="C9 comm data alignment: validate HCCL CSVs vs CommAnalyticModel.",
)
parser.add_argument(
"--csv-dir", required=True,
help="Directory containing hcom_*.csv microbenchmark files",
)
parser.add_argument(
"--tolerance", type=float, default=_DEFAULT_TOLERANCE,
help=f"Acceptable ratio range [1/tol, tol] (default: {_DEFAULT_TOLERANCE})",
)
parser.add_argument(
"--verbose", action="store_true",
help="Print all rows, not just WARN/FAIL",
)
args = parser.parse_args()
print_logo()
csv_dir = Path(args.csv_dir)
if not csv_dir.is_dir():
print(f"ERROR: {csv_dir} is not a directory", file=sys.stderr)
sys.exit(1)
print(f"C9 Comm Alignment Validation")
print(f" csv_dir : {csv_dir}")
print(f" tolerance : {args.tolerance:.1f}x")
reports, all_ok = validate_directory(csv_dir, tolerance=args.tolerance)
for op_type, report in reports.items():
_print_report(op_type, report, verbose=args.verbose)
print(f"\n{'='*70}")
total_rows = sum(len(r.rows) for r in reports.values())
total_pass = sum(r.pass_count for r in reports.values())
total_warn = sum(r.warn_count for r in reports.values())
total_fail = sum(r.fail_count for r in reports.values())
print(f" TOTAL: {total_rows} rows PASS={total_pass} WARN={total_warn} FAIL={total_fail}")
print(f" Result: {'OK' if all_ok else 'FAILED'}")
print(f"{'='*70}")
sys.exit(0 if all_ok else 1)
if __name__ == "__main__":
main()