import math
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
torch.npu.set_compile_mode(jit_compile=False)
torch.npu.config.allow_internal_format = False
class TestScatterList(TestCase):
def supported_op_exec(self, var_list, indice, updates, axis=-2):
if axis == -2:
for i, item in enumerate(var_list):
for j in range(var_list[0].shape[0]):
item[j][indice[i]] = updates[i][j][0]
elif axis == -1:
for i, item in enumerate(var_list):
for j in range(var_list[0].shape[0]):
for k in range(var_list[0].shape[1]):
item[j][k][indice[i]] = updates[i][j][k][0]
return var_list
def custom_op_exec(self, var_list, indice, updates, mask):
reduce = 'update'
axis = -2
return torch_npu.npu_scatter_list(var_list, indice, updates, mask, reduce, axis)
@SupportedDevices(['Ascend910B'])
@unittest.skip("Temporarily skipping")
def test_npu_scatter_list(self, device="npu"):
if torch.__version__ > '2.0':
var_list = []
for i in range(8):
var = torch.zeros([4, 4096, 256], dtype=torch.float16).npu()
var_list.append(var)
indice = torch.zeros([8], dtype=torch.int32).npu()
updates = torch.ones([8, 4, 1, 256], dtype=torch.float16).npu()
mask = torch.ones([8], dtype=torch.uint8).npu()
supported_output = self.supported_op_exec(var_list, indice, updates, axis=-2)
custom_output = self.custom_op_exec(var_list, indice, updates, mask)
for i in range(8):
self.assertEqual(supported_output[i], custom_output[i])
if __name__ == "__main__":
run_tests()