import dataclasses
import re
from typing import Any, Dict, List, Optional

from .ai_task import AiAssistanceTask


@dataclasses.dataclass(frozen=True)
class PatchDiscoveryFinding:
    category: str
    message: str
    confidence: str
    evidence: str
    suggested_action: str

    def to_dict(self) -> Dict[str, Any]:
        return dataclasses.asdict(self)


@dataclasses.dataclass(frozen=True)
class PatchDiscoveryReport:
    model_type: Optional[str]
    suggested_patch_method_name: Optional[str]
    findings: List[PatchDiscoveryFinding]
    prompt_template: str
    ai_tasks: List[AiAssistanceTask]

    @property
    def requires_patch(self) -> bool:
        return bool(self.findings)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "model_type": self.model_type,
            "suggested_patch_method_name": self.suggested_patch_method_name,
            "requires_patch": self.requires_patch,
            "findings": [finding.to_dict() for finding in self.findings],
            "prompt_template": self.prompt_template,
            "ai_tasks": [task.to_dict() for task in self.ai_tasks],
        }


def classify_patch_failure(
    failure_text: str,
    model_type: Optional[str] = None,
    failed_command: Optional[str] = None,
) -> PatchDiscoveryReport:
    text = failure_text or ""
    lowered = text.lower()
    findings: List[PatchDiscoveryFinding] = []

    def add(category: str, message: str, confidence: str, evidence: str, action: str) -> None:
        if any(item.category == category and item.evidence == evidence for item in findings):
            return
        findings.append(
            PatchDiscoveryFinding(
                category=category,
                message=message,
                confidence=confidence,
                evidence=evidence,
                suggested_action=action,
            )
        )

    if "get_placeholder_mask" in text or "placeholder" in lowered and "image" in lowered:
        add(
            "PLACEHOLDER_STRICT_CHECK",
            "A strict multimodal placeholder/token validation path appears in the failure.",
            "high",
            _evidence_snippet(text, "get_placeholder_mask") or _evidence_snippet(text, "placeholder"),
            "Patch the simulation path to skip value-dependent placeholder validation while preserving tensor shapes.",
        )
    if "nonzero" in lowered or "boolean mask" in lowered or re.search(r"\[[^\]]*mask[^\]]*\]", lowered):
        add(
            "DYNAMIC_SHAPE_OP",
            "A data-dependent boolean mask or nonzero path appears in the failure.",
            "high",
            _evidence_snippet(text, "nonzero") or _evidence_snippet(text, "mask"),
            "Replace the meta-mode path with a shape-stable branch or bypass the value-dependent indexing.",
        )
    if ".item()" in lowered or "tensor.item" in lowered or "cannot be converted to scalar" in lowered:
        add(
            "META_TENSOR_VALUE_READ",
            "The failure suggests a tensor value read that is unsafe in meta mode.",
            "medium",
            _evidence_snippet(text, "item"),
            "Move the branch to shape/config metadata or guard it behind a simulation-safe path.",
        )
    if "graph break" in lowered or "torch.compile" in lowered or "dynamo" in lowered:
        add(
            "COMPILE_GRAPH_BREAK",
            "The failure mentions compile or Dynamo graph break behavior.",
            "medium",
            _evidence_snippet(text, "graph break") or _evidence_snippet(text, "dynamo"),
            "Patch Python control flow so compile mode sees a stable graph.",
        )
    if "unexpected keyword" in lowered or "positional argument" in lowered or "signature" in lowered:
        add(
            "SIGNATURE_MISMATCH",
            "The failure suggests wrapper and source method signatures diverge.",
            "medium",
            _evidence_snippet(text, "unexpected keyword") or _evidence_snippet(text, "signature"),
            "Filter unsupported kwargs or mirror the installed transformers method signature.",
        )
    if "unsupported" in lowered and ("op" in lowered or "operator" in lowered):
        add(
            "UNSUPPORTED_OP_ROUTING",
            "The failure mentions an unsupported operator path.",
            "medium",
            _evidence_snippet(text, "unsupported"),
            "Route the model source path to an existing TensorCast op or add explicit unsupported-semantics work.",
        )

    method_name = _suggest_patch_method_name(model_type) if findings else None
    suspected_locations = _extract_traceback_locations(text)
    prompt_template = build_patch_discovery_prompt(
        failure_text=text,
        model_type=model_type,
        failed_command=failed_command,
        suggested_patch_method_name=method_name,
        findings=findings,
        suspected_locations=suspected_locations,
    )
    ai_tasks = []
    if findings:
        ai_tasks.append(
            _build_patch_authoring_task(
                failure_text=text,
                model_type=model_type,
                failed_command=failed_command,
                suggested_patch_method_name=method_name,
                findings=findings,
                suspected_locations=suspected_locations,
                prompt_text=prompt_template,
            )
        )
    return PatchDiscoveryReport(
        model_type=model_type,
        suggested_patch_method_name=method_name,
        findings=findings,
        prompt_template=prompt_template,
        ai_tasks=ai_tasks,
    )


def build_patch_discovery_prompt(
    failure_text: str,
    model_type: Optional[str],
    failed_command: Optional[str],
    suggested_patch_method_name: Optional[str],
    findings: List[PatchDiscoveryFinding],
    suspected_locations: Optional[List[Dict[str, Any]]] = None,
) -> str:
    finding_lines = "\n".join(
        f"- {item.category}: {item.message} Suggested action: {item.suggested_action}" for item in findings
    )
    location_lines = "\n".join(_render_location(item) for item in suspected_locations or [])
    method_name = suggested_patch_method_name or "patch_method_for_<model_type>"
    return (
        "You are adapting a TensorCast built-in model profile.\n"
        "Author a patch_method draft only from the stacktrace, installed transformers source, "
        "and the simulation goal. Do not rely on any existing built-in profile for the same model. "
        "Do not assume TensorCast doctor has generated correct patch code; doctor only produced "
        "deterministic evidence and constraints.\n\n"
        f"model_type: {model_type or '<unknown>'}\n"
        f"failed_command: {failed_command or '<not provided>'}\n"
        f"suggested_patch_method_name: {method_name}\n\n"
        "Findings:\n"
        f"{finding_lines or '- No deterministic patch category was recognized.'}\n\n"
        "Suspected traceback locations:\n"
        f"{location_lines or '- No traceback frame was parsed. Inspect the full failure text.'}\n\n"
        "Constraints:\n"
        "- Patch only TensorCast simulation compatibility, not real model semantics.\n"
        "- Preserve tensor shapes, module outputs, and downstream call signatures required by TensorCast.\n"
        "- Explain any real-model checks intentionally bypassed in simulation mode.\n"
        "- Keep the patch scoped to the built-in model adapter and register it through ModelProfile.patch_method.\n"
        "- Do not copy an existing built-in profile for the same model as the answer.\n\n"
        "Required output:\n"
        "- class and method names to patch\n"
        "- original failure reason\n"
        "- simulation semantics preserved by the patch\n"
        "- real-model semantics intentionally bypassed, if any\n"
        "- code diff for the built-in model adapter\n"
        "- verification commands: doctor dry-run, smoke, evidence verifier\n\n"
        "Failure text:\n"
        f"{failure_text.strip()}\n"
    )


def _suggest_patch_method_name(model_type: Optional[str]) -> Optional[str]:
    if not model_type:
        return None
    safe = re.sub(r"[^0-9a-zA-Z_]+", "_", model_type).strip("_").lower()
    return f"patch_method_for_{safe}" if safe else None


def _build_patch_authoring_task(
    failure_text: str,
    model_type: Optional[str],
    failed_command: Optional[str],
    suggested_patch_method_name: Optional[str],
    findings: List[PatchDiscoveryFinding],
    suspected_locations: List[Dict[str, Any]],
    prompt_text: str,
) -> AiAssistanceTask:
    return AiAssistanceTask(
        task_type="PATCH_METHOD_AUTHORING",
        title="Author TensorCast model adapter patch_method",
        summary=(
            "A runtime failure suggests the installed model source needs a TensorCast "
            "simulation-only patch_method. Doctor produced deterministic evidence and "
            "a prompt for an AI assistant; it did not generate patch code."
        ),
        model_type=model_type,
        evidence={
            "failed_command": failed_command,
            "failure_text": failure_text.strip(),
            "findings": [finding.to_dict() for finding in findings],
            "suggested_patch_method_name": suggested_patch_method_name,
        },
        suspected_locations=suspected_locations,
        constraints=[
            "Use installed transformers source and the failure stacktrace as the source of truth.",
            "Patch only TensorCast simulation compatibility paths.",
            "Preserve tensor shapes, output structure, and downstream call signatures.",
            "Document real-model checks or value-dependent paths intentionally bypassed in simulation.",
            "Register the reviewed patch through ModelProfile.patch_method.",
            "Do not copy an existing built-in profile for the same model as the answer.",
        ],
        required_output=[
            "Class and method names to patch.",
            "Original failure reason.",
            "Patch method code diff for the built-in model adapter.",
            "Simulation semantics preserved by the patch.",
            "Real-model semantics intentionally bypassed, if any.",
            "Verification commands to rerun.",
        ],
        verification_commands=[
            "python -m cli.inference.model_adapter doctor --from-command-file <command.txt> "
            "--patch-failure-file <failure.log>",
            "python -m cli.inference.text_generate <model_id> <original simulation options>",
            "python -m cli.inference.model_adapter verify --evidence-file <evidence.yaml>",
        ],
        prompt_text=prompt_text,
    )


def _extract_traceback_locations(text: str) -> List[Dict[str, Any]]:
    locations = []
    seen = set()
    frame_pattern = re.compile(r'File "(?P<file>[^"]+)", line (?P<line>\d+), in (?P<function>[^\n]+)')
    for match in frame_pattern.finditer(text or ""):
        location = {
            "file": match.group("file"),
            "line": int(match.group("line")),
            "function": match.group("function").strip(),
        }
        key = (location["file"], location["line"], location["function"])
        if key in seen:
            continue
        seen.add(key)
        locations.append(location)
    return locations


def _render_location(location: Dict[str, Any]) -> str:
    file_name = location.get("file", "<unknown>")
    line = location.get("line", "<unknown>")
    function = location.get("function", "<unknown>")
    return f"- {file_name}:{line} in {function}"


def _evidence_snippet(text: str, keyword: str, radius: int = 120) -> str:
    if not text or not keyword:
        return ""
    index = text.lower().find(keyword.lower())
    if index < 0:
        return ""
    start = max(0, index - radius)
    end = min(len(text), index + len(keyword) + radius)
    return text[start:end].strip()