import torch
import torch_npu
_SWIGLU_LIMIT = None
def get_swiglu_limit():
from megatron.training.global_vars import get_args
args = get_args()
return getattr(args, 'swiglu_limit', 0)
def fused_swiglu_with_limit(x, limit=0):
global _SWIGLU_LIMIT
if _SWIGLU_LIMIT is None:
_SWIGLU_LIMIT = get_swiglu_limit()
else:
limit = _SWIGLU_LIMIT
if limit > 0:
y_1, y_2 = torch.chunk(x.float(), 2, dim=0)
y_1 = y_1.clamp(min=None, max=limit)
y_2 = y_2.clamp(min=-limit, max=limit)
return torch_npu.npu_swiglu(torch.cat([y_1, y_2], dim=0).to(x.dtype), dim=-1)
return torch_npu.npu_swiglu(x, dim=-1)