import contextlib
import copy
import itertools
import unittest
import weakref
from unittest.mock import patch
import numpy as np
import torch
import torch._dynamo
import torch._functorch.config
import torch._prims as prims
import torch_npu
import torch_npu.testing
import torch.testing._internal.optests as optests
from torch import distributed as dist
from torch._dynamo.testing import rand_strided
from torch._subclasses.fake_tensor import (
FakeTensor,
FakeTensorMode,
FakeTensorConverter,
DynamicOutputShapeException,
UnsupportedOperatorException,
)
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.testing import FileCheck
from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_utils import (
TestCase, TEST_WITH_TORCHDYNAMO, run_tests, skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, parametrize,
instantiate_parametrized_tests)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten
RUN_NPU = torch.npu.is_available()
class FakeTensorTest(TestCase):
def checkType(self, t, device_str, size):
self.assertTrue(isinstance(t, FakeTensor))
self.assertEqual(t.device.type, device_str)
self.assertEqual(list(t.size()), size)
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_npu_initialized(self):
with FakeTensorMode():
p = torch.randn(4, 2, requires_grad=True, device='npu')
x = torch.randn(8, 4, device='npu')
y = torch.mm(x, p).square().sum()
y.backward()
def test_basic(self):
x = torch.empty(2, 2, device="cpu")
y = torch.empty(4, 2, 2, device="cpu")
with FakeTensorMode() as mode:
x = mode.from_tensor(x)
y = mode.from_tensor(y)
z = x + y
self.assertEqual(z.shape, (4, 2, 2))
self.assertEqual(z.device, torch.device("cpu"))
self.assertTrue(isinstance(z, FakeTensor))
def test_custom_op_fallback(self):
from torch.library import Library, impl
test_lib = Library("my_test_op", "DEF")
test_lib.define('foo(Tensor self) -> Tensor')
@impl(test_lib, 'foo', 'CPU')
def foo_impl(self):
return self.cos()
x = torch.empty(2, 2, device="cpu")
with self.assertRaisesRegex(UnsupportedOperatorException, "my_test_op.foo.default"):
with FakeTensorMode(allow_fallback_kernels=True) as mode:
x = mode.from_tensor(x)
torch.ops.my_test_op.foo(x)
def test_parameter_instantiation(self):
with FakeTensorMode():
x = torch.rand([4])
y = torch.nn.parameter.Parameter(x)
self.assertTrue(isinstance(y, torch.nn.Parameter))
@unittest.skipIf(not dist.is_available(), "requires distributed")
def test_fsdp_flat_param(self):
from torch.distributed.fsdp._flat_param import FlatParameter
with FakeTensorMode() as m:
data = torch.randn(2, 2)
param = FlatParameter(data, requires_grad=True)
self.assertIsInstance(param, FlatParameter)
self.assertIsInstance(param, torch.nn.Parameter)
self.assertIsInstance(param, FakeTensor)
def test_non_parameter_grad(self):
mode = FakeTensorMode()
t = torch.rand([4], requires_grad=True)
fake_t = mode.from_tensor(t)
self.assertEqual(fake_t.requires_grad, t.requires_grad)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_index_npu_with_cpu(self):
with FakeTensorMode():
x = torch.rand([2048], device='npu')
out = x[torch.zeros([36], dtype=torch.int64)]
self.checkType(out, "npu", [36])
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_shape_take_not_device(self):
with FakeTensorMode():
x = torch.empty(1, device="cpu")
y = torch.empty(8, 8, device="npu")
out = x.resize_as_(y)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.device.type, "cpu")
self.assertTrue(isinstance(out, FakeTensor))
def test_repr(self):
with FakeTensorMode():
x = torch.empty(2, 2, device="cpu")
self.assertEqual(repr(x), 'FakeTensor(..., size=(2, 2))')
x = torch.empty(2, 2, device="meta")
self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_zero_dim(self):
with FakeTensorMode() as mode:
x = torch.tensor(0.)
y = torch.rand([4, 4], device="npu")
out = x + y
self.assertEqual(out.shape, (4, 4))
self.assertEqual(out.device, y.device)
self.assertTrue(isinstance(out, FakeTensor))
def test_nan_to_num(self):
with FakeTensorMode():
for dtype in [torch.float16, torch.float32]:
x = torch.rand([4], dtype=dtype)
y = torch.nan_to_num(x, nan=None)
z = torch.nan_to_num(x, 0.0)
self.assertEqual(dtype, y.dtype)
self.assertEqual(dtype, z.dtype)
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_throw(self):
x = torch.tensor(0.)
with FakeTensorMode() as mode:
x_conv = mode.from_tensor(x)
y = torch.rand([4, 4], device="npu")
z = torch.rand([4, 4], device="cpu")
self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_type_as(self):
with FakeTensorMode():
x = torch.rand([16, 1], device="cpu")
y = torch.rand([4, 4], device="npu")
out = x.type_as(y)
self.assertEqual(out.device.type, "npu")
self.assertTrue(isinstance(out, FakeTensor))
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_setitem(self):
for device in ["cpu", "npu"]:
with FakeTensorMode():
x = torch.rand([16, 1], device=device)
x[..., 0] = 0
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_device_inplace_copy(self):
with FakeTensorMode():
x = torch.rand([8, 8], device="cpu")
y = torch.rand([8, 8], device="npu")
self.assertEqual(x.copy_(y).device.type, "cpu")
self.assertEqual(y.copy_(x).device.type, "npu")
def test_fake_dispatch_keys(self):
with FakeTensorMode():
x = torch.rand([4])
f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU")
f.run(torch._C._dispatch_key_set(x))
with torch.inference_mode():
x = torch.rand([4])
y = x + x
FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
def test_constructor(self):
with FakeTensorMode():
x = torch.rand([4, 4], device="cpu")
self.assertTrue(isinstance(x, FakeTensor))
self.assertTrue(x.device.type == "cpu")
def test_mode(self):
with FakeTensorMode():
y = torch.rand([4], device="cpu")
out = y + y
self.assertTrue(isinstance(out, FakeTensor))
def test_full(self):
with torch._subclasses.CrossRefFakeMode():
y = torch.full((4, 4), 1)
def check_function_with_fake(self, fn):
out = fn()
with torch._subclasses.FakeTensorMode():
out_fake = fn()
for a, b in zip(tree_flatten(out), tree_flatten(out_fake)):
if not isinstance(a, FakeTensor):
self.assertTrue(not isinstance(b, FakeTensor))
continue
prims.utils.compare_tensor_meta(a, b, check_strides=True)
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_non_kwarg_device(self):
with FakeTensorMode():
x = torch.rand([16, 1], device="cpu")
y = x.to(torch.device("cpu"))
self.assertIs(x, y)
z = x.to(torch.device("npu"))
self.assertEqual(z.device.type, "npu")
def test_non_overlapping_stride_zero(self):
def foo():
x = torch.empty_strided([1, 3, 427, 640], (0, 1, 1920, 3))
return x.half()
self.check_function_with_fake(foo)
def test_fake_mode_error(self):
x = torch.rand([4, 4])
with self.assertRaisesRegex(Exception, "Please convert all Tensors"):
with FakeTensorMode():
y = x[0]
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_fake_grad_copy(self):
x = torch.rand([4, 4], requires_grad=True)
x.grad = torch.rand([4, 4])
mode = FakeTensorMode()
fake_x = mode.from_tensor(x)
prims.utils.compare_tensor_meta(fake_x, x)
prims.utils.compare_tensor_meta(fake_x.grad, x.grad)
self.assertTrue(isinstance(fake_x.grad, FakeTensor))
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_like_constructor(self):
with FakeTensorMode():
x = torch.rand([4, 4])
y = torch.ones_like(x)
self.assertTrue(isinstance(y, FakeTensor))
self.assertEqual(y.device.type, "cpu")
z = torch.ones_like(x, device="npu")
self.assertTrue(isinstance(z, FakeTensor))
self.assertEqual(z.device.type, "npu")
def test_binary_op_type_promotion(self):
with FakeTensorMode():
x = torch.empty([2, 2], dtype=torch.float)
y = torch.empty([2, 2], dtype=torch.int64)
try:
out = x / y
except ZeroDivisionError:
print("Error: Division by zero is not allowed")
self.assertEqual(out.dtype, torch.float)
self.assertEqual(out.device.type, "cpu")
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_from_numpy(self):
with FakeTensorMode():
x = torch.tensor(np.zeros([4, 4]))
self.checkType(x, "cpu", [4, 4])
def test_randperm(self):
x = torch.randperm(10)
y = torch.randperm(5, device="cpu")
with FakeTensorMode():
x1 = torch.randperm(10)
prims.utils.compare_tensor_meta(x, x1)
y1 = torch.randperm(5, device="cpu")
prims.utils.compare_tensor_meta(y, y1)
def test_print_in_fake_mode(self):
x = torch.zeros(2)
with FakeTensorMode():
out = str(x)
self.assertNotIn("FakeTensor", out)
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_upsample_bilinear_small_channels(self):
out = []
mode = FakeTensorMode()
for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
with context():
arg0_1 = torch.empty_strided((3, 427, 640), (1, 1920, 3), dtype=torch.float32, device='npu')
unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 0)
out.append(torch.ops.aten.upsample_bilinear2d.default(unsqueeze, [800, 1199], False))
self.assertTrue(out[1].is_contiguous())
self.checkMetaProps(out[0], out[1])
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_cpu_fallback(self):
with FakeTensorMode(allow_fallback_kernels=False):
filters = torch.randn(8, 4, 3, 3).npu()
inputs = torch.randn(1, 4, 5, 5).npu()
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
self.assertEqual(out.device.type, "npu")
self.assertEqual(list(out.size()), [1, 8, 5, 5])
with FakeTensorMode(allow_fallback_kernels=True):
filters = torch.randn(8, 20, 3, 3).npu()
inputs = torch.randn(1, 7, 10, 5).npu()
with self.assertRaises(RuntimeError):
torch.nn.functional.conv2d(inputs, filters, padding=1)
with FakeTensorMode(allow_fallback_kernels=True):
filters = torch.randn(8, 4, 3, 3).npu()
inputs = torch.randn(1, 4, 5, 5).npu()
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
self.assertEqual(out.device.type, "npu")
self.assertEqual(list(out.size()), [1, 8, 5, 5])
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_out_multi_device(self):
with FakeTensorMode():
x = torch.rand([4])
y = torch.rand([4], device="npu")
with self.assertRaisesRegex(Exception, "found two different devices"):
torch.sin(x, out=y)
with self.assertRaisesRegex(Exception, "found two different devices"):
x.add_(y)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_normalize_device(self):
with FakeTensorMode():
x = torch.empty(1, device="npu")
y = torch.empty(1, device=f"npu:{torch.npu.current_device()}")
out = x + y
self.checkType(out, "npu", [1])
def test_recursive_invocation(self):
mode = FakeTensorMode()
with mode:
x = torch.tensor(2)
mode.in_kernel_invocation = True
y = x + x
self.assertTrue(mode.in_kernel_invocation)
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_lstm(self):
with FakeTensorMode(allow_fallback_kernels=False):
N = 5
L = 4
H_in = 2
hidden_size = 3
proj_size = 2
num_layers = 2
bidir = False
D = 2 if bidir else 1
H_out = proj_size if proj_size > 0 else hidden_size
lstm = torch.nn.LSTM(input_size=H_in, hidden_size=hidden_size,
num_layers=num_layers, proj_size=proj_size, batch_first=False,
bias=True, bidirectional=bidir, device='npu')
h_0 = torch.randn((num_layers * D, N, H_out), requires_grad=False, device='npu')
c_0 = torch.randn((num_layers * D, N, hidden_size), requires_grad=False, device='npu')
inp = torch.randn((L, N, H_in), requires_grad=False, device='npu')
(output, (h_n, c_n)) = lstm(inp, (h_0, c_0))
output.sum().backward()
self.assertEqual(output.shape, (L, N, D * H_out))
self.assertEqual(h_n.shape, (D * num_layers, N, H_out))
self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size))
def test_data_dependent_operator(self):
with FakeTensorMode(allow_fallback_kernels=False):
x = torch.rand([10, 10])
self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
def checkMetaProps(self, t1, t2):
prims.utils.compare_tensor_meta(t1, t2, check_strides=True)
@skipIfCrossRef
def test_deepcopy(self):
with FakeTensorMode() as mode:
pass
mod = torch.nn.BatchNorm2d(10)
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
mod_copied = copy.deepcopy(mod)
def check_copy(mod, mod_copied):
for name, param in itertools.chain(mod.named_parameters(), mod.named_buffers()):
param_copied = getattr(mod_copied, name)
self.checkMetaProps(param, param_copied)
self.assertTrue(isinstance(param_copied, FakeTensor))
self.assertEqual(isinstance(param, torch.nn.Parameter), isinstance(param_copied, torch.nn.Parameter))
self.assertEqual(param.requires_grad, param_copied.requires_grad)
check_copy(mod, mod_copied)
class ModuleNew(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.rand([10, 2])
self.b = self.a
self.c = self.a[0]
mod = ModuleNew()
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
mod_copied = copy.deepcopy(mod)
self.assertIs(mod_copied.a, mod_copied.b)
self.assertEqual(mod_copied.b.storage()._cdata, mod_copied.a.storage()._cdata)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_new(self):
with FakeTensorMode():
a = torch.rand([16, 1])
self.checkType(a.new(10, 10), "cpu", [10, 10])
self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
b = torch.rand([4, 4], device='npu')
self.checkType(b.new(device='npu'), "npu", [0])
self.checkType(a.new(torch.rand([1])), "cpu", [1])
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_scalar_inputs(self):
with FakeTensorMode():
self.checkType(torch.div(3, 2), "cpu", [])
ten = torch.zeros(2, dtype=torch.int32) * 2.0
self.assertEqual(ten.dtype, torch.float)
self.checkType(ten, "cpu", [2])
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_allow_meta(self):
def run_meta():
with FakeTensorMode():
x = torch.rand([4], device="meta")
return x + x
self.checkType(run_meta(), "meta", [4])
with patch.object(torch._functorch.config, "fake_tensor_allow_meta", False):
self.assertRaises(Exception, run_meta)
def test_embedding_bag_meta(self):
def f():
embedding = torch.nn.EmbeddingBag(10, 3, mode='sum', device='meta')
ipt = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
offsets = torch.tensor([0, 4], dtype=torch.long)
return embedding(ipt, offsets)
real_out = f()
with FakeTensorMode():
fake_out = f()
for r, f in zip(real_out, fake_out):
self.assertEqual(r.size(), f.size())
self.assertEqual(r.device, f.device)
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
def test_mixed_real_and_fake_inputs(self):
class _TestPattern(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.bn = torch.nn.BatchNorm2d(1)
def forward(self, ipt):
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
try:
scale_factor = self.bn.weight / running_std
except ZeroDivisionError:
print("Error: Division by zero is not allowed")
weight_shape = [1] * len(self.conv.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.conv.weight.shape)
bias_shape[1] = -1
scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
zero_bias = torch.zeros_like(self.conv.bias, dtype=ipt.dtype)
conv = self.conv._conv_forward(ipt, scaled_weight, zero_bias)
try:
conv_orig = conv / scale_factor.reshape(bias_shape)
except ZeroDivisionError:
print("Error: Division by zero is not allowed")
conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
conv = self.bn(conv_orig)
return conv
example_inputs = (torch.randn(1, 1, 3, 3),)
mod = _TestPattern()
with FakeTensorMode(allow_non_fake_inputs=True):
out = mod(torch.randn(1, 1, 3, 3))
self.checkType(out, "cpu", (1, 1, 3, 3))
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_aten_copy_multi_device(self):
with FakeTensorMode():
x1 = torch.rand(4, device="cpu")
x2 = torch.rand(4, device="npu")
copy1 = torch.ops.aten.copy.default(x1, x2)
copy2 = torch.ops.aten.copy.default(x2, x1)
out = torch.empty(4, device="cpu")
torch.ops.aten.copy.out(x1, x2, out=out)
self.checkType(copy1, "cpu", (4,))
self.checkType(copy2, "npu", (4,))
self.checkType(out, "cpu", (4,))
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_aten_index_multi_device(self):
with FakeTensorMode():
x1 = torch.rand(4, 4, device="cpu")
x2 = torch.rand(4, 4, device="npu")
i1 = torch.tensor([0, 1], device="npu")
i2 = torch.tensor([0, 1], device="cpu")
r1 = torch.ops.aten.index(x1, i1)
r2 = torch.ops.aten.index(x2, i2)
y1 = torch.rand(4, device="cpu")
y2 = torch.rand(4, device="npu")
j1 = torch.tensor([2], device="npu")
j2 = torch.tensor([2], device="cpu")
r3 = torch.ops.aten.index_put.default(x1, j1, y1)
r4 = torch.ops.aten.index_put.default(x2, j2, y2)
self.checkType(r1, "cpu", ())
self.checkType(r2, "npu", ())
self.checkType(r3, "cpu", (4, 4))
self.checkType(r4, "npu", (4, 4))
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile")
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_aten_slice_scatter_multi_device(self):
with FakeTensorMode():
x1 = torch.rand(4, 4, device="cpu")
y1 = torch.rand(2, 4, device="npu")
x2 = torch.rand(4, 4, device="npu")
y2 = torch.rand(2, 4, device="cpu")
out = torch.empty(4, 4, device="cpu")
r1 = torch.ops.aten.slice_scatter.default(x1, y1, start=2)
r2 = torch.ops.aten.slice_scatter.default(x2, y2, start=2)
r3 = torch.ops.aten.slice_scatter.out(x1, y1, out=out, start=2)
self.checkType(r1, "cpu", (4, 4))
self.checkType(r2, "npu", (4, 4))
self.checkType(r3, "cpu", (4, 4))
self.checkType(out, "cpu", (4, 4))
def test__adaptive_avg_pool2d_backward(self):
with FakeTensorMode():
grad_out = torch.rand(2, 3, 4, 4)
inp = torch.rand(2, 3, 4, 4).to(memory_format=torch.channels_last)
grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp)
self.assertTrue(torch._prims_common.suggest_memory_format(grad_in) == torch.channels_last)
class FakeTensorConstHandling(TestCase):
def assertConst(self, *args):
for arg in args:
self.assertTrue(arg.constant is not None)
def assertNotConst(self, *args):
for arg in args:
self.assertTrue(arg.constant is None)
def test_simple(self):
with FakeTensorMode():
x = torch.tensor(4.)
self.assertEqual(x.item(), 4.)
def test_inplace_add(self):
with FakeTensorMode():
x = torch.tensor(4.)
y = x.add_(1)
self.assertEqual(x.item(), 5.)
self.assertEqual(y.item(), 5.)
self.assertConst(x, y)
def test_shared_storages(self):
with FakeTensorMode():
x = torch.tensor([4.])
y = x[:]
self.assertEqual(x.storage()._cdata, y.storage()._cdata)
self.assertEqual(x.constant.storage()._cdata, y.constant.storage()._cdata)
def test_constant_invalidation(self):
with FakeTensorMode():
x = torch.tensor([1.])
self.assertConst(x)
y = torch.rand([1])
x.add_(y)
self.assertNotConst(x)
def test_inplace_view_invalidation(self):
with FakeTensorMode():
x = torch.tensor([1])
self.assertConst(x)
x.resize_([2])
self.assertEqual(x.size(0), 2)
self.assertNotConst(x)
def test_fake_tensor_in_intlist_repro(self):
def fn(tensors):
max_size = torch.tensor([800, 1216], dtype=torch.int64)
batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size)
return tensors[0].new_full(batch_shape, 0.0)
with self.assertRaises(torch._subclasses.fake_tensor.DataDependentOutputException):
with torch._subclasses.fake_tensor.FakeTensorMode():
a = torch.randn(3, 800, 1199)
b = torch.randn(3, 800, 800)
inputs = [a, b]
ref = fn(inputs)
def test_fake_tensor_batch_norm_cpu(self):
with torch._subclasses.CrossRefFakeMode():
m = torch.nn.Sequential(
torch.nn.BatchNorm2d(10),
torch.nn.ReLU(),
)
m.eval()
out = m(torch.randn([2, 10, 8, 8]))
def test_shared_storage_invalidation(self):
with FakeTensorMode():
x = torch.tensor([1.])
y = x[:]
self.assertConst(x, y)
y.add_(torch.rand([1]))
self.assertNotConst(x, y)
def test_aliased_const_write(self):
with FakeTensorMode():
x = torch.tensor([1])
y = x.expand([4])
self.assertNotConst(y)
y[0] = 1
self.assertNotConst(x)
def test_constant_propagate_through_functions(self):
with FakeTensorMode():
y = torch.div(4, 4, rounding_mode='trunc')
self.assertConst(y)
def contains_type(type_tmp: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type_tmp) or any(
contains_type(e, maybe_contained_type) for e in type_tmp.containedTypes()
)
class FakeTensorOpInfoTest(TestCase):
@ops(custom_op_db, dtypes=OpDTypes.any_one)
def test_fake(self, device, dtype, op):
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
for sample_input in sample_inputs_itr:
args = (sample_input.input,) + sample_input.args
kwargs = sample_input.kwargs
optests.fake_check(op, args, kwargs)
class FakeTensorConverterTest(TestCase):
def test_memoized_conversion_to_meta(self):
x = torch.rand(2, 2, 2)
mode = FakeTensorMode()
self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x))
def test_memoized_conversion_from_meta(self):
x = torch.rand(2, 2).to(device="meta")
mode = FakeTensorMode()
converter = mode.fake_tensor_converter
self.assertTrue(converter.from_meta_and_device(mode, x, "cpu") is converter.from_meta_and_device(mode, x, "cpu"))
def test_separate_tensor_storages_view(self):
x = torch.rand(2, 2, 2)
y = x[0]
mode = FakeTensorMode()
converter = mode.fake_tensor_converter
x_conv = converter.from_real_tensor(mode, x)
y_conv = converter.from_real_tensor(mode, y)
self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv))
@skipIfTorchDynamo("see pytorch torchdynamo issue 1991")
def test_separate_tensor_storages_non_view(self):
x = torch.rand(2, 2, 2)
y = torch.rand(4, 2)
y.set_(x.storage())
mode = FakeTensorMode()
converter = mode.fake_tensor_converter
x_conv = converter.from_real_tensor(mode, x)
y_conv = converter.from_real_tensor(mode, y)
stor_id = torch._C._storage_id(x_conv)
self.assertEqual(stor_id, torch._C._storage_id(y_conv))
del x
del x_conv
self.assertEqual(len(converter.tensor_memo), 1)
self.assertEqual(len(converter.meta_converter.storage_memo), 1)
del y
del y_conv
self.assertEqual(len(converter.tensor_memo), 0)
self.assertEqual(len(converter.meta_converter.storage_memo), 0)
@skipIfTorchDynamo("see pytorch torchdynamo issue 1991")
def test_dead_weak_ref(self):
x = torch.rand(2, 2, 2)
y = x[0]
mode = FakeTensorMode()
converter = FakeTensorConverter()
x_conv = converter.from_real_tensor(mode, x)
x_conv_storage = x_conv.untyped_storage()
del x_conv
self.assertFalse(x in converter.tensor_memo)
y_conv = converter.from_real_tensor(mode, y)
self.assertEqual(x_conv_storage, y_conv.untyped_storage())
@skipIfTorchDynamo("see pytorch torchdynamo issue 1991")
def test_dead_key(self):
x = torch.rand(2, 2, 2)
mode = FakeTensorMode()
converter = FakeTensorConverter()
x_conv = converter.from_real_tensor(mode, x)
self.assertEqual(len(converter.tensor_memo), 1)
x_conv2 = converter.from_real_tensor(mode, x)
self.assertIs(x_conv2, x_conv)
del x
del x_conv
del x_conv2
self.assertEqual(len(converter.tensor_memo), 0)
def test_no_active_mode(self):
with FakeTensorMode() as mode:
x = torch.empty(2, 2, device="cpu")
y = torch.empty(2, 2, device="cpu")
out = x + y
self.assertEqual(mode, out.fake_mode)
self.assertTrue(isinstance(out, FakeTensor))
self.assertEqual(out.device.type, "cpu")
def test_multiple_modes(self):
t = torch.rand([4])
t2 = torch.rand([4])
with FakeTensorMode() as m:
with FakeTensorMode() as m2:
t_fake = m.from_tensor(t)
t2_fake = m2.from_tensor(t2)
with self.assertRaisesRegex(Exception, "Mixing fake modes"):
t_fake + t2_fake
def test_separate_mode_error(self):
with FakeTensorMode():
x = torch.empty(2, 2, device="cpu")
with FakeTensorMode():
y = torch.empty(2, 2, device="cpu")
self.assertRaises(Exception, lambda: x, y)
@skipIfTorchDynamo("see pytorch torchdynamo issue 1991")
def test_no_ref_cycle(self):
x = torch.rand([4])
mode = FakeTensorMode()
y = mode.from_tensor(x)
self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1)
mode_weak = weakref.ref(mode)
y_weak = weakref.ref(mode)
del mode
del y
self.assertIsNone(mode_weak())
self.assertIsNone(y_weak())
class FakeTensorOperatorInvariants(TestCase):
@staticmethod
def get_aten_op(schema):
namespace, name = schema.name.split("::")
overload = schema.overload_name if schema.overload_name else "default"
try:
namespace == "aten"
except AttributeError:
print("AttributeError: torch.ops.aten has no attribute")
return getattr(getattr(torch.ops.aten, name), overload)
@staticmethod
def get_all_aten_schemas():
for schema in torch._C._jit_get_all_schemas():
namespace = schema.name.split("::")[0]
if namespace != "aten":
continue
yield schema
def test_non_kwarg_only_device(self):
for schema in self.get_all_aten_schemas():
ten_type = torch._C.TensorType.get()
if not any(
contains_type(arg.type, ten_type)
for arg in itertools.chain(schema.arguments, schema.returns)
):
continue
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
has_non_kwarg_device = any(
not arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
for arg in schema.arguments
)
if has_non_kwarg_device:
self.assertTrue(
self.get_aten_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops
)
def test_tensor_constructors_all_have_kwarg_device(self):
for schema in self.get_all_aten_schemas():
op = self.get_aten_op(schema)
if not torch._subclasses.fake_tensor._is_tensor_constructor(op):
continue
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
has_kwarg_device = any(
arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
for arg in schema.arguments
)
self.assertTrue(
has_kwarg_device or op == torch.ops.aten._list_to_tensor.default
)
@unittest.expectedFailure
def test_sparse_new(self):
with FakeTensorMode():
indices = torch.randn(1, 1, dtype=torch.int64)
values = torch.randn(1)
extra = (2,)
sparse = torch.randn(1).to_sparse()
sparse2 = sparse.new(indices, values, extra)
def test_tensor_new(self):
with FakeTensorMode():
x = torch.Tensor([1, 2, 3])
self.assertIsInstance(x, FakeTensor)
def test_like_ops(self):
for schema in self.get_all_aten_schemas():
if "_like" == schema.name[-5:]:
op = self.get_aten_op(schema)
self.assertIn(op, torch._subclasses.fake_tensor._like_tensor_constructors)
def test_embedding_bag_private(self):
args = [
torch.ones(6, 1),
torch.ones(6, dtype=torch.int64),
torch.arange(2, dtype=torch.int64),
False,
2,
]
ref_out = torch.ops.aten._embedding_bag(*args)
with FakeTensorMode() as m:
meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
meta_out = torch.ops.aten._embedding_bag(*meta_args)
self.assertEqual(len(ref_out), len(meta_out))
for ref_o, meta_o in zip(ref_out, meta_out):
self.assertEqual(ref_o.size(), meta_o.size())
def test_cross_entropy_loss(self):
inp = torch.randn(3, 5)
target = torch.randint(5, (3,), dtype=torch.long)
weight = torch.rand(5)
fn = torch.nn.functional.cross_entropy
for w in (weight, None):
args = (inp, target, w)
ref = fn(*args)
with FakeTensorMode() as m:
meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
meta_out = torch.nn.functional.cross_entropy(*meta_args, label_smoothing=0.5)
self.assertEqual(ref.size(), meta_out.size())
@skipIfRocm
def test_flash_attention(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, arg1, arg2, arg3):
torch.ops.aten._scaled_dot_product_flash_attention(arg1, arg2, arg3)
args_new = [
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "npu"),
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "npu"),
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "npu"),
]
args = [rand_strided(bsz, num_heads, seq_len, head_dim) for
(bsz, num_heads, seq_len, head_dim) in args_new]
try:
with torch._subclasses.CrossRefFakeMode():
Repro()(*args)
except RuntimeError as e:
self.assertTrue("output[0]" not in str(e))
self.assertTrue("found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" in str(e))
@skipIfRocm
@unittest.skipIf(not RUN_NPU, "requires npu")
def test_conv_c1_backward(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, arg1, arg2, arg3):
torch.ops.aten.convolution_backward.default(
arg1,
arg2,
arg3,
[1],
[1, 1],
[1, 1],
[1, 1],
False,
[0, 0],
1,
[True, True, False],
)
args_new = [
((16, 1, 128, 128), (16384, 16384, 128, 1), torch.float16, "npu"),
((16, 64, 128, 128), (1048576, 1, 8192, 64), torch.float16, "npu"),
((1, 64, 3, 3), (576, 9, 3, 1), torch.float16, "npu"),
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args_new]
with torch._subclasses.CrossRefFakeMode():
Repro()(*args)
def test_no_dispatch_with_like_function(self):
class CountingMode(TorchDispatchMode):
def __init__(self):
self.count = 0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
self.count += 1
return func(*args, **kwargs)
with FakeTensorMode():
x = torch.randn(2)
with CountingMode() as mode:
with no_dispatch():
torch.zeros_like(x)
self.assertEqual(mode.count, 0)
class FakeTensorPropTest(TestCase):
def test_fake_tensor_prop_on_nn_module(self):
class ToyNnModuleWithParameters(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(4, 3)
self.layer2 = torch.nn.Linear(3, 2)
def forward(self, value):
value = self.layer1(value)
value = torch.relu(value)
value = self.layer2(value)
return value
model = ToyNnModuleWithParameters()
value = torch.randn(5, 4)
graph_model = torch.fx.symbolic_trace(model, (value,))
with FakeTensorMode() as fake_tensor_mode:
def to_fake_tensor(x):
if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor):
return fake_tensor_mode.from_tensor(x)
return x
chain_iter = itertools.chain(graph_model.named_parameters(), graph_model.named_buffers())
fake_parameters_and_buffers = {
k: to_fake_tensor(v)
for k, v in chain_iter
}
with torch.nn.utils.stateless._reparametrize_module(
graph_model, fake_parameters_and_buffers
):
result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value)
self.assertTrue(isinstance(result, FakeTensor))
self.assertEqual(result.shape, (5, 2))
failed = False
try:
FakeTensorProp(graph_model).propagate(value)
except AssertionError:
failed = True
self.assertTrue(failed)
def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
class OptionalArgumentInBetween(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(4, 3)
self.layer2 = torch.nn.Linear(3, 2)
def forward(self, value, another_value=None, another_optional_value=None):
if another_value is None:
another_value = torch.rand_like(value)
if another_optional_value is None:
another_optional_value = torch.rand_like(value)
value = value + another_value + another_optional_value
return value * value
fake_mode = FakeTensorMode(allow_non_fake_inputs=True, allow_fallback_kernels=False)
with fake_mode:
model = OptionalArgumentInBetween()
value = torch.randn(5, 4)
another_optional_value = torch.randn(5, 4)
graph_model = torch.fx.symbolic_trace(model, (value, None, another_optional_value))
FakeTensorProp(graph_model, fake_mode).propagate(value, None, another_optional_value)
class TestFastGelu(TestCase):
def test_fast_gelu(self):
with FakeTensorMode():
a = torch.randn(2, 3).npu()
a.requires_grad = True
result = torch_npu.fast_gelu(a)
self.assertTrue(a.shape == result.shape)
def test_fast_gelu_backward(self):
with FakeTensorMode():
a = torch.randn(2, 3).npu()
a.requires_grad = True
result = torch_npu.fast_gelu(a)
result.sum().backward()
self.assertTrue(a.shape == a.grad.shape)
def test_npu_fast_gelu(self):
with FakeTensorMode():
a = torch.randn(2, 3).npu()
a.requires_grad = True
result = torch_npu.npu_fast_gelu(a)
self.assertEqual(a.shape, result.shape)
class TestGelu(TestCase):
def test_npu_gelu(self):
with FakeTensorMode():
a = torch.randn(10, 3, 4).npu()
a.requires_grad = True
result = torch.ops.npu.npu_gelu(a)
self.assertTrue(a.shape == result.shape)
def test_npu_gelu_backward(self):
with FakeTensorMode():
a = torch.randn(10, 3, 4).npu()
a.requires_grad = True
result = torch.ops.npu.npu_gelu(a)
result.sum().backward()
self.assertTrue(a.shape == a.grad.shape)
class TestIncreFlashAttention(TestCase):
def testIncreFlashAttention(self):
with FakeTensorMode():
q = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
k = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
v = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
res = torch.ops.npu.npu_incre_flash_attention(q, k, v)
print("q.shape: ", q.shape)
print("res.shape: ", res.shape)
self.assertTrue(q.shape == res.shape)
def test_incre_flash_attention_int8_in(self):
with FakeTensorMode():
q = torch.randint(1, 40, (1, 128), dtype=torch.int8).npu()
k = torch.randint(1, 40, (1, 128), dtype=torch.int8).npu()
v = torch.randint(1, 40, (1, 128), dtype=torch.int8).npu()
res = torch.ops.npu.npu_incre_flash_attention(q, k, v)
print("q.shape: ", q.shape)
print("res.shape: ", res.shape)
self.assertTrue(q.shape == res.shape)
self.assertTrue(res.dtype == torch.half)
def test_incre_flash_attention_kv_padding_size(self):
with FakeTensorMode():
q = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
k = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
v = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
kv_padding_size = torch.randint(1, 2, (1,)).npu()
s_min = 1
s_max = 1
actual_seq_shape = [1]
actual_seq = np.random.uniform(1, 1, actual_seq_shape).astype(np.int64).tolist()
res = torch.ops.npu.npu_incre_flash_attention(q, k, v, actual_seq_lengths=actual_seq,
kv_padding_size=kv_padding_size)
print("q.shape: ", q.shape)
print("res.shape: ", res.shape)
self.assertTrue(q.shape == res.shape)
class TestNpuBmmV2(TestCase):
def test_npu_bmmV2(self):
with FakeTensorMode():
npu_input1 = torch.randn(10, 3, 4).npu()
npu_input2 = torch.randn(10, 4, 5).npu()
output_size = []
result = torch_npu.npu_bmmV2(npu_input1, npu_input2, output_size)
self.assertEqual(result.dtype, npu_input1.dtype)
self.assertEqual(result.shape, torch.matmul(npu_input1, npu_input2).shape)
class TestNpuDropout(TestCase):
def test_npu_dropout(self):
b = torch.randn(2, 3).npu()
b.requires_grad = True
result_b = torch_npu._npu_dropout(b, 0.5)
with FakeTensorMode():
a = torch.randn(2, 3).npu()
a.requires_grad = True
result = torch_npu._npu_dropout(a, 0.5)
self.assertTrue(result[0].shape == result_b[0].shape)
self.assertTrue(result[1].shape == result_b[1].shape)
def test_npu_dropout_backward(self):
with FakeTensorMode():
a = torch.randn(2, 3).npu()
a.requires_grad = True
result = torch_npu._npu_dropout(a, 0.5)
result[0].sum().backward()
self.assertTrue(a.shape == a.grad.shape)
class TestNpuDtypeCast(TestCase):
def test_npu_dtype_cast(self):
with FakeTensorMode():
npu_input = torch.randn((2, 3), dtype=torch.float32).npu()
dst_dtype = torch.float16
result = torch_npu.npu_dtype_cast(npu_input, dst_dtype)
self.assertEqual(result.dtype, dst_dtype)
self.assertEqual(result.shape, npu_input.shape)
def test_npu_dtype_cast_backward(self):
with FakeTensorMode():
npu_input = torch.randn((2, 3), dtype=torch.float32).npu()
npu_input.requires_grad = True
dst_dtype = torch.float16
result = torch_npu.npu_dtype_cast(npu_input, dst_dtype)
result.sum().backward()
self.assertEqual(result.dtype, dst_dtype)
self.assertEqual(npu_input.shape, npu_input.grad.shape)
class TestNpuRotaryMul(TestCase):
def test_npu_rotary_mul(self):
with FakeTensorMode():
embedding = torch.randn(2, 8192, 5, 125, dtype=torch.float16, requires_grad=True).npu()
cosine = torch.randn(1, 8192, 1, 128, dtype=torch.float16, requires_grad=True).npu()
sine = torch.randn(1, 8192, 1, 128, dtype=torch.float16, requires_grad=True).npu()
ret = torch.ops.npu.npu_rotary_mul(embedding, cosine, sine)
self.assertEqual(embedding.shape, ret.shape)
self.assertEqual(embedding.dtype, ret.dtype)
class TestNpuRotaryMulBackward(TestCase):
def test_npu_rotary_mul_backward(self):
with FakeTensorMode():
grad = torch.randn(4, 2048, 40, 128, dtype=torch.float16).npu()
embedding = torch.randn(4, 2048, 40, 128, dtype=torch.float16).npu()
cosine = torch.randn(1, 2048, 1, 128, dtype=torch.float16).npu()
sine = torch.randn(1, 2048, 1, 128, dtype=torch.float16).npu()
ret = torch.ops.npu.npu_rotary_mul_backward(grad, embedding, cosine, sine)
self.assertEqual(ret[0].shape, embedding.shape)
self.assertEqual(ret[1].shape, cosine.shape)
self.assertEqual(ret[2].shape, sine.shape)
class TestNpuTranspose(TestCase):
def test_npu_transpose(self):
with FakeTensorMode():
npu_input = torch.randn((5, 3, 6, 4)).npu()
perm = [1, 0, 2, 3]
exp_shape = npu_input.permute(perm).shape
result = torch_npu.npu_transpose(npu_input, perm)
self.assertEqual(result.shape, exp_shape)
class TestPromptFlashAttention(TestCase):
def testPromptFlashAttention(self):
with FakeTensorMode():
q = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
k = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
v = torch.randn(1, 40, 1, 128, dtype=torch.float16).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
res = torch.ops.npu.npu_prompt_flash_attention(q, k, v)
print("q.shape: ", q.shape)
print("res.shape: ", res.shape)
self.assertTrue(q.shape == res.shape)
class TestFlashAttentionScore(TestCase):
def testFlashAttentionScore(self):
with FakeTensorMode():
q = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
k = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
v = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
softmax_max_sum = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
res = torch.ops.npu.npu_fusion_attention(q, k, v, head_num=40, input_layout="BNSD")
self.assertEqual(q.shape, res[0].shape)
self.assertEqual(q.dtype, res[0].dtype)
self.assertEqual(softmax_max_sum.shape, res[1].shape)
self.assertEqual(softmax_max_sum.dtype, res[1].dtype)
self.assertEqual(softmax_max_sum.shape, res[2].shape)
self.assertEqual(softmax_max_sum.dtype, res[2].dtype)
class TestFlashAttentionScoreGrad(TestCase):
def testFlashAttentionScoreGrad(self):
with FakeTensorMode():
q = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
k = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
v = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
dy = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
attention_in = torch.randn(1, 40, 16, 128, dtype=torch.float16).npu()
softmax_max = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
softmax_sum = torch.randn(1, 40, 16, 8, dtype=torch.float32).npu()
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
dy.requires_grad = True
attention_in.requires_grad = True
softmax_max.requires_grad = True
softmax_sum.requires_grad = True
res = torch.ops.npu.npu_fusion_attention_grad(q, k, v, dy, head_num=40, input_layout="BNSD",
softmax_max=softmax_max, softmax_sum=softmax_sum, attention_in=attention_in)
self.assertEqual(q.shape, res[0].shape)
self.assertEqual(q.dtype, res[0].dtype)
self.assertEqual(k.shape, res[1].shape)
self.assertEqual(k.dtype, res[1].dtype)
self.assertEqual(k.shape, res[2].shape)
self.assertEqual(k.dtype, res[2].dtype)
class TestNpuMoeComputeExpertTokens(TestCase):
def test_npu_moe_compute_expert_tokens(self):
with FakeTensorMode():
data = list(range(10))
experts = torch.tensor(data, dtype=torch.int32).npu()
sorted_experts = torch.sort(experts)[0]
num_experts = 10
ret = torch.ops.npu.npu_moe_compute_expert_tokens(sorted_experts, num_experts)
self.assertEqual(sorted_experts.shape, ret.shape)
self.assertEqual(sorted_experts.dtype, ret.dtype)
class TestMaskedSoftmaxWithRelPosBias(TestCase):
def testMaskedSoftmaxWithRelPosBias(self):
with FakeTensorMode():
x = torch.randn(96, 2, 2, 32, 32, dtype=torch.float)
relative_pos_bias = torch.randn(1, 1, 2, 32, 32, dtype=torch.float)
atten_mask = torch.randn(1, 2, 1, 32, 32, dtype=torch.float)
x.requires_grad = True
atten_mask.requires_grad = True
relative_pos_bias.requires_grad = True
res = torch.ops.npu.npu_masked_softmax_with_rel_pos_bias(x, atten_mask, relative_pos_bias)
print("x.shape: ", x.shape)
print("res.shape: ", res.shape)
self.assertTrue(x.shape == res.shape)
class TestNpuRopeQuantKVCache(TestCase):
@unittest.skip("skip test_npu_rope_quant_kvcache_meta now")
def test_npu_rope_quant_kvcache_meta(self):
with FakeTensorMode() as mode:
data_x = np.random.uniform(0, 1, [1, 1, 128 * 3]).astype(np.float16)
in_x = torch.from_numpy(data_x).to(torch.float16).npu()
data_cos = np.random.uniform(0, 1, [1, 1, 1, 128]).astype(np.float16)
in_cos = torch.from_numpy(data_cos).to(torch.float16).npu()
data_sin = np.random.uniform(0, 1, [1, 1, 1, 128]).astype(np.float16)
in_sin = torch.from_numpy(data_sin).to(torch.float16).npu()
data_k_cache = np.random.uniform(0, 1, [1, 2, 1, 128]).astype(np.int8)
in_k_cache = torch.from_numpy(data_k_cache).to(torch.int8).npu()
data_v_cache = np.random.uniform(0, 1, [1, 2, 1, 128]).astype(np.int8)
in_v_cache = torch.from_numpy(data_v_cache).to(torch.int8).npu()
in_indices = torch.tensor([0]).to(torch.int32).npu()
in_scale_k = torch.randn([128], dtype=torch.float32).npu()
in_scale_v = torch.randn([128], dtype=torch.float32).npu()
size_splits = [128, 128, 128]
fake_x = mode.from_tensor(in_x)
fake_cos = mode.from_tensor(in_cos)
fake_sin = mode.from_tensor(in_sin)
fake_indices = mode.from_tensor(in_indices)
fake_k_cache = mode.from_tensor(in_k_cache)
fake_v_cache = mode.from_tensor(in_v_cache)
fake_scale_k = mode.from_tensor(in_scale_k)
fake_scale_v = mode.from_tensor(in_scale_v)
self.assertIsNotNone(fake_x)
self.assertIsNotNone(fake_cos)
self.assertIsNotNone(fake_sin)
self.assertIsNotNone(fake_k_cache)
self.assertIsNotNone(fake_v_cache)
self.assertIsNotNone(fake_scale_k)
self.assertIsNotNone(fake_scale_v)
q_result, k_result, c_result, k_cache_result, v_cache_result = torch.ops.npu.npu_rope_quant_kvcache(fake_x,
fake_cos,
fake_sin,
fake_k_cache,
fake_v_cache,
fake_indices,
fake_scale_k,
fake_scale_v,
size_splits)
self.assertEqual(q_result.shape, torch.Size([1, 1, 1, 128]))
self.assertEqual(q_result.dtype, in_x.dtype)
self.assertEqual(q_result.device, in_x.device)
self.assertTrue(isinstance(q_result, FakeTensor))
self.assertEqual(k_cache_result.shape, in_k_cache.shape)
self.assertEqual(k_cache_result.dtype, in_k_cache.dtype)
self.assertEqual(k_cache_result.device, in_k_cache.device)
self.assertTrue(isinstance(k_cache_result, FakeTensor))
class TestGeGlu(TestCase):
def TestGeGlu(self):
with FakeTensorMode():
x = torch.randn(2, 10, 1024, dtype=torch.float)
dim = -1
approximate = 1
activate_left = False
atten_mask.requires_grad = True
relative_pos_bias.requires_grad = True
res = torch.ops.npu.npu_geglu(x, dim, approximate, activate_left)
self.assertTrue(x.shape == res.shape)
class TestScatterUpdateMeta(TestCase):
def test_scatter_update_meta(self):
with FakeTensorMode() as mode:
in_self = torch.randn(4, 4, 32, 256, dtype=torch.float16).npu()
in_indices = torch.tensor([1, 1, 1, 1]).npu()
in_updates = torch.randn(4, 4, 1, 256, dtype=torch.float16).npu()
fake_self = mode.from_tensor(in_self)
fake_indices = mode.from_tensor(in_indices)
fake_updates = mode.from_tensor(in_updates)
self.assertIsNotNone(fake_self)
self.assertIsNotNone(fake_indices)
self.assertIsNotNone(fake_updates)
fake_result = torch.ops.npu.scatter_update(fake_self, fake_indices, fake_updates, -2)
self.assertEqual(fake_result.shape, in_self.shape)
self.assertEqual(fake_result.dtype, in_self.dtype)
self.assertEqual(fake_result.device, in_self.device)
self.assertTrue(isinstance(fake_result, FakeTensor))
self.assertIsNot(fake_result, fake_self)
self.assertIsNot(fake_result, in_self)
def test_scatter_update__meta(self):
with FakeTensorMode() as mode:
in_self = torch.randn(4, 4, 32, 256, dtype=torch.float32).npu()
in_indices = torch.tensor([1, 1, 1, 1]).npu()
in_updates = torch.randn(4, 4, 1, 256, dtype=torch.float32).npu()
fake_self = mode.from_tensor(in_self)
fake_indices = mode.from_tensor(in_indices)
fake_updates = mode.from_tensor(in_updates)
self.assertIsNotNone(fake_self)
self.assertIsNotNone(fake_indices)
self.assertIsNotNone(fake_updates)
fake_result = torch.ops.npu.scatter_update_(fake_self, fake_indices, fake_updates, -2)
self.assertEqual(fake_result.shape, in_self.shape)
self.assertEqual(fake_result.dtype, in_self.dtype)
self.assertEqual(fake_result.device, in_self.device)
self.assertTrue(isinstance(fake_result, FakeTensor))
self.assertIs(fake_result, fake_self)
self.assertIsNot(fake_result, in_self)
class TestNpuQuantScatterMeta(TestCase):
def test_npu_quant_scatter_meta(self):
with FakeTensorMode() as mode:
data_var = np.random.uniform(0, 1, [1, 1, 32]).astype(np.int8)
in_var = torch.from_numpy(data_var).to(torch.int8).npu()
data_indices = np.random.uniform(0, 1, [1]).astype(np.int32)
in_indices = torch.from_numpy(data_indices).to(torch.int32).npu()
data_updates = np.random.uniform(1, 2, [1, 1, 32]).astype(np.float16)
in_updates = torch.from_numpy(data_updates).to(torch.bfloat16).npu()
data_quant_scales = np.random.uniform(0, 1, [1, 1, 32]).astype(np.float16)
in_quant_scales = torch.from_numpy(data_quant_scales).to(torch.bfloat16).npu()
fake_var = mode.from_tensor(in_var)
fake_indices = mode.from_tensor(in_indices)
fake_updates = mode.from_tensor(in_updates)
fake_quant_scales = mode.from_tensor(in_quant_scales)
self.assertIsNotNone(fake_var)
self.assertIsNotNone(fake_indices)
self.assertIsNotNone(fake_updates)
self.assertIsNotNone(fake_quant_scales)
fake_result = torch.ops.npu.npu_quant_scatter(fake_var, fake_indices, fake_updates, fake_quant_scales, None,
-2, -1, "update")
self.assertEqual(fake_result.shape, in_var.shape)
self.assertEqual(fake_result.dtype, in_var.dtype)
self.assertEqual(fake_result.device, in_var.device)
self.assertTrue(isinstance(fake_result, FakeTensor))
self.assertIsNot(fake_result, fake_var)
self.assertIsNot(fake_result, in_var)
def test_npu_quant_scatter__meta(self):
with FakeTensorMode() as mode:
data_var = np.random.uniform(0, 1, [1, 1, 32]).astype(np.int8)
in_var = torch.from_numpy(data_var).to(torch.int8).npu()
data_indices = np.random.uniform(0, 1, [1]).astype(np.int32)
in_indices = torch.from_numpy(data_indices).to(torch.int32).npu()
data_updates = np.random.uniform(1, 2, [1, 1, 32]).astype(np.float16)
in_updates = torch.from_numpy(data_updates).to(torch.bfloat16).npu()
data_quant_scales = np.random.uniform(0, 1, [1, 1, 32]).astype(np.float16)
in_quant_scales = torch.from_numpy(data_quant_scales).to(torch.bfloat16).npu()
fake_var = mode.from_tensor(in_var)
fake_indices = mode.from_tensor(in_indices)
fake_updates = mode.from_tensor(in_updates)
fake_quant_scales = mode.from_tensor(in_quant_scales)
self.assertIsNotNone(fake_var)
self.assertIsNotNone(fake_indices)
self.assertIsNotNone(fake_updates)
self.assertIsNotNone(fake_quant_scales)
fake_result = torch.ops.npu.npu_quant_scatter_(fake_var, fake_indices, fake_updates, fake_quant_scales,
None, -2, -1, "update")
self.assertEqual(fake_result.shape, in_var.shape)
self.assertEqual(fake_result.dtype, in_var.dtype)
self.assertEqual(fake_result.device, in_var.device)
self.assertTrue(isinstance(fake_result, FakeTensor))
self.assertIs(fake_result, fake_var)
self.assertIsNot(fake_result, in_var)
class TestNpuApplyRotoryPosEmbMeta(TestCase):
def test_npu_apply_rotary_pos_emb_meta(self):
with FakeTensorMode() as mode:
query_var = np.random.uniform(0, 1, [4, 1024, 16, 128]).astype(np.float16)
data_query = torch.from_numpy(query_var).to(torch.float16).npu()
key_var = np.random.uniform(0, 1, [4, 1024, 16, 128]).astype(np.float16)
data_key = torch.from_numpy(key_var).to(torch.float16).npu()
cos_var = np.random.uniform(0, 1, [4, 1024, 1, 128]).astype(np.float16)
data_cos = torch.from_numpy(cos_var).to(torch.float16).npu()
sin_var = np.random.uniform(0, 1, [4, 1024, 1, 128]).astype(np.float16)
data_sin = torch.from_numpy(sin_var).to(torch.float16).npu()
fake_query = mode.from_tensor(data_query)
fake_key = mode.from_tensor(data_key)
fake_cos = mode.from_tensor(data_cos)
fake_sin = mode.from_tensor(data_sin)
self.assertIsNotNone(fake_query)
self.assertIsNotNone(fake_key)
self.assertIsNotNone(fake_cos)
self.assertIsNotNone(fake_sin)
fake_query_result, fake_key_result = torch.ops.npu.npu_apply_rotary_pos_emb(fake_query, fake_key, fake_cos, fake_sin, "BSH")
self.assertEqual(fake_query_result.shape, data_query.shape)
self.assertEqual(fake_query_result.dtype, data_query.dtype)
self.assertEqual(fake_query_result.device, data_query.device)
self.assertEqual(fake_key_result.shape, data_key.shape)
self.assertEqual(fake_key_result.dtype, data_key.dtype)
self.assertEqual(fake_key_result.device, data_key.device)
self.assertTrue(isinstance(fake_query_result, FakeTensor))
self.assertTrue(isinstance(fake_key_result, FakeTensor))
class TestMmAllReduce(TestCase):
def test_mm_all_reduce(self):
with FakeTensorMode():
dst_dtype = torch.float16
x1 = torch.randn(128, 256, dtype=torch.float16).npu()
x2 = torch.randn(256, 128, dtype=torch.float16).npu()
hcom = "fake group info"
output = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom, reduce_op="sum")
self.assertEqual(output.shape, (128, 128))
self.assertEqual(output.dtype, dst_dtype)
def test_mm_all_reduce_quant(self):
with FakeTensorMode():
dst_dtype = torch.float16
x1 = torch.randn(128, 256, dtype=torch.float16).to(torch.int8).npu()
x2 = torch.randn(256, 128, dtype=torch.float16).to(torch.int8).npu()
dequant = torch.randn(128, dtype=torch.float16).to(torch.int64).npu()
hcom = "fake group info"
output = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom, reduce_op="sum", dequant_scale=dequant)
self.assertEqual(output.shape, (128, 128))
self.assertEqual(output.dtype, dst_dtype)
def test_mm_all_reduce_quant(self):
with FakeTensorMode():
dst_dtype = torch.bfloat16
x1 = torch.randn(128, 256, dtype=torch.float16).to(torch.int8).npu()
x2 = torch.randn(256, 128, dtype=torch.float16).to(torch.int8).npu()
dequant = torch.randn(128, dtype=torch.bfloat16).npu()
hcom = "fake group info"
output = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom, reduce_op="sum", dequant_scale=dequant)
self.assertEqual(output.shape, (128, 128))
self.assertEqual(output.dtype, dst_dtype)
class TestNpuDeepNorm(TestCase):
def test_npu_deep_norm(self):
with FakeTensorMode():
npu_x = torch.randn((2, 3), dtype=torch.float32).npu()
npu_gx = torch.randn((2, 3), dtype=torch.float32).npu()
npu_beta = torch.randn((3,), dtype=torch.float32).npu()
npu_gamma = torch.randn((3,), dtype=torch.float32).npu()
result_mean, result_rstd, result_y = torch_npu.npu_deep_norm(npu_x, npu_gx, npu_beta, npu_gamma)
self.assertEqual(result_y.dtype, npu_x.dtype)
self.assertEqual(result_y.shape, npu_x.shape)
self.assertEqual(result_y.device, npu_x.device)
self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
self.assertEqual(result_rstd.device, npu_x.device)
self.assertEqual(result_mean.shape, torch.Size([2, 1]))
self.assertEqual(result_mean.device, npu_x.device)
class TestNpuRmsNorm(TestCase):
def test_npu_rms_norm(self):
with FakeTensorMode():
npu_x = torch.randn((2, 3), dtype=torch.float32).npu()
npu_gamma = torch.randn((3,), dtype=torch.float32).npu()
result_y, result_rstd = torch_npu.npu_rms_norm(npu_x, npu_gamma)
self.assertEqual(result_y.dtype, npu_x.dtype)
self.assertEqual(result_y.shape, npu_x.shape)
self.assertEqual(result_y.device, npu_x.device)
self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
self.assertEqual(result_rstd.device, npu_x.device)
def test_npu_rms_norm_backward(self):
with FakeTensorMode():
npu_x = torch.randn((2, 3), dtype=torch.float32).npu()
npu_gamma = torch.randn((3,), dtype=torch.float32).npu()
npu_x.requires_grad = True
npu_gamma.requires_grad = True
out = torch_npu.npu_rms_norm(npu_x, npu_gamma)[0]
grad_y = torch.randn((2, 3), dtype=torch.float32).npu()
out.backward(grad_y)
dx = npu_x.grad
dw = npu_gamma.grad
self.assertEqual(dx.dtype, npu_x.dtype)
self.assertEqual(dx.shape, npu_x.shape)
self.assertEqual(dx.device, npu_x.device)
self.assertEqual(dw.shape, npu_gamma.shape)
self.assertEqual(dw.device, npu_gamma.device)
class TestNpuAddRmsNorm(TestCase):
def test_npu_add_rms_norm(self):
with FakeTensorMode():
npu_x1 = torch.randn((2, 3), dtype=torch.float32).npu()
npu_x2 = torch.randn((2, 3), dtype=torch.float32).npu()
npu_gamma = torch.randn((3,), dtype=torch.float32).npu()
result_y, result_rstd, result_x = torch_npu.npu_add_rms_norm(npu_x1, npu_x2, npu_gamma)
self.assertEqual(result_y.dtype, npu_x1.dtype)
self.assertEqual(result_y.shape, npu_x1.shape)
self.assertEqual(result_y.device, npu_x1.device)
self.assertEqual(result_rstd.shape, torch.Size([2, 1]))
self.assertEqual(result_rstd.device, npu_x1.device)
self.assertEqual(result_x.dtype, npu_x1.dtype)
self.assertEqual(result_x.shape, npu_x1.shape)
self.assertEqual(result_x.device, npu_x1.device)
class TestFFN(TestCase):
def test_npu_ffn_meta(self):
with FakeTensorMode():
x = torch.randn(1, 320, dtype=torch.float16).npu()
w1 = torch.randn(320, 2560, dtype=torch.float16).npu()
w2 = torch.randn(2560, 320, dtype=torch.float16).npu()
activation = "gelu"
res = torch_npu.npu_ffn(x, w1, w2, activation, inner_precise=1, output_dtype=torch.float16)
self.assertTrue(x.shape == res.shape)
class TestNpuDynamicQuant(TestCase):
def test_npu_dynamic_quant(self):
with FakeTensorMode():
input_npu = torch.randn((4, 2048, 1024)).npu().to(torch.float16)
smooth_scales_npu = torch.randn((1024)).npu().to(torch.float16)
output = torch.randn((4, 2048, 1024)).npu().to(torch.int8)
scale = torch.randn((4, 2048)).npu().to(torch.float32)
actual_output, actual_scale = torch_npu.npu_dynamic_quant(input_npu, smooth_scales=smooth_scales_npu)
self.assertEqual(actual_output.dtype, output.dtype)
self.assertEqual(actual_output.shape, output.shape)
self.assertEqual(actual_output.device, output.device)
self.assertEqual(actual_scale.dtype, scale.dtype)
self.assertEqual(actual_scale.shape, scale.shape)
self.assertEqual(actual_scale.device, scale.device)
class TestDynamicQuantAsymmetric(TestCase):
def test_npu_dynamic_quant_asymmetric(self):
with FakeTensorMode():
input_npu = torch.randn((4, 2048, 1024)).npu().to(torch.float16)
smooth_scales_npu = torch.randn((1024)).npu().to(torch.float16)
output = torch.randn((4, 2048, 1024)).npu().to(torch.int8)
scale = torch.randn((4, 2048)).npu().to(torch.float32)
offset = torch.randn((4, 2048)).npu().to(torch.float32)
actual_output, actual_scale, actual_offset = torch_npu.npu_dynamic_quant_asymmetric(input_npu, smooth_scales=smooth_scales_npu)
self.assertEqual(actual_output.dtype, output.dtype)
self.assertEqual(actual_output.shape, output.shape)
self.assertEqual(actual_output.device, output.device)
self.assertEqual(actual_scale.dtype, scale.dtype)
self.assertEqual(actual_scale.shape, scale.shape)
self.assertEqual(actual_scale.device, scale.device)
self.assertEqual(actual_offset.dtype, offset.dtype)
self.assertEqual(actual_offset.shape, offset.shape)
self.assertEqual(actual_offset.device, offset.device)
class TestGroupedMatmul(TestCase):
def test_npu_grouped_matmul_meta_0(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(256, 256, dtype=torch.float16).npu()
x2 = torch.randn(1024, 256, dtype=torch.float16).npu()
x3 = torch.randn(512, 1024, dtype=torch.float16).npu()
x = [x1, x2, x3]
w1 = torch.randn(256, 256, dtype=torch.float16).npu()
w2 = torch.randn(256, 1024, dtype=torch.float16).npu()
w3 = torch.randn(1024, 128, dtype=torch.float16).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(1024, dtype=torch.float16).npu()
b3 = torch.randn(128, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = None
split_item = 0
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=-1)
self.assertTrue(x[0].shape[0] == res[0].shape[0])
self.assertTrue(x[1].shape[0] == res[1].shape[0])
self.assertTrue(x[2].shape[0] == res[2].shape[0])
self.assertTrue(w[0].shape[1] == res[0].shape[1])
self.assertTrue(w[1].shape[1] == res[1].shape[1])
self.assertTrue(w[2].shape[1] == res[2].shape[1])
def test_npu_grouped_matmul_meta_1(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(1792, 1024, dtype=torch.float16).npu()
x = [x1]
w1 = torch.randn(1024, 256, dtype=torch.float16).npu()
w2 = torch.randn(1024, 1024, dtype=torch.float16).npu()
w3 = torch.randn(1024, 128, dtype=torch.float16).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(1024, dtype=torch.float16).npu()
b3 = torch.randn(128, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = [256, 1280, 1792]
split_item = 1
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=-1)
self.assertTrue(group_list[0] == res[0].shape[0])
self.assertTrue(group_list[1] - group_list[0] == res[1].shape[0])
self.assertTrue(group_list[2] - group_list[1] == res[2].shape[0])
self.assertTrue(w[0].shape[1] == res[0].shape[1])
self.assertTrue(w[1].shape[1] == res[1].shape[1])
self.assertTrue(w[2].shape[1] == res[2].shape[1])
def test_npu_grouped_matmul_meta_2(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(256, 256, dtype=torch.float16).npu()
x2 = torch.randn(1024, 256, dtype=torch.float16).npu()
x3 = torch.randn(512, 1024, dtype=torch.float16).npu()
x = [x1, x2, x3]
w1 = torch.randn(256, 256, dtype=torch.float16).npu()
w2 = torch.randn(256, 256, dtype=torch.float16).npu()
w3 = torch.randn(1024, 256, dtype=torch.float16).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(256, dtype=torch.float16).npu()
b3 = torch.randn(256, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = None
split_item = 2
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
dim0 = 0
for xi in x:
dim0 += xi.shape[0]
self.assertTrue(dim0 == res[0].shape[0])
self.assertTrue(w[0].shape[1] == res[0].shape[1])
def test_npu_grouped_matmul_meta_3(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(1792, 1024, dtype=torch.float16).npu()
x = [x1]
w1 = torch.randn(1024, 256, dtype=torch.float16).npu()
w2 = torch.randn(1024, 256, dtype=torch.float16).npu()
w3 = torch.randn(1024, 256, dtype=torch.float16).npu()
w = [w1, w2, w3]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(256, dtype=torch.float16).npu()
b3 = torch.randn(256, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = [256, 1280, 1792]
split_item = 3
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
self.assertTrue(x[0].shape[0] == res[0].shape[0])
self.assertTrue(w[0].shape[1] == res[0].shape[1])
def test_npu_grouped_matmul_meta_4(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(1792, 1024, dtype=torch.float16).npu()
x = [x1]
w1 = torch.randn(3, 1024, 256, dtype=torch.float16).npu()
w = [w1]
b1 = torch.randn(256, dtype=torch.float16).npu()
b2 = torch.randn(256, dtype=torch.float16).npu()
b3 = torch.randn(256, dtype=torch.float16).npu()
b = [b1, b2, b3]
group_list = [256, 1280, 1792]
split_item = 3
res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0)
self.assertTrue(x[0].shape[0] == res[0].shape[0])
self.assertTrue(w[0].shape[2] == res[0].shape[1])
def test_npu_grouped_matmul_meta_5(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(256, 1792, dtype=torch.float16).npu()
x = [x1.t()]
w1 = torch.randn(256, 256, dtype=torch.float16).npu()
w2 = torch.randn(256, 256, dtype=torch.float16).npu()
w3 = torch.randn(512, 256, dtype=torch.float16).npu()
w4 = torch.randn(768, 256, dtype=torch.float16).npu()
w = [w1, w2, w3, w4]
group_list = torch.tensor([256, 512, 1024, 1792]).npu()
split_item = 2
res = torch_npu.npu_grouped_matmul(x, w, bias=None, group_list=group_list,
split_item=split_item, group_type=2,
group_list_type=0)
self.assertTrue(len(res[0].shape) == 3)
self.assertTrue(group_list.shape[0] == res[0].shape[0])
self.assertTrue(x1.shape[1] == res[0].shape[1])
self.assertTrue(w1.shape[1] == res[0].shape[2])
def test_npu_grouped_matmul_meta_6(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(256, 1792, dtype=torch.float16).npu()
x = [x1.t()]
w1 = torch.randn(1792, 256, dtype=torch.float16).npu()
w = [w1]
group_list = torch.tensor([256, 512, 1024, 1792]).npu()
split_item = 3
res = torch_npu.npu_grouped_matmul(x, w, bias=None, group_list=group_list,
split_item=split_item, group_type=2,
group_list_type=0)
self.assertTrue(len(res[0].shape) == 3)
self.assertTrue(group_list.shape[0] == res[0].shape[0])
self.assertTrue(x1.shape[1] == res[0].shape[1])
self.assertTrue(w1.shape[1] == res[0].shape[2])
def test_npu_grouped_matmul_meta_invalid_group_type(self):
with FakeTensorMode():
torch.manual_seed(0)
x1 = torch.randn(256, 256, dtype=torch.float16).npu()
x = [x1]
w1 = torch.randn(256, 256, dtype=torch.float16).npu()
w = [w1]
group_list = torch.tensor([256]).npu()
for invalid_type in [1, 3, 100]:
with self.assertRaisesRegex(RuntimeError, "group_type only supports"):
torch_npu.npu_grouped_matmul(x, w, group_list=group_list,
split_item=2, group_type=invalid_type)
class TestQuantMatmul(TestCase):
def test_npu_quant_matmul_meta(self):
with FakeTensorMode():
x1 = torch.randint(-1, 1, (1, 1, 1024), dtype=torch.int8).npu()
x2 = torch.randint(-1, 1, (1, 1024, 100), dtype=torch.int8).npu()
expect_ret = torch.randint(-1, 1, (1, 1, 100), dtype=torch.int8).npu()
scale = torch.randn(1, dtype=torch.float32).npu()
offset = torch.randn(1, dtype=torch.float32).npu()
bias = torch.randint(-1, 1, (1, 1, 100), dtype=torch.int32).npu()
res = torch_npu.npu_quant_matmul(x1, x2, scale, offset=offset, bias=bias)
self.assertTrue(expect_ret.shape == res.shape)
self.assertTrue(expect_ret.dtype == res.dtype)
expect_ret_bf16 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.bfloat16).npu()
scale_bf16 = torch.randn(1, dtype=torch.bfloat16).npu()
bias_bf16 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.bfloat16).npu()
res_bf16 = torch_npu.npu_quant_matmul(x1, x2, scale_bf16, offset=None, bias=bias_bf16, output_dtype=torch.bfloat16)
self.assertTrue(expect_ret_bf16.shape == res_bf16.shape)
self.assertTrue(expect_ret_bf16.dtype == res_bf16.dtype)
expect_ret_fp16 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.float16).npu()
bias_fp32 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.float32).npu()
pertoken_scale = torch.randn(1, dtype=torch.float32).npu()
res_fp16 = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, pertoken_scale=pertoken_scale,
bias=bias_fp32, output_dtype=torch.float16)
self.assertTrue(expect_ret_fp16.shape == res_fp16.shape)
self.assertTrue(expect_ret_fp16.dtype == res_fp16.dtype)
expect_ret_int32 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.int32).npu()
bias_int32 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.int32).npu()
res_int32 = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, pertoken_scale=None,
bias=bias_int32, output_dtype=torch.int32)
self.assertTrue(expect_ret_int32.shape == res_int32.shape)
self.assertTrue(expect_ret_int32.dtype == res_int32.dtype)
x1 = torch.randint(-1, 1, (16, 8), dtype=torch.int32).npu()
x2 = torch.randint(-1, 1, (64, 5), dtype=torch.int32).npu()
expect_ret = torch.randint(-1, 1, (16, 40), dtype=torch.float16).npu()
scale = torch.randn(1, dtype=torch.float32).npu()
bias = torch.randint(-1, 1, (40,), dtype=torch.int32).npu()
res = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, bias=bias, output_dtype=torch.float16)
self.assertTrue(expect_ret.shape == res.shape)
self.assertTrue(expect_ret.dtype == res.dtype)
class TestTranQuantParam(TestCase):
def test_npu_trans_quant_param_meta(self):
with FakeTensorMode():
test_1_expect_ret = torch.randint(-1, 1, (4,), dtype=torch.int64).npu()
test_1_scale = torch.randn(1, dtype=torch.float32).npu()
test_1_offset = torch.randn(4, dtype=torch.float32).npu()
res = torch_npu.npu_trans_quant_param(test_1_scale, test_1_offset)
self.assertTrue(res.shape == test_1_expect_ret.shape)
self.assertTrue(res.dtype == test_1_expect_ret.dtype)
test_2_expect_ret = torch.randint(-1, 1, (1, 4), dtype=torch.int64).npu()
test_2_scale = torch.randn(1, 4, dtype=torch.float32).npu()
test_2_offset = torch.randn(1, 4, dtype=torch.float32).npu()
res = torch_npu.npu_trans_quant_param(test_2_scale, test_2_offset)
self.assertTrue(res.shape == test_2_expect_ret.shape)
self.assertTrue(res.dtype == test_2_expect_ret.dtype)
class TestAntiQuant(TestCase):
@unittest.skipIf(torch.__version__ != "2.1.0",
"OP `AntiQuant` is only supported on torch v2.1.0, skip this ut for this torch version")
def test_npu_anti_quant_meta(self):
with FakeTensorMode():
x = torch.randint(low=-128, high=127, size=(20, 100), dtype=torch.int8).npu()
scale = torch.randn(100, dtype=torch.float).npu()
offset = torch.randn(100, dtype=torch.float).npu()
dstType = torch.float16
res = torch_npu.npu_anti_quant(x, scale, offset=offset, dst_dtype=dstType)
self.assertTrue(x.shape == res.shape)
x = x.to(dstType)
self.assertTrue(x.numel() * x.element_size() == res.numel() * res.element_size())
class TestNpuLinear(TestCase):
def test_npu_linear_meta(self):
with FakeTensorMode():
npu_input1 = torch.randn(16, 128).npu()
npu_input2 = torch.randn(32, 128).npu()
npu_bias = torch.randn(32,).npu()
result = torch_npu.npu_linear(npu_input1, npu_input2, npu_bias)
self.assertEqual(result.dtype, npu_input1.dtype)
self.assertEqual(result.shape, torch.nn.functional.linear(npu_input1, npu_input2, npu_bias).shape)
class TestMoeFinalizeRouting(TestCase):
def test_npu_moe_finalize_routing_meta(self):
with FakeTensorMode():
num_rows = 50
top_k = 4
token_len = 10
expert_num = 16
expanded_permuted_rows = torch.randn(num_rows * top_k, token_len).to(torch.float32)
skip1 = torch.randn(num_rows, token_len).to(torch.float32)
skip2_optional = torch.randn(num_rows, token_len).to(torch.float32)
bias = torch.randn(num_rows, top_k).to(torch.float32)
scales = torch.randn(num_rows, top_k).to(torch.float32)
expanded_src_to_dst_row = torch.arange(num_rows * top_k).to(torch.int32)
expert_for_source_row = torch.randint(low = 0, high = expert_num, size = (num_rows, top_k)).to(torch.int32)
result = torch_npu.npu_moe_finalize_routing(expanded_permuted_rows, skip1, skip2_optional, bias, scales,
expanded_src_to_dst_row, expert_for_source_row)
self.assertTrue(result.shape == skip1.shape)
self.assertTrue(result.dtype == skip1.dtype)
class TestNpuPrefetch(TestCase):
def test_npu_prefetch(self):
with FakeTensorMode():
input1 = torch.randn(8, 8).npu()
with self.assertRaises(RuntimeError) as cm:
torch_npu.npu_prefetch(input1, None, -1)
exception = cm.exception
self.assertEqual(str(exception), "The max_size should be greater than zero, but got -1.")
with self.assertRaises(RuntimeError) as cm:
torch_npu.npu_prefetch(input1, None, 10, -1)
exception = cm.exception
self.assertEqual(str(exception), "The offset should be nonnegative, but got -1.")
instantiate_parametrized_tests(FakeTensorTest)
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for="cpu")
if __name__ == "__main__":
run_tests()