"""
Distributed testing utilities for smoke tests.
This module provides utilities for distributed testing following
PyTorch and torchtitan patterns.
Key patterns from PyTorch/torchtitan:
1. DTensorTestBase - Base class for distributed tests
2. @with_comms decorator - Sets up process groups
3. world_size property - Specify number of ranks
Usage:
# For single-rank tests (no distributed needed)
def test_something(npu_device):
model = nn.Linear(64, 64).to(npu_device)
...
# For multi-rank distributed tests
class TestDistributed(DTensorTestBase):
@property
def world_size(self):
return 4
@with_comms
def test_dtensor(self):
mesh = init_device_mesh("npu", (4,))
...
"""
import functools
import os
import unittest
from collections.abc import Callable
from contextlib import contextmanager
import pytest
import torch
import torch.distributed as dist
class DistributedTestBase(unittest.TestCase):
"""
Base class for distributed tests following PyTorch patterns.
This class provides:
- Automatic distributed environment setup/teardown
- world_size property for specifying number of ranks
- device_type property for backend selection
Similar to torch.testing._internal.distributed._tensor.common_dtensor.DTensorTestBase
"""
@property
def world_size(self) -> int:
"""Number of ranks for distributed tests. Override in subclasses."""
return 1
@property
def device_type(self) -> str:
"""Device type for distributed tests. Returns 'npu' or 'cpu'."""
if hasattr(torch, "npu") and torch.npu.is_available():
return "npu"
return "cpu"
@property
def backend(self) -> str:
"""Distributed backend to use."""
if self.device_type == "npu":
return "nccl"
return "gloo"
def setUp(self):
"""Set up distributed environment."""
super().setUp()
self.ensure_distributed_initialized()
def tearDown(self):
"""Tear down distributed environment."""
if dist.is_initialized():
dist.destroy_process_group()
super().tearDown()
def ensure_distributed_initialized(self):
"""Initialize distributed communication when it is not ready."""
if not dist.is_initialized():
self._init_distributed()
def _init_distributed(self):
"""Initialize distributed process group."""
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", 0))
else:
rank = 0
world_size = 1
local_rank = 0
if world_size > 1:
os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ.setdefault("MASTER_PORT", "29500")
dist.init_process_group(
backend=self.backend,
rank=rank,
world_size=world_size,
)
if self.device_type == "npu":
torch.npu.set_device(local_rank)
def with_comms(func: Callable) -> Callable:
"""
Decorator to set up distributed communication for tests.
Similar to torch.testing._internal.distributed._tensor.common_dtensor.with_comms
This decorator:
1. Initializes distributed process group before the test
2. Cleans up after the test
Usage:
class MyTest(DistributedTestBase):
@property
def world_size(self):
return 4
@with_comms
def test_something(self):
# Distributed environment is ready
mesh = init_device_mesh(self.device_type, (self.world_size,))
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
self.ensure_distributed_initialized()
try:
return func(self, *args, **kwargs)
finally:
if dist.is_initialized():
dist.destroy_process_group()
return wrapper
def distributed_available() -> bool:
"""Check if distributed testing is available."""
return dist.is_available()
def npu_available() -> bool:
"""Check if NPU is available for testing."""
return hasattr(torch, "npu") and torch.npu.is_available()
def skip_on_runtime_unsupported(
error: RuntimeError, unsupported_markers: tuple[str, ...], reason: str
):
"""Skip smoke tests when the current runtime/SOC does not support the requested op shape or dtype."""
message = str(error)
if any(marker in message for marker in unsupported_markers):
pytest.skip(reason)
raise error
@contextmanager
def distributed_context(
rank: int,
world_size: int,
backend: str = "gloo",
init_method: str = "tcp://localhost:29500",
):
"""
Context manager for distributed testing.
Usage:
with distributed_context(rank=0, world_size=1) as ctx:
model = nn.Linear(64, 64)
# Run distributed code
"""
try:
dist.init_process_group(
backend=backend,
init_method=init_method,
rank=rank,
world_size=world_size,
)
yield {"rank": rank, "world_size": world_size}
finally:
if dist.is_initialized():
dist.destroy_process_group()
def get_device_mesh(
mesh_dims: tuple,
mesh_dim_names: tuple | None = None,
):
"""
Create a device mesh for distributed testing.
Args:
mesh_dims: Tuple of mesh dimensions, e.g., (2, 4) for 2x4 mesh
mesh_dim_names: Optional names for each dimension
Returns:
DeviceMesh object
"""
from torch.distributed.device_mesh import init_device_mesh
device = "npu" if npu_available() else "cpu"
if mesh_dim_names:
return init_device_mesh(device, mesh_dims, mesh_dim_names=mesh_dim_names)
return init_device_mesh(device, mesh_dims)
def create_dtensor(
tensor: torch.Tensor,
mesh,
placements,
):
"""
Create a distributed tensor from a local tensor.
Args:
tensor: Local tensor to distribute
mesh: DeviceMesh to distribute over
placements: Placements for distribution (e.g., [Shard(0)])
Returns:
DTensor
"""
from torch.distributed.tensor import distribute_tensor
return distribute_tensor(tensor, mesh, placements)
def requires_distributed(func: Callable) -> Callable:
"""Decorator to skip tests if distributed is not available."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not distributed_available():
pytest.skip("Distributed not available")
return func(*args, **kwargs)
return wrapper
def requires_npu(func: Callable) -> Callable:
"""Decorator to skip tests if NPU is not available."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not npu_available():
pytest.skip("NPU not available")
return func(*args, **kwargs)
return wrapper
def requires_world_size(min_size: int):
"""
Decorator factory to skip tests if world_size is less than required.
Usage:
@requires_world_size(4)
def test_tp_4(self):
# Requires at least 4 GPUs
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if self.world_size < min_size:
pytest.skip(f"Requires world_size >= {min_size}, got {self.world_size}")
return func(self, *args, **kwargs)
return wrapper
return decorator