import torch
import torch_npu
import mie_ops
torch.npu.config.allow_internal_format = True
experimental_config = torch_npu.profiler._ExperimentalConfig(
export_type=[
torch_npu.profiler.ExportType.Text,
torch_npu.profiler.ExportType.Db
],
profiler_level=torch_npu.profiler.ProfilerLevel.Level0,
msprof_tx=False,
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
l2_cache=False,
op_attr=False,
data_simplification=False,
record_op_args=False,
gc_detect_threshold=None
)
with torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.CPU,
torch_npu.profiler.ProfilerActivity.NPU
],
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=1, repeat=1, skip_first=1),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./result"),
record_shapes=False,
profile_memory=False,
with_stack=False,
with_modules=False,
with_flops=False,
experimental_config=experimental_config) as prof:
wdqkv_nd = torch.randint(4, 5, (2112, 7168), dtype=torch.int8).npu()
wdqkv=torch_npu.npu_format_cast(wdqkv_nd, 29)
wuq_nd = torch.randint(4, 5, (8*192, 1536), dtype=torch.int8).npu()
wuq=torch_npu.npu_format_cast(wuq_nd, 29)
decode_q_nope, k_cache, decode_q_pe, v_cache = torch.ops.mie_ops.npu_mla_process(
input=torch.randn((6, 7168), dtype=torch.bfloat16).npu(),
gamma0=torch.randn((7168), dtype=torch.bfloat16).npu(),
beta0=torch.zeros(7168).to(torch.bfloat16).npu(),
wdqkv=wdqkv,
descale0=torch.randn((2112,), dtype=torch.float).npu(),
gamma1=torch.randn((1536), dtype=torch.bfloat16).npu(),
beta1=torch.zeros(1536).to(torch.bfloat16).npu(),
wuq=wuq,
descale1=torch.randn((8*192,), dtype=torch.float).npu(),
gamma2=torch.randn((512), dtype=torch.bfloat16).npu(),
cos=torch.randn((6, 64), dtype=torch.bfloat16).npu(),
sin=torch.randn((6, 64), dtype=torch.bfloat16).npu(),
wuk=torch.randn((8, 128, 512), dtype=torch.bfloat16).npu(),
kv_cache=torch.full([8, 128, 1, 512], float('nan'), dtype=torch.bfloat16).npu(),
kv_cache_rope=torch.full([8, 128, 1, 64], float('nan'), dtype=torch.bfloat16).npu(),
slotmapping=torch.tensor([1,3,5,7,9,11], dtype=torch.int32).npu(),
quant_scale0=torch.tensor([1], dtype=torch.bfloat16).npu(),
quant_offset0=torch.tensor([0], dtype=torch.int8).npu(),
bias0=torch.randint(4, 5, (1536,), dtype=torch.int32).npu(),
quant_scale1=torch.tensor([1], dtype=torch.bfloat16).npu(),
quant_offset1=torch.tensor([0], dtype=torch.int8).npu(),
bias1=torch.randint(4, 5, (1536,), dtype=torch.int32).npu(),
ctkv_scale=torch.tensor([1], dtype=torch.bfloat16).npu(),
q_nope_scale=torch.tensor([1], dtype=torch.bfloat16).npu(),
cache_mode_opt="krope_ctkv",
quant_mode_opt="per_tensor_quant_asymm",
)
print(f"decode_q_nope is {decode_q_nope} end, {k_cache} ,end {decode_q_pe} end {v_cache} end")