import os
import unittest
import random
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor, SupportedDevices
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU

class TestDistributeBarrier(TestCase):
    @classmethod
    def _init_dist_hccl(cls, rank, world_size, ep_world_size, tp_world_size):
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = '50000'
        os.environ['HCCL_WHITELIST_DISABLE'] = '1'
        torch_npu.npu.set_device(rank)
        dist.init_process_group(backend='hccl', world_size = world_size, rank = rank)
        ep_ranks_list = []
        tp_ranks_list = []
        for i in range(tp_world_size):
            ep_ranks_list.append(list(range(i, world_size, tp_world_size)))
        for i in range(ep_world_size):
            tp_ranks_list.append(list(range(i * tp_world_size, (i + 1) * tp_world_size)))
        for i in range(tp_world_size):
            ep_group = dist.new_group(backend='hccl', ranks=ep_ranks_list[i])
            if rank in ep_ranks_list[i]:
                ep_group_tmp = ep_group
        for i in range(ep_world_size):
            tp_group = dist.new_group(backend='hccl', ranks=tp_ranks_list[i])
            if rank in tp_ranks_list[i]:
                tp_group_tmp = tp_group
        return dist, ep_group_tmp, tp_group_tmp

    @classmethod
    def gen_elastic_info(cls, is_elastic, world_size, shared_expert_rank_num, share_broken_card_num,
                         moe_broken_card_num, local_moe, ep_world_size, moe_expert_num):
        if not is_elastic: return None
        elastic_info = torch.zeros(4 + 2 * ep_world_size, dtype = torch.int32)
        elastic_info[0] = is_elastic
        elastic_info[1] = world_size - share_broken_card_num -moe_broken_card_num
        elastic_info[2] = shared_expert_rank_num
        elastic_info[3] = moe_expert_num - local_moe*(moe_broken_card_num + share_broken_card_num)
        table1 = [-1] * ep_world_size
        table2 = [-1] * ep_world_size

        if is_elastic:
            _ = [i for i in range(shared_expert_rank_num)]
            __= [i for i in range(shared_expert_rank_num, world_size)]
            random_seed=24
            torch.manual_seed(random_seed)
            elastic_rank = random.sample(_, elastic_info[2]) + random.sample(__, elastic_info[1] - elastic_info[2])
            elastic_rank.sort()

        for local_rank_id, ep_rank_id in enumerate(elastic_rank):
            if ep_rank_id < ep_world_size:
                table1[ep_rank_id] = local_rank_id
                table2[local_rank_id] = ep_rank_id
        for i in range(ep_world_size):
            elastic_info[4 + i] = table1[i]
        for i in range(ep_world_size):
            elastic_info[4 + ep_world_size + i] = table2[i]
        assert elastic_info.shape[0] == 4 + 2 * ep_world_size

        if is_elastic:
            table1 = elastic_info[4 : 4 + ep_world_size]
            table2 = elastic_info[4 + ep_world_size: 4 + 2 * ep_world_size]
        return elastic_info

    @classmethod
    def _test_npu_distribute_barrier(cls, rank, x_ref, time_out, elastic_info,
                                         ep_world_size, tp_world_size, init_pg, c2p, p2c ):

        pg, ep_group, tp_group = init_pg(rank, ep_world_size * tp_world_size, ep_world_size, tp_world_size)
        ep_hcomm_name = ep_group._get_backend(torch.device('npu')).get_hccl_comm_name(rank)
        tp_hcomm_name = tp_group._get_backend(torch.device('npu')).get_hccl_comm_name(rank)

        if rank in elastic_info[4 + ep_world_size: 4 + 2 * ep_world_size]:
            out = torch_npu._npu_distribute_barrier(x_ref = x_ref.npu(),
                                                   time_out = time_out.npu(),
                                                   elastic_info = elastic_info.npu(),
                                                   group = ep_hcomm_name,
                                                   world_size = ep_world_size)
        else:
            out = None
        if out is not None:
            c2p.put((rank, out.cpu()))
        else:
            c2p.put((rank, None))
        p2c.get()

    def _test_multiprocess(self, f, init_pg, input_list):
        expt_out_list, x_ref, time_out, elastic_info, \
            ep_world_size, tp_world_size = input_list
        ctx = mp.get_context('spawn')
        c2p = ctx.Queue(ep_world_size * tp_world_size)
        p2c = ctx.Queue(ep_world_size * tp_world_size)
        ps = []

        for i in range(ep_world_size * tp_world_size):
            p = ctx.Process(
                target=f,
                args=(i, x_ref, time_out, elastic_info, ep_world_size, tp_world_size, init_pg, c2p, p2c))
            p.start()
            ps.append(p)

        for _ in range(ep_world_size * tp_world_size):
            rank, output = c2p.get()
            self.assertEqual(output, expt_out_list[rank],
                             ("rank {} Expect receive tensor {} but got {}.").format(rank, expt_out_list[rank], output))

        for _ in range(ep_world_size * tp_world_size):
            p2c.put(0)

        for p in ps:
            p.join()

    def _construct_excepted_result(self, x_ref, elastic_info, ep_world_size):
        out_list = []
        for i in elastic_info[4 + ep_world_size: 4 + 2 * ep_world_size]:
            if i == -1:
                out_list.append(None)
            else:
                out_list.append(x_ref)
        return out_list

    @skipIfUnsupportMultiNPU(16)
    @SupportedDevices(['Ascend910_93', 'Ascend950'])
    def test_npu_distribute_barrier(self):
        ep_world_size = 8
        tp_world_size = 1
        bs = 8
        h = 7168
        k = 7
        shared_broken_card_num = 0
        shared_expert_rank_num = 0
        moe_broken_card_num = 0
        local_moe= 4
        moe_expert_num = local_moe * (ep_world_size * tp_world_size - shared_expert_rank_num)
        is_elastic = 1
        x_ref = torch.ones(1, dtype = torch.int32)
        time_out = torch.tensor([100000], dtype = torch.int32).npu()
        elastic_info_x1 = TestDistributeBarrier.gen_elastic_info(is_elastic, ep_world_size * tp_world_size,
            shared_expert_rank_num, shared_broken_card_num, moe_broken_card_num, local_moe,ep_world_size,
            moe_expert_num)

        for _ in range(ep_world_size):
            expt_out_list_1 = self._construct_excepted_result(x_ref, elastic_info_x1, ep_world_size)

        self._test_multiprocess(TestDistributeBarrier._test_npu_distribute_barrier,
                TestDistributeBarrier._init_dist_hccl, [expt_out_list_1, x_ref, time_out, elastic_info_x1,
                                                           ep_world_size, tp_world_size])
if __name__ == '__main__':
    run_tests()