"""Unit tests for FSDP gradient clipping helper functions and control flow."""

import math
import os
import types

import pytest


def _make_param_with_grad(values):
    torch = pytest.importorskip("torch")
    param = torch.nn.Parameter(torch.zeros(len(values), dtype=torch.float32))
    param.grad = torch.tensor(values, dtype=torch.float32)
    return param


class TestLocalPthSum:
    @pytest.mark.parametrize(
        "grads,p,expected",
        [
            ([[3.0, 4.0]], 2.0, 25.0),
            ([[1.0, 2.0, 2.0]], 2.0, 9.0),
            ([[1.0, -2.0, 3.0]], 1.0, 6.0),
            ([[2.0, -3.0]], 3.0, 35.0),
            ([[0.0, 0.0]], 2.0, 0.0),
            ([[1.5, -2.0]], 2.0, 6.25),
            ([[1.0], [2.0], [3.0]], 2.0, 14.0),
            ([[1.0, 2.0], [3.0, 4.0]], 2.0, 30.0),
            ([[1.0, -2.0], [-3.0, 4.0]], 1.0, 10.0),
            ([[2.0, 2.0], [2.0, 2.0]], 4.0, 64.0),
        ],
    )
    def test_local_pth_sum_matches_manual_sum_of_powers(self, grads, p, expected):
        from mindspeed_mm.fsdp.optimizer.clip_grad_norm import _local_pth_sum

        params = [_make_param_with_grad(values) for values in grads]

        assert _local_pth_sum(params, p).item() == pytest.approx(expected)

    def test_local_pth_sum_ignores_parameters_without_grad(self):
        torch = pytest.importorskip("torch")
        from mindspeed_mm.fsdp.optimizer.clip_grad_norm import _local_pth_sum

        with_grad = _make_param_with_grad([3.0, 4.0])
        without_grad = torch.nn.Parameter(torch.zeros(2, dtype=torch.float32))

        assert _local_pth_sum([with_grad, without_grad], 2.0).item() == pytest.approx(25.0)

    def test_local_pth_sum_preserves_current_empty_param_behavior(self, monkeypatch):
        pytest.importorskip("torch")
        import mindspeed_mm.fsdp.optimizer.clip_grad_norm as mod

        monkeypatch.setattr(mod, "get_device_type", lambda: "cpu")

        with pytest.raises(RuntimeError, match="empty"):
            mod._local_pth_sum([], 2.0)


class TestLocalMax:
    @pytest.mark.parametrize(
        "grads,expected",
        [
            ([[3.0, 4.0]], 4.0),
            ([[-3.0, 4.0]], 4.0),
            ([[0.0, 0.0]], 0.0),
            ([[1.0, 2.0], [9.0]], 9.0),
            ([[-1.0, -2.0], [-8.0, 3.0]], 8.0),
            ([[0.5], [0.25], [0.125]], 0.5),
        ],
    )
    def test_local_max_returns_largest_absolute_gradient(self, grads, expected):
        from mindspeed_mm.fsdp.optimizer.clip_grad_norm import _local_max

        params = [_make_param_with_grad(values) for values in grads]

        assert _local_max(params).item() == pytest.approx(expected)

    def test_local_max_ignores_parameters_without_grad(self):
        torch = pytest.importorskip("torch")
        from mindspeed_mm.fsdp.optimizer.clip_grad_norm import _local_max

        with_grad = _make_param_with_grad([1.0, -7.0])
        without_grad = torch.nn.Parameter(torch.zeros(3, dtype=torch.float32))

        assert _local_max([without_grad, with_grad]).item() == pytest.approx(7.0)

    def test_local_max_returns_zero_for_empty_param_list(self, monkeypatch):
        pytest.importorskip("torch")
        import mindspeed_mm.fsdp.optimizer.clip_grad_norm as mod

        monkeypatch.setattr(mod, "get_device_type", lambda: "cpu")

        value = mod._local_max([])

        assert value.item() == pytest.approx(0.0)


class TestFsdp2ReduceGroup:
    def test_fsdp2_reduce_group_uses_sum_for_finite_norms(self, monkeypatch):
        pytest.importorskip("torch")
        import mindspeed_mm.fsdp.optimizer.clip_grad_norm as mod

        calls = []

        def fake_all_reduce(value, op=None, group=None):
            calls.append((value.item(), op, group))
            value.add_(10.0)

        monkeypatch.setattr(mod.dist, "all_reduce", fake_all_reduce)

        group_a = object()
        group_b = object()
        result = mod._fsdp2_reduce_group(
            params=[_make_param_with_grad([3.0, 4.0])],
            norm_type=2.0,
            reduce_groups=[("a", group_a), ("b", group_b), ("none", None)],
        )

        assert result.item() == pytest.approx(45.0)
        assert calls == [
            (25.0, mod.dist.ReduceOp.SUM, group_a),
            (35.0, mod.dist.ReduceOp.SUM, group_b),
        ]

    def test_fsdp2_reduce_group_uses_max_for_inf_norm(self, monkeypatch):
        pytest.importorskip("torch")
        import mindspeed_mm.fsdp.optimizer.clip_grad_norm as mod

        calls = []

        def fake_all_reduce(value, op=None, group=None):
            calls.append((value.item(), op, group))
            value.fill_(max(value.item(), 9.0))

        monkeypatch.setattr(mod.dist, "all_reduce", fake_all_reduce)

        group = object()
        result = mod._fsdp2_reduce_group(
            params=[_make_param_with_grad([3.0, 4.0])],
            norm_type=float("inf"),
            reduce_groups=[("fsdp", group)],
        )

        assert result.item() == pytest.approx(9.0)
        assert calls == [(4.0, mod.dist.ReduceOp.MAX, group)]

    def test_fsdp2_reduce_group_skips_none_groups(self, monkeypatch):
        pytest.importorskip("torch")
        import mindspeed_mm.fsdp.optimizer.clip_grad_norm as mod

        calls = []
        monkeypatch.setattr(mod.dist, "all_reduce", lambda *args, **kwargs: calls.append((args, kwargs)))

        result = mod._fsdp2_reduce_group(
            params=[_make_param_with_grad([1.0, 2.0])],
            norm_type=2.0,
            reduce_groups=[("none", None)],
        )

        assert result.item() == pytest.approx(5.0)
        assert calls == []


class TestClipGradNormControlFlow:
    def test_clip_grad_norm_uses_ep_path_when_model_has_ep_param_groups(self, monkeypatch):
        pytest.importorskip("torch")
        import torch
        import mindspeed_mm.fsdp.optimizer.clip_grad_norm as mod

        model = torch.nn.Linear(2, 2)
        model._ep_param_groups = {"ep": set(), "non_ep": set()}
        sentinel = torch.tensor(12.0)

        monkeypatch.setattr(mod, "ep_fsdp2_clip_grad_norm", lambda *args, **kwargs: sentinel)

        assert mod.clip_grad_norm(model, max_norm=1.0) is sentinel

    def test_clip_grad_norm_compute_only_does_not_modify_gradients(self, monkeypatch):
        torch = pytest.importorskip("torch")
        import mindspeed_mm.fsdp.optimizer.clip_grad_norm as mod

        dummy_ps = types.SimpleNamespace(get_fsdp_group=lambda: None)
        monkeypatch.setattr(mod, "get_parallel_state", lambda: dummy_ps)

        model = torch.nn.Linear(2, 2, bias=False)
        model.weight.grad = torch.full_like(model.weight, 3.0)
        before = model.weight.grad.clone()

        returned = mod.clip_grad_norm(model, max_norm=0.0, norm_type=2.0)

        assert torch.allclose(model.weight.grad, before)
        assert returned.item() == pytest.approx(math.sqrt(float(torch.sum(before.float() ** 2))))

    def test_ep_fsdp2_clip_grad_norm_compute_only_returns_combined_norm(self, monkeypatch):
        torch = pytest.importorskip("torch")
        import mindspeed_mm.fsdp.optimizer.clip_grad_norm as mod

        ep_param = _make_param_with_grad([3.0])
        non_ep_param = _make_param_with_grad([4.0])

        model = types.SimpleNamespace(
            _ep_param_groups={
                "ep": {ep_param},
                "non_ep": {non_ep_param},
            }
        )
        dummy_ps = types.SimpleNamespace(
            get_fsdp_group=lambda: None,
            get_ep_group=lambda: None,
            get_efsdp_group=lambda: None,
            is_ep_enable=lambda: True,
        )
        monkeypatch.setattr(mod, "get_parallel_state", lambda: dummy_ps)

        returned = mod.ep_fsdp2_clip_grad_norm(model, max_norm=0.0, norm_type=2.0)

        assert returned.item() == pytest.approx(5.0)
        assert torch.allclose(ep_param.grad, torch.tensor([3.0]))
        assert torch.allclose(non_ep_param.grad, torch.tensor([4.0]))