import types

import pytest
import torch
import torch.fx as fx
from tensor_cast.config import performance_model as perf_config
from tensor_cast.compilation.shape_prop import shape_propagation
from tensor_cast.device import TEST_DEVICE
from tensor_cast.model_config import AttentionQuantConfig, QuantConfig
from tensor_cast.performance_model.op_benchmark import (
    OpBenchmark,
    get_op_impl,
    register_op_impl,
)
from tensor_cast.performance_model.op_estimator_registry import (
    _op_estimator_table,
    get_op_estimator,
    register_op_estimator,
)
from tensor_cast.performance_model.op_invoke_info import OpInvokeInfo
from tensor_cast.quantize_utils import AttentionQuantType


class _NonDefaultEpsRMSNormModule(torch.nn.Module):
    def __init__(self, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.ones(4, dtype=torch.float32))

    def _rms_norm(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
        return self.weight * hidden_states.to(input_dtype)

    def forward(self, hidden_states, residual):
        rms = self._rms_norm(hidden_states)
        add_rms = self._rms_norm(hidden_states + residual)
        added = hidden_states + residual
        add_rms2 = self._rms_norm(added)
        return rms, add_rms, add_rms2, added


def test_register_and_get_op_estimator():
    op_key = object()
    original = _op_estimator_table.get(None, {}).get(op_key)

    @register_op_estimator(op_key, None, True)
    def _estimator(op_invoke_info, device_profile):
        return "ok"

    assert get_op_estimator(op_key, None) is _estimator

    if original is None:
        _op_estimator_table[None].pop(op_key, None)
    else:
        _op_estimator_table[None][op_key] = original


def test_rms_norm_non_default_eps_path_consistency():
    module = _NonDefaultEpsRMSNormModule(eps=1e-6)
    hidden_states = torch.randn(2, 4, dtype=torch.float32)
    residual = torch.randn(2, 4, dtype=torch.float32)
    _, add_rms, add_rms2, added = module(hidden_states, residual)
    torch.testing.assert_close(add_rms, add_rms2, rtol=1e-6, atol=1e-6)
    torch.testing.assert_close(added, hidden_states + residual, rtol=0.0, atol=0.0)


def test_quant_attention_config_can_target_single_layer():
    quant_config = QuantConfig()
    attn_config = AttentionQuantConfig(
        quant_type=AttentionQuantType.INT8,
        query_scale=torch.tensor(1.0),
        kv_scale=torch.tensor(1.0),
        attention_prob_scale=torch.tensor(1.0),
    )
    quant_config.attention_configs[0] = attn_config
    assert 0 in quant_config.attention_configs
    assert quant_config.attention_configs[0].quant_type == AttentionQuantType.INT8


def test_multistream_count_nodes_helper_behavior():
    graph = fx.Graph()
    x = graph.placeholder("x")
    y = graph.call_function(torch.ops.aten.neg.default, args=(x,))
    graph.output(y)
    gm = fx.GraphModule({}, graph)
    count = sum(1 for node in gm.graph.nodes if node.target == torch.ops.aten.neg.default)
    assert count == 1


def test_grouped_matmul_meta_ops_preserve_shapes_and_dtype():
    x = [torch.empty((2, 3), device="meta"), torch.empty((1, 3), device="meta")]
    w = [torch.empty((3, 4), device="meta"), torch.empty((3, 4), device="meta")]
    bias = [None, torch.empty((4,), device="meta")]
    scales = [torch.empty((1,), device="meta"), torch.empty((1,), device="meta")]

    assert torch.ops.tensor_cast.grouped_matmul.default(x, w, bias).shape == (3, 4)
    quant_out = torch.ops.tensor_cast.grouped_matmul_quant.default(
        x,
        w,
        scales,
        [None, None],
        scales,
        [None, None],
        bias,
        None,
    )
    assert quant_out.shape == (3, 4)
    int4_out = torch.ops.tensor_cast.grouped_matmul_quant_int4.default(
        x,
        w,
        scales,
        [None, None],
        scales,
        [None, None],
        bias,
        torch.float16,
    )
    assert int4_out.dtype == torch.float16
    assert torch.ops.tensor_cast.grouped_matmul_fp8.default(x, w, scales, scales, bias, torch.bfloat16).dtype == (
        torch.bfloat16
    )
    assert torch.ops.tensor_cast.grouped_matmul_mxfp4.default(x, w, scales, scales, bias, None).dtype == torch.float32
    assert torch.ops.tensor_cast.grouped_matmul_swiglu.default([], [], []).shape == (
        0,
        0,
    )
    assert torch.ops.tensor_cast.grouped_matmul_quant_swiglu.default([], [], [], [], [], [], [], None).dtype == (
        torch.float32
    )
    assert torch.ops.tensor_cast.grouped_matmul_fp8_swiglu.default([], [], [], [], [], None).shape == (0, 0)


def test_communication_meta_ops_compute_collective_shapes(monkeypatch):
    x = torch.empty((4, 3), device="meta")

    assert torch.ops.tensor_cast.all_to_all.default(x, [1, 3], [2, 2], 0, [0, 1]).shape == (4, 3)
    assert torch.ops.tensor_cast.all_reduce.default(x, 0, [0, 1]).shape == x.shape
    assert torch.ops.tensor_cast.reduce_scatter.default(x, 0, 0, [0, 1]).shape == (2, 3)
    assert torch.ops.tensor_cast.all_gather.default(x, 1, 0, [0, 1]).shape == (4, 6)
    matmul_out = torch.ops.tensor_cast.matmul_all_reduce.default(x, torch.empty((3, 5), device="meta"), None, 0, [0])
    assert matmul_out.shape == (4, 5)

    linear_out = torch.empty((4, 5), device="meta", dtype=torch.float16)
    monkeypatch.setattr(torch.ops.tensor_cast.static_quant_linear, "default", lambda *args: linear_out)
    monkeypatch.setattr(
        torch.ops.tensor_cast.static_quant_linear_int4,
        "default",
        lambda *args: linear_out,
    )
    monkeypatch.setattr(torch.ops.tensor_cast.fp8_linear, "default", lambda *args: linear_out)
    monkeypatch.setattr(torch.ops.tensor_cast.mxfp4_linear, "default", lambda *args: linear_out)

    quant_args = (
        x,
        torch.empty((3, 5), device="meta"),
        torch.empty((1,), device="meta"),
        None,
        None,
        None,
        None,
        None,
        0,
        [0],
    )
    assert torch.ops.tensor_cast.static_quant_linear_all_reduce.default(*quant_args).shape == (4, 5)
    assert torch.ops.tensor_cast.static_quant_linear_int4_all_reduce.default(*quant_args).dtype == torch.float16
    fp_args = (
        x,
        torch.empty((3, 5), device="meta"),
        torch.empty((1,), device="meta"),
        torch.empty((1,), device="meta"),
        None,
        None,
        0,
        [0],
    )
    assert torch.ops.tensor_cast.fp8_linear_all_reduce.default(*fp_args).shape == (4, 5)
    assert torch.ops.tensor_cast.mxfp4_linear_all_reduce.default(*fp_args).shape == (
        4,
        5,
    )


def test_shape_propagation_records_tensor_metadata():
    class Tiny(torch.nn.Module):
        def forward(self, x):
            return x + 1

    gm = fx.symbolic_trace(Tiny())
    result = shape_propagation(gm, [torch.empty((2, 3), device="meta")])

    output_node = next(node for node in result.graph.nodes if node.op == "output")
    produced_node = output_node.args[0]
    assert tuple(produced_node.meta["tensor_meta"].shape) == (2, 3)


def test_op_benchmark_registry_runtime_and_quantize(monkeypatch):
    quantize_impl = get_op_impl(torch.ops.tensor_cast.quantize.default, torch.device("cpu"))
    x = torch.tensor([1.1, 2.1])
    scale = torch.tensor([1.0, 1.0])
    assert torch.equal(
        quantize_impl(x, scale, torch.tensor([1.0, -1.0])),
        torch.tensor([2, 1], dtype=torch.int8),
    )

    op_name = "unit_test_op"
    register_op_impl(op_name, "cpu")(lambda tensor: tensor)
    with pytest.raises(ValueError, match="already registered"):
        register_op_impl(op_name, "cpu")(lambda tensor: tensor)

    benchmark = OpBenchmark(TEST_DEVICE)
    assert benchmark.runtime_device == torch.device("cpu")
    monkeypatch.setattr(perf_config.empirical, "warmup_runs", 0)
    monkeypatch.setattr(perf_config.empirical, "benchmark_runs", 1)
    result = benchmark.do_bench(lambda tensor: tensor + 1, (torch.empty((2, 2), device="meta"),), {})
    assert result.execution_time_s >= 0

    monkeypatch.setattr(perf_config.empirical, "runtime_device_override", torch.device("cpu"))
    try:
        assert OpBenchmark(TEST_DEVICE).infer_runtime_device() == torch.device("cpu")
    finally:
        monkeypatch.setattr(perf_config.empirical, "runtime_device_override", None)

    class FakeTensorCastOp:
        namespace = "tensor_cast"
        is_view = False

    fake_func = FakeTensorCastOp()
    info = OpInvokeInfo(fake_func, (), {}, None, cache_key="unit")
    with pytest.raises(ValueError, match="No implementation registered"):
        benchmark.benchmark(info)


def test_op_benchmark_handles_non_tensor_cast_ops(monkeypatch):
    benchmark = OpBenchmark(TEST_DEVICE)
    monkeypatch.setattr(benchmark, "do_bench", lambda op_impl, args, kwargs: op_impl(*args, **kwargs))
    info = types.SimpleNamespace(
        func=torch.ops.aten.neg.default,
        args=(torch.tensor([1.0]),),
        kwargs={},
    )

    assert torch.equal(benchmark.benchmark(info), torch.tensor([-1.0]))