import logging
import re
import sys
from typing import List, Optional
import numpy as np
from whisperlivekit.local_agreement.backends import ASRBase
from whisperlivekit.timed_objects import ASRToken
logger = logging.getLogger(__name__)
def _patch_transformers_compat():
"""Patch transformers for qwen_asr 0.0.6 + transformers >= 5.3 compatibility."""
import torch
try:
import transformers.utils.generic as _g
if not hasattr(_g, "check_model_inputs"):
def check_model_inputs(*args, **kwargs):
def decorator(fn):
return fn
return decorator
_g.check_model_inputs = check_model_inputs
except ImportError:
pass
try:
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
if "default" not in ROPE_INIT_FUNCTIONS:
def _compute_default_rope_parameters(config=None, device=None, seq_len=None, **kwargs):
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
partial = getattr(config, "partial_rotary_factor", 1.0)
dim = int(head_dim * partial)
base = config.rope_theta
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, 1.0
ROPE_INIT_FUNCTIONS["default"] = _compute_default_rope_parameters
except ImportError:
pass
try:
from qwen_asr.core.transformers_backend.configuration_qwen3_asr import (
Qwen3ASRThinkerConfig,
)
if not hasattr(Qwen3ASRThinkerConfig, "pad_token_id"):
Qwen3ASRThinkerConfig.pad_token_id = None
except ImportError:
pass
try:
from transformers.models.auto import processing_auto
_orig_ap_from_pretrained = processing_auto.AutoProcessor.from_pretrained.__func__
@classmethod
def _patched_ap_from_pretrained(cls, *args, **kwargs):
kwargs.pop("fix_mistral_regex", None)
return _orig_ap_from_pretrained(cls, *args, **kwargs)
processing_auto.AutoProcessor.from_pretrained = _patched_ap_from_pretrained
except Exception:
pass
try:
from qwen_asr.core.transformers_backend.modeling_qwen3_asr import (
Qwen3ASRThinkerTextRotaryEmbedding,
)
if not hasattr(Qwen3ASRThinkerTextRotaryEmbedding, "compute_default_rope_parameters"):
@staticmethod
def _rope_params(config=None, device=None, seq_len=None, **kwargs):
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
partial = getattr(config, "partial_rotary_factor", 1.0)
dim = int(head_dim * partial)
base = config.rope_theta
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, 1.0
Qwen3ASRThinkerTextRotaryEmbedding.compute_default_rope_parameters = _rope_params
except ImportError:
pass
_patch_transformers_compat()
WHISPER_TO_QWEN3_LANGUAGE = {
"zh": "Chinese", "en": "English", "yue": "Cantonese",
"ar": "Arabic", "de": "German", "fr": "French", "es": "Spanish",
"pt": "Portuguese", "id": "Indonesian", "it": "Italian",
"ko": "Korean", "ru": "Russian", "th": "Thai", "vi": "Vietnamese",
"ja": "Japanese", "tr": "Turkish", "hi": "Hindi", "ms": "Malay",
"nl": "Dutch", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
"pl": "Polish", "cs": "Czech", "fa": "Persian",
"el": "Greek", "hu": "Hungarian", "mk": "Macedonian", "ro": "Romanian",
}
QWEN3_TO_WHISPER_LANGUAGE = {v: k for k, v in WHISPER_TO_QWEN3_LANGUAGE.items()}
QWEN3_MODEL_MAPPING = {
"qwen3-asr-1.7b": "Qwen/Qwen3-ASR-1.7B",
"qwen3-asr-0.6b": "Qwen/Qwen3-ASR-0.6B",
"qwen3-1.7b": "Qwen/Qwen3-ASR-1.7B",
"qwen3-0.6b": "Qwen/Qwen3-ASR-0.6B",
"large": "Qwen/Qwen3-ASR-1.7B",
"large-v3": "Qwen/Qwen3-ASR-1.7B",
"medium": "Qwen/Qwen3-ASR-1.7B",
"base": "Qwen/Qwen3-ASR-0.6B",
"small": "Qwen/Qwen3-ASR-0.6B",
"tiny": "Qwen/Qwen3-ASR-0.6B",
}
_PUNCTUATION_ENDS = set(".!?。!?;;")
_GARBAGE_RE = re.compile(r"^language\s+\S+$", re.IGNORECASE)
class Qwen3ASR(ASRBase):
"""Qwen3-ASR backend with ForcedAligner word-level timestamps."""
sep = ""
SAMPLING_RATE = 16000
def __init__(self, lan="auto", model_size=None, cache_dir=None,
model_dir=None, logfile=sys.stderr, **kwargs):
self.logfile = logfile
self.transcribe_kargs = {}
self.original_language = None if lan == "auto" else lan
self.model = self.load_model(model_size, cache_dir, model_dir)
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
import torch
from qwen_asr import Qwen3ASRModel
if model_dir:
model_id = model_dir
elif model_size:
model_id = QWEN3_MODEL_MAPPING.get(model_size.lower(), model_size)
else:
model_id = "Qwen/Qwen3-ASR-1.7B"
if torch.cuda.is_available():
dtype, device = torch.bfloat16, "cuda:0"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
dtype, device = torch.float32, "mps"
else:
dtype, device = torch.float32, "cpu"
logger.info(f"Loading Qwen3-ASR: {model_id} ({dtype}, {device})")
model = Qwen3ASRModel.from_pretrained(
model_id,
forced_aligner="Qwen/Qwen3-ForcedAligner-0.6B",
forced_aligner_kwargs=dict(dtype=dtype, device_map=device),
dtype=dtype,
device_map=device,
)
logger.info("Qwen3-ASR loaded with ForcedAligner")
return model
def _qwen3_language(self) -> Optional[str]:
if self.original_language is None:
return None
return WHISPER_TO_QWEN3_LANGUAGE.get(self.original_language)
def transcribe(self, audio: np.ndarray, init_prompt: str = ""):
try:
results = self.model.transcribe(
audio=(audio, 16000),
language=self._qwen3_language(),
context=init_prompt or "",
return_time_stamps=True,
)
except Exception:
logger.warning("Qwen3 timestamp alignment failed, falling back to no timestamps", exc_info=True)
results = self.model.transcribe(
audio=(audio, 16000),
language=self._qwen3_language(),
context=init_prompt or "",
return_time_stamps=False,
)
result = results[0]
result._audio_duration = len(audio) / 16000
logger.info(
"Qwen3 result: language=%r text=%r ts=%s",
result.language, result.text[:80] if result.text else "",
bool(result.time_stamps),
)
return result
@staticmethod
def _detected_language(result) -> Optional[str]:
"""Extract Whisper-style language code from Qwen3 result."""
lang = getattr(result, 'language', None)
if not lang or lang.lower() == "none":
return None
first = lang.split(",")[0].strip()
if not first or first.lower() == "none":
return None
return QWEN3_TO_WHISPER_LANGUAGE.get(first, first.lower())
def ts_words(self, result) -> List[ASRToken]:
text = (result.text or "").strip()
if not text or _GARBAGE_RE.match(text):
if text:
logger.info("Filtered garbage Qwen3 output: %r", text)
return []
detected = self._detected_language(result)
if result.time_stamps:
tokens = []
for i, item in enumerate(result.time_stamps):
text = item.text if i == 0 else " " + item.text
tokens.append(ASRToken(
start=item.start_time, end=item.end_time, text=text,
detected_language=detected,
))
return tokens
if not result.text:
return []
words = result.text.split()
duration = getattr(result, '_audio_duration', 5.0)
step = duration / max(len(words), 1)
return [
ASRToken(
start=round(i * step, 3), end=round((i + 1) * step, 3),
text=w if i == 0 else " " + w,
detected_language=detected,
)
for i, w in enumerate(words)
]
def segments_end_ts(self, result) -> List[float]:
if not result.time_stamps:
duration = getattr(result, '_audio_duration', 5.0)
return [duration]
ends = []
for item in result.time_stamps:
if item.text and item.text.rstrip()[-1:] in _PUNCTUATION_ENDS:
ends.append(item.end_time)
last_end = result.time_stamps[-1].end_time
if not ends or ends[-1] != last_end:
ends.append(last_end)
return ends
def use_vad(self):
return False