"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import argparse
import functools
import json
import os
import sys
from unittest.mock import patch
import torch
import torch_npu
from tqdm import tqdm
current_directory = os.path.dirname(os.path.abspath(__file__))
parent_directory = os.path.abspath(os.path.join(current_directory, '..', ".."))
sys.path.append(parent_directory)
from convert_fp8_to_bf16 import auto_convert_model_fp8_to_bf16, OpsType
from add_safetensors import add_safetensors
from mtp_quant_module import warp_mtp_model, post_process_mtp_quant
from example.common.rot_utils.rot_ds import rot_model
from example.common.security.path import get_valid_read_path, get_write_directory
from example.common.security.path import json_safe_load, json_safe_dump, get_valid_write_path
from example.common.security.type import check_number
from example.common.utils import SafeGenerator, cmd_bool
from example.common.copy_config_files import copy_config_files, modify_config_json, modify_vllm_config_json
from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlierConfig, AntiOutlier
from msmodelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator, QuantConfig
from msmodelslim.utils.logging import set_logger_level
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help="The path of float model and tokenizer"),
parser.add_argument('--save_path', type=str, help="The path to save quant model"),
parser.add_argument('--layer_count', type=int, default=0, help="Layer count when loading model")
parser.add_argument('--anti_dataset', type=str, default="../common/deepseek_anti_prompt.json",
help="The calib data for anti outlier")
parser.add_argument('--calib_dataset', type=str, default="../common/deepseek_calib_prompt.json",
help="The calib data for calibration")
parser.add_argument('--batch_size', type=int, default=4, help="Batch size for anti and calibration")
parser.add_argument('--from_fp8', action='store_true', help="Origin model is of fp8")
parser.add_argument('--from_bf16', action='store_true', help="Origin model is of bf16")
parser.add_argument('--fa_quant', action="store_true", help="Enable fa quant")
parser.add_argument('--mindie_format', action='store_true', help="Compatible with quantization formats \
supported by before 2.1.RC1 version of MindIE")
parser.add_argument('--dynamic', action='store_true', help="Dynamic Quantization")
parser.add_argument('--disable_anti', action='store_true', help="No AntiOutlier")
parser.add_argument('--quant_mtp', type=str, choices=['mix', 'float', 'none'], default='none', \
help="Quantization mode: 'mix(w8a8 mix quant)' , or \
'float(save float mtp weight)' (default: %(default)s)")
parser.add_argument('--anti_method', type=str, choices=['m4', 'm6'], default='m4')
parser.add_argument('--rot', action='store_true', help="rot model")
parser.add_argument('--trust_remote_code', type=cmd_bool, default=False)
return parser.parse_args()
def create_custom_hook(quant_mtp, mindie_format, fa_quant):
def custom_hook(model_config):
model_config["mla_quantize"] = "w8a8"
if quant_mtp == 'mix':
model_config["mtp_quantize"] = "w8a8_dynamic"
model_config["quantize"] = "w8a8_dynamic"
if mindie_format:
model_config["model_type"] = "deepseekv2"
if fa_quant:
if 'quantization_config' not in model_config:
model_config['quantization_config'] = {}
model_config['quantization_config']['fa_quant_type'] = 'FAKQuant'
return custom_hook
def get_calib_dataset_batch(model_tokenizer, calib_list, batch_size, device="npu"):
calib_dataset = []
calib_list = [calib_list[i:i + batch_size] for i in range(0, len(calib_list), batch_size)]
for calib_data in calib_list:
inputs = model_tokenizer(calib_data, return_tensors='pt', padding=True).to(device)
calib_dataset.append(
[value.to(device) for key, value in inputs.data.items() if isinstance(value, torch.Tensor)])
return calib_dataset
def set_initialized_submodules(model, state_dict_keys):
"""
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
dict.
"""
state_dict_keys = set(state_dict_keys)
not_initialized_submodules = {}
for module_name, module in model.named_modules():
if module_name == "":
module_keys = set(module.state_dict())
else:
module_keys = {f"{module_name}.{k}" for k in module.state_dict()}
if module_keys.issubset(state_dict_keys):
module._is_hf_initialized = True
else:
not_initialized_submodules[module_name] = module
return not_initialized_submodules
def remove_module_entries(save_path, json_filename="quant_model_description_w8a8_dynamic.json"):
"""
移除JSON文件中键包含"module"的条目
参数:
save_path (str): 输入JSON文件路径,json_filename (str): 输入JSON文件名称
"""
json_file_path = os.path.join(save_path, json_filename)
description_data = json_safe_load(json_file_path)
filtered_data = {
key: value
for key, value in description_data.items()
if "norm.module." not in key
}
json_safe_dump(filtered_data, json_file_path, indent=4)
def update_quant_model_description_content(save_path, json_filename):
"""
更新量化类型为 W8A8_DYNAMIC
参数:
save_path (str): 输出JSON文件路径
json_filename (str): 输入JSON文件名称
返回值:
None
"""
dest_quant_description_filepath = os.path.join(save_path, json_filename)
dest_quant_description_filepath = get_valid_write_path(dest_quant_description_filepath, is_dir=False)
quant_description_data = json_safe_load(dest_quant_description_filepath, check_user_stat=True)
quant_description_data["model_quant_type"] = "W8A8_DYNAMIC"
if ("version" in quant_description_data.keys()) and \
("fa_quant_type" in quant_description_data.keys()) and \
(quant_description_data["fa_quant_type"] == "FAQuant"):
quant_description_data["fa_quant_type"] = "FAKQuant"
json_safe_dump(quant_description_data, dest_quant_description_filepath, indent=4)
def main():
args = parse_args()
set_logger_level("info")
pbar = tqdm(total=5, position=0, desc="Total Process")
model_path = args.model_path
batch_size = args.batch_size
save_path = get_write_directory(args.save_path, write_mode=0o750)
check_number(batch_size, int, 1, 16, "batch_size")
safe_generator = SafeGenerator()
config = safe_generator.get_config_from_pretrained(model_path=model_path,
trust_remote_code=args.trust_remote_code)
num_layer = config.num_hidden_layers
if args.layer_count < 0 or args.layer_count > num_layer:
raise ValueError(
f"Invalid value for parameter layer_count: {args.layer_count}."
f"Must be between 0 and {num_layer}."
)
config.num_hidden_layers = args.layer_count if args.layer_count != 0 else config.num_hidden_layers
if args.quant_mtp == "mix":
config.num_hidden_layers = config.num_hidden_layers + 1
config.model_type = "deepseekv2"
config.use_cache = False
tokenizer = safe_generator.get_tokenizer_from_pretrained(model_path=model_path,
config=config,
trust_remote_code=args.trust_remote_code,
use_fast=True,
add_eos_token=True)
model = safe_generator.get_model_from_pretrained(model_path=model_path,
config=config,
trust_remote_code=args.trust_remote_code,
device_map={
"model.embed_tokens": 0,
"model.layers": "cpu",
"model.norm": "cpu",
"lm_head": 0,
},
torch_dtype="auto",
attn_implementation='eager')
auto_convert_model_fp8_to_bf16(model, model_path, OpsType.get_ops_type(args.from_bf16, args.from_fp8))
if args.quant_mtp == "mix":
model = warp_mtp_model(config, model, model_path)
if args.rot:
if args.quant_mtp == "float":
raise NotImplementedError("original float mtp is not supported rotation, "
"please use mix mode and disable all MTP Linear to get float rotated MTP model!")
rot_model(model)
pbar.update(1)
anti_dataset_path = get_valid_read_path(args.anti_dataset, "json", is_dir=False)
calib_dataset_path = get_valid_read_path(args.calib_dataset, "json", is_dir=False)
with open(anti_dataset_path, "r") as file:
anti_prompt = json.load(file)
with open(calib_dataset_path, "r") as file:
calib_prompt = json.load(file)
anti_dataset = get_calib_dataset_batch(tokenizer, anti_prompt, batch_size, model.device)
dataset_calib = get_calib_dataset_batch(tokenizer, calib_prompt, batch_size, model.device)
if not args.disable_anti:
with torch.no_grad():
anti_config = AntiOutlierConfig(w_bit=8,
a_bit=8,
anti_method=args.anti_method,
dev_type='npu',
dev_id=model.device.index)
anti_outlier = AntiOutlier(model, calib_data=anti_dataset, cfg=anti_config)
anti_outlier.process()
pbar.update(1)
disable_names = []
for ids in range(config.num_hidden_layers):
disable_names.append("model.layers." + str(ids) + ".self_attn.kv_b_proj")
if args.fa_quant:
if 'R1' in args.model_path.upper():
disable_attn_layers = [0, 1, 2, 46, 47, 50, 54, 55, 56, 57, 58, 59, 60]
elif 'V3.1' in args.model_path.upper():
disable_attn_layers = [0, 1, 2, 50, 51, 53, 54, 55, 56, 57, 58, 59, 60]
elif 'V3' in args.model_path.upper():
disable_attn_layers = [0, 1, 2, 46, 50, 51, 53, 54, 55, 56, 57, 59, 60]
else:
disable_attn_layers = []
for ids in disable_attn_layers:
if ids < num_layer:
disable_names.append("model.layers." + str(ids) + ".self_attn")
if args.quant_mtp == "mix":
disable_names.append("lm_head")
disable_names.append("mtp_decoder.self_attn.kv_b_proj")
disable_names.append("mtp_layer.shared_head.head")
quant_config = QuantConfig(
a_bit=8,
w_bit=8,
disable_names=disable_names,
dev_type='npu',
dev_id=model.device.index,
act_method=1,
pr=1.0,
w_sym=True,
mm_tensor=False,
is_dynamic=args.dynamic
)
if args.fa_quant:
quant_config = quant_config.fa_quant()
calibrator = Calibrator(model, quant_config, calib_data=dataset_calib, disable_level="L0")
calibrator.run()
pbar.update(1)
if args.mindie_format:
quant_model_description_json_name = "quant_model_description_w8a8_dynamic.json"
else:
quant_model_description_json_name = "quant_model_description.json"
save_type = "safe_tensor" if args.mindie_format else "ascendV1"
calibrator.save(save_path,
json_name=quant_model_description_json_name,
safetensors_name="quant_model_weight_w8a8_dynamic.safetensors",
save_type=[save_type],
part_file_size=4)
update_quant_model_description_content(save_path, quant_model_description_json_name)
if args.mindie_format:
remove_module_entries(save_path)
custom_hook_instance = create_custom_hook(args.quant_mtp, args.mindie_format, args.fa_quant)
custom_hooks = {
'config.json': functools.partial(modify_config_json, custom_hook=custom_hook_instance) \
if args.mindie_format \
else functools.partial(modify_vllm_config_json, custom_hook=custom_hook_instance)
}
copy_config_files(input_path=model_path, output_path=save_path, quant_config=quant_config,
mindie_format=args.mindie_format, custom_hooks=custom_hooks)
pbar.update(1)
if args.quant_mtp == "float":
add_safetensors(org_paths=model_path, target_dir=save_path, safetensors_prefix="mtp_float",
max_file_size_gb=5, prefix="model.layers.61.")
if args.quant_mtp == "mix":
post_process_mtp_quant(save_path)
pbar.update(1)
if __name__ == "__main__":
torch_npu.npu.init()
patch("transformers.modeling_utils.set_initialized_submodules", new=set_initialized_submodules).start()
main()