import torch
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
from testutils import TestUtils
import torch_npu
class TestSlice(TestUtils):
def op_calc(self, a, b, dim, step):
if dim == 0:
target = a.shape[0]
end = target // step
a = a[:end:, ::, ::, ::]
b = b[:end:, ::, ::, ::]
elif dim == 1:
target = a.shape[1]
end = target // step
a = a[::, :end:, ::, ::]
b = b[::, :end:, ::, ::]
elif dim == 2:
target = a.shape[2]
end = target // step
a = a[::, ::, :end:, ::]
b = b[::, ::, :end:, ::]
elif dim == 3:
target = a.shape[3]
end = target // step
a = a[::, ::, ::, :end:]
b = b[::, ::, ::, :end:]
y = a + b
return y
@parametrize('shape', [(8, 8, 256, 128)])
@parametrize('dtype', ['float32', 'int32', 'float16', 'bfloat16', 'int64'])
def test_view_cases(self, shape, dtype):
a = self._generate_tensor(shape, dtype)
b = self._generate_tensor(shape, dtype)
for dim in [3, 2, 1, 0]:
std_slice = self.op_calc(a, b, dim, min(shape) // 2)
compiled_op_calc = torch.compile(self.op_calc, backend="inductor")
inductor_slice = compiled_op_calc(a, b, dim, min(shape) // 2)
self.assertEqual(std_slice, inductor_slice, atol=1e-3, rtol=1e-3)
instantiate_parametrized_tests(TestSlice)
if __name__ == "__main__":
run_tests()