import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.decorator import Dtypes, instantiate_tests
@instantiate_tests
class TestNpuSlice(TestCase):
def split_npu_slice(self, input1, offset, sizes):
input_dim = input1.size()
num_dim = len(input_dim)
for i in range(num_dim):
input_index = [j for j in range(offset[i], offset[i] + sizes[i])]
input1 = torch.index_select(input=input1,
dim=i,
index=torch.tensor(input_index, device="npu"))
return input1
def split_npu_slice_out(self, input1, offset, sizes, out=None):
input_dim = input1.size()
num_dim = len(input_dim)
for i in range(num_dim):
input_index = [j for j in range(offset[i], offset[i] + sizes[i])]
input1 = torch.index_select(input=input1,
dim=i,
index=torch.tensor(input_index, device="npu"))
out = input1.clone()
return out
def npu_op_exec(self, input1, offset, sizes):
output = torch_npu.npu_slice(input1, offset, sizes)
return output
def split_op_exec(self, input1, offset, sizes):
output = self.split_npu_slice(input1, offset, sizes)
return output
@Dtypes(torch.float, torch.half, torch.int32, torch.uint8, torch.int8, torch.int16, torch.long)
def test_slice(self, dtype):
input_data = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]).npu().to(dtype)
split_out = self.split_op_exec(input_data, [0, 0], [2, 2])
exp_out = self.npu_op_exec(input_data, [0, 0], [2, 2])
self.assertRtolEqual(split_out, exp_out)
input_data = torch.tensor([[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]],
[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]],
[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]]).npu().to(dtype)
split_out = self.split_op_exec(input_data, [0, 0, 0], [2, 2, 2])
exp_out = self.npu_op_exec(input_data, [0, 0, 0], [2, 2, 2])
self.assertRtolEqual(split_out, exp_out)
if __name__ == '__main__':
run_tests()