import contextlib
import logging
import os
import torch
from typing import List, Optional, Tuple
from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
from transformers.quantizers.auto import AutoQuantizationConfig
from transformers.utils.quantization_config import (
CompressedTensorsConfig,
FineGrainedFP8Config,
QuantizationConfigMixin,
)
from .custom_model_registry import get_model_profile
from ..layers.mla import MultiheadLatentAttentionBase
from ..model_config import AttentionQuantConfig, ModelConfig, RemoteSource
logger = logging.getLogger(__name__)
_MODELSCOPE_WEIGHT_IGNORE_PATTERNS = [
"*.safetensors",
"*.safetensors.index.json",
"*.bin",
"*.pt",
"*.pth",
"*.ckpt",
"*.h5",
"*.npz",
"*.onnx",
"*.gguf",
"*.zip",
"*.tar",
"*.tar.gz",
]
def _modelscope_snapshot_config_only(model_id: str) -> str:
"""
Materialize a local Hub directory with config and code files only (no weight tensors).
ModelScope ``AutoConfig.from_pretrained`` may otherwise sync the full repository.
"""
from modelscope import snapshot_download
try:
return snapshot_download(model_id, ignore_patterns=_MODELSCOPE_WEIGHT_IGNORE_PATTERNS)
except TypeError as e:
if "ignore_patterns" not in str(e):
raise
return snapshot_download(model_id, ignore_file_pattern=_MODELSCOPE_WEIGHT_IGNORE_PATTERNS)
def replace_module(model, name: str, new_module: torch.nn.Module):
path = name.split(".")
parent_name = ".".join(path[:-1])
child_name = path[-1]
parent_module = model
if parent_name:
parent_module = model.get_submodule(parent_name)
setattr(parent_module, child_name, new_module)
def strip_module_name(name: str) -> str:
"""Strip `_inner` module name from the given module path name"""
stripped = name.removeprefix("_inner.")
stripped_before = name
while stripped != stripped_before:
stripped_before = stripped
stripped = stripped_before.removeprefix("_inner.")
stripped = stripped.replace("._inner.", ".")
stripped_before = stripped
stripped = stripped_before.removesuffix("._inner")
while stripped != stripped_before:
stripped_before = stripped
stripped = stripped_before.removesuffix("._inner")
return stripped
def get_attention_quant_config(model, layer_idx) -> Optional[AttentionQuantConfig]:
if model.model_config.mla_config is not None:
for _, module in model._inner.named_modules():
if (
isinstance(module, MultiheadLatentAttentionBase)
and hasattr(module, "layer_idx")
and module.layer_idx == layer_idx
and (attn_quant_config := module.quant_config) is not None
):
return attn_quant_config
if hasattr(model, "attention_by_layers") and layer_idx in model.attention_by_layers:
return model.attention_by_layers[layer_idx].quant_config
return None
@contextlib.contextmanager
def init_on_device_without_buffers(device: torch.device):
"""
A context manager under which models are initialized with all
parameters on the specified device. However, buffers are not
initialized on specified device.
Args:
device (`torch.device`):
Device to initialize all parameters on.
"""
old_register_parameter = torch.nn.Module.register_parameter
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
tensor_constructors_to_patch = [
"empty",
"zeros",
"ones",
"arange",
"randn",
"rand",
"randint",
]
old_tensor_constructors = {}
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
try:
torch.nn.Module.register_parameter = register_empty_parameter
for torch_function_name in tensor_constructors_to_patch:
old_tensor_constructors[torch_function_name] = getattr(torch, torch_function_name)
setattr(
torch,
torch_function_name,
patch_tensor_constructor(getattr(torch, torch_function_name)),
)
yield
finally:
torch.nn.Module.register_parameter = old_register_parameter
for torch_function_name, old_torch_function in old_tensor_constructors.items():
setattr(torch, torch_function_name, old_torch_function)
@contextlib.contextmanager
def patch_find_packed_sequence_indices_for_meta():
"""
This function tells the model which tokens belong to the same sentence
when multiple sentences are packed into one batch.
But during performance modeling (e.g., estimating memory or compute),
we don’t care about how sequences are packed—we only need the model’s structure (like top_k=2, num_experts=64).
Returning None simply means “assume no packing,” which is a safe and reasonable default for modeling.
Even if real inference uses packing, it doesn’t change the model’s architecture, parameters,
or compute graph—so performance estimates remain accurate.
"""
from transformers import masking_utils
original_func = masking_utils.find_packed_sequence_indices
def safe_find_packed_sequence_indices(position_ids: torch.Tensor):
if position_ids.device.type == "meta":
return None
return original_func(position_ids)
masking_utils.find_packed_sequence_indices = safe_find_packed_sequence_indices
try:
yield
finally:
masking_utils.find_packed_sequence_indices = original_func
class AutoModelConfigLoader:
modules_to_not_convert_map = {
"fp8": "modules_to_not_convert",
"fp_quant": "modules_to_not_convert",
"compressed-tensors": "ignore",
}
def __init__(self):
self.is_transformers_natively_supported: bool = False
@staticmethod
def is_model_type_different(config: PretrainedConfig) -> Tuple[bool, str]:
"""
Check whether the model type has changed.
for example: kimi_k2's real model_type is deepseek_v3
Args:
config: hf_config.
Returns:
tuple: (is_different, type)
- (False, original_type) if the types are the same
- (True, current_type) if the types are different
"""
maybe_real_type = config.to_dict()["model_type"]
if maybe_real_type and config.model_type != maybe_real_type:
return True, maybe_real_type
return False, config.model_type
@staticmethod
def check_model_path(path):
"""
Check whether a config.json file and Python files starting with 'configuration' exist in the specified path.
Args:
path (str): The directory path to check.
Returns:
dict: A dictionary containing the check results:
- has_config_json (bool): Whether config.json exists.
- has_configuration_py (bool): Whether any Python file starting with 'configuration' exists.
- configuration_py_files (list[str]): List of Python files starting with 'configuration'.
"""
result = {
"has_config_json": False,
"has_configuration_py": False,
"configuration_py_files": [],
}
if not os.path.exists(path) or not os.path.isdir(path):
return result
for file in os.listdir(path):
if file == "config.json":
result["has_config_json"] = True
elif file.startswith("configuration") and file.endswith(".py"):
result["has_configuration_py"] = True
result["configuration_py_files"].append(file)
return result
def load_config(self, model_id: str, remote_source: str = RemoteSource.huggingface) -> Optional[PretrainedConfig]:
"""
load config
"""
if remote_source == RemoteSource.modelscope:
from modelscope import AutoConfig
else:
from transformers import AutoConfig
if remote_source == RemoteSource.modelscope and not os.path.isdir(model_id):
resolved = _modelscope_snapshot_config_only(model_id)
logger.info(
"ModelScope Hub id %s resolved to config-only snapshot at %s",
model_id,
resolved,
)
model_id = resolved
check_model_path_res = self.check_model_path(model_id)
if check_model_path_res["has_config_json"] and not check_model_path_res["has_configuration_py"]:
model_id = os.path.join(
model_id, "config.json"
)
try:
hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=False)
self.is_transformers_natively_supported = True
except Exception:
hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
is_diff, real_type = self.is_model_type_different(hf_config)
if is_diff:
logger.warning("Using a model of type %s to instantiate again.", real_type)
hf_config = AutoConfig.for_model(real_type).from_dict(hf_config.to_dict())
self.is_transformers_natively_supported = True
logger.info(
"is_transformers_natively_supported = %s",
self.is_transformers_natively_supported,
)
return hf_config
def _apply_hf_config_patches(self, hf_config: PretrainedConfig, model_id: str):
model_type = getattr(hf_config, "model_type", None)
if model_type is None:
return
profile = get_model_profile(model_type)
if profile is not None and profile.hf_config_patch_method is not None:
try:
profile.hf_config_patch_method(hf_config, model_id)
except Exception as e:
logger.warning(f"Failed to apply HF config patches for {model_type}: {e}")
def load_model(
self,
hf_config: PretrainedConfig,
dtype: torch.dtype,
remote_source: str = RemoteSource.huggingface,
**kwargs,
) -> Optional[PreTrainedModel]:
trust_remote_code = not self.is_transformers_natively_supported
if "trust_remote_code" in kwargs:
trust_remote_code = kwargs.pop("trust_remote_code")
return self.try_to_load_model(
hf_config,
dtype=dtype,
trust_remote_code=trust_remote_code,
remote_source=remote_source,
)
@staticmethod
def load_quant_config(hf_config: PretrainedConfig) -> QuantizationConfigMixin:
quant_config = AutoQuantizationConfig.from_dict(hf_config.quantization_config)
return quant_config
@staticmethod
def get_modules_to_not_convert(quant_config) -> List[Optional[str]]:
modules_to_not_convert = []
if isinstance(quant_config, FineGrainedFP8Config):
modules_to_not_convert = quant_config.modules_to_not_convert
elif isinstance(quant_config, CompressedTensorsConfig):
modules_to_not_convert = quant_config.quantization_config.ignore
return modules_to_not_convert
def auto_load_model_and_config(
self, model_id: str, model_config: ModelConfig
) -> Tuple[PretrainedConfig, PreTrainedModel]:
"""
Load the model and config using model_id and model_config.
"""
hf_config = self.load_config(model_id, remote_source=model_config.remote_source)
if model_config.num_hidden_layers_override:
hf_config.num_hidden_layers = model_config.num_hidden_layers_override
self._apply_hf_config_patches(hf_config, model_id)
hf_model = self.load_model(hf_config, model_config.dtype, remote_source=model_config.remote_source)
return hf_config, hf_model
@staticmethod
def try_to_load_model(*args, remote_source: str = RemoteSource.huggingface, **kwarg):
if remote_source == RemoteSource.modelscope:
from modelscope import AutoModel
else:
from transformers import AutoModel
try:
hf_model = AutoModel.from_config(*args, **kwarg)
except Exception:
hf_model = AutoModelForCausalLM.from_config(*args, **kwarg)
return hf_model