05360171创建于 2022年3月18日历史提交
# BSD 3-Clause License
#
# Copyright (c) 2017 xxxx
# All rights reserved.
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# ============================================================================
import argparse
import os
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from utils import accuracy, ProgressMeter, AverageMeter
from repvgg import get_RepVGG_func_by_name, RepVGGBlock
from utils import load_checkpoint, get_ImageNet_train_dataset, get_default_train_trans

#   Insert BN into an inference-time RepVGG (e.g., for quantization-aware training).
#   Get the mean and std on every conv3x3 (before the bias-adding) on the train set. Then use such data to initialize BN layers and insert them after conv3x3.
#   May, 07, 2021

parser = argparse.ArgumentParser(description='Get the mean and std on every conv3x3 (before the bias-adding) on the train set. Then use such data to initialize BN layers and insert them after conv3x3.')
parser.add_argument('data', metavar='DIR', help='path to dataset')
parser.add_argument('weights', metavar='WEIGHTS', help='path to the weights file')
parser.add_argument('save', metavar='SAVE', help='path to save the model with BN')
parser.add_argument('-a', '--arch', metavar='ARCH', default='RepVGG-A0')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=100, type=int,
                    metavar='N',
                    help='mini-batch size (default: 100) for test')
parser.add_argument('-n', '--num-batches', default=500, type=int,
                    metavar='N',
                    help='number of batches (default: 500) to record the mean and std on the train set')
parser.add_argument('-r', '--resolution', default=224, type=int,
                    metavar='R',
                    help='resolution (default: 224) for test')


def update_running_mean_var(x, running_mean, running_var, momentum=0.9, is_first_batch=False):
    mean = x.mean(dim=(0, 2, 3), keepdim=True)
    var = ((x - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    if is_first_batch:
        running_mean = mean
        running_var = var
    else:
        running_mean = momentum * running_mean + (1.0 - momentum) * mean
        running_var = momentum * running_var + (1.0 - momentum) * var
    return running_mean, running_var

#   Record the mean and std like a BN layer but do no normalization
class BNStatistics(nn.Module):
    def __init__(self, num_features):
        super(BNStatistics, self).__init__()
        shape = (1, num_features, 1, 1)
        self.register_buffer('running_mean', torch.zeros(shape))
        self.register_buffer('running_var', torch.zeros(shape))
        self.is_first_batch = True

    def forward(self, x):
        if self.running_mean.device != x.device:
            self.running_mean = self.running_mean.to(x.device)
            self.running_var = self.running_var.to(x.device)
        self.running_mean, self.running_var = update_running_mean_var(x, self.running_mean, self.running_var, momentum=0.9, is_first_batch=self.is_first_batch)
        self.is_first_batch = False
        return x

#   This is designed to insert BNStat layer between Conv2d(without bias) and its bias
class BiasAdd(nn.Module):
    def __init__(self, num_features):
        super(BiasAdd, self).__init__()
        self.bias = torch.nn.Parameter(torch.Tensor(num_features))
    def forward(self, x):
        return x + self.bias.view(1, -1, 1, 1)

def switch_repvggblock_to_bnstat(model):
    for n, block in model.named_modules():
        if isinstance(block, RepVGGBlock):
            print('switch to BN Statistics: ', n)
            assert hasattr(block, 'rbr_reparam')
            stat = nn.Sequential()
            stat.add_module('conv', nn.Conv2d(block.rbr_reparam.in_channels, block.rbr_reparam.out_channels,
                                              block.rbr_reparam.kernel_size,
                                              block.rbr_reparam.stride, block.rbr_reparam.padding,
                                              block.rbr_reparam.dilation,
                                              block.rbr_reparam.groups, bias=False))  # Note bias=False
            stat.add_module('bnstat', BNStatistics(block.rbr_reparam.out_channels))
            stat.add_module('biasadd', BiasAdd(block.rbr_reparam.out_channels))  # Bias is here
            stat.conv.weight.data = block.rbr_reparam.weight.data
            stat.biasadd.bias.data = block.rbr_reparam.bias.data
            block.__delattr__('rbr_reparam')
            block.rbr_reparam = stat

def switch_bnstat_to_convbn(model):
    for n, block in model.named_modules():
        if isinstance(block, RepVGGBlock):
            assert hasattr(block, 'rbr_reparam')
            assert hasattr(block.rbr_reparam, 'bnstat')
            print('switch to ConvBN: ', n)
            conv = nn.Conv2d(block.rbr_reparam.conv.in_channels, block.rbr_reparam.conv.out_channels,
                             block.rbr_reparam.conv.kernel_size,
                             block.rbr_reparam.conv.stride, block.rbr_reparam.conv.padding,
                             block.rbr_reparam.conv.dilation,
                             block.rbr_reparam.conv.groups, bias=False)
            bn = nn.BatchNorm2d(block.rbr_reparam.conv.out_channels)
            bn.running_mean = block.rbr_reparam.bnstat.running_mean.squeeze()  # Initialize the mean and var of BN with the statistics
            bn.running_var = block.rbr_reparam.bnstat.running_var.squeeze()
            std = (bn.running_var + bn.eps).sqrt()
            conv.weight.data = block.rbr_reparam.conv.weight.data
            bn.weight.data = std
            bn.bias.data = block.rbr_reparam.biasadd.bias.data + bn.running_mean  # Initialize gamma = std and beta = bias + mean

            convbn = nn.Sequential()
            convbn.add_module('conv', conv)
            convbn.add_module('bn', bn)
            block.__delattr__('rbr_reparam')
            block.rbr_reparam = convbn


#   Insert a BN after conv3x3 (rbr_reparam). With no reasonable initialization of BN, the model may break down.
#   So you have to load the weights obtained through the BN statistics (please see the function "insert_bn" in this file).
def directly_insert_bn_without_init(model):
    for n, block in model.named_modules():
        if isinstance(block, RepVGGBlock):
            print('directly insert a BN with no initialization: ', n)
            assert hasattr(block, 'rbr_reparam')
            convbn = nn.Sequential()
            convbn.add_module('conv', nn.Conv2d(block.rbr_reparam.in_channels, block.rbr_reparam.out_channels,
                                              block.rbr_reparam.kernel_size,
                                              block.rbr_reparam.stride, block.rbr_reparam.padding,
                                              block.rbr_reparam.dilation,
                                              block.rbr_reparam.groups, bias=False))  # Note bias=False
            convbn.add_module('bn', nn.BatchNorm2d(block.rbr_reparam.out_channels))
            #   ====================
            convbn.add_module('relu', nn.ReLU())
            # TODO we moved ReLU from "block.nonlinearity" into "rbr_reparam" (nn.Sequential). This makes it more convenient to fuse operators (see RepVGGWholeQuant.fuse_model) using off-the-shelf APIs.
            block.nonlinearity = nn.Identity()
            #==========================
            block.__delattr__('rbr_reparam')
            block.rbr_reparam = convbn


def insert_bn():
    args = parser.parse_args()

    repvgg_build_func = get_RepVGG_func_by_name(args.arch)

    model = repvgg_build_func(deploy=True).cuda()

    load_checkpoint(model, args.weights)

    switch_repvggblock_to_bnstat(model)

    cudnn.benchmark = True

    trans = get_default_train_trans(args)
    print('data aug: ', trans)

    train_dataset = get_ImageNet_train_dataset(args, trans)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')

    progress = ProgressMeter(
        min(len(train_loader), args.num_batches),
        [batch_time, losses, top1, top5],
        prefix='BN stat: ')

    criterion = nn.CrossEntropyLoss().cuda()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(train_loader):
            if i >= args.num_batches:
                break
            images = images.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 10 == 0:
                progress.display(i)


        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    switch_bnstat_to_convbn(model)

    torch.save(model.state_dict(), args.save)




if __name__ == '__main__':
    insert_bn()