import os
from enum import Enum, auto
from functools import lru_cache
import torch
from safetensors.torch import load_file
from tqdm import tqdm
from ascend_utils.common.security import json_safe_load, get_valid_read_path, MAX_READ_FILE_SIZE_32G
from msmodelslim import logger as msmodelslim_logger
from msmodelslim.pytorch.llm_ptq.accelerate_adapter import replace_device_align_hook_if_needed
from msmodelslim.pytorch.llm_ptq.accelerate_adapter.hook_adapter import get_offloaded_weights_loader_if_have
from msmodelslim.pytorch.llm_ptq.accelerate_adapter.utils import judge_model_with_accelerate
WEIGHT_SCALE_INV = '.weight_scale_inv'
HF_HOOK = '_hf_hook'
class OpsType(Enum):
FP8 = auto()
BF16 = auto()
AUTO = auto()
@staticmethod
def get_ops_type(is_bf16: bool, is_fp8: bool):
if is_bf16 and is_fp8:
raise ValueError(f'Using both label fp8 and label bf16.')
ops_type = OpsType.AUTO
if is_bf16:
ops_type = OpsType.BF16
if is_fp8:
ops_type = OpsType.FP8
return ops_type
def weight_dequant(weight: torch.Tensor, scale: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor, efficiently handling cases where
`weight` is not a multiple of `block_size` by broadcasting `scale`.
Args:
weight (torch.Tensor): The quantized weight tensor of shape(M, N).
scale (torch.Tensor): The scale tensor of shape (M // block_size, N // block_size).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `weight`, converted to the default dtype.
"""
m, n = weight.shape
weight = weight.to(torch.float32)
scale_expanded = scale.repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1)
scale_expanded = scale_expanded[:m, :n]
weight *= scale_expanded
weight = weight.to(torch.bfloat16)
return weight
@lru_cache(maxsize=3)
def get_tensor_file(file_dir, file_name):
file_path = os.path.join(file_dir, file_name)
file_path = get_valid_read_path(file_path, 'safetensors', size_max=MAX_READ_FILE_SIZE_32G)
return load_file(file_path, device='cpu')
def get_tensor(tensor_name, fp8_path, weight_map):
file_name = weight_map[tensor_name]
loaded_file = get_tensor_file(fp8_path, file_name)
return loaded_file[tensor_name]
def get_module_by_name(model, submodule_key=None):
if submodule_key is None:
return submodule_key
tokens = submodule_key.split('.')
cur_mod = model
for s in tokens:
cur_mod = getattr(cur_mod, s, None)
return cur_mod
def auto_convert_model_fp8_to_bf16(model, model_path, ops_type: OpsType = OpsType.AUTO):
if ops_type is OpsType.BF16:
return
scale_list, weight_map = get_weight_map_and_scale_list_from_index(model, model_path)
if ops_type is OpsType.FP8:
if not scale_list:
raise ValueError('Can not find any fp8 inv scale, please check whether model is of fp8.')
convert_model_fp8_to_bf16(model, model_path, scale_list, weight_map)
return
if not scale_list:
return
try:
convert_model_fp8_to_bf16(model, model_path, scale_list, weight_map)
except KeyError:
msmodelslim_logger.warning(f'Safetensors files not match index.json, please check whether model is of bf16.')
msmodelslim_logger.warning(f'Skip fp8 to bf16.')
except Exception as e:
msmodelslim_logger.error(f'Unexpected error occurred: {e}.')
raise
def get_weight_map_and_scale_list_from_index(model, model_path):
model_index_path = os.path.join(model_path, "model.safetensors.index.json")
model_index = json_safe_load(model_index_path)
weight_map = model_index['weight_map']
convert_list = set(
map(lambda x: x.replace(WEIGHT_SCALE_INV, ''), filter(lambda x: WEIGHT_SCALE_INV in x, weight_map.keys())))
scale_list = []
for name, _ in model.named_modules():
if name in convert_list:
scale_list.append(name)
return scale_list, weight_map
def convert_model_fp8_to_bf16(model, model_path, scale_list, weight_map):
if not judge_model_with_accelerate(model):
raise ValueError(f'Deepseek V3/R1 only support npu limited cpu unlimited quantization for now')
replace_device_align_hook_if_needed(model)
with torch.no_grad():
for name in tqdm(scale_list, desc='fp8 to bf16'):
module = get_module_by_name(model, name)
scale = get_tensor(name + WEIGHT_SCALE_INV, model_path, weight_map)
weight_loader = get_offloaded_weights_loader_if_have(module)
if weight_loader and getattr(module, HF_HOOK).old_hook.offload:
weight_name = name + '.weight'
weight_loader[weight_name][:] = weight_dequant(weight_loader[weight_name], scale)
continue
device = getattr(module, HF_HOOK).old_hook.execution_device
if device != 'cpu':
device = f'npu:{device}'
scale = scale.to(device)
module.weight[:] = weight_dequant(module.weight, scale)