"""
Global Common Mapping Templates
===============================
Purpose: Reusable mappings for all models. DO NOT modify unless adding new templates.
Each template is identified by a unique name and maps original functions to their
Context Parallel replacements.
Fields:
- Key: Template name (referenced in MODEL_CP_CONFIG -> cp_to_att_template)
- Value: Either a single tuple or a list of tuples
- Single tuple (target_path, patch_path): One original function to replace
- List of tuples: Multiple original functions to replace
Example:
"template_name": (
"original.module.function_name", # Function to replace
"patch.module.new_function" # Replacement function
)
Available templates:
fixed_cross_entropy - Replaces cross entropy loss with CP version
cp_attention_unified - Replaces attention forward with unified CP version
dsa_attention - Replaces attention for DSA models
"""
from typing import Tuple, Dict, List
COMMON_CP_MAPPINGS: Dict[str, Tuple[str, str] | List[Tuple[str, str]]] = {
"fixed_cross_entropy": (
"transformers.loss.loss_utils.fixed_cross_entropy",
"mindspeed_llm.fsdp2.distributed.context_parallel.context_parallel_functions.fixed_cross_entropy_with_cp",
),
"cp_attention_unified": (
"transformers.models.{model_id}.modeling_{model_id}.eager_attention_forward",
"mindspeed_llm.fsdp2.distributed.context_parallel.context_parallel_functions.context_parallel_attention_forward",
),
"dsa_attention": [
(
"transformers.models.{model_id}.modeling_{model_id}.eager_attention_forward",
"mindspeed_llm.fsdp2.distributed.context_parallel.ulysses_context_parallel.dsa_attention.flash_attention_forward_fa_dsa",
),
(
"transformers.masking_utils.sdpa_mask",
"mindspeed_llm.fsdp2.distributed.context_parallel.ulysses_context_parallel.dsa_attention.sdpa_mask",
),
],
}
"""
Model Group Configuration
==================
Purpose: Group models by type, bind common templates, and define model-specific overrides.
Each model group contains:
models: Set of supported model names that use this group configuration
cp_to_att_template: Template names from COMMON_CP_MAPPINGS to apply
model_specific: Per-model overrides that take highest priority over common templates
Configuration fields explained:
group_name:
models - Model identifiers (e.g., "gpt_oss", "qwen2")
cp_to_att_template - List of template names to apply
model_specific:
model_name - List of (target, patch) tuples for this specific model only
Priority order (highest to lowest):
1. model_specific (per-model overrides)
2. cp_to_att_template (common templates)
"""
MODEL_CP_CONFIG: Dict[str, Dict] = {
"cp_attention_common": {
"models": {"gpt_oss", "qwen3_moe", "qwen2", "deepseek_v3", "qwen3"},
"cp_to_att_template": {"fixed_cross_entropy", "cp_attention_unified"},
"model_specific": {
"gpt_oss": [
(
"mindspeed_llm.fsdp2.models.gpt_oss.modeling_gpt_oss.flash_attention_forward",
"mindspeed_llm.fsdp2.distributed.context_parallel.context_parallel_functions.context_parallel_attention_forward",
)
]
},
},
"dsa": {
"models": {"glm_moe_dsa"},
"cp_to_att_template": {"fixed_cross_entropy", "dsa_attention"},
"model_specific": {
"glm_moe_dsa": [
(
"transformers.models.{model_id}.modeling_{model_id}.GlmMoeDsaAttention.forward",
"mindspeed_llm.fsdp2.distributed.context_parallel.ulysses_context_parallel.dsa_attention.dsa_forward",
)
]
},
},
"qwen3_next_gdn": {
"models": {"qwen3_next"},
"cp_to_att_template": {"fixed_cross_entropy"},
"model_specific": {
"qwen3_next": [
(
"mindspeed_llm.fsdp2.models.qwen3_next.modeling_qwen3_next.Qwen3NextGatedDeltaNet.forward",
"mindspeed_llm.fsdp2.distributed.context_parallel.gdn_context_parallel.gdn_forward_with_cp",
),
(
"mindspeed_llm.fsdp2.models.qwen3_next.modeling_qwen3_next.flash_attention_forward",
"mindspeed_llm.fsdp2.distributed.context_parallel.context_parallel_functions.context_parallel_attention_forward",
),
]
},
},
}