import os
import time
import argparse
import torch
from torch_npu.contrib import transfer_to_npu
from model import FunASRNano


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", required=True, help="model path")
    parser.add_argument("--input_file", required=True, help="input file path")
    parser.add_argument("--loops", type=int, default=10, help="loop count")
    parser.add_argument("--output_file", help="output file path")
    parser.add_argument("--device_id", type=int, default=0, help="device id")
    return parser.parse_args()


def main():
    cfg = parse_args()
    input_file = cfg.input_file

    device = (
        f"cuda:{cfg.device_id}"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(device)
    model, kwargs = FunASRNano.from_pretrained(model=cfg.model_dir, device=device)
    model.eval()

    if not input_file.endswith(".scp"):
        t = 0
        loops = cfg.loops
        text = None
        for _ in range(loops):
            start = time.perf_counter()
            res = model.inference(data_in=[input_file], **kwargs)
            end = time.perf_counter()
            infer_time = end - start
            t += infer_time
            text = res[0][0]["text"]
        print(f"infer result: {text}")
        print(f"avg time: {(t / loops)*1000:.2f} ms")
    else:
        output_file = cfg.output_file
        output_dir = os.path.dirname(output_file)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)

        with open(input_file, "r", encoding="utf-8") as f1:
            with open(output_file, "w", encoding="utf-8") as f2:
                for line in f1:
                    line = line.strip()
                    if not line:
                        continue
                    parts = line.split(maxsplit=1)
                    if len(parts) == 2:
                        res = model.inference(data_in=[parts[1]], **kwargs)
                        text = res[0][0]["text"]
                        f2.write(f"{parts[0]}\t{text}\n")

if __name__ == "__main__":
    main()