"""hf fp8 safetensors cast to hf bf16 safetensors"""
import os
import json
from argparse import ArgumentParser
from glob import glob
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import torch
from mindformers.tools.utils import set_safe_mode_for_file_or_dir
def weight_dequant(weight: torch.Tensor, scale: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
"""
row, col = weight.shape
scale_m, scale_n = scale.shape
assert scale_m == (row + block_size - 1) // block_size, "Mismatch in scale rows and weight rows."
assert scale_n == (col + block_size - 1) // block_size, "Mismatch in scale columns and weight columns."
weight = weight.to(torch.float32)
scale_expanded = scale.repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1)
scale_expanded = scale_expanded[:row, :col]
dequantized_weight = weight * scale_expanded
return dequantized_weight.to(torch.get_default_dtype())
def load_model_index(fp8_path):
"""Load the model index JSON file."""
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
return model_index["weight_map"]
def get_tensor_loader(fp8_path, weight_map):
"""Closure-based loader to fetch tensors from the correct safetensor file."""
loaded_files = {}
def get_tensor(tensor_name):
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cpu")
return loaded_files[file_name][tensor_name]
return get_tensor, loaded_files
def convert_single_file(
safetensor_file,
bf16_path,
get_tensor_func,
loaded_files,
fp8_weight_names,
progress_bar=None
):
"""Convert a single .safetensors file and save the result."""
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cpu")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1:
scale_inv_name = f"{weight_name}_scale_inv"
try:
scale_inv = get_tensor_func(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
if progress_bar:
progress_bar.update(1)
def update_model_index(bf16_path, weight_map, fp8_weight_names):
"""Update and save the model index by removing _scale_inv entries."""
for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
weight_map.pop(scale_inv_name)
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
set_safe_mode_for_file_or_dir(new_model_index_file)
def main(fp8_path, bf16_path):
"""Main function that orchestrates the entire conversion process."""
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
weight_map = load_model_index(fp8_path)
get_tensor_func, loaded_files = get_tensor_loader(fp8_path, weight_map)
safetensor_files = sorted(glob(os.path.join(fp8_path, "*.safetensors")))
fp8_weight_names = []
with tqdm(total=len(safetensor_files), desc="Converting Files") as pbar:
for safetensor_file in safetensor_files:
convert_single_file(
safetensor_file=safetensor_file,
bf16_path=bf16_path,
get_tensor_func=get_tensor_func,
loaded_files=loaded_files,
fp8_weight_names=fp8_weight_names,
progress_bar=pbar
)
update_model_index(bf16_path, weight_map, fp8_weight_names)
if __name__ == "__main__":
parser = ArgumentParser(description="Convert FP8 weights to BF16.")
parser.add_argument("--input-fp8-hf-path", type=str, required=True,
help="Path to directory containing FP8 weights.")
parser.add_argument("--output-bf16-hf-path", type=str, required=True,
help="Path to output directory for BF16 weights.")
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)