import unittest
from unittest.mock import Mock, patch
import torch
import numpy as np

from mindspeed_mm.utils.utils import dist_sort, EncoderBalanceComm
from tests.ut.utils import judge_expression


class TestDistSort(unittest.TestCase):
    """Test cases for dist_sort function"""

    def test_dist_sort_equal_distribution(self):
        """Test dist_sort when images are equally distributed"""
        image_num_list = np.array([4, 4, 4, 4])
        transfer, target = dist_sort(image_num_list)

        judge_expression(np.sum(transfer) == 0)
        judge_expression(all(t == 4 for t in target))

    def test_dist_sort_unequal_distribution(self):
        """Test dist_sort when images need redistribution"""
        image_num_list = np.array([8, 2, 6, 0])
        transfer, target = dist_sort(image_num_list)

        total_before = np.sum(image_num_list)
        total_after = np.sum(target)

        judge_expression(total_before == total_after)
        judge_expression(all(t == 4 for t in target))
        judge_expression(transfer.shape == (4, 4))
        judge_expression(np.sum(transfer) > 0)

    def test_dist_sort_with_remainder(self):
        """Test dist_sort when total images don't divide evenly"""
        image_num_list = np.array([7, 1, 3, 2])
        _, target = dist_sort(image_num_list)

        total_target = np.sum(target)
        judge_expression(total_target == 13)

        avg = 13 // 4
        judge_expression(all(t >= avg and t <= avg + 1 for t in target))


class TestEncoderBalanceComm(unittest.TestCase):
    """Test cases for EncoderBalanceComm autograd function"""

    def setUp(self):
        """Set up test fixtures"""
        self.device = torch.device('cpu')
        self.world_size = 4
        self.rank = 0

    @patch('torch.distributed.get_rank')
    @patch('torch.distributed.get_world_size')
    def test_forward_no_transfer_needed(self, mock_world_size, mock_rank):
        """Test forward when no load balancing is needed"""
        mock_rank.return_value = self.rank
        mock_world_size.return_value = self.world_size

        mock_group = Mock()

        input_tensor = torch.randn(4, 64, dtype=torch.float32)

        transfer = np.zeros((4, 4))
        target = [4, 4, 4, 4]

        result = EncoderBalanceComm.apply(input_tensor, mock_group, (transfer, target))

        judge_expression(torch.equal(result, input_tensor))

    @patch('torch.distributed.get_rank')
    @patch('torch.distributed.get_world_size')
    @patch('torch.distributed.all_to_all')
    def test_forward_with_transfer(self, mock_all_to_all, mock_world_size, mock_rank):
        """Test forward when load balancing is needed"""
        mock_rank.return_value = self.rank
        mock_world_size.return_value = self.world_size

        mock_group = Mock()

        input_tensor = torch.randn(8, 64, dtype=torch.float32)

        transfer = np.array([[0, 4, 0, 0],
                           [0, 0, 0, 0],
                           [0, 0, 0, 0],
                           [0, 0, 0, 0]])
        target = [4, 4, 4, 4]

        def mock_all_to_all_side_effect(recv_list, send_list, group=None):
            for i, send_tensor in enumerate(send_list):
                if i < len(recv_list) and send_tensor.numel() > 0:
                    recv_list[i].resize_(send_tensor.shape)
                    recv_list[i].copy_(send_tensor)

        mock_all_to_all.side_effect = mock_all_to_all_side_effect

        result = EncoderBalanceComm.apply(input_tensor, mock_group, (transfer, target))

        judge_expression(isinstance(result, torch.Tensor))
        judge_expression(result.dim() == 2)
        judge_expression(result.shape[1] == 64)

    def test_forward_skip_mode(self):
        """Test forward in skip mode"""
        mock_group = Mock()

        input_tensor = torch.randn(4, 64, dtype=torch.float32)

        transfer = np.zeros((4, 4))
        target = [4, 4, 4, 4]

        with patch('torch.distributed.get_rank', return_value=0), \
             patch('torch.distributed.get_world_size', return_value=4):

            result, returned_transfer = EncoderBalanceComm.apply(
                input_tensor, mock_group, (transfer, target), False, True
            )

            judge_expression(torch.equal(result, input_tensor))

    @patch('torch.distributed.get_rank')
    @patch('torch.distributed.get_world_size')
    def test_backward_no_transfer(self, mock_world_size, mock_rank):
        """Test backward when no transfer was needed"""
        mock_rank.return_value = self.rank
        mock_world_size.return_value = self.world_size

        mock_group = Mock()

        input_tensor = torch.randn(4, 64, dtype=torch.float32, requires_grad=True)

        transfer = np.zeros((4, 4))
        target = [4, 4, 4, 4]

        result = EncoderBalanceComm.apply(input_tensor, mock_group, (transfer, target))

        grad_output = torch.randn_like(result)
        result.backward(grad_output)

        judge_expression(input_tensor.grad is not None)
        judge_expression(torch.equal(input_tensor.grad, grad_output))

    def test_nopadding_flag(self):
        """Test the nopadding flag functionality"""
        mock_group = Mock()

        input_tensor = torch.randn(6, 64, dtype=torch.float32)

        transfer = np.zeros((4, 4))
        target = [4, 4, 4, 4]

        with patch('torch.distributed.get_rank', return_value=0), \
             patch('torch.distributed.get_world_size', return_value=4):

            result = EncoderBalanceComm.apply(
                input_tensor, mock_group, (transfer, target), False, False
            )

            judge_expression(result.shape[0] <= target[0] or np.sum(transfer) > 0)