import functools
import re
from typing import Optional, Callable, Dict, Type
import torch
import torch.nn as nn
import torch_npu
from mindspeed.fsdp.quantization.mxfp8_config import QuantBaseConfig
from mindspeed.fsdp.utils.str_match import module_name_match
_QUANTIZE_CONFIG_HANDLER: Dict[
Type[QuantBaseConfig],
Callable[[torch.nn.Module, QuantBaseConfig], torch.nn.Module],
] = {}
def register_quantize_module_handler(config_type):
@functools.wraps(config_type)
def decorator(func):
_QUANTIZE_CONFIG_HANDLER[config_type] = func
return func
return decorator
class TensorWithTranspose:
def __init__(
self,
fp8_dtype: torch.dtype,
data: torch.Tensor,
scale: torch.Tensor,
data_t: torch.Tensor,
scale_t: torch.Tensor,
dtype: torch.dtype = torch.float32,
):
self.fp8_dtype = fp8_dtype
self.data = data
self.scale = scale
self.data_t = data_t
self.scale_t = scale_t
self.ori_dtype = dtype
def get_by_trans(self, transpose=False):
if transpose:
return self.data_t, self.scale_t
return self.data, self.scale
def module_filter_fn(mod: nn.Module, fqn: str, config: QuantBaseConfig) -> bool:
def ignored_modules(fqn: str, config: QuantBaseConfig):
for pattern in config.mxfp8_ignored_modules:
if module_name_match(pattern, fqn):
return True
return False
if not isinstance(mod, nn.Linear):
return False
ignored_modules_flag = ignored_modules(fqn, config)
if ignored_modules_flag:
return False
for pattern in config.mxfp8_apply_modules:
m = re.match(r"(.*?layers\.\d+)", fqn)
if m is not None:
prefix = m.group(1)
if module_name_match(pattern, prefix):
return True
else:
return False
return False