"""Shared Canary-1B inference/evaluation helpers.
This module is intentionally scoped to model execution helpers used by
``infer.py`` and ``eval_canary.py``. It imports the required NeMo / torch / NPU
runtime dependencies at module import time so missing execution dependencies fail
promptly instead of silently falling back.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
import torch_npu
from nemo.collections.asr.models import EncDecMultiTaskModel
def resolve_device(device_name: str) -> torch.device:
"""Return a torch.device without hard-coding card indices."""
if device_name not in {"npu", "cpu", "cuda"}:
raise ValueError("--device must be one of: npu, cpu, cuda")
if device_name == "npu" and not torch_npu.npu.is_available():
raise RuntimeError("NPU device requested but torch_npu reports NPU unavailable")
return torch.device(device_name)
def resolve_compute_dtype(
dtype_name: str,
device: torch.device | None = None,
auto_bf16: bool = False,
) -> torch.dtype | None:
"""Return requested model compute dtype; None preserves checkpoint/default dtype."""
if dtype_name == "float32":
return torch.float32
if dtype_name == "float16":
return torch.float16
if dtype_name == "bfloat16":
return torch.bfloat16
if dtype_name == "auto":
if auto_bf16:
if device is None:
raise ValueError("device is required when auto_bf16=True")
if device.type in {"npu", "cuda"}:
return torch.bfloat16
return None
raise ValueError("--compute_dtype must be one of: auto, float32, float16, bfloat16")
def load_canary_model(
model: str,
device_name: str,
compute_dtype: str = "auto",
auto_bf16: bool = False,
beam_size: int | None = None,
decoding_strategy: str | None = None,
performance_mode: bool = False,
) -> Any:
"""Load Canary-1B from a .nemo file, local directory, or HF model id."""
device = resolve_device(device_name)
model_path = Path(model).expanduser()
if model_path.is_file() and model_path.suffix == ".nemo":
loaded_model = EncDecMultiTaskModel.restore_from(str(model_path), map_location=device)
elif model_path.is_dir() and (model_path / "canary-1b.nemo").is_file():
loaded_model = EncDecMultiTaskModel.restore_from(str(model_path / "canary-1b.nemo"), map_location=device)
else:
loaded_model = EncDecMultiTaskModel.from_pretrained(model, map_location=device)
loaded_model.eval()
loaded_model.to(device)
dtype = resolve_compute_dtype(compute_dtype, device=device, auto_bf16=auto_bf16)
if dtype is not None:
loaded_model.to(dtype)
if beam_size is not None:
configure_decoding(
loaded_model,
beam_size=beam_size,
decoding_strategy=decoding_strategy,
performance_mode=performance_mode,
)
return loaded_model
def configure_decoding(
model: Any,
beam_size: int,
decoding_strategy: str | None = None,
performance_mode: bool = False,
) -> None:
"""Apply Canary AED decoding settings in-place."""
decode_cfg = model.cfg.decoding
decode_cfg.beam.beam_size = beam_size
if decoding_strategy is not None and decoding_strategy != "auto":
decode_cfg.strategy = decoding_strategy
elif decoding_strategy == "auto" and performance_mode and beam_size == 1:
decode_cfg.strategy = "greedy_batch"
model.change_decoding_strategy(decode_cfg)
def synchronize_device(device_name: str) -> None:
"""Synchronize asynchronous accelerator work before/after timed sections."""
if device_name == "cuda":
torch.cuda.synchronize()
elif device_name == "npu":
torch.npu.synchronize()
def extract_text(item: Any) -> str:
"""Return the expected NeMo transcription text field."""
return str(item.text)