import os
import tempfile
import pytest
def _reset_parallel_state_for_worker():
"""Reset global/singleton state inside spawned worker process.
pytest-xdist/torch.multiprocessing may execute multiple tests per worker.
Resetting avoids leaking a previously-created device mesh/parallelstate
"""
import mindspeed_mm.fsdp.distributed.parallel_state as ps_mod
ps_mod.get_device_type = lambda: "npu"
ps_mod._PARALLEL_STATE = None
from mindspeed_mm.fsdp.utils.decorators import Singleton
Singleton._instances = {}
def _init_pg(rank: int, world_size: int, init_file: str):
import torch.distributed as dist
import torch
if hasattr(torch, "npu"):
torch.npu.set_device(rank)
dist.init_process_group(
backend="hccl",
init_method=f"file://{init_file}",
rank=rank,
world_size=world_size
)
def _destroy_pg():
import torch.distributed as dist
if dist.is_initialized():
dist.destroy_process_group()
def _worker_l2_clip(rank: int, world_size: int, init_file: str):
pytest.importorskip("torch")
import torch
import torch.distributed as dist
try:
_init_pg(rank, world_size, init_file)
_reset_parallel_state_for_worker()
import mindspeed_mm.fsdp.distributed.parallel_state as ps_mod
ps_mod.init_parallel_state(
data_parallel_size=world_size,
fully_shard_parallel_size=1,
tensor_parallel_size=1,
ring_attention_size=1,
ulysses_parallel_size=1,
expert_parallel_size=1,
expert_fully_shard_parallel_size=1
)
import mindspeed_mm.fsdp.optimizer.clip_grad_norm as mod
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.nn.Parameter(torch.zeros(1))
self.b = torch.nn.Parameter(torch.zeros(1))
m = M()
m = m.to(torch.device("npu", rank))
if rank == 0:
m.w.grad = torch.tensor([3.0], device=torch.device("npu", rank))
m.b.grad = torch.tensor([4.0], device=torch.device("npu", rank))
else:
m.w.grad = torch.tensor([0.0], device=torch.device("npu", rank))
m.b.grad = torch.tensor([0.0], device=torch.device("npu", rank))
returned = mod.clip_grad_norm(m, max_norm=1.0, norm_type=2.0)
expected = torch.tensor(5.0, device=torch.device("npu", rank))
gathered = [torch.zeros_like(returned) for _ in range(world_size)]
dist.all_gather(gathered, returned)
assert all(torch.allclose(x, expected, atol=1e-6) for x in gathered)
if rank == 0:
assert torch.allclose(m.w.grad, torch.tensor([0.6], device=torch.device("npu", rank)), atol=1e-6)
assert torch.allclose(m.b.grad, torch.tensor([0.8], device=torch.device("npu", rank)), atol=1e-6)
else:
assert torch.allclose(m.w.grad, torch.tensor([0.0], device=torch.device("npu", rank)), atol=1e-6)
assert torch.allclose(m.b.grad, torch.tensor([0.0], device=torch.device("npu", rank)), atol=1e-6)
finally:
_destroy_pg()
def _worker_inf_norm(rank: int, world_size: int, init_file: str):
pytest.importorskip("torch")
import torch
import torch.distributed as dist
try:
_init_pg(rank, world_size, init_file)
_reset_parallel_state_for_worker()
import mindspeed_mm.fsdp.distributed.parallel_state as ps_mod
ps_mod.init_parallel_state(
data_parallel_size=world_size,
fully_shard_parallel_size=1,
tensor_parallel_size=1,
ring_attention_size=1,
ulysses_parallel_size=1,
expert_parallel_size=1,
expert_fully_shard_parallel_size=1
)
import mindspeed_mm.fsdp.optimizer.clip_grad_norm as mod
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = torch.nn.Parameter(torch.zeros(4))
m = M()
m = m.to(torch.device("npu", rank))
if rank == 0:
m.p.grad = torch.full((4,), 2.0, device=torch.device("npu", rank))
else:
m.p.grad = torch.full((4,), 5.0, device=torch.device("npu", rank))
before = m.p.grad.clone()
returned = mod.clip_grad_norm(m, max_norm=0.0, norm_type=float("inf"))
gathered = [torch.zeros_like(returned) for _ in range(world_size)]
dist.all_gather(gathered, returned)
assert all(torch.allclose(x, torch.tensor(5.0, device=torch.device("npu", rank)), atol=1e-6) for x in gathered)
assert torch.allclose(m.p.grad, before)
finally:
_destroy_pg()
def _worker_ep_path(rank: int, world_size: int, init_file: str):
pytest.importorskip("torch")
import torch
import torch.distributed as dist
try:
_init_pg(rank, world_size, init_file)
_reset_parallel_state_for_worker()
import mindspeed_mm.fsdp.distributed.parallel_state as ps_mod
ps = ps_mod.init_parallel_state(
data_parallel_size=world_size,
fully_shard_parallel_size=1,
tensor_parallel_size=1,
ring_attention_size=1,
ulysses_parallel_size=1,
expert_parallel_size=world_size,
expert_fully_shard_parallel_size=1,
)
assert ps.is_ep_enable() is True
import mindspeed_mm.fsdp.optimizer.clip_grad_norm as mod
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.ep = torch.nn.Parameter(torch.zeros(1))
self.non_ep = torch.nn.Parameter(torch.zeros(1))
self._ep_param_groups = {"ep": {self.ep}, "non_ep": {self.non_ep}}
m = M()
m = m.to(torch.device("npu", rank))
if rank == 0:
m.ep.grad = torch.tensor([3.0], device=torch.device("npu", rank))
m.non_ep.grad = torch.tensor([4.0], device=torch.device("npu", rank))
else:
m.ep.grad = torch.tensor([0.0], device=torch.device("npu", rank))
m.non_ep.grad = torch.tensor([0.0], device=torch.device("npu", rank))
returned = mod.clip_grad_norm(m, max_norm=1.0, norm_type=2.0)
gathered = [torch.zeros_like(returned) for _ in range(world_size)]
dist.all_gather(gathered, returned)
assert all(torch.allclose(x, torch.tensor(5.0, device=torch.device("npu", rank)), atol=1e-6) for x in gathered)
if rank == 0:
assert torch.allclose(m.ep.grad, torch.tensor([0.6], device=torch.device("npu", rank)), atol=1e-6)
assert torch.allclose(m.non_ep.grad, torch.tensor([0.8], device=torch.device("npu", rank)), atol=1e-6)
else:
assert torch.allclose(m.ep.grad, torch.tensor([0.0], device=torch.device("npu", rank)), atol=1e-6)
assert torch.allclose(m.non_ep.grad, torch.tensor([0.0], device=torch.device("npu", rank)), atol=1e-6)
finally:
_destroy_pg()
@pytest.mark.parametrize(
"worker",
[_worker_l2_clip, _worker_inf_norm, _worker_ep_path],
)
def test_clip_grad_norm_multi_process(worker):
""" verify real gloo all_reduce paths across 2 ranks
This test convers:
- non-EP path (FSDP group SUM reduction)
- inf-norm path (MAX reduction)
- EP-aware path (efsdp then ep reductions; shared clip coefficient)
"""
pytest.importorskip("torch")
import torch
import torch.multiprocessing as mp
if not hasattr(torch, "npu") or torch.npu.device_count() < 2:
pytest.skip("需要至少2卡NPU才能运行该用例")
world_size = 2
with tempfile.NamedTemporaryFile(delete=False) as f:
init_file = f.name
try:
mp.spawn(worker, args=(world_size, init_file), nprocs=world_size, join=True)
finally:
try:
os.remove(init_file)
except OSError:
pass