989aa0be创建于 2024年11月1日历史提交
import math
import unittest
import numpy as np
import torch

import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices

torch.npu.set_compile_mode(jit_compile=False)
torch.npu.config.allow_internal_format = False


class TestNpuScatterNdUpdate(TestCase):

    def supported_op_exec(self, var, indices_tensor, updates):
        for i in range(len(var)):
            if i < len(indices_tensor):
                var[indices_tensor[i][0]][indices_tensor[i][1]] = updates[i]

        return var

    def custom_op_exec(self, var, indices_tensor, updates):
        return torch_npu.npu_scatter_nd_update(var, indices_tensor, updates)

    @SupportedDevices(['Ascend910B'])
    def test_npu_scatter_nd_update(self, device="npu"):
        var = torch.zeros([3, 2], dtype=torch.float16).npu()
        indices = np.array([[0, 0], [1, 1]])
        indices_tensor = torch.from_numpy(indices).to(device)
        updates = np.array([10, 20])
        updates_tensor = torch.from_numpy(updates.astype(np.float16)).to(device)

        supported_output = self.supported_op_exec(var, indices_tensor, updates)
        custom_output = self.custom_op_exec(var, indices_tensor, updates_tensor)
        self.assertEqual(supported_output, custom_output)


if __name__ == "__main__":
    run_tests()