from unittest.mock import MagicMock, patch, PropertyMock

from mindspeed_mm.fsdp.tools.flops_tool.flops_qwen3_5 import Qwen35FlopsCounter





class TestQwen35FlopsCounter:



    @patch("transformers.AutoConfig")

    def setup_method(self, method, mock_autoconfig):

        self.mock_text_config = MagicMock()

        self.mock_text_config.hidden_size = 1024

        self.mock_text_config.vocab_size = 10000

        self.mock_text_config.num_hidden_layers = 2

        self.mock_text_config.num_attention_heads = 16

        self.mock_text_config.num_key_value_heads = 16

        self.mock_text_config.full_attention_interval = 2

        self.mock_text_config.intermediate_size = 4096



        self.mock_text_config.linear_num_key_heads = 8

        self.mock_text_config.linear_key_head_dim = 64

        self.mock_text_config.linear_num_value_heads = 8

        self.mock_text_config.linear_value_head_dim = 64

        self.mock_text_config.linear_conv_kernel_dim = 16



        self.mock_vision_config = MagicMock()

        self.mock_vision_config.num_heads = 8

        self.mock_vision_config.depth = 2

        self.mock_vision_config.hidden_size = 512

        self.mock_vision_config.intermediate_size = 2048

        self.mock_vision_config.out_hidden_size = 512

        self.mock_vision_config.spatial_merge_size = 2

        self.mock_vision_config.in_channels = 3

        self.mock_vision_config.temporal_patch_size = 2

        self.mock_vision_config.patch_size = 14



        mock_config = MagicMock()

        mock_config.text_config = self.mock_text_config

        mock_config.vision_config = self.mock_vision_config



        mock_autoconfig.from_pretrained.return_value = mock_config



        self.counter = Qwen35FlopsCounter(config=mock_config)



    def test_estimate_flops_text_only_dense(self, mocker):

        """

        Test Scenario: Text-only input, Dense architecture.

        Objective: Verify core FLOPs calculation flow.

        """

        mock_estimate_family = mocker.patch.object(

            self.counter,

            "_estimate_qwen3_5_family_flops",

        )



        batch_seqlens = [128, 128]

        step_time = 6.9



        result = self.counter.estimate_flops(batch_seqlens=batch_seqlens, step_time=step_time)



        mock_estimate_family.assert_called_once()



    def test_estimate_flops_text_only_moe(self, mocker):

        """

        Test Scenario: Text-only input, MoE architecture.

        Objective: Verify MoE branch logic is triggered.

        """

        type(self.mock_text_config).num_experts = PropertyMock(return_value=8)

        self.mock_text_config.num_experts_per_tok = 2

        self.mock_text_config.moe_intermediate_size = 1024

        self.mock_text_config.shared_expert_intermediate_size = 512



        mock_estimate_family = mocker.patch.object(

            self.counter,

            "_estimate_qwen3_5_family_flops",

        )



        batch_seqlens = [64]

        step_time = 6.9



        result = self.counter.estimate_flops(batch_seqlens=batch_seqlens, step_time=step_time)



        mock_estimate_family.assert_called_once()



    def test_estimate_flops_multimodal_with_vit(self, mocker):

        """

        Test Scenario: Multimodal input (containing images).

        Objective: Verify ViT calculation logic is invoked when images_seqlens is passed.

        """

        mock_estimate_family = mocker.patch.object(

            self.counter,

            "_estimate_qwen3_5_family_flops",

        )



        mock_estimate_vit = mocker.patch.object(

            self.counter,

            "_estimate_qwen3_vit_flop",

        )



        batch_seqlens = [256]

        images_seqlens = [100]

        step_time = 6.9



        result = self.counter.estimate_flops(

            batch_seqlens=batch_seqlens, images_seqlens=images_seqlens, step_time=step_time

        )



        mock_estimate_family.assert_called_once()



        _, kwargs = mock_estimate_family.call_args

        assert "images_seqlens" in kwargs

        assert kwargs["images_seqlens"] == images_seqlens