import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_npu
def npu_mxfp8_quant(x, dst_type=torch.float8_e4m3fn, output_layout="TND", axis=-1):
"""
Perform MXFP8 dynamic quantization for NPU fused attention.
Args:
x: Input tensor with shape (batch, seq_len, num_heads, head_dim) in BSND format.
dst_type: Target data type for quantized output, default is FP8 E4M3.
output_layout: Output layout for FA v2. Supports "TND" now.
Returns:
tuple: (quantized_tensor, scale_tensor)
"""
b, s, n, d = x.shape
if output_layout == "TND":
x_quant, x_scale = torch_npu.npu_dynamic_mx_quant(
x.view(b * s, n, d),
dst_type=dst_type,
axis=axis
)
else:
raise ValueError("Flash Attention MXFP8 quantization currently only supports the TND layout")
return x_quant, x_scale
def _actual_seq_lens(batch_size, seq_len, device):
return torch.arange(
seq_len,
seq_len * (batch_size + 1),
seq_len,
dtype=torch.int64,
device=device,
)
def npu_mxfp8_attn(q, k, v, dst_type=torch.float8_e4m3fn, softmax_scale=None):
"""
Perform MXFP8 quantized attention computation on NPU.
Q/K/V are dynamically quantized with npu_dynamic_mx_quant, then passed to
npu_fused_infer_attention_score_v2 with MXFP8 quant modes.
"""
b, s, n, d = q.shape
kv_s = k.shape[1]
out_dtype = q.dtype
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(d)
q_quant, q_scale = npu_mxfp8_quant(q, dst_type=dst_type, output_layout="TND")
k_quant, k_scale = npu_mxfp8_quant(k, dst_type=dst_type, output_layout="TND")
v_quant, v_scale = npu_mxfp8_quant(v, dst_type=dst_type, output_layout="TND", axis=0)
try:
attn_out = torch_npu.npu_fused_infer_attention_score_v2(
q_quant,
k_quant,
v_quant,
actual_seq_qlen=_actual_seq_lens(b, s, q.device),
actual_seq_kvlen=_actual_seq_lens(b, kv_s, k.device),
dequant_scale_query=q_scale,
dequant_scale_key=k_scale,
dequant_scale_value=v_scale,
num_query_heads=n,
num_key_value_heads=k.shape[2],
softmax_scale=softmax_scale,
input_layout="TND",
sparse_mode=0,
query_quant_mode=6,
key_quant_mode=6,
value_quant_mode=8,
query_dtype=torch.float8_e4m3fn,
key_dtype=torch.float8_e4m3fn,
value_dtype=torch.float8_e4m3fn,
dequant_scale_query_dtype=torch_npu.float8_e8m0fnu,
dequant_scale_key_dtype=torch_npu.float8_e8m0fnu,
dequant_scale_value_dtype=torch_npu.float8_e8m0fnu,
out_dtype=out_dtype
)[0]
except RuntimeError as exc:
raise RuntimeError(
"MXFP8 FA quantization is supported only with CANN 9.1.0 or later."
) from exc
return attn_out.view(b, s, n, d)