import math
import unittest
import copy
import struct
from struct import pack, unpack
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestQuantScatter(TestCase):
def supported_op_exec(self, var, indices, updates, quant_scales):
quant_scales_new = quant_scales.view(32)
updates_new = torch_npu.npu_quantize(updates, quant_scales_new, None, torch.qint8, -1).to(torch.int8)
return torch_npu.scatter_update(var, indices, updates_new, -2)
def custom_op_exec(self, var, indices, updates, quant_scales):
return torch_npu.npu_quant_scatter(var, indices, updates, quant_scales, None, -2, -1, "update")
@SupportedDevices(['Ascend910B'])
def test_npu_quant_scatter(self, device="npu"):
var_data = np.random.uniform(0, 1, [1, 1, 32]).astype(np.int8)
var1 = torch.from_numpy(var_data).to(torch.int8).npu()
var2 = var1.clone()
indices_data = np.random.uniform(0, 1, [1]).astype(np.int32)
indices1 = torch.from_numpy(indices_data).to(torch.int32).npu()
indices2 = indices1.clone()
updates_data = np.random.uniform(1, 2, [1, 1, 32]).astype(np.float16)
updates1 = torch.from_numpy(updates_data).to(torch.bfloat16).npu()
updates2 = updates1.clone()
quant_scales_data = np.random.uniform(0, 1, [1, 1, 32]).astype(np.float16)
quant_scales1 = torch.from_numpy(quant_scales_data).to(torch.bfloat16).npu()
quant_scales2 = quant_scales1.clone()
supported_output = self.supported_op_exec(var1, indices1, updates1, quant_scales1)
custom_output = self.custom_op_exec(var2, indices2, updates2, quant_scales2)
self.assertRtolEqual(supported_output, custom_output, 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_scatter_fp8(self, device="npu"):
var_data = np.random.uniform(0, 1, [1, 1, 32]).astype(np.int8)
var1 = torch.from_numpy(var_data).to(torch.float8_e5m2).npu()
var2 = var1.clone()
indices_data = np.random.uniform(0, 1, [1]).astype(np.int32)
indices1 = torch.from_numpy(indices_data).to(torch.int32).npu()
indices2 = indices1.clone()
updates_data = np.random.uniform(1, 2, [1, 1, 32]).astype(np.float16)
updates1 = torch.from_numpy(updates_data).to(torch.bfloat16).npu()
updates2 = updates1.clone()
quant_scales_data = np.random.uniform(0, 1, [1, 1, 32]).astype(np.float16)
quant_scales1 = torch.from_numpy(quant_scales_data).to(torch.bfloat16).npu()
quant_scales2 = quant_scales1.clone()
with self.assertRaisesRegex(RuntimeError, msg):
torch_npu.npu_quant_scatter(var, indices, updates, quant_scales, None, -2, -1, "update", torch_npu.bfloat16, "rint")
if __name__ == "__main__":
run_tests()