"""
DEPRECATED MODULE
This module is deprecated and will be removed in future releases.
Normalization
"""
import mindspore.common.dtype as mstype
import mindspore.ops.operations as P
from mindspore import Parameter, nn
from mindspore.common.initializer import initializer
from mindformers.version_control import check_rmsnorm_big_kernel_valid
class RMSNorm(nn.Cell):
r"""
A self-defined RMSNorm operation using reduce mean.
Args:
dim (tuple): The shape of the input tensor
eps (float): The epsilon value of the denominator. Default 1e-5.
compute_type: The compute type.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(batch, seq\_length, hidden\_size)`.
Outputs:
Tensor of shape :math:`(batch, seq_length, hidden_size)`.
"""
def __init__(self, dim, eps=1e-6, compute_type=mstype.float32):
super().__init__()
self.eps = eps
self.compute_type = compute_type
self.weight = Parameter(initializer('ones', (dim,), dtype=self.compute_type), parallel_optimizer=False)
if check_rmsnorm_big_kernel_valid():
self.norm = P.RmsNorm(eps)
self.rms_norm = self._rms_norm
self.self_define = False
self.cast = P.Cast()
self.rcast = P.Cast()
else:
self.cast = P.Cast()
self.mul = P.Mul()
self.mul2 = P.Mul()
self.square = P.Square()
self.mean = P.ReduceMean(keep_dims=True)
self.add = P.Add()
self.rsqrt = P.Rsqrt()
self.rms_norm = self._self_norm
self.self_define = True
def _self_norm(self, x):
original_type = x.dtype
norm_factor = self.square(self.cast(x, self.compute_type))
norm_factor = self.mean(norm_factor, -1)
norm_factor = self.add(norm_factor, self.eps)
norm_factor = self.rsqrt(norm_factor)
output = self.mul(x, self.cast(norm_factor, original_type))
output = self.mul2(output, self.cast(self.weight, original_type))
return output
def _rms_norm(self, x):
original_type = x.dtype
output = self.norm(self.cast(x, self.compute_type), self.weight)[0]
return self.rcast(output, original_type)
def construct(self, x):
"""Forward of RMSNorm."""
return self.rms_norm(x)
def sharded_state_dict(self):
"""provide the sharded state dict based on the config"""
w_shard = (1,)
state_dict = {}
state_dict[self.weight.name] = {'shape': self.weight.shape,
'shard': w_shard}
return state_dict