from typing import Optional, Callable
import torch
from torch.nn import Parameter
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.extensions.transformer_engine import condition_init_method
from megatron.core.parallel_state import (
get_expert_model_parallel_world_size,
get_expert_model_parallel_rank,
get_expert_data_parallel_rank,
get_expert_tensor_parallel_group,
get_expert_tensor_parallel_world_size,
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size,
)
from megatron.core.tensor_parallel.layers import _initialize_affine_weight_cpu, _initialize_affine_weight_gpu
from megatron.core.transformer.mlp import apply_swiglu_sharded_factory
from megatron.core.transformer.moe.experts import expert_dist_ckpt_decorator
from megatron.core.transformer.utils import sharded_state_dict_default, make_sharded_tensors_for_checkpoint
def _get_partition_dim(parallel_mode):
if parallel_mode == "column":
return 0
if parallel_mode == "row":
return 1
return -1
def _set_explicit_expert_comm_attrs(param, partition_dim):
setattr(param, "tensor_model_parallel", False)
setattr(param, "partition_dim", partition_dim)
setattr(param, "partition_stride", 1)
class MindSpeedTEPerformanceGroupedLinear(torch.nn.Module):
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
parallel_mode: Optional[str],
config,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: Optional[str] = None,
**kwargs,
):
super().__init__()
self.num_gemms = num_gemms
self.config = config
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.use_bias = bias
self.output_size = output_size
self.input_size = input_size
self.partition_dim = _get_partition_dim(parallel_mode)
self.parallel_mode = parallel_mode
if is_expert:
tp_group = get_expert_tensor_parallel_group(check_initialized=False)
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
self.expert_parallel = self.config.expert_model_parallel_size > 1
self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
if self.explicit_expert_comm:
if parallel_mode == "column":
if output_size % tp_size != 0:
raise AssertionError("{} is not divisible by {}".format(output_size, tp_size))
self.output_size = output_size // tp_size
self.input_size = input_size
elif parallel_mode == "row":
if input_size % tp_size != 0:
raise AssertionError("{} is not divisible by {}".format(input_size, tp_size))
self.output_size = output_size
self.input_size = input_size // tp_size
self.tp_size = 1
self.tp_group = None
self.weight = Parameter(
torch.empty(
self.num_gemms,
self.output_size,
self.input_size,
device=torch.device('cpu') if self.config.use_cpu_initialization else torch.npu.current_device(),
dtype=config.params_dtype,
)
)
if self.config.perform_initialization:
if self.config.use_cpu_initialization:
_initialize_affine_weight_cpu(
self.weight,
output_size,
input_size,
self.output_size if parallel_mode == "column" else self.input_size,
partition_dim=self.partition_dim,
init_method=init_method,
stride=self.num_gemms,
rank=torch.distributed.get_rank(tp_group),
world_size=tp_size,
)
else:
_initialize_affine_weight_gpu(
self.weight,
init_method,
partition_dim=self.partition_dim,
stride=self.num_gemms,
is_expert=is_expert,
)
if self.explicit_expert_comm and parallel_mode in ("column", "row"):
_set_explicit_expert_comm_attrs(self.weight, self.partition_dim)
for param in self.parameters():
setattr(param, 'allreduce', not (is_expert and self.expert_parallel))
def forward(self, x, m_splits):
if self.parallel_mode == 'column':
weight = self.weight.view(self.num_gemms, self.config.hidden_size, -1)
else:
weight = self.weight.view(self.num_gemms, -1, self.config.hidden_size)
from mindspeed.core.transformer.moe.grouped_matmul_util import get_gmm_op_cls
output = get_gmm_op_cls().gmm_apply(x, weight, None, m_splits, self.weight)
return output, None
def _sharded_state_dict_grouped(self, tp_axis_map, prefix='', sharded_offsets=(), metadata=None):
"""
prefix should be module_name to make keys identical to sequetial ones.
"""
sharded_state_dict = {}
full_state_dict = self.state_dict(prefix='', keep_vars=True)
num_global_experts = get_expert_model_parallel_world_size() * self.num_gemms
local_expert_indices_offset = get_expert_model_parallel_rank() * self.num_gemms
ep_axis = len(sharded_offsets)
for gemm_idx in range(self.num_gemms):
state_dict = {
f'{gemm_idx}.weight': full_state_dict['weight'][gemm_idx],
}
if self.use_bias:
state_dict[f'{gemm_idx}.bias'] = full_state_dict[f'bias{gemm_idx}']
sub_sd = make_sharded_tensors_for_checkpoint(
state_dict,
'',
tp_axis_map,
(
*sharded_offsets,
(ep_axis, local_expert_indices_offset + gemm_idx, num_global_experts),
),
)
replace_prefix_for_sharding(sub_sd, f'{gemm_idx}.', prefix)
sharded_state_dict.update({f'{prefix}weight{gemm_idx}': sub_sd[f'{gemm_idx}.weight']})
if self.use_bias:
sharded_state_dict[f'{prefix}bias{gemm_idx}'] = sub_sd[f'{gemm_idx}.bias']
for k, sh_ten in sharded_state_dict.items():
replica_id = sh_ten.replica_id
if len(replica_id) != 3:
raise ValueError(f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}')
if getattr(sh_ten, "is_data_parallel_fully_shard", False):
edp_replica_id = 0
else:
edp_replica_id = get_expert_data_parallel_rank()
sh_ten.replica_id = (*replica_id[:2], edp_replica_id)
return sharded_state_dict
@expert_dist_ckpt_decorator
def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None):
"""
Maps local expert to global experts.
The sharded state dict is interchangable with SequentialMLP's.
"""
sharded_state_dict = {}
for name, module in self._modules.items():
sub_sd = sharded_state_dict_default(module, f'{name}.', sharded_offsets, metadata)
if name == 'linear_fc1' and self.config.gated_linear_unit:
num_global_experts = parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
local_expert_indices_offset = parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
ep_axis = len(sharded_offsets)
for i in range(self.num_local_experts):
new_sharded_offsets = (
*sharded_offsets,
(ep_axis, local_expert_indices_offset + i, num_global_experts),
)
for k in (f'{name}.weight{i}', f'{name}.bias{i}'):
if k in sub_sd:
sub_sd[k] = apply_swiglu_sharded_factory(sub_sd[k], new_sharded_offsets)
replace_prefix_for_sharding(sub_sd, f'{name}.', f'{prefix}experts.{name}.')
sharded_state_dict.update({f"{prefix}{k}": v for k, v in sub_sd.items()})
return sharded_state_dict
class MindSpeedTEPerformanceColumnParallelGroupedLinear(MindSpeedTEPerformanceGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to column-parallel style.
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: Optional[str] = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="column",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
For each gemm, sharding along axis 0, bias sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map = {}
for gemm_idx in range(self.num_gemms):
tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0})
return super()._sharded_state_dict_grouped(tp_axis_map, prefix, sharded_offsets, metadata)
class MindSpeedTEPerformanceRowParallelGroupedLinear(MindSpeedTEPerformanceGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to row-parallel style.
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: Optional[str] = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="row",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
For each gemm, sharding along axis 1, bias not sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map = {f'{gemm_idx}.weight': 1 for gemm_idx in range(self.num_gemms)}
return super()._sharded_state_dict_grouped(tp_axis_map, prefix, sharded_offsets, metadata)