"""
"""
from dataclasses import dataclass
import logging
import pytest
import pypto
from conftest import duration_estimate
@dataclass
class NSASimpleParamsObj:
def __init__(self, block_size=128, topk=2048, **kwargs):
self.block_size = int(block_size)
self.topk = int(topk)
defaults = {
"n1": 128,
"n2": 1,
"idx_n_heads": 64,
"idx_head_dim": 128,
"q_lora_rank": 1536,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"kv_lora_rank": 512,
}
defaults.update(kwargs)
for k, v in defaults.items():
setattr(self, k, int(v))
@dataclass
class GatherInputs:
top_k_indices: pypto.tensor
k_nope_cache: pypto.tensor
k_rope_cache: pypto.tensor
block_table: pypto.tensor
act_seqs: pypto.tensor
gather_res: pypto.tensor
nsa_params: NSASimpleParamsObj
b: pypto.symbolic_scalar
s1: pypto.symbolic_scalar
def gather_after_prolog_compute(args: GatherInputs):
top_k_indices = args.top_k_indices
k_nope_cache = args.k_nope_cache
k_rope_cache = args.k_rope_cache
block_table = args.block_table
act_seqs = args.act_seqs
gather_res = args.gather_res
nsa_params = args.nsa_params
b = args.b
s1 = args.s1
d_n = k_nope_cache.shape[-1]
d_r = k_rope_cache.shape[-1]
n2 = top_k_indices.shape[2]
block_size = nsa_params.block_size
topk = nsa_params.topk
unroll_list = {64, 32, 16, 8, 4, 2, 1}
input_tensors = [top_k_indices, k_nope_cache,
k_rope_cache, block_table, act_seqs]
output_tensors = [gather_res]
with pypto.function("main", *input_tensors, *output_tensors):
for b_idx in pypto.loop(b, submit_before_loop=True):
for s1_idx in pypto.loop(s1):
for n2_idx in pypto.loop(n2):
pypto.set_semantic_label("gather0")
cur_kv_seq = act_seqs[b_idx]
top_k_loop = ((cur_kv_seq - s1 + 1 + s1_idx).max(0).min(topk))
for topk_idx in pypto.loop(top_k_loop, unroll_list=unroll_list):
pypto.set_vec_tile_shapes(1, 1, 1, 16)
topk_index = top_k_indices[
b_idx, s1_idx, n2_idx, topk_idx
]
block_idx_in_batch = (
topk_index // block_size
)
tail = topk_index % block_size
slc_block_idx = block_table[
b_idx, block_idx_in_batch
]
pypto.set_vec_tile_shapes(1, d_n)
kv_slc_block = pypto.view(
k_nope_cache,
[1, d_n],
[slc_block_idx * block_size + tail, 0],
)
kr_slc_block = pypto.view(
k_rope_cache,
[1, d_r],
[slc_block_idx * block_size + tail, 0],
)
pypto.set_semantic_label("gather1")
kv_slc_block_fp32 = pypto.cast(
kv_slc_block, pypto.DT_FP32
)
kr_slc_block_fp32 = pypto.cast(
kr_slc_block, pypto.DT_FP32
)
pypto.set_semantic_label("gather2")
kv_slc_block_fp16 = pypto.cast(
kv_slc_block_fp32, gather_res.dtype
)
kr_slc_block_fp16 = pypto.cast(
kr_slc_block_fp32, gather_res.dtype
)
ofs = (
b_idx * s1 * n2 * topk
+ s1_idx * n2 * topk
+ n2_idx * topk
+ topk_idx
)
pypto.assemble(
kv_slc_block_fp16, [ofs, 0], gather_res
)
pypto.assemble(
kr_slc_block_fp16,
[ofs, d_n],
gather_res,
)
@dataclass
class BuildConfig:
b: int = 32
s1: int = 4
n2: int = 128
d_n: int = 512
d_r: int = 64
block_size: int = 128
topk: int = 2048
num_blocks: int = 32
s2: int = 4096
def build_gather_args(cfg: BuildConfig = BuildConfig()):
cache_dtype = pypto.DT_FP16
index_dtype = pypto.DT_INT32
cache_rows = cfg.num_blocks * cfg.block_size
max_block_per_batch = cfg.s2 // cfg.block_size
top_k_indices = pypto.tensor(
[cfg.b, cfg.s1, cfg.n2, cfg.topk], index_dtype, "topKIndices"
)
k_nope_cache = pypto.tensor(
[cache_rows, cfg.d_n], cache_dtype, "kNopeCache")
k_rope_cache = pypto.tensor(
[cache_rows, cfg.d_r], cache_dtype, "kRopeCache")
block_table = pypto.tensor(
[cfg.b, max_block_per_batch], index_dtype, "blockTable")
act_seqs = pypto.tensor([cfg.b], index_dtype, "actSeqs")
gather_res = pypto.tensor(
[cfg.b * cfg.s1 * cfg.topk, cfg.d_n + cfg.d_r], cache_dtype, "gatherRes"
)
nsa_params = NSASimpleParamsObj(block_size=cfg.block_size, topk=cfg.topk)
args = GatherInputs(
top_k_indices=top_k_indices,
k_nope_cache=k_nope_cache,
k_rope_cache=k_rope_cache,
block_table=block_table,
act_seqs=act_seqs,
gather_res=gather_res,
nsa_params=nsa_params,
b=pypto.symbolic_scalar(cfg.b),
s1=pypto.symbolic_scalar(cfg.s1),
)
meta = {
"b": cfg.b,
"s1": cfg.s1,
"n2": cfg.n2,
"dims": {
"topKIndices": [cfg.b, cfg.s1, cfg.n2, cfg.topk],
"kNopeCache": [cache_rows, cfg.d_n],
"kRopeCache": [cache_rows, cfg.d_r],
"blockTable": [cfg.b, max_block_per_batch],
"actSeqs": [cfg.b],
"gatherRes": [cfg.b * cfg.s1 * cfg.topk, cfg.d_n + cfg.d_r],
},
"params": {
"blockSize": cfg.block_size,
"topk": cfg.topk,
"numBlocks": cfg.num_blocks,
"s2": cfg.s2,
"dN": cfg.d_n,
"dR": cfg.d_r,
},
"dtypes": {
"cache": str(cache_dtype),
"index": str(index_dtype),
},
}
return args, meta
@duration_estimate(32)
def test_gather_with_builder():
logging.basicConfig(level=logging.INFO)
args, meta = build_gather_args()
logging.info({"Sanity": meta})
gather_after_prolog_compute(args)
assert True