from functools import wraps
import torch
import torch.nn.functional as F
from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.moe import grouped_gemm_util as gg
from megatron.training import get_args
from mindspeed.core.fusions.fused_bias_swiglu import fused_swiglu
from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput
from mindspeed.core.transformer.moe.grouped_gemm_util import fused_alltoall_gather_bmm, fused_bmm_reducescatter_alltoall
from mindspeed.core.transformer.moe.grouped_matmul_util import get_gmm_quant_func
from mindspeed.core.transformer.moe.grouped_mlp_with_comp_and_comm_overlap_all2all import \
grouped_mlp_with_comp_and_comm_overlap_all2all
from mindspeed.core.transformer.moe.grouped_mlp_with_comp_and_comm_overlap_allgather import \
grouped_mlp_with_comp_and_comm_overlap_allgather
from mindspeed.model.transformer import should_recompute_activation
COMM_STREAM = None
def get_zeros_with_tp(input_):
world_size = parallel_state.get_tensor_model_parallel_world_size()
zeros_shape = input_.shape[:-1] + (input_.shape[-1] * world_size,)
return torch.zeros(zeros_shape, dtype=input_.dtype, layout=input_.layout, device=input_.device)
def sequential_mlp_forward(self, permuted_local_hidden_states, tokens_per_expert, permuted_probs):
output_local = get_zeros_with_tp(permuted_local_hidden_states)
output_bias_local = None
if self.add_bias:
output_bias_local = get_zeros_with_tp(permuted_local_hidden_states)
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
global COMM_STREAM
if parallel_state.get_tensor_model_parallel_world_size() > 1:
if COMM_STREAM is None:
COMM_STREAM = torch.cuda.Stream()
COMM_STREAM.wait_stream(torch.cuda.current_stream())
for expert_num, expert in enumerate(self.local_experts):
start = cumsum_num_tokens[expert_num]
end = cumsum_num_tokens[expert_num + 1]
hidden = permuted_local_hidden_states[start:end]
if parallel_state.get_tensor_model_parallel_world_size() > 1:
with torch.cuda.stream(COMM_STREAM):
hidden = tensor_parallel.all_gather_last_dim_from_tensor_parallel_region(hidden)
torch.cuda.current_stream().wait_stream(COMM_STREAM)
if get_args().prof_file and get_args().num_experts:
activation_func = torch.nn.Hardshrink()
hidden = activation_func(hidden)
output, output_bias = expert(hidden)
if get_args().prof_file and get_args().num_experts:
activation_func = torch.nn.Hardshrink()
output = activation_func(output)
output_local[start:end] = output
if self.add_bias:
output_bias = output_bias.expand_as(output)
output_bias_local[start:end, :] = output_bias
return output_local, output_bias_local
def group_mlp_forward(self, permuted_local_hidden_states, tokens_per_expert, ctx=None):
if permuted_local_hidden_states.nelement() != 0:
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
else:
w1 = self.weight1.view(self.config.hidden_size, -1)
w2 = self.weight2.view(-1, self.config.hidden_size)
group_list = torch.cumsum(tokens_per_expert, dim=0)
if get_args().moe_alltoall_overlap_comm:
return grouped_mlp_with_comp_and_comm_overlap_all2all(permuted_local_hidden_states, w1, w2,
(self.weight1, self.weight2, self.activation_func,
group_list, self.layer_number),
ctx=ctx)
else:
return grouped_mlp_with_comp_and_comm_overlap_allgather(permuted_local_hidden_states, w1, w2,
(self.weight1, self.weight2, self.activation_func,
group_list, self.layer_number))
def groupedmlp_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
args_ = get_args()
tp_size = parallel_state._MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE
if args_.moe_tp_extend_ep:
parallel_state._MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE = 1
fn(self, *args, **kwargs)
if args_.moe_tp_extend_ep:
parallel_state._MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE = tp_size
if self.config.gated_linear_unit:
assert (self.config.activation_func == F.silu
), 'Activation function must be silu when using fused_swiglu.'
self.activation_func = fused_swiglu
self.layer_number = None
self.set_recompute_activation_func = False
return wrapper
def groupedmlp_forward(self, permuted_local_hidden_states, tokens_per_expert, permuted_probs):
args = get_args()
is_recompute_activation = should_recompute_activation(
self.layer_number) and not args.moe_alltoall_overlap_comm and not args.moe_allgather_overlap_comm
gemm_fusion = args.gemm_gradient_accumulation_fusion
tp_group = parallel_state.get_tensor_model_parallel_group()
ep_group = parallel_state.get_expert_model_parallel_group()
gmm_quant_func = get_gmm_quant_func()
if not is_recompute_activation:
if permuted_local_hidden_states.nelement() != 0:
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
if args.moe_bmm_mc2:
permuted_local_hidden_states = permuted_local_hidden_states.view(self.config.num_moe_experts,
permuted_local_hidden_states.shape[
0] // self.config.num_moe_experts,
-1)
bmm_param = {'group_ep': ep_group, 'group_tp': tp_group, 'shard_type': 0,
'need_recompute': False}
fc1_output = fused_alltoall_gather_bmm(permuted_local_hidden_states, w1, None, bmm_param)
intermediate_parallel = self.activation_func(fc1_output)
fc2_output = fused_bmm_reducescatter_alltoall(intermediate_parallel, w2, None, bmm_param)
fc2_output = fc2_output.view(-1, fc2_output.shape[2])
elif gmm_quant_func:
fc1_output = gmm_quant_func.gmm_apply(permuted_local_hidden_states, w1, None, tokens_per_expert, self.weight1)
intermediate_parallel = self.activation_func(fc1_output)
fc2_output = gmm_quant_func.gmm_apply(intermediate_parallel, w2, None, tokens_per_expert, self.weight2)
else:
fc1_output = gg.ops.gmm(permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False,
gemm_fusion=gemm_fusion, original_weight=self.weight1)
intermediate_parallel = self.activation_func(fc1_output)
fc2_output = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False,
gemm_fusion=gemm_fusion, original_weight=self.weight2)
else:
assert torch.count_nonzero(tokens_per_expert) == 0
w1 = self.weight1.view(self.config.hidden_size, -1)
w2 = self.weight2.view(-1, self.config.hidden_size)
h = torch.matmul(permuted_local_hidden_states, w1)
h = self.activation_func(h)
h = torch.matmul(h, w2)
fc2_output = h
else:
if permuted_local_hidden_states.nelement() != 0:
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
bmm_param = {'group_ep': ep_group, 'group_tp': tp_group, 'shard_type': 0,
'need_recompute': False}
if args.moe_bmm_mc2:
permuted_local_hidden_states = permuted_local_hidden_states.view(self.config.num_moe_experts,
permuted_local_hidden_states.shape[
0] // self.config.num_moe_experts,
-1)
fc1_output = fused_alltoall_gather_bmm(permuted_local_hidden_states, w1, None, bmm_param)
elif gmm_quant_func:
fc1_output = gmm_quant_func.gmm_apply(permuted_local_hidden_states, w1, None, tokens_per_expert, self.weight1)
else:
fc1_output = gg.ops.gmm(
permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False, gemm_fusion=gemm_fusion,
original_weight=self.weight1
)
self.activation_checkpoint_manager = CheckpointWithoutOutput()
intermediate_parallel = self.activation_checkpoint_manager.checkpoint(self.activation_func,
False,
fc1_output)
if args.moe_bmm_mc2:
fc2_output = fused_bmm_reducescatter_alltoall(intermediate_parallel, w2, None, bmm_param)
fc2_output = fc2_output.view(-1, fc2_output.shape[2])
elif gmm_quant_func:
fc2_output = gmm_quant_func.gmm_apply(intermediate_parallel, w2, None, tokens_per_expert, self.weight2)
else:
fc2_output = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False,
gemm_fusion=gemm_fusion, original_weight=self.weight2)
else:
assert torch.count_nonzero(tokens_per_expert) == 0
w1 = self.weight1.view(self.config.hidden_size, -1)
w2 = self.weight2.view(-1, self.config.hidden_size)
h = torch.matmul(permuted_local_hidden_states, w1)
self.activation_checkpoint_manager = CheckpointWithoutOutput()
intermediate_parallel = self.activation_checkpoint_manager.checkpoint(self.activation_func,
False,
h)
h = torch.matmul(intermediate_parallel, w2)
fc2_output = h
self.activation_checkpoint_manager.discard_output()
if fc2_output.requires_grad:
fc2_output.register_hook(self.activation_checkpoint_manager.recompute)
return fc2_output, None