"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import os
import sys
import argparse
import json
from typing import List
import numpy as np
import torch
from safetensors.torch import load_file, save_file
current_directory = os.path.dirname(os.path.abspath(__file__))
parent_directory = os.path.abspath(os.path.join(current_directory, '..'))
sys.path.append(parent_directory)
from example.common.security.path import get_valid_write_path, SafeWriteUmask, get_valid_read_path
TOOL_AWQ = 'awq'
TOOL_GPTQ = 'gptq'
QUANT_TYPE = ['W4A16', 'W8A16']
STORAGE_BITS = 32
ORDINAL_PACK_ORDER = [0, 1, 2, 3, 4, 5, 6, 7]
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
def awq_pack(iweight: torch.Tensor, w_bit: int, direction: str = "column"):
pack_num = STORAGE_BITS // w_bit
shifts = torch.arange(0, STORAGE_BITS, w_bit, device=iweight.device)
iweight = iweight.to(torch.int8)
iweight = torch.bitwise_and(iweight, 0x0F)
if direction == "column":
iweight = iweight.view(-1, iweight.shape[1] // pack_num, pack_num)
qmatrix = torch.bitwise_left_shift(iweight, shifts[None, None, :]).sum(dim=-1)
elif direction == "row":
iweight = iweight.view(iweight.shape[0] // pack_num, pack_num, -1)
qmatrix = torch.bitwise_left_shift(iweight, shifts[None, :, None]).sum(dim=1)
qmatrix = qmatrix.to(torch.int32)
return qmatrix
def apply_order(
iweight: torch.Tensor,
w_bit: int,
direction: str = "column",
order: List[int] = None,
):
pack_num = STORAGE_BITS // w_bit
try:
if direction == "column":
iweight = iweight.view(-1, pack_num)[:, order].view(iweight.shape)
elif direction == "row":
iweight = iweight.view(pack_num, -1)[order, :].view(iweight.shape)
except IndexError as ide:
raise IndexError(f"Order index {order} out of range for pack_num {pack_num}. "
f"Order indices must be < {pack_num} for w_bit={w_bit}") from ide
return iweight
def gptq_qweight_pack(iweight: torch.Tensor, w_bit: int):
i = 0
row = 0
iweight = iweight.numpy().astype(np.uint32)
if len(iweight.shape) < 2:
raise ValueError("Expected qweight to have at least 2 dimensions, but got shape: {}".format(iweight.shape))
qweight = np.zeros((iweight.shape[0] // STORAGE_BITS * w_bit, iweight.shape[1]), dtype=np.uint32)
while row < qweight.shape[0]:
if w_bit in (4, 8):
for j in range(i, i + (32 // w_bit)):
qweight[row] |= iweight[j] << (w_bit * (j - i))
i += 32 // w_bit
row += 1
qweight = qweight.astype(np.int32)
qweight = torch.from_numpy(qweight)
return qweight
def gptq_qzeros_pack(zeros: torch.Tensor, w_bit: int):
i = 0
col = 0
if zeros.dtype == torch.bfloat16:
zeros = zeros.to(torch.float16)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
if len(zeros.shape) < 2:
raise ValueError("Expected zeros to have at least 2 dimensions, but got shape: {}".format(zeros.shape))
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // STORAGE_BITS * w_bit), dtype=np.uint32)
while col < qzeros.shape[1]:
if w_bit in (4, 8):
for j in range(i, i + (32 // w_bit)):
qzeros[:, col] |= zeros[:, j] << (w_bit * (j - i))
i += 32 // w_bit
col += 1
qzeros = qzeros.astype(np.int32)
qzeros = torch.from_numpy(qzeros)
return qzeros
def convert_ms_to_vllm(target_tool, w_bit, weight_dict, json_dict):
vllm_weight_dict = {}
for name, quant_type in json_dict.items():
if name in weight_dict.keys():
tensor = weight_dict[name]
if quant_type in QUANT_TYPE:
order = AWQ_PACK_ORDER
direction = 'column'
if name.endswith('.weight') and '.module.weight' not in name:
base_name = name.rsplit('.', 1)[0]
tmp_key = base_name + '.module.weight'
if tmp_key in weight_dict.keys():
continue
vllm_name = base_name + '.qweight'
tensor = tensor.t().contiguous()
tensor = torch.clamp(tensor, -2**(w_bit - 1), 2**(w_bit - 1) - 1)
if w_bit == 8:
tensor = tensor.to(torch.int32)
tensor.add_(2 ** (w_bit - 1))
if target_tool == TOOL_AWQ:
iweights = apply_order(tensor, w_bit, direction, order)
qweight = awq_pack(iweights, w_bit)
elif target_tool == TOOL_GPTQ:
qweight = gptq_qweight_pack(tensor, w_bit)
vllm_weight_dict[vllm_name] = qweight
elif name.endswith('.weight_scale'):
vllm_name = name.rsplit('.', 1)[0] + '.scales'
tensor = tensor.t().contiguous()
vllm_weight_dict[vllm_name] = tensor
elif name.endswith('.weight_offset'):
vllm_name = name.rsplit('.', 1)[0] + '.qzeros'
tensor = tensor.t().contiguous()
tensor = torch.clamp(tensor, -2**(w_bit - 1), 2**(w_bit - 1) - 1)
if w_bit == 8:
tensor = tensor.to(torch.int32)
tensor.add_(2 ** (w_bit - 1))
if target_tool == TOOL_AWQ:
izeros = apply_order(tensor, w_bit, direction, order)
qzeros = awq_pack(izeros, w_bit, direction)
elif target_tool == TOOL_GPTQ:
qzeros = gptq_qzeros_pack(tensor, w_bit)
vllm_weight_dict[vllm_name] = qzeros
elif 'module.weight' in name and 'model.norm' not in name:
vllm_name = name.replace('module.weight', 'weight')
vllm_weight_dict[vllm_name] = tensor
elif 'model.norm.module.bias' in name or 'model.norm.module.weight' in name:
pass
else:
vllm_weight_dict[name] = tensor
else:
vllm_weight_dict[name] = tensor
return vllm_weight_dict
def load_json_info(json_file_path):
with open(json_file_path, 'r', encoding='utf-8') as f:
quant_info = json.load(f)
return quant_info
def check_w_bit(value):
ivalue = int(value)
if ivalue not in (4, 8):
raise argparse.ArgumentTypeError(f"Invalid w_bit value: {value}. Supported values are 4 and 8.")
return ivalue
def check_target_tool(value):
if value not in ("awq", "gptq"):
raise argparse.ArgumentTypeError(f"Invalid target_tool value: {value}. Supported values are 'awq' and 'gptq'.")
return value
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default=None, help="Quantied safetensors file path")
parser.add_argument("--json", type=str, default=None, help="Quantied description file path")
parser.add_argument("--save_path", type=str,
default='res.safetensors',
help="The path to save converted quant weights")
parser.add_argument("--w_bit", type=check_w_bit, default=4, help="Quantied weight bits")
parser.add_argument(
"--target_tool", type=check_target_tool, default="awq", help="target tool, value include awq and gptq"
)
args = parser.parse_args()
save_path = args.save_path
w_bit = args.w_bit
quant_tool = args.target_tool
model_path = get_valid_read_path(args.model, size_max=0)
tensor_info = load_file(model_path)
json_path = get_valid_read_path(args.json)
json_info = load_json_info(json_path)
vllm_weight = convert_ms_to_vllm(quant_tool, w_bit, weight_dict=tensor_info, json_dict=json_info)
save_path = get_valid_write_path(save_path)
with SafeWriteUmask(umask=0o377):
save_file(vllm_weight, save_path)