import torch.nn as nn
from mindspeed_mm.models.common.activations import get_activation_layer, Sigmoid
from tests.ut.utils import judge_expression
class TestActivation:
"""
Test activation basic function.
"""
def test_activation_when_get_right_act_type(self):
act_type = "relu"
res = get_activation_layer(act_type)
judge_expression(isinstance(res(), nn.ReLU))
act_type = "gelu"
res = get_activation_layer(act_type)
judge_expression(isinstance(res(), nn.GELU))
act_type = "swish"
res = get_activation_layer(act_type)
judge_expression(isinstance(res(), Sigmoid))
def test_unknown_activation(self):
try:
get_activation_layer("invalid_act")
judge_expression(False)
except ValueError:
judge_expression(True)