import importlib
from mindspeed_mm.models.transformers.custom_model_registry import get_model_class_global
class ModelHub:
@staticmethod
def build(config, transformer_config):
"""
Constructs and returns the corresponding model class based on the provided configuration.
This method determines the model class in the following priority order:
1. First checks the `model_id` field in the `config` object (typically from the "model_id"
field in model.json). If present, it looks up the corresponding model class in
ModelHub.MODEL_MAPPINGS.
2. If `model_id` is not available, it attempts to use the `architectures` field from
`transformer_config` (usually from the config.json file in the Hugging Face model
directory) to dynamically load the model class from the transformers library.
3. If `architectures` is not available, it falls back to the `model_type` field in
`transformer_config`, and looks up the corresponding model class in
ModelHub.MODEL_MAPPINGS.
Args:
config (object): Configuration object expected to contain a `model_id` attribute.
transformer_config (transformers.PretrainedConfig): Model configuration object
from the transformers library, containing fields such as `architectures`
and `model_type`.
Returns:
type: The corresponding model class if found.
Raises:
ValueError: If no model class can be determined after all lookup attempts.
Example:
model_class = ModelBuilder.build(config, transformer_config)
model = model_class()
"""
architectures = getattr(transformer_config, "architectures", [])
model_type = getattr(transformer_config, "model_type", None)
model_cls = None
model_id = getattr(config, "model_id", None)
if model_id:
model_cls = get_model_class_global(model_id)
return model_cls
if architectures:
transformers_module = importlib.import_module("transformers")
model_cls = getattr(transformers_module, architectures[0], None)
if model_cls is not None:
return model_cls
if model_type:
model_cls = get_model_class_global(model_type)
if model_cls is not None:
return model_cls
if model_cls is None:
raise ValueError("load model from config failed")
return model_cls