import os
import sys
import multiprocessing as mp

# Ensure local package import when running from other working directories.
repo_root = os.path.dirname(os.path.abspath(__file__))
pkg_root = repo_root
if os.path.basename(repo_root) == "cosyvoice" and os.path.isfile(os.path.join(repo_root, "__init__.py")):
    pkg_root = os.path.dirname(repo_root)
elif not os.path.isdir(os.path.join(pkg_root, "cosyvoice")):
    nested_root = os.path.join(repo_root, "CosyVoice")
    if os.path.isdir(os.path.join(nested_root, "cosyvoice")):
        pkg_root = nested_root
if pkg_root not in sys.path:
    sys.path.insert(0, pkg_root)

# NPU requires spawn start method when multiprocessing is used (vLLM engine).
if mp.get_start_method(allow_none=True) != "spawn":
    mp.set_start_method("spawn", force=True)

# --- 新增:针对 NPU 和 vLLM 的环境配置 ---
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
os.environ["VLLM_USE_V1"] = "1"
import torch
import torch_npu
torch.npu.config.allow_internal_format = False
# ----------------------------------------

sys.path.append('third_party/Matcha-TTS')
from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
try:
    from vllm import ModelRegistry as _ModelRegistry
except Exception:
    try:
        from vllm.model_executor.models import ModelRegistry as _ModelRegistry
    except Exception:
        try:
            from vllm.model_executor.model_registry import ModelRegistry as _ModelRegistry
        except Exception as exc:
            _ModelRegistry = None
            raise ImportError(
                "Cannot import ModelRegistry from vllm; check vllm version or API change."
            ) from exc
_ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)

from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.common import set_all_random_seed
from tqdm import tqdm
import numpy as np
import soundfile as sf


def _to_audio_array(audio):
    if isinstance(audio, torch.Tensor):
        audio = audio.detach().float().cpu().numpy()
    audio = np.asarray(audio).squeeze()
    if audio.ndim != 1:
        audio = audio.reshape(-1)
    return audio


def cosyvoice3_example():
    """ CosyVoice3 vLLM usage - 只生成 2 个音频并保存。"""
    # 本地模型路径
    local_model_path = '/opt/atomgit/Fun-CosyVoice3-0.5B-2512'
    print(f"正在从本地目录加载模型: {local_model_path}")

    cosyvoice = AutoModel(model_dir=local_model_path, load_trt=True, load_vllm=True, fp16=False)
    outputs = []

    # 只循环 2 次,生成 2 个音频
    for i in tqdm(range(2), desc="生成音频批次"):
        set_all_random_seed(i)
        audio_segments = []
        for result in cosyvoice.inference_zero_shot(
                '收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。',
                'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
                './asset/zero_shot_prompt.wav',
                stream=False):
            audio_segments.append(_to_audio_array(result['tts_speech']))

        if not audio_segments:
            continue

        audio = np.concatenate(audio_segments, axis=0).astype(np.float32, copy=False)
        filename = f"output_{i:03d}.wav"
        sf.write(filename, audio, samplerate=cosyvoice.sample_rate, format="WAV", subtype="PCM_16")
        outputs.append({"path": filename, "audio": audio, "sample_rate": cosyvoice.sample_rate})
        print(f"已保存 {filename} | len={audio.shape[0]} sr={cosyvoice.sample_rate}")

    return outputs


def main():
    outputs = cosyvoice3_example()
    print(f"完成,生成 {len(outputs)} 个音频文件")


if __name__ == '__main__':
    main()