"""Tests for generate_comm_microbench.py.
Test strategy
-------------
Tests are split into two categories:
1. Pure-logic tests (no NPU required) — run in any environment:
resolve_topology_tier, build_group_for_tier, _apply_dispatch_overhead,
_active_iters_for_msg, _build_run_op (CPU path), _parse_kernel_comm_duration,
TestMainKernelPath, TestMainNoDoRun.
2. NPU integration tests — marked @pytest.mark.npu, skipped by default
(run with: pytest -m npu).
These verify real hardware behavior: actual kernel Duration values,
profiler output format, no_sync pipeline overlap, end-to-end CSV output.
They require torch_npu and a physical NPU device (world_size=1, single card).
dist is initialized automatically inside each test (hccl backend, rank=0, world_size=1).
"""
import csv
import inspect
import sys
from pathlib import Path
from unittest.mock import patch
import pytest
pytest.importorskip("torch", reason="torch not installed")
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_DISPATCH_OVERHEAD,
PROFILER_ACTIVE_ITERS,
PROFILER_ACTIVE_ITERS_LARGE,
PROFILER_LARGE_MSG_THRESHOLD,
_active_iters_for_msg,
_apply_dispatch_overhead,
_build_run_op,
_parse_kernel_comm_duration,
_run_bench_event,
_run_bench_kernel,
build_group_for_tier,
resolve_topology_tier,
run_benchmark,
)
class TestResolveTopologyTier:
"""Verify tier resolution matches CommAnalyticModel logic for ATLAS_800_A3."""
GRID = [48, 8, 2]
def test_nd16_intra_pod_tier1(self):
assert resolve_topology_tier(list(range(16)), self.GRID) == 1
def test_nd8_intra_pod_tier1(self):
assert resolve_topology_tier(list(range(8)), self.GRID) == 1
def test_nd2_die_level_tier2(self):
assert resolve_topology_tier([0, 1], self.GRID) == 2
def test_nd4_spans_nodes_tier1(self):
assert resolve_topology_tier([0, 1, 2, 3], self.GRID) == 1
def test_nd128_inter_pod_tier0(self):
assert resolve_topology_tier(list(range(128)), self.GRID) == 0
def test_single_rank_returns_innermost(self):
assert resolve_topology_tier([5], self.GRID) == 2
class TestBuildGroupForTier:
"""Verify group construction is anchored correctly to the tier."""
GRID = [48, 8, 2]
def test_tier1_nd16_from_rank0(self):
group = build_group_for_tier(0, 16, 1, self.GRID)
assert group == list(range(16))
def test_tier1_nd8_from_rank0(self):
group = build_group_for_tier(0, 8, 1, self.GRID)
assert group == list(range(8))
def test_tier2_nd2_from_rank0(self):
group = build_group_for_tier(0, 2, 2, self.GRID)
assert group == [0, 1]
def test_group_resolves_back_to_same_tier(self):
for nd, tier in [(16, 1), (8, 1), (2, 2)]:
group = build_group_for_tier(0, nd, tier, self.GRID)
assert resolve_topology_tier(group, self.GRID) == tier
def test_exceeds_span_raises(self):
with pytest.raises(ValueError, match="exceeds span size"):
build_group_for_tier(0, 4, 2, self.GRID)
class TestApplyDispatchOverhead:
"""Verify overhead correction is applied correctly."""
def _row(self, op_type, nd, duration_us, msg_bytes=1048576):
return {
"message_bytes": msg_bytes,
"num_devices": nd,
"dtype": "DT_BF16",
"topology_tier": 1,
"Duration(us)": duration_us,
"bandwidth_gbps": round(msg_bytes / (duration_us * 1e-6) / 1e9, 2),
}
def test_all_known_entries_applied(self):
"""Every entry in _DISPATCH_OVERHEAD must increase Duration(us) by its value."""
for (op_type, nd), overhead in _DISPATCH_OVERHEAD.items():
row = self._row(op_type, nd, 100.0)
result = _apply_dispatch_overhead(row, op_type)
assert result["Duration(us)"] == pytest.approx(100.0 + overhead, abs=0.01), (
f"op={op_type} nd={nd}: expected {100.0 + overhead}"
)
def test_bandwidth_recalculated_after_overhead(self):
row = self._row("all_gather", 16, 100.0, msg_bytes=1048576)
result = _apply_dispatch_overhead(row, "all_gather")
overhead = _DISPATCH_OVERHEAD[("all_gather", 16)]
expected_dur = 100.0 + overhead
expected_bw = round(1048576 / (expected_dur * 1e-6) / 1e9, 2)
assert result["bandwidth_gbps"] == pytest.approx(expected_bw, abs=0.01)
def test_no_overhead_for_missing_op(self):
row = self._row("all_to_all", 16, 100.0)
result = _apply_dispatch_overhead(row, "all_to_all")
assert result is row
def test_no_overhead_for_unknown_nd(self):
row = self._row("all_reduce", 4, 100.0)
result = _apply_dispatch_overhead(row, "all_reduce")
assert result is row
def test_original_row_not_mutated(self):
row = self._row("all_reduce", 16, 100.0)
original_dur = row["Duration(us)"]
_apply_dispatch_overhead(row, "all_reduce")
assert row["Duration(us)"] == original_dur
class TestActiveItersForMsg:
"""Verify small/large message threshold routing for profiler active iterations."""
def test_small_msg_returns_full_active(self):
assert _active_iters_for_msg(65536) == PROFILER_ACTIVE_ITERS
def test_just_below_threshold_returns_full_active(self):
assert _active_iters_for_msg(PROFILER_LARGE_MSG_THRESHOLD - 1) == PROFILER_ACTIVE_ITERS
def test_at_threshold_returns_one(self):
assert _active_iters_for_msg(PROFILER_LARGE_MSG_THRESHOLD) == PROFILER_ACTIVE_ITERS_LARGE
def test_above_threshold_returns_one(self):
assert _active_iters_for_msg(PROFILER_LARGE_MSG_THRESHOLD * 4) == PROFILER_ACTIVE_ITERS_LARGE
class TestBuildRunOp:
"""Verify _build_run_op constructs callable closures for all op types (CPU path)."""
def _group(self, nd):
return None
def test_all_ops_return_callable(self):
for op_type in ["all_reduce", "all_gather", "reduce_scatter", "all_to_all"]:
run_op = _build_run_op(
op_type,
65536,
"torch.bfloat16",
"cpu",
group=None,
group_ranks=list(range(4)),
)
assert callable(run_op), f"{op_type} should return a callable"
def test_all_reduce_tensor_shape(self):
import torch
run_op = _build_run_op(
"all_reduce",
65536,
"torch.bfloat16",
"cpu",
group=None,
group_ranks=[0],
)
closure_vars = {
cell.cell_contents
for cell in run_op.__closure__
if hasattr(cell, "cell_contents") and isinstance(cell.cell_contents, torch.Tensor)
}
shapes = [t.shape for t in closure_vars]
assert any(s == torch.Size([32768]) for s in shapes), f"Expected tensor of shape [32768], got {shapes}"
def _npu_dist_init():
"""Initialize dist with hccl backend on npu:0 (world_size=1) if not already done.
Uses FileStore to avoid requiring MASTER_ADDR/MASTER_PORT env vars.
"""
import os
import tempfile
import torch
import torch.distributed as dist
import torch_npu
if not dist.is_initialized():
torch.npu.set_device(0)
store_path = os.path.join(tempfile.gettempdir(), "npu_test_store")
store = dist.FileStore(store_path, 1)
dist.init_process_group(backend="hccl", rank=0, world_size=1, store=store)
rank = dist.get_rank()
world_size = dist.get_world_size()
device = f"npu:{rank}"
return rank, world_size, device
def test_run_bench_profiler_batch_signature():
"""_run_bench_profiler_batch must accept parse_fn and no_sync parameters."""
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_run_bench_profiler_batch,
)
sig = inspect.signature(_run_bench_profiler_batch)
assert "parse_fn" in sig.parameters
assert sig.parameters["parse_fn"].default is None
assert "no_sync" in sig.parameters
assert sig.parameters["no_sync"].default is False
@pytest.mark.npu
def test_run_bench_event_returns_positive_duration():
"""_run_bench_event on real NPU should return positive duration in µs."""
import torch.distributed as dist
rank, world_size, device = _npu_dist_init()
import torch
tensor = torch.zeros(1024, dtype=torch.bfloat16, device=device)
def run_op():
dist.all_reduce(tensor)
result = _run_bench_event(run_op, is_npu=True)
assert isinstance(result, float)
assert result > 0.0, f"Expected positive duration, got {result}"
@pytest.mark.npu
def test_run_bench_kernel_leader_returns_positive_duration():
"""_run_bench_kernel on real NPU should return positive median duration (world_size=1)."""
import torch.distributed as dist
rank, world_size, device = _npu_dist_init()
import torch
tensor = torch.zeros(1024, dtype=torch.bfloat16, device=device)
def run_op():
dist.all_reduce(tensor)
result = _run_bench_kernel(run_op, "all_reduce", is_npu=True, is_leader=True)
assert result is not None, "Leader should return a duration"
assert result > 0.0, f"Expected positive duration, got {result}"
@pytest.mark.npu
def test_run_bench_kernel_no_sync_returns_positive_duration():
"""_run_bench_kernel with no_sync=True (HCCL pipeline overlap) should return positive duration."""
import torch.distributed as dist
rank, world_size, device = _npu_dist_init()
import torch
tensor = torch.zeros(1024, dtype=torch.bfloat16, device=device)
def run_op():
dist.all_reduce(tensor)
result = _run_bench_kernel(run_op, "all_reduce", is_npu=True, is_leader=True, no_sync=True)
assert result is not None, "Leader (no_sync) should return a duration"
assert result > 0.0, f"Expected positive duration, got {result}"
@pytest.mark.npu
def test_run_bench_profiler_batch_small_msg():
"""_run_bench_profiler_batch small msg (<512KB): leader returns dict with positive duration."""
import torch.distributed as dist
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_run_bench_profiler_batch,
)
rank, world_size, device = _npu_dist_init()
msg_bytes = 65536
results = _run_bench_profiler_batch(
op_type="all_reduce",
msg_bytes_list=[msg_bytes],
dtype_str="torch.bfloat16",
device=device,
group=dist.group.WORLD,
group_ranks=list(range(world_size)),
is_npu=True,
is_leader=True,
parse_fn=None,
no_sync=True,
)
assert isinstance(results, dict), "Should return a dict"
assert msg_bytes in results, f"Result missing key {msg_bytes}"
assert results[msg_bytes] > 0.0, f"Expected positive duration, got {results[msg_bytes]}"
@pytest.mark.npu
def test_run_bench_profiler_batch_large_msg():
"""_run_bench_profiler_batch large msg (>=512KB, active=1 path): returns positive duration."""
import torch.distributed as dist
from tools.perf_data_collection.comm_bench.generate_comm_microbench import (
_run_bench_profiler_batch,
)
rank, world_size, device = _npu_dist_init()
msg_bytes = PROFILER_LARGE_MSG_THRESHOLD
results = _run_bench_profiler_batch(
op_type="all_reduce",
msg_bytes_list=[msg_bytes],
dtype_str="torch.bfloat16",
device=device,
group=dist.group.WORLD,
group_ranks=list(range(world_size)),
is_npu=True,
is_leader=True,
parse_fn=None,
no_sync=True,
)
assert isinstance(results, dict), "Should return a dict"
assert msg_bytes in results, f"Result missing key {msg_bytes}"
assert results[msg_bytes] > 0.0, f"Expected positive duration, got {results[msg_bytes]}"
@pytest.mark.npu
def test_run_benchmark_kernel_mode_writes_csv(tmp_path):
"""run_benchmark kernel mode: writes valid CSV row with positive Duration and bandwidth."""
rank, world_size, device = _npu_dist_init()
group_ranks = list(range(world_size))
grid_shape = [48, 8, 2]
tier = resolve_topology_tier(group_ranks, grid_shape)
csv_path = str(tmp_path / "out.csv")
result = run_benchmark(
op_type="all_reduce",
message_bytes=65536,
group_ranks=group_ranks,
topology_tier=tier,
dtype_str="torch.bfloat16",
output_csv=csv_path,
bench_mode="kernel",
)
assert result is not None, "Should return a result dict"
assert result["Duration(us)"] > 0.0
assert result["bandwidth_gbps"] > 0.0
assert result["message_bytes"] == 65536
assert Path(csv_path).exists(), "CSV file should have been written"
with open(csv_path, encoding="utf-8") as f:
rows = list(csv.DictReader(f))
assert len(rows) == 1
assert float(rows[0]["Duration(us)"]) > 0.0
def _write_kernel_details(path: Path, rows: list):
"""Write a minimal kernel_details.csv with Type, Name, Duration(us) columns."""
with path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=["Type", "Name", "Duration(us)"])
writer.writeheader()
writer.writerows(rows)
class TestParseKernelCommDuration:
"""Verify AivKernel deduplication and CSV parsing logic."""
def _make_prof_dir(self, tmp_path, rows):
"""Create a fake profiler output directory with kernel_details.csv."""
csv_path = tmp_path / "kernel_details.csv"
_write_kernel_details(csv_path, rows)
return str(tmp_path)
def test_returns_durations_for_matching_op(self, tmp_path):
"""hcom_allReduce_ rows without AivKernel should be returned."""
rows = [
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_0",
"Duration(us)": "100.0",
},
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_1",
"Duration(us)": "120.0",
},
]
prof_dir = self._make_prof_dir(tmp_path, rows)
result = _parse_kernel_comm_duration(prof_dir, "all_reduce")
assert result == [100.0, 120.0]
def test_excludes_aivkernel_rows(self, tmp_path):
"""Rows with 'AivKernel' in Name must be filtered out (deduplication)."""
rows = [
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_0",
"Duration(us)": "100.0",
},
{
"Type": "hcom_allReduce_",
"Name": "AivKernel_hcom_allReduce_",
"Duration(us)": "100.0",
},
]
prof_dir = self._make_prof_dir(tmp_path, rows)
result = _parse_kernel_comm_duration(prof_dir, "all_reduce")
assert result == [100.0]
def test_excludes_zero_duration_rows(self, tmp_path):
"""Rows with Duration=0 must be excluded (spurious profiler entries)."""
rows = [
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_0",
"Duration(us)": "0",
},
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_1",
"Duration(us)": "150.0",
},
]
prof_dir = self._make_prof_dir(tmp_path, rows)
result = _parse_kernel_comm_duration(prof_dir, "all_reduce")
assert result == [150.0]
def test_ignores_other_op_types(self, tmp_path):
"""Rows for a different op type must not appear in results."""
rows = [
{
"Type": "hcom_allGather_",
"Name": "hcom_allGather_0",
"Duration(us)": "200.0",
},
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_0",
"Duration(us)": "100.0",
},
]
prof_dir = self._make_prof_dir(tmp_path, rows)
result = _parse_kernel_comm_duration(prof_dir, "all_reduce")
assert result == [100.0]
def test_empty_directory_returns_empty_list(self, tmp_path):
"""No kernel_details.csv in directory → return [] without raising."""
result = _parse_kernel_comm_duration(str(tmp_path), "all_reduce")
assert result == []
def test_nested_csv_is_found(self, tmp_path):
"""kernel_details.csv nested in subdirectory should be discovered."""
nested = tmp_path / "rank0" / "profiler_output"
nested.mkdir(parents=True)
rows = [
{
"Type": "hcom_allReduce_",
"Name": "hcom_allReduce_0",
"Duration(us)": "80.0",
},
]
_write_kernel_details(nested / "kernel_details.csv", rows)
result = _parse_kernel_comm_duration(str(tmp_path), "all_reduce")
assert result == [80.0]
class TestMainKernelPath:
"""Source-inspection regression tests for main() kernel branch structure.
These guard against accidental removal of the kernel-mode code path, which
is a CANN constraint: kernel mode must use a batch profiler session rather
than per-point profiler restarts to avoid ring-buffer pressure.
"""
def _main_source(self):
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
return inspect.getsource(mod.main)
def test_kernel_branch_exists(self):
"""main() must have an explicit 'bench_mode == kernel' branch."""
assert 'bench_mode == "kernel"' in self._main_source(), (
"main() must have an explicit kernel branch to avoid per-point profiler restart (CANN constraint)"
)
def test_kernel_branch_uses_parse_kernel_fn(self):
"""Kernel branch must pass _parse_kernel_comm_duration as parse_fn."""
assert "_parse_kernel_comm_duration" in self._main_source(), (
"kernel batch branch must use _parse_kernel_comm_duration "
"to parse kernel_details.csv instead of operator_details.csv"
)
def test_profiler_batch_returns_empty_on_no_durations(self):
"""_run_bench_profiler_batch must return {} (not raise) when parse returns empty."""
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
source = inspect.getsource(mod._run_bench_profiler_batch)
assert "return {}" in source, "_run_bench_profiler_batch must return empty dict for tolerant error handling"
assert "if not durations:" in source
class TestMainNoDoRun:
def test_exits_with_error_when_no_do_run(self, capsys):
"""main() must exit with code 1 and print error when --do-run is not provided."""
import argparse
from tools.perf_data_collection.comm_bench import (
generate_comm_microbench as mod,
)
with patch.object(mod, "build_argparser") as mock_parser:
args = argparse.Namespace(
run=False,
ops=["all_reduce"],
num_devices=[16],
topology_tier=None,
grid_shape=[48, 8, 2],
bytes_grid=None,
dtype="torch.bfloat16",
bench_mode="kernel",
output_csv=None,
output_dir=None,
)
mock_parser.return_value.parse_args.return_value = args
with pytest.raises(SystemExit) as exc_info:
mod.main()
assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "--do-run is required" in captured.err
@pytest.mark.skipif(
sys.platform == "win32",
reason="bash stub (chmod 0o755 + #!/bin/bash) is not portable to Windows CI",
)
class TestRunCommBenchShellMultiNode:
"""Smoke test for the multi-node (inter-pod) branch of run_comm_bench.sh.
The shell script is a thin dispatcher; this single test asserts the
NNODES>=2 path actually calls torchrun with the expected multi-node
flags and forwards --topology-tier 0 to the Python script.
Strategy: stub torchrun on PATH to record argv, then inspect calls.
"""
def test_multinode_dispatches_inter_pod_torchrun(self, tmp_path):
import os
import subprocess
import textwrap
repo_root = Path(__file__).resolve().parents[3]
script = repo_root / "tools" / "perf_data_collection" / "comm_bench" / "run_comm_bench.sh"
stub_dir = tmp_path / "bin"
stub_dir.mkdir()
log_file = tmp_path / "torchrun.log"
stub = stub_dir / "torchrun"
stub.write_text(
textwrap.dedent(f"""\
#!/bin/bash
for a in "$@"; do
printf '%s\\n' "$a" >> "{log_file}"
done
printf '%s\\n' '---END---' >> "{log_file}"
exit 0
""")
)
stub.chmod(0o755)
env = os.environ.copy()
env["PATH"] = f"{stub_dir}:{env['PATH']}"
env.update(
{
"NNODES": "2",
"NODE_RANK": "1",
"MASTER_ADDR": "127.0.0.1",
"QUICK": "1",
}
)
proc = subprocess.run(
["bash", str(script), str(tmp_path / "out")],
env=env,
capture_output=True,
text=True,
timeout=30,
)
assert proc.returncode == 0, proc.stderr
calls, current = [], []
for line in log_file.read_text().splitlines():
if line == "---END---":
if current:
calls.append(current)
current = []
else:
current.append(line)
assert len(calls) == 3, f"expected 3 inter-pod rounds, got {len(calls)}"
for argv in calls:
assert "--nnodes=2" in argv
assert "--node_rank=1" in argv
assert "--master_addr=127.0.0.1" in argv
assert "--topology-tier" in argv
assert argv[argv.index("--topology-tier") + 1] == "0"
assert argv[argv.index("--num-devices") + 1] == "32"
def test_multinode_aborts_when_world_size_below_min_group(self, tmp_path):
"""world_size < 32 yields an empty ND_LIST; the script must abort.
Regression for the case NPROC=1, NNODES=2 (world_size=2): the
reachable-group loop starts at 32, so ND_LIST is empty and a bare
``--num-devices`` would otherwise be forwarded to torchrun. The
guard must exit non-zero with a clear error and never call torchrun.
"""
import os
import subprocess
import textwrap
repo_root = Path(__file__).resolve().parents[3]
script = repo_root / "tools" / "perf_data_collection" / "comm_bench" / "run_comm_bench.sh"
stub_dir = tmp_path / "bin"
stub_dir.mkdir()
log_file = tmp_path / "torchrun.log"
stub = stub_dir / "torchrun"
stub.write_text(
textwrap.dedent(f"""\
#!/bin/bash
printf 'CALLED\\n' >> "{log_file}"
exit 0
""")
)
stub.chmod(0o755)
env = os.environ.copy()
env["PATH"] = f"{stub_dir}:{env['PATH']}"
env.update(
{
"NNODES": "2",
"NODE_RANK": "0",
"MASTER_ADDR": "127.0.0.1",
"NPROC": "1",
}
)
proc = subprocess.run(
["bash", str(script), str(tmp_path / "out")],
env=env,
capture_output=True,
text=True,
timeout=30,
)
assert proc.returncode == 1, proc.stdout + proc.stderr
assert "empty device list" in proc.stderr
assert "WORLD_SIZE=2" in proc.stderr
assert not log_file.exists(), "torchrun was called despite empty ND_LIST"
def test_multinode_nd_list_scales_past_legacy_ceiling(self, tmp_path):
"""ND_LIST is generated dynamically up to WORLD_SIZE, not capped at 768.
Regression for the hardcoded ``32 64 128 256 384 512 768`` sequence:
a 1024-rank cluster must collect a 1024 group, and the list forwarded
to ``--num-devices`` must be ascending, unique, and bounded by
WORLD_SIZE (no value above it).
"""
import os
import subprocess
import textwrap
repo_root = Path(__file__).resolve().parents[3]
script = repo_root / "tools" / "perf_data_collection" / "comm_bench" / "run_comm_bench.sh"
stub_dir = tmp_path / "bin"
stub_dir.mkdir()
log_file = tmp_path / "torchrun.log"
stub = stub_dir / "torchrun"
stub.write_text(
textwrap.dedent(f"""\
#!/bin/bash
for a in "$@"; do
printf '%s\\n' "$a" >> "{log_file}"
done
printf '%s\\n' '---END---' >> "{log_file}"
exit 0
""")
)
stub.chmod(0o755)
env = os.environ.copy()
env["PATH"] = f"{stub_dir}:{env['PATH']}"
env.update(
{
"NNODES": "64",
"NODE_RANK": "0",
"MASTER_ADDR": "127.0.0.1",
"QUICK": "1",
}
)
proc = subprocess.run(
["bash", str(script), str(tmp_path / "out")],
env=env,
capture_output=True,
text=True,
timeout=30,
)
assert proc.returncode == 0, proc.stderr
argv = log_file.read_text().splitlines()
start = argv.index("--num-devices") + 1
end = argv.index("--topology-tier")
nd = [int(x) for x in argv[start:end]]
assert nd == [32, 64, 128, 256, 384, 512, 768, 1024], nd
assert max(nd) == 1024
assert all(v <= 1024 for v in nd)
assert nd == sorted(nd)
assert len(nd) == len(set(nd))
def test_multinode_port_is_base_plus_round_index(self, tmp_path):
"""Each round's --master_port is MASTER_PORT base + round index (1..3).
Regression for the fragile ``MASTER_PORT=$((MASTER_PORT+1))`` outer-var
mutation: ports must be derived from base + idx so the dispatch is
subshell-safe and the three rounds use distinct, ordered ports.
"""
import os
import subprocess
import textwrap
repo_root = Path(__file__).resolve().parents[3]
script = repo_root / "tools" / "perf_data_collection" / "comm_bench" / "run_comm_bench.sh"
stub_dir = tmp_path / "bin"
stub_dir.mkdir()
log_file = tmp_path / "torchrun.log"
stub = stub_dir / "torchrun"
stub.write_text(
textwrap.dedent(f"""\
#!/bin/bash
for a in "$@"; do
printf '%s\\n' "$a" >> "{log_file}"
done
printf '%s\\n' '---END---' >> "{log_file}"
exit 0
""")
)
stub.chmod(0o755)
env = os.environ.copy()
env["PATH"] = f"{stub_dir}:{env['PATH']}"
env.update(
{
"NNODES": "2",
"NODE_RANK": "0",
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": "30000",
"QUICK": "1",
}
)
proc = subprocess.run(
["bash", str(script), str(tmp_path / "out")],
env=env,
capture_output=True,
text=True,
timeout=30,
)
assert proc.returncode == 0, proc.stderr
calls, current = [], []
for line in log_file.read_text().splitlines():
if line == "---END---":
if current:
calls.append(current)
current = []
else:
current.append(line)
ports = [a.split("=", 1)[1] for c in calls for a in c if a.startswith("--master_port=")]
assert ports == ["30001", "30002", "30003"], ports