"""
"""
import sys
from dataclasses import dataclass, field
from typing import List, Set, Optional
import logging
import pytest
import pypto
SHAPE_DIM0 = 0
SHAPE_DIM1 = 1
SHAPE_DIM2 = 2
SHAPE_DIM3 = 3
NUM_NEG1 = -1
NUM_0 = 0
NUM_1 = 1
NUM_2 = 2
NUM_3 = 3
NUM_4 = 4
NUM_8 = 8
NUM_16 = 16
NUM_32 = 32
NUM_64 = 64
NUM_100 = 100
NUM_128 = 128
NUM_1024 = 1024
NUM_1127 = 1127
NUM_2048 = 2048
NUM_4096 = 4096
NUM_8192 = 8192
AVOID_FP32_TO_FP16_OVERFLOW_SCALE = 1.0 / 2048.0
@dataclass
class LightningIndexerTileConfig:
weight_tile: List[int]
c1_tile: List[List[int]]
v1_tile: List[int]
topk_tile: List[int]
adds_tile: List[int]
@dataclass
class LightningIndexerParams:
b: int
s1: int
index_n1: int
qk_nope: int
qk_rope: int
n2: int
block_size: int
block_num: int
selected_count: int
is_quant: bool = False
@dataclass
class LightningIndexerInputs:
query: pypto.Tensor
key: pypto.Tensor
weights: pypto.Tensor
act_seq_key: pypto.Tensor
block_table: pypto.Tensor
topk_res: pypto.Tensor
q_scale: Optional[pypto.Tensor]
k_scale: Optional[pypto.Tensor]
tmp_out: Optional[pypto.Tensor]
topk_value: Optional[pypto.Tensor]
tile_config: LightningIndexerTileConfig
unroll_list: Set[int]
params: LightningIndexerParams
@dataclass
class LightningIndexerBuildConfig:
b: int = NUM_4
s1: int = NUM_2
index_n1: int = NUM_64
qk_nope: int = NUM_128
qk_rope: int = NUM_0
n2: int = NUM_1
block_size: int = NUM_128
block_num: int = NUM_1127
selected_count: int = NUM_2048
is_quant: bool = True
c1_tile: List[List[int]] = field(
default_factory=lambda: [
[NUM_64, NUM_64],
[NUM_128, NUM_128],
[NUM_128, NUM_128],
]
)
v1_tile: List[int] = field(default_factory=lambda: [NUM_64, NUM_128])
topk_tile: List[int] = field(default_factory=lambda: [NUM_1, NUM_4096])
adds_tile: List[int] = field(
default_factory=lambda: [NUM_1, NUM_1, NUM_1, NUM_4096]
)
def setup_lightning_indexer_topk_config():
pypto.set_pass_options(
cube_l1_reuse_setting={-1: NUM_32},
vec_nbuffer_setting={NUM_NEG1: NUM_16})
def build_lightning_indexer_topk_args(
cfg: LightningIndexerBuildConfig = LightningIndexerBuildConfig(),
):
d_bf16 = pypto.DT_FP16
d_i32 = pypto.DT_INT32
d_int8 = pypto.DT_INT8
d_f16 = pypto.DT_FP16
index_d = cfg.qk_nope + cfg.qk_rope
max_block_num = NUM_1024
if cfg.is_quant:
qk_dtype = d_int8
scale_dtype = d_f16
else:
qk_dtype = d_bf16
scale_dtype = d_f16
query = pypto.tensor(
[cfg.b, cfg.s1, cfg.index_n1, index_d],
qk_dtype,
"query",
)
key = pypto.tensor(
[cfg.block_num, cfg.block_size, cfg.n2, index_d],
qk_dtype,
"key",
)
weights = pypto.tensor(
[cfg.b, cfg.s1, cfg.index_n1],
d_bf16,
"weights",
)
act_seq_key = pypto.tensor(
[cfg.b],
d_i32,
"actSeqKey",
)
block_table = pypto.tensor(
[cfg.b, max_block_num],
d_i32,
"blockTable",
)
topk_res = pypto.tensor(
[cfg.b, cfg.s1, cfg.n2, cfg.selected_count],
d_i32,
"topkRes",
)
q_scale = (
pypto.tensor(
[cfg.b, cfg.s1, cfg.index_n1, 1],
scale_dtype,
"qScale",
)
if cfg.is_quant
else None
)
k_scale = (
pypto.tensor(
[cfg.block_num, cfg.block_size, cfg.n2, 1],
scale_dtype,
"kScale",
)
if cfg.is_quant
else None
)
tmp_out = None
topk_value = None
tile_cfg = LightningIndexerTileConfig(
weight_tile=[NUM_64, NUM_128],
c1_tile=cfg.c1_tile,
v1_tile=cfg.v1_tile,
topk_tile=cfg.topk_tile,
adds_tile=cfg.adds_tile,
)
unroll_list: List[int] = [1, 2, 4, 8, 16, 32, 64]
params = LightningIndexerParams(
b=cfg.b,
s1=cfg.s1,
index_n1=cfg.index_n1,
qk_nope=cfg.qk_nope,
qk_rope=cfg.qk_rope,
n2=cfg.n2,
block_size=cfg.block_size,
block_num=cfg.block_num,
selected_count=cfg.selected_count,
is_quant=cfg.is_quant,
)
args = LightningIndexerInputs(
query=query,
key=key,
weights=weights,
act_seq_key=act_seq_key,
block_table=block_table,
topk_res=topk_res,
q_scale=q_scale,
k_scale=k_scale,
tmp_out=tmp_out,
topk_value=topk_value,
tile_config=tile_cfg,
unroll_list=unroll_list,
params=params,
)
meta = {
"B": cfg.b,
"S1": cfg.s1,
"indexN1": cfg.index_n1,
"indexD": index_d,
"N2": cfg.n2,
"blockSize": cfg.block_size,
"blockNum": cfg.block_num,
"maxBlockNum": max_block_num,
"selectedCount": cfg.selected_count,
"isQuant": cfg.is_quant,
"dims": {
"query": [cfg.b, cfg.s1, cfg.index_n1, index_d],
"key": [cfg.block_num, cfg.block_size, cfg.n2, index_d],
"weights": [cfg.b, cfg.s1, cfg.index_n1],
"actSeqKey": [cfg.b],
"blockTable": [cfg.b, max_block_num],
"topkRes": [cfg.b, cfg.s1, cfg.n2, cfg.selected_count],
"qScale": ([cfg.b, cfg.s1, cfg.index_n1, 1] if cfg.is_quant else None),
"kScale": (
[cfg.block_num, cfg.block_size, cfg.n2, 1] if cfg.is_quant else None
),
},
"tiles": {
"weightTile": tile_cfg.weight_tile,
"c1Tile": tile_cfg.c1_tile,
"v1Tile": tile_cfg.v1_tile,
"topkTile": tile_cfg.topk_tile,
"addsTile": tile_cfg.adds_tile,
},
"unrollList": sorted(list(unroll_list)),
}
return args, meta