from mindspeed_mm.models.common.module_spec.qwen2vl_layer_spec import get_qwen2vl_layer_spec, get_mlp_module_spec, \
    get_qwen2vl_llm_layer_spec, get_qwen2_5_vit_layer_spec
from mindspeed_mm.models.common.module_spec.qwen3vl_layer_spec import get_qwen3vl_layer_spec
from mindspeed_mm.models.common.module_spec.qwen2_5omni_layer_spec import get_qwen_omni_audio_layer_spec
from mindspeed_mm.models.common.module_spec.internvl_layer_spec import get_language_layer_spec, get_vit_layer_spec
from mindspeed_mm.models.common.module_spec.deepseekvl_layer_spec import get_deepseekvl_model_spec
from mindspeed_mm.models.common.module_spec.qwen3vl_layer_spec import get_qwen3vl_llm_layer_local_spec
from mindspeed_mm.models.common.module_spec.glm4v_layer_spec import get_glm4v_layer_spec, get_glm4v_vit_layer_spec
from mindspeed_mm.models.common.module_spec.videoalign_layer_spec import get_videoalign_layer_spec, get_videoalign_llm_layer_spec

audio_layer_specs = {'qwen_omni': get_qwen_omni_audio_layer_spec}

vit_layer_specs = {'qwen2vit': get_qwen2vl_layer_spec,
                  'qwen3vit': get_qwen3vl_layer_spec,
                   'qwen2_5_vit': get_qwen2_5_vit_layer_spec,
                   'InternViT': get_vit_layer_spec,
                   'glm4v_vit': get_glm4v_vit_layer_spec,
                   'videoalign_vit': get_videoalign_layer_spec}

llm_layer_specs = {'qwen2lm': get_qwen2vl_llm_layer_spec,
                   'qwen2_5_lm': get_qwen2vl_llm_layer_spec,
                   'qwen2_5_omni_thinker': get_qwen2vl_llm_layer_spec,
                   'internllm': get_language_layer_spec,
                   'deepseek': get_deepseekvl_model_spec,
                   "qwen3_lm": get_qwen3vl_llm_layer_local_spec,
                   "glm4v_lm": get_glm4v_layer_spec,
                   "videoalign_lm": get_videoalign_llm_layer_spec}

projector_layer_specs = {'lnmlp': get_mlp_module_spec, 'mlp': get_mlp_module_spec}


def get_vit_layer_spec(config, *args, **kwargs):
    if getattr(config, 'model_id', None) is not None:
        if config.model_id in vit_layer_specs:
            return vit_layer_specs[config.model_id](config, is_vit=True)
    return None


def get_audio_layer_spec(config, *args, **kwargs):
    if getattr(config, 'model_id', None) is not None:
        if config.model_id in audio_layer_specs:
            return audio_layer_specs[config.model_id](config, is_vit=True)
    return None


def get_llm_layer_spec(config, *args, **kwargs):
    if getattr(config, 'model_id', None) is not None:
        if config.model_id in llm_layer_specs:
            return llm_layer_specs[config.model_id](config, is_vit=False)
    return None


def get_projector_layer_spec(config, *args, **kwargs):
    if getattr(config, 'model_id', None) is not None:
        if config.model_id in projector_layer_specs:
            return projector_layer_specs[config.model_id](config, use_te=False).submodules
    return None