"""
Pure-MLX Voxtral Realtime ASR backend for WhisperLiveKit.

Provides ``VoxtralMLXASR`` (model holder) and ``VoxtralMLXOnlineProcessor``
(streaming processor) that plug into WhisperLiveKit's audio processing
pipeline via ``insert_audio_chunk`` / ``process_iter`` / ``get_buffer`` etc.

Unlike the HuggingFace backend, this runs the full inference loop in-process
(no background thread / queue) — MLX operations on Apple Silicon are fast
enough to run synchronously inside ``asyncio.to_thread(process_iter)``.
"""

import logging
import sys
import time
from typing import List, Optional, Tuple

import mlx.core as mx
import numpy as np
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy

from whisperlivekit.timed_objects import ASRToken, Transcript
from whisperlivekit.voxtral_mlx.loader import DEFAULT_MODEL_ID, load_voxtral_model
from whisperlivekit.voxtral_mlx.model import SlidingKVCache
from whisperlivekit.voxtral_mlx.spectrogram import (
    LEFT_PAD_TOKENS,
    RIGHT_PAD_TOKENS,
    SAMPLES_PER_TOKEN,
    compute_mel_streaming,
)

logger = logging.getLogger(__name__)

# Decoder sliding-window size (matches the model's training configuration).
_DECODER_WINDOW = 8192


def _prompt_tokens(tokenizer, n_left_pad=LEFT_PAD_TOKENS, n_delay=6):
    """Build the prompt token sequence and return ``(token_ids, n_delay)``."""
    pad_id = tokenizer.get_special_token("[STREAMING_PAD]")
    ids = [tokenizer.bos_id] + [pad_id] * (n_left_pad + n_delay)
    return ids, n_delay


# ---------------------------------------------------------------------------
# Model holder
# ---------------------------------------------------------------------------


class VoxtralMLXASR:
    """Lightweight model holder — loads the MLX Voxtral model once and keeps
    it alive for the lifetime of the server."""

    sep = " "
    SAMPLING_RATE = 16_000

    def __init__(self, logfile=sys.stderr, **kwargs):
        self.logfile = logfile
        self.transcribe_kargs = {}

        lan = kwargs.get("lan", "auto")
        self.original_language = None if lan == "auto" else lan

        model_path = kwargs.get("model_dir") or kwargs.get("model_path")
        if not model_path:
            model_size = kwargs.get("model_size", "")
            if model_size and ("/" in model_size or model_size.startswith(".")):
                model_path = model_size
            else:
                model_path = DEFAULT_MODEL_ID

        t0 = time.time()
        logger.info("Loading Voxtral MLX model '%s' ...", model_path)
        self.model, self.tokenizer, self.config = load_voxtral_model(model_path)
        logger.info("Voxtral MLX model loaded in %.2fs", time.time() - t0)

        self.backend_choice = "voxtral-mlx"

    def transcribe(self, audio):
        pass  # all work happens in the online processor


# ---------------------------------------------------------------------------
# Online processor
# ---------------------------------------------------------------------------


class VoxtralMLXOnlineProcessor:
    """Streaming processor that incrementally encodes audio and decodes text
    using the MLX Voxtral model.

    Lifecycle (called by ``AudioProcessor.transcription_processor``):

        insert_audio_chunk(pcm, time)  →  process_iter()  →  get_buffer()
                      ... repeat ...
        start_silence() / end_silence()
        finish()
    """

    SAMPLING_RATE = 16_000

    def __init__(self, asr: VoxtralMLXASR, logfile=sys.stderr):
        self.asr = asr
        self.logfile = logfile
        self.end = 0.0
        self.buffer: list = []
        self.audio_buffer = np.array([], dtype=np.float32)

        self._model = asr.model
        self._tokenizer = asr.tokenizer

        # Pre-compute prompt tokens and delay conditioning (constant across utterances).
        self._prompt_ids, self._n_delay = _prompt_tokens(self._tokenizer)
        self._prefix_len = len(self._prompt_ids)

        self._delay_cond = self._model.delay_embedding(
            mx.array([self._n_delay], dtype=mx.float32)
        )
        mx.eval(self._delay_cond)

        self._prompt_embeds = self._model.decoder.embed(
            mx.array([self._prompt_ids])
        )[0]  # [prefix_len, dim]
        mx.eval(self._prompt_embeds)

        self._eos_id = self._tokenizer.eos_id
        self._secs_per_token = SAMPLES_PER_TOKEN / self.SAMPLING_RATE
        # The streaming model has an inherent delay: text for audio at position P
        # is generated at decoder position P + n_delay. Compensate timestamps.
        self._delay_secs = self._n_delay * self._secs_per_token

        self._reset_state()

    # -- state management --

    def _reset_state(self):
        """Reset all incremental state for a fresh utterance."""
        # Audio accumulation (list of chunks, concatenated on demand)
        self._pending_chunks: list[np.ndarray] = []
        self._pending_len = 0
        # Mel overlap
        self._mel_overlap: np.ndarray | None = None
        # Encoder incremental state
        self._conv_tail1 = None
        self._conv_tail2 = None
        self._enc_cache = None
        self._ds_remainder = None
        # Audio embeddings not yet decoded
        self._audio_embeds: mx.array | None = None
        # Decoder state
        self._dec_cache: list[SlidingKVCache] | None = None
        self._last_token: mx.array | None = None
        # Bookkeeping
        self._samples_encoded = 0
        self._positions_decoded = 0
        self._prefilled = False
        self._first_chunk = True
        # Text state
        self._full_text = ""
        self._n_text_tokens = 0
        self._n_committed_words = 0
        self._time_offset = 0.0
        # Per-word audio position tracking: decoder position (relative to prefix)
        # where each word in _full_text started and ended
        self._word_audio_starts: list[int] = []   # audio pos where word i started
        self._word_audio_ends: list[int] = []     # audio pos where word i last produced a token
        self._current_word_pos: Optional[int] = None  # audio pos of current (incomplete) word's first token

    # -- audio ingestion --

    def _get_pending(self) -> np.ndarray:
        """Flatten pending chunks into a single array."""
        if not self._pending_chunks:
            return np.zeros(0, dtype=np.float32)
        if len(self._pending_chunks) == 1:
            return self._pending_chunks[0]
        flat = np.concatenate(self._pending_chunks)
        self._pending_chunks = [flat]
        return flat

    def _set_pending(self, arr: np.ndarray):
        """Replace pending audio with a single array."""
        if len(arr) == 0:
            self._pending_chunks = []
            self._pending_len = 0
        else:
            self._pending_chunks = [arr]
            self._pending_len = len(arr)

    def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
        self.end = audio_stream_end_time
        self._pending_chunks.append(audio)
        self._pending_len += len(audio)
        self.audio_buffer = audio  # diagnostic only

    # -- core processing --

    def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
        try:
            return self._step(is_last)
        except Exception as e:
            logger.warning("[voxtral-mlx] process_iter error: %s", e, exc_info=True)
            return [], self.end

    def _step(self, is_last: bool) -> Tuple[List[ASRToken], float]:
        # 1. Encode any new audio
        self._encode_pending()

        if self._audio_embeds is None:
            return [], self.end

        # 2. Compute how many positions we can safely decode
        total_safe = LEFT_PAD_TOKENS + self._samples_encoded // SAMPLES_PER_TOKEN
        n_available = self._audio_embeds.shape[0]
        n_decodable = min(n_available, total_safe - self._positions_decoded)

        if n_decodable <= 0:
            return [], self.end

        # 3. Prefill if needed
        if not self._prefilled:
            if self._positions_decoded + n_available < self._prefix_len:
                return [], self.end
            self._do_prefill()
            # Re-check after consuming prefix embeddings
            n_available = self._audio_embeds.shape[0] if self._audio_embeds is not None else 0
            n_decodable = min(n_available, total_safe - self._positions_decoded)

        if n_decodable <= 0 or self._audio_embeds is None:
            return [], self.end

        # 4. Decode available positions
        hit_eos = self._decode_positions(n_decodable)

        if hit_eos:
            # Flush words, reset for next utterance
            words = self._flush_all_words()
            logger.debug(
                "[voxtral-mlx] EOS hit during stream: flushed %d words, "
                "samples_encoded=%d (%.2fs), text='%s'",
                len(words), self._samples_encoded,
                self._samples_encoded / self.SAMPLING_RATE,
                self._full_text[-60:] if self._full_text else "",
            )
            saved_offset = self._time_offset
            self._reset_state()
            self._time_offset = saved_offset
            return words, self.end

        # 5. Extract committed words (all but the last, which may still grow)
        return self._extract_committed_words(), self.end

    def _encode_pending(self):
        """Feed pending audio through the incremental encoder."""
        if self._pending_len < SAMPLES_PER_TOKEN:
            return

        pending = self._get_pending()
        available = len(pending)

        if self._first_chunk:
            # First chunk: prepend silence for left-padding
            n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
            left_pad = np.zeros(LEFT_PAD_TOKENS * SAMPLES_PER_TOKEN, dtype=np.float32)
            chunk = np.concatenate([left_pad, pending[:n_take]])
            self._set_pending(pending[n_take:])
            self._samples_encoded += n_take
            self._first_chunk = False
        else:
            n_take = (available // SAMPLES_PER_TOKEN) * SAMPLES_PER_TOKEN
            chunk = pending[:n_take]
            self._set_pending(pending[n_take:])
            self._samples_encoded += n_take

        mel, self._mel_overlap = compute_mel_streaming(chunk, self._mel_overlap)

        embeds, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder = (
            self._model.encode_incremental(
                mel, self._conv_tail1, self._conv_tail2, self._enc_cache, self._ds_remainder
            )
        )

        if embeds is not None:
            mx.eval(embeds)
            if self._audio_embeds is not None:
                self._audio_embeds = mx.concatenate([self._audio_embeds, embeds])
                mx.eval(self._audio_embeds)
            else:
                self._audio_embeds = embeds

    def _do_prefill(self):
        """Run the decoder prefill pass over the prompt + first audio embeddings."""
        n_dec_layers = len(self._model.decoder.blocks)
        self._dec_cache = [SlidingKVCache(_DECODER_WINDOW) for _ in range(n_dec_layers)]

        prefix_embeds = self._prompt_embeds + self._audio_embeds[: self._prefix_len]
        prefix_embeds = prefix_embeds[None, :, :]  # [1, prefix_len, dim]

        logits = self._model.decode(prefix_embeds, self._delay_cond, "causal", self._dec_cache)
        mx.eval(logits, *[x for c in self._dec_cache for x in (c.keys, c.values)])

        self._last_token = self._sample(logits)
        mx.async_eval(self._last_token)

        # Remove consumed prefix embeddings
        self._audio_embeds = self._audio_embeds[self._prefix_len :]
        if self._audio_embeds.shape[0] == 0:
            self._audio_embeds = None
        self._positions_decoded = self._prefix_len
        self._prefilled = True

    def _decode_positions(self, n: int) -> bool:
        """Autoregressively decode *n* positions.  Returns True on EOS."""
        base_pos = self._positions_decoded  # absolute position before this batch
        for i in range(n):
            tok_embed = self._model.decoder.embed(self._last_token.reshape(1, 1))[0, 0]
            combined = (self._audio_embeds[i] + tok_embed)[None, None, :]
            logits = self._model.decode(combined, self._delay_cond, mask=None, cache=self._dec_cache)
            next_tok = self._sample(logits)
            mx.async_eval(next_tok)

            token_id = self._last_token.item()
            if token_id == self._eos_id:
                # Close the current word if one is being built
                if self._current_word_pos is not None:
                    self._word_audio_ends.append(base_pos + i - self._prefix_len)
                    self._current_word_pos = None
                self._trim_embeds(i)
                self._positions_decoded += i
                return True

            text = self._tokenizer.decode(
                [token_id], special_token_policy=SpecialTokenPolicy.IGNORE
            )

            if text:
                audio_pos = base_pos + i - self._prefix_len

                # Detect word boundary: new word starts with space or is the very first text
                if text.lstrip() != text or not self._full_text:
                    # Close previous word if exists
                    if self._current_word_pos is not None:
                        self._word_audio_ends.append(audio_pos)
                    # Start new word
                    self._word_audio_starts.append(audio_pos)
                    self._current_word_pos = audio_pos
                elif self._current_word_pos is None:
                    # First token of first word (no leading space)
                    self._word_audio_starts.append(audio_pos)
                    self._current_word_pos = audio_pos

                self._full_text += text
                self._n_text_tokens += 1

            if i > 0 and i % 256 == 0:
                mx.clear_cache()

            self._last_token = next_tok

        self._positions_decoded += n
        self._trim_embeds(n)
        return False

    def _trim_embeds(self, n_consumed: int):
        if self._audio_embeds is not None and self._audio_embeds.shape[0] > n_consumed:
            self._audio_embeds = self._audio_embeds[n_consumed:]
        else:
            self._audio_embeds = None

    def _sample(self, logits: mx.array) -> mx.array:
        return mx.argmax(logits[0, -1:], axis=-1).squeeze()

    # -- word extraction --

    def _audio_pos_to_time(self, pos: int) -> float:
        """Convert an audio position (relative to prefix end) to seconds."""
        return max(0.0, pos * self._secs_per_token - self._delay_secs + self._time_offset)

    def _word_time_range(self, word_idx: int, n_words: int) -> Tuple[float, float]:
        """Compute (start, end) time for a word using tracked word positions."""
        starts = self._word_audio_starts
        ends = self._word_audio_ends

        if not starts:
            return self._time_offset, self._time_offset

        # Get start position for this word
        if word_idx < len(starts):
            t0 = self._audio_pos_to_time(starts[word_idx])
        else:
            # Fallback: estimate from last known position
            last_pos = ends[-1] if ends else starts[-1]
            t0 = self._audio_pos_to_time(last_pos + 1)

        # Get end position: use the start of the next word, or the end of this word
        if word_idx + 1 < len(starts):
            t1 = self._audio_pos_to_time(starts[word_idx + 1])
        elif word_idx < len(ends):
            t1 = self._audio_pos_to_time(ends[word_idx] + 1)
        else:
            # Last word, still being built: use last known position + 1 token
            last_pos = starts[word_idx] if word_idx < len(starts) else (ends[-1] if ends else 0)
            t1 = self._audio_pos_to_time(last_pos + 1)

        return t0, t1

    def _extract_committed_words(self) -> List[ASRToken]:
        """Return complete words (all except the last which may still grow)."""
        if not self._full_text:
            return []
        words = self._full_text.split()
        tokens: List[ASRToken] = []
        n_total = max(len(words), 1)

        while len(words) > self._n_committed_words + 1:
            w = words[self._n_committed_words]
            idx = self._n_committed_words
            t0, t1 = self._word_time_range(idx, n_total)
            label = w if idx == 0 else " " + w
            tokens.append(ASRToken(start=t0, end=t1, text=label))
            self._n_committed_words += 1

        return tokens

    def _flush_all_words(self) -> List[ASRToken]:
        """Flush every word including the last partial one."""
        if not self._full_text:
            return []
        words = self._full_text.split()
        tokens: List[ASRToken] = []
        n_total = max(len(words), 1)

        while self._n_committed_words < len(words):
            w = words[self._n_committed_words]
            idx = self._n_committed_words
            t0, t1 = self._word_time_range(idx, n_total)
            label = w if idx == 0 else " " + w
            tokens.append(ASRToken(start=t0, end=t1, text=label))
            self._n_committed_words += 1

        return tokens

    # -- interface methods --

    def get_buffer(self) -> Transcript:
        if not self._full_text:
            return Transcript(start=None, end=None, text="")
        words = self._full_text.split()
        remaining = words[self._n_committed_words :]
        if remaining:
            return Transcript(start=self.end, end=self.end, text=" ".join(remaining))
        return Transcript(start=None, end=None, text="")

    def start_silence(self) -> Tuple[List[ASRToken], float]:
        """Flush all pending words when silence starts.

        Adds right-padding silence and forces a full decode pass so the
        decoder emits tokens for the last words of speech. Without this,
        the model holds back the final tokens waiting for future context.
        """
        # Align pending audio to SAMPLES_PER_TOKEN boundary
        remainder = self._pending_len % SAMPLES_PER_TOKEN
        align_pad = (SAMPLES_PER_TOKEN - remainder) if remainder > 0 else 0

        # Add alignment + right-padding silence to provide future context
        total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
        if total_pad > 0:
            self._pending_chunks.append(np.zeros(total_pad, dtype=np.float32))
            self._pending_len += total_pad

        # Encode remaining audio (including right-padding)
        self._encode_pending()

        # Decode everything that's left
        if self._audio_embeds is not None and self._prefilled:
            self._decode_positions(self._audio_embeds.shape[0])

        # Flush last token if it wasn't EOS
        if self._last_token is not None:
            tid = self._last_token.item()
            if tid != self._eos_id:
                text = self._tokenizer.decode(
                    [tid], special_token_policy=SpecialTokenPolicy.IGNORE
                )
                if text:
                    last_pos = self._positions_decoded - self._prefix_len
                    if text.lstrip() != text or not self._full_text:
                        if self._current_word_pos is not None:
                            self._word_audio_ends.append(last_pos)
                        self._word_audio_starts.append(last_pos)
                        self._current_word_pos = last_pos
                    elif self._current_word_pos is None:
                        self._word_audio_starts.append(last_pos)
                        self._current_word_pos = last_pos
                    self._full_text += text
                    self._n_text_tokens += 1

        # Close the last word if still open
        if self._current_word_pos is not None:
            last_pos = self._positions_decoded - self._prefix_len
            self._word_audio_ends.append(last_pos)
            self._current_word_pos = None

        words = self._flush_all_words()
        logger.info("[voxtral-mlx] start_silence: flushed %d words", len(words))
        return words, self.end

    def end_silence(self, silence_duration: float, offset: float):
        self._time_offset += silence_duration
        self.end += silence_duration

    def new_speaker(self, change_speaker):
        self.start_silence()

    def warmup(self, audio, init_prompt=""):
        pass

    def finish(self) -> Tuple[List[ASRToken], float]:
        logger.debug(
            "[voxtral-mlx] finish: pending=%d samples, audio_embeds=%s, "
            "samples_encoded=%d, positions_decoded=%d, prefilled=%s, text so far='%s'",
            self._pending_len,
            self._audio_embeds.shape if self._audio_embeds is not None else None,
            self._samples_encoded,
            self._positions_decoded,
            self._prefilled,
            self._full_text[-80:] if self._full_text else "",
        )

        # Align pending audio to SAMPLES_PER_TOKEN boundary so nothing is lost
        remainder = self._pending_len % SAMPLES_PER_TOKEN
        if remainder > 0:
            align_pad = SAMPLES_PER_TOKEN - remainder
        else:
            align_pad = 0

        # Add alignment + right-padding silence
        total_pad = align_pad + RIGHT_PAD_TOKENS * SAMPLES_PER_TOKEN
        if total_pad > 0:
            self._pending_chunks.append(np.zeros(total_pad, dtype=np.float32))
            self._pending_len += total_pad

        # Encode remaining audio (including right-padding)
        self._encode_pending()

        logger.debug(
            "[voxtral-mlx] finish after encode: audio_embeds=%s, pending=%d",
            self._audio_embeds.shape if self._audio_embeds is not None else None,
            self._pending_len,
        )

        hit_eos = False

        # Decode everything that's left from right-padding
        if self._audio_embeds is not None and self._prefilled:
            hit_eos = self._decode_positions(self._audio_embeds.shape[0])
            logger.debug(
                "[voxtral-mlx] finish decode: hit_eos=%s, text='%s'",
                hit_eos, self._full_text[-80:] if self._full_text else "",
            )

        # Flush last token if it wasn't EOS
        if self._last_token is not None:
            tid = self._last_token.item()
            if tid != self._eos_id:
                text = self._tokenizer.decode(
                    [tid], special_token_policy=SpecialTokenPolicy.IGNORE
                )
                if text:
                    last_pos = self._positions_decoded - self._prefix_len
                    # Check if this starts a new word
                    if text.lstrip() != text or not self._full_text:
                        if self._current_word_pos is not None:
                            self._word_audio_ends.append(last_pos)
                        self._word_audio_starts.append(last_pos)
                        self._current_word_pos = last_pos
                    elif self._current_word_pos is None:
                        self._word_audio_starts.append(last_pos)
                        self._current_word_pos = last_pos
                    self._full_text += text
                    self._n_text_tokens += 1

        # Close the last word if still open
        if self._current_word_pos is not None:
            last_pos = self._positions_decoded - self._prefix_len
            self._word_audio_ends.append(last_pos)
            self._current_word_pos = None

        words = self._flush_all_words()
        logger.info("[voxtral-mlx] finish: flushed %d words", len(words))
        return words, self.end