import os
import unittest
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor, SupportedDevices
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU


class TestMoeDistributeCombine(TestCase):

    @classmethod
    def _init_dist_hccl(cls, rank, world_size, ep_world_size, tp_world_size):
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = '50000'
        os.environ['HCCL_WHITELIST_DISABLE'] = '1'
        torch_npu.npu.set_device(rank)
        dist.init_process_group(backend='hccl', world_size=world_size, rank=rank)
        ep_ranks_list = []
        tp_ranks_list = []
        for i in range(tp_world_size):
            ep_ranks_list.append(list(range(i, world_size, tp_world_size)))
        for i in range(ep_world_size):
            tp_ranks_list.append(list(range(i * tp_world_size, (i + 1) * tp_world_size)))
        for i in range(tp_world_size):
            ep_group = dist.new_group(backend='hccl', ranks=ep_ranks_list[i])
            if rank in ep_ranks_list[i]:
                ep_group_tmp = ep_group
        for i in range(ep_world_size):
            tp_group = dist.new_group(backend='hccl', ranks=tp_ranks_list[i])
            if rank in tp_ranks_list[i]:
                tp_group_tmp = tp_group
        return dist, ep_group_tmp, tp_group_tmp

    @classmethod
    def _test_npu_moe_distribute_combine_v2(cls, rank, input_list):
        expand_x, scales1_list, scales2_list, topk1_list, topk2_list, elastic_info, assist_info_for_combine,\
            ep_send_counts, tp_send_counts, ep_world_size, tp_world_size, globalBS, sharedExpertRankNum, moeExpertNum,\
            init_pg, c2p, p2c = input_list
        if rank % tp_world_size == 0:
            topk = topk1_list[rank // tp_world_size]
            expert_scales = scales1_list[rank // tp_world_size]
        else:
            topk = topk2_list[rank // tp_world_size]
            expert_scales = scales2_list[rank // tp_world_size]
        pg, ep_group, tp_group = init_pg(rank, ep_world_size * tp_world_size, ep_world_size, tp_world_size)
        ep_hcomm_name = ep_group._get_backend(torch.device('npu')).get_hccl_comm_name(rank)
        tp_hcomm_name = tp_group._get_backend(torch.device('npu')).get_hccl_comm_name(rank)

        expand_x = expand_x.npu()
        topk = topk.npu()
        assist_info_for_combine = assist_info_for_combine.npu()
        ep_send_counts = ep_send_counts.npu()
        tp_send_counts = tp_send_counts.npu()
        expert_scales = expert_scales.npu()
        elastic_info = elastic_info.npu()
        out = torch_npu.npu_moe_distribute_combine_v2(expand_x=expand_x,
                                                   expert_ids=topk,
                                                   assist_info_for_combine=assist_info_for_combine,
                                                   ep_send_counts=ep_send_counts,
                                                   expert_scales=expert_scales,
                                                   elastic_info=elastic_info,
                                                   group_ep=ep_hcomm_name,
                                                   ep_world_size=ep_world_size,
                                                   ep_rank_id=int(rank // tp_world_size),
                                                   moe_expert_num=moeExpertNum,
                                                   tp_send_counts=tp_send_counts,
                                                   group_tp=tp_hcomm_name,
                                                   tp_world_size=tp_world_size,
                                                   tp_rank_id=int(rank % tp_world_size),
                                                   expert_shard_type=0,
                                                   shared_expert_num=int(sharedExpertRankNum > 0),
                                                   shared_expert_rank_num=sharedExpertRankNum,
                                                   global_bs=globalBS)
        c2p.put((rank, out.cpu()))
        p2c.get()

    def _test_multiprocess(self, f, init_pg, input_list):
        expt_out_list, expand_x_list, scales1_list, scales2_list, topk1_list,\
            topk2_list, idx_list, ep_recvCount_list, tp_recvCount_list, ep_world_size, tp_world_size, globalBS,\
            sharedExpertRankNum, moeExpertNum = input_list
        ctx = mp.get_context('spawn')
        c2p = ctx.Queue(ep_world_size * tp_world_size)
        p2c = ctx.Queue(ep_world_size * tp_world_size)
        ps = []

        for i in range(ep_world_size * tp_world_size):
            p = ctx.Process(
                target=f,
                args=(i, [expand_x_list[i], scales1_list, scales2_list, topk1_list,
                    topk2_list, idx_list[i], ep_recvCount_list[i], tp_recvCount_list[i], ep_world_size, tp_world_size, globalBS,
                    sharedExpertRankNum, moeExpertNum, init_pg, c2p, p2c]))
            p.start()
            ps.append(p)

        for _ in range(ep_world_size * tp_world_size):
            rank, output = c2p.get()
            self.assertEqual(output, expt_out_list[rank],
                             ("rank {} Expect receive tensor {} but got {}.").format(rank, expt_out_list[rank], output))

        for _ in range(ep_world_size * tp_world_size):
            p2c.put(0)

        for p in ps:
            p.join()

    def _chunk_tensor(self, tensor, num_chunks):
        chunk_size = tensor.size(0) // num_chunks
        chunks = []
        for i in range(num_chunks):
            chunk = tensor[i * chunk_size:(i + 1) * chunk_size]
            chunks.append(chunk)
        return chunks

    def _construct_idx(self, tensor, ep_world_size):
        num_groups = ep_world_size
        group_size = tensor.size(0) // num_groups
        split_tensors = torch.split(tensor, group_size)
        count_tensor = torch.zeros_like(tensor)
        for i, split_tensor in enumerate(split_tensors):
            start_idx = i * group_size
            end_idx = start_idx + group_size
            count_dict = {}
            for j, num in enumerate(split_tensor):
                num = num.item()
                count_dict[num] = count_dict.get(num, -1) + 1
                count_tensor[start_idx + j] = count_dict[num]
        return count_tensor.to(torch.int32)

    def _gen_recvCount(self, tensor, bs, ep_world_size, moeExpertNum, sharedExpertRankNum):
        segment_length = tensor.numel() // ep_world_size
        result_tensor = torch.zeros(moeExpertNum, ep_world_size, dtype=torch.int32)
        for i in range(ep_world_size):
            start_idx = i * segment_length
            end_idx = start_idx + segment_length if i < ep_world_size - 1 else tensor.numel()
            segment = tensor[start_idx:end_idx]
            counts = torch.bincount(segment, minlength=moeExpertNum)
            result_tensor[:, i] = counts
        shared = torch.zeros(sharedExpertRankNum, ep_world_size, dtype=torch.int32)
        for i in range(sharedExpertRankNum):
            for j in range(ep_world_size):
                if i == j:
                    shared[i][j] = bs
                elif j >= sharedExpertRankNum:
                    shared[i][j] = int(bs // sharedExpertRankNum)
        result_tensor = torch.cat((shared, result_tensor), dim=0)
        return result_tensor.flatten()

    def _construct_excepted_result(self, x1_list, x2_list, topk1_list, topk2_list, bs, h, k, globalBS,
                                   sharedExpertRankNum, moeExpertNum, ep_world_size, tp_world_size, scales1, scales2):
        col_idx = torch.arange(0, globalBS * k, dtype=torch.int32)
        row_idx = col_idx.view(k, -1).permute(1, 0)
        mapping = dict(zip(map(int, row_idx.flatten()), map(int, col_idx.flatten())))
        row_idx = row_idx.reshape([globalBS, k]).contiguous()

        x1 = torch.cat(x1_list, dim=0).view(-1, h)
        x2 = torch.cat(x2_list, dim=0).view(-1, h)
        topk1 = torch.cat(topk1_list, dim=0).view(-1, k)
        topk2 = torch.cat(topk2_list, dim=0).view(-1, k)

        expandX1, expand_row1, expand_expert1 = torch_npu.npu_moe_init_routing(x1.npu(), row_idx=row_idx.npu(),
                                                                       expert_idx=topk1.npu(),
                                                                       active_num=globalBS)
        expandX2, expand_row2, expand_expert2 = torch_npu.npu_moe_init_routing(x2.npu(), row_idx=row_idx.npu(),
                                                                       expert_idx=topk2.npu(),
                                                                       active_num=globalBS)

        expandX1 = expandX1.cpu().view(-1, h)
        expandX2 = expandX2.cpu().view(-1, h)
        expand_row1 = expand_row1.cpu()
        expand_row2 = expand_row2.cpu()
        expand_expert1 = expand_expert1.cpu()
        expand_expert2 = expand_expert2.cpu()
        j = 0
        result_idx = np.zeros(globalBS * k).astype(int)
        for i in expand_row1:
            result_idx[int(i)] = mapping[j]
            j += 1
        middle_idx1 = torch.tensor(result_idx.astype(np.int32))
        j = 0
        result_idx = np.zeros(globalBS * k).astype(int)
        for i in expand_row2:
            result_idx[int(i)] = mapping[j]
            j += 1
        middle_idx2 = torch.tensor(result_idx.astype(np.int32))

        shared_list = []
        shared_tokens = []
        for i in range(sharedExpertRankNum):
            tmp_list = []
            shared_tokens.append(bs * (int(moeExpertNum / sharedExpertRankNum) + 1))
            tmp_list.append(x1[(bs * i):(bs * (i + 1)), :])
            for j in range(int(moeExpertNum / sharedExpertRankNum)):
                tmp_list.append(x1[(bs * (i + (j + 1) * sharedExpertRankNum)):(bs * (i + (j + 1) * sharedExpertRankNum + 1)), :])
            tmp_list = torch.cat(tmp_list, dim=0).to(torch.float16)
            shared_list.append(tmp_list)
        shared_x1 = shared_list
        token1 = torch.cat((torch.tensor(shared_tokens), torch.bincount(expand_expert1)))
        token2 = torch.cat((torch.tensor(shared_tokens), torch.bincount(expand_expert2)))
        shared_list = []
        for i in range(sharedExpertRankNum):
            tmp_list = []
            tmp_list.append(x2[(bs * i):(bs * (i + 1)), :])
            for j in range(int(moeExpertNum / sharedExpertRankNum)):
                tmp_list.append(x2[(bs * (i + (j + 1) * sharedExpertRankNum)):(bs * (i + (j + 1) * sharedExpertRankNum + 1)), :])
            tmp_list = torch.cat(tmp_list, dim=0).to(torch.float16)
            shared_list.append(tmp_list)
        shared_x2 = shared_list

        expand_x_list = []
        for i in range(sharedExpertRankNum):
            expand_x_list.append(torch.cat((shared_x1[i], shared_x2[i])))
            expand_x_list.append(torch.cat((shared_x2[i], shared_x1[i])))
            shared_x1[i] = shared_x1[i] + shared_x1[i]
            shared_x2[i] = shared_x2[i] + shared_x2[i]
        sums1 = 0
        sums2 = 0
        local = int(moeExpertNum // (ep_world_size - sharedExpertRankNum))
        A = int(globalBS * local)
        for i in range(sharedExpertRankNum, ep_world_size):
            start1 = sums1
            end1 = sums1 + int(token1[i])
            sums1 = end1
            start2 = sums2
            end2 = sums2 + int(token2[i])
            sums2 = end2
            pad = torch.tensor(np.random.uniform(0, 1, size=[tp_world_size * A - int(token1[i]) - int(token2[i]), h])).to(torch.float16)
            expand_x_list.append(torch.cat((expandX1[start1:end1, :], expandX2[start2:end2, :], pad)))
            expand_x_list.append(torch.cat((expandX2[start2:end2, :], expandX1[start1:end1, :], pad)))
            expandX1[start1:end1, :] = expandX1[start1:end1, :] + expandX1[start1:end1, :]
            expandX2[start2:end2, :] = expandX2[start2:end2, :] + expandX2[start2:end2, :]

        shared_x1 = torch.cat(shared_x1, dim=0).view(-1, h)
        shared_x2 = torch.cat(shared_x2, dim=0).view(-1, h)
        combine_x1_shared = []
        combine_x2_shared = []
        for i in range(ep_world_size):
            if i < sharedExpertRankNum:
                start_idx = i * int(globalBS // sharedExpertRankNum)
                end_idx = i * int(globalBS // sharedExpertRankNum) + bs
                combine_x1_shared.append(shared_x1[start_idx:end_idx, :])
                combine_x2_shared.append(shared_x2[start_idx:end_idx, :])
            else:
                startIdx = int((i % sharedExpertRankNum * globalBS // sharedExpertRankNum) + i // sharedExpertRankNum * bs)
                endIdx = startIdx + int(bs)
                combine_x1_shared.append(expandX1[startIdx:endIdx, :])
                combine_x2_shared.append(expandX2[startIdx:endIdx, :])
        combine_x1_shared = torch.cat(combine_x1_shared, dim=0).flatten()
        combine_x2_shared = torch.cat(combine_x2_shared, dim=0).flatten()

        result_list = [None] * len(middle_idx1)
        for i, pos in enumerate(middle_idx1):
            result_list[int(pos)] = expandX1[i].to(torch.float32)
        result_list = [t * s for t, s in zip(result_list, scales1.flatten())]
        group_sums = []
        for i in range(globalBS):
            start_idx = i * k
            end_idx = start_idx + k
            group_tensors = result_list[start_idx:end_idx]
            group_sum = torch.stack(group_tensors).sum(dim=0)
            group_sums.append(group_sum)
        combine_x1 = torch.cat(group_sums) + combine_x1_shared

        result_list = [None] * len(middle_idx2)
        for i, pos in enumerate(middle_idx2):
            result_list[int(pos)] = expandX2[i].to(torch.float32)
        result_list = [t * s for t, s in zip(result_list, scales2.flatten())]
        group_sums = []
        for i in range(globalBS):
            start_idx = i * k
            end_idx = start_idx + k
            group_tensors = result_list[start_idx:end_idx]
            group_sum = torch.stack(group_tensors).sum(dim=0)
            group_sums.append(group_sum)
        combine_x2 = torch.cat(group_sums) + combine_x2_shared

        combine_x1 = combine_x1.to(torch.float16).view(-1, h)
        combine_x2 = combine_x2.to(torch.float16).view(-1, h)
        out_list = []
        sums = 0
        for _ in range(ep_world_size):
            start_idx = sums
            sums = sums + bs
            end_idx = sums
            out_list.append(combine_x1[start_idx:end_idx, :])
            out_list.append(combine_x2[start_idx:end_idx, :])

        topk1_list = torch.cat(topk1_list).flatten()
        topk2_list = torch.cat(topk2_list).flatten()
        idx1 = self._construct_idx(topk1_list, ep_world_size)
        idx2 = self._construct_idx(topk2_list, ep_world_size)
        idx1 = self._chunk_tensor(idx1, ep_world_size)
        idx2 = self._chunk_tensor(idx2, ep_world_size)
        idx_list = []
        for i in range(ep_world_size):
            idx_list.append(idx1[i])
            idx_list.append(idx2[i])

        recvCount1 = self._gen_recvCount(topk1_list, bs, ep_world_size, moeExpertNum, sharedExpertRankNum)
        recvCount2 = self._gen_recvCount(topk2_list, bs, ep_world_size, moeExpertNum, sharedExpertRankNum)
        for i in range(ep_world_size):
            sums1 = 0
            sums2 = 0
            for j in range(ep_world_size):
                sums1 = recvCount1[i * ep_world_size + j] + sums1
                recvCount1[i * ep_world_size + j] = sums1
                sums2 = recvCount2[i * ep_world_size + j] + sums2
                recvCount2[i * ep_world_size + j] = sums2
        recvCount1 = self._chunk_tensor(recvCount1, ep_world_size)
        recvCount2 = self._chunk_tensor(recvCount2, ep_world_size)
        ep_recvCount_list = []
        for i in range(ep_world_size):
            ep_recvCount_list.append(recvCount1[i])
            ep_recvCount_list.append(recvCount2[i])

        tp_recvCount_list = []
        for i in range(ep_world_size):
            tp_recvCount_list.append(torch.tensor([int(token1[i]), int(token2[i])]).to(torch.int32))
            tp_recvCount_list.append(torch.tensor([int(token1[i]), int(token2[i])]).to(torch.int32))

        return expand_x_list, out_list, idx_list, ep_recvCount_list, tp_recvCount_list

    @skipIfUnsupportMultiNPU(16)
    @SupportedDevices(['Ascend910_'])
    def test_npu_moe_distribute_combine_v2(self):
        ep_world_size = 8
        tp_world_size = 2
        world_size = ep_world_size * tp_world_size
        bs = 8
        h = 7168
        k = 7
        sharedExpertRankNum = 1
        moeExpertNum = 7
        globalBS = bs * ep_world_size
        dtype = np.float16
        data_format = -1
        topk = torch.tile(torch.arange(k), (bs,)).int().view(-1, k)
        topk1_list = []
        topk2_list = []
        x1_shape = [dtype, data_format, [bs, h]]
        x2_shape = [dtype, data_format, [bs, h]]
        x1_list = []
        x2_list = []
        scales1_shape = [np.float32, data_format, [bs, k]]
        scales2_shape = [np.float32, data_format, [bs, k]]
        scales1_list = []
        scales2_list = []
        elastic_info_shape = [np.int32, data_format, [4 + ep_world_size]]
        elastic_info = []
        for _ in range(ep_world_size):
            x1, _ = create_common_tensor(x1_shape, -1, 1)
            x2, _ = create_common_tensor(x2_shape, -1, 1)
            x1_list.append(x1)
            x2_list.append(x2)
            topk1_list.append(topk)
            topk2_list.append(topk)
            scales1, _ = create_common_tensor(scales1_shape, -1, 1)
            scales2, _ = create_common_tensor(scales2_shape, -1, 1)
            scales1_list.append(scales1)
            scales2_list.append(scales2)
            temp_elastic_info, _ = create_common_tensor(elastic_info_shape, -1, 1)
            elastic_info.append(temp_elastic_info)
        expand_x_list, expt_out_list, idx_list, ep_recvCount_list, tp_recvCount_list = self._construct_excepted_result(x1_list,
            x2_list, topk1_list, topk2_list, bs, h, k, globalBS, sharedExpertRankNum, moeExpertNum, ep_world_size, tp_world_size,
            torch.cat(scales1_list), torch.cat(scales2_list))
        self._test_multiprocess(TestMoeDistributeCombine._test_npu_moe_distribute_combine_v2,
                TestMoeDistributeCombine._init_dist_hccl, [expt_out_list, expand_x_list, scales1_list, scales2_list, topk1_list,
                topk2_list, elastic_info, idx_list, ep_recvCount_list, tp_recvCount_list, ep_world_size, tp_world_size, globalBS,
                sharedExpertRankNum, moeExpertNum])


if __name__ == '__main__':
    run_tests()