"""LayerNorm API for TransformerEngineNPU PyTorch"""
import warnings
from typing import Any, Iterable, Optional, Union
import torch
import torch.nn as nn
class LayerNorm(nn.LayerNorm):
r"""Layer Normalization
Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform
parameters that match the inner-most dimensions of the input
tensor.
Parameters
----------
normalized_shape : int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
A value added to the denominator of layer normalization for
numerical stability
device : torch.device, default = default CUDA device
Tensor device
dtype : torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
If ``True``, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
sm_margin : int or dict, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
margin at each compute stage (``"forward"``, ``"backward"``,
``"inference"``).
sequence_parallel : bool
**Legacy parameter.** Set a bool attr named ``sequence_parallel`` in the parameters.
This is custom logic for Megatron-LM integration.
"""
def __init__(
self,
normalized_shape: Union[Iterable[int], int, None] = None,
eps: float = 1e-5,
sequence_parallel: Optional[bool] = None,
params_dtype: Optional[torch.dtype] = None,
zero_centered_gamma: bool = False,
hidden_size: Optional[int] = None,
**kwargs,
):
if normalized_shape is None:
if hidden_size is None:
raise RuntimeError(
"Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided"
)
warnings.warn(
"`hidden_size` arg has been renamed to `normalized_shape` "
"for compatibility with `torch.nn.LayerNorm`.",
DeprecationWarning,
stacklevel=2,
)
normalized_shape = hidden_size
elif hidden_size is not None:
raise RuntimeError(
"Both `normalized_shape` and `hidden_size` (deprecated) args are provided"
)
if params_dtype is not None:
if "dtype" in kwargs:
raise RuntimeError(
"Both `dtype` and `params_dtype` (deprecated) kwargs are provided"
)
kwargs["dtype"] = params_dtype
self.sequence_parallel: Optional[bool] = sequence_parallel
if zero_centered_gamma :
raise NotImplementedError("Zero-centered gamma is not supported in this dummy implementation.")
super().__init__(
normalized_shape,
eps=eps,
**kwargs,
)
def fast_setattr(self, name: str, value: Any) -> None:
"""Fast attribute set for non-parameter fields."""
self.__dict__[name] = value
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
warnings.warn(
"This method will be deprecated in an upcoming release. "
"Update your code to use LayerNorm.reset_parameters() instead.",
DeprecationWarning,
stacklevel=2,
)
self.reset_parameters()
def reset_parameters(self, defer_init: Optional[bool] = None) -> None:
"""Init LayerNorm parameters"""
if defer_init is not None:
warnings.warn(
"defer_init argument to reset_parameters function is deprecated. Set device to"
' "meta" instead.',
DeprecationWarning,
stacklevel=2,
)
if defer_init:
return
super().reset_parameters()
if self.sequence_parallel is not None:
self.weight.sequence_parallel = self.sequence_parallel
self.bias.sequence_parallel = self.sequence_parallel
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
return super().forward(input)
__all__ = ["LayerNorm"]