import unittest
import torch
import numpy as np
from torch.nn import functional as F
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor
class TestBaddBmm(TestCase):
def generate_scalar(self, dtype, min1, max1):
if dtype == "float32":
scalar = np.random.uniform(min1, max1)
if dtype == "float16":
scalar = np.random.uniform(min1, max1)
if dtype == "int32":
scalar = np.random.randint(min1, max1)
return scalar
def cpu_op_exec(self, input1, input2, input3, scalar1, scalar2):
output = torch.baddbmm(input1, input2, input3, beta=scalar1, alpha=scalar2)
output = output.numpy()
return output
def cpu_op_exec_(self, input1, input2, input3, scalar1, scalar2):
input1.baddbmm_(input2, input3, beta=scalar1, alpha=scalar2)
input1 = input1.numpy()
return input1
def npu_op_exec(self, input1, input2, input3, scalar1, scalar2):
output = torch.baddbmm(input1, input2, input3, beta=scalar1, alpha=scalar2)
output = output.to("cpu")
output = output.numpy()
return output
def npu_op_exec_(self, input1, input2, input3, scalar1, scalar2):
input1.baddbmm_(input2, input3, beta=scalar1, alpha=scalar2)
input1 = input1.to("cpu")
input1 = input1.numpy()
return input1
def test_baddbmm_common_shape_format(self):
shape_format = [
[
[np.float16, -1, (1, 3, 5)],
[np.float16, -1, (1, 3, 4)],
[np.float16, -1, (1, 4, 5)],
"float32",
],
[
[np.float16, -1, (6, 4, 3)],
[np.float16, -1, (6, 4, 5)],
[np.float16, -1, (6, 5, 3)],
"float32",
],
[
[np.float16, -1, (175, 455, 22)],
[np.float16, -1, (175, 455, 116)],
[np.float16, -1, (175, 116, 22)],
"float32",
],
[
[np.float16, -1, (25, 56, 12)],
[np.float16, -1, (25, 56, 51)],
[np.float16, -1, (25, 51, 12)],
"float32",
],
]
for item in shape_format:
cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 1)
cpu_input2, npu_input2 = create_common_tensor(item[1], 0, 1)
cpu_input3, npu_input3 = create_common_tensor(item[2], 0, 1)
scalar1 = self.generate_scalar(item[3], 0, 2)
scalar2 = self.generate_scalar(item[3], 0, 2)
cpu_output = self.cpu_op_exec(
cpu_input1.float(),
cpu_input2.float(),
cpu_input3.float(),
scalar1,
scalar2,
)
npu_output = self.npu_op_exec(
npu_input1.float(),
npu_input2.float(),
npu_input3.float(),
scalar1,
scalar2,
)
self.assertRtolEqual(cpu_output, npu_output, prec=1.0e-3, prec16=1.0e-3)
cpu_output_ = self.cpu_op_exec_(
cpu_input1.float(),
cpu_input2.float(),
cpu_input3.float(),
scalar1,
scalar2,
)
npu_output_ = self.npu_op_exec_(
npu_input1.float(),
npu_input2.float(),
npu_input3.float(),
scalar1,
scalar2,
)
self.assertRtolEqual(cpu_output_, npu_output_, prec=1.0e-3, prec16=1.0e-3)
def test_baddbmm_float16_shape_format(self):
def cpu_op_exec_fp16(input1, input2, input3, scalar1, scalar2):
input1 = input1.to(torch.float32)
input2 = input2.to(torch.float32)
input3 = input3.to(torch.float32)
output = torch.baddbmm(input1, input2, input3, beta=scalar1, alpha=scalar2)
output = output.numpy()
output = output.astype(np.float16)
return output
shape_format = [
[
[np.float16, -1, (1, 3, 5)],
[np.float16, -1, (1, 3, 4)],
[np.float16, -1, (1, 4, 5)],
"float16",
],
[
[np.float16, -1, (500, 40, 300)],
[np.float16, -1, (500, 40, 500)],
[np.float16, -1, (500, 500, 300)],
"float16",
],
[
[np.float16, -1, (175, 455, 22)],
[np.float16, -1, (175, 455, 116)],
[np.float16, -1, (175, 116, 22)],
"float16",
],
[
[np.float16, -1, (25, 21, 11)],
[np.float16, -1, (25, 21, 34)],
[np.float16, -1, (25, 34, 11)],
"float16",
],
]
for item in shape_format:
cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 1)
cpu_input2, npu_input2 = create_common_tensor(item[1], 0, 1)
cpu_input3, npu_input3 = create_common_tensor(item[2], 0, 1)
scalar1 = self.generate_scalar(item[3], 0, 2)
scalar2 = self.generate_scalar(item[3], 0, 2)
cpu_output = cpu_op_exec_fp16(
cpu_input1, cpu_input2, cpu_input3, scalar1, scalar2
)
npu_output = self.npu_op_exec(
npu_input1, npu_input2, npu_input3, scalar1, scalar2
)
self.assertRtolEqual(cpu_output, npu_output, prec=1.0e-3, prec16=1.0e-2)
if __name__ == "__main__":
run_tests()