import json
import math
import os
import sys
import re
import shutil
from argparse import ArgumentParser
from glob import glob
import numpy as np
import torch
import torch_npu
from safetensors.torch import load_file, save_file
from tqdm import tqdm
CUR_DIR = os.path.dirname(__file__)
ROOT_DIR = os.path.realpath(os.path.join(CUR_DIR, "../../../"))
sys.path.append(ROOT_DIR)
from mx_quantize import quantize_mx, pack_uint4, unpack_uint4, f4_unpacked_to_f32, f32_to_f4_unpacked
from convert_config import generate_quant_config, generate_ignore_item
NUM_BITS_4 = 4
NUM_BITS_8 = 8
DBL_MAX = 1.79769e308
def weight_dequant(weight: torch.Tensor, scale: torch.Tensor, block_size: int = 128, is_mx: bool=False) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor, efficiently handling cases where
`weight` is not a multiple of `block_size` by broadcasting `scale`.
Args:
weight (torch.Tensor): The quantized weight tensor of shape(M, N).
scale (torch.Tensor): The scale tensor of shape (M // block_size, N // block_size).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `weight`, converted to the default dtype.
Raises:
AssertionError: If `scale` dimensions do not align with `weight` shape after scaling.
"""
M, N = weight.shape
weight = weight.to(torch.float32)
scale = scale.to(torch.float32)
if is_mx:
scale_expanded = scale.repeat_interleave(block_size, dim=1)
else:
scale_m, scale_n = scale.shape
assert scale_m == (
M + block_size - 1) // block_size, "Mismatch in scale rows and weight rows."
assert scale_n == (
N + block_size - 1) // block_size, "Mismatch in scale columns and weight columns."
scale_expanded = scale.repeat_interleave(
block_size, dim=0).repeat_interleave(block_size, dim=1)
scale_expanded = scale_expanded[:M, :N]
dequantized_weight = weight * scale_expanded
dequantized_weight = dequantized_weight.to(torch.get_default_dtype())
return dequantized_weight
def per_channel_searching(x):
end_ratio = 1.3
start_ratio = 0.7
step = 0.01
min_val, max_val = x.min(1, keepdim=True)[0], x.max(1, keepdim=True)[0]
qmax = 2 ** (NUM_BITS_4 - 1) - 1
max_size = int((end_ratio - start_ratio) / step + 1)
smallest_loss = torch.zeros_like(min_val) + DBL_MAX
tmp_val = torch.zeros_like(max_val)
for i in range(max_size):
max_candidates = (start_ratio + i * step) * max_val
current_scale = max_candidates / qmax
qx = x.clamp(-max_candidates, max_candidates).mul(1.0 / current_scale).round().mul(current_scale)
current_loss = (x - qx).pow(2).sum(dim=1, keepdim=True)
mask = current_loss < smallest_loss
smallest_loss = torch.where(mask, current_loss, smallest_loss)
tmp_val = torch.where(mask, max_candidates, tmp_val)
scales = tmp_val / qmax
qx = x.clamp(-tmp_val, tmp_val).mul(1.0 / scales).round()
qx = torch.clamp(qx, -qmax, qmax)
return qx.cpu(), scales.cpu()
def unpack_mxfloat4_to_fp32(packed_tensor):
e2m1_values = torch.tensor([
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0
], dtype=torch.float32, device=packed_tensor.device)
low_4bits = packed_tensor & 0x0F
high_4bits = (packed_tensor // 16) & 0x0F
unpacked = torch.stack([low_4bits, high_4bits], dim=-1)
fp32_tensor = e2m1_values[unpacked.long()]
new_shape = list(packed_tensor.shape)
new_shape[-1] = new_shape[-1] * 2
return fp32_tensor.view(*new_shape)
def int4_assistance_bias(weight, weight_scale):
"""
Calculate the int4 weight assistance bias matrix for W4A8 MoEGEMM.
"""
repeat_times = weight.shape[1] // weight_scale.shape[1]
expanded_scale = weight_scale.repeat_interleave(repeat_times, dim=1)
weight_assistant_matrix = (expanded_scale * weight * 8).sum(dim=1).float()
return weight_assistant_matrix
def scale_fp32_to_u64(weight_scale):
"""
Convert FP32 scale to UINT64 scale for W4A8 MoEGMM.
"""
k, n = weight_scale.shape
scale_np = weight_scale.float().cpu().numpy()
scale_uint32 = scale_np.astype(np.float32)
scale_uint32.dtype = np.uint32
scale_uint64 = np.zeros((k, n * 2), dtype=np.uint32)
scale_uint64[..., ::2] = scale_uint32
scale_uint64.dtype = np.uint64
scale_uint64 = torch.from_numpy(scale_uint64).to(torch.uint64)
return scale_uint64
def pack_4bit(x: torch.Tensor):
"""
Pack int4 weight for W4A8 MoEGMM. Each two int4 numbers are packed into one byte.
"""
assert x.dtype == torch.int8
x = x.T.contiguous()
shape = x.shape
x = x.view(-1, 2)
x1 = x[:, 0]
x2 = x[:, 1]
y_x2 = torch.bitwise_left_shift(x2, 4)
y_x1 = x1 & 0b00001111
y = torch.bitwise_or(y_x1, y_x2)
y = y.view(shape[0], shape[1] // 2)
return y.T.contiguous()
def int_weight_quant(tensor: torch.Tensor, bits=8, weight_clip_factor=None):
assert tensor.dim() == 2
if weight_clip_factor is not None:
cur_min, cur_max = tensor.min(1, keepdim=True)[0], tensor.max(1, keepdim=True)[0]
cur_min *= weight_clip_factor[0]
cur_max *= weight_clip_factor[1]
tensor = torch.clamp(tensor, min=cur_min, max=cur_max)
if bits == 4:
quantized, scale = per_channel_searching(tensor.npu())
quantized = quantized.to(torch.int8)
bias = int4_assistance_bias(quantized, scale)
quantized = pack_4bit(quantized)
scale = scale_fp32_to_u64(scale)
return quantized, scale, bias
qmax = 2 ** (bits - 1) - 1
abs_max = torch.abs(tensor).max(dim=1, keepdim=True)[0]
scale = abs_max / qmax
assert scale.shape == (tensor.shape[0], 1)
quantized = torch.round(tensor / scale)
quantized = torch.clamp(quantized, -qmax, qmax)
return quantized.to(torch.int8), scale.to(torch.float32), None
def generate_quant_layers(num_layers, num_experts, compress_ratios, w4a8=False, is_mx=False):
quant_layers = {}
moe_bit = NUM_BITS_4 if w4a8 else NUM_BITS_8
se_bit = NUM_BITS_8
attn_bit = NUM_BITS_8
mtp_bit = NUM_BITS_8
mlp_linears = ["w1", "w2", "w3"]
for i in range(num_layers):
ratio = compress_ratios[i]
for j in range(num_experts):
for n in mlp_linears:
quant_layers[f"layers.{i}.ffn.experts.{j}.{n}"] = moe_bit
for n in mlp_linears:
quant_layers[f"layers.{i}.ffn.shared_experts.{n}"] = se_bit
if is_mx:
quant_layers[f"layers.{i}.attn.wq_a"] = attn_bit
quant_layers[f"layers.{i}.attn.wkv"] = attn_bit
quant_layers[f"layers.{i}.attn.wo_a"] = attn_bit
quant_layers[f"layers.{i}.attn.wq_b"] = attn_bit
quant_layers[f"layers.{i}.attn.wo_b"] = attn_bit
if ratio == NUM_BITS_4:
quant_layers[f"layers.{i}.attn.indexer.wq_b"] = attn_bit
for j in range(num_experts):
for n in mlp_linears:
quant_layers[f"mtp.0.ffn.experts.{j}.{n}"] = moe_bit
for n in mlp_linears:
quant_layers[f"mtp.0.ffn.shared_experts.{n}"] = mtp_bit
if is_mx:
quant_layers[f"mtp.0.attn.wq_a"] = mtp_bit
quant_layers[f"mtp.0.attn.wkv"] = mtp_bit
quant_layers[f"mtp.0.attn.wo_a"] = mtp_bit
quant_layers[f"mtp.0.attn.wq_b"] = mtp_bit
quant_layers[f"mtp.0.attn.wo_b"] = mtp_bit
quant_layers[f"mtp.0.e_proj"] = mtp_bit
quant_layers[f"mtp.0.h_proj"] = mtp_bit
return quant_layers
def copy_py_json(src, target):
for root, _, files in os.walk(src):
for file in files:
if file.endswith(('.py', '.json', '.jinja')):
src_path = os.path.join(root, file)
rel_dir = os.path.relpath(root, src)
dst_dir = os.path.join(target, rel_dir)
os.makedirs(dst_dir, exist_ok=True)
dst_path = os.path.join(dst_dir, file)
shutil.copy2(src_path, dst_path)
def load_clip_params(num_layers, clip_param_path):
weight_clip_params = {}
for layer_idx in range(num_layers):
expected_file = os.path.join(clip_param_path, f"quant_params_{layer_idx}.pt")
if not os.path.exists(expected_file):
raise FileNotFoundError(expected_file)
else:
quant_params = torch.load(expected_file)
weight_clip_params.update(quant_params)
return weight_clip_params
def hif8_weight_quant(tensor: torch.Tensor, bits=8, weight_clip_factor=None):
tensor_npu = tensor.npu()
quantized_tensor, scale = torch_npu.npu_dynamic_quant(tensor_npu, dst_type=torch_npu.hifloat8)
quantized_tensor = quantized_tensor.cpu()
scale = scale.cpu()
scale = scale.unsqueeze(-1)
return quantized_tensor, scale
def main(fp8_path, output_path, quant_type, quant_param_path=None):
"""
Converts FP8 weights to BF16 and saves the converted weights.
This function reads FP8 weights from the specified directory, converts them to BF16,
and saves the converted weights to another specified directory. It also updates the
model index file to reflect the changes.
Args:
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
output_path (str): The path to the directory where the converted BF16/INT8/MXFP4/8 weights will be saved.
quant_type (str): The type of quantization to apply. Supported values are "bfloat16",
"w8a8-int", "w8a8-mx", "w4a8-mx".
clip (bool, optional): Whether to apply clipping during quantization. Defaults to False.
quant_param_path (str, optional): The path to the directory containing quantization parameters.
w4a8 (bool): Quantize the MoE to W4A8.
Raises:
KeyError: If a required scale_inv tensor is missing for a weight.
Notes:
- The function assumes that the FP8 weights are stored in safetensor files.
- The function caches loaded safetensor files to optimize memory usage.
- The function updates the model index file to remove references to scale_inv tensors.
"""
torch.set_default_dtype(torch.bfloat16)
os.makedirs(output_path, exist_ok=True)
assert quant_type in [
"bfloat16", "w8a8-int", "w4a8-int", "w8a8-mx", "w4a8-mx", "w4a8-mx-hif"], f"Unsupported quant_type:{quant_type}"
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
config_file = os.path.join(fp8_path, 'config.json')
with open(model_index_file, "r") as f:
model_index = json.load(f)
with open(config_file, "r") as f:
config = json.load(f)
weight_map = model_index["weight_map"]
new_weight_map = {}
num_layers = config['num_hidden_layers']
num_experts = config['n_routed_experts']
compress_ratios = config['compress_ratios']
quant_parts = quant_type.split('-')
w4a8 = quant_type.startswith("w4a8")
w8a8 = quant_type.startswith("w8a8")
mx = "mx" in quant_parts
hif = quant_type.endswith('hif')
if w8a8 or w4a8:
cache_scheme = {"kv_cache_scheme": {"num_bits": NUM_BITS_8, "type": "float"} if mx else None,
"li_cache_scheme": {
"type": "float" if mx else "int",
"num_bits": NUM_BITS_8,
}}
if mx and w8a8 and not hif:
config['quantization_config']["quant_method"] = "mxfp8"
config['quantization_config'].pop("weight_block_size")
config['quantization_config'].update(cache_scheme)
else:
if 'quantization_config' in config:
config.pop('quantization_config')
quant_ignore_layers = generate_ignore_item(num_layers, compress_ratios, is_fp=mx)
quantization_config = generate_quant_config(
cache_scheme, quant_ignore_layers, w4a8=w4a8, is_fp=mx)
config['quantization_config'] = quantization_config
quant_layers = generate_quant_layers(num_layers, num_experts, compress_ratios, w4a8=w4a8, is_mx=mx)
loaded_files = {}
def get_tensor(tensor_name):
"""
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
Args:
tensor_name (str): The name of the tensor to retrieve.
Returns:
torch.Tensor: The retrieved tensor.
Raises:
KeyError: If the tensor does not exist in the safetensor file.
"""
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cpu")
return loaded_files[file_name][tensor_name]
is_clip = False
if quant_param_path is not None:
weight_clip_params = load_clip_params(num_layers, quant_param_path)
is_clip = True
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cpu")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith(".scale"):
continue
elif weight.element_size() == 1:
scale_inv_name = weight_name.replace('.weight', '.scale')
try:
scale_inv = get_tensor(scale_inv_name)
if weight.dtype == torch.int8:
weight = unpack_mxfloat4_to_fp32(weight.view(torch.uint8))
weight = weight_dequant(weight, scale_inv, block_size=32, is_mx=True)
else:
weight = weight_dequant(weight, scale_inv)
except KeyError:
print(
f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
new_weight_map[weight_name] = file_name
else:
new_state_dict[weight_name] = weight
new_weight_map[weight_name] = file_name
if w4a8 or w8a8:
new_weight_name = weight_name.rsplit(".", 1)[0]
if new_weight_name in list(quant_layers.keys()):
bit = quant_layers[new_weight_name]
if mx:
if hif and bit != NUM_BITS_4:
quant_weight, scale_inv = hif8_weight_quant(weight, bit)
else:
quant_weight, scale_inv = quantize_mx(weight, bit, real_quant=True)
if bit == NUM_BITS_4:
quant_weight = f32_to_f4_unpacked(quant_weight.float())
quant_weight = pack_uint4(quant_weight)
else:
if is_clip:
weight_clip_factor = [weight_clip_params.get(f"{new_weight_name}.clip_factor_min")]
weight_clip_factor.append(weight_clip_params.get(f"{new_weight_name}.clip_factor_max"))
weight_clip_factor = None if all(v is None for v in weight_clip_factor) \
else weight_clip_factor
else:
weight_clip_factor = None
quant_weight, scale_inv, bias = int_weight_quant(weight, bits=bit,
weight_clip_factor=weight_clip_factor)
if w4a8 and bias is not None:
bias_name = weight_name.replace(".weight", ".bias")
new_state_dict[bias_name] = bias
new_weight_map[bias_name] = file_name
new_scale_name = weight_name.replace('.weight', '.scale')
new_state_dict[weight_name] = quant_weight
new_state_dict[new_scale_name] = scale_inv
new_weight_map[weight_name] = file_name
new_weight_map[new_scale_name] = file_name
new_safetensor_file = os.path.join(output_path, file_name)
save_file(new_state_dict, new_safetensor_file,
metadata={'format': 'pt'})
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
copy_py_json(fp8_path, output_path)
new_model_index_file = os.path.join(
output_path, "model.safetensors.index.json")
new_config_file = os.path.join(output_path, "config.json")
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": new_weight_map}, f, indent=2)
with open(new_config_file, "w") as f:
json.dump(config, f, indent=2)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input_fp8_hf_path", type=str, required=True)
parser.add_argument("--output_hf_path", type=str, required=True)
parser.add_argument("--quant_type", type=str, default="w8a8-int",
choices=["w8a8-int", "w4a8-int", "bfloat16", "w4a8-mx"])
parser.add_argument("--quant_param_path", type=str, default=None)
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_hf_path,
args.quant_type, args.quant_param_path)