"""
model_builder
"""
import logging
from .. import config
from ..compilation import get_backend
from ..core.config_resolver import ConfigResolver
from ..core.user_config import UserInputConfig
from ..transformers.custom_model_registry import get_visual
from ..transformers.model import TransformerModel
logger = logging.getLogger(__name__)
def _prepare_vl_compile(model: TransformerModel) -> bool:
logger.warning(
"Skipping compile for visual encoder: wrap visual.forward with torch._dynamo.disable "
"(small share ~20%, limited fusion benefit, current compile errors; introduces graph break)."
)
visual = get_visual(model)
if visual is not None and hasattr(visual, "forward"):
import torch._dynamo
orig_forward = visual.forward
def _wrapped_forward(*args, **kwargs):
@torch._dynamo.disable
def _call(*a, **k):
return orig_forward(*a, **k)
return _call(*args, **kwargs)
visual.forward = _wrapped_forward
return False
def build_model(user_input: UserInputConfig = None) -> TransformerModel:
"""
Build a transformer model based on the given args
:param user_input: user_input
:return: The loaded (and possibly compiled) Transformer model.
"""
config_resolver = ConfigResolver(user_input=user_input)
model_config = config_resolver.resolve()
model = TransformerModel(user_input.model_id, model_config)
use_full_graph = not user_input.allow_graph_break
if user_input.do_compile and getattr(model, "is_vl_model", False):
use_full_graph = _prepare_vl_compile(model)
if user_input.do_compile:
import torch
config.compilation.multistream.enable = bool(user_input.enable_multistream)
model = torch.compile(
model,
backend=get_backend(device_name=user_input.device),
dynamic=False,
fullgraph=use_full_graph,
)
return model