# Owner(s): ["module: meta tensors"]

import contextlib
import copy
import itertools
import os
import unittest
import weakref
from unittest.mock import patch
import numpy as np
import math
import random

import torch
import torch._dynamo
import torch._functorch.config
import torch._prims as prims

import torch_npu
import torch_npu.testing
from torch_npu.testing.common_utils import SupportedDevices
import torch.testing._internal.optests as optests
from torch import distributed as dist
from torch._dynamo.testing import rand_strided
from torch._subclasses.fake_tensor import (
    FakeTensor,
    FakeTensorMode,
    FakeTensorConverter,
    DynamicOutputShapeException,
    UnsupportedOperatorException,
)
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.testing import FileCheck
from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_utils import (
    TestCase, TEST_WITH_TORCHDYNAMO, run_tests, skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, parametrize,
    instantiate_parametrized_tests)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten

RUN_NPU = torch.npu.is_available()


def _get_test_torch_version():
    torch_npu_version = torch_npu.__version__
    version_list = torch_npu_version.split('.')
    if len(version_list) > 2:
        return f'v{version_list[0]}.{version_list[1]}'
    else:
        raise RuntimeError("Invalid torch_npu version.")


class FakeTensorTest(TestCase):
    def checkType(self, t, device_str, size):
        self.assertTrue(isinstance(t, FakeTensor))
        self.assertEqual(t.device.type, device_str)
        self.assertEqual(list(t.size()), size)


    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_npu_initialized(self):
        # doesnt error
        with FakeTensorMode():
            p = torch.randn(4, 2, requires_grad=True, device='npu')
            x = torch.randn(8, 4, device='npu')
            y = torch.mm(x, p).square().sum()
            y.backward()

    def test_basic(self):
        x = torch.empty(2, 2, device="cpu")
        y = torch.empty(4, 2, 2, device="cpu")
        with FakeTensorMode() as mode:
            x = mode.from_tensor(x)
            y = mode.from_tensor(y)
            z = x + y
            self.assertEqual(z.shape, (4, 2, 2))
            self.assertEqual(z.device, torch.device("cpu"))
            self.assertTrue(isinstance(z, FakeTensor))

    @unittest.skipIf(_get_test_torch_version() not in ["v2.1", "v2.3.1"],
                     "Skipping test for these torch versions.")
    def test_basic_forced_memo_only(self):
        x = torch.empty(2, 2, device="cpu")
        y = torch.empty(4, 2, 2, device="cpu")
        with FakeTensorMode() as mode:
            x_fake = mode.from_tensor(x)
            x2 = mode.from_tensor(x, memoized_only=True)
            self.assertTrue(x2 is not None)
            y = mode.from_tensor(y, memoized_only=True)
            self.assertIs(y, None)

    def test_custom_op_fallback(self):
        from torch.library import Library, impl

        test_lib = Library("my_test_op", "DEF")
        test_lib.define('foo(Tensor self) -> Tensor')

        @impl(test_lib, 'foo', 'CPU')
        def foo_impl(self):
            return self.cos()

        x = torch.empty(2, 2, device="cpu")
        with self.assertRaisesRegex(UnsupportedOperatorException, "my_test_op.foo.default"):
            with FakeTensorMode(allow_fallback_kernels=True) as mode:
                x = mode.from_tensor(x)
                torch.ops.my_test_op.foo(x)

    def test_parameter_instantiation(self):
        with FakeTensorMode():
            x = torch.rand([4])
            y = torch.nn.parameter.Parameter(x)
            self.assertTrue(isinstance(y, torch.nn.Parameter))

    @unittest.skipIf(not dist.is_available(), "requires distributed")
    def test_fsdp_flat_param(self):
        if "2.1." in torch.__version__:
            from torch.distributed.fsdp.flat_param import FlatParameter
        else:
            from torch.distributed.fsdp._flat_param import FlatParameter
        with FakeTensorMode() as m:
            data = torch.randn(2, 2)
            param = FlatParameter(data, requires_grad=True)
        self.assertIsInstance(param, FlatParameter)
        self.assertIsInstance(param, torch.nn.Parameter)
        self.assertIsInstance(param, FakeTensor)

    def test_non_parameter_grad(self):
        mode = FakeTensorMode()
        t = torch.rand([4], requires_grad=True)
        fake_t = mode.from_tensor(t)
        self.assertEqual(fake_t.requires_grad, t.requires_grad)

    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_index_npu_with_cpu(self):
        with FakeTensorMode():
            x = torch.rand([2048], device='npu')
            out = x[torch.zeros([36], dtype=torch.int64)]
            self.checkType(out, "npu", [36])

    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_shape_take_not_device(self):
        with FakeTensorMode():
            x = torch.empty(1, device="cpu")
            y = torch.empty(8, 8, device="npu")
            out = x.resize_as_(y)
            self.assertEqual(out.shape, (8, 8))
            self.assertEqual(out.device.type, "cpu")
            self.assertTrue(isinstance(out, FakeTensor))

    def test_repr(self):
        with FakeTensorMode():
            x = torch.empty(2, 2, device="cpu")
            self.assertEqual(repr(x), 'FakeTensor(..., size=(2, 2))')
            x = torch.empty(2, 2, device="meta")
            self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")

    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_zero_dim(self):
        with FakeTensorMode() as mode:
            x = torch.tensor(0.)
            y = torch.rand([4, 4], device="npu")
            out = x + y
            self.assertEqual(out.shape, (4, 4))
            self.assertEqual(out.device, y.device)
            self.assertTrue(isinstance(out, FakeTensor))

    def test_nan_to_num(self):
        with FakeTensorMode():
            for dtype in [torch.float16, torch.float32]:
                x = torch.rand([4], dtype=dtype)
                y = torch.nan_to_num(x, nan=None)
                z = torch.nan_to_num(x, 0.0)
                self.assertEqual(dtype, y.dtype)
                self.assertEqual(dtype, z.dtype)

    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_throw(self):
        x = torch.tensor(0.)
        with FakeTensorMode() as mode:
            x_conv = mode.from_tensor(x)
            y = torch.rand([4, 4], device="npu")
            z = torch.rand([4, 4], device="cpu")
            self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))

    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_type_as(self):
        with FakeTensorMode():
            x = torch.rand([16, 1], device="cpu")
            y = torch.rand([4, 4], device="npu")
            out = x.type_as(y)
            self.assertEqual(out.device.type, "npu")
            self.assertTrue(isinstance(out, FakeTensor))

    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_setitem(self):
        for device in ["cpu", "npu"]:
            with FakeTensorMode():
                x = torch.rand([16, 1], device=device)
                x[..., 0] = 0

    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_device_inplace_copy(self):
        with FakeTensorMode():
            x = torch.rand([8, 8], device="cpu")
            y = torch.rand([8, 8], device="npu")
            self.assertEqual(x.copy_(y).device.type, "cpu")
            self.assertEqual(y.copy_(x).device.type, "npu")

    def test_fake_dispatch_keys(self):
        with FakeTensorMode():
            x = torch.rand([4])
            f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU")
            f.run(torch._C._dispatch_key_set(x))

            with torch.inference_mode():
                x = torch.rand([4])
                y = x + x
                FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
                FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))

    def test_constructor(self):
        with FakeTensorMode():
            x = torch.rand([4, 4], device="cpu")

        self.assertTrue(isinstance(x, FakeTensor))
        self.assertTrue(x.device.type == "cpu")

    def test_mode(self):
        with FakeTensorMode():
            y = torch.rand([4], device="cpu")
            out = y + y

        self.assertTrue(isinstance(out, FakeTensor))

    def test_full(self):
        # Test torch.full returns tensor with correct dtype
        with torch._subclasses.CrossRefFakeMode():
            y = torch.full((4, 4), 1)

    def check_function_with_fake(self, fn):
        out = fn()
        with torch._subclasses.FakeTensorMode():
            out_fake = fn()

        for a, b in zip(tree_flatten(out), tree_flatten(out_fake)):
            if not isinstance(a, FakeTensor):
                self.assertTrue(not isinstance(b, FakeTensor))
                continue

            prims.utils.compare_tensor_meta(a, b, check_strides=True)

    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_non_kwarg_device(self):
        with FakeTensorMode():
            x = torch.rand([16, 1], device="cpu")
            y = x.to(torch.device("cpu"))
            self.assertIs(x, y)
            z = x.to(torch.device("npu"))
            self.assertEqual(z.device.type, "npu")

    def test_non_overlapping_stride_zero(self):
        def foo():
            x = torch.empty_strided([1, 3, 427, 640], (0, 1, 1920, 3))
            return x.half()

        self.check_function_with_fake(foo)

    def test_fake_mode_error(self):
        x = torch.rand([4, 4])

        with self.assertRaisesRegex(Exception, "Please convert all Tensors"):
            with FakeTensorMode():
                y = x[0]

    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
    def test_fake_grad_copy(self):
        x = torch.rand([4, 4], requires_grad=True)
        x.grad = torch.rand([4, 4])
        mode = FakeTensorMode()
        fake_x = mode.from_tensor(x)
        prims.utils.compare_tensor_meta(fake_x, x)
        prims.utils.compare_tensor_meta(fake_x.grad, x.grad)

        self.assertTrue(isinstance(fake_x.grad, FakeTensor))

    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_like_constructor(self):
        with FakeTensorMode():
            x = torch.rand([4, 4])
            y = torch.ones_like(x)
            self.assertTrue(isinstance(y, FakeTensor))
            self.assertEqual(y.device.type, "cpu")
            z = torch.ones_like(x, device="npu")
            self.assertTrue(isinstance(z, FakeTensor))
            self.assertEqual(z.device.type, "npu")

    def test_binary_op_type_promotion(self):
        with FakeTensorMode():
            x = torch.empty([2, 2], dtype=torch.float)
            y = torch.empty([2, 2], dtype=torch.int64)
            try:
                out = x / y
            except ZeroDivisionError:
                print("Error: Division by zero is not allowed")
            self.assertEqual(out.dtype, torch.float)
            self.assertEqual(out.device.type, "cpu")

    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
    def test_from_numpy(self):
        with FakeTensorMode():
            x = torch.tensor(np.zeros([4, 4]))
            self.checkType(x, "cpu", [4, 4])

    def test_randperm(self):
        x = torch.randperm(10)
        y = torch.randperm(5, device="cpu")
        with FakeTensorMode():
            x1 = torch.randperm(10)
            prims.utils.compare_tensor_meta(x, x1)
            y1 = torch.randperm(5, device="cpu")
            prims.utils.compare_tensor_meta(y, y1)

    def test_print_in_fake_mode(self):
        x = torch.zeros(2)
        # does not fail
        with FakeTensorMode():
            out = str(x)
        self.assertNotIn("FakeTensor", out)

    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_upsample_bilinear_small_channels(self):
        out = []
        mode = FakeTensorMode()
        for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
            with context():
                arg0_1 = torch.empty_strided((3, 427, 640), (1, 1920, 3), dtype=torch.float32, device='npu')
                unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 0)
                out.append(torch.ops.aten.upsample_bilinear2d.default(unsqueeze, [800, 1199], False))

        self.assertTrue(out[1].is_contiguous())
        self.checkMetaProps(out[0], out[1])

    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_cpu_fallback(self):
        with FakeTensorMode(allow_fallback_kernels=False):
            filters = torch.randn(8, 4, 3, 3).npu()
            inputs = torch.randn(1, 4, 5, 5).npu()
            out = torch.nn.functional.conv2d(inputs, filters, padding=1)
            self.assertEqual(out.device.type, "npu")
            self.assertEqual(list(out.size()), [1, 8, 5, 5])

        with FakeTensorMode(allow_fallback_kernels=True):
            # intentionally bad inputs
            filters = torch.randn(8, 20, 3, 3).npu()
            inputs = torch.randn(1, 7, 10, 5).npu()
            with self.assertRaises(RuntimeError):
                torch.nn.functional.conv2d(inputs, filters, padding=1)

        with FakeTensorMode(allow_fallback_kernels=True):
            filters = torch.randn(8, 4, 3, 3).npu()
            inputs = torch.randn(1, 4, 5, 5).npu()

            out = torch.nn.functional.conv2d(inputs, filters, padding=1)
            self.assertEqual(out.device.type, "npu")
            self.assertEqual(list(out.size()), [1, 8, 5, 5])

    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_out_multi_device(self):
        with FakeTensorMode():
            x = torch.rand([4])
            y = torch.rand([4], device="npu")

            with self.assertRaisesRegex(Exception, "found two different devices"):
                torch.sin(x, out=y)

            with self.assertRaisesRegex(Exception, "found two different devices"):
                x.add_(y)


    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_normalize_device(self):
        with FakeTensorMode():
            x = torch.empty(1, device="npu")
            y = torch.empty(1, device=f"npu:{torch.npu.current_device()}")
            out = x + y
        self.checkType(out, "npu", [1])

    def test_recursive_invocation(self):
        mode = FakeTensorMode()
        with mode:
            x = torch.tensor(2)
            mode.in_kernel_invocation = True
            y = x + x
            self.assertTrue(mode.in_kernel_invocation)

    # Currently skip by configuring JSON. The NPU needs to be adapted.
    @unittest.skip("skip test_lstm now")
    def test_lstm(self):
        with FakeTensorMode(allow_fallback_kernels=False):
            N = 5
            L = 4
            H_in = 2
            hidden_size = 3
            proj_size = 2
            num_layers = 2
            bidir = False
            D = 2 if bidir else 1
            H_out = proj_size if proj_size > 0 else hidden_size

            lstm = torch.nn.LSTM(input_size=H_in, hidden_size=hidden_size,
                                    num_layers=num_layers, proj_size=proj_size, batch_first=False,
                                    bias=True, bidirectional=bidir, device='npu')

            h_0 = torch.randn((num_layers * D, N, H_out), requires_grad=False, device='npu')
            c_0 = torch.randn((num_layers * D, N, hidden_size), requires_grad=False, device='npu')
            inp = torch.randn((L, N, H_in), requires_grad=False, device='npu')
            (output, (h_n, c_n)) = lstm(inp, (h_0, c_0))
            output.sum().backward()

            self.assertEqual(output.shape, (L, N, D * H_out))
            self.assertEqual(h_n.shape, (D * num_layers, N, H_out))
            self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size))

    def test_data_dependent_operator(self):
        with FakeTensorMode(allow_fallback_kernels=False):
            x = torch.rand([10, 10])

            self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))

    def checkMetaProps(self, t1, t2):
        prims.utils.compare_tensor_meta(t1, t2, check_strides=True)

    @skipIfCrossRef
    def test_deepcopy(self):
        with FakeTensorMode() as mode:
            pass
        mod = torch.nn.BatchNorm2d(10)
        with torch._subclasses.fake_tensor.FakeCopyMode(mode):
            mod_copied = copy.deepcopy(mod)

        def check_copy(mod, mod_copied):
            for name, param in itertools.chain(mod.named_parameters(), mod.named_buffers()):
                param_copied = getattr(mod_copied, name)
                self.checkMetaProps(param, param_copied)
                self.assertTrue(isinstance(param_copied, FakeTensor))
                self.assertEqual(isinstance(param, torch.nn.Parameter), isinstance(param_copied, torch.nn.Parameter))
                self.assertEqual(param.requires_grad, param_copied.requires_grad)

        check_copy(mod, mod_copied)

        class ModuleNew(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.a = torch.rand([10, 2])
                self.b = self.a
                self.c = self.a[0]

        mod = ModuleNew()
        with torch._subclasses.fake_tensor.FakeCopyMode(mode):
            mod_copied = copy.deepcopy(mod)

        self.assertIs(mod_copied.a, mod_copied.b)
        self.assertEqual(mod_copied.b.storage()._cdata, mod_copied.a.storage()._cdata)

    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_new(self):
        with FakeTensorMode():
            a = torch.rand([16, 1])
            self.checkType(a.new(10, 10), "cpu", [10, 10])
            self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
            b = torch.rand([4, 4], device='npu')
            self.checkType(b.new(device='npu'), "npu", [0])
            self.checkType(a.new(torch.rand([1])), "cpu", [1])

    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
    def test_scalar_inputs(self):
        with FakeTensorMode():
            self.checkType(torch.div(3, 2), "cpu", [])
            ten = torch.zeros(2, dtype=torch.int32) * 2.0
            self.assertEqual(ten.dtype, torch.float)
            self.checkType(ten, "cpu", [2])

    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
    def test_allow_meta(self):
        def run_meta():
            with FakeTensorMode():
                x = torch.rand([4], device="meta")
                return x + x

        self.checkType(run_meta(), "meta", [4])

        with patch.object(torch._functorch.config, "fake_tensor_allow_meta", False):
            self.assertRaises(Exception, run_meta)

    def test_embedding_bag_meta(self):
        def f():
            # This behavior was originally unintentional but we see people
            # relying on it
            embedding = torch.nn.EmbeddingBag(10, 3, mode='sum', device='meta')
            ipt = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
            offsets = torch.tensor([0, 4], dtype=torch.long)
            return embedding(ipt, offsets)

        real_out = f()
        with FakeTensorMode():
            fake_out = f()

        for r, f in zip(real_out, fake_out):
            self.assertEqual(r.size(), f.size())
            self.assertEqual(r.device, f.device)

    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
    def test_mixed_real_and_fake_inputs(self):
        class _TestPattern(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(1, 1, 1)
                self.bn = torch.nn.BatchNorm2d(1)

            def forward(self, ipt):
                running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
                try:
                    scale_factor = self.bn.weight / running_std
                except ZeroDivisionError:
                    print("Error: Division by zero is not allowed")
                weight_shape = [1] * len(self.conv.weight.shape)
                weight_shape[0] = -1
                bias_shape = [1] * len(self.conv.weight.shape)
                bias_shape[1] = -1
                scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
                zero_bias = torch.zeros_like(self.conv.bias, dtype=ipt.dtype)
                conv = self.conv._conv_forward(ipt, scaled_weight, zero_bias)
                try:
                    conv_orig = conv / scale_factor.reshape(bias_shape)
                except ZeroDivisionError:
                    print("Error: Division by zero is not allowed")
                conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
                conv = self.bn(conv_orig)
                return conv

        example_inputs = (torch.randn(1, 1, 3, 3),)
        mod = _TestPattern()
        with FakeTensorMode(allow_non_fake_inputs=True):
            out = mod(torch.randn(1, 1, 3, 3))
        self.checkType(out, "cpu", (1, 1, 3, 3))

    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_batch_norm_gather_stats_update_fake_tensor(self):
        with FakeTensorMode():
            input_t = torch.rand(2, 3, 4, 4, device="npu")
            mean = torch.rand(4, 3, device="npu")
            invstd = torch.rand(4, 3, device="npu")
            running_mean = torch.rand(3, device="npu")
            running_var = torch.rand(3, device="npu")
            counts = torch.tensor([1, 1, 1, 1], dtype=torch.float32, device="npu")
            batch_mean, batch_invstd = torch_npu.batch_norm_gather_stats_update(
                input_t, mean, invstd, running_mean, running_var, 0.1, 1e-5, counts
            )
        self.checkType(batch_mean, "npu", [3])
        self.checkType(batch_invstd, "npu", [3])

    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_aten_copy_multi_device(self):
        with FakeTensorMode():
            x1 = torch.rand(4, device="cpu")
            x2 = torch.rand(4, device="npu")
            copy1 = torch.ops.aten.copy.default(x1, x2)
            copy2 = torch.ops.aten.copy.default(x2, x1)
            out = torch.empty(4, device="cpu")
            torch.ops.aten.copy.out(x1, x2, out=out)
        self.checkType(copy1, "cpu", (4,))
        self.checkType(copy2, "npu", (4,))
        self.checkType(out, "cpu", (4,))

    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_aten_index_multi_device(self):
        with FakeTensorMode():
            x1 = torch.rand(4, 4, device="cpu")
            x2 = torch.rand(4, 4, device="npu")
            i1 = torch.tensor([0, 1], device="npu")
            i2 = torch.tensor([0, 1], device="cpu")
            r1 = torch.ops.aten.index(x1, i1)
            r2 = torch.ops.aten.index(x2, i2)

            y1 = torch.rand(4, device="cpu")
            y2 = torch.rand(4, device="npu")
            j1 = torch.tensor([2], device="npu")
            j2 = torch.tensor([2], device="cpu")
            r3 = torch.ops.aten.index_put.default(x1, j1, y1)
            r4 = torch.ops.aten.index_put.default(x2, j2, y2)
        self.checkType(r1, "cpu", ())
        self.checkType(r2, "npu", ())
        self.checkType(r3, "cpu", (4, 4))
        self.checkType(r4, "npu", (4, 4))

    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_aten_slice_scatter_multi_device(self):
        with FakeTensorMode():
            x1 = torch.rand(4, 4, device="cpu")
            y1 = torch.rand(2, 4, device="npu")
            x2 = torch.rand(4, 4, device="npu")
            y2 = torch.rand(2, 4, device="cpu")
            out = torch.empty(4, 4, device="cpu")
            r1 = torch.ops.aten.slice_scatter.default(x1, y1, start=2)
            r2 = torch.ops.aten.slice_scatter.default(x2, y2, start=2)
            r3 = torch.ops.aten.slice_scatter.out(x1, y1, out=out, start=2)
        self.checkType(r1, "cpu", (4, 4))
        self.checkType(r2, "npu", (4, 4))
        self.checkType(r3, "cpu", (4, 4))
        self.checkType(out, "cpu", (4, 4))

    def test__adaptive_avg_pool2d_backward(self):
        with FakeTensorMode():
            grad_out = torch.rand(2, 3, 4, 4)
            inp = torch.rand(2, 3, 4, 4).to(memory_format=torch.channels_last)
            grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp)
            self.assertTrue(torch._prims_common.suggest_memory_format(grad_in) == torch.channels_last)


class FakeTensorConstHandling(TestCase):
    def assertConst(self, *args):
        for arg in args:
            self.assertTrue(arg.constant is not None)

    def assertNotConst(self, *args):
        for arg in args:
            self.assertTrue(arg.constant is None)

    def test_simple(self):
        with FakeTensorMode():
            x = torch.tensor(4.)
            self.assertEqual(x.item(), 4.)

    def test_inplace_add(self):
        with FakeTensorMode():
            x = torch.tensor(4.)
            y = x.add_(1)
            self.assertEqual(x.item(), 5.)
            self.assertEqual(y.item(), 5.)
            self.assertConst(x, y)

    def test_shared_storages(self):
        with FakeTensorMode():
            x = torch.tensor([4.])
            y = x[:]

            self.assertEqual(x.storage()._cdata, y.storage()._cdata)
            self.assertEqual(x.constant.storage()._cdata, y.constant.storage()._cdata)

    def test_constant_invalidation(self):
        with FakeTensorMode():
            x = torch.tensor([1.])
            self.assertConst(x)
            y = torch.rand([1])
            x.add_(y)
            self.assertNotConst(x)

    def test_inplace_view_invalidation(self):
        with FakeTensorMode():
            x = torch.tensor([1])
            self.assertConst(x)
            x.resize_([2])
            self.assertEqual(x.size(0), 2)
            self.assertNotConst(x)

    def test_fake_tensor_in_intlist_repro(self):

        def fn(tensors):
            max_size = torch.tensor([800, 1216], dtype=torch.int64)
            batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size)
            return tensors[0].new_full(batch_shape, 0.0)

        with self.assertRaises(torch._subclasses.fake_tensor.DataDependentOutputException):
            with torch._subclasses.fake_tensor.FakeTensorMode():
                a = torch.randn(3, 800, 1199)
                b = torch.randn(3, 800, 800)
                inputs = [a, b]
                ref = fn(inputs)

    def test_fake_tensor_batch_norm_cpu(self):
        with torch._subclasses.CrossRefFakeMode():
            m = torch.nn.Sequential(
                torch.nn.BatchNorm2d(10),
                torch.nn.ReLU(),
            )
            m.eval()
            out = m(torch.randn([2, 10, 8, 8]))

    def test_shared_storage_invalidation(self):
        with FakeTensorMode():
            x = torch.tensor([1.])
            y = x[:]
            self.assertConst(x, y)
            y.add_(torch.rand([1]))
            self.assertNotConst(x, y)

    def test_aliased_const_write(self):
        with FakeTensorMode():
            x = torch.tensor([1])
            y = x.expand([4])
            self.assertNotConst(y)
            y[0] = 1
            self.assertNotConst(x)

    def test_constant_propagate_through_functions(self):
        with FakeTensorMode():
            y = torch.div(4, 4, rounding_mode='trunc')
            self.assertConst(y)


def contains_type(type_tmp: torch._C.Type, maybe_contained_type: torch._C.Type):
    return maybe_contained_type.isSubtypeOf(type_tmp) or any(
        contains_type(e, maybe_contained_type) for e in type_tmp.containedTypes()
    )


class FakeTensorOpInfoTest(TestCase):
    @ops(custom_op_db, dtypes=OpDTypes.any_one)
    def test_fake(self, device, dtype, op):
        if "2.1." in torch.__version__:
            data_dependent_outputs = {
            'NumpyNMSCustomOp',
            'NumpyNonzeroCustomOp',
            }
        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
        for sample_input in sample_inputs_itr:
            args = (sample_input.input,) + sample_input.args
            kwargs = sample_input.kwargs
            if "2.1." in torch.__version__:
                optests.fake_check(op, args, kwargs, op.name in data_dependent_outputs)
            else:
                optests.fake_check(op, args, kwargs)


class FakeTensorConverterTest(TestCase):
    def test_memoized_conversion_to_meta(self):
        x = torch.rand(2, 2, 2)
        mode = FakeTensorMode()
        self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x))

    def test_memoized_conversion_from_meta(self):
        x = torch.rand(2, 2).to(device="meta")
        mode = FakeTensorMode()
        converter = mode.fake_tensor_converter
        self.assertTrue(converter.from_meta_and_device(mode, x, "cpu") is converter.from_meta_and_device(mode, x, "cpu"))

    def test_separate_tensor_storages_view(self):
        x = torch.rand(2, 2, 2)
        y = x[0]
        mode = FakeTensorMode()
        converter = mode.fake_tensor_converter
        if (_get_test_torch_version() in ["v2.1", "v2.3"]):
            x_conv = converter(mode, x)
            y_conv = converter(mode, y)
        else:
            x_conv = converter.from_real_tensor(mode, x)
            y_conv = converter.from_real_tensor(mode, y)
        self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv))

    @skipIfTorchDynamo("see pytorch torchdynamo issue 1991")
    def test_separate_tensor_storages_non_view(self):
        x = torch.rand(2, 2, 2)
        y = torch.rand(4, 2)
        y.set_(x.storage())
        mode = FakeTensorMode()
        converter = mode.fake_tensor_converter
        if (_get_test_torch_version() in ["v2.1", "v2.3"]):
            x_conv = converter(mode, x)
            y_conv = converter(mode, y)
        else:
            x_conv = converter.from_real_tensor(mode, x)
            y_conv = converter.from_real_tensor(mode, y)
        stor_id = torch._C._storage_id(x_conv)
        self.assertEqual(stor_id, torch._C._storage_id(y_conv))
        del x
        if (_get_test_torch_version() not in ["v2.1", "v2.3"]):
            del x_conv
        self.assertEqual(len(converter.tensor_memo), 1)
        if (_get_test_torch_version() in ["v2.1", "v2.3"]):
            converter.meta_converter.check_for_expired_weak_storages()
        self.assertEqual(len(converter.meta_converter.storage_memo), 1)
        del y
        if (_get_test_torch_version() not in ["v2.1", "v2.3"]):
            del y_conv
        self.assertEqual(len(converter.tensor_memo), 0)
        if (_get_test_torch_version() in ["v2.1", "v2.3"]):
            converter.meta_converter.check_for_expired_weak_storages()
        self.assertEqual(len(converter.meta_converter.storage_memo), 0)

    @skipIfTorchDynamo("see pytorch torchdynamo issue 1991")
    def test_dead_weak_ref(self):
        x = torch.rand(2, 2, 2)
        y = x[0]
        mode = FakeTensorMode()
        converter = FakeTensorConverter()
        if (_get_test_torch_version() in ["v2.1", "v2.3"]):
            x_conv = converter(mode, x)
            x_conv_storage = torch._C._storage_id(x_conv)
        else:
            x_conv = converter.from_real_tensor(mode, x)
            x_conv_storage = x_conv.untyped_storage()
        del x_conv
        self.assertFalse(x in converter.tensor_memo)
        if (_get_test_torch_version() in ["v2.1", "v2.3"]):
            y_conv = converter(mode, y)
            self.assertEqual(x_conv_storage, torch._C._storage_id(y_conv))
        else:
            y_conv = converter.from_real_tensor(mode, y)
            self.assertEqual(x_conv_storage, y_conv.untyped_storage())

    @skipIfTorchDynamo("see pytorch torchdynamo issue 1991")
    def test_dead_key(self):
        x = torch.rand(2, 2, 2)
        mode = FakeTensorMode()
        converter = FakeTensorConverter()
        if (_get_test_torch_version() in ["v2.1", "v2.3"]):
            x_conv = converter(mode, x)
        else:
            x_conv = converter.from_real_tensor(mode, x)
        self.assertEqual(len(converter.tensor_memo), 1)
        if (_get_test_torch_version() in ["v2.1", "v2.3"]):
            x_conv2 = converter(mode, x)
        else:
            x_conv2 = converter.from_real_tensor(mode, x)
        self.assertIs(x_conv2, x_conv)
        del x
        if (_get_test_torch_version() not in ["v2.1", "v2.3"]):
            del x_conv
            del x_conv2
        self.assertEqual(len(converter.tensor_memo), 0)

    def test_no_active_mode(self):
        with FakeTensorMode() as mode:
            x = torch.empty(2, 2, device="cpu")
            y = torch.empty(2, 2, device="cpu")

        out = x + y
        self.assertEqual(mode, out.fake_mode)
        self.assertTrue(isinstance(out, FakeTensor))
        self.assertEqual(out.device.type, "cpu")

    def test_multiple_modes(self):
        t = torch.rand([4])
        t2 = torch.rand([4])
        with FakeTensorMode() as m:
            with FakeTensorMode() as m2:
                t_fake = m.from_tensor(t)
                t2_fake = m2.from_tensor(t2)

                with self.assertRaisesRegex(Exception, "Mixing fake modes"):
                    t_fake + t2_fake

    def test_separate_mode_error(self):
        with FakeTensorMode():
            x = torch.empty(2, 2, device="cpu")
        with FakeTensorMode():
            y = torch.empty(2, 2, device="cpu")
        self.assertRaises(Exception, lambda: x, y)

    @skipIfTorchDynamo("see pytorch torchdynamo issue 1991")
    def test_no_ref_cycle(self):
        x = torch.rand([4])
        mode = FakeTensorMode()
        y = mode.from_tensor(x)
        self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1)
        mode_weak = weakref.ref(mode)
        y_weak = weakref.ref(mode)
        del mode
        del y
        self.assertIsNone(mode_weak())
        self.assertIsNone(y_weak())


class FakeTensorOperatorInvariants(TestCase):
    @staticmethod
    def get_aten_op(schema):
        namespace, name = schema.name.split("::")
        overload = schema.overload_name if schema.overload_name else "default"
        try:
            namespace == "aten"
        except AttributeError:
            print("AttributeError: torch.ops.aten has no attribute")
        return getattr(getattr(torch.ops.aten, name), overload)


    @staticmethod
    def get_all_aten_schemas():
        for schema in torch._C._jit_get_all_schemas():
            namespace = schema.name.split("::")[0]
            if namespace != "aten":
                continue
            yield schema

    def test_non_kwarg_only_device(self):
        for schema in self.get_all_aten_schemas():
            ten_type = torch._C.TensorType.get()
            if not any(
                contains_type(arg.type, ten_type)
                for arg in itertools.chain(schema.arguments, schema.returns)
            ):
                continue

            opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
            has_non_kwarg_device = any(
                not arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
                for arg in schema.arguments
            )
            if has_non_kwarg_device:
                self.assertTrue(
                    self.get_aten_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops
                )

    def test_tensor_constructors_all_have_kwarg_device(self):
        for schema in self.get_all_aten_schemas():
            op = self.get_aten_op(schema)
            if not torch._subclasses.fake_tensor._is_tensor_constructor(op):
                continue

            opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
            has_kwarg_device = any(
                arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
                for arg in schema.arguments
            )

            self.assertTrue(
                has_kwarg_device or op == torch.ops.aten._list_to_tensor.default
            )

    @unittest.expectedFailure
    def test_sparse_new(self):
        with FakeTensorMode():
            indices = torch.randn(1, 1, dtype=torch.int64)
            values = torch.randn(1)
            extra = (2,)
            sparse = torch.randn(1).to_sparse()
            # This used to segfault, now it does not, but it still raises an
            # error
            sparse2 = sparse.new(indices, values, extra)

    def test_tensor_new(self):
        with FakeTensorMode():
            x = torch.Tensor([1, 2, 3])
        self.assertIsInstance(x, FakeTensor)

    def test_like_ops(self):
        for schema in self.get_all_aten_schemas():
            if "_like" == schema.name[-5:]:
                op = self.get_aten_op(schema)
                self.assertIn(op, torch._subclasses.fake_tensor._like_tensor_constructors)

    # at::_embedding_bag has no op info,
    # and returns extra tensors that at::embedding bag throws away
    def test_embedding_bag_private(self):
        args = [
            torch.ones(6, 1),
            torch.ones(6, dtype=torch.int64),
            torch.arange(2, dtype=torch.int64),
            False,
            2,  # mode = max
        ]

        ref_out = torch.ops.aten._embedding_bag(*args)
        with FakeTensorMode() as m:
            meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
            meta_out = torch.ops.aten._embedding_bag(*meta_args)

        self.assertEqual(len(ref_out), len(meta_out))
        for ref_o, meta_o in zip(ref_out, meta_out):
            self.assertEqual(ref_o.size(), meta_o.size())

    def test_cross_entropy_loss(self):
        inp = torch.randn(3, 5)
        target = torch.randint(5, (3,), dtype=torch.long)
        weight = torch.rand(5)
        fn = torch.nn.functional.cross_entropy
        for w in (weight, None):
            args = (inp, target, w)
            ref = fn(*args)
            with FakeTensorMode() as m:
                meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
                meta_out = torch.nn.functional.cross_entropy(*meta_args, label_smoothing=0.5)

            self.assertEqual(ref.size(), meta_out.size())

    # require support SDPA or pre-SM80 hardware. Currently skip by configuring JSON. The NPU needs to be adapted.
    @unittest.skip("skip test_flash_attention now")
    def test_flash_attention(self):
        class Repro(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, arg1, arg2, arg3):
                torch.ops.aten._scaled_dot_product_flash_attention(arg1, arg2, arg3)

        args_new = [
            ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "npu"),
            ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "npu"),
            ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "npu"),
        ]

        args = [rand_strided(bsz, num_heads, seq_len, head_dim) for
                (bsz, num_heads, seq_len, head_dim) in args_new]
        try:
            with torch._subclasses.CrossRefFakeMode():
                Repro()(*args)
        except RuntimeError as e:
            # We expect the cross ref to succed for the first output to fail
            # for the rng state, see Note [Seed and Offset]
            self.assertTrue("output[0]" not in str(e))
            self.assertTrue("found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" in str(e))

    @unittest.skip("skip ci err")
    @skipIfRocm
    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_conv_c1_backward(self):
        class Repro(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, arg1, arg2, arg3):
                torch.ops.aten.convolution_backward.default(
                    arg1,
                    arg2,
                    arg3,
                    [1],
                    [1, 1],
                    [1, 1],
                    [1, 1],
                    False,
                    [0, 0],
                    1,
                    [True, True, False],
                )

        args_new = [
            ((16, 1, 128, 128), (16384, 16384, 128, 1), torch.float16, "npu"),
            ((16, 64, 128, 128), (1048576, 1, 8192, 64), torch.float16, "npu"),
            ((1, 64, 3, 3), (576, 9, 3, 1), torch.float16, "npu"),
        ]
        args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args_new]

        with torch._subclasses.CrossRefFakeMode():
            Repro()(*args)

    def test_no_dispatch_with_like_function(self):
        class CountingMode(TorchDispatchMode):
            def __init__(self):
                self.count = 0

            # rewrite parent class function __torch_dispatch__
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                self.count += 1
                return func(*args, **kwargs)

        with FakeTensorMode():
            x = torch.randn(2)
            with CountingMode() as mode:
                with no_dispatch():
                    torch.zeros_like(x)

        self.assertEqual(mode.count, 0)


class FakeTensorPropTest(TestCase):
    def test_fake_tensor_prop_on_nn_module(self):
        class ToyNnModuleWithParameters(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.layer1 = torch.nn.Linear(4, 3)
                self.layer2 = torch.nn.Linear(3, 2)

            def forward(self, value):
                value = self.layer1(value)
                value = torch.relu(value)
                value = self.layer2(value)
                return value

        model = ToyNnModuleWithParameters()
        value = torch.randn(5, 4)
        # Convert nn.Module to GraphModule so that FakeTensorProp runs.
        graph_model = torch.fx.symbolic_trace(model, (value,))
        # The following block runs FakeTensorProp on graph_module w/to the same FakeTensorMode
        #
        # there will be an API to run FakeTensorProp for GraphModule
        # with parameters and buffers.
        with FakeTensorMode() as fake_tensor_mode:

            def to_fake_tensor(x):
                if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor):
                    return fake_tensor_mode.from_tensor(x)
                return x

            chain_iter = itertools.chain(graph_model.named_parameters(), graph_model.named_buffers())
            fake_parameters_and_buffers = {
                k: to_fake_tensor(v)
                for k, v in chain_iter
            }
            with torch.nn.utils.stateless._reparametrize_module(
                graph_model, fake_parameters_and_buffers
            ):
                # This case uses the **same** fake tensor mode to
                #  1. create fake parameters and fake buffers, and
                #  2. run FakeTensorProp
                # The result should be correct.
                result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value)
                self.assertTrue(isinstance(result, FakeTensor))
                self.assertEqual(result.shape, (5, 2))
                # This case uses the **different** fake tensor modes to
                #  1. create fake parameters and fake buffers, and
                #  2. run FakeTensorProp
                # The following code should fail.
                failed = False
                try:
                    FakeTensorProp(graph_model).propagate(value)
                except AssertionError:
                    # AssertionError: tensor's device must be `meta`, got cpu instead
                    failed = True
                self.assertTrue(failed)


    def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
        class OptionalArgumentInBetween(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.layer1 = torch.nn.Linear(4, 3)
                self.layer2 = torch.nn.Linear(3, 2)

            def forward(self, value, another_value=None, another_optional_value=None):
                # Mimic huggingface's `forward` methods which have several optional arguments.
                # For example, GPT accepts forward(self, input_ids, None, attention_mask, ...).
                # To apply FakeTensorProp, its from_real_tensor(...) needs to accept None.
                if another_value is None:
                    another_value = torch.rand_like(value)
                if another_optional_value is None:
                    another_optional_value = torch.rand_like(value)
                value = value + another_value + another_optional_value
                return value * value

        fake_mode = FakeTensorMode(allow_non_fake_inputs=True, allow_fallback_kernels=False)
        with fake_mode:
            model = OptionalArgumentInBetween()
            value = torch.randn(5, 4)
            another_optional_value = torch.randn(5, 4)
            graph_model = torch.fx.symbolic_trace(model, (value, None, another_optional_value))
            FakeTensorProp(graph_model, fake_mode).propagate(value, None, another_optional_value)


class TestFastGelu(TestCase):

    def test_fast_gelu(self):
        with FakeTensorMode():
            a = torch.randn(2, 3).npu()
            a.requires_grad = True
            result = torch_npu.fast_gelu(a)
            self.assertTrue(a.shape == result.shape)

    def test_fast_gelu_backward(self):
        with FakeTensorMode():
            a = torch.randn(2, 3).npu()
            a.requires_grad = True
            result = torch_npu.fast_gelu(a)
            result.sum().backward()
            self.assertTrue(a.shape == a.grad.shape)

    def test_npu_fast_gelu(self):
        with FakeTensorMode():
            a = torch.randn(2, 3).npu()
            a.requires_grad = True
            result = torch_npu.npu_fast_gelu(a)

            self.assertEqual(a.shape, result.shape)

    def test_npu_fast_gelu_backward(self):
        with FakeTensorMode():
            a = torch.randn(2, 3).npu()
            a.requires_grad = True
            result = torch_npu.npu_fast_gelu(a)
            result.sum().backward()
            self.assertTrue(a.shape == a.grad.shape)


class TestGelu(TestCase):

    def test_npu_gelu(self):
        with FakeTensorMode():
            a = torch.randn(10, 3, 4).npu()
            a.requires_grad = True
            result = torch.ops.npu.npu_gelu(a)
            self.assertTrue(a.shape == result.shape)

    def test_npu_gelu_backward(self):
        with FakeTensorMode():
            a = torch.randn(10, 3, 4).npu()
            a.requires_grad = True
            result = torch.ops.npu.npu_gelu(a)
            result.sum().backward()
            self.assertTrue(a.shape == a.grad.shape)


class TestGeluMul(TestCase):

    def test_npu_gelu_mul(self):
        with FakeTensorMode():
            input_shape = [100, 400]
            for dtype in [torch.float32, torch.float16, torch.bfloat16]:
                for mode in ["none", "tanh"]:
                    input_tensor = torch.rand(input_shape, dtype=dtype, device="npu")
                    expected_shape = list(input_tensor.shape)
                    expected_shape[-1] = expected_shape[-1] // 2
                    expected_shape = tuple(expected_shape)
                    output_npu = torch_npu.npu_gelu_mul(input_tensor, approximate=mode)
                    self.assertTrue(output_npu.shape == expected_shape)
                    self.assertTrue(output_npu.dtype == input_tensor.dtype)


class TestIncreFlashAttention(TestCase):
    def testIncreFlashAttention(self):
        with FakeTensorMode():
            q = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
            k = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
            v = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
            q.requires_grad = True
            k.requires_grad = True
            v.requires_grad = True
            res = torch.ops.npu.npu_incre_flash_attention(q, k, v)

            self.assertTrue(q.shape == res.shape)

    def test_incre_flash_attention_int8_in(self):
        with FakeTensorMode():
            q = torch.randint(1, 40, (1, 128), dtype=torch.int8).npu()
            k = torch.randint(1, 40, (1, 128), dtype=torch.int8).npu()
            v = torch.randint(1, 40, (1, 128), dtype=torch.int8).npu()
            res = torch.ops.npu.npu_incre_flash_attention(q, k, v)

            self.assertTrue(q.shape == res.shape)
            self.assertTrue(res.dtype == torch.half)

    def test_incre_flash_attention_kv_padding_size(self):
        with FakeTensorMode():
            q = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
            k = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
            v = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
            kv_padding_size = torch.randint(1, 2, (1,)).npu()

            s_min = 1
            s_max = 1
            actual_seq_shape = [1]
            actual_seq = np.random.uniform(1, 1, actual_seq_shape).astype(np.int64).tolist()
            res = torch.ops.npu.npu_incre_flash_attention(q, k, v, actual_seq_lengths=actual_seq,
                                                          kv_padding_size=kv_padding_size)

            self.assertTrue(q.shape == res.shape)


class TestNpuBmmV2(TestCase):
    def test_npu_bmmV2(self):
        with FakeTensorMode():
            npu_input1 = torch.randn(10, 3, 4).npu()
            npu_input2 = torch.randn(10, 4, 5).npu()
            output_size = []
            result = torch_npu.npu_bmmV2(npu_input1, npu_input2, output_size)

            self.assertEqual(result.dtype, npu_input1.dtype)
            self.assertEqual(result.shape, torch.matmul(npu_input1, npu_input2).shape)


class TestNpuBlockSparseAttentionMeta(TestCase):
    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_npu_block_sparse_attention_meta(self):
        with FakeTensorMode():
            B, N, S, D = 1, 2, 4, 8
            num_kv_heads = 2
            scale_value = 1.0 / (D ** 0.5)
            block_shape = [128, 128]  # blockY 须为 128 的倍数
            ceil_q = (S + block_shape[0] - 1) // block_shape[0]
            ceil_kv = (S + block_shape[1] - 1) // block_shape[1]

            query = torch.randn(B, N, S, D, dtype=torch.float16).npu()
            key = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
            value = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
            block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int32).npu()

            attention_out, softmax_lse = torch_npu.npu_block_sparse_attention(
                query, key, value, block_sparse_mask, block_shape,
                q_input_layout="BNSD", kv_input_layout="BNSD",
                num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
            )

            self.assertEqual(attention_out.shape, (B, N, S, D))
            self.assertEqual(attention_out.dtype, query.dtype)
            self.assertEqual(softmax_lse.shape, (B, N, S, 1))
            self.assertEqual(softmax_lse.dtype, torch.float32)


class TestNpuBlockSparseAttentionMeta(TestCase):
    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_npu_block_sparse_attention_meta(self):
        with FakeTensorMode():
            B, N, S, D = 1, 2, 4, 8
            num_kv_heads = 2
            scale_value = 1.0 / (D ** 0.5)
            block_shape = [128, 128]  # blockY 须为 128 的倍数
            ceil_q = (S + block_shape[0] - 1) // block_shape[0]
            ceil_kv = (S + block_shape[1] - 1) // block_shape[1]

            query = torch.randn(B, N, S, D, dtype=torch.float16).npu()
            key = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
            value = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
            block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int32).npu()

            attention_out, softmax_lse = torch_npu.npu_block_sparse_attention(
                query, key, value, block_sparse_mask, block_shape,
                q_input_layout="BNSD", kv_input_layout="BNSD",
                num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
            )

            self.assertEqual(attention_out.shape, (B, N, S, D))
            self.assertEqual(attention_out.dtype, query.dtype)
            self.assertEqual(softmax_lse.shape, (B, N, S, 1))
            self.assertEqual(softmax_lse.dtype, torch.float32)

    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_npu_block_sparse_attention_backward_meta(self):
        with FakeTensorMode():
            B, N, S, D = 1, 2, 4, 8
            num_kv_heads = 2
            scale_value = 1.0 / (D ** 0.5)
            block_shape = [128, 128]  # blockY 须为 128 的倍数
            ceil_q = (S + block_shape[0] - 1) // block_shape[0]
            ceil_kv = (S + block_shape[1] - 1) // block_shape[1]

            query = torch.randn(B, N, S, D, dtype=torch.float16).npu()
            key = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
            value = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
            block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int32).npu()

            attention_out, softmax_lse = torch_npu.npu_block_sparse_attention(
                query, key, value, block_sparse_mask, block_shape,
                q_input_layout="BNSD", kv_input_layout="BNSD",
                num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
            )
            d_out = torch.randn(B, N, S, D, dtype=torch.float16).npu()

            d_query, d_key, d_value = torch_npu.npu_block_sparse_attention_backward(
                d_out, query, key, value, attention_out, softmax_lse, block_sparse_mask,
                block_shape=block_shape,
                actual_seq_lengths=[S] * B, actual_seq_lengths_kv=[S] * B,
                q_input_layout="BNSD", kv_input_layout="BNSD",
                num_key_value_heads=num_kv_heads, scale_value=scale_value,
            )

            self.assertEqual(d_query.shape, query.shape)
            self.assertEqual(d_query.dtype, query.dtype)
            self.assertEqual(d_key.shape, key.shape)
            self.assertEqual(d_key.dtype, key.dtype)
            self.assertEqual(d_value.shape, value.shape)
            self.assertEqual(d_value.dtype, value.dtype)

    @unittest.skipIf(not RUN_NPU, "requires npu")
    def test_npu_block_sparse_attention_autograd(self):
        """验证 derivatives.yaml 中 npu_block_sparse_attention 的自动前反向绑定:forward 后 .backward() 能正确反传梯度到 query/key/value."""
        with FakeTensorMode():
            B, N, S, D = 1, 2, 4, 8
            num_kv_heads = 2
            scale_value = 1.0 / (D ** 0.5)
            block_shape = [128, 128]  # blockY 须为 128 的倍数
            ceil_q = (S + block_shape[0] - 1) // block_shape[0]
            ceil_kv = (S + block_shape[1] - 1) // block_shape[1]

            query = torch.randn(B, N, S, D, dtype=torch.float16).npu()
            key = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
            value = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
            query.requires_grad = True
            key.requires_grad = True
            value.requires_grad = True
            block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int32).npu()

            attention_out, softmax_lse = torch_npu.npu_block_sparse_attention(
                query, key, value, block_sparse_mask, block_shape,
                q_input_layout="BNSD", kv_input_layout="BNSD",
                num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
            )
            attention_out.sum().backward()

            self.assertIsNotNone(query.grad)
            self.assertIsNotNone(key.grad)
            self.assertIsNotNone(value.grad)
            self.assertEqual(query.grad.shape, query.shape)
            self.assertEqual(key.grad.shape, key.shape)
            self.assertEqual(value.grad.shape, value.shape)


class TestNpuWeightQuantBatchMatmul2(TestCase):
    def test_npu_wqbmmV2(self):
        with FakeTensorMode():
            x = torch.randint(-3, 3, (2, 64), dtype=torch.float16).npu()
            weight = torch.randint(-3, 3, (64, 128), dtype=torch.int8).npu()
            antiquant_scale = torch.randint(-3, 3, (1, 128), dtype=torch.float16).npu()
            expect_ret = torch.randint(-1, 1, (2, 128), dtype=torch.float16).npu()
            res = torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, weight_dtype=torch_npu.hifloat8)
            self.assertTrue(expect_ret.shape == res.shape)
            self.assertTrue(expect_ret.dtype == res.dtype)


class TestNpuDropout(TestCase):

    def test_npu_dropout(self):
        b = torch.randn(2, 3).npu()
        b.requires_grad = True
        result_b = torch_npu._npu_dropout(b, 0.5)

        with FakeTensorMode():
            a = torch.randn(2, 3).npu()
            a.requires_grad = True
            result = torch_npu._npu_dropout(a, 0.5)
            self.assertTrue(result[0].shape == result_b[0].shape)
            self.assertTrue(result[1].shape == result_b[1].shape)

    def test_npu_dropout_backward(self):
        with FakeTensorMode():
            a = torch.randn(2, 3).npu()
            a.requires_grad = True
            result = torch_npu._npu_dropout(a, 0.5)
            result[0].sum().backward()
            self.assertTrue(a.shape == a.grad.shape)


class TestNpuDtypeCast(TestCase):
    def test_npu_dtype_cast(self):
        with FakeTensorMode():
            npu_input = torch.randn((2, 3), dtype=torch.float32).npu()
            dst_dtype = torch.float16
            result = torch_npu.npu_dtype_cast(npu_input, dst_dtype)

            self.assertEqual(result.dtype, dst_dtype)
            self.assertEqual(result.shape, npu_input.shape)

    def test_npu_dtype_cast_backward(self):
        with FakeTensorMode():
            npu_input = torch.randn((2, 3), dtype=torch.float32).npu()
            npu_input.requires_grad = True
            dst_dtype = torch.float16
            result = torch_npu.npu_dtype_cast(npu_input, dst_dtype)
            result.sum().backward()
            self.assertEqual(result.dtype, dst_dtype)
            self.assertEqual(npu_input.shape, npu_input.grad.shape)


class TestNpuRotaryMul(TestCase):
    def test_npu_rotary_mul(self):
        with FakeTensorMode():
            embedding = torch.randn(2, 8192, 5, 125, dtype=torch.float16, requires_grad=True).npu()
            cosine = torch.randn(1, 8192, 1, 128, dtype=torch.float16, requires_grad=True).npu()
            sine = torch.randn(1, 8192, 1, 128, dtype=torch.float16, requires_grad=True).npu()
            ret = torch.ops.npu.npu_rotary_mul(embedding, cosine, sine)

            self.assertEqual(embedding.shape, ret.shape)
            self.assertEqual(embedding.dtype, ret.dtype)

    @unittest.skip("skip test_npu_rotary_mul_matrix now")
    def test_npu_rotary_mul_matrix(self):
        with FakeTensorMode():
            embedding = torch.randn(2, 2, 8192, 128, dtype=torch.bfloat16).npu()
            cosine = torch.randn(1, 1, 8192, 128, dtype=torch.bfloat16).npu()
            sine = torch.randn(1, 1, 8192, 128, dtype=torch.bfloat16).npu()
            rotate = torch.zeros(128, 128, dtype=torch.bfloat16)
            half = 128 // 2
            rotate[:half, half:] = torch.eye(half)
            rotate[half:, :half] = -torch.eye(half)
            ret = torch_npu.npu_rotary_mul(embedding, cosine, sine, "half", rotate.npu())

            self.assertEqual(embedding.shape, ret.shape)
            self.assertEqual(embedding.dtype, ret.dtype)


class TestNpuRotaryMulBackward(TestCase):
    def test_npu_rotary_mul_backward(self):
        with FakeTensorMode():
            grad = torch.randn(4, 2048, 40, 128, dtype=torch.float16).npu()
            embedding = torch.randn(4, 2048, 40, 128, dtype=torch.float16).npu()
            cosine = torch.randn(1, 2048, 1, 128, dtype=torch.float16).npu()
            sine = torch.randn(1, 2048, 1, 128, dtype=torch.float16).npu()
            ret = torch.ops.npu.npu_rotary_mul_backward(grad, embedding, cosine, sine)

            self.assertEqual(ret[0].shape, embedding.shape)
            self.assertEqual(ret[1].shape, cosine.shape)
            self.assertEqual(ret[2].shape, sine.shape)


class TestNpuTranspose(TestCase):
    def test_npu_transpose(self):
        with FakeTensorMode():
            npu_input = torch.randn((5, 3, 6, 4)).npu()
            perm = [1, 0, 2, 3]
            exp_shape = npu_input.permute(perm).shape
            result = torch_npu.npu_transpose(npu_input, perm)

            self.assertEqual(result.shape, exp_shape)


class TestPromptFlashAttention(TestCase):
    def testPromptFlashAttention(self):
        with FakeTensorMode():
            if "2.1." in torch.__version__:
                q = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
                k = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
                v = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
            else:
                q = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
                k = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
                v = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
            q.requires_grad = True
            k.requires_grad = True
            v.requires_grad = True
            res = torch.ops.npu.npu_prompt_flash_attention(q, k, v)

            self.assertTrue(q.shape == res.shape)


class TestFusedInferAttentionScore(TestCase):
    @unittest.skipIf("2.1." not in torch.__version__, "skip this test for torch version other than 2.1")
    def testFusedInferAttentionScore(self):
        with FakeTensorMode():
            q = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
            k = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
            v = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
            q.requires_grad = True
            k.requires_grad = True
            v.requires_grad = True
            atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score(q, k, v)

            self.assertTrue(q.shape == atten_out.shape)


class TestFusedInferAttentionV2(TestCase):
    @unittest.skipIf("2.1." not in torch.__version__, "skip this test for torch version other than 2.1")
    def testFusedInferAttentionV2(self):
        with FakeTensorMode():
            q = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
            k = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
            v = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
            q.requires_grad = True
            k.requires_grad = True
            v.requires_grad = True
            atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score_v2(q, k, v)

            self.assertTrue(q.shape == atten_out.shape)

    @unittest.skipIf("2.1." not in torch.__version__, "skip this test for torch version other than 2.1")
    def testFusedInferAttentionV2Pa(self):
        with FakeTensorMode():
            q = torch.randn(128, 1, 128, dtype=torch.bfloat16).npu()
            k = torch.randn(1, 1, 8, 128, 16, dtype=torch.bfloat16).npu()
            v = torch.randn(1, 1, 8, 128, 16, dtype=torch.bfloat16).npu()
            block_table = torch.randint(0, 10, (1, 1), dtype=torch.int32).npu()
            atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score_v2(q, k, v, block_table=block_table)

            self.assertTrue(q.shape == atten_out.shape)

    def testFusedInferAttentionV2_bnsd_bsnd_d_unequal(self):
        with FakeTensorMode():
            q = torch.randn(32, 8, 2048, 192, dtype=torch.float16).npu()
            k = torch.randn(32, 8, 2048, 192, dtype=torch.float16).npu()
            v = torch.randn(32, 8, 2048, 128, dtype=torch.float16).npu()
            q.requires_grad = True
            k.requires_grad = True
            v.requires_grad = True

            softmax_scale = 1 / 0.0078125
            atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score_v2(
                q, k, v, num_query_heads=8, input_layout="BNSD_BSND",
                softmax_scale=softmax_scale, pre_tokens=65535, next_tokens=65535)

            golden_output = torch.randn(32, 2048, 8, 128, dtype=torch.float16).npu()

            self.assertTrue(golden_output.shape == atten_out.shape)

    def testFusedInferAttentionV2_bsnd_d_unequal(self):
        with FakeTensorMode():
            q = torch.randn(32, 2048, 8, 192, dtype=torch.float16).npu()
            k = torch.randn(32, 2048, 8, 192, dtype=torch.float16).npu()
            v = torch.randn(32, 2048, 8, 128, dtype=torch.float16).npu()
            q.requires_grad = True
            k.requires_grad = True
            v.requires_grad = True

            softmax_scale = 1 / 0.0078125
            atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score_v2(
                q, k, v, num_query_heads=8, input_layout="BSND",
                softmax_scale=softmax_scale, pre_tokens=65535, next_tokens=65535)

            golden_output = torch.randn(32, 2048, 8, 128, dtype=torch.float16).npu()

            self.assertTrue(golden_output.shape == atten_out.shape)

    def testFusedInferAttentionV2_bsh_d_unequal(self):
        with FakeTensorMode():
            q = torch.randn(32, 2048, 1536, dtype=torch.float16).npu()
            k = torch.randn(32, 2048, 1536, dtype=torch.float16).npu()
            v = torch.randn(32, 2048, 1024, dtype=torch.float16).npu()
            q.requires_grad = True
            k.requires_grad = True
            v.requires_grad = True

            softmax_scale = 1 / 0.0078125
            atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score_v2(
                q, k, v, num_query_heads=8, input_layout="BSH",
                softmax_scale=softmax_scale, pre_tokens=65535, next_tokens=65535)

            golden_output = torch.randn(32, 2048, 1024, dtype=torch.float16).npu()

            self.assertTrue(golden_output.shape == atten_out.shape)

    def testFusedInferAttentionV2_tnd_d_unequal(self):
        with FakeTensorMode():
            q = torch.randn(32, 8, 192, dtype=torch.float16).npu()
            k = torch.randn(32, 8, 192, dtype=torch.float16).npu()
            v = torch.randn(32, 8, 128, dtype=torch.float16).npu()
            q.requires_grad = True
            k.requires_grad = True
            v.requires_grad = True

            softmax_scale = 1 / 0.0078125

            atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score_v2(
                q, k, v, num_query_heads=8, input_layout="TND",
                softmax_scale=softmax_scale, pre_tokens=65535, next_tokens=65535)

            golden_output = torch.randn(32, 8, 128, dtype=torch.float16).npu()

            self.assertTrue(golden_output.shape == atten_out.shape)


class TestFlashAttentionScore(TestCase):
    def testFlashAttentionScore(self):
        with FakeTensorMode():
            q = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            k = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            v = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            softmax_max_sum = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
            q.requires_grad = True
            k.requires_grad = True
            v.requires_grad = True
            res = torch.ops.npu.npu_fusion_attention(q, k, v, head_num=40, input_layout="BNSD")

            self.assertEqual(q.shape, res[0].shape)
            self.assertEqual(q.dtype, res[0].dtype)
            self.assertEqual(softmax_max_sum.shape, res[1].shape)
            self.assertEqual(softmax_max_sum.dtype, res[1].dtype)
            self.assertEqual(softmax_max_sum.shape, res[2].shape)
            self.assertEqual(softmax_max_sum.dtype, res[2].dtype)


class TestFlashAttentionScoreGrad(TestCase):
    def testFlashAttentionScoreGrad(self):
        with FakeTensorMode():
            q = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            k = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            v = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            dy = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            attention_in = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            softmax_max = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
            softmax_sum = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
            q.requires_grad = True
            k.requires_grad = True
            v.requires_grad = True
            dy.requires_grad = True
            attention_in.requires_grad = True
            softmax_max.requires_grad = True
            softmax_sum.requires_grad = True
            res = torch.ops.npu.npu_fusion_attention_grad(q, k, v, dy, head_num=40, input_layout="BNSD",
                            softmax_max=softmax_max, softmax_sum=softmax_sum, attention_in=attention_in)

            self.assertEqual(q.shape, res[0].shape)
            self.assertEqual(q.dtype, res[0].dtype)
            self.assertEqual(k.shape, res[1].shape)
            self.assertEqual(k.dtype, res[1].dtype)
            self.assertEqual(k.shape, res[2].shape)
            self.assertEqual(k.dtype, res[2].dtype)


class TestFlashAttentionTensorScore(TestCase):
    def testFlashAttentionScore(self):
        with FakeTensorMode():
            q = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            k = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            v = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            softmax_max_sum = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
            q.requires_grad = True
            k.requires_grad = True
            v.requires_grad = True
            res = torch.ops.npu.npu_fusion_attention_v3(q, k, v, head_num=40, input_layout="BNSD")

            self.assertEqual(q.shape, res[0].shape)
            self.assertEqual(q.dtype, res[0].dtype)
            self.assertEqual(softmax_max_sum.shape, res[1].shape)
            self.assertEqual(softmax_max_sum.dtype, res[1].dtype)
            self.assertEqual(softmax_max_sum.shape, res[2].shape)
            self.assertEqual(softmax_max_sum.dtype, res[2].dtype)


class TestFlashAttentionTensorScoreGrad(TestCase):
    def testFlashAttentionScoreGrad(self):
        with FakeTensorMode():
            q = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            k = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            v = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            dy = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            attention_in = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
            softmax_max = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
            softmax_sum = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
            q.requires_grad = True
            k.requires_grad = True
            v.requires_grad = True
            dy.requires_grad = True
            attention_in.requires_grad = True
            softmax_max.requires_grad = True
            softmax_sum.requires_grad = True
            res = torch.ops.npu.npu_fusion_attention_grad_v3(q, k, v, dy, head_num=40, input_layout="BNSD",
                            softmax_max=softmax_max, softmax_sum=softmax_sum, attention_in=attention_in)

            self.assertEqual(q.shape, res[0].shape)
            self.assertEqual(q.dtype, res[0].dtype)
            self.assertEqual(k.shape, res[1].shape)
            self.assertEqual(k.dtype, res[1].dtype)
            self.assertEqual(k.shape, res[2].shape)
            self.assertEqual(k.dtype, res[2].dtype)


class TestFusedFloydAttention(TestCase):
    def testFusedFloydAttention(self):
        with FakeTensorMode():
            batch_size = 2
            seq_len_q = 40
            seq_len_kv = 64
            head_num = 16
            head_dim = 128
            dtype_qkv = torch.float16
            dtype_mask = torch.float16

            query_ik = torch.randn(batch_size, seq_len_q, head_num, head_dim, dtype=dtype_qkv).npu()
            key_ij = torch.randn(batch_size, seq_len_kv, head_num, head_dim, dtype=dtype_qkv).npu()
            value_ij = torch.randn_like(key_ij, dtype=dtype_qkv).npu()
            key_jk = torch.randn_like(key_ij, dtype=dtype_qkv).npu()
            value_jk = torch.randn_like(key_ij, dtype=dtype_qkv).npu()
            atten_mask = torch.randn(batch_size, 1, seq_len_q, seq_len_kv, requires_grad=False).npu()

            out0, out1, out2 = torch_npu.npu_fused_floyd_attention(query_ik, key_ij, value_ij, key_jk, value_jk, atten_mask=atten_mask, scale_value=1.0)
            expected_out0_shape = (batch_size, seq_len_q, head_num, head_dim, 8)

            self.assertEqual(out0.shape, expected_out0_shape)
            self.assertEqual(out0.dtype, torch.float32)
            self.assertEqual(out1.shape, out0.shape)
            self.assertEqual(out1.dtype, out0.dtype)

            self.assertEqual(out2.shape, query_ik.shape)
            self.assertEqual(out2.dtype, query_ik.dtype)


class TestFusedFloydAttentionGrad(TestCase):
    def testFusedFloydAttentionGrad(self):
        with FakeTensorMode():
            batch_size = 2
            seq_len_q = 40
            seq_len_kv = 64
            head_num = 16
            head_dim = 128
            dtype_qkv = torch.float16
            dtype_softmax = torch.float32
            dtype_mask = torch.float16

            query_ik = torch.randn(batch_size, seq_len_q, head_num, head_dim, dtype=dtype_qkv).npu()
            key_ij = torch.randn(batch_size, seq_len_kv, head_num, head_dim, dtype=dtype_qkv).npu()
            value_ij = torch.randn_like(key_ij, dtype=dtype_qkv).npu()
            key_jk = torch.randn_like(key_ij, dtype=dtype_qkv).npu()
            value_jk = torch.randn_like(key_ij, dtype=dtype_qkv).npu()
            atten_mask = torch.randn(batch_size, 1, seq_len_q, seq_len_kv, dtype=dtype_mask, requires_grad=False).npu()
            grad_output = torch.randn_like(query_ik, dtype=dtype_qkv).npu()
            softmax_shape = (batch_size, seq_len_q, head_num, head_dim, 8)
            softmax_max = torch.randn(softmax_shape, dtype=dtype_softmax).npu()
            softmax_sum = torch.randn(softmax_shape, dtype=dtype_softmax).npu()
            attention_out = torch.randn(softmax_shape, dtype=dtype_softmax).npu()

            query_ik.requires_grad = True
            key_ij.requires_grad = True
            value_ij.requires_grad = True
            key_jk.requires_grad = True
            value_jk.requires_grad = True
            dquery, dkey_0, dvalue_0, dkey_1, dvalue_1 = torch_npu.npu_fused_floyd_attention_backward(grad_output, query_ik, key_ij, value_ij, key_jk, value_jk,
                                                                                                  attention_out, softmax_max, softmax_sum,
                                                                                                  atten_mask=atten_mask,
                                                                                                  scale_value=1.0)

            self.assertEqual(dquery.shape, query_ik.shape)
            self.assertEqual(dquery.dtype, query_ik.dtype)
            self.assertEqual(dkey_0.shape, key_ij.shape)
            self.assertEqual(dkey_0.dtype, key_ij.dtype)
            self.assertEqual(dvalue_0.shape, value_ij.shape)
            self.assertEqual(dvalue_0.dtype, value_ij.dtype)
            self.assertEqual(dkey_1.shape, key_jk.shape)
            self.assertEqual(dkey_1.dtype, key_jk.dtype)
            self.assertEqual(dvalue_1.shape, value_jk.shape)
            self.assertEqual(dvalue_1.dtype, value_jk.dtype)


class TestNpuMoeComputeExpertTokens(TestCase):
    @unittest.skipIf('2.3.1' in torch.__version__, "skip this ut for this torch version")
    def test_npu_moe_compute_expert_tokens(self):
        with FakeTensorMode():
            data = list(range(10))
            experts = torch.tensor(data, dtype=torch.int32).npu()
            sorted_experts = torch.sort(experts)[0]
            num_experts = 10
            ret = torch.ops.npu.npu_moe_compute_expert_tokens(sorted_experts, num_experts)

            self.assertEqual(sorted_experts.shape, ret.shape)
            self.assertEqual(sorted_experts.dtype, ret.dtype)


class TestMaskedSoftmaxWithRelPosBias(TestCase):
    # meta shape推导
    @unittest.skipIf("2.1." not in torch.__version__, "skip this test for torch version other than 2.1")
    def testMaskedSoftmaxWithRelPosBias(self):
        with FakeTensorMode():
            x = torch.randn(96, 2, 2, 32, 32, dtype=torch.float)
            relative_pos_bias = torch.randn(1, 1, 2, 32, 32, dtype=torch.float)
            atten_mask = torch.randn(1, 2, 1, 32, 32, dtype=torch.float)
            x.requires_grad = True
            atten_mask.requires_grad = True
            relative_pos_bias.requires_grad = True
            res = torch.ops.npu.npu_masked_softmax_with_rel_pos_bias(x, atten_mask, relative_pos_bias)
            self.assertTrue(x.shape == res.shape)


class TestNpuMoeInitRouting(TestCase):
    # meta shape推导
    @unittest.skipIf("2.1." not in torch.__version__, "skip this test for torch version other than 2.1")
    def testNpuMoeInitRouting(self):
        with FakeTensorMode():
            x = torch.randn(3, 4, dtype=torch.float).npu()
            row_idx = torch.randint(0, 6, (3, 2), dtype=torch.int32).npu()
            expert_idx = torch.randint(0, 10, (3, 2), dtype=torch.int32).npu()
            active_num = 3
            expanded_x_golden = torch.randn(6, 4, dtype=torch.float).npu()
            expanded_row_idx_golden = torch.randint(0, 6, (6, ), dtype=torch.int32).npu()
            expanded_expert_idx_golden = torch.randint(0, 6, (6, ), dtype=torch.int32).npu()
            expanded_x, expanded_row_idx, expanded_expert_idx = torch.ops.npu.npu_moe_init_routing(x, row_idx, expert_idx, active_num=active_num)

            self.assertTrue(expanded_x.dtype == expanded_x_golden.dtype)
            self.assertTrue(expanded_row_idx.dtype == expanded_row_idx_golden.dtype)
            self.assertTrue(expanded_expert_idx.dtype == expanded_expert_idx_golden.dtype)
            self.assertTrue(expanded_x.shape == expanded_x_golden.shape)
            self.assertTrue(expanded_row_idx.shape == expanded_row_idx_golden.shape)
            self.assertTrue(expanded_expert_idx.shape == expanded_expert_idx_golden.shape)


class TestNpuMoeGatingTopKSoftmax(TestCase):
    # meta shape推导
    @unittest.skipIf("2.1." not in torch.__version__, "skip this test for torch version other than 2.1")
    def testNpuMoeGatingTopKSoftmax(self):
        with FakeTensorMode():
            x = torch.randn(3, 4, dtype=torch.float).npu()
            y_golden = torch.randn(3, 2, dtype=torch.float).npu()
            expert_idx_golden = torch.randint(-1, 1, (3, 2), dtype=torch.int32).npu()
            row_idx_golden = torch.randint(-1, 1, (3, 2), dtype=torch.int32).npu()
            y, expert_idx, row_idx = torch.ops.npu.npu_moe_gating_top_k_softmax(x, None, k=2)

            self.assertTrue(y.dtype == y_golden.dtype)
            self.assertTrue(expert_idx.dtype == expert_idx_golden.dtype)
            self.assertTrue(row_idx.dtype == row_idx_golden.dtype)
            self.assertTrue(y.shape == y_golden.shape)
            self.assertTrue(expert_idx.shape == expert_idx_golden.shape)
            self.assertTrue(row_idx.shape == row_idx_golden.shape)


class TestNpuRopeQuantKVCache(TestCase):
    @unittest.skip("skip test_npu_rope_quant_kvcache_meta now")
    def test_npu_rope_quant_kvcache_meta(self):
        with FakeTensorMode() as mode:
            data_x = np.random.uniform(0, 1, [1, 1, 128 * 3]).astype(np.float16)
            in_x = torch.from_numpy(data_x).to(torch.float16).npu()
            data_cos = np.random.uniform(0, 1, [1, 1, 1, 128]).astype(np.float16)
            in_cos = torch.from_numpy(data_cos).to(torch.float16).npu()
            data_sin = np.random.uniform(0, 1, [1, 1, 1, 128]).astype(np.float16)
            in_sin = torch.from_numpy(data_sin).to(torch.float16).npu()
            data_k_cache = np.random.uniform(0, 1, [1, 2, 1, 128]).astype(np.int8)
            in_k_cache = torch.from_numpy(data_k_cache).to(torch.int8).npu()
            data_v_cache = np.random.uniform(0, 1, [1, 2, 1, 128]).astype(np.int8)
            in_v_cache = torch.from_numpy(data_v_cache).to(torch.int8).npu()
            in_indices = torch.tensor([0]).to(torch.int32).npu()
            in_scale_k = torch.randn([128], dtype=torch.float32).npu()
            in_scale_v = torch.randn([128], dtype=torch.float32).npu()
            size_splits = [128, 128, 128]

            fake_x = mode.from_tensor(in_x)
            fake_cos = mode.from_tensor(in_cos)
            fake_sin = mode.from_tensor(in_sin)
            fake_indices = mode.from_tensor(in_indices)
            fake_k_cache = mode.from_tensor(in_k_cache)
            fake_v_cache = mode.from_tensor(in_v_cache)
            fake_scale_k = mode.from_tensor(in_scale_k)
            fake_scale_v = mode.from_tensor(in_scale_v)
            self.assertIsNotNone(fake_x)
            self.assertIsNotNone(fake_cos)
            self.assertIsNotNone(fake_sin)
            self.assertIsNotNone(fake_k_cache)
            self.assertIsNotNone(fake_v_cache)
            self.assertIsNotNone(fake_scale_k)
            self.assertIsNotNone(fake_scale_v)
            q_result, k_result, c_result, k_cache_result, v_cache_result = torch.ops.npu.npu_rope_quant_kvcache(fake_x,
                                                                                                                fake_cos,
                                                                                                                fake_sin,
                                                                                                                fake_k_cache,
                                                                                                                fake_v_cache,
                                                                                                                fake_indices,
                                                                                                                fake_scale_k,
                                                                                                                fake_scale_v,
                                                                                                                size_splits)

            self.assertEqual(q_result.shape, torch.Size([1, 1, 1, 128]))
            self.assertEqual(q_result.dtype, in_x.dtype)
            self.assertEqual(q_result.device, in_x.device)
            self.assertTrue(isinstance(q_result, FakeTensor))
            self.assertEqual(k_cache_result.shape, in_k_cache.shape)
            self.assertEqual(k_cache_result.dtype, in_k_cache.dtype)
            self.assertEqual(k_cache_result.device, in_k_cache.device)
            self.assertTrue(isinstance(k_cache_result, FakeTensor))


class TestGeGlu(TestCase):
    # meta shape infer
    def TestGeGlu(self):
        with FakeTensorMode():
            x = torch.randn(2, 10, 64, dtype=torch.float)
            dim = -1
            approximate = 1
            activate_left = False
            y, gelu = torch.ops.npu.npu_geglu(x, dim, approximate, activate_left)

            dim_num = x.dim()
            if dim < 0:
                dim += dim_num
            for index in range(dim_num):
                if index != dim:
                    self.assertEqual(y.size(index), x.size(index))
                    self.assertEqual(gelu.size(index), x.size(index))
                else:
                    self.assertEqual(y.size(index) * 2, x.size(index))
                    self.assertEqual(gelu.size(index) * 2, x.size(index))


class TestScatterUpdateMeta(TestCase):

    def test_scatter_update_meta(self):
        with FakeTensorMode() as mode:
            in_self = torch.randn(4, 4, 32, 256, dtype=torch.float16).npu()
            in_indices = torch.tensor([1, 1, 1, 1]).npu()
            in_updates = torch.randn(4, 4, 1, 256, dtype=torch.float16).npu()
            fake_self = mode.from_tensor(in_self)
            fake_indices = mode.from_tensor(in_indices)
            fake_updates = mode.from_tensor(in_updates)
            self.assertIsNotNone(fake_self)
            self.assertIsNotNone(fake_indices)
            self.assertIsNotNone(fake_updates)
            fake_result = torch.ops.npu.scatter_update(fake_self, fake_indices, fake_updates, -2)

            self.assertEqual(fake_result.shape, in_self.shape)
            self.assertEqual(fake_result.dtype, in_self.dtype)
            self.assertEqual(fake_result.device, in_self.device)
            self.assertTrue(isinstance(fake_result, FakeTensor))
            self.assertIsNot(fake_result, fake_self)
            self.assertIsNot(fake_result, in_self)

    def test_scatter_update__meta(self):
        with FakeTensorMode() as mode:
            in_self = torch.randn(4, 4, 32, 256, dtype=torch.float32).npu()
            in_indices = torch.tensor([1, 1, 1, 1]).npu()
            in_updates = torch.randn(4, 4, 1, 256, dtype=torch.float32).npu()
            fake_self = mode.from_tensor(in_self)
            fake_indices = mode.from_tensor(in_indices)
            fake_updates = mode.from_tensor(in_updates)
            self.assertIsNotNone(fake_self)
            self.assertIsNotNone(fake_indices)
            self.assertIsNotNone(fake_updates)
            fake_result = torch.ops.npu.scatter_update_(fake_self, fake_indices, fake_updates, -2)

            self.assertEqual(fake_result.shape, in_self.shape)
            self.assertEqual(fake_result.dtype, in_self.dtype)
            self.assertEqual(fake_result.device, in_self.device)
            self.assertTrue(isinstance(fake_result, FakeTensor))
            self.assertIs(fake_result, fake_self)
            self.assertIsNot(fake_result, in_self)


class TestNpuQuantScatterMeta(TestCase):

    def test_npu_quant_scatter_meta(self):
        with FakeTensorMode() as mode:
            data_var = np.random.uniform(0, 1, [1, 1, 32]).astype(np.int8)
            in_var = torch.from_numpy(data_var).to(torch.int8).npu()
            data_indices = np.random.uniform(0, 1, [1]).astype(np.int32)
            in_indices = torch.from_numpy(data_indices).to(torch.int32).npu()
            data_updates = np.random.uniform(1, 2, [1, 1, 32]).astype(np.float16)
            in_updates = torch.from_numpy(data_updates).to(torch.bfloat16).npu()
            data_quant_scales = np.random.uniform(0, 1, [1, 1, 32]).astype(np.float16)
            in_quant_scales = torch.from_numpy(data_quant_scales).to(torch.bfloat16).npu()
            fake_var = mode.from_tensor(in_var)
            fake_indices = mode.from_tensor(in_indices)
            fake_updates = mode.from_tensor(in_updates)
            fake_quant_scales = mode.from_tensor(in_quant_scales)
            self.assertIsNotNone(fake_var)
            self.assertIsNotNone(fake_indices)
            self.assertIsNotNone(fake_updates)
            self.assertIsNotNone(fake_quant_scales)
            fake_result = torch.ops.npu.npu_quant_scatter(fake_var, fake_indices, fake_updates, fake_quant_scales, None,
                                                          -2, -1, "update")

            self.assertEqual(fake_result.shape, in_var.shape)
            self.assertEqual(fake_result.dtype, in_var.dtype)
            self.assertEqual(fake_result.device, in_var.device)
            self.assertTrue(isinstance(fake_result, FakeTensor))
            self.assertIsNot(fake_result, fake_var)
            self.assertIsNot(fake_result, in_var)

    def test_npu_quant_scatter__meta(self):
        with FakeTensorMode() as mode:
            data_var = np.random.uniform(0, 1, [1, 1, 32]).astype(np.int8)
            in_var = torch.from_numpy(data_var).to(torch.int8).npu()
            data_indices = np.random.uniform(0, 1, [1]).astype(np.int32)
            in_indices = torch.from_numpy(data_indices).to(torch.int32).npu()
            data_updates = np.random.uniform(1, 2, [1, 1, 32]).astype(np.float16)
            in_updates = torch.from_numpy(data_updates).to(torch.bfloat16).npu()
            data_quant_scales = np.random.uniform(0, 1, [1, 1, 32]).astype(np.float16)
            in_quant_scales = torch.from_numpy(data_quant_scales).to(torch.bfloat16).npu()
            fake_var = mode.from_tensor(in_var)
            fake_indices = mode.from_tensor(in_indices)
            fake_updates = mode.from_tensor(in_updates)
            fake_quant_scales = mode.from_tensor(in_quant_scales)
            self.assertIsNotNone(fake_var)
            self.assertIsNotNone(fake_indices)
            self.assertIsNotNone(fake_updates)
            self.assertIsNotNone(fake_quant_scales)
            fake_result = torch.ops.npu.npu_quant_scatter_(fake_var, fake_indices, fake_updates, fake_quant_scales,
                                                           None, -2, -1, "update")

            self.assertEqual(fake_result.shape, in_var.shape)
            self.assertEqual(fake_result.dtype, in_var.dtype)
            self.assertEqual(fake_result.device, in_var.device)
            self.assertTrue(isinstance(fake_result, FakeTensor))
            self.assertIs(fake_result, fake_var)
            self.assertIsNot(fake_result, in_var)


class TestNpuApplyRotoryPosEmbMeta(TestCase):

    def test_npu_apply_rotary_pos_emb_meta(self):
        with FakeTensorMode() as mode:
            query_var = np.random.uniform(0, 1, [4, 1024, 16, 128]).astype(np.float16)
            data_query = torch.from_numpy(query_var).to(torch.float16).npu()
            key_var = np.random.uniform(0, 1, [4, 1024, 16, 128]).astype(np.float16)
            data_key = torch.from_numpy(key_var).to(torch.float16).npu()
            cos_var = np.random.uniform(0, 1, [4, 1024, 1, 128]).astype(np.float16)
            data_cos = torch.from_numpy(cos_var).to(torch.float16).npu()
            sin_var = np.random.uniform(0, 1, [4, 1024, 1, 128]).astype(np.float16)
            data_sin = torch.from_numpy(sin_var).to(torch.float16).npu()
            fake_query = mode.from_tensor(data_query)
            fake_key = mode.from_tensor(data_key)
            fake_cos = mode.from_tensor(data_cos)
            fake_sin = mode.from_tensor(data_sin)
            self.assertIsNotNone(fake_query)
            self.assertIsNotNone(fake_key)
            self.assertIsNotNone(fake_cos)
            self.assertIsNotNone(fake_sin)
            fake_query_result, fake_key_result = torch.ops.npu.npu_apply_rotary_pos_emb(fake_query, fake_key, fake_cos, fake_sin, "BSND","half")
            self.assertEqual(fake_query_result.shape, data_query.shape)
            self.assertEqual(fake_query_result.dtype, data_query.dtype)
            self.assertEqual(fake_query_result.device, data_query.device)
            self.assertEqual(fake_key_result.shape, data_key.shape)
            self.assertEqual(fake_key_result.dtype, data_key.dtype)
            self.assertEqual(fake_key_result.device, data_key.device)
            self.assertTrue(isinstance(fake_query_result, FakeTensor))
            self.assertTrue(isinstance(fake_key_result, FakeTensor))


class TestMmAllReduce(TestCase):
    def test_mm_all_reduce(self):
        with FakeTensorMode():
            dst_dtype = torch.float16
            x1 = torch.randn(128, 256, dtype=torch.float16).npu()
            x2 = torch.randn(256, 128, dtype=torch.float16).npu()
            hcom = "fake group info"
            output = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom, reduce_op="sum")
            self.assertEqual(output.shape, (128, 128))
            self.assertEqual(output.dtype, dst_dtype)

    def test_mm_all_reduce_quant(self):
        with FakeTensorMode():
            dst_dtype = torch.float16
            x1 = torch.randn(128, 256, dtype=torch.float16).to(torch.int8).npu()
            x2 = torch.randn(256, 128, dtype=torch.float16).to(torch.int8).npu()
            dequant = torch.randn(128, dtype=torch.float16).to(torch.int64).npu()
            hcom = "fake group info"
            output = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom, reduce_op="sum", dequant_scale=dequant)
            self.assertEqual(output.shape, (128, 128))
            self.assertEqual(output.dtype, dst_dtype)

    def test_mm_all_reduce_quant(self):
        with FakeTensorMode():
            dst_dtype = torch.bfloat16
            x1 = torch.randn(128, 256, dtype=torch.float16).to(torch.int8).npu()
            x2 = torch.randn(256, 128, dtype=torch.float16).to(torch.int8).npu()
            dequant = torch.randn(128, dtype=torch.bfloat16).npu()
            hcom = "fake group info"
            output = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom, reduce_op="sum", dequant_scale=dequant)
            self.assertEqual(output.shape, (128, 128))
            self.assertEqual(output.dtype, dst_dtype)


class TestMmReduceScatter(TestCase):
    def test_mm_reduce_scatter(self):
        with FakeTensorMode():
            dst_dtype = torch.float16
            m, k, n = 128, 512, 256
            x1 = torch.randn(m, k, dtype=torch.float16).npu()
            x2 = torch.randn(k, n, dtype=torch.float16).npu()
            hcom = "fake group info"
            world_size = 8
            output = torch_npu.npu_mm_reduce_scatter_base(x1, x2, hcom, world_size, reduce_op="sum")
            self.assertEqual(output.shape, (m // world_size, n))
            self.assertEqual(output.dtype, dst_dtype)

    def test_mm_all_reduce_quant(self):
        with FakeTensorMode():
            world_size = 8
            m, k, n = 128, 512, 256
            dst_dtype = torch.bfloat16
            x1 = torch.randint(-10, 10, size=(m, k), dtype=torch.int8).npu()
            x2 = torch.randint(-10, 10, size=(k, n), dtype=torch.int8).npu()
            x1_scale = torch.randn((m, 1), dtype=torch.float32).npu()
            x2_scale = torch.randn((1, n), dtype=torch.float32).npu()
            hcom = "fake group info"
            output = torch_npu.npu_mm_reduce_scatter_base(x1, x2, hcom, world_size, x1_scale=x1_scale,
                                                          x2_scale=x2_scale, reduce_op="sum")
            self.assertEqual(output.shape, (m // world_size, n))
            self.assertEqual(output.dtype, dst_dtype)

class TestAllGatherBaseMm(TestCase):
    def test_all_gather_base_mm(self):
        with FakeTensorMode():
            dst_dtype = torch.float16
            m, k, n = 128, 512, 256
            x1 = torch.randn(m, k, dtype=torch.float16).npu()
            x2 = torch.randn(k, n, dtype=torch.float16).npu()
            hcom = "fake group info"
            world_size = 8
            output, gather_out = torch_npu.npu_all_gather_base_mm(x1, x2, hcom, world_size)
            self.assertEqual(output.shape, (m * world_size, n))
            self.assertEqual(output.dtype, dst_dtype)
            self.assertEqual(gather_out.shape, (m * world_size, k))
            self.assertEqual(gather_out.dtype, dst_dtype)

    def test_all_gather_base_mm_gather_out_false(self):
        with FakeTensorMode():
            dst_dtype = torch.float16
            m, k, n = 128, 512, 256
            x1 = torch.randn(m, k, dtype=torch.float16).npu()
            x2 = torch.randn(k, n, dtype=torch.float16).npu()
            hcom = "fake group info"
            world_size = 8
            output, gather_out = torch_npu.npu_all_gather_base_mm(x1, x2, hcom, world_size, gather_output=False)
            self.assertEqual(output.shape, (m * world_size, n))
            self.assertEqual(output.dtype, dst_dtype)

    def test_all_gather_base_mm_quant(self):
        with FakeTensorMode():
            world_size = 8
            m, k, n = 128, 512, 256
            dst_dtype = torch.bfloat16
            x1 = torch.randint(-10, 10, size=(m, k), dtype=torch.int8).npu()
            x2 = torch.randint(-10, 10, size=(k, n), dtype=torch.int8).npu()
            x1_scale = torch.randn((m, 1), dtype=torch.float32).npu()
            x2_scale = torch.randn((1, n), dtype=torch.float32).npu()
            hcom = "fake group info"
            output, gather_out = torch_npu.npu_all_gather_base_mm(x1, x2, hcom, world_size, x1_scale=x1_scale,
                                                      x2_scale=x2_scale)
            self.assertEqual(output.shape, (m * world_size, n))
            self.assertEqual(output.dtype, dst_dtype)
            self.assertEqual(gather_out.shape, (m * world_size, k))
            self.assertEqual(gather_out.dtype, x1.dtype)

class TestNpuDeepNorm(TestCase):
    def test_npu_deep_norm(self):
        with FakeTensorMode():
            npu_x = torch.randn((2, 3), dtype=torch.float32).npu()
            npu_gx = torch.randn((2, 3), dtype=torch.float32).npu()
            npu_beta = torch.randn((3,), dtype=torch.float32).npu()
            npu_gamma = torch.randn((3,), dtype=torch.float32).npu()

            result_mean, result_rstd, result_y = torch_npu.npu_deep_norm(npu_x, npu_gx, npu_beta, npu_gamma)

            self.assertEqual(result_y.dtype, npu_x.dtype)
            self.assertEqual(result_y.shape, npu_x.shape)
            self.assertEqual(result_y.device, npu_x.device)
            self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
            self.assertEqual(result_rstd.device, npu_x.device)
            self.assertEqual(result_mean.shape, torch.Size([2, 1]))
            self.assertEqual(result_mean.device, npu_x.device)


class TestRmsNormQuant(TestCase):
    def test_rms_norm_quant(self):
        with FakeTensorMode():
            x = torch.randn([2, 16], dtype=torch.float16).npu()
            gamma = torch.randn([16, ], dtype=torch.float16).npu()
            beta = torch.randn([16, ], dtype=torch.float16).npu()
            scale = torch.randn(1, dtype=torch.float16).npu()
            offset = torch.randint(-10, 10, (1, ), dtype=torch.int8).npu()
            y = torch_npu.npu_rms_norm_quant(x, gamma, beta, scale, offset)
            self.assertTrue(y.shape == x.shape)
            self.assertTrue(y.dtype == torch.int8)

            x = torch.randn([2, 16], dtype=torch.bfloat16).npu()
            gamma = torch.randn([16, ], dtype=torch.bfloat16).npu()
            beta = torch.randn([16, ], dtype=torch.bfloat16).npu()
            scale = torch.randn(1, dtype=torch.bfloat16).npu()
            offset = torch.randint(-10, 10, (1, ), dtype=torch.int8).npu()
            y = torch_npu.npu_rms_norm_quant(x, gamma, beta, scale, offset)
            self.assertTrue(y.shape == x.shape)
            self.assertTrue(y.dtype == torch.int8)

            x = torch.randn([2, 16], dtype=torch.float16).npu()
            gamma = torch.randn([16, ], dtype=torch.float16).npu()
            beta = torch.randn([16, ], dtype=torch.float16).npu()
            scale = torch.randn(1, dtype=torch.float16).npu()
            offset = torch.randint(-10, 10, (1, ), dtype=torch.int8).npu()
            y = torch_npu.npu_rms_norm_quant(x, gamma, beta, scale, offset, epsilon=1e-06, dst_dtype=291)
            self.assertTrue(y.shape == x.shape)
            self.assertTrue(y.dtype == torch.float8_e5m2)


class TestRmsNormQuantV2(TestCase):
    def test_rms_norm_quant_v2(self):
        with FakeTensorMode():
            x = torch.randn([2, 16], dtype=torch.float16).npu()
            gamma = torch.randn([16, ], dtype=torch.float16).npu()
            scale = torch.randn(1, dtype=torch.float16).npu()
            offset = torch.randint(-10, 10, (1, ), dtype=torch.float16).npu()
            beta = torch.randn([16, ], dtype=torch.float16).npu()
            y, rstd = torch_npu.npu_rms_norm_quant_v2(x, gamma, scale, offset=offset, beta=beta, epsilon=1e-06, div_mode=True, dst_dtype=1)
            self.assertTrue(y.shape == x.shape)
            self.assertTrue(y.dtype == torch.int8)

            x = torch.randn([2, 16], dtype=torch.bfloat16).npu()
            gamma = torch.randn([16, ], dtype=torch.bfloat16).npu()
            scale = torch.randn(1, dtype=torch.bfloat16).npu()
            offset = torch.randint(-10, 10, (1, ), dtype=torch.bfloat16).npu()
            beta = torch.randn([16, ], dtype=torch.bfloat16).npu()
            y, rstd = torch_npu.npu_rms_norm_quant_v2(x, gamma, scale, offset=offset, beta=beta, epsilon=1e-06, div_mode=True, dst_dtype=1)
            self.assertTrue(y.shape == x.shape)
            self.assertTrue(y.dtype == torch.int8)

            x = torch.randn([2, 16], dtype=torch.float16).npu()
            gamma = torch.randn([16, ], dtype=torch.float16).npu()
            scale = torch.randn(1, dtype=torch.float16).npu()
            offset = torch.randint(-10, 10, (1, ), dtype=torch.float16).npu()
            beta = torch.randn([16, ], dtype=torch.float16).npu()
            y, rstd = torch_npu.npu_rms_norm_quant_v2(x, gamma, scale, offset=offset, beta=beta, epsilon=1e-06, div_mode=False, dst_dtype=291)
            self.assertTrue(y.shape == x.shape)
            self.assertTrue(y.dtype == torch.float8_e5m2)

            x = torch.randn([2, 16], dtype=torch.float16).npu()
            gamma = torch.randn([1,16], dtype=torch.float16).npu()
            scale = torch.randn(1, dtype=torch.float16).npu()
            offset = torch.randint(-10, 10, (1, ), dtype=torch.float16).npu()
            beta = torch.randn([1, 16], dtype=torch.float16).npu()
            y, rstd = torch_npu.npu_rms_norm_quant_v2(x, gamma, scale, offset=offset, beta=beta, epsilon=1e-06, div_mode=False, dst_dtype=1)
            self.assertTrue(y.shape == x.shape)
            self.assertTrue(y.dtype == torch.int8)


class TestRmsNormDynamicMxQuant(TestCase):
    def test_npu_rms_norm_dynamic_mx_quant_meta(self):
        with FakeTensorMode():
            x = torch.randn([8, 64], dtype=torch.float16, device='npu')
            gamma = torch.ones([64, ], dtype=torch.float16, device='npu')
            beta = torch.zeros([64, ], dtype=torch.float16, device='npu')
            y_npu, mxscale_npu, rstd_npu = torch_npu.npu_rms_norm_dynamic_mx_quant(
                x, gamma, beta=beta, epsilon=1e-6, scale_alg=0, round_mode="rint", dst_type=torch_npu.float8_e5m2
            )
            self.assertEqual(y_npu.shape, x.shape)
            self.assertEqual(y_npu.dtype, torch.float8_e5m2)
            self.assertEqual(mxscale_npu.shape, torch.Size([8, 1, 2]))
            self.assertEqual(mxscale_npu.dtype, torch.uint8)
            self.assertEqual(rstd_npu.shape, torch.Size([8, 1]))
            self.assertEqual(rstd_npu.dtype, torch.float32)


class TestNpuRmsNorm(TestCase):
    def test_npu_rms_norm(self):
        with FakeTensorMode():
            npu_x = torch.randn((2, 3), dtype=torch.float32).npu()
            npu_gamma = torch.randn((3,), dtype=torch.float32).npu()

            result_y, result_rstd = torch_npu.npu_rms_norm(npu_x, npu_gamma)

            self.assertEqual(result_y.dtype, npu_x.dtype)
            self.assertEqual(result_y.shape, npu_x.shape)
            self.assertEqual(result_y.device, npu_x.device)
            self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
            self.assertEqual(result_rstd.device, npu_x.device)

    def test_npu_rms_norm_backward(self):
        with FakeTensorMode():
            npu_x = torch.randn((2, 3), dtype=torch.float32).npu()
            npu_gamma = torch.randn((3,), dtype=torch.float32).npu()
            npu_x.requires_grad = True
            npu_gamma.requires_grad = True

            out = torch_npu.npu_rms_norm(npu_x, npu_gamma)[0]
            grad_y = torch.randn((2, 3), dtype=torch.float32).npu()
            out.backward(grad_y)
            dx = npu_x.grad
            dw = npu_gamma.grad
            self.assertEqual(dx.dtype, npu_x.dtype)
            self.assertEqual(dx.shape, npu_x.shape)
            self.assertEqual(dx.device, npu_x.device)
            self.assertEqual(dw.shape, npu_gamma.shape)
            self.assertEqual(dw.device, npu_gamma.device)

class TestDistributeBarrier(TestCase):
    def test_npu_distribute_barrier(self):
        with FakeTensorMode():
            ep_world_size = 16
            x_ref = torch.randn(1).to(torch.int32)
            time_out = torch.randn(1).to(torch.int32)
            elastic_info = torch.randn(1,4+2*ep_world_size).to(torch.int32)
            result = torch_npu._npu_distribute_barrier(x_ref = x_ref, group = "group_ep",
            world_size = ep_world_size, time_out = time_out, elastic_info = elastic_info)
            self.assertEqual(result.shape, x_ref.shape)

class TestQuantLightningIndexer(TestCase):
    def quant_lightning_indexer_result(self):
        with FakeTensorMode():
            b = 1
            t = None
            s1 = 4
            s2 = 512
            act_seq_q = None
            act_seq_k = None
            sparse_mode = 0
            sparse_count = 2048
            n1 = 64
            n2 = 1
            d = 128
            block_size = 128
            layout_query = "BSND"
            layout_key = 'PA_BSND'
            query_quant_mode = 0
            key_quant_mode = 0
            np.random.seed(0)
            # -------------
            max_block_table_num = (s2 + block_size - 1) // block_size
            block_table = torch.tensor([range(b * max_block_table_num)], dtype = torch.int32).reshape(b, -1)
            key = torch.tensor(np.random.uniform(-128, 127, (b * max_block_table_num, block_size, n2, d))).to(torch.int8)
            key_dequant_scale = torch.tensor(np.random.uniform(0, 10, (b * max_block_table_num, block_size, n2)))
            key_dequant_scale = key_dequant_scale.to(torch.float16)
            if layout_query == 'BSND':
                query = torch.tensor(np.random.uniform(-128, 127, (b, s1, n1, d))).to(torch.int8)
                query_dequant_scale = torch.tensor(np.random.uniform(0, 10, (b, s1, n1))).to(torch.float16)
                weights = torch.tensor(np.random.uniform(0, 0.01, (b, s1, n1))).to(torch.float16)
                actual_seq_lengths_query = torch.tensor(np.random.uniform(s1, s1, (b))).to(torch.int32) \
                                        if act_seq_q is None else torch.tensor(act_seq_q).to(torch.int32)
                actual_seq_lengths_key = torch.tensor(np.random.uniform(s2, s2, (b))).to(torch.int32) \
                                        if act_seq_k is None else torch.tensor(act_seq_k).to(torch.int32)
            else:
                query = torch.tensor(np.random.uniform(-128, 127, (t, n1, d))).to(torch.int8)
                query_dequant_scale = torch.tensor(np.random.uniform(0, 10, (t, n1))).to(torch.float16)
                weights = torch.tensor(np.random.uniform(0, 0.01, (t, n1))).to(torch.float16)
                actual_seq_lengths_query = torch.tensor(act_seq_q).to(torch.int32)
                actual_seq_lengths_key = torch.tensor(act_seq_k).to(torch.int32)

            cpu_out = self.cpu_op_exec(query.cpu(), key.cpu(), weights.cpu(), query_dequant_scale.cpu(), key_dequant_scale.cpu(),
                                    actual_seq_lengths_query.cpu(), actual_seq_lengths_key.cpu(), block_table.cpu(),
                                    layout_query, sparse_count, sparse_mode)

            npu_eager_out = self.npu_op_exec_eager(query.npu(), key.npu(), weights.npu(), query_dequant_scale.npu(), key_dequant_scale.npu(),
                                                actual_seq_lengths_query.npu(), actual_seq_lengths_key.npu(), block_table.npu(),
                                                query_quant_mode, key_quant_mode,
                                                layout_query, layout_key, sparse_count, sparse_mode)
            res = npu_eager_out.equal(cpu_out)
            self.assertRtolEqual(res, True)

class TestNpuAddRmsNorm(TestCase):
    def test_npu_add_rms_norm(self):
        with FakeTensorMode():
            npu_x1 = torch.randn((2, 3), dtype=torch.float32).npu()
            npu_x2 = torch.randn((2, 3), dtype=torch.float32).npu()
            npu_gamma = torch.randn((3,), dtype=torch.float32).npu()

            result_y, result_rstd, result_x = torch_npu.npu_add_rms_norm(npu_x1, npu_x2, npu_gamma)

            self.assertEqual(result_y.dtype, npu_x1.dtype)
            self.assertEqual(result_y.shape, npu_x1.shape)
            self.assertEqual(result_y.device, npu_x1.device)
            self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
            self.assertEqual(result_rstd.device, npu_x1.device)
            self.assertEqual(result_x.dtype, npu_x1.dtype)
            self.assertEqual(result_x.shape, npu_x1.shape)
            self.assertEqual(result_x.device, npu_x1.device)


class TestNpuAddRmsNormV2(TestCase):
    def test_npu_add_rms_norm_v2(self):
        with FakeTensorMode():
            npu_x1 = torch.randn((2, 3), dtype=torch.float32).npu()
            npu_x2 = torch.randn((2, 3), dtype=torch.float32).npu()
            npu_gamma = torch.randn((3,), dtype=torch.float32).npu()

            result_rstd = torch_npu.npu_add_rms_norm_v2(npu_x1, npu_x2, npu_gamma)

            self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
            self.assertEqual(result_rstd.device, npu_x1.device)
            self.assertEqual(result_rstd.dtype, torch.float32)


class TestNpuAddRmsNormV2Functional(TestCase):
    def test_npu_add_rms_norm_v2_functional(self):
        with FakeTensorMode():
            npu_x1 = torch.randn((2, 3), dtype=torch.float32).npu()
            npu_x2 = torch.randn((2, 3), dtype=torch.float32).npu()
            npu_gamma = torch.randn((3,), dtype=torch.float32).npu()

            result_rstd, result_y, result_x = torch_npu.npu_add_rms_norm_v2_functional(npu_x1, npu_x2, npu_gamma)

            self.assertEqual(result_y.dtype, npu_x1.dtype)
            self.assertEqual(result_y.shape, npu_x1.shape)
            self.assertEqual(result_y.device, npu_x1.device)
            self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
            self.assertEqual(result_rstd.device, npu_x1.device)
            self.assertEqual(result_x.dtype, npu_x1.dtype)
            self.assertEqual(result_x.shape, npu_x1.shape)
            self.assertEqual(result_x.device, npu_x1.device)


class TestNpuFfnWorkerBatching(TestCase):
    def test_npu_ffn_worker_batching(self):
        with FakeTensorMode():
            x = torch.randn((1024,)).npu().to(torch.int8)
            expert_num = 2
            max_out_shape = [16, 8, 9, 7168]
            Y = max_out_shape[0] * max_out_shape[1] * max_out_shape[2]
            y, group_list, session_ids, micro_batch_ids, token_ids, expert_offsets, dynamic_scale, actual_token_num = torch_npu.npu_ffn_worker_batching(x,
                expert_num, max_out_shape, token_dtype=2, need_schedule=1, layer_num=1)
            self.assertTrue(y.shape[0] == Y)
            self.assertTrue(y.shape[1] == max_out_shape[3])
            self.assertTrue(group_list.shape[0] == expert_num)
            self.assertTrue(group_list.shape[1] == 2)
            self.assertTrue(session_ids.shape[0] == Y)
            self.assertTrue(micro_batch_ids.shape[0] == Y)
            self.assertTrue(token_ids.shape[0] == Y)
            self.assertTrue(expert_offsets.shape[0] == Y)
            self.assertTrue(dynamic_scale.shape[0] == Y)
            self.assertTrue(actual_token_num.shape[0] == 1)


class TestFFN(TestCase):
    def test_npu_ffn_meta(self):
        with FakeTensorMode():
            x = torch.randn(1, 320, dtype=torch.float16).npu()
            w1 = torch.randn(320, 2560, dtype=torch.float16).npu()
            w2 = torch.randn(2560, 320, dtype=torch.float16).npu()
            activation = "gelu"
            res = torch_npu.npu_ffn(x, w1, w2, activation, inner_precise=1, output_dtype=torch.float16)
            self.assertTrue(x.shape == res.shape)


class TestNpuDynamicQuant(TestCase):
    def test_npu_dynamic_quant(self):
        with FakeTensorMode():
            input_npu = torch.randn((4, 2048, 1024)).npu().to(torch.float16)
            smooth_scales_npu = torch.randn((1024)).npu().to(torch.float16)

            output = torch.randn((4, 2048, 1024)).npu().to(torch.int8)
            scale = torch.randn((4, 2048)).npu().to(torch.float32)

            actual_output, actual_scale = torch_npu.npu_dynamic_quant(input_npu, smooth_scales=smooth_scales_npu)

            self.assertEqual(actual_output.dtype, output.dtype)
            self.assertEqual(actual_output.shape, output.shape)
            self.assertEqual(actual_output.device, output.device)
            self.assertEqual(actual_scale.dtype, scale.dtype)
            self.assertEqual(actual_scale.shape, scale.shape)
            self.assertEqual(actual_scale.device, scale.device)


class TestDynamicQuantAsymmetric(TestCase):
    def test_npu_dynamic_quant_asymmetric(self):
        with FakeTensorMode():
            input_npu = torch.randn((4, 2048, 1024)).npu().to(torch.float16)
            smooth_scales_npu = torch.randn((1024)).npu().to(torch.float16)

            output = torch.randn((4, 2048, 1024)).npu().to(torch.int8)
            scale = torch.randn((4, 2048)).npu().to(torch.float32)
            offset = torch.randn((4, 2048)).npu().to(torch.float32)

            actual_output, actual_scale, actual_offset = torch_npu.npu_dynamic_quant_asymmetric(input_npu, smooth_scales=smooth_scales_npu)

            self.assertEqual(actual_output.dtype, output.dtype)
            self.assertEqual(actual_output.shape, output.shape)
            self.assertEqual(actual_output.device, output.device)
            self.assertEqual(actual_scale.dtype, scale.dtype)
            self.assertEqual(actual_scale.shape, scale.shape)
            self.assertEqual(actual_scale.device, scale.device)
            self.assertEqual(actual_offset.dtype, offset.dtype)
            self.assertEqual(actual_offset.shape, offset.shape)
            self.assertEqual(actual_offset.device, offset.device)


class TestGroupedMatmul(TestCase):
    def test_npu_grouped_matmul_meta_0(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randn(256, 256, dtype=torch.float16).npu()
            x2 = torch.randn(1024, 256, dtype=torch.float16).npu()
            x3 = torch.randn(512, 1024, dtype=torch.float16).npu()
            x = [x1, x2, x3]
            w1 = torch.randn(256, 256, dtype=torch.float16).npu()
            w2 = torch.randn(256, 1024, dtype=torch.float16).npu()
            w3 = torch.randn(1024, 128, dtype=torch.float16).npu()
            w = [w1, w2, w3]
            b1 = torch.randn(256, dtype=torch.float16).npu()
            b2 = torch.randn(1024, dtype=torch.float16).npu()
            b3 = torch.randn(128, dtype=torch.float16).npu()
            b = [b1, b2, b3]
            group_list = None
            split_item = 0

            res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=-1)
            self.assertTrue(x[0].shape[0] == res[0].shape[0])
            self.assertTrue(x[1].shape[0] == res[1].shape[0])
            self.assertTrue(x[2].shape[0] == res[2].shape[0])
            self.assertTrue(w[0].shape[1] == res[0].shape[1])
            self.assertTrue(w[1].shape[1] == res[1].shape[1])
            self.assertTrue(w[2].shape[1] == res[2].shape[1])

    def test_npu_grouped_matmul_meta_1(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randn(1792, 1024, dtype=torch.float16).npu()
            x = [x1]
            w1 = torch.randn(1024, 256, dtype=torch.float16).npu()
            w2 = torch.randn(1024, 1024, dtype=torch.float16).npu()
            w3 = torch.randn(1024, 128, dtype=torch.float16).npu()
            w = [w1, w2, w3]
            b1 = torch.randn(256, dtype=torch.float16).npu()
            b2 = torch.randn(1024, dtype=torch.float16).npu()
            b3 = torch.randn(128, dtype=torch.float16).npu()
            b = [b1, b2, b3]
            group_list = [256, 1280, 1792]
            split_item = 1

            res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=-1)
            self.assertTrue(group_list[0] == res[0].shape[0])
            self.assertTrue(group_list[1] - group_list[0] == res[1].shape[0])
            self.assertTrue(group_list[2] - group_list[1] == res[2].shape[0])
            self.assertTrue(w[0].shape[1] == res[0].shape[1])
            self.assertTrue(w[1].shape[1] == res[1].shape[1])
            self.assertTrue(w[2].shape[1] == res[2].shape[1])

    def test_npu_grouped_matmul_meta_2(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randn(256, 256, dtype=torch.float16).npu()
            x2 = torch.randn(1024, 256, dtype=torch.float16).npu()
            x3 = torch.randn(512, 1024, dtype=torch.float16).npu()
            x = [x1, x2, x3]
            w1 = torch.randn(256, 256, dtype=torch.float16).npu()
            w2 = torch.randn(256, 256, dtype=torch.float16).npu()
            w3 = torch.randn(1024, 256, dtype=torch.float16).npu()
            w = [w1, w2, w3]
            b1 = torch.randn(256, dtype=torch.float16).npu()
            b2 = torch.randn(256, dtype=torch.float16).npu()
            b3 = torch.randn(256, dtype=torch.float16).npu()
            b = [b1, b2, b3]
            group_list = None
            split_item = 2

            res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
            dim0 = 0
            for xi in x:
                dim0 += xi.shape[0]
            self.assertTrue(dim0 == res[0].shape[0])
            self.assertTrue(w[0].shape[1] == res[0].shape[1])

    def test_npu_grouped_matmul_meta_3(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randn(1792, 1024, dtype=torch.float16).npu()
            x = [x1]
            w1 = torch.randn(1024, 256, dtype=torch.float16).npu()
            w2 = torch.randn(1024, 256, dtype=torch.float16).npu()
            w3 = torch.randn(1024, 256, dtype=torch.float16).npu()
            w = [w1, w2, w3]
            b1 = torch.randn(256, dtype=torch.float16).npu()
            b2 = torch.randn(256, dtype=torch.float16).npu()
            b3 = torch.randn(256, dtype=torch.float16).npu()
            b = [b1, b2, b3]
            group_list = [256, 1280, 1792]
            split_item = 3

            res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
            self.assertTrue(x[0].shape[0] == res[0].shape[0])
            self.assertTrue(w[0].shape[1] == res[0].shape[1])

    def test_npu_grouped_matmul_meta_4(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randn(1792, 1024, dtype=torch.float16).npu()
            x = [x1]
            w1 = torch.randn(3, 1024, 256, dtype=torch.float16).npu()
            w = [w1]
            b1 = torch.randn(256, dtype=torch.float16).npu()
            b2 = torch.randn(256, dtype=torch.float16).npu()
            b3 = torch.randn(256, dtype=torch.float16).npu()
            b = [b1, b2, b3]
            group_list = [256, 1280, 1792]
            split_item = 3

            res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
            self.assertTrue(x[0].shape[0] == res[0].shape[0])
            self.assertTrue(w[0].shape[2] == res[0].shape[1])

    def test_npu_grouped_matmul_meta_5(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randint(0, 10, (256, 256), dtype=torch.int8).npu()
            x2 = torch.randint(0, 10, (1024, 256), dtype=torch.int8).npu()
            x3 = torch.randint(0, 10, (512, 1024), dtype=torch.int8).npu()
            x = [x1, x2, x3]
            w1 = torch.randint(0, 10, (256, 256), dtype=torch.int32).npu()
            w2 = torch.randint(0, 10, (256, 1024), dtype=torch.int32).npu()
            w3 = torch.randint(0, 10, (1024, 128), dtype=torch.int32).npu()
            w = [w1, w2, w3]
            b1 = torch.randn(256, dtype=torch.float16).npu()
            b2 = torch.randn(1024, dtype=torch.float16).npu()
            b3 = torch.randn(128, dtype=torch.float16).npu()
            b = [b1, b2, b3]
            group_list = None
            split_item = 0

            res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=-1)
            self.assertTrue(x[0].shape[0] == res[0].shape[0])
            self.assertTrue(x[1].shape[0] == res[1].shape[0])
            self.assertTrue(x[2].shape[0] == res[2].shape[0])
            self.assertTrue((w[0].shape[1] * 8) == res[0].shape[1])
            self.assertTrue((w[1].shape[1] * 8) == res[1].shape[1])
            self.assertTrue((w[2].shape[1] * 8) == res[2].shape[1])

    def test_npu_grouped_matmul_meta_6(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randint(0, 10, (1792, 1024), dtype=torch.int8).npu()
            x = [x1]
            w1 = torch.randint(0, 10, (1024, 256), dtype=torch.int32).npu()
            w2 = torch.randint(0, 10, (1024, 1024), dtype=torch.int32).npu()
            w3 = torch.randint(0, 10, (1024, 128), dtype=torch.int32).npu()
            w = [w1, w2, w3]
            b1 = torch.randn(256, dtype=torch.float16).npu()
            b2 = torch.randn(1024, dtype=torch.float16).npu()
            b3 = torch.randn(128, dtype=torch.float16).npu()
            b = [b1, b2, b3]
            group_list = [256, 1280, 1792]
            split_item = 1

            res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
            self.assertTrue(group_list[0] == res[0].shape[0])
            self.assertTrue(group_list[1] - group_list[0] == res[1].shape[0])
            self.assertTrue(group_list[2] - group_list[1] == res[2].shape[0])
            self.assertTrue((w[0].shape[1] * 8) == res[0].shape[1])
            self.assertTrue((w[1].shape[1] * 8) == res[1].shape[1])
            self.assertTrue((w[2].shape[1] * 8) == res[2].shape[1])

    def test_npu_grouped_matmul_meta_7(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randint(0, 10, (256, 256), dtype=torch.int8).npu()
            x2 = torch.randint(0, 10, (1024, 256), dtype=torch.int8).npu()
            x3 = torch.randint(0, 10, (512, 1024), dtype=torch.int8).npu()

            x = [x1, x2, x3]
            w1 = torch.randint(0, 10, (256, 256), dtype=torch.int32).npu()
            w2 = torch.randint(0, 10, (256, 1024), dtype=torch.int32).npu()
            w3 = torch.randint(0, 10, (1024, 128), dtype=torch.int32).npu()

            w = [w1, w2, w3]
            b1 = torch.randn(256, dtype=torch.float16).npu()
            b2 = torch.randn(256, dtype=torch.float16).npu()
            b3 = torch.randn(256, dtype=torch.float16).npu()
            b = [b1, b2, b3]
            group_list = None
            split_item = 2

            res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
            dim0 = 0
            for xi in x:
                dim0 += xi.shape[0]
            self.assertTrue(dim0 == res[0].shape[0])
            self.assertTrue((w[0].shape[1] * 8) == res[0].shape[1])

    def test_npu_grouped_matmul_meta_8(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randint(0, 10, (1792, 1024), dtype=torch.int8).npu()

            x = [x1]
            w1 = torch.randint(0, 10, (1024, 256), dtype=torch.int32).npu()
            w2 = torch.randint(0, 10, (1024, 256), dtype=torch.int32).npu()
            w3 = torch.randint(0, 10, (1024, 256), dtype=torch.int32).npu()

            w = [w1, w2, w3]
            b1 = torch.randn(256, dtype=torch.float16).npu()
            b2 = torch.randn(256, dtype=torch.float16).npu()
            b3 = torch.randn(256, dtype=torch.float16).npu()
            b = [b1, b2, b3]
            group_list = [256, 1280, 1792]
            split_item = 3

            res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
            self.assertTrue(x[0].shape[0] == res[0].shape[0])
            self.assertTrue((w[0].shape[1] * 8) == res[0].shape[1])

    def test_npu_grouped_matmul_meta_9(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randint(2, 3, size=(1792, 1024), dtype=torch.int8).npu()
            x = [x1]
            w1 = torch.randint(2, 3, size=(3, 1024, 256), dtype=torch.int8).npu()
            w = [w1]
            scale1 = torch.randint(2, 3, size=(3, 256), dtype=torch.int64).npu()
            scale = [scale1]
            group_list = torch.tensor([256, 1280, 1792]).to(torch.int64).npu()
            split_item = 3

            res = torch_npu.npu_grouped_matmul(x, w, bias=None, scale=scale, group_list=group_list, split_item=split_item, group_type=0, output_dtype=torch.int32)
            self.assertTrue(x[0].shape[0] == res[0].shape[0])
            self.assertTrue(w[0].shape[2] == res[0].shape[1])

    def test_npu_grouped_matmul_meta_10(self): # 单多单fp32
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randn(1792, 1024, dtype=torch.float32).npu()
            x = [x1]
            w1 = torch.randn(1024, 256, dtype=torch.float32).npu()
            w2 = torch.randn(1024, 256, dtype=torch.float32).npu()
            w3 = torch.randn(1024, 256, dtype=torch.float32).npu()
            w = [w1, w2, w3]
            b1 = torch.randn(256, dtype=torch.float32).npu()
            b2 = torch.randn(256, dtype=torch.float32).npu()
            b3 = torch.randn(256, dtype=torch.float32).npu()
            b = [b1, b2, b3]
            group_list = [256, 1280, 1792]
            split_item = 3

            res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
            self.assertTrue(x[0].shape[0] == res[0].shape[0])
            self.assertTrue(w[0].shape[1] == res[0].shape[1])

    def test_npu_grouped_matmul_meta_11(self): # 单单单fp32
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randn(1792, 1024, dtype=torch.float32).npu()
            x = [x1]
            w1 = torch.randn(3, 1024, 256, dtype=torch.float32).npu()
            w = [w1]
            b1 = torch.randn(256, dtype=torch.float32).npu()
            b2 = torch.randn(256, dtype=torch.float32).npu()
            b3 = torch.randn(256, dtype=torch.float32).npu()
            b = [b1, b2, b3]
            group_list = [256, 1280, 1792]
            split_item = 3

            res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
            self.assertTrue(x[0].shape[0] == res[0].shape[0])
            self.assertTrue(w[0].shape[2] == res[0].shape[1])

    def test_npu_grouped_matmul_meta_12(self): # 单多单fp32
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randn(1792, 1024, dtype=torch.float32).npu()
            x = [x1]
            w1 = torch.randn(1024, 256, dtype=torch.float32).npu()
            w2 = torch.randn(1024, 256, dtype=torch.float32).npu()
            w3 = torch.randn(1024, 256, dtype=torch.float32).npu()
            w = [w1, w2, w3]
            b1 = torch.randn(256, dtype=torch.float32).npu()
            b2 = torch.randn(256, dtype=torch.float32).npu()
            b3 = torch.randn(256, dtype=torch.float32).npu()
            b = [b1, b2, b3]
            group_list = torch.tensor([256, 1280, 1792]).to(torch.int64).npu()
            split_item = 3
            group_type = 0

            res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=group_type)
            self.assertTrue(x[0].shape[0] == res[0].shape[0])
            self.assertTrue(w[0].shape[1] == res[0].shape[1])

    def test_npu_grouped_matmul_meta_950_fp8_1(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randint(2, 3, size=(16, 256), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
            x = [x1, ]
            w1 = torch.randint(2, 3, size=(1, 256, 32), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
            w = [w1, ]
            group_list = torch.tensor([16, ]).to(torch.int64).npu()
            split_item = 2
            scale1 = torch.randint(2, 3, size=(1, 32), dtype=torch.int64).npu()
            scale = [scale1, ]
            res = torch_npu.npu_grouped_matmul(x, w, bias=None, scale=scale, group_list=group_list, split_item=split_item, group_type=0, output_dtype=torch.float16)
            dim0 = x1.shape[0]
            dim1 = w1.shape[2]
            self.assertTrue(dim0 == res[0].shape[0])
            self.assertTrue(dim1 == res[0].shape[1])

    def test_npu_grouped_matmul_meta_950_fp8_group_type_2(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            x1 = torch.randint(2, 3, size=(128, 4096), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
            x = [x1.t(), ]
            w1 = torch.randint(2, 3, size=(128, 7168), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
            w = [w1, ]
            group_list = torch.tensor([32, 32, 64]).to(torch.int64).npu()
            split_item = 2
            scale1 = torch.randn(3, 56, dtype=torch.float32).npu()
            scale2 = torch.randn(3, 4096, dtype=torch.float32).npu()
            scale = [scale1, ]
            per_token_scale = [scale2.t(), ]
            res = torch_npu.npu_grouped_matmul(x, w, bias=None, scale=scale, per_token_scale=per_token_scale, group_list=group_list, split_item=split_item, group_type=2, group_list_type=0, output_dtype=torch.float16)
            dim0 = x1.shape[1]
            dim1 = w1.shape[1]
            self.assertTrue(res[0].dtype == torch.float16)
            self.assertTrue(res[0].shape[0] == 3)
            self.assertTrue(len(res[0].shape) == 3)
            self.assertTrue(dim0 == res[0].shape[1])
            self.assertTrue(dim1 == res[0].shape[2])

    def test_npu_grouped_matmul_meta_weight_float32_transpose(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            K = 1024
            M = 16
            N = 256
            E = 2
            group_size = 32
            x1 = torch.randint(0, 10, (M, K), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
            x = [x1]
            w1 = torch.randint(0, 10, (E, N, K), dtype=torch.float32).npu()
            w = [w1.transpose(-1, -2)]
            antiQuant_Scale1 = torch.randint(0, 10, (E, N, K//group_size), dtype=torch.uint8).npu()
            antiQuant_Scale = [antiQuant_Scale1.transpose(-1, -2)]
            per_token_scale1 = torch.randint(0, 10, (M, K//group_size), dtype=torch.uint8).npu()
            per_token_scale = [per_token_scale1]
            group_list = [8, 8]
            split_item = 3

            res = torch_npu.npu_grouped_matmul(x, w, antiquant_scale=antiQuant_Scale, per_token_scale=per_token_scale, group_list_type=1, group_list=group_list, split_item=split_item, group_type=0)
            self.assertTrue(x1.shape[0] == res[0].shape[0])
            self.assertTrue((w1.shape[1]) == res[0].shape[1])

    def test_npu_grouped_matmul_meta_quant_empty_tensor_m0(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            K = 128
            M = 0
            N = 128
            E = 2
            x1 = torch.randint(0, 10, (M, K), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
            x2 = torch.randint(0, 10, (E, N, K), dtype=torch.int8).view(torch.float8_e5m2).npu()
            x = [x1,]
            w = [x2,]
            group_list = torch.tensor([0, 0]).to(torch.int64).npu()
            split_item = 2
            scale2 = torch.randint(0, 10, (E, N), dtype=torch.int64).npu()
            scale = [scale2,]
            res = torch_npu.npu_grouped_matmul(x, w, bias=None, scale=scale, group_list=group_list, split_item=split_item, group_type=0, group_list_type=0, output_dtype=torch.float16)
            self.assertTrue(len(res[0]) == 0)

    def test_npu_grouped_matmul_meta_quant_empty_tensor_k0(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            K = 0
            M = 128
            N = 128
            E = 2
            x1 = torch.randint(0, 10, (M, K), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
            x2 = torch.randint(0, 10, (E, K, N), dtype=torch.int8).view(torch.float8_e5m2).npu()
            x = [x1,]
            w = [x2,]
            group_list = torch.tensor([0, 0]).to(torch.int64).npu()
            split_item = 2
            scale2 = torch.randint(0, 10, (E, N), dtype=torch.int64).npu()
            scale = [scale2,]
            with self.assertRaises(RuntimeError):
                torch_npu.npu_grouped_matmul(x, w, bias=None, scale=scale, group_list=group_list, split_item=split_item, group_type=0, group_list_type=0, output_dtype=torch.float16)

class TestQuantMatmul(TestCase):
    def test_npu_quant_matmul_meta(self):
        with FakeTensorMode():
            x1 = torch.randint(-1, 1, (1, 1, 1024), dtype=torch.int8).npu()
            x2 = torch.randint(-1, 1, (1, 1024, 100), dtype=torch.int8).npu()
            expect_ret = torch.randint(-1, 1, (1, 1, 100), dtype=torch.int8).npu()
            scale = torch.randn(1, dtype=torch.float32).npu()
            offset = torch.randn(1, dtype=torch.float32).npu()
            bias = torch.randint(-1, 1, (1, 1, 100), dtype=torch.int32).npu()
            res = torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias)
            self.assertTrue(expect_ret.shape == res.shape)
            self.assertTrue(expect_ret.dtype == res.dtype)

            expect_ret_bf16 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.bfloat16).npu()
            scale_bf16 = torch.randn(1, dtype=torch.bfloat16).npu()
            bias_bf16 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.bfloat16).npu()
            res_bf16 = torch_npu.npu_quant_matmul(x1, x2, scale_bf16, offset=None, bias=bias_bf16, output_dtype=torch.bfloat16)
            self.assertTrue(expect_ret_bf16.shape == res_bf16.shape)
            self.assertTrue(expect_ret_bf16.dtype == res_bf16.dtype)

            expect_ret_fp16 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.float16).npu()
            bias_fp32 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.float32).npu()
            pertoken_scale = torch.randn(1, dtype=torch.float32).npu()
            res_fp16 = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, pertoken_scale=pertoken_scale,
                                                  bias=bias_fp32, output_dtype=torch.float16)
            self.assertTrue(expect_ret_fp16.shape == res_fp16.shape)
            self.assertTrue(expect_ret_fp16.dtype == res_fp16.dtype)

            expect_ret_int32 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.int32).npu()
            bias_int32 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.int32).npu()
            res_int32 = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, pertoken_scale=None,
                                                  bias=bias_int32, output_dtype=torch.int32)
            self.assertTrue(expect_ret_int32.shape == res_int32.shape)
            self.assertTrue(expect_ret_int32.dtype == res_int32.dtype)

            x1 = torch.randint(-1, 1, (16, 8), dtype=torch.int32).npu()
            x2 = torch.randint(-1, 1, (64, 5), dtype=torch.int32).npu()
            expect_ret = torch.randint(-1, 1, (16, 40), dtype=torch.float16).npu()
            scale = torch.randn(1, dtype=torch.float32).npu()
            bias = torch.randint(-1, 1, (40,), dtype=torch.int32).npu()
            res = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, bias=bias, output_dtype=torch.float16)
            self.assertTrue(expect_ret.shape == res.shape)
            self.assertTrue(expect_ret.dtype == res.dtype)

            x1 = torch.randint(-1, 1, (16, 8), dtype=torch.int32).npu()
            x2 = torch.randint(-1, 1, (64, 5), dtype=torch.int32).npu()
            expect_ret = torch.randint(-1, 1, (16, 40), dtype=torch.bfloat16).npu()
            pertoken_scale = torch.randn(16, dtype=torch.float32).npu()
            scale = torch.randn(1, dtype=torch.float32).npu()
            bias = torch.randint(-1, 1, (40,), dtype=torch.int32).npu()
            res = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, pertoken_scale=pertoken_scale, bias=bias, output_dtype=torch.bfloat16)
            self.assertTrue(expect_ret.shape == res.shape)
            self.assertTrue(expect_ret.dtype == res.dtype)

            x1 = torch.randint(-8, 8, (1, 8192), dtype=torch.int8).npu()
            x2 = torch.randint(-8, 8, (8192, 128), dtype=torch.int32).npu()
            expect_ret = torch.randn([1, 128 * 8], dtype=torch.float16).npu()
            y_offset = torch.randn([128 * 8, ], dtype=torch.float32).npu()
            x1_scale = torch.randn([1, 1], dtype=torch.float32).npu()
            x2_scale = torch.randint(1, 3, [8192 // 256, 128 * 8], dtype=torch.int64).npu()
            group_size_list = [0, 0, 256]
            res = torch_npu.npu_quant_matmul(x1, x2, x2_scale, offset=y_offset, pertoken_scale=x1_scale,
                                             bias=None, output_dtype=torch.float16, group_sizes=group_size_list)
            self.assertTrue(expect_ret.shape == res.shape)
            self.assertTrue(expect_ret.dtype == res.dtype)

            x1 = torch.randint(-1, 1, (1, 64), dtype=torch.float8_e4m3fn).npu()
            x2 = torch.randn((64, 16), dtype=torch.float32).npu()
            scale = torch.randint(-1, 1, (2, 128), dtype=torch.uint8).npu()
            y_scale = torch.randint(-1, 1, (1, 128), dtype=torch.int64).npu()
            expect_ret = torch.randn((1, 128), dtype=torch.bfloat16).npu()
            res = torch_npu.npu_quant_matmul(x1, x2, scale, pertoken_scale=None, scale_dtype=None, pertoken_scale_dtype=None, y_scale=y_scale, output_dtype=torch.bfloat16, group_sizes=[0, 0, 32])
            self.assertTrue(expect_ret.shape == res.shape)
            self.assertTrue(expect_ret.dtype == res.dtype)

            x1 = torch.randint(-8, 8, (1, 8192), dtype=torch.int8).npu()
            x2 = torch.randint(-8, 8, (8192, 128), dtype=torch.int8).npu()
            expect_ret = torch.randn([1, 128], dtype=torch.float16).npu()
            scale = torch.randn(1, dtype=torch.float32).npu()
            res = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, pertoken_scale=None,
                                             bias=None, output_dtype=torch_npu.hifloat8, x1_dtype=torch_npu.hifloat8,
                                             x2_dtype=torch_npu.hifloat8)
            self.assertTrue(expect_ret.shape == res.shape)


class TestQuantMatmulReduceSum(TestCase):
    def test_quant_matmul_reduce_sum(self):
        with FakeTensorMode():
            b, m, n, k = (2, 3, 4, 5)
            x1 = torch.randint(-1, 1, (b, m, k), dtype=torch.int8).npu()
            x2 = torch.randint(-1, 1, (b, k, n), dtype=torch.int8).npu()
            x1_scale = torch.ones((b, m), dtype=torch.float32).npu()
            x2_scale = torch.ones((n,), dtype=torch.bfloat16).npu()
            y = torch_npu.npu_quant_matmul_reduce_sum(x1, x2, x1_scale=x1_scale, x2_scale=x2_scale)
            self.assertTrue(y.shape[0] == m)
            self.assertTrue(y.shape[1] == n)
            self.assertTrue(y.dtype == torch.bfloat16)


class TestQuantMatmulGelu(TestCase):
    @unittest.skip("Skipping due to outdated CANN version; please update CANN to the latest version and remove this skip")
    def test_npu_quant_matmul_gelu_meta(self):
        with FakeTensorMode():
            # Test A8W8 scenario
            m, k, n = 128, 256, 512
            x1 = torch.randint(-5, 5, (m, k), dtype=torch.int8).npu()
            x2 = torch.randint(-5, 5, (k, n), dtype=torch.int8).npu()
            x1_scale = torch.randn(m, dtype=torch.float32).abs() * 0.01
            x2_scale = torch.randn(n, dtype=torch.float32).abs() * 0.01

            expect_ret = torch.randn((m, n), dtype=torch.float16).npu()
            res = torch_npu.npu_quant_matmul_gelu(x1, x2, x1_scale.npu(), x2_scale.npu(), approximate="gelu_tanh")
            self.assertTrue(expect_ret.shape == res.shape)
            self.assertTrue(expect_ret.dtype == res.dtype)

            # Test A8W8 with bias
            bias = torch.randn(n, dtype=torch.float32) * 0.1
            res_bias = torch_npu.npu_quant_matmul_gelu(x1, x2, x1_scale.npu(), x2_scale.npu(), bias=bias.npu(), approximate="gelu_erf")
            self.assertTrue(expect_ret.shape == res_bias.shape)
            self.assertTrue(expect_ret.dtype == res_bias.dtype)

            # Test A8W8 with BF16 output
            x2_scale_bf16 = torch.randn(n, dtype=torch.bfloat16).abs() * 0.01
            expect_ret_bf16 = torch.randn((m, n), dtype=torch.bfloat16).npu()
            res_bf16 = torch_npu.npu_quant_matmul_gelu(x1, x2, x1_scale.npu(), x2_scale_bf16.npu(), approximate="gelu_tanh")
            self.assertTrue(expect_ret_bf16.shape == res_bf16.shape)
            self.assertTrue(expect_ret_bf16.dtype == res_bf16.dtype)

            # Test A8W8 with batch dimensions
            batch, m, k, n = 4, 64, 128, 256
            x1_batch = torch.randint(-5, 5, (batch, m, k), dtype=torch.int8).npu()
            x2_batch = torch.randint(-5, 5, (batch, k, n), dtype=torch.int8).npu()
            x1_scale_batch = torch.randn(m, dtype=torch.float32).abs() * 0.01
            x2_scale_batch = torch.randn(n, dtype=torch.float32).abs() * 0.01
            expect_ret_batch = torch.randn((batch, m, n), dtype=torch.float16).npu()
            res_batch = torch_npu.npu_quant_matmul_gelu(x1_batch, x2_batch, x1_scale_batch.npu(), x2_scale_batch.npu())
            self.assertTrue(expect_ret_batch.shape == res_batch.shape)
            self.assertTrue(expect_ret_batch.dtype == res_batch.dtype)

            # Test A4W4 with int32 (packed INT4)
            m, k, n = 64, 128, 256
            k_packed = k // 8
            n_packed = n // 8
            x1_int32 = torch.randint(-8, 8, (m, k_packed), dtype=torch.int32).npu()
            x2_int32 = torch.randint(-8, 8, (k_packed, n_packed), dtype=torch.int32).npu()
            x1_scale_int32 = torch.randn(m, dtype=torch.float32).abs() * 0.01
            x2_scale_int32 = torch.randn(n, dtype=torch.float32).abs() * 0.01
            expect_ret_int32 = torch.randn((m, n), dtype=torch.float16).npu()
            res_int32 = torch_npu.npu_quant_matmul_gelu(x1_int32, x2_int32, x1_scale_int32.npu(), x2_scale_int32.npu(), approximate="gelu_tanh")
            self.assertTrue(expect_ret_int32.shape == res_int32.shape)
            self.assertTrue(expect_ret_int32.dtype == res_int32.dtype)

            # Test A4W4 with int32 and bias
            bias_int32 = torch.randint(-5, 5, (n,), dtype=torch.int32)
            res_int32_bias = torch_npu.npu_quant_matmul_gelu(x1_int32, x2_int32, x1_scale_int32.npu(), x2_scale_int32.npu(),
                                                             bias=bias_int32.npu(), approximate="gelu_erf")
            self.assertTrue(expect_ret_int32.shape == res_int32_bias.shape)
            self.assertTrue(expect_ret_int32.dtype == res_int32_bias.dtype)


class TestMatmulCompressDequant(TestCase):
    @unittest.skip("Skip until CANN supports aclnnMatmulCompressDequant; do not execute")
    def test_npu_matmul_compress_dequant_meta(self):
        with FakeTensorMode():
            m, k, n = 16, 256, 128
            x1 = torch.ones((m, k), dtype=torch.int8).npu()
            # x2 为 1 维压缩权重,长度为小于 k*n 的 tensor
            x2_len = k * n - 1
            x2 = torch.zeros((x2_len,), dtype=torch.int8).npu()
            compress_index = torch.zeros(8, dtype=torch.int8).npu()
            bias = torch.zeros((1, n), dtype=torch.int32).npu()
            scale = torch.ones((1, n), dtype=torch.float32).npu()
            scale_uint64 = torch_npu.npu_trans_quant_param(scale).to(torch.uint64)
            res = torch_npu.npu_matmul_compress_dequant(x1, x2, compress_index, bias, scale_uint64)
            self.assertEqual(res.shape, (m, n))
            self.assertEqual(res.dtype, torch.float16)


class TestRecurrentGatedDeltaRule(TestCase):
    def test_recurrent_gated_delta_rule(self):
        with FakeTensorMode():
            (b, mtp, nk, nv, dk, dv) = (64, 2, 4, 8, 128, 128)

            actual_seq_lengths = (torch.ones(b) * mtp).npu().to(torch.int32)
            T = b * mtp
            state = torch.rand((T, nv, dv, dk), dtype=torch.bfloat16).npu()
            query = torch.rand((T, nk, dk), dtype=torch.bfloat16).npu()
            key = torch.rand((T, nk, dk), dtype=torch.bfloat16).npu()
            value = torch.rand((T, nv, dv), dtype=torch.bfloat16).npu()
            g = torch.rand((T, nv), dtype=torch.float32).npu()
            beta = torch.rand((T, nv), dtype=torch.bfloat16).npu()
            ssm_state_indices = (torch.arange(T).npu()).to(torch.int32)
            query = torch.nn.functional.normalize(query, p=2, dim=-1)
            key = torch.nn.functional.normalize(key, p=2, dim=-1)
            scale = 0.5
            num_accepted_tokens = torch.randint(1, mtp + 1, (b,)).npu().to(torch.int32)

            state_copy = state.clone()
            out = torch_npu.npu_recurrent_gated_delta_rule(query, key, value, state_copy, beta=beta, scale=scale, actual_seq_lengths=actual_seq_lengths, ssm_state_indices=ssm_state_indices, g=g, num_accepted_tokens=num_accepted_tokens)
            expect_out_shape = torch.randn(T, nv, dv, dtype = torch.bfloat16).npu()
            self.assertTrue(out.shape == expect_out_shape.shape)
            self.assertTrue(out.dtype == torch.bfloat16)

            state_copy = state.clone()
            out_inplace, state_out_inplace = torch_npu.npu_recurrent_gated_delta_rule_functional(query, key, value, state, beta=beta, scale=scale, actual_seq_lengths=actual_seq_lengths, ssm_state_indices=ssm_state_indices, g=g, num_accepted_tokens=num_accepted_tokens)

            expect_out_inplace = torch.randn(T, nv, dv, dtype = torch.bfloat16).npu()
            expect_state_inplace = torch.randn(T, nv, dv, dk, dtype = torch.bfloat16).npu()

            self.assertTrue(out_inplace.shape == expect_out_inplace.shape)
            self.assertTrue(state_out_inplace.shape == expect_state_inplace.shape)

            self.assertTrue(state_out_inplace.dtype == torch.bfloat16)
            self.assertTrue(out_inplace.dtype == torch.bfloat16)


class TestTranQuantParam(TestCase):
    def test_npu_trans_quant_param_meta(self):
        with FakeTensorMode():
            test_1_expect_ret = torch.randint(-1, 1, (4,), dtype=torch.int64).npu()
            test_1_scale = torch.randn(1, dtype=torch.float32).npu()
            test_1_offset = torch.randn(4, dtype=torch.float32).npu()
            res = torch_npu.npu_trans_quant_param(test_1_scale, test_1_offset)
            self.assertTrue(res.shape == test_1_expect_ret.shape)
            self.assertTrue(res.dtype == test_1_expect_ret.dtype)

            test_2_expect_ret = torch.randint(-1, 1, (1, 4), dtype=torch.int64).npu()
            test_2_scale = torch.randn(1, 4, dtype=torch.float32).npu()
            test_2_offset = torch.randn(1, 4, dtype=torch.float32).npu()
            res = torch_npu.npu_trans_quant_param(test_2_scale, test_2_offset)
            self.assertTrue(res.shape == test_2_expect_ret.shape)
            self.assertTrue(res.dtype == test_2_expect_ret.dtype)


class TestAttentionUpdate(TestCase):
    def test_npu_attention_update_meta(self):
        with FakeTensorMode():
            N, H, K = 8, 64, 2
            update_type = 0
            dtype = torch.float32
            lse_list = [torch.randn([N], dtype=dtype).npu() for _ in range(K)]
            local_out_list = [torch.randn([N, H], dtype=dtype).npu() for _ in range(K)]
            out, lse_out = torch_npu.npu_attention_update(lse_list, local_out_list, update_type)
            self.assertTrue(out.shape == torch.Size([N, H]))
            self.assertTrue(out.dtype == dtype)


class TestRingAttentionUpdate(TestCase):
    def test_npu_ring_attention_update_meta_sbh(self):
        with FakeTensorMode():
            prev_attn_out = torch.randn((4, 2, 32), dtype=torch.float16).npu()
            cur_attn_out = torch.randn((4, 2, 32), dtype=torch.float16).npu()
            prev_softmax_max = torch.randn((2, 2, 4, 8), dtype=torch.float32).abs().npu()
            prev_softmax_sum = torch.randn((2, 2, 4, 8), dtype=torch.float32).abs().npu()
            cur_softmax_max = torch.randn((2, 2, 4, 8), dtype=torch.float32).abs().npu()
            cur_softmax_sum = torch.randn((2, 2, 4, 8), dtype=torch.float32).abs().npu()

            attn_out, softmax_max, softmax_sum = torch_npu.npu_ring_attention_update(
                prev_attn_out, prev_softmax_max, prev_softmax_sum,
                cur_attn_out, cur_softmax_max, cur_softmax_sum)
            self.assertEqual(attn_out.shape, prev_attn_out.shape)
            self.assertEqual(attn_out.dtype, prev_attn_out.dtype)
            self.assertEqual(attn_out.device.type, "npu")
            self.assertEqual(softmax_max.shape, prev_softmax_max.shape)
            self.assertEqual(softmax_max.dtype, torch.float32)
            self.assertEqual(softmax_sum.shape, prev_softmax_sum.shape)
            self.assertEqual(softmax_sum.dtype, torch.float32)

    def test_npu_ring_attention_update_meta_tnd(self):
        with FakeTensorMode():
            prev_attn_out = torch.randn((5, 2, 64), dtype=torch.bfloat16).npu()
            cur_attn_out = torch.randn((5, 2, 64), dtype=torch.bfloat16).npu()
            prev_softmax_max = torch.randn((5, 2, 8), dtype=torch.float32).abs().npu()
            prev_softmax_sum = torch.randn((5, 2, 8), dtype=torch.float32).abs().npu()
            cur_softmax_max = torch.randn((5, 2, 8), dtype=torch.float32).abs().npu()
            cur_softmax_sum = torch.randn((5, 2, 8), dtype=torch.float32).abs().npu()
            actual_seq_qlen = torch.tensor([2, 5], dtype=torch.int64).npu()

            attn_out, softmax_max, softmax_sum = torch_npu.npu_ring_attention_update(
                prev_attn_out, prev_softmax_max, prev_softmax_sum,
                cur_attn_out, cur_softmax_max, cur_softmax_sum,
                actual_seq_qlen=actual_seq_qlen, input_layout="TND", input_softmax_layout="TND")
            self.assertEqual(attn_out.shape, prev_attn_out.shape)
            self.assertEqual(attn_out.dtype, prev_attn_out.dtype)
            self.assertEqual(attn_out.device.type, "npu")
            self.assertEqual(softmax_max.shape, prev_softmax_max.shape)
            self.assertEqual(softmax_sum.shape, prev_softmax_sum.shape)


class TestAntiQuant(TestCase):
    @unittest.skipIf(torch.__version__ < '2.1.0',
                     "OP `AntiQuant` is supported on torch v2.1 and above, skip this test for torch version below 2.1")
    def test_npu_anti_quant_meta(self):
        with FakeTensorMode():
            x = torch.randint(low=-128, high=127, size=(20, 100), dtype=torch.int8).npu()
            scale = torch.randn(100, dtype=torch.float).npu()
            offset = torch.randn(100, dtype=torch.float).npu()
            dstType = torch.float16
            res = torch_npu.npu_anti_quant(x, scale, offset=offset, dst_dtype=dstType)

            self.assertTrue(x.shape == res.shape)
            x = x.to(dstType)
            self.assertTrue(x.numel() * x.element_size() == res.numel() * res.element_size())

class TestNpuKroneckerQuant(TestCase):
    def test_npu_kronecker_quant_meta(self):
        with FakeTensorMode():
            x = torch.randn(16, 64, 64).half().npu()
            kronecker_p1 = torch.randn(64, 64).half().npu()
            kronecker_p2 = torch.randn(64, 64).half().npu()
            expect_out = torch.randint(low=-128, high=127, size=(16, 64, 8), dtype=torch.int32).npu()
            expect_quant_scale = torch.randn(16, dtype=torch.float).npu()
            out, quant_scale = torch_npu.npu_kronecker_quant(x, kronecker_p1, kronecker_p2)

            self.assertEqual(out.dtype, torch.int32)
            self.assertEqual(quant_scale.dtype, torch.float)
            self.assertEqual(out.shape, expect_out.shape)
            self.assertEqual(quant_scale.shape, expect_quant_scale.shape)

class TestNpuLinear(TestCase):
    def test_npu_linear_meta(self):
        with FakeTensorMode():
            npu_input1 = torch.randn(16, 128).npu()
            npu_input2 = torch.randn(32, 128).npu()
            npu_bias = torch.randn(32,).npu()
            result = torch_npu.npu_linear(npu_input1, npu_input2, npu_bias)

            self.assertEqual(result.dtype, npu_input1.dtype)
            self.assertEqual(result.shape, torch.nn.functional.linear(npu_input1, npu_input2, npu_bias).shape)


class TestMoeFinalizeRouting(TestCase):
    def test_npu_moe_finalize_routing_meta(self):
        with FakeTensorMode():
            num_rows = 50
            top_k = 4
            token_len = 10
            expert_num = 16
            expanded_permuted_rows = torch.randn(num_rows * top_k, token_len).to(torch.float32)
            skip1 = torch.randn(num_rows, token_len).to(torch.float32)
            skip2_optional = torch.randn(num_rows, token_len).to(torch.float32)
            bias = torch.randn(num_rows, top_k).to(torch.float32)
            scales = torch.randn(num_rows, top_k).to(torch.float32)
            expanded_src_to_dst_row = torch.arange(num_rows * top_k).to(torch.int32)
            expert_for_source_row = torch.randint(low=0, high=expert_num, size=(num_rows, top_k)).to(torch.int32)

            result = torch_npu.npu_moe_finalize_routing(expanded_permuted_rows, skip1, skip2_optional, bias, scales,
                                                        expanded_src_to_dst_row, expert_for_source_row)

            self.assertTrue(result.shape == skip1.shape)
            self.assertTrue(result.dtype == skip1.dtype)


class TestGMMFinalizeRouting(TestCase):
    def test_npu_grouped_matmul_finalise_routing_meta(self):
        with FakeTensorMode():
            m, k, n, batch, topK, group_num = 576, 2048, 7168, 72, 8, 8
            x = torch.randint(-10, 10, (m, k), dtype=torch.int8)
            weight = torch.randint(-10, 10, (group_num, k, n), dtype=torch.int8)
            scale = torch.normal(0, 0.01, (group_num, n), dtype=torch.float32)
            pertoken_scale = torch.normal(0, 0.01, (m, ), dtype=torch.float32)
            group_list = torch.tensor([batch] * group_num, dtype=torch.float32)
            logit_ori = torch.normal(0, 0.1, (batch, group_num), dtype=torch.float32)
            routing = torch.argsort(logit_ori, 1)[:, -topK:]
            shared_input = torch.normal(0, 0.1, (batch // 4, n), dtype=torch.bfloat16)
            logit = torch.nn.functional.softmax(
                logit_ori[torch.arange(batch).reshape(-1, 1).repeat(1, topK), routing],
                dim=1,
                dtype=torch.float32
            ).reshape(m)
            row_index = (torch.argsort(routing.reshape(-1)) // topK).to(torch.int64)
            shared_input_offset = batch // 2
            output_bs = batch
            result = torch_npu.npu_grouped_matmul_finalize_routing(
                x.npu(), weight.npu(), group_list.npu(), scale=scale.npu(),
                pertoken_scale=pertoken_scale.npu(), shared_input=shared_input.npu(),
                logit=logit.npu(), row_index=row_index.npu(),
                shared_input_offset=shared_input_offset, output_bs=output_bs
            ).to("cpu")
            expect_ret = torch.normal(0, 0.1, (output_bs, n), dtype=torch.float32)
            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)

    def test_npu_grouped_matmul_finalise_routing_a8w4_meta(self):
        with FakeTensorMode():
            m, k, n, batch, topK, group_num = 8, 2048, 7168, 1, 8, 8
            quantGroupSize = k
            x = torch.randint(-10, 10, (m, k), dtype=torch.int8)
            weight = torch.randint(-10, 10, (group_num, k, n // 8), dtype=torch.int32)
            scale_np = np.random.normal(0, 0.01, (group_num, 1, n)).astype(np.float32)
            perGroupScale = np.ones([group_num, k // quantGroupSize, n]).astype(np.float32)
            scaleUint32 = (scale_np * perGroupScale).astype(np.float16).astype(np.float32)
            scaleUint32.dtype = np.uint32
            scaleUint64 = np.zeros((group_num, k // quantGroupSize, n * 2), dtype=np.uint32)
            scaleUint64[..., ::2] = scaleUint32
            scaleUint64.dtype = np.int64
            scale = torch.from_numpy(scaleUint64)
            bias = torch.normal(0, 0.01, (group_num, n), dtype=torch.float32)
            offset = torch.randint(-5, 5, (group_num, k // quantGroupSize, n), dtype=torch.float32)
            pertoken_scale = torch.normal(0, 0.01, (m, ), dtype=torch.float32)
            group_list = torch.tensor([batch] * group_num, dtype=torch.float32)
            logit_ori = torch.normal(0, 0.1, (batch, group_num), dtype=torch.float32)
            routing = torch.argsort(logit_ori, 1)[:, -topK:]
            shared_input = torch.normal(0, 0.1, (batch // 4, n), dtype=torch.bfloat16)
            logit = torch.nn.functional.softmax(
                logit_ori[torch.arange(batch).reshape(-1, 1).repeat(1, topK), routing],
                dim=1,
                dtype=torch.float32
            ).reshape(m)
            row_index = (torch.argsort(routing.reshape(-1)) // topK).to(torch.int64)
            shared_input_offset = batch // 2
            output_bs = batch
            result = torch_npu.npu_grouped_matmul_finalize_routing(
                x.npu(), weight.npu(), group_list.npu(), scale=scale.npu(),
                bias=bias.npu(), offset=offset.npu(),
                pertoken_scale=pertoken_scale.npu(), shared_input=shared_input.npu(),
                logit=logit.npu(), row_index=row_index.npu(),
                shared_input_offset=shared_input_offset, output_bs=output_bs
            ).to("cpu")
            expect_ret = torch.normal(0, 0.1, (output_bs, n), dtype=torch.float32)
            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)

    def test_npu_grouped_matmul_finalise_routing_a8w4_with_tuningconfig_meta(self):
        with FakeTensorMode():
            m, k, n, batch, topK, group_num = 8, 2048, 7168, 1, 8, 8
            quantGroupSize = k
            x = torch.randint(-10, 10, (m, k), dtype=torch.int8)
            weight = torch.randint(-10, 10, (group_num, k, n // 8), dtype=torch.int32)
            scale_np = np.random.normal(0, 0.01, (group_num, 1, n)).astype(np.float32)
            perGroupScale = np.ones([group_num, k // quantGroupSize, n]).astype(np.float32)
            scaleUint32 = (scale_np * perGroupScale).astype(np.float16).astype(np.float32)
            scaleUint32.dtype = np.uint32
            scaleUint64 = np.zeros((group_num, k // quantGroupSize, n * 2), dtype=np.uint32)
            scaleUint64[..., ::2] = scaleUint32
            scaleUint64.dtype = np.int64
            scale = torch.from_numpy(scaleUint64)
            bias = torch.normal(0, 0.01, (group_num, n), dtype=torch.float32)
            offset = torch.randint(-5, 5, (group_num, k // quantGroupSize, n), dtype=torch.float32)
            pertoken_scale = torch.normal(0, 0.01, (m, ), dtype=torch.float32)
            group_list = torch.tensor([batch] * group_num, dtype=torch.float32)
            logit_ori = torch.normal(0, 0.1, (batch, group_num), dtype=torch.float32)
            routing = torch.argsort(logit_ori, 1)[:, -topK:]
            shared_input = torch.normal(0, 0.1, (batch // 4, n), dtype=torch.bfloat16)
            logit = torch.nn.functional.softmax(
                logit_ori[torch.arange(batch).reshape(-1, 1).repeat(1, topK), routing],
                dim=1,
                dtype=torch.float32
            ).reshape(m)
            row_index = (torch.argsort(routing.reshape(-1)) // topK).to(torch.int64)
            shared_input_offset = batch // 2
            output_bs = batch
            result = torch_npu.npu_grouped_matmul_finalize_routing(
                x.npu(), weight.npu(), group_list.npu(), scale=scale.npu(),
                bias=bias.npu(), offset=offset.npu(),
                pertoken_scale=pertoken_scale.npu(), shared_input=shared_input.npu(),
                logit=logit.npu(), row_index=row_index.npu(),
                shared_input_offset=shared_input_offset, output_bs=output_bs
            ).to("cpu")
            expect_ret = torch.normal(0, 0.1, (output_bs, n), dtype=torch.float32)
            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)

    @unittest.skip("Skipping due to outdated CANN version; please update CANN to the latest version and remove this skip.")
    def test_npu_grouped_matmul_finalise_routing_sharedinput_none_grouplist_cumsum_meta(self):
        with FakeTensorMode():
            m, k, n, batch, topK, group_num, shared_input_scale = 576, 2048, 7168, 72, 8, 8, 1
            x = torch.randint(-10, 10, (m, k), dtype=torch.int8)
            weight = torch.randint(-10, 10, (group_num, k, n), dtype=torch.int8)
            scale = torch.normal(0, 0.01, (group_num, n), dtype=torch.float32)
            pertoken_scale = torch.normal(0, 0.01, (m, 1), dtype=torch.float32)
            group_list = torch.tensor([batch] * group_num, dtype=torch.int64)
            logit_ori = torch.normal(0, 0.1, (batch, group_num), dtype=torch.float32)
            routing = torch.argsort(logit_ori, 1)[:, -topK:]
            logit = F.softmax(
                logit_ori[torch.arange(batch).reshape(-1, 1).repeat(1, topK), routing],
                dim=1,
                dtype=torch.float32
            ).reshape(m)
            row_index = (torch.argsort(routing.reshape(-1)) // topK).to(torch.int64)
            shared_input_offset = batch // 2
            output_bs = batch
            group_list_type = 0
            group_list = torch.cumsum(group_list, dim=0)
            weightNz = torch_npu.npu_format_cast(weight.npu(), 29)
            pertoken_scale = pertoken_scale.reshape(m)
            result = torch_npu.npu_grouped_matmul_finalize_routing(
                x.npu(), weightNz, group_list.npu(), scale=scale.npu(),
                pertoken_scale=pertoken_scale.npu(), shared_input=None,
                logit=logit.npu(), row_index=row_index.npu(),
                shared_input_offset=shared_input_offset, output_bs=output_bs, group_list_type=group_list_type
            ).to("cpu")
            expect_ret = torch.normal(0, 0.1, (output_bs, n), dtype=torch.float32)
            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)


class TestTransposeBatchMatmul(TestCase):
    @unittest.skip("skip test_npu_transpose_batchmatmul")
    def test_npu_transpose_batchmatmul_meta_1(self):
        with FakeTensorMode():
            M, K, N, Batch = 32, 512, 128, 16
            x1 = torch.randn((M, Batch, K), dtype=torch.float16)
            x2 = torch.randn((Batch, K, N), dtype=torch.float16)
            scale = torch.randint(1, 10, (Batch * N, ), dtype=torch.int64)
            result = torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(), scale=scale.npu(),
                                                        perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
            expect_ret = torch.randint(-2, 2, (M, 1, Batch * N), dtype=torch.int8)
            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)

    @unittest.skip("skip test_npu_transpose_batchmatmul")
    def test_npu_transpose_batchmatmul_meta_2(self):
        with FakeTensorMode():
            M, K, N, Batch = 32, 512, 128, 16
            x1 = torch.randn((M, Batch, K), dtype=torch.float16)
            x2 = torch.randn((Batch, K, N), dtype=torch.float16)
            result = torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(),
                                                        perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
            expect_ret = torch.randn((M, Batch, N), dtype=torch.float16)
            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)

    @unittest.skip("skip test_npu_transpose_batchmatmul")
    def test_npu_transpose_batchmatmul_meta_3(self):
        with FakeTensorMode():
            M, K, N, Batch = 32, 512, 128, 16
            batch_split_factor = 4
            x1 = torch.randn((M, Batch, K), dtype=torch.float16)
            x2 = torch.randn((Batch, K, N), dtype=torch.float16)
            result = torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(), scale=None,
                                                        perm_x1=[1, 0, 2], perm_y=[1, 0, 2],
                                                        batch_split_factor=batch_split_factor)
            expect_ret = torch.randn((batch_split_factor, M, Batch * N // batch_split_factor), dtype=torch.float16)
            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)

    @unittest.skip("skip test_npu_transpose_batchmatmul")
    def test_npu_transpose_batchmatmul_meta_4(self):
        with FakeTensorMode():
            M, K, N, Batch = 32, 512, 128, 16
            x1 = torch.randn((M, Batch, K), dtype=torch.float16)
            x2 = torch.randn((Batch, K, N), dtype=torch.float16)
            x2_nz = torch_npu.npu_format_cast(x2.npu(), acl_format=29)
            result = torch_npu.npu_transpose_batchmatmul(x1.npu(), x2_nz.npu(),
                                                        perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
            expect_ret = torch.randn((M, Batch, N), dtype=torch.float16)
            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)

class TestTransposeQuantBatchMatmul(TestCase):
    @unittest.skip("skip test_npu_transpose_quant_batchmatmul")
    def test_npu_transpose_quant_batchmatmul_meta_1(self):
        with FakeTensorMode():
            M, K, N, Batch = 32, 512, 128, 16
            x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e5m2)
            x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e5m2)
            x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
            x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
            result = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.float16,
                                                        x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
                                                        perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
            expect_ret = torch.randint(-2, 2, (M, Batch, N), dtype=torch.float16)
            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)

    @unittest.skip("skip test_npu_transpose_quant_batchmatmul")
    def test_npu_transpose_quant_batchmatmul_meta_2(self):
        with FakeTensorMode():
            M, K, N, Batch = 32, 512, 128, 16
            x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
            x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
            x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
            x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
            result = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.float16,
                                                        x1_scale=scale.npu(), x2_scale=scale.npu(),
                                                        perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
            expect_ret = torch.randint(-2, 2, (M, Batch, N), dtype=torch.float16)
            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)

    @unittest.skip("skip test_npu_transpose_quant_batchmatmul")
    def test_npu_transpose_quant_batchmatmul_meta_3(self):
        with FakeTensorMode():
            M, K, N, Batch = 32, 512, 128, 16
            x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
            x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
            x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
            x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
            result = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.bfloat16,
                                                        x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
                                                        perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
            expect_ret = torch.randint(-2, 2, (M, Batch, N), dtype=torch.bfloat16)
            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)

    @unittest.skip("skip test_npu_transpose_quant_batchmatmul")
    def test_npu_transpose_quant_batchmatmul_meta_4(self):
        with FakeTensorMode():
            M, K, N, Batch = 32, 512, 128, 16
            x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e5m2)
            x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e5m2)
            x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
            x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
            result = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.bfloat16,
                                                        x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
                                                        perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
            expect_ret = torch.randint(-2, 2, (M, Batch, N), dtype=torch.bfloat16)
            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)

    @unittest.skip("skip test_npu_transpose_quant_batchmatmul")
    def test_npu_transpose_quant_batchmatmul_meta_5(self):
        with FakeTensorMode():
            M, K, N, Batch = 32, 512, 128, 16
            x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
            x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
            x1_scale = torch.randint(0, 3, (M, Batch, int(K/64), 2), dtype=torch.uint8)
            x2_scale = torch.randint(0, 3, (Batch, int(K/64), N, 2), dtype=torch.uint8)
            x1_scale =x1_scale.view(torch.float8_e8m0fnu)
            x2_scale =x2_scale.view(torch.float8_e8m0fnu)
            result = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.float16,
                                                        x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
                                                        group_sizes=[0,0,32],perm_x1=[1, 0, 2],perm_x2=[0, 1, 2], perm_y=[1, 0, 2])
            expect_ret = torch.randint(-2, 2, (M, Batch, N), dtype=torch.float16)
            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)


class TestNpuPrefetch(TestCase):
    def test_npu_prefetch(self):
        with FakeTensorMode():
            input1 = torch.randn(8, 8).npu()
            with self.assertRaises(RuntimeError) as cm:
                torch_npu.npu_prefetch(input1, None, -1)
            exception = cm.exception
            self.assertEqual(str(exception), "The max_size should be greater than zero, but got -1.")

            with self.assertRaises(RuntimeError) as cm:
                torch_npu.npu_prefetch(input1, None, 10, -1)
            exception = cm.exception
            self.assertEqual(str(exception), "The offset should be nonnegative, but got -1.")


class TestMoeDistributeDispatch(TestCase):
    def test_moe_distribute_dispatchA2(self):
        with FakeTensorMode():
            quant_mode = 2
            ep_world_size = 16
            tp_world_size = 0
            world_size = ep_world_size
            bs = 8
            h = 7168
            k = 8
            sharedExpertRankNum = 0
            moeExpertNum = 16
            global_bs = bs * ep_world_size
            expert_num_per_rank = 1
            total_expert_num = world_size * expert_num_per_rank

            local_moe_expert_num = moeExpertNum // ep_world_size
            a = global_bs * min(local_moe_expert_num, k)

            ep_recv_cnt_num = ep_world_size * local_moe_expert_num + global_bs * 2 * k * (ep_world_size // 8)

            x = torch.randn(bs, h).to(torch.bfloat16)
            expert_ids = torch.randn(bs, k).to(torch.int32)
            scales = torch.randn(total_expert_num, h).to(torch.float32)
            expert_scales = torch.randn(bs, k).to(torch.float32)
            result = torch_npu.npu_moe_distribute_dispatch(x, expert_ids, "group_ep", ep_world_size, 0, moeExpertNum, scales=scales, x_active_mask=None, expert_scales=expert_scales, group_tp="", tp_world_size=0,
                                                           tp_rank_id=0, expert_shard_type=0, shared_expert_num=0, shared_expert_rank_num=sharedExpertRankNum, quant_mode=quant_mode, global_bs=global_bs, expert_token_nums_type=1)
            self.assertEqual(result[0].shape[0], a)
            self.assertEqual(result[0].shape[1], h)
            self.assertEqual(result[0].dtype, torch.int8)

            self.assertEqual(result[1].shape[0], a)
            self.assertEqual(result[1].dtype, torch.float32)

            self.assertEqual(result[2].shape[0], bs * k)
            self.assertEqual(result[2].dtype, torch.int32)

            self.assertEqual(result[3].shape[0], local_moe_expert_num)
            self.assertEqual(result[3].dtype, torch.int64)

            self.assertEqual(result[4].shape[0], ep_recv_cnt_num)
            self.assertEqual(result[4].dtype, torch.int32)

            self.assertEqual(result[5].shape[0], tp_world_size)
            self.assertEqual(result[5].dtype, torch.int32)

            self.assertEqual(result[6].shape[0], a)
            self.assertEqual(result[6].dtype, torch.float32)


class TestMoeDistributeDispatchV2(TestCase):
    def _run_dispatch_v2(self, tp_world_size, quant_mode, scales=None, y_dtype=None):
        with FakeTensorMode():
            ep_world_size = 16
            bs = 8
            h = 7168
            k = 8
            moeExpertNum = 16
            global_bs = bs * ep_world_size

            local_moe_expert_num = moeExpertNum // ep_world_size
            a = global_bs * min(local_moe_expert_num, k)

            x = torch.randn(bs, h).to(torch.bfloat16)
            expert_ids = torch.randn(bs, k).to(torch.int32)

            result = torch_npu.npu_moe_distribute_dispatch_v2(
                x, expert_ids, "group_ep", ep_world_size, 0, moeExpertNum,
                scales=scales, x_active_mask=None, expert_scales=None,
                group_tp="", tp_world_size=tp_world_size, tp_rank_id=0,
                expert_shard_type=0, shared_expert_num=0, shared_expert_rank_num=0,
                quant_mode=quant_mode, global_bs=global_bs, expert_token_nums_type=1,
                y_dtype=y_dtype)

            return result, a, h, local_moe_expert_num

    def test_tp0_pertoken(self):
        result, a, h, _ = self._run_dispatch_v2(tp_world_size=0, quant_mode=2)
        self.assertEqual(result[0].shape, torch.Size([a, h]))
        self.assertEqual(result[0].dtype, torch.int8)
        self.assertEqual(result[1].shape, torch.Size([a]))

    def test_tp1_pertoken(self):
        result, a, h, _ = self._run_dispatch_v2(tp_world_size=1, quant_mode=2)
        self.assertEqual(result[0].shape, torch.Size([a, h]))
        self.assertEqual(result[0].dtype, torch.int8)
        self.assertEqual(result[1].shape, torch.Size([a]))

    def test_tp0_pergroup(self):
        result, a, h, _ = self._run_dispatch_v2(
            tp_world_size=0, quant_mode=3, y_dtype=torch.float8_e5m2)
        expected_dim1 = math.ceil(h / 128)
        self.assertEqual(result[1].shape, torch.Size([a, expected_dim1]))

    def test_tp1_pergroup(self):
        result, a, h, _ = self._run_dispatch_v2(
            tp_world_size=1, quant_mode=3, y_dtype=torch.float8_e5m2)
        expected_dim1 = math.ceil(h / 128)
        self.assertEqual(result[1].shape, torch.Size([a, expected_dim1]))

    def test_tp0_mx(self):
        result, a, h, _ = self._run_dispatch_v2(
            tp_world_size=0, quant_mode=4, y_dtype=torch.float8_e4m3fn)
        expected_dim1 = (math.ceil(h / 32) + 1) // 2 * 2
        self.assertEqual(result[1].shape, torch.Size([a, expected_dim1]))
        self.assertEqual(result[1].dtype, torch.uint8)

    def test_tp1_mx(self):
        result, a, h, _ = self._run_dispatch_v2(
            tp_world_size=1, quant_mode=4, y_dtype=torch.float8_e4m3fn)
        expected_dim1 = (math.ceil(h / 32) + 1) // 2 * 2
        self.assertEqual(result[1].shape, torch.Size([a, expected_dim1]))
        self.assertEqual(result[1].dtype, torch.uint8)

    def test_tp0_no_quant(self):
        result, a, h, _ = self._run_dispatch_v2(tp_world_size=0, quant_mode=0)
        self.assertEqual(result[0].shape, torch.Size([a, h]))
        self.assertEqual(result[0].dtype, torch.bfloat16)
        self.assertEqual(result[1].shape, torch.Size([a]))


class TestMoeDistributeCombineAddRmsNorm(TestCase):
    def test_moe_distribute_combine_add_rms_norm(self):
        with FakeTensorMode():
            expand_x = torch.randn(32, 7168).to(torch.bfloat16)
            expert_ids = torch.randn(32, 8).to(torch.int32)
            expand_idx = torch.randn(32, 8).to(torch.int32)
            ep_send_counts = torch.randn(8).to(torch.int32)
            expert_scales = torch.randn(32, 8).to(torch.float32)
            residual_x = torch.randn(32, 1, 7168).to(torch.bfloat16)
            gamma = torch.randn(7168).to(torch.bfloat16)
            ep_world_size = 16
            ep_rank_id = 0
            moe_expert_num = 16

            result = torch_npu.npu_moe_distribute_combine_add_rms_norm(expand_x=expand_x, expert_ids=expert_ids,
            expand_idx=expand_idx, ep_send_counts=ep_send_counts, expert_scales=expert_scales,
            residual_x=residual_x, gamma=gamma, group_ep="groupe_ep", ep_world_size=ep_world_size,
            ep_rank_id=ep_rank_id, moe_expert_num=moe_expert_num)

            self.assertEqual(result[0].shape[0], 32)
            self.assertEqual(result[0].shape[1], 1)
            self.assertEqual(result[0].shape[2], 7168)
            self.assertEqual(result[0].dtype, torch.bfloat16)

            self.assertEqual(result[1].shape[0], 32)
            self.assertEqual(result[1].shape[1], 1)
            self.assertEqual(result[1].shape[2], 1)
            self.assertEqual(result[1].dtype, torch.float32)

            self.assertEqual(result[2].shape[0], 32)
            self.assertEqual(result[2].shape[1], 1)
            self.assertEqual(result[2].shape[2], 7168)
            self.assertEqual(result[2].dtype, torch.bfloat16)


class TestMoeDistributeDispatchSetup(TestCase):
    def test_moe_distribute_dispatch_setup(self):
        with FakeTensorMode():
            x = torch.randn(32, 256).to(torch.bfloat16)
            expert_ids = torch.randn(32, 8).to(torch.int32)
            ep_world_size = 16
            ep_rank_id = 0
            moe_expert_num = 16
            result = torch_npu.npu_moe_distribute_dispatch_setup(
                                    x=x, expert_ids=expert_ids, group_ep="group",
                                    ep_world_size=ep_world_size, ep_rank_id=ep_rank_id,
                                    moe_expert_num=moe_expert_num)

            self.assertEqual(result[0].shape[0], 288)
            self.assertEqual(result[0].shape[1], 256)
            self.assertEqual(result[0].dtype, torch.bfloat16)

            self.assertEqual(result[1].shape[0], 256)
            self.assertEqual(result[1].dtype, torch.int32)

            self.assertEqual(result[2].shape[0], 4864)
            self.assertEqual(result[2].dtype, torch.int32)


class TestMoeDistributeDispatchTeardown(TestCase):
    def test_moe_distribute_dispatch_teardown(self):
        with FakeTensorMode():
            x = torch.randn(32, 256).to(torch.bfloat16)
            y = torch.randn(256, 512).to(torch.bfloat16)
            expert_ids = torch.randn(32, 8).to(torch.int32)
            comm_cmd_info = torch.randn(8704).to(torch.int32)
            ep_world_size = 16
            ep_rank_id = 0
            moe_expert_num = 16
            result = torch_npu.npu_moe_distribute_dispatch_teardown(
                                    x=x, y=y, expert_ids=expert_ids,
                                    comm_cmd_info=comm_cmd_info, group_ep="group",
                                    ep_world_size=ep_world_size, ep_rank_id=ep_rank_id,
                                    moe_expert_num=moe_expert_num)

            self.assertEqual(result[0].shape[0], 512)
            self.assertEqual(result[0].shape[1], 256)
            self.assertEqual(result[0].dtype, torch.bfloat16)

            self.assertEqual(result[1].shape[0], 512)
            self.assertEqual(result[1].dtype, torch.float32)

            self.assertEqual(result[2].shape[0], 512 * 128)
            self.assertEqual(result[2].dtype, torch.int32)

            self.assertEqual(result[3].shape[0], 1)
            self.assertEqual(result[3].dtype, torch.int64)


class TestMoeDistributeCombineSetup(TestCase):
    def test_moe_distribute_combine_setup(self):
        with FakeTensorMode():
            expand_x = torch.randn(32, 256).to(torch.bfloat16)
            expert_ids = torch.randn(32, 8).to(torch.int32)
            assist_info_for_combine = torch.randn(32 * 128).to(torch.int32)
            ep_world_size = 16
            ep_rank_id = 0
            moe_expert_num = 16
            result = torch_npu.npu_moe_distribute_combine_setup(
                                    expand_x=expand_x, expert_ids=expert_ids,
                                    assist_info_for_combine=assist_info_for_combine,
                                    group_ep="group", ep_world_size=ep_world_size,
                                    ep_rank_id=ep_rank_id, moe_expert_num=moe_expert_num)

            self.assertEqual(result[0].shape[0], 32)
            self.assertEqual(result[0].shape[1], 512)
            self.assertEqual(result[0].dtype, torch.int8)

            self.assertEqual(result[1].shape[0], 768)
            self.assertEqual(result[1].dtype, torch.int32)


class TestMoeDistributeCombineTeardown(TestCase):
    def test_moe_distribute_combine_teardown(self):
        with FakeTensorMode():
            expand_x = torch.randn(32, 256).to(torch.bfloat16)
            quant_expand_x = torch.randn(32, 1024).to(torch.int8)
            expert_ids = torch.randn(32, 8).to(torch.int32)
            expand_idx = torch.randn(256).to(torch.int32)
            expert_scales = torch.randn(256).to(torch.float32)
            comm_cmd_info = torch.randn(768).to(torch.int32)
            ep_world_size = 16
            ep_rank_id = 0
            moe_expert_num = 16
            result = torch_npu.npu_moe_distribute_combine_teardown(
                                expand_x=expand_x, quant_expand_x=quant_expand_x,
                                expert_ids=expert_ids, expand_idx=expand_idx,
                                expert_scales=expert_scales, comm_cmd_info=comm_cmd_info,
                                group_ep="group", ep_world_size=ep_world_size,
                                ep_rank_id=ep_rank_id, moe_expert_num=moe_expert_num)

            self.assertEqual(result.shape[0], 32)
            self.assertEqual(result.shape[1], 256)
            self.assertEqual(result.dtype, torch.bfloat16)


class TestAddRmsNormQuant(TestCase):
    def test_npu_add_rms_norm_quant(self):
        with FakeTensorMode():
            x1 = torch.randn([2, 16], dtype=torch.float16).npu()
            x2 = torch.randn([2, 16], dtype=torch.float16).npu()
            gamma = torch.randn([16, ], dtype=torch.float16).npu()
            scales1 = torch.randn([16, ], dtype=torch.float32).npu()
            zero_points1 = torch.randint(-10, 10, [16, ], dtype=torch.int32).npu()
            y1, y2, x_out = torch_npu.npu_add_rms_norm_quant(x1, x2, gamma, scales1, zero_points1)
            self.assertTrue(y1.shape == x1.shape)
            self.assertTrue(y1.dtype == torch.int8)
            self.assertTrue(y2.shape == x1.shape)
            self.assertTrue(y2.dtype == torch.int8)
            self.assertTrue(x_out.shape == x1.shape)
            self.assertTrue(x_out.dtype == x1.dtype)

            x1 = torch.randn([2, 16], dtype=torch.bfloat16).npu()
            x2 = torch.randn([2, 16], dtype=torch.bfloat16).npu()
            gamma = torch.randn([16, ], dtype=torch.bfloat16).npu()
            scales1 = torch.randn([16, ], dtype=torch.bfloat16).npu()
            zero_points1 = torch.randn([16, ], dtype=torch.bfloat16).npu()
            y1, y2, x_out = torch_npu.npu_add_rms_norm_quant(x1, x2, gamma, scales1, zero_points1)
            self.assertTrue(y1.shape == x1.shape)
            self.assertTrue(y1.dtype == torch.int8)
            self.assertTrue(y2.shape == x1.shape)
            self.assertTrue(y2.dtype == torch.int8)
            self.assertTrue(x_out.shape == x1.shape)
            self.assertTrue(x_out.dtype == x1.dtype)


class TestAddRmsNormDynamicQuant(TestCase):
    def test_npu_add_rms_norm_dynamic_quant_meta(self):
        with FakeTensorMode():
            x1 = torch.randn([2, 16], dtype=torch.float16, device='npu')
            x2 = torch.randn([2, 16], dtype=torch.float16, device='npu')
            gamma = torch.ones([16, ], dtype=torch.float16, device='npu')
            beta = torch.zeros([16, ], dtype=torch.float16, device='npu')
            smooth_scale1 = torch.ones([16, ], dtype=torch.float16, device='npu')
            smooth_scale2 = torch.ones([16, ], dtype=torch.float16, device='npu')
            y1_npu, y2_npu, x_out_npu, s1_npu, s2_npu = torch_npu.npu_add_rms_norm_dynamic_quant(
                x1, x2, gamma, smooth_scale1=smooth_scale1, smooth_scale2=smooth_scale2, beta=beta
            )
            self.assertEqual(y1_npu.shape, x1.shape)
            self.assertEqual(y1_npu.dtype, torch.int8)
            self.assertEqual(y2_npu.shape, x1.shape)
            self.assertEqual(y2_npu.dtype, torch.int8)
            self.assertEqual(x_out_npu.shape, x1.shape)
            self.assertEqual(x_out_npu.dtype, x1.dtype)
            self.assertEqual(s1_npu.shape, x1.shape[:-1])
            self.assertEqual(s1_npu.dtype, torch.float32)
            self.assertEqual(s2_npu.shape, x1.shape[:-1])
            self.assertEqual(s2_npu.dtype, torch.float32)

            x1 = torch.randn([2, 16], dtype=torch.bfloat16, device='npu')
            x2 = torch.randn([2, 16], dtype=torch.bfloat16, device='npu')
            gamma = torch.ones([16, ], dtype=torch.bfloat16, device='npu')
            beta = torch.zeros([16, ], dtype=torch.bfloat16, device='npu')
            smooth_scale1 = torch.ones([16, ], dtype=torch.bfloat16, device='npu')
            smooth_scale2 = torch.ones([16, ], dtype=torch.bfloat16, device='npu')
            y1_npu, y2_npu, x_out_npu, s1_npu, s2_npu = torch_npu.npu_add_rms_norm_dynamic_quant(
                x1, x2, gamma, smooth_scale1=smooth_scale1, smooth_scale2=smooth_scale2, beta=beta
            )
            self.assertEqual(y1_npu.shape, x1.shape)
            self.assertEqual(y1_npu.dtype, torch.int8)
            self.assertEqual(y2_npu.shape, x1.shape)
            self.assertEqual(y2_npu.dtype, torch.int8)
            self.assertEqual(x_out_npu.shape, x1.shape)
            self.assertEqual(x_out_npu.dtype, x1.dtype)
            self.assertEqual(s1_npu.shape, x1.shape[:-1])
            self.assertEqual(s1_npu.dtype, torch.float32)
            self.assertEqual(s2_npu.shape, x1.shape[:-1])
            self.assertEqual(s2_npu.dtype, torch.float32)

            # y_dtype=torch.quint4x2 (int4): y1/y2 dtype int32, last_dim/8 (8 int4 packed per int32)
            y1_npu, y2_npu, x_out_npu, s1_npu, s2_npu = torch_npu.npu_add_rms_norm_dynamic_quant(
                x1, x2, gamma, smooth_scale1=smooth_scale1, smooth_scale2=smooth_scale2, beta=beta, y_dtype=torch.quint4x2
            )
            expected_y_shape = list(x1.shape)
            expected_y_shape[-1] = expected_y_shape[-1] // 8
            self.assertEqual(y1_npu.shape, tuple(expected_y_shape))
            self.assertEqual(y1_npu.dtype, torch.int32)
            self.assertEqual(y2_npu.shape, tuple(expected_y_shape))
            self.assertEqual(y2_npu.dtype, torch.int32)


class TestAddRmsNormDynamicMxQuant(TestCase):
    def test_npu_add_rms_norm_dynamic_mx_quant_meta(self):
        with FakeTensorMode():
            x1 = torch.randn([8, 64], dtype=torch.float16, device='npu')
            x2 = torch.randn([8, 64], dtype=torch.float16, device='npu')
            gamma = torch.ones([64, ], dtype=torch.float16, device='npu')
            beta = torch.zeros([64, ], dtype=torch.float16, device='npu')
            y_npu, x_out_npu, mxscale_npu, rstd_npu = torch_npu.npu_add_rms_norm_dynamic_mx_quant(
                x1, x2, gamma, beta=beta, epsilon=1e-6, scale_alg=0, round_mode="rint", dst_type=torch_npu.float8_e5m2
            )
            self.assertEqual(y_npu.shape, x1.shape)
            self.assertEqual(y_npu.dtype, torch.float8_e5m2)
            self.assertEqual(x_out_npu.shape, x1.shape)
            self.assertEqual(x_out_npu.dtype, x1.dtype)
            self.assertEqual(mxscale_npu.shape, torch.Size([8, 1, 2]))
            self.assertEqual(mxscale_npu.dtype, torch.uint8)
            self.assertEqual(rstd_npu.shape, torch.Size([8, 1]))
            self.assertEqual(rstd_npu.dtype, torch.float32)


class TestMoeUpdateExpert(TestCase):
    def test_moe_update_expert(self):
        with FakeTensorMode():
            dtype = torch.bfloat16
            world_size = 8
            local_rank_id = 0
            balance_mode = 0
            bs = 32
            k = 8
            f = 4
            moe_expert_num = 256
            expert_ids = torch.randn(bs, k).to(torch.int32)
            eplb_table = torch.randn(moe_expert_num, f).to(torch.int32)
            expert_scales = torch.sort(torch.rand(bs, k, dtype=dtype), dim=-1, descending=True).values
            pruning_threshold = torch.rand(k, dtype=dtype)
            num_true = np.random.randint(0, bs + 1)
            active_mask_arr = np.concatenate([np.ones(num_true, dtype=bool), np.zeros(bs - num_true, dtype=bool)])
            active_mask = torch.from_numpy(active_mask_arr).to(torch.bool)
            result = torch_npu.npu_moe_update_expert(expert_ids=expert_ids, eplb_table=eplb_table,
            expert_scales=expert_scales, pruning_threshold=pruning_threshold, active_mask=active_mask,
            local_rank_id=local_rank_id, world_size=world_size, balance_mode=balance_mode)

            self.assertEqual(result[0].shape[0], 32)
            self.assertEqual(result[0].shape[1], 8)
            self.assertEqual(result[0].dtype, torch.int32)

            self.assertEqual(result[1].shape[0], 32)
            self.assertEqual(result[1].shape[1], 8)
            self.assertEqual(result[1].dtype, torch.bool)


class TestNpuMropeMeta(TestCase):
    def test_npu_mrope_meta(self):
        with FakeTensorMode():
            dtype = torch.bfloat16
            rotary_mode = 'half'
            cache_mode = 'default'
            num_tokens = 8
            num_q_heads = 8
            head_size = 128
            mrope_section = [0, 0, 0]
            num_kv_heads = num_q_heads
            max_seq_len = num_tokens
            rotary_dim = head_size
            positions = torch.arange(num_tokens, dtype=torch.int64)
            query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype)
            key = torch.rand(num_tokens, num_kv_heads * head_size, dtype=dtype)
            cos_sin_cache = torch.rand(max_seq_len, rotary_dim, dtype=dtype)

            positions_npu = positions.npu()
            query_npu = query.npu()
            key_npu = key.npu()
            cos_sin_cache_npu = cos_sin_cache.npu()

            query_out, key_out = torch_npu.npu_mrope(
                positions_npu,
                query_npu,
                key_npu,
                cos_sin_cache_npu,
                head_size,
                mrope_section=mrope_section,
                rotary_mode=rotary_mode,
                cache_mode=cache_mode,
            )

            self.assertEqual(query_out.shape, query_npu.shape)
            self.assertEqual(query_out.dtype, query_npu.dtype)
            self.assertEqual(key_out.shape, key_npu.shape)
            self.assertEqual(key_out.dtype, key_npu.dtype)


class TestGatherSparseIndex(TestCase):
    def test_npu_gather_sparse_index(self):
        with FakeTensorMode():
            inputs = torch.randn([16, 32], dtype=torch.float32).npu()
            index = torch.randint(0, 16, [4, 8]).npu()
            expect_ret = torch.randn([4, 8, 32], dtype=torch.float32).npu()
            result = torch_npu.npu_gather_sparse_index(inputs, index)

            self.assertTrue(result.shape == expect_ret.shape)
            self.assertTrue(result.dtype == expect_ret.dtype)


class TestNpuTopKTopP(TestCase):
    def test_npu_top_k_top_p_meta(self):
        vocab_size = 152064
        batch_size = 128
        dtype = torch.float32
        with FakeTensorMode():
            k_max = min(1024, vocab_size)
            logits = torch.randn(batch_size, vocab_size).to(dtype)
            p = torch.rand(batch_size).to(dtype)
            k = torch.randint(10, k_max, (batch_size,)).to(torch.int32)
            out_npu = torch_npu.npu_top_k_top_p(logits.npu(), p.npu(), k.npu())
            self.assertEqual(out_npu.dtype, dtype)
            self.assertEqual(out_npu.shape, logits.shape)


class TestNpuQkvRmsNormRopeCache(TestCase):
    def test_npu_qkv_rms_norm_rope_cache_meta(self):
        B, S = 16, 3
        Nq, Nk, Nv = 16, 1, 1
        D = 128
        N = Nq + Nk + Nv
        qkv_size = [B, S, N, D]
        head_nums = [Nq, Nk, Nv]
        block_size = 128
        block_num = (S + block_size - 1) // block_size * B
        dtype = torch.float16
        eps = 1e-6
        with FakeTensorMode():
            qkv = torch.empty(B * S, N * D, dtype=dtype).npu()
            q_gamma = torch.empty(D, dtype=dtype).npu()
            k_gamma = torch.empty(D, dtype=dtype).npu()
            cos = torch.empty(B * S, D, dtype=dtype).npu()
            sin = torch.empty(B * S, D, dtype=dtype).npu()
            index = torch.full((B * S,), -1, dtype=torch.int64).npu()

            q_out = torch.empty(B * S, Nq * D, dtype=dtype).npu()

            k_cache = torch.empty(
                block_num, Nk * D // 32, block_size, 32,
                dtype=torch.int8
            ).npu()
            v_cache = torch.empty(
                block_num, Nv * D // 32, block_size, 32,
                dtype=torch.int8
            ).npu()

            k_scale = torch.empty(Nk, D, dtype=torch.float32).npu()
            v_scale = torch.empty(Nv, D, dtype=torch.float32).npu()

            q_out_before_quant, k_out_before_quant, v_out_before_quant = torch_npu.npu_qkv_rms_norm_rope_cache(
                qkv=qkv,
                q_gamma=q_gamma,
                k_gamma=k_gamma,
                cos=cos,
                sin=sin,
                index=index,
                q_out=q_out,
                k_cache=k_cache,
                v_cache=v_cache,
                qkv_size=qkv_size,
                head_nums=head_nums,
                k_scale=k_scale,
                v_scale=v_scale,
                k_offset=None,
                v_offset=None,
                epsilon=eps,
                cache_mode="PA_NZ",
                is_output_qkv=True,
            )

            self.assertEqual(q_out_before_quant.shape, torch.Size([B * S, Nq * D]))
            self.assertEqual(q_out_before_quant.dtype, dtype)
            self.assertEqual(k_out_before_quant.shape, torch.Size([B * S, Nk * D]))
            self.assertEqual(k_out_before_quant.dtype, dtype)
            self.assertEqual(v_out_before_quant.shape, torch.Size([B * S, Nv * D]))
            self.assertEqual(v_out_before_quant.dtype, dtype)


class TestNpuMoeTokenPermuteAndUnpermute(TestCase):
    def test_npu_moe_token_permute_unpermute_meta(self):
        dtype = torch.bfloat16
        num_tokens = 1000
        num_output_tokens = None
        hidden_size = 6144
        num_experts = 128
        with FakeTensorMode():
            tokens = torch.randn(num_tokens, hidden_size).npu().to(dtype)
            probs = None
            num_output_tokens = 0 if num_output_tokens is None else num_output_tokens

            for topk in [1, 4]:
                flatten_size = num_tokens * topk
                indices = torch.randint(0, num_experts, (num_tokens, topk)).npu()
                indices = torch.randint(0, num_experts, (num_tokens, topk)).npu()
                permuted_tokens, sorted_indices = torch_npu.npu_moe_token_permute(tokens, indices)
                permuted_output_shape_0 = min(num_output_tokens, flatten_size) if num_output_tokens > 0 else flatten_size + num_output_tokens
                self.assertEqual(permuted_tokens.dtype, dtype)
                self.assertEqual(permuted_tokens.shape[0], permuted_output_shape_0)
                self.assertEqual(permuted_tokens.shape[1], hidden_size)
                self.assertEqual(sorted_indices.dtype, torch.int32)
                self.assertEqual(sorted_indices.shape[0], indices.numel())

                probs = (torch.ones_like(indices) / topk).npu().to(dtype)
                unpermuted_tokens = torch_npu.npu_moe_token_unpermute(permuted_tokens, sorted_indices, probs=probs)
                self.assertEqual(unpermuted_tokens.dtype, dtype)
                self.assertEqual(unpermuted_tokens.shape, (sorted_indices.size(0) / topk, hidden_size))

    def test_npu_moe_token_permute_unpermute_grad_meta(self):
        """Meta test for backward functions with topk=1 and topk=4"""
        dtype = torch.bfloat16
        num_tokens = 1000
        hidden_size = 6144
        num_experts = 128

        with FakeTensorMode():
            for topk in [1, 4]:
                tokens = torch.randn(num_tokens, hidden_size).npu().to(dtype)
                indices = torch.randint(0, num_experts, (num_tokens, topk)).npu()
                probs = (torch.ones(num_tokens, topk) / topk).npu().to(dtype)
                permuted_tokens, sorted_indices = torch_npu.npu_moe_token_permute(tokens, indices)
                expected_permuted_shape = (num_tokens * topk, hidden_size)
                self.assertEqual(permuted_tokens.shape, expected_permuted_shape)

                grad_unpermuted_tokens = torch.randn(num_tokens, hidden_size).npu().to(dtype)

                grad_permuted_tokens, grad_probs = torch_npu.npu_moe_token_unpermute_grad(
                    permuted_tokens=permuted_tokens,
                    grad_unpermuted_tokens=grad_unpermuted_tokens,
                    sorted_indices=sorted_indices,
                    probs=probs,
                    padded_mode=False,
                    restore_shape=None,
                )

                self.assertEqual(grad_permuted_tokens.dtype, permuted_tokens.dtype)
                self.assertEqual(grad_permuted_tokens.shape, permuted_tokens.shape)

                self.assertIsNotNone(grad_probs)
                self.assertEqual(grad_probs.dtype, probs.dtype)
                self.assertEqual(grad_probs.shape, probs.shape)

                grad_permuted_tokens_for_permute = torch.randn_like(permuted_tokens)
                grad_tokens = torch_npu.npu_moe_token_permute_grad(
                    tokens=tokens,
                    grad_permuted_tokens=grad_permuted_tokens_for_permute,
                    indices=indices,
                    sorted_indices=sorted_indices,
                    padded_mode=False,
                )

                self.assertEqual(grad_tokens.dtype, tokens.dtype)
                self.assertEqual(grad_tokens.shape, tokens.shape)

                grad_tokens_v2 = torch_npu.npu_moe_token_permute_grad_v2(
                    grad_permuted_tokens=grad_permuted_tokens_for_permute,
                    sorted_indices=sorted_indices,
                    tokens_size_0=tokens.shape[0],
                    tokens_dtype=tokens.dtype,
                    num_topK=topk,
                    padded_mode=False,
                )

                self.assertEqual(grad_tokens_v2.dtype, tokens.dtype)
                self.assertEqual(grad_tokens_v2.shape, tokens.shape)

    def test_npu_moe_token_unpermute_grad_mixed_dtype(self):
        """Test grad_probs dtype matches probs when probs and grad_unpermuted_tokens have different dtypes"""
        num_tokens = 100
        hidden_size = 64
        num_experts = 8
        topk = 4

        with FakeTensorMode():
            tokens = torch.randn(num_tokens, hidden_size).npu().to(torch.bfloat16)
            indices = torch.randint(0, num_experts, (num_tokens, topk)).npu()
            permuted_tokens, sorted_indices = torch_npu.npu_moe_token_permute(tokens, indices)

            probs = (torch.ones(num_tokens, topk) / topk).npu().to(torch.float32)
            grad_unpermuted_tokens = torch.randn(num_tokens, hidden_size).npu().to(torch.bfloat16)

            grad_permuted_tokens, grad_probs = torch_npu.npu_moe_token_unpermute_grad(
                permuted_tokens=permuted_tokens,
                grad_unpermuted_tokens=grad_unpermuted_tokens,
                sorted_indices=sorted_indices,
                probs=probs,
                padded_mode=False,
                restore_shape=None,
            )

            self.assertEqual(grad_permuted_tokens.dtype, permuted_tokens.dtype)
            self.assertIsNotNone(grad_probs)
            self.assertEqual(grad_probs.dtype, probs.dtype)
            self.assertEqual(grad_probs.shape, probs.shape)

    def test_npu_moe_token_unpermute_grad_probs_none(self):
        """Test npu_moe_token_unpermute_grad when probs is None"""
        num_tokens = 100
        hidden_size = 64
        num_experts = 8
        topk = 1

        with FakeTensorMode():
            tokens = torch.randn(num_tokens, hidden_size).npu().to(torch.bfloat16)
            indices = torch.randint(0, num_experts, (num_tokens, topk)).npu()
            permuted_tokens, sorted_indices = torch_npu.npu_moe_token_permute(tokens, indices)

            grad_unpermuted_tokens = torch.randn(num_tokens, hidden_size).npu().to(torch.bfloat16)

            grad_permuted_tokens, grad_probs = torch_npu.npu_moe_token_unpermute_grad(
                permuted_tokens=permuted_tokens,
                grad_unpermuted_tokens=grad_unpermuted_tokens,
                sorted_indices=sorted_indices,
                probs=None,
                padded_mode=False,
                restore_shape=None,
            )

            self.assertEqual(grad_permuted_tokens.dtype, permuted_tokens.dtype)
            self.assertEqual(grad_permuted_tokens.shape, permuted_tokens.shape)
            self.assertIsNone(grad_probs)


class TestNpuMoeUnpermuteWithRoutingMap(TestCase):
    def test_npu_moe_token_unpermute_with_routing_map_meta(self):
        token_num = 40
        hidden_size = 20
        expert_num = 20
        top_k = 20
        capacity = 20
        out_token_num = token_num * top_k
        out_token_num_pad = expert_num * capacity

        def generate_bool_matrix(m, n, k):
            matrix = torch.zeros((m, n), dtype=torch.bool)

            for i in range(m):
                indices = torch.randperm(n)[:k]
                matrix[i, indices] = True

            return matrix.to(torch.int8)

        with FakeTensorMode():
            for need_probs in [True, False]:
                routing_map = generate_bool_matrix(token_num, expert_num, top_k)
                routing_map_npu = routing_map.npu()

                permuted_tokens = torch.randn([out_token_num_pad, hidden_size])
                permuted_tokens_npu = permuted_tokens.npu()
                permuted_tokens.requires_grad_(True)
                permuted_tokens_npu.requires_grad_(True)

                routing_map_tmp = routing_map.T.contiguous()
                sorted_indices = routing_map_tmp.argsort(dim=-1, descending=True, stable=True)[:, :capacity].contiguous().to(torch.int32).view(-1)
                sorted_indices_npu = sorted_indices.npu()

                sorted_indices_npu = sorted_indices.npu()

                porbs = torch.randn([token_num, expert_num])
                porbs_npu = porbs.npu()
                porbs.requires_grad_(True)
                porbs_npu.requires_grad_(True)

                restore_shape = [token_num, hidden_size]
                drop_and_pad = True

                if need_probs:
                    unpermuted_tokens = torch_npu.npu_moe_token_unpermute_with_routing_map(
                        permuted_tokens_npu, sorted_indices_npu, restore_shape, probs=porbs_npu, routing_map=routing_map_npu, drop_and_pad=drop_and_pad)
                    permuted_tokens_, out_index, permuted_token_id, permute_probs = torch_npu._npu_moe_token_unpermute_with_routing_map(
                        permuted_tokens_npu, sorted_indices_npu, restore_shape, probs=porbs_npu, routing_map=routing_map_npu, drop_and_pad=drop_and_pad)
                else:
                    unpermuted_tokens = torch_npu.npu_moe_token_unpermute_with_routing_map(
                        permuted_tokens_npu, sorted_indices_npu, restore_shape, probs=None, routing_map=routing_map_npu, drop_and_pad=drop_and_pad)
                    permuted_tokens_, out_index, permuted_token_id, permute_probs = torch_npu._npu_moe_token_unpermute_with_routing_map(
                        permuted_tokens_npu, sorted_indices_npu, restore_shape, probs=None, routing_map=routing_map_npu, drop_and_pad=drop_and_pad)
                self.assertEqual(unpermuted_tokens.dtype, permuted_tokens_npu.dtype)
                self.assertEqual(unpermuted_tokens.shape[0], restore_shape[0])
                self.assertEqual(unpermuted_tokens.shape[1], restore_shape[1])

                self.assertEqual(permuted_tokens_.dtype, permuted_tokens_npu.dtype)
                self.assertEqual(permuted_tokens_.shape[0], restore_shape[0])
                self.assertEqual(permuted_tokens_.shape[1], restore_shape[1])

                self.assertEqual(permuted_token_id.dtype, sorted_indices_npu.dtype)
                self.assertEqual(permuted_token_id.shape, sorted_indices_npu.shape)

                if need_probs:
                    self.assertEqual(permute_probs.dtype, porbs_npu.dtype)
                    self.assertEqual(permute_probs.shape, sorted_indices.shape)


class TestNpuAttentionWorkerCombine(TestCase):
    def test_npu_attention_worker_combine(self):
        with FakeTensorMode():
            schedule_context = torch.randn((1024,)).npu().to(torch.int8)
            BS = 16
            expert_scales = torch.rand(BS, 8, dtype=torch.float32).npu()
            layer_id = torch.randint(1, 20, (1,), dtype=torch.int32).npu()
            H = 7168
            y_npu, next_layer_id_npu = \
            torch_npu.npu_attention_worker_combine(schedule_context, expert_scales, layer_id, H, token_dtype=0, need_schedule=0)
            self.assertTrue(y_npu.shape[0] == BS)
            self.assertTrue(y_npu.shape[1] == H)
            self.assertTrue(next_layer_id_npu.shape[0] == 1)


class TestNpuMoeTokenPermuteWithRoutingMap(TestCase):
    def test_npu_moe_token_permute_with_routing_map_meta(self):
        with FakeTensorMode():
            x = torch.randn((3, 4), dtype=torch.float)
            x.requires_grad = True
            routing_map = torch.tensor([[True, True], [True, True], [True, True]], dtype=torch.bool)
            numtoken = 6
            padMode = True
            probs = torch.randn([3, 2], dtype=torch.float)
            probs.requires_grad = True
            x_npu = x.npu().detach()
            x_npu.requires_grad = True
            routing_map_npu = routing_map.npu()
            probs_npu = probs.npu().detach()
            probs_npu.requires_grad = True

            out_token = numtoken // routing_map.size(1) * routing_map.size(1)
            output0 = torch.empty(out_token, x_npu.size(1), dtype=x_npu.dtype, device=x_npu.device)
            output1 = torch.empty(out_token, dtype=probs_npu.dtype, device=probs_npu.device)
            output2 = torch.empty(out_token, dtype=torch.int32, device=x_npu.device)

            x1, x2, x3 = torch_npu.npu_moe_token_permute_with_routing_map(x_npu, routing_map_npu, probs=probs_npu, num_out_tokens=numtoken, drop_and_pad=padMode)
            self.assertEqual(x1.dtype, output0.dtype)
            self.assertEqual(x2.dtype, output1.dtype)
            self.assertEqual(x3.dtype, output2.dtype)
            self.assertEqual(x1.shape, output0.shape)
            self.assertEqual(x2.shape, output1.shape)
            self.assertEqual(x3.shape, output2.shape)
            (x1.sum() + x2.sum()).backward()
            output0 = torch.empty_like(x_npu)
            output1 = torch.empty_like(probs)
            self.assertEqual(x_npu.grad.dtype, output0.dtype)
            self.assertEqual(probs_npu.grad.dtype, output1.dtype)
            self.assertEqual(x_npu.grad.shape, x_npu.shape)
            self.assertEqual(probs_npu.grad.shape, probs_npu.shape)


instantiate_parametrized_tests(FakeTensorTest)
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for="cpu")


class TestGroupedMatmulSwigluQuant(TestCase):
    def _assert_npu_grouped_matmul_swiglu_quant_shape(self, x, weight, groupList, weightScale, xScale, output_shape):
        output0_npu, output1_npu, output2_npu = torch_npu.npu_grouped_matmul_swiglu_quant(
            x.npu(), weight.npu(), groupList.npu(), weightScale.npu(), xScale.npu(), bias=None, offset=None)
        self.assertTrue(output0_npu.shape == output_shape)
        self.assertTrue(output0_npu.dtype == torch.int8)
        self.assertTrue(output1_npu.shape == torch.Size([x.size(0)]))
        self.assertTrue(output1_npu.dtype == torch.float32)
        self.assertTrue(output2_npu.shape == torch.Size([]))
        self.assertTrue(output2_npu.dtype == torch.float32)

    def test_npu_grouped_matmul_swiglu_quant(self):
        with FakeTensorMode():
            E = 16
            M = 512
            K = 7168
            N = 4096
            x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
            weight = torch.randint(-128, 127, (E, K, N), dtype=torch.int8)
            weightScale = torch.randn(E, N)
            xScale = torch.randn(M)
            groupList = torch.randn(E)
            output0 = torch.empty([M, N // 2], dtype=torch.int8, device=x.device)
            output1 = torch.empty([M], dtype=torch.float32, device=x.device)
            output2 = torch.empty([], dtype=torch.float32, device=x.device)
            output0_npu, output1_npu, output2_npu = torch_npu.npu_grouped_matmul_swiglu_quant(x.npu(), weight.npu(), groupList.npu(), weightScale.npu(), xScale.npu(), bias=None, offset=None)
            self.assertTrue(output0_npu.shape == output0.shape)
            self.assertTrue(output0_npu.dtype == output0.dtype)
            self.assertTrue(output1_npu.shape == output1.shape)
            self.assertTrue(output1_npu.dtype == output1.dtype)
            self.assertTrue(output2_npu.shape == output2.shape)
            self.assertTrue(output2_npu.dtype == output2.dtype)

    def test_npu_grouped_matmul_swiglu_quant_nz_shape(self):
        with FakeTensorMode():
            # A8W8 NZ: logical N is carried by weightScale, not by the physical weight shape.
            E = 2
            M = 8
            K = 64
            N = 128
            x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
            weight = torch.randint(-128, 127, (E, N // 32, K // 16, 16, 32), dtype=torch.int8)
            weightScale = torch.randn(E, N)
            xScale = torch.randn(M)
            groupList = torch.zeros(E, dtype=torch.int64)
            self._assert_npu_grouped_matmul_swiglu_quant_shape(
                x, weight, groupList, weightScale, xScale, torch.Size([M, N // 2]))

    def test_npu_grouped_matmul_swiglu_quant_nz_int32_pack_shape(self):
        with FakeTensorMode():
            # A8W4 INT32 packed NZ: physical weight.size(2) is K/16, so N must come from weightScale.
            E = 2
            M = 8
            K = 64
            N = 128
            x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
            weight = torch.randint(-128, 127, (E, N // 64, K // 16, 16, 8), dtype=torch.int32)
            weightScale = torch.randn(E, N)
            xScale = torch.randn(M)
            groupList = torch.zeros(E, dtype=torch.int64)
            self._assert_npu_grouped_matmul_swiglu_quant_shape(
                x, weight, groupList, weightScale, xScale, torch.Size([M, N // 2]))

    def test_npu_grouped_matmul_swiglu_quant_nz_per_group_shape(self):
        with FakeTensorMode():
            # A8W4 per-group NZ: weightScale has shape (E, K_group_num, N), and N is the tail axis.
            E = 2
            M = 8
            K = 64
            N = 128
            k_group_num = 2
            x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
            weight = torch.randint(-128, 127, (E, N // 64, K // 16, 16, 8), dtype=torch.int32)
            weightScale = torch.randn(E, k_group_num, N)
            xScale = torch.randn(M)
            groupList = torch.zeros(E, dtype=torch.int64)
            self._assert_npu_grouped_matmul_swiglu_quant_shape(
                x, weight, groupList, weightScale, xScale, torch.Size([M, N // 2]))


class TestGroupedMatmulSwigluQuantV2(TestCase):
    def _assert_npu_grouped_matmul_swiglu_quant_v2_shape(
            self, x, weight, weight_scale, x_scale, group_list, output_shape, output_scale_shape, **kwargs):
        output, output_scale = torch_npu.npu_grouped_matmul_swiglu_quant_v2(
            x.npu(), [weight.npu()], [weight_scale.npu()], x_scale.npu(), group_list.npu(), **kwargs)
        self.assertTrue(output.shape == output_shape)
        self.assertTrue(output_scale.shape == output_scale_shape)

    @unittest.skip("npu_format_cast has no fake impl; V2 meta shape regressions are covered by the cases below.")
    def test_npu_grouped_matmul_swiglu_quant_v2(self):
        with FakeTensorMode():
            E = 16
            M = 512
            K = 7168
            N = 4096
            x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
            weight = torch.randint(-128, 127, (E, K, N), dtype=torch.int8)
            weight_npu = torch_npu.npu_format_cast(weight.npu(), 29)
            weight_npu = [weight_npu]
            weightScale = torch.randn(E, N)
            xScale = torch.randn(M)
            groupList = torch.randn(E)
            output0 = torch.empty([M, N // 2], dtype=torch.int8, device=x.device)
            output1 = torch.empty([M], dtype=torch.float32, device=x.device)
            output0_npu, output1_npu = torch_npu.npu_grouped_matmul_swiglu_quant_v2(x.npu(), weight_npu, [weightScale.npu()], xScale.npu(), groupList.npu(), smooth_scale=None,
                                                                                    weight_assist_matrix=None, bias=None, dequant_mode=0, dequant_dtype=6, quant_mode=0,
                                                                                    quant_dtype=1, group_list_type=0,  tuning_config=None)
            self.assertTrue(output0_npu.shape == output0.shape)
            self.assertTrue(output0_npu.dtype == output0.dtype)
            self.assertTrue(output1_npu.shape == output1.shape)
            self.assertTrue(output1_npu.dtype == output1.dtype)

    def test_npu_grouped_matmul_swiglu_quant_v2_nz_int32_pack_shape(self):
        with FakeTensorMode():
            # A8W4 INT32 packed NZ: physical weight.size(2) is N/8, so meta must infer N from weightScale.
            E = 2
            M = 8
            K = 64
            N = 512
            x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
            weight = torch.randint(-128, 127, (E, K, N // 8), dtype=torch.int32)
            weight_scale = torch.randint(0, 10, (E, N), dtype=torch.int64)
            weight_assist = torch.randn(E, N)
            x_scale = torch.randn(M)
            group_list = torch.zeros(E, dtype=torch.int64)
            self._assert_npu_grouped_matmul_swiglu_quant_v2_shape(
                x, weight, weight_scale, x_scale, group_list,
                torch.Size([M, N // 2]), torch.Size([M]),
                weight_assist_matrix=[weight_assist.npu()], dequant_mode=0,
                dequant_dtype=6, quant_dtype=1)

    def test_npu_grouped_matmul_swiglu_quant_v2_nz_mx_shape(self):
        with FakeTensorMode():
            # MX NZ: weightScale is 4D and its tail axis is fixed to 2; logical N is dim2 for non-transpose.
            E = 2
            M = 8
            K = 128
            N = 512
            x = torch.empty((M, K), dtype=torch.float8_e4m3fn)
            weight = torch.empty((E, K, N // 8), dtype=torch.float8_e4m3fn)
            weight_scale = torch.empty((E, math.ceil(K / 64), N, 2), dtype=torch.uint8)
            x_scale = torch.empty((M, math.ceil(K / 64), 2), dtype=torch.uint8)
            group_list = torch.zeros(E, dtype=torch.int64)
            self._assert_npu_grouped_matmul_swiglu_quant_v2_shape(
                x, weight, weight_scale, x_scale, group_list,
                torch.Size([M, N // 2]), torch.Size([M, math.ceil((N // 2) / 64), 2]),
                dequant_mode=2, dequant_dtype=6, quant_mode=2,
                quant_dtype=24,
                weight_scale_dtype=torch_npu.float8_e8m0fnu,
                x_scale_dtype=torch_npu.float8_e8m0fnu)

    def _assert_npu_grouped_matmul_swiglu_quant_v2_ascend950_mxfp4_shape(
            self, case_name, e, m, k, n, output_tensor_dtype, group_list_type,
            expected_output_shape, expected_output_scale_shape):
        output_dtype_to_quant_dtype = {
            "float8_e5m2": 23,
            "float8_e4m3fn": 24,
            "float4_e2m1": torch_npu.float4_e2m1fn_x2,
        }
        # Without native torch FP4, public FP4 output is represented as packed uint8.
        if output_tensor_dtype == "float4_e2m1" and not hasattr(torch, "float4_e2m1fn_x2"):
            expected_output_shape = (expected_output_shape[0], expected_output_shape[1] // 2)
        x = torch.empty((m, k), dtype=torch.uint8)
        weight = torch.empty((e, k, n), dtype=torch.uint8)
        weight_scale = torch.empty((e, math.ceil(k / 64), n, 2), dtype=torch.uint8)
        x_scale = torch.empty((m, math.ceil(k / 64), 2), dtype=torch.uint8)
        group_list = torch.zeros(e, dtype=torch.int64)
        with self.subTest(case=case_name):
            self._assert_npu_grouped_matmul_swiglu_quant_v2_shape(
                x, weight, weight_scale, x_scale, group_list,
                torch.Size(expected_output_shape), torch.Size(expected_output_scale_shape),
                dequant_mode=2, dequant_dtype=6, quant_mode=2,
                quant_dtype=output_dtype_to_quant_dtype[output_tensor_dtype],
                group_list_type=group_list_type, tuning_config=[0],
                weight_scale_dtype=torch_npu.float8_e8m0fnu,
                x_scale_dtype=torch_npu.float8_e8m0fnu,
                x_dtype=torch_npu.float4_e2m1fn_x2,
                weight_dtype=torch_npu.float4_e2m1fn_x2)

    @unittest.skipIf(
        not hasattr(torch_npu, "float4_e2m1fn_x2"),
        "torch_npu float4 dtype wrappers are required for Ascend950 MXFP4 fake shape cases.")
    def test_npu_grouped_matmul_swiglu_quant_v2_ascend950_mxfp4_shape(self):
        with FakeTensorMode():
            # Ascend950 MXFP4 representative shapes: CANN currently supports FLOAT4_E2M1 FP4 input.
            cases = [
                ("fp8_e4m3fn_unaligned_shape", 2, 1775, 980, 4352, "float8_e4m3fn", 0, (1775, 2176), (1775, 34, 2)),
                ("fp8_e5m2_large_n", 4, 2048, 1024, 7168, "float8_e5m2", 0, (2048, 3584), (2048, 56, 2)),
                ("fp4_e2m1_grouped_shape", 4, 534, 130, 3200, "float4_e2m1", 1, (534, 1600), (534, 25, 2)),
                ("fp4_e2m1_large_n", 16, 2048, 128, 7168, "float4_e2m1", 1, (2048, 3584), (2048, 56, 2)),
                ("fp4_e2m1_large_shape", 4, 7168, 4096, 4096, "float4_e2m1", 1, (7168, 2048), (7168, 32, 2)),
            ]
            for case in cases:
                self._assert_npu_grouped_matmul_swiglu_quant_v2_ascend950_mxfp4_shape(*case)


@unittest.skip("skip until CANN is updated to support aclnnDynamicBlockQuant")
class TestNpuDynamicBlockQuant(TestCase):
    def test_npu_dynamic_block_quant_meta(self):
        # 2 dim
        x = torch.rand(3, 4).to("npu").to(torch.float16)
        actual_y, actual_scale = torch_npu.npu_dynamic_block_quant(x, dst_type=1)
        with FakeTensorMode():
            fake_x = torch.rand(3, 4).to("npu").to(torch.float16)
            fake_y, fake_scale = torch_npu.npu_dynamic_block_quant(fake_x, dst_type=1)
        self.assertEqual(actual_y.shape, fake_y.shape)
        self.assertEqual(actual_scale.shape, fake_scale.shape)

        # 3 dim
        x = torch.rand(3, 4, 5).to("npu").to(torch.float16)
        actual_y, actual_scale = torch_npu.npu_dynamic_block_quant(x, dst_type=1)
        with FakeTensorMode():
            fake_x = torch.rand(3, 4, 5).to("npu").to(torch.float16)
            fake_y, fake_scale = torch_npu.npu_dynamic_block_quant(fake_x, dst_type=1)
        self.assertEqual(actual_y.shape, fake_y.shape)
        self.assertEqual(actual_scale.shape, fake_scale.shape)


class TestQuantGroupedMatmulInplaceAdd(TestCase):
    def test_npu_quant_grouped_matmul_inplace_add_meta_0(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            y = torch.randint(-1, 1, (4, 576, 7168), dtype=torch.float32).npu()
            x1 = torch.randint(-1, 1, (512, 576), dtype=torch.int8).npu()
            x2 = torch.randint(-1, 1, (4, 512, 7168), dtype=torch.int8).npu()
            x2_scale = torch.randint(-1, 1, (4, 7168), dtype=torch.float32).npu()
            x1_scale = torch.randint(-1, 1, (4,), dtype=torch.float32).npu()
            group_list = torch.Tensor([8, 181, 415, 512]).to(torch.int64).npu()

            res_1 = torch_npu.npu_add_quant_gmm_(y, x1.t(), x2, x2_scale, group_list, x1_scale=x1_scale,
                  group_list_type=0, group_sizes=None, x1_dtype=torch_npu.hifloat8, x2_dtype=torch_npu.hifloat8)
            # x1 shape is k,m; x2 shape is e,k,n; y shape is e,m,n
            self.assertTrue(len(res_1.shape) == 3)
            self.assertTrue(x1.shape[1] == res_1.shape[1])
            self.assertTrue(x2.shape[0] == res_1.shape[0])
            self.assertTrue(x2.shape[2] == res_1.shape[2])
            self.assertTrue(res_1.dtype == torch.float32)
            self.assertTrue(y.shape[0] == res_1.shape[0])
            self.assertTrue(y.shape[1] == res_1.shape[1])
            self.assertTrue(y.shape[2] == res_1.shape[2])

            res_2 = torch_npu.npu_add_quant_gmm(y, x1.t(), x2, x2_scale, group_list, x1_scale=x1_scale,
                  group_list_type=0, group_sizes=None, x1_dtype=torch_npu.hifloat8, x2_dtype=torch_npu.hifloat8)
            self.assertTrue(len(res_2.shape) == 3)
            self.assertTrue(x1.shape[1] == res_2.shape[1])
            self.assertTrue(x2.shape[0] == res_2.shape[0])
            self.assertTrue(x2.shape[2] == res_2.shape[2])
            self.assertTrue(res_2.dtype == torch.float32)
            self.assertTrue(y.shape[0] == res_2.shape[0])
            self.assertTrue(y.shape[1] == res_2.shape[1])
            self.assertTrue(y.shape[2] == res_2.shape[2])


@unittest.skip("skip until CANN is updated to support aclnnMlaPrologV3WeightNz")
class TestMlaProLogV3(TestCase):
    def testMlaProLogV3(self):
        with FakeTensorMode():
            import math
            BlockNum = math.ceil(2 * 6144 / 128)
            token_x = torch.randint(-100, 100, (2, 1, 7168), dtype=torch.int8).npu()
            w_dq = torch.randint(-100, 100, (7168, 1536), dtype=torch.int8).npu()
            w_uq_qr = torch.randint(-100, 100, (1536, 32 * (128 + 64)), dtype=torch.int8).npu()
            w_uk = torch.rand(32, 128, 512, dtype=torch.bfloat16).npu()
            w_dkv_kr = torch.randint(-100, 100, (7168, 512 + 64), dtype=torch.int8).npu()
            rmsnorm_gamma_cq = torch.rand(1536, dtype=torch.bfloat16).npu()
            rmsnorm_gamma_ckv = torch.rand(512, dtype=torch.bfloat16).npu()
            rope_sin = torch.rand(2, 1, 64, dtype=torch.bfloat16).npu()
            rope_cos = torch.rand(2, 1, 64, dtype=torch.bfloat16).npu()
            kv_cache = torch.randint(-100, 100, (1, BlockNum * 128 * 1 * 512), dtype=torch.int8).npu()
            kr_cache = torch.rand(1, BlockNum * 128 * 1 * 64, dtype=torch.bfloat16).npu()

            query_shape = []
            query_shape.append(token_x.size(0))
            query_shape.append(token_x.size(1))
            query_shape.append(w_uk.size(0))
            query_shape.append(w_uk.size(2))

            query_rope_shape = []
            query_rope_shape.append(token_x.size(0))
            query_rope_shape.append(token_x.size(1))
            query_rope_shape.append(w_uk.size(0))
            query_rope_shape.append(rope_sin.size(2))

            dequant_scale_q_norm_shape = []
            dequant_scale_q_norm_shape.append(token_x.size(0) * token_x.size(1))
            dequant_scale_q_norm_shape.append(1)

            query = torch.empty(query_shape, dtype=rope_sin.dtype, device='meta')
            dequant_scale_q_nope = torch.empty([0], dtype=torch.float32, device='meta')
            query_norm = torch.empty([0], dtype=w_uq_qr.dtype, device='meta')
            dequant_scale_q_norm = torch.empty([0], dtype=torch.float32, device='meta')
            query_rope = torch.empty(query_rope_shape, dtype=torch.bfloat16, device='meta')

            query_mla, query_rope_mla, dequant_scale_q_nope_mla, query_norm_mla, dequant_scale_q_norm_mla = torch_npu.npu_mla_prolog_v3(token_x, w_dq, w_uq_qr, w_uk, w_dkv_kr, rmsnorm_gamma_cq,
            rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache)

            self.assertTrue(query_mla.shape == query.shape)
            self.assertTrue(query_rope_mla.shape == query_rope.shape)
            self.assertTrue(dequant_scale_q_nope_mla.shape == dequant_scale_q_nope.shape)
            self.assertTrue(query_norm_mla.shape == query_norm.shape)
            self.assertTrue(dequant_scale_q_norm_mla.shape == dequant_scale_q_norm.shape)


@unittest.skip("skip until CANN is updated to support aclnnMlaPrologV3WeightNz")
class TestMlaProLogV3Functional(TestCase):
    def testMlaProLogV3Functional(self):
        with FakeTensorMode():
            import math
            BlockNum = math.ceil(2 * 6144 / 128)
            token_x = torch.randint(-100, 100, (2, 1, 7168), dtype=torch.int8).npu()
            w_dq = torch.randint(-100, 100, (7168, 1536), dtype=torch.int8).npu()
            w_uq_qr = torch.randint(-100, 100, (1536, 32 * (128 + 64)), dtype=torch.int8).npu()
            w_uk = torch.rand(32, 128, 512, dtype=torch.bfloat16).npu()
            w_dkv_kr = torch.randint(-100, 100, (7168, 512 + 64), dtype=torch.int8).npu()
            rmsnorm_gamma_cq = torch.rand(1536, dtype=torch.bfloat16).npu()
            rmsnorm_gamma_ckv = torch.rand(512, dtype=torch.bfloat16).npu()
            rope_sin = torch.rand(2, 1, 64, dtype=torch.bfloat16).npu()
            rope_cos = torch.rand(2, 1, 64, dtype=torch.bfloat16).npu()
            kv_cache = torch.randint(-100, 100, (1, BlockNum * 128 * 1 * 512), dtype=torch.int8).npu()
            kr_cache = torch.rand(1, BlockNum * 128 * 1 * 64, dtype=torch.bfloat16).npu()

            query_shape = []
            query_shape.append(token_x.size(0))
            query_shape.append(token_x.size(1))
            query_shape.append(w_uk.size(0))
            query_shape.append(w_uk.size(2))

            query_rope_shape = []
            query_rope_shape.append(token_x.size(0))
            query_rope_shape.append(token_x.size(1))
            query_rope_shape.append(w_uk.size(0))
            query_rope_shape.append(rope_sin.size(2))

            dequant_scale_q_norm_shape = []
            dequant_scale_q_norm_shape.append(token_x.size(0) * token_x.size(1))
            dequant_scale_q_norm_shape.append(1)

            query = torch.empty(query_shape, dtype=rope_sin.dtype, device='meta')
            dequant_scale_q_nope = torch.empty([0], dtype=torch.float32, device='meta')
            query_norm = torch.empty([0], dtype=w_uq_qr.dtype, device='meta')
            dequant_scale_q_norm = torch.empty([0], dtype=torch.float32, device='meta')
            query_rope = torch.empty(query_rope_shape, dtype=torch.bfloat16, device='meta')

            query_mla, query_rope_mla, dequant_scale_q_nope_mla, _, _, kv_cache_mla, kr_cache_mla = torch_npu.npu_mla_prolog_v3_functional(token_x, w_dq, w_uq_qr, w_uk, w_dkv_kr, rmsnorm_gamma_cq,
            rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache)

            self.assertTrue(query_mla.shape == query.shape)
            self.assertTrue(query_rope_mla.shape == query_rope.shape)
            self.assertTrue(dequant_scale_q_nope_mla.shape == dequant_scale_q_nope.shape)
            self.assertTrue(kv_cache_mla.shape == kv_cache.shape)
            self.assertTrue(kr_cache_mla.shape == kr_cache.shape)

class TestGatherPAKVCacheFunctional(TestCase):
    def test_npu_gather_pa_kv_cache_functional_meta(self):
        with FakeTensorMode():

            batch_size = 2
            num_blocks = 8
            head_num = 4
            block_size = 64
            head_dim = 64
            max_blocks_per_sequence = 5

            seq_lens_list = [5, 8]
            total_tokens = sum(seq_lens_list)

            key_cache = torch.empty(
                (num_blocks, block_size, head_num, head_dim),
                dtype=torch.float16
            )
            value_cache = torch.empty_like(key_cache)

            block_tables = torch.empty(
                (batch_size, max_blocks_per_sequence),
                dtype=torch.int32
            )

            seq_lens = torch.empty(
                (batch_size,),
                dtype=torch.int32
            )

            key_out = torch.empty(
                (total_tokens, head_num, head_dim),
                dtype=torch.float16
            )
            value_out = torch.empty_like(key_out)

            seq_offset = torch.empty(
                (batch_size,),
                dtype=torch.int32
            )

            key_result, value_result = torch_npu.npu_gather_pa_kv_cache_functional(
                key_cache, value_cache,
                block_tables, seq_lens,
                key_out, value_out,
                seq_offset=seq_offset,
                is_seq_lens_cumsum=False
            )

            self.assertEqual(key_result.shape, key_out.shape)
            self.assertEqual(value_result.shape, value_out.shape)


class TestGroupedMatmulAdd(TestCase):
    def test_npu_grouped_matmul_add_meta_0(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            y = torch.randint(-1, 1, (4, 576, 7168), dtype=torch.float32).npu()
            x = torch.randint(-1, 1, (512, 576), dtype=torch.float16).npu()
            weight = torch.randint(-1, 1, (512, 7168), dtype=torch.float16).npu()

            group_list = torch.Tensor([8, 181, 415, 512]).to(torch.int64).npu()

            res_1 = torch_npu.npu_grouped_matmul_add_(y, x, weight, group_list, transpose_x=True, transpose_weight=False,
                  group_type=2, group_list_type=0)
            # x shape is k,m; weight shape is k,n; y shape is e,m,n
            self.assertTrue(len(res_1.shape) == 3)
            self.assertTrue(x.shape[1] == res_1.shape[1])
            self.assertTrue(weight.shape[1] == res_1.shape[2])
            self.assertTrue(res_1.dtype == torch.float32)
            self.assertTrue(y.shape[0] == res_1.shape[0])
            self.assertTrue(y.shape[1] == res_1.shape[1])
            self.assertTrue(y.shape[2] == res_1.shape[2])


class TestQuantAllReduce(TestCase):
    def test_npu_quant_all_reduce(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            # T-G量化
            x = torch.randint(-5, 5, (128, 5120), dtype=torch.int8).npu()
            scales = torch.randint(-5, 5, (128, 40), dtype=torch.float32).npu()
            # 其他参数
            hcom = "fake group info"
            world_size = 2
            res = torch_npu.npu_quant_all_reduce(x, scales, hcom, world_size)
            self.assertTrue(len(x.shape) == len(res.shape))
            # bs轴
            self.assertTrue(x.shape[0] == res.shape[0])
            self.assertTrue(x.shape[1] == res.shape[1])


class TestQuantReduceScatter(TestCase):
    def test_npu_quant_reduce_scatter(self):
        with FakeTensorMode():
            torch.manual_seed(0)
            # T-G量化
            x = torch.randint(-5, 5, (128, 5120), dtype=torch.int8).npu()
            scales = torch.randint(-5, 5, (128, 40), dtype=torch.float32).npu()
            # 其他参数
            hcom = "fake group info"
            world_size = 2
            # 断言
            res = torch_npu.npu_quant_reduce_scatter(x, scales, hcom, world_size)
            self.assertTrue(len(res.shape) == 2)
            self.assertTrue(x.shape[0] // world_size == res.shape[0])
            self.assertTrue(x.shape[1] == res.shape[1])


class TestMatmulAlltoAll(TestCase):
    def test_npu_matmul_all_to_all(self):
        # 使用模拟模式,在不实际占用设备内存的情况下运行张量操作,用假张量代替真实张量
        with FakeTensorMode():
            # 设置随机数种子
            torch.manual_seed(0)
            # 初始化输入
            x1 = torch.randint(-1, 2, (16, 32), dtype=torch.float16).npu()
            x2 = torch.randint(-1, 2, (32, 32), dtype=torch.float16).npu()
            bias = torch.randint(-1, 2, (32,), dtype=torch.float16).npu()
            # 其他参数
            hcom = "fake group info"
            world_size = 2
            # 断言
            res = torch_npu.npu_matmul_all_to_all(x1, x2, hcom, world_size, bias=bias, all2all_axes=[-1,-2])
            # 因为是假张量,所以只匹配shape和dtype是否对应
            self.assertTrue(len(res.shape) == 2)
            self.assertTrue(x1.shape[0] * world_size == res.shape[0])
            self.assertTrue(x2.shape[1] / world_size == res.shape[1])
            self.assertTrue(res.dtype == x1.dtype)


class TestQuantMatmulAlltoAll(TestCase):
    def test_npu_quant_matmul_all_to_all(self):
        # 使用模拟模式,在不实际占用设备内存的情况下运行张量操作,用假张量代替真实张量
        with FakeTensorMode():
            # 设置随机数种子
            torch.manual_seed(0)
            # 初始化输入
            x1 = torch.randint(-1, 2, (16, 32), dtype=torch.float8_e4m3fn).npu()
            x2 = torch.randint(-1, 2, (32, 32), dtype=torch.float8_e4m3fn).npu()
            bias = torch.randint(-1, 2, (32,), dtype=torch.float32).npu()
            x1Scale = torch.randint(-1, 2, (32,), dtype=torch.float32).npu()
            x2Scale = torch.randint(-1, 2, (32,), dtype=torch.float32).npu()
            # 其他参数
            hcom = "fake group info"
            world_size = 2
            common_scale = None
            x1_offset = None
            x2_offset = None
            x1_quant_mode = 3
            x2_quant_mode = 2
            common_quant_mode = 0
            group_size = [0]
            all2all_axes=[-1,-2]
            comm_quant_dtype = -1
            x1_dtype = None
            x2_dtype = None
            x1_scale_dtype = None
            x2_scale_dtype = None
            output_scale_dtype = None
            comm_scale_dtype = None
            y_dtype = None

            # 断言
            res = torch_npu.npu_quant_matmul_all_to_all(x1, x2, hcom, world_size, bias, x1Scale, x2Scale, common_scale, x1_offset, x2_offset,
                                                        x1_quant_mode, x2_quant_mode, common_quant_mode, group_size, all2all_axes, comm_quant_dtype, x1_dtype, x2_dtype, x1_scale_dtype,
                                                        x2_scale_dtype, output_scale_dtype, comm_scale_dtype, y_dtype)
            # 因为是假张量,所以只匹配shape和dtype是否对应
            self.assertTrue(len(res.shape) == 2)
            self.assertTrue(x1.shape[0] * world_size == res.shape[0])
            self.assertTrue(x2.shape[1] / world_size == res.shape[1])
            self.assertTrue(res.dtype == torch.float32)


class TestAlltoAllMatmul(TestCase):
    def test_npu_all_to_all_matmul(self):
        # 使用模拟模式,在不实际占用设备内存的情况下运行张量操作,用假张量代替真实张量
        with FakeTensorMode():
            # 设置随机数种子
            torch.manual_seed(0)
            # 初始化输入
            x1 = torch.randint(-1, 2, (16, 16), dtype=torch.float16).npu()
            x2 = torch.randint(-1, 2, (32, 32), dtype=torch.float16).npu()
            bias = torch.randint(-1, 2, (32,), dtype=torch.float16).npu()
            # 其他参数
            hcom = "fake group info"
            world_size = 2
            # 断言
            res , _ = torch_npu.npu_all_to_all_matmul(x1, x2, hcom, world_size, bias=bias, all2all_axes=[-2,-1], all2all_out_flag=True)
            # 因为是假张量,所以只匹配shape和dtype是否对应
            self.assertTrue(len(res.shape) == 2)
            self.assertTrue(x1.shape[0] / world_size == res.shape[0])
            self.assertTrue(x2.shape[1] == res.shape[1])
            self.assertTrue(res.dtype == x1.dtype)

class TestAlltoAllQuantMatmul(TestCase):
    def test_npu_all_to_all_quant_matmul(self):
        # 使用模拟模式,在不实际占用设备内存的情况下运行张量操作,用假张量代替真实张量
        with FakeTensorMode():
            # 设置随机数种子
            torch.manual_seed(0)
            # 初始化输入
            x1 = torch.randint(-1, 2, (16, 16), dtype=torch.float16).npu()
            x2 = torch.randint(-1, 2, (32, 32), dtype=torch.float8_e4m3fn).npu()
            bias = torch.randint(-1, 2, (32,), dtype=torch.float32).npu()
            x2Scale = torch.randint(-1, 2, (32,), dtype=torch.float32).npu()
            # 其他参数
            hcom = "fake group info"
            world_size = 2
            # 断言
            res , _ = torch_npu.npu_all_to_all_quant_matmul(x1, x2, hcom, world_size, bias=bias, all2all_axes=[-2,-1], all2all_out_flag=True, x2_scale=x2Scale)
            # 因为是假张量,所以只匹配shape和dtype是否对应
            self.assertTrue(len(res.shape) == 2)
            self.assertTrue(x1.shape[0] / world_size == res.shape[0])
            self.assertTrue(x2.shape[1] == res.shape[1])


class TestNpuDSA(TestCase):
    def setup_sparse_flash_attention_test_params(self, requires_grad=False, return_softmax_lse=False):
        scale_value = 0.041666666666666664
        sparse_block_size = 1
        query_type = torch.float16
        sparse_block_count = 2048
        t = 10
        b = 4
        s1 = 1
        s2 = 8192
        n1 = 128
        n2 = 1
        dn = 512
        dr = 64
        tile_size = 128
        block_size = 256
        s2_act = 4096
        attention_mode = 2

        layout = 'BSND'
        sparse_mode = 3

        query = torch.tensor(np.random.uniform(-10, 10, (b, s1, n1, dn))).to(query_type)
        key = torch.tensor(np.random.uniform(-5, 10, (b, s2, n2, dn))).to(query_type)
        value = key.clone()
        idxs = random.sample(range(s2_act - s1 + 1), sparse_block_count)
        sparse_indices = torch.tensor([idxs for _ in range(b * s1 * n2)]).reshape(b, s1, n2, sparse_block_count).to(torch.int32)
        query_rope = torch.tensor(np.random.uniform(-10, 10, (b, s1, n1, dr))).to(query_type)
        key_rope = torch.tensor(np.random.uniform(-10, 10, (b, s2, n2, dr))).to(query_type)
        act_seq_q = [s1] * b
        act_seq_kv = [s2_act] * b
        act_seq_q = torch.tensor(act_seq_q).to(torch.int32)
        act_seq_kv = torch.tensor(act_seq_kv).to(torch.int32)

        query_npu = query.npu()
        key_npu = key.npu()
        value_npu = value.npu()
        sparse_indices_npu = sparse_indices.npu()
        query_rope_npu = query_rope.npu()
        key_rope_npu = key_rope.npu()
        act_seq_q_npu = act_seq_q.npu()
        act_seq_kv_npu = act_seq_kv.npu()

        if requires_grad:
            query_npu.requires_grad = True
            key_npu.requires_grad = True
            value_npu.requires_grad = True
            query_rope_npu.requires_grad = True
            key_rope_npu.requires_grad = True

        params = {
            'query': query_npu,
            'key': key_npu,
            'value': value_npu,
            'sparse_indices': sparse_indices_npu,
            'query_rope': query_rope_npu,
            'key_rope': key_rope_npu,
            'act_seq_q': act_seq_q_npu,
            'act_seq_kv': act_seq_kv_npu,
            'scale_value': scale_value,
            'sparse_block_size': sparse_block_size,
            'layout': layout,
            'sparse_mode': sparse_mode,
            'attention_mode': attention_mode,
            'return_softmax_lse': return_softmax_lse
        }
        return params

    def call_npu_sparse_flash_attention(self, params):
        return torch_npu.npu_sparse_flash_attention(
            params['query'], params['key'], params['value'], params['sparse_indices'], params['scale_value'],
            block_table=None, actual_seq_lengths_query=params['act_seq_q'], actual_seq_lengths_kv=params['act_seq_kv'],
            query_rope=params['query_rope'], key_rope=params['key_rope'], sparse_block_size=params['sparse_block_size'],
            layout_query=params['layout'], layout_kv=params['layout'], sparse_mode=params['sparse_mode'],
            pre_tokens=(1<<63)-1, next_tokens=(1<<63)-1, attention_mode = params['attention_mode'],
            return_softmax_lse = params['return_softmax_lse'])

    def test_dsa_npu_sparse_flash_attention(self):
        with FakeTensorMode():
            params = self.setup_sparse_flash_attention_test_params()
            npu_out, npu_softmax_max, npu_softmax_sum = self.call_npu_sparse_flash_attention(params)

            query = params['query']
            key = params['key']

            expect_out = torch.empty([query.size(0), query.size(1), query.size(2), query.size(3)], dtype=query.dtype)
            expect_softmax_max = torch.empty([query.size(0), key.size(2), query.size(1), query.size(2) // key.size(2)],
                                             dtype=torch.float32)
            expect_softmax_sum = torch.empty([query.size(0), key.size(2), query.size(1), query.size(2) // key.size(2)],
                                             dtype=torch.float32)

            self.assertEqual(npu_out.dtype, expect_out.dtype)
            self.assertEqual(npu_out.shape, expect_out.shape)

            if params['return_softmax_lse']:
                self.assertEqual(npu_softmax_max.dtype, expect_softmax_max.dtype)
                self.assertEqual(npu_softmax_max.shape, expect_softmax_max.shape)
                self.assertEqual(npu_softmax_sum.dtype, expect_softmax_sum.dtype)
                self.assertEqual(npu_softmax_sum.shape, expect_softmax_sum.shape)

    @unittest.skip("skip sparse_flash_attention_grad now")
    def test_dsa_npu_sparse_flash_attention_grad(self):
        with FakeTensorMode():
            params = self.setup_sparse_flash_attention_test_params(requires_grad=True)
            npu_out, npu_softmax_max, npu_softmax_sum = self.call_npu_sparse_flash_attention(params)
            loss = npu_out.sum() + npu_softmax_max.sum() + npu_softmax_sum.sum()
            loss.backward()

            query = params['query']
            key = params['key']
            value = params['value']
            query_rope = params['query_rope']
            key_rope = params['key_rope']

            d_query = query.grad
            d_key = key.grad
            d_value = value.grad
            d_query_rope = query_rope.grad
            d_key_rope = key_rope.grad

            self.assertEqual(d_query.dtype, query.dtype)
            self.assertEqual(d_query.shape, query.shape)
            self.assertEqual(d_key.dtype, key.dtype)
            self.assertEqual(d_key.shape, key.shape)
            self.assertEqual(d_value.dtype, value.dtype)
            self.assertEqual(d_value.shape, value.shape)
            self.assertEqual(d_query_rope.dtype, query_rope.dtype)
            self.assertEqual(d_query_rope.shape, query_rope.shape)
            self.assertEqual(d_key_rope.dtype, key_rope.dtype)
            self.assertEqual(d_key_rope.shape, key_rope.shape)

    def setup_lightning_indexer_test(self, requires_grad=False, return_value=False):
        b = 1
        s1 = 1
        s2 = 8192
        n1 = 64
        n2 = 1
        d = 128
        block_size = 256
        t = 8192
        layout_query = 'BSND'

        query = torch.tensor(np.random.uniform(-10, 10, (b, s1, n1, d))).to(torch.bfloat16)
        key = torch.tensor(np.random.uniform(-10, 10, (b * (s2 // block_size), block_size, n2, d))).to(torch.bfloat16)
        weights = torch.tensor(np.random.uniform(-1, 1, (b, s1, n1))).to(torch.bfloat16)
        actual_seq_lengths_query = torch.tensor(np.random.uniform(s1, s1, (b))).to(torch.int32)
        actual_seq_lengths_key = torch.tensor(np.random.uniform(s2, s2, (b))).to(torch.int32)
        block_table = torch.tensor([range(b * s2 // block_size)], dtype=torch.int32).reshape(b, -1)
        layout_key = 'PA_BSND'
        sparse_count = 2048
        sparse_mode = 3

        query = query.npu()
        key = key.npu()
        weights = weights.npu()
        actual_seq_lengths_query = actual_seq_lengths_query.npu()
        actual_seq_lengths_key = actual_seq_lengths_key.npu()
        block_table = block_table.npu()

        if requires_grad:
            query.requires_grad = True
            key.requires_grad = True
            weights.requires_grad = True

        return {
            'query': query,
            'key': key,
            'weights': weights,
            'actual_seq_lengths_query': actual_seq_lengths_query,
            'actual_seq_lengths_key': actual_seq_lengths_key,
            'block_table': block_table,
            'layout_query': layout_query,
            'layout_key': layout_key,
            'sparse_count': sparse_count,
            'sparse_mode': sparse_mode,
            'return_value': return_value
        }

    def call_npu_lightning_indexer(self, inputs):
        return torch_npu.npu_lightning_indexer(
            inputs['query'], inputs['key'], inputs['weights'],
            actual_seq_lengths_query=inputs['actual_seq_lengths_query'],
            actual_seq_lengths_key=inputs['actual_seq_lengths_key'],
            block_table=inputs['block_table'], layout_query=inputs['layout_query'],layout_key=inputs['layout_key'],
            sparse_count=inputs['sparse_count'], sparse_mode=inputs['sparse_mode'], return_value=inputs['return_value'])

    def test_dsa_npu_lightning_indexer(self):
        with FakeTensorMode():
            params = self.setup_lightning_indexer_test()
            npu_out, npu_values_out = self.call_npu_lightning_indexer(params)

            query = params['query']
            key = params['key']
            sparse_count = params['sparse_count']

            expect_out = torch.empty([query.size(0), query.size(1), key.size(2), sparse_count], dtype=torch.int32)
            expect_values_out = torch.empty([query.size(0), query.size(1), key.size(2), sparse_count], dtype=query.dtype)

            self.assertEqual(npu_out.dtype, expect_out.dtype)
            self.assertEqual(npu_out.shape, expect_out.shape)
            if params['return_value']:
                self.assertEqual(npu_values_out.dtype, expect_values_out.dtype)
                self.assertEqual(npu_values_out.shape, expect_values_out.shape)

    @unittest.skip("skip lightning_indexer_grad now")
    def test_dsa_npu_lightning_indexer_grad(self):
        with FakeTensorMode():
            params = self.setup_lightning_indexer_test(requires_grad=True)
            npu_out, npu_values_out = self.call_npu_lightning_indexer(params)
            loss = npu_values_out.sum()
            loss.backward()

            query = params['query']
            key = params['key']
            weights = params['weights']
            d_query = query.grad
            d_key = key.grad
            d_weights = weights.grad

            self.assertEqual(d_query.dtype, query.dtype)
            self.assertEqual(d_query.shape, query.shape)
            self.assertEqual(d_key.dtype, key.dtype)
            self.assertEqual(d_key.shape, key.shape)
            self.assertEqual(d_weights.dtype, weights.dtype)
            self.assertEqual(d_weights.shape, weights.shape)

    def gen_npu_sparse_lightning_indexer_grad_kl_loss_inputs(self, seqlens_list_array, seqlens_list_kv_array, isTnd):
        B = 1
        NQuery = 64
        NQueryIndex = 64
        N2 = 1
        S1 = 128
        S2 = 128
        topK = 2048
        D = 512
        DIndex = 128
        DR = 64
        output_dtype = torch.float16
        q = torch.randn(B, S1, NQuery, D, dtype=output_dtype, device=torch.device('npu'))
        k = torch.randn(B, S2, N2, D, dtype=output_dtype, device=torch.device('npu'))

        q_index = torch.randn(B, S1, NQueryIndex, DIndex, dtype=output_dtype, device=torch.device('npu'))
        k_index = torch.randn(B, S2, N2, DIndex, dtype=output_dtype, device=torch.device('npu'))
        if DR != 0:
            q_rope = torch.randn(B, S1, NQuery, DR, dtype=output_dtype, device=torch.device('npu'))
            k_rope = torch.randn(B, S2, N2, DR, dtype=output_dtype, device=torch.device('npu'))
        else:
            q_rope = None
            k_rope = None
        weights = torch.randn(B, S1, NQueryIndex, dtype=output_dtype, device=torch.device('npu'))
        a = -0.05  # 最小值
        b = 0.05  # 最大值
        kk = 3.0  # 控制分布范围(3σ 覆盖绝大多数值)
        scale = (b - a) / (2 * kk)
        shift = (a + b) / 2
        weights = weights * scale + shift
        if isTnd:
            sparse_indices = torch.zeros(S1, N2, topK).to(torch.int32).npu()
            tIdx = 0
            for bIdx in range(B):
                for s1Idx in range(seqlens_list_array[bIdx]):
                    s2RealSize = (int)((seqlens_list_kv_array[bIdx] - seqlens_list_array[bIdx]) + s1Idx + 1)
                    if s2RealSize <= 0:
                        s2RealSize = seqlens_list_kv_array[bIdx]

                    if s2RealSize > topK:
                        s2RealLen = topK
                    else:
                        s2RealLen = s2RealSize
                    # 处理S2无效行场景,把对应的sparse indices置为-1
                    sparse_indices[tIdx, :, 0: s2RealLen] = (
                        torch.randint(0, s2RealSize, (s2RealLen,)).to(torch.int32)).npu()
                    sparse_indices[tIdx, :, s2RealLen: topK] = -1
                    tIdx = tIdx + 1
            q_tnd = q.squeeze(dim=0)
            k_tnd = k.squeeze(dim=0)
            q_index_tnd = q_index.squeeze(dim=0)
            k_index_tnd = k_index.squeeze(dim=0)
            if q_rope is not None:
                q_rope_tnd = q_rope.squeeze(dim=0)
                k_rope_tnd = k_rope.squeeze(dim=0)
            else:
                q_rope_tnd = None
                k_rope_tnd = None
            weights_tnd = weights.squeeze(dim=0)

            softmax_max = torch.randn(N2, S1, NQueryIndex, dtype=torch.float, device=torch.device('npu'))
            softmax_sum = torch.randn(N2, S1, NQueryIndex, dtype=torch.float, device=torch.device('npu'))
            return q_tnd, k_tnd, q_index_tnd, k_index_tnd, q_rope_tnd, k_rope_tnd, weights_tnd, sparse_indices, softmax_max, softmax_sum
        else:
            sparse_indices = torch.zeros(B, S1, N2, topK).to(torch.int32).npu()
            for s1Idx in range(S1):
                s2RealSize = (int)(S2 - S1 + s1Idx + 1)
                if s2RealSize <= 0:
                    s2RealSize = S2

                if s2RealSize > topK:
                    s2RealLen = topK
                else:
                    s2RealLen = s2RealSize
                sparse_indices[:, s1Idx, 0, 0: s2RealLen] = (
                    torch.randint(0, s2RealSize, (s2RealLen,)).to(torch.int32)).npu()
                sparse_indices[:, s1Idx, 0, s2RealLen: topK] = -1

            softmax_max = torch.randn(B, N2, S1, NQueryIndex, dtype=torch.float, device=torch.device('npu'))
            softmax_sum = torch.randn(B, N2, S1, NQueryIndex, dtype=torch.float, device=torch.device('npu'))
            return q, k, q_index, k_index, q_rope, k_rope, weights, sparse_indices, softmax_max, softmax_sum

    def test_dsa_npu_sparse_lightning_indexer_grad_kl_loss(self):
        with FakeTensorMode():
            actual_seq_qlen = [128]
            actual_seq_kvlen = [128]
            input_layout = 'TND'
            isTnd = True
            sparse_mode = 3
            scale = 1.0
            q, k, q_index, k_index, q_rope, k_rope, weights, sparse_indices, softmax_max, softmax_sum = self.gen_npu_sparse_lightning_indexer_grad_kl_loss_inputs(
                actual_seq_qlen, actual_seq_kvlen, isTnd)

            d_query_index, d_key_index, d_weights, loss = torch_npu.npu_sparse_lightning_indexer_grad_kl_loss(
                q, k, q_index, k_index, weights, sparse_indices, softmax_max, softmax_sum, scale,
                query_rope=q_rope, key_rope=k_rope, actual_seq_qlen=actual_seq_qlen, actual_seq_klen=actual_seq_kvlen,
                layout=input_layout, sparse_mode=sparse_mode, pre_tokens=65536, next_tokens=65536)
            expect_loss = torch.empty([1], dtype=torch.float32)

            self.assertEqual(d_query_index.dtype, q_index.dtype)
            self.assertEqual(d_query_index.shape, q_index.shape)
            self.assertEqual(d_key_index.dtype, k_index.dtype)
            self.assertEqual(d_key_index.shape, k_index.shape)
            self.assertEqual(d_weights.dtype, weights.dtype)
            self.assertEqual(d_weights.shape, weights.shape)
            self.assertEqual(loss.dtype, expect_loss.dtype)
            self.assertEqual(loss.shape, expect_loss.shape)

    def gen_npu_dense_lightning_indexer_grad_kl_loss_inputs(self, isTnd):
        B = 1
        N1 = 64
        N2 = N1
        N1_index = 64
        N2_index = 1
        S1 = 128
        S2 = 256
        D = 128
        Dr = 64
        output_dtype = torch.float16
        q = torch.randn(B, S1, N1, D, dtype=output_dtype, device=torch.device('npu'))
        k = torch.randn(B, S2, N2, D, dtype=output_dtype, device=torch.device('npu'))

        q_index = torch.randn(B, S1, N1_index, D, dtype=output_dtype, device=torch.device('npu'))
        k_index = torch.randn(B, S2, N2_index, D, dtype=output_dtype, device=torch.device('npu'))
        if Dr != 0:
            q_rope = torch.randn(B, S1, N1, Dr, dtype=output_dtype, device=torch.device('npu'))
            k_rope = torch.randn(B, S2, N2, Dr, dtype=output_dtype, device=torch.device('npu'))
        else:
            q_rope = None
            k_rope = None
        weights = torch.randn(B, S1, N1_index, dtype=output_dtype, device=torch.device('npu'))
        if isTnd:
            q_tnd = q.squeeze(dim=0)
            k_tnd = k.squeeze(dim=0)
            q_index_tnd = q_index.squeeze(dim=0)
            k_index_tnd = k_index.squeeze(dim=0)
            if q_rope is not None:
                q_rope_tnd = q_rope.squeeze(dim=0)
                k_rope_tnd = k_rope.squeeze(dim=0)
            else :
                q_rope_tnd = None
                k_rope_tnd = None
            weights_tnd = weights.squeeze(dim=0)

            softmax_max = (torch.randn(N2, S1, 1, dtype=torch.float32, device=torch.device('npu')).abs() + 0.4) * D
            softmax_sum = torch.ones(N2, S1, 1, dtype=torch.float32, device=torch.device('npu'))
            actual_seq_qlen = [S1]
            actual_seq_klen = [S2]
            return q_tnd, k_tnd, q_index_tnd, k_index_tnd, q_rope_tnd, k_rope_tnd, weights_tnd, softmax_max, softmax_sum, actual_seq_qlen, actual_seq_klen
        else :
            softmax_max = (torch.randn(B, N2, S1, 1, dtype=torch.float32, device=torch.device('npu')).abs() + 0.4) * D
            softmax_sum = torch.ones(B, N2, S1, 1, dtype=torch.float32, device=torch.device('npu'))
            actual_seq_qlen = None
            actual_seq_klen = None
            return q, k, q_index, k_index, q_rope, k_rope, weights, softmax_max, softmax_sum, actual_seq_qlen, actual_seq_klen

    def test_dsa_npu_dense_lightning_indexer_grad_kl_loss(self):
        with FakeTensorMode():
            q_dtype = torch.float16
            B, N1, N2, N1_index, N2_index, S1, S2, D, Dr = 1, 64, 64, 64, 1, 128, 256, 128, 64
            query = torch.randn(B, S1, N1, D, dtype=q_dtype)
            key = torch.randn(B, S2, N2, D, dtype=q_dtype)
            query_index = torch.randn(B, S1, N1_index, D, dtype=q_dtype)
            key_index = torch.randn(B, S2, N2_index, D, dtype=q_dtype)
            query_rope = torch.randn(B, S1, N1, Dr, dtype=q_dtype)
            key_rope = torch.randn(B, S2, N2, Dr, dtype=q_dtype)
            weights = torch.randn(B, S1, N1_index, dtype=q_dtype)
            softmax_max = (torch.randn(B, N2, S1, 1, dtype=torch.float32).abs() + 0.4) * D  # N1=N2
            softmax_sum = torch.ones(B, N2, S1, 1, dtype=torch.float32)
            softmax_max_index = (torch.randn(B, 1, S1, dtype=torch.float32).abs() + 0.4) * D * N1_index
            softmax_sum_index = torch.ones(B, 1, S1, dtype=torch.float32)
            actual_seq_qlen = [S1]
            actual_seq_klen = [S2]
            scale = 1.0

            d_query_index, d_key_index, d_weights, loss = torch_npu.npu_dense_lightning_indexer_grad_kl_loss(
                    query, key, query_index, key_index, weights, softmax_max, softmax_sum, softmax_max_index, softmax_sum_index, scale,
                    query_rope=query_rope, key_rope=key_rope, actual_seq_qlen=actual_seq_qlen,
                    actual_seq_klen=actual_seq_klen, layout="BSND", sparse_mode=3)
            expect_loss = torch.empty([1], dtype=torch.float32)

            self.assertEqual(d_query_index.dtype, query_index.dtype)
            self.assertEqual(d_query_index.shape, query_index.shape)
            self.assertEqual(d_key_index.dtype, key_index.dtype)
            self.assertEqual(d_key_index.shape, key_index.shape)
            self.assertEqual(d_weights.dtype, weights.dtype)
            self.assertEqual(d_weights.shape, weights.shape)
            self.assertEqual(loss.dtype, expect_loss.dtype)
            self.assertEqual(loss.shape, expect_loss.shape)

class TestNpuConv2d(TestCase):
    def test_npu_conv2d_meta_0(self):
        with FakeTensorMode():
            input_ = torch.randn(1, 3, 32, 32, dtype=torch.float).npu()
            weight = torch.randn(6, 3, 5, 5, dtype=torch.float).npu()
            bias = torch.randn(6, dtype=torch.float).npu()
            stride = [1, 1]
            padding = [2, 2]
            dilation = [1, 1]
            groups = 1
            output = torch_npu.npu_conv2d(
                input_, weight, bias, stride, padding, dilation, groups
            )
            expect_output = torch.empty([1, 6, 32, 32], dtype=torch.float)
            self.assertEqual(output.dtype, expect_output.dtype)
            self.assertEqual(output.shape, expect_output.shape)

    def test_npu_conv2d_backward_meta(self):
        input_tensor = torch.randn(4, 3, 28, 28, device="npu", requires_grad=True)
        weight_tensor = torch.randn(4, 3, 4, 4, device="npu", requires_grad=True)
        bias = torch.randn(4, device="npu", requires_grad=True)
        stride = (1, 1)
        padding = (1, 1)
        dilation = (1, 1)
        groups = 1
        output_mask = [True, True, True]

        output_tensor = torch_npu.npu_conv2d(input_tensor, weight_tensor, bias, stride, padding, dilation, groups)
        grad_output = torch.ones_like(output_tensor, device="npu")

        input_grad, weight_grad, bias_grad = torch_npu.npu_conv2d_backward(input_tensor, grad_output, weight_tensor, stride, padding, dilation, groups, output_mask)

        with FakeTensorMode():
            input_fake_tensor = torch.randn(4, 3, 28, 28, device="npu", requires_grad=True)
            weight_fake_tensor = torch.randn(4, 3, 4, 4, device="npu", requires_grad=True)
            bias_fake = torch.randn(4, device="npu", requires_grad=True)

            output_fake_tensor = torch_npu.npu_conv2d(input_fake_tensor, weight_fake_tensor, bias_fake, stride, padding, dilation, groups)
            grad_output_fake = torch.ones_like(output_fake_tensor, device="npu")

            input_grad_fake, weight_grad_fake, bias_grad_fake = torch_npu.npu_conv2d_backward(input_fake_tensor, grad_output_fake, weight_fake_tensor, stride, padding, dilation, groups, output_mask)

            self.assertEqual(input_grad.shape, input_grad_fake.shape)
            self.assertEqual(input_grad.dtype, input_grad_fake.dtype)
            self.assertEqual(weight_grad.shape, weight_grad_fake.shape)
            self.assertEqual(weight_grad.dtype, weight_grad_fake.dtype)
            self.assertEqual(bias_grad.shape, bias_grad_fake.shape)
            self.assertEqual(bias_grad.dtype, bias_grad_fake.dtype)


class TestNpuApplyAdamW(TestCase):

    def test_npu_apply_adam_w_meta_0(self):
        with FakeTensorMode():
            var_value = np.random.uniform(10.0, 20.0)
            m_value = np.random.uniform(5.0, 10.0)

            beta1_power = np.random.uniform(0.0, 1.0)
            beta2_power = np.random.uniform(0.0, 1.0)
            lr = np.random.uniform(0.0001, 0.1)
            weight_decay = np.random.uniform(0.001, 0.1)
            beta1 = np.random.uniform(0.5, 1.0)
            beta2 = np.random.uniform(0.5, 1.0)
            eps = np.random.uniform(0.00001, 0.01)
            max_grad_norm = None
            amsgrad = False
            maximize = True

            var = torch.empty((21130, 512), device="npu")
            var.fill_(var_value)
            m = torch.empty((21130, 512), device="npu")
            m.fill_(m_value)
            v = torch.empty((21130, 512), device="npu")
            v.zero_()
            grad = torch.empty((21130, 512), device="npu")
            grad.zero_()
            var_out, m_out, v_out = torch_npu.npu_apply_adam_w(
                beta1_power,
                beta2_power,
                lr,
                weight_decay,
                beta1,
                beta2,
                eps,
                grad,
                max_grad_norm,
                amsgrad,
                maximize,
                out=(var, m, v),
            )
            self.assertEqual(var_out.dtype, var.dtype)
            self.assertEqual(var_out.shape, var.shape)
            self.assertEqual(m_out.dtype, m.dtype)
            self.assertEqual(m_out.shape, m.shape)
            self.assertEqual(v_out.dtype, v.dtype)
            self.assertEqual(v_out.shape, v.shape)


class TestNpuCrossEntropyLoss(TestCase):
    def test_npu_cross_entropy_loss_meta_0(self):
        with FakeTensorMode():
            N = 4096
            C = 8080

            input = torch.randn(N, C, dtype=torch.float32,
                                requires_grad=True).npu()
            target = torch.arange(0, N, dtype=torch.int64).npu()
            loss, log_prob, _, _ = torch_npu.npu_cross_entropy_loss(
                input, target, reduction="sum", ignore_index=100
            )
            expect_loss = torch.empty([1], dtype=torch.float32)
            expect_log_prob = torch.empty([N, C], dtype=torch.float32)
            self.assertEqual(loss.dtype, expect_loss.dtype)
            self.assertEqual(loss.shape, expect_loss.shape)
            self.assertEqual(log_prob.dtype, expect_log_prob.dtype)
            self.assertEqual(log_prob.shape, expect_log_prob.shape)

    @SupportedDevices(['Ascend910B'])
    def test_npu_cross_entropy_loss_backward_meta(self):
        N = 8
        C = 8
        input = torch.randn(N, C, device="npu", requires_grad=True)
        target = torch.arange(0, N, device="npu")

        loss, log_prob, _, _ = torch_npu.npu_cross_entropy_loss(input, target, reduction="sum")
        loss.backward()

        with FakeTensorMode():
            input_fake_tensor = torch.randn(N, C, device="npu", requires_grad=True)
            target_fake_tensor = torch.arange(0, N, device="npu")

            loss_fake_tensor, log_prob_fake_tensor, _, _ = torch_npu.npu_cross_entropy_loss(input_fake_tensor, target_fake_tensor, reduction="sum")
            loss_fake_tensor.backward()

            self.assertEqual(input.grad.shape, input_fake_tensor.grad.shape)
            self.assertEqual(input.grad.dtype, input_fake_tensor.grad.dtype)


class TestNpuFusionAttention(TestCase):
    def test_npu_fusion_attention_input_layout_is_BSND(self):

        with FakeTensorMode():
            B, N, S, D = 4, 8, 32, 32
            shape = (B, S, N, D)
            query = torch.randn(shape, dtype=torch.float32, device="npu", requires_grad=True)
            key = torch.randn(shape, dtype=torch.float32, device="npu", requires_grad=True)
            value = torch.randn(shape, dtype=torch.float32, device="npu", requires_grad=True)
            scale = 0.08838
            softmax_max_sum = torch.randn((B, N, S, 8), dtype=torch.float32).npu()

            res = torch_npu.npu_fusion_attention(query, key, value, head_num=N, input_layout="BSND", scale=scale)

            self.assertEqual(query.shape, res[0].shape)
            self.assertEqual(query.dtype, res[0].dtype)
            self.assertEqual(softmax_max_sum.shape, res[1].shape)
            self.assertEqual(softmax_max_sum.dtype, res[1].dtype)
            self.assertEqual(softmax_max_sum.shape, res[2].shape)
            self.assertEqual(softmax_max_sum.dtype, res[2].dtype)


class TestNpuRepeatInterleave(TestCase):
    def test_npu_repeat_interleave_backward(self):
        x = torch.randn(2, 2, device="npu", requires_grad=True)
        repeats_value = 3

        output = torch.repeat_interleave(x, repeats_value)
        grad = torch.randn(output.size(), device="npu")
        output.backward(grad)

        with FakeTensorMode():
            x_fake_tensor = torch.randn(2, 2, device="npu", requires_grad=True)
            repeats_value = 3

            output_fake_tensor = torch.repeat_interleave(x_fake_tensor, repeats_value)
            grad_fake_tensor = torch.randn(output_fake_tensor.size(), device="npu")
            output_fake_tensor.backward(grad_fake_tensor)

            self.assertEqual(x.shape, x_fake_tensor.shape)
            self.assertEqual(x.dtype, x_fake_tensor.dtype)


class TestQuantMatmulInplaceAdd(TestCase):
    def test_npu_quant_batch_matmul_inplace_add(self):
        M = 7168
        N = 576
        K = 512
        with FakeTensorMode():
            y = torch.randint(-1, 1, (M, N), dtype=torch.float32).npu()
            x1 = torch.randint(-1, 1, (M, math.ceil(K/64)), dtype=torch.int8).npu()
            x2 = torch.randint(-1, 1, (K, N), dtype=torch.int8).npu()
            x2_scale = torch.randint(-1, 1, (M, math.ceil(K/64), 2), dtype=torch.float32).npu()
            x1_scale = torch.randint(-1, 1, (math.ceil(K/64), 576, 2), dtype=torch.float32).npu()
            res_1 = torch_npu.npu_add_quant_matmul_(y, x1, x2, x2_scale, x1_scale=x1_scale, group_sizes=None)
            self.assertTrue(len(res_1.shape) == 2)
            self.assertTrue(x1.shape[0] == res_1.shape[0])
            self.assertTrue(x2.shape[1] == res_1.shape[1])
            self.assertTrue(res_1.dtype == torch.float32)
            res_2 = torch_npu.npu_add_quant_matmul(y, x1, x2, x2_scale, x1_scale=x1_scale, group_sizes=None)
            self.assertTrue(len(res_1.shape) == 2)
            self.assertTrue(x1.shape[0] == res_1.shape[0])
            self.assertTrue(x2.shape[1] == res_1.shape[1])
            self.assertTrue(res_1.dtype == torch.float32)

class TestNpuDeformableConv2d(TestCase):
    def test_npu_deformable_conv2d(self):
        with FakeTensorMode():
            input = torch.rand(16, 32, 32, 32).npu()
            weight = torch.rand(32, 32, 5, 5).npu()
            offset = torch.rand(16, 75, 38, 38).npu()

            out, deformOut = torch_npu.npu_deformable_conv2d(
                input,
                weight,
                offset,
                None,
                kernel_size=[5, 5],
                stride=[1, 1, 1, 1],
                padding=[4, 6, 8, 2]
            )

            self.assertEqual(out.shape, (16, 32, 38, 38))
            self.assertEqual(deformOut.shape, (16, 32, 190, 190))


class TestNpuPsRoiPooling(TestCase):
    def test_npu_ps_roi_pooling(self):
        with FakeTensorMode():
            x = torch.rand((8, 68, 32, 32, 1), dtype=torch.float).npu()
            roi = torch.rand((8, 5, 2), dtype=torch.float).npu()

            out = torch_npu.npu_ps_roi_pooling(x, roi, 0.5, 2, 2)

            self.assertEqual(out.shape, (16, 2, 2, 2))


class TestNpuConvolution(TestCase):
    def test_npu_convolution(self):
        with FakeTensorMode():
            input = torch.empty([1, 128, 4, 14, 14], dtype=torch.float).uniform_(-1, 1).npu()
            weight = torch.empty([1, 128, 3, 3, 3], dtype=torch.float).uniform_(-1, 1).npu()
            bias = None
            stride = [1, 1, 1]
            padding = [1, 1, 1]
            dilation = [1, 1, 1]
            groups = 1

            out = torch_npu.npu_convolution(input, weight, bias, stride, padding, dilation, groups)

            self.assertEqual(out.shape, (1, 1, 4, 14, 14))


class TestNpuConvolutionTranspose(TestCase):
    def test_npu_convolution_transpose(self):
        with FakeTensorMode():
            input = torch.empty([1, 3, 3, 3], dtype=torch.float).uniform_(-1, 1).npu()
            weight = torch.empty([3, 2, 3, 3], dtype=torch.float).uniform_(-1, 1).npu()
            bias = torch.empty([2], dtype=torch.float).uniform_(-1, 1).npu()
            padding = [1, 1]
            output_padding = [0, 0]
            stride = [1, 1]
            dilation = [1, 1]
            groups = 1

            out = torch_npu.npu_convolution_transpose(input, weight, bias, padding, output_padding, stride, dilation, groups)

            self.assertEqual(out.shape, (1, 2, 3, 3))


class TestBatchNormReduce(TestCase):
    def test_batch_norm_reduce(self):
        with FakeTensorMode():
            input_fake_tensor = torch.randn(2, 3, 12, 12, device="npu", requires_grad=True)
            eps = 1e-5
            output_fake_tensor1, output_fake_tensor2 = torch_npu.batch_norm_reduce(input_fake_tensor, eps)
            self.assertEqual(input_fake_tensor.shape[1], output_fake_tensor1.shape[0])
            self.assertEqual(input_fake_tensor.dtype, output_fake_tensor1.dtype)
            self.assertEqual(input_fake_tensor.shape[1], output_fake_tensor2.shape[0])
            self.assertEqual(input_fake_tensor.dtype, output_fake_tensor2.dtype)


class TestMatmul(TestCase):
    @staticmethod
    def _randn(shape, dtype, requires_grad=False):
        if len(shape) == 0:
            return torch.randn((), dtype=dtype, device="npu", requires_grad=requires_grad)
        return torch.randn(shape, dtype=dtype, device="npu", requires_grad=requires_grad)

    def _run_fake_matmul_backward(self, grad_shape, mat1_shape, mat2_shape, grad_dtype, mat1_dtype, mat2_dtype):
        with FakeTensorMode():
            grad = self._randn(grad_shape, grad_dtype)
            mat1 = self._randn(mat1_shape, mat1_dtype, requires_grad=True)
            mat2 = self._randn(mat2_shape, mat2_dtype, requires_grad=True)
            return torch.ops.aten.matmul_backward.default(grad, mat1, mat2, [True, True])

    def _assert_fake_backward_matches_real(self, mat1_shape, mat2_shape, dtype):
        mat1 = self._randn(mat1_shape, dtype, requires_grad=True)
        mat2 = self._randn(mat2_shape, dtype, requires_grad=True)
        output = torch.matmul(mat1, mat2)
        grad = torch.randn(output.size(), dtype=output.dtype, device="npu")
        output.backward(grad)

        with FakeTensorMode():
            fake_mat1 = self._randn(mat1_shape, dtype, requires_grad=True)
            fake_mat2 = self._randn(mat2_shape, dtype, requires_grad=True)
            fake_output = torch.matmul(fake_mat1, fake_mat2)
            fake_grad = torch.randn(fake_output.size(), dtype=fake_output.dtype, device="npu")
            fake_output.backward(fake_grad)

        self.assertEqual(mat1.grad.shape, fake_mat1.grad.shape)
        self.assertEqual(mat1.grad.dtype, fake_mat1.grad.dtype)
        self.assertEqual(mat2.grad.shape, fake_mat2.grad.shape)
        self.assertEqual(mat2.grad.dtype, fake_mat2.grad.dtype)

    def test_matmul_backward(self):
        x = torch.randn(8, 4, 8, device="npu", requires_grad=True)
        y = torch.randn(8, 8, 8, 4, device="npu", requires_grad=True)

        output = torch.matmul(x, y)
        grad = torch.randn(output.size(), device="npu")
        output.backward(grad)

        with FakeTensorMode():
            x_fake_tensor = torch.randn(8, 4, 8, device="npu", requires_grad=True)
            y_fake_tensor = torch.randn(8, 8, 8, 4, device="npu", requires_grad=True)

            output_fake_tensor = torch.matmul(x_fake_tensor, y_fake_tensor)
            grad_fake_tensor = torch.randn(output_fake_tensor.size(), device="npu")
            output_fake_tensor.backward(grad_fake_tensor)

            self.assertEqual(x.grad.shape, x_fake_tensor.grad.shape)
            self.assertEqual(x.grad.dtype, x_fake_tensor.grad.dtype)
            self.assertEqual(y.grad.shape, y_fake_tensor.grad.shape)
            self.assertEqual(y.grad.dtype, y_fake_tensor.grad.dtype)

    @unittest.skipIf(
        os.getenv("TORCH_NPU_USE_COMPATIBLE_IMPL") == "1",
        "matmul_backward meta registration is disabled in compatible impl",
    )
    def test_matmul_backward_meta_shapes_and_dtypes(self):
        test_cases = [
            {
                "name": "vector_vector",
                "grad_shape": (),
                "mat1_shape": (8,),
                "mat2_shape": (8,),
                "grad_dtype": torch.float32,
                "mat1_dtype": torch.float16,
                "mat2_dtype": torch.float16,
                "expected_self_shape": (1, 8),
                "expected_other_shape": (8,),
                "expected_self_dtype": torch.float32,
                "expected_other_dtype": torch.float16,
            },
            {
                "name": "vector_matrix",
                "grad_shape": (6,),
                "mat1_shape": (5,),
                "mat2_shape": (5, 6),
                "grad_dtype": torch.float32,
                "mat1_dtype": torch.float16,
                "mat2_dtype": torch.float16,
                "expected_self_shape": (1, 5),
                "expected_other_shape": (5, 6),
                "expected_self_dtype": torch.float32,
                "expected_other_dtype": torch.float16,
            },
            {
                "name": "matrix_vector_squeeze_other_grad",
                "grad_shape": (3,),
                "mat1_shape": (3, 5),
                "mat2_shape": (5,),
                "grad_dtype": torch.float32,
                "mat1_dtype": torch.float32,
                "mat2_dtype": torch.float16,
                "expected_self_shape": (3, 5),
                "expected_other_shape": (5,),
                "expected_self_dtype": torch.float32,
                "expected_other_dtype": torch.float32,
            },
            {
                "name": "matrix_batched_matrix_special_self_grad",
                "grad_shape": (2, 3, 4),
                "mat1_shape": (3, 5),
                "mat2_shape": (2, 5, 4),
                "grad_dtype": torch.float16,
                "mat1_dtype": torch.float16,
                "mat2_dtype": torch.float16,
                "expected_self_shape": (3, 5),
                "expected_other_shape": (2, 5, 4),
                "expected_self_dtype": torch.float16,
                "expected_other_dtype": torch.float16,
            },
            {
                "name": "batched_matrix_matrix_special_other_grad",
                "grad_shape": (2, 3, 6),
                "mat1_shape": (2, 3, 5),
                "mat2_shape": (5, 6),
                "grad_dtype": torch.float16,
                "mat1_dtype": torch.float16,
                "mat2_dtype": torch.float16,
                "expected_self_shape": (2, 3, 5),
                "expected_other_shape": (5, 6),
                "expected_self_dtype": torch.float16,
                "expected_other_dtype": torch.float16,
            },
            {
                "name": "leading_singleton_self_batch",
                "grad_shape": (1, 4, 3, 6),
                "mat1_shape": (1, 4, 3, 5),
                "mat2_shape": (5, 6),
                "grad_dtype": torch.float32,
                "mat1_dtype": torch.float16,
                "mat2_dtype": torch.float16,
                "expected_self_shape": (1, 4, 3, 5),
                "expected_other_shape": (5, 6),
                "expected_self_dtype": torch.float32,
                "expected_other_dtype": torch.float16,
            },
            {
                "name": "leading_singleton_other_batch",
                "grad_shape": (1, 4, 3, 6),
                "mat1_shape": (3, 5),
                "mat2_shape": (1, 4, 5, 6),
                "grad_dtype": torch.float32,
                "mat1_dtype": torch.float32,
                "mat2_dtype": torch.float16,
                "expected_self_shape": (3, 5),
                "expected_other_shape": (1, 4, 5, 6),
                "expected_self_dtype": torch.float32,
                "expected_other_dtype": torch.float32,
            },
        ]

        for case in test_cases:
            with self.subTest(case=case["name"]):
                self_grad, other_grad = self._run_fake_matmul_backward(
                    case["grad_shape"],
                    case["mat1_shape"],
                    case["mat2_shape"],
                    case["grad_dtype"],
                    case["mat1_dtype"],
                    case["mat2_dtype"],
                )
                self.assertEqual(self_grad.shape, torch.Size(case["expected_self_shape"]))
                self.assertEqual(other_grad.shape, torch.Size(case["expected_other_shape"]))
                self.assertEqual(self_grad.dtype, case["expected_self_dtype"])
                self.assertEqual(other_grad.dtype, case["expected_other_dtype"])

    @unittest.skipIf(
        os.getenv("TORCH_NPU_USE_COMPATIBLE_IMPL") == "1",
        "matmul_backward meta registration is disabled in compatible impl",
    )
    def test_matmul_backward_fake_autograd_matches_real_grad_meta(self):
        test_cases = [
            ((8,), (8, 4), torch.float16),
            ((4, 8), (2, 8, 6), torch.float16),
            ((1, 4, 3, 5), (5, 6), torch.float32),
        ]

        for mat1_shape, mat2_shape, dtype in test_cases:
            with self.subTest(mat1_shape=mat1_shape, mat2_shape=mat2_shape, dtype=dtype):
                self._assert_fake_backward_matches_real(mat1_shape, mat2_shape, dtype)


class TestKLDivLoss(TestCase):
    def test_kl_div_backward(self):
        loss_fn = torch.nn.KLDivLoss(reduction="sum", log_target=True).to("npu")
        pred = torch.randn(4, 4, device="npu", requires_grad=True)
        target = torch.randn(4, 4, device="npu")

        loss = loss_fn(pred.log_softmax(dim=-1), target)
        loss.backward()

        with FakeTensorMode():
            pred_fake_tensor = torch.randn(4, 4, device="npu", requires_grad=True)
            target_fake_tensor = torch.randn(4, 4, device="npu")

            loss_fake_tensor = loss_fn(pred_fake_tensor.log_softmax(dim=-1), target_fake_tensor)
            loss_fake_tensor.backward()

            self.assertEqual(pred.grad.shape, pred_fake_tensor.grad.shape)
            self.assertEqual(pred.grad.dtype, pred_fake_tensor.grad.dtype)

class TestNpuDropoutGenMaskMeta(TestCase):
    def test_npu_dropout_gen_mask(self):
        size = (2, 1024, 768)
        x = torch.randn(size, device="npu")
        out = torch_npu._npu_dropout_gen_mask(x, size, 0.5, 1, 0)

        with FakeTensorMode():
            x_fake_tensor = torch.empty(size, device='meta')
            out_fake_tensor = torch_npu._npu_dropout_gen_mask(x_fake_tensor, size, 0.5, 1, 0)

            self.assertEqual(out_fake_tensor.shape, out.shape)
            self.assertEqual(out_fake_tensor.dtype, out.dtype)


class TestNpuSoftmaxCrossEntropyWithLogitsMeta(TestCase):
    def test_npu_softmax_cross_entropy_with_logits(self):
        logits = torch.randn(32, 1000, device="npu")
        labels = torch.randint(0, 1000, (32,), dtype=torch.long, device="npu")
        loss = torch_npu.npu_softmax_cross_entropy_with_logits(logits, labels)

        with FakeTensorMode():
            logits_fake_tensor = torch.empty(32, 1000, device='meta')
            labels_fake_tensor = torch.empty(32, dtype=torch.long, device='meta')
            loss_fake_tensor = torch_npu.npu_softmax_cross_entropy_with_logits(logits_fake_tensor, labels_fake_tensor)

            self.assertEqual(loss_fake_tensor.shape, loss.shape)
            self.assertEqual(loss_fake_tensor.dtype, loss.dtype)


class TestNpuSoftmaxCrossEntropyWithLogitsBackwardMeta(TestCase):
    def test_npu_softmax_cross_entropy_with_logits_backward(self):
        logits = torch.randn(32, 1000, device="npu")
        labels = torch.randint(0, 1000, (32,), dtype=torch.long, device="npu")
        loss = torch_npu.npu_softmax_cross_entropy_with_logits(logits, labels)
        grad = torch.ones_like(loss)
        grad_out = torch_npu.npu_softmax_cross_entropy_with_logits_backward(grad, logits, labels)

        with FakeTensorMode():
            logits_fake_tensor = torch.empty(32, 1000, device='meta')
            labels_fake_tensor = torch.empty(32, dtype=torch.long, device='meta')
            grad_fake_tensor = torch.empty(32, device='meta')
            grad_out_fake_tensor = torch_npu.npu_softmax_cross_entropy_with_logits_backward(
                grad_fake_tensor, logits_fake_tensor, labels_fake_tensor
            )

            self.assertEqual(grad_out_fake_tensor.shape, grad_out.shape)
            self.assertEqual(grad_out_fake_tensor.dtype, grad_out.dtype)


class TestNpuIndexAddMeta(TestCase):
    @unittest.skip("Skip until CANN supports aclnnIndexAddV2; do not execute")
    def test_npu_index_add(self):
        x = torch.randn(10, 5, device="npu")
        index = torch.randint(0, 10, (2,), dtype=torch.int64, device="npu")
        source = torch.randn(2, 5, device="npu")
        out = torch_npu._npu_index_add(x, index, source, alpha=1)

        with FakeTensorMode():
            x_fake_tensor = torch.empty(10, 5, device='meta')
            index_fake_tensor = torch.empty(2, dtype=torch.int64, device='meta')
            source_fake_tensor = torch.empty(2, 5, device='meta')
            out_fake_tensor = torch_npu._npu_index_add(x_fake_tensor, index_fake_tensor, source_fake_tensor, alpha=1)

            self.assertEqual(out_fake_tensor.shape, out.shape)
            self.assertEqual(out_fake_tensor.dtype, out.dtype)


class TestNpuIndexAddInplaceMeta(TestCase):
    @unittest.skip("Skip until CANN supports aclnnIndexAddV2; do not execute")
    def test_npu_index_add_(self):
        x = torch.randn(10, 5, device="npu")
        index = torch.randint(0, 10, (2,), dtype=torch.int64, device="npu")
        source = torch.randn(2, 5, device="npu")
        out = torch_npu._npu_index_add_(x, index, source, alpha=1)

        with FakeTensorMode():
            x_fake_tensor = torch.empty(10, 5, device='meta')
            index_fake_tensor = torch.empty(2, dtype=torch.int64, device='meta')
            source_fake_tensor = torch.empty(2, 5, device='meta')
            out_fake_tensor = torch_npu._npu_index_add_(x_fake_tensor, index_fake_tensor, source_fake_tensor, alpha=1)

            self.assertEqual(out_fake_tensor.shape, out.shape)
            self.assertEqual(out_fake_tensor.dtype, out.dtype)
            self.assertIs(out_fake_tensor, x_fake_tensor)


class TestNpuReshapeMeta(TestCase):
    def test_npu_reshape(self):
        x = torch.randn(2, 3, 4, device="npu")
        shape = (6, 4)
        out = torch_npu.npu_reshape(x, shape, can_refresh=False)

        with FakeTensorMode():
            x_fake_tensor = torch.empty(2, 3, 4, device='meta')
            out_fake_tensor = torch_npu.npu_reshape(x_fake_tensor, shape, can_refresh=False)

            self.assertEqual(out_fake_tensor.shape, out.shape)
            self.assertEqual(out_fake_tensor.dtype, out.dtype)

class TestSwigluMxQuant(TestCase):
    def test_npu_swiglu_mx_quant_meta(self):
        with FakeTensorMode():
            x = torch.randn([4, 8], dtype=torch.float16, device='npu')
            group_index = None
            y_npu, mxscale_npu = torch_npu.npu_swiglu_mx_quant(
                x, group_index=group_index, activate_dim=-1, activate_left=False, swiglu_mode=0,
                clamp_limit=7, glu_alpha=1.702, glu_bias=1, group_mode=0, axis=-1,
                dst_type=torch_npu.float4_e2m1fn_x2, round_mode="rint", scale_alg=0, max_dtype_value=0
            )
            self.assertEqual(y_npu.shape, torch.Size([4, 2]))
            self.assertEqual(y_npu.dtype, torch.uint8)
            self.assertEqual(mxscale_npu.shape, torch.Size([4, 1, 2]))
            self.assertEqual(mxscale_npu.dtype, torch.uint8)


class TestSwigluGroupQuantBackward(TestCase):
    def test_npu_swiglu_group_quant_backward_meta(self):
        with FakeTensorMode():
            grad_y = torch.randn([4, 8], dtype=torch.float32, device='npu')
            x = torch.randn([4, 16], dtype=torch.float32, device='npu')
            weight = torch.randn([4, 1], dtype=torch.float32, device='npu')
            y_origin = torch.randn([4, 8], dtype=torch.float32, device='npu')
            group_index = None
            clamp_limit = 0.0
            grad_x_npu, grad_weight_npu = torch_npu.npu_swiglu_group_quant_backward(grad_y, x, weight=weight,
                                                                           y_origin=y_origin, group_index=group_index,
                                                                           clamp_limit=clamp_limit)
            self.assertEqual(grad_x_npu.shape, torch.Size([4, 16]))
            self.assertEqual(grad_x_npu.dtype, torch.float32)
            self.assertEqual(grad_weight_npu.shape, torch.Size([4, 1]))
            self.assertEqual(grad_weight_npu.dtype, torch.float32)


@unittest.skip("skip until CANN is updated to support aclnnDynamicBlockMxQuant")
class TestNpuDynamicBlockMxQuant(TestCase):
    def test_npu_dynamic_block_mx_quant_meta(self):
        # 2 dim
        x = torch.rand(64, 256).to("npu").to(torch.float16)
        actual_y, actual_scale1, actual_scale2 = torch_npu.npu_dynamic_block_mx_quant(x)
        with FakeTensorMode():
            fake_x = torch.rand(64, 256).to("npu").to(torch.float16)
            fake_y, fake_scale1, fake_scale2 = torch_npu.npu_dynamic_block_mx_quant(fake_x)
        self.assertEqual(actual_y.shape, fake_y.shape)
        self.assertEqual(actual_scale1.shape, fake_scale1.shape)
        self.assertEqual(actual_scale2.shape, fake_scale2.shape)
        self.assertEqual(actual_y.dtype, fake_y.dtype)
        self.assertEqual(actual_scale1.dtype, fake_scale1.dtype)
        self.assertEqual(actual_scale2.dtype, fake_scale2.dtype)

        # 3 dim
        x = torch.rand(32, 64, 256).to("npu").to(torch.float16)
        actual_y, actual_scale1, actual_scale2 = torch_npu.npu_dynamic_block_mx_quant(x)
        with FakeTensorMode():
            fake_x = torch.rand(32, 64, 256).to("npu").to(torch.float16)
            fake_y, fake_scale1, fake_scale2 = torch_npu.npu_dynamic_block_mx_quant(fake_x)
        self.assertEqual(actual_y.shape, fake_y.shape)
        self.assertEqual(actual_scale1.shape, fake_scale1.shape)
        self.assertEqual(actual_scale2.shape, fake_scale2.shape)
        self.assertEqual(actual_y.dtype, fake_y.dtype)
        self.assertEqual(actual_scale1.dtype, fake_scale1.dtype)
        self.assertEqual(actual_scale2.dtype, fake_scale2.dtype)


class TestNpuMultiHeadAttention(TestCase):
    def test_npu_multi_head_attention_meta(self):
        with FakeTensorMode():
            batch = 2
            tgt_len = 16
            src_len = 16
            attn_head_num = 8
            attn_dim_per_head = 64
            weight_col = attn_head_num * attn_dim_per_head
            dropout_prob = 0.5
            softmax_use_float = True
            query = torch.randn(batch * tgt_len, weight_col,
                                dtype=torch.float16).npu()
            key = torch.randn(batch * src_len, weight_col,
                              dtype=torch.float16).npu()
            value = torch.randn(batch * src_len, weight_col,
                                dtype=torch.float16).npu()
            query_weight = torch.randn(
                weight_col, weight_col, dtype=torch.float16).npu()
            key_weight = torch.randn(
                weight_col, weight_col, dtype=torch.float16).npu()
            value_weight = torch.randn(
                weight_col, weight_col, dtype=torch.float16).npu()
            attn_mask = torch.randn(
                batch * attn_head_num * tgt_len * src_len, dtype=torch.float16).npu()
            out_proj_weight = torch.randn(
                weight_col, weight_col, dtype=torch.float16).npu()
            query_bias = torch.randn(1, weight_col, dtype=torch.float16).npu()
            key_bias = torch.randn(1, weight_col, dtype=torch.float16).npu()
            value_bias = torch.randn(1, weight_col, dtype=torch.float16).npu()
            out_proj_bias = torch.randn(
                1, weight_col, dtype=torch.float16).npu()
            dropout_mask_input = torch.randint(
                0, 1, (weight_col, ), dtype=torch.uint8).npu()
            y, dropout_mask, query_res, key_res, value_res, attn_scores, attn_res, context = torch_npu.npu_multi_head_attention(
                query, key, value, query_weight, key_weight, value_weight, attn_mask, out_proj_weight,
                query_bias, key_bias, value_bias, out_proj_bias, attn_head_num=attn_head_num,
                attn_dim_per_head=attn_dim_per_head, src_len=src_len, tgt_len=tgt_len, dropout_mask=dropout_mask_input, dropout_prob=dropout_prob, softmax_use_float=softmax_use_float
            )

            self.assertEqual(y.shape, (batch * tgt_len, weight_col))
            self.assertEqual(y.dtype, query.dtype)
            self.assertEqual(dropout_mask.shape,
                             (batch * attn_head_num * tgt_len * src_len // 8,))
            self.assertEqual(dropout_mask.dtype, torch.uint8)
            self.assertEqual(
                query_res.shape, (batch, attn_head_num, tgt_len, attn_dim_per_head))
            self.assertEqual(
                key_res.shape, (batch, attn_head_num, src_len, attn_dim_per_head))
            self.assertEqual(
                value_res.shape, (batch, attn_head_num, src_len, attn_dim_per_head))
            self.assertEqual(attn_scores.shape,
                             (batch, attn_head_num, tgt_len, src_len))
            self.assertEqual(
                attn_res.shape, (batch, attn_head_num, tgt_len, src_len))
            self.assertEqual(context.shape, (batch * tgt_len, weight_col))


class TestNpuMultiHeadAttentionBackward(TestCase):
    def test_npu_multi_head_attention_backward_meta(self):
        with FakeTensorMode():
            batch = 8
            attn_head_num = 16
            attn_dim_per_head = 64
            src_len = 64
            tgt_len = 64
            dropout_prob = 0.0
            softmax_use_float = True
            dropout_prob = 0.0
            softmax_use_float = True
            weight_col = attn_head_num * attn_dim_per_head
            query = torch.randn(batch * tgt_len, weight_col,
                                dtype=torch.float16).npu()
            key = torch.randn(batch * src_len, weight_col,
                              dtype=torch.float16).npu()
            value = torch.randn(batch * tgt_len, weight_col,
                                dtype=torch.float16).npu()
            query_weight = torch.randn(
                weight_col, weight_col, dtype=torch.float16).npu()
            key_weight = torch.randn(
                weight_col, weight_col, dtype=torch.float16).npu()
            value_weight = torch.randn(
                weight_col, weight_col, dtype=torch.float16).npu()
            out_proj_weight = torch.randn(
                weight_col, weight_col, dtype=torch.float16).npu()
            attn_mask = torch.randn(
                batch, attn_head_num, tgt_len, src_len, dtype=torch.float16).npu()
            query_bias = torch.randn(weight_col, dtype=torch.float16).npu()
            key_bias = torch.randn(weight_col, dtype=torch.float16).npu()
            value_bias = torch.randn(weight_col, dtype=torch.float16).npu()
            out_proj_bias = torch.randn(weight_col, dtype=torch.float16).npu()
            dropout_mask = torch.randint(
                0, 1, (weight_col, ), dtype=torch.uint8).npu()

            result0, result1, result2, result3, result4, result5, result6, result7 = torch_npu.npu_multi_head_attention(
                query, key, value, query_weight, key_weight, value_weight, attn_mask, out_proj_weight, query_bias, key_bias, value_bias, out_proj_bias,  dropout_mask, attn_head_num, attn_dim_per_head, src_len, tgt_len, dropout_prob, softmax_use_float)

            grad = torch.ones_like(result0)
            query_weight_grad, key_weight_grad, value_weight_grad, out_proj_weight_grad, query_grad, key_grad, value_grad, query_bias_grad, key_bias_grad, value_bias_grad, out_proj_bias_grad = torch_npu.npu_multi_head_attention_backward(
                query, key, value, query_weight, key_weight, value_weight, out_proj_weight, query_bias, key_bias, value_bias, out_proj_bias,
                result2, result3, result4, result5, result6, result7, grad, result1, attn_head_num, attn_dim_per_head, src_len, tgt_len, dropout_prob, softmax_use_float
            )

            self.assertEqual(query_weight_grad.shape, (weight_col, weight_col))
            self.assertEqual(key_weight_grad.shape, (weight_col, weight_col))
            self.assertEqual(value_weight_grad.shape, (weight_col, weight_col))
            self.assertEqual(out_proj_weight_grad.shape,
                             (weight_col, weight_col))
            self.assertEqual(query_grad.shape, (batch * tgt_len, weight_col))
            self.assertEqual(key_grad.shape, (batch * src_len, weight_col))
            self.assertEqual(value_grad.shape, (batch * src_len, weight_col))
            self.assertEqual(query_bias_grad.shape, (1, weight_col))
            self.assertEqual(key_bias_grad.shape, (1, weight_col))
            self.assertEqual(value_bias_grad.shape, (1, weight_col))
            self.assertEqual(out_proj_bias_grad.shape, (1, weight_col))


class TestNpuLstmCell(TestCase):
    def test_npu_lstm_cell_meta(self):
        with FakeTensorMode():
            batch_size = 2
            hidden_size = 64
            input_size = 128

            input_ = torch.randn(batch_size, input_size,
                                 dtype=torch.float16).npu()
            w_ih = torch.randn(4 * hidden_size, input_size,
                               dtype=torch.float16).npu()
            w_hh = torch.randn(4 * hidden_size, hidden_size,
                               dtype=torch.float16).npu()
            h = torch.randn(batch_size, hidden_size, dtype=torch.float16).npu()
            c = torch.randn(batch_size, hidden_size, dtype=torch.float16).npu()
            b_ih = torch.randn(4 * hidden_size, dtype=torch.float16).npu()
            b_hh = torch.randn(4 * hidden_size, dtype=torch.float16).npu()

            y_output, h_out, c_out, i_output, j_output, f_output, o_output, tanhc = torch_npu.npu_lstm_cell(
                input_, w_ih, w_hh, h, c, b_ih, b_hh
            )
            self.assertEqual(y_output.shape, (1, batch_size, hidden_size // 4))
            self.assertEqual(h_out.shape, (batch_size, hidden_size // 4))
            self.assertEqual(c_out.shape, (batch_size, hidden_size // 4))
            self.assertEqual(i_output.shape, (1, batch_size, hidden_size // 4))
            self.assertEqual(j_output.shape, (1, batch_size, hidden_size // 4))
            self.assertEqual(f_output.shape, (1, batch_size, hidden_size // 4))
            self.assertEqual(o_output.shape, (1, batch_size, hidden_size // 4))
            self.assertEqual(tanhc.shape, (1, batch_size, hidden_size // 4))


class TestNpuLstmCellBackward(TestCase):
    def test_npu_lstm_cell_backward_meta(self):
        with FakeTensorMode():
            batch_size = 2
            hidden_size = 64
            input_size = 128

            input = torch.randn(batch_size, input_size,
                                dtype=torch.float16).npu()
            w_ih = torch.randn(4 * hidden_size, input_size,
                               dtype=torch.float16).npu()
            w_hh = torch.randn(4 * hidden_size, hidden_size,
                               dtype=torch.float16).npu()
            h = torch.randn(batch_size, hidden_size, dtype=torch.float16).npu()
            c = torch.randn(batch_size, hidden_size, dtype=torch.float16).npu()
            y_output = torch.randn(
                1, batch_size, hidden_size, dtype=torch.float16).npu()
            h_output = torch.randn(
                batch_size, hidden_size, dtype=torch.float16).npu()
            c_output = torch.randn(
                batch_size, hidden_size, dtype=torch.float16).npu()
            i = torch.randn(1, batch_size, hidden_size,
                            dtype=torch.float16).npu()
            j = torch.randn(1, batch_size, hidden_size,
                            dtype=torch.float16).npu()
            f = torch.randn(1, batch_size, hidden_size,
                            dtype=torch.float16).npu()
            o = torch.randn(1, batch_size, hidden_size,
                            dtype=torch.float16).npu()
            tanhc = torch.randn(1, batch_size, hidden_size,
                                dtype=torch.float16).npu()
            grad_y = torch.randn(1, batch_size, hidden_size,
                                 dtype=torch.float16).npu()
            grad_h = torch.randn(batch_size, hidden_size,
                                 dtype=torch.float16).npu()
            grad_c = torch.randn(batch_size, hidden_size,
                                 dtype=torch.float16).npu()

            grad_input, grad_wih, grad_whh, grad_bias1, grad_bias2, grad_ht, grad_ct = torch_npu.npu_lstm_cell_backward(
                grad_y, grad_h, grad_c, input, w_ih, w_hh, h, c, y_output, h_output, c_output, i, j, f, o, tanhc
            )

            self.assertEqual(grad_input.shape, input.shape)
            self.assertEqual(grad_wih.shape, w_ih.shape)
            self.assertEqual(grad_whh.shape, w_hh.shape)
            self.assertEqual(grad_bias1.shape, (4 * hidden_size,))
            self.assertEqual(grad_bias2.shape, (4 * hidden_size,))
            self.assertEqual(grad_ht.shape, h.shape)
            self.assertEqual(grad_ct.shape, c.shape)


class TestNpuFusedAttentionScoreFwd(TestCase):
    def test_npu_fused_attention_score_fwd_meta(self):
        with FakeTensorMode():
            batch_size = 2
            num_heads = 8
            seq_len = 16
            head_dim = 64

            query_layer = torch.randn(
                batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
            key_layer = torch.randn(
                batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
            value_layer = torch.randn(
                batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
            attention_mask = torch.randn(
                batch_size * num_heads * seq_len * seq_len, dtype=torch.float16).npu()
            scale = 1.0 / (head_dim ** 0.5)
            keep_prob = 1.0
            attention_score, softmax_output, drop_mask = torch_npu.npu_fused_attention_score_fwd(
                query_layer, key_layer, value_layer, attention_mask, scale, keep_prob
            )

            self.assertEqual(attention_score.shape,
                             (batch_size * seq_len, num_heads * head_dim))
            self.assertEqual(softmax_output.shape,
                             (batch_size, num_heads, seq_len, seq_len))
            self.assertEqual(drop_mask.shape, (batch_size *
                             num_heads * seq_len * seq_len,))
            self.assertEqual(drop_mask.dtype, torch.uint8)


class TestNpuFusedAttentionScoreBackward(TestCase):
    def test_npu_fused_attention_score_backward_meta(self):
        with FakeTensorMode():
            batch_size = 2
            num_heads = 8
            seq_len = 16
            head_dim = 64

            query_layer = torch.randn(
                batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
            key_layer = torch.randn(
                batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
            value_layer = torch.randn(
                batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
            softmax_output = torch.randn(
                batch_size, num_heads, seq_len, seq_len, dtype=torch.float16).npu()
            drop_mask = torch.randint(
                0, 1, (batch_size * num_heads * seq_len * seq_len, ), dtype=torch.uint8).npu()
            grad_output = torch.randn(
                batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
            scale = 1.0 / (head_dim ** 0.5)

            query_dx, key_dw, value_dw = torch_npu.npu_fused_attention_score_backward(
                grad_output, softmax_output, query_layer, key_layer, value_layer, drop_mask, scale, False, False, False, False)

            self.assertEqual(query_dx.shape, grad_output.shape)
            self.assertEqual(key_dw.shape, grad_output.shape)
            self.assertEqual(value_dw.shape, grad_output.shape)

@unittest.skip("skip until CANN is updated to support aclnnRotateQuant.")
class TestRotateQuant(TestCase):
    @SupportedDevices(['Ascend910B'])
    def test_npu_rotate_quant_int8(self):
        with FakeTensorMode():
            M = 512
            N = 1024
            K = 1024
            x = torch.randn(M, N, dtype=torch.bfloat16).npu()
            rotation = torch.randn(K, K, dtype=torch.bfloat16).npu()
            output0_npu, output1_npu = torch_npu.npu_rotate_quant(x, rotation, dst_dtype=torch.int8)
            self.assertEqual(output0_npu.shape, torch.Size([M, N]))
            self.assertEqual(output0_npu.dtype, torch.int8)
            self.assertEqual(output1_npu.shape, torch.Size([M]))
            self.assertEqual(output1_npu.dtype, torch.float32)

    @SupportedDevices(['Ascend910B'])
    def test_npu_rotate_quant_int4(self):
        with FakeTensorMode():
            M = 512
            N = 1024
            K = 1024
            x = torch.randn(M, N, dtype=torch.bfloat16).npu()
            rotation = torch.randn(K, K, dtype=torch.bfloat16).npu()
            output0_npu, output1_npu = torch_npu.npu_rotate_quant(x, rotation, dst_dtype=torch.quint4x2)
            self.assertEqual(output0_npu.shape, torch.Size([M, N // 8]))
            self.assertEqual(output0_npu.dtype, torch.int32)
            self.assertEqual(output1_npu.shape, torch.Size([M]))
            self.assertEqual(output1_npu.dtype, torch.float32)

    @SupportedDevices(['Ascend950'])
    def test_npu_rotate_quant_mxfp4(self):
        with FakeTensorMode():
            M = 512
            N = 1024
            K = 1024
            x = torch.randn(M, N, dtype=torch.bfloat16).npu()
            rotation = torch.randn(K, K, dtype=torch.bfloat16).npu()
            output0_npu, output1_npu = torch_npu.npu_rotate_quant(
                x, rotation, dst_dtype=torch_npu.float4_e2m1fn_x2, axis=-1, round_mode="rint")
            self.assertEqual(output0_npu.shape, torch.Size([M, N // 2]))
            self.assertEqual(output0_npu.dtype, torch.uint8)
            self.assertEqual(output1_npu.shape, torch.Size([M, 16, 2]))
            self.assertEqual(output1_npu.dtype, torch_npu.float8_e8m0fnu)

    @SupportedDevices(['Ascend950'])
    def test_npu_rotate_quant_mxfp8_e5m2(self):
        with FakeTensorMode():
            M = 512
            N = 1024
            K = 1024
            x = torch.randn(M, N, dtype=torch.bfloat16).npu()
            rotation = torch.randn(K, K, dtype=torch.bfloat16).npu()
            output0_npu, output1_npu = torch_npu.npu_rotate_quant(
                x, rotation, dst_dtype=torch.float8_e5m2, axis=-1, round_mode="rint")
            self.assertEqual(output0_npu.shape, torch.Size([M, N]))
            self.assertEqual(output0_npu.dtype, torch.float8_e5m2)
            self.assertEqual(output1_npu.shape, torch.Size([M, 16, 2]))
            self.assertEqual(output1_npu.dtype, torch_npu.float8_e8m0fnu)

    @SupportedDevices(['Ascend950'])
    def test_npu_rotate_quant_mxfp8_e4m3fn(self):
        with FakeTensorMode():
            M = 512
            N = 1024
            K = 1024
            x = torch.randn(M, N, dtype=torch.bfloat16).npu()
            rotation = torch.randn(K, K, dtype=torch.bfloat16).npu()
            output0_npu, output1_npu = torch_npu.npu_rotate_quant(
                x, rotation, dst_dtype=torch.float8_e4m3fn, axis=-1, round_mode="rint")
            self.assertEqual(output0_npu.shape, torch.Size([M, N]))
            self.assertEqual(output0_npu.dtype, torch.float8_e4m3fn)
            self.assertEqual(output1_npu.shape, torch.Size([M, 16, 2]))
            self.assertEqual(output1_npu.dtype, torch_npu.float8_e8m0fnu)

@unittest.skip("skip until CANN is updated to support aclnnRmsNormGradQuant.")
class TestNpuRmsNormBackwardQuant(TestCase):
    def test_npu_rms_norm_backward_quant_meta(self):
        dy = torch.randn(2, 4, 64, dtype=torch.float16).to("npu")
        x = torch.randn(2, 4, 64, dtype=torch.float16).to("npu")
        rstd = torch.randn(2, 4, dtype=torch.float32).to("npu")
        gamma = torch.randn(64, dtype=torch.float16).to("npu")
        scale_x = torch.randn(1, dtype=torch.float32).to("npu")
        offset_x = torch.zeros(1, dtype=torch.int32).to("npu")

        actual_dx, actual_dgamma = torch_npu._npu_rms_norm_backward_quant(
            dy, x, rstd, gamma, scale_x, offset_x=offset_x, dst_type=1)

        with FakeTensorMode():
            fake_dy = torch.randn(2, 4, 64, dtype=torch.float16).to("npu")
            fake_x = torch.randn(2, 4, 64, dtype=torch.float16).to("npu")
            fake_rstd = torch.randn(2, 4, dtype=torch.float32).to("npu")
            fake_gamma = torch.randn(64, dtype=torch.float16).to("npu")
            fake_scale_x = torch.randn(1, dtype=torch.float32).to("npu")
            fake_offset_x = torch.zeros(1, dtype=torch.int32).to("npu")

            fake_dx, fake_dgamma = torch_npu._npu_rms_norm_backward_quant(
                fake_dy, fake_x, fake_rstd, fake_gamma, fake_scale_x,
                offset_x=fake_offset_x, dst_type=1)

        self.assertEqual(actual_dx.shape, fake_dx.shape)
        self.assertEqual(actual_dgamma.shape, fake_dgamma.shape)


class FakeTensorFormatCastTest(TestCase):
    """FakeTensor tests for _npu_format_cast meta registrations."""

    ACL_FORMAT_ND = 2
    ACL_FORMAT_NC1HWC0 = 4
    ACL_FORMAT_FRACTAL_NZ = 29
    ACL_FORMAT_NCDHW = 30
    ACL_FORMAT_NDC1HWC0 = 32
    ACL_FORMAT_FRACTAL_Z_3D = 33

    # ---- _npu_format_cast (base overload) ---- #

    def test_format_cast_2d_shape(self):
        with FakeTensorMode():
            x = torch.randn(16, 32, device='npu')
            out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
            self.assertTrue(isinstance(out, FakeTensor))
            self.assertEqual(list(out.shape), [16, 32])
            self.assertEqual(out.dtype, torch.float32)

    def test_format_cast_4d_shape(self):
        with FakeTensorMode():
            x = torch.randn(2, 32, 4, 4, device='npu', dtype=torch.float16)
            out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_NC1HWC0)
            self.assertTrue(isinstance(out, FakeTensor))
            self.assertEqual(list(out.shape), [2, 32, 4, 4])
            self.assertEqual(out.dtype, torch.float16)

    def test_format_cast_5d_shape(self):
        with FakeTensorMode():
            x = torch.randn(1, 16, 2, 4, 4, device='npu')
            out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_NDC1HWC0)
            self.assertTrue(isinstance(out, FakeTensor))
            self.assertEqual(list(out.shape), [1, 16, 2, 4, 4])

    def test_format_cast_nd_to_nd(self):
        with FakeTensorMode():
            x = torch.randn(8, 16, device='npu')
            out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_ND)
            self.assertTrue(isinstance(out, FakeTensor))
            self.assertEqual(list(out.shape), [8, 16])

    def test_format_cast_non_aligned_shape(self):
        with FakeTensorMode():
            x = torch.randn(15, 17, device='npu', dtype=torch.float16)
            out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
            self.assertEqual(list(out.shape), [15, 17])
            self.assertEqual(out.dtype, torch.float16)

    def test_format_cast_int8_shape(self):
        with FakeTensorMode():
            x = torch.randint(-128, 127, (32, 64), device='npu', dtype=torch.int8)
            out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
            self.assertEqual(list(out.shape), [32, 64])
            self.assertEqual(out.dtype, torch.int8)

    def test_format_cast_preserves_device(self):
        with FakeTensorMode():
            x = torch.randn(4, 8, device='npu')
            out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
            self.assertEqual(out.device.type, 'npu')

    # ---- _npu_format_cast.aclnn (with customize_dtype) ---- #

    def test_format_cast_aclnn_overload_shape(self):
        with FakeTensorMode():
            x = torch.randn(16, 32, device='npu', dtype=torch.float16)
            # customize_dtype=5 is float16
            out = torch.ops.npu._npu_format_cast.aclnn(x, self.ACL_FORMAT_FRACTAL_NZ, 5)
            self.assertTrue(isinstance(out, FakeTensor))
            self.assertEqual(list(out.shape), [16, 32])
            self.assertEqual(out.dtype, torch.float16)

    def test_format_cast_aclnn_int8(self):
        with FakeTensorMode():
            x = torch.randint(-128, 127, (32, 64), device='npu', dtype=torch.int8)
            # customize_dtype=1 is int8
            out = torch.ops.npu._npu_format_cast.aclnn(x, self.ACL_FORMAT_FRACTAL_NZ, 1)
            self.assertEqual(list(out.shape), [32, 64])
            self.assertEqual(out.dtype, torch.int8)

    # ---- _npu_format_cast.input_dtype (with customize_dtype + input_dtype) ---- #

    def test_format_cast_input_dtype_overload_shape(self):
        with FakeTensorMode():
            x = torch.randn(16, 32, device='npu', dtype=torch.float16)
            # customize_dtype=5 (fp16), input_dtype=5 (fp16)
            out = torch.ops.npu._npu_format_cast.input_dtype(x, self.ACL_FORMAT_FRACTAL_NZ, 5, 5)
            self.assertTrue(isinstance(out, FakeTensor))
            self.assertEqual(list(out.shape), [16, 32])

    def test_format_cast_input_dtype_int8(self):
        with FakeTensorMode():
            x = torch.randint(-128, 127, (32, 64), device='npu', dtype=torch.int8)
            # customize_dtype=1 (int8), input_dtype=1 (int8)
            out = torch.ops.npu._npu_format_cast.input_dtype(x, self.ACL_FORMAT_FRACTAL_NZ, 1, 1)
            self.assertEqual(list(out.shape), [32, 64])
            self.assertEqual(out.dtype, torch.int8)

    # ---- Multiple dtypes ---- #

    def test_format_cast_bfloat16(self):
        with FakeTensorMode():
            x = torch.randn(8, 16, device='npu', dtype=torch.bfloat16)
            out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
            self.assertEqual(out.dtype, torch.bfloat16)
            self.assertEqual(list(out.shape), [8, 16])

    def test_format_cast_float32(self):
        with FakeTensorMode():
            x = torch.randn(8, 16, device='npu', dtype=torch.float32)
            out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
            self.assertEqual(out.dtype, torch.float32)
            self.assertEqual(list(out.shape), [8, 16])

    def test_format_cast_3d_shape(self):
        with FakeTensorMode():
            x = torch.randn(4, 15, 17, device='npu', dtype=torch.float16)
            out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
            self.assertEqual(list(out.shape), [4, 15, 17])


if __name__ == "__main__":
    run_tests()