import random
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestReshapeAndCache(TestCase):
num_tokens = 14
num_head = 1
head_size = 16
block_size = 16
num_blocks = 53535
def cal_nd(self, key, value, key_cache, value_cache, slot_mapping):
key_expect = key_cache.clone()
value_expect = value_cache.clone()
for i, slot in enumerate(slot_mapping):
if slot < 0:
continue
block_index = slot // self.block_size
block_offset = slot % self.block_size
token_key = key[i]
token_v = value[i]
key_expect[block_index][block_offset] = token_key
value_expect[block_index][block_offset] = token_v
return key_expect, value_expect
def cal_nz(self, key, value, key_cache, value_cache, slot_mapping):
key_expect_nz = key_cache.clone()
value_expect_nz = value_cache.clone()
data_type = key.dtype
k_head_size = key.shape[2]
v_head_size = value.shape[2]
last_dim_k = 0
last_dim_v = 16
if data_type == torch.int8:
last_dim_k = 32
else:
last_dim_k = 16
num_blocks, _, block_size, _ = key_cache.shape
value_expect_nz = value_cache
for i, slot in enumerate(slot_mapping):
block_index = slot // block_size
block_offset = slot % block_size
token_key = key[i]
token_v = value[i]
num_head = self.num_head
token_key = token_key.reshape(num_head * k_head_size)
token_v = token_v.reshape(num_head * v_head_size)
for k in range(num_head * k_head_size // last_dim_k):
key_expect_nz[block_index][k][block_offset][:] = token_key[k * last_dim_k: k * last_dim_k + last_dim_k]
for v in range(num_head * v_head_size // last_dim_v):
value_expect_nz[block_index][v][block_offset][:] = token_v[v * last_dim_v: v * last_dim_v + last_dim_v]
return [key_expect_nz, value_expect_nz]
@SupportedDevices(['Ascend910B'])
def test_reshape_and_cache(self):
head_size_k = np.random.randint(1, 256)
head_size_v = np.random.randint(1, 256)
key = torch.rand((self.num_tokens, self.num_head, head_size_k), dtype=torch.float16)
value = torch.rand((self.num_tokens, self.num_head, head_size_v), dtype=torch.float16)
num_slots = self.block_size * self.num_blocks
slot_list = random.sample(range(num_slots), self.num_tokens)
slot_mapping = np.array(slot_list).astype(np.int32)
key_cache = torch.rand((self.num_blocks, self.block_size, self.num_head, head_size_k), dtype=torch.float16)
value_cache = torch.rand((self.num_blocks, self.block_size, self.num_head, head_size_v), dtype=torch.float16)
key_expect, value_expect = self.cal_nd(key, value, key_cache, value_cache, slot_mapping)
key = key.npu()
value = value.npu()
key_cache = key_cache.npu()
value_cache = value_cache.npu()
slot_mapping = torch.from_numpy(slot_mapping).to(torch.int32).npu()
torch_npu._npu_reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
self.assertRtolEqual(key_expect, key_cache)
self.assertRtolEqual(value_expect, value_cache)
@SupportedDevices(['Ascend910B'])
def test_reshape_and_cache_int8(self):
head_size_k = 512
head_size_v = 64
key = torch.randint(-128, 128, (self.num_tokens, self.num_head, head_size_k), dtype=torch.int8)
value = torch.rand((self.num_tokens, self.num_head, head_size_v), dtype=torch.bfloat16)
num_slots = self.block_size * self.num_blocks
slot_list = random.sample(range(num_slots), self.num_tokens)
slot_mapping = np.array(slot_list).astype(np.int32)
key_cache = torch.randint(-128, 128, (self.num_blocks, 16, 128, 32), dtype=torch.int8)
value_cache = torch.rand((self.num_blocks, 4, 128, 16), dtype=torch.bfloat16)
key_expect_nz, value_expect_nz = self.cal_nz(key, value, key_cache, value_cache, slot_mapping)
key_cache_nz = torch_npu.npu_format_cast(key_cache.npu(), 29)
value_cache_nz = torch_npu.npu_format_cast(value_cache.npu(), 29)
key = key.npu()
value = value.npu()
key_cache = key_cache.npu()
value_cache = value_cache.npu()
key_cache = key_cache_nz.npu()
value_cache = value_cache_nz.npu()
slot_mapping = torch.from_numpy(slot_mapping).to(torch.int32).npu()
torch_npu._npu_reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
self.assertRtolEqual(key_expect_nz, key_cache)
self.assertRtolEqual(value_expect_nz, value_cache)
if __name__ == '__main__':
run_tests()