import torch
from mindspeed_mm.fsdp.utils.device import get_device_type
from mindspeed_mm.fsdp.ops.swiglu import swiglu
from tests.ut_fsdp.utils.utils import judge_expression
def test_eager_swiglu_basic():
inputs = torch.randn(2, 4, 8, device=torch.device(get_device_type()))
output = swiglu(inputs, fused=False)
judge_expression(output.shape == (2, 4, 4))
judge_expression(not torch.isnan(output).any())
judge_expression(not torch.isinf(output).any())
def test_fused_swiglu_basic():
inputs = torch.randn(4, 16, 8, device=torch.device(get_device_type()))
output = swiglu(inputs, fused=True)
judge_expression(output.shape == (4, 16, 4))
judge_expression(not torch.isnan(output).any())
judge_expression(not torch.isinf(output).any())
def test_fused_vs_eager_consistency_for_swiglu():
test_shapes = [
(2, 8),
(4, 16, 8),
(1, 32, 16, 8),
(3, 6, 12, 24),
]
test_dims = [-1, 0, 1, 2]
for shape in test_shapes:
for dim in test_dims:
if abs(dim) >= len(shape) or shape[dim] % 2 != 0:
continue
inputs = torch.randn(shape, device=torch.device(get_device_type()), dtype=torch.bfloat16)
out_eager = swiglu(inputs, fused=False)
out_fused = swiglu(inputs, fused=True)
torch.testing.assert_close(out_fused, out_eager)
judge_expression(not torch.isnan(out_fused).any())
judge_expression(not torch.isinf(out_fused).any())