"""Single source-of-truth cache for deterministic test construction artifacts.
CACHE POLICY:
Only deterministic, pure-construction artifacts are cached:
* Hugging Face configs from ``AutoModelConfigLoader.load_config``.
* Freshly ``build_model(...)`` results.
Entries are keyed by their full determining inputs (``model_id`` for configs,
``user_config_build_cache_key`` for models).
Configs are handed out as deepcopies (cheap, mutable dicts) so a caller can
never mutate shared config state. Built models are handed out SHARED, not
copied: freshly built models hold non-leaf / meta tensors that do not support
``deepcopy``, and tests treat the built model as read-only (forward runs
produce separate runtime/event objects). Callers must NOT mutate a returned
model.
DO NOT cache objects that have run forward/compile, anything a test mutates,
or per-test scratch data.
Both pytest fixtures (``cfg_registry``, ``get_session_*`` and friends) and
unittest ``TestCase`` code paths (which cannot consume fixtures) delegate here,
so the cache is loaded at most once per session regardless of entry point.
"""
from __future__ import annotations
import copy
from tensor_cast.core.model_builder import build_model
from tensor_cast.core.user_config import UserInputConfig
from tensor_cast.transformers.model import TransformerModel
from tensor_cast.transformers.utils import AutoModelConfigLoader
_HF_CONFIG_CACHE: dict[str, object] = {}
_BUILT_MODEL_CACHE: dict[tuple, TransformerModel] = {}
def user_config_build_cache_key(user_config: UserInputConfig) -> tuple:
"""Fields that affect ConfigResolver.resolve() / build_model()."""
return (
user_config.model_id,
user_config.do_compile,
user_config.num_mtp_tokens,
user_config.num_hidden_layers_override,
user_config.quantize_linear_action,
user_config.quantize_attention_action,
user_config.remote_source,
user_config.allow_graph_break,
user_config.enable_multistream,
user_config.world_size,
user_config.tp_size,
user_config.mlp_tp_size,
user_config.lmhead_tp_size,
user_config.vision_tp_size,
user_config.ep_size,
user_config.moe_dp_size,
user_config.moe_tp_size,
user_config.enable_redundant_experts,
user_config.enable_external_shared_experts,
user_config.enable_shared_expert_tp,
user_config.host_external_shared_experts,
user_config.disable_repetition,
)
def get_hf_config(model_id: str):
"""Return a deepcopy of the session-cached Hugging Face config for ``model_id``."""
if model_id not in _HF_CONFIG_CACHE:
_HF_CONFIG_CACHE[model_id] = AutoModelConfigLoader().load_config(model_id)
return copy.deepcopy(_HF_CONFIG_CACHE[model_id])
def get_built_model(user_config: UserInputConfig) -> TransformerModel:
"""Return the session-cached ``build_model`` result for ``user_config``.
Shared (not deepcopied): built models contain non-leaf / meta tensors that do
not support ``deepcopy``, and tests treat the model as read-only. Callers must
NOT mutate the returned model.
"""
key = user_config_build_cache_key(user_config)
if key not in _BUILT_MODEL_CACHE:
_BUILT_MODEL_CACHE[key] = build_model(user_config)
return _BUILT_MODEL_CACHE[key]