import torch
import torch.distributed as dist
from torch.testing._internal.common_fsdp import FSDPTest
import torch_npu
torch.testing._internal.common_fsdp.DEVICE_TYPE = "npu"
torch.testing._internal.common_fsdp.DISTRIBUTED_BACKEND = "hccl"
torch.testing._internal.common_fsdp.DEVICE_COUNT = torch.npu.device_count()
class FSDPNPUTest(FSDPTest):
@classmethod
def _run(cls, rank, test_name, file_name, pipe, **kwargs):
self = cls(test_name)
self.rank = rank
self.file_name = file_name
fake_pg = kwargs.get("fake_pg", False)
print(f"dist init r={self.rank}, world={self.world_size}")
try:
if fake_pg:
store = torch.testing._internal.distributed.fake_pg.FakeStore()
dist.init_process_group(
backend="fake",
world_size=self.world_size,
rank=rank,
store=store,
)
else:
dist.init_process_group(
init_method=self.init_method,
backend="hccl",
world_size=int(self.world_size),
rank=self.rank,
)
except RuntimeError as e:
if "recompile" in e.args[0]:
sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
raise
device_ids = None
device_id = self.rank % torch.npu.device_count()
torch.npu.set_device(device_id)
device_ids = [device_id]
dist.barrier(device_ids=device_ids)
torch._dynamo.reset()
self.run_test(test_name, pipe)
torch._dynamo.reset()
dist.barrier(device_ids=device_ids)
dist.destroy_process_group()