"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import os
import json
import sys
import torch
import torch.nn.functional as F
current_directory = os.path.dirname(os.path.abspath(__file__))
parent_directory = os.path.abspath(os.path.join(current_directory, '..', ".."))
sys.path.append(parent_directory)
from example.common.security.path import get_valid_write_path, get_valid_read_path, get_write_directory
from example.common.utils import SafeGenerator, ArgumentParser, StringArgumentValidator, MAX_KEY_LENGTH, \
MAX_JSON_LENGTH, cmd_bool, parse_tokenizer_args
from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlier, AntiOutlierConfig
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator, QuantConfig
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools.layer_select import LayerSelector
CPU = "cpu"
NPU = "npu"
def get_down_proj_disable_names(num_layers: int) -> list:
disable_names = ["lm_head"]
for i in range(num_layers):
disable_names.append(f"model.layers.{i}.mlp.down_proj")
return disable_names
def get_c_proj_disable_names(num_layers: int) -> list:
disable_names = ["lm_head"]
for i in range(num_layers):
disable_names.append(f"transformer.h.{i}.mlp.c_proj")
return disable_names
def get_padding_data(tokenizer, calib_list, device_type):
calib_dataset = []
max_len = 0
for calib_data in calib_list:
inputs = tokenizer(calib_data, return_tensors='pt', add_special_tokens=False)
calib_dataset.append(
inputs.data['input_ids'].to(device_type)
)
max_len = max(max_len, inputs.data['input_ids'].size(1))
new_calib_dataset = []
for inputs in calib_dataset:
new_inputs = F.pad(inputs, (0, max_len - inputs.size(1)), value=0)
new_calib_dataset.append(new_inputs)
return [torch.cat(new_calib_dataset)]
def get_batch_tokenized_data(tokenizer, input_texts, device_type, batch_size=4):
batch_ant_calib_texts = [input_texts[i:i + batch_size] for i in range(0, len(input_texts), batch_size)]
tokenized_ant_calib_data = []
for prompt in batch_ant_calib_texts:
tmp = get_padding_data(tokenizer, prompt, device_type)
tokenized_ant_calib_data.append(tmp)
return tokenized_ant_calib_data
def auto_layer_select(model, disable_names, disable_threshold, select_layer_data):
layer_selector = LayerSelector(model=model, layer_names=disable_names)
layer_selector.run(select_layer_data)
return layer_selector.select_layers_by_threshold(disable_threshold)
def get_select_anti_dataset(tokenizer, mixed_dataset, device="npu"):
"""用于离群值抑制的校准集"""
anti_data = []
for prpt_ans in mixed_dataset:
calib_dataset = []
calib_list = [prpt_ans["prompt"]]
max_len = 0
for calib_data in calib_list:
inputs = tokenizer(calib_data, return_tensors='pt')
calib_dataset.append(inputs.data['input_ids'].to(device))
max_len = max(max_len, inputs.data['input_ids'].size(1))
for i, data in enumerate(calib_dataset):
calib_dataset[i] = F.pad(data, (0, max_len - data.size(1)), value=0)
anti_data.append(torch.cat(calib_dataset))
anti_dataset = []
for data in anti_data:
anti_dataset.append([data])
return anti_dataset
def parse_arguments():
parser = ArgumentParser()
parser.add_argument('--model_path', type=str, help="model and tokenizer path")
parser.add_argument('--save_directory', type=str)
parser.add_argument('--part_file_size', type=int, default=None)
parser.add_argument(
'--calib_texts',
type=str,
nargs='+',
default=None)
parser.add_argument(
'--calib_file',
type=str,
help='A jsonl file contains calibration data.',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'common', 'teacher_qualification.jsonl'))
parser.add_argument('--w_bit', type=int, default=8)
parser.add_argument('--a_bit', type=int, default=8)
parser.add_argument('--disable_names', type=str, nargs='+', default=None)
parser.add_argument('--device_type', type=str, choices=[CPU, NPU], default=CPU)
parser.add_argument('--fraction', type=float, default=0.01)
parser.add_argument("--act_method", type=int, choices=[1, 2, 3], default=1,
help=" 1: MinMax, 2: Histogram, 3: Auto")
parser.add_argument('--co_sparse', type=cmd_bool, default=False)
parser.add_argument('--anti_method', type=str, default='')
parser.add_argument('--disable_level', type=str, default='L0')
parser.add_argument('--do_smooth', type=cmd_bool, default=False)
parser.add_argument('--use_sigma', type=cmd_bool, default=False)
parser.add_argument('--use_reduce_quant', type=cmd_bool, default=False)
parser.add_argument('--tp_size', type=int, default=1)
parser.add_argument('--sigma_factor', type=float, default=3.0)
parser.add_argument('--is_lowbit', type=cmd_bool, default=False)
parser.add_argument('--w_sym', type=cmd_bool, default=True)
parser.add_argument('--use_kvcache_quant', type=cmd_bool, default=False)
parser.add_argument('--use_fa_quant', type=cmd_bool, default=False)
parser.add_argument('--fa_amp', type=int, default=0)
parser.add_argument('--open_outlier', type=cmd_bool, default=True)
parser.add_argument('--group_size', type=int, default=64)
parser.add_argument('--is_dynamic', type=cmd_bool, default=False)
parser.add_argument('--input_ids_name', type=str, default='input_ids',
validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH))
parser.add_argument('--attention_mask_name', type=str, default='attention_mask',
validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH))
parser.add_argument('--tokenizer_args', type=str, default='{}',
validator=StringArgumentValidator(min_length=2, max_length=MAX_JSON_LENGTH))
parser.add_argument('--disable_last_linear', type=cmd_bool, default=True)
parser.add_argument('--model_name', type=str, default=None,
validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH, allow_none=True))
parser.add_argument('--model_type', type=str, default='qwen2',
choices=['qwen1', 'qwen1.5', 'qwen2', 'qwen2.5', 'qwen3'],
help='Specify the type of qwen model (choices: qwen1, qwen1.5, qwen2, qwen2.5, qwen3)')
parser.add_argument('--anti_calib_file', type=str, default=None,
help='Path to anti-calibration data file (.json or .jsonl)')
parser.add_argument('--disable_threshold', type=float, default=0,
help='Disable threshold when auto select disable names')
parser.add_argument('--pdmix', type=cmd_bool, default=False,
help='use pdmix quantization type')
parser.add_argument('--trust_remote_code', type=cmd_bool, default=False)
parser.add_argument('--layer_count', type=int, default=0)
parser.add_argument('--mindie_format', action="store_true", help="Compatible with quantization formats \
supported by before B050 version of MindIE")
parser.add_argument('--w_method', type=str, default='MinMax',
choices=['MinMax', 'GPTQ', 'HQQ', 'NF'],
help='Specify the type of weight quantization method (choices: MinMax, GPTQ, HQQ, NF)')
return parser.parse_args()
class Quantifier:
def __init__(self, model_path_or_name, args,
anti_outlier_config=None, device_type='cpu', **kwargs):
self.args = args
safe_generator = SafeGenerator()
self.device_type = device_type
device_map = CPU if self.device_type == CPU else "auto"
self.anti_outlier_config = anti_outlier_config
self.model_path_or_name = model_path_or_name
self.trust_remote_code = self.args.trust_remote_code
self.config = safe_generator.get_config_from_pretrained(
self.model_path_or_name,
trust_remote_code=self.trust_remote_code
)
self.layer_count = kwargs.get("layer_count", 0)
self.config.num_hidden_layers = self.layer_count if self.layer_count > 0 else self.config.num_hidden_layers
self.dtype = self.config.torch_dtype if self.device_type == NPU else torch.float32
self.model = safe_generator.get_model_from_pretrained(
self.model_path_or_name,
low_cpu_mem_usage=True, torch_dtype=self.dtype,
device_map=device_map,
trust_remote_code=self.trust_remote_code
)
tokenizer_args = kwargs.get("tokenizer_args", {})
self.tokenizer = safe_generator.get_tokenizer_from_pretrained(
self.model_path_or_name,
use_fast=False,
trust_remote_code=self.trust_remote_code,
legacy=False,
**tokenizer_args
)
self.model_name = kwargs.get("model_name", None)
self.quant_config = None
def create_quant_config(self, num_layers, select_layer_data=None, ):
args = self.args
disable_names = args.disable_names
if not disable_names and args.a_bit in [8, 16]:
if args.disable_threshold > 0:
disable_names = get_down_proj_disable_names(num_layers)
elif args.model_type == 'qwen1':
disable_names = get_c_proj_disable_names(num_layers)
else:
disable_names = get_down_proj_disable_names(num_layers)
if args.disable_threshold > 0:
disable_names = auto_layer_select(self.model, disable_names, args.disable_threshold, select_layer_data)
quant_config = QuantConfig(
w_bit=args.w_bit,
a_bit=args.a_bit,
disable_names=disable_names,
dev_type=args.device_type,
dev_id=rank,
act_method=args.act_method,
w_sym=args.w_sym,
mm_tensor=False,
co_sparse=args.co_sparse,
fraction=args.fraction,
sigma_factor=args.sigma_factor,
use_sigma=args.use_sigma,
is_lowbit=args.is_lowbit,
do_smooth=args.do_smooth,
open_outlier=args.open_outlier,
group_size=args.group_size,
use_kvcache_quant=args.use_kvcache_quant,
is_dynamic=args.is_dynamic,
disable_last_linear=args.disable_last_linear,
w_method=args.w_method,
pdmix=args.pdmix,
)
if args.use_fa_quant:
quant_config = quant_config.fa_quant(fa_amp=args.fa_amp)
self.quant_config = quant_config
def get_batch_tokenized_data(self, input_texts, batch_size=4):
return get_batch_tokenized_data(self.tokenizer, input_texts, self.device_type, batch_size)
def get_tokenized_data(self, input_texts,
input_ids_name='input_ids',
attention_mask_name='attention_mask'):
tokenized_data = []
for input_text in input_texts:
inputs = self.tokenizer(input_text, return_tensors='pt', padding=True).to(self.device_type)
if args.model_type == 'qwen1':
tokenized_data.append(
[inputs.data[input_ids_name], None, inputs.data[attention_mask_name]])
else:
tokenized_data.append(
[inputs.data[input_ids_name], inputs.data[attention_mask_name]])
return tokenized_data
def convert(self, tokenized_data, save_path, disable_level, part_file_size=None, tokenized_ant_calib_data=None):
if self.device_type == NPU:
torch.npu.set_compile_mode(jit_compile=False)
if tokenized_ant_calib_data is None:
tokenized_ant_calib_data = tokenized_data
if self.anti_outlier_config is not None:
if self.model_name == "baichuan":
anti_outlier = AntiOutlier(self.model, calib_data=tokenized_ant_calib_data,
cfg=self.anti_outlier_config, norm_class_name="RMSNorm")
else:
anti_outlier = AntiOutlier(self.model, calib_data=tokenized_ant_calib_data, \
cfg=self.anti_outlier_config)
anti_outlier.process()
if not os.path.exists(save_path):
os.mkdir(save_path, mode=0o750)
calibrator = Calibrator(self.model, self.quant_config, calib_data=tokenized_data, disable_level=disable_level)
calibrator.run()
save_type = "safe_tensor" if args.mindie_format else "ascendV1"
calibrator.save(save_path, save_type=[save_type], part_file_size=part_file_size)
if __name__ == '__main__':
args = parse_arguments()
checker = SafeGenerator()
rank: int = int(os.getenv("RANK", "0"))
model_path = get_valid_read_path(args.model_path, is_dir=True, check_user_stat=True)
save_directory = get_write_directory(args.save_directory, write_mode=0o750)
num_layers = checker.get_config_from_pretrained(
model_path,
trust_remote_code=args.trust_remote_code
).num_hidden_layers
num_layers = args.layer_count if args.layer_count > 0 else num_layers
anti_outlier_config_val = None
if args.anti_method == 'm3':
anti_outlier_config_val = AntiOutlierConfig(a_bit=args.a_bit, w_bit=args.w_bit,
anti_method=args.anti_method, w_sym=args.w_sym,
dev_type=args.device_type, dev_id=rank)
elif args.anti_method == 'm6':
keys = ['.o_proj']
anti_disable_names = ["model.layers.{}.self_attn.o_proj".format(i) for i in range(num_layers)]
if args.model_type == 'qwen3':
anti_outlier_config_val = AntiOutlierConfig(
a_bit=args.a_bit,
w_bit=args.w_bit,
w_sym=args.w_sym,
anti_method=args.anti_method,
dev_type=args.device_type,
disable_anti_names=anti_disable_names,
flex_config={'alpha': 0.4, 'beta': 0.325}
)
else:
anti_outlier_config_val = AntiOutlierConfig(
anti_method=args.anti_method,
dev_type=args.device_type,
disable_anti_names=anti_disable_names,
flex_config={'alpha': 0.6, 'beta': 0.3}
)
elif args.anti_method:
anti_outlier_config_val = AntiOutlierConfig(anti_method=args.anti_method,
dev_type=args.device_type)
tokenizer_args = parse_tokenizer_args(
args.tokenizer_args,
default={}
)
if tokenizer_args == {} and args.model_type == 'qwen1':
tokenizer_args = {
"padding_side": "left",
"pad_token": "<|extra_0|>",
"eos_token": "<|endoftext|>"
}
quantifier = Quantifier(
model_path, args, anti_outlier_config_val,
device_type=args.device_type, tokenizer_args=tokenizer_args,
model_name=args.model_name, layer_count=args.layer_count
)
tokenized_calib_data = []
if args.calib_file.lower() == 'none':
args.calib_file = None
calib_file = args.calib_file
if calib_file:
calib_file = get_valid_read_path(calib_file)
if calib_file.endswith('.jsonl'):
calib_texts = checker.load_jsonl(calib_file)
if calib_texts is not None:
tokenized_calib_data = quantifier.get_tokenized_data(
calib_texts,
input_ids_name=args.input_ids_name,
attention_mask_name=args.attention_mask_name
)
elif calib_file.endswith('.json'):
def get_calib_dataset(tokenizer, mixed_dataset, device='npu'):
"""用于量化的校准集"""
dataset_calib = []
for prpt_ans in mixed_dataset:
calib_list = [prpt_ans["prompt"]]
calib_dataset = []
for calib_data in calib_list:
inputs = tokenizer(calib_data, return_tensors='pt').to(device)
calib_dataset.append([inputs.data['input_ids']])
dataset_calib += calib_dataset
return dataset_calib
with open(calib_file, 'r') as f:
calib_promt = json.load(f)
tokenized_calib_data = get_calib_dataset(quantifier.tokenizer, calib_promt, args.device_type)
else:
raise ValueError("Unsupported calibration file format: {}".format(calib_file))
else:
calib_texts = args.calib_texts
tokenized_ant_calib_data = tokenized_calib_data
if args.anti_calib_file:
args.anti_calib_file = get_valid_read_path(args.anti_calib_file)
if args.model_type == "qwen3":
anti_calib_file_path = args.anti_calib_file
with open(anti_calib_file_path, 'r') as f:
anti_promt = json.load(f)
tokenized_ant_calib_data = get_select_anti_dataset(quantifier.tokenizer, anti_promt, args.device_type)
else:
ant_calib_texts = checker.load_jsonl(args.anti_calib_file)
if ant_calib_texts is not None:
tokenized_ant_calib_data = quantifier.get_batch_tokenized_data(ant_calib_texts)
if isinstance(args.disable_threshold, float) and args.disable_threshold > 0:
quantifier.create_quant_config(num_layers, tokenized_ant_calib_data)
elif args.disable_threshold == 0:
quantifier.create_quant_config(num_layers)
else:
raise ValueError("disable_threshold should be a float number >= 0")
quantifier.convert(tokenized_calib_data, save_directory, args.disable_level, part_file_size=args.part_file_size, \
tokenized_ant_calib_data=tokenized_ant_calib_data)
quant_type = quantifier.quant_config.model_quant_type.lower()
auto_config = checker.get_config_from_pretrained(model_path, trust_remote_code=args.trust_remote_code)
if args.model_type == 'qwen1':
auto_config.torch_dtype = 'torch.bfloat16'
checker.modify_config(model_path, save_directory, auto_config.torch_dtype,
quant_type, args)
checker.copy_tokenizer_files(model_path, save_directory)