"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import os
import json
import sys
import torch
current_directory = os.path.dirname(os.path.abspath(__file__))
parent_directory = os.path.abspath(os.path.join(current_directory, '..', ".."))
sys.path.append(parent_directory)
from example.common.security.path import get_valid_write_path, get_valid_read_path
from example.common.utils import SafeGenerator, ArgumentParser, StringArgumentValidator, MAX_KEY_LENGTH, \
MAX_JSON_LENGTH, cmd_bool, parse_tokenizer_args
from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlier, AntiOutlierConfig
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator, QuantConfig
from msmodelslim import logger
CPU = "cpu"
NPU = "npu"
def cmd_bool(cmd_arg):
if cmd_arg == "True":
return True
elif cmd_arg == "False":
return False
raise ValueError(f"{cmd_arg} should be True or False")
def get_down_proj_disable_names(num_layers: int) -> list:
disable_names = []
for i in range(num_layers):
disable_names.append(f"model.layers.{i}.feed_forward.w2")
return disable_names
def parse_arguments():
parser = ArgumentParser()
parser.add_argument('--model_path', type=str, help="model and tokenizer path")
parser.add_argument('--save_directory', type=str)
parser.add_argument('--part_file_size', type=int, default=None)
parser.add_argument(
'--calib_texts',
type=str,
nargs='+',
default=None)
parser.add_argument(
'--calib_file',
type=str,
help='A jsonl file contains calibration data.',
default=os.path.join(os.path.dirname(os.path.dirname(__file__)), 'common', 'teacher_qualification.jsonl'))
parser.add_argument('--anti_file', type=str, default=None)
parser.add_argument('--w_bit', type=int, default=8)
parser.add_argument('--a_bit', type=int, default=8)
parser.add_argument('--disable_names', type=str, nargs='+', default=None)
parser.add_argument('--device_type', type=str, choices=[CPU, NPU], default=CPU)
parser.add_argument('--fraction', type=float, default=0.01)
parser.add_argument("--act_method", type=int, choices=[1, 2, 3], default=1,
help=" 1: MinMax, 2: Histogram, 3: Auto")
parser.add_argument('--co_sparse', type=cmd_bool, default=False)
parser.add_argument('--anti_method', type=str, default='')
parser.add_argument('--disable_level', type=str, default='L0')
parser.add_argument('--do_smooth', type=cmd_bool, default=False)
parser.add_argument('--use_sigma', type=cmd_bool, default=False)
parser.add_argument('--use_reduce_quant', type=cmd_bool, default=False)
parser.add_argument('--tp_size', type=int, default=1)
parser.add_argument('--sigma_factor', type=float, default=3.0)
parser.add_argument('--is_lowbit', type=cmd_bool, default=False)
parser.add_argument('--mm_tensor', type=cmd_bool, default=True)
parser.add_argument('--w_sym', type=cmd_bool, default=True)
parser.add_argument('--use_kvcache_quant', type=cmd_bool, default=False)
parser.add_argument('--use_fa_quant', type=cmd_bool, default=False)
parser.add_argument('--fa_amp', type=int, default=0)
parser.add_argument('--open_outlier', type=cmd_bool, default=True)
parser.add_argument('--group_size', type=int, default=64)
parser.add_argument('--is_dynamic', type=cmd_bool, default=False)
parser.add_argument('--input_ids_name', type=str, default='input_ids',
validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH))
parser.add_argument('--attention_mask_name', type=str, default='attention_mask',
validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH))
parser.add_argument('--tokenizer_args', type=str, default='{"padding_side":"left","pad_token":"<unk>"}',
validator=StringArgumentValidator(min_length=2, max_length=MAX_JSON_LENGTH))
parser.add_argument('--disable_last_linear', type=cmd_bool, default=True)
parser.add_argument('--model_name', type=str, default=None,
validator=StringArgumentValidator(min_length=1, max_length=MAX_KEY_LENGTH, allow_none=True))
parser.add_argument('--trust_remote_code', type=cmd_bool, default=False)
parser.add_argument('--mindie_format', action="store_true", help="Compatible with quantization formats \
supported by before 2.1.RC1 version of MindIE")
return parser.parse_args()
class Quantifier:
def __init__(self, model_path_or_name, quant_config=None,
anti_outlier_config=None, device_type='cpu', trust_remote_code=False, **kwargs):
safe_generator = SafeGenerator()
self.device_type = device_type
device_map = CPU if self.device_type == CPU else "auto"
self.trust_remote_code = trust_remote_code
self.quant_config = quant_config
self.anti_outlier_config = anti_outlier_config
self.model_path_or_name = model_path_or_name
self.config = safe_generator.get_config_from_pretrained(
self.model_path_or_name,
trust_remote_code=self.trust_remote_code
)
self.dtype = self.config.torch_dtype if self.device_type == NPU else torch.float32
self.model = safe_generator.get_model_from_pretrained(
self.model_path_or_name,
low_cpu_mem_usage=True,
torch_dtype=self.dtype,
device_map=device_map,
trust_remote_code=self.trust_remote_code
)
tokenizer_args = kwargs.get("tokenizer_args", {})
self.tokenizer = safe_generator.get_tokenizer_from_pretrained(
self.model_path_or_name,
use_fast=False,
trust_remote_code=self.trust_remote_code,
legacy=False,
**tokenizer_args
)
self.model_name = kwargs.get("model_name", None)
def get_tokenized_data(self, input_texts,
input_ids_name='input_ids',
attention_mask_name='attention_mask'):
tokenized_data = []
for input_text in input_texts:
inputs = self.tokenizer(input_text, return_tensors='pt', padding=True).to(self.device_type)
tokenized_data.append(
[inputs.data[input_ids_name], inputs.data[attention_mask_name]])
return tokenized_data
def convert(self, tokenized_calib_data, tokenized_anti_data, save_path, disable_level, part_file_size=None):
if self.device_type == NPU:
torch.npu.set_compile_mode(jit_compile=False)
if self.anti_outlier_config is not None:
anti_outlier = AntiOutlier(self.model, calib_data=tokenized_anti_data, cfg=self.anti_outlier_config)
anti_outlier.process()
calibrator = Calibrator(self.model,
self.quant_config,
calib_data=tokenized_calib_data,
disable_level=disable_level)
calibrator.run()
save_type = "safe_tensor" if args.mindie_format else "ascendV1"
calibrator.save(save_path, save_type=[save_type], part_file_size=part_file_size)
if __name__ == '__main__':
args = parse_arguments()
checker = SafeGenerator()
try:
rank: int = int(os.getenv("RANK", "0"))
except ValueError as e:
logger.warning(f"Error converting 'RANK' environment variable to integer: {e}")
logger.info("Defaulting to 0.")
rank: int = 0
model_path = args.model_path
save_directory = args.save_directory
num_layers = checker.get_config_from_pretrained(
model_path,
trust_remote_code=args.trust_remote_code
).num_hidden_layers
disable_names = args.disable_names
if not disable_names and args.a_bit == 8:
disable_names = get_down_proj_disable_names(num_layers)
quant_conf = QuantConfig(
w_bit=args.w_bit,
a_bit=args.a_bit,
disable_names=disable_names,
dev_type=args.device_type,
dev_id=rank,
act_method=args.act_method,
w_sym=args.w_sym,
mm_tensor=False,
co_sparse=args.co_sparse,
fraction=args.fraction,
sigma_factor=args.sigma_factor,
use_sigma=args.use_sigma,
is_lowbit=args.is_lowbit,
do_smooth=args.do_smooth,
open_outlier=args.open_outlier,
group_size=args.group_size,
use_kvcache_quant=args.use_kvcache_quant,
is_dynamic=args.is_dynamic,
disable_last_linear=args.disable_last_linear,
)
if args.use_fa_quant:
quant_conf = quant_conf.fa_quant(fa_amp=args.fa_amp)
anti_outlier_config_val = None
if args.anti_method == 'm3':
anti_outlier_config_val = AntiOutlierConfig(a_bit=args.a_bit, w_bit=args.w_bit,
anti_method=args.anti_method, w_sym=args.w_sym,
dev_type=args.device_type, dev_id=rank)
elif args.anti_method:
anti_outlier_config_val = AntiOutlierConfig(anti_method=args.anti_method,
dev_type=args.device_type)
tokenizer_args = parse_tokenizer_args(
args.tokenizer_args,
default={"padding_side": "left", "pad_token": "<unk>"}
)
quantifier = Quantifier(
model_path, quant_conf, anti_outlier_config_val,
device_type=args.device_type, tokenizer_args=tokenizer_args,
model_name=args.model_name, trust_remote_code=args.trust_remote_code
)
tokenized_calib_data = None
calib_file = args.calib_file
anti_file = args.anti_file
calib_texts = checker.load_jsonl(calib_file) if calib_file else args.calib_texts
anti_texts = checker.load_jsonl(anti_file) if anti_file else calib_texts
if calib_texts is not None:
tokenized_calib_data = quantifier.get_tokenized_data(
calib_texts,
input_ids_name=args.input_ids_name,
attention_mask_name=args.attention_mask_name
)
if anti_texts is not None:
tokenized_anti_data = quantifier.get_tokenized_data(
anti_texts,
input_ids_name=args.input_ids_name,
attention_mask_name=args.attention_mask_name
)
if not os.path.exists(save_directory):
os.makedirs(save_directory, mode=0o750, exist_ok=True)
save_directory = get_valid_write_path(save_directory, is_dir=True)
quantifier.convert(tokenized_calib_data,
tokenized_anti_data, save_directory,
args.disable_level,
part_file_size=args.part_file_size)
quant_type = f"w{args.w_bit}a{args.a_bit}"
is_sparseCompress = args.w_bit == 4 and args.a_bit == 8 and (args.co_sparse or args.is_lowbit)
if is_sparseCompress:
quant_type = "w8a8s"
is_w8a8_dynamic = args.w_bit == 8 and args.a_bit == 8 and args.is_dynamic
if is_w8a8_dynamic:
quant_type = "w8a8_dynamic"
auto_config = checker.get_config_from_pretrained(model_path, trust_remote_code=args.trust_remote_code)
checker.modify_config(model_path, save_directory, auto_config.torch_dtype,
quant_type, args)
checker.copy_tokenizer_files(model_path, save_directory)