"""Fusible operation for SwiGLU and variants."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import torch
import torch_npu
from .npu_activation import (
swiglu_fwd, swiglu_bwd,
)
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize
__all__ = ["SwiGLU"]
class SwiGLU(BasicOperation):
r"""Swish gated linear unit
The input tensor is split into chunks :math:``a`` and :math:``b``
along the last dimension and the following is computed:
.. math::
\text{SwiGLU}(a,b) = \text{SiLU}(a) * b
where
.. math::
\text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)}
.. warning::
Transformer Engine's gated activations and PyTorch's GLU
activation follow opposite conventions for :math:``a`` and
:math:``b``. Transformer Engine applies the gating function to
the first half of the input tensor, while PyTorch applies it to
the second half.
The Sigmoid Linear Unit (SiLU) gating function is also known as
the swish function. See
`GLU Variants Improve Transformer <https://arxiv.org/abs/2002.05202>`__.
Parameters
----------
cache_quantized_input : bool, default = False
Quantize input tensor when caching for use in the backward
pass. This will typically reduce memory usage but require
extra compute and increase numerical error. This feature is
highly experimental.
glu_interleave_size : int, optional
When set, the GLU activations will use a block interleaved
format. Instead of interpreting the input tensor as a
concatenation of gates and linear units (e.g.
:math:``[a_1, a_2, a_3, a_4, b_1, b_2, b_3, b_4]``
in the above notation), it will be interpreted
as alternating blocks of gates and linear units (e.g.
:math:``[a_1, a_2, b_1, b_2, a_3, a_4, b_3, b_4]``
when the interleave size is 2). This data format is highly
experiental and is primarily intended to support some advanced
fused kernels.
"""
def __init__(
self,
*,
cache_quantized_input: bool = False,
glu_interleave_size: Optional[int] = None,
):
super().__init__()
self.cache_quantized_input: bool = cache_quantized_input
self.glu_interleave_size: Optional[int] = glu_interleave_size
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
dtype: torch.dtype
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("npu")
else:
dtype = input_.dtype
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise RuntimeError(f"Unsupported dtype ({dtype})")
input_ = maybe_dequantize(input_.contiguous(), dtype)
swiglu_in = input_
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
out = swiglu_fwd(swiglu_in)
if self.cache_quantized_input:
input_quantizer = Float8CurrentScalingQuantizer(
torch.float8_e4m3fn,
input_.device,
)
input_quantizer.set_usage(rowwise=True, columnwise=False)
input_ = input_quantizer(input_)
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(input_)
ctx.save_for_backward(input_)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
return out
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
(input_,) = ctx.saved_tensors
x = maybe_dequantize(input_.contiguous(), ctx.dtype)
dy = maybe_dequantize(grad_output.contiguous(), ctx.dtype)
swiglu_in = x
if self.glu_interleave_size is not None:
shape = swiglu_in.size()
swiglu_in = swiglu_in.reshape(
-1,
shape[-1] // (2 * self.glu_interleave_size),
2,
self.glu_interleave_size,
)
swiglu_in = swiglu_in.transpose(1, 2).contiguous()
swiglu_in = swiglu_in.view(shape)
grad_swiglu_in = swiglu_bwd(swiglu_in, dy)
dx = grad_swiglu_in
if self.glu_interleave_size is not None:
shape = dx.size()
dx = dx.reshape(
-1,
2,
shape[-1] // (2 * self.glu_interleave_size),
self.glu_interleave_size,
)
dx = dx.transpose(1, 2).contiguous()
dx = dx.view(shape)
clear_tensor_data(input_)
return dx, ()