import logging
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
class SingleLinear(nn.Module):
"""args_shape: [(1, 16, 28, 28)]
"""
def __init__(self):
super(SingleConv, self).__init__()
self.layer1 = nn.Sequential(
nn.Linear(in_features=16, out_features=64, bias=True),
nn.BatchNorm2d(64))
self.layer2 = nn.Sequential(
nn.Linear(in_features=64, out_features=64, ),
nn.BatchNorm2d(64))
self.layer3 = nn.Sequential(
nn.Linear(in_features=64, out_features=48,),
nn.BatchNorm2d(48))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
class SingleConv(nn.Module):
"""args_shape: [(1, 16, 28, 28)]
"""
def __init__(self):
super(SingleConv, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=64, groups=1, kernel_size=3, bias=True),
nn.BatchNorm2d(64))
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=48, groups=1, kernel_size=3, bias=True),
nn.BatchNorm2d(48))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x
class SingleDepthwsieConv(nn.Module):
"""args_shape: [(1, 16, 28, 28)]
"""
def __init__(self):
super(SingleDepthwsieConv, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, groups=16, kernel_size=3, bias=True),
nn.BatchNorm2d(16))
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=8, groups=1, kernel_size=3, bias=True),
nn.BatchNorm2d(8))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x
class GroupConv(nn.Module):
"""args_shape: [(1, 16, 28, 28)]
"""
def __init__(self):
super(GroupConv, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=64, groups=1, kernel_size=3, bias=True),
nn.BatchNorm2d(64))
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(128))
self.layer3 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=8, groups=1, kernel_size=3, bias=True),
nn.BatchNorm2d(8))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
class SingleGroupConv(nn.Module):
"""args_shape: [(1, 16, 28, 28)]
"""
def __init__(self):
super(SingleGroupConv, self).__init__()
self.layer1 = nn.Conv2d(in_channels=16, out_channels=64, groups=4, kernel_size=3, bias=True)
def forward(self, x):
x = self.layer1(x)
return x
class SingleDeconv(nn.Module):
"""args_shape: [(1, 16, 28, 28)]
"""
def __init__(self):
super(SingleDeconv, self).__init__()
self.layer1 = nn.ConvTranspose2d(in_channels=16, out_channels=16, groups=4, kernel_size=3, bias=False)
def forward(self, x):
x = self.layer1(x)
return x
class MultiDeconv(nn.Module):
"""args_shape: [(1, 16, 28, 28)]
"""
def __init__(self):
super(MultiDeconv, self).__init__()
self.layer1 = nn.Sequential(
nn.ConvTranspose2d(in_channels=16, out_channels=64, groups=1, kernel_size=3, bias=True),
nn.BatchNorm2d(64))
self.layer2 = nn.Sequential(
nn.ConvTranspose2d(in_channels=64, out_channels=32, groups=32, kernel_size=3, bias=True),
nn.BatchNorm2d(32))
self.layer3 = nn.Sequential(
nn.ConvTranspose2d(in_channels=32, out_channels=8, groups=1, kernel_size=3, bias=True),
nn.BatchNorm2d(8))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
class ConcatConv(nn.Module):
"""args_shape: [(1, 16, 28, 28)]
"""
def __init__(self):
super(ConcatConv, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=48, groups=1, kernel_size=3, bias=True),
nn.BatchNorm2d(48))
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=128, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(128))
self.layer3 = nn.Sequential(
nn.Conv2d(in_channels=176, out_channels=8, groups=1, kernel_size=3, bias=True),
nn.BatchNorm2d(8))
def forward(self, x):
x1 = self.layer1(x)
x2 = self.layer2(x)
x_1_2 = torch.cat((x1, x2), dim=1)
x3 = self.layer3(x_1_2)
return x3
class EltwiseConv(nn.Module):
"""args_shape: [(1, 16, 28, 28)]
"""
def __init__(self):
super(EltwiseConv, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=160, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(160))
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=160, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(160))
self.layer3 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=160, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(160))
self.layer4 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=160, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(160))
self.layer5 = nn.Sequential(
nn.Conv2d(in_channels=160, out_channels=80, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(80))
self.layer6 = nn.Sequential(
nn.Conv2d(in_channels=160, out_channels=80, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(80))
def forward(self, x):
x1 = self.layer1(x)
x2 = self.layer2(x)
x_1_2 = x1 + x2
x3 = self.layer3(x)
x_1_2_3 = x_1_2 + x3
x4 = self.layer4(x)
x_1_2_4 = x_1_2 + x4
y1 = self.layer5(x_1_2_3)
y2 = self.layer6(x_1_2_4)
return y1, y2
class SplitConv(nn.Module):
"""args_shape: [(1, 16, 28, 28)]
"""
def __init__(self):
super(SplitConv, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=160, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(160))
self.layer2_1 = nn.Sequential(
nn.Conv2d(in_channels=80, out_channels=160, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(160))
self.layer2_2 = nn.Sequential(
nn.Conv2d(in_channels=80, out_channels=80, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(80))
def forward(self, x):
x1 = self.layer1(x)
logger.info('%s', x1.shape)
x1 = torch.split(x1, 80, dim=1)
logger.info('%s %s', x1[0].shape, x1[0].shape)
x2_1 = self.layer2_1(x1[0])
x2_2 = self.layer2_2(x1[1])
return x2_1, x2_2
class SplitConcatGroupConv(nn.Module):
"""args_shape: [(1, 16, 28, 28)]
"""
def __init__(self):
super(SplitConcatGroupConv, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=160, groups=1, kernel_size=3, bias=True),
nn.BatchNorm2d(160))
self.layer2_1 = nn.Sequential(
nn.Conv2d(in_channels=80, out_channels=160, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(160))
self.layer2_2 = nn.Sequential(
nn.Conv2d(in_channels=80, out_channels=80, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(80))
self.layer3 = nn.Sequential(
nn.Conv2d(in_channels=240, out_channels=16, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(16))
def forward(self, x):
x1 = self.layer1(x)
logger.info('%s', x1.shape)
x1 = torch.split(x1, 80, dim=1)
logger.info('%s %s', x1[0].shape, x1[0].shape)
x2_1 = self.layer2_1(x1[0])
x2_2 = self.layer2_2(x1[1])
x2 = torch.cat([x2_1, x2_2], 1)
x3 = self.layer3(x2)
return x3
class SplitConcatConv(nn.Module):
"""args_shape: [(1, 16, 28, 28)]
"""
def __init__(self):
super(SplitConcatConv, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=160, groups=1, kernel_size=3, bias=True),
nn.BatchNorm2d(160))
self.layer2_1 = nn.Sequential(
nn.Conv2d(in_channels=80, out_channels=160, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(160))
self.layer2_2 = nn.Sequential(
nn.Conv2d(in_channels=80, out_channels=80, groups=2, kernel_size=3, bias=True),
nn.BatchNorm2d(80))
self.layer3 = nn.Sequential(
nn.Conv2d(in_channels=240, out_channels=16, groups=1, kernel_size=3, bias=True),
nn.BatchNorm2d(16))
def forward(self, x):
x1 = self.layer1(x)
logger.info('%s', x1.shape)
x1 = torch.split(x1, 80, dim=1)
logger.info('%s %s', x1[0].shape, x1[0].shape)
x2_1 = self.layer2_1(x1[0])
x2_2 = self.layer2_2(x1[1])
x2 = torch.cat([x2_1, x2_2], 1)
x3 = self.layer3(x2)
return x3
class AvgpoolFlatten(nn.Module):
"""args_shape: [(1, 16, 28, 28)]
"""
def __init__(self):
super(AvgpoolFlatten, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, groups=1, kernel_size=3, bias=True),
nn.BatchNorm2d(32))
self.layer2 = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.layer3 = nn.Linear(in_features=32, out_features=10, bias=True)
def forward(self, x):
x1 = self.layer1(x)
x2 = self.layer2(x1)
x2 = torch.flatten(x2, start_dim=1)
x3 = self.layer3(x2)
return x3
class NetTrainBranch(nn.Module):
""" args_shape: [(1, 2, 28, 28)]
conv + bn
conv(with bias) + bn
depthwise_conv + bn
depthwise_conv(with bais) + bn
group_conv + bn
group_conv(bias) + bn
fc + bn
fc(bias) + bn
"""
def __init__(self):
super(NetTrainBranch, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(2, 16, kernel_size=3, bias=False),
nn.BatchNorm2d(16))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 16, kernel_size=3, bias=True),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True))
self.layer3 = nn.Sequential(
nn.Conv2d(16, 16, kernel_size=3, bias=True),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.layer1(x)
if self.training:
x = self.layer2(x)
else:
x = self.layer3(x)
return x