from dataclasses import dataclass
from functools import wraps
from typing import Any
import torch.nn.functional as F
from megatron.core import ModelParallelConfig
from megatron.core.transformer.transformer_config import TransformerConfig, MLATransformerConfig
from megatron.training import get_args
from mindspeed.patch_utils import MindSpeedPatchesManager as pm
from mindspeed_mm.configs.config import ConfigReader
from .utils import get_dtype, quick_gelu, gelu_tanh
def get_class_variables(cls):
all_members = dir(cls)
filtered_members = [member for member in all_members if not member.startswith("__")]
return filtered_members
def get_model_config(config):
global_args = get_args()
config_dict = config.to_dict()
if "moe_intermediate_size" in config_dict:
config_dict["moe_ffn_hidden_size"] = config_dict["moe_intermediate_size"]
if "n_shared_experts" in config_dict:
config_dict["moe_shared_expert_intermediate_size"] = (config_dict["n_shared_experts"] *
config_dict["moe_ffn_hidden_size"])
t_config = dict()
if getattr(global_args, "multi_latent_attention", False):
tfc_variables = get_class_variables(MLATransformerConfig)
else:
tfc_variables = get_class_variables(TransformerConfig)
mpc_variables = get_class_variables(ModelParallelConfig)
for key in tfc_variables:
if key in config_dict.keys():
t_config[key] = config_dict[key]
elif key in mpc_variables and hasattr(global_args, key):
t_config[key] = getattr(global_args, key)
t_config["params_dtype"] = get_dtype(t_config.get("params_dtype"))
if t_config.get("activation_func") == "silu":
t_config["activation_func"] = F.silu
elif t_config.get("activation_func") == "quick_gelu":
t_config["activation_func"] = quick_gelu
elif t_config.get("activation_func") == "gelu_pytorch_tanh":
t_config["activation_func"] = gelu_tanh
else:
t_config["activation_func"] = F.gelu
if t_config.get("kv_channels") is None and t_config.get("hidden_size") and t_config.get("num_attention_heads"):
t_config["kv_channels"] = t_config["hidden_size"] // t_config["num_attention_heads"]
if t_config.get("ffn_hidden_size") is None and t_config.get("hidden_size"):
t_config["ffn_hidden_size"] = 4 * t_config["hidden_size"]
if t_config.get("num_attention_heads") is None:
t_config["num_attention_heads"] = 0
if t_config.get("num_query_groups") is None and t_config.get("num_attention_heads"):
t_config["num_query_groups"] = t_config.get("num_attention_heads")
if t_config.get("cp_comm_type") is None:
t_config["cp_comm_type"] = None
if getattr(global_args, "multi_latent_attention", False):
t_config["rope_type"] = "rope"
trans_config = MLATransformerConfig(**t_config)
else:
trans_config = TransformerConfig(**t_config)
for key in tfc_variables:
config_dict[key] = getattr(trans_config, key)
mindspeed_variables = get_class_variables(MindSpeedArgsRequired)
for key in mindspeed_variables:
if key not in config_dict:
config_dict[key] = getattr(global_args, key)
new_config = ConfigReader(config_dict)
return new_config
@dataclass
class MindSpeedArgsRequired:
"""Base configuration for MindSpeed Core"""
micro_batch_size: Any = None
pre_tockens: Any = None
next_tockens: Any = None
sparse_mode: Any = None
use_fusion_attn_v2: Any = None
overlap_param_gather: Any = None
overlap_grad_reduce: Any = None
alibi_fusion_attn_type: Any = None
tp_2d: Any = None
context_parallel_algo: Any = None
context_parallel_kv_cache_policy: Any = None
context_parallel_cache_interval: Any = None
use_ulysses_allgather_kv: Any = None
megatron_cp_in_bnsd: Any = None
attention_mask_type: Any = None
use_cp_send_recv_overlap: Any = None
variable_seq_lengths: Any = None
reset_attention_mask: Any = None
use_ascend_coc: Any = None
coc_mode: Any = None
coc_parallel_num: Any = None
coc_fused_kernel: Any = None
def tranformer_config_post_init_wrapper(fn):
@wraps(fn)
def wrapper(self):
"""Modified from __post_init__ method of TransformerConfig to adapt MLP config"""
if self.kv_channels is None and self.num_attention_heads == 0:
self.kv_channels = 0
if self.pipeline_dtype is None:
self.pipeline_dtype = self.params_dtype
fn(self)
return wrapper
pm.register_patch("megatron.core.transformer.transformer_config.TransformerConfig.__post_init__", tranformer_config_post_init_wrapper)
pm.apply_patches()