"""Lazy registry mapping model name to (ForCausalLM, [ModelMTP,] Config) classes."""
import importlib
_specs: dict[str, list[tuple[str, str]]] = {
"deepseek_r1": [
("models.deepseek_r1.models.modeling_deepseek", "DeepseekV3ForCausalLM"),
("models.deepseek_r1.models.modeling_deepseek", "DeepseekV3ModelMTP"),
("models.deepseek_r1.models.configuration_deepseek", "DeepseekV3Config"),
],
"deepseek_v3_2_exp": [
("models.deepseek_v3_2_exp.models.modeling_deepseek", "DeepseekV3ForCausalLM"),
("models.deepseek_v3_2_exp.models.modeling_deepseek", "DeepseekV3ModelMTP"),
("models.deepseek_v3_2_exp.models.configuration_deepseek", "DeepseekV3IndexConfig"),
],
"deepseek_v2_lite": [
("models.deepseek_r1.models.modeling_deepseek", "DeepseekV3ForCausalLM"),
("models.deepseek_r1.models.configuration_deepseek", "DeepseekV3Config"),
],
"gemma_4": [
("models.gemma_4.models.modeling_gemma4", "Gemma4ForCausalLM"),
("models.gemma_4.models.configuration_gemma4", "Gemma4TextConfig"),
],
"gpt_oss": [
("models.gpt_oss.models.modeling_gpt_oss", "GptOssForCausalLM"),
("models.gpt_oss.models.configuration_gpt_oss", "GptOssConfig"),
],
"hy3_preview": [
("models.hy3_preview.models.modeling_hy_v3", "HYV3ForCausalLM"),
("models.hy3_preview.models.configuration_hy_v3", "HYV3Config"),
],
"kimi_k2": [
("models.deepseek_r1.models.modeling_deepseek", "DeepseekV3ForCausalLM"),
("models.deepseek_r1.models.configuration_deepseek", "DeepseekV3Config"),
],
"kimi_k2_thinking": [
("models.kimi_k2_thinking.models.modeling_deepseek", "DeepseekV3ForCausalLM"),
("models.kimi_k2_thinking.models.modeling_deepseek", "DeepseekV3ModelMTP"),
("models.kimi_k2_thinking.models.configuration_deepseek", "DeepseekV3Config"),
],
"longcat_flash_lite": [
("models.longcat_flash_lite.models.modeling_longcat_flash_lite", "LongcatFlashNgramForCausalLM"),
("models.longcat_flash_lite.models.configuration_longcat_flash_lite", "LongcatFlashNgramConfig"),
],
"qwen25_7b_instruct": [
("models.qwen.models.modeling_qwen", "QwenForCausalLM"),
("models.qwen.models.configuration_qwen", "Qwen2Config"),
],
"qwen3_8b": [
("models.qwen.models.modeling_qwen", "QwenForCausalLM"),
("models.qwen.models.configuration_qwen", "Qwen3Config"),
],
"qwen3_moe": [
("models.qwen3_moe.models.modeling_qwen3_moe", "Qwen3MoeForCausalLM"),
("models.qwen3_moe.models.configuration_qwen3_moe", "Qwen3MoeConfig"),
],
"step3p7_flash": [
("models.step3p7_flash.models.modeling_step3p7", "Step3p5ForCausalLM"),
("models.step3p7_flash.models.configuration_step3p7", "Step3p7TextConfig"),
],
"qwen3.5": [
("models.qwen3_5.models.modeling_qwen3_5_moe", "Qwen3_5MoeForCausalLM"),
("models.qwen3_5.models.configuration_qwen3_5_moe", "Qwen3_5MoeTextConfig"),
],
"longcat_flash": [
("models.longcat_flash.models.modeling_longcat_flash", "LongcatFlashForCausalLM"),
("models.longcat_flash.models.modeling_longcat_flash", "LongcatFlashModelMTP"),
("models.longcat_flash.models.configuration_longcat_flash", "LongcatFlashConfig"),
],
"longcat_flash_ffn": [
("models.longcat_flash.models.ffn", "FFNForCausalLM"),
("models.longcat_flash.models.configuration_longcat_flash", "LongcatFlashConfig"),
],
}
def load_model_classes(name: str) -> tuple:
if name not in _specs:
raise ValueError(f"Unsupported model: {name}")
try:
return tuple(getattr(importlib.import_module(m), a) for m, a in _specs[name])
except Exception as e:
raise ImportError(f"failed to load model '{name}': {e}") from e