import unittest

from unittest.mock import MagicMock, patch



import torch

import torch.distributed as dist

from torch_npu.testing.testcase import TestCase, run_tests



from mx_driving.dataset.utils import DynamicSampler, DynamicDistributedSampler, ReplicasDistributedSampler

from mx_driving.dataset.agent_dataset import DynamicBatchSampler, AgentDynamicBatchSampler, AgentDynamicBatchDataLoader





class TestDynamicSampler(TestCase):

    def test_abstract_methods(self):

        with self.assertRaises(TypeError):

            DynamicSampler()





class TestDynamicDistributedSampler(TestCase):

    def setUp(self):

        super().setUp()

        self.mock_dataset = MagicMock()

        self.mock_dataset.buckets = [[1, 2, 3], [4, 5], [6, 7, 8, 9, 10]]

        self.mock_dataset.__len__ = MagicMock(return_value=10)



    @patch('torch.distributed.is_available')

    @patch('torch.distributed.get_world_size')

    @patch('torch.distributed.get_rank')

    def test_init_with_distributed(self, mock_get_rank, mock_get_world_size, mock_is_available):

        mock_is_available.return_value = True

        mock_get_world_size.return_value = 2

        mock_get_rank.return_value = 0



        sampler = DynamicDistributedSampler(self.mock_dataset)



        self.assertEqual(sampler.num_replicas, 2)

        self.assertEqual(sampler.rank, 0)

        self.assertEqual(sampler.num_samples, 5)

        self.assertEqual(sampler.total_size, 10)



    def test_init_without_distributed(self):

        with patch('torch.distributed.is_available', return_value=False):

            with self.assertRaises(RuntimeError):

                DynamicDistributedSampler(self.mock_dataset)



    def test_init_with_invalid_dataset(self):

        invalid_dataset = MagicMock()

        invalid_dataset.buckets = "not a list"



        with self.assertRaises(ValueError):

            DynamicDistributedSampler(invalid_dataset)



        invalid_dataset.buckets = [[1, 2], "not a list", [3, 4]]

        with self.assertRaises(ValueError):

            DynamicDistributedSampler(invalid_dataset)



    def test_init_with_custom_replicas_and_rank(self):

        with patch('torch.distributed.is_available', return_value=False):

            sampler = DynamicDistributedSampler(self.mock_dataset, num_replicas=4, rank=1)



        self.assertEqual(sampler.num_replicas, 4)

        self.assertEqual(sampler.rank, 1)

        self.assertEqual(sampler.num_samples, 3)

        self.assertEqual(sampler.total_size, 12)



    def test_bucket_arange(self):

        with patch('torch.distributed.is_available', return_value=False):

            sampler = DynamicDistributedSampler(self.mock_dataset, num_replicas=2, rank=0)

        sampler.epoch = 1

        result = sampler.bucket_arange()

        expected_length = 10

        self.assertEqual(len(result), expected_length)



        all_elements = []

        for bucket in self.mock_dataset.buckets:

            all_elements.extend(bucket)

        self.assertEqual(sorted(result), sorted(all_elements))



    def test_iter_with_shuffle(self):

        with patch('torch.distributed.is_available', return_value=False):

            sampler = DynamicDistributedSampler(self.mock_dataset, num_replicas=2, rank=0, shuffle=True)

        with patch.object(sampler, 'bucket_arange', return_value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]):

            indices = list(iter(sampler))

            self.assertEqual(indices, [1, 2, 3, 4, 5])

            self.assertEqual(len(indices), sampler.num_samples)



    def test_iter_without_shuffle(self):

        with patch('torch.distributed.is_available', return_value=False):

            sampler = DynamicDistributedSampler(self.mock_dataset, num_replicas=2, rank=0, shuffle=False)

        indices = list(iter(sampler))



        all_elements = []

        for bucket in self.mock_dataset.buckets:

            all_elements.extend(bucket)

        expected = all_elements[:5]

        self.assertEqual(indices, expected)

        self.assertEqual(len(indices), sampler.num_samples)



    def test_set_epoch(self):

        with patch('torch.distributed.is_available', return_value=False):

            sampler = DynamicDistributedSampler(self.mock_dataset, num_replicas=2, rank=0)

        sampler.set_epoch(5)

        self.assertEqual(sampler.epoch, 5)





class TestReplicasDistributedSampler(TestCase):

    def setUp(self):

        super().setUp()

        self.mock_dataset = MagicMock()

        self.mock_dataset.buckets = [[1, 2], [3, 4], [5, 6]]

        self.mock_dataset.__len__ = MagicMock(return_value=6)



    def test_init_with_valid_dataset(self):

        with patch('torch.distributed.is_available', return_value=False):

            sampler = ReplicasDistributedSampler(self.mock_dataset, num_replicas=2, rank=0)

        self.assertEqual(sampler.num_replicas, 2)

        self.assertEqual(sampler.rank, 0)

        self.assertEqual(sampler.num_samples, 3)

        self.assertEqual(sampler.total_size, 6)



    def test_bucket_arange(self):

        with patch('torch.distributed.is_available', return_value=False):

            sampler = ReplicasDistributedSampler(self.mock_dataset, num_replicas=2, rank=0)

        sampler.epoch = 1

        result = sampler.bucket_arange()

        expected_length = 6

        self.assertEqual(len(result), expected_length)



        all_elements = []

        for bucket in self.mock_dataset.buckets:

            all_elements.extend(bucket)

        self.assertEqual(sorted(result), sorted(all_elements))



    def test_iter_with_shuffle(self):

        with patch('torch.distributed.is_available', return_value=False):

            sampler = ReplicasDistributedSampler(self.mock_dataset, num_replicas=2, rank=0, shuffle=True)

        with patch.object(sampler, 'bucket_arange', return_value=[1, 2, 3, 4, 5, 6]):

            indices = list(iter(sampler))

            self.assertEqual(indices, [1, 2, 3])

            self.assertEqual(len(indices), sampler.num_samples)





class TestDynamicBatchSampler(TestCase):

    def setUp(self):

        super().setUp()

        self.mock_dataset = MagicMock()

        self.mock_dataset.agents_num = {0: 15, 1: 25, 2: 5, 3: 35, 4: 45}



        self.mock_sampler = MagicMock()

        self.mock_sampler.__iter__ = MagicMock(return_value=iter([0, 1, 2, 3, 4]))



    def test_iter_without_drop_last(self):

        batch_sampler = DynamicBatchSampler(

            self.mock_dataset, self.mock_sampler, batch_size=2, drop_last=False)

        batches = list(batch_sampler)

        expected_batches = [[0, 1], [2, 3], [4]]

        self.assertEqual(batches, expected_batches)



    def test_iter_with_drop_last(self):

        batch_sampler = DynamicBatchSampler(

            self.mock_dataset, self.mock_sampler, batch_size=2, drop_last=True)

        batches = list(batch_sampler)

        expected_batches = [[0, 1], [2, 3]]

        self.assertEqual(batches, expected_batches)





class TestAgentDynamicBatchSampler(TestCase):

    def setUp(self):

        super().setUp()

        self.mock_dataset = MagicMock()

        self.mock_dataset.buckets = [[0, 1], [2, 3], [4]]

        self.mock_dataset.__len__ = MagicMock(return_value=5)



    @patch('torch.distributed.is_available')

    @patch('torch.distributed.get_world_size')

    @patch('torch.distributed.get_rank')

    def test_init_with_drop_last(self, mock_get_rank, mock_get_world_size, mock_is_available):

        mock_is_available.return_value = True

        mock_get_world_size.return_value = 2

        mock_get_rank.return_value = 0



        sampler = AgentDynamicBatchSampler(

            self.mock_dataset, drop_last=True

        )



        self.assertEqual(sampler.num_samples, 2)

        self.assertEqual(sampler.total_size, 4)



    @patch('torch.distributed.is_available')

    @patch('torch.distributed.get_world_size')

    @patch('torch.distributed.get_rank')

    def test_iter_with_drop_last(self, mock_get_rank, mock_get_world_size, mock_is_available):

        mock_is_available.return_value = True

        mock_get_world_size.return_value = 2

        mock_get_rank.return_value = 0



        sampler = AgentDynamicBatchSampler(

            self.mock_dataset, drop_last=True

        )



        with patch.object(sampler, 'bucket_arange', return_value=[0, 1, 2, 3, 4]):

            indices = list(sampler)

            self.assertEqual(indices, [0, 2])

            self.assertEqual(len(indices), sampler.num_samples)



    @patch('torch.distributed.is_available')

    @patch('torch.distributed.get_world_size')

    @patch('torch.distributed.get_rank')

    def test_iter_without_drop_last(self, mock_get_rank, mock_get_world_size, mock_is_available):

        mock_is_available.return_value = True

        mock_get_world_size.return_value = 2

        mock_get_rank.return_value = 0

        sampler = AgentDynamicBatchSampler(self.mock_dataset, drop_last=False)



        with patch.object(sampler, 'bucket_arange', return_value=[0, 1, 2, 3, 4]):

            indices = list(sampler)

            self.assertEqual(indices, [0, 2, 4])

            self.assertEqual(len(indices), sampler.num_samples)





class TestAgentDynamicBatchDataLoader(TestCase):

    def setUp(self):

        super().setUp()

        self.mock_dataset = MagicMock()

        self.mock_dataset.agents_num = {0: 15, 1: 25, 2: 5, 3: 35, 4: 45}



    @patch('mx_driving.dataset.agent_dataset.AgentDynamicBatchSampler')

    @patch('mx_driving.dataset.agent_dataset.DynamicBatchSampler')

    @patch('mx_driving.dataset.agent_dataset.Collater')

    def test_init(self, mock_collater, mock_batch_sampler, mock_sampler):

        mock_sampler_instance = MagicMock()

        mock_sampler.return_value = mock_sampler_instance



        mock_batch_sampler_instance = MagicMock()

        mock_batch_sampler.return_value = mock_batch_sampler_instance



        mock_collater_instance = MagicMock()

        mock_collater.return_value = mock_collater_instance



        dataloader = AgentDynamicBatchDataLoader(

            self.mock_dataset,

            batch_size=2,

            train_batch_size=2,

            shuffle=True,

            follow_batch=['agent'],

            exclude_keys=['exclude_key']

        )



        mock_sampler.assert_called_once_with(self.mock_dataset, shuffle=True)

        mock_batch_sampler.assert_called_once_with(

            self.mock_dataset, mock_sampler_instance, 2

        )

        mock_collater.assert_called_once_with(['agent'], ['exclude_key'])



        self.assertEqual(dataloader.dataset, self.mock_dataset)

        self.assertEqual(dataloader.collate_fn, mock_collater_instance)

        self.assertEqual(dataloader.batch_sampler, mock_batch_sampler_instance)



if __name__ == '__main__':

    run_tests()