import logging
import re
import sys
from typing import List, Optional

import numpy as np

from whisperlivekit.local_agreement.backends import ASRBase
from whisperlivekit.timed_objects import ASRToken

logger = logging.getLogger(__name__)


def _patch_transformers_compat():
    """Patch transformers for qwen_asr 0.0.6 + transformers >= 5.3 compatibility."""
    import torch

    # 1. check_model_inputs was removed
    try:
        import transformers.utils.generic as _g
        if not hasattr(_g, "check_model_inputs"):
            def check_model_inputs(*args, **kwargs):
                def decorator(fn):
                    return fn
                return decorator
            _g.check_model_inputs = check_model_inputs
    except ImportError:
        pass

    # 2. 'default' rope type was removed from ROPE_INIT_FUNCTIONS
    try:
        from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
        if "default" not in ROPE_INIT_FUNCTIONS:
            def _compute_default_rope_parameters(config=None, device=None, seq_len=None, **kwargs):
                head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
                partial = getattr(config, "partial_rotary_factor", 1.0)
                dim = int(head_dim * partial)
                base = config.rope_theta
                inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
                return inv_freq, 1.0
            ROPE_INIT_FUNCTIONS["default"] = _compute_default_rope_parameters
    except ImportError:
        pass

    # 3. pad_token_id missing on thinker config
    try:
        from qwen_asr.core.transformers_backend.configuration_qwen3_asr import (
            Qwen3ASRThinkerConfig,
        )
        if not hasattr(Qwen3ASRThinkerConfig, "pad_token_id"):
            Qwen3ASRThinkerConfig.pad_token_id = None
    except ImportError:
        pass

    # 4. fix_mistral_regex kwarg not accepted by newer transformers
    try:
        from transformers.models.auto import processing_auto
        _orig_ap_from_pretrained = processing_auto.AutoProcessor.from_pretrained.__func__

        @classmethod
        def _patched_ap_from_pretrained(cls, *args, **kwargs):
            kwargs.pop("fix_mistral_regex", None)
            return _orig_ap_from_pretrained(cls, *args, **kwargs)

        processing_auto.AutoProcessor.from_pretrained = _patched_ap_from_pretrained
    except Exception:
        pass

    # 5. compute_default_rope_parameters missing on RotaryEmbedding
    try:
        from qwen_asr.core.transformers_backend.modeling_qwen3_asr import (
            Qwen3ASRThinkerTextRotaryEmbedding,
        )
        if not hasattr(Qwen3ASRThinkerTextRotaryEmbedding, "compute_default_rope_parameters"):
            @staticmethod
            def _rope_params(config=None, device=None, seq_len=None, **kwargs):
                head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
                partial = getattr(config, "partial_rotary_factor", 1.0)
                dim = int(head_dim * partial)
                base = config.rope_theta
                inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
                return inv_freq, 1.0
            Qwen3ASRThinkerTextRotaryEmbedding.compute_default_rope_parameters = _rope_params
    except ImportError:
        pass


_patch_transformers_compat()

# Whisper language codes → Qwen3 canonical language names
WHISPER_TO_QWEN3_LANGUAGE = {
    "zh": "Chinese", "en": "English", "yue": "Cantonese",
    "ar": "Arabic", "de": "German", "fr": "French", "es": "Spanish",
    "pt": "Portuguese", "id": "Indonesian", "it": "Italian",
    "ko": "Korean", "ru": "Russian", "th": "Thai", "vi": "Vietnamese",
    "ja": "Japanese", "tr": "Turkish", "hi": "Hindi", "ms": "Malay",
    "nl": "Dutch", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
    "pl": "Polish", "cs": "Czech", "fa": "Persian",
    "el": "Greek", "hu": "Hungarian", "mk": "Macedonian", "ro": "Romanian",
}

# Reverse mapping: Qwen3 canonical names → Whisper language codes
QWEN3_TO_WHISPER_LANGUAGE = {v: k for k, v in WHISPER_TO_QWEN3_LANGUAGE.items()}

# Short convenience names → HuggingFace model IDs
QWEN3_MODEL_MAPPING = {
    "qwen3-asr-1.7b": "Qwen/Qwen3-ASR-1.7B",
    "qwen3-asr-0.6b": "Qwen/Qwen3-ASR-0.6B",
    "qwen3-1.7b": "Qwen/Qwen3-ASR-1.7B",
    "qwen3-0.6b": "Qwen/Qwen3-ASR-0.6B",
    # Whisper-style size aliases (map to closest Qwen3 model)
    "large": "Qwen/Qwen3-ASR-1.7B",
    "large-v3": "Qwen/Qwen3-ASR-1.7B",
    "medium": "Qwen/Qwen3-ASR-1.7B",
    "base": "Qwen/Qwen3-ASR-0.6B",
    "small": "Qwen/Qwen3-ASR-0.6B",
    "tiny": "Qwen/Qwen3-ASR-0.6B",
}

_PUNCTUATION_ENDS = set(".!?。!?;;")
# Qwen3 raw output starts with "language <Name>" metadata before <asr_text> tag.
# When the tag is missing (silence/noise), this metadata leaks as transcription text.
_GARBAGE_RE = re.compile(r"^language\s+\S+$", re.IGNORECASE)


class Qwen3ASR(ASRBase):
    """Qwen3-ASR backend with ForcedAligner word-level timestamps."""

    sep = ""  # tokens include leading spaces, like faster-whisper
    SAMPLING_RATE = 16000

    def __init__(self, lan="auto", model_size=None, cache_dir=None,
                 model_dir=None, logfile=sys.stderr, **kwargs):
        self.logfile = logfile
        self.transcribe_kargs = {}
        self.original_language = None if lan == "auto" else lan
        self.model = self.load_model(model_size, cache_dir, model_dir)

    def load_model(self, model_size=None, cache_dir=None, model_dir=None):
        import torch
        from qwen_asr import Qwen3ASRModel

        if model_dir:
            model_id = model_dir
        elif model_size:
            model_id = QWEN3_MODEL_MAPPING.get(model_size.lower(), model_size)
        else:
            model_id = "Qwen/Qwen3-ASR-1.7B"

        if torch.cuda.is_available():
            dtype, device = torch.bfloat16, "cuda:0"
        elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            dtype, device = torch.float32, "mps"
        else:
            dtype, device = torch.float32, "cpu"

        logger.info(f"Loading Qwen3-ASR: {model_id} ({dtype}, {device})")
        model = Qwen3ASRModel.from_pretrained(
            model_id,
            forced_aligner="Qwen/Qwen3-ForcedAligner-0.6B",
            forced_aligner_kwargs=dict(dtype=dtype, device_map=device),
            dtype=dtype,
            device_map=device,
        )
        logger.info("Qwen3-ASR loaded with ForcedAligner")
        return model

    def _qwen3_language(self) -> Optional[str]:
        if self.original_language is None:
            return None
        return WHISPER_TO_QWEN3_LANGUAGE.get(self.original_language)

    def transcribe(self, audio: np.ndarray, init_prompt: str = ""):
        try:
            results = self.model.transcribe(
                audio=(audio, 16000),
                language=self._qwen3_language(),
                context=init_prompt or "",
                return_time_stamps=True,
            )
        except Exception:
            logger.warning("Qwen3 timestamp alignment failed, falling back to no timestamps", exc_info=True)
            results = self.model.transcribe(
                audio=(audio, 16000),
                language=self._qwen3_language(),
                context=init_prompt or "",
                return_time_stamps=False,
            )
        result = results[0]
        # Stash audio length for timestamp estimation fallback
        result._audio_duration = len(audio) / 16000
        logger.info(
            "Qwen3 result: language=%r text=%r ts=%s",
            result.language, result.text[:80] if result.text else "",
            bool(result.time_stamps),
        )
        return result

    @staticmethod
    def _detected_language(result) -> Optional[str]:
        """Extract Whisper-style language code from Qwen3 result."""
        lang = getattr(result, 'language', None)
        if not lang or lang.lower() == "none":
            return None
        # merge_languages may return comma-separated; take the first
        first = lang.split(",")[0].strip()
        if not first or first.lower() == "none":
            return None
        return QWEN3_TO_WHISPER_LANGUAGE.get(first, first.lower())

    def ts_words(self, result) -> List[ASRToken]:
        # Filter garbage model output (e.g. "language None" for silence/noise)
        text = (result.text or "").strip()
        if not text or _GARBAGE_RE.match(text):
            if text:
                logger.info("Filtered garbage Qwen3 output: %r", text)
            return []
        detected = self._detected_language(result)
        if result.time_stamps:
            tokens = []
            for i, item in enumerate(result.time_stamps):
                # Prepend space to match faster-whisper convention (tokens carry
                # their own whitespace so ''.join works in Segment.from_tokens)
                text = item.text if i == 0 else " " + item.text
                tokens.append(ASRToken(
                    start=item.start_time, end=item.end_time, text=text,
                    detected_language=detected,
                ))
            return tokens
        # Fallback: estimate timestamps from word count
        if not result.text:
            return []
        words = result.text.split()
        duration = getattr(result, '_audio_duration', 5.0)
        step = duration / max(len(words), 1)
        return [
            ASRToken(
                start=round(i * step, 3), end=round((i + 1) * step, 3),
                text=w if i == 0 else " " + w,
                detected_language=detected,
            )
            for i, w in enumerate(words)
        ]

    def segments_end_ts(self, result) -> List[float]:
        if not result.time_stamps:
            duration = getattr(result, '_audio_duration', 5.0)
            return [duration]
        # Create segment boundaries at punctuation marks
        ends = []
        for item in result.time_stamps:
            if item.text and item.text.rstrip()[-1:] in _PUNCTUATION_ENDS:
                ends.append(item.end_time)
        last_end = result.time_stamps[-1].end_time
        if not ends or ends[-1] != last_end:
            ends.append(last_end)
        return ends

    def use_vad(self):
        return False