import os
import copy
import shutil
import random
import torch
import torch_npu
import numpy as np
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import get_npu_device, SupportedDevices
def cumsum2index(seq_lens):
if seq_lens is None:
raise KeyError('seq_lens is None')
seq_lens = np.asarray(seq_lens)
batch = len(seq_lens) - 1
result = np.zeros(batch, dtype=seq_lens.dtype)
for i in range(batch):
result[i] = int(seq_lens[i + 1] - seq_lens[i])
return result
def gather_pa_kv_cache_nd(context, key_cache, value_cache, block_tables, seq_lens, key, value, seq_offset):
num_blocks, block_size, num_heads, head_size_k = key_cache.shape
is_seq_lens_cumsum = context.get('is_seq_lens_cumsum', False)
num_tokens = key.shape[0]
kv_rslt_id = 0
if is_seq_lens_cumsum:
seq_lens = cumsum2index(seq_lens)
accum_seq_len = 0
for i in range(len(seq_lens)):
block_table = block_tables[i]
seq_len = seq_lens[i]
if num_blocks > accum_seq_len and num_tokens <= (acc_seq_len + seq_len):
seq_len = num_tokens - accum_seq_len
accum_seq_len += seq_len
if seq_offset is None:
block_start = 0
else:
block_start = seq_offset[i] // block_size
for j in range(seq_len):
if kv_rslt_id >= key.shape[0]:
break
block_table_idx = block_start + j // block_size
if block_table_idx >= block_table.shape[0]:
is_filled_with_zero = True
block_id = -1
else:
is_filled_with_zero = False
block_id = block_table[block_table_idx]
block_offset = j % block_size
if block_id >= num_blocks or block_id < 0 or is_filled_with_zero:
temp_k = np.zeros_like(key_cache[0][0])
temp_v = np.zeros_like(value_cache[0][0])
else:
temp_k = key_cache[block_id][block_offset]
temp_v = value_cache[block_id][block_offset]
key[kv_rslt_id] = temp_k
value[kv_rslt_id] = temp_v
kv_rslt_id += 1
return [key, value]
def gather_pa_kv_cache_nz(context, key_cache, value_cache, block_tables, seq_lens, key, value, seq_offset):
num_blocks, _, block_size, elenum_aligned = key_cache.shape
num_tokens, num_heads, head_size_k = key.shape
num_tokens, num_heads, head_size_v = value.shape
is_seq_lens_cumsum = context.get('is_seq_lens_cumsum', False)
num_heads_k = num_heads * head_size_k
num_heads_v = num_heads * head_size_v
key = key.reshape((num_tokens, num_heads_k))
value = value.reshape((num_tokens, num_heads_v))
if is_seq_lens_cumsum:
seq_lens = cumsum2index(seq_lens)
kv_rslt_id = 0
for i in range(len(seq_lens)):
block_table = block_tables[i]
seq_len = seq_lens[i]
if seq_offset is None:
block_table = 0
else:
block_start = seq_offset[i] // block_size
for j in range(seq_len):
if kv_rslt_id >= key.shape[0]:
break
block_table_idx = block_start + j // block_size
if block_table_idx >= block_table.shape[0]:
block_id = -1
else:
block_id = block_table[block_table_idx]
block_offset = j % block_size
temp_k = np.zeros_like((num_heads_k,), dtype=key.dtype)
temp_v = np.zeros_like((num_heads_v,), dtype=value.dtype)
if block_id >= 0 and block_id < num_blocks:
for k in range(num_heads_k // elenum_aligned):
temp_k[k * elenum_aligned: (k + 1) * elenum_aligned] = \
key_cache[block_id][k][block_offset][:]
for k in range(num_heads_v // elenum_aligned):
temp_v[k * elenum_aligned: (k + 1) * elenum_aligned] = \
value_cache[block_id][k][block_offset][:]
key[kv_rslt_id] = temp_k
value[kv_rslt_id] = temp_v
kv_rslt_id += 1
key = key.reshape((num_tokens, num_heads, head_size_k))
value = value.reshape((num_tokens, num_heads, head_size_v))
return [key, value]
def golden_gather_pa_kv_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
key_out: torch.Tensor,
value_out: torch.Tensor,
seq_offset: torch.Tensor = None,
is_seq_lens_cumsum: bool = False,
cache_mode: str = 'Norm'
):
key_cache_np = key_cache.cpu().numpy()
value_cache_np = value_cache.cpu().numpy()
block_tables_np = block_tables.cpu().numpy()
seq_lens_np = seq_lens.cpu().numpy()
key_np = key.cpu().numpy().copy()
value_np = value.cpu().numpy().copy()
seq_offset_np = seq_offset.cpu().numpy() if seq_offset is not None else None
context = {
'is_seq_lens_cumsum': is_seq_lens_cumsum,
'cache_mode': cache_mode
}
if cache_mode == 'Norm':
key_result, value_result = gather_pa_kv_cache_nd(
context=context,
key_cache=key_cache_np,
value_cache=value_cache_np,
block_tables=block_tables_np,
seq_lens=seq_lens_np,
key=key_np,
value=value_np,
seq_offset=seq_offset_np
)
elif cache_mode == 'PA_NZ':
key_result, value_result = gather_pa_kv_cache_nz(
context=context,
key_cache=key_cache_np,
value_cache=value_cache_np,
block_tables=block_tables_np,
seq_lens=seq_lens_np,
key=key_np,
value=value_np,
seq_offset=seq_offset_np
)
else:
raise KeyError(f'cache mode can only be one of Norm or PA_NZ')
return torch.from_numpy(key_result), torch.from_numpy(value_result)
class GatherPaKvCacheModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, key_cache, value_cache, block_tables, seq_lens, key, value, seq_offset, is_seq_lens_cumsum):
output_key, output_value = torch_npu.npu_gather_pa_kv_cache_functional(
key_cache, value_cache, block_tables, seq_lens, key, value,
seq_offset=seq_offset, is_seq_lens_cumsum=is_seq_lens_cumsum
)
return output_key, output_value
class GatherPaKvCacheInplaceModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, key_cache, value_cache, block_tables, seq_lens, key, value, seq_offset, is_seq_lens_cumsum):
torch_npu.npu_gather_pa_kv_cache(
key_cache, value_cache, block_tables, seq_lens, key, value,
seq_offset=seq_offset, is_seq_lens_cumsum=is_seq_lens_cumsum
)
return key, value
class TestGatherPaKvCache(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
random.seed(0)
torch.manual_seed(0)
def _run_test(self, mode, api_impl_mode, input_dtype=torch.float16):
batch_size = 2
num_blocks = 8
head_num = 4
block_size = 64
head_dim = 64
max_blocks_per_sequence = 5
seq_lens_list = [random.randint(1, 10) for _ in range(batch_size)]
is_seq_lens_cumsum = True
if is_seq_lens_cumsum:
cumsum = [0]
for x in seq_lens_list:
cumsum.append(cumsum[-1] + x)
seq_lens = torch.tensor(seq_lens_list, dtype=torch.int32)
total_tokens = cumsum[-1]
else:
seq_lens = torch.tensor(seq_lens_list, dtype=torch.int32)
total_tokens = sum(seq_lens_list)
key_cache = torch.randn(num_blocks, block_size, head_num, head_dim, dtype=input_dtype)
value_cache = torch.randn(num_blocks, block_size, head_num, head_dim, dtype=input_dtype)
block_tables = torch.randint(0, num_blocks, (batch_size, max_blocks_per_sequence), dtype=torch.int32)
key_out = torch.zeros(total_tokens, head_num, head_dim, dtype=input_dtype)
value_out = torch.zeros(total_tokens, head_num, head_dim, dtype=input_dtype)
seq_offset = torch.randint(0, block_size, (batch_size,), dtype=torch.int32)
key_gold, value_gold = golden_gather_pa_kv_cache(
key_cache=key_cache,
value_cache=value_cache,
block_tables=block_tables,
seq_lens=seq_lens,
key_out=key_out,
value_out=value_out,
seq_offset=seq_offset,
is_seq_lens_cumsum=is_seq_lens_cumsum,
cache_mode='Norm'
)
key_cache_npu = key_cache.npu()
value_cache_npu = value_cache.npu()
block_tables_npu = block_tables.npu()
seq_lens_npu = seq_lens.npu()
key_out_npu = key_out.npu()
value_out_npu = value_out.npu()
seq_offset_npu = seq_offset.npu()
if mode == "inplace":
model = GatherPaKvCacheInplaceModel()
key_input_npu = key_out_npu
value_input_npu = value_out_npu
elif mode == "out_of_place":
model = GatherPaKvCacheModel()
key_input_npu = key_out_npu
value_input_npu = value_out_npu
else:
self.fail(f"Unsupported mode: {mode}")
if api_impl_mode == "eager":
model = torch.compile(model, backend="eager", dynamic=True)
else:
self.fail(f"Unsupported api_impl_mode: {api_impl_mode}")
with torch.npu_gard():
out_key_npu, out_value_npu = model(
key_cache_npu, value_cache_npu, block_tables_npu, seq_lens_npu,
key_input_npu, value_input_npu,
seq_offset=seq_offset_npu,
is_seq_lens_cumsum=is_seq_lens_cumsum
)
torch.npu.synchronize()
out_key_cpu = out_key_npu.cpu()
out_value_cpu = out_value_npu.cpu()
rtol = 1e-3 if input_dtype == torch.float16 else 1e-4
atol = 1e-3 if input_dtype == torch.float16 else 1e-4
self.assertRtolEqual(key_gold, out_key_cpu, rtol=rtol, atol=atol)
self.assertRtolEqual(value_gold, out_value_cpu, rtol=rtol, atol=atol)
@SupportedDevices(['Ascend950'])
def test_out_of_place_eager_fp16(self):
self._run_test("out_of_place", "eager", torch.float16)
@SupportedDevices(['Ascend950'])
def test_inplace_eager_fp16(self):
self._run_test("inplace", "eager", torch.float16)
if __name__ == "__main__":
run_tests()