import torch
from ..utils import register_tensor_cast_op
@register_tensor_cast_op("swiglu")
def _(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
if gate.shape != up.shape:
raise RuntimeError(f"Shape mismatch in swiglu: gate {gate.shape} vs up {up.shape}")
output_shape = list(gate.shape)
return torch.empty(output_shape, dtype=gate.dtype, device="meta")