import torch.nn.functional as F
from megatron.core import ModelParallelConfig
from megatron.core.transformer.transformer_config import TransformerConfig, MLATransformerConfig
from megatron.training import get_args
from mindspeed_mm.configs.config import ConfigReader
from mindspeed_mm.utils.utils import get_dtype, quick_gelu, gelu_tanh
from mindspeed_mm.utils.transformer_model_config import get_class_variables, MindSpeedArgsRequired
def get_model_config(config):
global_args = get_args()
config_dict = config.to_dict()
if "model_id" in config_dict and config_dict["model_id"] == "InternVLMLP":
config_dict["params_dtype"] = "bf16"
config_dict["hidden_size"] = 4096
config_dict["num_attention_heads"] = 1
config_dict["num_layers"] = 1
if "model_id" in config_dict and config_dict["model_id"] == "Qwen2.5llm":
config_dict["use_repeat_kv"] = True
if "moe_intermediate_size" in config_dict:
config_dict["moe_ffn_hidden_size"] = config_dict["moe_intermediate_size"]
if "n_shared_experts" in config_dict:
config_dict["moe_shared_expert_intermediate_size"] = (config_dict["n_shared_experts"] *
config_dict["moe_ffn_hidden_size"])
t_config = dict()
if getattr(global_args, "multi_latent_attention", False):
tfc_variables = get_class_variables(MLATransformerConfig)
else:
tfc_variables = get_class_variables(TransformerConfig)
mpc_variables = get_class_variables(ModelParallelConfig)
for key in tfc_variables:
if key in config_dict.keys():
t_config[key] = config_dict[key]
elif key in mpc_variables and hasattr(global_args, key):
t_config[key] = getattr(global_args, key)
t_config["params_dtype"] = get_dtype(t_config.get("params_dtype"))
if t_config.get("activation_func") == "silu":
t_config["activation_func"] = F.silu
elif t_config.get("activation_func") == "quick_gelu":
t_config["activation_func"] = quick_gelu
elif t_config.get("activation_func") == "gelu_pytorch_tanh":
t_config["activation_func"] = gelu_tanh
else:
t_config["activation_func"] = F.gelu
if t_config.get("kv_channels") is None and t_config.get("hidden_size") and t_config.get("num_attention_heads"):
t_config["kv_channels"] = t_config.get("hidden_size") // t_config.get("num_attention_heads")
if t_config.get("ffn_hidden_size") is None and t_config.get("hidden_size"):
t_config["ffn_hidden_size"] = 4 * t_config.get("hidden_size")
if t_config.get("num_query_groups") is None and t_config.get("num_attention_heads") != 1:
t_config["num_query_groups"] = t_config.get("num_attention_heads")
if getattr(global_args, "multi_latent_attention", False):
t_config["rope_type"] = "rope"
trans_config = MLATransformerConfig(**t_config)
else:
trans_config = TransformerConfig(**t_config)
for key in tfc_variables:
config_dict[key] = getattr(trans_config, key)
mindspeed_variables = get_class_variables(MindSpeedArgsRequired)
for key in mindspeed_variables:
if key not in config_dict:
config_dict[key] = getattr(global_args, key)
new_config = ConfigReader(config_dict)
return new_config