"""
BSD 3-Clause License
Copyright (c) Soumith Chintala 2016,
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Copyright 2020 Huawei Technologies Co., Ltd
Licensed under the BSD 3-Clause License (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://spdx.org/licenses/BSD-3-Clause.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn as nn
from timm.models.helpers import load_pretrained
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
from timm.models.resnet import resnet26d, resnet50d, resnet101d
import numpy as np
from .layers import *
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
'classifier': 'head',
**kwargs
}
default_cfgs = {
'LV_ViT_Tiny': _cfg(),
'LV_ViT': _cfg(),
'LV_ViT_Medium': _cfg(crop_pct=1.0),
'LV_ViT_Large': _cfg(crop_pct=1.0),
}
def get_block(block_type, **kargs):
if block_type=='mha':
return MHABlock(**kargs)
elif block_type=='ffn':
return FFNBlock(**kargs)
elif block_type=='tr':
return Block(**kargs)
def rand_bbox(size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def get_dpr(drop_path_rate,depth,drop_path_decay='linear'):
if drop_path_decay=='linear':
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
elif drop_path_decay=='fix':
dpr= [drop_path_rate]*depth
else:
assert len(drop_path_rate)==depth
dpr=drop_path_rate
return dpr
class LV_ViT(nn.Module):
""" Vision Transformer with tricks
Arguements:
p_emb: different conv based position embedding (default: 4 layer conv)
skip_lam: residual scalar for skip connection (default: 1.0)
order: which order of layers will be used (default: None, will override depth if given)
mix_token: use mix token augmentation for batch of tokens (default: False)
return_dense: whether to return feature of all tokens with an additional aux_head (default: False)
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., drop_path_decay='linear', hybrid_backbone=None, norm_layer=nn.LayerNorm, p_emb='4_2', head_dim = None,
skip_lam = 1.0,order=None, mix_token=False, return_dense=False):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.output_dim = embed_dim if num_classes==0 else num_classes
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
if p_emb=='4_2':
patch_embed_fn = PatchEmbed4_2
elif p_emb=='4_2_128':
patch_embed_fn = PatchEmbed4_2_128
else:
patch_embed_fn = PatchEmbedNaive
self.patch_embed = patch_embed_fn(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
if order is None:
dpr=get_dpr(drop_path_rate, depth, drop_path_decay)
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, head_dim=head_dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, skip_lam=skip_lam)
for i in range(depth)])
else:
dpr=get_dpr(drop_path_rate, len(order), drop_path_decay)
self.blocks = nn.ModuleList([
get_block(order[i],
dim=embed_dim, num_heads=num_heads, head_dim=head_dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, skip_lam=skip_lam)
for i in range(len(order))])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.return_dense=return_dense
self.mix_token=mix_token
if return_dense:
self.aux_head=nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if mix_token:
self.beta = 1.0
assert return_dense, "always return all features when mixtoken is enabled"
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, GroupLinear):
trunc_normal_(m.group_weight, std=.02)
if isinstance(m, GroupLinear) and m.group_bias is not None:
nn.init.constant_(m.group_bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_embeddings(self,x):
x = self.patch_embed(x)
return x
def forward_tokens(self, x):
B = x.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
def forward_features(self,x):
x = self.forward_embeddings(x)
x = x.flatten(2).transpose(1, 2)
x = self.forward_tokens(x)
return x
def forward(self, x):
x = self.forward_embeddings(x)
if self.mix_token and self.training:
lam = np.random.beta(self.beta, self.beta)
patch_h, patch_w = x.shape[2],x.shape[3]
bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
temp_x = x.clone()
temp_x[:, :, bbx1:bbx2, bby1:bby2] = x.flip(0)[:, :, bbx1:bbx2, bby1:bby2]
x = temp_x
else:
bbx1, bby1, bbx2, bby2 = 0,0,0,0
x = x.flatten(2).transpose(1, 2)
x = self.forward_tokens(x)
x_cls = self.head(x[:,0])
if self.return_dense:
x_aux = self.aux_head(x[:,1:])
if not self.training:
return x_cls+0.5*x_aux.max(1)[0]
if self.mix_token and self.training:
x_aux = x_aux.reshape(x_aux.shape[0],patch_h, patch_w,x_aux.shape[-1])
temp_x = x_aux.clone()
temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]
x_aux = temp_x
x_aux = x_aux.reshape(x_aux.shape[0],patch_h*patch_w,x_aux.shape[-1])
return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)
return x_cls
@register_model
def vit(pretrained=False, **kwargs):
model = LV_ViT(patch_size=16, embed_dim=384, depth=16, num_heads=6, mlp_ratio=3.,
p_emb=1, **kwargs)
model.default_cfg = default_cfgs['LV_ViT']
return model
@register_model
def lvvit(pretrained=False, **kwargs):
model = LV_ViT(patch_size=16, embed_dim=384, depth=16, num_heads=6, mlp_ratio=3.,
p_emb='4_2',skip_lam=2., **kwargs)
model.default_cfg = default_cfgs['LV_ViT']
return model
@register_model
def lvvit_s(pretrained=False, **kwargs):
model = LV_ViT(patch_size=16, embed_dim=384, depth=16, num_heads=6, mlp_ratio=3.,
p_emb='4_2',skip_lam=2., return_dense=True,mix_token=True, **kwargs)
model.default_cfg = default_cfgs['LV_ViT']
return model
@register_model
def lvvit_m(pretrained=False, **kwargs):
model = LV_ViT(patch_size=16, embed_dim=512, depth=20, num_heads=8, mlp_ratio=3.,
p_emb='4_2',skip_lam=2., return_dense=True,mix_token=True, **kwargs)
model.default_cfg = default_cfgs['LV_ViT_Medium']
return model
@register_model
def lvvit_l(pretrained=False, **kwargs):
order = ['tr']*24
model = LV_ViT(patch_size=16, embed_dim=768,depth=24, num_heads=12, mlp_ratio=3.,
p_emb='4_2_128',skip_lam=3., return_dense=True,mix_token=True, order=order, **kwargs)
model.default_cfg = default_cfgs['LV_ViT_Large']
return model