import math
import unittest
import copy
import random
import torch
import numpy as np
import torch_npu
import torch_npu.npu.utils as utils
import re
from ml_dtypes import float8_e4m3fn
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
def s8_saturation(inputdata):
inputdata = torch.where(inputdata > 127, 127, inputdata)
inputdata = torch.where(inputdata < -128, -128, inputdata)
return inputdata.to(torch.int8)
def s9_saturation(inputdata):
inputdata = torch.where(inputdata > 255, 255, inputdata)
inputdata = torch.where(inputdata < -256, -256, inputdata)
return inputdata
def rotate_half(x):
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.concatenate((-x2, x1), dim=-1)
def dynamic_quant(inputdata, smooth_scale=None):
T = inputdata.size(0)
H = inputdata.size(1)
y = torch.zeros(T, H).to(torch.int32)
scale = torch.zeros(T).to(torch.float32)
inputdata = inputdata.reshape(T, H).to(torch.float32)
smooth_scale = smooth_scale.to(torch.float32)
for bs_index in range(T):
abs_bs_tensor = torch.abs(inputdata[bs_index, :] * smooth_scale[0, :])
scale_bs = abs_bs_tensor.max() / 127
scale[bs_index] = scale_bs
y[bs_index:] = torch.round(inputdata[bs_index:] * smooth_scale[0, :] / scale_bs)
return y, scale
def dynamic_quant_without_smooth_scale(inputdata, out_deqq_shape_shape):
T = inputdata.size(0)
N = inputdata.size(1)
H = inputdata.size(2)
quant_loops = inputdata.size(0)
eles_with_one_scale = inputdata.size(1) * inputdata.size(2)
if len(out_deqq_shape_shape) == 3:
quant_loops = inputdata.size(0) * inputdata.size(1)
eles_with_one_scale = inputdata.size(2)
y = torch.zeros(quant_loops, eles_with_one_scale).to(torch.int32)
scale = torch.zeros(quant_loops).to(torch.float32)
inputdata = inputdata.reshape(quant_loops, eles_with_one_scale).to(torch.float32)
max_values, _ = torch.max(torch.abs(inputdata), dim=-1, keepdim=True)
scale = max_values / 127
y = torch.round(inputdata / scale)
y = s8_saturation(y)
if len(out_deqq_shape_shape) == 2:
print(f"[INFO]dynamic_quant_without_smooth_scale in per_token mode")
return y.reshape(T, N, H), scale.reshape(quant_loops, 1).to(torch.float64)
else:
print(f"[INFO]dynamic_quant_without_smooth_scale in per_head mode")
return y.reshape(T, N, H), scale.reshape(T, N, 1).to(torch.float64)
def dequant(inputdata, deq_scale_q_nope, quant_scale_ckv):
org_shape = inputdata.size()
quant_loops = inputdata.size(0)
eles_with_one_scale = inputdata.size(1) * inputdata.size(2)
if len(deq_scale_q_nope.size()) == 3:
quant_loops = inputdata.size(0) * inputdata.size(1)
eles_with_one_scale = inputdata.size(2)
inputdata = inputdata.reshape(quant_loops, eles_with_one_scale)
deq_scale_q_nope = deq_scale_q_nope.reshape(quant_loops, 1)
for quant_idx in range(quant_loops):
inputdata[quant_idx, :] = inputdata[quant_idx, :] * quant_scale_ckv / (deq_scale_q_nope[quant_idx, 0])
return inputdata.reshape(org_shape)
def quant(x, qscale):
scaled_values = (x * qscale).round().to(torch.float32)
s9_res = s9_saturation(scaled_values)
s8_res_cal = s8_saturation(s9_res)
return s8_res_cal
def ieee_754_conversion(sign, exponent_raw, mantissa, exp_len=8, mant_len=7):
"""Convert binary data into the floating point value"""
sign_mult = -1 if sign == 1 else 1
exponent = exponent_raw - (2 ** (exp_len - 1) - 1)
mant_mult = 1
for b in range(mant_len - 1, -1, -1):
if mantissa & (2 ** b):
mant_mult += 1 / (2 ** (mant_len - b))
return sign_mult * (2 ** exponent) * mant_mult
def trans_torch_fp8_e8m0_to_bf16(array):
new_array = torch.zeros_like(array).to(torch.bfloat16)
for i in range(array.size(0)):
for j in range(array.size(1)):
new_value = ieee_754_conversion(0, int(array[i][j]), 0)
new_array[i][j] = new_value
return new_array
def get_dtype_range(dt):
if "bfloat16" in str(dt):
return -float.fromhex("0x1.FEp127"), float.fromhex("0x1.FEp127")
if "uint4" in str(dt):
return 0, 15
if "int4" in str(dt):
return -8, 7
if "bool" in str(dt):
return 0, 1
if "float4_e2m1" in str(dt):
return -float.fromhex("0x1.8p2"), float.fromhex("0x1.8p2")
if "float4_e1m2" in str(dt):
return -float.fromhex("0x1.Cp0"), float.fromhex("0x1.Cp0")
if "float8_e8m0" in str(dt):
return float.fromhex("0x1.p-127"), float.fromhex("0x1.p127")
if "float8_e5m2" in str(dt):
return -float.fromhex("0x1.Cp15"), float.fromhex("0x1.Cp15")
if "float8_e4m3fn" in str(dt):
return -float.fromhex("0x1.Cp8"), float.fromhex("0x1.Cp8")
if "hifloat8" in str(dt):
return -float.fromhex("0x1.p15"), float.fromhex("0x1.p15")
if "complex32" in str(dt):
dt = "float16"
numpy_dtype = np.dtype(dt)
if numpy_dtype.kind in "iu":
numpy_info = np.iinfo(numpy_dtype)
else:
numpy_info = np.finfo(numpy_dtype)
return numpy_info.min, numpy_info.max
def _mx_reshape_to_blocks(fp_array: np.ndarray, axis: int, block_size: int):
fp_array = np.expand_dims(fp_array, axis=axis + 1)
orig_shape = fp_array.shape
pad = [[0, 0] for _ in range(len(orig_shape))]
pad_size = orig_shape[axis] % block_size
pad[axis][1] = block_size - pad_size
if pad_size > 0:
fp_array = np.pad(fp_array, pad, 'constant')
padded_shape = fp_array.shape
reshape = list(padded_shape)
reshape[axis + 1] = block_size
reshape[axis] = reshape[axis] // block_size
fp_array = fp_array.reshape(reshape)
return fp_array, orig_shape, padded_shape
def _mx_calculate_share_exp(fp_array: np.ndarray, scale_axis: int, mx_ele_dtype: str):
FP32_EXPONENT_BIAS = 127
FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1)
max_norm = get_dtype_range(mx_ele_dtype)[1]
ele_emax = int(np.log2(max_norm))
fp_abs_max = np.max(np.abs(fp_array), axis=scale_axis, keepdim=True)
res = np.floor(
np.log2(fp_abs_max.astype(np.float32) + FP32_MIN_NORMAL * (fp_abs_max == 0))
) - ele_emax
res = res + FP32_EXPONENT_BIAS
res[fp_abs_max == 0] = -float("inf")
return res
def _mx_round_mantissa(fp_array: np.ndarray, round_mode: str):
if round_mode in ("rint", "even"):
fp_array = np.rint(fp_array)
elif round_mode in ("round", "nearest"):
sign = np.signbit(fp_array)
rounded_abs = np.floor(np.abs(fp_array) + np.array([0.5], dtype=fp_array.dtype))
fp_array = np.where(sign, -round_abs, rounded_abs)
elif round_mode == "floor":
fp_array = np.floor(fp_array)
elif round_mode == "ceil":
fp_array = np.ceil(fp_array)
elif round_mode == "trunc":
fp_array = np.trunc(fp_array)
else:
raise Exception(f"Unrecognized round method {round_mode}")
return fp_array
def _mx_quantize_to_element_format(fp_array: np.ndarray, share_exp: np.ndarray,
mx_ele_dtype: str, round_mode: str):
mx_dtype = str(mx_ele_dtype)
match = re.search(r'e(\d+)m(\d+)', mx_dtype)
if match:
exp_bits = int(match.group(1))
mantissa_bits = int(match.group(2))
else:
raise ValueError(f'mx element dtype [{mx_ele_dtype}] is not recognized.')
ret = fp_array / (2 ** (share_exp - 127))
private_exp = np.floor(np.log2(np.abs(ret.astype(np.float32)) + (ret == 0))).astype(fp_array.dtype, copy=False)
min_exp = 0 if "float4_e1m2" in mx_dtype else -(2 ** (exp_bits - 1)) + 2
private_exp = private_exp.clip(min=min_exp)
ret = ret / (2 ** private_exp) * (2 ** mantissa_bits)
ret = _mx_round_mantissa(ret, round_mode)
ret = ret / (2 ** mantissa_bits) * (2 ** private_exp)
max_norm = get_dtype_range(mx_dtype)[1]
np.clip(ret, a_min=-max_norm, a_max=max_norm, out=ret)
return ret
def pad_to_even(tensor: np.ndarray, axis: int) ->np.ndarray:
if not isinstance(tensor, np.ndarray):
raise ValueError("Input must be a numpy ndarray.")
if axis < 0 or axis >= tensor.ndim:
raise ValueError(f"Axis {axis} is out of bounds for tensor with {tensor.ndim} dimensions.")
shape = tensor.shape
length = shape[axis]
if length % 2 == 0:
return tensor
pad_width = [(0, 0)] * tensor.ndim
pad_width[axis] = (0, 1)
padded_tensor = np.pad(tensor, pad_width, mode='constant', constant_values=2 ** -127)
return padded_tensor
def interleave(tensor:np.ndarray, axis: int, n_group: int = 2)-> np.ndarray:
if not isinstance(tensor, np.ndarray):
raise ValueError("Input must be a numpy ndarray")
if axis < 0 or axis >= tensor.ndim:
raise ValueError(f"Axis {axis} is out of bounds for tensor with {tensor.ndim} dimensions.")
length = tensor.shape[axis]
if length % n_group != 0:
raise ValueError(f"Axis length ({length}) must be divisible by n_group ({n_group})")
group_length = length // n_group
shape = list(tensor.shape)
new_shape = (
shape[:axis] +
[group_length, 2] +
shape[axis + 1:])
reshaped = tensor.reshape(new_shape)
transpose_order = (
list(range(0, axis + 1)) +
list(range(axis + 2, len(new_shape))) +
[axis + 1,])
transposed = reshaped.transpose(transpose_order)
return transposed
def _mx_undo_reshape_to_blocks(fp_array: np.ndarray, axis: int, orig_shape: tuple, padded_shape: tuple):
fp_array = fp_array.reshape(padded_shape)
if tuple(padded_shape) != tuple(orig_shape):
slices = [slice(0, x) for x in orig_shape]
fp_array = fp_array[tuple(slices)]
fp_array = np.squeeze(fp_array, axis=axis + 1)
return fp_array
def numpy_float8_e4m3fn():
try:
return float8_e4m3fn
except ModuleNotFoundError:
raise RuntimeError("ml_dtypes is needed to support float8_e4m3fn dtype!!! "
"Please install with pip3 install ml-dtypes")
def dynamic_mx_quant_cq(fp_array: np.ndarray, mx_ele_dtype: str = "float4_e2m1",
axis: int = -1, block_size: int = 32, round_mode: str = "rint") -> tuple:
if not isinstance(fp_array, np.ndarray):
raise RuntimeError(f"Input tensor to be quantized should be numpy array. But got {type(fp_array)}")
if fp_array.dtype.name not in ("bfloat16", "float16", "float32"):
raise RuntimeError(f"Dtype of input tensor to be quantized is not supported: {fp_array.dtype.name}")
if mx_ele_dtype not in ("float4_e2m1", "float4_e1m2", "float8_e4m3fn", "float8_e5m2"):
raise NotImplementedError(f"Not support {mx_ele_dtype} yet!")
axis = len(fp_array.shape) + axis if axis < 0 else axis
fp_array, orig_shape, padded_shape = _mx_reshape_to_blocks(fp_array, axis, block_size)
share_exp = _mx_calculate_share_exp(fp_array, scale_axis=axis + 1, mx_ele_dtype=mx_ele_dtype)
scale_emax = 2 ** 7 - 1
share_exp[(share_exp - 127) > scale_emax] = float("NaN")
share_exp[(share_exp - 127) < -scale_emax] = -scale_emax
ele_array = _mx_quantize_to_element_format(fp_array, share_exp, mx_ele_dtype, round_mode)
ele_array =_mx_undo_reshape_to_blocks(ele_array, axis, orig_shape, padded_shape)
share_exp = np.squeeze(share_exp, axis=axis + 1)
scale_array = share_exp
if ele_array.dtype.name == "bfloat16":
ele_array = ele_array.astype("float32", copy=False)
ele_array = np.nan_to_num(ele_array, nan=0.0, copy=False)
ele_array = ele_array.astype(numpy_float8_e4m3fn(), copy=False)
scale_array_pad = pad_to_even(scale_array, axis=axis)
result_shape = copy.deepcopy(list(scale_array_pad.shape))
result_shape.append(2)
result_shape[axis] = scale_array_pad.shape[axis] // 2
if axis != (len(fp_array.shape) - 1):
scale_array_pad = interleave(scale_array_pad, axis=axis)
scale_array_pad = scale_array_pad.reshape(result_shape)
scale_array = scale_array_pad.astype("uint8", copy=False)
return scale_array, ele_array
def quant_ckv_per_tensor(input, quant_scale_ckv):
scaled_value = input * quant_scale_ckv
scaled_value = np.round(scaled_value, 8)
scaled_value = scaled_value.astype(numpy_float8_e4m3fn(), copy=False)
return scaled_value
def dynamic_mx_quant_qn(x):
scale_max = np.float32(448.0)
x = x.astype("float32")
input_mul = x
input_abs = np.abs(input_mul)
input_max = np.max(input_abs, axis=-1, keepdims=True)
scale = input_max * (np.float32(1.0) / scale_max)
input_scaled = input_mul / scale
output_data = input_scaled.astype("float8_e4m3fn", copy=False)
return scale, output_data
class TestPromptFlashAttetion(TestCase):
def baseline(self, token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr, quant_scale_ckv,
smooth_scale_cq, query_norm_flag, weight_quant_mode, kv_cache_quant_mode,
query_quant_mode, ckvkr_repo_mode, quant_scale_repo_mode, tile_size, k_nope_clip_alpha, mla_param, qc_qr_scale=1.0, kc_scale=1.0):
B = mla_param['B']
S1 = mla_param['S1']
S2 = mla_param['S2']
D = mla_param['D']
Dr = mla_param['Dr']
N1 = mla_param['N1']
N2 = mla_param['N2']
He = mla_param['He']
Hckv = mla_param['Hckv']
Hcq = mla_param['Hcq']
BlockNum = mla_param['BlockNum']
BlockSize = mla_param['BlockSize']
T = mla_param['T']
out_deqq_shape_shape = mla_param["out_deqq_shape_shape"]
index_table = cache_index
cos = rope_cos
sin = rope_sin
dequant_scale_qcqr = None
dequant_scale_q_nope = None
if not mla_param["t_flag"]:
T = B * S1
token_x = token_x.reshape(T, He)
cos = cos.reshape(T, Dr)
sin = sin.reshape(T, Dr)
index_table = index_table.reshape(T)
token_x = token_x.to(torch.int32)
w_dq = weight_dq.to(torch.int32)
matmul1_res = torch.matmul(token_x, w_dq).to(torch.int32)
matmul1_res = matmul1_res.to(torch.float32)
for t_index in range(T):
matmul1_res[t_index, :] = matmul1_res[t_index, :] * dequant_scale_x[t_index, 0]
for h_index in range(Hcq):
matmul1_res[:, h_index] = matmul1_res[:, h_index] * dequant_scale_w_dq[0, h_index]
ep1 = float(rmsnorm_epsilon_cq)
gamma1 = rmsnorm_gamma_cq
norm1_res = matmul1_res / torch.sqrt(torch.mean(matmul1_res ** 2, dim=-1, keepdim=True) + ep1)
norm1_res *= gamma1
norm1_res *= qc_qr_scale
weight_uq_qr = weight_uq_qr.to(torch.int32)
norm1_res, dequant_scale_qcqr = dynamic_quant(norm1_res, smooth_scale_cq)
w_uq_qr = weight_uq_qr
matmul2_res = torch.matmul(norm1_res, w_uq_qr).to(torch.int32)
matmul2_res = matmul2_res.to(torch.float32)
for t_index in range(T):
matmul2_res[t_index, :] = matmul2_res[t_index, :] * dequant_scale_qcqr[t_index]
for nddr_index in range(matmul2_res.shape[1]):
matmul2_res[:, nddr_index] = matmul2_res[:, nddr_index] * dequant_scale_w_uqqr[0, nddr_index]
matmul2_res = matmul2_res.reshape(T, N1, D + Dr)
splitd1_res1 = matmul2_res[:, :, :D]
splitd1_res2 = matmul2_res[:, :, D:]
w_uk = weight_uk.to(torch.float32)
splitd1_res1 = splitd1_res1.transpose(0, 1)
splitd1_res1 = splitd1_res1.to(torch.bfloat16).to(torch.float32)
query_mla = torch.zeros((N1, T, Hckv))
for n1_index in range(N1):
query_mla[n1_index, :, :] = torch.matmul(splitd1_res1[n1_index, :, :], w_uk[n1_index, :, :]).to(torch.float32)
query_mla = query_mla.transpose(0, 1)
query_mla = query_mla.to(torch.bfloat16).to(torch.float32)
query_mla, dequant_scale_q_nope = dynamic_quant_without_smooth_scale(query_mla, out_deqq_shape_shape)
query_mla = query_mla if mla_param["t_flag"] else query_mla.reshape(B, S1, N1, Hckv)
expanded_cos = cos.unsqueeze(1).repeat(1, N1, 1)
expanded_sin = sin.unsqueeze(1).repeat(1, N1, 1)
q = splitd1_res2.reshape(T, N1, int(Dr / 2), 2).transpose(3, 2).reshape(T, N1, Dr)
query_rope_mla = (q * expanded_cos) + (rotate_half(q) * expanded_sin)
query_rope_mla = query_rope_mla.to(torch.bfloat16).to(torch.float32)
query_rope_mla = dequant(query_rope_mla, dequant_scale_q_nope, quant_scale_ckv)
query_rope_mla = query_rope_mla if mla_param["t_flag"] else query_rope_mla.reshape(B, S1, N1, Dr)
w_kv_kr = weight_dkv_kr.to(torch.int32)
matmul4_res = torch.matmul(token_x, w_kv_kr).to(torch.int32).to(torch.float32)
for t_index in range(T):
matmul4_res[t_index, :] = matmul4_res[t_index, :] * dequant_scale_x[t_index, 0]
for h_index in range(Hckv + Dr):
matmul4_res[:, h_index] = matmul4_res[:, h_index] * dequant_scale_w_dkvkr[0, h_index]
splitd2_res1 = matmul4_res[:, :Hckv]
splitd2_res2 = matmul4_res[:, Hckv:]
ep2 = float(rmsnorm_epsilon_ckv)
gamma2 = rmsnorm_gamma_ckv
norm2_res = splitd2_res1 / torch.sqrt(torch.mean(splitd2_res1 ** 2, dim=-1, keepdim=True) + ep2)
norm2_res *= gamma2
norm2_res *= kc_scale
norm2_res = quant(norm2_res, quant_scale_ckv)
kv_cache = copy.deepcopy(kv_cache)
kv_cache_out_mla_shape = kv_cache.shape
kv_cache = kv_cache.reshape(BlockNum * BlockSize, N2, Hckv)
for i in range(T):
for j in range(N2):
kv_cache[index_table[i], j, :] = norm2_res[i, :]
kv_cache_out_mla = kv_cache.reshape(kv_cache_out_mla_shape)
k = splitd2_res2.reshape(T, 1, int(Dr / 2), 2).transpose(3, 2).reshape(T, Dr)
rotary2_res = (k * cos) + (rotate_half(k) * sin)
kr_cache = copy.deepcopy(kr_cache)
kr_cache_out_mla_shape = kr_cache.shape
kr_cache = kr_cache.reshape(BlockNum * BlockSize, N2, Dr)
for i in range(T):
for j in range(N2):
kr_cache[index_table[i], j, :] = rotary2_res[i, :]
kr_cache_out_mla = kr_cache.reshape(kr_cache_out_mla_shape)
return query_mla, query_rope_mla, kv_cache_out_mla, kr_cache_out_mla, dequant_scale_q_nope
def baseline_mxfp8(self, token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr, quant_scale_ckv,
smooth_scale_cq, query_norm_flag, weight_quant_mode, kv_cache_quant_mode,
query_quant_mode, ckvkr_repo_mode, quant_scale_repo_mode, tile_size, k_nope_clip_alpha, mla_param, qc_qr_scale=1.0, kc_scale=1.0):
B = mla_param['B']
S1 = mla_param['S1']
S2 = mla_param['S2']
D = mla_param['D']
Dr = mla_param['Dr']
N1 = mla_param['N1']
N2 = mla_param['N2']
He = mla_param['He']
Hckv = mla_param['Hckv']
Hcq = mla_param['Hcq']
BlockNum = mla_param['BlockNum']
BlockSize = mla_param['BlockSize']
T = mla_param['T']
out_deqq_shape_shape = mla_param["out_deqq_shape_shape"]
index_table = cache_index
cos = rope_cos
sin = rope_sin
dequant_scale_qcqr = None
dequant_scale_q_nope = None
if not mla_param["t_flag"]:
T = B * S1
token_x = token_x.reshape(T, He)
cos = cos.reshape(T, Dr)
sin = sin.reshape(T, Dr)
index_table = index_table.reshape(T)
token_x_new = token_x
token_x_new = token_x_new.to(torch.bfloat16)
token_x = token_x.to(torch.bfloat16)
dequant_scale_x = trans_torch_fp8_e8m0_to_bf16(dequant_scale_x)
xs0 = dequant_scale_x.shape[0]
xs1 = dequant_scale_x.shape[1]
grp_size = 32
for xs0_idx in range(xs0):
for xs1_idx in range(xs1):
token_cur = token_x_new[xs0_idx : xs0_idx + 1, xs1_idx * grp_size : (xs1_idx + 1) * grp_size]
scale_x = dequant_scale_x[xs0_idx : xs0_idx + 1, xs1_idx : xs1_idx + 1]
scale_x = torch.full((1, grp_size), scale_x.item())
token_cur = token_cur * scale_x
token_x_new[xs0_idx : xs0_idx + 1, xs1_idx * grp_size : (xs1_idx + 1) * grp_size] = token_cur
w_dq = weight_dq.to(torch.bfloat16)
deq_scale_w_dq = trans_torch_fp8_e8m0_to_bf16(dequant_scale_w_dq)
dqs0 = deq_scale_w_dq.shape[0]
dqs1 = deq_scale_w_dq.shape[1]
for dqs0_idx in range(dqs0):
for dqs1_idx in range(dqs1):
scale_w_dq = deq_scale_w_dq[dqs0_idx : dqs0_idx + 1, dqs1_idx : dqs1_idx + 1]
w_dq[dqs1_idx * grp_size: (dqs1_idx + 1) * grp_size, dqs0_idx : dqs0_idx + 1] *= scale_w_dq
token_x_new = token_x_new.to(torch.float32)
w_dq = w_dq.to(torch.float32)
matmul1_res = torch.matmul(token_x_new, w_dq).to(torch.float32)
matmul1_res = matmul1_res.to(torch.float32)
ep1 = float(rmsnorm_epsilon_cq)
gamma1 = rmsnorm_gamma_cq
norm1_res = matmul1_res / torch.sqrt(torch.mean(matmul1_res ** 2, dim=-1, keepdim=True) + ep1)
norm1_res *= gamma1
w_uq_qr = weight_uq_qr.to(torch.bfloat16)
deq_scale_uqqr = trans_torch_fp8_e8m0_to_bf16(dequant_scale_w_uqqr)
uqqrs0 = deq_scale_uqqr.shape[0]
uqqrs1 = deq_scale_uqqr.shape[1]
for uqqrs0_idx in range(uqqrs0):
for uqqrs1_idx in range(uqqrs1):
scale_uqqr = deq_scale_uqqr[uqqrs0_idx : uqqrs0_idx + 1, uqqrs1_idx : uqqrs1_idx + 1]
w_uq_qr[uqqrs1_idx * grp_size: (uqqrs1_idx + 1) * grp_size, uqqrs0_idx : uqqrs0_idx + 1] *= scale_uqqr
norm1_res = norm1_res.to(torch.bfloat16)
deq_scale_qcqr_np, norm1_res_np = dynamic_mx_quant_cq(norm1_res.float().numpy(), "float8_e4m3fn")
norm1_res = torch.tensor(norm1_res_np.astype(np.float32)).to(torch.float8_e4m3fn)
deq_scale_qcqr = torch.from_numpy(deq_scale_qcqr_np)
deq_scale_qcqr = deq_scale_qcqr.reshape(deq_scale_qcqr.shape[0], deq_scale_qcqr.shape[1] * deq_scale_qcqr.shape[2])
norm1_res = norm1_res.to(torch.bfloat16)
deq_scale_qcqr = trans_torch_fp8_e8m0_to_bf16(deq_scale_qcqr)
qcqrs0 = deq_scale_qcqr.shape[0]
qcqrs1 = deq_scale_qcqr.shape[1]
for qcqrs0_idx in range(qcqrs0):
for qcqrs1_idx in range(qcqrs1):
normal_cur = norm1_res[qcqrs0_idx : qcqrs0_idx + 1, qcqrs1_idx * grp_size : (qcqrs1_idx + 1) * grp_size]
scale_qcqr = deq_scale_qcqr[qcqrs0_idx : qcqrs0_idx + 1, qcqrs1_idx : qcqrs1_idx + 1]
scale_qcqr = torch.full((1, grp_size), scale_qcqr.item())
normal_cur *= scale_qcqr
norm1_res[qcqrs0_idx : qcqrs0_idx + 1, qcqrs1_idx * grp_size : (qcqrs1_idx + 1) * grp_size] = normal_cur
norm1_res = norm1_res.to(torch.float32)
w_uq_qr = w_uq_qr.to(torch.float32)
matmul2_res = torch.matmul(norm1_res, w_uq_qr).to(torch.float32)
matmul2_res = matmul2_res.reshape(T, N1, D + Dr)
splitd1_res1 = matmul2_res[:, :, :D]
splitd1_res2 = matmul2_res[:, :, D:]
w_uk = weight_uk.to(torch.bfloat16)
splitd1_res1 = splitd1_res1.transpose(0, 1)
splitd1_res1 = splitd1_res1.to(torch.bfloat16)
query_mla = torch.zeros((N1, T, Hckv))
for n1_index in range(N1):
query_mla[n1_index, :, :] = torch.matmul(splitd1_res1[n1_index, :, :].to(torch.float32), w_uk[n1_index, :, :].to(torch.float32)).to(torch.float32)
query_mla = query_mla.transpose(0, 1)
query_mla = query_mla.to(torch.bfloat16).to(torch.float32)
deq_scale_q_nope_np, out_np = dynamic_mx_quant_qn(query_mla.numpy())
query_mla = torch.tensor(out_np.astype(np.float32)).to(torch.float8_e4m3fn)
dequant_scale_q_nope = torch.from_numpy(deq_scale_q_nope_np)
query_mla = query_mla if mla_param["t_flag"] else query_mla.reshape(B, S1, N1, Hckv)
expanded_cos = cos.unsqueeze(1).repeat(1, N1, 1)
expanded_sin = sin.unsqueeze(1).repeat(1, N1, 1)
q = splitd1_res2.reshape(T, N1, int(Dr / 2), 2).transpose(3, 2).reshape(T, N1, Dr)
query_rope_mla = (q * expanded_cos) + (rotate_half(q) * expanded_sin)
query_rope_mla = query_rope_mla.to(torch.bfloat16).to(torch.float32)
query_rope_mla = dequant(query_rope_mla, dequant_scale_q_nope, quant_scale_ckv)
query_rope_mla = query_rope_mla if mla_param["t_flag"] else query_rope_mla.reshape(B, S1, N1, Dr)
w_kv_kr = weight_dkv_kr.to(torch.bfloat16)
deq_scale_dkvkr = trans_torch_fp8_e8m0_to_bf16(dequant_scale_w_dkvkr)
dkvkrs0 = deq_scale_dkvkr.shape[0]
dkvkrs1 = deq_scale_dkvkr.shape[1]
for dkvkrs0_idx in range(dkvkrs0):
for dkvkrs1_idx in range(dkvkrs1):
scale_dkvkr = deq_scale_dkvkr[dkvkrs0_idx : dkvkrs0_idx + 1, dkvkrs1_idx : dkvkrs1_idx + 1]
w_kv_kr[dkvkrs1_idx * grp_size: (dkvkrs1_idx + 1) * grp_size, dkvkrs0_idx : dkvkrs0_idx + 1] *= scale_dkvkr
matmul4_res = torch.matmul(token_x_new.to(torch.float32), w_kv_kr.to(torch.float32)).to(torch.float32)
splitd2_res1 = matmul4_res[:, :Hckv]
splitd2_res2 = matmul4_res[:, Hckv:]
ep2 = float(rmsnorm_epsilon_ckv)
gamma2 = rmsnorm_gamma_ckv
norm2_res = splitd2_res1 / torch.sqrt(torch.mean(splitd2_res1 ** 2, dim=-1, keepdim=True) + ep2)
norm2_res *= gamma2
norm2_res_np = quant_ckv_per_tensor(norm2_res.numpy(), quant_scale_ckv.numpy())
norm2_res = torch.tensor(norm2_res_np.astype(np.float32)).to(torch.float8_e4m3fn)
kv_cache = copy.deepcopy(kv_cache)
kv_cache_out_mla_shape = kv_cache.shape
kv_cache = kv_cache.reshape(BlockNum * BlockSize, N2, Hckv)
for i in range(T):
for j in range(N2):
kv_cache[index_table.reshape(T)[i], j, :] = norm2_res[i, :]
kv_cache_out_mla = kv_cache.reshape(kv_cache_out_mla_shape)
k = splitd2_res2.reshape(T, 1, int(Dr / 2), 2).transpose(3, 2).reshape(T, Dr)
rotary2_res = (k * cos) + (rotate_half(k) * sin)
kr_cache = copy.deepcopy(kr_cache)
kr_cache_out_mla_shape = kr_cache.shape
kr_cache = kr_cache.reshape(BlockNum * BlockSize, N2, Dr)
for i in range(T):
for j in range(N2):
kr_cache[index_table.reshape(T)[i], j, :] = rotary2_res[i, :]
kr_cache_out_mla = kr_cache.reshape(kr_cache_out_mla_shape)
return query_mla, query_rope_mla, kv_cache_out_mla, kr_cache_out_mla, dequant_scale_q_nope
def mla_prolog_npu_v3(self, token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr,
quant_scale_ckv, smooth_scale_cq, query_norm_flag, weight_quant_mode, kv_cache_quant_mode,
query_quant_mode, ckvkr_repo_mode, quant_scale_repo_mode, tile_size, k_nope_clip_alpha, qc_qr_scale=1.0, kc_scale=1.0):
return torch_npu.npu_mla_prolog_v3(
token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq, rmsnorm_gamma_ckv,
rope_sin, rope_cos, kv_cache, kr_cache, cache_index=cache_index, rmsnorm_epsilon_cq=rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv,
cache_mode=cache_mode, dequant_scale_x=dequant_scale_x, dequant_scale_w_dq=dequant_scale_w_dq,
dequant_scale_w_uq_qr=dequant_scale_w_uqqr, dequant_scale_w_dkv_kr=dequant_scale_w_dkvkr,
quant_scale_ckv=quant_scale_ckv, smooth_scales_cq=smooth_scale_cq, k_nope_clip_alpha=k_nope_clip_alpha, query_norm_flag=query_norm_flag,
weight_quant_mode=weight_quant_mode, kv_cache_quant_mode=kv_cache_quant_mode, query_quant_mode=query_quant_mode, ckvkr_repo_mode=ckvkr_repo_mode,
quant_scale_repo_mode=quant_scale_repo_mode, tile_size=tile_size, qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
def npu_mla_prolog_v3_functional(self, token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr,
quant_scale_ckv, smooth_scale_cq, query_norm_flag, weight_quant_mode, kv_cache_quant_mode,
query_quant_mode, ckvkr_repo_mode, quant_scale_repo_mode, tile_size, k_nope_clip_alpha, qc_qr_scale=1.0, kc_scale=1.0):
return torch_npu.npu_mla_prolog_v3_functional(
token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq, rmsnorm_gamma_ckv,
rope_sin, rope_cos, kv_cache, kr_cache, cache_index=cache_index, rmsnorm_epsilon_cq=rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv=rmsnorm_epsilon_ckv,
cache_mode=cache_mode, dequant_scale_x=dequant_scale_x, dequant_scale_w_dq=dequant_scale_w_dq,
dequant_scale_w_uq_qr=dequant_scale_w_uqqr, dequant_scale_w_dkv_kr=dequant_scale_w_dkvkr,
quant_scale_ckv=quant_scale_ckv, smooth_scales_cq=smooth_scale_cq, k_nope_clip_alpha=k_nope_clip_alpha, query_norm_flag=query_norm_flag,
weight_quant_mode=weight_quant_mode, kv_cache_quant_mode=kv_cache_quant_mode, query_quant_mode=query_quant_mode, ckvkr_repo_mode=ckvkr_repo_mode,
quant_scale_repo_mode=quant_scale_repo_mode, tile_size=tile_size, qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
@unittest.skip("Skipping due to outdated CANN version; please update CANN to the latest version and remove this skip")
@SupportedDevices(['Ascend910B'])
def test_op_exec_mla_prolog_npu_v3(self):
B = 2
He = 7168
Hcq = 1536
Hckv = 512
N = 32
D = 128
Dr = 64
Skv = 6144
S = 1
Nkv = 1
BlockSize = 128
BlockNum = math.ceil(B * Skv / BlockSize)
T = 8
token_x = torch.randint(-100, 100, (B, S, He), dtype=torch.int8).npu()
w_dq = torch.randint(-100, 100, (He, Hcq), dtype=torch.int8).npu()
w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
w_uq_qr = torch.randint(-100, 100, (Hcq, N * (D + Dr)), dtype=torch.int8).npu()
w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.randint(-100, 100, (He, Hckv + Dr), dtype=torch.int8).npu()
w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
cache_index = torch.randint(0, B * S, (B, S), dtype=torch.int64).npu()
kv_cache = torch.randint(-100, 100, (1, BlockNum * BlockSize * Nkv * Hckv), dtype=torch.int8).npu()
kv_cache = kv_cache.view(BlockNum, BlockSize, Nkv, Hckv)
kr_cache = torch.rand(1, BlockNum * BlockSize * Nkv * Dr, dtype=torch.bfloat16).npu()
kr_cache = kr_cache.view(BlockNum, BlockSize, Nkv, Dr)
rmsnorm_epsilon_cq = 1.0e-5
rmsnorm_epsilon_ckv = 1.0e-5
cache_mode = "PA_BSND"
qc_qr_scale = 10.0
kc_scale = 10.0
dequant_scale_x = torch.rand(B * S, 1, dtype=torch.float32).npu()
dequant_scale_w_dq = torch.rand(1, Hcq, dtype=torch.float32).npu()
dequant_scale_w_uqqr = torch.rand(1, N * (D + Dr), dtype=torch.float32).npu()
dequant_scale_w_dkvkr = torch.rand(1, Hckv + Dr, dtype=torch.float32).npu()
quant_scale_ckv = torch.rand(1, dtype=torch.float32).npu()
smooth_scale_cq = torch.ones(1, Hcq, dtype=torch.float32).npu()
mla_param = {
'B': B,
'He': He,
'Hcq': Hcq,
'Hckv': Hckv,
'N1': N,
'D': D,
'Dr': Dr,
'S2': Skv,
'S1': S,
'N2': Nkv,
'BlockNum': BlockNum,
'BlockSize': BlockSize,
't_flag': False,
'T': T,
"out_deqq_shape_shape": [B * S, N, 1]
}
kv_cache_copy = kv_cache.clone()
kr_cache_copy = kr_cache.clone()
query_mla, query_rope_mla, dequant_scale_q_nope_mla, query_norm_mla, dequant_scale_q_norm_mla = self.mla_prolog_npu_v3(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache_copy, kr_cache_copy, cache_index, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr, quant_scale_ckv, smooth_scale_cq, query_norm_flag=True, weight_quant_mode=2, kv_cache_quant_mode=1,
query_quant_mode=1, ckvkr_repo_mode=0, quant_scale_repo_mode=0, tile_size=128, k_nope_clip_alpha=None, qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
query_mla = query_mla.cpu()
query_rope_mla = query_rope_mla.cpu()
kv_cache_copy = kv_cache_copy.cpu()
kr_cache_copy = kr_cache_copy.cpu()
dequant_scale_q_nope_mla = dequant_scale_q_nope_mla.cpu()
token_x = token_x.cpu()
w_dq = w_dq.cpu()
w_dq_cast = w_dq_cast.cpu()
w_uq_qr_cast = w_uq_qr_cast.cpu()
w_uk = w_uk.cpu()
w_dkv_kr_cast = w_dkv_kr_cast.cpu()
rmsnorm_gamma_cq = rmsnorm_gamma_cq.cpu()
rmsnorm_gamma_ckv = rmsnorm_gamma_ckv.cpu()
rope_sin = rope_sin.cpu()
rope_cos = rope_cos.cpu()
cache_index = cache_index.cpu()
kv_cache = kv_cache.cpu()
kr_cache = kr_cache.cpu()
dequant_scale_x = dequant_scale_x.cpu()
dequant_scale_w_dq = dequant_scale_w_dq.cpu()
dequant_scale_w_uqqr = dequant_scale_w_uqqr.cpu()
dequant_scale_w_dkvkr = dequant_scale_w_dkvkr.cpu()
quant_scale_ckv = quant_scale_ckv.cpu()
smooth_scale_cq = smooth_scale_cq.cpu()
query, query_rope, kv_cache_out, kr_cache_out, dequant_scale_q_nope = self.baseline(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr,
quant_scale_ckv, smooth_scale_cq, query_norm_flag=True, weight_quant_mode=2, kv_cache_quant_mode=1,
query_quant_mode=1, ckvkr_repo_mode=0, quant_scale_repo_mode=0, tile_size=128, k_nope_clip_alpha=None, mla_param=mla_param, qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
self.assertRtolEqual(query_mla, query, prec=1, prec16=1)
self.assertRtolEqual(query_rope_mla.to(torch.float32), query_rope.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(kv_cache_copy.to(torch.float32), kv_cache_out.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(kr_cache_copy.to(torch.float32), kr_cache_out.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(dequant_scale_q_nope_mla.to(torch.float32), dequant_scale_q_nope.to(torch.float32), prec=0.005, prec16=0.005)
@unittest.skip("Skipping due to outdated CANN version; please update CANN to the latest version and remove this skip")
@SupportedDevices(['Ascend910B'])
def test_op_exec_mla_prolog_npu_v3_functional(self):
B = 2
He = 7168
Hcq = 1536
Hckv = 512
N = 32
D = 128
Dr = 64
Skv = 6144
S = 1
Nkv = 1
BlockSize = 128
BlockNum = math.ceil(B * Skv / BlockSize)
T = 8
token_x = torch.randint(-100, 100, (B, S, He), dtype=torch.int8).npu()
w_dq = torch.randint(-100, 100, (He, Hcq), dtype=torch.int8).npu()
w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
w_uq_qr = torch.randint(-100, 100, (Hcq, N * (D + Dr)), dtype=torch.int8).npu()
w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.randint(-100, 100, (He, Hckv + Dr), dtype=torch.int8).npu()
w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
cache_index = torch.randint(0, B * S, (B, S), dtype=torch.int64).npu()
kv_cache = torch.randint(-100, 100, (1, BlockNum * BlockSize * Nkv * Hckv), dtype=torch.int8).npu()
kv_cache = kv_cache.view(BlockNum, BlockSize, Nkv, Hckv)
kr_cache = torch.rand(1, BlockNum * BlockSize * Nkv * Dr, dtype=torch.bfloat16).npu()
kr_cache = kr_cache.view(BlockNum, BlockSize, Nkv, Dr)
rmsnorm_epsilon_cq = 1.0e-5
rmsnorm_epsilon_ckv = 1.0e-5
cache_mode = "PA_BSND"
qc_qr_scale = 10.0
kc_scale = 10.0
dequant_scale_x = torch.rand(B * S, 1, dtype=torch.float32).npu()
dequant_scale_w_dq = torch.rand(1, Hcq, dtype=torch.float32).npu()
dequant_scale_w_uqqr = torch.rand(1, N * (D + Dr), dtype=torch.float32).npu()
dequant_scale_w_dkvkr = torch.rand(1, Hckv + Dr, dtype=torch.float32).npu()
quant_scale_ckv = torch.rand(1, dtype=torch.float32).npu()
smooth_scale_cq = torch.ones(1, Hcq, dtype=torch.float32).npu()
mla_param = {
'B': B,
'He': He,
'Hcq': Hcq,
'Hckv': Hckv,
'N1': N,
'D': D,
'Dr': Dr,
'S2': Skv,
'S1': S,
'N2': Nkv,
'BlockNum': BlockNum,
'BlockSize': BlockSize,
't_flag': False,
'T': T,
"out_deqq_shape_shape": [B * S, N, 1]
}
query_mla, query_rope_mla, dequant_scale_q_nope_mla, _, _, kv_cache_mla, kr_cache_mla = self.npu_mla_prolog_v3_functional(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr, quant_scale_ckv, smooth_scale_cq, query_norm_flag=True, weight_quant_mode=2, kv_cache_quant_mode=1,
query_quant_mode=1, ckvkr_repo_mode=0, quant_scale_repo_mode=0, tile_size=128, k_nope_clip_alpha=None, qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
query_mla = query_mla.cpu()
query_rope_mla = query_rope_mla.cpu()
kv_cache_mla = kv_cache_mla.cpu()
kr_cache_mla = kr_cache_mla.cpu()
dequant_scale_q_nope_mla = dequant_scale_q_nope_mla.cpu()
token_x = token_x.cpu()
w_dq = w_dq.cpu()
w_dq_cast = w_dq_cast.cpu()
w_uq_qr_cast = w_uq_qr_cast.cpu()
w_uk = w_uk.cpu()
w_dkv_kr_cast = w_dkv_kr_cast.cpu()
rmsnorm_gamma_cq = rmsnorm_gamma_cq.cpu()
rmsnorm_gamma_ckv = rmsnorm_gamma_ckv.cpu()
rope_sin = rope_sin.cpu()
rope_cos = rope_cos.cpu()
cache_index = cache_index.cpu()
kv_cache = kv_cache.cpu()
kr_cache = kr_cache.cpu()
dequant_scale_x = dequant_scale_x.cpu()
dequant_scale_w_dq = dequant_scale_w_dq.cpu()
dequant_scale_w_uqqr = dequant_scale_w_uqqr.cpu()
dequant_scale_w_dkvkr = dequant_scale_w_dkvkr.cpu()
quant_scale_ckv = quant_scale_ckv.cpu()
smooth_scale_cq = smooth_scale_cq.cpu()
query, query_rope, kv_cache_out, kr_cache_out, dequant_scale_q_nope = self.baseline(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr,
quant_scale_ckv, smooth_scale_cq, query_norm_flag=True, weight_quant_mode=2, kv_cache_quant_mode=1,
query_quant_mode=1, ckvkr_repo_mode=0, quant_scale_repo_mode=0, tile_size=128, k_nope_clip_alpha=None, mla_param=mla_param, qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
self.assertRtolEqual(query_mla, query, prec=1, prec16=1)
self.assertRtolEqual(query_rope_mla.to(torch.float32), query_rope.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(kv_cache_mla.to(torch.float32), kv_cache_out.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(kr_cache_mla.to(torch.float32), kr_cache_out.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(dequant_scale_q_nope_mla.to(torch.float32), dequant_scale_q_nope.to(torch.float32), prec=0.005, prec16=0.005)
@unittest.skip("Skipping due to outdated CANN version; please update CANN to the latest version and remove this skip")
@SupportedDevices(['Ascend950'])
def test_op_exec_mla_prolog_npu_v3_mxfp8(self):
B = 2
He = 7168
Hcq = 1536
Hckv = 512
N = 32
D = 128
Dr = 64
Skv = 6144
S = 1
Nkv = 1
BlockSize = 128
BlockNum = math.ceil(B * Skv / BlockSize)
T = 8
token_x = torch.rand(B, S, He).to(torch.float8_e4m3fn).npu()
w_dq = torch.rand(He, Hcq).to(torch.float8_e4m3fn).npu()
w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
w_uq_qr = torch.rand(Hcq, N * (D + Dr)).to(torch.float8_e4m3fn).npu()
w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.rand(He, Hckv + Dr).to(torch.float8_e4m3fn).npu()
w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
cache_index = torch.randint(0, B * S, (B, S), dtype=torch.int64).npu()
kv_cache = torch.rand(1, BlockNum * BlockSize * Nkv * Hckv).to(torch.float8_e4m3fn).npu()
kv_cache = kv_cache.view(BlockNum, BlockSize, Nkv, Hckv)
kr_cache = torch.rand(1, BlockNum * BlockSize * Nkv * Dr, dtype=torch.bfloat16).npu()
kr_cache = kr_cache.view(BlockNum, BlockSize, Nkv, Dr)
rmsnorm_epsilon_cq = 1.0e-5
rmsnorm_epsilon_ckv = 1.0e-5
cache_mode = "PA_BSND"
qc_qr_scale = 10.0
kc_scale = 10.0
dequant_scale_x = torch.randint(0, 100, (B * S, He // 32)).to(torch.int8).npu()
dequant_scale_w_dq = torch.randint(0, 100, (Hcq, He // 32)).to(torch.int8).npu()
dequant_scale_w_uqqr = torch.randint(0, 100, (N * (D + Dr), Hcq // 32)).to(torch.int8).npu()
dequant_scale_w_dkvkr = torch.randint(0, 100, (Hckv + Dr, He // 32)).to(torch.int8).npu()
quant_scale_ckv = torch.rand(1, dtype=torch.float32).npu()
smooth_scale_cq = None
mla_param = {
'B': B,
'He': He,
'Hcq': Hcq,
'Hckv': Hckv,
'N1': N,
'D': D,
'Dr': Dr,
'S2': Skv,
'S1': S,
'N2': Nkv,
'BlockNum': BlockNum,
'BlockSize': BlockSize,
't_flag': False,
'T': T,
"out_deqq_shape_shape": [B * S, N, 1]
}
kv_cache_copy = kv_cache.clone()
kr_cache_copy = kr_cache.clone()
query_mla, query_rope_mla, dequant_scale_q_nope_mla, query_norm_mla, dequant_scale_q_norm_mla = self.mla_prolog_npu_v3(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache_copy, kr_cache_copy, cache_index, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr, quant_scale_ckv, smooth_scale_cq, query_norm_flag=False, weight_quant_mode=3, kv_cache_quant_mode=1,
query_quant_mode=1, ckvkr_repo_mode=0, quant_scale_repo_mode=0, tile_size=128, k_nope_clip_alpha=None, qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
query_mla = query_mla.cpu()
query_rope_mla = query_rope_mla.cpu()
kv_cache_copy = kv_cache_copy.cpu()
kr_cache_copy = kr_cache_copy.cpu()
dequant_scale_q_nope_mla = dequant_scale_q_nope_mla.cpu()
token_x = token_x.cpu()
w_dq = w_dq.cpu()
w_dq_cast = w_dq_cast.cpu()
w_uq_qr_cast = w_uq_qr_cast.cpu()
w_uk = w_uk.cpu()
w_dkv_kr_cast = w_dkv_kr_cast.cpu()
rmsnorm_gamma_cq = rmsnorm_gamma_cq.cpu()
rmsnorm_gamma_ckv = rmsnorm_gamma_ckv.cpu()
rope_sin = rope_sin.cpu()
rope_cos = rope_cos.cpu()
cache_index = cache_index.cpu()
kv_cache = kv_cache.cpu()
kr_cache = kr_cache.cpu()
dequant_scale_x = dequant_scale_x.cpu()
dequant_scale_w_dq = dequant_scale_w_dq.cpu()
dequant_scale_w_uqqr = dequant_scale_w_uqqr.cpu()
dequant_scale_w_dkvkr = dequant_scale_w_dkvkr.cpu()
quant_scale_ckv = quant_scale_ckv.cpu()
smooth_scale_cq = None
query, query_rope, kv_cache_out, kr_cache_out, dequant_scale_q_nope = self.baseline_mxfp8(token_x, w_dq, w_uq_qr, w_uk, w_dkv_kr, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr,
quant_scale_ckv, smooth_scale_cq, query_norm_flag=False, weight_quant_mode=3, kv_cache_quant_mode=1,
query_quant_mode=1, ckvkr_repo_mode=0, quant_scale_repo_mode=0, tile_size=128, k_nope_clip_alpha=None, mla_param=mla_param, qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
self.assertRtolEqual(query_mla.to(torch.float32), query.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(query_rope_mla.to(torch.float32), query_rope.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(kv_cache_copy.to(torch.float32), kv_cache_out.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(kr_cache_copy.to(torch.float32), kr_cache_out.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(dequant_scale_q_nope_mla.to(torch.float32), dequant_scale_q_nope.to(torch.float32), prec=0.005, prec16=0.005)
@unittest.skip("Skipping due to outdated CANN version; please update CANN to the latest version and remove this skip")
@SupportedDevices(['Ascend950'])
def test_op_exec_mla_prolog_npu_v3_functional_mxfp8(self):
B = 2
He = 7168
Hcq = 1536
Hckv = 512
N = 32
D = 128
Dr = 64
Skv = 6144
S = 1
Nkv = 1
BlockSize = 128
BlockNum = math.ceil(B * Skv / BlockSize)
T = 8
token_x = torch.rand(B, S, He).to(torch.float8_e4m3fn).npu()
w_dq = torch.rand(He, Hcq).to(torch.float8_e4m3fn).npu()
w_dq_cast = torch_npu.npu_format_cast(w_dq.contiguous(), 29)
w_uq_qr = torch.rand(Hcq, N * (D + Dr)).to(torch.float8_e4m3fn).npu()
w_uq_qr_cast = torch_npu.npu_format_cast(w_uq_qr.contiguous(), 29)
w_uk = torch.rand(N, D, Hckv, dtype=torch.bfloat16).npu()
w_dkv_kr = torch.rand(He, Hckv + Dr).to(torch.float8_e4m3fn).npu()
w_dkv_kr_cast = torch_npu.npu_format_cast(w_dkv_kr.contiguous(), 29)
rmsnorm_gamma_cq = torch.rand(Hcq, dtype=torch.bfloat16).npu()
rmsnorm_gamma_ckv = torch.rand(Hckv, dtype=torch.bfloat16).npu()
rope_sin = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
rope_cos = torch.rand(B, S, Dr, dtype=torch.bfloat16).npu()
cache_index = torch.randint(0, B * S, (B, S), dtype=torch.int64).npu()
kv_cache = torch.rand(1, BlockNum * BlockSize * Nkv * Hckv).to(torch.float8_e4m3fn).npu()
kv_cache = kv_cache.view(BlockNum, BlockSize, Nkv, Hckv)
kr_cache = torch.rand(1, BlockNum * BlockSize * Nkv * Dr, dtype=torch.bfloat16).npu()
kr_cache = kr_cache.view(BlockNum, BlockSize, Nkv, Dr)
rmsnorm_epsilon_cq = 1.0e-5
rmsnorm_epsilon_ckv = 1.0e-5
cache_mode = "PA_BSND"
qc_qr_scale = 10.0
kc_scale = 10.0
dequant_scale_x = torch.randint(0, 100, (B * S, He // 32)).to(torch.int8).npu()
dequant_scale_w_dq = torch.randint(0, 100, (Hcq, He // 32)).to(torch.int8).npu()
dequant_scale_w_uqqr = torch.randint(0, 100, (N * (D + Dr), Hcq // 32)).to(torch.int8).npu()
dequant_scale_w_dkvkr = torch.randint(0, 100, (Hckv + Dr, He // 32)).to(torch.int8).npu()
quant_scale_ckv = torch.rand(1, dtype=torch.float32).npu()
smooth_scale_cq = None
mla_param = {
'B': B,
'He': He,
'Hcq': Hcq,
'Hckv': Hckv,
'N1': N,
'D': D,
'Dr': Dr,
'S2': Skv,
'S1': S,
'N2': Nkv,
'BlockNum': BlockNum,
'BlockSize': BlockSize,
't_flag': False,
'T': T,
"out_deqq_shape_shape": [B * S, N, 1]
}
query_mla, query_rope_mla, dequant_scale_q_nope_mla, _, _, kv_cache_mla, kr_cache_mla = self.npu_mla_prolog_v3_functional(token_x, w_dq_cast, w_uq_qr_cast, w_uk, w_dkv_kr_cast, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr, quant_scale_ckv, smooth_scale_cq, query_norm_flag=False, weight_quant_mode=3, kv_cache_quant_mode=1,
query_quant_mode=1, ckvkr_repo_mode=0, quant_scale_repo_mode=0, tile_size=128, k_nope_clip_alpha=None, qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
query_mla = query_mla.cpu()
query_rope_mla = query_rope_mla.cpu()
kv_cache_mla = kv_cache_mla.cpu()
kr_cache_mla = kr_cache_mla.cpu()
dequant_scale_q_nope_mla = dequant_scale_q_nope_mla.cpu()
token_x = token_x.cpu()
w_dq = w_dq.cpu()
w_dq_cast = w_dq_cast.cpu()
w_uq_qr_cast = w_uq_qr_cast.cpu()
w_uk = w_uk.cpu()
w_dkv_kr_cast = w_dkv_kr_cast.cpu()
rmsnorm_gamma_cq = rmsnorm_gamma_cq.cpu()
rmsnorm_gamma_ckv = rmsnorm_gamma_ckv.cpu()
rope_sin = rope_sin.cpu()
rope_cos = rope_cos.cpu()
cache_index = cache_index.cpu()
kv_cache = kv_cache.cpu()
kr_cache = kr_cache.cpu()
dequant_scale_x = dequant_scale_x.cpu()
dequant_scale_w_dq = dequant_scale_w_dq.cpu()
dequant_scale_w_uqqr = dequant_scale_w_uqqr.cpu()
dequant_scale_w_dkvkr = dequant_scale_w_dkvkr.cpu()
smooth_scale_cq = None
query, query_rope, kv_cache_out, kr_cache_out, dequant_scale_q_nope = self.baseline_mxfp8(token_x, w_dq, w_uq_qr, w_uk, w_dkv_kr, rmsnorm_gamma_cq,
rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv,
cache_mode, dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uqqr, dequant_scale_w_dkvkr,
quant_scale_ckv, smooth_scale_cq, query_norm_flag=False, weight_quant_mode=3, kv_cache_quant_mode=1,
query_quant_mode=1, ckvkr_repo_mode=0, quant_scale_repo_mode=0, tile_size=128, k_nope_clip_alpha=None, mla_param=mla_param, qc_qr_scale=qc_qr_scale, kc_scale=kc_scale)
self.assertRtolEqual(query_mla.to(torch.float32), query.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(query_rope_mla.to(torch.float32), query_rope.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(kv_cache_mla.to(torch.float32), kv_cache_out.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(kr_cache_mla.to(torch.float32), kr_cache_out.to(torch.float32), prec=0.005, prec16=0.005)
self.assertRtolEqual(dequant_scale_q_nope_mla.to(torch.float32), dequant_scale_q_nope.to(torch.float32), prec=0.005, prec16=0.005)
if __name__ == "__main__":
run_tests()