from functools import wraps
import math
import torch
import torch.nn.functional as F
from megatron.training import get_args
from megatron.core import parallel_state
from megatron.core.transformer.spec_utils import build_module
from megatron.core.tensor_parallel.layers import _initialize_affine_weight_gpu
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.transformer_config import TransformerConfig
from mindspeed.core.tensor_parallel.tp_2d.group_api_2d import TPXCollectiveComm, TPXOverlapCollectiveComm, \
TPYCollectiveComm, TPYOverlapCollectiveComm
from mindspeed.core.tensor_parallel.tp_2d.parallel_linear_2d import ParallelLinear2D
def should_recompute_activation(self):
args = get_args()
if not args.recompute_activation_function or self.layer_number is None:
return False
activation_recompute_layers = args.recompute_activation_function_num_layers
vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vpp_size = args.virtual_pipeline_model_parallel_size
pp_size = args.transformer_pipeline_model_parallel_size
if vpp_size is not None:
layer_per_chunk = args.num_layers_per_virtual_pipeline_stage
elif pp_size is not None:
layer_per_chunk = args.num_layers // pp_size
else:
layer_per_chunk = args.num_layers
if vpp_rank is None or not args.enable_recompute_layers_per_pp_rank:
vpp_rank = 0
if vpp_size is None or not args.enable_recompute_layers_per_pp_rank:
vpp_size = 1
recompute_priority = ((self.layer_number - 1) % layer_per_chunk) * vpp_size + vpp_rank
full_recompute_layers = args.recompute_num_layers
if full_recompute_layers:
if recompute_priority < full_recompute_layers:
return False
elif activation_recompute_layers is None:
return True
elif recompute_priority < full_recompute_layers + activation_recompute_layers:
return True
else:
return False
if activation_recompute_layers is None:
return True
else:
return recompute_priority < activation_recompute_layers
def core_mlp_init(self, config, submodules, is_expert=False, input_size=None, shared_expert=False):
super(MLP, self).__init__(config=config)
self.config: TransformerConfig = config
self.input_size = input_size if input_size else self.config.hidden_size
_args = get_args()
if _args.geglu:
self.config.gated_linear_unit = True
self.config.activation_func = F.gelu
self.config.bias_gelu_fusion = False
if _args.gelu_tanh:
def gelu_tanh_approximation(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
self.config.gated_linear_unit = True
self.config.activation_func = gelu_tanh_approximation
self.config.bias_gelu_fusion = False
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
if shared_expert:
self.linear_fc1 = build_module(
submodules.linear_fc1,
self.input_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc1',
shared_expert=shared_expert
)
else:
self.linear_fc1 = build_module(
submodules.linear_fc1,
self.input_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc1'
)
self.activation_func = self.config.activation_func
if shared_expert:
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.config.ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc2',
shared_expert=shared_expert
)
else:
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.config.ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc2'
)
self.shared_expert = shared_expert
if _args.tp_2d:
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
self.linear_fc1 = ParallelLinear2D(
self.config.hidden_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
add_bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=False,
ag_comm_intf=TPXCollectiveComm,
ag_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
rs_comm_intf=TPYCollectiveComm,
rs_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
enable_overlap_ag_with_matmul=False,
enable_overlap_matmul_with_rs=_args.enable_overlap_matmul_with_rs,
partition_dim=0,
enable_backward_overlap_ag_with_matmul=_args.enable_backward_overlap_ag_with_matmul,
_initialize_affine_weight_gpu=_initialize_affine_weight_gpu
)
self.linear_fc2 = ParallelLinear2D(
self.config.ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
add_bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=False,
ag_comm_intf=TPYCollectiveComm,
ag_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
rs_comm_intf=TPXCollectiveComm,
rs_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
enable_overlap_ag_with_matmul=_args.enable_overlap_ag_with_matmul,
enable_overlap_matmul_with_rs=False,
partition_dim=1,
enable_backward_overlap_ag_with_matmul=False,
_initialize_affine_weight_gpu=_initialize_affine_weight_gpu
)
def core_mlp_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
if 'config' in kwargs:
_config = kwargs['config']
elif len(args) > 0:
_config = args[0]
else:
raise ValueError("Missing the required argument 'config' when initializing MLP.")
_args = get_args()
if _args.geglu:
_config.gated_linear_unit = True
_config.activation_func = F.gelu
_config.bias_gelu_fusion = False
if _args.gelu_tanh:
def gelu_tanh_approximation(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
_config.gated_linear_unit = True
_config.activation_func = gelu_tanh_approximation
_config.bias_gelu_fusion = False
if hasattr(_args, 'fc_type') and _args.fc_type == 'up_down':
_config.gated_linear_unit = False
if hasattr(_args, 'llama') and _args.llama.fc_type == 'up_down':
_config.gated_linear_unit = False
if hasattr(_config, 'fc_type') and _config.fc_type == 'up_down':
_config.gated_linear_unit = False
if 'config' in kwargs:
kwargs['config'] = _config
fn(self, *args, **kwargs)
else:
fn(self, _config, *args[1:], **kwargs)
return wrapper