"""GroupedLinear API for TransformerEngine NPU PyTorch
This module implements grouped linear layers for Mixture of Experts (MoE) models,
integrating MindSpeed's NPU-optimized grouped matmul operations.
Reference: TransformerEngine/transformer_engine/pytorch/module/grouped_linear.py
"""
import warnings
import weakref
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
from ..constants import GemmParallelModes, dist_group_type
from ..distributed import get_distributed_world_size, set_tensor_model_parallel_attributes
from ..ops.gemm import general_grouped_gemm
from ..quantization import FP8GlobalStateManager
from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload
from ..quantized_tensor import (
QuantizedTensorStorage,
)
from ..tensor.grouped_tensor import GroupedTensor
from ..utils import (
cast_if_needed,
divide,
init_method_constant,
requires_grad,
)
from ._common import WeightGradStore
from .base import (
TransformerEngineBaseModule,
_check_fp8_reduce_and_update,
get_dummy_wgrad,
quantize_weight,
)
from .performance_grouped_linear_impl import GroupedLinearArgs, _PerformanceGroupedLinear
class _GroupedLinear(torch.autograd.Function):
"""GroupedLinear autograd function with FP8 support.
This implements grouped matrix multiplication with support for:
- NPU optimized grouped matmul
- FP8 quantization (placeholder for NPU)
- Gradient computation (dgrad and wgrad)
"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
args: GroupedLinearArgs,
*weights_and_biases,
) -> torch.Tensor:
assert not args.fuse_wgrad_accumulation, (
"Grouped Linear not support fuse_wgrad_accumulation yet"
)
input_quantizer = args.input_quantizer
weight_quantizer = args.weight_quantizer
output_quantizer = args.output_quantizer
num_gemms = len(args.m_splits)
weights = list(weights_and_biases[:num_gemms])
biases = list(weights_and_biases[num_gemms:])
weight_requires_grad = weights[0].requires_grad
if input_quantizer is not None:
columnwise = args.is_grad_enabled and weight_requires_grad
if columnwise:
input_quantizer.columnwise_use_group_quant = True
input_quantizer.set_usage(
rowwise=True,
columnwise=columnwise,
)
if weight_quantizer is not None:
weight_quantizer.set_usage(
rowwise=True,
columnwise=args.is_grad_enabled and inp.requires_grad,
)
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
if args.grad_output_quantizer is not None:
args.grad_output_quantizer.columnwise_use_group_quant = True
in_features = weights[0].size(-1)
if inp.size(-1) != in_features:
raise ValueError(
f"Input tensor (shape={tuple(inp.size())}) is not compatible with "
f"weight tensor (shape={tuple(weights[0].size())})"
)
inp_view = inp.reshape(-1, in_features)
cast_biases = biases
if args.use_bias:
bias_dtype = (
args.activation_dtype if args.activation_dtype != torch.float32 else torch.bfloat16
)
cast_biases = [cast_if_needed(bias, bias_dtype) for bias in biases]
new_workspace = None
if args.fp8 and not args.debug:
inputmats = input_quantizer.grouped_quantize(inp_view, args.group_list)
weight_stacked = torch.stack(weights, dim=0)
update_ws = args.is_first_microbatch is None or args.is_first_microbatch
weights_fp8, new_workspace = quantize_weight(
tensor=weight_stacked,
quantizer=weight_quantizer,
workspace=args.weight_workspace,
update_workspace=update_ws,
skip_update_flag=args.skip_fp8_weight_update,
workspace_dtype=args.activation_dtype,
cache=args.cache_weight,
group_list=args.group_list,
)
else:
inputmats = cast_if_needed(inp_view, args.activation_dtype)
weights_fp8 = [cast_if_needed(weight, args.activation_dtype) for weight in weights]
out = general_grouped_gemm(
weights_fp8,
inputmats,
args.group_split,
use_bias=args.use_bias,
biases=cast_biases,
out_dtype=args.activation_dtype,
)
if args.fp8_calibration:
input_quantizer.calibrate(inp_view)
weight_quantizer.calibrate(
weight_stacked if args.fp8 and not args.debug else torch.stack(weights, dim=0)
)
if args.cpu_offloading:
start_offload(inputmats)
if isinstance(weights_fp8, torch.Tensor):
mark_weights_fp8 = [weights_fp8]
else:
mark_weights_fp8 = weights_fp8
mark_not_offload(*mark_weights_fp8, *weights)
if args.is_grad_enabled:
ctx.args = args
ctx.inputmats = inputmats
ctx.weights_fp8 = weights_fp8
ctx.biases = biases
ctx.num_gemms = num_gemms
ctx.inp_shape = inp.shape
ctx.requires_dgrad = inp.requires_grad
ctx.requires_wgrad = weights[0].requires_grad
if args.fuse_wgrad_accumulation and ctx.requires_wgrad:
ctx.origin_weight_refs = [weakref.ref(w) for w in weights]
ctx.origin_weights_overwrite_main_grad = getattr(
weights[0], "overwrite_main_grad", False
)
if hasattr(weights[0], "__fsdp_param__"):
ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_gemms)]
else:
ctx.main_grad_funcs = [
lambda j=i: weights[j].main_grad for i in range(num_gemms)
]
if args.fp8 and requires_grad(inp, weights[0], biases[0]):
ctx.reduce_and_update_bwd_fp8_tensors = _check_fp8_reduce_and_update()
else:
ctx.reduce_and_update_bwd_fp8_tensors = False
return out.view(-1, *inp.shape[1:-1], out.shape[-1]), new_workspace
@staticmethod
def backward(
ctx, grad_output: torch.Tensor, _grad_workspaces
) -> Tuple[Union[torch.Tensor, None], ...]:
args: GroupedLinearArgs = ctx.args
N = ctx.num_gemms
main_grads = [None] * N
origin_weights = [None] * N
if args.fuse_wgrad_accumulation and ctx.requires_wgrad:
origin_weight_refs = getattr(ctx, "origin_weight_refs", None)
if origin_weight_refs is not None:
ctx.origin_weight_refs = None
origin_weights = [ref() if ref is not None else None for ref in origin_weight_refs]
assert all(w is not None for w in origin_weights), (
"weight was removed while fuse_wgrad_accumulation=True"
)
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
for origin_weight, main_grad in zip(origin_weights, main_grads):
if main_grad is not None:
origin_weight.main_grad = main_grad
grad_output_view = grad_output.view(-1, grad_output.shape[-1])
if args.use_bias:
grad_output_split = torch.split(grad_output_view, args.m_splits)
grad_biases = [grad_output_split[i].sum(dim=0) for i in range(ctx.num_gemms)]
else:
grad_biases = [None] * N
dgrad = None
wgrad_list = [None] * N
if args.fp8 and not args.debug:
if not ctx.requires_dgrad:
args.grad_output_quantizer.set_usage(rowwise=False)
if not ctx.requires_wgrad:
args.grad_output_quantizer.set_usage(columnwise=False)
grad_output_mats: GroupedTensor = args.grad_output_quantizer.grouped_quantize(
grad_output_view, args.group_list
)
else:
grad_output_mats = cast_if_needed(grad_output_view, args.activation_dtype)
if ctx.requires_dgrad:
weights_fp8: GroupedTensor = ctx.weights_fp8
dgrad = general_grouped_gemm(
weights_fp8,
grad_output_mats,
args.group_split,
layout="NN",
out_dtype=args.activation_dtype,
)
if ctx.requires_wgrad:
wgrad = general_grouped_gemm(
ctx.inputmats,
grad_output_mats,
args.group_split,
layout="NT",
use_bias=args.use_bias if grad_biases[0] is None else None,
biases=ctx.biases,
group_type=2,
out_dtype=args.activation_dtype,
)
wgrad_list = [wgrad[i] for i in range(N)]
def handle_custom_ddp_from_mcore(weight, _main_grad, _wgrad):
if ctx.requires_wgrad:
if args.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"):
weight.grad_added_to_main_grad = True
_wgrad = get_dummy_wgrad(
list(_main_grad.shape),
weight.dtype,
zero=getattr(weight, "zero_out_wgrad", False),
)
elif args.fuse_wgrad_accumulation:
_wgrad = None
else:
_wgrad = None
return _wgrad
wgrad_list = [
handle_custom_ddp_from_mcore(weight, main_grad, wgrad)
for weight, main_grad, wgrad in zip(origin_weights, main_grads, wgrad_list)
]
if ctx.reduce_and_update_bwd_fp8_tensors:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
None,
*wgrad_list,
*grad_biases,
)
class GroupedLinear(TransformerEngineBaseModule):
"""Grouped Linear layer with FP8 support for NPU.
This layer implements grouped linear transformations, commonly used in
Mixture of Experts (MoE) models. It supports both column-parallel and
row-parallel modes for distributed training.
Parameters
----------
num_gemms : int
Number of groups (experts)
in_features : int
Input feature dimension
out_features : int
Output feature dimension
sequence_parallel : bool, default False
Whether to use sequence parallelism
fuse_wgrad_accumulation : bool, default False
Whether to fuse weight gradient accumulation
tp_group : ProcessGroup, optional
Tensor parallel process group
tp_size : int, default 1
Tensor parallel world size
get_rng_state_tracker : callable, optional
RNG state tracker for initialization
init_method : callable, optional
Weight initialization method
bias : bool, default True
Whether to use bias
return_bias : bool, default False
Whether to return bias separately
params_dtype : torch.dtype, optional
Parameter data type
parallel_mode : str, optional
Parallelization mode: "column" or "row"
device : torch.device, default "npu"
Device for parameters
"""
def __init__(
self,
num_gemms: int,
in_features: int,
out_features: int,
sequence_parallel: bool = False,
fuse_wgrad_accumulation: bool = False,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
rng_tracker_name: Optional[str] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
device: Union[torch.device, str] = "npu",
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
save_original_input: bool = False,
single_grouped_weight: bool = False,
single_grouped_bias: bool = False,
name: Optional[str] = None,
) -> None:
super().__init__(name)
self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.num_gemms = num_gemms
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.ub_overlap_rs = ub_overlap_rs
self.ub_overlap_ag = ub_overlap_ag
self.ub_name = ub_name
self.save_original_input = save_original_input
self.single_grouped_weight = single_grouped_weight
self.single_grouped_bias = single_grouped_bias
assert not ub_overlap_rs and not ub_overlap_ag, (
"GroupedLinear doesn't support Userbuffer overlap."
)
self.init_method = init_method
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self.wgrad_store = WeightGradStore(delay_wgrad_compute)
self._offsets = {
"input": 0,
"weight": 1,
"output": 2,
"grad_output": 0,
"grad_input": 1,
}
self._num_fp8_tensors_per_gemm = {
"fwd": 3,
"bwd": 2,
}
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
self.set_tensor_parallel_group(tp_group)
else:
self.tp_size = get_distributed_world_size(tp_group)
self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp()
if self.tp_size > 1 and bias:
raise ValueError(
"GroupedLinear doesn't support bias when TP > 1. "
"Because the TP communication is handled outside of this module."
)
self.parallel_mode = parallel_mode
assert self.parallel_mode in GemmParallelModes, (
f"parallel_mode {parallel_mode} not supported"
)
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
if isinstance(device, str):
if device == "npu":
device = torch.device(torch.npu.current_device())
else:
device = torch.device(device)
self.device = device
for i in range(self.num_gemms):
self.register_parameter(
f"weight{i}",
torch.nn.Parameter(
torch.empty(
self.out_features,
self.in_features,
device=device,
dtype=self.params_dtype,
),
),
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"],
)
if self.use_bias:
self.register_parameter(
f"bias{i}",
torch.nn.Parameter(
torch.empty(
self.out_features,
device=device,
dtype=self.params_dtype,
),
),
init_fn=init_method_constant(0.0),
)
else:
bias = torch.Tensor().to(dtype=self.params_dtype, device=device)
setattr(self, f"bias{i}", bias)
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms)
self.reset_parameters()
if self.wgrad_store.delay_wgrad_compute():
for pname, param in self.named_parameters():
for i in range(self.num_gemms):
if pname in (f"weight{i}", f"bias{i}"):
param.skip_backward_post_hook = True
def make_grouped_weights(self, defer_init=False) -> None:
"""
Convert parameters into a GroupedTensor and re-register them as parameters.
"""
if defer_init:
return
weight_quantizers = self._get_weight_quantizers()
recipe = (
weight_quantizers[0]._get_compatible_recipe()
if weight_quantizers and weight_quantizers[0] is not None
else None
)
if recipe is not None and (recipe.delayed() or recipe.float8_current_scaling()):
self.set_tensor_parallel_attributes(defer_init=defer_init)
return
weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
grouped_weights = torch.empty(
self.num_gemms,
self.out_features,
self.in_features,
dtype=weights[0].dtype,
device=weights[0].device,
)
with torch.no_grad():
for i in range(self.num_gemms):
grouped_weights[i].copy_(weights[i])
if not (
isinstance(grouped_weights, torch.Tensor)
and (weight_quantizers[0] is None or not weight_quantizers[0].internal)
):
raise RuntimeError("Found internal quantizer with `single_grouped_weight=True`.")
self.register_parameter(
"weight",
torch.nn.Parameter(grouped_weights),
init_fn=self.init_method,
get_rng_state_tracker=self.get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"],
)
for i in range(self.num_gemms):
self.register_parameter(f"weight{i}", None)
if self.use_bias and self.single_grouped_bias:
self._make_grouped_biases()
self.set_tensor_parallel_attributes(defer_init=defer_init)
def _make_grouped_biases(self) -> None:
"""Pack per-GEMM biases into one ``GroupedTensor`` (``single_grouped_bias``)."""
grouped_bias = getattr(self, "bias", None)
if grouped_bias is not None and all(
getattr(self, f"bias{i}", None) is None for i in range(self.num_gemms)
):
return
biases = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
packed = torch.stack([b.detach().clone() for b in biases], dim=0).contiguous()
grouped_bias = GroupedTensor.make_grouped_tensor_from_rowwise_data(
num_tensors=self.num_gemms,
tensor_shape=(self.out_features,),
rowwise_data=packed,
dtype=packed.dtype,
)
grouped_bias.requires_grad_(True)
self.register_parameter("bias", torch.nn.Parameter(grouped_bias))
for i in range(self.num_gemms):
self.register_parameter(f"bias{i}", None)
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
if self.single_grouped_weight:
self.make_grouped_weights(defer_init=defer_init)
elif self.single_grouped_bias:
self._make_grouped_biases()
def set_tensor_parallel_attributes(self, defer_init=False) -> None:
"""Set attributes needed for TP"""
if defer_init:
return
grouped_weight = getattr(self, "weight", None)
if grouped_weight is not None:
set_tensor_model_parallel_attributes(
tensor=grouped_weight,
is_parallel=True,
dim=2 if self.parallel_mode == "row" else 1,
stride=1,
)
else:
for i in range(self.num_gemms):
set_tensor_model_parallel_attributes(
tensor=getattr(self, f"weight{i}"),
is_parallel=True,
dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
grouped_bias = getattr(self, "bias", None)
if grouped_bias is not None:
if self.parallel_mode == "row":
setattr(grouped_bias, "sequence_parallel", self.sequence_parallel)
elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(grouped_bias, True, 0, 1)
else:
for i in range(self.num_gemms):
if self.parallel_mode == "row":
setattr(
getattr(self, f"bias{i}"),
"sequence_parallel",
self.sequence_parallel,
)
elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1)
def _has_packed_grouped_weight(self) -> bool:
"""Return whether the module currently owns a packed 3D grouped weight."""
grouped_weight = getattr(self, "weight", None)
return (
grouped_weight is not None
and isinstance(grouped_weight, torch.Tensor)
and grouped_weight.dim() == 3
and grouped_weight.size(0) == self.num_gemms
)
def _use_performance_grouped_linear(self) -> bool:
"""Performance path is valid only after grouped-weight packing succeeded."""
if not self.single_grouped_weight:
return False
if not self._has_packed_grouped_weight():
return False
return True
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
grouped_weight = getattr(self, "weight", None)
if self.single_grouped_weight:
if self._has_packed_grouped_weight():
return [grouped_weight]
return [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
if grouped_weight is not None:
weight_tensors = grouped_weight.quantized_tensors
if weight_tensors is None:
weight_tensors = grouped_weight.split_into_quantized_tensors()
else:
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
if not self.fp8 and any(isinstance(w, QuantizedTensorStorage) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensorStorage) else w
for w in weight_tensors
]
return weight_tensors
def _get_bias_tensors(self, *, for_linear: bool = False) -> List[torch.Tensor]:
"""Bias tensors, optionally shaped for the selected linear autograd path."""
grouped_bias = getattr(self, "bias", None)
if self.single_grouped_bias:
if grouped_bias is None:
return [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if self._use_performance_grouped_linear():
return [grouped_bias]
if isinstance(grouped_bias, GroupedTensor):
parts = grouped_bias.quantized_tensors
if parts is None:
parts = grouped_bias.split_into_quantized_tensors()
return [p.reshape(-1) for p in parts]
assert isinstance(grouped_bias, torch.Tensor), "Expected grouped bias to be a tensor"
assert grouped_bias.size(0) == self.num_gemms, "Grouped bias size mismatch"
return [b.reshape(-1) for b in grouped_bias.unbind(dim=0)]
if grouped_bias is not None:
if not isinstance(grouped_bias, GroupedTensor):
assert grouped_bias.size(0) == self.num_gemms, "Grouped bias size mismatch"
bias_tensors = [b.reshape(-1) for b in grouped_bias.unbind(dim=0)]
else:
parts = grouped_bias.quantized_tensors
if parts is None:
parts = grouped_bias.split_into_quantized_tensors()
bias_tensors = [p.reshape(-1) for p in parts]
else:
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if for_linear and self._use_performance_grouped_linear():
if self.apply_bias:
return [torch.stack(bias_tensors, dim=0).contiguous()]
return bias_tensors[:1]
return bias_tensors
def forward(
self,
inp: torch.Tensor,
m_splits: List[int],
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""Forward pass of GroupedLinear.
Parameters
----------
inp : torch.Tensor
Input tensor
m_splits : List[int]
List of integers representing the split of the input tensor.
is_first_microbatch : {True, False, None}, default None
Flag for microbatch handling during gradient accumulation.
Returns
-------
Union[torch.Tensor, Tuple[torch.Tensor, ...]]
Output tensor, or tuple of (output, bias) if return_bias=True
"""
debug = False
assert not isinstance(inp, QuantizedTensorStorage), (
"GroupedLinear doesn't support input tensor in FP8."
)
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
is_grad_enabled = torch.is_grad_enabled()
inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
try:
weight_tensors = self._get_weight_tensors()
bias_tensors = self._get_bias_tensors(for_linear=True)
quantizers = self._get_quantizers()
(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers
use_performance_grouped_linear = self._use_performance_grouped_linear()
_linear = (
_PerformanceGroupedLinear if use_performance_grouped_linear else _GroupedLinear
)
if is_grad_enabled:
linear_fn = _linear.apply
autograd_ctx = []
else:
linear_fn = _linear.forward
autograd_ctx = [None]
cache_weight = is_first_microbatch is not None
cache_name = None if (is_first_microbatch is None or self.is_fsdp2) else "weight"
weight_workspaces = self._get_weight_workspace(cache_weight, cache_name)
non_tensor_args = GroupedLinearArgs(
m_splits=m_splits,
use_bias=self.apply_bias,
is_first_microbatch=is_first_microbatch,
fp8=self.fp8,
fp8_calibration=self.fp8_calibration,
wgrad_store=self.wgrad_store,
input_quantizers=input_quantizers,
weight_quantizers=weight_quantizers,
output_quantizers=output_quantizers,
grad_input_quantizers=grad_input_quantizers,
grad_weight_quantizers=grad_weight_quantizers,
grad_output_quantizers=grad_output_quantizers,
fuse_wgrad_accumulation=self.fuse_wgrad_accumulation,
cpu_offloading=is_cpu_offload_enabled(),
sequence_parallel=self.sequence_parallel,
activation_dtype=self.activation_dtype,
is_grad_enabled=is_grad_enabled,
weight_workspaces=weight_workspaces,
cache_weight=cache_weight,
skip_fp8_weight_update=None,
save_original_input=False,
debug=debug,
)
if use_performance_grouped_linear:
weight = (
weight_tensors[0]
if isinstance(weight_tensors, (list, tuple))
else weight_tensors
)
bias = bias_tensors[0] if isinstance(bias_tensors, (list, tuple)) else bias_tensors
out, new_workspaces = linear_fn(*autograd_ctx, inp, non_tensor_args, weight, bias)
else:
out, new_workspaces = linear_fn(
*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors
)
if new_workspaces is not None:
if isinstance(new_workspaces, torch.Tensor):
new_workspaces = new_workspaces.detach()
if cache_name is not None:
self._fp8_workspaces[cache_name] = new_workspaces
finally:
self.end_forward()
if self.return_bias:
return_bias_tensors = self._get_bias_tensors()
return out, [cast_if_needed(b, self.activation_dtype) for b in return_bias_tensors]
return out
def _get_weight_quantizers(self) -> List[Any]:
"""Get the weight quantizers of the module."""
if not self.fp8 and not self.fp8_calibration and not self.primary_weights_in_fp8:
return [None] * self.num_gemms
weight_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
if weight_quantizers[i] is not None:
weight_quantizers[i].internal = not self.primary_weights_in_fp8
return weight_quantizers
def _get_quantizers(self) -> Tuple:
weight_quantizers = self._get_weight_quantizers()
input_quantizers, output_quantizers = ([None] * self.num_gemms, [None] * self.num_gemms)
grad_input_quantizers, grad_weight_quantizers, grad_output_quantizers = (
[None] * self.num_gemms,
[None] * self.num_gemms,
[None] * self.num_gemms,
)
if self.fp8:
input_quantizers = [
self.quantizers["scaling_fwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["fwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
if input_quantizers[i] is not None:
input_quantizers[i].internal = True
input_quantizers[i].optimize_for_gemm = True
if torch.is_grad_enabled():
grad_output_quantizers = [
self.quantizers["scaling_bwd"][
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
if grad_output_quantizers[i] is not None:
grad_output_quantizers[i].internal = True
grad_output_quantizers[i].optimize_for_gemm = True
return (
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
)
def _get_weight_workspace(self, cache_weight, cache_name):
if not cache_weight:
return [None]
return [self._fp8_workspaces.get(cache_name)]
__all__ = [
"GroupedLinear",
]