import math
import unittest
import numpy as np
import torch
import torch_npu
import torch.nn as nn
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestNPUFusedMatmul(TestCase):
def supported_op_exec(self, x1, x2, bias, x3, fused_op_type):
res = torch.matmul(x1, x2)
if fused_op_type == "add":
res = torch.add(res, x3)
elif fused_op_type == "mul":
res = torch.mul(res, x3)
elif fused_op_type == "gelu_erf":
m = nn.GELU()
res = m(res)
elif fused_op_type == "gelu_tanh":
m = nn.GELU('tanh')
res = m(res)
return res
def custom_op_exec(self, x1, x2, bias, x3, fused_op_type):
return torch_npu.npu_fused_matmul(x1, x2, bias=None, x3=x3, fused_op_type=fused_op_type)
@SupportedDevices(['Ascend950'])
def test_npu_fused_matmul_add(self, device="npu"):
torch.manual_seed(0)
x1 = torch.randn((16, 48), dtype=torch.float16).npu()
x2 = torch.randn((48, 16), dtype=torch.float16).npu()
x3 = torch.randn((16, 16), dtype=torch.float16).npu()
fused_op_type = "add"
supported_output = self.supported_op_exec(
x1, x2, None, x3, fused_op_type)
custom_output = self.custom_op_exec(
x1, x2, None, x3, fused_op_type)
self.assertRtolEqual(supported_output, custom_output, 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_fused_matmul_mul(self, device="npu"):
torch.manual_seed(0)
x1 = torch.randn((16, 48), dtype=torch.float16).npu()
x2 = torch.randn((48, 16), dtype=torch.float16).npu()
x3 = torch.randn((16, 16), dtype=torch.float16).npu()
fused_op_type = "mul"
supported_output = self.supported_op_exec(
x1, x2, None, x3, fused_op_type)
custom_output = self.custom_op_exec(
x1, x2, None, x3, fused_op_type)
self.assertRtolEqual(supported_output, custom_output, 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_fused_matmul_gelu_erf(self, device="npu"):
torch.manual_seed(0)
x1 = torch.randn((16, 48), dtype=torch.float16).npu()
x2 = torch.randn((48, 16), dtype=torch.float16).npu()
fused_op_type = "gelu_erf"
supported_output = self.supported_op_exec(
x1, x2, None, None, fused_op_type)
custom_output = self.custom_op_exec(
x1, x2, None, None, fused_op_type)
self.assertRtolEqual(supported_output, custom_output, 0.001)
if __name__ == "__main__":
run_tests()