import os
from functools import wraps
import torch
import torch_npu
from torch import nn
from megatron.legacy.model.rms_norm import RMSNorm
from megatron.training import get_args
from mindspeed.core.tensor_parallel.mapping import reduce_from_tensor_model_parallel_region_nd
def rms_norm_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
_args = get_args()
self.use_fused_rmsnorm = _args.use_fused_rmsnorm
self.use_nd_matmul = _args.use_nd_matmul
dim = args[0] if len(args) > 0 else kwargs.get('dim')
if self.use_nd_matmul:
if self.use_fused_rmsnorm:
raise RuntimeError('nd_matmul does not support fused_rmsnorm temporarily')
self.tensor_model_parallel_size = _args.tensor_model_parallel_size
self.weight = torch.nn.Parameter(
torch.ones(dim // self.tensor_model_parallel_size)
)
return wrapper
def rms_norm_forward_wrapper(fn):
@wraps(fn)
def wrapper(self, x):
if int(os.getenv('NPU_ASD_ENABLE', '0')):
from torch_npu.utils import register_asd_hook
register_asd_hook(x, self.weight)
if self.use_fused_rmsnorm:
return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]
return fn(self, x)
return wrapper
def rms_norm_norm_wrapper(fn):
@wraps(fn)
def wrapper(self, x):
if self.use_nd_matmul:
pow_mean = x.pow(2).mean(-1, keepdim=True)
all_pow_mean = reduce_from_tensor_model_parallel_region_nd(pow_mean)
pow_mean = torch.div(all_pow_mean, self.tensor_model_parallel_size)
return x * torch.rsqrt(pow_mean + self.eps)
return fn(self, x)
return wrapper