"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import os
import sys
import functools
import torch
current_directory = os.path.dirname(os.path.abspath(__file__))
parent_directory = os.path.abspath(os.path.join(current_directory, '..', ".."))
sys.path.append(parent_directory)
from example.common.security.path import get_valid_write_path
from example.common.utils import SafeGenerator, ArgumentParser, StringArgumentValidator, MAX_KEY_LENGTH, \
MAX_JSON_LENGTH, cmd_bool, parse_tokenizer_args
from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlier, AntiOutlierConfig
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator, QuantConfig
from example.common.copy_config_files import copy_config_files, modify_config_json
CPU = "cpu"
NPU = "npu"
def cmd_bool(cmd_arg):
if cmd_arg == "True":
return True
elif cmd_arg == "False":
return False
raise ValueError(f"{cmd_arg} should be True or False")
def get_disable_names(num_layers: int) -> list:
return [f"model.layers.{i}.mlp.gate.wg" for i in range(num_layers)]
def custom_hook(model_config):
model_config["quantize"] = "w8a8"
model_config["moe_quantize"] = "w8a8_dynamic"
def parse_arguments():
parser = ArgumentParser()
parser.add_argument('--model_path', type=str, help="model and tokenizer path")
parser.add_argument('--save_directory', type=str)
parser.add_argument('--part_file_size', type=int, default=5)
parser.add_argument('--w_bit', type=int, default=8)
parser.add_argument('--a_bit', type=int, default=8)
parser.add_argument('--disable_names', type=str, nargs='+', default=None)
parser.add_argument('--device_type', type=str, choices=[CPU, NPU], default=NPU)
parser.add_argument('--fraction', type=float, default=0.01)
parser.add_argument("--act_method", type=int, choices=[1, 2, 3], default=1,
help=" 1: MinMax, 2: Histogram, 3: Auto")
parser.add_argument('--co_sparse', type=cmd_bool, default=False)
parser.add_argument('--anti_method', type=str, default='')
parser.add_argument('--disable_level', type=str, default='L0')
parser.add_argument('--do_smooth', type=cmd_bool, default=False)
parser.add_argument('--use_sigma', type=cmd_bool, default=False)
parser.add_argument('--use_reduce_quant', type=cmd_bool, default=False)
parser.add_argument('--sigma_factor', type=float, default=3.0)
parser.add_argument('--is_lowbit', type=cmd_bool, default=False)
parser.add_argument('--mm_tensor', type=cmd_bool, default=True)
parser.add_argument('--w_sym', type=cmd_bool, default=True)
parser.add_argument('--use_kvcache_quant', type=cmd_bool, default=False)
parser.add_argument('--use_fa_quant', type=cmd_bool, default=False)
parser.add_argument('--fa_amp', type=int, default=0)
parser.add_argument('--open_outlier', type=cmd_bool, default=True)
parser.add_argument('--group_size', type=int, default=64)
parser.add_argument('--is_dynamic', type=cmd_bool, default=False)
parser.add_argument('--input_ids_name', type=str, default='input_ids',
validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH))
parser.add_argument('--attention_mask_name', type=str, default='attention_mask',
validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH))
parser.add_argument('--tokenizer_args', type=str, default='{}',
validator=StringArgumentValidator(min_length=2, max_length=MAX_JSON_LENGTH))
parser.add_argument('--disable_last_linear', type=cmd_bool, default=True)
parser.add_argument('--model_name', type=str, default=None,
validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH, allow_none=True))
parser.add_argument('--trust_remote_code', type=cmd_bool, default=False)
parser.add_argument('--mindie_format', action="store_true", help="Compatible with quantization formats \
supported by before 2.1.RC1 version of MindIE")
return parser.parse_args()
class Quantifier:
def __init__(self, model_path_or_name, quant_config=None,
anti_outlier_config=None, device_type='cpu', trust_remote_code=False, **kwargs):
safe_generator = SafeGenerator()
self.device_type = device_type
device_map = CPU if self.device_type == CPU else "auto"
self.trust_remote_code = trust_remote_code
self.quant_config = quant_config
self.anti_outlier_config = anti_outlier_config
self.model_path_or_name = model_path_or_name
self.config = safe_generator.get_config_from_pretrained(
self.model_path_or_name,
trust_remote_code=self.trust_remote_code
)
self.dtype = self.config.torch_dtype if self.device_type == NPU else torch.float32
self.model = safe_generator.get_model_from_pretrained(
self.model_path_or_name,
low_cpu_mem_usage=True,
torch_dtype=self.dtype,
trust_remote_code=self.trust_remote_code,
device_map={
"model.embed_tokens": 0,
"model.layers": "cpu",
"model.norm": "cpu",
"lm_head": 0,
}
)
tokenizer_args = kwargs.get("tokenizer_args", {})
self.tokenizer = safe_generator.get_tokenizer_from_pretrained(
self.model_path_or_name,
use_fast=True,
trust_remote_code=self.trust_remote_code,
add_eos_token=True,
**tokenizer_args
)
self.model_name = kwargs.get("model_name", None)
def get_tokenized_data(self, input_texts,
input_ids_name='input_ids',
attention_mask_name='attention_mask'):
tokenized_data = []
for input_text in input_texts:
inputs = self.tokenizer(input_text, return_tensors='pt', padding=True).to(self.device_type)
tokenized_data.append(
[inputs.data[input_ids_name], inputs.data[attention_mask_name]])
return tokenized_data
def convert(self, tokenized_data, save_path, disable_level, part_file_size=None):
if self.device_type == NPU:
torch.npu.set_compile_mode(jit_compile=False)
if self.anti_outlier_config is not None:
anti_outlier = AntiOutlier(self.model, calib_data=tokenized_data, cfg=self.anti_outlier_config)
anti_outlier.process()
mix_cfg = {
"*.experts.*": "w8a8_dynamic",
"*": "w8a8"
}
calibrator = Calibrator(
self.model,
self.quant_config,
calib_data=tokenized_data,
disable_level=disable_level,
mix_cfg=mix_cfg
)
calibrator.run()
save_type = "safe_tensor" if args.mindie_format else "ascendV1"
calibrator.save(save_path, save_type=[save_type], part_file_size=part_file_size)
if __name__ == '__main__':
args = parse_arguments()
checker = SafeGenerator()
rank: int = int(os.getenv("RANK", "0"))
model_path = args.model_path
save_directory = args.save_directory
num_layers = checker.get_config_from_pretrained(
model_path,
trust_remote_code=args.trust_remote_code
).num_hidden_layers
disable_names = args.disable_names
if not disable_names:
disable_names = get_disable_names(num_layers)
quant_conf = QuantConfig(
w_bit=args.w_bit,
a_bit=args.a_bit,
disable_names=disable_names,
dev_type=args.device_type,
dev_id=rank,
act_method=args.act_method,
w_sym=args.w_sym,
mm_tensor=False,
co_sparse=args.co_sparse,
fraction=args.fraction,
sigma_factor=args.sigma_factor,
use_sigma=args.use_sigma,
is_lowbit=args.is_lowbit,
do_smooth=args.do_smooth,
open_outlier=args.open_outlier,
group_size=args.group_size,
use_kvcache_quant=args.use_kvcache_quant,
is_dynamic=args.is_dynamic,
disable_last_linear=args.disable_last_linear,
)
if args.use_fa_quant:
quant_conf = quant_conf.fa_quant(fa_amp=args.fa_amp)
anti_outlier_config_val = None
if args.anti_method == 'm3':
anti_outlier_config_val = AntiOutlierConfig(a_bit=args.a_bit, w_bit=args.w_bit,
anti_method=args.anti_method, w_sym=args.w_sym,
dev_type=args.device_type, dev_id=rank)
elif args.anti_method:
anti_outlier_config_val = AntiOutlierConfig(anti_method=args.anti_method,
dev_type=args.device_type, dev_id=rank)
tokenizer_args = parse_tokenizer_args(
args.tokenizer_args,
default={}
)
quantifier = Quantifier(
model_path, quant_conf, anti_outlier_config_val,
device_type=args.device_type, tokenizer_args=tokenizer_args,
model_name=args.model_name, trust_remote_code=args.trust_remote_code
)
tokenized_calib_data = None
calib_texts = [
"Where is the capital of China?",
"Please make a poem:",
"I want to learn python, how should I learn it?",
"Please help me write a job report on large model inference optimization:",
"What are the most worth visiting scenic spots in China?"
]
if calib_texts is not None:
tokenized_calib_data = quantifier.get_tokenized_data(
calib_texts,
input_ids_name=args.input_ids_name,
attention_mask_name=args.attention_mask_name
)
if not os.path.exists(save_directory):
os.makedirs(save_directory, mode=0o750, exist_ok=True)
save_directory = get_valid_write_path(save_directory, is_dir=True)
quantifier.convert(tokenized_calib_data, save_directory, args.disable_level, part_file_size=args.part_file_size)
custom_hooks = {
'config.json': functools.partial(modify_config_json, custom_hook=custom_hook)
}
copy_config_files(
input_path=model_path,
output_path=save_directory,
quant_config=quant_conf,
mindie_format=args.mindie_format,
custom_hooks=custom_hooks
)
checker.copy_tokenizer_files(model_path, save_directory)