import contextlib
import copy
import itertools
import os
import unittest
import weakref
from unittest.mock import patch
import numpy as np
import math
import random
import torch
import torch._dynamo
import torch._functorch.config
import torch._prims as prims
import torch_npu
import torch_npu.testing
from torch_npu.testing.common_utils import SupportedDevices
import torch.testing._internal.optests as optests
from torch import distributed as dist
from torch._dynamo.testing import rand_strided
from torch._subclasses.fake_tensor import (
FakeTensor,
FakeTensorMode,
FakeTensorConverter,
DynamicOutputShapeException,
UnsupportedOperatorException,
)
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.testing import FileCheck
from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_utils import (
TestCase, TEST_WITH_TORCHDYNAMO, run_tests, skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, parametrize,
instantiate_parametrized_tests)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten
RUN_NPU = torch.npu.is_available()
def _get_test_torch_version():
torch_npu_version = torch_npu.__version__
version_list = torch_npu_version.split('.')
if len(version_list) > 2:
return f'v{version_list[0]}.{version_list[1]}'
else:
raise RuntimeError("Invalid torch_npu version.")
class FakeTensorTest(TestCase):
def checkType(self, t, device_str, size):
self.assertTrue(isinstance(t, FakeTensor))
self.assertEqual(t.device.type, device_str)
self.assertEqual(list(t.size()), size)
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_npu_initialized(self):
with FakeTensorMode():
p = torch.randn(4, 2, requires_grad=True, device='npu')
x = torch.randn(8, 4, device='npu')
y = torch.mm(x, p).square().sum()
y.backward()
def test_basic(self):
x = torch.empty(2, 2, device="cpu")
y = torch.empty(4, 2, 2, device="cpu")
with FakeTensorMode() as mode:
x = mode.from_tensor(x)
y = mode.from_tensor(y)
z = x + y
self.assertEqual(z.shape, (4, 2, 2))
self.assertEqual(z.device, torch.device("cpu"))
self.assertTrue(isinstance(z, FakeTensor))
@unittest.skipIf(_get_test_torch_version() not in ["v2.1", "v2.3.1"],
"Skipping test for these torch versions.")
def test_basic_forced_memo_only(self):
x = torch.empty(2, 2, device="cpu")
y = torch.empty(4, 2, 2, device="cpu")
with FakeTensorMode() as mode:
x_fake = mode.from_tensor(x)
x2 = mode.from_tensor(x, memoized_only=True)
self.assertTrue(x2 is not None)
y = mode.from_tensor(y, memoized_only=True)
self.assertIs(y, None)
def test_custom_op_fallback(self):
from torch.library import Library, impl
test_lib = Library("my_test_op", "DEF")
test_lib.define('foo(Tensor self) -> Tensor')
@impl(test_lib, 'foo', 'CPU')
def foo_impl(self):
return self.cos()
x = torch.empty(2, 2, device="cpu")
with self.assertRaisesRegex(UnsupportedOperatorException, "my_test_op.foo.default"):
with FakeTensorMode(allow_fallback_kernels=True) as mode:
x = mode.from_tensor(x)
torch.ops.my_test_op.foo(x)
def test_parameter_instantiation(self):
with FakeTensorMode():
x = torch.rand([4])
y = torch.nn.parameter.Parameter(x)
self.assertTrue(isinstance(y, torch.nn.Parameter))
@unittest.skipIf(not dist.is_available(), "requires distributed")
def test_fsdp_flat_param(self):
if "2.1." in torch.__version__:
from torch.distributed.fsdp.flat_param import FlatParameter
else:
from torch.distributed.fsdp._flat_param import FlatParameter
with FakeTensorMode() as m:
data = torch.randn(2, 2)
param = FlatParameter(data, requires_grad=True)
self.assertIsInstance(param, FlatParameter)
self.assertIsInstance(param, torch.nn.Parameter)
self.assertIsInstance(param, FakeTensor)
def test_non_parameter_grad(self):
mode = FakeTensorMode()
t = torch.rand([4], requires_grad=True)
fake_t = mode.from_tensor(t)
self.assertEqual(fake_t.requires_grad, t.requires_grad)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_index_npu_with_cpu(self):
with FakeTensorMode():
x = torch.rand([2048], device='npu')
out = x[torch.zeros([36], dtype=torch.int64)]
self.checkType(out, "npu", [36])
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_shape_take_not_device(self):
with FakeTensorMode():
x = torch.empty(1, device="cpu")
y = torch.empty(8, 8, device="npu")
out = x.resize_as_(y)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.device.type, "cpu")
self.assertTrue(isinstance(out, FakeTensor))
def test_repr(self):
with FakeTensorMode():
x = torch.empty(2, 2, device="cpu")
self.assertEqual(repr(x), 'FakeTensor(..., size=(2, 2))')
x = torch.empty(2, 2, device="meta")
self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_zero_dim(self):
with FakeTensorMode() as mode:
x = torch.tensor(0.)
y = torch.rand([4, 4], device="npu")
out = x + y
self.assertEqual(out.shape, (4, 4))
self.assertEqual(out.device, y.device)
self.assertTrue(isinstance(out, FakeTensor))
def test_nan_to_num(self):
with FakeTensorMode():
for dtype in [torch.float16, torch.float32]:
x = torch.rand([4], dtype=dtype)
y = torch.nan_to_num(x, nan=None)
z = torch.nan_to_num(x, 0.0)
self.assertEqual(dtype, y.dtype)
self.assertEqual(dtype, z.dtype)
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_throw(self):
x = torch.tensor(0.)
with FakeTensorMode() as mode:
x_conv = mode.from_tensor(x)
y = torch.rand([4, 4], device="npu")
z = torch.rand([4, 4], device="cpu")
self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_type_as(self):
with FakeTensorMode():
x = torch.rand([16, 1], device="cpu")
y = torch.rand([4, 4], device="npu")
out = x.type_as(y)
self.assertEqual(out.device.type, "npu")
self.assertTrue(isinstance(out, FakeTensor))
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_setitem(self):
for device in ["cpu", "npu"]:
with FakeTensorMode():
x = torch.rand([16, 1], device=device)
x[..., 0] = 0
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_device_inplace_copy(self):
with FakeTensorMode():
x = torch.rand([8, 8], device="cpu")
y = torch.rand([8, 8], device="npu")
self.assertEqual(x.copy_(y).device.type, "cpu")
self.assertEqual(y.copy_(x).device.type, "npu")
def test_fake_dispatch_keys(self):
with FakeTensorMode():
x = torch.rand([4])
f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU")
f.run(torch._C._dispatch_key_set(x))
with torch.inference_mode():
x = torch.rand([4])
y = x + x
FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
def test_constructor(self):
with FakeTensorMode():
x = torch.rand([4, 4], device="cpu")
self.assertTrue(isinstance(x, FakeTensor))
self.assertTrue(x.device.type == "cpu")
def test_mode(self):
with FakeTensorMode():
y = torch.rand([4], device="cpu")
out = y + y
self.assertTrue(isinstance(out, FakeTensor))
def test_full(self):
with torch._subclasses.CrossRefFakeMode():
y = torch.full((4, 4), 1)
def check_function_with_fake(self, fn):
out = fn()
with torch._subclasses.FakeTensorMode():
out_fake = fn()
for a, b in zip(tree_flatten(out), tree_flatten(out_fake)):
if not isinstance(a, FakeTensor):
self.assertTrue(not isinstance(b, FakeTensor))
continue
prims.utils.compare_tensor_meta(a, b, check_strides=True)
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_non_kwarg_device(self):
with FakeTensorMode():
x = torch.rand([16, 1], device="cpu")
y = x.to(torch.device("cpu"))
self.assertIs(x, y)
z = x.to(torch.device("npu"))
self.assertEqual(z.device.type, "npu")
def test_non_overlapping_stride_zero(self):
def foo():
x = torch.empty_strided([1, 3, 427, 640], (0, 1, 1920, 3))
return x.half()
self.check_function_with_fake(foo)
def test_fake_mode_error(self):
x = torch.rand([4, 4])
with self.assertRaisesRegex(Exception, "Please convert all Tensors"):
with FakeTensorMode():
y = x[0]
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_fake_grad_copy(self):
x = torch.rand([4, 4], requires_grad=True)
x.grad = torch.rand([4, 4])
mode = FakeTensorMode()
fake_x = mode.from_tensor(x)
prims.utils.compare_tensor_meta(fake_x, x)
prims.utils.compare_tensor_meta(fake_x.grad, x.grad)
self.assertTrue(isinstance(fake_x.grad, FakeTensor))
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_like_constructor(self):
with FakeTensorMode():
x = torch.rand([4, 4])
y = torch.ones_like(x)
self.assertTrue(isinstance(y, FakeTensor))
self.assertEqual(y.device.type, "cpu")
z = torch.ones_like(x, device="npu")
self.assertTrue(isinstance(z, FakeTensor))
self.assertEqual(z.device.type, "npu")
def test_binary_op_type_promotion(self):
with FakeTensorMode():
x = torch.empty([2, 2], dtype=torch.float)
y = torch.empty([2, 2], dtype=torch.int64)
try:
out = x / y
except ZeroDivisionError:
print("Error: Division by zero is not allowed")
self.assertEqual(out.dtype, torch.float)
self.assertEqual(out.device.type, "cpu")
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_from_numpy(self):
with FakeTensorMode():
x = torch.tensor(np.zeros([4, 4]))
self.checkType(x, "cpu", [4, 4])
def test_randperm(self):
x = torch.randperm(10)
y = torch.randperm(5, device="cpu")
with FakeTensorMode():
x1 = torch.randperm(10)
prims.utils.compare_tensor_meta(x, x1)
y1 = torch.randperm(5, device="cpu")
prims.utils.compare_tensor_meta(y, y1)
def test_print_in_fake_mode(self):
x = torch.zeros(2)
with FakeTensorMode():
out = str(x)
self.assertNotIn("FakeTensor", out)
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_upsample_bilinear_small_channels(self):
out = []
mode = FakeTensorMode()
for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
with context():
arg0_1 = torch.empty_strided((3, 427, 640), (1, 1920, 3), dtype=torch.float32, device='npu')
unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 0)
out.append(torch.ops.aten.upsample_bilinear2d.default(unsqueeze, [800, 1199], False))
self.assertTrue(out[1].is_contiguous())
self.checkMetaProps(out[0], out[1])
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_cpu_fallback(self):
with FakeTensorMode(allow_fallback_kernels=False):
filters = torch.randn(8, 4, 3, 3).npu()
inputs = torch.randn(1, 4, 5, 5).npu()
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
self.assertEqual(out.device.type, "npu")
self.assertEqual(list(out.size()), [1, 8, 5, 5])
with FakeTensorMode(allow_fallback_kernels=True):
filters = torch.randn(8, 20, 3, 3).npu()
inputs = torch.randn(1, 7, 10, 5).npu()
with self.assertRaises(RuntimeError):
torch.nn.functional.conv2d(inputs, filters, padding=1)
with FakeTensorMode(allow_fallback_kernels=True):
filters = torch.randn(8, 4, 3, 3).npu()
inputs = torch.randn(1, 4, 5, 5).npu()
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
self.assertEqual(out.device.type, "npu")
self.assertEqual(list(out.size()), [1, 8, 5, 5])
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_out_multi_device(self):
with FakeTensorMode():
x = torch.rand([4])
y = torch.rand([4], device="npu")
with self.assertRaisesRegex(Exception, "found two different devices"):
torch.sin(x, out=y)
with self.assertRaisesRegex(Exception, "found two different devices"):
x.add_(y)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_normalize_device(self):
with FakeTensorMode():
x = torch.empty(1, device="npu")
y = torch.empty(1, device=f"npu:{torch.npu.current_device()}")
out = x + y
self.checkType(out, "npu", [1])
def test_recursive_invocation(self):
mode = FakeTensorMode()
with mode:
x = torch.tensor(2)
mode.in_kernel_invocation = True
y = x + x
self.assertTrue(mode.in_kernel_invocation)
@unittest.skip("skip test_lstm now")
def test_lstm(self):
with FakeTensorMode(allow_fallback_kernels=False):
N = 5
L = 4
H_in = 2
hidden_size = 3
proj_size = 2
num_layers = 2
bidir = False
D = 2 if bidir else 1
H_out = proj_size if proj_size > 0 else hidden_size
lstm = torch.nn.LSTM(input_size=H_in, hidden_size=hidden_size,
num_layers=num_layers, proj_size=proj_size, batch_first=False,
bias=True, bidirectional=bidir, device='npu')
h_0 = torch.randn((num_layers * D, N, H_out), requires_grad=False, device='npu')
c_0 = torch.randn((num_layers * D, N, hidden_size), requires_grad=False, device='npu')
inp = torch.randn((L, N, H_in), requires_grad=False, device='npu')
(output, (h_n, c_n)) = lstm(inp, (h_0, c_0))
output.sum().backward()
self.assertEqual(output.shape, (L, N, D * H_out))
self.assertEqual(h_n.shape, (D * num_layers, N, H_out))
self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size))
def test_data_dependent_operator(self):
with FakeTensorMode(allow_fallback_kernels=False):
x = torch.rand([10, 10])
self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
def checkMetaProps(self, t1, t2):
prims.utils.compare_tensor_meta(t1, t2, check_strides=True)
@skipIfCrossRef
def test_deepcopy(self):
with FakeTensorMode() as mode:
pass
mod = torch.nn.BatchNorm2d(10)
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
mod_copied = copy.deepcopy(mod)
def check_copy(mod, mod_copied):
for name, param in itertools.chain(mod.named_parameters(), mod.named_buffers()):
param_copied = getattr(mod_copied, name)
self.checkMetaProps(param, param_copied)
self.assertTrue(isinstance(param_copied, FakeTensor))
self.assertEqual(isinstance(param, torch.nn.Parameter), isinstance(param_copied, torch.nn.Parameter))
self.assertEqual(param.requires_grad, param_copied.requires_grad)
check_copy(mod, mod_copied)
class ModuleNew(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.rand([10, 2])
self.b = self.a
self.c = self.a[0]
mod = ModuleNew()
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
mod_copied = copy.deepcopy(mod)
self.assertIs(mod_copied.a, mod_copied.b)
self.assertEqual(mod_copied.b.storage()._cdata, mod_copied.a.storage()._cdata)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_new(self):
with FakeTensorMode():
a = torch.rand([16, 1])
self.checkType(a.new(10, 10), "cpu", [10, 10])
self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
b = torch.rand([4, 4], device='npu')
self.checkType(b.new(device='npu'), "npu", [0])
self.checkType(a.new(torch.rand([1])), "cpu", [1])
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_scalar_inputs(self):
with FakeTensorMode():
self.checkType(torch.div(3, 2), "cpu", [])
ten = torch.zeros(2, dtype=torch.int32) * 2.0
self.assertEqual(ten.dtype, torch.float)
self.checkType(ten, "cpu", [2])
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_allow_meta(self):
def run_meta():
with FakeTensorMode():
x = torch.rand([4], device="meta")
return x + x
self.checkType(run_meta(), "meta", [4])
with patch.object(torch._functorch.config, "fake_tensor_allow_meta", False):
self.assertRaises(Exception, run_meta)
def test_embedding_bag_meta(self):
def f():
embedding = torch.nn.EmbeddingBag(10, 3, mode='sum', device='meta')
ipt = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
offsets = torch.tensor([0, 4], dtype=torch.long)
return embedding(ipt, offsets)
real_out = f()
with FakeTensorMode():
fake_out = f()
for r, f in zip(real_out, fake_out):
self.assertEqual(r.size(), f.size())
self.assertEqual(r.device, f.device)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_mixed_real_and_fake_inputs(self):
class _TestPattern(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.bn = torch.nn.BatchNorm2d(1)
def forward(self, ipt):
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
try:
scale_factor = self.bn.weight / running_std
except ZeroDivisionError:
print("Error: Division by zero is not allowed")
weight_shape = [1] * len(self.conv.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.conv.weight.shape)
bias_shape[1] = -1
scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
zero_bias = torch.zeros_like(self.conv.bias, dtype=ipt.dtype)
conv = self.conv._conv_forward(ipt, scaled_weight, zero_bias)
try:
conv_orig = conv / scale_factor.reshape(bias_shape)
except ZeroDivisionError:
print("Error: Division by zero is not allowed")
conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
conv = self.bn(conv_orig)
return conv
example_inputs = (torch.randn(1, 1, 3, 3),)
mod = _TestPattern()
with FakeTensorMode(allow_non_fake_inputs=True):
out = mod(torch.randn(1, 1, 3, 3))
self.checkType(out, "cpu", (1, 1, 3, 3))
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_batch_norm_gather_stats_update_fake_tensor(self):
with FakeTensorMode():
input_t = torch.rand(2, 3, 4, 4, device="npu")
mean = torch.rand(4, 3, device="npu")
invstd = torch.rand(4, 3, device="npu")
running_mean = torch.rand(3, device="npu")
running_var = torch.rand(3, device="npu")
counts = torch.tensor([1, 1, 1, 1], dtype=torch.float32, device="npu")
batch_mean, batch_invstd = torch_npu.batch_norm_gather_stats_update(
input_t, mean, invstd, running_mean, running_var, 0.1, 1e-5, counts
)
self.checkType(batch_mean, "npu", [3])
self.checkType(batch_invstd, "npu", [3])
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_aten_copy_multi_device(self):
with FakeTensorMode():
x1 = torch.rand(4, device="cpu")
x2 = torch.rand(4, device="npu")
copy1 = torch.ops.aten.copy.default(x1, x2)
copy2 = torch.ops.aten.copy.default(x2, x1)
out = torch.empty(4, device="cpu")
torch.ops.aten.copy.out(x1, x2, out=out)
self.checkType(copy1, "cpu", (4,))
self.checkType(copy2, "npu", (4,))
self.checkType(out, "cpu", (4,))
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_aten_index_multi_device(self):
with FakeTensorMode():
x1 = torch.rand(4, 4, device="cpu")
x2 = torch.rand(4, 4, device="npu")
i1 = torch.tensor([0, 1], device="npu")
i2 = torch.tensor([0, 1], device="cpu")
r1 = torch.ops.aten.index(x1, i1)
r2 = torch.ops.aten.index(x2, i2)
y1 = torch.rand(4, device="cpu")
y2 = torch.rand(4, device="npu")
j1 = torch.tensor([2], device="npu")
j2 = torch.tensor([2], device="cpu")
r3 = torch.ops.aten.index_put.default(x1, j1, y1)
r4 = torch.ops.aten.index_put.default(x2, j2, y2)
self.checkType(r1, "cpu", ())
self.checkType(r2, "npu", ())
self.checkType(r3, "cpu", (4, 4))
self.checkType(r4, "npu", (4, 4))
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_aten_slice_scatter_multi_device(self):
with FakeTensorMode():
x1 = torch.rand(4, 4, device="cpu")
y1 = torch.rand(2, 4, device="npu")
x2 = torch.rand(4, 4, device="npu")
y2 = torch.rand(2, 4, device="cpu")
out = torch.empty(4, 4, device="cpu")
r1 = torch.ops.aten.slice_scatter.default(x1, y1, start=2)
r2 = torch.ops.aten.slice_scatter.default(x2, y2, start=2)
r3 = torch.ops.aten.slice_scatter.out(x1, y1, out=out, start=2)
self.checkType(r1, "cpu", (4, 4))
self.checkType(r2, "npu", (4, 4))
self.checkType(r3, "cpu", (4, 4))
self.checkType(out, "cpu", (4, 4))
def test__adaptive_avg_pool2d_backward(self):
with FakeTensorMode():
grad_out = torch.rand(2, 3, 4, 4)
inp = torch.rand(2, 3, 4, 4).to(memory_format=torch.channels_last)
grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp)
self.assertTrue(torch._prims_common.suggest_memory_format(grad_in) == torch.channels_last)
class FakeTensorConstHandling(TestCase):
def assertConst(self, *args):
for arg in args:
self.assertTrue(arg.constant is not None)
def assertNotConst(self, *args):
for arg in args:
self.assertTrue(arg.constant is None)
def test_simple(self):
with FakeTensorMode():
x = torch.tensor(4.)
self.assertEqual(x.item(), 4.)
def test_inplace_add(self):
with FakeTensorMode():
x = torch.tensor(4.)
y = x.add_(1)
self.assertEqual(x.item(), 5.)
self.assertEqual(y.item(), 5.)
self.assertConst(x, y)
def test_shared_storages(self):
with FakeTensorMode():
x = torch.tensor([4.])
y = x[:]
self.assertEqual(x.storage()._cdata, y.storage()._cdata)
self.assertEqual(x.constant.storage()._cdata, y.constant.storage()._cdata)
def test_constant_invalidation(self):
with FakeTensorMode():
x = torch.tensor([1.])
self.assertConst(x)
y = torch.rand([1])
x.add_(y)
self.assertNotConst(x)
def test_inplace_view_invalidation(self):
with FakeTensorMode():
x = torch.tensor([1])
self.assertConst(x)
x.resize_([2])
self.assertEqual(x.size(0), 2)
self.assertNotConst(x)
def test_fake_tensor_in_intlist_repro(self):
def fn(tensors):
max_size = torch.tensor([800, 1216], dtype=torch.int64)
batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size)
return tensors[0].new_full(batch_shape, 0.0)
with self.assertRaises(torch._subclasses.fake_tensor.DataDependentOutputException):
with torch._subclasses.fake_tensor.FakeTensorMode():
a = torch.randn(3, 800, 1199)
b = torch.randn(3, 800, 800)
inputs = [a, b]
ref = fn(inputs)
def test_fake_tensor_batch_norm_cpu(self):
with torch._subclasses.CrossRefFakeMode():
m = torch.nn.Sequential(
torch.nn.BatchNorm2d(10),
torch.nn.ReLU(),
)
m.eval()
out = m(torch.randn([2, 10, 8, 8]))
def test_shared_storage_invalidation(self):
with FakeTensorMode():
x = torch.tensor([1.])
y = x[:]
self.assertConst(x, y)
y.add_(torch.rand([1]))
self.assertNotConst(x, y)
def test_aliased_const_write(self):
with FakeTensorMode():
x = torch.tensor([1])
y = x.expand([4])
self.assertNotConst(y)
y[0] = 1
self.assertNotConst(x)
def test_constant_propagate_through_functions(self):
with FakeTensorMode():
y = torch.div(4, 4, rounding_mode='trunc')
self.assertConst(y)
def contains_type(type_tmp: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type_tmp) or any(
contains_type(e, maybe_contained_type) for e in type_tmp.containedTypes()
)
class FakeTensorOpInfoTest(TestCase):
@ops(custom_op_db, dtypes=OpDTypes.any_one)
def test_fake(self, device, dtype, op):
if "2.1." in torch.__version__:
data_dependent_outputs = {
'NumpyNMSCustomOp',
'NumpyNonzeroCustomOp',
}
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
for sample_input in sample_inputs_itr:
args = (sample_input.input,) + sample_input.args
kwargs = sample_input.kwargs
if "2.1." in torch.__version__:
optests.fake_check(op, args, kwargs, op.name in data_dependent_outputs)
else:
optests.fake_check(op, args, kwargs)
class FakeTensorConverterTest(TestCase):
def test_memoized_conversion_to_meta(self):
x = torch.rand(2, 2, 2)
mode = FakeTensorMode()
self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x))
def test_memoized_conversion_from_meta(self):
x = torch.rand(2, 2).to(device="meta")
mode = FakeTensorMode()
converter = mode.fake_tensor_converter
self.assertTrue(converter.from_meta_and_device(mode, x, "cpu") is converter.from_meta_and_device(mode, x, "cpu"))
def test_separate_tensor_storages_view(self):
x = torch.rand(2, 2, 2)
y = x[0]
mode = FakeTensorMode()
converter = mode.fake_tensor_converter
if (_get_test_torch_version() in ["v2.1", "v2.3"]):
x_conv = converter(mode, x)
y_conv = converter(mode, y)
else:
x_conv = converter.from_real_tensor(mode, x)
y_conv = converter.from_real_tensor(mode, y)
self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv))
@skipIfTorchDynamo("see pytorch torchdynamo issue 1991")
def test_separate_tensor_storages_non_view(self):
x = torch.rand(2, 2, 2)
y = torch.rand(4, 2)
y.set_(x.storage())
mode = FakeTensorMode()
converter = mode.fake_tensor_converter
if (_get_test_torch_version() in ["v2.1", "v2.3"]):
x_conv = converter(mode, x)
y_conv = converter(mode, y)
else:
x_conv = converter.from_real_tensor(mode, x)
y_conv = converter.from_real_tensor(mode, y)
stor_id = torch._C._storage_id(x_conv)
self.assertEqual(stor_id, torch._C._storage_id(y_conv))
del x
if (_get_test_torch_version() not in ["v2.1", "v2.3"]):
del x_conv
self.assertEqual(len(converter.tensor_memo), 1)
if (_get_test_torch_version() in ["v2.1", "v2.3"]):
converter.meta_converter.check_for_expired_weak_storages()
self.assertEqual(len(converter.meta_converter.storage_memo), 1)
del y
if (_get_test_torch_version() not in ["v2.1", "v2.3"]):
del y_conv
self.assertEqual(len(converter.tensor_memo), 0)
if (_get_test_torch_version() in ["v2.1", "v2.3"]):
converter.meta_converter.check_for_expired_weak_storages()
self.assertEqual(len(converter.meta_converter.storage_memo), 0)
@skipIfTorchDynamo("see pytorch torchdynamo issue 1991")
def test_dead_weak_ref(self):
x = torch.rand(2, 2, 2)
y = x[0]
mode = FakeTensorMode()
converter = FakeTensorConverter()
if (_get_test_torch_version() in ["v2.1", "v2.3"]):
x_conv = converter(mode, x)
x_conv_storage = torch._C._storage_id(x_conv)
else:
x_conv = converter.from_real_tensor(mode, x)
x_conv_storage = x_conv.untyped_storage()
del x_conv
self.assertFalse(x in converter.tensor_memo)
if (_get_test_torch_version() in ["v2.1", "v2.3"]):
y_conv = converter(mode, y)
self.assertEqual(x_conv_storage, torch._C._storage_id(y_conv))
else:
y_conv = converter.from_real_tensor(mode, y)
self.assertEqual(x_conv_storage, y_conv.untyped_storage())
@skipIfTorchDynamo("see pytorch torchdynamo issue 1991")
def test_dead_key(self):
x = torch.rand(2, 2, 2)
mode = FakeTensorMode()
converter = FakeTensorConverter()
if (_get_test_torch_version() in ["v2.1", "v2.3"]):
x_conv = converter(mode, x)
else:
x_conv = converter.from_real_tensor(mode, x)
self.assertEqual(len(converter.tensor_memo), 1)
if (_get_test_torch_version() in ["v2.1", "v2.3"]):
x_conv2 = converter(mode, x)
else:
x_conv2 = converter.from_real_tensor(mode, x)
self.assertIs(x_conv2, x_conv)
del x
if (_get_test_torch_version() not in ["v2.1", "v2.3"]):
del x_conv
del x_conv2
self.assertEqual(len(converter.tensor_memo), 0)
def test_no_active_mode(self):
with FakeTensorMode() as mode:
x = torch.empty(2, 2, device="cpu")
y = torch.empty(2, 2, device="cpu")
out = x + y
self.assertEqual(mode, out.fake_mode)
self.assertTrue(isinstance(out, FakeTensor))
self.assertEqual(out.device.type, "cpu")
def test_multiple_modes(self):
t = torch.rand([4])
t2 = torch.rand([4])
with FakeTensorMode() as m:
with FakeTensorMode() as m2:
t_fake = m.from_tensor(t)
t2_fake = m2.from_tensor(t2)
with self.assertRaisesRegex(Exception, "Mixing fake modes"):
t_fake + t2_fake
def test_separate_mode_error(self):
with FakeTensorMode():
x = torch.empty(2, 2, device="cpu")
with FakeTensorMode():
y = torch.empty(2, 2, device="cpu")
self.assertRaises(Exception, lambda: x, y)
@skipIfTorchDynamo("see pytorch torchdynamo issue 1991")
def test_no_ref_cycle(self):
x = torch.rand([4])
mode = FakeTensorMode()
y = mode.from_tensor(x)
self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1)
mode_weak = weakref.ref(mode)
y_weak = weakref.ref(mode)
del mode
del y
self.assertIsNone(mode_weak())
self.assertIsNone(y_weak())
class FakeTensorOperatorInvariants(TestCase):
@staticmethod
def get_aten_op(schema):
namespace, name = schema.name.split("::")
overload = schema.overload_name if schema.overload_name else "default"
try:
namespace == "aten"
except AttributeError:
print("AttributeError: torch.ops.aten has no attribute")
return getattr(getattr(torch.ops.aten, name), overload)
@staticmethod
def get_all_aten_schemas():
for schema in torch._C._jit_get_all_schemas():
namespace = schema.name.split("::")[0]
if namespace != "aten":
continue
yield schema
def test_non_kwarg_only_device(self):
for schema in self.get_all_aten_schemas():
ten_type = torch._C.TensorType.get()
if not any(
contains_type(arg.type, ten_type)
for arg in itertools.chain(schema.arguments, schema.returns)
):
continue
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
has_non_kwarg_device = any(
not arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
for arg in schema.arguments
)
if has_non_kwarg_device:
self.assertTrue(
self.get_aten_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops
)
def test_tensor_constructors_all_have_kwarg_device(self):
for schema in self.get_all_aten_schemas():
op = self.get_aten_op(schema)
if not torch._subclasses.fake_tensor._is_tensor_constructor(op):
continue
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
has_kwarg_device = any(
arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
for arg in schema.arguments
)
self.assertTrue(
has_kwarg_device or op == torch.ops.aten._list_to_tensor.default
)
@unittest.expectedFailure
def test_sparse_new(self):
with FakeTensorMode():
indices = torch.randn(1, 1, dtype=torch.int64)
values = torch.randn(1)
extra = (2,)
sparse = torch.randn(1).to_sparse()
sparse2 = sparse.new(indices, values, extra)
def test_tensor_new(self):
with FakeTensorMode():
x = torch.Tensor([1, 2, 3])
self.assertIsInstance(x, FakeTensor)
def test_like_ops(self):
for schema in self.get_all_aten_schemas():
if "_like" == schema.name[-5:]:
op = self.get_aten_op(schema)
self.assertIn(op, torch._subclasses.fake_tensor._like_tensor_constructors)
def test_embedding_bag_private(self):
args = [
torch.ones(6, 1),
torch.ones(6, dtype=torch.int64),
torch.arange(2, dtype=torch.int64),
False,
2,
]
ref_out = torch.ops.aten._embedding_bag(*args)
with FakeTensorMode() as m:
meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
meta_out = torch.ops.aten._embedding_bag(*meta_args)
self.assertEqual(len(ref_out), len(meta_out))
for ref_o, meta_o in zip(ref_out, meta_out):
self.assertEqual(ref_o.size(), meta_o.size())
def test_cross_entropy_loss(self):
inp = torch.randn(3, 5)
target = torch.randint(5, (3,), dtype=torch.long)
weight = torch.rand(5)
fn = torch.nn.functional.cross_entropy
for w in (weight, None):
args = (inp, target, w)
ref = fn(*args)
with FakeTensorMode() as m:
meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
meta_out = torch.nn.functional.cross_entropy(*meta_args, label_smoothing=0.5)
self.assertEqual(ref.size(), meta_out.size())
@unittest.skip("skip test_flash_attention now")
def test_flash_attention(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, arg1, arg2, arg3):
torch.ops.aten._scaled_dot_product_flash_attention(arg1, arg2, arg3)
args_new = [
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "npu"),
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "npu"),
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "npu"),
]
args = [rand_strided(bsz, num_heads, seq_len, head_dim) for
(bsz, num_heads, seq_len, head_dim) in args_new]
try:
with torch._subclasses.CrossRefFakeMode():
Repro()(*args)
except RuntimeError as e:
self.assertTrue("output[0]" not in str(e))
self.assertTrue("found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" in str(e))
@unittest.skip("skip ci err")
@skipIfRocm
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_conv_c1_backward(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, arg1, arg2, arg3):
torch.ops.aten.convolution_backward.default(
arg1,
arg2,
arg3,
[1],
[1, 1],
[1, 1],
[1, 1],
False,
[0, 0],
1,
[True, True, False],
)
args_new = [
((16, 1, 128, 128), (16384, 16384, 128, 1), torch.float16, "npu"),
((16, 64, 128, 128), (1048576, 1, 8192, 64), torch.float16, "npu"),
((1, 64, 3, 3), (576, 9, 3, 1), torch.float16, "npu"),
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args_new]
with torch._subclasses.CrossRefFakeMode():
Repro()(*args)
def test_no_dispatch_with_like_function(self):
class CountingMode(TorchDispatchMode):
def __init__(self):
self.count = 0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
self.count += 1
return func(*args, **kwargs)
with FakeTensorMode():
x = torch.randn(2)
with CountingMode() as mode:
with no_dispatch():
torch.zeros_like(x)
self.assertEqual(mode.count, 0)
class FakeTensorPropTest(TestCase):
def test_fake_tensor_prop_on_nn_module(self):
class ToyNnModuleWithParameters(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(4, 3)
self.layer2 = torch.nn.Linear(3, 2)
def forward(self, value):
value = self.layer1(value)
value = torch.relu(value)
value = self.layer2(value)
return value
model = ToyNnModuleWithParameters()
value = torch.randn(5, 4)
graph_model = torch.fx.symbolic_trace(model, (value,))
with FakeTensorMode() as fake_tensor_mode:
def to_fake_tensor(x):
if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor):
return fake_tensor_mode.from_tensor(x)
return x
chain_iter = itertools.chain(graph_model.named_parameters(), graph_model.named_buffers())
fake_parameters_and_buffers = {
k: to_fake_tensor(v)
for k, v in chain_iter
}
with torch.nn.utils.stateless._reparametrize_module(
graph_model, fake_parameters_and_buffers
):
result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value)
self.assertTrue(isinstance(result, FakeTensor))
self.assertEqual(result.shape, (5, 2))
failed = False
try:
FakeTensorProp(graph_model).propagate(value)
except AssertionError:
failed = True
self.assertTrue(failed)
def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
class OptionalArgumentInBetween(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(4, 3)
self.layer2 = torch.nn.Linear(3, 2)
def forward(self, value, another_value=None, another_optional_value=None):
if another_value is None:
another_value = torch.rand_like(value)
if another_optional_value is None:
another_optional_value = torch.rand_like(value)
value = value + another_value + another_optional_value
return value * value
fake_mode = FakeTensorMode(allow_non_fake_inputs=True, allow_fallback_kernels=False)
with fake_mode:
model = OptionalArgumentInBetween()
value = torch.randn(5, 4)
another_optional_value = torch.randn(5, 4)
graph_model = torch.fx.symbolic_trace(model, (value, None, another_optional_value))
FakeTensorProp(graph_model, fake_mode).propagate(value, None, another_optional_value)
class TestFastGelu(TestCase):
def test_fast_gelu(self):
with FakeTensorMode():
a = torch.randn(2, 3).npu()
a.requires_grad = True
result = torch_npu.fast_gelu(a)
self.assertTrue(a.shape == result.shape)
def test_fast_gelu_backward(self):
with FakeTensorMode():
a = torch.randn(2, 3).npu()
a.requires_grad = True
result = torch_npu.fast_gelu(a)
result.sum().backward()
self.assertTrue(a.shape == a.grad.shape)
def test_npu_fast_gelu(self):
with FakeTensorMode():
a = torch.randn(2, 3).npu()
a.requires_grad = True
result = torch_npu.npu_fast_gelu(a)
self.assertEqual(a.shape, result.shape)
def test_npu_fast_gelu_backward(self):
with FakeTensorMode():
a = torch.randn(2, 3).npu()
a.requires_grad = True
result = torch_npu.npu_fast_gelu(a)
result.sum().backward()
self.assertTrue(a.shape == a.grad.shape)
class TestGelu(TestCase):
def test_npu_gelu(self):
with FakeTensorMode():
a = torch.randn(10, 3, 4).npu()
a.requires_grad = True
result = torch.ops.npu.npu_gelu(a)
self.assertTrue(a.shape == result.shape)
def test_npu_gelu_backward(self):
with FakeTensorMode():
a = torch.randn(10, 3, 4).npu()
a.requires_grad = True
result = torch.ops.npu.npu_gelu(a)
result.sum().backward()
self.assertTrue(a.shape == a.grad.shape)
class TestGeluMul(TestCase):
def test_npu_gelu_mul(self):
with FakeTensorMode():
input_shape = [100, 400]
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for mode in ["none", "tanh"]:
input_tensor = torch.rand(input_shape, dtype=dtype, device="npu")
expected_shape = list(input_tensor.shape)
expected_shape[-1] = expected_shape[-1] // 2
expected_shape = tuple(expected_shape)
output_npu = torch_npu.npu_gelu_mul(input_tensor, approximate=mode)
self.assertTrue(output_npu.shape == expected_shape)
self.assertTrue(output_npu.dtype == input_tensor.dtype)
class TestIncreFlashAttention(TestCase):
def testIncreFlashAttention(self):
with FakeTensorMode():
q = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
k = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
v = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
res = torch.ops.npu.npu_incre_flash_attention(q, k, v)
self.assertTrue(q.shape == res.shape)
def test_incre_flash_attention_int8_in(self):
with FakeTensorMode():
q = torch.randint(1, 40, (1, 128), dtype=torch.int8).npu()
k = torch.randint(1, 40, (1, 128), dtype=torch.int8).npu()
v = torch.randint(1, 40, (1, 128), dtype=torch.int8).npu()
res = torch.ops.npu.npu_incre_flash_attention(q, k, v)
self.assertTrue(q.shape == res.shape)
self.assertTrue(res.dtype == torch.half)
def test_incre_flash_attention_kv_padding_size(self):
with FakeTensorMode():
q = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
k = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
v = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
kv_padding_size = torch.randint(1, 2, (1,)).npu()
s_min = 1
s_max = 1
actual_seq_shape = [1]
actual_seq = np.random.uniform(1, 1, actual_seq_shape).astype(np.int64).tolist()
res = torch.ops.npu.npu_incre_flash_attention(q, k, v, actual_seq_lengths=actual_seq,
kv_padding_size=kv_padding_size)
self.assertTrue(q.shape == res.shape)
class TestNpuBmmV2(TestCase):
def test_npu_bmmV2(self):
with FakeTensorMode():
npu_input1 = torch.randn(10, 3, 4).npu()
npu_input2 = torch.randn(10, 4, 5).npu()
output_size = []
result = torch_npu.npu_bmmV2(npu_input1, npu_input2, output_size)
self.assertEqual(result.dtype, npu_input1.dtype)
self.assertEqual(result.shape, torch.matmul(npu_input1, npu_input2).shape)
class TestNpuBlockSparseAttentionMeta(TestCase):
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_npu_block_sparse_attention_meta(self):
with FakeTensorMode():
B, N, S, D = 1, 2, 4, 8
num_kv_heads = 2
scale_value = 1.0 / (D ** 0.5)
block_shape = [128, 128]
ceil_q = (S + block_shape[0] - 1) // block_shape[0]
ceil_kv = (S + block_shape[1] - 1) // block_shape[1]
query = torch.randn(B, N, S, D, dtype=torch.float16).npu()
key = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
value = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int32).npu()
attention_out, softmax_lse = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
)
self.assertEqual(attention_out.shape, (B, N, S, D))
self.assertEqual(attention_out.dtype, query.dtype)
self.assertEqual(softmax_lse.shape, (B, N, S, 1))
self.assertEqual(softmax_lse.dtype, torch.float32)
class TestNpuBlockSparseAttentionMeta(TestCase):
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_npu_block_sparse_attention_meta(self):
with FakeTensorMode():
B, N, S, D = 1, 2, 4, 8
num_kv_heads = 2
scale_value = 1.0 / (D ** 0.5)
block_shape = [128, 128]
ceil_q = (S + block_shape[0] - 1) // block_shape[0]
ceil_kv = (S + block_shape[1] - 1) // block_shape[1]
query = torch.randn(B, N, S, D, dtype=torch.float16).npu()
key = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
value = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int32).npu()
attention_out, softmax_lse = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
)
self.assertEqual(attention_out.shape, (B, N, S, D))
self.assertEqual(attention_out.dtype, query.dtype)
self.assertEqual(softmax_lse.shape, (B, N, S, 1))
self.assertEqual(softmax_lse.dtype, torch.float32)
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_npu_block_sparse_attention_backward_meta(self):
with FakeTensorMode():
B, N, S, D = 1, 2, 4, 8
num_kv_heads = 2
scale_value = 1.0 / (D ** 0.5)
block_shape = [128, 128]
ceil_q = (S + block_shape[0] - 1) // block_shape[0]
ceil_kv = (S + block_shape[1] - 1) // block_shape[1]
query = torch.randn(B, N, S, D, dtype=torch.float16).npu()
key = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
value = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int32).npu()
attention_out, softmax_lse = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
)
d_out = torch.randn(B, N, S, D, dtype=torch.float16).npu()
d_query, d_key, d_value = torch_npu.npu_block_sparse_attention_backward(
d_out, query, key, value, attention_out, softmax_lse, block_sparse_mask,
block_shape=block_shape,
actual_seq_lengths=[S] * B, actual_seq_lengths_kv=[S] * B,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value,
)
self.assertEqual(d_query.shape, query.shape)
self.assertEqual(d_query.dtype, query.dtype)
self.assertEqual(d_key.shape, key.shape)
self.assertEqual(d_key.dtype, key.dtype)
self.assertEqual(d_value.shape, value.shape)
self.assertEqual(d_value.dtype, value.dtype)
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_npu_block_sparse_attention_autograd(self):
"""验证 derivatives.yaml 中 npu_block_sparse_attention 的自动前反向绑定:forward 后 .backward() 能正确反传梯度到 query/key/value."""
with FakeTensorMode():
B, N, S, D = 1, 2, 4, 8
num_kv_heads = 2
scale_value = 1.0 / (D ** 0.5)
block_shape = [128, 128]
ceil_q = (S + block_shape[0] - 1) // block_shape[0]
ceil_kv = (S + block_shape[1] - 1) // block_shape[1]
query = torch.randn(B, N, S, D, dtype=torch.float16).npu()
key = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
value = torch.randn(B, num_kv_heads, S, D, dtype=torch.float16).npu()
query.requires_grad = True
key.requires_grad = True
value.requires_grad = True
block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int32).npu()
attention_out, softmax_lse = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
)
attention_out.sum().backward()
self.assertIsNotNone(query.grad)
self.assertIsNotNone(key.grad)
self.assertIsNotNone(value.grad)
self.assertEqual(query.grad.shape, query.shape)
self.assertEqual(key.grad.shape, key.shape)
self.assertEqual(value.grad.shape, value.shape)
class TestNpuWeightQuantBatchMatmul2(TestCase):
def test_npu_wqbmmV2(self):
with FakeTensorMode():
x = torch.randint(-3, 3, (2, 64), dtype=torch.float16).npu()
weight = torch.randint(-3, 3, (64, 128), dtype=torch.int8).npu()
antiquant_scale = torch.randint(-3, 3, (1, 128), dtype=torch.float16).npu()
expect_ret = torch.randint(-1, 1, (2, 128), dtype=torch.float16).npu()
res = torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, weight_dtype=torch_npu.hifloat8)
self.assertTrue(expect_ret.shape == res.shape)
self.assertTrue(expect_ret.dtype == res.dtype)
class TestNpuDropout(TestCase):
def test_npu_dropout(self):
b = torch.randn(2, 3).npu()
b.requires_grad = True
result_b = torch_npu._npu_dropout(b, 0.5)
with FakeTensorMode():
a = torch.randn(2, 3).npu()
a.requires_grad = True
result = torch_npu._npu_dropout(a, 0.5)
self.assertTrue(result[0].shape == result_b[0].shape)
self.assertTrue(result[1].shape == result_b[1].shape)
def test_npu_dropout_backward(self):
with FakeTensorMode():
a = torch.randn(2, 3).npu()
a.requires_grad = True
result = torch_npu._npu_dropout(a, 0.5)
result[0].sum().backward()
self.assertTrue(a.shape == a.grad.shape)
class TestNpuDtypeCast(TestCase):
def test_npu_dtype_cast(self):
with FakeTensorMode():
npu_input = torch.randn((2, 3), dtype=torch.float32).npu()
dst_dtype = torch.float16
result = torch_npu.npu_dtype_cast(npu_input, dst_dtype)
self.assertEqual(result.dtype, dst_dtype)
self.assertEqual(result.shape, npu_input.shape)
def test_npu_dtype_cast_backward(self):
with FakeTensorMode():
npu_input = torch.randn((2, 3), dtype=torch.float32).npu()
npu_input.requires_grad = True
dst_dtype = torch.float16
result = torch_npu.npu_dtype_cast(npu_input, dst_dtype)
result.sum().backward()
self.assertEqual(result.dtype, dst_dtype)
self.assertEqual(npu_input.shape, npu_input.grad.shape)
class TestNpuRotaryMul(TestCase):
def test_npu_rotary_mul(self):
with FakeTensorMode():
embedding = torch.randn(2, 8192, 5, 125, dtype=torch.float16, requires_grad=True).npu()
cosine = torch.randn(1, 8192, 1, 128, dtype=torch.float16, requires_grad=True).npu()
sine = torch.randn(1, 8192, 1, 128, dtype=torch.float16, requires_grad=True).npu()
ret = torch.ops.npu.npu_rotary_mul(embedding, cosine, sine)
self.assertEqual(embedding.shape, ret.shape)
self.assertEqual(embedding.dtype, ret.dtype)
@unittest.skip("skip test_npu_rotary_mul_matrix now")
def test_npu_rotary_mul_matrix(self):
with FakeTensorMode():
embedding = torch.randn(2, 2, 8192, 128, dtype=torch.bfloat16).npu()
cosine = torch.randn(1, 1, 8192, 128, dtype=torch.bfloat16).npu()
sine = torch.randn(1, 1, 8192, 128, dtype=torch.bfloat16).npu()
rotate = torch.zeros(128, 128, dtype=torch.bfloat16)
half = 128 // 2
rotate[:half, half:] = torch.eye(half)
rotate[half:, :half] = -torch.eye(half)
ret = torch_npu.npu_rotary_mul(embedding, cosine, sine, "half", rotate.npu())
self.assertEqual(embedding.shape, ret.shape)
self.assertEqual(embedding.dtype, ret.dtype)
class TestNpuRotaryMulBackward(TestCase):
def test_npu_rotary_mul_backward(self):
with FakeTensorMode():
grad = torch.randn(4, 2048, 40, 128, dtype=torch.float16).npu()
embedding = torch.randn(4, 2048, 40, 128, dtype=torch.float16).npu()
cosine = torch.randn(1, 2048, 1, 128, dtype=torch.float16).npu()
sine = torch.randn(1, 2048, 1, 128, dtype=torch.float16).npu()
ret = torch.ops.npu.npu_rotary_mul_backward(grad, embedding, cosine, sine)
self.assertEqual(ret[0].shape, embedding.shape)
self.assertEqual(ret[1].shape, cosine.shape)
self.assertEqual(ret[2].shape, sine.shape)
class TestNpuTranspose(TestCase):
def test_npu_transpose(self):
with FakeTensorMode():
npu_input = torch.randn((5, 3, 6, 4)).npu()
perm = [1, 0, 2, 3]
exp_shape = npu_input.permute(perm).shape
result = torch_npu.npu_transpose(npu_input, perm)
self.assertEqual(result.shape, exp_shape)
class TestPromptFlashAttention(TestCase):
def testPromptFlashAttention(self):
with FakeTensorMode():
if "2.1." in torch.__version__:
q = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
k = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
v = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
else:
q = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
k = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
v = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
res = torch.ops.npu.npu_prompt_flash_attention(q, k, v)
self.assertTrue(q.shape == res.shape)
class TestFusedInferAttentionScore(TestCase):
@unittest.skipIf("2.1." not in torch.__version__, "skip this test for torch version other than 2.1")
def testFusedInferAttentionScore(self):
with FakeTensorMode():
q = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
k = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
v = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score(q, k, v)
self.assertTrue(q.shape == atten_out.shape)
class TestFusedInferAttentionV2(TestCase):
@unittest.skipIf("2.1." not in torch.__version__, "skip this test for torch version other than 2.1")
def testFusedInferAttentionV2(self):
with FakeTensorMode():
q = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
k = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
v = torch.randn(1, 1024, 1024, dtype=torch.float16).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score_v2(q, k, v)
self.assertTrue(q.shape == atten_out.shape)
@unittest.skipIf("2.1." not in torch.__version__, "skip this test for torch version other than 2.1")
def testFusedInferAttentionV2Pa(self):
with FakeTensorMode():
q = torch.randn(128, 1, 128, dtype=torch.bfloat16).npu()
k = torch.randn(1, 1, 8, 128, 16, dtype=torch.bfloat16).npu()
v = torch.randn(1, 1, 8, 128, 16, dtype=torch.bfloat16).npu()
block_table = torch.randint(0, 10, (1, 1), dtype=torch.int32).npu()
atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score_v2(q, k, v, block_table=block_table)
self.assertTrue(q.shape == atten_out.shape)
def testFusedInferAttentionV2_bnsd_bsnd_d_unequal(self):
with FakeTensorMode():
q = torch.randn(32, 8, 2048, 192, dtype=torch.float16).npu()
k = torch.randn(32, 8, 2048, 192, dtype=torch.float16).npu()
v = torch.randn(32, 8, 2048, 128, dtype=torch.float16).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
softmax_scale = 1 / 0.0078125
atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score_v2(
q, k, v, num_query_heads=8, input_layout="BNSD_BSND",
softmax_scale=softmax_scale, pre_tokens=65535, next_tokens=65535)
golden_output = torch.randn(32, 2048, 8, 128, dtype=torch.float16).npu()
self.assertTrue(golden_output.shape == atten_out.shape)
def testFusedInferAttentionV2_bsnd_d_unequal(self):
with FakeTensorMode():
q = torch.randn(32, 2048, 8, 192, dtype=torch.float16).npu()
k = torch.randn(32, 2048, 8, 192, dtype=torch.float16).npu()
v = torch.randn(32, 2048, 8, 128, dtype=torch.float16).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
softmax_scale = 1 / 0.0078125
atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score_v2(
q, k, v, num_query_heads=8, input_layout="BSND",
softmax_scale=softmax_scale, pre_tokens=65535, next_tokens=65535)
golden_output = torch.randn(32, 2048, 8, 128, dtype=torch.float16).npu()
self.assertTrue(golden_output.shape == atten_out.shape)
def testFusedInferAttentionV2_bsh_d_unequal(self):
with FakeTensorMode():
q = torch.randn(32, 2048, 1536, dtype=torch.float16).npu()
k = torch.randn(32, 2048, 1536, dtype=torch.float16).npu()
v = torch.randn(32, 2048, 1024, dtype=torch.float16).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
softmax_scale = 1 / 0.0078125
atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score_v2(
q, k, v, num_query_heads=8, input_layout="BSH",
softmax_scale=softmax_scale, pre_tokens=65535, next_tokens=65535)
golden_output = torch.randn(32, 2048, 1024, dtype=torch.float16).npu()
self.assertTrue(golden_output.shape == atten_out.shape)
def testFusedInferAttentionV2_tnd_d_unequal(self):
with FakeTensorMode():
q = torch.randn(32, 8, 192, dtype=torch.float16).npu()
k = torch.randn(32, 8, 192, dtype=torch.float16).npu()
v = torch.randn(32, 8, 128, dtype=torch.float16).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
softmax_scale = 1 / 0.0078125
atten_out, softmax_lse = torch.ops.npu.npu_fused_infer_attention_score_v2(
q, k, v, num_query_heads=8, input_layout="TND",
softmax_scale=softmax_scale, pre_tokens=65535, next_tokens=65535)
golden_output = torch.randn(32, 8, 128, dtype=torch.float16).npu()
self.assertTrue(golden_output.shape == atten_out.shape)
class TestFlashAttentionScore(TestCase):
def testFlashAttentionScore(self):
with FakeTensorMode():
q = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
k = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
v = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
softmax_max_sum = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
res = torch.ops.npu.npu_fusion_attention(q, k, v, head_num=40, input_layout="BNSD")
self.assertEqual(q.shape, res[0].shape)
self.assertEqual(q.dtype, res[0].dtype)
self.assertEqual(softmax_max_sum.shape, res[1].shape)
self.assertEqual(softmax_max_sum.dtype, res[1].dtype)
self.assertEqual(softmax_max_sum.shape, res[2].shape)
self.assertEqual(softmax_max_sum.dtype, res[2].dtype)
class TestFlashAttentionScoreGrad(TestCase):
def testFlashAttentionScoreGrad(self):
with FakeTensorMode():
q = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
k = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
v = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
dy = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
attention_in = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
softmax_max = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
softmax_sum = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
dy.requires_grad = True
attention_in.requires_grad = True
softmax_max.requires_grad = True
softmax_sum.requires_grad = True
res = torch.ops.npu.npu_fusion_attention_grad(q, k, v, dy, head_num=40, input_layout="BNSD",
softmax_max=softmax_max, softmax_sum=softmax_sum, attention_in=attention_in)
self.assertEqual(q.shape, res[0].shape)
self.assertEqual(q.dtype, res[0].dtype)
self.assertEqual(k.shape, res[1].shape)
self.assertEqual(k.dtype, res[1].dtype)
self.assertEqual(k.shape, res[2].shape)
self.assertEqual(k.dtype, res[2].dtype)
class TestFlashAttentionTensorScore(TestCase):
def testFlashAttentionScore(self):
with FakeTensorMode():
q = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
k = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
v = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
softmax_max_sum = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
res = torch.ops.npu.npu_fusion_attention_v3(q, k, v, head_num=40, input_layout="BNSD")
self.assertEqual(q.shape, res[0].shape)
self.assertEqual(q.dtype, res[0].dtype)
self.assertEqual(softmax_max_sum.shape, res[1].shape)
self.assertEqual(softmax_max_sum.dtype, res[1].dtype)
self.assertEqual(softmax_max_sum.shape, res[2].shape)
self.assertEqual(softmax_max_sum.dtype, res[2].dtype)
class TestFlashAttentionTensorScoreGrad(TestCase):
def testFlashAttentionScoreGrad(self):
with FakeTensorMode():
q = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
k = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
v = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
dy = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
attention_in = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
softmax_max = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
softmax_sum = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
dy.requires_grad = True
attention_in.requires_grad = True
softmax_max.requires_grad = True
softmax_sum.requires_grad = True
res = torch.ops.npu.npu_fusion_attention_grad_v3(q, k, v, dy, head_num=40, input_layout="BNSD",
softmax_max=softmax_max, softmax_sum=softmax_sum, attention_in=attention_in)
self.assertEqual(q.shape, res[0].shape)
self.assertEqual(q.dtype, res[0].dtype)
self.assertEqual(k.shape, res[1].shape)
self.assertEqual(k.dtype, res[1].dtype)
self.assertEqual(k.shape, res[2].shape)
self.assertEqual(k.dtype, res[2].dtype)
class TestFusedFloydAttention(TestCase):
def testFusedFloydAttention(self):
with FakeTensorMode():
batch_size = 2
seq_len_q = 40
seq_len_kv = 64
head_num = 16
head_dim = 128
dtype_qkv = torch.float16
dtype_mask = torch.float16
query_ik = torch.randn(batch_size, seq_len_q, head_num, head_dim, dtype=dtype_qkv).npu()
key_ij = torch.randn(batch_size, seq_len_kv, head_num, head_dim, dtype=dtype_qkv).npu()
value_ij = torch.randn_like(key_ij, dtype=dtype_qkv).npu()
key_jk = torch.randn_like(key_ij, dtype=dtype_qkv).npu()
value_jk = torch.randn_like(key_ij, dtype=dtype_qkv).npu()
atten_mask = torch.randn(batch_size, 1, seq_len_q, seq_len_kv, requires_grad=False).npu()
out0, out1, out2 = torch_npu.npu_fused_floyd_attention(query_ik, key_ij, value_ij, key_jk, value_jk, atten_mask=atten_mask, scale_value=1.0)
expected_out0_shape = (batch_size, seq_len_q, head_num, head_dim, 8)
self.assertEqual(out0.shape, expected_out0_shape)
self.assertEqual(out0.dtype, torch.float32)
self.assertEqual(out1.shape, out0.shape)
self.assertEqual(out1.dtype, out0.dtype)
self.assertEqual(out2.shape, query_ik.shape)
self.assertEqual(out2.dtype, query_ik.dtype)
class TestFusedFloydAttentionGrad(TestCase):
def testFusedFloydAttentionGrad(self):
with FakeTensorMode():
batch_size = 2
seq_len_q = 40
seq_len_kv = 64
head_num = 16
head_dim = 128
dtype_qkv = torch.float16
dtype_softmax = torch.float32
dtype_mask = torch.float16
query_ik = torch.randn(batch_size, seq_len_q, head_num, head_dim, dtype=dtype_qkv).npu()
key_ij = torch.randn(batch_size, seq_len_kv, head_num, head_dim, dtype=dtype_qkv).npu()
value_ij = torch.randn_like(key_ij, dtype=dtype_qkv).npu()
key_jk = torch.randn_like(key_ij, dtype=dtype_qkv).npu()
value_jk = torch.randn_like(key_ij, dtype=dtype_qkv).npu()
atten_mask = torch.randn(batch_size, 1, seq_len_q, seq_len_kv, dtype=dtype_mask, requires_grad=False).npu()
grad_output = torch.randn_like(query_ik, dtype=dtype_qkv).npu()
softmax_shape = (batch_size, seq_len_q, head_num, head_dim, 8)
softmax_max = torch.randn(softmax_shape, dtype=dtype_softmax).npu()
softmax_sum = torch.randn(softmax_shape, dtype=dtype_softmax).npu()
attention_out = torch.randn(softmax_shape, dtype=dtype_softmax).npu()
query_ik.requires_grad = True
key_ij.requires_grad = True
value_ij.requires_grad = True
key_jk.requires_grad = True
value_jk.requires_grad = True
dquery, dkey_0, dvalue_0, dkey_1, dvalue_1 = torch_npu.npu_fused_floyd_attention_backward(grad_output, query_ik, key_ij, value_ij, key_jk, value_jk,
attention_out, softmax_max, softmax_sum,
atten_mask=atten_mask,
scale_value=1.0)
self.assertEqual(dquery.shape, query_ik.shape)
self.assertEqual(dquery.dtype, query_ik.dtype)
self.assertEqual(dkey_0.shape, key_ij.shape)
self.assertEqual(dkey_0.dtype, key_ij.dtype)
self.assertEqual(dvalue_0.shape, value_ij.shape)
self.assertEqual(dvalue_0.dtype, value_ij.dtype)
self.assertEqual(dkey_1.shape, key_jk.shape)
self.assertEqual(dkey_1.dtype, key_jk.dtype)
self.assertEqual(dvalue_1.shape, value_jk.shape)
self.assertEqual(dvalue_1.dtype, value_jk.dtype)
class TestNpuMoeComputeExpertTokens(TestCase):
@unittest.skipIf('2.3.1' in torch.__version__, "skip this ut for this torch version")
def test_npu_moe_compute_expert_tokens(self):
with FakeTensorMode():
data = list(range(10))
experts = torch.tensor(data, dtype=torch.int32).npu()
sorted_experts = torch.sort(experts)[0]
num_experts = 10
ret = torch.ops.npu.npu_moe_compute_expert_tokens(sorted_experts, num_experts)
self.assertEqual(sorted_experts.shape, ret.shape)
self.assertEqual(sorted_experts.dtype, ret.dtype)
class TestMaskedSoftmaxWithRelPosBias(TestCase):
@unittest.skipIf("2.1." not in torch.__version__, "skip this test for torch version other than 2.1")
def testMaskedSoftmaxWithRelPosBias(self):
with FakeTensorMode():
x = torch.randn(96, 2, 2, 32, 32, dtype=torch.float)
relative_pos_bias = torch.randn(1, 1, 2, 32, 32, dtype=torch.float)
atten_mask = torch.randn(1, 2, 1, 32, 32, dtype=torch.float)
x.requires_grad = True
atten_mask.requires_grad = True
relative_pos_bias.requires_grad = True
res = torch.ops.npu.npu_masked_softmax_with_rel_pos_bias(x, atten_mask, relative_pos_bias)
self.assertTrue(x.shape == res.shape)
class TestNpuMoeInitRouting(TestCase):
@unittest.skipIf("2.1." not in torch.__version__, "skip this test for torch version other than 2.1")
def testNpuMoeInitRouting(self):
with FakeTensorMode():
x = torch.randn(3, 4, dtype=torch.float).npu()
row_idx = torch.randint(0, 6, (3, 2), dtype=torch.int32).npu()
expert_idx = torch.randint(0, 10, (3, 2), dtype=torch.int32).npu()
active_num = 3
expanded_x_golden = torch.randn(6, 4, dtype=torch.float).npu()
expanded_row_idx_golden = torch.randint(0, 6, (6, ), dtype=torch.int32).npu()
expanded_expert_idx_golden = torch.randint(0, 6, (6, ), dtype=torch.int32).npu()
expanded_x, expanded_row_idx, expanded_expert_idx = torch.ops.npu.npu_moe_init_routing(x, row_idx, expert_idx, active_num=active_num)
self.assertTrue(expanded_x.dtype == expanded_x_golden.dtype)
self.assertTrue(expanded_row_idx.dtype == expanded_row_idx_golden.dtype)
self.assertTrue(expanded_expert_idx.dtype == expanded_expert_idx_golden.dtype)
self.assertTrue(expanded_x.shape == expanded_x_golden.shape)
self.assertTrue(expanded_row_idx.shape == expanded_row_idx_golden.shape)
self.assertTrue(expanded_expert_idx.shape == expanded_expert_idx_golden.shape)
class TestNpuMoeGatingTopKSoftmax(TestCase):
@unittest.skipIf("2.1." not in torch.__version__, "skip this test for torch version other than 2.1")
def testNpuMoeGatingTopKSoftmax(self):
with FakeTensorMode():
x = torch.randn(3, 4, dtype=torch.float).npu()
y_golden = torch.randn(3, 2, dtype=torch.float).npu()
expert_idx_golden = torch.randint(-1, 1, (3, 2), dtype=torch.int32).npu()
row_idx_golden = torch.randint(-1, 1, (3, 2), dtype=torch.int32).npu()
y, expert_idx, row_idx = torch.ops.npu.npu_moe_gating_top_k_softmax(x, None, k=2)
self.assertTrue(y.dtype == y_golden.dtype)
self.assertTrue(expert_idx.dtype == expert_idx_golden.dtype)
self.assertTrue(row_idx.dtype == row_idx_golden.dtype)
self.assertTrue(y.shape == y_golden.shape)
self.assertTrue(expert_idx.shape == expert_idx_golden.shape)
self.assertTrue(row_idx.shape == row_idx_golden.shape)
class TestNpuRopeQuantKVCache(TestCase):
@unittest.skip("skip test_npu_rope_quant_kvcache_meta now")
def test_npu_rope_quant_kvcache_meta(self):
with FakeTensorMode() as mode:
data_x = np.random.uniform(0, 1, [1, 1, 128 * 3]).astype(np.float16)
in_x = torch.from_numpy(data_x).to(torch.float16).npu()
data_cos = np.random.uniform(0, 1, [1, 1, 1, 128]).astype(np.float16)
in_cos = torch.from_numpy(data_cos).to(torch.float16).npu()
data_sin = np.random.uniform(0, 1, [1, 1, 1, 128]).astype(np.float16)
in_sin = torch.from_numpy(data_sin).to(torch.float16).npu()
data_k_cache = np.random.uniform(0, 1, [1, 2, 1, 128]).astype(np.int8)
in_k_cache = torch.from_numpy(data_k_cache).to(torch.int8).npu()
data_v_cache = np.random.uniform(0, 1, [1, 2, 1, 128]).astype(np.int8)
in_v_cache = torch.from_numpy(data_v_cache).to(torch.int8).npu()
in_indices = torch.tensor([0]).to(torch.int32).npu()
in_scale_k = torch.randn([128], dtype=torch.float32).npu()
in_scale_v = torch.randn([128], dtype=torch.float32).npu()
size_splits = [128, 128, 128]
fake_x = mode.from_tensor(in_x)
fake_cos = mode.from_tensor(in_cos)
fake_sin = mode.from_tensor(in_sin)
fake_indices = mode.from_tensor(in_indices)
fake_k_cache = mode.from_tensor(in_k_cache)
fake_v_cache = mode.from_tensor(in_v_cache)
fake_scale_k = mode.from_tensor(in_scale_k)
fake_scale_v = mode.from_tensor(in_scale_v)
self.assertIsNotNone(fake_x)
self.assertIsNotNone(fake_cos)
self.assertIsNotNone(fake_sin)
self.assertIsNotNone(fake_k_cache)
self.assertIsNotNone(fake_v_cache)
self.assertIsNotNone(fake_scale_k)
self.assertIsNotNone(fake_scale_v)
q_result, k_result, c_result, k_cache_result, v_cache_result = torch.ops.npu.npu_rope_quant_kvcache(fake_x,
fake_cos,
fake_sin,
fake_k_cache,
fake_v_cache,
fake_indices,
fake_scale_k,
fake_scale_v,
size_splits)
self.assertEqual(q_result.shape, torch.Size([1, 1, 1, 128]))
self.assertEqual(q_result.dtype, in_x.dtype)
self.assertEqual(q_result.device, in_x.device)
self.assertTrue(isinstance(q_result, FakeTensor))
self.assertEqual(k_cache_result.shape, in_k_cache.shape)
self.assertEqual(k_cache_result.dtype, in_k_cache.dtype)
self.assertEqual(k_cache_result.device, in_k_cache.device)
self.assertTrue(isinstance(k_cache_result, FakeTensor))
class TestGeGlu(TestCase):
def TestGeGlu(self):
with FakeTensorMode():
x = torch.randn(2, 10, 64, dtype=torch.float)
dim = -1
approximate = 1
activate_left = False
y, gelu = torch.ops.npu.npu_geglu(x, dim, approximate, activate_left)
dim_num = x.dim()
if dim < 0:
dim += dim_num
for index in range(dim_num):
if index != dim:
self.assertEqual(y.size(index), x.size(index))
self.assertEqual(gelu.size(index), x.size(index))
else:
self.assertEqual(y.size(index) * 2, x.size(index))
self.assertEqual(gelu.size(index) * 2, x.size(index))
class TestScatterUpdateMeta(TestCase):
def test_scatter_update_meta(self):
with FakeTensorMode() as mode:
in_self = torch.randn(4, 4, 32, 256, dtype=torch.float16).npu()
in_indices = torch.tensor([1, 1, 1, 1]).npu()
in_updates = torch.randn(4, 4, 1, 256, dtype=torch.float16).npu()
fake_self = mode.from_tensor(in_self)
fake_indices = mode.from_tensor(in_indices)
fake_updates = mode.from_tensor(in_updates)
self.assertIsNotNone(fake_self)
self.assertIsNotNone(fake_indices)
self.assertIsNotNone(fake_updates)
fake_result = torch.ops.npu.scatter_update(fake_self, fake_indices, fake_updates, -2)
self.assertEqual(fake_result.shape, in_self.shape)
self.assertEqual(fake_result.dtype, in_self.dtype)
self.assertEqual(fake_result.device, in_self.device)
self.assertTrue(isinstance(fake_result, FakeTensor))
self.assertIsNot(fake_result, fake_self)
self.assertIsNot(fake_result, in_self)
def test_scatter_update__meta(self):
with FakeTensorMode() as mode:
in_self = torch.randn(4, 4, 32, 256, dtype=torch.float32).npu()
in_indices = torch.tensor([1, 1, 1, 1]).npu()
in_updates = torch.randn(4, 4, 1, 256, dtype=torch.float32).npu()
fake_self = mode.from_tensor(in_self)
fake_indices = mode.from_tensor(in_indices)
fake_updates = mode.from_tensor(in_updates)
self.assertIsNotNone(fake_self)
self.assertIsNotNone(fake_indices)
self.assertIsNotNone(fake_updates)
fake_result = torch.ops.npu.scatter_update_(fake_self, fake_indices, fake_updates, -2)
self.assertEqual(fake_result.shape, in_self.shape)
self.assertEqual(fake_result.dtype, in_self.dtype)
self.assertEqual(fake_result.device, in_self.device)
self.assertTrue(isinstance(fake_result, FakeTensor))
self.assertIs(fake_result, fake_self)
self.assertIsNot(fake_result, in_self)
class TestNpuQuantScatterMeta(TestCase):
def test_npu_quant_scatter_meta(self):
with FakeTensorMode() as mode:
data_var = np.random.uniform(0, 1, [1, 1, 32]).astype(np.int8)
in_var = torch.from_numpy(data_var).to(torch.int8).npu()
data_indices = np.random.uniform(0, 1, [1]).astype(np.int32)
in_indices = torch.from_numpy(data_indices).to(torch.int32).npu()
data_updates = np.random.uniform(1, 2, [1, 1, 32]).astype(np.float16)
in_updates = torch.from_numpy(data_updates).to(torch.bfloat16).npu()
data_quant_scales = np.random.uniform(0, 1, [1, 1, 32]).astype(np.float16)
in_quant_scales = torch.from_numpy(data_quant_scales).to(torch.bfloat16).npu()
fake_var = mode.from_tensor(in_var)
fake_indices = mode.from_tensor(in_indices)
fake_updates = mode.from_tensor(in_updates)
fake_quant_scales = mode.from_tensor(in_quant_scales)
self.assertIsNotNone(fake_var)
self.assertIsNotNone(fake_indices)
self.assertIsNotNone(fake_updates)
self.assertIsNotNone(fake_quant_scales)
fake_result = torch.ops.npu.npu_quant_scatter(fake_var, fake_indices, fake_updates, fake_quant_scales, None,
-2, -1, "update")
self.assertEqual(fake_result.shape, in_var.shape)
self.assertEqual(fake_result.dtype, in_var.dtype)
self.assertEqual(fake_result.device, in_var.device)
self.assertTrue(isinstance(fake_result, FakeTensor))
self.assertIsNot(fake_result, fake_var)
self.assertIsNot(fake_result, in_var)
def test_npu_quant_scatter__meta(self):
with FakeTensorMode() as mode:
data_var = np.random.uniform(0, 1, [1, 1, 32]).astype(np.int8)
in_var = torch.from_numpy(data_var).to(torch.int8).npu()
data_indices = np.random.uniform(0, 1, [1]).astype(np.int32)
in_indices = torch.from_numpy(data_indices).to(torch.int32).npu()
data_updates = np.random.uniform(1, 2, [1, 1, 32]).astype(np.float16)
in_updates = torch.from_numpy(data_updates).to(torch.bfloat16).npu()
data_quant_scales = np.random.uniform(0, 1, [1, 1, 32]).astype(np.float16)
in_quant_scales = torch.from_numpy(data_quant_scales).to(torch.bfloat16).npu()
fake_var = mode.from_tensor(in_var)
fake_indices = mode.from_tensor(in_indices)
fake_updates = mode.from_tensor(in_updates)
fake_quant_scales = mode.from_tensor(in_quant_scales)
self.assertIsNotNone(fake_var)
self.assertIsNotNone(fake_indices)
self.assertIsNotNone(fake_updates)
self.assertIsNotNone(fake_quant_scales)
fake_result = torch.ops.npu.npu_quant_scatter_(fake_var, fake_indices, fake_updates, fake_quant_scales,
None, -2, -1, "update")
self.assertEqual(fake_result.shape, in_var.shape)
self.assertEqual(fake_result.dtype, in_var.dtype)
self.assertEqual(fake_result.device, in_var.device)
self.assertTrue(isinstance(fake_result, FakeTensor))
self.assertIs(fake_result, fake_var)
self.assertIsNot(fake_result, in_var)
class TestNpuApplyRotoryPosEmbMeta(TestCase):
def test_npu_apply_rotary_pos_emb_meta(self):
with FakeTensorMode() as mode:
query_var = np.random.uniform(0, 1, [4, 1024, 16, 128]).astype(np.float16)
data_query = torch.from_numpy(query_var).to(torch.float16).npu()
key_var = np.random.uniform(0, 1, [4, 1024, 16, 128]).astype(np.float16)
data_key = torch.from_numpy(key_var).to(torch.float16).npu()
cos_var = np.random.uniform(0, 1, [4, 1024, 1, 128]).astype(np.float16)
data_cos = torch.from_numpy(cos_var).to(torch.float16).npu()
sin_var = np.random.uniform(0, 1, [4, 1024, 1, 128]).astype(np.float16)
data_sin = torch.from_numpy(sin_var).to(torch.float16).npu()
fake_query = mode.from_tensor(data_query)
fake_key = mode.from_tensor(data_key)
fake_cos = mode.from_tensor(data_cos)
fake_sin = mode.from_tensor(data_sin)
self.assertIsNotNone(fake_query)
self.assertIsNotNone(fake_key)
self.assertIsNotNone(fake_cos)
self.assertIsNotNone(fake_sin)
fake_query_result, fake_key_result = torch.ops.npu.npu_apply_rotary_pos_emb(fake_query, fake_key, fake_cos, fake_sin, "BSND","half")
self.assertEqual(fake_query_result.shape, data_query.shape)
self.assertEqual(fake_query_result.dtype, data_query.dtype)
self.assertEqual(fake_query_result.device, data_query.device)
self.assertEqual(fake_key_result.shape, data_key.shape)
self.assertEqual(fake_key_result.dtype, data_key.dtype)
self.assertEqual(fake_key_result.device, data_key.device)
self.assertTrue(isinstance(fake_query_result, FakeTensor))
self.assertTrue(isinstance(fake_key_result, FakeTensor))
class TestMmAllReduce(TestCase):
def test_mm_all_reduce(self):
with FakeTensorMode():
dst_dtype = torch.float16
x1 = torch.randn(128, 256, dtype=torch.float16).npu()
x2 = torch.randn(256, 128, dtype=torch.float16).npu()
hcom = "fake group info"
output = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom, reduce_op="sum")
self.assertEqual(output.shape, (128, 128))
self.assertEqual(output.dtype, dst_dtype)
def test_mm_all_reduce_quant(self):
with FakeTensorMode():
dst_dtype = torch.float16
x1 = torch.randn(128, 256, dtype=torch.float16).to(torch.int8).npu()
x2 = torch.randn(256, 128, dtype=torch.float16).to(torch.int8).npu()
dequant = torch.randn(128, dtype=torch.float16).to(torch.int64).npu()
hcom = "fake group info"
output = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom, reduce_op="sum", dequant_scale=dequant)
self.assertEqual(output.shape, (128, 128))
self.assertEqual(output.dtype, dst_dtype)
def test_mm_all_reduce_quant(self):
with FakeTensorMode():
dst_dtype = torch.bfloat16
x1 = torch.randn(128, 256, dtype=torch.float16).to(torch.int8).npu()
x2 = torch.randn(256, 128, dtype=torch.float16).to(torch.int8).npu()
dequant = torch.randn(128, dtype=torch.bfloat16).npu()
hcom = "fake group info"
output = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom, reduce_op="sum", dequant_scale=dequant)
self.assertEqual(output.shape, (128, 128))
self.assertEqual(output.dtype, dst_dtype)
class TestMmReduceScatter(TestCase):
def test_mm_reduce_scatter(self):
with FakeTensorMode():
dst_dtype = torch.float16
m, k, n = 128, 512, 256
x1 = torch.randn(m, k, dtype=torch.float16).npu()
x2 = torch.randn(k, n, dtype=torch.float16).npu()
hcom = "fake group info"
world_size = 8
output = torch_npu.npu_mm_reduce_scatter_base(x1, x2, hcom, world_size, reduce_op="sum")
self.assertEqual(output.shape, (m // world_size, n))
self.assertEqual(output.dtype, dst_dtype)
def test_mm_all_reduce_quant(self):
with FakeTensorMode():
world_size = 8
m, k, n = 128, 512, 256
dst_dtype = torch.bfloat16
x1 = torch.randint(-10, 10, size=(m, k), dtype=torch.int8).npu()
x2 = torch.randint(-10, 10, size=(k, n), dtype=torch.int8).npu()
x1_scale = torch.randn((m, 1), dtype=torch.float32).npu()
x2_scale = torch.randn((1, n), dtype=torch.float32).npu()
hcom = "fake group info"
output = torch_npu.npu_mm_reduce_scatter_base(x1, x2, hcom, world_size, x1_scale=x1_scale,
x2_scale=x2_scale, reduce_op="sum")
self.assertEqual(output.shape, (m // world_size, n))
self.assertEqual(output.dtype, dst_dtype)
class TestAllGatherBaseMm(TestCase):
def test_all_gather_base_mm(self):
with FakeTensorMode():
dst_dtype = torch.float16
m, k, n = 128, 512, 256
x1 = torch.randn(m, k, dtype=torch.float16).npu()
x2 = torch.randn(k, n, dtype=torch.float16).npu()
hcom = "fake group info"
world_size = 8
output, gather_out = torch_npu.npu_all_gather_base_mm(x1, x2, hcom, world_size)
self.assertEqual(output.shape, (m * world_size, n))
self.assertEqual(output.dtype, dst_dtype)
self.assertEqual(gather_out.shape, (m * world_size, k))
self.assertEqual(gather_out.dtype, dst_dtype)
def test_all_gather_base_mm_gather_out_false(self):
with FakeTensorMode():
dst_dtype = torch.float16
m, k, n = 128, 512, 256
x1 = torch.randn(m, k, dtype=torch.float16).npu()
x2 = torch.randn(k, n, dtype=torch.float16).npu()
hcom = "fake group info"
world_size = 8
output, gather_out = torch_npu.npu_all_gather_base_mm(x1, x2, hcom, world_size, gather_output=False)
self.assertEqual(output.shape, (m * world_size, n))
self.assertEqual(output.dtype, dst_dtype)
def test_all_gather_base_mm_quant(self):
with FakeTensorMode():
world_size = 8
m, k, n = 128, 512, 256
dst_dtype = torch.bfloat16
x1 = torch.randint(-10, 10, size=(m, k), dtype=torch.int8).npu()
x2 = torch.randint(-10, 10, size=(k, n), dtype=torch.int8).npu()
x1_scale = torch.randn((m, 1), dtype=torch.float32).npu()
x2_scale = torch.randn((1, n), dtype=torch.float32).npu()
hcom = "fake group info"
output, gather_out = torch_npu.npu_all_gather_base_mm(x1, x2, hcom, world_size, x1_scale=x1_scale,
x2_scale=x2_scale)
self.assertEqual(output.shape, (m * world_size, n))
self.assertEqual(output.dtype, dst_dtype)
self.assertEqual(gather_out.shape, (m * world_size, k))
self.assertEqual(gather_out.dtype, x1.dtype)
class TestNpuDeepNorm(TestCase):
def test_npu_deep_norm(self):
with FakeTensorMode():
npu_x = torch.randn((2, 3), dtype=torch.float32).npu()
npu_gx = torch.randn((2, 3), dtype=torch.float32).npu()
npu_beta = torch.randn((3,), dtype=torch.float32).npu()
npu_gamma = torch.randn((3,), dtype=torch.float32).npu()
result_mean, result_rstd, result_y = torch_npu.npu_deep_norm(npu_x, npu_gx, npu_beta, npu_gamma)
self.assertEqual(result_y.dtype, npu_x.dtype)
self.assertEqual(result_y.shape, npu_x.shape)
self.assertEqual(result_y.device, npu_x.device)
self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
self.assertEqual(result_rstd.device, npu_x.device)
self.assertEqual(result_mean.shape, torch.Size([2, 1]))
self.assertEqual(result_mean.device, npu_x.device)
class TestRmsNormQuant(TestCase):
def test_rms_norm_quant(self):
with FakeTensorMode():
x = torch.randn([2, 16], dtype=torch.float16).npu()
gamma = torch.randn([16, ], dtype=torch.float16).npu()
beta = torch.randn([16, ], dtype=torch.float16).npu()
scale = torch.randn(1, dtype=torch.float16).npu()
offset = torch.randint(-10, 10, (1, ), dtype=torch.int8).npu()
y = torch_npu.npu_rms_norm_quant(x, gamma, beta, scale, offset)
self.assertTrue(y.shape == x.shape)
self.assertTrue(y.dtype == torch.int8)
x = torch.randn([2, 16], dtype=torch.bfloat16).npu()
gamma = torch.randn([16, ], dtype=torch.bfloat16).npu()
beta = torch.randn([16, ], dtype=torch.bfloat16).npu()
scale = torch.randn(1, dtype=torch.bfloat16).npu()
offset = torch.randint(-10, 10, (1, ), dtype=torch.int8).npu()
y = torch_npu.npu_rms_norm_quant(x, gamma, beta, scale, offset)
self.assertTrue(y.shape == x.shape)
self.assertTrue(y.dtype == torch.int8)
x = torch.randn([2, 16], dtype=torch.float16).npu()
gamma = torch.randn([16, ], dtype=torch.float16).npu()
beta = torch.randn([16, ], dtype=torch.float16).npu()
scale = torch.randn(1, dtype=torch.float16).npu()
offset = torch.randint(-10, 10, (1, ), dtype=torch.int8).npu()
y = torch_npu.npu_rms_norm_quant(x, gamma, beta, scale, offset, epsilon=1e-06, dst_dtype=291)
self.assertTrue(y.shape == x.shape)
self.assertTrue(y.dtype == torch.float8_e5m2)
class TestRmsNormQuantV2(TestCase):
def test_rms_norm_quant_v2(self):
with FakeTensorMode():
x = torch.randn([2, 16], dtype=torch.float16).npu()
gamma = torch.randn([16, ], dtype=torch.float16).npu()
scale = torch.randn(1, dtype=torch.float16).npu()
offset = torch.randint(-10, 10, (1, ), dtype=torch.float16).npu()
beta = torch.randn([16, ], dtype=torch.float16).npu()
y, rstd = torch_npu.npu_rms_norm_quant_v2(x, gamma, scale, offset=offset, beta=beta, epsilon=1e-06, div_mode=True, dst_dtype=1)
self.assertTrue(y.shape == x.shape)
self.assertTrue(y.dtype == torch.int8)
x = torch.randn([2, 16], dtype=torch.bfloat16).npu()
gamma = torch.randn([16, ], dtype=torch.bfloat16).npu()
scale = torch.randn(1, dtype=torch.bfloat16).npu()
offset = torch.randint(-10, 10, (1, ), dtype=torch.bfloat16).npu()
beta = torch.randn([16, ], dtype=torch.bfloat16).npu()
y, rstd = torch_npu.npu_rms_norm_quant_v2(x, gamma, scale, offset=offset, beta=beta, epsilon=1e-06, div_mode=True, dst_dtype=1)
self.assertTrue(y.shape == x.shape)
self.assertTrue(y.dtype == torch.int8)
x = torch.randn([2, 16], dtype=torch.float16).npu()
gamma = torch.randn([16, ], dtype=torch.float16).npu()
scale = torch.randn(1, dtype=torch.float16).npu()
offset = torch.randint(-10, 10, (1, ), dtype=torch.float16).npu()
beta = torch.randn([16, ], dtype=torch.float16).npu()
y, rstd = torch_npu.npu_rms_norm_quant_v2(x, gamma, scale, offset=offset, beta=beta, epsilon=1e-06, div_mode=False, dst_dtype=291)
self.assertTrue(y.shape == x.shape)
self.assertTrue(y.dtype == torch.float8_e5m2)
x = torch.randn([2, 16], dtype=torch.float16).npu()
gamma = torch.randn([1,16], dtype=torch.float16).npu()
scale = torch.randn(1, dtype=torch.float16).npu()
offset = torch.randint(-10, 10, (1, ), dtype=torch.float16).npu()
beta = torch.randn([1, 16], dtype=torch.float16).npu()
y, rstd = torch_npu.npu_rms_norm_quant_v2(x, gamma, scale, offset=offset, beta=beta, epsilon=1e-06, div_mode=False, dst_dtype=1)
self.assertTrue(y.shape == x.shape)
self.assertTrue(y.dtype == torch.int8)
class TestRmsNormDynamicMxQuant(TestCase):
def test_npu_rms_norm_dynamic_mx_quant_meta(self):
with FakeTensorMode():
x = torch.randn([8, 64], dtype=torch.float16, device='npu')
gamma = torch.ones([64, ], dtype=torch.float16, device='npu')
beta = torch.zeros([64, ], dtype=torch.float16, device='npu')
y_npu, mxscale_npu, rstd_npu = torch_npu.npu_rms_norm_dynamic_mx_quant(
x, gamma, beta=beta, epsilon=1e-6, scale_alg=0, round_mode="rint", dst_type=torch_npu.float8_e5m2
)
self.assertEqual(y_npu.shape, x.shape)
self.assertEqual(y_npu.dtype, torch.float8_e5m2)
self.assertEqual(mxscale_npu.shape, torch.Size([8, 1, 2]))
self.assertEqual(mxscale_npu.dtype, torch.uint8)
self.assertEqual(rstd_npu.shape, torch.Size([8, 1]))
self.assertEqual(rstd_npu.dtype, torch.float32)
class TestNpuRmsNorm(TestCase):
def test_npu_rms_norm(self):
with FakeTensorMode():
npu_x = torch.randn((2, 3), dtype=torch.float32).npu()
npu_gamma = torch.randn((3,), dtype=torch.float32).npu()
result_y, result_rstd = torch_npu.npu_rms_norm(npu_x, npu_gamma)
self.assertEqual(result_y.dtype, npu_x.dtype)
self.assertEqual(result_y.shape, npu_x.shape)
self.assertEqual(result_y.device, npu_x.device)
self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
self.assertEqual(result_rstd.device, npu_x.device)
def test_npu_rms_norm_backward(self):
with FakeTensorMode():
npu_x = torch.randn((2, 3), dtype=torch.float32).npu()
npu_gamma = torch.randn((3,), dtype=torch.float32).npu()
npu_x.requires_grad = True
npu_gamma.requires_grad = True
out = torch_npu.npu_rms_norm(npu_x, npu_gamma)[0]
grad_y = torch.randn((2, 3), dtype=torch.float32).npu()
out.backward(grad_y)
dx = npu_x.grad
dw = npu_gamma.grad
self.assertEqual(dx.dtype, npu_x.dtype)
self.assertEqual(dx.shape, npu_x.shape)
self.assertEqual(dx.device, npu_x.device)
self.assertEqual(dw.shape, npu_gamma.shape)
self.assertEqual(dw.device, npu_gamma.device)
class TestDistributeBarrier(TestCase):
def test_npu_distribute_barrier(self):
with FakeTensorMode():
ep_world_size = 16
x_ref = torch.randn(1).to(torch.int32)
time_out = torch.randn(1).to(torch.int32)
elastic_info = torch.randn(1,4+2*ep_world_size).to(torch.int32)
result = torch_npu._npu_distribute_barrier(x_ref = x_ref, group = "group_ep",
world_size = ep_world_size, time_out = time_out, elastic_info = elastic_info)
self.assertEqual(result.shape, x_ref.shape)
class TestQuantLightningIndexer(TestCase):
def quant_lightning_indexer_result(self):
with FakeTensorMode():
b = 1
t = None
s1 = 4
s2 = 512
act_seq_q = None
act_seq_k = None
sparse_mode = 0
sparse_count = 2048
n1 = 64
n2 = 1
d = 128
block_size = 128
layout_query = "BSND"
layout_key = 'PA_BSND'
query_quant_mode = 0
key_quant_mode = 0
np.random.seed(0)
max_block_table_num = (s2 + block_size - 1) // block_size
block_table = torch.tensor([range(b * max_block_table_num)], dtype = torch.int32).reshape(b, -1)
key = torch.tensor(np.random.uniform(-128, 127, (b * max_block_table_num, block_size, n2, d))).to(torch.int8)
key_dequant_scale = torch.tensor(np.random.uniform(0, 10, (b * max_block_table_num, block_size, n2)))
key_dequant_scale = key_dequant_scale.to(torch.float16)
if layout_query == 'BSND':
query = torch.tensor(np.random.uniform(-128, 127, (b, s1, n1, d))).to(torch.int8)
query_dequant_scale = torch.tensor(np.random.uniform(0, 10, (b, s1, n1))).to(torch.float16)
weights = torch.tensor(np.random.uniform(0, 0.01, (b, s1, n1))).to(torch.float16)
actual_seq_lengths_query = torch.tensor(np.random.uniform(s1, s1, (b))).to(torch.int32) \
if act_seq_q is None else torch.tensor(act_seq_q).to(torch.int32)
actual_seq_lengths_key = torch.tensor(np.random.uniform(s2, s2, (b))).to(torch.int32) \
if act_seq_k is None else torch.tensor(act_seq_k).to(torch.int32)
else:
query = torch.tensor(np.random.uniform(-128, 127, (t, n1, d))).to(torch.int8)
query_dequant_scale = torch.tensor(np.random.uniform(0, 10, (t, n1))).to(torch.float16)
weights = torch.tensor(np.random.uniform(0, 0.01, (t, n1))).to(torch.float16)
actual_seq_lengths_query = torch.tensor(act_seq_q).to(torch.int32)
actual_seq_lengths_key = torch.tensor(act_seq_k).to(torch.int32)
cpu_out = self.cpu_op_exec(query.cpu(), key.cpu(), weights.cpu(), query_dequant_scale.cpu(), key_dequant_scale.cpu(),
actual_seq_lengths_query.cpu(), actual_seq_lengths_key.cpu(), block_table.cpu(),
layout_query, sparse_count, sparse_mode)
npu_eager_out = self.npu_op_exec_eager(query.npu(), key.npu(), weights.npu(), query_dequant_scale.npu(), key_dequant_scale.npu(),
actual_seq_lengths_query.npu(), actual_seq_lengths_key.npu(), block_table.npu(),
query_quant_mode, key_quant_mode,
layout_query, layout_key, sparse_count, sparse_mode)
res = npu_eager_out.equal(cpu_out)
self.assertRtolEqual(res, True)
class TestNpuAddRmsNorm(TestCase):
def test_npu_add_rms_norm(self):
with FakeTensorMode():
npu_x1 = torch.randn((2, 3), dtype=torch.float32).npu()
npu_x2 = torch.randn((2, 3), dtype=torch.float32).npu()
npu_gamma = torch.randn((3,), dtype=torch.float32).npu()
result_y, result_rstd, result_x = torch_npu.npu_add_rms_norm(npu_x1, npu_x2, npu_gamma)
self.assertEqual(result_y.dtype, npu_x1.dtype)
self.assertEqual(result_y.shape, npu_x1.shape)
self.assertEqual(result_y.device, npu_x1.device)
self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
self.assertEqual(result_rstd.device, npu_x1.device)
self.assertEqual(result_x.dtype, npu_x1.dtype)
self.assertEqual(result_x.shape, npu_x1.shape)
self.assertEqual(result_x.device, npu_x1.device)
class TestNpuAddRmsNormV2(TestCase):
def test_npu_add_rms_norm_v2(self):
with FakeTensorMode():
npu_x1 = torch.randn((2, 3), dtype=torch.float32).npu()
npu_x2 = torch.randn((2, 3), dtype=torch.float32).npu()
npu_gamma = torch.randn((3,), dtype=torch.float32).npu()
result_rstd = torch_npu.npu_add_rms_norm_v2(npu_x1, npu_x2, npu_gamma)
self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
self.assertEqual(result_rstd.device, npu_x1.device)
self.assertEqual(result_rstd.dtype, torch.float32)
class TestNpuAddRmsNormV2Functional(TestCase):
def test_npu_add_rms_norm_v2_functional(self):
with FakeTensorMode():
npu_x1 = torch.randn((2, 3), dtype=torch.float32).npu()
npu_x2 = torch.randn((2, 3), dtype=torch.float32).npu()
npu_gamma = torch.randn((3,), dtype=torch.float32).npu()
result_rstd, result_y, result_x = torch_npu.npu_add_rms_norm_v2_functional(npu_x1, npu_x2, npu_gamma)
self.assertEqual(result_y.dtype, npu_x1.dtype)
self.assertEqual(result_y.shape, npu_x1.shape)
self.assertEqual(result_y.device, npu_x1.device)
self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
self.assertEqual(result_rstd.device, npu_x1.device)
self.assertEqual(result_x.dtype, npu_x1.dtype)
self.assertEqual(result_x.shape, npu_x1.shape)
self.assertEqual(result_x.device, npu_x1.device)
class TestNpuFfnWorkerBatching(TestCase):
def test_npu_ffn_worker_batching(self):
with FakeTensorMode():
x = torch.randn((1024,)).npu().to(torch.int8)
expert_num = 2
max_out_shape = [16, 8, 9, 7168]
Y = max_out_shape[0] * max_out_shape[1] * max_out_shape[2]
y, group_list, session_ids, micro_batch_ids, token_ids, expert_offsets, dynamic_scale, actual_token_num = torch_npu.npu_ffn_worker_batching(x,
expert_num, max_out_shape, token_dtype=2, need_schedule=1, layer_num=1)
self.assertTrue(y.shape[0] == Y)
self.assertTrue(y.shape[1] == max_out_shape[3])
self.assertTrue(group_list.shape[0] == expert_num)
self.assertTrue(group_list.shape[1] == 2)
self.assertTrue(session_ids.shape[0] == Y)
self.assertTrue(micro_batch_ids.shape[0] == Y)
self.assertTrue(token_ids.shape[0] == Y)
self.assertTrue(expert_offsets.shape[0] == Y)
self.assertTrue(dynamic_scale.shape[0] == Y)
self.assertTrue(actual_token_num.shape[0] == 1)
class TestFFN(TestCase):
def test_npu_ffn_meta(self):
with FakeTensorMode():
x = torch.randn(1, 320, dtype=torch.float16).npu()
w1 = torch.randn(320, 2560, dtype=torch.float16).npu()
w2 = torch.randn(2560, 320, dtype=torch.float16).npu()
activation = "gelu"
res = torch_npu.npu_ffn(x, w1, w2, activation, inner_precise=1, output_dtype=torch.float16)
self.assertTrue(x.shape == res.shape)
class TestNpuDynamicQuant(TestCase):
def test_npu_dynamic_quant(self):
with FakeTensorMode():
input_npu = torch.randn((4, 2048, 1024)).npu().to(torch.float16)
smooth_scales_npu = torch.randn((1024)).npu().to(torch.float16)
output = torch.randn((4, 2048, 1024)).npu().to(torch.int8)
scale = torch.randn((4, 2048)).npu().to(torch.float32)
actual_output, actual_scale = torch_npu.npu_dynamic_quant(input_npu, smooth_scales=smooth_scales_npu)
self.assertEqual(actual_output.dtype, output.dtype)
self.assertEqual(actual_output.shape, output.shape)
self.assertEqual(actual_output.device, output.device)
self.assertEqual(actual_scale.dtype, scale.dtype)
self.assertEqual(actual_scale.shape, scale.shape)
self.assertEqual(actual_scale.device, scale.device)
class TestDynamicQuantAsymmetric(TestCase):
def test_npu_dynamic_quant_asymmetric(self):
with FakeTensorMode():
input_npu = torch.randn((4, 2048, 1024)).npu().to(torch.float16)
smooth_scales_npu = torch.randn((1024)).npu().to(torch.float16)
output = torch.randn((4, 2048, 1024)).npu().to(torch.int8)
scale = torch.randn((4, 2048)).npu().to(torch.float32)
offset = torch.randn((4, 2048)).npu().to(torch.float32)
actual_output, actual_scale, actual_offset = torch_npu.npu_dynamic_quant_asymmetric(input_npu, smooth_scales=smooth_scales_npu)
self.assertEqual(actual_output.dtype, output.dtype)
self.assertEqual(actual_output.shape, output.shape)
self.assertEqual(actual_output.device, output.device)
self.assertEqual(actual_scale.dtype, scale.dtype)
self.assertEqual(actual_scale.shape, scale.shape)
self.assertEqual(actual_scale.device, scale.device)
self.assertEqual(actual_offset.dtype, offset.dtype)
self.assertEqual(actual_offset.shape, offset.shape)
self.assertEqual(actual_offset.device, offset.device)
class TestGroupedMatmul(TestCase):
def test_npu_grouped_matmul_meta_0(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(256, 256, dtype=torch.float16).npu()
x2 = torch.randn(1024, 256, dtype=torch.float16).npu()
x3 = torch.randn(512, 1024, dtype=torch.float16).npu()
x = [x1, x2, x3]
w1 = torch.randn(256, 256, dtype=torch.float16).npu()
w2 = torch.randn(256, 1024, dtype=torch.float16).npu()
w3 = torch.randn(1024, 128, dtype=torch.float16).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(1024, dtype=torch.float16).npu()
b3 = torch.randn(128, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = None
split_item = 0
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=-1)
self.assertTrue(x[0].shape[0] == res[0].shape[0])
self.assertTrue(x[1].shape[0] == res[1].shape[0])
self.assertTrue(x[2].shape[0] == res[2].shape[0])
self.assertTrue(w[0].shape[1] == res[0].shape[1])
self.assertTrue(w[1].shape[1] == res[1].shape[1])
self.assertTrue(w[2].shape[1] == res[2].shape[1])
def test_npu_grouped_matmul_meta_1(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(1792, 1024, dtype=torch.float16).npu()
x = [x1]
w1 = torch.randn(1024, 256, dtype=torch.float16).npu()
w2 = torch.randn(1024, 1024, dtype=torch.float16).npu()
w3 = torch.randn(1024, 128, dtype=torch.float16).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(1024, dtype=torch.float16).npu()
b3 = torch.randn(128, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = [256, 1280, 1792]
split_item = 1
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=-1)
self.assertTrue(group_list[0] == res[0].shape[0])
self.assertTrue(group_list[1] - group_list[0] == res[1].shape[0])
self.assertTrue(group_list[2] - group_list[1] == res[2].shape[0])
self.assertTrue(w[0].shape[1] == res[0].shape[1])
self.assertTrue(w[1].shape[1] == res[1].shape[1])
self.assertTrue(w[2].shape[1] == res[2].shape[1])
def test_npu_grouped_matmul_meta_2(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(256, 256, dtype=torch.float16).npu()
x2 = torch.randn(1024, 256, dtype=torch.float16).npu()
x3 = torch.randn(512, 1024, dtype=torch.float16).npu()
x = [x1, x2, x3]
w1 = torch.randn(256, 256, dtype=torch.float16).npu()
w2 = torch.randn(256, 256, dtype=torch.float16).npu()
w3 = torch.randn(1024, 256, dtype=torch.float16).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(256, dtype=torch.float16).npu()
b3 = torch.randn(256, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = None
split_item = 2
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
dim0 = 0
for xi in x:
dim0 += xi.shape[0]
self.assertTrue(dim0 == res[0].shape[0])
self.assertTrue(w[0].shape[1] == res[0].shape[1])
def test_npu_grouped_matmul_meta_3(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(1792, 1024, dtype=torch.float16).npu()
x = [x1]
w1 = torch.randn(1024, 256, dtype=torch.float16).npu()
w2 = torch.randn(1024, 256, dtype=torch.float16).npu()
w3 = torch.randn(1024, 256, dtype=torch.float16).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(256, dtype=torch.float16).npu()
b3 = torch.randn(256, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = [256, 1280, 1792]
split_item = 3
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
self.assertTrue(x[0].shape[0] == res[0].shape[0])
self.assertTrue(w[0].shape[1] == res[0].shape[1])
def test_npu_grouped_matmul_meta_4(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(1792, 1024, dtype=torch.float16).npu()
x = [x1]
w1 = torch.randn(3, 1024, 256, dtype=torch.float16).npu()
w = [w1]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(256, dtype=torch.float16).npu()
b3 = torch.randn(256, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = [256, 1280, 1792]
split_item = 3
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
self.assertTrue(x[0].shape[0] == res[0].shape[0])
self.assertTrue(w[0].shape[2] == res[0].shape[1])
def test_npu_grouped_matmul_meta_5(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randint(0, 10, (256, 256), dtype=torch.int8).npu()
x2 = torch.randint(0, 10, (1024, 256), dtype=torch.int8).npu()
x3 = torch.randint(0, 10, (512, 1024), dtype=torch.int8).npu()
x = [x1, x2, x3]
w1 = torch.randint(0, 10, (256, 256), dtype=torch.int32).npu()
w2 = torch.randint(0, 10, (256, 1024), dtype=torch.int32).npu()
w3 = torch.randint(0, 10, (1024, 128), dtype=torch.int32).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(1024, dtype=torch.float16).npu()
b3 = torch.randn(128, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = None
split_item = 0
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=-1)
self.assertTrue(x[0].shape[0] == res[0].shape[0])
self.assertTrue(x[1].shape[0] == res[1].shape[0])
self.assertTrue(x[2].shape[0] == res[2].shape[0])
self.assertTrue((w[0].shape[1] * 8) == res[0].shape[1])
self.assertTrue((w[1].shape[1] * 8) == res[1].shape[1])
self.assertTrue((w[2].shape[1] * 8) == res[2].shape[1])
def test_npu_grouped_matmul_meta_6(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randint(0, 10, (1792, 1024), dtype=torch.int8).npu()
x = [x1]
w1 = torch.randint(0, 10, (1024, 256), dtype=torch.int32).npu()
w2 = torch.randint(0, 10, (1024, 1024), dtype=torch.int32).npu()
w3 = torch.randint(0, 10, (1024, 128), dtype=torch.int32).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(1024, dtype=torch.float16).npu()
b3 = torch.randn(128, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = [256, 1280, 1792]
split_item = 1
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
self.assertTrue(group_list[0] == res[0].shape[0])
self.assertTrue(group_list[1] - group_list[0] == res[1].shape[0])
self.assertTrue(group_list[2] - group_list[1] == res[2].shape[0])
self.assertTrue((w[0].shape[1] * 8) == res[0].shape[1])
self.assertTrue((w[1].shape[1] * 8) == res[1].shape[1])
self.assertTrue((w[2].shape[1] * 8) == res[2].shape[1])
def test_npu_grouped_matmul_meta_7(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randint(0, 10, (256, 256), dtype=torch.int8).npu()
x2 = torch.randint(0, 10, (1024, 256), dtype=torch.int8).npu()
x3 = torch.randint(0, 10, (512, 1024), dtype=torch.int8).npu()
x = [x1, x2, x3]
w1 = torch.randint(0, 10, (256, 256), dtype=torch.int32).npu()
w2 = torch.randint(0, 10, (256, 1024), dtype=torch.int32).npu()
w3 = torch.randint(0, 10, (1024, 128), dtype=torch.int32).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(256, dtype=torch.float16).npu()
b3 = torch.randn(256, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = None
split_item = 2
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
dim0 = 0
for xi in x:
dim0 += xi.shape[0]
self.assertTrue(dim0 == res[0].shape[0])
self.assertTrue((w[0].shape[1] * 8) == res[0].shape[1])
def test_npu_grouped_matmul_meta_8(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randint(0, 10, (1792, 1024), dtype=torch.int8).npu()
x = [x1]
w1 = torch.randint(0, 10, (1024, 256), dtype=torch.int32).npu()
w2 = torch.randint(0, 10, (1024, 256), dtype=torch.int32).npu()
w3 = torch.randint(0, 10, (1024, 256), dtype=torch.int32).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(256, dtype=torch.float16).npu()
b3 = torch.randn(256, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = [256, 1280, 1792]
split_item = 3
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
self.assertTrue(x[0].shape[0] == res[0].shape[0])
self.assertTrue((w[0].shape[1] * 8) == res[0].shape[1])
def test_npu_grouped_matmul_meta_9(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randint(2, 3, size=(1792, 1024), dtype=torch.int8).npu()
x = [x1]
w1 = torch.randint(2, 3, size=(3, 1024, 256), dtype=torch.int8).npu()
w = [w1]
scale1 = torch.randint(2, 3, size=(3, 256), dtype=torch.int64).npu()
scale = [scale1]
group_list = torch.tensor([256, 1280, 1792]).to(torch.int64).npu()
split_item = 3
res = torch_npu.npu_grouped_matmul(x, w, bias=None, scale=scale, group_list=group_list, split_item=split_item, group_type=0, output_dtype=torch.int32)
self.assertTrue(x[0].shape[0] == res[0].shape[0])
self.assertTrue(w[0].shape[2] == res[0].shape[1])
def test_npu_grouped_matmul_meta_10(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(1792, 1024, dtype=torch.float32).npu()
x = [x1]
w1 = torch.randn(1024, 256, dtype=torch.float32).npu()
w2 = torch.randn(1024, 256, dtype=torch.float32).npu()
w3 = torch.randn(1024, 256, dtype=torch.float32).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float32).npu()
b2 = torch.randn(256, dtype=torch.float32).npu()
b3 = torch.randn(256, dtype=torch.float32).npu()
b = [b1, b2, b3]
group_list = [256, 1280, 1792]
split_item = 3
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
self.assertTrue(x[0].shape[0] == res[0].shape[0])
self.assertTrue(w[0].shape[1] == res[0].shape[1])
def test_npu_grouped_matmul_meta_11(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(1792, 1024, dtype=torch.float32).npu()
x = [x1]
w1 = torch.randn(3, 1024, 256, dtype=torch.float32).npu()
w = [w1]
b1 = torch.randn(256, dtype=torch.float32).npu()
b2 = torch.randn(256, dtype=torch.float32).npu()
b3 = torch.randn(256, dtype=torch.float32).npu()
b = [b1, b2, b3]
group_list = [256, 1280, 1792]
split_item = 3
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
self.assertTrue(x[0].shape[0] == res[0].shape[0])
self.assertTrue(w[0].shape[2] == res[0].shape[1])
def test_npu_grouped_matmul_meta_12(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(1792, 1024, dtype=torch.float32).npu()
x = [x1]
w1 = torch.randn(1024, 256, dtype=torch.float32).npu()
w2 = torch.randn(1024, 256, dtype=torch.float32).npu()
w3 = torch.randn(1024, 256, dtype=torch.float32).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float32).npu()
b2 = torch.randn(256, dtype=torch.float32).npu()
b3 = torch.randn(256, dtype=torch.float32).npu()
b = [b1, b2, b3]
group_list = torch.tensor([256, 1280, 1792]).to(torch.int64).npu()
split_item = 3
group_type = 0
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=group_type)
self.assertTrue(x[0].shape[0] == res[0].shape[0])
self.assertTrue(w[0].shape[1] == res[0].shape[1])
def test_npu_grouped_matmul_meta_950_fp8_1(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randint(2, 3, size=(16, 256), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
x = [x1, ]
w1 = torch.randint(2, 3, size=(1, 256, 32), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
w = [w1, ]
group_list = torch.tensor([16, ]).to(torch.int64).npu()
split_item = 2
scale1 = torch.randint(2, 3, size=(1, 32), dtype=torch.int64).npu()
scale = [scale1, ]
res = torch_npu.npu_grouped_matmul(x, w, bias=None, scale=scale, group_list=group_list, split_item=split_item, group_type=0, output_dtype=torch.float16)
dim0 = x1.shape[0]
dim1 = w1.shape[2]
self.assertTrue(dim0 == res[0].shape[0])
self.assertTrue(dim1 == res[0].shape[1])
def test_npu_grouped_matmul_meta_950_fp8_group_type_2(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randint(2, 3, size=(128, 4096), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
x = [x1.t(), ]
w1 = torch.randint(2, 3, size=(128, 7168), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
w = [w1, ]
group_list = torch.tensor([32, 32, 64]).to(torch.int64).npu()
split_item = 2
scale1 = torch.randn(3, 56, dtype=torch.float32).npu()
scale2 = torch.randn(3, 4096, dtype=torch.float32).npu()
scale = [scale1, ]
per_token_scale = [scale2.t(), ]
res = torch_npu.npu_grouped_matmul(x, w, bias=None, scale=scale, per_token_scale=per_token_scale, group_list=group_list, split_item=split_item, group_type=2, group_list_type=0, output_dtype=torch.float16)
dim0 = x1.shape[1]
dim1 = w1.shape[1]
self.assertTrue(res[0].dtype == torch.float16)
self.assertTrue(res[0].shape[0] == 3)
self.assertTrue(len(res[0].shape) == 3)
self.assertTrue(dim0 == res[0].shape[1])
self.assertTrue(dim1 == res[0].shape[2])
def test_npu_grouped_matmul_meta_weight_float32_transpose(self):
with FakeTensorMode():
torch.manual_seed(0)
K = 1024
M = 16
N = 256
E = 2
group_size = 32
x1 = torch.randint(0, 10, (M, K), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
x = [x1]
w1 = torch.randint(0, 10, (E, N, K), dtype=torch.float32).npu()
w = [w1.transpose(-1, -2)]
antiQuant_Scale1 = torch.randint(0, 10, (E, N, K//group_size), dtype=torch.uint8).npu()
antiQuant_Scale = [antiQuant_Scale1.transpose(-1, -2)]
per_token_scale1 = torch.randint(0, 10, (M, K//group_size), dtype=torch.uint8).npu()
per_token_scale = [per_token_scale1]
group_list = [8, 8]
split_item = 3
res = torch_npu.npu_grouped_matmul(x, w, antiquant_scale=antiQuant_Scale, per_token_scale=per_token_scale, group_list_type=1, group_list=group_list, split_item=split_item, group_type=0)
self.assertTrue(x1.shape[0] == res[0].shape[0])
self.assertTrue((w1.shape[1]) == res[0].shape[1])
def test_npu_grouped_matmul_meta_quant_empty_tensor_m0(self):
with FakeTensorMode():
torch.manual_seed(0)
K = 128
M = 0
N = 128
E = 2
x1 = torch.randint(0, 10, (M, K), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
x2 = torch.randint(0, 10, (E, N, K), dtype=torch.int8).view(torch.float8_e5m2).npu()
x = [x1,]
w = [x2,]
group_list = torch.tensor([0, 0]).to(torch.int64).npu()
split_item = 2
scale2 = torch.randint(0, 10, (E, N), dtype=torch.int64).npu()
scale = [scale2,]
res = torch_npu.npu_grouped_matmul(x, w, bias=None, scale=scale, group_list=group_list, split_item=split_item, group_type=0, group_list_type=0, output_dtype=torch.float16)
self.assertTrue(len(res[0]) == 0)
def test_npu_grouped_matmul_meta_quant_empty_tensor_k0(self):
with FakeTensorMode():
torch.manual_seed(0)
K = 0
M = 128
N = 128
E = 2
x1 = torch.randint(0, 10, (M, K), dtype=torch.int8).view(torch.float8_e4m3fn).npu()
x2 = torch.randint(0, 10, (E, K, N), dtype=torch.int8).view(torch.float8_e5m2).npu()
x = [x1,]
w = [x2,]
group_list = torch.tensor([0, 0]).to(torch.int64).npu()
split_item = 2
scale2 = torch.randint(0, 10, (E, N), dtype=torch.int64).npu()
scale = [scale2,]
with self.assertRaises(RuntimeError):
torch_npu.npu_grouped_matmul(x, w, bias=None, scale=scale, group_list=group_list, split_item=split_item, group_type=0, group_list_type=0, output_dtype=torch.float16)
class TestQuantMatmul(TestCase):
def test_npu_quant_matmul_meta(self):
with FakeTensorMode():
x1 = torch.randint(-1, 1, (1, 1, 1024), dtype=torch.int8).npu()
x2 = torch.randint(-1, 1, (1, 1024, 100), dtype=torch.int8).npu()
expect_ret = torch.randint(-1, 1, (1, 1, 100), dtype=torch.int8).npu()
scale = torch.randn(1, dtype=torch.float32).npu()
offset = torch.randn(1, dtype=torch.float32).npu()
bias = torch.randint(-1, 1, (1, 1, 100), dtype=torch.int32).npu()
res = torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias)
self.assertTrue(expect_ret.shape == res.shape)
self.assertTrue(expect_ret.dtype == res.dtype)
expect_ret_bf16 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.bfloat16).npu()
scale_bf16 = torch.randn(1, dtype=torch.bfloat16).npu()
bias_bf16 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.bfloat16).npu()
res_bf16 = torch_npu.npu_quant_matmul(x1, x2, scale_bf16, offset=None, bias=bias_bf16, output_dtype=torch.bfloat16)
self.assertTrue(expect_ret_bf16.shape == res_bf16.shape)
self.assertTrue(expect_ret_bf16.dtype == res_bf16.dtype)
expect_ret_fp16 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.float16).npu()
bias_fp32 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.float32).npu()
pertoken_scale = torch.randn(1, dtype=torch.float32).npu()
res_fp16 = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, pertoken_scale=pertoken_scale,
bias=bias_fp32, output_dtype=torch.float16)
self.assertTrue(expect_ret_fp16.shape == res_fp16.shape)
self.assertTrue(expect_ret_fp16.dtype == res_fp16.dtype)
expect_ret_int32 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.int32).npu()
bias_int32 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.int32).npu()
res_int32 = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, pertoken_scale=None,
bias=bias_int32, output_dtype=torch.int32)
self.assertTrue(expect_ret_int32.shape == res_int32.shape)
self.assertTrue(expect_ret_int32.dtype == res_int32.dtype)
x1 = torch.randint(-1, 1, (16, 8), dtype=torch.int32).npu()
x2 = torch.randint(-1, 1, (64, 5), dtype=torch.int32).npu()
expect_ret = torch.randint(-1, 1, (16, 40), dtype=torch.float16).npu()
scale = torch.randn(1, dtype=torch.float32).npu()
bias = torch.randint(-1, 1, (40,), dtype=torch.int32).npu()
res = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, bias=bias, output_dtype=torch.float16)
self.assertTrue(expect_ret.shape == res.shape)
self.assertTrue(expect_ret.dtype == res.dtype)
x1 = torch.randint(-1, 1, (16, 8), dtype=torch.int32).npu()
x2 = torch.randint(-1, 1, (64, 5), dtype=torch.int32).npu()
expect_ret = torch.randint(-1, 1, (16, 40), dtype=torch.bfloat16).npu()
pertoken_scale = torch.randn(16, dtype=torch.float32).npu()
scale = torch.randn(1, dtype=torch.float32).npu()
bias = torch.randint(-1, 1, (40,), dtype=torch.int32).npu()
res = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, pertoken_scale=pertoken_scale, bias=bias, output_dtype=torch.bfloat16)
self.assertTrue(expect_ret.shape == res.shape)
self.assertTrue(expect_ret.dtype == res.dtype)
x1 = torch.randint(-8, 8, (1, 8192), dtype=torch.int8).npu()
x2 = torch.randint(-8, 8, (8192, 128), dtype=torch.int32).npu()
expect_ret = torch.randn([1, 128 * 8], dtype=torch.float16).npu()
y_offset = torch.randn([128 * 8, ], dtype=torch.float32).npu()
x1_scale = torch.randn([1, 1], dtype=torch.float32).npu()
x2_scale = torch.randint(1, 3, [8192 // 256, 128 * 8], dtype=torch.int64).npu()
group_size_list = [0, 0, 256]
res = torch_npu.npu_quant_matmul(x1, x2, x2_scale, offset=y_offset, pertoken_scale=x1_scale,
bias=None, output_dtype=torch.float16, group_sizes=group_size_list)
self.assertTrue(expect_ret.shape == res.shape)
self.assertTrue(expect_ret.dtype == res.dtype)
x1 = torch.randint(-1, 1, (1, 64), dtype=torch.float8_e4m3fn).npu()
x2 = torch.randn((64, 16), dtype=torch.float32).npu()
scale = torch.randint(-1, 1, (2, 128), dtype=torch.uint8).npu()
y_scale = torch.randint(-1, 1, (1, 128), dtype=torch.int64).npu()
expect_ret = torch.randn((1, 128), dtype=torch.bfloat16).npu()
res = torch_npu.npu_quant_matmul(x1, x2, scale, pertoken_scale=None, scale_dtype=None, pertoken_scale_dtype=None, y_scale=y_scale, output_dtype=torch.bfloat16, group_sizes=[0, 0, 32])
self.assertTrue(expect_ret.shape == res.shape)
self.assertTrue(expect_ret.dtype == res.dtype)
x1 = torch.randint(-8, 8, (1, 8192), dtype=torch.int8).npu()
x2 = torch.randint(-8, 8, (8192, 128), dtype=torch.int8).npu()
expect_ret = torch.randn([1, 128], dtype=torch.float16).npu()
scale = torch.randn(1, dtype=torch.float32).npu()
res = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, pertoken_scale=None,
bias=None, output_dtype=torch_npu.hifloat8, x1_dtype=torch_npu.hifloat8,
x2_dtype=torch_npu.hifloat8)
self.assertTrue(expect_ret.shape == res.shape)
class TestQuantMatmulReduceSum(TestCase):
def test_quant_matmul_reduce_sum(self):
with FakeTensorMode():
b, m, n, k = (2, 3, 4, 5)
x1 = torch.randint(-1, 1, (b, m, k), dtype=torch.int8).npu()
x2 = torch.randint(-1, 1, (b, k, n), dtype=torch.int8).npu()
x1_scale = torch.ones((b, m), dtype=torch.float32).npu()
x2_scale = torch.ones((n,), dtype=torch.bfloat16).npu()
y = torch_npu.npu_quant_matmul_reduce_sum(x1, x2, x1_scale=x1_scale, x2_scale=x2_scale)
self.assertTrue(y.shape[0] == m)
self.assertTrue(y.shape[1] == n)
self.assertTrue(y.dtype == torch.bfloat16)
class TestQuantMatmulGelu(TestCase):
@unittest.skip("Skipping due to outdated CANN version; please update CANN to the latest version and remove this skip")
def test_npu_quant_matmul_gelu_meta(self):
with FakeTensorMode():
m, k, n = 128, 256, 512
x1 = torch.randint(-5, 5, (m, k), dtype=torch.int8).npu()
x2 = torch.randint(-5, 5, (k, n), dtype=torch.int8).npu()
x1_scale = torch.randn(m, dtype=torch.float32).abs() * 0.01
x2_scale = torch.randn(n, dtype=torch.float32).abs() * 0.01
expect_ret = torch.randn((m, n), dtype=torch.float16).npu()
res = torch_npu.npu_quant_matmul_gelu(x1, x2, x1_scale.npu(), x2_scale.npu(), approximate="gelu_tanh")
self.assertTrue(expect_ret.shape == res.shape)
self.assertTrue(expect_ret.dtype == res.dtype)
bias = torch.randn(n, dtype=torch.float32) * 0.1
res_bias = torch_npu.npu_quant_matmul_gelu(x1, x2, x1_scale.npu(), x2_scale.npu(), bias=bias.npu(), approximate="gelu_erf")
self.assertTrue(expect_ret.shape == res_bias.shape)
self.assertTrue(expect_ret.dtype == res_bias.dtype)
x2_scale_bf16 = torch.randn(n, dtype=torch.bfloat16).abs() * 0.01
expect_ret_bf16 = torch.randn((m, n), dtype=torch.bfloat16).npu()
res_bf16 = torch_npu.npu_quant_matmul_gelu(x1, x2, x1_scale.npu(), x2_scale_bf16.npu(), approximate="gelu_tanh")
self.assertTrue(expect_ret_bf16.shape == res_bf16.shape)
self.assertTrue(expect_ret_bf16.dtype == res_bf16.dtype)
batch, m, k, n = 4, 64, 128, 256
x1_batch = torch.randint(-5, 5, (batch, m, k), dtype=torch.int8).npu()
x2_batch = torch.randint(-5, 5, (batch, k, n), dtype=torch.int8).npu()
x1_scale_batch = torch.randn(m, dtype=torch.float32).abs() * 0.01
x2_scale_batch = torch.randn(n, dtype=torch.float32).abs() * 0.01
expect_ret_batch = torch.randn((batch, m, n), dtype=torch.float16).npu()
res_batch = torch_npu.npu_quant_matmul_gelu(x1_batch, x2_batch, x1_scale_batch.npu(), x2_scale_batch.npu())
self.assertTrue(expect_ret_batch.shape == res_batch.shape)
self.assertTrue(expect_ret_batch.dtype == res_batch.dtype)
m, k, n = 64, 128, 256
k_packed = k // 8
n_packed = n // 8
x1_int32 = torch.randint(-8, 8, (m, k_packed), dtype=torch.int32).npu()
x2_int32 = torch.randint(-8, 8, (k_packed, n_packed), dtype=torch.int32).npu()
x1_scale_int32 = torch.randn(m, dtype=torch.float32).abs() * 0.01
x2_scale_int32 = torch.randn(n, dtype=torch.float32).abs() * 0.01
expect_ret_int32 = torch.randn((m, n), dtype=torch.float16).npu()
res_int32 = torch_npu.npu_quant_matmul_gelu(x1_int32, x2_int32, x1_scale_int32.npu(), x2_scale_int32.npu(), approximate="gelu_tanh")
self.assertTrue(expect_ret_int32.shape == res_int32.shape)
self.assertTrue(expect_ret_int32.dtype == res_int32.dtype)
bias_int32 = torch.randint(-5, 5, (n,), dtype=torch.int32)
res_int32_bias = torch_npu.npu_quant_matmul_gelu(x1_int32, x2_int32, x1_scale_int32.npu(), x2_scale_int32.npu(),
bias=bias_int32.npu(), approximate="gelu_erf")
self.assertTrue(expect_ret_int32.shape == res_int32_bias.shape)
self.assertTrue(expect_ret_int32.dtype == res_int32_bias.dtype)
class TestMatmulCompressDequant(TestCase):
@unittest.skip("Skip until CANN supports aclnnMatmulCompressDequant; do not execute")
def test_npu_matmul_compress_dequant_meta(self):
with FakeTensorMode():
m, k, n = 16, 256, 128
x1 = torch.ones((m, k), dtype=torch.int8).npu()
x2_len = k * n - 1
x2 = torch.zeros((x2_len,), dtype=torch.int8).npu()
compress_index = torch.zeros(8, dtype=torch.int8).npu()
bias = torch.zeros((1, n), dtype=torch.int32).npu()
scale = torch.ones((1, n), dtype=torch.float32).npu()
scale_uint64 = torch_npu.npu_trans_quant_param(scale).to(torch.uint64)
res = torch_npu.npu_matmul_compress_dequant(x1, x2, compress_index, bias, scale_uint64)
self.assertEqual(res.shape, (m, n))
self.assertEqual(res.dtype, torch.float16)
class TestRecurrentGatedDeltaRule(TestCase):
def test_recurrent_gated_delta_rule(self):
with FakeTensorMode():
(b, mtp, nk, nv, dk, dv) = (64, 2, 4, 8, 128, 128)
actual_seq_lengths = (torch.ones(b) * mtp).npu().to(torch.int32)
T = b * mtp
state = torch.rand((T, nv, dv, dk), dtype=torch.bfloat16).npu()
query = torch.rand((T, nk, dk), dtype=torch.bfloat16).npu()
key = torch.rand((T, nk, dk), dtype=torch.bfloat16).npu()
value = torch.rand((T, nv, dv), dtype=torch.bfloat16).npu()
g = torch.rand((T, nv), dtype=torch.float32).npu()
beta = torch.rand((T, nv), dtype=torch.bfloat16).npu()
ssm_state_indices = (torch.arange(T).npu()).to(torch.int32)
query = torch.nn.functional.normalize(query, p=2, dim=-1)
key = torch.nn.functional.normalize(key, p=2, dim=-1)
scale = 0.5
num_accepted_tokens = torch.randint(1, mtp + 1, (b,)).npu().to(torch.int32)
state_copy = state.clone()
out = torch_npu.npu_recurrent_gated_delta_rule(query, key, value, state_copy, beta=beta, scale=scale, actual_seq_lengths=actual_seq_lengths, ssm_state_indices=ssm_state_indices, g=g, num_accepted_tokens=num_accepted_tokens)
expect_out_shape = torch.randn(T, nv, dv, dtype = torch.bfloat16).npu()
self.assertTrue(out.shape == expect_out_shape.shape)
self.assertTrue(out.dtype == torch.bfloat16)
state_copy = state.clone()
out_inplace, state_out_inplace = torch_npu.npu_recurrent_gated_delta_rule_functional(query, key, value, state, beta=beta, scale=scale, actual_seq_lengths=actual_seq_lengths, ssm_state_indices=ssm_state_indices, g=g, num_accepted_tokens=num_accepted_tokens)
expect_out_inplace = torch.randn(T, nv, dv, dtype = torch.bfloat16).npu()
expect_state_inplace = torch.randn(T, nv, dv, dk, dtype = torch.bfloat16).npu()
self.assertTrue(out_inplace.shape == expect_out_inplace.shape)
self.assertTrue(state_out_inplace.shape == expect_state_inplace.shape)
self.assertTrue(state_out_inplace.dtype == torch.bfloat16)
self.assertTrue(out_inplace.dtype == torch.bfloat16)
class TestTranQuantParam(TestCase):
def test_npu_trans_quant_param_meta(self):
with FakeTensorMode():
test_1_expect_ret = torch.randint(-1, 1, (4,), dtype=torch.int64).npu()
test_1_scale = torch.randn(1, dtype=torch.float32).npu()
test_1_offset = torch.randn(4, dtype=torch.float32).npu()
res = torch_npu.npu_trans_quant_param(test_1_scale, test_1_offset)
self.assertTrue(res.shape == test_1_expect_ret.shape)
self.assertTrue(res.dtype == test_1_expect_ret.dtype)
test_2_expect_ret = torch.randint(-1, 1, (1, 4), dtype=torch.int64).npu()
test_2_scale = torch.randn(1, 4, dtype=torch.float32).npu()
test_2_offset = torch.randn(1, 4, dtype=torch.float32).npu()
res = torch_npu.npu_trans_quant_param(test_2_scale, test_2_offset)
self.assertTrue(res.shape == test_2_expect_ret.shape)
self.assertTrue(res.dtype == test_2_expect_ret.dtype)
class TestAttentionUpdate(TestCase):
def test_npu_attention_update_meta(self):
with FakeTensorMode():
N, H, K = 8, 64, 2
update_type = 0
dtype = torch.float32
lse_list = [torch.randn([N], dtype=dtype).npu() for _ in range(K)]
local_out_list = [torch.randn([N, H], dtype=dtype).npu() for _ in range(K)]
out, lse_out = torch_npu.npu_attention_update(lse_list, local_out_list, update_type)
self.assertTrue(out.shape == torch.Size([N, H]))
self.assertTrue(out.dtype == dtype)
class TestRingAttentionUpdate(TestCase):
def test_npu_ring_attention_update_meta_sbh(self):
with FakeTensorMode():
prev_attn_out = torch.randn((4, 2, 32), dtype=torch.float16).npu()
cur_attn_out = torch.randn((4, 2, 32), dtype=torch.float16).npu()
prev_softmax_max = torch.randn((2, 2, 4, 8), dtype=torch.float32).abs().npu()
prev_softmax_sum = torch.randn((2, 2, 4, 8), dtype=torch.float32).abs().npu()
cur_softmax_max = torch.randn((2, 2, 4, 8), dtype=torch.float32).abs().npu()
cur_softmax_sum = torch.randn((2, 2, 4, 8), dtype=torch.float32).abs().npu()
attn_out, softmax_max, softmax_sum = torch_npu.npu_ring_attention_update(
prev_attn_out, prev_softmax_max, prev_softmax_sum,
cur_attn_out, cur_softmax_max, cur_softmax_sum)
self.assertEqual(attn_out.shape, prev_attn_out.shape)
self.assertEqual(attn_out.dtype, prev_attn_out.dtype)
self.assertEqual(attn_out.device.type, "npu")
self.assertEqual(softmax_max.shape, prev_softmax_max.shape)
self.assertEqual(softmax_max.dtype, torch.float32)
self.assertEqual(softmax_sum.shape, prev_softmax_sum.shape)
self.assertEqual(softmax_sum.dtype, torch.float32)
def test_npu_ring_attention_update_meta_tnd(self):
with FakeTensorMode():
prev_attn_out = torch.randn((5, 2, 64), dtype=torch.bfloat16).npu()
cur_attn_out = torch.randn((5, 2, 64), dtype=torch.bfloat16).npu()
prev_softmax_max = torch.randn((5, 2, 8), dtype=torch.float32).abs().npu()
prev_softmax_sum = torch.randn((5, 2, 8), dtype=torch.float32).abs().npu()
cur_softmax_max = torch.randn((5, 2, 8), dtype=torch.float32).abs().npu()
cur_softmax_sum = torch.randn((5, 2, 8), dtype=torch.float32).abs().npu()
actual_seq_qlen = torch.tensor([2, 5], dtype=torch.int64).npu()
attn_out, softmax_max, softmax_sum = torch_npu.npu_ring_attention_update(
prev_attn_out, prev_softmax_max, prev_softmax_sum,
cur_attn_out, cur_softmax_max, cur_softmax_sum,
actual_seq_qlen=actual_seq_qlen, input_layout="TND", input_softmax_layout="TND")
self.assertEqual(attn_out.shape, prev_attn_out.shape)
self.assertEqual(attn_out.dtype, prev_attn_out.dtype)
self.assertEqual(attn_out.device.type, "npu")
self.assertEqual(softmax_max.shape, prev_softmax_max.shape)
self.assertEqual(softmax_sum.shape, prev_softmax_sum.shape)
class TestAntiQuant(TestCase):
@unittest.skipIf(torch.__version__ < '2.1.0',
"OP `AntiQuant` is supported on torch v2.1 and above, skip this test for torch version below 2.1")
def test_npu_anti_quant_meta(self):
with FakeTensorMode():
x = torch.randint(low=-128, high=127, size=(20, 100), dtype=torch.int8).npu()
scale = torch.randn(100, dtype=torch.float).npu()
offset = torch.randn(100, dtype=torch.float).npu()
dstType = torch.float16
res = torch_npu.npu_anti_quant(x, scale, offset=offset, dst_dtype=dstType)
self.assertTrue(x.shape == res.shape)
x = x.to(dstType)
self.assertTrue(x.numel() * x.element_size() == res.numel() * res.element_size())
class TestNpuKroneckerQuant(TestCase):
def test_npu_kronecker_quant_meta(self):
with FakeTensorMode():
x = torch.randn(16, 64, 64).half().npu()
kronecker_p1 = torch.randn(64, 64).half().npu()
kronecker_p2 = torch.randn(64, 64).half().npu()
expect_out = torch.randint(low=-128, high=127, size=(16, 64, 8), dtype=torch.int32).npu()
expect_quant_scale = torch.randn(16, dtype=torch.float).npu()
out, quant_scale = torch_npu.npu_kronecker_quant(x, kronecker_p1, kronecker_p2)
self.assertEqual(out.dtype, torch.int32)
self.assertEqual(quant_scale.dtype, torch.float)
self.assertEqual(out.shape, expect_out.shape)
self.assertEqual(quant_scale.shape, expect_quant_scale.shape)
class TestNpuLinear(TestCase):
def test_npu_linear_meta(self):
with FakeTensorMode():
npu_input1 = torch.randn(16, 128).npu()
npu_input2 = torch.randn(32, 128).npu()
npu_bias = torch.randn(32,).npu()
result = torch_npu.npu_linear(npu_input1, npu_input2, npu_bias)
self.assertEqual(result.dtype, npu_input1.dtype)
self.assertEqual(result.shape, torch.nn.functional.linear(npu_input1, npu_input2, npu_bias).shape)
class TestMoeFinalizeRouting(TestCase):
def test_npu_moe_finalize_routing_meta(self):
with FakeTensorMode():
num_rows = 50
top_k = 4
token_len = 10
expert_num = 16
expanded_permuted_rows = torch.randn(num_rows * top_k, token_len).to(torch.float32)
skip1 = torch.randn(num_rows, token_len).to(torch.float32)
skip2_optional = torch.randn(num_rows, token_len).to(torch.float32)
bias = torch.randn(num_rows, top_k).to(torch.float32)
scales = torch.randn(num_rows, top_k).to(torch.float32)
expanded_src_to_dst_row = torch.arange(num_rows * top_k).to(torch.int32)
expert_for_source_row = torch.randint(low=0, high=expert_num, size=(num_rows, top_k)).to(torch.int32)
result = torch_npu.npu_moe_finalize_routing(expanded_permuted_rows, skip1, skip2_optional, bias, scales,
expanded_src_to_dst_row, expert_for_source_row)
self.assertTrue(result.shape == skip1.shape)
self.assertTrue(result.dtype == skip1.dtype)
class TestGMMFinalizeRouting(TestCase):
def test_npu_grouped_matmul_finalise_routing_meta(self):
with FakeTensorMode():
m, k, n, batch, topK, group_num = 576, 2048, 7168, 72, 8, 8
x = torch.randint(-10, 10, (m, k), dtype=torch.int8)
weight = torch.randint(-10, 10, (group_num, k, n), dtype=torch.int8)
scale = torch.normal(0, 0.01, (group_num, n), dtype=torch.float32)
pertoken_scale = torch.normal(0, 0.01, (m, ), dtype=torch.float32)
group_list = torch.tensor([batch] * group_num, dtype=torch.float32)
logit_ori = torch.normal(0, 0.1, (batch, group_num), dtype=torch.float32)
routing = torch.argsort(logit_ori, 1)[:, -topK:]
shared_input = torch.normal(0, 0.1, (batch // 4, n), dtype=torch.bfloat16)
logit = torch.nn.functional.softmax(
logit_ori[torch.arange(batch).reshape(-1, 1).repeat(1, topK), routing],
dim=1,
dtype=torch.float32
).reshape(m)
row_index = (torch.argsort(routing.reshape(-1)) // topK).to(torch.int64)
shared_input_offset = batch // 2
output_bs = batch
result = torch_npu.npu_grouped_matmul_finalize_routing(
x.npu(), weight.npu(), group_list.npu(), scale=scale.npu(),
pertoken_scale=pertoken_scale.npu(), shared_input=shared_input.npu(),
logit=logit.npu(), row_index=row_index.npu(),
shared_input_offset=shared_input_offset, output_bs=output_bs
).to("cpu")
expect_ret = torch.normal(0, 0.1, (output_bs, n), dtype=torch.float32)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
def test_npu_grouped_matmul_finalise_routing_a8w4_meta(self):
with FakeTensorMode():
m, k, n, batch, topK, group_num = 8, 2048, 7168, 1, 8, 8
quantGroupSize = k
x = torch.randint(-10, 10, (m, k), dtype=torch.int8)
weight = torch.randint(-10, 10, (group_num, k, n // 8), dtype=torch.int32)
scale_np = np.random.normal(0, 0.01, (group_num, 1, n)).astype(np.float32)
perGroupScale = np.ones([group_num, k // quantGroupSize, n]).astype(np.float32)
scaleUint32 = (scale_np * perGroupScale).astype(np.float16).astype(np.float32)
scaleUint32.dtype = np.uint32
scaleUint64 = np.zeros((group_num, k // quantGroupSize, n * 2), dtype=np.uint32)
scaleUint64[..., ::2] = scaleUint32
scaleUint64.dtype = np.int64
scale = torch.from_numpy(scaleUint64)
bias = torch.normal(0, 0.01, (group_num, n), dtype=torch.float32)
offset = torch.randint(-5, 5, (group_num, k // quantGroupSize, n), dtype=torch.float32)
pertoken_scale = torch.normal(0, 0.01, (m, ), dtype=torch.float32)
group_list = torch.tensor([batch] * group_num, dtype=torch.float32)
logit_ori = torch.normal(0, 0.1, (batch, group_num), dtype=torch.float32)
routing = torch.argsort(logit_ori, 1)[:, -topK:]
shared_input = torch.normal(0, 0.1, (batch // 4, n), dtype=torch.bfloat16)
logit = torch.nn.functional.softmax(
logit_ori[torch.arange(batch).reshape(-1, 1).repeat(1, topK), routing],
dim=1,
dtype=torch.float32
).reshape(m)
row_index = (torch.argsort(routing.reshape(-1)) // topK).to(torch.int64)
shared_input_offset = batch // 2
output_bs = batch
result = torch_npu.npu_grouped_matmul_finalize_routing(
x.npu(), weight.npu(), group_list.npu(), scale=scale.npu(),
bias=bias.npu(), offset=offset.npu(),
pertoken_scale=pertoken_scale.npu(), shared_input=shared_input.npu(),
logit=logit.npu(), row_index=row_index.npu(),
shared_input_offset=shared_input_offset, output_bs=output_bs
).to("cpu")
expect_ret = torch.normal(0, 0.1, (output_bs, n), dtype=torch.float32)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
def test_npu_grouped_matmul_finalise_routing_a8w4_with_tuningconfig_meta(self):
with FakeTensorMode():
m, k, n, batch, topK, group_num = 8, 2048, 7168, 1, 8, 8
quantGroupSize = k
x = torch.randint(-10, 10, (m, k), dtype=torch.int8)
weight = torch.randint(-10, 10, (group_num, k, n // 8), dtype=torch.int32)
scale_np = np.random.normal(0, 0.01, (group_num, 1, n)).astype(np.float32)
perGroupScale = np.ones([group_num, k // quantGroupSize, n]).astype(np.float32)
scaleUint32 = (scale_np * perGroupScale).astype(np.float16).astype(np.float32)
scaleUint32.dtype = np.uint32
scaleUint64 = np.zeros((group_num, k // quantGroupSize, n * 2), dtype=np.uint32)
scaleUint64[..., ::2] = scaleUint32
scaleUint64.dtype = np.int64
scale = torch.from_numpy(scaleUint64)
bias = torch.normal(0, 0.01, (group_num, n), dtype=torch.float32)
offset = torch.randint(-5, 5, (group_num, k // quantGroupSize, n), dtype=torch.float32)
pertoken_scale = torch.normal(0, 0.01, (m, ), dtype=torch.float32)
group_list = torch.tensor([batch] * group_num, dtype=torch.float32)
logit_ori = torch.normal(0, 0.1, (batch, group_num), dtype=torch.float32)
routing = torch.argsort(logit_ori, 1)[:, -topK:]
shared_input = torch.normal(0, 0.1, (batch // 4, n), dtype=torch.bfloat16)
logit = torch.nn.functional.softmax(
logit_ori[torch.arange(batch).reshape(-1, 1).repeat(1, topK), routing],
dim=1,
dtype=torch.float32
).reshape(m)
row_index = (torch.argsort(routing.reshape(-1)) // topK).to(torch.int64)
shared_input_offset = batch // 2
output_bs = batch
result = torch_npu.npu_grouped_matmul_finalize_routing(
x.npu(), weight.npu(), group_list.npu(), scale=scale.npu(),
bias=bias.npu(), offset=offset.npu(),
pertoken_scale=pertoken_scale.npu(), shared_input=shared_input.npu(),
logit=logit.npu(), row_index=row_index.npu(),
shared_input_offset=shared_input_offset, output_bs=output_bs
).to("cpu")
expect_ret = torch.normal(0, 0.1, (output_bs, n), dtype=torch.float32)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
@unittest.skip("Skipping due to outdated CANN version; please update CANN to the latest version and remove this skip.")
def test_npu_grouped_matmul_finalise_routing_sharedinput_none_grouplist_cumsum_meta(self):
with FakeTensorMode():
m, k, n, batch, topK, group_num, shared_input_scale = 576, 2048, 7168, 72, 8, 8, 1
x = torch.randint(-10, 10, (m, k), dtype=torch.int8)
weight = torch.randint(-10, 10, (group_num, k, n), dtype=torch.int8)
scale = torch.normal(0, 0.01, (group_num, n), dtype=torch.float32)
pertoken_scale = torch.normal(0, 0.01, (m, 1), dtype=torch.float32)
group_list = torch.tensor([batch] * group_num, dtype=torch.int64)
logit_ori = torch.normal(0, 0.1, (batch, group_num), dtype=torch.float32)
routing = torch.argsort(logit_ori, 1)[:, -topK:]
logit = F.softmax(
logit_ori[torch.arange(batch).reshape(-1, 1).repeat(1, topK), routing],
dim=1,
dtype=torch.float32
).reshape(m)
row_index = (torch.argsort(routing.reshape(-1)) // topK).to(torch.int64)
shared_input_offset = batch // 2
output_bs = batch
group_list_type = 0
group_list = torch.cumsum(group_list, dim=0)
weightNz = torch_npu.npu_format_cast(weight.npu(), 29)
pertoken_scale = pertoken_scale.reshape(m)
result = torch_npu.npu_grouped_matmul_finalize_routing(
x.npu(), weightNz, group_list.npu(), scale=scale.npu(),
pertoken_scale=pertoken_scale.npu(), shared_input=None,
logit=logit.npu(), row_index=row_index.npu(),
shared_input_offset=shared_input_offset, output_bs=output_bs, group_list_type=group_list_type
).to("cpu")
expect_ret = torch.normal(0, 0.1, (output_bs, n), dtype=torch.float32)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
class TestTransposeBatchMatmul(TestCase):
@unittest.skip("skip test_npu_transpose_batchmatmul")
def test_npu_transpose_batchmatmul_meta_1(self):
with FakeTensorMode():
M, K, N, Batch = 32, 512, 128, 16
x1 = torch.randn((M, Batch, K), dtype=torch.float16)
x2 = torch.randn((Batch, K, N), dtype=torch.float16)
scale = torch.randint(1, 10, (Batch * N, ), dtype=torch.int64)
result = torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(), scale=scale.npu(),
perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
expect_ret = torch.randint(-2, 2, (M, 1, Batch * N), dtype=torch.int8)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
@unittest.skip("skip test_npu_transpose_batchmatmul")
def test_npu_transpose_batchmatmul_meta_2(self):
with FakeTensorMode():
M, K, N, Batch = 32, 512, 128, 16
x1 = torch.randn((M, Batch, K), dtype=torch.float16)
x2 = torch.randn((Batch, K, N), dtype=torch.float16)
result = torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(),
perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
expect_ret = torch.randn((M, Batch, N), dtype=torch.float16)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
@unittest.skip("skip test_npu_transpose_batchmatmul")
def test_npu_transpose_batchmatmul_meta_3(self):
with FakeTensorMode():
M, K, N, Batch = 32, 512, 128, 16
batch_split_factor = 4
x1 = torch.randn((M, Batch, K), dtype=torch.float16)
x2 = torch.randn((Batch, K, N), dtype=torch.float16)
result = torch_npu.npu_transpose_batchmatmul(x1.npu(), x2.npu(), scale=None,
perm_x1=[1, 0, 2], perm_y=[1, 0, 2],
batch_split_factor=batch_split_factor)
expect_ret = torch.randn((batch_split_factor, M, Batch * N // batch_split_factor), dtype=torch.float16)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
@unittest.skip("skip test_npu_transpose_batchmatmul")
def test_npu_transpose_batchmatmul_meta_4(self):
with FakeTensorMode():
M, K, N, Batch = 32, 512, 128, 16
x1 = torch.randn((M, Batch, K), dtype=torch.float16)
x2 = torch.randn((Batch, K, N), dtype=torch.float16)
x2_nz = torch_npu.npu_format_cast(x2.npu(), acl_format=29)
result = torch_npu.npu_transpose_batchmatmul(x1.npu(), x2_nz.npu(),
perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
expect_ret = torch.randn((M, Batch, N), dtype=torch.float16)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
class TestTransposeQuantBatchMatmul(TestCase):
@unittest.skip("skip test_npu_transpose_quant_batchmatmul")
def test_npu_transpose_quant_batchmatmul_meta_1(self):
with FakeTensorMode():
M, K, N, Batch = 32, 512, 128, 16
x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e5m2)
x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e5m2)
x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
result = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.float16,
x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
expect_ret = torch.randint(-2, 2, (M, Batch, N), dtype=torch.float16)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
@unittest.skip("skip test_npu_transpose_quant_batchmatmul")
def test_npu_transpose_quant_batchmatmul_meta_2(self):
with FakeTensorMode():
M, K, N, Batch = 32, 512, 128, 16
x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
result = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.float16,
x1_scale=scale.npu(), x2_scale=scale.npu(),
perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
expect_ret = torch.randint(-2, 2, (M, Batch, N), dtype=torch.float16)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
@unittest.skip("skip test_npu_transpose_quant_batchmatmul")
def test_npu_transpose_quant_batchmatmul_meta_3(self):
with FakeTensorMode():
M, K, N, Batch = 32, 512, 128, 16
x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
result = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.bfloat16,
x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
expect_ret = torch.randint(-2, 2, (M, Batch, N), dtype=torch.bfloat16)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
@unittest.skip("skip test_npu_transpose_quant_batchmatmul")
def test_npu_transpose_quant_batchmatmul_meta_4(self):
with FakeTensorMode():
M, K, N, Batch = 32, 512, 128, 16
x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e5m2)
x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e5m2)
x1_scale = torch.randint(-3, 3, (M, ), dtype=torch.float32)
x2_scale = torch.randint(-3, 3, (N, ), dtype=torch.float32)
result = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.bfloat16,
x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
perm_x1=[1, 0, 2], perm_y=[1, 0, 2])
expect_ret = torch.randint(-2, 2, (M, Batch, N), dtype=torch.bfloat16)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
@unittest.skip("skip test_npu_transpose_quant_batchmatmul")
def test_npu_transpose_quant_batchmatmul_meta_5(self):
with FakeTensorMode():
M, K, N, Batch = 32, 512, 128, 16
x1 = torch.randint(-5, 5, (M, Batch, K), dtype=torch.int8).to(torch.float8_e4m3fn)
x2 = torch.randint(-5, 5, (Batch, K, N), dtype=torch.int8).to(torch.float8_e4m3fn)
x1_scale = torch.randint(0, 3, (M, Batch, int(K/64), 2), dtype=torch.uint8)
x2_scale = torch.randint(0, 3, (Batch, int(K/64), N, 2), dtype=torch.uint8)
x1_scale =x1_scale.view(torch.float8_e8m0fnu)
x2_scale =x2_scale.view(torch.float8_e8m0fnu)
result = torch_npu.npu_transpose_quant_batchmatmul(x1.npu(), x2.npu(), dtype=torch.float16,
x1_scale=x1_scale.npu(), x2_scale=x2_scale.npu(),
group_sizes=[0,0,32],perm_x1=[1, 0, 2],perm_x2=[0, 1, 2], perm_y=[1, 0, 2])
expect_ret = torch.randint(-2, 2, (M, Batch, N), dtype=torch.float16)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
class TestNpuPrefetch(TestCase):
def test_npu_prefetch(self):
with FakeTensorMode():
input1 = torch.randn(8, 8).npu()
with self.assertRaises(RuntimeError) as cm:
torch_npu.npu_prefetch(input1, None, -1)
exception = cm.exception
self.assertEqual(str(exception), "The max_size should be greater than zero, but got -1.")
with self.assertRaises(RuntimeError) as cm:
torch_npu.npu_prefetch(input1, None, 10, -1)
exception = cm.exception
self.assertEqual(str(exception), "The offset should be nonnegative, but got -1.")
class TestMoeDistributeDispatch(TestCase):
def test_moe_distribute_dispatchA2(self):
with FakeTensorMode():
quant_mode = 2
ep_world_size = 16
tp_world_size = 0
world_size = ep_world_size
bs = 8
h = 7168
k = 8
sharedExpertRankNum = 0
moeExpertNum = 16
global_bs = bs * ep_world_size
expert_num_per_rank = 1
total_expert_num = world_size * expert_num_per_rank
local_moe_expert_num = moeExpertNum // ep_world_size
a = global_bs * min(local_moe_expert_num, k)
ep_recv_cnt_num = ep_world_size * local_moe_expert_num + global_bs * 2 * k * (ep_world_size // 8)
x = torch.randn(bs, h).to(torch.bfloat16)
expert_ids = torch.randn(bs, k).to(torch.int32)
scales = torch.randn(total_expert_num, h).to(torch.float32)
expert_scales = torch.randn(bs, k).to(torch.float32)
result = torch_npu.npu_moe_distribute_dispatch(x, expert_ids, "group_ep", ep_world_size, 0, moeExpertNum, scales=scales, x_active_mask=None, expert_scales=expert_scales, group_tp="", tp_world_size=0,
tp_rank_id=0, expert_shard_type=0, shared_expert_num=0, shared_expert_rank_num=sharedExpertRankNum, quant_mode=quant_mode, global_bs=global_bs, expert_token_nums_type=1)
self.assertEqual(result[0].shape[0], a)
self.assertEqual(result[0].shape[1], h)
self.assertEqual(result[0].dtype, torch.int8)
self.assertEqual(result[1].shape[0], a)
self.assertEqual(result[1].dtype, torch.float32)
self.assertEqual(result[2].shape[0], bs * k)
self.assertEqual(result[2].dtype, torch.int32)
self.assertEqual(result[3].shape[0], local_moe_expert_num)
self.assertEqual(result[3].dtype, torch.int64)
self.assertEqual(result[4].shape[0], ep_recv_cnt_num)
self.assertEqual(result[4].dtype, torch.int32)
self.assertEqual(result[5].shape[0], tp_world_size)
self.assertEqual(result[5].dtype, torch.int32)
self.assertEqual(result[6].shape[0], a)
self.assertEqual(result[6].dtype, torch.float32)
class TestMoeDistributeDispatchV2(TestCase):
def _run_dispatch_v2(self, tp_world_size, quant_mode, scales=None, y_dtype=None):
with FakeTensorMode():
ep_world_size = 16
bs = 8
h = 7168
k = 8
moeExpertNum = 16
global_bs = bs * ep_world_size
local_moe_expert_num = moeExpertNum // ep_world_size
a = global_bs * min(local_moe_expert_num, k)
x = torch.randn(bs, h).to(torch.bfloat16)
expert_ids = torch.randn(bs, k).to(torch.int32)
result = torch_npu.npu_moe_distribute_dispatch_v2(
x, expert_ids, "group_ep", ep_world_size, 0, moeExpertNum,
scales=scales, x_active_mask=None, expert_scales=None,
group_tp="", tp_world_size=tp_world_size, tp_rank_id=0,
expert_shard_type=0, shared_expert_num=0, shared_expert_rank_num=0,
quant_mode=quant_mode, global_bs=global_bs, expert_token_nums_type=1,
y_dtype=y_dtype)
return result, a, h, local_moe_expert_num
def test_tp0_pertoken(self):
result, a, h, _ = self._run_dispatch_v2(tp_world_size=0, quant_mode=2)
self.assertEqual(result[0].shape, torch.Size([a, h]))
self.assertEqual(result[0].dtype, torch.int8)
self.assertEqual(result[1].shape, torch.Size([a]))
def test_tp1_pertoken(self):
result, a, h, _ = self._run_dispatch_v2(tp_world_size=1, quant_mode=2)
self.assertEqual(result[0].shape, torch.Size([a, h]))
self.assertEqual(result[0].dtype, torch.int8)
self.assertEqual(result[1].shape, torch.Size([a]))
def test_tp0_pergroup(self):
result, a, h, _ = self._run_dispatch_v2(
tp_world_size=0, quant_mode=3, y_dtype=torch.float8_e5m2)
expected_dim1 = math.ceil(h / 128)
self.assertEqual(result[1].shape, torch.Size([a, expected_dim1]))
def test_tp1_pergroup(self):
result, a, h, _ = self._run_dispatch_v2(
tp_world_size=1, quant_mode=3, y_dtype=torch.float8_e5m2)
expected_dim1 = math.ceil(h / 128)
self.assertEqual(result[1].shape, torch.Size([a, expected_dim1]))
def test_tp0_mx(self):
result, a, h, _ = self._run_dispatch_v2(
tp_world_size=0, quant_mode=4, y_dtype=torch.float8_e4m3fn)
expected_dim1 = (math.ceil(h / 32) + 1) // 2 * 2
self.assertEqual(result[1].shape, torch.Size([a, expected_dim1]))
self.assertEqual(result[1].dtype, torch.uint8)
def test_tp1_mx(self):
result, a, h, _ = self._run_dispatch_v2(
tp_world_size=1, quant_mode=4, y_dtype=torch.float8_e4m3fn)
expected_dim1 = (math.ceil(h / 32) + 1) // 2 * 2
self.assertEqual(result[1].shape, torch.Size([a, expected_dim1]))
self.assertEqual(result[1].dtype, torch.uint8)
def test_tp0_no_quant(self):
result, a, h, _ = self._run_dispatch_v2(tp_world_size=0, quant_mode=0)
self.assertEqual(result[0].shape, torch.Size([a, h]))
self.assertEqual(result[0].dtype, torch.bfloat16)
self.assertEqual(result[1].shape, torch.Size([a]))
class TestMoeDistributeCombineAddRmsNorm(TestCase):
def test_moe_distribute_combine_add_rms_norm(self):
with FakeTensorMode():
expand_x = torch.randn(32, 7168).to(torch.bfloat16)
expert_ids = torch.randn(32, 8).to(torch.int32)
expand_idx = torch.randn(32, 8).to(torch.int32)
ep_send_counts = torch.randn(8).to(torch.int32)
expert_scales = torch.randn(32, 8).to(torch.float32)
residual_x = torch.randn(32, 1, 7168).to(torch.bfloat16)
gamma = torch.randn(7168).to(torch.bfloat16)
ep_world_size = 16
ep_rank_id = 0
moe_expert_num = 16
result = torch_npu.npu_moe_distribute_combine_add_rms_norm(expand_x=expand_x, expert_ids=expert_ids,
expand_idx=expand_idx, ep_send_counts=ep_send_counts, expert_scales=expert_scales,
residual_x=residual_x, gamma=gamma, group_ep="groupe_ep", ep_world_size=ep_world_size,
ep_rank_id=ep_rank_id, moe_expert_num=moe_expert_num)
self.assertEqual(result[0].shape[0], 32)
self.assertEqual(result[0].shape[1], 1)
self.assertEqual(result[0].shape[2], 7168)
self.assertEqual(result[0].dtype, torch.bfloat16)
self.assertEqual(result[1].shape[0], 32)
self.assertEqual(result[1].shape[1], 1)
self.assertEqual(result[1].shape[2], 1)
self.assertEqual(result[1].dtype, torch.float32)
self.assertEqual(result[2].shape[0], 32)
self.assertEqual(result[2].shape[1], 1)
self.assertEqual(result[2].shape[2], 7168)
self.assertEqual(result[2].dtype, torch.bfloat16)
class TestMoeDistributeDispatchSetup(TestCase):
def test_moe_distribute_dispatch_setup(self):
with FakeTensorMode():
x = torch.randn(32, 256).to(torch.bfloat16)
expert_ids = torch.randn(32, 8).to(torch.int32)
ep_world_size = 16
ep_rank_id = 0
moe_expert_num = 16
result = torch_npu.npu_moe_distribute_dispatch_setup(
x=x, expert_ids=expert_ids, group_ep="group",
ep_world_size=ep_world_size, ep_rank_id=ep_rank_id,
moe_expert_num=moe_expert_num)
self.assertEqual(result[0].shape[0], 288)
self.assertEqual(result[0].shape[1], 256)
self.assertEqual(result[0].dtype, torch.bfloat16)
self.assertEqual(result[1].shape[0], 256)
self.assertEqual(result[1].dtype, torch.int32)
self.assertEqual(result[2].shape[0], 4864)
self.assertEqual(result[2].dtype, torch.int32)
class TestMoeDistributeDispatchTeardown(TestCase):
def test_moe_distribute_dispatch_teardown(self):
with FakeTensorMode():
x = torch.randn(32, 256).to(torch.bfloat16)
y = torch.randn(256, 512).to(torch.bfloat16)
expert_ids = torch.randn(32, 8).to(torch.int32)
comm_cmd_info = torch.randn(8704).to(torch.int32)
ep_world_size = 16
ep_rank_id = 0
moe_expert_num = 16
result = torch_npu.npu_moe_distribute_dispatch_teardown(
x=x, y=y, expert_ids=expert_ids,
comm_cmd_info=comm_cmd_info, group_ep="group",
ep_world_size=ep_world_size, ep_rank_id=ep_rank_id,
moe_expert_num=moe_expert_num)
self.assertEqual(result[0].shape[0], 512)
self.assertEqual(result[0].shape[1], 256)
self.assertEqual(result[0].dtype, torch.bfloat16)
self.assertEqual(result[1].shape[0], 512)
self.assertEqual(result[1].dtype, torch.float32)
self.assertEqual(result[2].shape[0], 512 * 128)
self.assertEqual(result[2].dtype, torch.int32)
self.assertEqual(result[3].shape[0], 1)
self.assertEqual(result[3].dtype, torch.int64)
class TestMoeDistributeCombineSetup(TestCase):
def test_moe_distribute_combine_setup(self):
with FakeTensorMode():
expand_x = torch.randn(32, 256).to(torch.bfloat16)
expert_ids = torch.randn(32, 8).to(torch.int32)
assist_info_for_combine = torch.randn(32 * 128).to(torch.int32)
ep_world_size = 16
ep_rank_id = 0
moe_expert_num = 16
result = torch_npu.npu_moe_distribute_combine_setup(
expand_x=expand_x, expert_ids=expert_ids,
assist_info_for_combine=assist_info_for_combine,
group_ep="group", ep_world_size=ep_world_size,
ep_rank_id=ep_rank_id, moe_expert_num=moe_expert_num)
self.assertEqual(result[0].shape[0], 32)
self.assertEqual(result[0].shape[1], 512)
self.assertEqual(result[0].dtype, torch.int8)
self.assertEqual(result[1].shape[0], 768)
self.assertEqual(result[1].dtype, torch.int32)
class TestMoeDistributeCombineTeardown(TestCase):
def test_moe_distribute_combine_teardown(self):
with FakeTensorMode():
expand_x = torch.randn(32, 256).to(torch.bfloat16)
quant_expand_x = torch.randn(32, 1024).to(torch.int8)
expert_ids = torch.randn(32, 8).to(torch.int32)
expand_idx = torch.randn(256).to(torch.int32)
expert_scales = torch.randn(256).to(torch.float32)
comm_cmd_info = torch.randn(768).to(torch.int32)
ep_world_size = 16
ep_rank_id = 0
moe_expert_num = 16
result = torch_npu.npu_moe_distribute_combine_teardown(
expand_x=expand_x, quant_expand_x=quant_expand_x,
expert_ids=expert_ids, expand_idx=expand_idx,
expert_scales=expert_scales, comm_cmd_info=comm_cmd_info,
group_ep="group", ep_world_size=ep_world_size,
ep_rank_id=ep_rank_id, moe_expert_num=moe_expert_num)
self.assertEqual(result.shape[0], 32)
self.assertEqual(result.shape[1], 256)
self.assertEqual(result.dtype, torch.bfloat16)
class TestAddRmsNormQuant(TestCase):
def test_npu_add_rms_norm_quant(self):
with FakeTensorMode():
x1 = torch.randn([2, 16], dtype=torch.float16).npu()
x2 = torch.randn([2, 16], dtype=torch.float16).npu()
gamma = torch.randn([16, ], dtype=torch.float16).npu()
scales1 = torch.randn([16, ], dtype=torch.float32).npu()
zero_points1 = torch.randint(-10, 10, [16, ], dtype=torch.int32).npu()
y1, y2, x_out = torch_npu.npu_add_rms_norm_quant(x1, x2, gamma, scales1, zero_points1)
self.assertTrue(y1.shape == x1.shape)
self.assertTrue(y1.dtype == torch.int8)
self.assertTrue(y2.shape == x1.shape)
self.assertTrue(y2.dtype == torch.int8)
self.assertTrue(x_out.shape == x1.shape)
self.assertTrue(x_out.dtype == x1.dtype)
x1 = torch.randn([2, 16], dtype=torch.bfloat16).npu()
x2 = torch.randn([2, 16], dtype=torch.bfloat16).npu()
gamma = torch.randn([16, ], dtype=torch.bfloat16).npu()
scales1 = torch.randn([16, ], dtype=torch.bfloat16).npu()
zero_points1 = torch.randn([16, ], dtype=torch.bfloat16).npu()
y1, y2, x_out = torch_npu.npu_add_rms_norm_quant(x1, x2, gamma, scales1, zero_points1)
self.assertTrue(y1.shape == x1.shape)
self.assertTrue(y1.dtype == torch.int8)
self.assertTrue(y2.shape == x1.shape)
self.assertTrue(y2.dtype == torch.int8)
self.assertTrue(x_out.shape == x1.shape)
self.assertTrue(x_out.dtype == x1.dtype)
class TestAddRmsNormDynamicQuant(TestCase):
def test_npu_add_rms_norm_dynamic_quant_meta(self):
with FakeTensorMode():
x1 = torch.randn([2, 16], dtype=torch.float16, device='npu')
x2 = torch.randn([2, 16], dtype=torch.float16, device='npu')
gamma = torch.ones([16, ], dtype=torch.float16, device='npu')
beta = torch.zeros([16, ], dtype=torch.float16, device='npu')
smooth_scale1 = torch.ones([16, ], dtype=torch.float16, device='npu')
smooth_scale2 = torch.ones([16, ], dtype=torch.float16, device='npu')
y1_npu, y2_npu, x_out_npu, s1_npu, s2_npu = torch_npu.npu_add_rms_norm_dynamic_quant(
x1, x2, gamma, smooth_scale1=smooth_scale1, smooth_scale2=smooth_scale2, beta=beta
)
self.assertEqual(y1_npu.shape, x1.shape)
self.assertEqual(y1_npu.dtype, torch.int8)
self.assertEqual(y2_npu.shape, x1.shape)
self.assertEqual(y2_npu.dtype, torch.int8)
self.assertEqual(x_out_npu.shape, x1.shape)
self.assertEqual(x_out_npu.dtype, x1.dtype)
self.assertEqual(s1_npu.shape, x1.shape[:-1])
self.assertEqual(s1_npu.dtype, torch.float32)
self.assertEqual(s2_npu.shape, x1.shape[:-1])
self.assertEqual(s2_npu.dtype, torch.float32)
x1 = torch.randn([2, 16], dtype=torch.bfloat16, device='npu')
x2 = torch.randn([2, 16], dtype=torch.bfloat16, device='npu')
gamma = torch.ones([16, ], dtype=torch.bfloat16, device='npu')
beta = torch.zeros([16, ], dtype=torch.bfloat16, device='npu')
smooth_scale1 = torch.ones([16, ], dtype=torch.bfloat16, device='npu')
smooth_scale2 = torch.ones([16, ], dtype=torch.bfloat16, device='npu')
y1_npu, y2_npu, x_out_npu, s1_npu, s2_npu = torch_npu.npu_add_rms_norm_dynamic_quant(
x1, x2, gamma, smooth_scale1=smooth_scale1, smooth_scale2=smooth_scale2, beta=beta
)
self.assertEqual(y1_npu.shape, x1.shape)
self.assertEqual(y1_npu.dtype, torch.int8)
self.assertEqual(y2_npu.shape, x1.shape)
self.assertEqual(y2_npu.dtype, torch.int8)
self.assertEqual(x_out_npu.shape, x1.shape)
self.assertEqual(x_out_npu.dtype, x1.dtype)
self.assertEqual(s1_npu.shape, x1.shape[:-1])
self.assertEqual(s1_npu.dtype, torch.float32)
self.assertEqual(s2_npu.shape, x1.shape[:-1])
self.assertEqual(s2_npu.dtype, torch.float32)
y1_npu, y2_npu, x_out_npu, s1_npu, s2_npu = torch_npu.npu_add_rms_norm_dynamic_quant(
x1, x2, gamma, smooth_scale1=smooth_scale1, smooth_scale2=smooth_scale2, beta=beta, y_dtype=torch.quint4x2
)
expected_y_shape = list(x1.shape)
expected_y_shape[-1] = expected_y_shape[-1] // 8
self.assertEqual(y1_npu.shape, tuple(expected_y_shape))
self.assertEqual(y1_npu.dtype, torch.int32)
self.assertEqual(y2_npu.shape, tuple(expected_y_shape))
self.assertEqual(y2_npu.dtype, torch.int32)
class TestAddRmsNormDynamicMxQuant(TestCase):
def test_npu_add_rms_norm_dynamic_mx_quant_meta(self):
with FakeTensorMode():
x1 = torch.randn([8, 64], dtype=torch.float16, device='npu')
x2 = torch.randn([8, 64], dtype=torch.float16, device='npu')
gamma = torch.ones([64, ], dtype=torch.float16, device='npu')
beta = torch.zeros([64, ], dtype=torch.float16, device='npu')
y_npu, x_out_npu, mxscale_npu, rstd_npu = torch_npu.npu_add_rms_norm_dynamic_mx_quant(
x1, x2, gamma, beta=beta, epsilon=1e-6, scale_alg=0, round_mode="rint", dst_type=torch_npu.float8_e5m2
)
self.assertEqual(y_npu.shape, x1.shape)
self.assertEqual(y_npu.dtype, torch.float8_e5m2)
self.assertEqual(x_out_npu.shape, x1.shape)
self.assertEqual(x_out_npu.dtype, x1.dtype)
self.assertEqual(mxscale_npu.shape, torch.Size([8, 1, 2]))
self.assertEqual(mxscale_npu.dtype, torch.uint8)
self.assertEqual(rstd_npu.shape, torch.Size([8, 1]))
self.assertEqual(rstd_npu.dtype, torch.float32)
class TestMoeUpdateExpert(TestCase):
def test_moe_update_expert(self):
with FakeTensorMode():
dtype = torch.bfloat16
world_size = 8
local_rank_id = 0
balance_mode = 0
bs = 32
k = 8
f = 4
moe_expert_num = 256
expert_ids = torch.randn(bs, k).to(torch.int32)
eplb_table = torch.randn(moe_expert_num, f).to(torch.int32)
expert_scales = torch.sort(torch.rand(bs, k, dtype=dtype), dim=-1, descending=True).values
pruning_threshold = torch.rand(k, dtype=dtype)
num_true = np.random.randint(0, bs + 1)
active_mask_arr = np.concatenate([np.ones(num_true, dtype=bool), np.zeros(bs - num_true, dtype=bool)])
active_mask = torch.from_numpy(active_mask_arr).to(torch.bool)
result = torch_npu.npu_moe_update_expert(expert_ids=expert_ids, eplb_table=eplb_table,
expert_scales=expert_scales, pruning_threshold=pruning_threshold, active_mask=active_mask,
local_rank_id=local_rank_id, world_size=world_size, balance_mode=balance_mode)
self.assertEqual(result[0].shape[0], 32)
self.assertEqual(result[0].shape[1], 8)
self.assertEqual(result[0].dtype, torch.int32)
self.assertEqual(result[1].shape[0], 32)
self.assertEqual(result[1].shape[1], 8)
self.assertEqual(result[1].dtype, torch.bool)
class TestNpuMropeMeta(TestCase):
def test_npu_mrope_meta(self):
with FakeTensorMode():
dtype = torch.bfloat16
rotary_mode = 'half'
cache_mode = 'default'
num_tokens = 8
num_q_heads = 8
head_size = 128
mrope_section = [0, 0, 0]
num_kv_heads = num_q_heads
max_seq_len = num_tokens
rotary_dim = head_size
positions = torch.arange(num_tokens, dtype=torch.int64)
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype)
key = torch.rand(num_tokens, num_kv_heads * head_size, dtype=dtype)
cos_sin_cache = torch.rand(max_seq_len, rotary_dim, dtype=dtype)
positions_npu = positions.npu()
query_npu = query.npu()
key_npu = key.npu()
cos_sin_cache_npu = cos_sin_cache.npu()
query_out, key_out = torch_npu.npu_mrope(
positions_npu,
query_npu,
key_npu,
cos_sin_cache_npu,
head_size,
mrope_section=mrope_section,
rotary_mode=rotary_mode,
cache_mode=cache_mode,
)
self.assertEqual(query_out.shape, query_npu.shape)
self.assertEqual(query_out.dtype, query_npu.dtype)
self.assertEqual(key_out.shape, key_npu.shape)
self.assertEqual(key_out.dtype, key_npu.dtype)
class TestGatherSparseIndex(TestCase):
def test_npu_gather_sparse_index(self):
with FakeTensorMode():
inputs = torch.randn([16, 32], dtype=torch.float32).npu()
index = torch.randint(0, 16, [4, 8]).npu()
expect_ret = torch.randn([4, 8, 32], dtype=torch.float32).npu()
result = torch_npu.npu_gather_sparse_index(inputs, index)
self.assertTrue(result.shape == expect_ret.shape)
self.assertTrue(result.dtype == expect_ret.dtype)
class TestNpuTopKTopP(TestCase):
def test_npu_top_k_top_p_meta(self):
vocab_size = 152064
batch_size = 128
dtype = torch.float32
with FakeTensorMode():
k_max = min(1024, vocab_size)
logits = torch.randn(batch_size, vocab_size).to(dtype)
p = torch.rand(batch_size).to(dtype)
k = torch.randint(10, k_max, (batch_size,)).to(torch.int32)
out_npu = torch_npu.npu_top_k_top_p(logits.npu(), p.npu(), k.npu())
self.assertEqual(out_npu.dtype, dtype)
self.assertEqual(out_npu.shape, logits.shape)
class TestNpuQkvRmsNormRopeCache(TestCase):
def test_npu_qkv_rms_norm_rope_cache_meta(self):
B, S = 16, 3
Nq, Nk, Nv = 16, 1, 1
D = 128
N = Nq + Nk + Nv
qkv_size = [B, S, N, D]
head_nums = [Nq, Nk, Nv]
block_size = 128
block_num = (S + block_size - 1) // block_size * B
dtype = torch.float16
eps = 1e-6
with FakeTensorMode():
qkv = torch.empty(B * S, N * D, dtype=dtype).npu()
q_gamma = torch.empty(D, dtype=dtype).npu()
k_gamma = torch.empty(D, dtype=dtype).npu()
cos = torch.empty(B * S, D, dtype=dtype).npu()
sin = torch.empty(B * S, D, dtype=dtype).npu()
index = torch.full((B * S,), -1, dtype=torch.int64).npu()
q_out = torch.empty(B * S, Nq * D, dtype=dtype).npu()
k_cache = torch.empty(
block_num, Nk * D // 32, block_size, 32,
dtype=torch.int8
).npu()
v_cache = torch.empty(
block_num, Nv * D // 32, block_size, 32,
dtype=torch.int8
).npu()
k_scale = torch.empty(Nk, D, dtype=torch.float32).npu()
v_scale = torch.empty(Nv, D, dtype=torch.float32).npu()
q_out_before_quant, k_out_before_quant, v_out_before_quant = torch_npu.npu_qkv_rms_norm_rope_cache(
qkv=qkv,
q_gamma=q_gamma,
k_gamma=k_gamma,
cos=cos,
sin=sin,
index=index,
q_out=q_out,
k_cache=k_cache,
v_cache=v_cache,
qkv_size=qkv_size,
head_nums=head_nums,
k_scale=k_scale,
v_scale=v_scale,
k_offset=None,
v_offset=None,
epsilon=eps,
cache_mode="PA_NZ",
is_output_qkv=True,
)
self.assertEqual(q_out_before_quant.shape, torch.Size([B * S, Nq * D]))
self.assertEqual(q_out_before_quant.dtype, dtype)
self.assertEqual(k_out_before_quant.shape, torch.Size([B * S, Nk * D]))
self.assertEqual(k_out_before_quant.dtype, dtype)
self.assertEqual(v_out_before_quant.shape, torch.Size([B * S, Nv * D]))
self.assertEqual(v_out_before_quant.dtype, dtype)
class TestNpuMoeTokenPermuteAndUnpermute(TestCase):
def test_npu_moe_token_permute_unpermute_meta(self):
dtype = torch.bfloat16
num_tokens = 1000
num_output_tokens = None
hidden_size = 6144
num_experts = 128
with FakeTensorMode():
tokens = torch.randn(num_tokens, hidden_size).npu().to(dtype)
probs = None
num_output_tokens = 0 if num_output_tokens is None else num_output_tokens
for topk in [1, 4]:
flatten_size = num_tokens * topk
indices = torch.randint(0, num_experts, (num_tokens, topk)).npu()
indices = torch.randint(0, num_experts, (num_tokens, topk)).npu()
permuted_tokens, sorted_indices = torch_npu.npu_moe_token_permute(tokens, indices)
permuted_output_shape_0 = min(num_output_tokens, flatten_size) if num_output_tokens > 0 else flatten_size + num_output_tokens
self.assertEqual(permuted_tokens.dtype, dtype)
self.assertEqual(permuted_tokens.shape[0], permuted_output_shape_0)
self.assertEqual(permuted_tokens.shape[1], hidden_size)
self.assertEqual(sorted_indices.dtype, torch.int32)
self.assertEqual(sorted_indices.shape[0], indices.numel())
probs = (torch.ones_like(indices) / topk).npu().to(dtype)
unpermuted_tokens = torch_npu.npu_moe_token_unpermute(permuted_tokens, sorted_indices, probs=probs)
self.assertEqual(unpermuted_tokens.dtype, dtype)
self.assertEqual(unpermuted_tokens.shape, (sorted_indices.size(0) / topk, hidden_size))
def test_npu_moe_token_permute_unpermute_grad_meta(self):
"""Meta test for backward functions with topk=1 and topk=4"""
dtype = torch.bfloat16
num_tokens = 1000
hidden_size = 6144
num_experts = 128
with FakeTensorMode():
for topk in [1, 4]:
tokens = torch.randn(num_tokens, hidden_size).npu().to(dtype)
indices = torch.randint(0, num_experts, (num_tokens, topk)).npu()
probs = (torch.ones(num_tokens, topk) / topk).npu().to(dtype)
permuted_tokens, sorted_indices = torch_npu.npu_moe_token_permute(tokens, indices)
expected_permuted_shape = (num_tokens * topk, hidden_size)
self.assertEqual(permuted_tokens.shape, expected_permuted_shape)
grad_unpermuted_tokens = torch.randn(num_tokens, hidden_size).npu().to(dtype)
grad_permuted_tokens, grad_probs = torch_npu.npu_moe_token_unpermute_grad(
permuted_tokens=permuted_tokens,
grad_unpermuted_tokens=grad_unpermuted_tokens,
sorted_indices=sorted_indices,
probs=probs,
padded_mode=False,
restore_shape=None,
)
self.assertEqual(grad_permuted_tokens.dtype, permuted_tokens.dtype)
self.assertEqual(grad_permuted_tokens.shape, permuted_tokens.shape)
self.assertIsNotNone(grad_probs)
self.assertEqual(grad_probs.dtype, probs.dtype)
self.assertEqual(grad_probs.shape, probs.shape)
grad_permuted_tokens_for_permute = torch.randn_like(permuted_tokens)
grad_tokens = torch_npu.npu_moe_token_permute_grad(
tokens=tokens,
grad_permuted_tokens=grad_permuted_tokens_for_permute,
indices=indices,
sorted_indices=sorted_indices,
padded_mode=False,
)
self.assertEqual(grad_tokens.dtype, tokens.dtype)
self.assertEqual(grad_tokens.shape, tokens.shape)
grad_tokens_v2 = torch_npu.npu_moe_token_permute_grad_v2(
grad_permuted_tokens=grad_permuted_tokens_for_permute,
sorted_indices=sorted_indices,
tokens_size_0=tokens.shape[0],
tokens_dtype=tokens.dtype,
num_topK=topk,
padded_mode=False,
)
self.assertEqual(grad_tokens_v2.dtype, tokens.dtype)
self.assertEqual(grad_tokens_v2.shape, tokens.shape)
def test_npu_moe_token_unpermute_grad_mixed_dtype(self):
"""Test grad_probs dtype matches probs when probs and grad_unpermuted_tokens have different dtypes"""
num_tokens = 100
hidden_size = 64
num_experts = 8
topk = 4
with FakeTensorMode():
tokens = torch.randn(num_tokens, hidden_size).npu().to(torch.bfloat16)
indices = torch.randint(0, num_experts, (num_tokens, topk)).npu()
permuted_tokens, sorted_indices = torch_npu.npu_moe_token_permute(tokens, indices)
probs = (torch.ones(num_tokens, topk) / topk).npu().to(torch.float32)
grad_unpermuted_tokens = torch.randn(num_tokens, hidden_size).npu().to(torch.bfloat16)
grad_permuted_tokens, grad_probs = torch_npu.npu_moe_token_unpermute_grad(
permuted_tokens=permuted_tokens,
grad_unpermuted_tokens=grad_unpermuted_tokens,
sorted_indices=sorted_indices,
probs=probs,
padded_mode=False,
restore_shape=None,
)
self.assertEqual(grad_permuted_tokens.dtype, permuted_tokens.dtype)
self.assertIsNotNone(grad_probs)
self.assertEqual(grad_probs.dtype, probs.dtype)
self.assertEqual(grad_probs.shape, probs.shape)
def test_npu_moe_token_unpermute_grad_probs_none(self):
"""Test npu_moe_token_unpermute_grad when probs is None"""
num_tokens = 100
hidden_size = 64
num_experts = 8
topk = 1
with FakeTensorMode():
tokens = torch.randn(num_tokens, hidden_size).npu().to(torch.bfloat16)
indices = torch.randint(0, num_experts, (num_tokens, topk)).npu()
permuted_tokens, sorted_indices = torch_npu.npu_moe_token_permute(tokens, indices)
grad_unpermuted_tokens = torch.randn(num_tokens, hidden_size).npu().to(torch.bfloat16)
grad_permuted_tokens, grad_probs = torch_npu.npu_moe_token_unpermute_grad(
permuted_tokens=permuted_tokens,
grad_unpermuted_tokens=grad_unpermuted_tokens,
sorted_indices=sorted_indices,
probs=None,
padded_mode=False,
restore_shape=None,
)
self.assertEqual(grad_permuted_tokens.dtype, permuted_tokens.dtype)
self.assertEqual(grad_permuted_tokens.shape, permuted_tokens.shape)
self.assertIsNone(grad_probs)
class TestNpuMoeUnpermuteWithRoutingMap(TestCase):
def test_npu_moe_token_unpermute_with_routing_map_meta(self):
token_num = 40
hidden_size = 20
expert_num = 20
top_k = 20
capacity = 20
out_token_num = token_num * top_k
out_token_num_pad = expert_num * capacity
def generate_bool_matrix(m, n, k):
matrix = torch.zeros((m, n), dtype=torch.bool)
for i in range(m):
indices = torch.randperm(n)[:k]
matrix[i, indices] = True
return matrix.to(torch.int8)
with FakeTensorMode():
for need_probs in [True, False]:
routing_map = generate_bool_matrix(token_num, expert_num, top_k)
routing_map_npu = routing_map.npu()
permuted_tokens = torch.randn([out_token_num_pad, hidden_size])
permuted_tokens_npu = permuted_tokens.npu()
permuted_tokens.requires_grad_(True)
permuted_tokens_npu.requires_grad_(True)
routing_map_tmp = routing_map.T.contiguous()
sorted_indices = routing_map_tmp.argsort(dim=-1, descending=True, stable=True)[:, :capacity].contiguous().to(torch.int32).view(-1)
sorted_indices_npu = sorted_indices.npu()
sorted_indices_npu = sorted_indices.npu()
porbs = torch.randn([token_num, expert_num])
porbs_npu = porbs.npu()
porbs.requires_grad_(True)
porbs_npu.requires_grad_(True)
restore_shape = [token_num, hidden_size]
drop_and_pad = True
if need_probs:
unpermuted_tokens = torch_npu.npu_moe_token_unpermute_with_routing_map(
permuted_tokens_npu, sorted_indices_npu, restore_shape, probs=porbs_npu, routing_map=routing_map_npu, drop_and_pad=drop_and_pad)
permuted_tokens_, out_index, permuted_token_id, permute_probs = torch_npu._npu_moe_token_unpermute_with_routing_map(
permuted_tokens_npu, sorted_indices_npu, restore_shape, probs=porbs_npu, routing_map=routing_map_npu, drop_and_pad=drop_and_pad)
else:
unpermuted_tokens = torch_npu.npu_moe_token_unpermute_with_routing_map(
permuted_tokens_npu, sorted_indices_npu, restore_shape, probs=None, routing_map=routing_map_npu, drop_and_pad=drop_and_pad)
permuted_tokens_, out_index, permuted_token_id, permute_probs = torch_npu._npu_moe_token_unpermute_with_routing_map(
permuted_tokens_npu, sorted_indices_npu, restore_shape, probs=None, routing_map=routing_map_npu, drop_and_pad=drop_and_pad)
self.assertEqual(unpermuted_tokens.dtype, permuted_tokens_npu.dtype)
self.assertEqual(unpermuted_tokens.shape[0], restore_shape[0])
self.assertEqual(unpermuted_tokens.shape[1], restore_shape[1])
self.assertEqual(permuted_tokens_.dtype, permuted_tokens_npu.dtype)
self.assertEqual(permuted_tokens_.shape[0], restore_shape[0])
self.assertEqual(permuted_tokens_.shape[1], restore_shape[1])
self.assertEqual(permuted_token_id.dtype, sorted_indices_npu.dtype)
self.assertEqual(permuted_token_id.shape, sorted_indices_npu.shape)
if need_probs:
self.assertEqual(permute_probs.dtype, porbs_npu.dtype)
self.assertEqual(permute_probs.shape, sorted_indices.shape)
class TestNpuAttentionWorkerCombine(TestCase):
def test_npu_attention_worker_combine(self):
with FakeTensorMode():
schedule_context = torch.randn((1024,)).npu().to(torch.int8)
BS = 16
expert_scales = torch.rand(BS, 8, dtype=torch.float32).npu()
layer_id = torch.randint(1, 20, (1,), dtype=torch.int32).npu()
H = 7168
y_npu, next_layer_id_npu = \
torch_npu.npu_attention_worker_combine(schedule_context, expert_scales, layer_id, H, token_dtype=0, need_schedule=0)
self.assertTrue(y_npu.shape[0] == BS)
self.assertTrue(y_npu.shape[1] == H)
self.assertTrue(next_layer_id_npu.shape[0] == 1)
class TestNpuMoeTokenPermuteWithRoutingMap(TestCase):
def test_npu_moe_token_permute_with_routing_map_meta(self):
with FakeTensorMode():
x = torch.randn((3, 4), dtype=torch.float)
x.requires_grad = True
routing_map = torch.tensor([[True, True], [True, True], [True, True]], dtype=torch.bool)
numtoken = 6
padMode = True
probs = torch.randn([3, 2], dtype=torch.float)
probs.requires_grad = True
x_npu = x.npu().detach()
x_npu.requires_grad = True
routing_map_npu = routing_map.npu()
probs_npu = probs.npu().detach()
probs_npu.requires_grad = True
out_token = numtoken // routing_map.size(1) * routing_map.size(1)
output0 = torch.empty(out_token, x_npu.size(1), dtype=x_npu.dtype, device=x_npu.device)
output1 = torch.empty(out_token, dtype=probs_npu.dtype, device=probs_npu.device)
output2 = torch.empty(out_token, dtype=torch.int32, device=x_npu.device)
x1, x2, x3 = torch_npu.npu_moe_token_permute_with_routing_map(x_npu, routing_map_npu, probs=probs_npu, num_out_tokens=numtoken, drop_and_pad=padMode)
self.assertEqual(x1.dtype, output0.dtype)
self.assertEqual(x2.dtype, output1.dtype)
self.assertEqual(x3.dtype, output2.dtype)
self.assertEqual(x1.shape, output0.shape)
self.assertEqual(x2.shape, output1.shape)
self.assertEqual(x3.shape, output2.shape)
(x1.sum() + x2.sum()).backward()
output0 = torch.empty_like(x_npu)
output1 = torch.empty_like(probs)
self.assertEqual(x_npu.grad.dtype, output0.dtype)
self.assertEqual(probs_npu.grad.dtype, output1.dtype)
self.assertEqual(x_npu.grad.shape, x_npu.shape)
self.assertEqual(probs_npu.grad.shape, probs_npu.shape)
instantiate_parametrized_tests(FakeTensorTest)
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for="cpu")
class TestGroupedMatmulSwigluQuant(TestCase):
def _assert_npu_grouped_matmul_swiglu_quant_shape(self, x, weight, groupList, weightScale, xScale, output_shape):
output0_npu, output1_npu, output2_npu = torch_npu.npu_grouped_matmul_swiglu_quant(
x.npu(), weight.npu(), groupList.npu(), weightScale.npu(), xScale.npu(), bias=None, offset=None)
self.assertTrue(output0_npu.shape == output_shape)
self.assertTrue(output0_npu.dtype == torch.int8)
self.assertTrue(output1_npu.shape == torch.Size([x.size(0)]))
self.assertTrue(output1_npu.dtype == torch.float32)
self.assertTrue(output2_npu.shape == torch.Size([]))
self.assertTrue(output2_npu.dtype == torch.float32)
def test_npu_grouped_matmul_swiglu_quant(self):
with FakeTensorMode():
E = 16
M = 512
K = 7168
N = 4096
x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
weight = torch.randint(-128, 127, (E, K, N), dtype=torch.int8)
weightScale = torch.randn(E, N)
xScale = torch.randn(M)
groupList = torch.randn(E)
output0 = torch.empty([M, N // 2], dtype=torch.int8, device=x.device)
output1 = torch.empty([M], dtype=torch.float32, device=x.device)
output2 = torch.empty([], dtype=torch.float32, device=x.device)
output0_npu, output1_npu, output2_npu = torch_npu.npu_grouped_matmul_swiglu_quant(x.npu(), weight.npu(), groupList.npu(), weightScale.npu(), xScale.npu(), bias=None, offset=None)
self.assertTrue(output0_npu.shape == output0.shape)
self.assertTrue(output0_npu.dtype == output0.dtype)
self.assertTrue(output1_npu.shape == output1.shape)
self.assertTrue(output1_npu.dtype == output1.dtype)
self.assertTrue(output2_npu.shape == output2.shape)
self.assertTrue(output2_npu.dtype == output2.dtype)
def test_npu_grouped_matmul_swiglu_quant_nz_shape(self):
with FakeTensorMode():
E = 2
M = 8
K = 64
N = 128
x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
weight = torch.randint(-128, 127, (E, N // 32, K // 16, 16, 32), dtype=torch.int8)
weightScale = torch.randn(E, N)
xScale = torch.randn(M)
groupList = torch.zeros(E, dtype=torch.int64)
self._assert_npu_grouped_matmul_swiglu_quant_shape(
x, weight, groupList, weightScale, xScale, torch.Size([M, N // 2]))
def test_npu_grouped_matmul_swiglu_quant_nz_int32_pack_shape(self):
with FakeTensorMode():
E = 2
M = 8
K = 64
N = 128
x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
weight = torch.randint(-128, 127, (E, N // 64, K // 16, 16, 8), dtype=torch.int32)
weightScale = torch.randn(E, N)
xScale = torch.randn(M)
groupList = torch.zeros(E, dtype=torch.int64)
self._assert_npu_grouped_matmul_swiglu_quant_shape(
x, weight, groupList, weightScale, xScale, torch.Size([M, N // 2]))
def test_npu_grouped_matmul_swiglu_quant_nz_per_group_shape(self):
with FakeTensorMode():
E = 2
M = 8
K = 64
N = 128
k_group_num = 2
x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
weight = torch.randint(-128, 127, (E, N // 64, K // 16, 16, 8), dtype=torch.int32)
weightScale = torch.randn(E, k_group_num, N)
xScale = torch.randn(M)
groupList = torch.zeros(E, dtype=torch.int64)
self._assert_npu_grouped_matmul_swiglu_quant_shape(
x, weight, groupList, weightScale, xScale, torch.Size([M, N // 2]))
class TestGroupedMatmulSwigluQuantV2(TestCase):
def _assert_npu_grouped_matmul_swiglu_quant_v2_shape(
self, x, weight, weight_scale, x_scale, group_list, output_shape, output_scale_shape, **kwargs):
output, output_scale = torch_npu.npu_grouped_matmul_swiglu_quant_v2(
x.npu(), [weight.npu()], [weight_scale.npu()], x_scale.npu(), group_list.npu(), **kwargs)
self.assertTrue(output.shape == output_shape)
self.assertTrue(output_scale.shape == output_scale_shape)
@unittest.skip("npu_format_cast has no fake impl; V2 meta shape regressions are covered by the cases below.")
def test_npu_grouped_matmul_swiglu_quant_v2(self):
with FakeTensorMode():
E = 16
M = 512
K = 7168
N = 4096
x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
weight = torch.randint(-128, 127, (E, K, N), dtype=torch.int8)
weight_npu = torch_npu.npu_format_cast(weight.npu(), 29)
weight_npu = [weight_npu]
weightScale = torch.randn(E, N)
xScale = torch.randn(M)
groupList = torch.randn(E)
output0 = torch.empty([M, N // 2], dtype=torch.int8, device=x.device)
output1 = torch.empty([M], dtype=torch.float32, device=x.device)
output0_npu, output1_npu = torch_npu.npu_grouped_matmul_swiglu_quant_v2(x.npu(), weight_npu, [weightScale.npu()], xScale.npu(), groupList.npu(), smooth_scale=None,
weight_assist_matrix=None, bias=None, dequant_mode=0, dequant_dtype=6, quant_mode=0,
quant_dtype=1, group_list_type=0, tuning_config=None)
self.assertTrue(output0_npu.shape == output0.shape)
self.assertTrue(output0_npu.dtype == output0.dtype)
self.assertTrue(output1_npu.shape == output1.shape)
self.assertTrue(output1_npu.dtype == output1.dtype)
def test_npu_grouped_matmul_swiglu_quant_v2_nz_int32_pack_shape(self):
with FakeTensorMode():
E = 2
M = 8
K = 64
N = 512
x = torch.randint(-128, 127, (M, K), dtype=torch.int8)
weight = torch.randint(-128, 127, (E, K, N // 8), dtype=torch.int32)
weight_scale = torch.randint(0, 10, (E, N), dtype=torch.int64)
weight_assist = torch.randn(E, N)
x_scale = torch.randn(M)
group_list = torch.zeros(E, dtype=torch.int64)
self._assert_npu_grouped_matmul_swiglu_quant_v2_shape(
x, weight, weight_scale, x_scale, group_list,
torch.Size([M, N // 2]), torch.Size([M]),
weight_assist_matrix=[weight_assist.npu()], dequant_mode=0,
dequant_dtype=6, quant_dtype=1)
def test_npu_grouped_matmul_swiglu_quant_v2_nz_mx_shape(self):
with FakeTensorMode():
E = 2
M = 8
K = 128
N = 512
x = torch.empty((M, K), dtype=torch.float8_e4m3fn)
weight = torch.empty((E, K, N // 8), dtype=torch.float8_e4m3fn)
weight_scale = torch.empty((E, math.ceil(K / 64), N, 2), dtype=torch.uint8)
x_scale = torch.empty((M, math.ceil(K / 64), 2), dtype=torch.uint8)
group_list = torch.zeros(E, dtype=torch.int64)
self._assert_npu_grouped_matmul_swiglu_quant_v2_shape(
x, weight, weight_scale, x_scale, group_list,
torch.Size([M, N // 2]), torch.Size([M, math.ceil((N // 2) / 64), 2]),
dequant_mode=2, dequant_dtype=6, quant_mode=2,
quant_dtype=24,
weight_scale_dtype=torch_npu.float8_e8m0fnu,
x_scale_dtype=torch_npu.float8_e8m0fnu)
def _assert_npu_grouped_matmul_swiglu_quant_v2_ascend950_mxfp4_shape(
self, case_name, e, m, k, n, output_tensor_dtype, group_list_type,
expected_output_shape, expected_output_scale_shape):
output_dtype_to_quant_dtype = {
"float8_e5m2": 23,
"float8_e4m3fn": 24,
"float4_e2m1": torch_npu.float4_e2m1fn_x2,
}
if output_tensor_dtype == "float4_e2m1" and not hasattr(torch, "float4_e2m1fn_x2"):
expected_output_shape = (expected_output_shape[0], expected_output_shape[1] // 2)
x = torch.empty((m, k), dtype=torch.uint8)
weight = torch.empty((e, k, n), dtype=torch.uint8)
weight_scale = torch.empty((e, math.ceil(k / 64), n, 2), dtype=torch.uint8)
x_scale = torch.empty((m, math.ceil(k / 64), 2), dtype=torch.uint8)
group_list = torch.zeros(e, dtype=torch.int64)
with self.subTest(case=case_name):
self._assert_npu_grouped_matmul_swiglu_quant_v2_shape(
x, weight, weight_scale, x_scale, group_list,
torch.Size(expected_output_shape), torch.Size(expected_output_scale_shape),
dequant_mode=2, dequant_dtype=6, quant_mode=2,
quant_dtype=output_dtype_to_quant_dtype[output_tensor_dtype],
group_list_type=group_list_type, tuning_config=[0],
weight_scale_dtype=torch_npu.float8_e8m0fnu,
x_scale_dtype=torch_npu.float8_e8m0fnu,
x_dtype=torch_npu.float4_e2m1fn_x2,
weight_dtype=torch_npu.float4_e2m1fn_x2)
@unittest.skipIf(
not hasattr(torch_npu, "float4_e2m1fn_x2"),
"torch_npu float4 dtype wrappers are required for Ascend950 MXFP4 fake shape cases.")
def test_npu_grouped_matmul_swiglu_quant_v2_ascend950_mxfp4_shape(self):
with FakeTensorMode():
cases = [
("fp8_e4m3fn_unaligned_shape", 2, 1775, 980, 4352, "float8_e4m3fn", 0, (1775, 2176), (1775, 34, 2)),
("fp8_e5m2_large_n", 4, 2048, 1024, 7168, "float8_e5m2", 0, (2048, 3584), (2048, 56, 2)),
("fp4_e2m1_grouped_shape", 4, 534, 130, 3200, "float4_e2m1", 1, (534, 1600), (534, 25, 2)),
("fp4_e2m1_large_n", 16, 2048, 128, 7168, "float4_e2m1", 1, (2048, 3584), (2048, 56, 2)),
("fp4_e2m1_large_shape", 4, 7168, 4096, 4096, "float4_e2m1", 1, (7168, 2048), (7168, 32, 2)),
]
for case in cases:
self._assert_npu_grouped_matmul_swiglu_quant_v2_ascend950_mxfp4_shape(*case)
@unittest.skip("skip until CANN is updated to support aclnnDynamicBlockQuant")
class TestNpuDynamicBlockQuant(TestCase):
def test_npu_dynamic_block_quant_meta(self):
x = torch.rand(3, 4).to("npu").to(torch.float16)
actual_y, actual_scale = torch_npu.npu_dynamic_block_quant(x, dst_type=1)
with FakeTensorMode():
fake_x = torch.rand(3, 4).to("npu").to(torch.float16)
fake_y, fake_scale = torch_npu.npu_dynamic_block_quant(fake_x, dst_type=1)
self.assertEqual(actual_y.shape, fake_y.shape)
self.assertEqual(actual_scale.shape, fake_scale.shape)
x = torch.rand(3, 4, 5).to("npu").to(torch.float16)
actual_y, actual_scale = torch_npu.npu_dynamic_block_quant(x, dst_type=1)
with FakeTensorMode():
fake_x = torch.rand(3, 4, 5).to("npu").to(torch.float16)
fake_y, fake_scale = torch_npu.npu_dynamic_block_quant(fake_x, dst_type=1)
self.assertEqual(actual_y.shape, fake_y.shape)
self.assertEqual(actual_scale.shape, fake_scale.shape)
class TestQuantGroupedMatmulInplaceAdd(TestCase):
def test_npu_quant_grouped_matmul_inplace_add_meta_0(self):
with FakeTensorMode():
torch.manual_seed(0)
y = torch.randint(-1, 1, (4, 576, 7168), dtype=torch.float32).npu()
x1 = torch.randint(-1, 1, (512, 576), dtype=torch.int8).npu()
x2 = torch.randint(-1, 1, (4, 512, 7168), dtype=torch.int8).npu()
x2_scale = torch.randint(-1, 1, (4, 7168), dtype=torch.float32).npu()
x1_scale = torch.randint(-1, 1, (4,), dtype=torch.float32).npu()
group_list = torch.Tensor([8, 181, 415, 512]).to(torch.int64).npu()
res_1 = torch_npu.npu_add_quant_gmm_(y, x1.t(), x2, x2_scale, group_list, x1_scale=x1_scale,
group_list_type=0, group_sizes=None, x1_dtype=torch_npu.hifloat8, x2_dtype=torch_npu.hifloat8)
self.assertTrue(len(res_1.shape) == 3)
self.assertTrue(x1.shape[1] == res_1.shape[1])
self.assertTrue(x2.shape[0] == res_1.shape[0])
self.assertTrue(x2.shape[2] == res_1.shape[2])
self.assertTrue(res_1.dtype == torch.float32)
self.assertTrue(y.shape[0] == res_1.shape[0])
self.assertTrue(y.shape[1] == res_1.shape[1])
self.assertTrue(y.shape[2] == res_1.shape[2])
res_2 = torch_npu.npu_add_quant_gmm(y, x1.t(), x2, x2_scale, group_list, x1_scale=x1_scale,
group_list_type=0, group_sizes=None, x1_dtype=torch_npu.hifloat8, x2_dtype=torch_npu.hifloat8)
self.assertTrue(len(res_2.shape) == 3)
self.assertTrue(x1.shape[1] == res_2.shape[1])
self.assertTrue(x2.shape[0] == res_2.shape[0])
self.assertTrue(x2.shape[2] == res_2.shape[2])
self.assertTrue(res_2.dtype == torch.float32)
self.assertTrue(y.shape[0] == res_2.shape[0])
self.assertTrue(y.shape[1] == res_2.shape[1])
self.assertTrue(y.shape[2] == res_2.shape[2])
@unittest.skip("skip until CANN is updated to support aclnnMlaPrologV3WeightNz")
class TestMlaProLogV3(TestCase):
def testMlaProLogV3(self):
with FakeTensorMode():
import math
BlockNum = math.ceil(2 * 6144 / 128)
token_x = torch.randint(-100, 100, (2, 1, 7168), dtype=torch.int8).npu()
w_dq = torch.randint(-100, 100, (7168, 1536), dtype=torch.int8).npu()
w_uq_qr = torch.randint(-100, 100, (1536, 32 * (128 + 64)), dtype=torch.int8).npu()
w_uk = torch.rand(32, 128, 512, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.randint(-100, 100, (7168, 512 + 64), dtype=torch.int8).npu()
rmsnorm_gamma_cq = torch.rand(1536, dtype=torch.bfloat16).npu()
rmsnorm_gamma_ckv = torch.rand(512, dtype=torch.bfloat16).npu()
rope_sin = torch.rand(2, 1, 64, dtype=torch.bfloat16).npu()
rope_cos = torch.rand(2, 1, 64, dtype=torch.bfloat16).npu()
kv_cache = torch.randint(-100, 100, (1, BlockNum * 128 * 1 * 512), dtype=torch.int8).npu()
kr_cache = torch.rand(1, BlockNum * 128 * 1 * 64, dtype=torch.bfloat16).npu()
query_shape = []
query_shape.append(token_x.size(0))
query_shape.append(token_x.size(1))
query_shape.append(w_uk.size(0))
query_shape.append(w_uk.size(2))
query_rope_shape = []
query_rope_shape.append(token_x.size(0))
query_rope_shape.append(token_x.size(1))
query_rope_shape.append(w_uk.size(0))
query_rope_shape.append(rope_sin.size(2))
dequant_scale_q_norm_shape = []
dequant_scale_q_norm_shape.append(token_x.size(0) * token_x.size(1))
dequant_scale_q_norm_shape.append(1)
query = torch.empty(query_shape, dtype=rope_sin.dtype, device='meta')
dequant_scale_q_nope = torch.empty([0], dtype=torch.float32, device='meta')
query_norm = torch.empty([0], dtype=w_uq_qr.dtype, device='meta')
dequant_scale_q_norm = torch.empty([0], dtype=torch.float32, device='meta')
query_rope = torch.empty(query_rope_shape, dtype=torch.bfloat16, device='meta')
query_mla, query_rope_mla, dequant_scale_q_nope_mla, query_norm_mla, dequant_scale_q_norm_mla = torch_npu.npu_mla_prolog_v3(token_x, w_dq, w_uq_qr, w_uk, w_dkv_kr, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache)
self.assertTrue(query_mla.shape == query.shape)
self.assertTrue(query_rope_mla.shape == query_rope.shape)
self.assertTrue(dequant_scale_q_nope_mla.shape == dequant_scale_q_nope.shape)
self.assertTrue(query_norm_mla.shape == query_norm.shape)
self.assertTrue(dequant_scale_q_norm_mla.shape == dequant_scale_q_norm.shape)
@unittest.skip("skip until CANN is updated to support aclnnMlaPrologV3WeightNz")
class TestMlaProLogV3Functional(TestCase):
def testMlaProLogV3Functional(self):
with FakeTensorMode():
import math
BlockNum = math.ceil(2 * 6144 / 128)
token_x = torch.randint(-100, 100, (2, 1, 7168), dtype=torch.int8).npu()
w_dq = torch.randint(-100, 100, (7168, 1536), dtype=torch.int8).npu()
w_uq_qr = torch.randint(-100, 100, (1536, 32 * (128 + 64)), dtype=torch.int8).npu()
w_uk = torch.rand(32, 128, 512, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.randint(-100, 100, (7168, 512 + 64), dtype=torch.int8).npu()
rmsnorm_gamma_cq = torch.rand(1536, dtype=torch.bfloat16).npu()
rmsnorm_gamma_ckv = torch.rand(512, dtype=torch.bfloat16).npu()
rope_sin = torch.rand(2, 1, 64, dtype=torch.bfloat16).npu()
rope_cos = torch.rand(2, 1, 64, dtype=torch.bfloat16).npu()
kv_cache = torch.randint(-100, 100, (1, BlockNum * 128 * 1 * 512), dtype=torch.int8).npu()
kr_cache = torch.rand(1, BlockNum * 128 * 1 * 64, dtype=torch.bfloat16).npu()
query_shape = []
query_shape.append(token_x.size(0))
query_shape.append(token_x.size(1))
query_shape.append(w_uk.size(0))
query_shape.append(w_uk.size(2))
query_rope_shape = []
query_rope_shape.append(token_x.size(0))
query_rope_shape.append(token_x.size(1))
query_rope_shape.append(w_uk.size(0))
query_rope_shape.append(rope_sin.size(2))
dequant_scale_q_norm_shape = []
dequant_scale_q_norm_shape.append(token_x.size(0) * token_x.size(1))
dequant_scale_q_norm_shape.append(1)
query = torch.empty(query_shape, dtype=rope_sin.dtype, device='meta')
dequant_scale_q_nope = torch.empty([0], dtype=torch.float32, device='meta')
query_norm = torch.empty([0], dtype=w_uq_qr.dtype, device='meta')
dequant_scale_q_norm = torch.empty([0], dtype=torch.float32, device='meta')
query_rope = torch.empty(query_rope_shape, dtype=torch.bfloat16, device='meta')
query_mla, query_rope_mla, dequant_scale_q_nope_mla, _, _, kv_cache_mla, kr_cache_mla = torch_npu.npu_mla_prolog_v3_functional(token_x, w_dq, w_uq_qr, w_uk, w_dkv_kr, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache)
self.assertTrue(query_mla.shape == query.shape)
self.assertTrue(query_rope_mla.shape == query_rope.shape)
self.assertTrue(dequant_scale_q_nope_mla.shape == dequant_scale_q_nope.shape)
self.assertTrue(kv_cache_mla.shape == kv_cache.shape)
self.assertTrue(kr_cache_mla.shape == kr_cache.shape)
class TestGatherPAKVCacheFunctional(TestCase):
def test_npu_gather_pa_kv_cache_functional_meta(self):
with FakeTensorMode():
batch_size = 2
num_blocks = 8
head_num = 4
block_size = 64
head_dim = 64
max_blocks_per_sequence = 5
seq_lens_list = [5, 8]
total_tokens = sum(seq_lens_list)
key_cache = torch.empty(
(num_blocks, block_size, head_num, head_dim),
dtype=torch.float16
)
value_cache = torch.empty_like(key_cache)
block_tables = torch.empty(
(batch_size, max_blocks_per_sequence),
dtype=torch.int32
)
seq_lens = torch.empty(
(batch_size,),
dtype=torch.int32
)
key_out = torch.empty(
(total_tokens, head_num, head_dim),
dtype=torch.float16
)
value_out = torch.empty_like(key_out)
seq_offset = torch.empty(
(batch_size,),
dtype=torch.int32
)
key_result, value_result = torch_npu.npu_gather_pa_kv_cache_functional(
key_cache, value_cache,
block_tables, seq_lens,
key_out, value_out,
seq_offset=seq_offset,
is_seq_lens_cumsum=False
)
self.assertEqual(key_result.shape, key_out.shape)
self.assertEqual(value_result.shape, value_out.shape)
class TestGroupedMatmulAdd(TestCase):
def test_npu_grouped_matmul_add_meta_0(self):
with FakeTensorMode():
torch.manual_seed(0)
y = torch.randint(-1, 1, (4, 576, 7168), dtype=torch.float32).npu()
x = torch.randint(-1, 1, (512, 576), dtype=torch.float16).npu()
weight = torch.randint(-1, 1, (512, 7168), dtype=torch.float16).npu()
group_list = torch.Tensor([8, 181, 415, 512]).to(torch.int64).npu()
res_1 = torch_npu.npu_grouped_matmul_add_(y, x, weight, group_list, transpose_x=True, transpose_weight=False,
group_type=2, group_list_type=0)
self.assertTrue(len(res_1.shape) == 3)
self.assertTrue(x.shape[1] == res_1.shape[1])
self.assertTrue(weight.shape[1] == res_1.shape[2])
self.assertTrue(res_1.dtype == torch.float32)
self.assertTrue(y.shape[0] == res_1.shape[0])
self.assertTrue(y.shape[1] == res_1.shape[1])
self.assertTrue(y.shape[2] == res_1.shape[2])
class TestQuantAllReduce(TestCase):
def test_npu_quant_all_reduce(self):
with FakeTensorMode():
torch.manual_seed(0)
x = torch.randint(-5, 5, (128, 5120), dtype=torch.int8).npu()
scales = torch.randint(-5, 5, (128, 40), dtype=torch.float32).npu()
hcom = "fake group info"
world_size = 2
res = torch_npu.npu_quant_all_reduce(x, scales, hcom, world_size)
self.assertTrue(len(x.shape) == len(res.shape))
self.assertTrue(x.shape[0] == res.shape[0])
self.assertTrue(x.shape[1] == res.shape[1])
class TestQuantReduceScatter(TestCase):
def test_npu_quant_reduce_scatter(self):
with FakeTensorMode():
torch.manual_seed(0)
x = torch.randint(-5, 5, (128, 5120), dtype=torch.int8).npu()
scales = torch.randint(-5, 5, (128, 40), dtype=torch.float32).npu()
hcom = "fake group info"
world_size = 2
res = torch_npu.npu_quant_reduce_scatter(x, scales, hcom, world_size)
self.assertTrue(len(res.shape) == 2)
self.assertTrue(x.shape[0] // world_size == res.shape[0])
self.assertTrue(x.shape[1] == res.shape[1])
class TestMatmulAlltoAll(TestCase):
def test_npu_matmul_all_to_all(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randint(-1, 2, (16, 32), dtype=torch.float16).npu()
x2 = torch.randint(-1, 2, (32, 32), dtype=torch.float16).npu()
bias = torch.randint(-1, 2, (32,), dtype=torch.float16).npu()
hcom = "fake group info"
world_size = 2
res = torch_npu.npu_matmul_all_to_all(x1, x2, hcom, world_size, bias=bias, all2all_axes=[-1,-2])
self.assertTrue(len(res.shape) == 2)
self.assertTrue(x1.shape[0] * world_size == res.shape[0])
self.assertTrue(x2.shape[1] / world_size == res.shape[1])
self.assertTrue(res.dtype == x1.dtype)
class TestQuantMatmulAlltoAll(TestCase):
def test_npu_quant_matmul_all_to_all(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randint(-1, 2, (16, 32), dtype=torch.float8_e4m3fn).npu()
x2 = torch.randint(-1, 2, (32, 32), dtype=torch.float8_e4m3fn).npu()
bias = torch.randint(-1, 2, (32,), dtype=torch.float32).npu()
x1Scale = torch.randint(-1, 2, (32,), dtype=torch.float32).npu()
x2Scale = torch.randint(-1, 2, (32,), dtype=torch.float32).npu()
hcom = "fake group info"
world_size = 2
common_scale = None
x1_offset = None
x2_offset = None
x1_quant_mode = 3
x2_quant_mode = 2
common_quant_mode = 0
group_size = [0]
all2all_axes=[-1,-2]
comm_quant_dtype = -1
x1_dtype = None
x2_dtype = None
x1_scale_dtype = None
x2_scale_dtype = None
output_scale_dtype = None
comm_scale_dtype = None
y_dtype = None
res = torch_npu.npu_quant_matmul_all_to_all(x1, x2, hcom, world_size, bias, x1Scale, x2Scale, common_scale, x1_offset, x2_offset,
x1_quant_mode, x2_quant_mode, common_quant_mode, group_size, all2all_axes, comm_quant_dtype, x1_dtype, x2_dtype, x1_scale_dtype,
x2_scale_dtype, output_scale_dtype, comm_scale_dtype, y_dtype)
self.assertTrue(len(res.shape) == 2)
self.assertTrue(x1.shape[0] * world_size == res.shape[0])
self.assertTrue(x2.shape[1] / world_size == res.shape[1])
self.assertTrue(res.dtype == torch.float32)
class TestAlltoAllMatmul(TestCase):
def test_npu_all_to_all_matmul(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randint(-1, 2, (16, 16), dtype=torch.float16).npu()
x2 = torch.randint(-1, 2, (32, 32), dtype=torch.float16).npu()
bias = torch.randint(-1, 2, (32,), dtype=torch.float16).npu()
hcom = "fake group info"
world_size = 2
res , _ = torch_npu.npu_all_to_all_matmul(x1, x2, hcom, world_size, bias=bias, all2all_axes=[-2,-1], all2all_out_flag=True)
self.assertTrue(len(res.shape) == 2)
self.assertTrue(x1.shape[0] / world_size == res.shape[0])
self.assertTrue(x2.shape[1] == res.shape[1])
self.assertTrue(res.dtype == x1.dtype)
class TestAlltoAllQuantMatmul(TestCase):
def test_npu_all_to_all_quant_matmul(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randint(-1, 2, (16, 16), dtype=torch.float16).npu()
x2 = torch.randint(-1, 2, (32, 32), dtype=torch.float8_e4m3fn).npu()
bias = torch.randint(-1, 2, (32,), dtype=torch.float32).npu()
x2Scale = torch.randint(-1, 2, (32,), dtype=torch.float32).npu()
hcom = "fake group info"
world_size = 2
res , _ = torch_npu.npu_all_to_all_quant_matmul(x1, x2, hcom, world_size, bias=bias, all2all_axes=[-2,-1], all2all_out_flag=True, x2_scale=x2Scale)
self.assertTrue(len(res.shape) == 2)
self.assertTrue(x1.shape[0] / world_size == res.shape[0])
self.assertTrue(x2.shape[1] == res.shape[1])
class TestNpuDSA(TestCase):
def setup_sparse_flash_attention_test_params(self, requires_grad=False, return_softmax_lse=False):
scale_value = 0.041666666666666664
sparse_block_size = 1
query_type = torch.float16
sparse_block_count = 2048
t = 10
b = 4
s1 = 1
s2 = 8192
n1 = 128
n2 = 1
dn = 512
dr = 64
tile_size = 128
block_size = 256
s2_act = 4096
attention_mode = 2
layout = 'BSND'
sparse_mode = 3
query = torch.tensor(np.random.uniform(-10, 10, (b, s1, n1, dn))).to(query_type)
key = torch.tensor(np.random.uniform(-5, 10, (b, s2, n2, dn))).to(query_type)
value = key.clone()
idxs = random.sample(range(s2_act - s1 + 1), sparse_block_count)
sparse_indices = torch.tensor([idxs for _ in range(b * s1 * n2)]).reshape(b, s1, n2, sparse_block_count).to(torch.int32)
query_rope = torch.tensor(np.random.uniform(-10, 10, (b, s1, n1, dr))).to(query_type)
key_rope = torch.tensor(np.random.uniform(-10, 10, (b, s2, n2, dr))).to(query_type)
act_seq_q = [s1] * b
act_seq_kv = [s2_act] * b
act_seq_q = torch.tensor(act_seq_q).to(torch.int32)
act_seq_kv = torch.tensor(act_seq_kv).to(torch.int32)
query_npu = query.npu()
key_npu = key.npu()
value_npu = value.npu()
sparse_indices_npu = sparse_indices.npu()
query_rope_npu = query_rope.npu()
key_rope_npu = key_rope.npu()
act_seq_q_npu = act_seq_q.npu()
act_seq_kv_npu = act_seq_kv.npu()
if requires_grad:
query_npu.requires_grad = True
key_npu.requires_grad = True
value_npu.requires_grad = True
query_rope_npu.requires_grad = True
key_rope_npu.requires_grad = True
params = {
'query': query_npu,
'key': key_npu,
'value': value_npu,
'sparse_indices': sparse_indices_npu,
'query_rope': query_rope_npu,
'key_rope': key_rope_npu,
'act_seq_q': act_seq_q_npu,
'act_seq_kv': act_seq_kv_npu,
'scale_value': scale_value,
'sparse_block_size': sparse_block_size,
'layout': layout,
'sparse_mode': sparse_mode,
'attention_mode': attention_mode,
'return_softmax_lse': return_softmax_lse
}
return params
def call_npu_sparse_flash_attention(self, params):
return torch_npu.npu_sparse_flash_attention(
params['query'], params['key'], params['value'], params['sparse_indices'], params['scale_value'],
block_table=None, actual_seq_lengths_query=params['act_seq_q'], actual_seq_lengths_kv=params['act_seq_kv'],
query_rope=params['query_rope'], key_rope=params['key_rope'], sparse_block_size=params['sparse_block_size'],
layout_query=params['layout'], layout_kv=params['layout'], sparse_mode=params['sparse_mode'],
pre_tokens=(1<<63)-1, next_tokens=(1<<63)-1, attention_mode = params['attention_mode'],
return_softmax_lse = params['return_softmax_lse'])
def test_dsa_npu_sparse_flash_attention(self):
with FakeTensorMode():
params = self.setup_sparse_flash_attention_test_params()
npu_out, npu_softmax_max, npu_softmax_sum = self.call_npu_sparse_flash_attention(params)
query = params['query']
key = params['key']
expect_out = torch.empty([query.size(0), query.size(1), query.size(2), query.size(3)], dtype=query.dtype)
expect_softmax_max = torch.empty([query.size(0), key.size(2), query.size(1), query.size(2) // key.size(2)],
dtype=torch.float32)
expect_softmax_sum = torch.empty([query.size(0), key.size(2), query.size(1), query.size(2) // key.size(2)],
dtype=torch.float32)
self.assertEqual(npu_out.dtype, expect_out.dtype)
self.assertEqual(npu_out.shape, expect_out.shape)
if params['return_softmax_lse']:
self.assertEqual(npu_softmax_max.dtype, expect_softmax_max.dtype)
self.assertEqual(npu_softmax_max.shape, expect_softmax_max.shape)
self.assertEqual(npu_softmax_sum.dtype, expect_softmax_sum.dtype)
self.assertEqual(npu_softmax_sum.shape, expect_softmax_sum.shape)
@unittest.skip("skip sparse_flash_attention_grad now")
def test_dsa_npu_sparse_flash_attention_grad(self):
with FakeTensorMode():
params = self.setup_sparse_flash_attention_test_params(requires_grad=True)
npu_out, npu_softmax_max, npu_softmax_sum = self.call_npu_sparse_flash_attention(params)
loss = npu_out.sum() + npu_softmax_max.sum() + npu_softmax_sum.sum()
loss.backward()
query = params['query']
key = params['key']
value = params['value']
query_rope = params['query_rope']
key_rope = params['key_rope']
d_query = query.grad
d_key = key.grad
d_value = value.grad
d_query_rope = query_rope.grad
d_key_rope = key_rope.grad
self.assertEqual(d_query.dtype, query.dtype)
self.assertEqual(d_query.shape, query.shape)
self.assertEqual(d_key.dtype, key.dtype)
self.assertEqual(d_key.shape, key.shape)
self.assertEqual(d_value.dtype, value.dtype)
self.assertEqual(d_value.shape, value.shape)
self.assertEqual(d_query_rope.dtype, query_rope.dtype)
self.assertEqual(d_query_rope.shape, query_rope.shape)
self.assertEqual(d_key_rope.dtype, key_rope.dtype)
self.assertEqual(d_key_rope.shape, key_rope.shape)
def setup_lightning_indexer_test(self, requires_grad=False, return_value=False):
b = 1
s1 = 1
s2 = 8192
n1 = 64
n2 = 1
d = 128
block_size = 256
t = 8192
layout_query = 'BSND'
query = torch.tensor(np.random.uniform(-10, 10, (b, s1, n1, d))).to(torch.bfloat16)
key = torch.tensor(np.random.uniform(-10, 10, (b * (s2 // block_size), block_size, n2, d))).to(torch.bfloat16)
weights = torch.tensor(np.random.uniform(-1, 1, (b, s1, n1))).to(torch.bfloat16)
actual_seq_lengths_query = torch.tensor(np.random.uniform(s1, s1, (b))).to(torch.int32)
actual_seq_lengths_key = torch.tensor(np.random.uniform(s2, s2, (b))).to(torch.int32)
block_table = torch.tensor([range(b * s2 // block_size)], dtype=torch.int32).reshape(b, -1)
layout_key = 'PA_BSND'
sparse_count = 2048
sparse_mode = 3
query = query.npu()
key = key.npu()
weights = weights.npu()
actual_seq_lengths_query = actual_seq_lengths_query.npu()
actual_seq_lengths_key = actual_seq_lengths_key.npu()
block_table = block_table.npu()
if requires_grad:
query.requires_grad = True
key.requires_grad = True
weights.requires_grad = True
return {
'query': query,
'key': key,
'weights': weights,
'actual_seq_lengths_query': actual_seq_lengths_query,
'actual_seq_lengths_key': actual_seq_lengths_key,
'block_table': block_table,
'layout_query': layout_query,
'layout_key': layout_key,
'sparse_count': sparse_count,
'sparse_mode': sparse_mode,
'return_value': return_value
}
def call_npu_lightning_indexer(self, inputs):
return torch_npu.npu_lightning_indexer(
inputs['query'], inputs['key'], inputs['weights'],
actual_seq_lengths_query=inputs['actual_seq_lengths_query'],
actual_seq_lengths_key=inputs['actual_seq_lengths_key'],
block_table=inputs['block_table'], layout_query=inputs['layout_query'],layout_key=inputs['layout_key'],
sparse_count=inputs['sparse_count'], sparse_mode=inputs['sparse_mode'], return_value=inputs['return_value'])
def test_dsa_npu_lightning_indexer(self):
with FakeTensorMode():
params = self.setup_lightning_indexer_test()
npu_out, npu_values_out = self.call_npu_lightning_indexer(params)
query = params['query']
key = params['key']
sparse_count = params['sparse_count']
expect_out = torch.empty([query.size(0), query.size(1), key.size(2), sparse_count], dtype=torch.int32)
expect_values_out = torch.empty([query.size(0), query.size(1), key.size(2), sparse_count], dtype=query.dtype)
self.assertEqual(npu_out.dtype, expect_out.dtype)
self.assertEqual(npu_out.shape, expect_out.shape)
if params['return_value']:
self.assertEqual(npu_values_out.dtype, expect_values_out.dtype)
self.assertEqual(npu_values_out.shape, expect_values_out.shape)
@unittest.skip("skip lightning_indexer_grad now")
def test_dsa_npu_lightning_indexer_grad(self):
with FakeTensorMode():
params = self.setup_lightning_indexer_test(requires_grad=True)
npu_out, npu_values_out = self.call_npu_lightning_indexer(params)
loss = npu_values_out.sum()
loss.backward()
query = params['query']
key = params['key']
weights = params['weights']
d_query = query.grad
d_key = key.grad
d_weights = weights.grad
self.assertEqual(d_query.dtype, query.dtype)
self.assertEqual(d_query.shape, query.shape)
self.assertEqual(d_key.dtype, key.dtype)
self.assertEqual(d_key.shape, key.shape)
self.assertEqual(d_weights.dtype, weights.dtype)
self.assertEqual(d_weights.shape, weights.shape)
def gen_npu_sparse_lightning_indexer_grad_kl_loss_inputs(self, seqlens_list_array, seqlens_list_kv_array, isTnd):
B = 1
NQuery = 64
NQueryIndex = 64
N2 = 1
S1 = 128
S2 = 128
topK = 2048
D = 512
DIndex = 128
DR = 64
output_dtype = torch.float16
q = torch.randn(B, S1, NQuery, D, dtype=output_dtype, device=torch.device('npu'))
k = torch.randn(B, S2, N2, D, dtype=output_dtype, device=torch.device('npu'))
q_index = torch.randn(B, S1, NQueryIndex, DIndex, dtype=output_dtype, device=torch.device('npu'))
k_index = torch.randn(B, S2, N2, DIndex, dtype=output_dtype, device=torch.device('npu'))
if DR != 0:
q_rope = torch.randn(B, S1, NQuery, DR, dtype=output_dtype, device=torch.device('npu'))
k_rope = torch.randn(B, S2, N2, DR, dtype=output_dtype, device=torch.device('npu'))
else:
q_rope = None
k_rope = None
weights = torch.randn(B, S1, NQueryIndex, dtype=output_dtype, device=torch.device('npu'))
a = -0.05
b = 0.05
kk = 3.0
scale = (b - a) / (2 * kk)
shift = (a + b) / 2
weights = weights * scale + shift
if isTnd:
sparse_indices = torch.zeros(S1, N2, topK).to(torch.int32).npu()
tIdx = 0
for bIdx in range(B):
for s1Idx in range(seqlens_list_array[bIdx]):
s2RealSize = (int)((seqlens_list_kv_array[bIdx] - seqlens_list_array[bIdx]) + s1Idx + 1)
if s2RealSize <= 0:
s2RealSize = seqlens_list_kv_array[bIdx]
if s2RealSize > topK:
s2RealLen = topK
else:
s2RealLen = s2RealSize
sparse_indices[tIdx, :, 0: s2RealLen] = (
torch.randint(0, s2RealSize, (s2RealLen,)).to(torch.int32)).npu()
sparse_indices[tIdx, :, s2RealLen: topK] = -1
tIdx = tIdx + 1
q_tnd = q.squeeze(dim=0)
k_tnd = k.squeeze(dim=0)
q_index_tnd = q_index.squeeze(dim=0)
k_index_tnd = k_index.squeeze(dim=0)
if q_rope is not None:
q_rope_tnd = q_rope.squeeze(dim=0)
k_rope_tnd = k_rope.squeeze(dim=0)
else:
q_rope_tnd = None
k_rope_tnd = None
weights_tnd = weights.squeeze(dim=0)
softmax_max = torch.randn(N2, S1, NQueryIndex, dtype=torch.float, device=torch.device('npu'))
softmax_sum = torch.randn(N2, S1, NQueryIndex, dtype=torch.float, device=torch.device('npu'))
return q_tnd, k_tnd, q_index_tnd, k_index_tnd, q_rope_tnd, k_rope_tnd, weights_tnd, sparse_indices, softmax_max, softmax_sum
else:
sparse_indices = torch.zeros(B, S1, N2, topK).to(torch.int32).npu()
for s1Idx in range(S1):
s2RealSize = (int)(S2 - S1 + s1Idx + 1)
if s2RealSize <= 0:
s2RealSize = S2
if s2RealSize > topK:
s2RealLen = topK
else:
s2RealLen = s2RealSize
sparse_indices[:, s1Idx, 0, 0: s2RealLen] = (
torch.randint(0, s2RealSize, (s2RealLen,)).to(torch.int32)).npu()
sparse_indices[:, s1Idx, 0, s2RealLen: topK] = -1
softmax_max = torch.randn(B, N2, S1, NQueryIndex, dtype=torch.float, device=torch.device('npu'))
softmax_sum = torch.randn(B, N2, S1, NQueryIndex, dtype=torch.float, device=torch.device('npu'))
return q, k, q_index, k_index, q_rope, k_rope, weights, sparse_indices, softmax_max, softmax_sum
def test_dsa_npu_sparse_lightning_indexer_grad_kl_loss(self):
with FakeTensorMode():
actual_seq_qlen = [128]
actual_seq_kvlen = [128]
input_layout = 'TND'
isTnd = True
sparse_mode = 3
scale = 1.0
q, k, q_index, k_index, q_rope, k_rope, weights, sparse_indices, softmax_max, softmax_sum = self.gen_npu_sparse_lightning_indexer_grad_kl_loss_inputs(
actual_seq_qlen, actual_seq_kvlen, isTnd)
d_query_index, d_key_index, d_weights, loss = torch_npu.npu_sparse_lightning_indexer_grad_kl_loss(
q, k, q_index, k_index, weights, sparse_indices, softmax_max, softmax_sum, scale,
query_rope=q_rope, key_rope=k_rope, actual_seq_qlen=actual_seq_qlen, actual_seq_klen=actual_seq_kvlen,
layout=input_layout, sparse_mode=sparse_mode, pre_tokens=65536, next_tokens=65536)
expect_loss = torch.empty([1], dtype=torch.float32)
self.assertEqual(d_query_index.dtype, q_index.dtype)
self.assertEqual(d_query_index.shape, q_index.shape)
self.assertEqual(d_key_index.dtype, k_index.dtype)
self.assertEqual(d_key_index.shape, k_index.shape)
self.assertEqual(d_weights.dtype, weights.dtype)
self.assertEqual(d_weights.shape, weights.shape)
self.assertEqual(loss.dtype, expect_loss.dtype)
self.assertEqual(loss.shape, expect_loss.shape)
def gen_npu_dense_lightning_indexer_grad_kl_loss_inputs(self, isTnd):
B = 1
N1 = 64
N2 = N1
N1_index = 64
N2_index = 1
S1 = 128
S2 = 256
D = 128
Dr = 64
output_dtype = torch.float16
q = torch.randn(B, S1, N1, D, dtype=output_dtype, device=torch.device('npu'))
k = torch.randn(B, S2, N2, D, dtype=output_dtype, device=torch.device('npu'))
q_index = torch.randn(B, S1, N1_index, D, dtype=output_dtype, device=torch.device('npu'))
k_index = torch.randn(B, S2, N2_index, D, dtype=output_dtype, device=torch.device('npu'))
if Dr != 0:
q_rope = torch.randn(B, S1, N1, Dr, dtype=output_dtype, device=torch.device('npu'))
k_rope = torch.randn(B, S2, N2, Dr, dtype=output_dtype, device=torch.device('npu'))
else:
q_rope = None
k_rope = None
weights = torch.randn(B, S1, N1_index, dtype=output_dtype, device=torch.device('npu'))
if isTnd:
q_tnd = q.squeeze(dim=0)
k_tnd = k.squeeze(dim=0)
q_index_tnd = q_index.squeeze(dim=0)
k_index_tnd = k_index.squeeze(dim=0)
if q_rope is not None:
q_rope_tnd = q_rope.squeeze(dim=0)
k_rope_tnd = k_rope.squeeze(dim=0)
else :
q_rope_tnd = None
k_rope_tnd = None
weights_tnd = weights.squeeze(dim=0)
softmax_max = (torch.randn(N2, S1, 1, dtype=torch.float32, device=torch.device('npu')).abs() + 0.4) * D
softmax_sum = torch.ones(N2, S1, 1, dtype=torch.float32, device=torch.device('npu'))
actual_seq_qlen = [S1]
actual_seq_klen = [S2]
return q_tnd, k_tnd, q_index_tnd, k_index_tnd, q_rope_tnd, k_rope_tnd, weights_tnd, softmax_max, softmax_sum, actual_seq_qlen, actual_seq_klen
else :
softmax_max = (torch.randn(B, N2, S1, 1, dtype=torch.float32, device=torch.device('npu')).abs() + 0.4) * D
softmax_sum = torch.ones(B, N2, S1, 1, dtype=torch.float32, device=torch.device('npu'))
actual_seq_qlen = None
actual_seq_klen = None
return q, k, q_index, k_index, q_rope, k_rope, weights, softmax_max, softmax_sum, actual_seq_qlen, actual_seq_klen
def test_dsa_npu_dense_lightning_indexer_grad_kl_loss(self):
with FakeTensorMode():
q_dtype = torch.float16
B, N1, N2, N1_index, N2_index, S1, S2, D, Dr = 1, 64, 64, 64, 1, 128, 256, 128, 64
query = torch.randn(B, S1, N1, D, dtype=q_dtype)
key = torch.randn(B, S2, N2, D, dtype=q_dtype)
query_index = torch.randn(B, S1, N1_index, D, dtype=q_dtype)
key_index = torch.randn(B, S2, N2_index, D, dtype=q_dtype)
query_rope = torch.randn(B, S1, N1, Dr, dtype=q_dtype)
key_rope = torch.randn(B, S2, N2, Dr, dtype=q_dtype)
weights = torch.randn(B, S1, N1_index, dtype=q_dtype)
softmax_max = (torch.randn(B, N2, S1, 1, dtype=torch.float32).abs() + 0.4) * D
softmax_sum = torch.ones(B, N2, S1, 1, dtype=torch.float32)
softmax_max_index = (torch.randn(B, 1, S1, dtype=torch.float32).abs() + 0.4) * D * N1_index
softmax_sum_index = torch.ones(B, 1, S1, dtype=torch.float32)
actual_seq_qlen = [S1]
actual_seq_klen = [S2]
scale = 1.0
d_query_index, d_key_index, d_weights, loss = torch_npu.npu_dense_lightning_indexer_grad_kl_loss(
query, key, query_index, key_index, weights, softmax_max, softmax_sum, softmax_max_index, softmax_sum_index, scale,
query_rope=query_rope, key_rope=key_rope, actual_seq_qlen=actual_seq_qlen,
actual_seq_klen=actual_seq_klen, layout="BSND", sparse_mode=3)
expect_loss = torch.empty([1], dtype=torch.float32)
self.assertEqual(d_query_index.dtype, query_index.dtype)
self.assertEqual(d_query_index.shape, query_index.shape)
self.assertEqual(d_key_index.dtype, key_index.dtype)
self.assertEqual(d_key_index.shape, key_index.shape)
self.assertEqual(d_weights.dtype, weights.dtype)
self.assertEqual(d_weights.shape, weights.shape)
self.assertEqual(loss.dtype, expect_loss.dtype)
self.assertEqual(loss.shape, expect_loss.shape)
class TestNpuConv2d(TestCase):
def test_npu_conv2d_meta_0(self):
with FakeTensorMode():
input_ = torch.randn(1, 3, 32, 32, dtype=torch.float).npu()
weight = torch.randn(6, 3, 5, 5, dtype=torch.float).npu()
bias = torch.randn(6, dtype=torch.float).npu()
stride = [1, 1]
padding = [2, 2]
dilation = [1, 1]
groups = 1
output = torch_npu.npu_conv2d(
input_, weight, bias, stride, padding, dilation, groups
)
expect_output = torch.empty([1, 6, 32, 32], dtype=torch.float)
self.assertEqual(output.dtype, expect_output.dtype)
self.assertEqual(output.shape, expect_output.shape)
def test_npu_conv2d_backward_meta(self):
input_tensor = torch.randn(4, 3, 28, 28, device="npu", requires_grad=True)
weight_tensor = torch.randn(4, 3, 4, 4, device="npu", requires_grad=True)
bias = torch.randn(4, device="npu", requires_grad=True)
stride = (1, 1)
padding = (1, 1)
dilation = (1, 1)
groups = 1
output_mask = [True, True, True]
output_tensor = torch_npu.npu_conv2d(input_tensor, weight_tensor, bias, stride, padding, dilation, groups)
grad_output = torch.ones_like(output_tensor, device="npu")
input_grad, weight_grad, bias_grad = torch_npu.npu_conv2d_backward(input_tensor, grad_output, weight_tensor, stride, padding, dilation, groups, output_mask)
with FakeTensorMode():
input_fake_tensor = torch.randn(4, 3, 28, 28, device="npu", requires_grad=True)
weight_fake_tensor = torch.randn(4, 3, 4, 4, device="npu", requires_grad=True)
bias_fake = torch.randn(4, device="npu", requires_grad=True)
output_fake_tensor = torch_npu.npu_conv2d(input_fake_tensor, weight_fake_tensor, bias_fake, stride, padding, dilation, groups)
grad_output_fake = torch.ones_like(output_fake_tensor, device="npu")
input_grad_fake, weight_grad_fake, bias_grad_fake = torch_npu.npu_conv2d_backward(input_fake_tensor, grad_output_fake, weight_fake_tensor, stride, padding, dilation, groups, output_mask)
self.assertEqual(input_grad.shape, input_grad_fake.shape)
self.assertEqual(input_grad.dtype, input_grad_fake.dtype)
self.assertEqual(weight_grad.shape, weight_grad_fake.shape)
self.assertEqual(weight_grad.dtype, weight_grad_fake.dtype)
self.assertEqual(bias_grad.shape, bias_grad_fake.shape)
self.assertEqual(bias_grad.dtype, bias_grad_fake.dtype)
class TestNpuApplyAdamW(TestCase):
def test_npu_apply_adam_w_meta_0(self):
with FakeTensorMode():
var_value = np.random.uniform(10.0, 20.0)
m_value = np.random.uniform(5.0, 10.0)
beta1_power = np.random.uniform(0.0, 1.0)
beta2_power = np.random.uniform(0.0, 1.0)
lr = np.random.uniform(0.0001, 0.1)
weight_decay = np.random.uniform(0.001, 0.1)
beta1 = np.random.uniform(0.5, 1.0)
beta2 = np.random.uniform(0.5, 1.0)
eps = np.random.uniform(0.00001, 0.01)
max_grad_norm = None
amsgrad = False
maximize = True
var = torch.empty((21130, 512), device="npu")
var.fill_(var_value)
m = torch.empty((21130, 512), device="npu")
m.fill_(m_value)
v = torch.empty((21130, 512), device="npu")
v.zero_()
grad = torch.empty((21130, 512), device="npu")
grad.zero_()
var_out, m_out, v_out = torch_npu.npu_apply_adam_w(
beta1_power,
beta2_power,
lr,
weight_decay,
beta1,
beta2,
eps,
grad,
max_grad_norm,
amsgrad,
maximize,
out=(var, m, v),
)
self.assertEqual(var_out.dtype, var.dtype)
self.assertEqual(var_out.shape, var.shape)
self.assertEqual(m_out.dtype, m.dtype)
self.assertEqual(m_out.shape, m.shape)
self.assertEqual(v_out.dtype, v.dtype)
self.assertEqual(v_out.shape, v.shape)
class TestNpuCrossEntropyLoss(TestCase):
def test_npu_cross_entropy_loss_meta_0(self):
with FakeTensorMode():
N = 4096
C = 8080
input = torch.randn(N, C, dtype=torch.float32,
requires_grad=True).npu()
target = torch.arange(0, N, dtype=torch.int64).npu()
loss, log_prob, _, _ = torch_npu.npu_cross_entropy_loss(
input, target, reduction="sum", ignore_index=100
)
expect_loss = torch.empty([1], dtype=torch.float32)
expect_log_prob = torch.empty([N, C], dtype=torch.float32)
self.assertEqual(loss.dtype, expect_loss.dtype)
self.assertEqual(loss.shape, expect_loss.shape)
self.assertEqual(log_prob.dtype, expect_log_prob.dtype)
self.assertEqual(log_prob.shape, expect_log_prob.shape)
@SupportedDevices(['Ascend910B'])
def test_npu_cross_entropy_loss_backward_meta(self):
N = 8
C = 8
input = torch.randn(N, C, device="npu", requires_grad=True)
target = torch.arange(0, N, device="npu")
loss, log_prob, _, _ = torch_npu.npu_cross_entropy_loss(input, target, reduction="sum")
loss.backward()
with FakeTensorMode():
input_fake_tensor = torch.randn(N, C, device="npu", requires_grad=True)
target_fake_tensor = torch.arange(0, N, device="npu")
loss_fake_tensor, log_prob_fake_tensor, _, _ = torch_npu.npu_cross_entropy_loss(input_fake_tensor, target_fake_tensor, reduction="sum")
loss_fake_tensor.backward()
self.assertEqual(input.grad.shape, input_fake_tensor.grad.shape)
self.assertEqual(input.grad.dtype, input_fake_tensor.grad.dtype)
class TestNpuFusionAttention(TestCase):
def test_npu_fusion_attention_input_layout_is_BSND(self):
with FakeTensorMode():
B, N, S, D = 4, 8, 32, 32
shape = (B, S, N, D)
query = torch.randn(shape, dtype=torch.float32, device="npu", requires_grad=True)
key = torch.randn(shape, dtype=torch.float32, device="npu", requires_grad=True)
value = torch.randn(shape, dtype=torch.float32, device="npu", requires_grad=True)
scale = 0.08838
softmax_max_sum = torch.randn((B, N, S, 8), dtype=torch.float32).npu()
res = torch_npu.npu_fusion_attention(query, key, value, head_num=N, input_layout="BSND", scale=scale)
self.assertEqual(query.shape, res[0].shape)
self.assertEqual(query.dtype, res[0].dtype)
self.assertEqual(softmax_max_sum.shape, res[1].shape)
self.assertEqual(softmax_max_sum.dtype, res[1].dtype)
self.assertEqual(softmax_max_sum.shape, res[2].shape)
self.assertEqual(softmax_max_sum.dtype, res[2].dtype)
class TestNpuRepeatInterleave(TestCase):
def test_npu_repeat_interleave_backward(self):
x = torch.randn(2, 2, device="npu", requires_grad=True)
repeats_value = 3
output = torch.repeat_interleave(x, repeats_value)
grad = torch.randn(output.size(), device="npu")
output.backward(grad)
with FakeTensorMode():
x_fake_tensor = torch.randn(2, 2, device="npu", requires_grad=True)
repeats_value = 3
output_fake_tensor = torch.repeat_interleave(x_fake_tensor, repeats_value)
grad_fake_tensor = torch.randn(output_fake_tensor.size(), device="npu")
output_fake_tensor.backward(grad_fake_tensor)
self.assertEqual(x.shape, x_fake_tensor.shape)
self.assertEqual(x.dtype, x_fake_tensor.dtype)
class TestQuantMatmulInplaceAdd(TestCase):
def test_npu_quant_batch_matmul_inplace_add(self):
M = 7168
N = 576
K = 512
with FakeTensorMode():
y = torch.randint(-1, 1, (M, N), dtype=torch.float32).npu()
x1 = torch.randint(-1, 1, (M, math.ceil(K/64)), dtype=torch.int8).npu()
x2 = torch.randint(-1, 1, (K, N), dtype=torch.int8).npu()
x2_scale = torch.randint(-1, 1, (M, math.ceil(K/64), 2), dtype=torch.float32).npu()
x1_scale = torch.randint(-1, 1, (math.ceil(K/64), 576, 2), dtype=torch.float32).npu()
res_1 = torch_npu.npu_add_quant_matmul_(y, x1, x2, x2_scale, x1_scale=x1_scale, group_sizes=None)
self.assertTrue(len(res_1.shape) == 2)
self.assertTrue(x1.shape[0] == res_1.shape[0])
self.assertTrue(x2.shape[1] == res_1.shape[1])
self.assertTrue(res_1.dtype == torch.float32)
res_2 = torch_npu.npu_add_quant_matmul(y, x1, x2, x2_scale, x1_scale=x1_scale, group_sizes=None)
self.assertTrue(len(res_1.shape) == 2)
self.assertTrue(x1.shape[0] == res_1.shape[0])
self.assertTrue(x2.shape[1] == res_1.shape[1])
self.assertTrue(res_1.dtype == torch.float32)
class TestNpuDeformableConv2d(TestCase):
def test_npu_deformable_conv2d(self):
with FakeTensorMode():
input = torch.rand(16, 32, 32, 32).npu()
weight = torch.rand(32, 32, 5, 5).npu()
offset = torch.rand(16, 75, 38, 38).npu()
out, deformOut = torch_npu.npu_deformable_conv2d(
input,
weight,
offset,
None,
kernel_size=[5, 5],
stride=[1, 1, 1, 1],
padding=[4, 6, 8, 2]
)
self.assertEqual(out.shape, (16, 32, 38, 38))
self.assertEqual(deformOut.shape, (16, 32, 190, 190))
class TestNpuPsRoiPooling(TestCase):
def test_npu_ps_roi_pooling(self):
with FakeTensorMode():
x = torch.rand((8, 68, 32, 32, 1), dtype=torch.float).npu()
roi = torch.rand((8, 5, 2), dtype=torch.float).npu()
out = torch_npu.npu_ps_roi_pooling(x, roi, 0.5, 2, 2)
self.assertEqual(out.shape, (16, 2, 2, 2))
class TestNpuConvolution(TestCase):
def test_npu_convolution(self):
with FakeTensorMode():
input = torch.empty([1, 128, 4, 14, 14], dtype=torch.float).uniform_(-1, 1).npu()
weight = torch.empty([1, 128, 3, 3, 3], dtype=torch.float).uniform_(-1, 1).npu()
bias = None
stride = [1, 1, 1]
padding = [1, 1, 1]
dilation = [1, 1, 1]
groups = 1
out = torch_npu.npu_convolution(input, weight, bias, stride, padding, dilation, groups)
self.assertEqual(out.shape, (1, 1, 4, 14, 14))
class TestNpuConvolutionTranspose(TestCase):
def test_npu_convolution_transpose(self):
with FakeTensorMode():
input = torch.empty([1, 3, 3, 3], dtype=torch.float).uniform_(-1, 1).npu()
weight = torch.empty([3, 2, 3, 3], dtype=torch.float).uniform_(-1, 1).npu()
bias = torch.empty([2], dtype=torch.float).uniform_(-1, 1).npu()
padding = [1, 1]
output_padding = [0, 0]
stride = [1, 1]
dilation = [1, 1]
groups = 1
out = torch_npu.npu_convolution_transpose(input, weight, bias, padding, output_padding, stride, dilation, groups)
self.assertEqual(out.shape, (1, 2, 3, 3))
class TestBatchNormReduce(TestCase):
def test_batch_norm_reduce(self):
with FakeTensorMode():
input_fake_tensor = torch.randn(2, 3, 12, 12, device="npu", requires_grad=True)
eps = 1e-5
output_fake_tensor1, output_fake_tensor2 = torch_npu.batch_norm_reduce(input_fake_tensor, eps)
self.assertEqual(input_fake_tensor.shape[1], output_fake_tensor1.shape[0])
self.assertEqual(input_fake_tensor.dtype, output_fake_tensor1.dtype)
self.assertEqual(input_fake_tensor.shape[1], output_fake_tensor2.shape[0])
self.assertEqual(input_fake_tensor.dtype, output_fake_tensor2.dtype)
class TestMatmul(TestCase):
@staticmethod
def _randn(shape, dtype, requires_grad=False):
if len(shape) == 0:
return torch.randn((), dtype=dtype, device="npu", requires_grad=requires_grad)
return torch.randn(shape, dtype=dtype, device="npu", requires_grad=requires_grad)
def _run_fake_matmul_backward(self, grad_shape, mat1_shape, mat2_shape, grad_dtype, mat1_dtype, mat2_dtype):
with FakeTensorMode():
grad = self._randn(grad_shape, grad_dtype)
mat1 = self._randn(mat1_shape, mat1_dtype, requires_grad=True)
mat2 = self._randn(mat2_shape, mat2_dtype, requires_grad=True)
return torch.ops.aten.matmul_backward.default(grad, mat1, mat2, [True, True])
def _assert_fake_backward_matches_real(self, mat1_shape, mat2_shape, dtype):
mat1 = self._randn(mat1_shape, dtype, requires_grad=True)
mat2 = self._randn(mat2_shape, dtype, requires_grad=True)
output = torch.matmul(mat1, mat2)
grad = torch.randn(output.size(), dtype=output.dtype, device="npu")
output.backward(grad)
with FakeTensorMode():
fake_mat1 = self._randn(mat1_shape, dtype, requires_grad=True)
fake_mat2 = self._randn(mat2_shape, dtype, requires_grad=True)
fake_output = torch.matmul(fake_mat1, fake_mat2)
fake_grad = torch.randn(fake_output.size(), dtype=fake_output.dtype, device="npu")
fake_output.backward(fake_grad)
self.assertEqual(mat1.grad.shape, fake_mat1.grad.shape)
self.assertEqual(mat1.grad.dtype, fake_mat1.grad.dtype)
self.assertEqual(mat2.grad.shape, fake_mat2.grad.shape)
self.assertEqual(mat2.grad.dtype, fake_mat2.grad.dtype)
def test_matmul_backward(self):
x = torch.randn(8, 4, 8, device="npu", requires_grad=True)
y = torch.randn(8, 8, 8, 4, device="npu", requires_grad=True)
output = torch.matmul(x, y)
grad = torch.randn(output.size(), device="npu")
output.backward(grad)
with FakeTensorMode():
x_fake_tensor = torch.randn(8, 4, 8, device="npu", requires_grad=True)
y_fake_tensor = torch.randn(8, 8, 8, 4, device="npu", requires_grad=True)
output_fake_tensor = torch.matmul(x_fake_tensor, y_fake_tensor)
grad_fake_tensor = torch.randn(output_fake_tensor.size(), device="npu")
output_fake_tensor.backward(grad_fake_tensor)
self.assertEqual(x.grad.shape, x_fake_tensor.grad.shape)
self.assertEqual(x.grad.dtype, x_fake_tensor.grad.dtype)
self.assertEqual(y.grad.shape, y_fake_tensor.grad.shape)
self.assertEqual(y.grad.dtype, y_fake_tensor.grad.dtype)
@unittest.skipIf(
os.getenv("TORCH_NPU_USE_COMPATIBLE_IMPL") == "1",
"matmul_backward meta registration is disabled in compatible impl",
)
def test_matmul_backward_meta_shapes_and_dtypes(self):
test_cases = [
{
"name": "vector_vector",
"grad_shape": (),
"mat1_shape": (8,),
"mat2_shape": (8,),
"grad_dtype": torch.float32,
"mat1_dtype": torch.float16,
"mat2_dtype": torch.float16,
"expected_self_shape": (1, 8),
"expected_other_shape": (8,),
"expected_self_dtype": torch.float32,
"expected_other_dtype": torch.float16,
},
{
"name": "vector_matrix",
"grad_shape": (6,),
"mat1_shape": (5,),
"mat2_shape": (5, 6),
"grad_dtype": torch.float32,
"mat1_dtype": torch.float16,
"mat2_dtype": torch.float16,
"expected_self_shape": (1, 5),
"expected_other_shape": (5, 6),
"expected_self_dtype": torch.float32,
"expected_other_dtype": torch.float16,
},
{
"name": "matrix_vector_squeeze_other_grad",
"grad_shape": (3,),
"mat1_shape": (3, 5),
"mat2_shape": (5,),
"grad_dtype": torch.float32,
"mat1_dtype": torch.float32,
"mat2_dtype": torch.float16,
"expected_self_shape": (3, 5),
"expected_other_shape": (5,),
"expected_self_dtype": torch.float32,
"expected_other_dtype": torch.float32,
},
{
"name": "matrix_batched_matrix_special_self_grad",
"grad_shape": (2, 3, 4),
"mat1_shape": (3, 5),
"mat2_shape": (2, 5, 4),
"grad_dtype": torch.float16,
"mat1_dtype": torch.float16,
"mat2_dtype": torch.float16,
"expected_self_shape": (3, 5),
"expected_other_shape": (2, 5, 4),
"expected_self_dtype": torch.float16,
"expected_other_dtype": torch.float16,
},
{
"name": "batched_matrix_matrix_special_other_grad",
"grad_shape": (2, 3, 6),
"mat1_shape": (2, 3, 5),
"mat2_shape": (5, 6),
"grad_dtype": torch.float16,
"mat1_dtype": torch.float16,
"mat2_dtype": torch.float16,
"expected_self_shape": (2, 3, 5),
"expected_other_shape": (5, 6),
"expected_self_dtype": torch.float16,
"expected_other_dtype": torch.float16,
},
{
"name": "leading_singleton_self_batch",
"grad_shape": (1, 4, 3, 6),
"mat1_shape": (1, 4, 3, 5),
"mat2_shape": (5, 6),
"grad_dtype": torch.float32,
"mat1_dtype": torch.float16,
"mat2_dtype": torch.float16,
"expected_self_shape": (1, 4, 3, 5),
"expected_other_shape": (5, 6),
"expected_self_dtype": torch.float32,
"expected_other_dtype": torch.float16,
},
{
"name": "leading_singleton_other_batch",
"grad_shape": (1, 4, 3, 6),
"mat1_shape": (3, 5),
"mat2_shape": (1, 4, 5, 6),
"grad_dtype": torch.float32,
"mat1_dtype": torch.float32,
"mat2_dtype": torch.float16,
"expected_self_shape": (3, 5),
"expected_other_shape": (1, 4, 5, 6),
"expected_self_dtype": torch.float32,
"expected_other_dtype": torch.float32,
},
]
for case in test_cases:
with self.subTest(case=case["name"]):
self_grad, other_grad = self._run_fake_matmul_backward(
case["grad_shape"],
case["mat1_shape"],
case["mat2_shape"],
case["grad_dtype"],
case["mat1_dtype"],
case["mat2_dtype"],
)
self.assertEqual(self_grad.shape, torch.Size(case["expected_self_shape"]))
self.assertEqual(other_grad.shape, torch.Size(case["expected_other_shape"]))
self.assertEqual(self_grad.dtype, case["expected_self_dtype"])
self.assertEqual(other_grad.dtype, case["expected_other_dtype"])
@unittest.skipIf(
os.getenv("TORCH_NPU_USE_COMPATIBLE_IMPL") == "1",
"matmul_backward meta registration is disabled in compatible impl",
)
def test_matmul_backward_fake_autograd_matches_real_grad_meta(self):
test_cases = [
((8,), (8, 4), torch.float16),
((4, 8), (2, 8, 6), torch.float16),
((1, 4, 3, 5), (5, 6), torch.float32),
]
for mat1_shape, mat2_shape, dtype in test_cases:
with self.subTest(mat1_shape=mat1_shape, mat2_shape=mat2_shape, dtype=dtype):
self._assert_fake_backward_matches_real(mat1_shape, mat2_shape, dtype)
class TestKLDivLoss(TestCase):
def test_kl_div_backward(self):
loss_fn = torch.nn.KLDivLoss(reduction="sum", log_target=True).to("npu")
pred = torch.randn(4, 4, device="npu", requires_grad=True)
target = torch.randn(4, 4, device="npu")
loss = loss_fn(pred.log_softmax(dim=-1), target)
loss.backward()
with FakeTensorMode():
pred_fake_tensor = torch.randn(4, 4, device="npu", requires_grad=True)
target_fake_tensor = torch.randn(4, 4, device="npu")
loss_fake_tensor = loss_fn(pred_fake_tensor.log_softmax(dim=-1), target_fake_tensor)
loss_fake_tensor.backward()
self.assertEqual(pred.grad.shape, pred_fake_tensor.grad.shape)
self.assertEqual(pred.grad.dtype, pred_fake_tensor.grad.dtype)
class TestNpuDropoutGenMaskMeta(TestCase):
def test_npu_dropout_gen_mask(self):
size = (2, 1024, 768)
x = torch.randn(size, device="npu")
out = torch_npu._npu_dropout_gen_mask(x, size, 0.5, 1, 0)
with FakeTensorMode():
x_fake_tensor = torch.empty(size, device='meta')
out_fake_tensor = torch_npu._npu_dropout_gen_mask(x_fake_tensor, size, 0.5, 1, 0)
self.assertEqual(out_fake_tensor.shape, out.shape)
self.assertEqual(out_fake_tensor.dtype, out.dtype)
class TestNpuSoftmaxCrossEntropyWithLogitsMeta(TestCase):
def test_npu_softmax_cross_entropy_with_logits(self):
logits = torch.randn(32, 1000, device="npu")
labels = torch.randint(0, 1000, (32,), dtype=torch.long, device="npu")
loss = torch_npu.npu_softmax_cross_entropy_with_logits(logits, labels)
with FakeTensorMode():
logits_fake_tensor = torch.empty(32, 1000, device='meta')
labels_fake_tensor = torch.empty(32, dtype=torch.long, device='meta')
loss_fake_tensor = torch_npu.npu_softmax_cross_entropy_with_logits(logits_fake_tensor, labels_fake_tensor)
self.assertEqual(loss_fake_tensor.shape, loss.shape)
self.assertEqual(loss_fake_tensor.dtype, loss.dtype)
class TestNpuSoftmaxCrossEntropyWithLogitsBackwardMeta(TestCase):
def test_npu_softmax_cross_entropy_with_logits_backward(self):
logits = torch.randn(32, 1000, device="npu")
labels = torch.randint(0, 1000, (32,), dtype=torch.long, device="npu")
loss = torch_npu.npu_softmax_cross_entropy_with_logits(logits, labels)
grad = torch.ones_like(loss)
grad_out = torch_npu.npu_softmax_cross_entropy_with_logits_backward(grad, logits, labels)
with FakeTensorMode():
logits_fake_tensor = torch.empty(32, 1000, device='meta')
labels_fake_tensor = torch.empty(32, dtype=torch.long, device='meta')
grad_fake_tensor = torch.empty(32, device='meta')
grad_out_fake_tensor = torch_npu.npu_softmax_cross_entropy_with_logits_backward(
grad_fake_tensor, logits_fake_tensor, labels_fake_tensor
)
self.assertEqual(grad_out_fake_tensor.shape, grad_out.shape)
self.assertEqual(grad_out_fake_tensor.dtype, grad_out.dtype)
class TestNpuIndexAddMeta(TestCase):
@unittest.skip("Skip until CANN supports aclnnIndexAddV2; do not execute")
def test_npu_index_add(self):
x = torch.randn(10, 5, device="npu")
index = torch.randint(0, 10, (2,), dtype=torch.int64, device="npu")
source = torch.randn(2, 5, device="npu")
out = torch_npu._npu_index_add(x, index, source, alpha=1)
with FakeTensorMode():
x_fake_tensor = torch.empty(10, 5, device='meta')
index_fake_tensor = torch.empty(2, dtype=torch.int64, device='meta')
source_fake_tensor = torch.empty(2, 5, device='meta')
out_fake_tensor = torch_npu._npu_index_add(x_fake_tensor, index_fake_tensor, source_fake_tensor, alpha=1)
self.assertEqual(out_fake_tensor.shape, out.shape)
self.assertEqual(out_fake_tensor.dtype, out.dtype)
class TestNpuIndexAddInplaceMeta(TestCase):
@unittest.skip("Skip until CANN supports aclnnIndexAddV2; do not execute")
def test_npu_index_add_(self):
x = torch.randn(10, 5, device="npu")
index = torch.randint(0, 10, (2,), dtype=torch.int64, device="npu")
source = torch.randn(2, 5, device="npu")
out = torch_npu._npu_index_add_(x, index, source, alpha=1)
with FakeTensorMode():
x_fake_tensor = torch.empty(10, 5, device='meta')
index_fake_tensor = torch.empty(2, dtype=torch.int64, device='meta')
source_fake_tensor = torch.empty(2, 5, device='meta')
out_fake_tensor = torch_npu._npu_index_add_(x_fake_tensor, index_fake_tensor, source_fake_tensor, alpha=1)
self.assertEqual(out_fake_tensor.shape, out.shape)
self.assertEqual(out_fake_tensor.dtype, out.dtype)
self.assertIs(out_fake_tensor, x_fake_tensor)
class TestNpuReshapeMeta(TestCase):
def test_npu_reshape(self):
x = torch.randn(2, 3, 4, device="npu")
shape = (6, 4)
out = torch_npu.npu_reshape(x, shape, can_refresh=False)
with FakeTensorMode():
x_fake_tensor = torch.empty(2, 3, 4, device='meta')
out_fake_tensor = torch_npu.npu_reshape(x_fake_tensor, shape, can_refresh=False)
self.assertEqual(out_fake_tensor.shape, out.shape)
self.assertEqual(out_fake_tensor.dtype, out.dtype)
class TestSwigluMxQuant(TestCase):
def test_npu_swiglu_mx_quant_meta(self):
with FakeTensorMode():
x = torch.randn([4, 8], dtype=torch.float16, device='npu')
group_index = None
y_npu, mxscale_npu = torch_npu.npu_swiglu_mx_quant(
x, group_index=group_index, activate_dim=-1, activate_left=False, swiglu_mode=0,
clamp_limit=7, glu_alpha=1.702, glu_bias=1, group_mode=0, axis=-1,
dst_type=torch_npu.float4_e2m1fn_x2, round_mode="rint", scale_alg=0, max_dtype_value=0
)
self.assertEqual(y_npu.shape, torch.Size([4, 2]))
self.assertEqual(y_npu.dtype, torch.uint8)
self.assertEqual(mxscale_npu.shape, torch.Size([4, 1, 2]))
self.assertEqual(mxscale_npu.dtype, torch.uint8)
class TestSwigluGroupQuantBackward(TestCase):
def test_npu_swiglu_group_quant_backward_meta(self):
with FakeTensorMode():
grad_y = torch.randn([4, 8], dtype=torch.float32, device='npu')
x = torch.randn([4, 16], dtype=torch.float32, device='npu')
weight = torch.randn([4, 1], dtype=torch.float32, device='npu')
y_origin = torch.randn([4, 8], dtype=torch.float32, device='npu')
group_index = None
clamp_limit = 0.0
grad_x_npu, grad_weight_npu = torch_npu.npu_swiglu_group_quant_backward(grad_y, x, weight=weight,
y_origin=y_origin, group_index=group_index,
clamp_limit=clamp_limit)
self.assertEqual(grad_x_npu.shape, torch.Size([4, 16]))
self.assertEqual(grad_x_npu.dtype, torch.float32)
self.assertEqual(grad_weight_npu.shape, torch.Size([4, 1]))
self.assertEqual(grad_weight_npu.dtype, torch.float32)
@unittest.skip("skip until CANN is updated to support aclnnDynamicBlockMxQuant")
class TestNpuDynamicBlockMxQuant(TestCase):
def test_npu_dynamic_block_mx_quant_meta(self):
x = torch.rand(64, 256).to("npu").to(torch.float16)
actual_y, actual_scale1, actual_scale2 = torch_npu.npu_dynamic_block_mx_quant(x)
with FakeTensorMode():
fake_x = torch.rand(64, 256).to("npu").to(torch.float16)
fake_y, fake_scale1, fake_scale2 = torch_npu.npu_dynamic_block_mx_quant(fake_x)
self.assertEqual(actual_y.shape, fake_y.shape)
self.assertEqual(actual_scale1.shape, fake_scale1.shape)
self.assertEqual(actual_scale2.shape, fake_scale2.shape)
self.assertEqual(actual_y.dtype, fake_y.dtype)
self.assertEqual(actual_scale1.dtype, fake_scale1.dtype)
self.assertEqual(actual_scale2.dtype, fake_scale2.dtype)
x = torch.rand(32, 64, 256).to("npu").to(torch.float16)
actual_y, actual_scale1, actual_scale2 = torch_npu.npu_dynamic_block_mx_quant(x)
with FakeTensorMode():
fake_x = torch.rand(32, 64, 256).to("npu").to(torch.float16)
fake_y, fake_scale1, fake_scale2 = torch_npu.npu_dynamic_block_mx_quant(fake_x)
self.assertEqual(actual_y.shape, fake_y.shape)
self.assertEqual(actual_scale1.shape, fake_scale1.shape)
self.assertEqual(actual_scale2.shape, fake_scale2.shape)
self.assertEqual(actual_y.dtype, fake_y.dtype)
self.assertEqual(actual_scale1.dtype, fake_scale1.dtype)
self.assertEqual(actual_scale2.dtype, fake_scale2.dtype)
class TestNpuMultiHeadAttention(TestCase):
def test_npu_multi_head_attention_meta(self):
with FakeTensorMode():
batch = 2
tgt_len = 16
src_len = 16
attn_head_num = 8
attn_dim_per_head = 64
weight_col = attn_head_num * attn_dim_per_head
dropout_prob = 0.5
softmax_use_float = True
query = torch.randn(batch * tgt_len, weight_col,
dtype=torch.float16).npu()
key = torch.randn(batch * src_len, weight_col,
dtype=torch.float16).npu()
value = torch.randn(batch * src_len, weight_col,
dtype=torch.float16).npu()
query_weight = torch.randn(
weight_col, weight_col, dtype=torch.float16).npu()
key_weight = torch.randn(
weight_col, weight_col, dtype=torch.float16).npu()
value_weight = torch.randn(
weight_col, weight_col, dtype=torch.float16).npu()
attn_mask = torch.randn(
batch * attn_head_num * tgt_len * src_len, dtype=torch.float16).npu()
out_proj_weight = torch.randn(
weight_col, weight_col, dtype=torch.float16).npu()
query_bias = torch.randn(1, weight_col, dtype=torch.float16).npu()
key_bias = torch.randn(1, weight_col, dtype=torch.float16).npu()
value_bias = torch.randn(1, weight_col, dtype=torch.float16).npu()
out_proj_bias = torch.randn(
1, weight_col, dtype=torch.float16).npu()
dropout_mask_input = torch.randint(
0, 1, (weight_col, ), dtype=torch.uint8).npu()
y, dropout_mask, query_res, key_res, value_res, attn_scores, attn_res, context = torch_npu.npu_multi_head_attention(
query, key, value, query_weight, key_weight, value_weight, attn_mask, out_proj_weight,
query_bias, key_bias, value_bias, out_proj_bias, attn_head_num=attn_head_num,
attn_dim_per_head=attn_dim_per_head, src_len=src_len, tgt_len=tgt_len, dropout_mask=dropout_mask_input, dropout_prob=dropout_prob, softmax_use_float=softmax_use_float
)
self.assertEqual(y.shape, (batch * tgt_len, weight_col))
self.assertEqual(y.dtype, query.dtype)
self.assertEqual(dropout_mask.shape,
(batch * attn_head_num * tgt_len * src_len // 8,))
self.assertEqual(dropout_mask.dtype, torch.uint8)
self.assertEqual(
query_res.shape, (batch, attn_head_num, tgt_len, attn_dim_per_head))
self.assertEqual(
key_res.shape, (batch, attn_head_num, src_len, attn_dim_per_head))
self.assertEqual(
value_res.shape, (batch, attn_head_num, src_len, attn_dim_per_head))
self.assertEqual(attn_scores.shape,
(batch, attn_head_num, tgt_len, src_len))
self.assertEqual(
attn_res.shape, (batch, attn_head_num, tgt_len, src_len))
self.assertEqual(context.shape, (batch * tgt_len, weight_col))
class TestNpuMultiHeadAttentionBackward(TestCase):
def test_npu_multi_head_attention_backward_meta(self):
with FakeTensorMode():
batch = 8
attn_head_num = 16
attn_dim_per_head = 64
src_len = 64
tgt_len = 64
dropout_prob = 0.0
softmax_use_float = True
dropout_prob = 0.0
softmax_use_float = True
weight_col = attn_head_num * attn_dim_per_head
query = torch.randn(batch * tgt_len, weight_col,
dtype=torch.float16).npu()
key = torch.randn(batch * src_len, weight_col,
dtype=torch.float16).npu()
value = torch.randn(batch * tgt_len, weight_col,
dtype=torch.float16).npu()
query_weight = torch.randn(
weight_col, weight_col, dtype=torch.float16).npu()
key_weight = torch.randn(
weight_col, weight_col, dtype=torch.float16).npu()
value_weight = torch.randn(
weight_col, weight_col, dtype=torch.float16).npu()
out_proj_weight = torch.randn(
weight_col, weight_col, dtype=torch.float16).npu()
attn_mask = torch.randn(
batch, attn_head_num, tgt_len, src_len, dtype=torch.float16).npu()
query_bias = torch.randn(weight_col, dtype=torch.float16).npu()
key_bias = torch.randn(weight_col, dtype=torch.float16).npu()
value_bias = torch.randn(weight_col, dtype=torch.float16).npu()
out_proj_bias = torch.randn(weight_col, dtype=torch.float16).npu()
dropout_mask = torch.randint(
0, 1, (weight_col, ), dtype=torch.uint8).npu()
result0, result1, result2, result3, result4, result5, result6, result7 = torch_npu.npu_multi_head_attention(
query, key, value, query_weight, key_weight, value_weight, attn_mask, out_proj_weight, query_bias, key_bias, value_bias, out_proj_bias, dropout_mask, attn_head_num, attn_dim_per_head, src_len, tgt_len, dropout_prob, softmax_use_float)
grad = torch.ones_like(result0)
query_weight_grad, key_weight_grad, value_weight_grad, out_proj_weight_grad, query_grad, key_grad, value_grad, query_bias_grad, key_bias_grad, value_bias_grad, out_proj_bias_grad = torch_npu.npu_multi_head_attention_backward(
query, key, value, query_weight, key_weight, value_weight, out_proj_weight, query_bias, key_bias, value_bias, out_proj_bias,
result2, result3, result4, result5, result6, result7, grad, result1, attn_head_num, attn_dim_per_head, src_len, tgt_len, dropout_prob, softmax_use_float
)
self.assertEqual(query_weight_grad.shape, (weight_col, weight_col))
self.assertEqual(key_weight_grad.shape, (weight_col, weight_col))
self.assertEqual(value_weight_grad.shape, (weight_col, weight_col))
self.assertEqual(out_proj_weight_grad.shape,
(weight_col, weight_col))
self.assertEqual(query_grad.shape, (batch * tgt_len, weight_col))
self.assertEqual(key_grad.shape, (batch * src_len, weight_col))
self.assertEqual(value_grad.shape, (batch * src_len, weight_col))
self.assertEqual(query_bias_grad.shape, (1, weight_col))
self.assertEqual(key_bias_grad.shape, (1, weight_col))
self.assertEqual(value_bias_grad.shape, (1, weight_col))
self.assertEqual(out_proj_bias_grad.shape, (1, weight_col))
class TestNpuLstmCell(TestCase):
def test_npu_lstm_cell_meta(self):
with FakeTensorMode():
batch_size = 2
hidden_size = 64
input_size = 128
input_ = torch.randn(batch_size, input_size,
dtype=torch.float16).npu()
w_ih = torch.randn(4 * hidden_size, input_size,
dtype=torch.float16).npu()
w_hh = torch.randn(4 * hidden_size, hidden_size,
dtype=torch.float16).npu()
h = torch.randn(batch_size, hidden_size, dtype=torch.float16).npu()
c = torch.randn(batch_size, hidden_size, dtype=torch.float16).npu()
b_ih = torch.randn(4 * hidden_size, dtype=torch.float16).npu()
b_hh = torch.randn(4 * hidden_size, dtype=torch.float16).npu()
y_output, h_out, c_out, i_output, j_output, f_output, o_output, tanhc = torch_npu.npu_lstm_cell(
input_, w_ih, w_hh, h, c, b_ih, b_hh
)
self.assertEqual(y_output.shape, (1, batch_size, hidden_size // 4))
self.assertEqual(h_out.shape, (batch_size, hidden_size // 4))
self.assertEqual(c_out.shape, (batch_size, hidden_size // 4))
self.assertEqual(i_output.shape, (1, batch_size, hidden_size // 4))
self.assertEqual(j_output.shape, (1, batch_size, hidden_size // 4))
self.assertEqual(f_output.shape, (1, batch_size, hidden_size // 4))
self.assertEqual(o_output.shape, (1, batch_size, hidden_size // 4))
self.assertEqual(tanhc.shape, (1, batch_size, hidden_size // 4))
class TestNpuLstmCellBackward(TestCase):
def test_npu_lstm_cell_backward_meta(self):
with FakeTensorMode():
batch_size = 2
hidden_size = 64
input_size = 128
input = torch.randn(batch_size, input_size,
dtype=torch.float16).npu()
w_ih = torch.randn(4 * hidden_size, input_size,
dtype=torch.float16).npu()
w_hh = torch.randn(4 * hidden_size, hidden_size,
dtype=torch.float16).npu()
h = torch.randn(batch_size, hidden_size, dtype=torch.float16).npu()
c = torch.randn(batch_size, hidden_size, dtype=torch.float16).npu()
y_output = torch.randn(
1, batch_size, hidden_size, dtype=torch.float16).npu()
h_output = torch.randn(
batch_size, hidden_size, dtype=torch.float16).npu()
c_output = torch.randn(
batch_size, hidden_size, dtype=torch.float16).npu()
i = torch.randn(1, batch_size, hidden_size,
dtype=torch.float16).npu()
j = torch.randn(1, batch_size, hidden_size,
dtype=torch.float16).npu()
f = torch.randn(1, batch_size, hidden_size,
dtype=torch.float16).npu()
o = torch.randn(1, batch_size, hidden_size,
dtype=torch.float16).npu()
tanhc = torch.randn(1, batch_size, hidden_size,
dtype=torch.float16).npu()
grad_y = torch.randn(1, batch_size, hidden_size,
dtype=torch.float16).npu()
grad_h = torch.randn(batch_size, hidden_size,
dtype=torch.float16).npu()
grad_c = torch.randn(batch_size, hidden_size,
dtype=torch.float16).npu()
grad_input, grad_wih, grad_whh, grad_bias1, grad_bias2, grad_ht, grad_ct = torch_npu.npu_lstm_cell_backward(
grad_y, grad_h, grad_c, input, w_ih, w_hh, h, c, y_output, h_output, c_output, i, j, f, o, tanhc
)
self.assertEqual(grad_input.shape, input.shape)
self.assertEqual(grad_wih.shape, w_ih.shape)
self.assertEqual(grad_whh.shape, w_hh.shape)
self.assertEqual(grad_bias1.shape, (4 * hidden_size,))
self.assertEqual(grad_bias2.shape, (4 * hidden_size,))
self.assertEqual(grad_ht.shape, h.shape)
self.assertEqual(grad_ct.shape, c.shape)
class TestNpuFusedAttentionScoreFwd(TestCase):
def test_npu_fused_attention_score_fwd_meta(self):
with FakeTensorMode():
batch_size = 2
num_heads = 8
seq_len = 16
head_dim = 64
query_layer = torch.randn(
batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
key_layer = torch.randn(
batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
value_layer = torch.randn(
batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
attention_mask = torch.randn(
batch_size * num_heads * seq_len * seq_len, dtype=torch.float16).npu()
scale = 1.0 / (head_dim ** 0.5)
keep_prob = 1.0
attention_score, softmax_output, drop_mask = torch_npu.npu_fused_attention_score_fwd(
query_layer, key_layer, value_layer, attention_mask, scale, keep_prob
)
self.assertEqual(attention_score.shape,
(batch_size * seq_len, num_heads * head_dim))
self.assertEqual(softmax_output.shape,
(batch_size, num_heads, seq_len, seq_len))
self.assertEqual(drop_mask.shape, (batch_size *
num_heads * seq_len * seq_len,))
self.assertEqual(drop_mask.dtype, torch.uint8)
class TestNpuFusedAttentionScoreBackward(TestCase):
def test_npu_fused_attention_score_backward_meta(self):
with FakeTensorMode():
batch_size = 2
num_heads = 8
seq_len = 16
head_dim = 64
query_layer = torch.randn(
batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
key_layer = torch.randn(
batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
value_layer = torch.randn(
batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
softmax_output = torch.randn(
batch_size, num_heads, seq_len, seq_len, dtype=torch.float16).npu()
drop_mask = torch.randint(
0, 1, (batch_size * num_heads * seq_len * seq_len, ), dtype=torch.uint8).npu()
grad_output = torch.randn(
batch_size, num_heads, seq_len, head_dim, dtype=torch.float16).npu()
scale = 1.0 / (head_dim ** 0.5)
query_dx, key_dw, value_dw = torch_npu.npu_fused_attention_score_backward(
grad_output, softmax_output, query_layer, key_layer, value_layer, drop_mask, scale, False, False, False, False)
self.assertEqual(query_dx.shape, grad_output.shape)
self.assertEqual(key_dw.shape, grad_output.shape)
self.assertEqual(value_dw.shape, grad_output.shape)
@unittest.skip("skip until CANN is updated to support aclnnRotateQuant.")
class TestRotateQuant(TestCase):
@SupportedDevices(['Ascend910B'])
def test_npu_rotate_quant_int8(self):
with FakeTensorMode():
M = 512
N = 1024
K = 1024
x = torch.randn(M, N, dtype=torch.bfloat16).npu()
rotation = torch.randn(K, K, dtype=torch.bfloat16).npu()
output0_npu, output1_npu = torch_npu.npu_rotate_quant(x, rotation, dst_dtype=torch.int8)
self.assertEqual(output0_npu.shape, torch.Size([M, N]))
self.assertEqual(output0_npu.dtype, torch.int8)
self.assertEqual(output1_npu.shape, torch.Size([M]))
self.assertEqual(output1_npu.dtype, torch.float32)
@SupportedDevices(['Ascend910B'])
def test_npu_rotate_quant_int4(self):
with FakeTensorMode():
M = 512
N = 1024
K = 1024
x = torch.randn(M, N, dtype=torch.bfloat16).npu()
rotation = torch.randn(K, K, dtype=torch.bfloat16).npu()
output0_npu, output1_npu = torch_npu.npu_rotate_quant(x, rotation, dst_dtype=torch.quint4x2)
self.assertEqual(output0_npu.shape, torch.Size([M, N // 8]))
self.assertEqual(output0_npu.dtype, torch.int32)
self.assertEqual(output1_npu.shape, torch.Size([M]))
self.assertEqual(output1_npu.dtype, torch.float32)
@SupportedDevices(['Ascend950'])
def test_npu_rotate_quant_mxfp4(self):
with FakeTensorMode():
M = 512
N = 1024
K = 1024
x = torch.randn(M, N, dtype=torch.bfloat16).npu()
rotation = torch.randn(K, K, dtype=torch.bfloat16).npu()
output0_npu, output1_npu = torch_npu.npu_rotate_quant(
x, rotation, dst_dtype=torch_npu.float4_e2m1fn_x2, axis=-1, round_mode="rint")
self.assertEqual(output0_npu.shape, torch.Size([M, N // 2]))
self.assertEqual(output0_npu.dtype, torch.uint8)
self.assertEqual(output1_npu.shape, torch.Size([M, 16, 2]))
self.assertEqual(output1_npu.dtype, torch_npu.float8_e8m0fnu)
@SupportedDevices(['Ascend950'])
def test_npu_rotate_quant_mxfp8_e5m2(self):
with FakeTensorMode():
M = 512
N = 1024
K = 1024
x = torch.randn(M, N, dtype=torch.bfloat16).npu()
rotation = torch.randn(K, K, dtype=torch.bfloat16).npu()
output0_npu, output1_npu = torch_npu.npu_rotate_quant(
x, rotation, dst_dtype=torch.float8_e5m2, axis=-1, round_mode="rint")
self.assertEqual(output0_npu.shape, torch.Size([M, N]))
self.assertEqual(output0_npu.dtype, torch.float8_e5m2)
self.assertEqual(output1_npu.shape, torch.Size([M, 16, 2]))
self.assertEqual(output1_npu.dtype, torch_npu.float8_e8m0fnu)
@SupportedDevices(['Ascend950'])
def test_npu_rotate_quant_mxfp8_e4m3fn(self):
with FakeTensorMode():
M = 512
N = 1024
K = 1024
x = torch.randn(M, N, dtype=torch.bfloat16).npu()
rotation = torch.randn(K, K, dtype=torch.bfloat16).npu()
output0_npu, output1_npu = torch_npu.npu_rotate_quant(
x, rotation, dst_dtype=torch.float8_e4m3fn, axis=-1, round_mode="rint")
self.assertEqual(output0_npu.shape, torch.Size([M, N]))
self.assertEqual(output0_npu.dtype, torch.float8_e4m3fn)
self.assertEqual(output1_npu.shape, torch.Size([M, 16, 2]))
self.assertEqual(output1_npu.dtype, torch_npu.float8_e8m0fnu)
@unittest.skip("skip until CANN is updated to support aclnnRmsNormGradQuant.")
class TestNpuRmsNormBackwardQuant(TestCase):
def test_npu_rms_norm_backward_quant_meta(self):
dy = torch.randn(2, 4, 64, dtype=torch.float16).to("npu")
x = torch.randn(2, 4, 64, dtype=torch.float16).to("npu")
rstd = torch.randn(2, 4, dtype=torch.float32).to("npu")
gamma = torch.randn(64, dtype=torch.float16).to("npu")
scale_x = torch.randn(1, dtype=torch.float32).to("npu")
offset_x = torch.zeros(1, dtype=torch.int32).to("npu")
actual_dx, actual_dgamma = torch_npu._npu_rms_norm_backward_quant(
dy, x, rstd, gamma, scale_x, offset_x=offset_x, dst_type=1)
with FakeTensorMode():
fake_dy = torch.randn(2, 4, 64, dtype=torch.float16).to("npu")
fake_x = torch.randn(2, 4, 64, dtype=torch.float16).to("npu")
fake_rstd = torch.randn(2, 4, dtype=torch.float32).to("npu")
fake_gamma = torch.randn(64, dtype=torch.float16).to("npu")
fake_scale_x = torch.randn(1, dtype=torch.float32).to("npu")
fake_offset_x = torch.zeros(1, dtype=torch.int32).to("npu")
fake_dx, fake_dgamma = torch_npu._npu_rms_norm_backward_quant(
fake_dy, fake_x, fake_rstd, fake_gamma, fake_scale_x,
offset_x=fake_offset_x, dst_type=1)
self.assertEqual(actual_dx.shape, fake_dx.shape)
self.assertEqual(actual_dgamma.shape, fake_dgamma.shape)
class FakeTensorFormatCastTest(TestCase):
"""FakeTensor tests for _npu_format_cast meta registrations."""
ACL_FORMAT_ND = 2
ACL_FORMAT_NC1HWC0 = 4
ACL_FORMAT_FRACTAL_NZ = 29
ACL_FORMAT_NCDHW = 30
ACL_FORMAT_NDC1HWC0 = 32
ACL_FORMAT_FRACTAL_Z_3D = 33
def test_format_cast_2d_shape(self):
with FakeTensorMode():
x = torch.randn(16, 32, device='npu')
out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
self.assertTrue(isinstance(out, FakeTensor))
self.assertEqual(list(out.shape), [16, 32])
self.assertEqual(out.dtype, torch.float32)
def test_format_cast_4d_shape(self):
with FakeTensorMode():
x = torch.randn(2, 32, 4, 4, device='npu', dtype=torch.float16)
out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_NC1HWC0)
self.assertTrue(isinstance(out, FakeTensor))
self.assertEqual(list(out.shape), [2, 32, 4, 4])
self.assertEqual(out.dtype, torch.float16)
def test_format_cast_5d_shape(self):
with FakeTensorMode():
x = torch.randn(1, 16, 2, 4, 4, device='npu')
out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_NDC1HWC0)
self.assertTrue(isinstance(out, FakeTensor))
self.assertEqual(list(out.shape), [1, 16, 2, 4, 4])
def test_format_cast_nd_to_nd(self):
with FakeTensorMode():
x = torch.randn(8, 16, device='npu')
out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_ND)
self.assertTrue(isinstance(out, FakeTensor))
self.assertEqual(list(out.shape), [8, 16])
def test_format_cast_non_aligned_shape(self):
with FakeTensorMode():
x = torch.randn(15, 17, device='npu', dtype=torch.float16)
out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
self.assertEqual(list(out.shape), [15, 17])
self.assertEqual(out.dtype, torch.float16)
def test_format_cast_int8_shape(self):
with FakeTensorMode():
x = torch.randint(-128, 127, (32, 64), device='npu', dtype=torch.int8)
out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
self.assertEqual(list(out.shape), [32, 64])
self.assertEqual(out.dtype, torch.int8)
def test_format_cast_preserves_device(self):
with FakeTensorMode():
x = torch.randn(4, 8, device='npu')
out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
self.assertEqual(out.device.type, 'npu')
def test_format_cast_aclnn_overload_shape(self):
with FakeTensorMode():
x = torch.randn(16, 32, device='npu', dtype=torch.float16)
out = torch.ops.npu._npu_format_cast.aclnn(x, self.ACL_FORMAT_FRACTAL_NZ, 5)
self.assertTrue(isinstance(out, FakeTensor))
self.assertEqual(list(out.shape), [16, 32])
self.assertEqual(out.dtype, torch.float16)
def test_format_cast_aclnn_int8(self):
with FakeTensorMode():
x = torch.randint(-128, 127, (32, 64), device='npu', dtype=torch.int8)
out = torch.ops.npu._npu_format_cast.aclnn(x, self.ACL_FORMAT_FRACTAL_NZ, 1)
self.assertEqual(list(out.shape), [32, 64])
self.assertEqual(out.dtype, torch.int8)
def test_format_cast_input_dtype_overload_shape(self):
with FakeTensorMode():
x = torch.randn(16, 32, device='npu', dtype=torch.float16)
out = torch.ops.npu._npu_format_cast.input_dtype(x, self.ACL_FORMAT_FRACTAL_NZ, 5, 5)
self.assertTrue(isinstance(out, FakeTensor))
self.assertEqual(list(out.shape), [16, 32])
def test_format_cast_input_dtype_int8(self):
with FakeTensorMode():
x = torch.randint(-128, 127, (32, 64), device='npu', dtype=torch.int8)
out = torch.ops.npu._npu_format_cast.input_dtype(x, self.ACL_FORMAT_FRACTAL_NZ, 1, 1)
self.assertEqual(list(out.shape), [32, 64])
self.assertEqual(out.dtype, torch.int8)
def test_format_cast_bfloat16(self):
with FakeTensorMode():
x = torch.randn(8, 16, device='npu', dtype=torch.bfloat16)
out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
self.assertEqual(out.dtype, torch.bfloat16)
self.assertEqual(list(out.shape), [8, 16])
def test_format_cast_float32(self):
with FakeTensorMode():
x = torch.randn(8, 16, device='npu', dtype=torch.float32)
out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
self.assertEqual(out.dtype, torch.float32)
self.assertEqual(list(out.shape), [8, 16])
def test_format_cast_3d_shape(self):
with FakeTensorMode():
x = torch.randn(4, 15, 17, device='npu', dtype=torch.float16)
out = torch.ops.npu._npu_format_cast(x, self.ACL_FORMAT_FRACTAL_NZ)
self.assertEqual(list(out.shape), [4, 15, 17])
if __name__ == "__main__":
run_tests()