05360171创建于 2022年3月18日历史提交
# Copyright 2021 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



# Activation functions



import torch

import torch.nn as nn

import torch.nn.functional as F



# Swish https://arxiv.org/pdf/1905.02244.pdf ---------------------------------------------------------------------------

class Swish(nn.Module):  #

    @staticmethod

    def forward(x):

        return x * torch.sigmoid(x)





class Hardswish(nn.Module):  # export-friendly version of nn.Hardswish()

    @staticmethod

    def forward(x):

        # return x * F.hardsigmoid(x)  # for torchscript and CoreML

        return x * F.hardtanh(x + 3, 0., 6.) / 6.  # for torchscript, CoreML and ONNX





class MemoryEfficientSwish(nn.Module):

    class F(torch.autograd.Function):

        @staticmethod

        def forward(ctx, x):

            ctx.save_for_backward(x)

            return x * torch.sigmoid(x)



        @staticmethod

        def backward(ctx, grad_output):

            x = ctx.saved_tensors[0]

            sx = torch.sigmoid(x)

            return grad_output * (sx * (1 + x * (1 - sx)))



    def forward(self, x):

        return self.F.apply(x)





# Mish https://github.com/digantamisra98/Mish --------------------------------------------------------------------------

class Mish(nn.Module):

    @staticmethod

    def forward(x):

        return x * F.softplus(x).tanh()





class MemoryEfficientMish(nn.Module):

    class F(torch.autograd.Function):

        @staticmethod

        def forward(ctx, x):

            ctx.save_for_backward(x)

            return x.mul(torch.tanh(F.softplus(x)))  # x * tanh(ln(1 + exp(x)))



        @staticmethod

        def backward(ctx, grad_output):

            x = ctx.saved_tensors[0]

            sx = torch.sigmoid(x)

            fx = F.softplus(x).tanh()

            return grad_output * (fx + x * sx * (1 - fx * fx))



    def forward(self, x):

        return self.F.apply(x)





# FReLU https://arxiv.org/abs/2007.11824 -------------------------------------------------------------------------------

class FReLU(nn.Module):

    def __init__(self, c1, k=3):  # ch_in, kernel

        super().__init__()

        self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1)

        self.bn = nn.BatchNorm2d(c1)



    def forward(self, x):

        return torch.max(x, self.bn(self.conv(x)))