import re
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Union
def render_builtin_profile_draft(
profile: Dict[str, Any],
patch_method_name: Optional[str] = None,
header: Optional[Iterable[str]] = None,
) -> str:
args = dict(profile)
imports, args = _normalize_callable_args(args)
lines = list(header or _default_header())
lines.extend(
[
"from tensor_cast.transformers.custom_model_registry import ModelProfile, register_model_profile",
"",
"",
]
)
if imports:
lines[2:2] = imports + [""]
if patch_method_name:
lines.extend(
[
f"def {patch_method_name}(_model):",
' """Simulation-only runtime patch generated by adapter doctor.',
"",
" Replace this placeholder with the reviewed patch discovery output, then",
" rerun doctor dry-run, smoke, and evidence verification.",
' """',
' raise NotImplementedError("Patch discovery output has not been implemented yet.")',
"",
"",
]
)
if patch_method_name:
args["patch_method"] = _RawPython(patch_method_name)
rendered_args = ",\n".join(
f" {key}={_render_value(value)}"
for key, value in sorted(args.items())
if value is not None and value != {} and value != []
)
lines.extend(
[
"register_model_profile(",
" ModelProfile(",
rendered_args,
" )",
")",
"",
]
)
return "\n".join(lines)
def write_builtin_profile_draft(
profile: Dict[str, Any],
output_path: Union[str, Path],
patch_method_name: Optional[str] = None,
) -> Path:
path = Path(output_path)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(
render_builtin_profile_draft(
profile,
patch_method_name=patch_method_name,
),
encoding="utf-8",
)
return path
def default_builtin_profile_path(model_type: str) -> str:
safe = re.sub(r"[^0-9a-zA-Z_]+", "_", model_type).strip("_").lower()
return f"tensor_cast/transformers/builtin_model/{safe}.py"
class _RawPython(str):
pass
def _render_value(value: Any) -> str:
if isinstance(value, _RawPython):
return str(value)
return repr(value)
def _normalize_callable_args(args: Dict[str, Any]) -> tuple[list[str], Dict[str, Any]]:
callable_keys = {
"custom_expert_module_type",
"hf_config_patch_method",
"mla_module_class_type",
"patch_method",
}
imports = []
normalized = dict(args)
used_symbols = set()
for key in callable_keys:
value = normalized.get(key)
if not isinstance(value, str) or "." not in value:
continue
module_name, symbol = value.rsplit(".", maxsplit=1)
if not module_name or not symbol:
continue
alias = symbol
if alias in used_symbols:
alias = f"{symbol}_{key}"
imports.append(f"from {module_name} import {symbol} as {alias}")
else:
imports.append(f"from {module_name} import {symbol}")
used_symbols.add(alias)
normalized[key] = _RawPython(alias)
return sorted(imports), normalized
def _default_header() -> Iterable[str]:
return (
"# Generated by TensorCast adapter doctor.",
"# Review before enabling in production.",
)