import os
import sys
import shutil
import argparse
import torch
import torch_npu
from safetensors import safe_open
from safetensors.torch import save_file
from msmodelslim import logger
current_directory = os.path.dirname(os.path.abspath(__file__))
parent_directory = os.path.abspath(os.path.join(current_directory, ".."))
sys.path.append(parent_directory)
from example.common.security.path import get_valid_read_path, get_write_directory, get_valid_write_path
SUPPORTED_EXTENSIONS = {'.json', '.py'}
MAX_FILE_NUM = 1024
def parse_args():
parser = argparse.ArgumentParser(description="Creating quant weights ")
parser.add_argument("--model_path", type=str, help="Quantied safetensors file path")
parser.add_argument("--save_directory", type=str, help="The path to save processed quant weights")
return parser.parse_args()
def cast_deq_scale_to_int64(tensor: torch.Tensor) -> torch.Tensor:
"""
Converts a quantized tensor's scale parameter to int64 format.
Args:
tensor (torch.Tensor): Input quantized tensor containing scale parameters.
Should be compatible with NPU operations.
Returns:
torch.Tensor: Processed tensor in int64 format containing quantization scale information,
transferred back to CPU memory.
"""
processed_tensor = torch_npu.npu_trans_quant_param(tensor.npu()).cpu()
return processed_tensor
def process_safetensors_file(file_path: str, save_path: str):
"""
Processes a safetensors file by converting specific dequantization scale tensors from float32 to int64 format.
This function reads a safetensors file, identifies tensors that contain dequantization scale parameters,
converts them to a more efficient int64 representation, and saves the modified tensors to a new file while
preserving all original metadata and non-scale tensors.
Args:
file_path (str): Path to the input safetensors file to be processed
save_path (str): Path where the processed safetensors file will be saved
"""
tensors = {}
metadata = {}
try:
with safe_open(file_path, framework="pt", device="cpu") as f:
keys = f.keys()
if hasattr(f, 'metadata'):
metadata = f.metadata()
for key in keys:
tensor = f.get_tensor(key)
if "deq_scale" in key and tensor.dtype == torch.float32:
processed_tensor = cast_deq_scale_to_int64(tensor)
tensors[key] = processed_tensor
else:
tensors[key] = tensor
save_file(tensors, save_path, metadata=metadata)
except Exception as e:
raise RuntimeError(f"Error processing {file_path}: {e}") from e
def copy_config_files(model_path: str, save_directory: str):
"""
Copies configuration files from a source model directory to a destination directory.
This function selectively copies configuration files with specific extensions from the source
model directory to the target save directory. It includes safety checks for file count limits
and sets secure file permissions on the copied files.
Args:
model_path (str): Source directory path containing the configuration files to be copied
save_directory (str): Destination directory path where files will be copied to
"""
filenames = os.listdir(model_path)
if len(filenames) > MAX_FILE_NUM:
raise ValueError(
f"The file num in dir is {len(filenames)}, which exceeds the limit {MAX_FILE_NUM}."
)
for filename in filenames:
filename = os.path.basename(filename)
_, ext = os.path.splitext(filename)
if ext not in SUPPORTED_EXTENSIONS:
continue
src_filepath = get_valid_read_path(os.path.join(model_path, filename))
dest_filepath = get_valid_write_path(os.path.join(save_directory, filename))
shutil.copyfile(src_filepath, dest_filepath)
os.chmod(dest_filepath, 0o600)
def process_safetensors(model_path: str, save_directory: str):
"""
Processes all safetensors files in a model directory and copying configuration files.
This function serves as the main entry point for processing model files. It discovers all safetensors files
in the specified model directory, processes each one to convert dequantization scale parameters, and then
copies relevant configuration files to the destination directory.
Args:
model_path (str): Source directory containing the model files (.safetensors) to be processed
save_directory (str): Target directory where processed files and configurations will be saved
"""
file_extension = ".safetensors"
safetensors_files = []
save_files = []
for file in os.listdir(model_path):
file = os.path.basename(file)
if file.endswith(file_extension):
safetensors_files.append(os.path.join(model_path, file))
save_files.append(os.path.join(save_directory, file))
if not safetensors_files:
raise RuntimeError(f"No safetensors files found in: {model_path}")
logger.info(f"Found {len(safetensors_files)} safetensors files to process")
for i, file_path in enumerate(safetensors_files):
logger.info(f"Processing: {file_path}")
process_safetensors_file(file_path, save_files[i])
copy_config_files(model_path, save_directory)
if __name__ == "__main__":
args = parse_args()
args.model_path = get_valid_read_path(args.model_path, is_dir=True, check_user_stat=True)
args.save_directory = get_write_directory(args.save_directory, write_mode=0o750)
try:
process_safetensors(args.model_path, args.save_directory)
logger.info("Processed weights saved successfully.")
except Exception as e:
logger.error(f"Process weights failed. Error detail: {e}")