from typing import Optional, Callable
import torch
from torch.nn import Parameter
import torch_npu
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.enums import Fp8Recipe
from megatron.core.extensions.transformer_engine import condition_init_method
from megatron.core.parallel_state import (
get_expert_model_parallel_world_size,
get_expert_model_parallel_rank,
get_expert_data_parallel_rank,
get_expert_tensor_parallel_group,
get_expert_tensor_parallel_world_size,
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size,
)
from megatron.core.tensor_parallel.layers import _initialize_affine_weight_cpu, _initialize_affine_weight_gpu
from megatron.core.transformer.mlp import apply_swiglu_sharded_factory
from megatron.core.transformer.moe.experts import expert_dist_ckpt_decorator
from megatron.core.transformer.utils import sharded_state_dict_default, make_sharded_tensors_for_checkpoint
from mindspeed.args_utils import get_full_args as get_args
from mindspeed.core.transformer.moe.grouped_matmul_util import MXFP8GMMFunction, MXFP832x32GMMFunction
def _get_partition_dim(parallel_mode):
if parallel_mode == "column":
return 0
if parallel_mode == "row":
return 1
return -1
def _set_explicit_expert_comm_attrs(param, partition_dim):
setattr(param, "tensor_model_parallel", False)
setattr(param, "partition_dim", partition_dim)
setattr(param, "partition_stride", 1)
class MindSpeedTEGroupedLinearGMM(torch.autograd.Function):
@staticmethod
def forward(
ctx, input_tensor: torch.Tensor, m_split=None, group_list_type=None, ori_weight=None, *weight_input_T
) -> torch.Tensor:
if not isinstance(m_split, torch.Tensor):
ctx.group_list = torch.tensor(m_split, device='npu', dtype=torch.int64)
else:
ctx.group_list = m_split
weight_T = weight_input_T
ctx.group_list_type = group_list_type
fwd_output = torch_npu.npu_grouped_matmul(
[input_tensor],
weight_T,
bias=None,
group_list=ctx.group_list,
split_item=2,
group_type=0,
group_list_type=ctx.group_list_type,
)[0]
ctx.save_for_backward(input_tensor, *ori_weight)
return fwd_output
@staticmethod
def backward(ctx, grad_output):
group_list = ctx.group_list
inp = ctx.saved_tensors[0]
weight = ctx.saved_tensors[1:]
group_list_type = ctx.group_list_type
grad = torch_npu.npu_grouped_matmul(
[grad_output],
weight,
bias=None,
group_list=group_list,
split_item=2,
group_type=0,
group_list_type=group_list_type,
)[0]
grad_weight = torch_npu.npu_grouped_matmul(
[inp.T],
[grad_output],
bias=None,
group_list=group_list,
split_item=3,
group_type=2,
group_list_type=group_list_type,
)[0]
return grad, None, None, None, *grad_weight
def _grouped_linear_forward(cls, ctx, input_tensor, m_split, reuse_identity, *weight_views):
if not isinstance(m_split, torch.Tensor):
m_split = torch.tensor(m_split, device='npu', dtype=torch.int64)
ctx.group_list = torch.cumsum(m_split, dim=0)
weight = torch.stack(weight_views, dim=0)
fwd_output = cls.op_forward(ctx, input_tensor, weight, ctx.group_list, reuse_identity=reuse_identity)[0]
ctx.save_for_backward(input_tensor, weight)
return fwd_output
def _grouped_linear_backward(cls, ctx, grad_outputs):
group_list = ctx.group_list
inp, weight = ctx.saved_tensors
grad = cls.op_dx(ctx, grad_outputs, weight, group_list)[0]
grad_weight = cls.op_dw(ctx, inp, grad_outputs, group_list)[0]
return grad, None, None, *grad_weight
class MindSpeedTEGroupedLinearMXFP8GMM(MXFP8GMMFunction):
@classmethod
def forward(cls, ctx, input_tensor: torch.Tensor, m_split=None, reuse_identity=None, *weight_views) -> torch.Tensor:
return _grouped_linear_forward(cls, ctx, input_tensor, m_split, reuse_identity, *weight_views)
@classmethod
def backward(cls, ctx, grad_outputs):
return _grouped_linear_backward(cls, ctx, grad_outputs)
class MindSpeedTEGroupedLinearMXFP832x32GMM(MXFP832x32GMMFunction):
@classmethod
def forward(cls, ctx, input_tensor: torch.Tensor, m_split=None, reuse_identity=None, *weight_views) -> torch.Tensor:
return _grouped_linear_forward(cls, ctx, input_tensor, m_split, reuse_identity, *weight_views)
@classmethod
def backward(cls, ctx, grad_outputs):
return _grouped_linear_backward(cls, ctx, grad_outputs)
class MindSpeedTEGroupedLinear(torch.nn.Module):
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
parallel_mode: Optional[str],
config,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: Optional[str] = None,
**kwargs,
):
super().__init__()
self.num_gemms = num_gemms
self.config = config
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.use_bias = bias
self.output_size = output_size
self.input_size = input_size
self.partition_dim = _get_partition_dim(parallel_mode)
self.parallel_mode = parallel_mode
if is_expert:
tp_group = get_expert_tensor_parallel_group(check_initialized=False)
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
self.expert_parallel = self.config.expert_model_parallel_size > 1
self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
if self.explicit_expert_comm:
if parallel_mode == "column":
if output_size % tp_size != 0:
raise AssertionError("{} is not divisible by {}".format(output_size, tp_size))
self.output_size = output_size // tp_size
self.input_size = input_size
elif parallel_mode == "row":
if input_size % tp_size != 0:
raise AssertionError("{} is not divisible by {}".format(input_size, tp_size))
self.output_size = output_size
self.input_size = input_size // tp_size
self.tp_size = 1
self.tp_group = None
self.total_weight = []
for i in range(self.num_gemms):
expert_weight = Parameter(
torch.empty(
self.output_size,
self.input_size,
device=torch.device('cpu') if self.config.use_cpu_initialization else torch.npu.current_device(),
dtype=config.params_dtype,
)
)
self.register_parameter('weight{}'.format(i), expert_weight)
if self.config.perform_initialization:
if self.config.use_cpu_initialization:
_initialize_affine_weight_cpu(
expert_weight,
output_size,
input_size,
self.output_size if parallel_mode == "column" else self.input_size,
partition_dim=self.partition_dim,
init_method=init_method,
stride=1,
rank=torch.distributed.get_rank(tp_group),
world_size=tp_size,
)
else:
_initialize_affine_weight_gpu(
expert_weight, init_method, partition_dim=self.partition_dim, stride=1, is_expert=is_expert
)
if self.explicit_expert_comm and parallel_mode in ("column", "row"):
_set_explicit_expert_comm_attrs(expert_weight, self.partition_dim)
self.total_weight.append(expert_weight)
for param in self.parameters():
setattr(param, 'allreduce', not (is_expert and self.expert_parallel))
def forward(self, x, m_splits):
args = get_args()
if not getattr(args, "fp8", False) or getattr(args, "fp8_recipe", False) not in (
Fp8Recipe.mxfp8,
Fp8Recipe.mxfp8_32x32,
):
group_list_type = 1
self.total_weight_T = [w.T.contiguous() for w in self.total_weight]
output = MindSpeedTEGroupedLinearGMM.apply(
x, m_splits, group_list_type, self.total_weight, *self.total_weight_T
)
elif getattr(args, "fp8_recipe", False) == Fp8Recipe.mxfp8:
output = MindSpeedTEGroupedLinearMXFP8GMM.apply(
x, m_splits, self.total_weight, *[w.T.contiguous() for w in self.total_weight]
)
else:
output = MindSpeedTEGroupedLinearMXFP832x32GMM.apply(
x, m_splits, self.total_weight, *[w.T.contiguous() for w in self.total_weight]
)
return output, None
def _sharded_state_dict_grouped(self, tp_axis_map, prefix='', sharded_offsets=(), metadata=None):
"""
prefix should be module_name to make keys identical to sequetial ones.
"""
sharded_state_dict = {}
full_state_dict = self.state_dict(prefix='', keep_vars=True)
num_global_experts = get_expert_model_parallel_world_size() * self.num_gemms
local_expert_indices_offset = get_expert_model_parallel_rank() * self.num_gemms
ep_axis = len(sharded_offsets)
for gemm_idx in range(self.num_gemms):
state_dict = {
f'{gemm_idx}.weight': full_state_dict[f'weight{gemm_idx}'],
}
if self.use_bias:
state_dict[f'{gemm_idx}.bias'] = full_state_dict[f'bias{gemm_idx}']
sub_sd = make_sharded_tensors_for_checkpoint(
state_dict,
'',
tp_axis_map,
(
*sharded_offsets,
(ep_axis, local_expert_indices_offset + gemm_idx, num_global_experts),
),
)
replace_prefix_for_sharding(sub_sd, f'{gemm_idx}.', prefix)
sharded_state_dict.update({f'{prefix}weight{gemm_idx}': sub_sd[f'{gemm_idx}.weight']})
if self.use_bias:
sharded_state_dict[f'{prefix}bias{gemm_idx}'] = sub_sd[f'{gemm_idx}.bias']
for k, sh_ten in sharded_state_dict.items():
replica_id = sh_ten.replica_id
assert len(replica_id) == 3, f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
if getattr(sh_ten, "is_data_parallel_fully_shard", False):
edp_replica_id = 0
else:
edp_replica_id = get_expert_data_parallel_rank()
sh_ten.replica_id = (*replica_id[:2], edp_replica_id)
return sharded_state_dict
@expert_dist_ckpt_decorator
def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None):
"""
Maps local expert to global experts.
The sharded state dict is interchangable with SequentialMLP's.
"""
sharded_state_dict = {}
for name, module in self._modules.items():
sub_sd = sharded_state_dict_default(module, f'{name}.', sharded_offsets, metadata)
if name == 'linear_fc1' and self.config.gated_linear_unit:
num_global_experts = parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
local_expert_indices_offset = parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
ep_axis = len(sharded_offsets)
for i in range(self.num_local_experts):
new_sharded_offsets = (
*sharded_offsets,
(ep_axis, local_expert_indices_offset + i, num_global_experts),
)
for k in (f'{name}.weight{i}', f'{name}.bias{i}'):
if k in sub_sd:
sub_sd[k] = apply_swiglu_sharded_factory(sub_sd[k], new_sharded_offsets)
replace_prefix_for_sharding(sub_sd, f'{name}.', f'{prefix}experts.{name}.')
sharded_state_dict.update({f"{prefix}{k}": v for k, v in sub_sd.items()})
return sharded_state_dict
class MindSpeedTEColumnParallelGroupedLinear(MindSpeedTEGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to column-parallel style.
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: Optional[str] = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="column",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
For each gemm, sharding along axis 0, bias sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map = {}
for gemm_idx in range(self.num_gemms):
tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0})
return super()._sharded_state_dict_grouped(tp_axis_map, prefix, sharded_offsets, metadata)
class MindSpeedTERowParallelGroupedLinear(MindSpeedTEGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to row-parallel style.
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: Optional[str] = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="row",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
For each gemm, sharding along axis 1, bias not sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map = {f'{gemm_idx}.weight': 1 for gemm_idx in range(self.num_gemms)}
return super()._sharded_state_dict_grouped(tp_axis_map, prefix, sharded_offsets, metadata)
def mindspeed_groupedmlp_weighted_bias_swiglu_impl(x, bias, probs, fp8_input_store=False):
"""Patch of TEGroupedMLP with MindSpeed.
Use ascend fused_swiglu instead weighted_bias_swiglu_impl for better performance.
"""
from mindspeed.core.fusions.fused_bias_swiglu import fused_swiglu
if bias is not None:
raise NotImplementedError("Bias is not support for weighted swiglu fusion.")
dtype = x.dtype
res = fused_swiglu(x) * probs
return res.to(dtype)