import torch
import numpy as np
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor
class TestBroadCastTensors(TestCase):
def cpu_op_exec(self, input1, input2):
output1, output2 = torch.broadcast_tensors(input1, input2)
return output1.numpy(), output2.numpy()
def npu_op_exec(self, input1, input2, npu_format=None):
input1 = input1.npu()
input2 = input2.npu()
if npu_format is not None:
input1 = torch_npu.npu_format_cast(input1, npu_format)
input2 = torch_npu.npu_format_cast(input2, npu_format)
output1, output2 = torch.broadcast_tensors(input1, input2)
return output1.cpu().numpy(), output2.cpu().numpy()
def test_broadcast_tensors_common_shape_format(self, device='npu'):
shape_format = [
[[1, 3], (2, 1), torch.float32],
[[1, 9], (5, 1), torch.float32],
[[3, 1], (1, 3), torch.float32],
]
for item in shape_format:
cpu_input1 = torch.randn(item[0], dtype=item[2])
cpu_input2 = torch.randn(item[1], dtype=item[2])
cpu_output1, cpu_output2 = self.cpu_op_exec(cpu_input1, cpu_input2)
npu_output1, npu_output2 = self.npu_op_exec(cpu_input1, cpu_input2)
self.assertRtolEqual(cpu_output1, npu_output1)
self.assertRtolEqual(cpu_output2, npu_output2)
def test_broadcast_tensors_discontiguous_shape(self, device='npu'):
shape_format = [
[[1, 6], (4, 1), torch.float32],
[[1, 18], (10, 1), torch.float32],
]
for item in shape_format:
cpu_input1 = torch.randn(item[0], dtype=item[2])[:, ::2]
cpu_input2 = torch.randn(item[1], dtype=item[2])[::2, :]
cpu_output1, cpu_output2 = self.cpu_op_exec(cpu_input1, cpu_input2)
npu_output1, npu_output2 = self.npu_op_exec(cpu_input1, cpu_input2)
self.assertRtolEqual(cpu_output1, npu_output1)
self.assertRtolEqual(cpu_output2, npu_output2)
def test_broadcast_tensors_format(self, device='npu'):
shape_format = [
[[2, 3, 4], (3, 1), torch.float32, 0],
[[2, 3, 4], (3, 1), torch.float32, 2],
[[2, 3, 4], (3, 1), torch.float32, 3],
[[2, 3, 4], (3, 1), torch.float32, 4],
[[2, 3, 4], (3, 1), torch.float32, 29],
[[2, 3, 4], (3, 1), torch.float32, 30],
]
for item in shape_format:
cpu_input1 = torch.randn(item[0], dtype=item[2])[:, ::2]
cpu_input2 = torch.randn(item[1], dtype=item[2])[::2, :]
cpu_output1, cpu_output2 = self.cpu_op_exec(cpu_input1, cpu_input2)
npu_format = item[3]
npu_output1, npu_output2 = self.npu_op_exec(cpu_input1, cpu_input2, npu_format)
self.assertRtolEqual(cpu_output1, npu_output1)
self.assertRtolEqual(cpu_output2, npu_output2)
if __name__ == "__main__":
run_tests()