"""
Common evaluation and reward scoring for QMD query expansion models.

Shared by sft.py and grpo.py for post-training evaluation.
"""

import csv
import io
import re
from collections import Counter

import torch
from huggingface_hub import HfApi

# =============================================================================
# Reward function (single source of truth)
# =============================================================================

STOPWORDS = frozenset({
    'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
    'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
})

KEY_TERM_STOPWORDS = frozenset({
    'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
    'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
    'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
})

GENERIC_LEX_PHRASES = frozenset({
    'find information about', 'search for', 'look up', 'get information',
    'learn about', 'information on', 'details about', 'find out about',
    'what is', 'how to', 'guide to', 'help with',
})

CHAT_TEMPLATE_TOKENS = frozenset({
    '<|im_start|>', '<|im_end|>', '<|endoftext|>',
    '\nassistant\n', '\nuser\n',
})


def parse_expansion(text):
    result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
    for line in text.strip().split("\n"):
        line = line.strip()
        if not line:
            continue
        if line.startswith("lex:"):
            result["lex"].append(line[4:].strip())
        elif line.startswith("vec:"):
            result["vec"].append(line[4:].strip())
        elif line.startswith("hyde:"):
            result["hyde"].append(line[5:].strip())
        else:
            result["invalid"].append(line)
    return result


def clean_model_output(text):
    text = text.replace('<|im_end|>', '').strip()
    used_thinking = '<think>' in text and '</think>' in text
    if used_thinking:
        text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
    return text, used_thinking


def extract_named_entities(query):
    entities = set()
    words = query.split()
    prev_was_entity = False
    for i, word in enumerate(words):
        clean = word.strip('.,!?:;()[]"\'')
        if not clean:
            prev_was_entity = False
            continue
        is_entity = False
        if clean.isupper() and len(clean) >= 2:
            entities.add(clean.lower()); is_entity = True
        elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
            entities.add(clean.lower()); is_entity = True
        elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
            entities.add(clean.lower()); is_entity = True
        elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
            entities.add(clean.lower()); is_entity = True
        elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
            entities.add(clean.lower()); is_entity = True
        prev_was_entity = is_entity
    return entities


def get_key_terms(query):
    return set(query.lower().split()) - KEY_TERM_STOPWORDS


def lex_preserves_key_terms(lex_line, query):
    key_terms = get_key_terms(query)
    return not key_terms or bool(key_terms & set(lex_line.lower().split()))


def lex_preserves_entities(line, entities):
    if not entities:
        return True
    return any(e in line.lower() for e in entities)


def lex_is_generic(lex_line):
    lower = lex_line.lower().strip()
    for phrase in GENERIC_LEX_PHRASES:
        if phrase in lower or lower.startswith(phrase.split()[0]):
            remaining = lower
            for word in phrase.split():
                remaining = remaining.replace(word, '', 1).strip()
            if len(remaining) < 3:
                return True
    return False


def word_set_distance(a, b):
    return len(set(a.lower().split()) ^ set(b.lower().split()))


def is_diverse(a, b, min_distance=2):
    a, b = a.lower().strip(), b.lower().strip()
    if a == b or a in b or b in a:
        return False
    return word_set_distance(a, b) >= min_distance


def echoes_query(expansion, query):
    exp, q = expansion.lower().strip(), query.lower().strip()
    return exp == q or (q in exp and len(exp) < len(q) + 10)


def word_repetition_penalty(text):
    counts = Counter(re.findall(r'\b\w+\b', text.lower()))
    return sum((c - 2) * 2 for w, c in counts.items()
               if c >= 3 and w not in STOPWORDS and len(w) > 2)


def score_expansion(query, expansion):
    """Score expansion as float in [0.0, 1.0] for RL reward."""
    text, used_thinking = clean_model_output(expansion.strip())

    if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
        return 0.0
    for line in text.split("\n"):
        line = line.strip()
        if line and not line.startswith(("lex:", "vec:", "hyde:")):
            return 0.0

    parsed = parse_expansion(text)

    format_score = 10
    if parsed["lex"]: format_score += 10
    if parsed["vec"]: format_score += 10

    diversity_score = 0
    if sum(1 for t in ("lex", "vec") if parsed[t]) >= 2: diversity_score += 10
    if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5
    lex_div = 5
    for i, a in enumerate(parsed["lex"]):
        for b in parsed["lex"][i+1:]:
            if not is_diverse(a, b, 2): lex_div -= 2
    diversity_score += max(0, lex_div)
    vec_div = 5
    for i, a in enumerate(parsed["vec"]):
        for b in parsed["vec"][i+1:]:
            if not is_diverse(a, b, 3): vec_div -= 2
    diversity_score += max(0, vec_div)
    echo = 5
    for exp in parsed["lex"] + parsed["vec"]:
        if echoes_query(exp, query): echo -= 3
    diversity_score += max(0, echo)

    hyde_score = 0
    if parsed["hyde"]:
        hyde_text = parsed["hyde"][0]
        hyde_score += 5
        if 50 <= len(hyde_text) <= 200: hyde_score += 5
        elif len(hyde_text) < 50: hyde_score += 2
        if "\n" not in hyde_text: hyde_score += 5
        hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))

    quality_score = 5
    if parsed["lex"] and parsed["vec"]:
        avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
        avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
        if avg_lex <= avg_vec: quality_score += 5
    if parsed["vec"]:
        natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
        quality_score += 5 if natural == len(parsed["vec"]) else 2
    if parsed["lex"]:
        with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
        if with_terms == len(parsed["lex"]): quality_score += 5
        elif with_terms > 0: quality_score += 2

    entity_score = 0
    entities = extract_named_entities(query)
    if entities and parsed["lex"]:
        with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
        if with_entities == len(parsed["lex"]): entity_score += 15
        elif with_entities > 0: entity_score += 5
        else: entity_score -= 30
        generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
        if generic_count: entity_score -= generic_count * 15
        if parsed["vec"]:
            vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
            if vec_with > 0: entity_score += 5
    elif not entities:
        entity_score = 10

    think_bonus = 0 if used_thinking else 20
    total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
    max_possible = 140 if parsed["hyde"] else 120
    return max(0.0, min(1.0, total / max_possible))


def extract_query_from_prompt(prompt):
    """Extract the search query from a formatted prompt string."""
    if "Expand this search query:" in prompt:
        query = prompt.split("Expand this search query:")[-1].strip()
        if "<|im_end|>" in query:
            query = query.split("<|im_end|>")[0].strip()
        return query
    return prompt.strip()


class QMDRewardFunction:
    """Reward function wrapper for TRL's GRPOTrainer."""
    __name__ = "qmd_scoring_reward"

    def __call__(self, completions, prompts=None, **kwargs):
        rewards = []
        for i, completion in enumerate(completions):
            query = ""
            if prompts and i < len(prompts):
                query = extract_query_from_prompt(prompts[i])
            rewards.append(score_expansion(query, completion))
        return rewards


# =============================================================================
# Evaluation
# =============================================================================

EVAL_QUERIES = [
    # Technical documentation
    "how to configure authentication",
    "typescript async await",
    "docker compose networking",
    "git rebase vs merge",
    "react useEffect cleanup",
    # Short/ambiguous
    "auth", "config", "setup", "api",
    # Named entities
    "who is TDS motorsports",
    "React hooks tutorial",
    "Docker container networking",
    "Kubernetes pod deployment",
    "AWS Lambda functions",
    # Personal notes / journals
    "meeting notes project kickoff",
    "ideas for new feature",
    "todo list app architecture",
    # Research / learning
    "what is dependency injection",
    "difference between sql and nosql",
    "kubernetes vs docker swarm",
    # Error/debugging
    "connection timeout error",
    "memory leak debugging",
    "cors error fix",
    # Temporal / recency
    "recent news about Shopify",
    "latest AI developments",
    "best laptops right now",
    "what changed in kubernetes latest version",
    # Complex
    "how to implement caching with redis in nodejs",
    "best practices for api rate limiting",
    "setting up ci cd pipeline with github actions",
]


def generate_expansion(model, tokenizer, query, max_new_tokens=200):
    """Generate a query expansion using the model."""
    messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs, max_new_tokens=max_new_tokens,
            temperature=0.7, do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "\nassistant\n" in full_output:
        return full_output.split("\nassistant\n")[-1].strip()
    elif "assistant\n" in full_output:
        return full_output.split("assistant\n")[-1].strip()
    return full_output[len(prompt):].strip()


def run_eval(model, tokenizer, label, upload_repo="tobil/qmd-query-expansion-evals"):
    """Evaluate model on EVAL_QUERIES, print results, upload CSV."""
    api = HfApi()
    api.create_repo(repo_id=upload_repo, repo_type="model", exist_ok=True)

    print(f"\n{'='*70}")
    print(f"  EVALUATING: {label}")
    print(f"{'='*70}")

    results = []
    for i, query in enumerate(EVAL_QUERIES, 1):
        expansion = generate_expansion(model, tokenizer, query)
        score = score_expansion(query, expansion)
        pct = round(score * 100, 1)
        rating = ("Excellent" if pct >= 80 else "Good" if pct >= 60
                  else "Acceptable" if pct >= 40 else "Poor" if pct >= 20 else "Failed")
        marker = "+" if pct >= 80 else "-" if pct < 60 else "~"
        print(f"  [{marker}] {i:2d}/{len(EVAL_QUERIES)} {pct:5.1f}% {rating:10s}  {query}")
        results.append({"query": query, "expansion": expansion, "score": pct, "rating": rating})

    avg = sum(r["score"] for r in results) / len(results)
    ratings = Counter(r["rating"] for r in results)

    print(f"\n  {'─'*50}")
    print(f"  Average score: {avg:.1f}%")
    for r in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
        c = ratings.get(r, 0)
        if c:
            print(f"    {r:10s}: {c:2d}  {'█' * c}")

    worst = sorted(results, key=lambda r: r["score"])[:5]
    print(f"\n  Bottom 5:")
    for r in worst:
        print(f"    {r['score']:5.1f}%  {r['query']}")

    buf = io.StringIO()
    writer = csv.writer(buf)
    writer.writerow(["model", "query", "expansion", "score_pct", "rating"])
    for r in results:
        writer.writerow([label, r["query"], r["expansion"], r["score"], r["rating"]])

    filename = f"eval_{label}.csv"
    print(f"\n  Uploading {filename} to {upload_repo}...")
    api.upload_file(
        path_or_fileobj=buf.getvalue().encode("utf-8"),
        path_in_repo=filename,
        repo_id=upload_repo,
        repo_type="model",
    )
    print(f"  Done: https://huggingface.co/{upload_repo}/blob/main/{filename}")