""" """
import os
import math
import random
from dataclasses import dataclass
from typing import Type, Any, List, Tuple, Union
import torch
import pypto
import pytest
from numpy.testing import assert_allclose
import torch_npu
TensorLike1D = Union[torch.Tensor, List[Any]]
TensorLike2D = Union[torch.Tensor, List[Any]]
@dataclass
class PageAttentionTestConfig:
"""
C++ 模板等价版本(偏 torch 风格):
template <typename IndexT, typename DataT>
struct PageAttentionTestConfig { ... };
这里直接存 torch 的 dtype,方便后面建 tensor。
"""
topk_count: int
num_logical_blocks: int
num_buffer_tokens: int
hidden_dim: int
block_size: int
index_dtype: torch.dtype = torch.int32
data_dtype: torch.dtype = torch.float16
def _to_list_1d(v: TensorLike1D) -> List[Any]:
if isinstance(v, torch.Tensor):
return v.flatten().tolist()
return list(v)
def validate_config(cfg: PageAttentionTestConfig) -> Tuple[bool, str]:
err = ""
total_logical_tokens = cfg.num_logical_blocks * cfg.block_size
if cfg.topk_count > total_logical_tokens:
err = "topk_count 必须 <= num_logical_blocks * block_size(topk 的 k 不能超过逻辑 token 总数)"
return False, err
if cfg.num_buffer_tokens < cfg.block_size:
err = "num_buffer_tokens 必须至少 >= block_size,才能容纳一个物理块"
return False, err
num_physical_blocks = cfg.num_buffer_tokens // cfg.block_size
if num_physical_blocks <= 0:
err = "num_buffer_tokens / block_size 必须 > 0"
return False, err
return True, ""
def make_buffer(cfg: PageAttentionTestConfig) -> torch.Tensor:
buffer = torch.empty(
(cfg.num_buffer_tokens, cfg.hidden_dim),
dtype=cfg.data_dtype,
)
for token_index in range(cfg.num_buffer_tokens):
for h in range(cfg.hidden_dim):
buffer[token_index, h] = 10.0 * token_index + h
return buffer.to(cfg.data_dtype)
def make_page_table(cfg: PageAttentionTestConfig,
seed: int = 42) -> torch.Tensor:
"""
返回 shape = [1, num_logical_blocks]
"""
num_physical_blocks = cfg.num_buffer_tokens // cfg.block_size
g = torch.Generator()
g.manual_seed(seed)
page_table = torch.randint(
low=0,
high=num_physical_blocks,
size=(1, cfg.num_logical_blocks),
generator=g,
dtype=cfg.index_dtype,
)
return page_table
def make_topk_indices(cfg: PageAttentionTestConfig,
seed: int = 123) -> torch.Tensor:
"""
返回 shape = [1, topk_count]
"""
total_logical_tokens = cfg.num_logical_blocks * cfg.block_size
g = torch.Generator()
g.manual_seed(seed)
topk_indices = torch.randint(
low=0,
high=total_logical_tokens,
size=(1, cfg.topk_count),
generator=g,
dtype=cfg.index_dtype,
)
return topk_indices
def compute_physical_index(
logical_index: Union[int, torch.Tensor],
page_table: torch.Tensor,
cfg: PageAttentionTestConfig,
) -> int:
"""
输入逻辑 token index(可以是 python int 或 0-dim tensor),
输出物理 token index(python int),便于后面做索引。
"""
if isinstance(logical_index, torch.Tensor):
logical_index = int(logical_index.item())
else:
logical_index = int(logical_index)
logical_block_id = logical_index // cfg.block_size
physical_block_id = int(page_table[logical_block_id].item())
block_offset = logical_index % cfg.block_size
physical_index = physical_block_id * cfg.block_size + block_offset
return physical_index
def gather_golden(
topk_indices: torch.Tensor,
page_table: torch.Tensor,
buffer: torch.Tensor,
cfg: PageAttentionTestConfig,
) -> torch.Tensor:
if topk_indices.dim() == 2:
assert topk_indices.shape[0] == 1
topk_indices = topk_indices.flatten()
if page_table.dim() == 2:
assert page_table.shape[0] == 1
page_table = page_table.flatten()
if topk_indices.numel() != cfg.topk_count:
raise RuntimeError("topk_indices.size() != topk_count")
if page_table.numel() != cfg.num_logical_blocks:
raise RuntimeError("page_table.size() != num_logical_blocks")
if buffer.numel() != cfg.num_buffer_tokens * cfg.hidden_dim:
raise RuntimeError("buffer.size() != num_buffer_tokens * hidden_dim")
if buffer.dim() != 2 or buffer.shape != (cfg.num_buffer_tokens, cfg.hidden_dim):
raise RuntimeError("buffer shape mismatch")
result = torch.empty((cfg.topk_count, cfg.hidden_dim),
dtype=cfg.data_dtype)
total_logical_tokens = cfg.num_logical_blocks * cfg.block_size
for j in range(cfg.topk_count):
logical_index = int(topk_indices[j])
if logical_index < 0 or logical_index >= total_logical_tokens:
raise RuntimeError(f"logical_index 越界: {logical_index}")
physical_index = compute_physical_index(logical_index, page_table, cfg)
if physical_index < 0 or physical_index >= cfg.num_buffer_tokens:
raise RuntimeError(f"physical_index 越界: {physical_index}")
result[j, :] = buffer[physical_index, :]
return result
def test_vector_operator_gatherinub():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
pypto.runtime._device_init()
cfg = PageAttentionTestConfig(
topk_count=8,
num_logical_blocks=3,
num_buffer_tokens=32,
hidden_dim=4,
block_size=4,
index_dtype=torch.int32,
data_dtype=torch.float16
)
ok, err = validate_config(cfg)
if not ok:
raise RuntimeError(f"Config invalid: {err}")
buffer = make_buffer(cfg)
page_table = make_page_table(cfg, seed=42)
topk_indices = make_topk_indices(cfg, seed=123)
golden = gather_golden(topk_indices, page_table, buffer, cfg)
src_shapes = [cfg.num_buffer_tokens, cfg.hidden_dim]
offsets_shapes = [1, cfg.topk_count]
page_table_shapes = [1, cfg.num_logical_blocks]
dst_shapes = [cfg.topk_count, cfg.hidden_dim]
src = pypto.tensor(src_shapes, pypto.DT_FP16, "src")
offsets = pypto.tensor(offsets_shapes, pypto.DT_INT32, "offsets")
pto_page_table = pypto.tensor(
page_table_shapes, pypto.DT_INT32, "page_table")
dst = pypto.tensor(dst_shapes, pypto.DT_FP16, "dst")
with pypto.function("MAIN", src, offsets, pto_page_table, dst):
for _ in pypto.loop(1, name="b0", idx_name="bidx"):
pypto.set_vec_tile_shapes(32, 64)
dyn_src = pypto.view(src, src_shapes, [0, 0], valid_shape=src_shapes)
dyn_offsets = pypto.view(offsets, offsets_shapes, [
0, 0], valid_shape=offsets_shapes)
tmp = pypto.experimental.gather_in_ub(
dyn_src, dyn_offsets, pto_page_table, cfg.block_size, -2)
pypto.assemble(tmp, [0, 0], dst)
del dyn_src, dyn_offsets
result = torch.zeros(dst_shapes, dtype=torch.float16)
pto_a_tensor = pypto.from_torch(buffer, "buffer")
pto_b_tensor = pypto.from_torch(topk_indices, "topk_indices")
pto_c_tensor = pypto.from_torch(page_table, "page_table")
pto_d_tensor = pypto.from_torch(result, "result")
pypto.runtime._device_run_once_data_from_host(
pto_a_tensor, pto_b_tensor, pto_c_tensor, pto_d_tensor)
assert_allclose(result.flatten(), golden.flatten(), rtol=3e-3, atol=3e-3)
pypto.runtime._device_fini()