import unittest
import copy
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestScatterPaCache(TestCase):
def supported_op_exec(self, key, keyCache, slotMapping):
key_shape = key.shape
num_tokens = key_shape[0]
num_heads = key_shape[1]
head_size_k = key_shape[2]
key_cache_shape = keyCache.shape
key_cache_out = copy.deepcopy(keyCache)
for i in range(num_tokens):
block_idx = slotMapping[i]
block_offset = slotMapping[i]
key_cache_out[block_idx, block_offset, :, :] = key[i, :, :]
return key_cache_out
def custom_op_exec(self, key, keyCache, slotMapping):
return torch_npu.npu_scatter_pa_cache(key, slotMapping, key_cache=keyCache)
@SupportedDevices(['Ascend950'])
def test_npu_scatter_pa_cache(self, device="npu"):
key = torch.randint(-1, 1, (256, 16, 16), dtype=torch.float32).npu()
keyCache = torch.randint(-1, 1, (16, 16, 16, 16), dtype=torch.float32).npu()
slotMapping = torch.arange(0, 256).view(256).to(torch.int32).npu()
supported_output = self.supported_op_exec(key, keyCache, slotMapping)
custom_output = self.custom_op_exec(key, keyCache, slotMapping)
self.assertRtolEqual(supported_output, custom_output)
if __name__ == "__main__":
run_tests()