"""Distributed test framework for Ascend NPU tests."""
from __future__ import annotations
import inspect
import os
import time
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from _pytest.outcomes import Skipped
from _pytest.fixtures import FixtureLookupError
from utils import npu_available, npu_device_count, get_free_port
def _resolve_world_size(world_size_attr, fixture_kwargs):
"""Return the concrete world_size from a class attribute or fixture kwargs."""
if world_size_attr is not None:
return world_size_attr
if "world_size" in fixture_kwargs:
return fixture_kwargs["world_size"]
if "sharding_dims" in fixture_kwargs:
return fixture_kwargs["sharding_dims"][0]
return 2
class DistributedTest:
"""Base class for distributed tests.
Usage::
class TestExample(DistributedTest):
world_size = 4
def test_something(self):
rank = int(os.environ["LOCAL_RANK"])
...
Features:
* ``@pytest.mark.parametrize`` on test methods.
* ``@pytest.mark.world_size(N)`` to override *world_size* per test.
* ``class_tmpdir`` — a per-class temporary directory shared across ranks.
Attributes:
backend: Distributed backend (default ``"hccl"``).
timeout: Seconds before a worker invocation is considered hung.
world_size: Default world size. Set to ``None`` when the world
size should be resolved dynamically from fixture kwargs
(e.g. a ``sharding_dims`` parametrization).
reuse_dist_env: If ``True``, the process pool is reused across
tests within the same class.
"""
is_dist_test: bool = True
world_size: int | None = 2
backend: str = "hccl"
timeout: int = 600
reuse_dist_env: bool = True
_pool_cache: dict[int, mp.Pool] = {}
@pytest.fixture(autouse=True, scope="class")
def class_tmpdir(self, tmpdir_factory):
"""Per-class temporary directory shared across all ranks."""
return tmpdir_factory.mktemp(self.__class__.__name__)
def run(self, **fixture_kwargs):
"""Execute the intercepted test method with resolved fixtures.
Subclasses may override this when they need custom pre/post
logic around the test call.
"""
getattr(self, self._current_test_name)(**fixture_kwargs)
def __call__(self, request=None):
self._current_test_name = request.function.__name__
test_func = getattr(self, self._current_test_name)
self._fixture_kwargs = self._get_fixture_kwargs(request, test_func)
if not npu_available():
pytest.skip("only supported in accelerator environments.")
world_size = _resolve_world_size(self.world_size, self._fixture_kwargs)
for mark in getattr(request.function, "pytestmark", []):
if mark.name == "world_size":
world_size = mark.args[0]
break
if isinstance(world_size, int):
world_size = [world_size]
for procs in world_size:
self._launch_procs(procs)
time.sleep(0.5)
def _get_fixture_kwargs(self, request, func):
if not request:
return {}
fixture_kwargs = {}
params = inspect.getfullargspec(func).args
params.remove("self")
for p in params:
try:
fixture_kwargs[p] = request.getfixturevalue(p)
except FixtureLookupError:
pass
return fixture_kwargs
def _launch_procs(self, num_procs):
if npu_device_count() < num_procs:
pytest.skip(f"Not enough NPUs: {num_procs} required, {npu_device_count()} available")
mp.set_start_method("forkserver", force=True)
if self.reuse_dist_env:
if num_procs not in self._pool_cache:
self._pool_cache[num_procs] = mp.Pool(processes=num_procs)
pool = self._pool_cache[num_procs]
else:
pool = mp.Pool(processes=num_procs)
master_port = get_free_port()
args = [(local_rank, num_procs, master_port) for local_rank in range(num_procs)]
skip_msgs_async = pool.starmap_async(self._dist_run, args)
try:
skip_msgs = skip_msgs_async.get(self.timeout)
except mp.TimeoutError:
pytest.exit("Distributed test timed out, exiting", returncode=1)
self._close_pool(pool, num_procs)
non_empty = [m for m in skip_msgs if m]
if non_empty:
assert len(set(non_empty)) == 1, (
f"Multiple different skip messages received: {non_empty}"
)
pytest.skip(non_empty[0])
def _dist_run(self, local_rank, num_procs, master_port):
"""Initialize the distributed environment and execute the user function."""
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(num_procs)
os.environ["LOCAL_SIZE"] = str(num_procs)
torch.npu.set_device(local_rank % torch.npu.device_count())
if not dist.is_initialized():
dist.init_process_group(
backend=self.backend,
world_size=num_procs,
rank=local_rank,
init_method=f"tcp://127.0.0.1:{master_port}",
)
skip_msg = ""
try:
self.run(**self._fixture_kwargs)
except BaseException as e:
if isinstance(e, Skipped):
skip_msg = getattr(e, "msg", "")
else:
raise
return skip_msg
def _dist_destroy(self):
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
def _close_pool(self, pool, num_procs, force=False):
if force or not self.reuse_dist_env:
pool.starmap(self._dist_destroy, [() for _ in range(num_procs)])
pool.close()
pool.join()