import os

import unittest

import numpy as np

import torch

import torch.distributed as dist

import torch.multiprocessing as mp

import torch_npu



from torch_npu.testing.testcase import TestCase, run_tests

from torch_npu.testing.common_utils import create_common_tensor, SupportedDevices

from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU







class TestAllGatherQuantMm(TestCase):



    @classmethod

    def _init_dist_hccl(cls, rank, world_size):

        os.environ['MASTER_ADDR'] = '127.0.0.1'

        os.environ['MASTER_PORT'] = '50000'

        os.environ['HCCL_WHITELIST_DISABLE'] = '1'

        torch_npu.npu.set_device(rank)

        dist.init_process_group(backend='hccl', world_size=world_size, rank=rank)

        return dist



    @classmethod

    def _test_npu_all_gather_quant_mm(cls, rank, input_list):

        x1_list, x2_list, world_size, init_pg, c2p = input_list

        x1 = x1_list[rank]

        x2 = x2_list[rank]

        pg = init_pg(rank, world_size)

        group = pg.distributed_c10d._get_default_group()

        hcom_name = group._get_backend(torch.device('npu')).get_hccl_comm_name(rank)



        x1 = x1.npu()

        x2 = x2.npu()

        out, gather_out = torch_npu.npu_all_gather_quant_mm(x1,

                                                            x2,

                                                            hcom_name,

                                                            world_size,

                                                            bias=None,

                                                            x1_scale=None,

                                                            x2_scale=None,

                                                            quant_scale=None,

                                                            gather_index=0,

                                                            gather_output=True,

                                                            comm_turn=0)



        c2p.put((rank, out.cpu(), gather_out.cpu()))

        pg.barrier()



    def _test_multiprocess(self, f, init_pg, input_list):

        expt_out_list, expt_gather, x1, x2, world_size = input_list

        ctx = mp.get_context('spawn')

        c2p = ctx.Queue(world_size)

        ps = []



        for i in range(world_size):

            p = ctx.Process(

                target=f,

                args=(i, [x1, x2, world_size, init_pg, c2p]))

            p.start()

            ps.append(p)



        for _ in range(world_size):

            rank, output, gather_output = c2p.get()

            self.assertEqual(output, expt_out_list[rank],

                             ("rank {} Expect receive tensor {} but got {}.").format(rank, expt_out_list[rank], output))

            self.assertEqual(gather_output, expt_gather,

                             ("rank {} Expect receive tensor {} but got {}.").format(rank, expt_gather, gather_output))



        for p in ps:

            p.join()



    def _construct_excepted_result(self, x1_list, x2_list, world_size):

        gather_out = torch.cat(x1_list)

        out_list = []

        out_dtype = gather_out.dtype

        for i in range(world_size):

            out_list.append(torch.matmul(gather_out.npu(), x2_list[i].npu()).to(out_dtype).cpu())

        return out_list, gather_out



    @skipIfUnsupportMultiNPU(8)

    @SupportedDevices(['Ascend910B'])

    def test_npu_all_gather_quant_mm(self):

        world_size = 8

        dtype = np.float16

        data_format = -1

        x1_shape = [dtype, data_format, [16, 512]]

        x2_shape = [dtype, data_format, [512, 256]]

        x1_list = []

        x2_list = []

        for _ in range(world_size):

            x1, _ = create_common_tensor(x1_shape, -1, 1)

            x2, _ = create_common_tensor(x2_shape, -1, 1)

            x1_list.append(x1)

            x2_list.append(x2)

        expt_out_list, expt_gather = self._construct_excepted_result(x1_list, x2_list, world_size)

        self._test_multiprocess(TestAllGatherQuantMm._test_npu_all_gather_quant_mm,

                                TestAllGatherQuantMm._init_dist_hccl, [expt_out_list, expt_gather, x1_list, x2_list, world_size])





if __name__ == '__main__':

    run_tests()