import argparse

import torch
from openmind import AutoTokenizer, pipeline, is_torch_npu_available
from openmind_hub import snapshot_download
from transformers import AutoModelForTokenClassification


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        help="Path to model",
        default=None,
    )
    args = parser.parse_args()

    return args


def main():
    args = parse_args()
    if args.model_name_or_path:
        model_path = args.model_name_or_path
    else:
        model_path = snapshot_download("PyTorch-NPU/camembert_ner", revision="main", resume_download=True,
                                       ignore_patterns=["*.h5", "*.ot", "*.msgpack"])
    if is_torch_npu_available():
        device = "npu:0"
    else:
        device = "cpu"

    # 推理
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForTokenClassification.from_pretrained(model_path)

    nlp = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple", device=device)
    output = nlp(
        "Apple est créée le 1er avril 1976 dans le garage de la maison d'enfance de Steve Jobs à Los Altos en Californie par Steve Jobs, Steve Wozniak et Ronald Wayne14, puis constituée sous forme de société le 3 janvier 1977 à l'origine sous le nom d'Apple Computer, mais pour ses 30 ans et pour refléter la diversification de ses produits, le mot « computer » est retiré le 9 janvier 2015.")
    print(f'>>>output={output}')


if __name__ == "__main__":
    main()