import json
import os
from safetensors.torch import load_file
from modelslim.pytorch.weight_compression import CompressConfig, Compressor
LOAD_PATH = f"{os.environ['PROJECT_PATH']}/resource/weight_compression_safetensor"
weight_path = os.path.join(LOAD_PATH, "quant_model_weight_w8a8s.safetensors")
json_path = os.path.join(LOAD_PATH, "quant_model_description_w8a8s.json")
compress_config = CompressConfig(do_pseudo_sparse=False, sparse_ratio=1, is_debug=True,
record_detail_root=f"{os.environ['PROJECT_PATH']}/output/"
f"weight_compression_safetensor", multiprocess_num=8)
sparse_weight = load_file(weight_path)
with open(json_path, 'r') as f:
quant_model_description = json.load(f)
compressor = Compressor(compress_config, weight=sparse_weight, quant_model_description=quant_model_description)
compress_weight, compress_index, compress_info = compressor.run()
compressor.export_safetensors(f"{os.environ['PROJECT_PATH']}/output/weight_compression_safetensor",
safetensors_name=None, json_name=None)