import torch
from mindspeed_mm.fsdp.utils.device import IS_NPU_AVAILABLE

if IS_NPU_AVAILABLE:
    import torch_npu


def eager_swiglu(inputs, dim=-1):
    if dim < 0:
        dim = inputs.dim() + dim
    x1, x2 = torch.chunk(inputs, 2, dim=dim)
    return torch.nn.functional.silu(x1) * x2


def fused_swiglu(inputs, dim=-1):
    return torch_npu.npu_swiglu(inputs, dim=dim)


def swiglu(inputs, dim=-1, fused=True):
    if fused and IS_NPU_AVAILABLE:
        return fused_swiglu(inputs, dim=dim)
    else:
        return eager_swiglu(inputs, dim=dim)

def clamp_swiglu(inputs, dim=-1, fused=True, limit=0.0):
    y1, y2 = torch.chunk(inputs, 2, dim=-1)
    if limit > 0:
        y1 = y1.clamp(min=None, max=limit)
        y2 = y2.clamp(min=-limit, max=limit)
    return swiglu(torch.cat([y1, y2], dim=-1), dim=dim, fused=fused)