import pytest
import torch
import torch_npu
import mindspeed.megatron_adaptor
from tests_extend.unit_tests.common import DistributedTest
import megatron.core.parallel_state as Utils
from megatron.training.global_vars import set_args
from megatron.training.arguments import parse_args
from megatron.core.tensor_parallel.random import CudaRNGStatesTracker
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed, get_cuda_rng_tracker
from megatron.core.tensor_parallel.random import checkpoint


class TestTPRandom(DistributedTest):
    world_size = 8
    args = parse_args(None, True)
    set_args(args)

    def test_cuda_rng_states_tracker(self):
        rng_tracker = CudaRNGStatesTracker()
        rng_tracker.set_states({"state1": 1234})
        assert(rng_tracker.get_states()["state1"] == 1234)
        rng_tracker.reset()
        assert(rng_tracker.get_states() == {})
        seed = 1111
        rng_tracker.add("state2", seed)
        with pytest.raises(Exception):
            assert(rng_tracker.add("state3", seed))
        with pytest.raises(Exception):
            assert(rng_tracker.add("state2", 111))
        assert(rng_tracker.get_states()['state2'] is not None)
        with pytest.raises(Exception):
            assert()
        
        rng_tracker.fork("state2")
        torch.cuda.manual_seed(seed)
        rng_state = torch.cuda.get_rng_state()
        assert torch.equal(rng_tracker.get_states()['state2'], rng_state)

    def test_model_parallel_cuda_manual_seed(self):
        Utils.initialize_model_parallel(4, 2)
        model_parallel_cuda_manual_seed(0)
        rng_tracker = get_cuda_rng_tracker()
        assert(rng_tracker.get_states()['model-parallel-rng'] is not None)
        Utils.destroy_model_parallel()

    def test_checkpoint(self):
        def test_forward(*input_list):
            return input_list[0] + input_list[1]
        assert(torch.equal(torch.ones(16) * 3, 
            checkpoint(test_forward, None, torch.ones(16), torch.ones(16) * 2)))
        Utils.initialize_model_parallel()
        input1 = torch.ones((4, 4))
        checkpoint(test_forward, True, input1, torch.ones((4, 4)) * 2)
        assert(torch.equal(torch.ones(input1.numel()).cuda(), input1))
        Utils.destroy_model_parallel()