import torch
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
from testutils import TestUtils
import torch_npu
class TestCat(TestUtils):
def op_calc(self, input_element, dim):
return torch.cat([input_element, input_element], dim)
@parametrize('shape', [(8, 16, 32, 64)])
@parametrize('dim', [-1])
@parametrize('dtype', ['bfloat16'])
def test_reduction_cases_shapes(self, shape, dim, dtype):
input_element = self._generate_tensor(shape, dtype)
std_cat = self.op_calc(input_element, dim)
compiled_op_calc = torch.compile(self.op_calc, backend="inductor")
inductor_cat = compiled_op_calc(input_element, dim)
self.assertEqual(std_cat, inductor_cat, atol=1e-1, rtol=1e-1, equal_nan=True)
def op_calc_non_contiguous(self, input_element, dim):
return torch.cat([input_element, input_element], dim)
@parametrize('shape', [(8, 16, 32, 64)])
@parametrize('dim', [1])
@parametrize('dtype', ['bfloat16'])
def test_cat_non_contiguous(self, shape, dim, dtype):
input_element = self._generate_tensor(shape, dtype)
input_element = input_element.transpose(-1, -2)
std_cat = self.op_calc_non_contiguous(input_element, dim)
compiled_op_calc = torch.compile(self.op_calc_non_contiguous, backend="inductor")
inductor_cat = compiled_op_calc(input_element, dim)
self.assertEqual(std_cat, inductor_cat, atol=1e-4, rtol=1e-4, equal_nan=True)
class PatternModel(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, *xs):
slices = [x[..., :sz] for x, sz in zip(xs, (128, 32, 48, 48, 48, 48, 48))]
output_tensor = torch.cat(slices, self.dim)
return output_tensor
@parametrize('shape', [(128, 50, 128)])
@parametrize('dim', [2])
@parametrize('dtype', ['float32', 'bfloat16'])
def test_model_input_is_concat(self, shape, dim, dtype):
inputs = [self._generate_tensor(shape, dtype) for _ in range(7)]
model = self.PatternModel(dim).to(dtype=getattr(torch, dtype))
model.eval()
with torch.no_grad():
eager_out = model(*inputs)
compiled_model = torch.compile(model, backend="inductor")
with torch.no_grad():
inductor_out = compiled_model(*inputs)
self.assertEqual(eager_out, inductor_out,
atol=1e-4, rtol=1e-4, equal_nan=True)
instantiate_parametrized_tests(TestCat)
if __name__ == "__main__":
run_tests()