import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
__all__ = ['resnet']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
class Layers_NCHW:
Conv2d = nn.Conv2d
MaxPool = nn.MaxPool2d
BnAddRelu = None
def __init__(self, bn_group, **kwargs):
super(Layers_NCHW, self).__init__()
self.nhwc = False
self.bn_group = bn_group
bn_base = nn.BatchNorm2d
class BnAddRelu_(bn_base):
def __init__(self, planes, fuse_relu=False, bn_group=1):
super(BnAddRelu_, self).__init__(planes)
self.fuse_relu_flag = fuse_relu
def forward(self, x, z=None):
out = super().forward(x)
if z is not None:
out = out.add_(z)
if self.fuse_relu_flag:
out = out.relu_()
return out
self.BnAddRelu = BnAddRelu_
def build_bn(self, planes, fuse_relu=False):
return self.BnAddRelu(planes, fuse_relu, self.bn_group)
def conv1x1(layer_types, in_planes, out_planes, stride=1):
"""1x1 convolution"""
return layer_types.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
bias=False)
def conv3x3(layer_types, in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return layer_types.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, layerImpls, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(layerImpls, inplanes, planes, stride=stride)
self.bn1 = layerImpls.build_bn(planes, fuse_relu=True)
self.conv2 = conv3x3(layerImpls, planes, planes)
self.bn2 = layerImpls.build_bn(planes, fuse_relu=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
if self.downsample is not None:
residual = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.conv2(out)
out = self.bn2(out, residual)
return out
class ResNet(nn.Module):
def __init__(self, layerImpls, block, layers, num_classes=1000,
pad_input=False, ssd_mods=False, use_nhwc=False,
bn_group=1):
self.inplanes = 64
super(ResNet, self).__init__()
if pad_input:
input_channels = 4
else:
input_channels = 3
self.conv1 = layerImpls.Conv2d(input_channels, 64, kernel_size=7, stride=2,
padding=3, bias=False)
self.bn1 = layerImpls.build_bn(64, fuse_relu=True)
self.maxpool = layerImpls.MaxPool(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(layerImpls, block, 64, layers[0])
self.layer2 = self._make_layer(layerImpls, block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(layerImpls, block, 256, layers[2], stride=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, layerImpls, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
layerImpls.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
layerImpls.build_bn(planes * block.expansion, fuse_relu=False),
)
layers = []
layers.append(block(layerImpls, self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(layerImpls, self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.classifier(x)
return x
def _transpose_state(state, pad_input=False):
for k in state.keys():
if len(state[k].shape) == 4:
if pad_input and "conv1.weight" in k and not 'layer' in k:
s = state[k].shape
state[k] = torch.cat([state[k], torch.zeros([s[0], 1, s[2], s[3]])], dim=1)
state[k] = state[k].permute(0, 2, 3, 1).contiguous()
return state
def resnet34(pretrained=False, nhwc=False, ssd_mods=False, **kwargs):
"""Constructs a ResNet model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
layerImpls = Layers_NCHW(**kwargs)
block = BasicBlock
layer_list = [3, 4, 6, 3]
model = ResNet(layerImpls, block, layer_list, ssd_mods=ssd_mods, use_nhwc=nhwc, **kwargs)
if pretrained:
orig_state_dict = torch.load('./resnet34-333f7ec4.pth')
state_dict = {k:orig_state_dict[k] for k in orig_state_dict if (not k.startswith('layer4') and not k.startswith('fc'))}
pad_input = kwargs.get('pad_input', False)
if nhwc:
state_dict = _transpose_state(state_dict, pad_input)
model.load_state_dict(state_dict)
return nn.Sequential(model.conv1, model.bn1, model.maxpool, model.layer1, model.layer2, model.layer3)