import sys
from pathlib import Path
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
import torch
from transformers import (
AutoConfig,
AutoImageProcessor,
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
GenerationConfig,
PreTrainedTokenizer,
ProcessorMixin,
)
from transformers.generation.utils import GenerateOutput
from bridge.models.hf_pretrained.base import PreTrainedBase
from bridge.models.hf_pretrained.safe_config_loader import safe_load_config_with_retry
if sys.version_info >= (3, 12):
from typing import TypedDict, Unpack
else:
from typing_extensions import TypedDict, Unpack
CausalLMType = TypeVar("CausalLMType", bound=AutoModelForCausalLM)
class PreTrainedCausalLM(PreTrainedBase, Generic[CausalLMType]):
"""
A generic class for Pretrained Causal Language Models with lazy loading.
"""
ARTIFACTS = ["tokenizer"]
OPTIONAL_ARTIFACTS = ["generation_config", "processor", "image_processor"]
def __init__(
self,
model_name_or_path: Optional[Union[str, Path]] = None,
device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
**kwargs,
):
"""
Initialize a Pretrained Causal LM with lazy loading.
Args:
model_name_or_path: HuggingFace model identifier or local path
device: Device to load model on (e.g., 'cuda', 'cpu')
torch_dtype: Data type to load model in (e.g., torch.float16)
trust_remote_code: Whether to trust remote code when loading
**kwargs: Additional arguments passed to from_pretrained methods
"""
self._model_name_or_path = model_name_or_path
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.torch_dtype = torch_dtype
self.trust_remote_code = trust_remote_code
super().__init__(**kwargs)
if model_name_or_path and trust_remote_code:
self._original_source_path = model_name_or_path
def _load_model(self) -> CausalLMType:
"""Load the model."""
if self.model_name_or_path is None:
raise ValueError("model_name_or_path must be provided to load model")
model_kwargs = {
"trust_remote_code": self.trust_remote_code,
**self.init_kwargs,
}
if self.torch_dtype is not None:
model_kwargs["torch_dtype"] = self.torch_dtype
config = getattr(self, "_config", None)
if config is not None:
model_kwargs["config"] = config
model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, **model_kwargs)
model = model.to(self.device)
generation_config = getattr(self, "_generation_config", None)
if generation_config is not None and hasattr(model, "generation_config"):
model.generation_config = generation_config
return model
def _load_config(self) -> AutoConfig:
"""Load the model config with thread-safety protection."""
if self.model_name_or_path is None:
raise ValueError("model_name_or_path must be provided to load config")
return safe_load_config_with_retry(
self.model_name_or_path,
trust_remote_code=self.trust_remote_code,
**self.init_kwargs,
)
@property
def auto_map_model_class(self) -> Optional[str]:
"""Get the AutoModelForCausalLM class from the config."""
config = self.config
auto_map = getattr(config, "auto_map", None)
if auto_map and "AutoModelForCausalLM" in auto_map:
auto_map_class = auto_map["AutoModelForCausalLM"]
return str(auto_map_class)
return None
@property
def model_name_or_path(self) -> Optional[Union[str, Path]]:
"""Return the model name or path."""
return self._model_name_or_path
@classmethod
def from_pretrained(
cls,
model_name_or_path: Union[str, Path],
device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
**kwargs,
) -> "PreTrainedCausalLM[CausalLMType]":
"""
Create a PreTrainedCausalLM instance for lazy loading.
Args:
model_name_or_path: HuggingFace model identifier or local path
device: Device to load model on
torch_dtype: Data type to load model in
trust_remote_code: Whether to trust remote code
**kwargs: Additional arguments for from_pretrained methods
Returns:
PreTrainedCausalLM instance configured for lazy loading
"""
return cls(
model_name_or_path=model_name_or_path,
device=device,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
**kwargs,
)
@property
def dtype(self) -> Optional[torch.dtype]:
"""Get model's dtype if loaded."""
if self.has_model:
try:
return next(self.model.parameters()).dtype
except StopIteration:
return None
return None
@property
def num_parameters(self) -> Optional[int]:
"""Get total number of parameters if model is loaded."""
if self.has_model:
return sum(p.numel() for p in self.model.parameters())
return None