import cv2
import os
import argparse
import glob
import numpy as np
import torch
import torch.nn as nn
from urllib.request import urlretrieve
from skimage.measure import compare_psnr
def batch_PSNR(img, imclean, data_range):
""" comprare two data """
Img = img.data.cpu().numpy().astype(np.float32)
Iclean = imclean.data.cpu().numpy().astype(np.float32)
PSNR = 0
for i in range(Img.shape[0]):
PSNR += compare_psnr(Iclean[i, :, :, :], Img[i, :, :, :], data_range=data_range)
return (PSNR / Img.shape[0])
class DnCNN(nn.Module):
""" DnCnn class """
def __init__(self, channels, num_of_layers=17):
super(DnCNN, self).__init__()
kernel_size = 3
padding = 1
features = 64
layers = []
layers.append(nn.Conv2d(in_channels=channels, out_channels=features, \
kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.ReLU(inplace=True))
for _ in range(num_of_layers - 2):
layers.append(nn.Conv2d(in_channels=features, out_channels=features, \
kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.BatchNorm2d(features))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=features, out_channels=channels, \
kernel_size=kernel_size, padding=padding, bias=False))
self.dncnn = nn.Sequential(*layers)
def forward(self, x):
""" forward train """
out = self.dncnn(x)
return out
def main():
""" check one pic """
global deviceType
deviceType='cpu'
net = DnCNN(channels=1, num_of_layers=17)
device_ids = [0]
model = nn.DataParallel(net, device_ids=device_ids)
torch.device(deviceType)
model.load_state_dict(torch.load("net.pth"))
print("model get")
with open('url.ini', 'r') as f:
content = f.read()
img_url = content.split('img_url=')[1].split('\n')[0]
IMAGE_URL = img_url
urlretrieve(IMAGE_URL,"tem.png")
model.eval()
img = cv2.imread("tem.png")
im_h = img.shape[0]
im_w = img.shape[1]
imgA = np.float32(img[:, :, 0])
imgB = np.float32(img[:, :, 1])
imgC = np.float32(img[:, :, 2])
imgA -= modelOneChannle(imgA/255, model)
imgB -= modelOneChannle(imgB/255, model)
imgC -= modelOneChannle(imgC/255, model)
cl_im = np.zeros((im_h, im_w, 3))
for tm_h in range(im_h):
for tm_w in range(im_w):
cl_im[tm_h][tm_w][0] = imgA[tm_h][tm_w]
cl_im[tm_h][tm_w][1] = imgB[tm_h][tm_w]
cl_im[tm_h][tm_w][2] = imgC[tm_h][tm_w]
cv2.imwrite("clear.png", cl_im)
def modelOneChannle(imgTmp, model):
""" model one channel change """
imgTmp = np.expand_dims(imgTmp, 0)
imgTmp = np.expand_dims(imgTmp, 1)
imgTmp = torch.Tensor(imgTmp)
imgTmp = imgTmp.to(deviceType)
imgTmp = model(imgTmp)
imgTmp = imgTmp.cpu().detach().numpy()
imgTmp = np.squeeze(imgTmp)
return imgTmp
if __name__ == "__main__":
main()