"""
"""
from dataclasses import dataclass, field
from typing import List
import logging
import pytest
import pypto
from conftest import duration_estimate
SHAPE_DIM_0 = 0
SHAPE_DIM_1 = 1
@dataclass
class SelectedAttentionTileConfig:
g_tile: int
s2_tile: int
c1_tile: List
v1_tile: List
c2_tile: List
v2_tile: List
@dataclass
class SASimpleParams:
n_q: int
n_kv: int
softmax_scale: float
topk: int
tile: SelectedAttentionTileConfig
@dataclass
class SAInputs:
q_nope: pypto.tensor
q_rope: pypto.tensor
k_slc: pypto.tensor
v_slc: pypto.tensor
kv_slc_act_seqs: pypto.tensor
attention_out: pypto.tensor
params: SASimpleParams
@dataclass
class SABuildConfig:
b: int = 32
s1: int = 4
n_q: int = 128
n_kv: int = 1
qk_nope_head_dim: int = 512
qk_rope_head_dim: int = 64
kv_head_dim: int = 512
topk: int = 2048
softmax_scale: float = 1.0 / 24.0
g_tile: int = 128
s2_tile: int = 2048
c1_tile: List[int] = field(
default_factory=lambda: [[128, 128], [64, 64], [256, 256]]
)
v1_tile: List[int] = field(default_factory=lambda: [16, 256])
c2_tile: List[int] = field(
default_factory=lambda: [[128, 128], [128, 128], [128, 128]]
)
v2_tile: List[int] = field(default_factory=lambda: [64, 128])
def build_selected_args(cfg: SABuildConfig = SABuildConfig()):
d_type = pypto.DT_FP16
i32 = pypto.DT_INT32
q_nope_shape = [cfg.b * cfg.s1 * cfg.n_q, cfg.qk_nope_head_dim]
q_rope_shape = [cfg.b * cfg.s1 * cfg.n_q, cfg.qk_rope_head_dim]
k_concat_dim = cfg.qk_nope_head_dim + cfg.qk_rope_head_dim
k_slc_shape = [cfg.b * cfg.s1 * cfg.topk, k_concat_dim]
v_slc_shape = [cfg.b * cfg.s1 * cfg.topk, cfg.kv_head_dim]
kv_slc_act_seqs_shape = [cfg.b]
attention_out_shape = [cfg.b, cfg.s1, cfg.n_q, cfg.qk_nope_head_dim]
q_nope = pypto.tensor(q_nope_shape, d_type, "qNope")
q_rope = pypto.tensor(q_rope_shape, d_type, "qRope")
k_slc = pypto.tensor(k_slc_shape, d_type, "kSlc")
v_slc = pypto.tensor(v_slc_shape, d_type, "vSlc")
kv_slc_act_seqs = pypto.tensor(kv_slc_act_seqs_shape, i32, "kvSlcActSeqs")
attention_out = pypto.tensor(attention_out_shape, d_type, "attentionOut")
tile = SelectedAttentionTileConfig(
g_tile=cfg.g_tile,
s2_tile=cfg.s2_tile,
c1_tile=cfg.c1_tile,
v1_tile=cfg.v1_tile,
c2_tile=cfg.c2_tile,
v2_tile=cfg.v2_tile,
)
params = SASimpleParams(
n_q=cfg.n_q,
n_kv=cfg.n_kv,
softmax_scale=cfg.softmax_scale,
topk=cfg.topk,
tile=tile,
)
args = SAInputs(
q_nope=q_nope,
q_rope=q_rope,
k_slc=k_slc,
v_slc=v_slc,
kv_slc_act_seqs=kv_slc_act_seqs,
attention_out=attention_out,
params=params,
)
meta = {
"b": cfg.b,
"s1": cfg.s1,
"nQ": cfg.n_q,
"nKv": cfg.n_kv,
"dims": {
"qNope": q_nope_shape,
"qRope": q_rope_shape,
"kSlc": k_slc_shape,
"vSlc": v_slc_shape,
"kvSlcActSeqs": kv_slc_act_seqs_shape,
"attentionOut": attention_out_shape,
},
"topk": cfg.topk,
"softmaxScale": cfg.softmax_scale,
"tiles": {
"gTile": cfg.g_tile,
"s2Tile": cfg.s2_tile,
"c1Tile": cfg.c1_tile,
"v1Tile": cfg.v1_tile,
"c2Tile": cfg.c2_tile,
"v2Tile": cfg.v2_tile,
},
}
return args, meta