import dataclasses
from typing import Dict, List, Optional

from .actual import ActualSummary
from .evidence import EvidenceCase, ExpectedOp


@dataclasses.dataclass(frozen=True)
class VerificationIssue:
    category: str
    message: str
    severity: str = "error"
    expected: Optional[object] = None
    actual: Optional[object] = None
    evidence_path: str = ""

    def to_dict(self) -> Dict[str, object]:
        return {
            "category": self.category,
            "message": self.message,
            "severity": self.severity,
            "expected": self.expected,
            "actual": self.actual,
            "evidence_path": self.evidence_path,
        }


@dataclasses.dataclass(frozen=True)
class VerificationReport:
    case_name: str
    passed: bool
    issues: List[VerificationIssue]

    def issues_by_category(self) -> Dict[str, List[VerificationIssue]]:
        grouped: Dict[str, List[VerificationIssue]] = {}
        for issue in self.issues:
            grouped.setdefault(issue.category, []).append(issue)
        return grouped

    def to_dict(self) -> Dict[str, object]:
        return {
            "case_name": self.case_name,
            "passed": self.passed,
            "issues": [issue.to_dict() for issue in self.issues],
        }


def _format_expected_count(expected_op: ExpectedOp) -> object:
    if expected_op.count is not None:
        return expected_op.count
    return {"min": expected_op.count_min, "max": expected_op.count_max}


def _severity_for_expected_op(expected_op: ExpectedOp) -> str:
    return "warning" if expected_op.confidence.lower() in {"low", "medium"} else "error"


def _accepted_gap_matches(evidence_case: EvidenceCase, op_name: str) -> bool:
    return any(op_name in gap or gap in op_name for gap in evidence_case.accepted_gaps)


def _is_tensor_cast_op(op_name: str) -> bool:
    return op_name.startswith("tensor_cast.")


def _is_communication_op(op_name: str) -> bool:
    lowered = op_name.lower()
    return any(
        token in lowered
        for token in (
            "allreduce",
            "all_reduce",
            "allgather",
            "all_gather",
            "alltoall",
            "all_to_all",
            "broadcast",
            "reduce_scatter",
            "hcom",
            "hccl",
            "collective",
        )
    )


def _missing_expected_category(expected_op: ExpectedOp, actual: ActualSummary) -> str:
    if expected_op.name.startswith("profiling."):
        return "FUSION_GAP_ACCEPTED_OR_NEEDS_REVIEW"
    if _is_communication_op(expected_op.name):
        return "COMMUNICATION_GAP"
    if _is_tensor_cast_op(expected_op.name) and not any(_is_tensor_cast_op(name) for name in actual.ops):
        return "PATCH_SEMANTICS_MISSING"
    return "OP_MAPPING_MISSING"


def _coverage_issues(actual: ActualSummary) -> List[VerificationIssue]:
    issues: List[VerificationIssue] = []
    for model_name, coverage in actual.coverage.items():
        if not isinstance(coverage, dict):
            continue
        m1 = coverage.get("m1") or coverage
        hit_rate = m1.get("m1_raw_op_count_hr") if isinstance(m1, dict) else None
        if hit_rate is not None and hit_rate < 1.0:
            issues.append(
                VerificationIssue(
                    category="PROFILING_SHAPE_MISSING",
                    message="Profiling coverage is incomplete for empirical performance model.",
                    severity="warning",
                    expected={"m1_raw_op_count_hr": 1.0},
                    actual={"model": model_name, "m1_raw_op_count_hr": hit_rate},
                    evidence_path="actual.coverage",
                )
            )
    return issues


def verify_evidence_case(
    evidence_case: EvidenceCase,
    actual: ActualSummary,
    extra_op_time_ratio: float = 0.05,
    extra_op_min_time_s: float = 0.0,
) -> VerificationReport:
    issues: List[VerificationIssue] = []

    if evidence_case.total_forward is not None and not evidence_case.total_forward.matches(actual.total_forward_time_s):
        issues.append(
            VerificationIssue(
                category="LATENCY_MODEL_MISMATCH",
                message="Total forward time is outside tolerance.",
                expected=evidence_case.total_forward.time_s,
                actual=actual.total_forward_time_s,
                evidence_path=f"cases[{evidence_case.name}].expected.total_forward.time_s",
            )
        )

    expected_names = {op.name for op in evidence_case.major_ops}
    for index, expected_op in enumerate(evidence_case.major_ops):
        actual_op = actual.get_op(expected_op.name)
        path = f"cases[{evidence_case.name}].expected.major_ops[{index}]"
        if actual_op is None:
            category = _missing_expected_category(expected_op, actual)
            issues.append(
                VerificationIssue(
                    category=category,
                    message=f"Expected major op {expected_op.name!r} is missing from actual summary.",
                    severity=_severity_for_expected_op(expected_op),
                    expected=expected_op.name,
                    actual=None,
                    evidence_path=f"{path}.name",
                )
            )
            continue

        if not expected_op.count_matches(actual_op.count):
            issues.append(
                VerificationIssue(
                    category="OP_COUNT_MISMATCH",
                    message=f"Op {expected_op.name!r} call count is outside expectation.",
                    severity=_severity_for_expected_op(expected_op),
                    expected=_format_expected_count(expected_op),
                    actual=actual_op.count,
                    evidence_path=f"{path}.count",
                )
            )

        if expected_op.total_time is not None and not expected_op.total_time.matches(actual_op.total_time_s):
            issues.append(
                VerificationIssue(
                    category="LATENCY_MODEL_MISMATCH",
                    message=f"Op {expected_op.name!r} total time is outside tolerance.",
                    severity=_severity_for_expected_op(expected_op),
                    expected=expected_op.total_time.time_s,
                    actual=actual_op.total_time_s,
                    evidence_path=f"{path}.total_time_s",
                )
            )

    extra_threshold = max(actual.total_forward_time_s * extra_op_time_ratio, extra_op_min_time_s)
    if extra_threshold > 0:
        for op in actual.high_time_ops(extra_threshold):
            if op.name in expected_names or _accepted_gap_matches(evidence_case, op.name):
                continue
            category = "COMMUNICATION_GAP" if _is_communication_op(op.name) else "FUSION_GAP_ACCEPTED_OR_NEEDS_REVIEW"
            issues.append(
                VerificationIssue(
                    category=category,
                    message=f"Actual high-time op {op.name!r} is not declared in evidence.",
                    severity="warning",
                    expected=None,
                    actual={"count": op.count, "total_time_s": op.total_time_s},
                    evidence_path=f"cases[{evidence_case.name}].expected.major_ops",
                )
            )

    issues.extend(_coverage_issues(actual))

    blocking = [issue for issue in issues if issue.severity == "error"]
    return VerificationReport(
        case_name=evidence_case.name,
        passed=not blocking,
        issues=issues,
    )