from dataclasses import dataclass
from typing import Optional
import unittest
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
@dataclass
class GatherCacheParams:
src_cache: torch.Tensor
dst: torch.Tensor
block_table: torch.Tensor
cu_seq_lens: torch.Tensor
batch_size: int
seq_starts: Optional[torch.Tensor] = None
@dataclass
class TestPagedCacheLoadParams:
kv_lora_rank: int
qk_rope_head_dim: int
block_size: int
num_blocks: int
max_seq_len: int
batch_size: int
device: str
@dataclass
class PreparedData:
expected: torch.Tensor
seq_len_tensor: torch.Tensor
block_table: torch.Tensor
seq_starts: Optional[torch.Tensor]
cached_kv_c: torch.Tensor
cached_k_pe: torch.Tensor
class TestPagedCacheLoadSeqStarts(TestCase):
def _create_mla_cache(
self,
num_blocks: int,
block_size: int,
entry_size: int,
device: str,
) -> torch.Tensor:
return torch.randn(
num_blocks, block_size, entry_size, dtype=torch.float16, device=device
)
def _gather_cache_torch(
self,
gather_cache_params: GatherCacheParams,
) -> None:
"""
Gather sequence data from source cache to destination tensor
Args:
src_cache: Source cache tensor [NUM_BLOCKS, BLOCK_SIZE, HEAD, ENTRIES]
dst: Destination tensor [TOT_TOKENS, ENTRIES]
block_table: Block table mapping [BATCH, BLOCK_INDICES]
cu_seq_lens: Cumulative sequence lengths [BATCH+1]
batch_size: Batch size
seq_starts: Optional, starting offsets for each batch [BATCH]
"""
(src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts) = (
gather_cache_params.src_cache,
gather_cache_params.dst,
gather_cache_params.block_table,
gather_cache_params.cu_seq_lens,
gather_cache_params.batch_size,
gather_cache_params.seq_starts,
)
assert src_cache.dtype == dst.dtype, "src_cache and dst must have same dtype"
assert block_table.dtype == torch.int32, "block_table must be int32"
assert cu_seq_lens.dtype == torch.int32, "cu_seq_lens must be int32"
if seq_starts is not None:
assert seq_starts.dtype == torch.int32, "seq_starts must be int32"
block_size = src_cache.size(1)
for bid in range(batch_size):
seq_start = cu_seq_lens[bid].item()
seq_end = cu_seq_lens[bid + 1].item()
seq_len = seq_end - seq_start
if seq_len == 0:
continue
tot_blocks = (seq_len + block_size - 1) // block_size
offset = 0
if seq_starts is not None:
offset = seq_starts[bid].item() // block_size
batch_block_table = block_table[bid, offset: offset + tot_blocks]
full_blocks = tot_blocks - 1 if seq_len % block_size else tot_blocks
partial_block_size = seq_len % block_size if seq_len % block_size else 0
dst_start = seq_start
for i in range(full_blocks):
block_id = batch_block_table[i].item()
dst[dst_start: dst_start + block_size] = src_cache[block_id].squeeze(1)
dst_start += block_size
if partial_block_size > 0:
block_id = batch_block_table[full_blocks].item()
dst[dst_start: dst_start + partial_block_size] = src_cache[
block_id, :partial_block_size
].squeeze(1)
def _prepare_data(self, test_params: TestPagedCacheLoadParams) -> None:
kv_lora_rank = test_params.kv_lora_rank
qk_rope_head_dim = test_params.qk_rope_head_dim
block_size = test_params.block_size
num_blocks = test_params.num_blocks
max_seq_len = test_params.max_seq_len
batch_size = test_params.batch_size
device = test_params.device
entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = self._create_mla_cache(num_blocks, block_size, entry_size, device)
seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
block_table = torch.empty(
(batch_size, num_blocks), dtype=torch.int32, device=device
)
for b in range(batch_size):
perm = torch.randperm(num_blocks, device=device)
block_table[b, :] = perm
expected = torch.zeros(
(total_tokens, entry_size), dtype=src_cache.dtype, device=device
)
max_start = max_seq_len // 2
seq_starts = torch.randint(
0, max_start + 1, (batch_size,), dtype=torch.int32, device=device
)
gather_cache_params = GatherCacheParams(
src_cache=src_cache,
dst=expected,
block_table=block_table,
cu_seq_lens=cu_seq_lens,
batch_size=batch_size,
seq_starts=seq_starts,
)
self._gather_cache_torch(gather_cache_params)
cached_kv_c, cached_k_pe = src_cache.split(
[kv_lora_rank, qk_rope_head_dim], dim=2
)
cached_kv_c = cached_kv_c.view(num_blocks, block_size, 1, kv_lora_rank).to(
torch.float16
)
cached_k_pe = cached_k_pe.view(num_blocks, block_size, 1, qk_rope_head_dim).to(
torch.float16
)
return PreparedData(
expected,
seq_len_tensor,
block_table,
seq_starts,
cached_kv_c,
cached_k_pe,
)
@SupportedDevices(["Ascend910B"])
def test_atb_paged_cache_load_out(self):
kv_lora_rank = 512
qk_rope_head_dim = 64
block_size = 16
num_blocks = 1024
max_seq_len = 512
batch_size = 8
device = "npu"
test_params = TestPagedCacheLoadParams(
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
block_size=block_size,
num_blocks=num_blocks,
max_seq_len=max_seq_len,
batch_size=batch_size,
device=device,
)
prepared_data = self._prepare_data(test_params)
(
expected,
seq_len_tensor,
block_table,
seq_starts,
cached_kv_c,
cached_k_pe,
) = (
prepared_data.expected,
prepared_data.seq_len_tensor,
prepared_data.block_table,
prepared_data.seq_starts,
prepared_data.cached_kv_c,
prepared_data.cached_k_pe,
)
total_tokens = seq_len_tensor.sum()
kv_c = torch.empty(
(total_tokens, 1, kv_lora_rank), dtype=torch.float16, device=device
)
k_pe = torch.empty(
(total_tokens, 1, qk_rope_head_dim), dtype=torch.float16, device=device
)
torch_npu.atb.npu_paged_cache_load(
cached_kv_c,
cached_k_pe,
block_table,
seq_len_tensor.int(),
seq_starts=seq_starts,
key=kv_c,
value=k_pe,
)
torch_npu_result = torch.cat([kv_c, k_pe], dim=2).view(total_tokens, -1)
self.assertRtolEqual(expected, torch_npu_result)
@SupportedDevices(["Ascend910B"])
def test_atb_paged_cache_load(self):
kv_lora_rank = 512
qk_rope_head_dim = 64
block_size = 16
num_blocks = 1024
max_seq_len = 512
batch_size = 8
device = "npu"
test_params = TestPagedCacheLoadParams(
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
block_size=block_size,
num_blocks=num_blocks,
max_seq_len=max_seq_len,
batch_size=batch_size,
device=device,
)
prepared_data = self._prepare_data(test_params)
(
expected,
seq_len_tensor,
block_table,
seq_starts,
cached_kv_c,
cached_k_pe,
) = (
prepared_data.expected,
prepared_data.seq_len_tensor,
prepared_data.block_table,
prepared_data.seq_starts,
prepared_data.cached_kv_c,
prepared_data.cached_k_pe,
)
total_tokens = seq_len_tensor.sum()
kv_c, k_pe = torch_npu.atb.npu_paged_cache_load(
cached_kv_c,
cached_k_pe,
block_table,
seq_len_tensor.int(),
seq_starts=seq_starts,
)
torch_npu_result = torch.cat([kv_c, k_pe], dim=2).view(total_tokens, -1)
self.assertRtolEqual(expected, torch_npu_result)
if __name__ == "__main__":
run_tests()