import torch
from ...model_config import MoEFieldNames
from ...utils import exact_division
from ..custom_model_registry import (
ModelProfile,
register_model_profile,
resolve_visual_config,
)
QWEN3_5_VISUAL_CONFIG = resolve_visual_config({})
def _set_qwen3_5_linear_attn_tp_size(model):
tp_size = model.parallel_group_manager.tp_group.world_size
if tp_size <= 1:
return
for module in model._inner.modules():
if hasattr(module, "num_k_heads") and hasattr(module, "num_v_heads"):
if module.num_k_heads % tp_size != 0:
raise ValueError(
"Qwen3.5 linear attention requires tp_size to divide "
f"num_k_heads exactly, but got num_k_heads={module.num_k_heads} "
f"and tp_size={tp_size}."
)
if module.num_v_heads % tp_size != 0:
raise ValueError(
"Qwen3.5 linear attention requires tp_size to divide "
f"num_v_heads exactly, but got num_v_heads={module.num_v_heads} "
f"and tp_size={tp_size}."
)
module.tensor_cast_tp_size = tp_size
def patch_method_for_qwen3_5(model):
from transformers.models.qwen3_5.modeling_qwen3_5 import (
Qwen3_5GatedDeltaNet,
Qwen3_5Model,
Qwen3_5TextModel,
)
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeGatedDeltaNet,
Qwen3_5MoeModel,
Qwen3_5MoeTextModel,
)
def _get_local_linear_attn_heads(self):
tp_size = getattr(self, "tensor_cast_tp_size", 1)
return (
exact_division(self.num_k_heads, tp_size),
exact_division(self.num_v_heads, tp_size),
)
def _patched_update_linear_attn_mask(self, attention_mask, cache_position):
if torch.compiler.is_compiling():
return None
if cache_position is not None and cache_position.device.type == "meta":
return attention_mask
linear_attn_mask = attention_mask
is_meta_tensor = (hasattr(cache_position, "is_meta") and cache_position.is_meta) or (
attention_mask is not None and hasattr(attention_mask, "is_meta") and attention_mask.is_meta
)
if is_meta_tensor:
return None
try:
if cache_position is None:
cache_condition = False
else:
cache_condition = cache_position[0] > 0 if cache_position.numel() > 0 else False
mask_condition = (
torch.all(attention_mask == 1).item()
if attention_mask is not None and attention_mask.numel() > 0
else False
)
if cache_condition or mask_condition:
linear_attn_mask = None
except RuntimeError:
return None
return linear_attn_mask
def _patched_linear_attn_forward(
self,
hidden_states: torch.Tensor,
cache_params=None,
cache_position=None,
attention_mask=None,
):
del cache_params
local_num_k_heads, local_num_v_heads = _get_local_linear_attn_heads(self)
return torch.ops.tensor_cast.linear_attention(
hidden_states,
attention_mask,
cache_position,
local_num_k_heads,
local_num_v_heads,
self.head_k_dim,
self.head_v_dim,
self.conv_kernel_size,
)
target_classes = [Qwen3_5Model, Qwen3_5MoeModel]
original_methods = {cls: cls.get_placeholder_mask for cls in target_classes}
def _patched_get_placeholder_mask(self, *args, **kwargs):
kwargs["image_features"] = None
return original_methods[type(self)](self, *args, **kwargs)
Qwen3_5TextModel._update_linear_attn_mask = _patched_update_linear_attn_mask
Qwen3_5MoeTextModel._update_linear_attn_mask = _patched_update_linear_attn_mask
Qwen3_5GatedDeltaNet.forward = _patched_linear_attn_forward
Qwen3_5MoeGatedDeltaNet.forward = _patched_linear_attn_forward
for cls in target_classes:
cls.get_placeholder_mask = _patched_get_placeholder_mask
_set_qwen3_5_linear_attn_tp_size(model)
register_model_profile(
ModelProfile(
model_type="qwen3_5_moe",
moe_module_name="Qwen3_5MoeSparseMoeBlock",
moe_gate_returns_raw_logits=False,
moe_num_experts_key=["text_config", "num_experts"],
moe_field_names_override=MoEFieldNames(
shared_experts="shared_expert",
shared_experts_gate="shared_expert_gate",
),
mtp_block_module_name="Qwen3_5MoeDecoderLayer",
model_family="qwen3_5",
patch_method=patch_method_for_qwen3_5,
**QWEN3_5_VISUAL_CONFIG,
)
)