from collections.abc import Callable
from enum import Enum
import torch
import torch.distributed as dist
import torch_npu
from ...quantization.config import QuantConfig
from ...quantization.mode import QuantAlgorithm
from ...utils import ParametersInvalid
from ...utils.get_platform import NPUDevice, get_npu_device
from .moe_dataclass import (
MoEMlpComputeInput,
MoEPrepareInput,
MoERoutingInput,
MoETokenDispatchInput,
MoETokenDispatchOutput,
MoEWeights,
)
MOE_INT_QUANT_ALGOS = (QuantAlgorithm.W8A8_DYNAMIC,)
MOE_MXFP_QUANT_ALGOS = (QuantAlgorithm.W8A8_MXFP8,)
MOE_SUPPORTED_QUANT_ALGOS = (QuantAlgorithm.NO_QUANT, *MOE_INT_QUANT_ALGOS, *MOE_MXFP_QUANT_ALGOS)
def validate_moe_inputs(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
num_experts: int,
top_k: int,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
quant_config: QuantConfig | None = None,
w13_weight_scale: torch.Tensor | None = None,
w2_weight_scale: torch.Tensor | None = None,
tp_group: dist.ProcessGroup | None = None,
ep_group: dist.ProcessGroup | None = None,
dispatcher_type: str | None = None,
tokens_full: bool = True,
k_group: int = 1,
group_count: int = 1,
group_select_mode: int = 0,
routing_method: str = "softmax",
renormalize: bool = False,
routed_scaling_factor: float = 1.0,
custom_routing_function: Callable | None = None,
reduce_results: bool = True,
) -> QuantAlgorithm:
"""Validate public MoE inputs and return the normalized quantization algorithm."""
if not isinstance(hidden_states, torch.Tensor):
raise ParametersInvalid(f"hidden_states must be a torch.Tensor, but got {type(hidden_states)}.")
if not isinstance(router_logits, torch.Tensor):
raise ParametersInvalid(f"router_logits must be a torch.Tensor, but got {type(router_logits)}.")
if not isinstance(w13_weight, torch.Tensor):
raise ParametersInvalid(f"w13_weight must be a torch.Tensor, but got {type(w13_weight)}.")
if not isinstance(w2_weight, torch.Tensor):
raise ParametersInvalid(f"w2_weight must be a torch.Tensor, but got {type(w2_weight)}.")
if w13_bias is not None and not isinstance(w13_bias, torch.Tensor):
raise ParametersInvalid(f"w13_bias must be a torch.Tensor or None, but got {type(w13_bias)}.")
if w2_bias is not None and not isinstance(w2_bias, torch.Tensor):
raise ParametersInvalid(f"w2_bias must be a torch.Tensor or None, but got {type(w2_bias)}.")
if quant_config is not None and not isinstance(quant_config, QuantConfig):
raise ParametersInvalid(f"quant_config must be a QuantConfig or None, but got {type(quant_config)}.")
quant_algo = QuantAlgorithm.NO_QUANT if quant_config is None else quant_config.quant_algo or QuantAlgorithm.NO_QUANT
if quant_algo not in MOE_SUPPORTED_QUANT_ALGOS:
raise ParametersInvalid(f"Unsupported MoE quantization algorithm: {quant_algo}.")
if hidden_states.device.type == "npu":
npu_device = get_npu_device()
if quant_algo in MOE_INT_QUANT_ALGOS and npu_device not in (NPUDevice.A2, NPUDevice.A3):
raise ParametersInvalid("MoE integer quantization is only supported on A2 and A3.")
if quant_algo in MOE_MXFP_QUANT_ALGOS and npu_device != NPUDevice.A5:
raise ParametersInvalid("MoE MXFP quantization is only supported on A5.")
if quant_algo == QuantAlgorithm.NO_QUANT and (w13_weight_scale is not None or w2_weight_scale is not None):
raise ParametersInvalid("w13_weight_scale and w2_weight_scale must be None.")
if quant_algo != QuantAlgorithm.NO_QUANT and (w13_weight_scale is None or w2_weight_scale is None):
raise ParametersInvalid("w13_weight_scale and w2_weight_scale must be provided.")
if not isinstance(num_experts, int) or isinstance(num_experts, bool):
raise ParametersInvalid(f"num_experts must be an integer, but got {type(num_experts)}.")
if not isinstance(top_k, int) or isinstance(top_k, bool):
raise ParametersInvalid(f"top_k must be an integer, but got {type(top_k)}.")
if not isinstance(k_group, int) or isinstance(k_group, bool):
raise ParametersInvalid(f"k_group must be an integer, but got {type(k_group)}.")
if not isinstance(group_count, int) or isinstance(group_count, bool):
raise ParametersInvalid(f"group_count must be an integer, but got {type(group_count)}.")
if not isinstance(routed_scaling_factor, float):
raise ParametersInvalid(f"routed_scaling_factor must be a float, but got {type(routed_scaling_factor)}.")
if not isinstance(tokens_full, bool):
raise ParametersInvalid(f"tokens_full must be a bool, but got {type(tokens_full)}.")
if not isinstance(renormalize, bool):
raise ParametersInvalid(f"renormalize must be a bool, but got {type(renormalize)}.")
if not isinstance(reduce_results, bool):
raise ParametersInvalid(f"reduce_results must be a bool, but got {type(reduce_results)}.")
if dispatcher_type not in (None, "static", "dynamic"):
raise ParametersInvalid(f"dispatcher_type must be None, 'static', or 'dynamic', but got {dispatcher_type}.")
if group_select_mode not in (0, 1):
raise ParametersInvalid(f"group_select_mode must be 0 or 1, but got {group_select_mode}.")
if routing_method not in ("softmax", "sigmoid"):
raise ParametersInvalid(f"routing_method must be 'softmax' or 'sigmoid', but got {routing_method}.")
if custom_routing_function is not None and not callable(custom_routing_function):
raise ParametersInvalid("custom_routing_function must be callable if provided.")
if not 1 <= top_k <= num_experts:
raise ParametersInvalid(f"top_k must be in [1, num_experts], but got top_k={top_k}, num_experts={num_experts}.")
if k_group < 1:
raise ParametersInvalid(f"k_group must be positive, but got {k_group}.")
if group_count < 1:
raise ParametersInvalid(f"group_count must be positive, but got {group_count}.")
if k_group > group_count:
raise ParametersInvalid(f"k_group={k_group} exceeds group_count={group_count}.")
if tp_group is not None and not isinstance(tp_group, dist.ProcessGroup):
raise ParametersInvalid(f"tp_group must be a dist.ProcessGroup or None, but got {type(tp_group)}.")
if ep_group is not None and not isinstance(ep_group, dist.ProcessGroup):
raise ParametersInvalid(f"ep_group must be a dist.ProcessGroup or None, but got {type(ep_group)}.")
if hidden_states.dim() < 2:
raise ParametersInvalid(f"hidden_states must be at least 2D, but got shape {tuple(hidden_states.shape)}.")
if router_logits.dim() < 2:
raise ParametersInvalid(f"router_logits must be at least 2D, but got shape {tuple(router_logits.shape)}.")
if w13_weight.dim() != 3:
raise ParametersInvalid(f"w13_weight must be a 3D tensor, but got shape {tuple(w13_weight.shape)}.")
if w2_weight.dim() != 3:
raise ParametersInvalid(
"w2_weight must be a 3D tensor with shape "
"[local_experts, intermediate_size, hidden_size], "
f"but got shape {tuple(w2_weight.shape)}."
)
if router_logits.shape[-1] != num_experts:
raise ParametersInvalid(
f"The last dim of router_logits must match num_experts={num_experts}, "
f"but got shape {tuple(router_logits.shape)}."
)
if hidden_states.shape[:-1] != router_logits.shape[:-1]:
raise ParametersInvalid(
"hidden_states and router_logits must have the same leading dimensions, "
f"but got hidden_states={tuple(hidden_states.shape)}, "
f"router_logits={tuple(router_logits.shape)}."
)
if w13_weight.shape[0] < 1:
raise ParametersInvalid(f"local_num_experts must be positive, but got {w13_weight.shape[0]}.")
if w13_weight.shape[0] != w2_weight.shape[0]:
raise ParametersInvalid(
"w13_weight and w2_weight must have the same number of local experts, "
f"but got w13_weight={tuple(w13_weight.shape)}, w2_weight={tuple(w2_weight.shape)}."
)
if hidden_states.shape[-1] != w13_weight.shape[1]:
raise ParametersInvalid(
"hidden size must match w13_weight input dimension, "
f"but got hidden_states={hidden_states.shape[-1]}, "
f"w13_weight={w13_weight.shape[1]}."
)
if w13_weight.shape[1] != w2_weight.shape[2]:
raise ParametersInvalid(
"w13_weight input dimension must match w2_weight output dimension, "
f"but got w13_weight={w13_weight.shape[1]}, w2_weight={w2_weight.shape[2]}."
)
if w13_weight.shape[2] != 2 * w2_weight.shape[1]:
raise ParametersInvalid(
"w13_weight output dimension must be twice w2_weight input dimension, "
f"but got w13_weight={w13_weight.shape[2]}, w2_weight={w2_weight.shape[1]}."
)
if w13_bias is not None and w13_bias.shape != w13_weight.shape[:1] + w13_weight.shape[2:]:
raise ParametersInvalid(
"w13_bias shape must match w13_weight expert and output dimensions, "
f"but got bias={tuple(w13_bias.shape)}, weight={tuple(w13_weight.shape)}."
)
if w2_bias is not None and w2_bias.shape != w2_weight.shape[:1] + w2_weight.shape[2:]:
raise ParametersInvalid(
"w2_bias shape must match w2_weight expert and output dimensions, "
f"but got bias={tuple(w2_bias.shape)}, weight={tuple(w2_weight.shape)}."
)
if ep_group is not None:
ep_size = dist.get_world_size(ep_group)
if num_experts % ep_size != 0:
raise ParametersInvalid(
"num_experts must be evenly divisible by ep_size, "
f"but got num_experts={num_experts}, ep_size={ep_size}."
)
if num_experts % group_count != 0:
raise ParametersInvalid(
"num_experts must be evenly divisible by group_count, "
f"but got num_experts={num_experts}, group_count={group_count}."
)
experts_per_group = num_experts // group_count
if group_select_mode == 1 and experts_per_group < 2:
raise ParametersInvalid(
"group_select_mode=1 requires at least two experts per group, "
f"but got experts_per_group={experts_per_group}."
)
if top_k > k_group * experts_per_group:
raise ParametersInvalid(
"top_k cannot exceed the number of experts in selected groups, "
f"but got top_k={top_k}, k_group={k_group}, experts_per_group={experts_per_group}."
)
return quant_algo
class MoECommType(Enum):
"""Resolved MoE communication mode."""
NONE = "none"
TP = "tp"
EP = "ep"
_TP_GROUP = None
_EP_GROUP = None
_MOE_COMM_TYPE = MoECommType.NONE
_MOE_QUANT_ALGO = QuantAlgorithm.NO_QUANT
def set_moe_comm_context(tp_group=None, ep_group=None) -> None:
"""Set process-group context for the current MoE invocation."""
global _TP_GROUP, _EP_GROUP, _MOE_COMM_TYPE
_TP_GROUP = tp_group
_EP_GROUP = ep_group
if ep_group is not None and dist.get_world_size(ep_group) > 1:
_MOE_COMM_TYPE = MoECommType.EP
elif tp_group is not None and dist.get_world_size(tp_group) > 1:
_MOE_COMM_TYPE = MoECommType.TP
else:
_MOE_COMM_TYPE = MoECommType.NONE
def get_moe_comm_type() -> MoECommType:
"""Return the resolved MoE communication mode."""
return _MOE_COMM_TYPE
def set_moe_quant_algo(quant_algo: QuantAlgorithm = QuantAlgorithm.NO_QUANT) -> None:
"""Store the current MoE quantization algorithm."""
global _MOE_QUANT_ALGO
_MOE_QUANT_ALGO = quant_algo
def get_moe_quant_algo() -> QuantAlgorithm:
"""Return the current MoE quantization algorithm."""
return _MOE_QUANT_ALGO
def dynamic_quant(hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply dynamic quantization according to the current MoE quantization algorithm."""
if _MOE_QUANT_ALGO == QuantAlgorithm.W8A8_DYNAMIC:
return torch_npu.npu_dynamic_quant(hidden_states, dst_type=torch.int8)
if _MOE_QUANT_ALGO == QuantAlgorithm.W8A8_MXFP8:
return torch_npu.npu_dynamic_mx_quant(hidden_states, dst_type=torch.float8_e4m3fn)
raise ParametersInvalid(f"Unsupported MoE quantization algorithm: {_MOE_QUANT_ALGO}.")
def get_init_routing_quant_mode(dynamic_scale: torch.Tensor | None) -> int:
"""Return npu_moe_init_routing_v2 quant mode: -1 no quant, 1 dynamic INT8, 3 MXFP8."""
if not is_moe_quant() or dynamic_scale is not None:
return -1
if _MOE_QUANT_ALGO == QuantAlgorithm.W8A8_DYNAMIC:
return 1
if _MOE_QUANT_ALGO == QuantAlgorithm.W8A8_MXFP8:
return 3
raise ParametersInvalid(f"Unsupported MoE quantization algorithm: {_MOE_QUANT_ALGO}.")
def is_moe_quant() -> bool:
"""Return whether the current MoE path uses quantization."""
return _MOE_QUANT_ALGO != QuantAlgorithm.NO_QUANT
def is_moe_int_quant() -> bool:
"""Return whether the current MoE path uses integer quantization."""
return _MOE_QUANT_ALGO in MOE_INT_QUANT_ALGOS
def is_moe_mxfp_quant() -> bool:
"""Return whether the current MoE path uses MXFP quantization."""
return _MOE_QUANT_ALGO in MOE_MXFP_QUANT_ALGOS
def set_moe_context(tp_group=None, ep_group=None, quant_algo: QuantAlgorithm = QuantAlgorithm.NO_QUANT) -> None:
"""Set MoE context."""
set_moe_comm_context(tp_group=tp_group, ep_group=ep_group)
set_moe_quant_algo(quant_algo)
def get_moe_group():
"""Return the process group selected for MoE communication."""
if _MOE_COMM_TYPE == MoECommType.EP:
return _EP_GROUP
if _MOE_COMM_TYPE == MoECommType.TP:
return _TP_GROUP
return None
def build_prepare_input(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
tokens_full: bool = True,
) -> MoEPrepareInput:
"""Build the prepare-stage input wrapper."""
return MoEPrepareInput(
hidden_states=hidden_states,
router_logits=router_logits,
tokens_full=tokens_full,
)
def build_routing_input(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
k_group: int = 1,
group_count: int = 1,
group_select_mode: int = 0,
routing_method: str = "softmax",
renormalize: bool = False,
routed_scaling_factor: float = 1.0,
custom_routing_function: Callable | None = None,
) -> MoERoutingInput:
"""Build the expert-selection input wrapper."""
return MoERoutingInput(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
k_group=k_group,
group_count=group_count,
group_select_mode=group_select_mode,
norm_type=0 if routing_method == "softmax" else 1,
routed_scaling_factor=routed_scaling_factor,
custom_routing_function=custom_routing_function,
)
def build_moe_weights(
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
w13_weight_scale: torch.Tensor | None = None,
w2_weight_scale: torch.Tensor | None = None,
) -> MoEWeights:
"""Build the expert weight payload."""
return MoEWeights(
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_bias=w13_bias,
w2_bias=w2_bias,
w13_weight_scale=w13_weight_scale,
w2_weight_scale=w2_weight_scale,
)
def build_token_dispatch_input(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
top_k: int,
weights: MoEWeights,
dynamic_scale: torch.Tensor | None = None,
) -> MoETokenDispatchInput:
"""Build the token-dispatch input wrapper."""
return MoETokenDispatchInput(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
num_experts=num_experts,
top_k=top_k,
local_num_experts=weights.w13_weight.shape[0],
dynamic_scale=dynamic_scale,
)
def build_mlp_compute_input(
dispatch_output: MoETokenDispatchOutput,
weights: MoEWeights,
mlp_output_dtype: torch.dtype,
) -> MoEMlpComputeInput:
"""Build the grouped-MLP input wrapper."""
return MoEMlpComputeInput(
hidden_states=dispatch_output.hidden_states,
dynamic_scale=dispatch_output.dynamic_scale,
group_list=dispatch_output.group_list,
group_list_type=dispatch_output.group_list_type,
weights=weights,
mlp_output_dtype=mlp_output_dtype,
)