import unittest

import pytest
import torch
from parameterized import parameterized
from tensor_cast.compilation import get_backend
from tensor_cast.core.config_resolver import ConfigResolver
from tensor_cast.core.quantization.datatypes import QuantizeLinearAction
from tensor_cast.core.user_config import UserInputConfig
from tensor_cast.device import TEST_DEVICE
from tensor_cast.layers.attention import AttentionTensorCast
from tensor_cast.layers.quant_linear import TensorCastQuantLinear
from tensor_cast.model_config import ModelConfig, ParallelConfig
from tensor_cast.performance_model.analytic import AnalyticPerformanceModel
from tensor_cast.performance_model.memory_tracker import MemoryTracker
from tensor_cast.quantize_utils import LinearQuantType, QuantGranularity
from tensor_cast.runtime import Runtime
from tensor_cast.transformers.custom_model_registry import get_moe_config
from tensor_cast.transformers.model import TransformerModel
from tensor_cast.transformers.utils import AutoModelConfigLoader

from .test_common import count_events, get_quant_config


@pytest.mark.nightly
class GmmPassTestCase(unittest.TestCase):
    def setUp(self):
        torch.compiler.reset()

    @parameterized.expand(
        [
            "Qwen/Qwen3-235B-A22B",
            "Qwen/Qwen3-VL-30B-A3B-Instruct",
        ]
    )
    def test_qwen3_fp(self, model_id):
        user_input = UserInputConfig(
            model_id=model_id,
            do_compile=True,
            num_hidden_layers_override=1,
            quantize_linear_action=QuantizeLinearAction.DISABLED,
        )
        config_resolver = ConfigResolver(user_input=user_input)
        model_config = config_resolver.resolve()
        model = TransformerModel(model_id, model_config)
        model = torch.compile(model, backend=get_backend(), fullgraph=True)

        num_tokens = 100
        inputs = torch.empty([1, num_tokens], dtype=torch.long, device="meta")
        position_ids = torch.empty([1, num_tokens], dtype=torch.long, device="meta")
        device_profile = TEST_DEVICE
        perf_model = AnalyticPerformanceModel(device_profile)
        with (
            Runtime(perf_model, device_profile, memory_tracker=MemoryTracker(device_profile)) as runtime,
            torch.no_grad(),
        ):
            outputs = model.forward(inputs, position_ids)
            self.assertEqual(outputs.shape, (1, num_tokens, model.vocab_size))
        self.assertEqual(count_events(runtime, torch.ops.tensor_cast.grouped_matmul.default), 1)

    def test_qwen3_static_int8(self):
        model_id = "Qwen/Qwen3-235B-A22B"
        auto_loader = AutoModelConfigLoader()
        hf_config = auto_loader.load_config(model_id)
        moe_config = get_moe_config(hf_config.model_type)
        num_tokens = 100
        model_config = ModelConfig(
            ParallelConfig(),
            get_quant_config(
                quant_type=LinearQuantType.W8A8,
                activation_scale=torch.tensor(1.0),
            ),
            quant_linear_cls=TensorCastQuantLinear,
            attention_cls=AttentionTensorCast,
            num_hidden_layers_override=1,
            moe_config=moe_config,
            hf_config=hf_config,
        )
        model = TransformerModel(model_id, model_config)
        model = torch.compile(model, backend=get_backend(), fullgraph=True)
        inputs = torch.empty([1, num_tokens], dtype=torch.long, device="meta")
        position_ids = torch.empty([1, num_tokens], dtype=torch.long, device="meta")
        device_profile = TEST_DEVICE
        perf_model = AnalyticPerformanceModel(device_profile)
        with (
            Runtime(perf_model, device_profile, memory_tracker=MemoryTracker(device_profile)) as runtime,
            torch.no_grad(),
        ):
            outputs = model.forward(inputs, position_ids)
            self.assertEqual(outputs.shape, (1, num_tokens, model.vocab_size))
        self.assertEqual(count_events(runtime, torch.ops.tensor_cast.grouped_matmul_quant.default), 1)
        self.assertEqual(
            count_events(runtime, torch.ops.tensor_cast.grouped_matmul_quant_swiglu.default),
            1,
        )

    @parameterized.expand(
        [
            [LinearQuantType.W8A8],
            [LinearQuantType.W4A8],
            [LinearQuantType.FP8],
            [LinearQuantType.MXFP4],
        ]
    )
    def test_qwen3_dynamic_quant(self, quant_type):
        model_id = "Qwen/Qwen3-235B-A22B"
        num_tokens = 100
        auto_loader = AutoModelConfigLoader()
        hf_config = auto_loader.load_config(model_id)
        moe_config = get_moe_config(hf_config.model_type)
        model_config = ModelConfig(
            ParallelConfig(),
            get_quant_config(
                quant_type=quant_type,
                weight_quant_granularity=QuantGranularity.PER_GROUP
                if quant_type == LinearQuantType.MXFP4
                else QuantGranularity.PER_TENSOR,
                weight_group_size=32 if quant_type == LinearQuantType.MXFP4 else None,
            ),
            quant_linear_cls=TensorCastQuantLinear,
            attention_cls=AttentionTensorCast,
            num_hidden_layers_override=1,
            moe_config=moe_config,
            hf_config=hf_config,
        )
        model = TransformerModel(model_id, model_config)
        model = torch.compile(model, backend=get_backend(), fullgraph=True)
        inputs = torch.empty([1, num_tokens], dtype=torch.long, device="meta")
        position_ids = torch.empty([1, num_tokens], dtype=torch.long, device="meta")
        device_profile = TEST_DEVICE
        perf_model = AnalyticPerformanceModel(device_profile)
        with (
            Runtime(perf_model, device_profile, memory_tracker=MemoryTracker(device_profile)) as runtime,
            torch.no_grad(),
        ):
            outputs = model.forward(inputs, position_ids)
            self.assertEqual(outputs.shape, (1, num_tokens, model.vocab_size))
        expected_op = None
        expected_swiglu_op = None
        if quant_type == LinearQuantType.W8A8:
            expected_op = torch.ops.tensor_cast.grouped_matmul_quant.default
            expected_swiglu_op = torch.ops.tensor_cast.grouped_matmul_quant_swiglu.default
        elif quant_type == LinearQuantType.W4A8:
            expected_op = torch.ops.tensor_cast.grouped_matmul_quant_int4.default
            expected_swiglu_op = torch.ops.tensor_cast.grouped_matmul_quant_int4_swiglu.default
        elif quant_type == LinearQuantType.FP8:
            expected_op = torch.ops.tensor_cast.grouped_matmul_fp8.default
            expected_swiglu_op = torch.ops.tensor_cast.grouped_matmul_fp8_swiglu.default
        elif quant_type == LinearQuantType.MXFP4:
            expected_op = torch.ops.tensor_cast.grouped_matmul_mxfp4.default
            expected_swiglu_op = torch.ops.tensor_cast.grouped_matmul_mxfp4_swiglu.default
        self.assertEqual(count_events(runtime, expected_op), 1)
        self.assertEqual(count_events(runtime, expected_swiglu_op), 1)