import argparse
import os
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from utils import accuracy, ProgressMeter, AverageMeter
from repvgg import get_RepVGG_func_by_name, RepVGGBlock
from utils import load_checkpoint, get_ImageNet_train_dataset, get_default_train_trans
parser = argparse.ArgumentParser(description='Get the mean and std on every conv3x3 (before the bias-adding) on the train set. Then use such data to initialize BN layers and insert them after conv3x3.')
parser.add_argument('data', metavar='DIR', help='path to dataset')
parser.add_argument('weights', metavar='WEIGHTS', help='path to the weights file')
parser.add_argument('save', metavar='SAVE', help='path to save the model with BN')
parser.add_argument('-a', '--arch', metavar='ARCH', default='RepVGG-A0')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=100, type=int,
metavar='N',
help='mini-batch size (default: 100) for test')
parser.add_argument('-n', '--num-batches', default=500, type=int,
metavar='N',
help='number of batches (default: 500) to record the mean and std on the train set')
parser.add_argument('-r', '--resolution', default=224, type=int,
metavar='R',
help='resolution (default: 224) for test')
def update_running_mean_var(x, running_mean, running_var, momentum=0.9, is_first_batch=False):
mean = x.mean(dim=(0, 2, 3), keepdim=True)
var = ((x - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
if is_first_batch:
running_mean = mean
running_var = var
else:
running_mean = momentum * running_mean + (1.0 - momentum) * mean
running_var = momentum * running_var + (1.0 - momentum) * var
return running_mean, running_var
class BNStatistics(nn.Module):
def __init__(self, num_features):
super(BNStatistics, self).__init__()
shape = (1, num_features, 1, 1)
self.register_buffer('running_mean', torch.zeros(shape))
self.register_buffer('running_var', torch.zeros(shape))
self.is_first_batch = True
def forward(self, x):
if self.running_mean.device != x.device:
self.running_mean = self.running_mean.to(x.device)
self.running_var = self.running_var.to(x.device)
self.running_mean, self.running_var = update_running_mean_var(x, self.running_mean, self.running_var, momentum=0.9, is_first_batch=self.is_first_batch)
self.is_first_batch = False
return x
class BiasAdd(nn.Module):
def __init__(self, num_features):
super(BiasAdd, self).__init__()
self.bias = torch.nn.Parameter(torch.Tensor(num_features))
def forward(self, x):
return x + self.bias.view(1, -1, 1, 1)
def switch_repvggblock_to_bnstat(model):
for n, block in model.named_modules():
if isinstance(block, RepVGGBlock):
print('switch to BN Statistics: ', n)
assert hasattr(block, 'rbr_reparam')
stat = nn.Sequential()
stat.add_module('conv', nn.Conv2d(block.rbr_reparam.in_channels, block.rbr_reparam.out_channels,
block.rbr_reparam.kernel_size,
block.rbr_reparam.stride, block.rbr_reparam.padding,
block.rbr_reparam.dilation,
block.rbr_reparam.groups, bias=False))
stat.add_module('bnstat', BNStatistics(block.rbr_reparam.out_channels))
stat.add_module('biasadd', BiasAdd(block.rbr_reparam.out_channels))
stat.conv.weight.data = block.rbr_reparam.weight.data
stat.biasadd.bias.data = block.rbr_reparam.bias.data
block.__delattr__('rbr_reparam')
block.rbr_reparam = stat
def switch_bnstat_to_convbn(model):
for n, block in model.named_modules():
if isinstance(block, RepVGGBlock):
assert hasattr(block, 'rbr_reparam')
assert hasattr(block.rbr_reparam, 'bnstat')
print('switch to ConvBN: ', n)
conv = nn.Conv2d(block.rbr_reparam.conv.in_channels, block.rbr_reparam.conv.out_channels,
block.rbr_reparam.conv.kernel_size,
block.rbr_reparam.conv.stride, block.rbr_reparam.conv.padding,
block.rbr_reparam.conv.dilation,
block.rbr_reparam.conv.groups, bias=False)
bn = nn.BatchNorm2d(block.rbr_reparam.conv.out_channels)
bn.running_mean = block.rbr_reparam.bnstat.running_mean.squeeze()
bn.running_var = block.rbr_reparam.bnstat.running_var.squeeze()
std = (bn.running_var + bn.eps).sqrt()
conv.weight.data = block.rbr_reparam.conv.weight.data
bn.weight.data = std
bn.bias.data = block.rbr_reparam.biasadd.bias.data + bn.running_mean
convbn = nn.Sequential()
convbn.add_module('conv', conv)
convbn.add_module('bn', bn)
block.__delattr__('rbr_reparam')
block.rbr_reparam = convbn
def directly_insert_bn_without_init(model):
for n, block in model.named_modules():
if isinstance(block, RepVGGBlock):
print('directly insert a BN with no initialization: ', n)
assert hasattr(block, 'rbr_reparam')
convbn = nn.Sequential()
convbn.add_module('conv', nn.Conv2d(block.rbr_reparam.in_channels, block.rbr_reparam.out_channels,
block.rbr_reparam.kernel_size,
block.rbr_reparam.stride, block.rbr_reparam.padding,
block.rbr_reparam.dilation,
block.rbr_reparam.groups, bias=False))
convbn.add_module('bn', nn.BatchNorm2d(block.rbr_reparam.out_channels))
convbn.add_module('relu', nn.ReLU())
block.nonlinearity = nn.Identity()
block.__delattr__('rbr_reparam')
block.rbr_reparam = convbn
def insert_bn():
args = parser.parse_args()
repvgg_build_func = get_RepVGG_func_by_name(args.arch)
model = repvgg_build_func(deploy=True).cuda()
load_checkpoint(model, args.weights)
switch_repvggblock_to_bnstat(model)
cudnn.benchmark = True
trans = get_default_train_trans(args)
print('data aug: ', trans)
train_dataset = get_ImageNet_train_dataset(args, trans)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
min(len(train_loader), args.num_batches),
[batch_time, losses, top1, top5],
prefix='BN stat: ')
criterion = nn.CrossEntropyLoss().cuda()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(train_loader):
if i >= args.num_batches:
break
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
output = model(images)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0:
progress.display(i)
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
switch_bnstat_to_convbn(model)
torch.save(model.state_dict(), args.save)
if __name__ == '__main__':
insert_bn()