import unittest
from unittest.mock import patch
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU
class HcclSendRecvDistTest(TestCase):
world_size = 2
@classmethod
def _init_dist_hccl(cls, rank, world_size):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
os.environ['HCCL_WHITELIST_DISABLE'] = '1'
torch_npu.npu.set_device(rank)
dist.init_process_group(backend='hccl', world_size=world_size, rank=rank)
return dist
@classmethod
def _init_pg_hccl(cls, rank, world_size):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
os.environ['HCCL_WHITELIST_DISABLE'] = '1'
torch_npu.npu.set_device(rank)
dist.init_process_group(backend='hccl', world_size=world_size, rank=rank)
return dist.new_group([0, 1])
@classmethod
def _test_send_recv_dist(
cls, rank, shared_tensors, world_size, init_pg, c2p, p2c):
pg = init_pg(rank, world_size)
res = torch.ones(2, 2).to(shared_tensors.dtype).to(f"npu:{rank}")
xs = shared_tensors.to(f"npu:{rank}")
dst = 0
src = 1
if src == rank:
pg.send(xs, dst)
c2p.put((dst, xs.cpu()))
else:
pg.recv(res, src)
c2p.put((src, res.cpu()))
@classmethod
def _test_send_recv_dist_with_internal_format_and_offset(cls, rank, shared_tensors, world_size, init_pg):
torch_npu.npu.config.allow_internal_format = True
pg = init_pg(rank, world_size)
first_dim = shared_tensors.shape[0]
other_dims = shared_tensors.shape[1:]
res = torch.ones_like(shared_tensors).repeat(2, *other_dims).to(f"npu:{rank}")
res = torch_npu.npu_format_cast(res, 29)[first_dim:]
xs = shared_tensors.repeat(2, *other_dims).to(f"npu:{rank}")
xs = torch_npu.npu_format_cast(xs, 29)[first_dim:]
dst = 0
src = 1
test_case = TestCase()
error_expect = "For a tensor of internal format, it's storage_offset must be 0"
with test_case.assertRaisesRegex(RuntimeError, error_expect):
if src == rank:
pg.send(xs, dst)
else:
pg.recv(res, src)
@classmethod
def _test_send_recv_dist_with_p2p(
cls, rank, shared_tensors, world_size, init_pg, c2p, p2c):
pg = init_pg(rank, world_size)
res = torch.ones(2, 2).to(shared_tensors.dtype).to(f"npu:{rank}")
xs = shared_tensors.to(f"npu:{rank}")
dst = 0
src = 1
if src == rank:
pg.send(xs, dst)
c2p.put((dst, xs.cpu()))
elif rank == dst:
pg.recv(res, src)
c2p.put((src, res.cpu()))
@classmethod
def _test_send_recv_dist_recv_with_format(
cls, rank, shared_tensors, world_size, init_pg, c2p, p2c):
pg = init_pg(rank, world_size)
res = torch_npu.npu_format_cast(torch.ones(2, 2).to(shared_tensors.dtype).to(f"npu:{rank}"), 29)
xs = shared_tensors.to(f"npu:{rank}")
dst = 0
src = 1
if src == rank:
pg.send(xs, dst)
c2p.put((dst, xs.cpu()))
else:
pg.recv(res, src)
c2p.put((src, res.cpu()))
@classmethod
def _test_send_recv_dist_send_with_format(
cls, rank, shared_tensors, world_size, init_pg, c2p, p2c):
pg = init_pg(rank, world_size)
res = torch.ones(2, 2).to(shared_tensors.dtype).to(f"npu:{rank}")
xs = torch_npu.npu_format_cast(shared_tensors.to(f"npu:{rank}"), 29)
dst = 0
src = 1
if src == rank:
pg.send(xs, dst)
c2p.put((dst, xs.cpu()))
else:
pg.recv(res, src)
c2p.put((src, res.cpu()))
@classmethod
def _test_send_recv_group(
cls, rank, shared_tensors, world_size, init_pg, c2p, p2c):
pg = init_pg(rank, world_size)
tag = 1
res = [torch.ones(2, 2).to(f"npu:{rank}")]
xs = [shared_tensors.to(f"npu:{rank}")]
dst = 0
src = 1
if src == rank:
pg.send(xs, dst, tag).wait()
c2p.put((dst, xs[0].cpu()))
else:
pg.recv(res, src, tag).wait()
c2p.put((src, res[0].cpu()))
@classmethod
def _test_send_recv_group_recv_with_format(
cls, rank, shared_tensors, world_size, init_pg, c2p, p2c):
pg = init_pg(rank, world_size)
tag = 1
res = [torch_npu.npu_format_cast(torch.ones(2, 2).to(f"npu:{rank}"), 29)]
xs = [shared_tensors.to(f"npu:{rank}")]
dst = 0
src = 1
if src == rank:
pg.send(xs, dst, tag).wait()
c2p.put((dst, xs[0].cpu()))
else:
pg.recv(res, src, tag).wait()
c2p.put((src, res[0].cpu()))
@classmethod
def _test_send_recv_group_send_with_format(
cls, rank, shared_tensors, world_size, init_pg, c2p, p2c):
pg = init_pg(rank, world_size)
tag = 1
res = [torch.ones(2, 2).to(f"npu:{rank}")]
xs = [torch_npu.npu_format_cast(shared_tensors.to(f"npu:{rank}"), 29)]
dst = 0
src = 1
if src == rank:
pg.send(xs, dst, tag).wait()
c2p.put((dst, xs[0].cpu()))
else:
pg.recv(res, src, tag).wait()
c2p.put((src, res[0].cpu()))
def _test_multiprocess(self, f, shared_tensors, init_pg, ws=0):
if not ws:
ws = self.world_size
ctx = mp.get_context('spawn')
c2p = ctx.Queue(2)
p2c = ctx.Queue(2)
ps = []
expected = 0
result = 1
for i in range(ws):
p = ctx.Process(
target=f,
args=(i, shared_tensors, ws, init_pg, c2p, p2c))
p.start()
ps.append(p)
for _ in range(2):
pid, output = c2p.get()
if pid == 0:
expected = output
else:
result = output
self.assertEqual(
expected,
result,
(
"Expect rank {} to receive tensor {} but got {}."
).format(pid, expected, result)
)
for _ in range(2):
p2c.put(0)
for p in ps:
p.join(2)
def _test_multiprocess_with_error(self, f, shared_tensors, init_pg, ws=0):
if not ws:
ws = self.world_size
ctx = mp.get_context('spawn')
ps = []
for i in range(ws):
p = ctx.Process(
target=f,
args=(i, shared_tensors, ws, init_pg))
p.start()
ps.append(p)
for p in ps:
p.join()
self.assertEqual(p.exitcode, 0, "subprocess exit with abnormal code.")
@skipIfUnsupportMultiNPU(2)
def test_send_recv_hccl_dist(self):
self._test_multiprocess(
HcclSendRecvDistTest._test_send_recv_dist,
torch.randn(2, 2),
HcclSendRecvDistTest._init_dist_hccl)
@skipIfUnsupportMultiNPU(2)
def test_send_recv_hccl_dist_with_internal_format_and_offset(self):
self._test_multiprocess_with_error(
HcclSendRecvDistTest._test_send_recv_dist_with_internal_format_and_offset,
torch.randn(31, 31),
HcclSendRecvDistTest._init_dist_hccl)
@skipIfUnsupportMultiNPU(4)
def test_send_recv_hccl_dist_with_p2p(self):
with patch.dict(os.environ, {"P2P_HCCL_BUFFSIZE": "20"}):
world_size = 4
self._test_multiprocess(
HcclSendRecvDistTest._test_send_recv_dist_with_p2p,
torch.randn(2, 2),
HcclSendRecvDistTest._init_dist_hccl,
world_size)
@skipIfUnsupportMultiNPU(2)
def test_send_recv_hccl_dist_with_format(self):
self._test_multiprocess(
HcclSendRecvDistTest._test_send_recv_dist_recv_with_format,
torch.randn(2, 2),
HcclSendRecvDistTest._init_dist_hccl)
self._test_multiprocess(
HcclSendRecvDistTest._test_send_recv_dist_send_with_format,
torch.randn(2, 2),
HcclSendRecvDistTest._init_dist_hccl)
@skipIfUnsupportMultiNPU(2)
def test_send_recv_hccl_group(self):
self._test_multiprocess(
HcclSendRecvDistTest._test_send_recv_group,
torch.randn(2, 2),
HcclSendRecvDistTest._init_pg_hccl)
@skipIfUnsupportMultiNPU(2)
def test_send_recv_hccl_group_with_format(self):
self._test_multiprocess(
HcclSendRecvDistTest._test_send_recv_group_recv_with_format,
torch.randn(2, 2),
HcclSendRecvDistTest._init_pg_hccl)
self._test_multiprocess(
HcclSendRecvDistTest._test_send_recv_group_send_with_format,
torch.randn(2, 2),
HcclSendRecvDistTest._init_pg_hccl)
@skipIfUnsupportMultiNPU(2)
def test_send_recv_hccl_bool(self):
self._test_multiprocess(
HcclSendRecvDistTest._test_send_recv_dist,
torch.randn(2, 2) > 0.5,
HcclSendRecvDistTest._init_dist_hccl)
if __name__ == '__main__':
run_tests()