8de4d5c0创建于 2024年12月23日历史提交
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU


class HcclGatherTest(TestCase):

    @classmethod
    def _init_dist_hccl(cls, rank, world_size):
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = '29500'
        os.environ['HCCL_WHITELIST_DISABLE'] = '1'
        torch_npu.npu.set_device(rank)
        dist.init_process_group(backend='hccl', world_size=world_size, rank=rank)
        return dist

    @classmethod
    def _test_gather(cls, test_data, world_size, init_pg, c2p, p2c):
        rank, input1, output1 = test_data
        pg = init_pg(rank, world_size)
        dst = 0
        input1 = input1.npu()
        output1 = [i.npu() for i in output1]
        if rank == dst:
            pg.gather(input1, output1, dst=dst)
        else:
            pg.gather(input1, [], dst=dst)
        if rank == 0:
            c2p.put((rank, dst, [t.cpu() for t in output1]))
        else:
            c2p.put((rank, dst, []))
        p2c.get()

    @classmethod
    def _test_gather_object(cls, test_data, world_size, init_pg, c2p, p2c):
        rank, input1, output1 = test_data
        pg = init_pg(rank, world_size)
        dst = 0
        output1 = [i for i in output1]
        if rank == dst:
            pg.gather_object(input1, output1, dst=dst)
        else:
            pg.gather_object(input1, [], dst=dst)
        if rank == 0:
            c2p.put((rank, dst, [t.cpu() for t in output1]))
        else:
            c2p.put((rank, dst, []))
        p2c.get()


    def _test_multiprocess(self, f, init_pg, proc_data, world_size):
        input1, output1, expected = proc_data
        ctx = mp.get_context('spawn')
        c2p = ctx.Queue(world_size)
        p2c = ctx.Queue(world_size)
        ps = []
        for i in range(world_size):
            p = ctx.Process(
                target=f,
                args=((i, input1, output1), world_size, init_pg, c2p, p2c))
            p.start()
            ps.append(p)

        for _ in range(world_size):
            rank, dst, output = c2p.get()
            if rank == 0:
                for i, j in zip(output, expected):
                    self.assertEqual(i, j,
                                     ("rank {} Expect receive tensor {} but got {}.").format(rank, expected, output))

        for _ in range(world_size):
            p2c.put(0)

        for p in ps:
            p.join()

    @skipIfUnsupportMultiNPU(2)
    def test_gather_dist(self):
        ranks = [2]
        dtypes = [torch.float32, torch.float16, torch.int32, torch.uint8]
        for rank in ranks:
            for _dtype in dtypes:
                _input = torch.tensor([rank], dtype=_dtype)
                _output = [torch.empty([1], dtype=_dtype) for _ in range(rank)]
                _expected = [torch.tensor([rank], dtype=_dtype) for _ in range(rank)]
                proc_data = (_input, _output, _expected)
                self._test_multiprocess(HcclGatherTest._test_gather,
                                        HcclGatherTest._init_dist_hccl, proc_data, rank)

    @skipIfUnsupportMultiNPU(2)
    def test_gather_object_dist(self):
        ranks = [2]
        dtypes = [torch.float32, torch.float16, torch.int32, torch.uint8]
        for rank in ranks:
            for _dtype in dtypes:
                _input = torch.tensor([rank], dtype=_dtype)
                _output = [torch.empty([1], dtype=_dtype) for _ in range(rank)]
                _expected = [torch.tensor([rank], dtype=_dtype) for _ in range(rank)]
                proc_data = (_input, _output, _expected)
                self._test_multiprocess(HcclGatherTest._test_gather_object,
                                        HcclGatherTest._init_dist_hccl, proc_data, rank)


if __name__ == '__main__':
    run_tests()