import argparse
import os
from utils import get_qwen, build_enc, seed_everything, get_wikitext2
import amct_pytorch as amct
from amct_pytorch.experimental.flatquant.flat_quant_module.flat_utils import save_flat_matrices, load_flat_matrices
from amct_pytorch.experimental.flatquant.flat_quant_module.train_utils import cali_flat_quant
INT4_FLAT_QUANT_CFG = {
'quant_cfg': {
'inputs': {
'type': 'int4',
'symmetric': True,
'strategy': 'token'
},
'weights': {
'type': 'int4',
'symmetric': True,
'strategy': 'channel',
},
},
'algorithm': {
'flatquant': {
'use_kcache_quant': False,
'k_bits': 16,
'k_sym': False,
'use_vcache_quant': False,
'v_bits': 16,
'v_sym': False,
'use_o_quant': False,
'use_down_quant': False,
'lac': True,
'lwc': True,
'diag_alpha': 0.8,
'epochs': 15,
'cali_bsz': 4,
'flat_lr': 3e-3,
'cali_trans': True,
},
},
'skip_layers': {'lm_head'}
}
def content_generate(model, tokenizer):
prompt = "Give me a short introduction to the Ascend Model Compression Toolkit(AMCT). /no_think"
messages = [{'role': 'user', 'content': prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors='pt').to(model.device)
generated_ids = model.generate(**model_inputs, max_new_tokens=16384)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
content = tokenizer.decode(output_ids, skip_special_tokens=True)
return content
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, required=True, help='model location')
parser.add_argument('--device', type=str, default="npu:0", help='NPU device')
parser.add_argument('--load_matrix', action='store_true', help="whether to load matrix")
parser.add_argument('--flat_matrix_path', type=str,
default="./outputs/qwen/flat_matrices.pth", help='flat matrix location'
)
args = parser.parse_args()
seed_everything()
os.makedirs(os.path.dirname(args.flat_matrix_path), exist_ok=True)
model = get_qwen(args.model_path)
model.to(args.device)
tokenizer = build_enc(args.model_path)
ori_content = content_generate(model, tokenizer)
amct.quantize(model, INT4_FLAT_QUANT_CFG)
if args.load_matrix:
model = load_flat_matrices(model, args.flat_matrix_path)
else:
calib_dataset = get_wikitext2(nsamples=128, seed=0, seqlen=2048, tokenizer=tokenizer)
cali_flat_quant(model, calib_dataset, args.device)
save_flat_matrices(model, args.flat_matrix_path)
model.to(args.device)
amct.convert(model)
print(f'quantize model to W4A4 with FlatQuant success.')
print(f'original model content: \n{ori_content}\n')
quant_content = content_generate(model, tokenizer)
print(f'quantized model content: \n{quant_content}\n')