'''
- resize_pos_embed: resize position embedding
- load_for_transfer_learning: load pretrained paramters to model in transfer learning
- get_mean_and_std: calculate the mean and std value of dataset.
'''
import os
import sys
import time
import torch
import math
import torch.nn as nn
import torch.nn.init as init
import logging
import os
from collections import OrderedDict
import torch.nn.functional as F
_logger = logging.getLogger(__name__)
def resize_pos_embed(posemb, posemb_new):
ntok_new = posemb_new.shape[1]
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
ntok_new -= 1
gs_old = int(math.sqrt(len(posemb_grid)))
gs_new = int(math.sqrt(ntok_new))
_logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bicubic')
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def resize_pos_embed_without_cls(posemb, posemb_new):
ntok_new = posemb_new.shape[1]
posemb_grid = posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
gs_new = int(math.sqrt(ntok_new))
_logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bicubic')
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
return posemb_grid
def resize_pos_embed_4d(posemb, posemb_new):
gs_old = posemb.shape[1]
gs_new = posemb_new.shape[1]
_logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
posemb_grid = posemb
posemb_grid = posemb_grid.permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bicubic')
posemb_grid = posemb_grid.permute(0, 2, 3, 1)
return posemb_grid
def load_state_dict(checkpoint_path,model, use_ema=False, num_classes=1000):
if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = 'state_dict'
if isinstance(checkpoint, dict):
if use_ema and 'state_dict_ema' in checkpoint:
state_dict_key = 'state_dict_ema'
if state_dict_key and state_dict_key in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
state_dict = new_state_dict
else:
state_dict = checkpoint
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
if num_classes != state_dict['head.bias'].shape[0]:
del state_dict['head.weight']
del state_dict['head.bias']
old_aux_head_weight = state_dict.pop('aux_head.weight', None)
old_aux_head_bias = state_dict.pop('aux_head.bias', None)
old_posemb = state_dict['pos_embed']
if model.pos_embed.shape != old_posemb.shape:
if len(old_posemb.shape)==3:
if int(math.sqrt(old_posemb.shape[1]))**2==old_posemb.shape[1]:
new_posemb = resize_pos_embed_without_cls(old_posemb, model.pos_embed)
else:
new_posemb = resize_pos_embed(old_posemb, model.pos_embed)
elif len(old_posemb.shape)==4:
new_posemb = resize_pos_embed_4d(old_posemb, model.pos_embed)
state_dict['pos_embed'] = new_posemb
return state_dict
else:
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_pretrained_weights(model, checkpoint_path, use_ema=False, strict=True, num_classes=1000):
state_dict = load_state_dict(checkpoint_path, model, use_ema, num_classes)
model.load_state_dict(state_dict, strict=strict)
def get_mean_and_std(dataset):
'''Compute the mean and std value of dataset.'''
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
mean = torch.zeros(3)
std = torch.zeros(3)
print('==> Computing mean and std..')
for inputs, targets in dataloader:
for i in range(3):
mean[i] += inputs[:,i,:,:].mean()
std[i] += inputs[:,i,:,:].std()
mean.div_(len(dataset))
std.div_(len(dataset))
return mean, std
def init_params(net):
'''Init layer parameters.'''
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode='fan_out')
if m.bias:
init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant(m.weight, 1)
init.constant(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal(m.weight, std=1e-3)
if m.bias:
init.constant(m.bias, 0)