from typing import Optional
import torch
import torch_npu
from ..constants import USAGE_WITH_TRANS
from ..quantization import FP8GlobalStateManager
from ..quantized_tensor import QuantizedTensor
from ..tensor.grouped_tensor import GroupedTensor
from ..utils import get_quant_dtype, view_as_n_dim
def _align_group_list_device(group_list, device):
"""Keep grouped GEMM split metadata on the same device as input tensors."""
if isinstance(group_list, torch.Tensor) and group_list.device != device:
return group_list.to(device=device)
return group_list
class MXFP8MatMul(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, weight: torch.Tensor, need_grad: bool = True):
recipe = FP8GlobalStateManager.get_fp8_recipe()
qdtype = get_quant_dtype(recipe.fp8_format)
dst_type = qdtype.fwd
ctx.qdtype = qdtype
x_2d = view_as_n_dim(x)
ctx.output_dtype = x.dtype
if need_grad:
x_quant, x_scale, ctx.x, ctx.x_scale = torch_npu.npu_dynamic_mx_quant_with_dual_axis(
x_2d, dst_type=dst_type
)
w_quant, w_scale, ctx.w, ctx.w_scale = torch_npu.npu_dynamic_mx_quant_with_dual_axis(
weight, dst_type=dst_type
)
else:
x_quant, x_scale = torch_npu.npu_dynamic_mx_quant(x_2d, axis=-1, dst_type=dst_type)
w_quant, w_scale = torch_npu.npu_dynamic_mx_quant(weight, axis=-1, dst_type=dst_type)
ctx.save_for_backward(x, weight)
output = torch_npu.npu_quant_matmul(
x_quant,
w_quant.t(),
w_scale.transpose(0, 1),
pertoken_scale=x_scale,
output_dtype=x.dtype,
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
group_sizes=[1, 1, 32],
)
if len(x.shape) != 2:
output = output.reshape(*x.shape[:-1], *output.shape[1:])
if weight.requires_grad:
output.requires_grad = True
return output
@staticmethod
def backward(ctx, grads: torch.Tensor):
qdtype = ctx.qdtype
grads_dx, grads_dx_scale, grads_dw, grads_dw_scale = (
torch_npu.npu_dynamic_mx_quant_with_dual_axis(view_as_n_dim(grads), dst_type=qdtype.bwd)
)
if hasattr(ctx, "x"):
x_quant, x_scale, w_quant, w_scale = ctx.x, ctx.x_scale, ctx.w, ctx.w_scale
else:
x, weight = ctx.saved_tensors
w_quant, w_scale = torch_npu.npu_dynamic_mx_quant(weight, axis=-2, dst_type=qdtype.fwd)
x_quant, x_scale = torch_npu.npu_dynamic_mx_quant(
view_as_n_dim(x), axis=-2, dst_type=qdtype.fwd
)
dx = torch_npu.npu_quant_matmul(
grads_dx,
w_quant,
w_scale,
pertoken_scale=grads_dx_scale,
output_dtype=ctx.output_dtype,
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
group_sizes=[1, 1, 32],
)
if len(grads.shape) != 2:
dx = dx.reshape(*grads.shape[:-1], *dx.shape[1:])
dw = torch_npu.npu_quant_matmul(
grads_dw.t(),
x_quant,
x_scale,
pertoken_scale=grads_dw_scale.transpose(0, 1),
output_dtype=ctx.output_dtype,
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
group_sizes=[1, 1, 32],
)
return dx, dw, None, None, None
def general_gemm(A, B, usage_a, usage_b, out_dtype, bias=None):
if isinstance(A, QuantizedTensor):
out = A.matmul(B, usage_a, usage_b, out_dtype)
return out if bias is None else out + bias
if usage_a in USAGE_WITH_TRANS:
A = A.t()
if usage_b in USAGE_WITH_TRANS:
B = B.t()
out = torch.matmul(A, B)
return out if bias is None else out + bias
def general_gemm_add(main_grad, A, B, usage_a, usage_b, out_dtype):
if not all(A.shape) or not all(B.shape):
return
if isinstance(A, QuantizedTensor):
A.matmul_add(main_grad, B, usage_a, usage_b, out_dtype)
return
if usage_a in USAGE_WITH_TRANS:
A = A.t()
if usage_b in USAGE_WITH_TRANS:
B = B.t()
main_grad.addmm_(A, B)
def matmul_add(main_grad, inp, grad_out):
if not all(inp.shape) or not all(grad_out.shape):
return
main_grad.addmm_(grad_out.t(), inp)
def general_grouped_gemm(
B,
A,
group_split: torch.Tensor,
out: Optional[torch.Tensor] = None,
layout="TN",
use_bias: bool = False,
biases=None,
group_type: int = 0,
group_list_type: int = 1,
split_item: int = 3,
out_dtype: torch.dtype = torch.bfloat16,
):
if isinstance(A, GroupedTensor) and isinstance(B, GroupedTensor):
if out is not None:
general_grouped_gemm_add_(out, B, A, group_split, layout, group_list_type)
return None
else:
return general_grouped_gemm_for_grouped_tensor(
B,
A,
group_split,
layout,
use_bias,
biases,
group_type,
group_list_type,
split_item,
out_dtype,
)
if isinstance(B, torch.Tensor):
B = [B]
if layout[0] == "T":
B = [b.transpose(-1, -2) for b in B]
if layout[1] == "T":
if not A.is_contiguous():
A = A.contiguous()
A = A.t()
if A.shape[-1] == 0 and B[0].shape[0] == 0:
return torch.zeros(
len(group_split),
*A.shape[:-1],
*B[0].shape[1:],
device=A.device,
dtype=out_dtype,
)
out = torch_npu.npu_grouped_matmul(
[A],
B,
group_list=group_split,
output_dtype=out_dtype,
group_type=group_type,
group_list_type=group_list_type,
split_item=split_item,
)[0]
if use_bias and biases is not None:
out = out_add_biases(out, biases, group_split)
return out
def general_grouped_gemm_for_grouped_tensor(
B: GroupedTensor,
A: GroupedTensor,
group_split: torch.Tensor,
layout="TN",
use_bias: bool = False,
biases=None,
group_type: int = 0,
group_list_type: int = 1,
split_item: int = 3,
out_dtype: torch.dtype = torch.bfloat16,
):
if not isinstance(A, GroupedTensor) or not isinstance(B, GroupedTensor):
raise TypeError(
"general_grouped_gemm_for_grouped_tensor expects GroupedTensor inputs, "
f"but got {type(A)}(A) with {type(B)}(B)"
)
if (A.quantizer is None) != (B.quantizer is None):
raise RuntimeError("Mixed dense and quantized GroupedTensor GEMM is not supported")
a_data, a_scale = A.get_data(f"L{layout[1]}")
b_data, b_scale = B.get_data(f"R{layout[0]}")
is_low_precision = A.quantizer is not None and B.quantizer is not None
if (
(not is_low_precision or not A.quantizer._get_compatible_recipe().mxfp4())
and a_data.shape[-1] == 0
and b_data.shape[0] == 0
):
return torch.zeros(
len(group_split),
*a_data.shape[:-1],
*b_data.shape[1:],
device=A.device,
dtype=out_dtype,
)
group_split = _align_group_list_device(group_split, a_data.device)
gmm_kwargs = {}
if is_low_precision:
gmm_kwargs["scale"] = [b_scale]
gmm_kwargs["per_token_scale"] = [a_scale]
recipe = A.quantizer._get_compatible_recipe()
if recipe.mxfp8():
gmm_kwargs["scale_dtype"] = torch_npu.float8_e8m0fnu
gmm_kwargs["per_token_scale_dtype"] = torch_npu.float8_e8m0fnu
elif recipe.mxfp4():
gmm_kwargs["scale_dtype"] = torch_npu.float8_e8m0fnu
gmm_kwargs["per_token_scale_dtype"] = torch_npu.float8_e8m0fnu
if not (A.quantizer.with_rht and B.quantizer.with_rht):
gmm_kwargs["x_dtype"] = torch_npu.float4_e2m1fn_x2
gmm_kwargs["weight_dtype"] = torch_npu.float4_e2m1fn_x2
if A.quantizer.dtype == torch_npu.hifloat8:
gmm_kwargs["x_dtype"] = torch_npu.hifloat8
gmm_kwargs["weight_dtype"] = torch_npu.hifloat8
out = torch_npu.npu_grouped_matmul(
[a_data],
[b_data],
group_list=group_split,
output_dtype=out_dtype,
group_type=group_type,
group_list_type=group_list_type,
split_item=split_item,
**gmm_kwargs,
)[0]
if use_bias and biases is not None:
out = out_add_biases(out, biases, group_split)
return out
def general_grouped_gemm_add_(
out: torch.Tensor,
B: GroupedTensor,
A: GroupedTensor,
group_split: torch.Tensor,
layout="TN",
group_list_type: int = 1,
):
if not isinstance(A, GroupedTensor) or not isinstance(B, GroupedTensor):
raise TypeError("general_grouped_gemm_add_ expects GroupedTensor inputs")
if (A.quantizer is None) != (B.quantizer is None):
raise RuntimeError("Mixed dense and quantized GroupedTensor GEMM is not supported")
a_data, a_scale = A.get_data(f"L{layout[1]}")
b_data, b_scale = B.get_data(f"R{layout[0]}")
if a_data.shape[-1] == 0 and b_data.shape[0] == 0:
return
group_split = _align_group_list_device(group_split, a_data.device)
is_low_precision = A.quantizer is not None and B.quantizer is not None
g_size = len(group_split)
if not is_low_precision:
torch_npu.npu_grouped_matmul_add_(
out.view(g_size, b_data.shape[-1], -1), b_data, a_data, group_split
)
elif A.quantizer._get_compatible_recipe().mxfp8():
torch_npu.npu_add_quant_gmm_(
out.view(g_size, b_data.shape[-1], -1),
b_data,
a_data,
a_scale,
x1_scale=b_scale,
group_list_type=group_list_type,
group_list=group_split,
x1_scale_dtype=torch_npu.float8_e8m0fnu,
x2_scale_dtype=torch_npu.float8_e8m0fnu,
)
else:
raise RuntimeError(
"Currently, general_grouped_gemm_add_ only supports BF16 and MXFP8 formats"
)
def out_add_biases(out, biases, group_split):
if isinstance(group_split, torch.Tensor):
group_split = group_split.tolist()
out = list(torch.split(out, group_split))
for i, bias in enumerate(biases):
out[i] = out[i] + bias
return torch.cat(out, dim=0)