import argparse
import os
import torch
if torch.__version__ >= "1.8":
import torch_npu
from torch import nn
import torch.multiprocessing as mp
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
from model import RCAN
from dataset import Dataset,Dataset_test_label
from utils import AverageMeter,device_id_to_process_device_map,information_print,timer
import shutil
import numpy as np
import time
import PIL.Image as pil_image
import skimage.io as io
from skimage.metrics import peak_signal_noise_ratio,structural_similarity
from apex import amp
parser = argparse.ArgumentParser()
parser.add_argument('--arch', type=str, default='RCAN')
parser.add_argument('--test_dataset_dir', type=str, required=True)
parser.add_argument('--outputs_dir', type=str, required=True)
parser.add_argument('--workers', type=int, default=8)
parser.add_argument('--scale', type=int, required=True)
parser.add_argument('--num_features', type=int, default=64)
parser.add_argument('--num_rg', type=int, default=10)
parser.add_argument('--num_rcab', type=int, default=20)
parser.add_argument('--reduction', type=int, default=16)
parser.add_argument('--checkpoint_path', type=str, help='the path of checkpoint to load')
parser.add_argument('--amp', default=False, action='store_true', help='if use amp to train the model')
parser.add_argument('--loss_scale', default=128.0, type=float, help='amp setting: loss scale, default 128.0')
parser.add_argument('--opt_level', default='O2', type=str, help='amp setting: opt level, default O2')
parser.add_argument('--device', type=str,default= "gpu", help='npu or gpu')
parser.add_argument('--device_list', type=str,default= '0,1,2,3,4,5,6,7')
parser.add_argument('--device_id', type=int, default=None,help='index of gpu/npu to use')
parser.add_argument('--world_size', type=int, default=1,help='number of nodes for distributed training')
parser.add_argument('--from_multiprocessing_distributed', action='store_true',
help='if the checkpoint trained from mutil P')
parser.add_argument('--multiprocessing_distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
def test_eval(model,dataloader_test,opt):
psnr_list = []
psnr_list_bic = []
ssim_list = []
ssim_list_bic = []
for batch, (hr, lr, bicubic, filename) in enumerate(dataloader_test):
with torch.no_grad():
if opt.device == "npu":
lr = lr.npu()
elif opt.device == "gpu":
lr = lr.cuda()
pred = model(lr)
hr = hr[0].numpy().astype('uint8')
bicubic = bicubic[0].numpy().astype('uint8')
pred = np.transpose( (pred[0]*255.0).clamp_(0.0, 255.0).byte().cpu().numpy(),axes = (1,2,0))
if not opt.multiprocessing_distributed or (opt.multiprocessing_distributed and opt.process_id % opt.ndevices_per_node == 0):
image_pred = pil_image.fromarray(pred, mode='RGB')
image_pred.save(os.path.join(opt.outputs_dir, '{}_x{}_output{}.png'.format(filename[0], opt.scale,opt.process_id)))
image_src = hr/255.0
image_src = 65.481 * image_src[:,:,0] + 128.553 * image_src[:,:,1] + 24.966 * image_src[:,:,2] + 16
image_src = image_src/255.0
image_src = np.expand_dims(image_src,axis=2)
bicubic = bicubic/255.0
bicubic = 65.481 * bicubic[:,:,0] + 128.553 * bicubic[:,:,1] + 24.966 * bicubic[:,:,2] + 16
bicubic = bicubic/255.0
bicubic = np.expand_dims(bicubic,axis=2)
image_rcan = pred/255.0
image_rcan = 65.481 * image_rcan[:,:,0] + 128.553 * image_rcan[:,:,1] + 24.966 * image_rcan[:,:,2] + 16
image_rcan = image_rcan/255.0
image_rcan = np.expand_dims(image_rcan,axis=2)
image_src = image_src[opt.scale+6:-(opt.scale+6),opt.scale+6:-(opt.scale+6),:]
bicubic = bicubic[opt.scale+6:-(opt.scale+6),opt.scale+6:-(opt.scale+6),:]
image_rcan = image_rcan[opt.scale+6:-(opt.scale+6),opt.scale+6:-(opt.scale+6),:]
psnr_list.append(peak_signal_noise_ratio(image_src, image_rcan))
psnr_list_bic.append(peak_signal_noise_ratio(image_src, bicubic))
ssim_list.append(structural_similarity(image_src, image_rcan,win_size=11,gaussian_weights=True,multichannel=True,data_range=1.0,K1=0.01,K2=0.03,sigma=1.5))
ssim_list_bic.append(structural_similarity(image_src, bicubic,win_size=11,gaussian_weights=True,multichannel=True,data_range=1.0,K1=0.01,K2=0.03,sigma=1.5))
return np.mean(psnr_list),np.mean(ssim_list),np.mean(psnr_list_bic),np.mean(ssim_list_bic)
def test_prepare(opt):
opt.process_device_map = device_id_to_process_device_map(opt.device_list)
information_print(opt.process_device_map,mode = 0)
if opt.multiprocessing_distributed:
if opt.device_id != None:
raise ValueError("when you choose multi processing, you don't need to select one npu or gpu")
opt.ndevices_per_node = len(opt.process_device_map)
opt.world_size = opt.ndevices_per_node * opt.world_size
information_print("...multi processing...\nChoose to use {} {}s from device list...".format(opt.ndevices_per_node,opt.device),mode = 0)
elif opt.device_id != None:
opt.ndevices_per_node = 1
information_print("...single processing...\nChoose to use {}ID:{} from device list...".format(opt.device,opt.device_id),mode = 0)
else:
raise ValueError("The process_para set Wrong")
opt.outputs_dir = opt.outputs_dir + "/X{}/".format(
str(opt.scale))
if not os.path.exists(opt.outputs_dir):
os.makedirs(opt.outputs_dir)
else:
information_print("The dir is existing, if continue, Retraining from and Replacing...",mode = 0)
with open(opt.outputs_dir+'/TEST_Para_{}.txt'.format(time.strftime("%Y_%m_%d_%H_%M", time.localtime()) ),"w") as f:
for i in vars(opt):
f.write(i+":"+str(vars(opt)[i])+'\n')
f.close()
def main():
print("===============main() start=================")
opt = parser.parse_args()
information_print(opt,mode = 0)
os.environ['MASTER_ADDR'] = "127.0.0.1"
os.environ['MASTER_PORT'] = '29688'
test_prepare(opt)
print("===============main() end=================")
if opt.device == "npu":
import torch
import torch.npu
elif opt.device == "gpu":
import torch
if opt.multiprocessing_distributed:
mp.spawn(main_worker, nprocs=opt.ndevices_per_node,args=[ opt])
else:
main_worker( 0, opt)
def main_worker(process_id,opt):
opt.process_id = process_id
if opt.multiprocessing_distributed:
os.environ['MASTER_ADDR'] = "127.0.0.1"
os.environ['MASTER_PORT'] = '29688'
opt.device_id = opt.process_device_map[opt.process_id]
if opt.device =='npu':
torch.distributed.init_process_group(backend='hccl', world_size=opt.world_size, rank=opt.process_id)
if opt.device == 'gpu':
torch.distributed.init_process_group(backend='nccl', init_method="env://", world_size=opt.world_size, rank=opt.process_id)
else:
pass
print("===="*5)
print("SUB PROCESSING INFORMATION")
print("Number of Mutil-process {}".format(opt.ndevices_per_node))
print("rank ID {}".format(opt.process_id))
print("Wanted device {}ID:{}".format(opt.device,opt.device_id))
print("Chosen device {}".format(opt.device,torch.cuda.current_device() if opt.device == "gpu" else torch.npu.current_device()))
print("===="*5)
if opt.device =='npu':
loc = 'npu:{}'.format(opt.device_id)
device = torch.device("npu:{}".format(opt.device_id))
torch.npu.set_device(device)
model = RCAN(opt).to(device)
if opt.device == 'gpu':
loc = 'cuda:{}'.format(opt.device_id)
torch.cuda.set_device(opt.device_id)
model = RCAN(opt).cuda()
if opt.amp:
model = amp.initialize(model, opt_level = opt.opt_level,loss_scale=opt.loss_scale)
if opt.multiprocessing_distributed:
model = torch.nn.parallel.DistributedDataParallel(model,device_ids=[opt.device_id], broadcast_buffers=False)
if os.path.exists(opt.checkpoint_path):
print("loading checkpoint {}".format(opt.checkpoint_path))
checkpoint = torch.load(opt.checkpoint_path, map_location=loc)
if not opt.from_multiprocessing_distributed:
if opt.multiprocessing_distributed:
pretrained_dict = checkpoint['model']
new_stat_dict = {}
for k,v in pretrained_dict.items():
new_stat_dict["module."+k] = v
model.load_state_dict(new_stat_dict)
else:
model.load_state_dict(checkpoint['model'])
else:
if opt.multiprocessing_distributed:
model.load_state_dict(checkpoint['model'])
else:
pretrained_dict = checkpoint['model']
new_stat_dict = {}
for k,v in pretrained_dict.items():
new_stat_dict[k[7:]] = v
model.load_state_dict(new_stat_dict)
checkpoint_performance =checkpoint['ck_performance']
checkpoint_time =checkpoint['ck_time']
if opt.amp:
amp.load_state_dict(checkpoint['amp'])
print("loaded checkpoint '{}' (epoch {})".format(opt.checkpoint_path, len(checkpoint['ck_performance'])))
else:
raise ValueError("=> no checkpoint found at '{}'".format(opt.checkpoint_path))
torch.backends.cudnn.benchmark=True
dataset_test = Dataset_test_label(opt.test_dataset_dir, opt.scale)
if opt.multiprocessing_distributed:
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
test_sampler = None
dataloader_test = DataLoader(dataset=dataset_test,batch_size=1,num_workers=1 ,pin_memory=True,drop_last=True)
model.eval()
psnr_avg, ssim_avg, psnr_avg_bic, ssim_avg_bic = test_eval(model,dataloader_test,opt)
print("PSNR_bic:",psnr_avg_bic)
print("SSIM_bic:",ssim_avg_bic)
print("PSNR_RCAN:",psnr_avg)
print("SSIM_RCAN:",ssim_avg)
print("TIME:",np.mean(checkpoint_time))
if __name__ == '__main__':
main()