import os
import torch
import torch.utils.data
import torch_npu
from transformers import AutoTokenizer, AutoModelForCausalLM
SEQ_LEN_OUT = 12
load_path = f"{os.environ['PROJECT_PATH']}/resource/llm_ptq/llama2_7b/"
tokenizer = AutoTokenizer.from_pretrained(load_path,
trust_remote_code=True,
local_files_only=True,
use_fast=False)
model = AutoModelForCausalLM.from_pretrained(load_path,
trust_remote_code=True,
torch_dtype=torch.float16,
local_files_only=True).npu()
calib_list = ["Where is the capital of China?",
"Please make a poem:",
"I want to learn python, how should I learn it?",
"Please help me write a job report on large model inference optimization:",
"What are the most worth visiting scenic spots in China?"]
def get_calib_dataset(tokenizers, calib_lists):
calib_dataset = []
for calib_data in calib_lists:
inputs = tokenizers([calib_data], return_tensors='pt')
print(inputs)
calib_dataset.append([inputs.data['input_ids'].npu(), inputs.data['attention_mask'].npu()])
return calib_dataset
dataset_calib = get_calib_dataset(tokenizer, calib_list)
from modelslim.pytorch.llm_ptq.llm_ptq_tools import QuantConfig, Calibrator
quant_config = QuantConfig(w_bit=4,
do_smooth=False,
dev_type = 'npu',
is_lowbit=True,
use_sigma=True,
)
calibrator = Calibrator(model, quant_config, calib_data=dataset_calib, disable_level='L2')
calibrator.run()
print("testing quantized weights...")
test_prompt = "Common sense questions and answers\n\nQuestion: How to learn a new language\nFactual answer:"
test_input = tokenizer(test_prompt, return_tensors="pt")
print("model is inferring...")
model = model.npu()
model.eval()
generate_ids = model.generate(test_input.input_ids.npu(), attention_mask=test_input.attention_mask.npu(), max_new_tokens=SEQ_LEN_OUT)
res = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(res)
for _, item in enumerate(res):
print(item)
calibrator.save(f"{os.environ['PROJECT_PATH']}/output/llm_ptq_lowbit", save_type=['numpy', 'safe_tensor'])
print('Save quant weight success!')