from typing import Any, List
import torch.nn as nn
from mindspeed.fsdp.quantization.converter.model_converter import register_model_converter
from mindspeed.fsdp.quantization.converter.utils import module_filter_fn, moe_filter_fn
from mindspeed.fsdp.parallel_engine_config import QuantizeConfig
class MXLinearConverter:
"""Converts the linear layers of `model` to `MXLinear`."""
filter_fqns: List[str]
mx_config: Any
def __init__(self, config: QuantizeConfig):
self.config = config
def convert(self, model: nn.Module):
"""
Converts the linear layers of `model` to `MXLinear`.
Note that today, only MXFP8 (the default) is supported.
This will mutate the model inplace.
"""
from mindspeed.fsdp.quantization.converter.utils import convert_model
from mindspeed.fsdp.quantization.module.linear_mxfp8 import MXLinear
convert_model(
model,
config=self.config,
convert_fn=MXLinear.from_float,
filter_fn=module_filter_fn,
device=model.device,
)
class MXMoeConverter:
"""Converts the linear layers of `model` to `MXLinear`."""
filter_fqns: List[str]
mx_config: Any
def __init__(self, config: QuantizeConfig):
self.config = config
def convert(self, model: nn.Module):
"""Converts the linear layers of `model` to `MXLinear`.
Note that today, only MXFP8 (the default) is supported.
This will mutate the model inplace.
"""
from mindspeed.fsdp.quantization.converter.utils import convert_model
from mindspeed.fsdp.quantization.module.gmm_mxfp8 import MXFP8GMM
convert_model(
model,
config=self.config,
convert_fn=MXFP8GMM.from_float,
filter_fn=moe_filter_fn,
device=model.device,
)
register_model_converter(MXLinearConverter, "quantize.linear.mx")
register_model_converter(MXMoeConverter, "quantize.moe.mx")