import json
import os
from functools import cached_property
from pathlib import Path
from typing import Generic, Type, TypeVar, Union
import transformers
from megatron.core.transformer.module import MegatronModule
from transformers.configuration_utils import PretrainedConfig
from bridge.models.conversion import model_bridge
from bridge.models.conversion.model_bridge import MegatronModelBridge
from bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
MegatronModelT = TypeVar("MegatronModelT", bound=MegatronModule)
DataclassT = TypeVar("DataclassT")
SUPPORTED_HF_ARCHITECTURES: tuple[str, ...] = (
"ForCausalLM",
"ForConditionalGeneration",
"NemotronH_Nano_VL_V2",
)
CLASS_MODULE_MAPPING = {
"WanTransformer3DModel": ("bridge.models", "WanTransformer3DModel"),
"HunyuanVideo_1_5_DiffusionTransformer": ("bridge.models", "HunyuanVideo_1_5_DiffusionTransformer")
}
SUPPORTED_HF_ARCHITECTURES_DISPLAY = " or ".join(f"'{s}'" for s in SUPPORTED_HF_ARCHITECTURES)
class AutoBridge(Generic[MegatronModelT]):
def __init__(self, hf_pretrained: PreTrainedCausalLM | PretrainedConfig):
if not isinstance(hf_pretrained, (PreTrainedCausalLM, PretrainedConfig)):
raise ValueError("hf_pretrained must be a PreTrainedCausalLM or PretrainedConfig instance")
self.hf_pretrained: PreTrainedCausalLM | PretrainedConfig = hf_pretrained
@classmethod
def from_hf_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoBridge":
try:
return cls(PreTrainedCausalLM.from_pretrained(path, **kwargs))
except Exception as e:
raise ValueError(f"Failed to load model with AutoBridge: {e}") from e
def load_hf_weights(
self,
model: list[MegatronModelT],
hf_path: str | Path | None = None,
allowed_mismatched_params: list[str] | None = None,
) -> None:
if hf_path is None:
if not isinstance(self.hf_pretrained, PreTrainedCausalLM):
raise ValueError("hf_path is required when hf_pretrained is not a PreTrainedCausalLM instance")
pre_trained = self.hf_pretrained
else:
trust_remote_code = getattr(self.hf_pretrained, "trust_remote_code", False)
pre_trained = PreTrainedCausalLM.from_pretrained(hf_path, trust_remote_code=trust_remote_code)
self._model_bridge.load_weights_hf_to_megatron(
pre_trained, model, allowed_mismatched_params=allowed_mismatched_params
)
return model
@property
def _model_bridge(self) -> "MegatronModelBridge":
return model_bridge.get_model_bridge(self._model_architecture)
@cached_property
def _model_architecture(self):
config_path = os.path.join(self.hf_pretrained.model_name_or_path, 'config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
else:
config = {}
if isinstance(config, dict) and "_class_name" in config:
class_name = config["_class_name"]
return self._resolve_generation_model_architecture(class_name, config)
if isinstance(self.hf_pretrained, PreTrainedCausalLM):
config = self.hf_pretrained.config
else:
config = self.hf_pretrained
architectures = getattr(config, "architectures", [])
if not architectures:
raise ValueError(
"\n✗ No architectures found in model config\n\n"
"The model configuration does not specify any architectures.\n"
"This is required for determining the model type."
)
causal_lm_arch = None
for architecture_name in architectures:
if architecture_name.endswith(SUPPORTED_HF_ARCHITECTURES):
causal_lm_arch = architecture_name
break
if not causal_lm_arch:
raise ValueError(
f"\n✗ No CausalLM architecture found\n\n"
f"Model architectures: {architectures}\n\n"
f"None of the architectures end with {SUPPORTED_HF_ARCHITECTURES_DISPLAY}.\n"
f"This bridge only supports causal language models.\n"
f"For other model types, use a different bridge class."
)
try:
return getattr(transformers, causal_lm_arch)
except AttributeError as e:
raise ValueError(
f"\n✗ Architecture class '{causal_lm_arch}' not found in transformers\n\n"
f"This could mean:\n"
f"1. The model requires a newer version of transformers\n"
f"2. The model uses a custom modeling file not in the standard library\n"
f"3. There's a typo in the architecture name\n\n"
f"Please verify your transformers installation and the model requirements."
) from e
def _resolve_generation_model_architecture(self, class_name: str, config) -> Type:
if class_name not in CLASS_MODULE_MAPPING:
raise KeyError(f"No mapping found for: {class_name}")
try:
import importlib
module_name, actual_class_name = CLASS_MODULE_MAPPING[class_name]
module = importlib.import_module(module_name)
return getattr(module, actual_class_name)
except (ImportError, AttributeError) as e:
raise ImportError(f"Unable import {module_name}.{actual_class_name}: {e}") from e
def _get_model_instance(self, model: list[MegatronModelT]) -> MegatronModelT:
model_instance = model[0]
while hasattr(model_instance, "module"):
model_instance = model_instance.module
return model_instance
def __repr__(self) -> str:
class_name = self.__class__.__name__
lines_for_build = []
hf_repr_actual_lines = repr(self.hf_pretrained).splitlines()
if hf_repr_actual_lines:
lines_for_build.append(f" (hf_pretrained): {hf_repr_actual_lines[0]}")
for line in hf_repr_actual_lines[1:]:
lines_for_build.append(f" {line}")
else:
lines_for_build.append(" (hf_pretrained): ")
mb_repr_actual_lines = repr(self._model_bridge).splitlines()
if mb_repr_actual_lines:
lines_for_build.append(f" (model_bridge): {mb_repr_actual_lines[0]}")
for line in mb_repr_actual_lines[1:]:
lines_for_build.append(f" {line}")
else:
lines_for_build.append(" (model_bridge): ")
return f"{class_name}(\n" + "\n".join(lines_for_build) + "\n)"