from functools import partial
from typing import List, Optional, Union
import torch
from torch.distributed.tensor import DTensor
import torch_npu
from einops import rearrange
from mindspeed.fsdp.parallel_engine_config import QuantizeConfig
from mindspeed.fsdp.quantization.core.pre_quant_weight import PreQuantWeight
from mindspeed.ops.npu_moe_token_permute import npu_moe_token_permute
from mindspeed.ops.npu_moe_token_unpermute import npu_moe_token_unpermute
from mindspeed.fsdp.quantization.core.post_quant_weight import PostQuantWeight
@torch._dynamo.allow_in_graph
class gmm_with_hp_or_lp_weight(torch.autograd.Function):
@classmethod
def gmm_apply(cls, x, weight, bias, tokens_per_expert, config, grad_enabled, to_shape, ori_weight=None):
if isinstance(tokens_per_expert, list):
tokens_per_expert = torch.tensor(tokens_per_expert, device="npu", dtype=torch.int64)
group_list = torch.cumsum(tokens_per_expert, dim=0)
return cls.apply(x, weight, bias, group_list, grad_enabled, config, to_shape, ori_weight, 0)
@classmethod
def forward(cls, ctx, x, weight, bias, group_list, grad_enabled, config: QuantizeConfig, to_shape, ori_weight=None,
group_list_type=0):
def get_quantized_weight(weight, grad_enabled, config):
ctx.weight_bwd, ctx.weight_scale_bwd = None, None
if isinstance(weight, PostQuantWeight):
return weight._weight_fwd, weight._scale_fwd
if isinstance(weight, PreQuantWeight):
weight = weight._tensor
if weight.dtype == torch.float32:
weight = weight.to(torch.bfloat16)
if grad_enabled:
ctx.weight_bwd, ctx.weight_scale_bwd, weight_fwd, weight_scale_fwd = torch_npu.npu_dynamic_mx_quant_with_dual_axis(
weight,
dst_type=config.get_key_dtype("weight"),
)
else:
weight_fwd, weight_scale_fwd = torch_npu.npu_dynamic_mx_quant(weight, axis=-2,
dst_type=config.get_key_dtype("weight"))
return weight_fwd, weight_scale_fwd
if isinstance(group_list, torch.Tensor):
if group_list.device.type == "cpu":
group_list = group_list.npu()
else:
group_list = torch.tensor(group_list, device="npu", dtype=torch.int64)
ori_weight = weight if ori_weight is None else ori_weight
weight = weight.view(to_shape)
weight_fwd, weight_scale_fwd = get_quantized_weight(weight, grad_enabled, config)
x_mxfp8, x_scale = torch_npu.npu_dynamic_mx_quant(x, axis=-1, dst_type=config.get_key_dtype("inputs"))
output = torch_npu.npu_grouped_matmul(
[x_mxfp8],
[weight_fwd],
bias=bias,
scale=[weight_scale_fwd],
per_token_scale=[x_scale],
group_list=group_list,
group_type=0,
output_dtype=x.dtype,
group_list_type=group_list_type,
scale_dtype=torch_npu.float8_e8m0fnu,
per_token_scale_dtype=torch_npu.float8_e8m0fnu,
split_item=3,
)
ctx.config = config
ctx.save_for_backward(x, ori_weight, group_list)
ctx.group_list_type = group_list_type
ctx.bias = bias
ctx.to_shape = to_shape
return output[0]
@classmethod
def backward(cls, ctx, grad_outputs):
x, weight, group_list = ctx.saved_tensors
if isinstance(weight, DTensor):
weight = weight.to_local()
if isinstance(weight, PostQuantWeight):
weight_bwd, weight_scale_bwd = weight._weight_bwd, weight._scale_bwd
weight_bwd = weight_bwd.view(ctx.to_shape)
else:
weight_bwd, weight_scale_bwd = ctx.weight_bwd, ctx.weight_scale_bwd
group_list_type = ctx.group_list_type
grad_bias = None
if ctx.bias is not None:
grad_bias = grad_outputs.reshape(-1, grad_outputs.shape[-1]).sum(dim=0)
grad_mxfp8, grad_scale = torch_npu.npu_dynamic_mx_quant(grad_outputs, axis=-1,
dst_type=ctx.config.get_key_dtype("grads"))
grad_x = torch_npu.npu_grouped_matmul(
[grad_mxfp8],
[rearrange(weight_bwd, "n h f -> n f h")],
scale=[rearrange(weight_scale_bwd, "n h f g -> n f h g")],
per_token_scale=[grad_scale],
group_list=group_list,
group_type=0,
output_dtype=grad_outputs.dtype,
group_list_type=group_list_type,
scale_dtype=torch_npu.float8_e8m0fnu,
per_token_scale_dtype=torch_npu.float8_e8m0fnu,
split_item=3,
)[0]
x_mxfp8, x_scale = torch_npu.npu_grouped_dynamic_mx_quant(x, group_list.to(torch.int32), round_mode="rint",
dst_type=ctx.config.get_key_dtype("inputs"),
blocksize=32)
grad_mxfp8, grad_scale = torch_npu.npu_grouped_dynamic_mx_quant(
grad_outputs, group_list.to(torch.int32), round_mode="rint", dst_type=ctx.config.get_key_dtype("grads"),
blocksize=32
)
grad_weights = torch_npu.npu_grouped_matmul(
[x_mxfp8.t()],
[grad_mxfp8],
scale=[grad_scale],
per_token_scale=[rearrange(x_scale, "n h f -> h n f")],
group_list=group_list,
group_type=2,
output_dtype=x.dtype,
group_list_type=group_list_type,
scale_dtype=torch_npu.float8_e8m0fnu,
per_token_scale_dtype=torch_npu.float8_e8m0fnu,
split_item=3,
)[0]
return grad_x, grad_weights.view(weight.shape), grad_bias, None, None, None, None, None, None
def mx_quant_group_gemm(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tokens_per_expert: Union[List[int], torch.Tensor] = None,
grad_enabled: bool = True,
config: QuantizeConfig = None,
to_shape: list = None,
ori_weight: torch.Tensor = None,
) -> torch.Tensor:
"""
Performs group-wise quantized GEMM (General Matrix Multiplication)
for MoE (Mixture-of-Experts) models, supporting both high-precision
and low-precision weight formats.
Args:
x: Input tensor of shape [batch_size * tokens_per_expert, in_features].
Should be in FP32 or FP16 depending on the quantization setup.
weight: Quantized weight tensor stored in low-precision format (e.g., MXFP8).
The function automatically handles dequantization and scaling during computation.
bias: Optional bias tensor of shape [out_features]. If None, no bias is added.
tokens_per_expert: List of integers or tensor specifying the number of tokens
assigned to each expert. Can be a list[int] or torch.Tensor.
grad_enabled: Whether to enable gradient computation (True for training, False for inference).
config: Quantization configuration object .
to_shape: Target shape to reshape the weight tensor for GEMM.
ori_weight: Original weight,Used for backward.
Returns:
Output tensor of shape [batch_size * tokens_per_expert, out_features],
"""
return gmm_with_hp_or_lp_weight.gmm_apply(
x=x,
weight=weight,
bias=bias,
tokens_per_expert=tokens_per_expert,
config=config,
grad_enabled=grad_enabled,
to_shape=to_shape,
ori_weight=ori_weight
)
class MXFP8GMM(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.config = kwargs.pop("config", None)
self.num_experts = kwargs.pop("num_experts", None)
self.hidden_dim = kwargs.pop("hidden_size", None)
self.intermediate_size = kwargs.pop("moe_intermediate_size", None)
self.act_fn = kwargs.pop("act_fn", None)
def forward(self, hidden_states, routing_weights=None, selected_experts=None):
permuted_hidden_states, row_ids_map = npu_moe_token_permute(hidden_states, selected_experts.to(torch.int32))
tokens_per_expert = torch.histc(selected_experts, bins=self.num_experts, min=0, max=self.num_experts)
fc1_output = mx_quant_group_gemm(
x=permuted_hidden_states,
weight=self.gate_up_proj,
bias=None,
tokens_per_expert=tokens_per_expert,
config=self.config,
grad_enabled=torch.is_grad_enabled(),
to_shape=[self.num_experts, self.hidden_dim, -1],
ori_weight=self.gate_up_proj,
)
fc1_activation = torch_npu.npu_swiglu(fc1_output, dim=-1)
fc2_out = mx_quant_group_gemm(
x=fc1_activation,
weight=self.down_proj,
bias=None,
tokens_per_expert=tokens_per_expert,
config=self.config,
grad_enabled=torch.is_grad_enabled(),
to_shape=[self.num_experts, -1, self.hidden_dim],
ori_weight=self.down_proj,
)
output = npu_moe_token_unpermute(fc2_out, row_ids_map, probs=routing_weights)
return output
def ep_forward(self, hidden_states, tokens_per_expert):
gate_up_proj = self.gate_up_proj.to_local()
down_proj = self.down_proj.to_local()
fc1_output = mx_quant_group_gemm(
x=hidden_states,
weight=gate_up_proj,
bias=None,
tokens_per_expert=tokens_per_expert,
config=self.config,
grad_enabled=torch.is_grad_enabled(),
to_shape=[self.num_local_experts, self.hidden_dim, -1],
ori_weight=self.gate_up_proj,
)
fc1_activation = torch_npu.npu_swiglu(fc1_output, dim=-1)
fc2_out = mx_quant_group_gemm(
x=fc1_activation,
weight=down_proj,
bias=None,
tokens_per_expert=tokens_per_expert,
config=self.config,
grad_enabled=torch.is_grad_enabled(),
to_shape=[self.num_local_experts, -1, self.hidden_dim],
ori_weight=self.down_proj,
)
return fc2_out
@classmethod
def from_float(
cls,
mod: torch.nn.Module,
config: Optional[QuantizeConfig] = None,
name: Optional[str] = None,
):
if config is None:
config = QuantizeConfig(recipe_name="mxfp8")
if config.enable_fsdp_low_precision_all_gather:
with torch.device("meta"):
new_mod = cls(
config=config,
num_experts=mod.num_experts,
hidden_size=mod.hidden_dim,
moe_intermediate_size=mod.intermediate_size,
act_fn=mod.act_fn,
)
new_mod.gate_up_proj = mod.gate_up_proj
new_mod.down_proj = mod.down_proj
new_mod.gate_up_proj = torch.nn.Parameter(
PreQuantWeight(
new_mod.gate_up_proj,
partial(weight_quant, dst_type=config.get_key_dtype("weight"),
new_shape=(-1, mod.hidden_dim, mod.intermediate_size * 2)),
config,
mod.gate_up_proj.dtype,
name=name,
),
requires_grad=new_mod.gate_up_proj.requires_grad,
)
new_mod.down_proj = torch.nn.Parameter(
PreQuantWeight(
new_mod.down_proj,
partial(weight_quant, dst_type=config.get_key_dtype("weight"),
new_shape=(-1, mod.intermediate_size, mod.hidden_dim)),
config,
mod.down_proj.dtype,
name=name,
),
requires_grad=new_mod.down_proj.requires_grad,
)
new_mod._name = name
return new_mod
mod.__class__ = cls
mod.config = config
mod._name = name
return mod
def weight_quant(weight, dst_type, new_shape):
original_shape = weight.shape
weight = weight.reshape(new_shape)
weight_bwd, weight_scale_bwd, weight_fwd, weight_scale_fwd = torch_npu.npu_dynamic_mx_quant_with_dual_axis(weight,
dst_type=dst_type)
weight_fwd = weight_fwd.reshape(original_shape)
weight_bwd = weight_bwd.reshape(original_shape)
return weight_fwd, weight_scale_fwd, weight_bwd, weight_scale_bwd