import torch
from tensor_cast.layers import COLWISE_LINEAR, ROWWISE_LINEAR
from ..custom_model_registry import (
ModelProfile,
register_model_profile,
resolve_visual_config,
)
GLM4V_VISUAL_CONFIG = resolve_visual_config(
{
"visual_merger_linear_mapping": {
"visual.merger.gate_proj": COLWISE_LINEAR,
"visual.merger.up_proj": COLWISE_LINEAR,
"visual.merger.down_proj": ROWWISE_LINEAR,
},
"visual_mlp_linear_mapping": {
"visual.blocks.*.mlp.gate_proj": COLWISE_LINEAR,
"visual.blocks.*.mlp.up_proj": COLWISE_LINEAR,
"visual.blocks.*.mlp.down_proj": ROWWISE_LINEAR,
},
}
)
def patch_method_for_glm4_vl(_model):
"""
Patch the GLM4V-MoE model to fix simulation issues in meta mode.
Problem background:
1. VisionEmbeddings.forward converts lengths in list form to a meta tensor,
while subsequent computations require actual values (implicitly calling item), which causes errors;
2. get_placeholder_mask uses boolean-mask-based tensor indexing operations,
which fail or cause dimension mismatch in meta mode.
Solution:
* Convert list-based lengths to a tensor before entering forward, avoiding the creation of a meta tensor.
* Force image_features=None to skip image-related checks in get_placeholder_mask.
"""
from transformers.models.glm4v_moe import Glm4vMoeModel
original_get_placeholder_mask = Glm4vMoeModel.get_placeholder_mask
def patched_get_placeholder_mask(self, *args, **kwargs):
kwargs["image_features"] = None
return original_get_placeholder_mask(self, *args, **kwargs)
Glm4vMoeModel.get_placeholder_mask = patched_get_placeholder_mask
from transformers.models.glm4v_moe.modeling_glm4v_moe import (
Glm4vMoeVisionEmbeddings,
)
original_forward = Glm4vMoeVisionEmbeddings.forward
def patched_forward(self, *args, **kwargs):
if len(args) > 1 and isinstance(args[1], list):
lengths_tensor = torch.tensor(args[1], dtype=torch.long)
args = (args[0], lengths_tensor) + args[2:]
return original_forward(self, *args, **kwargs)
Glm4vMoeVisionEmbeddings.forward = patched_forward
register_model_profile(
ModelProfile(
model_type="glm4v_moe",
moe_module_name="Glm4vMoeTextMoE",
moe_gate_returns_raw_logits=True,
moe_num_experts_key=["text_config", "n_routed_experts"],
model_family="glm4v",
patch_method=patch_method_for_glm4_vl,
**GLM4V_VISUAL_CONFIG,
)
)