import re
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Union


def render_builtin_profile_draft(
    profile: Dict[str, Any],
    patch_method_name: Optional[str] = None,
    header: Optional[Iterable[str]] = None,
) -> str:
    args = dict(profile)
    imports, args = _normalize_callable_args(args)
    lines = list(header or _default_header())
    lines.extend(
        [
            "from tensor_cast.transformers.custom_model_registry import ModelProfile, register_model_profile",
            "",
            "",
        ]
    )
    if imports:
        lines[2:2] = imports + [""]
    if patch_method_name:
        lines.extend(
            [
                f"def {patch_method_name}(_model):",
                '    """Simulation-only runtime patch generated by adapter doctor.',
                "",
                "    Replace this placeholder with the reviewed patch discovery output, then",
                "    rerun doctor dry-run, smoke, and evidence verification.",
                '    """',
                '    raise NotImplementedError("Patch discovery output has not been implemented yet.")',
                "",
                "",
            ]
        )

    if patch_method_name:
        args["patch_method"] = _RawPython(patch_method_name)
    rendered_args = ",\n".join(
        f"        {key}={_render_value(value)}"
        for key, value in sorted(args.items())
        if value is not None and value != {} and value != []
    )
    lines.extend(
        [
            "register_model_profile(",
            "    ModelProfile(",
            rendered_args,
            "    )",
            ")",
            "",
        ]
    )
    return "\n".join(lines)


def write_builtin_profile_draft(
    profile: Dict[str, Any],
    output_path: Union[str, Path],
    patch_method_name: Optional[str] = None,
) -> Path:
    path = Path(output_path)
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(
        render_builtin_profile_draft(
            profile,
            patch_method_name=patch_method_name,
        ),
        encoding="utf-8",
    )
    return path


def default_builtin_profile_path(model_type: str) -> str:
    safe = re.sub(r"[^0-9a-zA-Z_]+", "_", model_type).strip("_").lower()
    return f"tensor_cast/transformers/builtin_model/{safe}.py"


class _RawPython(str):
    pass


def _render_value(value: Any) -> str:
    if isinstance(value, _RawPython):
        return str(value)
    return repr(value)


def _normalize_callable_args(args: Dict[str, Any]) -> tuple[list[str], Dict[str, Any]]:
    callable_keys = {
        "custom_expert_module_type",
        "hf_config_patch_method",
        "mla_module_class_type",
        "patch_method",
    }
    imports = []
    normalized = dict(args)
    used_symbols = set()
    for key in callable_keys:
        value = normalized.get(key)
        if not isinstance(value, str) or "." not in value:
            continue
        module_name, symbol = value.rsplit(".", maxsplit=1)
        if not module_name or not symbol:
            continue
        alias = symbol
        if alias in used_symbols:
            alias = f"{symbol}_{key}"
            imports.append(f"from {module_name} import {symbol} as {alias}")
        else:
            imports.append(f"from {module_name} import {symbol}")
        used_symbols.add(alias)
        normalized[key] = _RawPython(alias)
    return sorted(imports), normalized


def _default_header() -> Iterable[str]:
    return (
        "# Generated by TensorCast adapter doctor.",
        "# Review before enabling in production.",
    )