"""
"""
import os
import sys
import math
import logging
from pathlib import Path
import torch
import torch_npu
import pytest
import pypto
from mla_prolog_quant_impl import MlaTileConfig
from utils.compare import compare
PRINT_DEBUG = False
def prep_env():
device_id = int(os.environ.get("TILE_FWK_DEVICE_ID", 0))
torch.npu.set_device(device_id)
torch_npu.npu.config.allow_internal_format = True
def quant(input_t, is_pertoken: bool = True, has_smooth=False, smooth_cq=None):
input_fp32 = input_t.to(torch.float32)
if has_smooth:
input_fp32 = input_fp32 * smooth_cq
abs_res = torch.abs(input_fp32)
reduce_idx = -1
if not is_pertoken:
reduce_idx = -2
logging.debug("This PerChannel Quant!!")
max_value = torch.max(abs_res, dim=reduce_idx, keepdims=True)[0]
scale_quant = 127 / max_value
out_fp32 = input_fp32 * scale_quant
out_int32 = torch.round(out_fp32).to(torch.int32)
out_fp16 = out_int32.to(torch.float16)
out_int8 = torch.trunc(out_fp16).to(torch.int8)
scale_dequant = 1 / scale_quant
return out_int8, scale_dequant
def rms_norm(x, gamma):
x_dtype = x.dtype
mean_coff = 1.0 / x.shape[-1]
x_f32 = x.to(torch.float32)
square = x_f32 * x_f32
mean_res = square * mean_coff
reduce_sum = torch.sum(mean_res, dim=-1, keepdims=True)
reduce_sqrt = torch.sqrt(reduce_sum)
res_div = x_f32 / reduce_sqrt
res = res_div * gamma
if x_dtype != torch.float32:
res = res.to(x_dtype)
return res
def scatter_update_4d(cache, key_states, indices, axis):
block_number, block_size, n2, d = cache.shape
res = cache.reshape(block_number * block_size * n2, d)
b, s1 = indices.shape
if axis == -2:
for b_i in range(b):
for s1_i in range(s1):
index_value = indices[b_i][s1_i]
res[index_value][:] = key_states[b_i * s1 + s1_i][:]
return res.reshape(block_number, block_size, n2, d)
def scatter_update_2d(cache, k_bsnd, cache_index, axis):
block_number, block_size, n_kv, d = cache.shape
res = cache.reshape(block_number * block_size * n_kv, d)
b, s1 = cache_index.shape
if axis == -2:
for b_i in range(b):
for s1_i in range(s1):
index_value = cache_index[b_i][s1_i]
res[index_value, :] = k_bsnd[b_i, s1_i, :, :]
return res.reshape(block_number, block_size, n_kv, d)
def apply_rotary_pos_emb_v2(q, k, cos, sin, unsqueeze_dim=2):
input_dtype = q.dtype
if input_dtype != torch.float32:
q = q.to(torch.float32)
k = k.to(torch.float32)
if cos.dtype != torch.float32:
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
cos = torch.unsqueeze(cos, dim=unsqueeze_dim)
sin = torch.unsqueeze(sin, dim=unsqueeze_dim)
b, s, h, d = q.shape
q = q.reshape(b, s, h, d // 2, 2).permute(0, 1, 2, 4, 3).reshape(b, s, h, d)
b, s, h, d = k.shape
k = k.reshape(b, s, h, d // 2, 2).permute(0, 1, 2, 4, 3).reshape(b, s, h, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
if input_dtype != torch.float32:
q_embed, k_embed = q_embed.to(input_dtype), k_embed.to(input_dtype)
return q_embed, k_embed
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def single_rope(x, cos_in, sin_in):
x_dtype = x.dtype
b, s, n, d = x.shape
x_cast = x.to(torch.float32)
cos_cast = cos_in.to(torch.float32)
sin_cast = sin_in.to(torch.float32)
cos_re = cos_cast.unsqueeze(2)
sin_re = sin_cast.unsqueeze(2)
res = x_cast * cos_re + rotate_half(x_cast) * sin_re
return res.to(x_dtype)
def layer_norm(x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor, eps=1e-6) -> torch.Tensor:
x_dtype = x.dtype
if x_dtype != torch.float32:
x = x.to(torch.float32)
mean = x.mean(dim=-1, keepdim=True)
var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
x = (x - mean) / torch.sqrt(var + eps)
return (x * gamma.to(torch.float32) + beta.to(torch.float32)).to(x_dtype)
def gen_block_table(act_seq, block_size, s1):
b = act_seq.shape[0]
block_num = 0
block_num_each = []
max_kv = max(act_seq)
for cur_s in act_seq:
cur_block_num = math.ceil(cur_s / block_size)
block_num_each.append(cur_block_num)
block_num += cur_block_num
block_table_shape = [b, math.ceil(max_kv / block_size)]
block_idx_list = torch.arange(0, block_num, 1)
block_idx_list = block_idx_list[torch.randperm(block_idx_list.size(0))].to(torch.int32)
block_table = -torch.ones(block_table_shape, dtype=torch.int32)
block_idx = 0
block_table_bidx = 0
for cur_block in block_num_each:
for j in range(cur_block):
block_table[block_table_bidx, j] = block_idx_list[block_idx]
block_idx += 1
block_table_bidx += 1
cache_index = -torch.ones((b, s1), dtype=torch.int64)
for i in range(b):
cur_act = act_seq[i]
for j in range(s1):
pos = cur_act - s1 + j
block_idx_in_seq = pos // block_size
global_block_id = block_table[i, block_idx_in_seq]
offset_in_block = pos % block_size
global_index = global_block_id * block_size + offset_in_block
cache_index[i, j] = global_index
return block_num, block_table, cache_index
def gen_cache_tensor(k_cache_bsnd, block_table, block_num, block_size):
dtype = k_cache_bsnd.dtype
b, s2, n_kv, d = k_cache_bsnd.shape
k_cache = torch.zeros((block_num, block_size, n_kv, d), dtype=dtype)
s2_new = ((s2 + block_size - 1) // block_size) * block_size
k_cache_raw = torch.zeros((b, s2_new, n_kv, d), dtype=dtype)
k_cache_raw[:, :s2, :, :] = k_cache_bsnd
for b_idx in range(b):
for block_idx, cache_block_idx in enumerate(block_table[b_idx]):
block_offset = block_idx * block_size
if cache_block_idx == -1:
continue
else:
k_cache[cache_block_idx, :, :, :] = k_cache_raw[
b_idx, block_offset: (block_offset + block_size), :, :
]
return k_cache
def gen_mla_prolog_quant_v32_inputs(params, dtypes, actual_seq, is_quant=(False, False),
is_nz=False, has_smooth=False, block_size=128, cache_mode='BSND'):
dtype, w_dtype = dtypes
is_quant_a, is_quant_b = is_quant
b = params.get('b')
s = params.get('s')
s1 = params.get('s1')
h = params.get('h')
n = params.get('num_heads')
q_lora_rank = params.get('q_lora_rank')
qk_nope_head_dim = params.get('qk_nope_head_dim')
qk_rope_head_dim = params.get('qk_rope_head_dim')
kv_lora_rank = params.get('kv_lora_rank')
block_num, block_table, cache_index = gen_block_table(actual_seq, block_size, s1)
skv_max = actual_seq.max()
q_head_dim = qk_nope_head_dim + qk_rope_head_dim
x_shape = [b, s, h]
w_qa_shape = [h, q_lora_rank]
w_qb_shape = [q_lora_rank, n * q_head_dim]
w_kv_a_shape = [h, kv_lora_rank + qk_rope_head_dim]
w_kv_b_k_shape = [n, qk_nope_head_dim, kv_lora_rank]
gamma_cq_shape = [q_lora_rank]
gamma_ckv_shape = [kv_lora_rank]
cos_shape = [b, s, qk_rope_head_dim]
kv_bsnd_shape = [b, skv_max, 1, kv_lora_rank + qk_rope_head_dim]
kv_cache_shape = [block_num, block_size, 1, kv_lora_rank]
kr_cache_shape = [block_num, block_size, 1, qk_rope_head_dim]
kv_quant_scale_cache_shape = [block_num, block_size, 1, 4]
smooth_cq_shape = [1, q_lora_rank]
res = [None] * 17
x = torch.empty(x_shape).uniform_(-1, 1).to(dtype)
res[0] = x
w_dq = torch.empty(w_qa_shape).uniform_(-0.1, 0.1).to(w_dtype)
w_uqqr = torch.empty(w_qb_shape).uniform_(-0.1, 0.1).to(w_dtype)
w_dkvkr = torch.empty(w_kv_a_shape).uniform_(-0.1, 0.1).to(w_dtype)
res[4] = dict()
if is_quant_a:
w_dq, w_qa_scale = quant(w_dq, False)
w_dkvkr, w_kva_scale = quant(w_dkvkr, False)
res[4]['w_dq'] = w_qa_scale
res[4]['w_dkvkr'] = w_kva_scale
if is_nz:
w_dq = w_dq.reshape(h, q_lora_rank // 32, 32).permute(1, 0, 2)
w_dkvkr = w_dkvkr.reshape(h, (kv_lora_rank + qk_rope_head_dim) // 32, 32).permute(1, 0, 2)
else:
if is_nz:
w_dq = w_dq.reshape(h, q_lora_rank // 16, 16).permute(1, 0, 2)
w_dkvkr = w_dkvkr.reshape(h, (kv_lora_rank + qk_rope_head_dim) // 16, 16).permute(1, 0, 2)
if is_quant_b:
w_uqqr, w_qb_scale = quant(w_uqqr, False)
res[4]['w_uqqr'] = w_qb_scale
if has_smooth:
smooth_cq = torch.empty(smooth_cq_shape).uniform_(-1, 1).to(torch.float32)
res[3] = smooth_cq
if is_nz:
w_uqqr = w_uqqr.reshape(q_lora_rank, n * q_head_dim // 32, 32).permute(1, 0, 2)
res[1] = w_dq
res[2] = w_uqqr
res[5] = w_dkvkr
w_uk = torch.empty(w_kv_b_k_shape).uniform_(-0.1, 0.1).to(w_dtype)
res[6] = w_uk
gamma_cq = torch.empty(gamma_cq_shape).uniform_(-1, 1).to(dtype)
gamma_ckv = torch.empty(gamma_ckv_shape).uniform_(-1, 1).to(dtype)
res[7] = gamma_cq
res[8] = gamma_ckv
cos = torch.empty(cos_shape).uniform_(-0.1, 0.1).to(dtype)
sin = torch.empty(cos_shape).uniform_(-0.1, 0.1).to(dtype)
res[9] = cos
res[10] = sin
res[11] = cache_index
k_bsnd = torch.empty(kv_bsnd_shape).uniform_(-1, 1).to(dtype)
per_batch_max_num = math.ceil(skv_max / block_size)
k_tensor_bsnd = torch.zeros((b, per_batch_max_num * block_size, 1, kv_lora_rank + qk_rope_head_dim)).to(dtype)
k_tensor_bsnd[:, :k_bsnd.shape[1], :, :] = k_bsnd[:, :, :, :]
k_cache_tensor = torch.zeros([block_num, block_size, 1, kv_lora_rank + qk_rope_head_dim]).to(dtype)
for b_idx in range(b):
for block_i, kv_cache_blk_id in enumerate(block_table[b_idx]):
block_offset = block_i * block_size
if kv_cache_blk_id == -1:
continue
else:
k_cache_tensor[kv_cache_blk_id, 0:block_size, :, :] = k_tensor_bsnd[
b_idx, block_offset:(block_offset + block_size), :, :]
kv_cache = k_cache_tensor[:, :, :, : kv_lora_rank]
kr_cache = k_cache_tensor[:, :, :, kv_lora_rank:]
kv_quant_scale_cache = None
if is_quant_b:
kv_cache_split = kv_cache.reshape(-1, 4, kv_lora_rank // 4)
kv_cache, kv_quant_scale_cache = quant(kv_cache_split, True)
kv_cache = kv_cache.reshape(kv_cache_shape)
kv_quant_scale_cache = kv_quant_scale_cache.reshape(kv_quant_scale_cache_shape)
res[12] = kv_cache
res[13] = kr_cache
res[14] = kv_quant_scale_cache
res[15] = block_num
res[16] = block_table
return res
def gen_indexer_prolog_inputs(params, block_num, block_table, mla_inputs, mla_goldens):
q_lora_rank = params['q_lora_rank']
b = params['b']
s2 = params['s2']
h = params['h']
n2 = params['n2']
idx_head_dim = params['idx_head_dim']
dtype = params['dtype']
block_size = params['block_size']
idx_n_heads = params['idx_n_heads']
quant_dtype = torch.int8
x = mla_inputs['x']
cos = mla_inputs['cos']
sin = mla_inputs['sin']
cache_index = mla_inputs['cache_index']
rms_norm_out = mla_goldens['rms_norm_out']
rms_norm_scale_out = mla_goldens['rms_norm_scale_out']
w_idx_qb = torch.randint(low=-128, high=128, size=(q_lora_rank, idx_n_heads * idx_head_dim), dtype=quant_dtype)
w_idx_qb_scale = torch.empty((1, idx_n_heads * idx_head_dim), dtype=torch.float32).uniform_(-1, 1)
w_idx_k = torch.empty((h, idx_head_dim), dtype=dtype).uniform_(-1, 1)
w_idx_proj = torch.empty((h, idx_n_heads), dtype=dtype).uniform_(-1, 1)
ln_gamma = torch.ones((idx_head_dim,), dtype=dtype)
ln_beta = torch.zeros((idx_head_dim,), dtype=dtype)
hadamard_q = torch.empty((idx_head_dim, idx_head_dim), dtype=dtype).uniform_(-1, 1)
hadamard_k = torch.empty((idx_head_dim, idx_head_dim), dtype=dtype).uniform_(-1, 1)
k_cache_bsnd = torch.rand((b * s2 * n2, idx_head_dim), dtype=torch.float32) * 2 - 1
k_cache_bsnd, k_scale_cache_bsnd = quant(k_cache_bsnd)
k_cache_bsnd = k_cache_bsnd.reshape(b, s2, n2, idx_head_dim).to(dtype=quant_dtype)
k_scale_cache_bsnd = k_scale_cache_bsnd.reshape(b, s2, n2, 1).to(torch.float16)
k_cache = gen_cache_tensor(k_cache_bsnd, block_table, block_num, block_size)
k_scale_cache = gen_cache_tensor(k_scale_cache_bsnd, block_table, block_num, block_size)
return {
'token_x': x,
'q_norm': rms_norm_out,
'q_norm_scale': rms_norm_scale_out,
'w_idx_qb': w_idx_qb,
'w_idx_qb_scale': w_idx_qb_scale,
'w_idx_k': w_idx_k,
'w_idx_proj': w_idx_proj,
'layer_norm_gamma': ln_gamma,
'layer_norm_beta': ln_beta,
'cos_idx_rope': cos,
'sin_idx_rope': sin,
'hadamard_q': hadamard_q,
'hadamard_k': hadamard_k,
'idx_k_cache': k_cache,
'idx_k_scale_cache': k_scale_cache,
'idx_k_cache_index': cache_index,
'idx_block_table': block_table,
}
def mla_prolog_quant_v32_compute(inputs):
dtype = inputs.get('dtype')
is_quant_a = inputs.get('is_quant_a')
is_quant_b = inputs.get('is_quant_b')
has_smooth = inputs.get('has_smooth')
gamma_cq = inputs.get('gamma_cq')
gamma_ckv = inputs.get('gamma_ckv')
x = inputs.get('x')
w_dq = inputs.get('w_dq')
w_uqqr = inputs.get('w_uqqr')
w_uk = inputs.get('w_uk')
w_dkvkr = inputs.get('w_dkvkr')
cos = inputs.get('cos')
sin = inputs.get('sin')
kv_cache = inputs.get('kv_cache')
kr_cache = inputs.get('kr_cache')
kv_quant_scale_cache = None
if is_quant_b:
kv_quant_scale_cache = inputs.get('kv_quant_scale_cache')
cache_index = inputs.get('cache_index')
if is_quant_a:
w_qa_scale = inputs.get('w_qa_scale')
w_kva_scale = inputs.get('w_kva_scale')
if is_quant_b:
w_qb_scale = inputs.get('w_qb_scale')
if has_smooth:
smooth_cq = inputs.get('smooth_cq')
b, s, h = x.shape
qk_rope_head_dim = cos.shape[2]
n, qk_nope_head_dim, kv_lora_rank = w_uk.shape
q_head_dim = qk_nope_head_dim + qk_rope_head_dim
""" q """
x_2d = x.reshape(b * s, h)
if is_quant_a:
x_2d_quant, x_2d_scale_dequant = quant(x_2d, True)
q_a_proj = torch.matmul(x_2d_quant.to(torch.int32), w_dq.to(torch.int32))
""" dequant """
q_a_proj_fp32 = q_a_proj.to(torch.float32)
q_a_proj_fp32_dequant = q_a_proj_fp32 * x_2d_scale_dequant
q_a_proj = q_a_proj_fp32_dequant * w_qa_scale
else:
q_a_proj = torch.matmul(x_2d.to(torch.float32), w_dq.to(torch.float32))
q_a_proj = q_a_proj.to(torch.bfloat16)
q_a_layernorm = rms_norm(q_a_proj, gamma_cq)
q_a_layernorm_scale_dequant = None
if is_quant_b:
if has_smooth:
q_a_layernorm, q_a_layernorm_scale_dequant = quant(q_a_layernorm, True, True, smooth_cq)
else:
q_a_layernorm, q_a_layernorm_scale_dequant = quant(q_a_layernorm, True)
q_b_proj = torch.matmul(q_a_layernorm.to(torch.int32), w_uqqr.to(torch.int32)).to(
q_a_layernorm.device)
""" dequant """
q_b_proj_fp32 = q_b_proj.to(torch.float32)
q_b_proj_fp32_dequant = q_b_proj_fp32 * q_a_layernorm_scale_dequant
q_b_proj = q_b_proj_fp32_dequant * w_qb_scale
else:
q_b_proj = torch.matmul(q_a_layernorm.to(torch.float32), w_uqqr.to(torch.float32))
q_b_proj = q_b_proj.to(dtype)
q_reshape = q_b_proj.reshape(b, s, n, q_head_dim)
q_nope = q_reshape[:, :, :, 0:qk_nope_head_dim]
q_nope_r = q_nope.reshape(b * s, n, qk_nope_head_dim)
q_nope_t = q_nope_r.permute(1, 0, 2)
q_nope_new = torch.matmul(q_nope_t.to(torch.float32), w_uk.to(torch.float32))
q_nope_new = q_nope_new.to(dtype)
q_nope_new_t = q_nope_new.permute(1, 0, 2)
q_nope = q_nope_new_t.reshape(b, s, n, kv_lora_rank)
""" kv """
if is_quant_a:
x_2d_quant, x_2d_scale_dequant = quant(x_2d, True)
kv_a_proj = torch.matmul(x_2d_quant.to(torch.int32), w_dkvkr.to(torch.int32))
""" dequant """
kv_a_proj_fp32 = kv_a_proj.to(torch.float32)
kv_a_proj_fp32_dequant = kv_a_proj_fp32 * x_2d_scale_dequant
kv_a_proj = kv_a_proj_fp32_dequant * w_kva_scale
else:
kv_a_proj = torch.matmul(x_2d.to(torch.float32),
w_dkvkr.to(torch.float32))
kv_a_proj = kv_a_proj.to(dtype)
kv_reshape = kv_a_proj.reshape(b, s, kv_lora_rank + qk_rope_head_dim)
compressed_kv = kv_reshape[:, :, 0:kv_lora_rank]
compressed_kv_norm = rms_norm(compressed_kv, gamma_ckv)
compressed_kv_quant_scale = None
if is_quant_b:
compressed_kv_norm_split = compressed_kv_norm.reshape(b * s, 4, kv_lora_rank // 4)
compressed_kv_norm, compressed_kv_quant_scale = quant(compressed_kv_norm_split, True)
compressed_kv_quant_scale = compressed_kv_quant_scale.reshape(b, s, 1, 4)
compressed_kv_r = compressed_kv_norm.reshape(b, s, 1, kv_lora_rank)
k_nope = compressed_kv_r.reshape(b * s * 1, kv_lora_rank)
""" RoPE """
q_pe = q_reshape[:, :, :, qk_nope_head_dim:]
k_pe = kv_reshape[:, :, kv_lora_rank:]
k_pe_r = k_pe.reshape(b, s, 1, qk_rope_head_dim)
q_embed, k_embed = apply_rotary_pos_emb_v2(q_pe, k_pe_r, cos, sin, 2)
k_embed_r = k_embed.reshape(b * 1 * s, qk_rope_head_dim)
""" kv_cache output, [b,1,s2,kv_lora_rank] """
kv_cache_tmp = kv_cache.clone()
kv_cache_out = scatter_update_4d(kv_cache_tmp, k_nope, cache_index, -2)
""" kr_cache output, [b,1,s2,qk_rope_head_dim] """
kr_cache_tmp = kr_cache.clone()
kr_cache_out = scatter_update_4d(kr_cache_tmp, k_embed_r, cache_index, -2)
if is_quant_b:
compressed_kv_quant_scale = compressed_kv_quant_scale.reshape(-1, 4)
kv_quant_scale_cache_tmp = kv_quant_scale_cache.clone()
kv_quant_scale_cache_out = \
scatter_update_4d(kv_quant_scale_cache_tmp, compressed_kv_quant_scale, cache_index, -2)
else:
kv_quant_scale_cache_out = None
res = [q_nope, q_embed, q_a_layernorm, q_a_layernorm_scale_dequant, kv_cache_out, \
kr_cache_out, kv_quant_scale_cache_out]
return res
def indexer_prolog(inputs: dict, dims: dict):
b, t, n, d = dims['b'], dims['t'], dims['idx_n_heads'], dims['idx_head_dim']
s = t // b
rope_head_dim = dims['rope_head_dim']
x = inputs['token_x']
q_norm = inputs['q_norm']
q_norm_scale = inputs['q_norm_scale']
w_idx_qb = inputs['w_idx_qb']
w_idx_qb_scale = inputs['w_idx_qb_scale']
w_idx_k = inputs['w_idx_k']
w_idx_proj = inputs['w_idx_proj']
layer_norm_gamma = inputs['layer_norm_gamma']
layer_norm_beta = inputs['layer_norm_beta']
cos = inputs['cos_idx_rope']
sin = inputs['sin_idx_rope']
hadamard_q = inputs['hadamard_q']
hadamard_k = inputs['hadamard_k']
idx_k_cache = inputs['idx_k_cache']
idx_k_scale_cache = inputs['idx_k_scale_cache']
cache_index = inputs['idx_k_cache_index']
x_dtype = x.dtype
q = torch.matmul(q_norm.to(torch.int32), w_idx_qb.to(torch.int32))
q_fp32 = q.to(torch.float32)
q_fp32 = q_fp32 * q_norm_scale
q_fp32 = q_fp32 * w_idx_qb_scale.reshape(1, n * d)
q_bf16 = q_fp32.reshape(b, s, n, d).to(torch.bfloat16)
q_rope, q_nope = torch.split(q_bf16, [rope_head_dim, d - rope_head_dim], dim=-1)
q_rope = single_rope(q_rope, cos, sin)
q = torch.cat([q_rope, q_nope], dim=-1)
q = torch.matmul(q.to(torch.float32), hadamard_q.to(torch.float32)).to(x_dtype)
q_int8, q_scale = quant(q)
q_scale = q_scale.to(torch.float16)
k = torch.matmul(x.to(torch.float32), w_idx_k.to(torch.float32))
k = layer_norm(k, layer_norm_gamma, layer_norm_beta).to(x_dtype)
k_rope, k_nope = torch.split(k, [rope_head_dim, d - rope_head_dim], dim=-1)
k_rope = single_rope(k_rope.unsqueeze(2), cos, sin).squeeze(2)
k = torch.cat([k_rope, k_nope], dim=-1)
k = torch.matmul(k.to(torch.float32), hadamard_k.to(torch.float32)).to(x_dtype)
k_int8, k_scale = quant(k)
k_scale = k_scale.to(torch.float16)
k_cache = idx_k_cache.clone()
k_scale_cache = idx_k_scale_cache.clone()
scatter_update_2d(k_cache, k_int8.reshape(b, s, 1, d), cache_index, -2)
scatter_update_2d(k_scale_cache, k_scale.reshape(b, s, 1, 1), cache_index, -2)
weights = torch.matmul(x.to(torch.float32), \
w_idx_proj.to(torch.float32)).to(x_dtype).to(torch.float32)
weights = weights * (n ** -0.5) * (d ** -0.5)
weights = weights.to(torch.float16)
outputs = {'q_int8': q_int8, 'q_scale': q_scale,
'idx_k_cache_out': k_cache, 'idx_k_scale_cache_out': k_scale_cache,
'weights': weights}
return outputs
def gen_test_data(params):
q_lora_rank = params['q_lora_rank']
t = params['t']
h = params['h']
n1 = params['n1']
head_num = params['idx_n_heads']
idx_head_dim = params['idx_head_dim']
qk_rope_head_dim = params['qk_rope_head_dim']
kv_lora_rank = params['kv_lora_rank']
rope_head_dim = params['rope_head_dim']
dtype = params['dtype']
dtypes = (dtype, dtype)
actual_seq = params['actual_seq']
is_quant = (False, True)
is_nz = False
has_smooth = False
block_size = params['block_size']
cache_mode = 'PA_BSND'
(x, w_dq, w_uqqr, smooth_cq, scale_data, w_dkvkr, w_uk, gamma_cq, gamma_ckv, cos, sin, kv_len, kv_cache,
kr_cache, kv_quant_scale_cache, block_num, block_table) = \
gen_mla_prolog_quant_v32_inputs(params, dtypes, actual_seq, is_quant, is_nz,
has_smooth, block_size, cache_mode)
mla_inputs = {'dtype': dtype, 'is_quant_a': is_quant[0], 'is_quant_b': is_quant[1], 'has_smooth': has_smooth}
mla_inputs['cache_mode'] = cache_mode
mla_inputs['gamma_cq'] = gamma_cq
mla_inputs['gamma_ckv'] = gamma_ckv
mla_inputs['x'] = x
mla_inputs['w_dq'] = w_dq
mla_inputs['w_uqqr'] = w_uqqr
mla_inputs['w_uk'] = w_uk
mla_inputs['w_dkvkr'] = w_dkvkr
mla_inputs['cos'] = cos
mla_inputs['sin'] = sin
mla_inputs['kv_cache'] = kv_cache
mla_inputs['kr_cache'] = kr_cache
mla_inputs['kv_quant_scale_cache'] = kv_quant_scale_cache
mla_inputs['cache_index'] = kv_len
mla_inputs['w_qb_scale'] = scale_data['w_uqqr']
res = mla_prolog_quant_v32_compute(mla_inputs)
q_nope, q_rope, rms_norm_out, rms_norm_scale_out, kv_cache_out, kr_cache_out, \
kv_quant_scale_cache_out = res
mla_goldens = {}
mla_goldens['q_nope'] = q_nope
mla_goldens['q_rope'] = q_rope
mla_goldens['rms_norm_out'] = rms_norm_out
mla_goldens['rms_norm_scale_out'] = rms_norm_scale_out
mla_goldens['kv_cache_out'] = kv_cache_out
mla_goldens['kr_cache_out'] = kr_cache_out
mla_goldens['kv_quant_scale_cache_out'] = kv_quant_scale_cache_out
mla_inputs_npu = {}
for k, v in mla_inputs.items():
if isinstance(v, torch.Tensor):
mla_inputs_npu[k] = v.npu().contiguous()
else:
mla_inputs_npu[k] = v
mla_inputs_npu['x'] = mla_inputs_npu['x'].reshape(t, h).contiguous()
mla_inputs_npu['cos'] = mla_inputs_npu['cos'].reshape(t, qk_rope_head_dim).contiguous()
mla_inputs_npu['sin'] = mla_inputs_npu['sin'].reshape(t, qk_rope_head_dim).contiguous()
mla_inputs_npu['cache_index'] = mla_inputs_npu['cache_index'].reshape(t).contiguous()
mla_inputs_npu['w_qb_scale'] = mla_inputs_npu['w_qb_scale'].reshape(-1, 1).contiguous()
mla_inputs_npu['w_dq'] = torch_npu.npu_format_cast(mla_inputs_npu['w_dq'], torch_npu.Format.FRACTAL_NZ)
mla_inputs_npu['w_uqqr'] = torch_npu.npu_format_cast(mla_inputs_npu['w_uqqr'], torch_npu.Format.FRACTAL_NZ)
mla_inputs_npu['w_dkvkr'] = torch_npu.npu_format_cast(mla_inputs_npu['w_dkvkr'], torch_npu.Format.FRACTAL_NZ)
mla_goldens['q_nope'] = mla_goldens['q_nope'].reshape(t, n1, kv_lora_rank).contiguous().cpu()
mla_goldens['q_rope'] = mla_goldens['q_rope'].reshape(t, n1, qk_rope_head_dim).contiguous().cpu()
if PRINT_DEBUG:
logging.debug("mla_inputs_npu======")
for k, v in mla_inputs_npu.items():
if isinstance(v, torch.Tensor):
logging.debug(f'{k}: {v.shape}, {v.dtype}')
else:
logging.debug(f'{k}: {v}')
logging.debug("mla_goldens======")
for k, v in mla_goldens.items():
logging.debug(f'{k}: {v.shape}, {v.dtype}')
ip_inputs = gen_indexer_prolog_inputs(params, block_num, block_table, mla_inputs, mla_goldens)
ip_goldens = indexer_prolog(ip_inputs, params)
ip_inputs_npu = {}
for k, v in ip_inputs.items():
if isinstance(v, torch.Tensor):
ip_inputs_npu[k] = v.npu().contiguous()
else:
ip_inputs_npu[k] = v
ip_inputs_npu['token_x'] = ip_inputs_npu['token_x'].reshape(t, h).contiguous()
ip_inputs_npu['cos_idx_rope'] = ip_inputs_npu['cos_idx_rope'].reshape(t, rope_head_dim).contiguous()
ip_inputs_npu['sin_idx_rope'] = ip_inputs_npu['sin_idx_rope'].reshape(t, rope_head_dim).contiguous()
ip_inputs_npu['idx_k_cache_index'] = ip_inputs_npu['idx_k_cache_index'].reshape(t).contiguous()
ip_inputs_npu['w_idx_qb_scale'] = ip_inputs_npu['w_idx_qb_scale'].reshape(-1, 1).contiguous()
ip_inputs_npu['q_norm'] = ip_inputs_npu['q_norm'].reshape(t, q_lora_rank).contiguous()
ip_inputs_npu['q_norm_scale'] = ip_inputs_npu['q_norm_scale'].reshape(t, 1).contiguous()
ip_inputs_npu['w_idx_qb_nz'] = torch_npu.npu_format_cast(ip_inputs_npu['w_idx_qb'], torch_npu.Format.FRACTAL_NZ)
ip_inputs_npu['w_idx_k_nz'] = torch_npu.npu_format_cast(ip_inputs_npu['w_idx_k'], torch_npu.Format.FRACTAL_NZ)
ip_inputs_npu['w_idx_proj_nz'] = torch_npu.npu_format_cast(ip_inputs_npu['w_idx_proj'],
torch_npu.Format.FRACTAL_NZ)
ip_goldens['q_int8'] = ip_goldens['q_int8'].reshape(t, head_num, idx_head_dim).contiguous()
ip_goldens['q_scale'] = ip_goldens['q_scale'].reshape(t, head_num, 1).contiguous()
ip_goldens['weights'] = ip_goldens['weights'].reshape(t, head_num).contiguous()
if PRINT_DEBUG:
logging.debug("ip_inputs_npu======")
for k, v in ip_inputs_npu.items():
if isinstance(v, torch.Tensor):
logging.debug(f'{k}: {v.shape}, {v.dtype}')
else:
logging.debug(f'{k}: {v}')
logging.debug("ip_goldens======")
for k, v in ip_goldens.items():
logging.debug(f'{k}: {v.shape}, {v.dtype}')
inputs = {
'x': mla_inputs_npu['x'],
'w_dq': mla_inputs_npu['w_dq'],
'w_uqqr': mla_inputs_npu['w_uqqr'],
'w_qb_scale': mla_inputs_npu['w_qb_scale'],
'w_uk': mla_inputs_npu['w_uk'],
'w_dkvkr': mla_inputs_npu['w_dkvkr'],
'gamma_cq': mla_inputs_npu['gamma_cq'],
'gamma_ckv': mla_inputs_npu['gamma_ckv'],
'cos': mla_inputs_npu['cos'],
'sin': mla_inputs_npu['sin'],
'cache_index': mla_inputs_npu['cache_index'],
'kv_cache': mla_inputs_npu['kv_cache'],
'kr_cache': mla_inputs_npu['kr_cache'],
'kv_quant_scale_cache': mla_inputs_npu['kv_quant_scale_cache'],
'w_idx_qb_nz': ip_inputs_npu['w_idx_qb_nz'],
'w_idx_qb_scale': ip_inputs_npu['w_idx_qb_scale'],
'w_idx_k_nz': ip_inputs_npu['w_idx_k_nz'],
'w_idx_proj_nz': ip_inputs_npu['w_idx_proj_nz'],
'layer_norm_gamma': ip_inputs_npu['layer_norm_gamma'],
'layer_norm_beta': ip_inputs_npu['layer_norm_beta'],
'hadamard_q': ip_inputs_npu['hadamard_q'],
'hadamard_k': ip_inputs_npu['hadamard_k'],
'idx_k_cache': ip_inputs_npu['idx_k_cache'],
'idx_k_scale_cache': ip_inputs_npu['idx_k_scale_cache'],
}
goldens = {
'q_nope': mla_goldens['q_nope'],
'q_rope': mla_goldens['q_rope'],
'rms_norm_out': mla_goldens['rms_norm_out'],
'rms_norm_scale_out': mla_goldens['rms_norm_scale_out'],
'kv_cache_out': mla_goldens['kv_cache_out'],
'kr_cache_out': mla_goldens['kr_cache_out'],
'kv_quant_scale_cache_out': mla_goldens['kv_quant_scale_cache_out'],
'q_int8': ip_goldens['q_int8'].reshape(t, head_num, idx_head_dim),
'q_scale': ip_goldens['q_scale'].reshape(t, head_num, 1),
'idx_k_cache_out': ip_goldens['idx_k_cache_out'],
'idx_k_scale_cache_out': ip_goldens['idx_k_scale_cache_out'],
'weights': ip_goldens['weights'].reshape(t, head_num),
}
outputs = {
'q_nope': gen_zero_tensor(goldens['q_nope']),
'q_rope': gen_zero_tensor(goldens['q_rope']),
'kv_cache_out': mla_inputs_npu['kv_cache'],
'kr_cache_out': mla_inputs_npu['kr_cache'],
'kv_quant_scale_cache_out': mla_inputs_npu['kv_quant_scale_cache'],
'q_int8': gen_zero_tensor(goldens['q_int8']),
'q_scale': gen_zero_tensor(goldens['q_scale']),
'idx_k_cache_out': ip_inputs_npu['idx_k_cache'],
'idx_k_scale_cache_out': ip_inputs_npu['idx_k_scale_cache'],
'weights': gen_zero_tensor(goldens['weights'])
}
return inputs, outputs, goldens
def gen_zero_tensor(t):
return torch.zeros_like(t).npu()
def check(case_name, outputs, goldens):
compare(outputs['q_nope'].cpu(), goldens['q_nope'], 'qNope', 0.005, 0.0078125,
0.005)
compare(outputs['q_rope'].cpu(), goldens['q_rope'], 'qRope', 0.005, 0.0078125, 0.005)
compare(outputs['kv_cache_out'].cpu(), goldens['kv_cache_out'], 'kv', 1, 0, 0)
compare(outputs['kr_cache_out'].cpu(), goldens['kr_cache_out'], 'kr', 0.0001, 0.0078125, 0.005)
compare(outputs['kv_quant_scale_cache_out'].cpu(), goldens['kv_quant_scale_cache_out'], 'kScaleCache', 0.000025,
0.005, 0.005)
compare(outputs['q_int8'].cpu(), goldens['q_int8'], 'q_int8', 2, 0, 0)
compare(outputs['q_scale'].cpu(), goldens['q_scale'], 'q_scale', 0.000025, 0.006)
compare(outputs['idx_k_cache_out'].cpu(), goldens['idx_k_cache_out'], 'k_int8', 1, 0, 0)
compare(outputs['idx_k_scale_cache_out'].cpu(), goldens['idx_k_scale_cache_out'], 'k_scale', 0.000025, 0, 0.005)
compare(outputs['weights'].cpu(), goldens['weights'], 'weights', 0.000025, 0, 0.005)
logging.debug(f'=== {case_name}: PASS ===')
def convert_torch_tensor(tensor_dict, dynamic_axis_dict, name_prefix):
dynamic_count = 0
pypto_tensors = []
for name, tensor in tensor_dict.items():
if name in dynamic_axis_dict.keys():
dynamic_axis = dynamic_axis_dict[name]
pypto_tensors.append(pypto.from_torch(tensor, name_prefix + name, dynamic_axis=dynamic_axis))
dynamic_count += 1
else:
pypto_tensors.append(pypto.from_torch(tensor, name_prefix + name))
assert dynamic_count == len(dynamic_axis_dict)
return pypto_tensors
def do_test(case_name, params, mla_epsilon_cq, mla_epsilon_ckv, mla_cache_mode, mla_tile_config, ip_attrs,
ip_configs, rope_tile_shape, is_prefill=False):
prep_env()
logging.debug(f'=== run test case: {case_name} ===')
inputs, outputs, goldens = gen_test_data(params)
pto_inputs = [
inputs["x"],
inputs["w_dq"],
inputs["w_uqqr"],
inputs["w_qb_scale"],
inputs["w_uk"],
inputs["w_dkvkr"],
inputs["gamma_cq"],
inputs["gamma_ckv"],
inputs["cos"],
inputs["sin"],
inputs["cache_index"],
inputs["kv_cache"],
inputs["kr_cache"],
inputs["kv_quant_scale_cache"],
inputs["w_idx_qb_nz"],
inputs["w_idx_qb_scale"],
inputs["w_idx_k_nz"],
inputs["w_idx_proj_nz"],
inputs["layer_norm_gamma"],
inputs["layer_norm_beta"],
inputs["hadamard_q"],
inputs["hadamard_k"],
inputs["idx_k_cache"],
inputs["idx_k_scale_cache"],
]
pto_outputs = [
outputs["q_nope"],
outputs["q_rope"],
outputs["kv_cache_out"],
outputs["kr_cache_out"],
outputs["kv_quant_scale_cache_out"],
outputs["q_int8"],
outputs["q_scale"],
outputs["idx_k_cache_out"],
outputs["idx_k_scale_cache_out"],
outputs["weights"]
]
import mla_indexer_prolog_quant_impl as mla_lp_quant
if is_prefill:
fun = mla_lp_quant.mla_indexer_prolog_quant_p
else:
fun = mla_lp_quant.mla_indexer_prolog_quant_d
fun(*pto_inputs, *pto_outputs, mla_epsilon_cq, mla_epsilon_ckv, mla_cache_mode, mla_tile_config,
ip_attrs, ip_configs, rope_tile_shape)
torch_npu.npu.synchronize()
check(case_name, outputs, goldens)
params_base = {
'n1': 128,
'n2': 1,
'h': 7168,
'num_heads': 128,
'idx_n_heads': 64,
'q_lora_rank': 1536,
'qk_nope_head_dim': 128,
'idx_head_dim': 128,
'qk_rope_head_dim': 64,
'rope_head_dim': 64,
'kv_lora_rank': 512,
'block_size': 128,
'dtype': torch.bfloat16
}
@pytest.mark.soc("950", "910")
def test_b_4_s1_2_tilebs_8_d():
'''
mlaLp decode测试函数
'''
seed = 6
torch.manual_seed(seed)
b = 4
s1 = 2
s2 = 1024
params = params_base
params_base.update({
'b': b,
's': s1,
't': b * s1,
's1': s1,
's2': s2,
'actual_seq': torch.tensor([s2] * b, dtype=torch.int32).unsqueeze(-1),
})
mla_tile_config = MlaTileConfig()
mla_tile_config.tile_bs = 8
c0 = 16
m_tile_value = (min(32, mla_tile_config.tile_bs) + c0 - 1) // c0 * c0
mv_tile_value = min(8, mla_tile_config.tile_bs)
mla_tile_config.m_tile = m_tile_value
from mla_prolog_quant_impl import RopeTileShapeConfig
rope_tile_shape = RopeTileShapeConfig(two_dim=[128, 128], three_dim=[128, 128, 128], four_dim=[16, 128, 128, 128])
mla_tile_config.pre_quant_cube_tile = [m_tile_value, m_tile_value, 256, 256, 128, 128]
mla_tile_config.mv_tile = mv_tile_value
mla_tile_config.q_vec_tile0 = 1
mla_tile_config.q_vec_tile1 = 32
mla_tile_config.k_vec_tile0 = 2
mla_tile_config.k_vec_tile1 = 512
mla_tile_config.unroll_list = [8, 4, 2, 1]
mla_cache_mode = 'PA_BSND'
mla_epsilon_cq = 1e-5
mla_epsilon_ckv = 1e-5
import lightning_indexer_prolog_quant_impl as ip
ip_attrs = ip.IndexerPrologQuantAttr(
eps=1e-6,
layerout_query='TND',
layerout_key='PA_BSND',
)
ip_configs = ip.IndexerPrologQuantConfigs(
q_linear=[16, 16, 256, 256, 128, 128],
q_hd=[64, 64, 128, 128, 128, 128],
k_linear=[16, 16, 256, 256, 128, 128],
w_linear=[16, 16, 256, 256, 128, 128],
unroll_list=[128, 64, 32, 16, 8, 4, 2, 1],
cube_l1_reuse_setting={1: 4},
block_size=128,
t_sub_tile=1,
chunk_size=2,
vec_nbuffer_setting={-1: 1},
)
do_test("mla_prolog_indexer_prolog_quant.test_b_4_s1_2_tilebs_8",
params, mla_epsilon_cq, mla_epsilon_ckv,
mla_cache_mode, mla_tile_config, ip_attrs, ip_configs, rope_tile_shape, False)
@pytest.mark.skip(reason="prefill test cast")
def test_t_32_tilebs_16_p():
'''
mlaLp prefill测试函数
'''
seed = 7
torch.manual_seed(seed)
b = 16
s1 = 2
s2 = 1024
params = params_base
params_base.update({
'b': b,
's': s1,
't': b * s1,
's1': s1,
's2': s2,
'actual_seq': torch.tensor([s2] * b, dtype=torch.int32).unsqueeze(-1),
})
mla_tile_config = MlaTileConfig()
mla_tile_config.tile_bs = 16
c0 = 16
m_tile_value = (min(128, mla_tile_config.tile_bs) + c0 - 1) // c0 * c0
mv_tile_value = min(8, mla_tile_config.tile_bs)
mla_tile_config.m_tile = m_tile_value
mla_tile_config.pre_quant_cube_tile = [m_tile_value, m_tile_value, 256, 256, 128, 128]
mla_tile_config.mv_tile = mv_tile_value
mla_tile_config.q_vec_tile0 = 32
mla_tile_config.q_vec_tile1 = 128
mla_tile_config.k_vec_tile0 = 32
mla_tile_config.k_vec_tile1 = 512
mla_tile_config.cube_l1_reuse_setting = {0: 2, 1: 1, 2: 1, 3: 4, 4: 4, 5: 1}
mla_tile_config.unroll_list = [32, 16, 8, 4, 2, 1]
mla_tile_config.dynamic_unaligned_enable = True
from mla_prolog_quant_impl import RopeTileShapeConfig
rope_tile_shape = RopeTileShapeConfig(two_dim=[32, 64], three_dim=[32, 32, 128], four_dim=[16, 128, 128, 128])
mla_cache_mode = 'PA_BSND'
mla_epsilon_cq = 1e-5
mla_epsilon_ckv = 1e-5
import lightning_indexer_prolog_quant_impl as ip
ip_attrs = ip.IndexerPrologQuantAttr(
eps=1e-6,
layerout_query='TND',
layerout_key='PA_BSND',
)
ip_configs = ip.IndexerPrologQuantConfigs(
q_linear=[16, 16, 512, 512, 128, 128],
q_hd=[32, 32, 128, 128, 128, 128],
k_linear=[16, 16, 512, 512, 64, 64],
w_linear=[16, 16, 1024, 1024, 32, 32],
unroll_list=[32, 16, 8, 4, 2, 1],
cube_l1_reuse_setting={1: 4},
block_size=128,
t_sub_tile=1,
chunk_size=2,
vec_nbuffer_setting={-1: 1},
)
do_test("mla_prolog_indexer_prolog_prefill.test_t_32_tilebs_16",
params, mla_epsilon_cq, mla_epsilon_ckv,
mla_cache_mode, mla_tile_config, ip_attrs, ip_configs, rope_tile_shape, True)
@pytest.mark.skip(reason="large shape")
def test_t_512_tilebs_128_p():
'''
mlaLp prefill测试函数
'''
seed = 15
torch.manual_seed(seed)
b = 128
s1 = 4
s2 = 1024
params = params_base
params_base.update({
'b': b,
's': s1,
't': b * s1,
's1': s1,
's2': s2,
'actual_seq': torch.tensor([s2] * b, dtype=torch.int32).unsqueeze(-1),
})
mla_tile_config = MlaTileConfig()
mla_tile_config.tile_bs = 128
c0 = 16
m_tile_value = (min(128, mla_tile_config.tile_bs) + c0 - 1) // c0 * c0
mv_tile_value = min(8, mla_tile_config.tile_bs)
mla_tile_config.m_tile = m_tile_value
mla_tile_config.pre_quant_cube_tile = [m_tile_value, m_tile_value, 256, 256, 128, 128]
mla_tile_config.mv_tile = mv_tile_value
mla_tile_config.q_vec_tile0 = 32
mla_tile_config.q_vec_tile1 = 128
mla_tile_config.k_vec_tile0 = 32
mla_tile_config.k_vec_tile1 = 512
mla_tile_config.cube_l1_reuse_setting = {0: 2, 1: 1, 2: 1, 3: 4, 4: 4, 5: 1}
mla_tile_config.unroll_list = [128, 64, 32, 16, 8, 4, 2, 1]
mla_tile_config.dynamic_unaligned_enable = True
from mla_prolog_quant_impl import RopeTileShapeConfig
rope_tile_shape = RopeTileShapeConfig(two_dim=[32, 64], three_dim=[32, 32, 128], four_dim=[16, 128, 128, 128])
mla_cache_mode = 'PA_BSND'
mla_epsilon_cq = 1e-5
mla_epsilon_ckv = 1e-5
import lightning_indexer_prolog_quant_impl as ip
ip_attrs = ip.IndexerPrologQuantAttr(
eps=1e-6,
layerout_query='TND',
layerout_key='PA_BSND',
)
ip_configs = ip.IndexerPrologQuantConfigs(
q_linear=[128, 128, 256, 256, 256, 256],
q_hd=[128, 128, 64, 64, 128, 128],
k_linear=[64, 64, 256, 256, 128, 128],
w_linear=[32, 32, 512, 512, 64, 64],
unroll_list=[128, 64, 32, 16, 8, 4, 2, 1],
cube_l1_reuse_setting={1: 4, 3: 4},
block_size=128,
t_sub_tile=2,
chunk_size=1,
vec_nbuffer_setting={-1: 1},
)
do_test("mla_prolog_indexer_prolog_prefill.test_t_512_tilebs_128", params, mla_epsilon_cq, mla_epsilon_ckv,
mla_cache_mode, mla_tile_config, ip_attrs, ip_configs, rope_tile_shape, True)
if __name__ == '__main__':
if PRINT_DEBUG:
logging.basicConfig(format='%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s: %(message)s',
level=logging.DEBUG)
test_b_4_s1_2_tilebs_8_d()