import argparse
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--path",
type=str,
default='./deepseek-ai/Janus-Pro',
help="The path of model weights",
)
parser.add_argument(
"--device_id",
type=int,
default=0,
help="NPU device id",
)
parser.add_argument(
"--type",
type=str,
default='bf16',
choices=['bf16', 'fp16'],
help="bf16 or fp16"
)
return parser.parse_args()
args = parse_arguments()
torch.npu.set_device(args.device_id)
dtype = torch.float16 if args.type == "fp16" else torch.bfloat16
torch_npu.npu.set_compile_mode(jit_compile=False)
model_path = args.path
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True
)
vl_gpt = vl_gpt.to(dtype).to("npu").eval()
conversation = [
{
"role": "User",
"content": "<image_placeholder>\nConvert the formula into latex code.",
"images": ["images/equation.png"],
},
{"role": "Assistant", "content": ""},
]
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(vl_gpt.device, dtype=dtype)
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False,
use_cache=True,
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
print(f"{prepare_inputs['sft_format'][0]}", answer)