import torch
import numpy as np
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor
class TestAsStrided(TestCase):
def cpu_op_exec(self, input1, size, stride, storage_offset):
output = torch.as_strided(input1, size, stride, storage_offset)
output = output.numpy()
return output
def npu_op_exec(self, input1, size, stride, storage_offset):
output = torch.as_strided(input1, size, stride, storage_offset)
output = output.cpu().numpy()
return output
def test_as_strided(self):
shape_format = [
[[np.float32, 0, [3, 3]], (2, 2), (1, 2), 0],
[[np.float16, 0, [13, 23]], (10, 15), (1, 2), 1],
[[np.int32, 0, [5, 5]], (3, 3), (1, 2), 1],
[[np.float32, 2, [32, 8, 2]], (8, 6, 2), (5, 4, 1), 1],
[[np.int32, 2, [8, 16]], (6, 3), (8, 2), 0],
]
for item in shape_format:
cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100)
cpu_output = self.cpu_op_exec(cpu_input1, item[1], item[2], item[3])
npu_output = self.npu_op_exec(npu_input1, item[1], item[2], item[3])
self.assertRtolEqual(cpu_output, npu_output)
if __name__ == "__main__":
run_tests()