from argparse import Namespace
import types

import torch
import torch_npu
import pytest
from pytest_mock import MockFixture

import mindspeed.megatron_adaptor


from mindspeed.core.pipeline_parallel.variable_seq_length.adaptor import (
    mindspeed_communicate,
    mindspeed_commuticate_shapes,
)
from mindspeed.core.pipeline_parallel.variable_seq_length.communicate import (
    communicate_impl,
    communicate_shapes_impl,
)


def test_mindspeed_communicate(mocker: MockFixture):
    mocker.patch(
        "mindspeed.core.pipeline_parallel.variable_seq_length.adaptor.communicate_impl",
        return_value=(1, 2),
    )
    ret = mindspeed_communicate(
        tensor_send_next=None,
        tensor_send_prev=None,
        recv_prev=False,
        recv_next=False,
        tensor_shape=None,
        config=None,
        wait_on_reqs=False,
    )
    assert ret == (1, 2)


def test_mindspeed_communicate_shapes(mocker: MockFixture):
    mocker.patch(
        "mindspeed.core.pipeline_parallel.variable_seq_length.adaptor.communicate_shapes_impl",
        return_value=(1, 2),
    )
    ret = mindspeed_commuticate_shapes(
        tensor_send_next=None,
        tensor_send_prev=None,
        recv_prev=False,
        recv_next=False,
        config=None,
    )
    assert ret == (1, 2)


@pytest.mark.parametrize(
    " config, expected",
    [
        (
            Namespace(
                use_ring_exchange_p2p=False,
                batch_p2p_comm=True,
                batch_p2p_sync=True,
            ),
            ([0, 0, 0]),
        ),
        (
            Namespace(
                use_ring_exchange_p2p=False,
                batch_p2p_comm=False,
                batch_p2p_sync=True,
            ),
            ([0, 0, 0]),
        ),
    ],
)
def test_communicate_shapes_impl(mocker: MockFixture, config, expected):
    mocker.patch("torch.cuda.synchronize")
    ret = communicate_shapes_impl(
        tensor_send_next=torch.tensor([1, 2, 3]),
        tensor_send_prev=torch.tensor([4, 5, 6]),
        recv_prev=True,
        recv_next=True,
        config=config,
        get_pipeline_model_parallel_group=lambda: None,
        get_pipeline_model_parallel_next_rank=lambda: 1,
        get_pipeline_model_parallel_prev_rank=lambda: 2,
        batched_p2p_ops=lambda **kwargs: [],
        p2p_ops=lambda **kwargs: [],
    )
    assert ret[0] == expected


@pytest.mark.parametrize(
    " config, expected",
    [
        (
            Namespace(
                use_ring_exchange_p2p=False,
                batch_p2p_comm=False,
                batch_p2p_sync=True,
                variable_seq_lengths=True,
                pipeline_dtype=torch.float32,
            ),
            (0, 0, []),
        ),
    ],
)
def test_communicate_impl(mocker: MockFixture, config, expected):
    mocker.patch("torch.cuda.synchronize")
    mocker.patch(
        "mindspeed.core.pipeline_parallel.variable_seq_length.communicate.communicate_shapes_impl",
        return_value=([0, 0, 0], [0, 0, 0]),
    )
    prev, next, reqs = communicate_impl(
        tensor_send_next=torch.tensor([1, 2, 3]),
        tensor_send_prev=torch.tensor([4, 5, 6]),
        recv_prev=True,
        recv_next=True,
        tensor_shape=[1, 2, 3],
        config=config,
        get_pipeline_model_parallel_group=lambda: [1],
        get_pipeline_model_parallel_next_rank=lambda: [1],
        get_pipeline_model_parallel_prev_rank=lambda: [2],
        batched_p2p_ops=lambda **kwargs: {},
        p2p_ops=lambda **kwargs: {},
        original_batched_p2p_ops=lambda **kwargs: {},
        original_p2p_ops=lambda **kwargs: {}
    )
    assert prev.sum() == expected[0]
    assert next.sum() == expected[1]
    assert len(reqs) == len(expected[2])