import importlib
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional, Set, Counter
from collections import defaultdict, Counter
import torch
import torch.distributed as dist
from mindspeed_llm.fsdp2.distributed.context_parallel.context_parallel_mappings import COMMON_CP_MAPPINGS, MODEL_CP_CONFIG
from mindspeed_llm.fsdp2.distributed.parallel_engine_config import CPPlanConfig
@dataclass
class CPTypeConfig:
cp_type: str
cp_specific_mappings: Optional[List[Tuple[str, str]]] = None
def __post_init__(self):
self.cp_specific_mappings = self.cp_specific_mappings or []
class ContextParallelMappingManager:
COMMON_CP_MAPPINGS = COMMON_CP_MAPPINGS
MODEL_CP_CONFIG = MODEL_CP_CONFIG
def __init__(self):
self.model_id_to_att_type: Dict[str, str] = self._generate_model_id_mapping()
def _generate_model_id_mapping(self) -> Dict[str, str]:
mapping = {}
for att_type, config in self.MODEL_CP_CONFIG.items():
for model_id in config["models"]:
mapping[model_id] = att_type
return mapping
def detect_duplicate_mappings(self, mappings: List[Tuple[str, str]], model_id: str) -> None:
if not mappings:
return
tuple_counter = Counter(mappings)
duplicate_tuples = {item: cnt for item, cnt in tuple_counter.items() if cnt > 1}
if duplicate_tuples:
duplicate_info = "\n - ".join([f"{item} (appeared {cnt} times)" for item, cnt in duplicate_tuples.items()])
warning_msg = f"[Model {model_id}] Detected duplicate mapping tuples:\n - {duplicate_info}"
raise ValueError(f"WARNING: {warning_msg}")
target_counter = defaultdict(list)
for idx, (target, patch) in enumerate(mappings):
target_counter[target].append((idx, patch))
duplicate_targets = {t: patches for t, patches in target_counter.items() if len(patches) > 1}
if duplicate_targets:
duplicate_info = []
for target, patches in duplicate_targets.items():
patch_info = ", ".join([f"Mapping #{idx + 1}: {patch}" for idx, patch in patches])
duplicate_info.append(f"{target} -> [{patch_info}]")
warning_msg = f"[Model {model_id}] Detected duplicate target function (same target replaced multiple times):\n - " + "\n - ".join(
duplicate_info)
raise ValueError(f"ERROR: {warning_msg}")
def _get_model_attention_config(self, model_id: str) -> Tuple[str, List[Tuple[str, str]]]:
if model_id not in self.model_id_to_att_type:
supported_models = sorted(list(self.model_id_to_att_type.keys()))
raise ValueError(
f"Unsupported model type: '{model_id}'\n"
f"Context Parallel only supports the following models: {', '.join(supported_models)}"
)
att_type = self.model_id_to_att_type[model_id]
model_specific_mappings = self.MODEL_CP_CONFIG[att_type]["model_specific"].get(model_id, [])
return att_type, model_specific_mappings
def get_model_cp_mappings(self, model_id: str, cp_type_config: CPTypeConfig) -> List[Tuple[str, str]]:
"""
Core method: Generate mappings with override logic.
Priority: model_specific > common templates
"""
mapping_dict: Dict[str, str] = {}
cp_type = cp_type_config.cp_type
att_type, model_specific_mappings = self._get_model_attention_config(model_id)
cp_att_templates = self.MODEL_CP_CONFIG[att_type]["cp_to_att_template"]
for template_name in cp_att_templates:
template = self.COMMON_CP_MAPPINGS[template_name]
mappings_to_add = template if isinstance(template, list) else [template]
for target, patch in mappings_to_add:
target_filled = target.format(model_id=model_id) if "{model_id}" in target else target
mapping_dict[target_filled] = patch
for target, patch in model_specific_mappings:
target_filled = target.format(model_id=model_id) if "{model_id}" in target else target
if target_filled in mapping_dict:
old_patch = mapping_dict[target_filled]
print(f"[Model {model_id}] OVERRIDE CONFLICT")
print(f" Target: {target_filled}")
print(f" Old patch: {old_patch}")
print(f" New patch: {patch}")
print(f" Reason: model_specific takes precedence over common templates\n")
mapping_dict[target_filled] = patch
final_mappings = [(target, patch) for target, patch in mapping_dict.items()]
final_mappings.extend(cp_type_config.cp_specific_mappings)
self.detect_duplicate_mappings(final_mappings, model_id)
return final_mappings
def apply_transformers_modules(self, modules: torch.nn.Module, cp_type_config: CPTypeConfig) -> None:
model_id = modules.config.model_type
model_patch_list = self.get_model_cp_mappings(model_id, cp_type_config)
for target_name, func_patch_name in model_patch_list:
parts = target_name.rsplit(".", 1)
if len(parts) == 2:
obj_path, method_name = parts
else:
obj_path = target_name
method_name = None
obj_parts = obj_path.rsplit(".", 1)
if len(obj_parts) == 2:
module_path, obj_name = obj_parts
else:
module_path = obj_path
obj_name = None
try:
target_module = importlib.import_module(module_path)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Failed to import module '{module_path}' for target '{target_name}'. "
f"Original error: {str(e)}"
)
if obj_name:
target_obj = getattr(target_module, obj_name)
patch_parts = func_patch_name.rsplit(".", 1)
patch_module_path = patch_parts[0]
patch_func_name = patch_parts[1]
patch_module = importlib.import_module(patch_module_path)
patch_func = getattr(patch_module, patch_func_name)
setattr(target_obj, method_name, patch_func)
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"CP Applied patch: {target_name} -> {func_patch_name}")
else:
target_func_name = obj_path.split(".")[-1]
patch_parts = func_patch_name.rsplit(".", 1)
patch_module_path = patch_parts[0]
patch_func_name = patch_parts[1]
patch_module = importlib.import_module(patch_module_path)
target_module.__dict__[target_func_name] = patch_module.__dict__[patch_func_name]
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"CP Applied patch: {target_name} -> {func_patch_name}")
@classmethod
def cp_parallelize_modules(cls, modules: torch.nn.Module, plan: CPPlanConfig) -> None:
if plan.context_parallel_type not in ["ulysses", "ring"]:
raise ValueError(f"Unsupported CP type: {plan.context_parallel_type}, only 'ulysses'/'ring' are supported")
manager = cls()
cp_type_config = CPTypeConfig(cp_type=plan.context_parallel_type)
manager.apply_transformers_modules(modules, cp_type_config)
def apply_context_parallelize_modules(modules: torch.nn.Module, plan: CPPlanConfig):
"""Compatibility wrapper for original ulysses function"""
ContextParallelMappingManager.cp_parallelize_modules(modules, plan)