import os
import shutil
import argparse
from dataclasses import dataclass
from typing import List, Optional, Any
from example.common.security.path import json_safe_load, json_safe_dump
from example.common.security.path import get_valid_read_path, get_valid_write_path
from example.common.utils import SafeGenerator
@dataclass
class ModifyConfigParams:
"""修改配置文件的参数封装"""
model_dir: str
dest_dir: str
torch_dtype: Any
quantize_type: str
args: Optional[Any] = None
quantize_config_parts: Optional[List[str]] = None
def __post_init__(self):
"""后初始化处理"""
if self.quantize_config_parts is None:
self.quantize_config_parts = []
@dataclass
class CopyTokenizerParams:
"""复制tokenizer文件的参数封装"""
model_dir: str
dest_dir: str
class VlmSafeGenerator(SafeGenerator):
DEFAULT_QUANTIZATION_CONFIG = {
'group_size': 0,
'act_method': 2,
'anti_method': 'm2',
'is_lowbit': False,
'mm_tensor': False,
'w_sym': True,
'open_outlier': True,
'is_dynamic': False,
}
SUPPORTED_EXTENSIONS = {'.json', '.py'}
EXCLUDED_FILES = {'config.json', 'model.safetensors.index.json'}
MAX_FILE_NUM = 1024
@staticmethod
def modify_config(params: ModifyConfigParams):
"""修改配置文件"""
model_dir = get_valid_read_path(params.model_dir, is_dir=True, check_user_stat=True)
dest_dir = get_valid_write_path(params.dest_dir, is_dir=True)
src_config_filepath = os.path.join(model_dir, 'config.json')
data = json_safe_load(src_config_filepath, check_user_stat=True)
dest_quant_description_filepath = VlmSafeGenerator._get_quantization_filename(
dest_dir, params.quantize_type, getattr(params.args, 'mindie_format', False)
)
dest_quant_description_filepath = get_valid_write_path(dest_quant_description_filepath, is_dir=False)
quant_description_data = json_safe_load(dest_quant_description_filepath, check_user_stat=True)
data['torch_dtype'] = str(params.torch_dtype).split('.')[1]
if params.args and getattr(params.args, 'mindie_format', False):
data['quantize'] = params.quantize_type
for config_part in params.quantize_config_parts:
data[config_part]['quantize'] = params.quantize_type
quantization_config = VlmSafeGenerator._build_quantization_config(params.args)
if quantization_config:
quant_description_data.update(quantization_config)
if params.args and getattr(params.args, 'mindie_format', False):
data['quantization_config'] = quantization_config
dest_config_filepath = os.path.join(dest_dir, 'config.json')
json_safe_dump(data, dest_config_filepath, 4)
@staticmethod
def copy_tokenizer_files(params: CopyTokenizerParams):
"""复制tokenizer文件"""
model_dir = get_valid_read_path(params.model_dir, is_dir=True, check_user_stat=True)
if not os.path.exists(params.dest_dir):
os.makedirs(params.dest_dir, mode=0o750, exist_ok=True)
dest_dir = get_valid_write_path(params.dest_dir, is_dir=True)
filenames = os.listdir(model_dir)
if len(filenames) > VlmSafeGenerator.MAX_FILE_NUM:
raise argparse.ArgumentTypeError(
f"The file num in dir is {len(filenames)}, "
f"which exceeds the limit {VlmSafeGenerator.MAX_FILE_NUM}."
)
for filename in filenames:
_, ext = os.path.splitext(filename)
if ext not in VlmSafeGenerator.SUPPORTED_EXTENSIONS:
continue
if filename in VlmSafeGenerator.EXCLUDED_FILES:
continue
src_filepath = os.path.join(model_dir, filename)
dest_filepath = os.path.join(dest_dir, filename)
shutil.copyfile(src_filepath, dest_filepath)
os.chmod(dest_filepath, 0o600)
@staticmethod
def _get_quantization_filename(dest_dir, quantize_type, mindie_format):
"""生成量化描述文件名"""
if mindie_format:
filename = f"quant_model_description_{quantize_type.lower()}.json"
else:
filename = "quant_model_description.json"
return os.path.join(dest_dir, filename)
@staticmethod
def _build_quantization_config(args):
"""构建量化配置字典"""
if args is None:
return {}
config = {}
for key, default_value in VlmSafeGenerator.DEFAULT_QUANTIZATION_CONFIG.items():
config[key] = getattr(args, key, default_value)
required_attrs = ['w_bit', 'a_bit', 'device_type']
for attr in required_attrs:
if hasattr(args, attr):
if attr == 'device_type':
config['dev_type'] = getattr(args, attr)
else:
config[attr] = getattr(args, attr)
if (hasattr(args, 'group_size') and hasattr(args, 'is_lowbit') and
hasattr(args, 'open_outlier')):
if args.is_lowbit and not args.open_outlier:
config['group_size'] = args.group_size
return config