import math
from typing import Callable, Any, Dict, List, Optional, Union
import torch
import torch_npu
import torch.nn.functional as F
from torch.nn import Parameter
from module.quantization import QuantizationMethods, QuantizeMethodBase, QuantizationConfig
from module.linear import LinearBase, LinearMethodBase, UnquantizedLinearMethod
from module.quantization.utils.quant_utils import is_layer_skipped
from module.fuse_moe_gmm import FusedMoEGMM, FusedMoeWeightScaleSupported
from module.utils import set_weight_attrs
ACTIVATION_SCHEMES = ["static", "dynamic"]
BEFORE_INIT = 0
AFTER_INIT = 1
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: Optional[list[str]] = None,
weight_block_size: Optional[list[int]] = None,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
self.ignored_layers = ignored_layers or []
if weight_block_size is not None:
if not is_checkpoint_fp8_serialized:
raise ValueError(
"The block-wise quantization only supports fp8-serialized "
"checkpoint for now.")
if len(weight_block_size) != 2:
raise ValueError(
"The quantization block size of weight must have 2 "
f"dimensions, but got {len(weight_block_size)} dimensions")
if activation_scheme != "dynamic":
raise ValueError("The block-wise quantization only supports "
"dynamic activation scheme for now, but got "
f"{activation_scheme} activation scheme.")
self.weight_block_size = weight_block_size
@classmethod
def get_name(cls) -> QuantizationMethods:
return "fp8"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
@classmethod
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers", "modules_to_not_convert"], None)
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
None)
quant_config_instance = cls(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=weight_block_size
)
quant_config_instance.mm_quant_mode = "w8a8float8"
quant_config_instance.gmm_quant_mode = "w8a8float8"
kv_cache_scheme = config.get("kv_cache_scheme")
if kv_cache_scheme is not None:
quant_config_instance.kv_cache_quant_mode = \
kv_cache_scheme.get("type", "float") + str(kv_cache_scheme.get("num_bits", 8))
return quant_config_instance
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoEGMM):
return Fp8MoEGMMMethod(self)
return None
class Fp8LinearMethod(LinearMethodBase):
_kernel_backends_being_used: set[str] = set()
def __init__(self, quant_config):
self.out_dtype = torch.get_default_dtype()
self.weight_block_size = quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
def create_weights(self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs):
weight_dtype = torch.float8_e4m3fn
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=weight_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0, "weight_loader": weight_loader})
layer.register_parameter("weight", weight)
scale_dtype = torch.float32
block_n, block_k = self.weight_block_size[0], self.weight_block_size[1]
scale = Parameter(
torch.empty(
(math.ceil(sum(output_partition_sizes) / block_n),
math.ceil(input_size_per_partition / block_k)),
dtype=scale_dtype
),
requires_grad=False
)
set_weight_attrs(
scale, {"input_dim": 1, "output_dim": 0, "is_per_block_scale": True, "weight_loader": weight_loader})
layer.register_parameter("scale", scale)
setattr(layer, "init_state", BEFORE_INIT)
def apply(self, layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor],
dynamic_scale: Optional = None,
out_dtype: Optional = torch.bfloat16,
) -> Union[torch.Tensor, Dict[str, Any]]:
out_shape_dim_0 = x.size()[:-1]
if dynamic_scale is not None:
x_scale = dynamic_scale
else:
x = x.view(-1, x.size(-1))
x, x_scale = torch_npu.npu_dynamic_block_quant(
x,
dst_type=torch.float8_e4m3fn,
row_block_size=1,
col_block_size=self.weight_block_size[1],
)
x = torch_npu.npu_quant_matmul(
x, layer.weight,
layer.scale,
pertoken_scale=x_scale,
bias=layer.bias,
output_dtype=torch.bfloat16,
group_sizes=[1, self.weight_block_size[1], self.weight_block_size[0]])
x = x.view(out_shape_dim_0 + (-1, ))
if out_dtype == torch.int32:
return x, x_scale
else:
return x
apply_weights = apply
def process_weights_after_loading(self, layer, is_transpose=True, is_nz=True, scales_dtype=None):
weight = layer.weight
scale = layer.scale
if is_transpose:
weight.data = weight.data.transpose(-2, -1)
scale.data = scale.data.view(scale.size(0), scale.size(1)).transpose(0, 1)
else:
scale.data = scale.data.view(scale.size(0), scale.size(1))
if is_nz:
weight.data = torch_npu.npu_format_cast(weight.data.contiguous(), 29)
scale.data = scale.data.contiguous()
if 'scale_dtype' in scales_dtype:
scale.data = scale.data.to(scales_dtype.get('scale_dtype'))
layer.weight = Parameter(weight, requires_grad=False)
layer.scale = Parameter(scale, requires_grad=False)
class Fp8MoEGMMMethod(QuantizeMethodBase):
def __init__(self, quant_config: Fp8Config):
super().__init__()
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
weight_type = torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype
scale_dtype = torch.float32
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=weight_type),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=weight_type),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
block_n, block_k = self.weight_block_size[0], self.weight_block_size[1]
w13_scale = torch.nn.Parameter(torch.ones(num_experts,
2 * math.ceil(intermediate_size_per_partition / block_n),
math.ceil(hidden_size / block_k),
dtype=scale_dtype),
requires_grad=False)
layer.register_parameter("w13_scale", w13_scale)
set_weight_attrs(w13_scale, extra_weight_attrs)
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
math.ceil(hidden_size / block_n),
math.ceil(intermediate_size_per_partition / block_k),
dtype=scale_dtype),
requires_grad=False)
layer.register_parameter("w2_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs)
smooth_scale_1 = Parameter(torch.ones((num_experts, hidden_size), dtype=scale_dtype), requires_grad=False)
smooth_scale_2 = Parameter(torch.ones((num_experts, intermediate_size_per_partition),
dtype=scale_dtype),
requires_grad=False)
layer.register_parameter("smooth_scale_1", smooth_scale_1)
layer.register_parameter("smooth_scale_2", smooth_scale_2)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
expert_tokens: torch.Tensor,
group_list_type: int,
pertoken_scale: torch.Tensor = None,
final_output_dtype: torch.dtype = torch.bfloat16,
**kwargs,
):
if pertoken_scale is None:
x, pertoken_scale = torch_npu.npu_dynamic_block_quant(
x,
dst_type=torch.float8_e4m3fn,
row_block_size=1,
col_block_size=self.weight_block_size[1],
)
mm1_mm3 = torch_npu.npu_grouped_matmul(
[x], [layer.w13_weight],
scale=[layer.w13_scale],
per_token_scale=[pertoken_scale],
group_list=expert_tokens, split_item=3,
output_dtype=torch.bfloat16, group_type=0,
group_list_type=group_list_type,
tuning_config=[0]
)[0]
swiglu_limit = kwargs.get("swiglu_limit", None)
enable_custom_ops = kwargs.get("enable_custom_ops", False)
if enable_custom_ops:
intermediate_h, pertoken_scale , _ = torch.ops.custom.npu_swiglu_group_quant(mm1_mm3,
dst_type=torch.float8_e4m3fn,
quant_mode=0,
clamp_limit=swiglu_limit,
group_index=expert_tokens)
else:
mm1_mm3 = torch_npu.npu_swiglu(mm1_mm3)
intermediate_h, pertoken_scale = torch_npu.npu_dynamic_block_quant(
mm1_mm3,
dst_type=torch.float8_e4m3fn,
row_block_size=1,
col_block_size=self.weight_block_size[1],
)
out_hidden = torch_npu.npu_grouped_matmul(
[intermediate_h], [layer.w2_weight], bias=None,
scale=[layer.w2_scale], per_token_scale=[pertoken_scale],
group_list=expert_tokens, split_item=3,
output_dtype=final_output_dtype, group_type=0,
group_list_type=group_list_type,
tuning_config=[0]
)[0]
return out_hidden
def process_weights_after_loading(
self, layer: torch.nn.Module,
is_transpose: bool = True,
is_nz: bool = True,
**kwargs,
) -> None:
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
w13_scale = layer.w13_scale
w2_scale = layer.w2_scale
if is_transpose:
w13_weight.data = w13_weight.data.transpose(1, 2)
w2_weight.data = w2_weight.data.transpose(1, 2)
w13_scale.data = w13_scale.data.transpose(1, 2)
w2_scale.data = w2_scale.data.transpose(1, 2)
if is_nz:
w13_weight.data = torch_npu.npu_format_cast(w13_weight.data.contiguous(), 29)
w2_weight.data = torch_npu.npu_format_cast(w2_weight.data.contiguous(), 29)
w13_scale.data = w13_scale.data.contiguous()
w2_scale.data = w2_scale.data.contiguous()
layer.w13_scale.data = layer.w13_scale.data.to(torch.float)
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
layer.w2_weight = Parameter(w2_weight, requires_grad=False)