import os
import pytest
import torch
import torch_npu
import mindspeed.megatron_adaptor

from megatron.training.global_vars import set_args
from megatron.training.arguments import parse_args
from megatron.core.tensor_parallel import mappings

import megatron.core.parallel_state as Utils
from tests_extend.unit_tests.common import DistributedTest


class TestTPMapping(DistributedTest):
    world_size = 8
    args = parse_args(None, True)
    set_args(args)
    Utils.world_size = 8
    
    def test_CopyToModelParallelRegion(self):
        rank = int(os.environ['LOCAL_RANK'])
        Utils.initialize_model_parallel(4, 2)
        input_data = torch.ones((1)).cuda() * rank
        output_data = mappings._CopyToModelParallelRegion.backward(None, input_data)
        result = torch.ones(1).cuda()
        result = result * 22 if rank >= 4 else result * 6
        assert(torch.equal(output_data, result))
        assert(torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data)))
        assert(torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data)))
        Utils.destroy_model_parallel()
    
    def test_ReduceFromModelParallelRegion(self):
        rank = int(os.environ['LOCAL_RANK'])
        Utils.initialize_model_parallel(4, 2)
        input_data = torch.ones((1)).cuda() * rank
        output_data = mappings._ReduceFromModelParallelRegion.symbolic(None, input_data)
        result = torch.ones(1).cuda()
        result = result * 22 if rank >= 4 else result * 6
        assert(torch.equal(output_data, result))
        input_data = torch.ones((1)).cuda() * rank
        assert(torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result))
        assert(torch.equal(input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data)))
        Utils.destroy_model_parallel()
    
    def test_ScatterToModelParallelRegion(self):
        rank = int(os.environ['LOCAL_RANK'])
        Utils.initialize_model_parallel(4, 2)
        input_data = torch.rand((8, 4)).cuda()
        output_data = mappings.scatter_to_tensor_model_parallel_region(input_data)
        req_dim = int(rank % (Utils.world_size / 2))
        assert(torch.equal(output_data, input_data[:, req_dim].reshape((8, 1))))
        output_data = mappings._ScatterToModelParallelRegion.symbolic(None, input_data)
        assert(torch.equal(output_data, input_data[:, req_dim].reshape((8, 1))))
    
        input_data = torch.ones(8).cuda() * rank
        actual_output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
        expected_output = torch.cat((
            torch.ones(8) * 0, 
            torch.ones(8) * 1, 
            torch.ones(8) * 2, 
            torch.ones(8) * 3)).cuda()
        if (rank >= 4):
            expected_output = expected_output + 4
        assert(torch.equal(actual_output_data, expected_output))
        Utils.destroy_model_parallel()
    
    def test_GatherFromModelParallelRegion(self):
        rank = int(os.environ['LOCAL_RANK'])
        Utils.initialize_model_parallel(4, 2)
        input_data = torch.rand((8, 4)).cuda()
        req_dim = int(rank % (Utils.world_size / 2))
        output_data = mappings._GatherFromModelParallelRegion.backward(None, input_data)
        assert(torch.equal(output_data, input_data[:, req_dim].reshape((8, 1))))
        input_data = torch.ones(8).cuda() * rank
        actual_output_data = mappings.gather_from_tensor_model_parallel_region(input_data)
        expected_output = torch.cat((
            torch.ones(8) * 0, 
            torch.ones(8) * 1, 
            torch.ones(8) * 2, 
            torch.ones(8) * 3)).cuda()
        if (rank >= 4):
            expected_output = expected_output + 4
        assert(torch.equal(actual_output_data, expected_output))
        assert(torch.equal(mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output))
        Utils.destroy_model_parallel()
     
    def test_ScatterToSequenceParallelRegion(self):
        rank = int(os.environ['LOCAL_RANK'])
        Utils.initialize_model_parallel(4, 2)
        input_data = torch.rand((8, 4)).cuda()
        req_dim = int(rank % (Utils.world_size / 2)) * 2
        output_data = mappings._ScatterToSequenceParallelRegion.symbolic(None, input_data)
        assert(torch.equal(output_data, input_data[req_dim:req_dim + 2, :]))
        output_data = mappings.scatter_to_sequence_parallel_region(input_data)
        assert(torch.equal(output_data, input_data[req_dim:req_dim + 2, :]))
        input_data = torch.ones(4).cuda() * rank
        output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
        expected_output = torch.concat((
            torch.ones(4) * 0, 
            torch.ones(4) * 1, 
            torch.ones(4) * 2, 
            torch.ones(4) * 3)).cuda()
        if (rank >= 4):
            expected_output = expected_output + 4
        assert(torch.equal(output_data, expected_output))
        Utils.destroy_model_parallel()
    
    def test_GatherFromSequenceParallelRegion(self):
        rank = int(os.environ['LOCAL_RANK'])
        Utils.initialize_model_parallel(4, 2)
        input_data = torch.ones(4).cuda() * rank
        output_data = mappings.gather_from_sequence_parallel_region(input_data)
        expected_output = torch.concat((
            torch.ones(4) * 0, 
            torch.ones(4) * 1, 
            torch.ones(4) * 2, 
            torch.ones(4) * 3)).cuda()
        if (rank >= 4):
            expected_output = expected_output + 4
        assert(torch.equal(output_data, expected_output))
        assert(torch.equal(mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output))
        input_data = torch.vstack((
            torch.ones(4) * 0, 
            torch.ones(4) * 1, 
            torch.ones(4) * 2, 
            torch.ones(4) * 3)).cuda()

        class Ctx:
            tensor_parallel_output_grad = True
            output_split_sizes = None
            group = None
            use_global_buffer = False
        output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data)
        expected_output = torch.ones((1, 4)).cuda() * 4 * int(rank % 4)
        assert(torch.equal(output_data[0], expected_output))
        Utils.destroy_model_parallel()
    
    def test_ReduceScatterToSequenceParallelRegion(self):
        rank = int(os.environ['LOCAL_RANK'])
        Utils.initialize_model_parallel(4, 2)
        input_data = torch.vstack((
            torch.ones(4) * 0, 
            torch.ones(4) * 1, 
            torch.ones(4) * 2, 
            torch.ones(4) * 3)).cuda()
        output_data = mappings.reduce_scatter_to_sequence_parallel_region(input_data)
        expected_output = torch.ones(4).cuda() * 4 * int(rank % 4)
        assert(torch.equal(output_data[0], expected_output))
        assert(torch.equal(mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data), expected_output.reshape((1, 4))))
        input_data = torch.ones(4).cuda() * rank

        class Ctx:
            input_split_sizes = None
            group = None
            use_global_buffer = False

        output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(Ctx(), input_data)
        expected_output = torch.concat((
            torch.ones(4) * 0, 
            torch.ones(4) * 1, 
            torch.ones(4) * 2, 
            torch.ones(4) * 3)).cuda()
        if (rank >= 4):
            expected_output = expected_output + 4
        assert(torch.equal(output_data[0], expected_output))
        Utils.destroy_model_parallel()