import math
import torch
import torch.nn as nn
import numpy as np
from skimage.measure import compare_psnr
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
self.val = 0
self.avg = 0
self.sum = 0
def reset(self):
""" clear """
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
""" and one val"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def weights_init_kaiming(m):
""" init layers """
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(mean=0, std=math.sqrt(2. / 9. / 64.)).clamp_(-0.025, 0.025)
nn.init.constant(m.bias.data, 0.0)
def batch_PSNR(img, imclean, data_range):
""" comprare two data """
Img = img.data.cpu().numpy().astype(np.float32)
Iclean = imclean.data.cpu().numpy().astype(np.float32)
PSNR = 0
for i in range(Img.shape[0]):
PSNR += compare_psnr(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range)
return (PSNR / Img.shape[0])
def data_augmentation(image, mode):
""" change numpy matrix """
out = np.transpose(image, (1, 2, 0))
if mode == 0:
out = out
elif mode == 1:
out = np.flipud(out)
elif mode == 2:
out = np.rot90(out)
elif mode == 3:
out = np.rot90(out)
out = np.flipud(out)
elif mode == 4:
out = np.rot90(out, k=2)
elif mode == 5:
out = np.rot90(out, k=2)
out = np.flipud(out)
elif mode == 6:
out = np.rot90(out, k=3)
elif mode == 7:
out = np.rot90(out, k=3)
out = np.flipud(out)
return np.transpose(out, (2, 0, 1))
def changePose(out, mode):
""" change numpy matrix """
out = np.squeeze(out, 0)
if mode == 0:
out=out
elif mode == 1:
out = np.flipud(out)
elif mode == 2:
out = np.rot90(out, axes=(1, 0))
elif mode == 3:
out = np.rot90(out, axes=(1, 0))
out = np.flipud(out)
elif mode == 4:
out = np.rot90(out, k=2, axes=(1, 0))
elif mode == 5:
out = np.rot90(out, k=2, axes=(1, 0))
out = np.flipud(out)
elif mode == 6:
out = np.rot90(out, k=3, axes=(1, 0))
elif mode == 7:
out = np.rot90(out, k=3, axes=(1, 0))
out = np.flipud(out)
out=np.expand_dims(out, axis=0)
return out