# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# See LICENSE for license information.

"""Unit tests for multi_tensor operations in TransformerEngineNPU.

Aligned with upstream TransformerEngine test_multi_tensor.py.
"""

import sys
from typing import Tuple

import pytest
import torch

# Adjust path so we can import transformer_engine from TransformerEngineNPU
sys.path.insert(0, "transformer_engine")

from transformer_engine.pytorch.optimizers import (
    MultiTensorApply,
    multi_tensor_applier,
    multi_tensor_scale,
    multi_tensor_scale_tensor,
    multi_tensor_l2norm,
    multi_tensor_unscale_l2norm,
    multi_tensor_compute_scale_and_scale_inv,
    multi_tensor_compute_scale_inv_e8m0,
)

from transformer_engine.pytorch.quantization import is_mxfp8_available

# ---------------------------------------------------------------------------
# Device detection: prefer NPU, fall back to CUDA, then CPU
# ---------------------------------------------------------------------------
_device = None
if hasattr(torch, "npu") and torch.npu.is_available():
    _device = torch.device("npu")
elif torch.cuda.is_available():
    _device = torch.device("cuda")
else:
    _device = torch.device("cpu")

# ---------------------------------------------------------------------------
# Shared test data
# ---------------------------------------------------------------------------
input_size_pairs = [
    (7777 * 77, 555 * 555),
    (777, 555),
    (555, 2048 * 32 + 1),
    (2048 * 32 + 1, 555),
    (555, 2048 * 32),
    (2048 * 32, 555),
    (33333, 555),
    (555, 33333),
]
appliers = [MultiTensorApply(2048 * 32), MultiTensorApply(333), MultiTensorApply(33333)]
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)


# ---------------------------------------------------------------------------
# Reference implementation
# ---------------------------------------------------------------------------
def _scale_from_amax_tensor(
    amax: torch.Tensor,
    fp8_dtype: torch.dtype,
    eps: float,
    pow_2_scales: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Reference implementation for scale/scale_inv computation."""
    assert amax.dtype == torch.float32, "amax must be a float32 tensor."
    fp8_max = torch.finfo(fp8_dtype).max
    clamped = torch.max(amax, torch.tensor(eps, dtype=torch.float32, device=amax.device))

    scale = torch.div(
        torch.tensor(fp8_max, dtype=torch.float32, device=amax.device), clamped
    )
    scale = torch.where(
        torch.isinf(scale),
        torch.tensor(torch.finfo(torch.float32).max, dtype=torch.float32, device=amax.device),
        scale,
    )

    if pow_2_scales:
        _, exp = torch.frexp(scale)
        exp = exp - 1
        unity = torch.tensor(1.0, device=exp.device)
        torch.ldexp(unity, exp, out=scale)
        scale = torch.where(amax == float("inf"), 0.0, scale)

    scale = torch.where(amax == 0.0, 1.0, scale)
    scale_inv = torch.reciprocal(scale)
    return scale, scale_inv


# ===========================================================================
class TestMultiTensorScale:
    """Tests for multi_tensor_scale."""

    @pytest.mark.parametrize("input_size_pair", input_size_pairs)
    @pytest.mark.parametrize("applier", appliers)
    @pytest.mark.parametrize("repeat", [1, 55])
    @pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
    @pytest.mark.parametrize("out_type", [torch.float32, torch.float16, torch.bfloat16])
    @pytest.mark.parametrize("inplace", [False, True])
    def test_multi_tensor_scale(
        self, input_size_pair, applier, repeat, in_type, out_type, inplace
    ):
        if inplace and out_type != in_type:
            pytest.skip("inplace=True and out_type != in_type is not supported")
        if (in_type == torch.float16 and out_type == torch.bfloat16) or (
            in_type == torch.bfloat16 and out_type == torch.float16
        ):
            pytest.skip("float16 to bfloat16 is not necessary and vice versa")

        scale = 4.0
        overflow_buf = torch.zeros(1, dtype=torch.int32, device=_device)
        ref = torch.tensor([1.0], dtype=torch.float32, device=_device)
        sizea, sizeb = input_size_pair

        def downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=False):
            overflow_buf.zero_()
            a = torch.full([sizea], scale, dtype=torch.float32, device=_device)
            b = torch.full([sizeb], scale, dtype=torch.float32, device=_device)

            out_list = []
            for i in range(repeat):
                out_list += [a.clone().to(out_type), b.clone().to(out_type)]

            if inplace:
                in_list = out_list
            else:
                in_list = [out.clone().to(in_type) for out in out_list]

            applier(multi_tensor_scale, overflow_buf, [in_list, out_list], 1.0 / scale)

            assert all(torch.allclose(out, ref.to(out_type)) for out in out_list)
            assert overflow_buf.item() == 0

        def find_inf(
            sizea, sizeb, applier, repeat, in_type, out_type,
            t, ind, val, inplace=False,
        ):
            overflow_buf.zero_()
            a = torch.full([sizea], scale, dtype=torch.float32, device=_device)
            b = torch.full([sizeb], scale, dtype=torch.float32, device=_device)

            out_list = []
            for i in range(repeat):
                out_list += [a.clone().to(out_type), b.clone().to(out_type)]

            if inplace:
                in_list = out_list
            else:
                in_list = [out.clone().to(in_type) for out in out_list]

            applier(multi_tensor_scale, overflow_buf, [in_list, out_list], 1.0 / scale)

            overflow_buf.zero_()
            in_list[t][ind] = val
            applier(multi_tensor_scale, overflow_buf, [in_list, out_list], 1.0 / scale)
            assert overflow_buf.item() > 0

        downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=inplace)
        find_inf(
            sizea, sizeb, applier, repeat, in_type, out_type,
            0, 0, float("nan"), inplace=inplace,
        )
        find_inf(
            sizea, sizeb, applier, repeat, in_type, out_type,
            2 * repeat - 1, sizeb - 1, float("inf"), inplace=inplace,
        )
        find_inf(
            sizea, sizeb, applier, repeat, in_type, out_type,
            2 * (repeat // 2), sizea // 2, float("inf"), inplace=inplace,
        )


class TestMultiTensorScaleTensor:
    """Tests for multi_tensor_scale_tensor."""

    @pytest.mark.parametrize("input_size_pair", input_size_pairs)
    @pytest.mark.parametrize("applier", appliers)
    @pytest.mark.parametrize("repeat", [1, 55])
    @pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
    @pytest.mark.parametrize("out_type", [torch.float32, torch.float16, torch.bfloat16])
    @pytest.mark.parametrize("inplace", [False, True])
    def test_multi_tensor_scale_tensor(
        self, input_size_pair, applier, repeat, in_type, out_type, inplace
    ):
        if inplace and out_type != in_type:
            pytest.skip("inplace=True and out_type != in_type is not supported")
        if (in_type == torch.float16 and out_type == torch.bfloat16) or (
            in_type == torch.bfloat16 and out_type == torch.float16
        ):
            pytest.skip("float16 to bfloat16 is not necessary and vice versa")

        scale = 4.0
        inv_scale_tensor = torch.tensor([1.0 / scale], dtype=torch.float32, device=_device)
        overflow_buf = torch.zeros(1, dtype=torch.int32, device=_device)
        ref = torch.tensor([1.0], dtype=torch.float32, device=_device)
        sizea, sizeb = input_size_pair

        def downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=False):
            overflow_buf.zero_()
            a = torch.full([sizea], scale, dtype=torch.float32, device=_device)
            b = torch.full([sizeb], scale, dtype=torch.float32, device=_device)

            out_list = []
            for _ in range(repeat):
                out_list += [a.clone().to(out_type), b.clone().to(out_type)]

            if inplace:
                in_list = out_list
            else:
                in_list = [out.clone().to(in_type) for out in out_list]

            applier(
                multi_tensor_scale_tensor, overflow_buf,
                [in_list, out_list], inv_scale_tensor,
            )

            assert all(torch.allclose(out, ref.to(out_type)) for out in out_list)
            assert overflow_buf.item() == 0

        def find_inf(
            sizea, sizeb, applier, repeat, in_type, out_type,
            t, ind, val, inplace=False,
        ):
            overflow_buf.zero_()
            a = torch.full([sizea], scale, dtype=torch.float32, device=_device)
            b = torch.full([sizeb], scale, dtype=torch.float32, device=_device)

            out_list = []
            for _ in range(repeat):
                out_list += [a.clone().to(out_type), b.clone().to(out_type)]

            if inplace:
                in_list = out_list
            else:
                in_list = [out.clone().to(in_type) for out in out_list]

            applier(
                multi_tensor_scale_tensor, overflow_buf,
                [in_list, out_list], inv_scale_tensor,
            )

            overflow_buf.zero_()
            in_list[t][ind] = val
            applier(
                multi_tensor_scale_tensor, overflow_buf,
                [in_list, out_list], inv_scale_tensor,
            )
            assert overflow_buf.item() > 0

        downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=inplace)
        find_inf(
            sizea, sizeb, applier, repeat, in_type, out_type,
            0, 0, float("nan"), inplace=inplace,
        )
        find_inf(
            sizea, sizeb, applier, repeat, in_type, out_type,
            2 * repeat - 1, sizeb - 1, float("inf"), inplace=inplace,
        )
        find_inf(
            sizea, sizeb, applier, repeat, in_type, out_type,
            2 * (repeat // 2), sizea // 2, float("inf"), inplace=inplace,
        )


class TestMultiTensorL2Norm:
    """Tests for multi_tensor_l2norm."""

    @pytest.mark.parametrize("input_size_pair", input_size_pairs)
    @pytest.mark.parametrize("applier", appliers)
    @pytest.mark.parametrize("repeat", [1, 55])
    @pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
    @pytest.mark.parametrize("per_tensor", [False, True])
    def test_multi_tensor_l2norm(
        self, input_size_pair, applier, repeat, in_type, per_tensor
    ):
        sizea, sizeb = input_size_pair
        val = 4.0
        overflow_buf = torch.zeros(1, dtype=torch.int32, device=_device)

        overflow_buf.zero_()
        a = torch.full([sizea], val, dtype=torch.float32, device=_device)
        b = torch.full([sizeb], val, dtype=torch.float32, device=_device)

        in_list = []
        for i in range(repeat):
            in_list += [a.clone().to(in_type), b.clone().to(in_type)]

        if per_tensor:
            norm, norm_per_tensor = applier(
                multi_tensor_l2norm, overflow_buf, [in_list], True,
            )
            normab = torch.cat((a.norm().view(1), b.norm().view(1)))
            norm_per_tensor = norm_per_tensor.view(-1, 2)
        else:
            norm, _ = applier(
                multi_tensor_l2norm, overflow_buf, [in_list], False,
            )

        reference = torch.full(
            [(sizea + sizeb) * repeat], val, dtype=torch.float32, device=_device,
        ).norm()

        torch.testing.assert_close(norm, reference.broadcast_to(norm.shape))
        if per_tensor:
            torch.testing.assert_close(
                norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape),
            )
        assert overflow_buf.item() == 0


class TestMultiTensorUnscaleL2Norm:
    """Tests for multi_tensor_unscale_l2norm."""

    @pytest.mark.parametrize("input_size_pair", input_size_pairs)
    @pytest.mark.parametrize("applier", appliers)
    @pytest.mark.parametrize("repeat", [1, 55])
    @pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
    @pytest.mark.parametrize("per_tensor", [False, True])
    def test_multi_tensor_unscale_l2norm(
        self, input_size_pair, applier, repeat, in_type, per_tensor
    ):
        sizea, sizeb = input_size_pair
        val = 4.0
        inv_scale = 0.5
        inv_scale_tensor = torch.tensor([inv_scale], dtype=torch.float32, device=_device)
        overflow_buf = torch.zeros(1, dtype=torch.int32, device=_device)

        overflow_buf.zero_()
        a = torch.full([sizea], val, dtype=torch.float32, device=_device)
        b = torch.full([sizeb], val, dtype=torch.float32, device=_device)

        in_list = []
        for i in range(repeat):
            in_list += [a.clone().to(in_type), b.clone().to(in_type)]

        norm, norm_per_tensor = applier(
            multi_tensor_unscale_l2norm,
            overflow_buf,
            [in_list],
            inv_scale_tensor,
            per_tensor,
        )

        reference = torch.full(
            [(sizea + sizeb) * repeat],
            val * inv_scale,
            dtype=torch.float32,
            device=_device,
        ).norm()

        torch.testing.assert_close(norm, reference.broadcast_to(norm.shape))
        if per_tensor:
            normab = torch.cat(
                ((a * inv_scale).norm().view(1), (b * inv_scale).norm().view(1)),
            )
            norm_per_tensor = norm_per_tensor.view(-1, 2)
            torch.testing.assert_close(
                norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape),
            )
        assert overflow_buf.item() == 0


class TestMultiTensorComputeScale:
    """Tests for multi_tensor_compute_scale_and_scale_inv."""

    @pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)])
    @pytest.mark.parametrize("applier", appliers)
    @pytest.mark.parametrize("repeat", [1, 55])
    @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
    @pytest.mark.parametrize("pow_2_scales", [False, True])
    @pytest.mark.parametrize("epsilon", [0.0, 100.0])
    def test_multi_tensor_compute_scale_and_scale_inv(
        self, input_size_pair, applier, repeat, fp8_dtype, pow_2_scales, epsilon,
    ):
        sizea, sizeb = input_size_pair
        overflow_buf = torch.zeros(1, dtype=torch.int32, device=_device)
        a = torch.randn([sizea], dtype=torch.float32, device=_device).abs()
        b = torch.randn([sizeb], dtype=torch.float32, device=_device).abs()
        max_fp8 = torch.finfo(fp8_dtype).max

        amax_list = []
        for i in range(repeat):
            amax_list += [a.clone(), b.clone()]

        scale_list = [torch.empty_like(x) for x in amax_list]
        scale_inv_list = [torch.empty_like(x) for x in amax_list]

        applier(
            multi_tensor_compute_scale_and_scale_inv,
            overflow_buf,
            [amax_list, scale_list, scale_inv_list],
            max_fp8,
            pow_2_scales,
            epsilon,
        )

        for amax, scale, scale_inv in zip(amax_list, scale_list, scale_inv_list):
            scale_ref, scale_inv_ref = _scale_from_amax_tensor(
                amax, fp8_dtype, eps=epsilon, pow_2_scales=pow_2_scales,
            )
            torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0)
            torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0)


class TestMultiTensorComputeScaleInvE8M0:
    """Tests for multi_tensor_compute_scale_inv_e8m0."""

    @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
    @pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)])
    @pytest.mark.parametrize("applier", appliers)
    @pytest.mark.parametrize("repeat", [1, 55])
    def test_multi_tensor_compute_scale_inv_e8m0(
        self, input_size_pair, applier, repeat,
    ):
        sizea, sizeb = input_size_pair
        a = torch.randn([sizea], dtype=torch.bfloat16, device=_device).abs()
        b = torch.randn([sizeb], dtype=torch.bfloat16, device=_device).abs()

        amax_list = []
        for _ in range(repeat):
            amax_list += [a.clone(), b.clone()]
        scale_inv_list = [torch.empty_like(x).to(torch.uint8) for x in amax_list]

        applier(
            multi_tensor_compute_scale_inv_e8m0,
            None,
            [amax_list, scale_inv_list],
        )

        max_fp8 = torch.finfo(torch.float8_e4m3fn).max
        for amax, scale_inv in zip(amax_list, scale_inv_list):
            scale_inv_u32 = (amax.float() / max_fp8).view(torch.int32)
            exponent = scale_inv_u32 >> 23
            mantissa = scale_inv_u32 & 0x7FFFFF
            exponent += (
                ((mantissa > 0) & (exponent != 0xFE))
                & ~((exponent == 0) & (mantissa <= 0x400000))
            ).to(torch.int32)
            torch.testing.assert_close(exponent.to(torch.uint8), scale_inv)


class TestMultiTensorApply:
    """Tests for MultiTensorApply class."""

    def test_chunk_size_injection(self):
        applier = MultiTensorApply(12345)
        called_with = {}

        def capture_op(chunk_size, buf, lists, *args):
            called_with["chunk_size"] = chunk_size
            called_with["buf"] = buf
            called_with["lists"] = lists
            called_with["args"] = args

        buf = torch.zeros(1, dtype=torch.int32)
        lists = [[torch.zeros(10), torch.zeros(20)]]
        applier(capture_op, buf, lists, "extra")

        assert called_with["chunk_size"] == 12345
        assert called_with["buf"] is buf
        assert called_with["lists"] is lists
        assert called_with["args"] == ("extra",)

    def test_default_applier_exists(self):
        assert hasattr(multi_tensor_applier, "chunk_size")
        assert multi_tensor_applier.chunk_size == 2048 * 32


class TestImportPaths:
    """Verify all expected import paths work."""

    def test_direct_import(self):
        from transformer_engine.pytorch.optimizers import multi_tensor_scale
        assert callable(multi_tensor_scale)

    def test_from_pytorch(self):
        from transformer_engine.pytorch import multi_tensor_scale
        assert callable(multi_tensor_scale)

    def test_from_pytorch_applier(self):
        from transformer_engine.pytorch import multi_tensor_applier
        assert callable(multi_tensor_applier)
        assert hasattr(multi_tensor_applier, "__call__")

    def test_from_pytorch_mtapply(self):
        from transformer_engine.pytorch import MultiTensorApply
        assert MultiTensorApply is MultiTensorApply

    def test_import_all_exports(self):
        from transformer_engine.pytorch.optimizers import (
            MultiTensorApply,
            multi_tensor_applier,
            multi_tensor_scale,
            multi_tensor_scale_tensor,
            multi_tensor_l2norm,
            multi_tensor_unscale_l2norm,
            multi_tensor_compute_scale_and_scale_inv,
            multi_tensor_compute_scale_inv_e8m0,
        )
        assert callable(multi_tensor_scale)
        assert callable(multi_tensor_scale_tensor)
        assert callable(multi_tensor_l2norm)
        assert callable(multi_tensor_unscale_l2norm)
        assert callable(multi_tensor_compute_scale_and_scale_inv)
        assert callable(multi_tensor_compute_scale_inv_e8m0)
        assert callable(MultiTensorApply)
        assert callable(multi_tensor_applier)