#!/usr/bin/env python3

"""Evaluate Canary-1B from already prepared JSONL manifests.



Keep this script intentionally small: data preparation is handled by

prepare_eval_data.py; inference uses the same NeMo model.transcribe() mechanism

as Canary-1B/infer.py; this file only adds metric calculation and result dumps.

"""



from __future__ import annotations



import argparse

import csv

import json

import os

import platform

import sys

import time

from pathlib import Path

from typing import Any



from whisper.normalizers import EnglishTextNormalizer



import nemo

import torch

from jiwer import wer

from sacrebleu import corpus_bleu

from utils import extract_text, load_canary_model, synchronize_device



DEFAULT_MANIFESTS = [

    "Canary-1B/eval_data/librispeech_test_clean/manifest_asr_en.jsonl",

    "Canary-1B/eval_data/mls_test_german/manifest_asr_de.jsonl",

    "Canary-1B/eval_data/mls_test_spanish/manifest_asr_es.jsonl",

    "Canary-1B/eval_data/mls_test_french/manifest_asr_fr.jsonl",

    "Canary-1B/eval_data/fleurs/en-de/manifest_ast_en_de.jsonl",

    "Canary-1B/eval_data/fleurs/en-es/manifest_ast_en_es.jsonl",

    "Canary-1B/eval_data/fleurs/en-fr/manifest_ast_en_fr.jsonl",

    "Canary-1B/eval_data/fleurs/de-en/manifest_ast_de_en.jsonl",

    "Canary-1B/eval_data/fleurs/es-en/manifest_ast_es_en.jsonl",

    "Canary-1B/eval_data/fleurs/fr-en/manifest_ast_fr_en.jsonl",

]





def parse_args() -> argparse.Namespace:

    parser = argparse.ArgumentParser(description="Evaluate Canary-1B using prepared manifests")

    parser.add_argument("--model", required=True, help="Local .nemo file, model directory, or HF model id")

    parser.add_argument("--manifest", nargs="+", help="Prepared JSONL manifest(s). Defaults to all standard manifests.")

    parser.add_argument("--device", default="npu", choices=["npu", "cpu", "cuda"])

    parser.add_argument("--output_dir", default="Canary-1B/eval_results")

    parser.add_argument("--batch_size", type=int, default=1)

    parser.add_argument("--beam_size", type=int, default=5)

    parser.add_argument(

        "--decoding_strategy",

        default="auto",

        choices=["auto", "beam", "greedy", "greedy_batch"],

        help=(

            "NeMo AED decoding strategy. auto keeps the checkpoint strategy except that "

            "--performance_mode with --beam_size 1 uses greedy_batch for faster throughput. "

            "Use --decoding_strategy beam to force exact beam decoder behavior."

        ),

    )

    parser.add_argument(

        "--performance_mode",

        action="store_true",

        help=(

            "Run a Hugging Face Open ASR Leaderboard-style timed path: sort samples by duration, "

            "warm up before timing, transcribe an audio filepath list instead of timing manifest setup, "

            "and report RTFx. Intended for ASR throughput comparisons."

        ),

    )

    parser.add_argument(

        "--compute_dtype",

        default="auto",

        choices=["auto", "float32", "float16", "bfloat16"],

        help="Model compute dtype. auto uses bfloat16 in performance_mode on NPU/CUDA, otherwise keeps float32.",

    )

    parser.add_argument(

        "--warmup_batches",

        type=int,

        default=4,

        help="Number of initial batches to run before timed inference when --performance_mode is set.",

    )

    parser.add_argument(

        "--num_workers",

        type=int,

        default=1,

        help="DataLoader workers passed to NeMo transcribe in --performance_mode.",

    )

    parser.add_argument(

        "--pnc",

        default="nopnc",

        choices=["pnc", "nopnc", "yes", "no"],

        help="Canary prompt PnC value used by the audio-list performance path.",

    )

    parser.add_argument("--source_lang", default="en", choices=["en", "de", "es", "fr"])

    parser.add_argument("--target_lang", default="en", choices=["en", "de", "es", "fr"])

    parser.add_argument("--task", default="asr", choices=["asr", "ast", "s2t_translation"])

    return parser.parse_args()





def read_manifest(path: Path) -> list[dict[str, Any]]:

    items: list[dict[str, Any]] = []

    with path.open("r", encoding="utf-8") as f:

        for line in f:

            line = line.strip()

            if line:

                items.append(json.loads(line))

    if not items:

        raise ValueError(f"empty manifest: {path}")

    return items





def tag_from_manifest(path: Path, items: list[dict[str, Any]]) -> str:

    first = items[0]

    task = first.get("taskname", "task")

    src = first.get("source_lang", "src")

    tgt = first.get("target_lang", "tgt")

    parent = path.parent.name.replace("-", "_")

    if task == "asr":

        return f"asr_{parent}_{src}"

    return f"ast_{src}_{tgt}"





_WER_NORMALIZER = EnglishTextNormalizer()





def normalize_for_wer(text: str) -> str:

    """Normalize ASR text with the official Whisper EnglishTextNormalizer."""

    return _WER_NORMALIZER(text)





def compute_metrics(taskname: str, references: list[str], hypotheses: list[str]) -> dict[str, float]:

    if taskname == "asr":

        refs = [normalize_for_wer(x) for x in references]

        hyps = [normalize_for_wer(x) for x in hypotheses]

        value = float(wer(refs, hyps))

        return {"wer": value, "wer_percent": value * 100.0}

    return {"bleu": float(corpus_bleu(hypotheses, [references]).score)}





def env_report(args: argparse.Namespace) -> dict[str, Any]:

    report = {

        "python": sys.version,

        "platform": platform.platform(),

        "argv": sys.argv,

        "model": args.model,

        "device": args.device,

        "batch_size": args.batch_size,

        "beam_size": args.beam_size,

        "decoding_strategy": args.decoding_strategy,

        "performance_mode": args.performance_mode,

        "compute_dtype": args.compute_dtype,

        "warmup_batches": args.warmup_batches,

        "num_workers": args.num_workers,

        "pnc": args.pnc,

        "source_lang": args.source_lang,

        "target_lang": args.target_lang,

        "task": args.task,

        "ascend_rt_visible_devices": os.environ.get("ASCEND_RT_VISIBLE_DEVICES"),

    }

    report["torch"] = torch.__version__

    report["nemo"] = nemo.__version__

    return report





def transcribe_audio_list(model: Any, audio_paths: list[str], args: argparse.Namespace) -> list[Any]:

    try:

        with torch.inference_mode(), torch.no_grad():

            outputs = model.transcribe(

                audio_paths,

                batch_size=args.batch_size,

                verbose=False,

                pnc=args.pnc,

                source_lang=args.source_lang,

                target_lang=args.target_lang,

                taskname=args.task,

                num_workers=args.num_workers,

            )

    except Exception as exc:

        raise RuntimeError(f"Canary transcription failed for {len(audio_paths)} audio file(s)") from exc



    if outputs is None:

        raise RuntimeError("Canary transcription returned no outputs")

    if isinstance(outputs, tuple):

        if len(outputs) != 2:

            raise RuntimeError(f"Unexpected Canary transcription tuple length: {len(outputs)}")

        outputs = outputs[0]



    transcriptions = list(outputs)

    if len(transcriptions) != len(audio_paths):

        raise RuntimeError(

            "Canary transcription output count does not match input count: "

            f"{len(transcriptions)} != {len(audio_paths)}"

        )

    return transcriptions





def evaluate_manifest_performance(model: Any, manifest: Path, args: argparse.Namespace) -> dict[str, Any]:

    items = read_manifest(manifest)

    tag = tag_from_manifest(manifest, items) + "_perf"

    sorted_items = sorted(items, key=lambda x: float(x.get("duration", 0.0)), reverse=True)

    audio_paths = [str(x["audio_filepath"]) for x in sorted_items]

    references = [str(x.get("answer", "")) for x in sorted_items]

    taskname = str(sorted_items[0].get("taskname", "asr"))



    warmup_samples = min(len(audio_paths), args.batch_size * args.warmup_batches)

    warmup_elapsed = 0.0

    if warmup_samples > 0:

        synchronize_device(args.device)

        warmup_started = time.time()

        warmup_outputs = transcribe_audio_list(model, audio_paths[:warmup_samples], args)

        if len(warmup_outputs) != warmup_samples:

            raise RuntimeError(

                "Warmup transcription output count does not match input count: "

                f"{len(warmup_outputs)} != {warmup_samples}"

            )

        synchronize_device(args.device)

        warmup_elapsed = time.time() - warmup_started



    synchronize_device(args.device)

    started = time.time()

    outputs = transcribe_audio_list(model, audio_paths, args)

    synchronize_device(args.device)

    elapsed = time.time() - started



    hypotheses = [extract_text(x) for x in outputs]

    metrics: dict[str, Any] = compute_metrics(taskname, references, hypotheses)

    total_audio = sum(float(x.get("duration", 0.0)) for x in sorted_items)

    metrics.update(

        {

            "tag": tag,

            "manifest": str(manifest),

            "mode": "performance",

            "num_samples": len(sorted_items),

            "audio_seconds": total_audio,

            "elapsed_seconds": elapsed,

            "rtf": elapsed / total_audio if total_audio > 0 else None,

            "rtfx": total_audio / elapsed if elapsed > 0 else None,

            "warmup_samples": warmup_samples,

            "warmup_elapsed_seconds": warmup_elapsed,

            "sorted_by_duration": "descending",

            "timed_input": "audio_filepath_list",

        }

    )



    output_dir = Path(args.output_dir)

    pred_path = output_dir / f"{tag}.tsv"

    with pred_path.open("w", encoding="utf-8", newline="") as f:

        writer = csv.writer(f, delimiter="\t")

        writer.writerow(["sample_id", "audio_path", "duration", "reference", "hypothesis"])

        for item, hyp in zip(sorted_items, hypotheses):

            writer.writerow(

                [

                    item.get("sample_id", ""),

                    item.get("audio_filepath", ""),

                    item.get("duration", ""),

                    item.get("answer", ""),

                    hyp,

                ]

            )



    with (output_dir / f"{tag}.metrics.json").open("w", encoding="utf-8") as f:

        json.dump(metrics, f, ensure_ascii=False, indent=2)

    print(json.dumps(metrics, ensure_ascii=False, indent=2))

    return metrics





def evaluate_manifest(model: Any, manifest: Path, args: argparse.Namespace) -> dict[str, Any]:

    items = read_manifest(manifest)

    tag = tag_from_manifest(manifest, items)

    started = time.time()

    outputs = model.transcribe(audio=str(manifest), batch_size=args.batch_size)

    elapsed = time.time() - started



    hypotheses = [extract_text(x) for x in outputs]

    references = [str(x.get("answer", "")) for x in items]

    taskname = str(items[0].get("taskname", "asr"))

    metrics: dict[str, Any] = compute_metrics(taskname, references, hypotheses)

    total_audio = sum(float(x.get("duration", 0.0)) for x in items)

    metrics.update(

        {

            "tag": tag,

            "manifest": str(manifest),

            "num_samples": len(items),

            "audio_seconds": total_audio,

            "elapsed_seconds": elapsed,

            "rtf": elapsed / total_audio if total_audio > 0 else None,

            "rtfx": total_audio / elapsed if elapsed > 0 else None,

        }

    )



    output_dir = Path(args.output_dir)

    pred_path = output_dir / f"{tag}.tsv"

    with pred_path.open("w", encoding="utf-8", newline="") as f:

        writer = csv.writer(f, delimiter="\t")

        writer.writerow(["sample_id", "audio_path", "duration", "reference", "hypothesis"])

        for item, hyp in zip(items, hypotheses):

            writer.writerow(

                [

                    item.get("sample_id", ""),

                    item.get("audio_filepath", ""),

                    item.get("duration", ""),

                    item.get("answer", ""),

                    hyp,

                ]

            )



    with (output_dir / f"{tag}.metrics.json").open("w", encoding="utf-8") as f:

        json.dump(metrics, f, ensure_ascii=False, indent=2)

    print(json.dumps(metrics, ensure_ascii=False, indent=2))

    return metrics





def main() -> None:

    args = parse_args()

    manifests = [Path(x) for x in (args.manifest or DEFAULT_MANIFESTS)]

    missing = [str(x) for x in manifests if not x.is_file()]

    if missing:

        raise FileNotFoundError("Missing manifest(s). Run prepare_eval_data.py first or pass --manifest. Missing: " + ", ".join(missing))



    output_dir = Path(args.output_dir)

    output_dir.mkdir(parents=True, exist_ok=True)



    with (output_dir / "run_env.json").open("w", encoding="utf-8") as f:

        json.dump(env_report(args), f, ensure_ascii=False, indent=2)



    model = load_canary_model(

        args.model,

        device_name=args.device,

        compute_dtype=args.compute_dtype,

        auto_bf16=args.performance_mode,

        beam_size=args.beam_size,

        decoding_strategy=args.decoding_strategy,

        performance_mode=args.performance_mode,

    )

    evaluate_fn = evaluate_manifest_performance if args.performance_mode else evaluate_manifest

    all_metrics = [evaluate_fn(model, manifest, args) for manifest in manifests]

    with (output_dir / "summary.metrics.json").open("w", encoding="utf-8") as f:

        json.dump(all_metrics, f, ensure_ascii=False, indent=2)





if __name__ == "__main__":

    main()