import pypto
import torch
from torch._dynamo import allow_in_graph
from dataclasses import dataclass
from typing import List
pyptolib = torch.library.Library("pypto", "FRAGMENT")
pyptolib.define(
"compressor(Tensor x, Tensor kv_state, Tensor score_state, Tensor kv_block_table, \
Tensor score_block_table, Tensor sin, Tensor cos, Tensor wkv, Tensor wgate, Tensor ape, Tensor weight, \
Tensor hadamard, Tensor start_pos, int ratio, int rope_head_dim, bool rotate) -> (Tensor, Tensor, Tensor)"
)
@torch.library.impl(pyptolib, "compressor", "Meta")
def compressor(
x,
kv_state,
score_state,
kv_block_table,
score_block_table,
sin,
cos,
wkv,
wgate,
ape,
weight,
hadamard,
start_pos,
ratio,
rope_head_dim,
rotate,
):
out = torch.empty((min(x.shape[0] * x.shape[1], x.shape[0] * x.shape[1] // ratio + x.shape[0]), weight.shape[0]),
dtype=x.dtype, device=x.device)
return out, kv_state, score_state
try:
@torch.library.impl(pyptolib, "compressor", "NPU")
def compressor(
x,
kv_state,
score_state,
kv_block_table,
score_block_table,
sin,
cos,
wkv,
wgate,
ape,
weight,
hadamard,
start_pos,
ratio,
rope_head_dim,
rotate,
):
return npu_compressor(
x,
kv_state,
score_state,
kv_block_table,
score_block_table,
sin,
cos,
wkv,
wgate,
ape,
weight,
hadamard,
start_pos,
ratio,
rope_head_dim,
rotate,
)
except Exception as e:
if "could not parse dispatch key: NPU" in str(e):
print(f"Skip: torchair not installed, skip NPU registration for operator 'compressor'")
else:
print(f"Skip: Unexpected error : {e}")
def compressor_pypto(
x,
kv_state,
score_state,
kv_block_table,
score_block_table,
sin,
cos,
wkv,
wgate,
ape,
weight,
hadamard,
start_pos,
ratio,
rope_head_dim,
rotate,
):
return torch.ops.pypto.compressor(
x,
kv_state,
score_state,
kv_block_table,
score_block_table,
sin,
cos,
wkv,
wgate,
ape,
weight,
hadamard,
start_pos,
ratio,
rope_head_dim,
rotate,
)
@allow_in_graph
def npu_compressor(
x,
kv_state,
score_state,
kv_block_table,
score_block_table,
sin,
cos,
wkv,
wgate,
ape,
weight,
hadamard,
start_pos,
ratio,
rope_head_dim,
rotate,
):
check_args(
x, kv_state, score_state, kv_block_table, score_block_table, sin, cos, wkv, wgate, ape, weight, hadamard,
start_pos, ratio, rope_head_dim, rotate
)
out = torch.zeros((min(x.shape[0] * x.shape[1], x.shape[0] * x.shape[1] // ratio + x.shape[0]), weight.shape[0]),
dtype=x.dtype, device=x.device)
tensors1 = [x, kv_state, score_state, kv_block_table, score_block_table, sin, cos, wkv, wgate,
ape, weight, out, kv_state, score_state, start_pos]
tensors2 = [x, kv_state, score_state, kv_block_table, score_block_table, sin, cos, wkv, wgate,
ape, weight, hadamard, out, kv_state, score_state, start_pos]
if rotate and ratio == 4:
compressor_ratio_4_rotate_kernel(*tensors2, ratio, rope_head_dim)
elif not rotate and ratio == 4:
compressor_ratio_4_kernel(*tensors1, ratio, rope_head_dim)
elif not rotate and ratio == 128:
compressor_ratio_128_kernel(*tensors1, ratio, rope_head_dim)
return out, kv_state, score_state
def check_args(
x,
kv_state,
score_state,
kv_block_table,
score_block_table,
sin,
cos,
wkv,
wgate,
ape,
weight,
hadamard,
start_pos,
ratio,
rope_head_dim,
rotate,
):
overlap = ratio == 4
coff = 1 + overlap
d = weight.shape[0]
bsz = x.size(0)
assert ratio == 4 or ratio == 128, f"ratio is {ratio}, expected 4 or 128"
assert rope_head_dim == 64, f"rope_head_dim is {rope_head_dim}, expected 64"
assert isinstance(rotate, bool), f"rotate dtype is {type(rotate)}, expected bool"
assert weight.dim() == 1 and ((d == 128 and rotate) or (d == 512 and not rotate)), (
f"weight dim num is {weight.dim()}, weight axis1 is {d}, \
expected 1, (d = 512 and rotate = False) or (d = 128 and rotate = True)"
)
assert x.dim() == 3 and x.size(1) in [1, 2, 3, 4] and x.size(2) == 4096, (
f"x dim num is {x.dim()}, x axis1 is {x.size(1)}, x axis2 is {x.size(2)}, "
f"expected 3 dimensions, axis1 in [1, 2, 3, 4], axis2 == 4096"
)
assert (
kv_state.dim() == 3
and kv_state.size(1) == 128
and kv_state.size(2) == coff * d
), (
f"kv_state dim num is {kv_state.dim()}, kv_state axis1 is {kv_state.size(1)}, \
kv_state axis2 is {kv_state.size(2)}, expected 3, 128, {coff * d}"
)
assert (
score_state.dim() == 3
and score_state.size(1) == 128
and score_state.size(2) == coff * d
), (
f"score_state dim num is {score_state.dim()}, score_state axis1 is {score_state.size(1)}, \
score_state axis2 is {score_state.size(2)}, expected 3, 128, {coff * d}"
)
assert (
kv_block_table.dim() == 2
and kv_block_table.size(0) == bsz
), (
f"kv_block_table dim num is {kv_block_table.dim()}, kv_block_table axis0 is {kv_block_table.size(0)}, \
expected 2, {bsz}"
)
assert (
score_block_table.dim() == 2
and score_block_table.size(0) == bsz
), (
f"score_block_table dim num is {score_block_table.dim()}, \
score_block_table axis0 is {score_block_table.size(0)}, expected 2, {bsz}"
)
expected_rope_axis0 = min(bsz * x.size(1), bsz * x.size(1) // ratio + bsz)
assert sin.dim() == 2 and sin.size(0) == expected_rope_axis0 and sin.size(1) == rope_head_dim, (
f"sin dim num is {sin.dim()}, sin axis0 is {sin.size(0)}, sin axis1 is {sin.size(1)}, "
f"expected 2, {expected_rope_axis0}, {rope_head_dim}"
)
assert cos.dim() == 2 and cos.size(0) == expected_rope_axis0 and cos.size(1) == rope_head_dim, (
f"cos dim num is {cos.dim()}, cos axis0 is {cos.size(0)}, cos axis1 is {cos.size(1)}, "
f"expected 2, {expected_rope_axis0}, {rope_head_dim}"
)
assert wkv.dim() == 2 and wkv.size(1) == 4096 and wkv.size(0) == coff * d, (
f"wkv dim num is {wkv.dim()}, wkv axis0 is {wkv.size(0)}, wkv axis1 is {wkv.size(1)}, \
expected 2, {coff * d}, 4096"
)
assert wgate.dim() == 2 and wgate.size(1) == 4096 and wgate.size(0) == coff * d, (
f"wgate dim num is {wgate.dim()}, wgate axis0 is {wgate.size(0)}, wgate axis1 is {wgate.size(1)}, \
expected 2, {coff * d}, 4096"
)
assert ape.dim() == 2 and ape.size(0) == ratio and ape.size(1) == coff * d, (
f"ape dim num is {ape.dim()}, ape axis0 is {ape.size(0)}, ape axis1 is {ape.size(1)}, \
expected 2, {ratio}, {coff * d}"
)
assert hadamard.dim() == 2 and hadamard.size(0) == d and hadamard.size(1) == d, (
f"hadamard dim num is {hadamard.dim()}, hadamard axis0 is {hadamard.size(0)}, \
hadamard axis1 is {hadamard.size(1)}, expected 2, {d}, {d}"
)
assert start_pos.dim() == 1 and start_pos.size(0) == bsz, (
f"start_pos dim num is {start_pos.dim()}, start_pos axis0 is {start_pos.size(0)}, expected 1, {bsz}"
)
assert x.dtype == torch.bfloat16, f"x.dtype is {x.dtype}, expected torch.bfloat16"
assert cos.dtype == torch.bfloat16, (
f"cos.dtype is {cos.dtype}, expected torch.bfloat16"
)
assert sin.dtype == torch.bfloat16, (
f"sin.dtype is {sin.dtype}, expected torch.bfloat16"
)
assert hadamard.dtype == torch.bfloat16, (
f"hadamard.dtype is {hadamard.dtype}, expected torch.bfloat16"
)
assert kv_state.dtype == torch.float32, (
f"kv_state.dtype is {kv_state.dtype}, expected torch.float32"
)
assert score_state.dtype == torch.float32, (
f"score_state.dtype is {score_state.dtype}, expected torch.float32"
)
assert kv_block_table.dtype == torch.int32, (
f"kv_block_table.dtype is {kv_block_table.dtype}, expected torch.int32"
)
assert score_block_table.dtype == torch.int32, (
f"score_block_table.dtype is {score_block_table.dtype}, expected torch.int32"
)
assert start_pos.dtype == torch.int32, (
f"start_pos.dtype is {start_pos.dtype}, expected torch.int32"
)
assert wkv.dtype == torch.bfloat16, (
f"wkv.dtype is {wkv.dtype}, expected torch.bfloat16"
)
assert wgate.dtype == torch.bfloat16, (
f"wgate.dtype is {wgate.dtype}, expected torch.bfloat16"
)
assert ape.dtype == torch.float32, (
f"ape.dtype is {ape.dtype}, expected torch.float32"
)
@dataclass
class Rope2dTileConfig:
two_dim_tile: List[int]
three_dim_tile: List[int]
def softmax(x: pypto.Tensor, dim) -> pypto.Tensor:
xmax = pypto.amax(x, dim, keepdim=True)
xsub = pypto.sub(x, xmax)
xexp = pypto.exp(xsub)
xsum = pypto.sum(xexp, dim, keepdim=True)
xdiv = pypto.div(xexp, xsum)
return xdiv
def rms_norm(
input_tensor: pypto.Tensor, gamma: pypto.Tensor, epsilon=1e-6
) -> pypto.Tensor:
input_fp32 = pypto.cast(input_tensor, pypto.DT_FP32)
dim = len(input_tensor.shape)
shape = [1] * dim
shape[dim - 1] = gamma.shape[0]
gamma_cast = pypto.reshape(gamma, shape)
gamma_fp32 = pypto.cast(gamma_cast, pypto.DT_FP32)
y = pypto.mul(input_fp32, input_fp32)
y = pypto.mul(y, 1.0 / input_tensor.shape[dim - 1])
y = pypto.sum(y, -1, keepdim=True)
y = pypto.add(y, epsilon)
y = pypto.sqrt(y)
ones_vector = pypto.full(y.shape, 1.0, pypto.DT_FP32)
y = pypto.div(ones_vector, y)
y = pypto.mul(input_fp32, y)
y = pypto.mul(gamma_fp32, y)
y = pypto.cast(y, input_tensor.dtype)
return y
def rotate_half(input_tensor: pypto.Tensor) -> pypto.Tensor:
chunk_size = 2
shape = input_tensor.shape
shape_size = len(shape)
shape[shape_size - 1] //= chunk_size
offset1 = [0] * shape_size
offset2 = [0] * shape_size
offset2[shape_size - 1] = shape[shape_size - 1]
x1 = pypto.view(input_tensor, shape, offset1)
x2 = pypto.view(input_tensor, shape, offset2)
return pypto.concat([x2 * (-1.0), x1], -1)
def interleaved_rope_2d(
x: pypto.Tensor,
cos: pypto.Tensor,
sin: pypto.Tensor,
rope_2d_config: Rope2dTileConfig,
):
pypto.set_vec_tile_shapes(*rope_2d_config.two_dim_tile)
cast_x = pypto.cast(x, pypto.DataType.DT_FP32)
cast_cos = pypto.cast(cos, pypto.DataType.DT_FP32)
cast_sin = pypto.cast(sin, pypto.DataType.DT_FP32)
pypto.set_vec_tile_shapes(*rope_2d_config.three_dim_tile)
x_view = pypto.reshape(cast_x, [x.shape[0], x.shape[1] // 2, 2])
x_trans = pypto.transpose(x_view, 1, 2)
x_trans = pypto.reshape(x_trans, x.shape)
x_trans = rotate_half(x_trans)
x_trans_reshape = pypto.reshape(x_trans, [x.shape[0], 2, x.shape[1] // 2])
x_trans_embed = pypto.transpose(x_trans_reshape, 1, 2)
x_second = pypto.reshape(x_trans_embed, x.shape)
x_embed = cast_x * cast_cos + x_second * cast_sin
return pypto.cast(x_embed, x.dtype)
def scatter_update_3d(input, index, src):
input_shape = input.shape
d = src.shape[2]
pypto.set_vec_tile_shapes(1, 24, d)
src = pypto.reshape(src, [src.shape[0] * src.shape[1], src.shape[2]])
input = pypto.reshape(input, [input.shape[0] * input.shape[1], input.shape[2]])
pypto.set_vec_tile_shapes(24, d)
output = pypto.scatter_update(input, -2, index, src)
return pypto.reshape(output, input_shape)
@pypto.frontend.jit(
pass_options={},
runtime_options={
"stitch_function_max_num": 128,
"device_sched_mode": 3,
},
)
def compressor_ratio_4_kernel(
x: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
kv_state_total: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
score_state_total: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
kv_block_table: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
score_block_table: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
sin: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
cos: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
wkv: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
wgate: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
ape: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
weight: pypto.Tensor([pypto.STATIC], pypto.DT_FP32),
out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
kv_state_out: pypto.Tensor([...], pypto.DT_FP32),
score_state_out: pypto.Tensor([...], pypto.DT_FP32),
start_pos_dy: pypto.Tensor([...], pypto.DT_INT32),
ratio,
rope_head_dim
):
dtype = x.dtype
bsz, s1, h = x.shape
x_tmp = pypto.reshape(x, [bsz * s1, h], inplace=True)
ratio = 4
coff = 2
d = 512
block_size = kv_state_total.shape[1]
pypto.set_vec_tile_shapes(block_size)
cache_index = pypto.arange(block_size)
b = 64
b_loop = (bsz + b - 1) // b
for b_idx in pypto.loop(b_loop, name="LOOP_COMP_1", idx_name="b_idx"):
b_valid = (bsz - b_idx * b).min(b)
x_view = pypto.view(x_tmp, [b * s1, h], [b_idx * b * s1, 0])
pypto.set_cube_tile_shapes([128, 128], [256, 512], [128, 128])
pypto.set_vec_tile_shapes(16, 2, 1024)
kv_t = pypto.matmul(x_view, wkv, pypto.DT_FP32, b_trans=True)
score_t = pypto.matmul(x_view, wgate, pypto.DT_FP32, b_trans=True)
for _ in pypto.loop(1):
pypto.set_pass_options(sg_set_scope=(1, True, False))
kv_t = pypto.reshape(kv_t, [b, s1, coff*d], inplace=True)
score_t = pypto.reshape(score_t, [b, s1, coff*d], inplace=True)
cache_index = pypto.reshape(cache_index, [1, block_size], inplace=True)
for c_idx in pypto.loop(b_valid, name="LOOP_COMP_2", idx_name="c_idx"):
pypto.set_pass_options(sg_set_scope=(1, True, False))
start_pos = start_pos_dy[b_idx * b + c_idx]
if start_pos % ratio + s1 < ratio:
pos = start_pos % ratio
kv = pypto.view(kv_t, [1, s1, coff * d], [c_idx, 0, 0])
score = pypto.view(score_t, [1, s1, coff * d], [c_idx, 0, 0])
pypto.set_vec_tile_shapes(s1, 1024)
ape_view = pypto.view(ape, [s1, coff * d], [pos, 0])
pypto.set_vec_tile_shapes(1, s1, 1024)
score = pypto.add(score, ape_view)
kv_block_idx = kv_block_table[
b_idx * b + c_idx, start_pos // block_size
]
score_block_idx = score_block_table[
b_idx * b + c_idx, start_pos// block_size
]
cur_pos = start_pos % block_size
pypto.set_vec_tile_shapes(1, s1, 1024)
pypto.assemble(kv, [kv_block_idx, cur_pos, 0], kv_state_out)
pypto.assemble(
score, [score_block_idx, cur_pos, 0], score_state_out
)
else:
pypto.set_vec_tile_shapes(1, 16, 1024)
kv_block_idx = kv_block_table[
b_idx * b + c_idx, start_pos // block_size
]
score_block_idx = score_block_table[
b_idx * b + c_idx, start_pos // block_size
]
start = ((start_pos // ratio) * ratio) % block_size
kv_state = pypto.view(
kv_state_total, [1, ratio, coff * d], [kv_block_idx, start, 0]
)
score_state = pypto.view(
score_state_total, [1, ratio, coff * d], [score_block_idx, start, 0]
)
if start_pos < ratio:
pre_kv_state = pypto.full([1, ratio, d], 0.0, pypto.DT_FP32)
pre_score_state = pypto.full(
[1, ratio, d], float("-inf"), pypto.DT_FP32
)
else:
pre_start = ((start_pos // ratio) * ratio - ratio) % block_size
pre_kv_block_idx = kv_block_table[
b_idx * b + c_idx, (start_pos - ratio) // block_size
]
pre_score_block_idx = score_block_table[
b_idx * b + c_idx, (start_pos - ratio) // block_size
]
pre_kv_state = pypto.view(
kv_state_total, [1, ratio, d], [pre_kv_block_idx, pre_start, 0]
)
pre_score_state = pypto.view(
score_state_total,
[1, ratio, d],
[pre_score_block_idx, pre_start, 0],
)
pos = start_pos % ratio
cur_pos = start_pos % block_size
if pos + s1 == ratio:
kv = pypto.view(kv_t, [1, s1, coff * d], [c_idx, 0, 0])
score = pypto.view(score_t, [1, s1, coff * d], [c_idx, 0, 0])
pypto.set_vec_tile_shapes(s1, 1024)
ape_view = pypto.view(ape, [s1, coff * d], [pos, 0])
pypto.set_vec_tile_shapes(1, s1, 1024)
score = pypto.add(score, ape_view)
pypto.set_vec_tile_shapes(1, s1, 1024)
pypto.assemble(kv, [kv_block_idx, cur_pos, 0], kv_state_out)
pypto.assemble(
score, [score_block_idx, cur_pos, 0], score_state_out
)
index = pypto.view(cache_index, [1, s1], [0, pos])
kv_state = scatter_update_3d(kv_state, index, kv)
score_state = scatter_update_3d(score_state, index, score)
else:
next_kv_block_idx = kv_block_table[b_idx * b + c_idx, (start_pos + s1) // block_size]
next_score_block_idx = score_block_table[b_idx * b + c_idx, (start_pos + s1) // block_size]
kv_pre = pypto.view(kv_t, [1, s1, coff * d], [c_idx, 0, 0], valid_shape=[1, ratio - pos, coff * d])
score_pre = pypto.view(score_t, [1, s1, coff * d], [c_idx, 0, 0],
valid_shape=[1, ratio - pos, coff * d])
kv_next = pypto.view(kv_t, [1, s1, coff * d], [c_idx, ratio - pos, 0],
valid_shape=[1, s1 - (ratio - pos), coff * d])
score_next = pypto.view(score_t, [1, s1, coff * d], [c_idx, ratio - pos, 0],
valid_shape=[1, s1 - (ratio - pos), coff * d])
pypto.set_vec_tile_shapes(s1, 1024)
ape_view_pre = pypto.view(ape, [s1, coff * d], [pos, 0], valid_shape=[ratio - pos, coff * d])
ape_view_next = pypto.view(ape, [s1, coff * d], [0, 0], valid_shape=[s1 - (ratio - pos), coff * d])
pypto.set_vec_tile_shapes(1, s1, 1024)
score_pre = pypto.add(score_pre, ape_view_pre)
score_next = pypto.add(score_next, ape_view_next)
pypto.assemble(kv_pre, [kv_block_idx, cur_pos, 0], kv_state_out)
pypto.assemble(score_pre, [score_block_idx, cur_pos, 0], score_state_out)
pypto.assemble(kv_next, [next_kv_block_idx, 0, 0], kv_state_out)
pypto.assemble(score_next, [next_score_block_idx, 0, 0], score_state_out)
index = pypto.view(cache_index, [1, s1], [0, pos], valid_shape=[1, ratio - pos])
kv_state = scatter_update_3d(kv_state, index, kv_pre)
score_state = scatter_update_3d(score_state, index, score_pre)
pypto.set_vec_tile_shapes(1, 8, 1024)
kv_state_tmp = pypto.concat(
[pre_kv_state, kv_state[:, :, d:]], 1
)
score_state_tmp = pypto.concat(
[pre_score_state, score_state[:, :, d:]], 1
)
kv = kv_state_tmp * softmax(score_state_tmp, 1)
kv = pypto.sum(kv, 1)
pypto.set_vec_tile_shapes(1, 512)
kv = rms_norm(pypto.cast(kv, dtype), weight)
kv_nope = kv[:, : d - rope_head_dim]
kv_rope = kv[:, d - rope_head_dim :]
sin_tile = pypto.view(
sin, kv_rope.shape, [b_idx * b + c_idx, 0]
)
cos_tile = pypto.view(
cos, kv_rope.shape, [b_idx * b + c_idx, 0]
)
rope2d_tile_config = Rope2dTileConfig(
[1, 64], [1, 128, 128]
)
kv_rope = interleaved_rope_2d(
kv_rope, cos_tile, sin_tile, rope2d_tile_config
)
pypto.set_vec_tile_shapes(1, 512)
kv = pypto.concat([kv_nope, kv_rope], dim=-1)
pypto.assemble(kv, [b_idx * b + c_idx, 0], out)
@pypto.frontend.jit(
pass_options={},
runtime_options={
"stitch_function_max_num": 128,
"device_sched_mode": 3,
},
)
def compressor_ratio_4_rotate_kernel(
x_in: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
kv_state_total: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
score_state_total: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
kv_block_table: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
score_block_table: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
sin: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
cos: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
wkv: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
wgate: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
ape: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
weight: pypto.Tensor([pypto.STATIC], pypto.DT_FP32),
hadamard: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
kv_state_out: pypto.Tensor([...], pypto.DT_FP32),
score_state_out: pypto.Tensor([...], pypto.DT_FP32),
start_pos_dy: pypto.Tensor([...], pypto.DT_INT32),
ratio,
rope_head_dim):
bsz, s1, h = x_in.shape
dtype = x_in.dtype
x_tmp = pypto.reshape(x_in, [bsz * s1, h], inplace=True)
ratio = 4
coff = 2
d = 128
block_size = kv_state_total.shape[1]
pypto.set_vec_tile_shapes(block_size)
cache_index = pypto.arange(block_size)
out_t = pypto.Tensor([bsz, d], pypto.DT_BF16)
pypto.set_vec_tile_shapes(1, 1)
is_compress = pypto.SymbolicScalar(0)
pypto.set_vec_tile_shapes(1, 256)
zero = pypto.full([1, d], 0.0, pypto.DT_FP32)
zero = pypto.cast(zero, pypto.DT_BF16)
b = 64
b_loop = (bsz + b - 1) // b
for b_idx in pypto.loop(b_loop, name="LOOP_COMP_1", idx_name="b_idx"):
b_valid = (bsz - b_idx * b).min(b)
x_view = pypto.view(x_tmp, [b * s1, h], [b_idx * b * s1, 0])
pypto.set_cube_tile_shapes([128, 128], [256, 512], [64, 64])
pypto.set_vec_tile_shapes(64, 2, 256)
kv_t = pypto.matmul(x_view, wkv, pypto.DT_FP32, b_trans=True)
score_t = pypto.matmul(x_view, wgate, pypto.DT_FP32, b_trans=True)
for _ in pypto.loop(1):
pypto.set_pass_options(sg_set_scope=(1, True, False))
kv_t = pypto.reshape(kv_t, [b, s1, coff*d], inplace=True)
score_t = pypto.reshape(score_t, [b, s1, coff*d], inplace=True)
cache_index = pypto.reshape(cache_index, [1, block_size], inplace=True)
for c_idx in pypto.loop(b_valid, name="LOOP_COMP_2", idx_name="c_idx"):
pypto.set_pass_options(sg_set_scope=(1, True, False))
start_pos = start_pos_dy[b_idx * b + c_idx]
if start_pos % ratio + s1 < ratio:
pos = start_pos % ratio
kv = pypto.view(kv_t, [1, s1, coff * d], [c_idx, 0, 0])
score = pypto.view(score_t, [1, s1, coff * d], [c_idx, 0, 0])
pypto.set_vec_tile_shapes(s1, 256)
ape_view = pypto.view(ape, [s1, coff * d], [pos, 0])
pypto.set_vec_tile_shapes(1, s1, 256)
score = pypto.add(score, ape_view)
kv_block_idx = kv_block_table[
b_idx * b + c_idx, start_pos // block_size
]
score_block_idx = score_block_table[
b_idx * b + c_idx, start_pos // block_size
]
cur_pos = start_pos % block_size
pypto.set_vec_tile_shapes(1, s1, 256)
pypto.assemble(kv, [kv_block_idx, cur_pos, 0], kv_state_out)
pypto.assemble(
score, [score_block_idx, cur_pos, 0], score_state_out
)
pypto.set_vec_tile_shapes(1, 256)
pypto.assemble(zero, [b_idx * b + c_idx, 0], out_t)
else:
is_compress = pypto.SymbolicScalar("is_compress") + 1
pypto.set_vec_tile_shapes(1, 16, 256)
kv_block_idx = kv_block_table[
b_idx * b + c_idx, start_pos // block_size
]
score_block_idx = score_block_table[
b_idx * b + c_idx, start_pos // block_size
]
start = ((start_pos // ratio) * ratio) % block_size
kv_state = pypto.view(
kv_state_total, [1, ratio, coff * d], [kv_block_idx, start, 0]
)
score_state = pypto.view(
score_state_total, [1, ratio, coff * d], [score_block_idx, start, 0]
)
if start_pos < ratio:
pre_kv_state = pypto.full([1, ratio, d], 0.0, pypto.DT_FP32)
pre_score_state = pypto.full(
[1, ratio, d], float("-inf"), pypto.DT_FP32
)
else:
pre_start = ((start_pos // ratio) * ratio - ratio) % block_size
pre_kv_block_idx = kv_block_table[
b_idx * b + c_idx, (start_pos - ratio) // block_size
]
pre_score_block_idx = score_block_table[
b_idx * b + c_idx, (start_pos - ratio) // block_size
]
pre_kv_state = pypto.view(
kv_state_total, [1, ratio, d], [pre_kv_block_idx, pre_start, 0]
)
pre_score_state = pypto.view(
score_state_total,
[1, ratio, d],
[pre_score_block_idx, pre_start, 0],
)
pos = start_pos % ratio
cur_pos = start_pos % block_size
if pos + s1 == ratio:
kv = pypto.view(kv_t, [1, s1, coff * d], [c_idx, 0, 0])
score = pypto.view(score_t, [1, s1, coff * d], [c_idx, 0, 0])
pypto.set_vec_tile_shapes(s1, 256)
ape_view = pypto.view(ape, [s1, coff * d], [pos, 0])
pypto.set_vec_tile_shapes(1, s1, 256)
score = pypto.add(score, ape_view)
pypto.set_vec_tile_shapes(1, s1, 256)
pypto.assemble(kv, [kv_block_idx, cur_pos, 0], kv_state_out)
pypto.assemble(
score, [score_block_idx, cur_pos, 0], score_state_out
)
index = pypto.view(cache_index, [1, s1], [0, pos])
kv_state = scatter_update_3d(kv_state, index, kv)
score_state = scatter_update_3d(score_state, index, score)
else:
next_kv_block_idx = kv_block_table[b_idx * b + c_idx, (start_pos + s1) // block_size]
next_score_block_idx = score_block_table[b_idx * b + c_idx, (start_pos + s1) // block_size]
kv_pre = pypto.view(kv_t, [1, s1, coff * d], [c_idx, 0, 0], valid_shape=[1, ratio - pos, coff * d])
score_pre = pypto.view(score_t, [1, s1, coff * d], [c_idx, 0, 0],
valid_shape=[1, ratio - pos, coff * d])
kv_next = pypto.view(kv_t, [1, s1, coff * d], [c_idx, ratio - pos, 0],
valid_shape=[1, s1 - (ratio - pos), coff * d])
score_next = pypto.view(score_t, [1, s1, coff * d], [c_idx, ratio - pos, 0],
valid_shape=[1, s1 - (ratio - pos), coff * d])
pypto.set_vec_tile_shapes(s1, 256)
ape_view_pre = pypto.view(ape, [s1, coff * d], [pos, 0], valid_shape=[ratio - pos, coff * d])
ape_view_next = pypto.view(ape, [s1, coff * d], [0, 0], valid_shape=[s1 - (ratio - pos), coff * d])
pypto.set_vec_tile_shapes(1, s1, 256)
score_pre = pypto.add(score_pre, ape_view_pre)
score_next = pypto.add(score_next, ape_view_next)
pypto.assemble(kv_pre, [kv_block_idx, cur_pos, 0], kv_state_out)
pypto.assemble(score_pre, [score_block_idx, cur_pos, 0], score_state_out)
pypto.assemble(kv_next, [next_kv_block_idx, 0, 0], kv_state_out)
pypto.assemble(score_next, [next_score_block_idx, 0, 0], score_state_out)
index = pypto.view(cache_index, [1, s1], [0, pos], valid_shape=[1, ratio - pos])
kv_state = scatter_update_3d(kv_state, index, kv_pre)
score_state = scatter_update_3d(score_state, index, score_pre)
pypto.set_vec_tile_shapes(1, 8, 256)
kv_state_tmp = pypto.concat(
[pre_kv_state, kv_state[:, :, d:]], 1
)
score_state_tmp = pypto.concat(
[pre_score_state, score_state[:, :, d:]], 1
)
kv = kv_state_tmp * softmax(score_state_tmp, 1)
kv = pypto.sum(kv, 1)
pypto.set_vec_tile_shapes(1, 256)
kv = rms_norm(pypto.cast(kv, dtype), weight)
kv_nope = kv[:, : d - rope_head_dim]
kv_rope = kv[:, d - rope_head_dim :]
sin_tile = pypto.view(
sin, kv_rope.shape, [b_idx * b + c_idx, 0]
)
cos_tile = pypto.view(
cos, kv_rope.shape, [b_idx * b + c_idx, 0]
)
rope2d_tile_config = Rope2dTileConfig(
[1, 64], [1, 128, 128]
)
kv_rope = interleaved_rope_2d(
kv_rope, cos_tile, sin_tile, rope2d_tile_config
)
pypto.set_vec_tile_shapes(1, 256)
kv = pypto.concat([kv_nope, kv_rope], dim=-1)
pypto.assemble(kv, [b_idx * b + c_idx, 0], out_t)
if is_compress > 0:
for _ in pypto.loop(1, submit_before_loop=True):
assert True
for b_idx in pypto.loop(b_loop, name="LOOP_HADAMARD", idx_name="b_idx"):
b_valid = (bsz - b_idx * b).min(b)
pypto.set_cube_tile_shapes([64, 64], [128, 128], [128, 128])
pypto.set_vec_tile_shapes(64, 128)
out_view = pypto.view(
out_t, [b, d], [b_idx * b, 0], valid_shape=[b_valid, d]
)
out_view = pypto.matmul(out_view, hadamard, pypto.DT_BF16)
pypto.assemble(out_view, [b_idx * b, 0], out)
@pypto.frontend.jit(
pass_options={},
runtime_options={
"stitch_function_max_num": 128,
"device_sched_mode": 3,
},
)
def compressor_ratio_128_kernel(
x: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
kv_state_total: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
score_state_total: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
kv_block_table: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
score_block_table: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
sin: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
cos: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
wkv: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
wgate: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
ape: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
weight: pypto.Tensor([pypto.STATIC], pypto.DT_FP32),
out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
kv_state_out: pypto.Tensor([...], pypto.DT_FP32),
score_state_out: pypto.Tensor([...], pypto.DT_FP32),
start_pos_dy: pypto.Tensor([...], pypto.DT_INT32),
ratio,
rope_head_dim):
dtype = x.dtype
bsz, s1, h = x.shape
x_tmp = pypto.reshape(x, [bsz * s1, h], inplace=True)
ratio = 128
d = 512
block_size = kv_state_total.shape[1]
pypto.set_vec_tile_shapes(block_size)
cache_index = pypto.arange(block_size)
b = 64
b_loop = (bsz + b - 1) // b
for b_idx in pypto.loop(b_loop, name="LOOP_COMP_1", idx_name="b_idx"):
b_valid = (bsz - b_idx * b).min(b)
x_view = pypto.view(x_tmp, [b * s1, h], [b_idx * b * s1, 0])
pypto.set_cube_tile_shapes([128, 128], [256, 512], [64, 64])
pypto.set_vec_tile_shapes(32, 2, 512)
kv_t = pypto.matmul(x_view, wkv, pypto.DT_FP32, b_trans=True)
score_t = pypto.matmul(x_view, wgate, pypto.DT_FP32, b_trans=True)
for _ in pypto.loop(1):
pypto.set_pass_options(sg_set_scope=(1, True, False))
kv_t = pypto.reshape(kv_t, [b, s1, d], inplace=True)
score_t = pypto.reshape(score_t, [b, s1, d], inplace=True)
cache_index = pypto.reshape(cache_index, [1, block_size], inplace=True)
for c_idx in pypto.loop(b_valid, name="LOOP_COMP_2", idx_name="c_idx"):
pypto.set_pass_options(sg_set_scope=(1, True, False))
start_pos = start_pos_dy[b_idx * b + c_idx]
if start_pos % ratio + s1 < ratio:
pos = start_pos % ratio
kv = pypto.view(kv_t, [1, s1, d], [c_idx, 0, 0])
score = pypto.view(score_t, [1, s1, d], [c_idx, 0, 0])
pypto.set_vec_tile_shapes(s1, 512)
ape_view = pypto.view(ape, [s1, d], [pos, 0])
pypto.set_vec_tile_shapes(1, s1, 512)
score = pypto.add(score, ape_view)
kv_block_idx = kv_block_table[
b_idx * b + c_idx, start_pos // block_size
]
score_block_idx = score_block_table[
b_idx * b + c_idx, start_pos // block_size
]
cur_pos = start_pos % block_size
pypto.set_vec_tile_shapes(1, s1, 512)
pypto.assemble(kv, [kv_block_idx, cur_pos, 0], kv_state_out)
pypto.assemble(
score, [score_block_idx, cur_pos, 0], score_state_out
)
else:
pypto.set_vec_tile_shapes(1, 32, 512)
kv_block_idx = kv_block_table[
b_idx * b + c_idx, start_pos // block_size
]
score_block_idx = score_block_table[
b_idx * b + c_idx, start_pos // block_size
]
kv_state = pypto.view(
kv_state_total, [1, block_size, d], [kv_block_idx, 0, 0]
)
score_state = pypto.view(
score_state_total, [1, block_size, d], [score_block_idx, 0, 0]
)
pos = start_pos % ratio
cur_pos = start_pos % block_size
if pos + s1 == ratio:
kv = pypto.view(kv_t, [1, s1, d], [c_idx, 0, 0])
score = pypto.view(score_t, [1, s1, d], [c_idx, 0, 0])
pypto.set_vec_tile_shapes(s1, 512)
ape_view = pypto.view(ape, [s1, d], [pos, 0])
pypto.set_vec_tile_shapes(1, s1, 512)
score = pypto.add(score, ape_view)
pypto.assemble(kv, [kv_block_idx, cur_pos, 0], kv_state_out)
pypto.assemble(score, [score_block_idx, cur_pos, 0], score_state_out)
index = pypto.view(cache_index, [1, s1], [0, pos])
kv_state = scatter_update_3d(kv_state, index, kv)
score_state = scatter_update_3d(score_state, index, score)
else:
next_kv_block_idx = kv_block_table[b_idx * b + c_idx, (start_pos + s1) // block_size]
next_score_block_idx = score_block_table[b_idx * b + c_idx, (start_pos + s1) // block_size]
kv_pre = pypto.view(kv_t, [1, s1, d], [c_idx, 0, 0], valid_shape=[1, ratio - pos, d])
score_pre = pypto.view(score_t, [1, s1, d], [c_idx, 0, 0], valid_shape=[1, ratio - pos, d])
kv_next = pypto.view(kv_t, [1, s1, d], [c_idx, ratio - pos, 0],
valid_shape=[1, s1 - (ratio - pos), d])
score_next = pypto.view(score_t, [1, s1, d], [c_idx, ratio - pos, 0],
valid_shape=[1, s1 - (ratio - pos), d])
pypto.set_vec_tile_shapes(s1, 512)
ape_view_pre = pypto.view(ape, [s1, d], [pos, 0], valid_shape=[ratio - pos, d])
ape_view_next = pypto.view(ape, [s1, d], [0, 0], valid_shape=[s1 - (ratio - pos), d])
pypto.set_vec_tile_shapes(1, s1, 512)
score_pre = pypto.add(score_pre, ape_view_pre)
score_next = pypto.add(score_next, ape_view_next)
pypto.assemble(kv_pre, [kv_block_idx, cur_pos, 0], kv_state_out)
pypto.assemble(score_pre, [score_block_idx, cur_pos, 0], score_state_out)
pypto.assemble(kv_next, [next_kv_block_idx, 0, 0], kv_state_out)
pypto.assemble(score_next, [next_score_block_idx, 0, 0], score_state_out)
index = pypto.view(cache_index, [1, s1], [0, pos], valid_shape=[1, ratio - pos])
kv_state = scatter_update_3d(kv_state, index, kv_pre)
score_state = scatter_update_3d(score_state, index, score_pre)
pypto.set_vec_tile_shapes(1, 128, 128)
kv = kv_state * softmax(score_state, 1)
kv = pypto.sum(kv, 1)
pypto.set_vec_tile_shapes(1, 512)
kv = rms_norm(pypto.cast(kv, dtype), weight)
kv_nope = kv[:, : d - rope_head_dim]
kv_rope = kv[:, d - rope_head_dim :]
sin_tile = pypto.view(
sin, kv_rope.shape, [b_idx * b + c_idx, 0]
)
cos_tile = pypto.view(
cos, kv_rope.shape, [b_idx * b + c_idx, 0]
)
rope2d_tile_config = Rope2dTileConfig(
[1, 64], [1, 128, 128]
)
kv_rope = interleaved_rope_2d(
kv_rope, cos_tile, sin_tile, rope2d_tile_config
)
pypto.set_vec_tile_shapes(1, 512)
kv = pypto.concat([kv_nope, kv_rope], dim=-1)
pypto.assemble(kv, [b_idx * b + c_idx, 0], out)