"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import os
import shutil
import argparse
from dataclasses import dataclass
from typing import List, Optional, Any
from example.common.security.path import json_safe_load, json_safe_dump
from example.common.security.path import get_valid_read_path, get_valid_write_path
from example.common.utils import SafeGenerator
@dataclass
class ModifyConfigParams:
"""修改配置文件的参数封装"""
model_dir: str
dest_dir: str
torch_dtype: Any
quantize_type: str
args: Optional[Any] = None
quantize_config_parts: Optional[List[str]] = None
def __post_init__(self):
"""后初始化处理"""
if self.quantize_config_parts is None:
self.quantize_config_parts = []
@dataclass
class CopyTokenizerParams:
"""复制tokenizer文件的参数封装"""
model_dir: str
dest_dir: str
class VlmSafeGenerator(SafeGenerator):
DEFAULT_QUANTIZATION_CONFIG = {
'group_size': 0,
'act_method': 2,
'anti_method': 'm2',
'is_lowbit': False,
'mm_tensor': False,
'w_sym': True,
'open_outlier': True,
'is_dynamic': False,
}
SUPPORTED_EXTENSIONS = {'.json', '.py'}
EXCLUDED_FILES = {'config.json', 'model.safetensors.index.json'}
MAX_FILE_NUM = 1024
@staticmethod
def modify_config(params: ModifyConfigParams):
"""修改配置文件"""
model_dir = get_valid_read_path(params.model_dir, is_dir=True, check_user_stat=True)
dest_dir = get_valid_write_path(params.dest_dir, is_dir=True)
src_config_filepath = os.path.join(model_dir, 'config.json')
data = json_safe_load(src_config_filepath, check_user_stat=True)
dest_quant_description_filepath = VlmSafeGenerator._get_quantization_filename(
dest_dir, params.quantize_type, getattr(params.args, 'mindie_format', False)
)
dest_quant_description_filepath = get_valid_write_path(dest_quant_description_filepath, is_dir=False)
quant_description_data = json_safe_load(dest_quant_description_filepath, check_user_stat=True)
data['torch_dtype'] = str(params.torch_dtype).split('.')[1]
if params.args and getattr(params.args, 'mindie_format', False):
data['quantize'] = params.quantize_type
for config_part in params.quantize_config_parts:
data[config_part]['quantize'] = params.quantize_type
quantization_config = VlmSafeGenerator._build_quantization_config(params.args)
if quantization_config:
quant_description_data.update(quantization_config)
if params.args and getattr(params.args, 'mindie_format', False):
data['quantization_config'] = quantization_config
dest_config_filepath = os.path.join(dest_dir, 'config.json')
json_safe_dump(data, dest_config_filepath, 4)
@staticmethod
def copy_tokenizer_files(params: CopyTokenizerParams):
"""复制tokenizer文件"""
model_dir = get_valid_read_path(params.model_dir, is_dir=True, check_user_stat=True)
if not os.path.exists(params.dest_dir):
os.makedirs(params.dest_dir, mode=0o750, exist_ok=True)
dest_dir = get_valid_write_path(params.dest_dir, is_dir=True)
filenames = os.listdir(model_dir)
if len(filenames) > VlmSafeGenerator.MAX_FILE_NUM:
raise argparse.ArgumentTypeError(
f"The file num in dir is {len(filenames)}, "
f"which exceeds the limit {VlmSafeGenerator.MAX_FILE_NUM}."
)
for filename in filenames:
_, ext = os.path.splitext(filename)
if ext not in VlmSafeGenerator.SUPPORTED_EXTENSIONS:
continue
if filename in VlmSafeGenerator.EXCLUDED_FILES:
continue
src_filepath = os.path.join(model_dir, filename)
dest_filepath = os.path.join(dest_dir, filename)
shutil.copyfile(src_filepath, dest_filepath)
os.chmod(dest_filepath, 0o600)
@staticmethod
def _get_quantization_filename(dest_dir, quantize_type, mindie_format):
"""生成量化描述文件名"""
if mindie_format:
filename = f"quant_model_description_{quantize_type.lower()}.json"
else:
filename = "quant_model_description.json"
return os.path.join(dest_dir, filename)
@staticmethod
def _build_quantization_config(args):
"""构建量化配置字典"""
if args is None:
return {}
config = {}
for key, default_value in VlmSafeGenerator.DEFAULT_QUANTIZATION_CONFIG.items():
config[key] = getattr(args, key, default_value)
required_attrs = ['w_bit', 'a_bit', 'device_type']
for attr in required_attrs:
if hasattr(args, attr):
if attr == 'device_type':
config['dev_type'] = getattr(args, attr)
else:
config[attr] = getattr(args, attr)
if (hasattr(args, 'group_size') and hasattr(args, 'is_lowbit') and
hasattr(args, 'open_outlier')):
if args.is_lowbit and not args.open_outlier:
config['group_size'] = args.group_size
return config