import copy
import torch
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
from testutils import TestUtils
import torch_npu
class TestBroadcast(TestUtils):
broadcast_size = 128
def op_calc(self, a, b, dim, new_shape):
a = a.unsqueeze(dim)
a = a.broadcast_to(new_shape)
b = b.unsqueeze(dim)
b = b.broadcast_to(new_shape)
y = a + b
return y
@parametrize('shape', [(8, 8, 256)])
@parametrize('dtype', ['float32', 'int32', 'float16', 'bfloat16'])
def test_view_cases(self, shape, dtype):
a = self._generate_tensor(shape, dtype)
b = self._generate_tensor(shape, dtype)
compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False)
for dim in [3, 2, 1, 0]:
new_shape = list(copy.deepcopy(shape))
new_shape.insert(dim, self.broadcast_size)
std_broadcast = self.op_calc(a, b, dim, new_shape)
inductor_broadcast = compiled_op_calc(a, b, dim, new_shape)
self.assertEqual(std_broadcast.float(), inductor_broadcast.float(), atol=1e-3, rtol=1e-3)
instantiate_parametrized_tests(TestBroadcast)
if __name__ == "__main__":
run_tests()