import torch
from torch.testing._internal.common_utils import run_tests
from testutils import TestUtils
import torch_npu
class Test_issue54(TestUtils):
def func_layernorm(self, args):
add_3, primals_6, primals_7, view, primals_9, permute_1, primals_10, primals_11 = args
permute: "f32[256, 256]" = torch.ops.aten.permute.default(primals_6, [1, 0])
addmm: "f32[32768, 256]" = torch.ops.aten.addmm.default(primals_7, view, permute)
view_1: "f32[64, 512, 256]" = torch.ops.aten.view.default(addmm, [64, 512, 256])
addmm_1: "f32[32768, 256]" = torch.ops.aten.addmm.default(primals_9, view, permute_1)
view_3: "f32[64, 512, 256]" = torch.ops.aten.view.default(addmm_1, [64, 512, 256])
view_4: "f32[64, 512, 4, 64]" = torch.ops.aten.view.default(view_3, [64, 512, 4, 64])
permute_2: "f32[64, 4, 512, 64]" = torch.ops.aten.permute.default(view_4, [0, 2, 1, 3])
permute_3: "f32[256, 256]" = torch.ops.aten.permute.default(primals_10, [1, 0])
addmm_2: "f32[32768, 256]" = torch.ops.aten.addmm.default(primals_11, view, permute_3)
view_6: "f32[64, 512, 256]" = torch.ops.aten.view.default(addmm_2, [64, 512, 256])
view_8: "f32[64, 512, 4, 64]" = torch.ops.aten.view.default(view_1, [64, 512, 4, 64])
permute_5: "f32[64, 4, 512, 64]" = torch.ops.aten.permute.default(view_8, [0, 2, 1, 3])
permute_6: "f32[64, 4, 64, 512]" = torch.ops.aten.permute.default(permute_2, [0, 1, 3, 2])
expand_1: "f32[64, 4, 512, 64]" = torch.ops.aten.expand.default(permute_5, [64, 4, 512, 64])
clone: "f32[64, 4, 512, 64]" = torch.ops.aten.clone.default(expand_1, memory_format=torch.contiguous_format)
view_9: "f32[256, 512, 64]" = torch.ops.aten.view.default(clone, [256, 512, 64])
expand_2: "f32[64, 4, 64, 512]" = torch.ops.aten.expand.default(permute_6, [64, 4, 64, 512])
clone_1: "f32[64, 4, 64, 512]" = torch.ops.aten.clone.default(expand_2, memory_format=torch.contiguous_format)
view_10: "f32[256, 64, 512]" = torch.ops.aten.view.default(clone_1, [256, 64, 512])
bmm: "f32[256, 512, 512]" = torch.ops.aten.bmm.default(view_9, view_10)
view_7: "f32[64, 512, 4, 64]" = torch.ops.aten.view.default(view_6, [64, 512, 4, 64])
permute_4: "f32[64, 4, 512, 64]" = torch.ops.aten.permute.default(view_7, [0, 2, 1, 3])
expand_4: "f32[64, 4, 512, 64]" = torch.ops.aten.expand.default(permute_4, [64, 4, 512, 64])
clone_2: "f32[64, 4, 512, 64]" = torch.ops.aten.clone.default(expand_4, memory_format=torch.contiguous_format)
view_13: "f32[256, 512, 64]" = torch.ops.aten.view.default(clone_2, [256, 512, 64])
return bmm, view_13
def test_issue54(self):
device = 'npu'
add_3 = torch.randn((64, 512, 256), device=device, dtype=torch.float32)
primals_6 = torch.randn((256, 256), device=device, dtype=torch.float32)
primals_7 = torch.randn((256), device=device, dtype=torch.float32)
view = torch.randn((32768, 256), device=device, dtype=torch.float32)
primals_9 = torch.randn((256), device=device, dtype=torch.float32)
permute_1 = torch.randn((256, 256), device=device, dtype=torch.float32)
primals_10 = torch.randn((256, 256), device=device, dtype=torch.float32)
primals_11 = torch.randn((256), device=device, dtype=torch.float32)
args = (add_3, primals_6, primals_7, view, primals_9, permute_1, primals_10, primals_11)
ref = self.func_layernorm(args)
func = torch.compile(self.func_layernorm, backend="inductor", dynamic=False,
options={"unroll_reductions_threshold": 1, "aggressive_fusion": True})
calc = func(args)
self.assertEqual(ref[0], calc[0], atol=1e-2, rtol=1e-2)
self.assertEqual(ref[1], calc[1], atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
run_tests()