""" """
import torch
import torch_npu
import pypto
import os
import pytest
import logging
from lightning_indexer_prolog_quant_v4_impl import *
from utils.golden.common_func import (
apply_rotary_pos_emb,
gen_uniform_data,
quant_golden,
)
from utils.compare import compare
def compute_quant_lightning_indexer_prolog(inputs, params):
qr = inputs[0]
idx_wq_b = inputs[1]
x = inputs[2]
weights_proj = inputs[3]
cos = inputs[4]
sin = inputs[5]
hadamard = inputs[6]
qr_scale = inputs[7]
idx_wq_b_scale = inputs[8]
t = params.get("t")
idx_nq = params.get("idx_nq")
head_dim = params.get("head_dim")
rope_dim = params.get("rope_dim")
calc_dtype = torch.bfloat16
q_fp32 = torch.matmul(qr.to(torch.int32), idx_wq_b.to(torch.int32)).to(
torch.float32
)
q_fp32 = q_fp32 * qr_scale
q_fp32 = q_fp32 * idx_wq_b_scale.reshape(1, idx_nq * head_dim)
q_re = q_fp32.to(calc_dtype).reshape(t, idx_nq, head_dim)
q_nope, q_rope = torch.split(q_re, [head_dim - rope_dim, rope_dim], dim=-1)
q_roped = apply_rotary_pos_emb(q_rope, cos, sin)
q = torch.cat([q_nope, q_roped], dim=-1)
q_hadamard = torch.matmul(
q.to(torch.float32), hadamard.reshape(1, head_dim, head_dim).to(torch.float32)
).to(calc_dtype)
q_int8, q_scale = quant_golden(q_hadamard)
q_scale = q_scale.to(torch.float16).reshape(t, idx_nq)
weights = (
torch.matmul(x.to(torch.float32), weights_proj.to(torch.float32))
.to(calc_dtype)
.to(torch.float32)
)
weights = weights * (idx_nq**-0.5) * (head_dim**-0.5)
weights = weights.to(torch.float16)
return q_int8, weights, q_scale
def gen_quant_attention_post_golden(params):
torch.manual_seed(42)
t = params.get("t")
idx_nq = params.get("idx_nq")
head_dim = params.get("head_dim")
rope_dim = params.get("rope_dim")
q_lora_rank = params.get("q_lora_rank")
h = params.get("h")
qr_ori = gen_uniform_data([t, q_lora_rank], -1, 1, torch.int8)
idx_wq_b_ori = gen_uniform_data([q_lora_rank, idx_nq * head_dim], -1, 1, torch.int8)
x = gen_uniform_data([t, h], -1, 1, torch.bfloat16)
weights_proj = gen_uniform_data([h, idx_nq], -0.1, 0.1, torch.bfloat16)
cos = gen_uniform_data([t, rope_dim], -1, 1, torch.bfloat16)
sin = gen_uniform_data([t, rope_dim], -1, 1, torch.bfloat16)
hadamard = gen_uniform_data([head_dim, head_dim], -1, 1, torch.bfloat16)
qr, qr_scale = quant_golden(qr_ori)
idx_wq_b, idx_wq_b_scale = quant_golden(idx_wq_b_ori.t())
idx_wq_b = idx_wq_b.t()
inputs = [
qr,
idx_wq_b,
x,
weights_proj,
cos,
sin,
hadamard,
qr_scale,
idx_wq_b_scale,
]
q_golden, weights_golden, q_scale_golden = compute_quant_lightning_indexer_prolog(
inputs, params
)
return inputs, q_golden, weights_golden, q_scale_golden
class QuantLightningIndexerProlog(torch.nn.Module):
def forward(
self,
qr,
idx_wq_b,
x,
weights_proj,
cos,
sin,
hadamard,
qr_scale,
idx_wq_b_scale,
):
for _ in range(20):
torch.add(qr_scale, 0)
return torch.ops.pypto.quant_lightning_indexer_prolog(
qr, idx_wq_b, x, weights_proj, cos, sin, hadamard, qr_scale, idx_wq_b_scale
)
def do_indexer_prolog_quant_func(inputs, params, golden_list):
"""
group name dtype shape format
INPUT 0 qr DT_INT8 (t, q_lora_rank) ND
INPUT 1 idx_wq_b DT_INT8 (q_lora_rank, idx_nq * head_dim) ND
INPUT 2 x DT_BF16 (t, h) ND
INPUT 3 weights_proj DT_BF16 (h, idx_nq) ND
INPUT 4 cos DT_BF16 (t, rope_dim) ND
INPUT 5 sin DT_BF16 (t, rope_dim) ND
INPUT 6 hadamard DT_BF16 (head_dim, head_dim) ND
INPUT 7 qr_scale DT_FP32 (t, 1) ND
INPUT 8 idx_wq_b_scale DT_FP32 (idx_nq * head_dim, 1) ND
OUTPUT 0 q DT_INT8 (t, idx_nq * head_dim) ND
OUTPUT 1 weights DT_FP16 (t, idx_nq) ND
OUTPUT 2 q_scale DT_FP16 (t, idx_nq) ND
CONFIGS tile_config / / /
"""
torch_npu.npu.config.allow_internal_format = True
device_id = int(os.environ.get("TILE_FWK_DEVICE_ID", 0))
torch.npu.set_device(device_id)
qr = inputs[0].npu()
idx_wq_b = inputs[1].npu()
idx_wq_b_nz = torch_npu.npu_format_cast(idx_wq_b, torch_npu.Format.FRACTAL_NZ)
x = inputs[2].npu()
weights_proj = inputs[3].npu()
weights_proj_nz = torch_npu.npu_format_cast(weights_proj, torch_npu.Format.FRACTAL_NZ)
cos = inputs[4].npu()
sin = inputs[5].npu()
hadamard = inputs[6].npu()
qr_scale = inputs[7].npu()
idx_wq_b_scale = inputs[8].npu()
t = params.get("t")
idx_nq = params.get("idx_nq")
head_dim = params.get("head_dim")
q = torch.zeros([t, idx_nq, head_dim]).to(torch.int8).npu()
weights = torch.zeros([t, idx_nq]).to(torch.float16).npu()
q_scale = torch.zeros([t, idx_nq]).to(torch.float16).npu()
inputs = [qr, idx_wq_b_nz, x, weights_proj_nz, cos, sin, hadamard, qr_scale, idx_wq_b_scale, q, weights, q_scale]
tile_config = IndexerPrologQuantConfig(unroll_list=[128, 64, 32, 16, 8, 1])
quant_lightning_indexer_prolog_kernel(*inputs, tile_config)
pypto.runtime._device_synchronize()
compare(q.cpu(), golden_list[0], "q", atol=0.0001, rtol=0.005)
compare(weights.cpu(), golden_list[1], "weights", atol=0.0001, rtol=0.007825)
compare(q_scale.cpu(), golden_list[2], "q_scale", atol=0.0001, rtol=0.007825)
def do_indexer_prolog_quant_torch_graph(inputs, golden_list):
"""
group name dtype shape format
INPUT 0 qr DT_INT8 (t, q_lora_rank) ND
INPUT 1 idx_wq_b DT_INT8 (q_lora_rank, idx_nq * head_dim) ND
INPUT 2 x DT_BF16 (t, h) ND
INPUT 3 weights_proj DT_BF16 (h, idx_nq) ND
INPUT 4 cos DT_BF16 (t, rope_dim) ND
INPUT 5 sin DT_BF16 (t, rope_dim) ND
INPUT 6 hadamard DT_BF16 (head_dim, head_dim) ND
INPUT 7 qr_scale DT_FP32 (t, 1) ND
INPUT 8 idx_wq_b_scale DT_FP32 (idx_nq * head_dim, 1) ND
OUTPUT 0 q DT_INT8 (t, idx_nq * head_dim) ND
OUTPUT 1 weights DT_FP16 (t, idx_nq) ND
OUTPUT 2 q_scale DT_FP16 (t, idx_nq) ND
CONFIGS tile_config / / /
"""
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
torch_npu.npu.config.allow_internal_format = True
device_id = int(os.environ.get("TILE_FWK_DEVICE_ID", 0))
torch.npu.set_device(device_id)
qr = inputs[0].npu()
idx_wq_b = inputs[1].npu()
idx_wq_b_nz = torch_npu.npu_format_cast(idx_wq_b, torch_npu.Format.FRACTAL_NZ)
x = inputs[2].npu()
weights_proj = inputs[3].npu()
weights_proj_nz = torch_npu.npu_format_cast(
weights_proj, torch_npu.Format.FRACTAL_NZ
)
cos = inputs[4].npu()
sin = inputs[5].npu()
hadamard = inputs[6].npu()
qr_scale = inputs[7].npu()
idx_wq_b_scale = inputs[8].npu()
compiler_config = CompilerConfig()
compiler_config.mode = "reduce-overhead"
npu_backend = tng.get_npu_backend(compiler_config=compiler_config)
model = torch.compile(
QuantLightningIndexerProlog(),
dynamic=False,
fullgraph=True,
backend=npu_backend,
)
q, weights, q_scale = model(
qr,
idx_wq_b_nz,
x,
weights_proj_nz,
cos,
sin,
hadamard,
qr_scale,
idx_wq_b_scale,
)
pypto.runtime._device_synchronize()
compare(q.cpu(), golden_list[0], "q", atol=0.0001, rtol=0.005)
compare(weights.cpu(), golden_list[1], "weights", atol=0.0001, rtol=0.007825)
compare(q_scale.cpu(), golden_list[2], "q_scale", atol=0.0001, rtol=0.007825)
def get_indexer_prolog_quant_config(case_name: str):
test_case_config = {
"test_indexer_prolog_quant_b64_s1": 64,
"test_indexer_prolog_quant_b64_s2": 128,
"test_indexer_prolog_quant_b64_s4": 256,
"test_indexer_prolog_quant_graph": 511,
}
return test_case_config.get(case_name)
def do_indexer_prolog_quant_entry(case_name: str, is_torch_graph: bool = False):
params = {
"idx_nq": 64,
"head_dim": 128,
"rope_dim": 64,
"q_lora_rank": 1024,
"h": 4096,
}
params["t"] = get_indexer_prolog_quant_config(case_name)
if not params:
logging.error("Can't get func to gen golden, Case(%s)", case_name)
return False
inputs, q_golden, weights_golden, q_scale_golden = gen_quant_attention_post_golden(
params
)
if is_torch_graph:
print("\n =============== torch graph ====================")
do_indexer_prolog_quant_torch_graph(
inputs, [q_golden, weights_golden, q_scale_golden]
)
else:
print("\n =============== st ====================")
do_indexer_prolog_quant_func(
inputs, params, [q_golden, weights_golden, q_scale_golden]
)
return True
@pytest.mark.skip(reason="large test case")
def test_indexer_prolog_quant_b64_s1():
"""
lightning_indexer_prolog quant decode mtp=0 case
"""
do_indexer_prolog_quant_entry(
"test_indexer_prolog_quant_b64_s1", is_torch_graph=False
)
@pytest.mark.skip(reason="large test case")
def test_indexer_prolog_quant_b64_s2():
"""
lightning_indexer_prolog quant decode mtp=1 case
"""
do_indexer_prolog_quant_entry(
"test_indexer_prolog_quant_b64_s2", is_torch_graph=False
)
@pytest.mark.skip(reason="large test case")
def test_indexer_prolog_quant_b64_s4():
"""
lightning_indexer_prolog quant decode mtp=3 case
"""
do_indexer_prolog_quant_entry(
"test_indexer_prolog_quant_b64_s4", is_torch_graph=False
)
@pytest.mark.skip(reason="large test case")
def test_indexer_prolog_quant_graph():
"""
lightning_indexer_prolog quant graph typical case
"""
do_indexer_prolog_quant_entry(
"test_indexer_prolog_quant_graph", is_torch_graph=True
)
if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s: %(message)s",
level=logging.INFO,
)
test_indexer_prolog_quant_b64_s1()