import torch
import torch_npu
import math
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestIndexing(TestCase):
def setUp(self):
super().setUp()
self._cpu_input, self._npu_input = self.create_tensor((2, 3, 4, 5))
@staticmethod
def create_tensor(shape, dtype=torch.float32):
tensor = torch.arange(math.prod(shape), dtype=dtype).reshape(shape)
return tensor.numpy(), tensor.npu()
def cpu_op_exec(self, input_tensor, begin, end, strides):
slices = tuple(slice(b, e, s) for b, e, s in zip(begin, end, strides))
return input_tensor[slices]
def npu_op_exec(
self,
input_tensor,
begin,
end,
strides,
begin_mask=0,
end_mask=0,
ellipsis_mask=0,
new_axis_mask=0,
shrink_axis_mask=0,
out=None,
):
if out is not None:
return (
torch_npu.npu_indexing.out(
input_tensor,
begin,
end,
strides,
begin_mask,
end_mask,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask,
out=out,
)
.cpu()
.numpy()
)
return (
torch_npu.npu_indexing(
input_tensor, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask
)
.cpu()
.numpy()
)
@SupportedDevices(["Ascend950"])
def test_npu_indexing_0_all_demensions(self):
"""覆盖各种维度"""
data_list = [
([9], [1], [8], [2]),
([6, 7], [1, 2], [4, 6], [3, 1]),
([10, 11, 12], [3, 2, 1], [10, 11, 12], [1, 3, 5]),
([2, 3, 4, 5, 6], [1, 0, 2, 1, 1], [2, 3, 4, 4, 5], [2, 1, 7, 3, 1]),
(
[2, 2, 2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1, 1, 1],
),
]
dtype_list = [torch.int8, torch.float16]
for shape, begin, end, strides in data_list:
for dtype in dtype_list:
cpu_input, npu_input = self.create_tensor(shape, dtype)
cpu_output = self.cpu_op_exec(cpu_input, begin, end, strides)
npu_output = self.npu_op_exec(npu_input, begin, end, strides)
self.assertRtolEqual(cpu_output, npu_output)
@SupportedDevices(["Ascend950"])
def test_npu_indexing_1_basic_slicing_all_dimensions(self):
"""基础切片测试"""
cpu_output = self._cpu_input[0:2, 1:3, 2:4, 3:5]
npu_output = self.npu_op_exec(self._npu_input, [0, 1, 2, 3], [2, 3, 4, 5], [1, 1, 1, 1])
self.assertRtolEqual(cpu_output, npu_output)
npu_output = torch.zeros(16, dtype=self._npu_input.dtype).reshape(2, 2, 2, 2).npu()
self.npu_op_exec(self._npu_input, [0, 1, 2, 3], [2, 3, 4, 5], [1, 1, 1, 1], out=npu_output)
self.assertRtolEqual(cpu_output, npu_output.cpu().numpy())
cpu_output = self._cpu_input[0:-1, 1:-1, -3:-1, -2:5:1]
npu_output = self.npu_op_exec(self._npu_input, [0, 1, -3, -2], [-1, -1, -1, 5], [1, 1, 1, 1])
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0:2, 1:3]
npu_output = self.npu_op_exec(self._npu_input, [0, 1], [2, 3], [1, 1])
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[...]
npu_output = self.npu_op_exec(self._npu_input, [], [], [])
self.assertRtolEqual(cpu_output, npu_output)
@SupportedDevices(["Ascend950"])
def test_npu_indexing_2_strides_all_directions(self):
"""测试各种步长组合"""
cpu_output = self._cpu_input[0:2:2, 0:3:1, 0:4:2, 0:5:3]
npu_output = self.npu_op_exec(self._npu_input, [0, 0, 0, 0], [2, 3, 4, 5], [2, 1, 2, 3])
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[1:0:-1, 2:0:-2, 3:0:-1, 4:0:-3]
npu_output = self.npu_op_exec(self._npu_input, [1, 2, 3, 4], [0, 0, 0, 0], [-1, -2, -1, -3])
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0:2:1, 3:0:-1, 0:4:2, 5:0:-2]
npu_output = self.npu_op_exec(self._npu_input, [0, 3, 0, 5], [2, 0, 4, 0], [1, -1, 2, -2])
self.assertRtolEqual(cpu_output, npu_output)
@SupportedDevices(["Ascend950"])
def test_npu_indexing_3_begin_mask_all_combinations(self):
"""测试所有begin_mask组合"""
cpu_output = self._cpu_input[:2, 1:3, 2:4, 3:5]
npu_output = self.npu_op_exec(self._npu_input, [0, 1, 2, 3], [2, 3, 4, 5], [1, 1, 1, 1], begin_mask=0b0001)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0:2, :3, 2:4, 3:5]
npu_output = self.npu_op_exec(self._npu_input, [0, 0, 2, 3], [2, 3, 4, 5], [1, 1, 1, 1], begin_mask=0b0010)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[:2, :3, :4, :5]
npu_output = self.npu_op_exec(self._npu_input, [0, 0, 0, 0], [2, 3, 4, 5], [1, 1, 1, 1], begin_mask=0b1111)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[:2, 1:, 0:, :5]
npu_output = self.npu_op_exec(
self._npu_input,
[0, 1, 0, 2],
[2, 3, 3, 5],
[1, 1, 1, 1],
begin_mask=0b1001,
end_mask=0b0110,
)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0:2, 1:3]
npu_output = self.npu_op_exec(self._npu_input, [0, 1], [2, 3], [1, 1], begin_mask=0b1000)
self.assertRtolEqual(cpu_output, npu_output)
@SupportedDevices(["Ascend950"])
def test_npu_indexing_4_end_mask_all_combinations(self):
"""测试所有end_mask组合"""
cpu_output = self._cpu_input[0:, 1:3, 2:4, 3:5]
npu_output = self.npu_op_exec(self._npu_input, [0, 1, 2, 3], [0, 3, 4, 5], [1, 1, 1, 1], end_mask=0b0001)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0:, 1:, 2:, 3:]
npu_output = self.npu_op_exec(self._npu_input, [0, 1, 2, 3], [0, 0, 0, 0], [1, 1, 1, 1], end_mask=0b1111)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0:2, 1:3]
npu_output = self.npu_op_exec(self._npu_input, [0, 1], [2, 3], [1, 1], end_mask=0b1000)
self.assertRtolEqual(cpu_output, npu_output)
@SupportedDevices(["Ascend950"])
def test_npu_indexing_5_shrink_axis_mask_all_combinations(self):
"""测试所有shrink_axis_mask组合"""
cpu_output = self._cpu_input[0, 1:3, 0:2, 1:5:2]
npu_output = self.npu_op_exec(
self._npu_input, [0, 1, 0, 1], [1, 3, 2, 5], [1, 1, 1, 2], shrink_axis_mask=0b0001
)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0, 1, 0:3:2, 1:5]
npu_output = self.npu_op_exec(
self._npu_input, [0, 1, 0, 1], [1, 2, 3, 5], [3, 2, 2, 1], shrink_axis_mask=0b0011
)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0, 1, 2, 3]
npu_output = self.npu_op_exec(
self._npu_input, [0, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], shrink_axis_mask=0b1111
)
self.assertEqual(cpu_output.item(), npu_output.item())
cpu_output = self._cpu_input[1:2, 1, 0:3, 2]
npu_output = self.npu_op_exec(
self._npu_input, [1, 1, 0, 2], [2, 2, 3, 3], [1, 1, 1, 1], shrink_axis_mask=0b1010
)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0:2, 1:3]
npu_output = self.npu_op_exec(self._npu_input, [0, 1], [2, 3], [1, 1], shrink_axis_mask=0b1000)
self.assertRtolEqual(cpu_output, npu_output)
@SupportedDevices(["Ascend950"])
def test_npu_indexing_6_new_axis_mask_all_combinations(self):
"""测试所有new_axis_mask组合"""
cpu_output = self._cpu_input[None, 0:2, 1:3, 1:4, 2:5]
npu_output = self.npu_op_exec(
self._npu_input, [0, 0, 1, 1, 2], [0, 2, 3, 4, 5], [1, 1, 1, 1, 1], new_axis_mask=0b000001
)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[1:2, None, 1:3, None, 2:4, 1:5]
npu_output = self.npu_op_exec(
self._npu_input, [1, 0, 1, 0, 2, 1], [2, 0, 3, 0, 4, 5], [1, 1, 1, 1, 1, 1], new_axis_mask=0b001010
)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[None, None, None, None, 1:2, 2:3, 1:4, 2:5]
npu_output = self.npu_op_exec(
self._npu_input,
[0, 0, 0, 0, 1, 2, 1, 2],
[0, 0, 0, 0, 2, 3, 4, 5],
[1, 1, 1, 1, 1, 1, 1, 1],
new_axis_mask=0b00001111,
)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0:2, 1:3]
npu_output = self.npu_op_exec(self._npu_input, [0, 1], [2, 3], [1, 1], new_axis_mask=0b1000)
self.assertRtolEqual(cpu_output, npu_output)
@SupportedDevices(["Ascend950"])
def test_npu_indexing_7_ellipsis_mask_various_positions(self):
"""测试省略号在不同位置"""
cpu_output = self._cpu_input[0:1, 0:3, ..., 2:3]
npu_output = self.npu_op_exec(self._npu_input, [0, 0, 0, 2], [1, 3, 4, 3], [1, 1, 1, 1], ellipsis_mask=0b0100)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[..., 0:3, 1:2, 2:3]
npu_output = self.npu_op_exec(self._npu_input, [0, 0, 1, 2], [2, 3, 2, 3], [1, 1, 1, 1], ellipsis_mask=0b0001)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0:1, 1:2, 0:4, ...]
npu_output = self.npu_op_exec(self._npu_input, [0, 1, 0, 0], [1, 2, 4, 5], [1, 1, 1, 1], ellipsis_mask=0b1000)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0:2, ..., 2:3, 1:3, 0:5:2]
npu_output = self.npu_op_exec(
self._npu_input,
[0, 0, 2, 1, 0],
[2, 3, 3, 3, 5],
[1, 1, 1, 1, 2],
ellipsis_mask=0b00010,
)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0:2, 1:3]
npu_output = self.npu_op_exec(self._npu_input, [0, 1], [2, 3], [1, 1], ellipsis_mask=0b1000)
self.assertRtolEqual(cpu_output, npu_output)
@SupportedDevices(["Ascend950"])
def test_npu_indexing_8_complex_combinations_all_masks(self):
"""测试复杂的mask组合"""
cpu_output = self._cpu_input[:2, 1:3:2, 4::-1, 2:5]
npu_output = self.npu_op_exec(
self._npu_input,
begin=[1, 1, 4, 2],
end=[2, 3, -1, 5],
strides=[1, 2, -1, 1],
begin_mask=0b0001,
end_mask=0b0100,
)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[0, None, None, 1, 0:4, 2:3]
npu_output = self.npu_op_exec(
self._npu_input,
begin=[0, 0, 0, 1, 0, 2],
end=[1, 0, 0, 2, 4, 3],
strides=[1, 1, 1, 1, 1, 1],
new_axis_mask=0b000110,
shrink_axis_mask=0b001001,
)
self.assertRtolEqual(cpu_output, npu_output)
@SupportedDevices(["Ascend950"])
def test_npu_indexing_9_edge_cases_with_all_masks(self):
"""边界情况测试"""
cpu_output = self._cpu_input[0:0, 1:1, 2:2, 3:3]
npu_output = self.npu_op_exec(self._npu_input, [0, 1, 2, 3], [0, 1, 2, 3], [1, 1, 1, 1])
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[-2:2:1, -3:3:1, -4:4:1, -5:5:1]
npu_output = self.npu_op_exec(self._npu_input, [-2, -3, -4, -5], [2, 3, 4, 5], [1, 1, 1, 1])
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[:, -2:3:1, :, -1:5:1]
npu_output = self.npu_op_exec(
self._npu_input,
begin=[0, -2, 0, -1],
end=[2, 3, 4, 5],
strides=[1, 1, 1, 1],
begin_mask=0b0101,
end_mask=0b0101,
)
self.assertRtolEqual(cpu_output, npu_output)
@SupportedDevices(["Ascend950"])
def test_npu_indexing_10_ellipsis_with_other_masks(self):
"""测试省略号与其他mask的组合"""
cpu_output = self._cpu_input[..., 1:, :2]
npu_output = self.npu_op_exec(
self._npu_input,
begin=[0, 1, 0],
end=[2, 3, 2],
strides=[1, 1, 1],
ellipsis_mask=0b0001,
begin_mask=0b100,
end_mask=0b010,
)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[..., ::-1, :]
npu_output = self.npu_op_exec(
self._npu_input,
begin=[0, 0, 4, 0],
end=[2, 3, 0, 5],
strides=[1, 1, -1, 1],
ellipsis_mask=0b1000,
end_mask=0b0100,
)
self.assertRtolEqual(cpu_output, npu_output)
cpu_output = self._cpu_input[..., 1:3, 0:2]
npu_output = self.npu_op_exec(
self._npu_input,
begin=[0, 1, 0],
end=[2, 3, 2],
strides=[1, 1, 1],
ellipsis_mask=0b001,
new_axis_mask=0b001,
shrink_axis_mask=0b001,
)
self.assertRtolEqual(cpu_output, npu_output)
@SupportedDevices(["Ascend950"])
def test_npu_indexing_11_exception_scenarios(self):
"""测试异常场景"""
with self.assertRaises(Exception):
self.npu_op_exec(self._npu_input, [0, 1, 2, 3, 0], [2, 3, 4, 5, 6], [1, 1, 1, 1, 1])
with self.assertRaises(Exception):
self.npu_op_exec(self._npu_input, [0, 0, 0, 0], [2, 3, 4, 5], [0, 1, 1, 1])
with self.assertRaises(Exception):
self.npu_op_exec(self._npu_input, [0, 1, 0, 0], [1, 2, 4, 5], [-1, 1, 1, 1], shrink_axis_mask=0b0011)
with self.assertRaises(Exception):
self.npu_op_exec(self._npu_input, [0, 0, 0], [2, 3, 4, 5], [1, 1, 1, 1])
with self.assertRaises(Exception):
self.npu_op_exec(self._npu_input, [0, 0, 0, 0], [2, 3, 4], [1, 1, 1, 1])
with self.assertRaises(Exception):
self.npu_op_exec(self._npu_input, [0, 0, 0, 0], [2, 3, 4, 5], [1, 1, 1])
with self.assertRaises(Exception):
self.npu_op_exec(
self._npu_input, [0, 0, 0, 0], [2, 3, 4, 5], [1, 1, 1, 1], ellipsis_mask=0b1010
)
with self.assertRaises(Exception):
self.npu_op_exec(
self._npu_input,
[0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 2, 3, 4, 5],
[1, 1, 1, 1, 1, 1, 1, 1, 1],
new_axis_mask=0b11111,
)
result = self.npu_op_exec(self._npu_input, [1, 0, 0, 0], [0, 3, 4, 5], [1, 1, 1, 1])
self.assertEqual(result.size, 0)
result = self.npu_op_exec(self._npu_input, [-10, -10, -10, -10], [10, 10, 10, 10], [1, 1, 1, 1])
expected = self._cpu_input[:, :, :, :]
self.assertRtolEqual(result, expected)
result = self.npu_op_exec(self._npu_input, [0, 0, 0, 0], [0, 0, 0, 0], [-1, -1, -1, -1])
self.assertEqual(result.size, 0)
if __name__ == "__main__":
run_tests()