"""Tests for generate_comm_microbench.py.
Test strategy
-------------
Tests are split into two categories:
1. Pure-logic tests (no NPU required) — run in any environment:
resolve_topology_tier, build_group_for_tier, _apply_dispatch_overhead,
_active_iters_for_msg, _build_run_op (CPU path), _parse_kernel_comm_duration,
TestMainKernelPath, TestCommMicrobenchCli.
2. NPU integration tests — marked @pytest.mark.npu, skipped by default
(run with: pytest -m npu).
These verify real hardware behavior: actual kernel Duration values,
profiler output format, no_sync pipeline overlap, end-to-end CSV output.
They require torch_npu and a physical NPU device (world_size=1, single card).
dist is initialized automatically inside each test (hccl backend, rank=0, world_size=1).
"""
import csv
import inspect
import sys
from pathlib import Path
from unittest import mock
import pytest
pytest.importorskip("torch", reason="torch not installed")
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_DISPATCH_OVERHEAD,
PROFILER_ACTIVE_ITERS,
PROFILER_ACTIVE_ITERS_LARGE,
PROFILER_LARGE_MSG_THRESHOLD,
_active_iters_for_msg,
_apply_dispatch_overhead,
_build_run_op,
_parse_kernel_comm_duration,
_run_bench_event,
_run_bench_kernel,
build_group_for_tier,
resolve_topology_tier,
run_benchmark,
)
class TestResolveTopologyTier:
"""Verify tier resolution matches CommAnalyticModel logic for ATLAS_800_A3."""
GRID = [48, 8, 2]
def test_nd16_intra_pod_tier1(self):
assert resolve_topology_tier(list(range(16)), self.GRID) == 1
def test_nd8_intra_pod_tier1(self):
assert resolve_topology_tier(list(range(8)), self.GRID) == 1
def test_nd2_die_level_tier2(self):
assert resolve_topology_tier([0, 1], self.GRID) == 2
def test_nd4_spans_nodes_tier1(self):
assert resolve_topology_tier([0, 1, 2, 3], self.GRID) == 1
def test_nd128_inter_pod_tier0(self):
assert resolve_topology_tier(list(range(128)), self.GRID) == 0
def test_single_rank_returns_innermost(self):
assert resolve_topology_tier([5], self.GRID) == 2
class TestBuildGroupForTier:
"""Verify group construction is anchored correctly to the tier."""
GRID = [48, 8, 2]
def test_tier1_nd16_from_rank0(self):
group = build_group_for_tier(0, 16, 1, self.GRID)
assert group == list(range(16))
def test_tier1_nd8_from_rank0(self):
group = build_group_for_tier(0, 8, 1, self.GRID)
assert group == list(range(8))
def test_tier2_nd2_from_rank0(self):
group = build_group_for_tier(0, 2, 2, self.GRID)
assert group == [0, 1]
def test_group_resolves_back_to_same_tier(self):
for nd, tier in [(16, 1), (8, 1), (2, 2)]:
group = build_group_for_tier(0, nd, tier, self.GRID)
assert resolve_topology_tier(group, self.GRID) == tier
def test_exceeds_span_raises(self):
with pytest.raises(ValueError, match="exceeds span size"):
build_group_for_tier(0, 4, 2, self.GRID)
class TestApplyDispatchOverhead:
"""Verify overhead correction is applied correctly."""
def _row(self, op_type, nd, duration_us, msg_bytes=1048576):
return {
"message_bytes": msg_bytes,
"num_devices": nd,
"dtype": "DT_BF16",
"topology_tier": 1,
"Duration(us)": duration_us,
"bandwidth_gbps": round(msg_bytes / (duration_us * 1e-6) / 1e9, 2),
}
def test_all_known_entries_applied(self):
"""Every entry in _DISPATCH_OVERHEAD must increase Duration(us) by its value."""
for (op_type, nd), overhead in _DISPATCH_OVERHEAD.items():
row = self._row(op_type, nd, 100.0)
result = _apply_dispatch_overhead(row, op_type)
assert result["Duration(us)"] == pytest.approx(100.0 + overhead, abs=0.01), (
f"op={op_type} nd={nd}: expected {100.0 + overhead}"
)
def test_bandwidth_recalculated_after_overhead(self):
row = self._row("all_gather", 16, 100.0, msg_bytes=1048576)
result = _apply_dispatch_overhead(row, "all_gather")
overhead = _DISPATCH_OVERHEAD[("all_gather", 16)]
expected_dur = 100.0 + overhead
expected_bw = round(1048576 / (expected_dur * 1e-6) / 1e9, 2)
assert result["bandwidth_gbps"] == pytest.approx(expected_bw, abs=0.01)
def test_no_overhead_for_missing_op(self):
row = self._row("all_to_all", 16, 100.0)
result = _apply_dispatch_overhead(row, "all_to_all")
assert result is row
def test_no_overhead_for_unknown_nd(self):
row = self._row("all_reduce", 4, 100.0)
result = _apply_dispatch_overhead(row, "all_reduce")
assert result is row
def test_original_row_not_mutated(self):
row = self._row("all_reduce", 16, 100.0)
original_dur = row["Duration(us)"]
_apply_dispatch_overhead(row, "all_reduce")
assert row["Duration(us)"] == original_dur
class TestActiveItersForMsg:
"""Verify small/large message threshold routing for profiler active iterations."""
def test_small_msg_returns_full_active(self):
assert _active_iters_for_msg(65536) == PROFILER_ACTIVE_ITERS
def test_just_below_threshold_returns_full_active(self):
assert _active_iters_for_msg(PROFILER_LARGE_MSG_THRESHOLD - 1) == PROFILER_ACTIVE_ITERS
def test_at_threshold_returns_one(self):
assert _active_iters_for_msg(PROFILER_LARGE_MSG_THRESHOLD) == PROFILER_ACTIVE_ITERS_LARGE
def test_above_threshold_returns_one(self):
assert _active_iters_for_msg(PROFILER_LARGE_MSG_THRESHOLD * 4) == PROFILER_ACTIVE_ITERS_LARGE
class TestBuildRunOp:
"""Verify _build_run_op constructs callable closures for all op types (CPU path)."""
def _group(self, nd):
return None
def test_all_ops_return_callable(self):
for op_type in ["all_reduce", "all_gather", "reduce_scatter", "all_to_all"]:
run_op = _build_run_op(
op_type,
65536,
"torch.bfloat16",
"cpu",
group=None,
group_ranks=list(range(4)),
)
assert callable(run_op), f"{op_type} should return a callable"
def test_all_reduce_tensor_shape(self):
import torch
run_op = _build_run_op(
"all_reduce",
65536,
"torch.bfloat16",
"cpu",
group=None,
group_ranks=[0],
)
closure_vars = {
cell.cell_contents
for cell in run_op.__closure__
if hasattr(cell, "cell_contents") and isinstance(cell.cell_contents, torch.Tensor)
}
shapes = [t.shape for t in closure_vars]
assert any(s == torch.Size([32768]) for s in shapes), f"Expected tensor of shape [32768], got {shapes}"
def _npu_dist_init():
"""Initialize dist with hccl backend on npu:0 (world_size=1) if not already done.
Uses FileStore to avoid requiring MASTER_ADDR/MASTER_PORT env vars.
"""
import os
import tempfile
import torch
import torch.distributed as dist
import torch_npu
if not dist.is_initialized():
torch.npu.set_device(0)
store_path = os.path.join(tempfile.gettempdir(), "npu_test_store")
store = dist.FileStore(store_path, 1)
dist.init_process_group(backend="hccl", rank=0, world_size=1, store=store)
rank = dist.get_rank()
world_size = dist.get_world_size()
device = f"npu:{rank}"
return rank, world_size, device
def test_run_bench_profiler_batch_signature():
"""_run_bench_profiler_batch must accept parse_fn and no_sync parameters."""
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_run_bench_profiler_batch,
)
sig = inspect.signature(_run_bench_profiler_batch)
assert "parse_fn" in sig.parameters
assert sig.parameters["parse_fn"].default is None
assert "no_sync" in sig.parameters
assert sig.parameters["no_sync"].default is False
@pytest.mark.npu
def test_run_bench_event_returns_positive_duration():
"""_run_bench_event on real NPU should return positive duration in µs."""
import torch.distributed as dist
rank, world_size, device = _npu_dist_init()
import torch
tensor = torch.zeros(1024, dtype=torch.bfloat16, device=device)
def run_op():
dist.all_reduce(tensor)
result = _run_bench_event(run_op, is_npu=True)
assert isinstance(result, float)
assert result > 0.0, f"Expected positive duration, got {result}"
@pytest.mark.npu
def test_run_bench_kernel_leader_returns_positive_duration():
"""_run_bench_kernel on real NPU should return positive median duration (world_size=1)."""
import torch.distributed as dist
rank, world_size, device = _npu_dist_init()
import torch
tensor = torch.zeros(1024, dtype=torch.bfloat16, device=device)
def run_op():
dist.all_reduce(tensor)
result = _run_bench_kernel(run_op, "all_reduce", is_npu=True, is_leader=True)
assert result is not None, "Leader should return a duration"
assert result > 0.0, f"Expected positive duration, got {result}"
@pytest.mark.npu
def test_run_bench_kernel_no_sync_returns_positive_duration():
"""_run_bench_kernel with no_sync=True (HCCL pipeline overlap) should return positive duration."""
import torch.distributed as dist
rank, world_size, device = _npu_dist_init()
import torch
tensor = torch.zeros(1024, dtype=torch.bfloat16, device=device)
def run_op():
dist.all_reduce(tensor)
result = _run_bench_kernel(run_op, "all_reduce", is_npu=True, is_leader=True, no_sync=True)
assert result is not None, "Leader (no_sync) should return a duration"
assert result > 0.0, f"Expected positive duration, got {result}"
@pytest.mark.npu
def test_run_bench_profiler_batch_small_msg():
"""_run_bench_profiler_batch small msg (<512KB): leader returns dict with positive duration."""
import torch.distributed as dist
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_run_bench_profiler_batch,
)
rank, world_size, device = _npu_dist_init()
msg_bytes = 65536
results = _run_bench_profiler_batch(
op_type="all_reduce",
msg_bytes_list=[msg_bytes],
dtype_str="torch.bfloat16",
device=device,
group=dist.group.WORLD,
group_ranks=list(range(world_size)),
is_npu=True,
is_leader=True,
parse_fn=None,
no_sync=True,
)
assert isinstance(results, dict), "Should return a dict"
assert msg_bytes in results, f"Result missing key {msg_bytes}"
assert results[msg_bytes] > 0.0, f"Expected positive duration, got {results[msg_bytes]}"
@pytest.mark.npu
def test_run_bench_profiler_batch_large_msg():
"""_run_bench_profiler_batch large msg (>=512KB, active=1 path): returns positive duration."""
import torch.distributed as dist
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_run_bench_profiler_batch,
)
rank, world_size, device = _npu_dist_init()
msg_bytes = PROFILER_LARGE_MSG_THRESHOLD
results = _run_bench_profiler_batch(
op_type="all_reduce",
msg_bytes_list=[msg_bytes],
dtype_str="torch.bfloat16",
device=device,
group=dist.group.WORLD,
group_ranks=list(range(world_size)),
is_npu=True,
is_leader=True,
parse_fn=None,
no_sync=True,
)
assert isinstance(results, dict), "Should return a dict"
assert msg_bytes in results, f"Result missing key {msg_bytes}"
assert results[msg_bytes] > 0.0, f"Expected positive duration, got {results[msg_bytes]}"
@pytest.mark.npu
def test_run_benchmark_kernel_mode_writes_csv(tmp_path):
"""run_benchmark kernel mode: writes valid CSV row with positive Duration and bandwidth."""
rank, world_size, device = _npu_dist_init()
group_ranks = list(range(world_size))
grid_shape = [48, 8, 2]
tier = resolve_topology_tier(group_ranks, grid_shape)
csv_path = str(tmp_path / "out.csv")
result = run_benchmark(
op_type="all_reduce",
message_bytes=65536,
group_ranks=group_ranks,
topology_tier=tier,
dtype_str="torch.bfloat16",
output_csv=csv_path,
bench_mode="kernel",
)
assert result is not None, "Should return a result dict"
assert result["Duration(us)"] > 0.0
assert result["bandwidth_gbps"] > 0.0
assert result["message_bytes"] == 65536
assert Path(csv_path).exists(), "CSV file should have been written"
with open(csv_path, encoding="utf-8") as f:
rows = list(csv.DictReader(f))
assert len(rows) == 1
assert float(rows[0]["Duration(us)"]) > 0.0
def _write_kernel_details(path: Path, rows: list):
"""Write a minimal kernel_details.csv with Type, Name, Duration(us) columns."""
with path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=["Type", "Name", "Duration(us)"])
writer.writeheader()
writer.writerows(rows)
class TestParseKernelCommDuration:
"""Verify AivKernel deduplication and CSV parsing logic."""
def _make_prof_dir(self, tmp_path, rows):
"""Create a fake profiler output directory with kernel_details.csv."""
csv_path = tmp_path / "kernel_details.csv"
_write_kernel_details(csv_path, rows)
return str(tmp_path)
def test_returns_durations_for_matching_op(self, tmp_path):
"""hcom_allReduce_ rows without AivKernel should be returned."""
rows = [
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_0",
"Duration(us)": "100.0",
},
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_1",
"Duration(us)": "120.0",
},
]
prof_dir = self._make_prof_dir(tmp_path, rows)
result = _parse_kernel_comm_duration(prof_dir, "all_reduce")
assert result == [100.0, 120.0]
def test_excludes_aivkernel_rows(self, tmp_path):
"""Rows with 'AivKernel' in Name must be filtered out (deduplication)."""
rows = [
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_0",
"Duration(us)": "100.0",
},
{
"Type": "hcom_allReduce_",
"Name": "AivKernel_hcom_allReduce_",
"Duration(us)": "100.0",
},
]
prof_dir = self._make_prof_dir(tmp_path, rows)
result = _parse_kernel_comm_duration(prof_dir, "all_reduce")
assert result == [100.0]
def test_excludes_zero_duration_rows(self, tmp_path):
"""Rows with Duration=0 must be excluded (spurious profiler entries)."""
rows = [
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_0",
"Duration(us)": "0",
},
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_1",
"Duration(us)": "150.0",
},
]
prof_dir = self._make_prof_dir(tmp_path, rows)
result = _parse_kernel_comm_duration(prof_dir, "all_reduce")
assert result == [150.0]
def test_ignores_other_op_types(self, tmp_path):
"""Rows for a different op type must not appear in results."""
rows = [
{
"Type": "hcom_allGather_",
"Name": "hcom_allGather_0",
"Duration(us)": "200.0",
},
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_0",
"Duration(us)": "100.0",
},
]
prof_dir = self._make_prof_dir(tmp_path, rows)
result = _parse_kernel_comm_duration(prof_dir, "all_reduce")
assert result == [100.0]
def test_empty_directory_returns_empty_list(self, tmp_path):
"""No kernel_details.csv in directory → return [] without raising."""
result = _parse_kernel_comm_duration(str(tmp_path), "all_reduce")
assert result == []
def test_nested_csv_is_found(self, tmp_path):
"""kernel_details.csv nested in subdirectory should be discovered."""
nested = tmp_path / "rank0" / "profiler_output"
nested.mkdir(parents=True)
rows = [
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_0",
"Duration(us)": "80.0",
},
]
_write_kernel_details(nested / "kernel_details.csv", rows)
result = _parse_kernel_comm_duration(str(tmp_path), "all_reduce")
assert result == [80.0]
class TestMainKernelPath:
"""Source-inspection regression tests for main() kernel branch structure.
These guard against accidental removal of the kernel-mode code path, which
is a CANN constraint: kernel mode must use a batch profiler session rather
than per-point profiler restarts to avoid ring-buffer pressure.
"""
def _main_source(self):
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
return inspect.getsource(mod.main)
def test_kernel_branch_exists(self):
"""main() must have an explicit NPU branch via _has_torch_npu()."""
assert "_has_torch_npu()" in self._main_source(), (
"main() must use _has_torch_npu() to gate the NPU profiler path (CANN constraint)"
)
def test_kernel_branch_uses_parse_kernel_fn(self):
"""Kernel branch must pass _parse_kernel_comm_duration as parse_fn."""
assert "_parse_kernel_comm_duration" in self._main_source(), (
"kernel batch branch must use _parse_kernel_comm_duration "
"to parse kernel_details.csv instead of operator_details.csv"
)
def test_profiler_batch_returns_empty_on_no_durations(self):
"""_run_bench_profiler_batch must return {} (not raise) when parse returns empty."""
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
source = inspect.getsource(mod._run_bench_profiler_batch)
assert "return {}" in source, "_run_bench_profiler_batch must return empty dict for tolerant error handling"
assert "if not durations:" in source
def test_has_torch_npu_returns_false_without_torch_npu():
"""_has_torch_npu must return False when torch_npu is not importable."""
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_has_torch_npu,
)
def _mock_import(name, *args, **kwargs):
if name == "torch_npu":
raise ImportError
return mock.DEFAULT
with mock.patch("builtins.__import__", side_effect=_mock_import):
assert _has_torch_npu() is False
class TestIterConfigs:
"""Verify config generation produces the expected (op, msg, nd, tier, ranks) tuples."""
GRID = [48, 8, 2]
OPS = ["all_reduce", "all_gather"]
BYTES_GRID = [1024, 65536, 1048576]
def test_generates_all_combinations(self):
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_iter_configs,
)
configs = _iter_configs(
self.OPS,
[16],
topology_tiers=[1],
grid_shape=self.GRID,
bytes_grid=self.BYTES_GRID,
)
assert len(configs) == 6
types = {c[0] for c in configs}
assert types == {"all_reduce", "all_gather"}
tiers = {c[3] for c in configs}
assert tiers == {1}
def test_auto_resolve_tier(self):
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_iter_configs,
)
configs = _iter_configs(
self.OPS,
[16],
topology_tiers=None,
grid_shape=self.GRID,
bytes_grid=self.BYTES_GRID,
)
assert len(configs) > 0
for _, _, _, tier, _ in configs:
assert tier == 1
def test_auto_resolve_with_small_group_tier2(self):
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_iter_configs,
)
configs = _iter_configs(
self.OPS,
[2],
topology_tiers=None,
grid_shape=self.GRID,
bytes_grid=self.BYTES_GRID,
)
assert len(configs) > 0
for _, _, _, tier, _ in configs:
assert tier == 2
def test_multiple_num_devices_and_tiers(self):
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_iter_configs,
)
configs = _iter_configs(
["all_reduce"],
[16, 2],
topology_tiers=[1, 2],
grid_shape=self.GRID,
bytes_grid=[1024, 65536],
)
assert len(configs) == 6
nd_tier_pairs = {(c[2], c[3]) for c in configs}
assert (16, 1) in nd_tier_pairs
assert (2, 2) in nd_tier_pairs
def test_skips_unreachable_group(self, capsys):
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_iter_configs,
)
configs = _iter_configs(
["all_reduce"],
[4],
topology_tiers=[2],
grid_shape=self.GRID,
bytes_grid=[1024],
)
assert len(configs) == 0
stderr = capsys.readouterr().err
assert "WARNING" in stderr
assert "exceeds span size" in stderr
def test_build_argparser_exposes_database_path_and_removes_legacy_flags():
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
build_argparser,
)
parser = build_argparser()
args = parser.parse_args(["--database-path", "db"])
assert args.database_path == "db"
assert not hasattr(args, "bench_mode")
assert not hasattr(args, "run")
with pytest.raises(SystemExit):
parser.parse_args(["--output-dir", "db"])
with pytest.raises(SystemExit):
parser.parse_args(["--do-run"])
class TestCommMicrobenchCli:
def test_database_path_replaces_output_dir_and_do_run(self):
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
parser = mod.build_argparser()
args = parser.parse_args(
[
"--database-path",
"db",
"--ops",
"all_reduce",
"--num-devices",
"16",
"2",
"--grid-shape",
"48",
"8",
"2",
]
)
assert args.database_path == "db"
assert args.ops == ["all_reduce"]
assert args.num_devices == [16, 2]
assert args.grid_shape == [48, 8, 2]
assert not hasattr(args, "bench_mode")
assert not hasattr(args, "run")
with pytest.raises(SystemExit):
parser.parse_args(["--output-dir", "db"])
with pytest.raises(SystemExit):
parser.parse_args(["--do-run"])
@pytest.mark.skipif(
sys.platform == "win32",
reason="bash stub (chmod 0o755 + #!/bin/bash) is not portable to Windows CI",
)
class TestRunCommBenchShellMultiNode:
"""Smoke test for the multi-node (inter-pod) branch of run_comm_bench.sh.
The shell script is a thin dispatcher; this single test asserts the
NNODES>=2 path actually calls torchrun with the expected multi-node
flags and forwards --topology-tier 0 to the Python script.
Strategy: stub torchrun on PATH to record argv, then inspect calls.
"""
def test_multinode_dispatches_inter_pod_torchrun(self, tmp_path):
import os
import subprocess
import textwrap
repo_root = Path(__file__).resolve().parents[3]
script = repo_root / "tools" / "perf_data_collection" / "comm_bench" / "run_comm_bench.sh"
stub_dir = tmp_path / "bin"
stub_dir.mkdir()
log_file = tmp_path / "torchrun.log"
stub = stub_dir / "torchrun"
stub.write_text(
textwrap.dedent(f"""\
#!/bin/bash
for a in "$@"; do
printf '%s\\n' "$a" >> "{log_file}"
done
printf '%s\\n' '---END---' >> "{log_file}"
exit 0
""")
)
stub.chmod(0o755)
env = os.environ.copy()
env["PATH"] = f"{stub_dir}:{env['PATH']}"
env.update(
{
"NNODES": "2",
"NODE_RANK": "1",
"MASTER_ADDR": "127.0.0.1",
"QUICK": "1",
}
)
proc = subprocess.run(
["bash", str(script), str(tmp_path / "out")],
env=env,
capture_output=True,
text=True,
timeout=30,
)
assert proc.returncode == 0, proc.stderr
calls, current = [], []
for line in log_file.read_text().splitlines():
if line == "---END---":
if current:
calls.append(current)
current = []
else:
current.append(line)
assert len(calls) == 3, f"expected 3 inter-pod rounds, got {len(calls)}"
for argv in calls:
assert "--nnodes=2" in argv
assert "--node_rank=1" in argv
assert "--master_addr=127.0.0.1" in argv
assert "--topology-tier" in argv
assert argv[argv.index("--topology-tier") + 1] == "0"
assert argv[argv.index("--num-devices") + 1] == "32"
def test_multinode_aborts_when_world_size_below_min_group(self, tmp_path):
"""world_size < 32 yields an empty ND_LIST; the script must abort.
Regression for the case NPROC=1, NNODES=2 (world_size=2): the
reachable-group loop starts at 32, so ND_LIST is empty and a bare
``--num-devices`` would otherwise be forwarded to torchrun. The
guard must exit non-zero with a clear error and never call torchrun.
"""
import os
import subprocess
import textwrap
repo_root = Path(__file__).resolve().parents[3]
script = repo_root / "tools" / "perf_data_collection" / "comm_bench" / "run_comm_bench.sh"
stub_dir = tmp_path / "bin"
stub_dir.mkdir()
log_file = tmp_path / "torchrun.log"
stub = stub_dir / "torchrun"
stub.write_text(
textwrap.dedent(f"""\
#!/bin/bash
printf 'CALLED\\n' >> "{log_file}"
exit 0
""")
)
stub.chmod(0o755)
env = os.environ.copy()
env["PATH"] = f"{stub_dir}:{env['PATH']}"
env.update(
{
"NNODES": "2",
"NODE_RANK": "0",
"MASTER_ADDR": "127.0.0.1",
"NPROC": "1",
}
)
proc = subprocess.run(
["bash", str(script), str(tmp_path / "out")],
env=env,
capture_output=True,
text=True,
timeout=30,
)
assert proc.returncode == 1, proc.stdout + proc.stderr
assert "empty device list" in proc.stderr
assert "WORLD_SIZE=2" in proc.stderr
assert not log_file.exists(), "torchrun was called despite empty ND_LIST"
def test_multinode_nd_list_scales_past_legacy_ceiling(self, tmp_path):
"""ND_LIST is generated dynamically up to WORLD_SIZE, not capped at 768.
Regression for the hardcoded ``32 64 128 256 384 512 768`` sequence:
a 1024-rank cluster must collect a 1024 group, and the list forwarded
to ``--num-devices`` must be ascending, unique, and bounded by
WORLD_SIZE (no value above it).
"""
import os
import subprocess
import textwrap
repo_root = Path(__file__).resolve().parents[3]
script = repo_root / "tools" / "perf_data_collection" / "comm_bench" / "run_comm_bench.sh"
stub_dir = tmp_path / "bin"
stub_dir.mkdir()
log_file = tmp_path / "torchrun.log"
stub = stub_dir / "torchrun"
stub.write_text(
textwrap.dedent(f"""\
#!/bin/bash
for a in "$@"; do
printf '%s\\n' "$a" >> "{log_file}"
done
printf '%s\\n' '---END---' >> "{log_file}"
exit 0
""")
)
stub.chmod(0o755)
env = os.environ.copy()
env["PATH"] = f"{stub_dir}:{env['PATH']}"
env.update(
{
"NNODES": "64",
"NODE_RANK": "0",
"MASTER_ADDR": "127.0.0.1",
"QUICK": "1",
}
)
proc = subprocess.run(
["bash", str(script), str(tmp_path / "out")],
env=env,
capture_output=True,
text=True,
timeout=30,
)
assert proc.returncode == 0, proc.stderr
argv = log_file.read_text().splitlines()
start = argv.index("--num-devices") + 1
end = argv.index("--topology-tier")
nd = [int(x) for x in argv[start:end]]
assert nd == [32, 64, 128, 256, 384, 512, 768, 1024], nd
assert max(nd) == 1024
assert all(v <= 1024 for v in nd)
assert nd == sorted(nd)
assert len(nd) == len(set(nd))
def test_multinode_port_is_base_plus_round_index(self, tmp_path):
"""Each round's --master_port is MASTER_PORT base + round index (1..3).
Regression for the fragile ``MASTER_PORT=$((MASTER_PORT+1))`` outer-var
mutation: ports must be derived from base + idx so the dispatch is
subshell-safe and the three rounds use distinct, ordered ports.
"""
import os
import subprocess
import textwrap
repo_root = Path(__file__).resolve().parents[3]
script = repo_root / "tools" / "perf_data_collection" / "comm_bench" / "run_comm_bench.sh"
stub_dir = tmp_path / "bin"
stub_dir.mkdir()
log_file = tmp_path / "torchrun.log"
stub = stub_dir / "torchrun"
stub.write_text(
textwrap.dedent(f"""\
#!/bin/bash
for a in "$@"; do
printf '%s\\n' "$a" >> "{log_file}"
done
printf '%s\\n' '---END---' >> "{log_file}"
exit 0
""")
)
stub.chmod(0o755)
env = os.environ.copy()
env["PATH"] = f"{stub_dir}:{env['PATH']}"
env.update(
{
"NNODES": "2",
"NODE_RANK": "0",
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": "30000",
"QUICK": "1",
}
)
proc = subprocess.run(
["bash", str(script), str(tmp_path / "out")],
env=env,
capture_output=True,
text=True,
timeout=30,
)
assert proc.returncode == 0, proc.stderr
calls, current = [], []
for line in log_file.read_text().splitlines():
if line == "---END---":
if current:
calls.append(current)
current = []
else:
current.append(line)
ports = [a.split("=", 1)[1] for c in calls for a in c if a.startswith("--master_port=")]
assert ports == ["30001", "30002", "30003"], ports
class TestMainControlFlow:
"""Verify main() dispatches to the correct branch and calls expected helpers.
The NPU path (``if _has_torch_npu():``) uses
``_run_bench_profiler_batch`` and splits messages by size.
The CPU path (``else:``) uses ``run_benchmark(bench_mode="event")``.
All tests mock ``torch.distributed`` functions directly to avoid
requiring a torchrun environment, and use synthetic ``_iter_configs``
output to control what configs are processed.
"""
SMALL_BYTES = 65536
LARGE_BYTES = 1048576
@staticmethod
def _patch_dist():
"""Patch torch.distributed functions for single-process testing.
Returns (context_manager, mock_dist) where mock_dist is the
torch.distributed module with patched functions.
"""
import torch.distributed as dist_module
initialized = [False]
def _init_pg(*args, **kwargs):
initialized[0] = True
def _is_init():
return initialized[0]
return mock.patch.multiple(
dist_module,
is_initialized=mock.MagicMock(side_effect=_is_init),
init_process_group=mock.MagicMock(side_effect=_init_pg),
get_rank=mock.MagicMock(return_value=0),
get_world_size=mock.MagicMock(return_value=1),
new_group=mock.MagicMock(return_value=mock.MagicMock()),
barrier=mock.MagicMock(return_value=None),
destroy_process_group=mock.MagicMock(return_value=None),
)
@staticmethod
def _make_configs(ops=None, num_devices=16, bytes_list=None):
"""Build synthetic _iter_configs output.
Returns list of (op_type, msg_bytes, nd, tier, group_ranks) tuples.
"""
ops = ops or ["all_reduce"]
bytes_list = bytes_list or [65536]
configs = []
for op in ops:
group_ranks = list(range(num_devices))
for mb in bytes_list:
configs.append((op, mb, num_devices, 1, group_ranks))
return configs
def test_cpu_fallback_calls_run_benchmark(self):
"""When _has_torch_npu()=False, CPU path calls run_benchmark(bench_mode="event")."""
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
configs = self._make_configs(bytes_list=[self.SMALL_BYTES])
with (
self._patch_dist(),
mock.patch.object(mod, "_has_torch_npu", return_value=False),
mock.patch.object(mod, "_iter_configs", return_value=configs),
mock.patch.object(mod, "build_argparser"),
mock.patch.object(mod, "print_logo"),
mock.patch.object(mod, "run_benchmark") as mock_run,
mock.patch("sys.exit", side_effect=SystemExit(1)),
):
mod.main()
assert mock_run.call_count == 2 * len(configs), (
f"Expected {2 * len(configs)} calls, got {mock_run.call_count}"
)
for call_args in mock_run.call_args_list:
assert call_args.kwargs["bench_mode"] == "event"
def test_cpu_fallback_passes_correct_csv(self):
"""CPU fallback passes _csv_for_op result as output_csv to run_benchmark."""
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
configs = self._make_configs(bytes_list=[self.SMALL_BYTES])
with (
self._patch_dist(),
mock.patch.object(mod, "_has_torch_npu", return_value=False),
mock.patch.object(mod, "_iter_configs", return_value=configs),
mock.patch.object(mod, "build_argparser"),
mock.patch.object(mod, "print_logo"),
mock.patch.object(mod, "run_benchmark") as mock_run,
mock.patch("sys.exit", side_effect=SystemExit(1)),
):
mod.main()
for i, call_args in enumerate(mock_run.call_args_list):
assert call_args.kwargs["bench_mode"] == "event"
if i < len(configs):
assert call_args.kwargs["output_csv"] is None
def test_npu_path_calls_profiler_batch_for_small_msgs(self):
"""NPU path with small messages uses _run_bench_profiler_batch."""
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
parser = mod.build_argparser()
args = parser.parse_args(["--ops", "all_reduce", "--num-devices", "16"])
configs = self._make_configs(bytes_list=[self.SMALL_BYTES])
with (
self._patch_dist(),
mock.patch.object(mod, "_has_torch_npu", return_value=True),
mock.patch.object(mod, "_iter_configs", return_value=configs),
mock.patch.object(mod, "build_argparser", return_value=parser),
mock.patch.object(parser, "parse_args", return_value=args),
mock.patch.object(mod, "print_logo"),
mock.patch.object(mod, "run_benchmark"),
mock.patch.object(mod, "_run_bench_profiler_batch", return_value={self.SMALL_BYTES: 100.0}),
mock.patch.object(mod, "_append_csv"),
mock.patch("sys.exit", side_effect=SystemExit(1)),
mock.patch("torch.npu", create=True),
mock.patch.dict("sys.modules", {"torch_npu": mock.MagicMock()}),
):
mod.main()
mod._run_bench_profiler_batch.assert_called_once()
pos_args, kw_args = mod._run_bench_profiler_batch.call_args
assert pos_args[0] == "all_reduce"
assert self.SMALL_BYTES in pos_args[1]
assert kw_args["parse_fn"] is not None
assert kw_args["no_sync"] is True
def test_npu_path_calls_profiler_batch_for_large_msgs(self):
"""NPU path with large messages uses per-msg _run_bench_profiler_batch sessions."""
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
configs = self._make_configs(bytes_list=[self.LARGE_BYTES])
with (
self._patch_dist(),
mock.patch.object(mod, "_has_torch_npu", return_value=True),
mock.patch.object(mod, "_iter_configs", return_value=configs),
mock.patch.object(mod, "build_argparser"),
mock.patch.object(mod, "print_logo"),
mock.patch.object(mod, "run_benchmark"),
mock.patch.object(mod, "_run_bench_profiler_batch", return_value={self.LARGE_BYTES: 500.0}),
mock.patch.object(mod, "_append_csv"),
mock.patch("sys.exit", side_effect=SystemExit(1)),
mock.patch("torch.npu", create=True),
mock.patch.dict("sys.modules", {"torch_npu": mock.MagicMock()}),
):
mod.main()
expected_calls = mod.PROFILER_LARGE_MSG_SESSIONS
assert mod._run_bench_profiler_batch.call_count >= expected_calls
def test_npu_path_mixed_small_and_large(self):
"""NPU path processes small msgs in batch and large msgs per-session."""
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
configs = self._make_configs(
bytes_list=[self.SMALL_BYTES, self.LARGE_BYTES],
)
with (
self._patch_dist(),
mock.patch.object(mod, "_has_torch_npu", return_value=True),
mock.patch.object(mod, "_iter_configs", return_value=configs),
mock.patch.object(mod, "build_argparser"),
mock.patch.object(mod, "print_logo"),
mock.patch.object(mod, "run_benchmark"),
mock.patch.object(
mod, "_run_bench_profiler_batch", return_value={self.SMALL_BYTES: 100.0, self.LARGE_BYTES: 500.0}
),
mock.patch.object(mod, "_append_csv"),
mock.patch("sys.exit", side_effect=SystemExit(1)),
mock.patch("torch.npu", create=True),
mock.patch.dict("sys.modules", {"torch_npu": mock.MagicMock()}),
):
mod.main()
expected_batch_calls = 1 + mod.PROFILER_LARGE_MSG_SESSIONS
assert mod._run_bench_profiler_batch.call_count >= expected_batch_calls
def test_npu_path_writes_csv_for_small_msgs(self):
"""NPU path writes CSV via _append_csv for small msg results."""
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
configs = self._make_configs(bytes_list=[self.SMALL_BYTES])
with (
self._patch_dist(),
mock.patch.object(mod, "_has_torch_npu", return_value=True),
mock.patch.object(mod, "_iter_configs", return_value=configs),
mock.patch.object(mod, "build_argparser"),
mock.patch.object(mod, "print_logo"),
mock.patch.object(mod, "run_benchmark"),
mock.patch.object(mod, "_run_bench_profiler_batch", return_value={self.SMALL_BYTES: 100.0}),
mock.patch.object(mod, "_append_csv") as mock_append,
mock.patch("sys.exit", side_effect=SystemExit(1)),
mock.patch("torch.npu", create=True),
mock.patch.dict("sys.modules", {"torch_npu": mock.MagicMock()}),
):
mod.main()
mock_append.assert_called()
args_call, _ = mock_append.call_args
row = args_call[1]
assert row["Duration(us)"] == 100.0
assert row["message_bytes"] == self.SMALL_BYTES
def test_npu_path_writes_csv_for_large_msgs(self):
"""NPU path writes CSV via _append_csv for large msg results."""
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
configs = self._make_configs(bytes_list=[self.LARGE_BYTES])
with (
self._patch_dist(),
mock.patch.object(mod, "_has_torch_npu", return_value=True),
mock.patch.object(mod, "_iter_configs", return_value=configs),
mock.patch.object(mod, "build_argparser"),
mock.patch.object(mod, "print_logo"),
mock.patch.object(mod, "run_benchmark"),
mock.patch.object(mod, "_run_bench_profiler_batch", return_value={self.LARGE_BYTES: 500.0}),
mock.patch.object(mod, "_append_csv") as mock_append,
mock.patch("sys.exit", side_effect=SystemExit(1)),
mock.patch("torch.npu", create=True),
mock.patch.dict("sys.modules", {"torch_npu": mock.MagicMock()}),
):
mod.main()
mock_append.assert_called()
args_call, _ = mock_append.call_args
row = args_call[1]
assert row["Duration(us)"] == 500.0
assert row["message_bytes"] == self.LARGE_BYTES
def test_npu_path_calls_dist_barrier(self):
"""NPU path calls dist.barrier() at the end of each batch."""
import torch.distributed as dist_module
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
configs = self._make_configs(bytes_list=[self.SMALL_BYTES])
with (
self._patch_dist(),
mock.patch.object(mod, "_has_torch_npu", return_value=True),
mock.patch.object(mod, "_iter_configs", return_value=configs),
mock.patch.object(mod, "build_argparser"),
mock.patch.object(mod, "print_logo"),
mock.patch.object(mod, "run_benchmark"),
mock.patch.object(mod, "_run_bench_profiler_batch", return_value={self.SMALL_BYTES: 100.0}),
mock.patch.object(mod, "_append_csv"),
mock.patch("sys.exit", side_effect=SystemExit(1)),
mock.patch("torch.npu", create=True),
mock.patch.dict("sys.modules", {"torch_npu": mock.MagicMock()}),
):
mod.main()
dist_module.barrier.assert_called()
def test_npu_path_calls_destroy_process_group(self):
"""NPU path calls dist.destroy_process_group() at the end."""
import torch.distributed as dist_module
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
configs = self._make_configs(bytes_list=[self.SMALL_BYTES])
with (
self._patch_dist(),
mock.patch.object(mod, "_has_torch_npu", return_value=True),
mock.patch.object(mod, "_iter_configs", return_value=configs),
mock.patch.object(mod, "build_argparser"),
mock.patch.object(mod, "print_logo"),
mock.patch.object(mod, "run_benchmark"),
mock.patch.object(mod, "_run_bench_profiler_batch", return_value={self.SMALL_BYTES: 100.0}),
mock.patch.object(mod, "_append_csv"),
mock.patch("sys.exit", side_effect=SystemExit(1)),
mock.patch("torch.npu", create=True),
mock.patch.dict("sys.modules", {"torch_npu": mock.MagicMock()}),
):
mod.main()
dist_module.destroy_process_group.assert_called_once()