import torch
import numpy as np
from thop import profile
from thop import clever_format
if torch.__version__ >= "1.8":
import torch_npu
def clip_gradient(optimizer, grad_clip):
"""
For calibrating misalignment gradient via cliping gradient technique
:param optimizer:
:param grad_clip:
:return:
"""
for group in optimizer.param_groups:
for param in group['params']:
if param.grad is not None:
param.grad.data.clamp_(-grad_clip, grad_clip)
def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
decay = decay_rate ** (epoch // decay_epoch)
for param_group in optimizer.param_groups:
param_group['lr'] *= decay
class AvgMeter(object):
def __init__(self, num=40):
self.num = num
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.losses = []
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
self.losses.append(val)
def show(self):
return torch.mean(torch.stack(self.losses[np.maximum(len(self.losses)-self.num, 0):]))
def CalParams(model, input_tensor):
"""
Usage:
Calculate Params and FLOPs via [THOP](https://github.com/Lyken17/pytorch-OpCounter)
Necessarity:
from thop import profile
from thop import clever_format
:param model:
:param input_tensor:
:return:
"""
flops, params = profile(model, inputs=(input_tensor,))
flops, params = clever_format([flops, params], "%.3f")
print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format(flops, params))
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', start_count_index=2):
self.name = name
self.fmt = fmt
self.reset()
self.start_count_index = start_count_index
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
if self.count == 0:
self.N = n
self.val = val
self.count += n
if self.count > (self.start_count_index * self.N):
self.sum += val * n
self.avg = self.sum / (self.count - self.start_count_index * self.N)
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class Conv2dForAvgPool2d(torch.nn.Conv2d):
def __init__(self, channels, kernel_size, stride, padding):
super(Conv2dForAvgPool2d, self).__init__(channels, channels,
kernel_size, stride,
padding, groups=channels,
bias=False)
self.weight.data = torch.ones_like(self.weight.data) / (kernel_size * kernel_size)
self.requires_grad_(False)
self.npu()
def forward(self, x):
return super().forward(x)