import cv2
import os
import argparse
import glob
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from models import ADNet
from utils import *
from collections import OrderedDict
import torch.distributed as dist
if torch.__version__ >= '1.8':
import torch_npu
parser = argparse.ArgumentParser(description="ADNet_Test")
parser.add_argument("--num_of_layers", type=int, default=17, help="Number of total layers")
parser.add_argument("--logdir", type=str, default="", help='path of log files')
parser.add_argument("--test_data", type=str, default='BSD68', help='test on Set12 or Set68')
parser.add_argument("--test_noiseL", type=float, default=25, help='noise level used on test set')
parser.add_argument("--DeviceID", type=int, default=0, help='choose a device id to use')
parser.add_argument("--is_distributed", type=int, default=0, help='choose ddp or not')
parser.add_argument('--world_size', default=-1, type=int, help='number of nodes for distributed training')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument("--num_gpus", default=1, type=int)
opt = parser.parse_args()
def normalize(data):
return data/255.
def main():
local_device = torch.device(f'npu:{opt.DeviceID}')
torch.npu.set_device(local_device)
print("using npu :{}".format(opt.DeviceID))
print('Loading model ...\n')
net = ADNet(channels=1, num_of_layers=17)
model = net
checkpoint = torch.load(os.path.join(opt.logdir, 'best_model.pth'), map_location=local_device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.npu()
model.eval()
print('Loading data info ...\n')
files_source = glob.glob(os.path.join('data', opt.test_data, '*.png'))
files_source.sort()
psnr_test = 0
for f in files_source:
Img = cv2.imread(f)
Img = normalize(np.float32(Img[:,:,0]))
Img = np.expand_dims(Img, 0)
Img = np.expand_dims(Img, 1)
ISource = torch.Tensor(Img)
torch.manual_seed(0)
noise = torch.FloatTensor(ISource.size()).normal_(mean=0, std=opt.test_noiseL/255.)
INoisy = ISource + noise
ISource = Variable(ISource)
INoisy = Variable(INoisy)
ISource = ISource.npu()
INoisy = INoisy.npu()
with torch.no_grad():
Out = torch.clamp(model(INoisy), 0., 1.)
psnr = batch_PSNR(Out, ISource, 1.)
psnr_test += psnr
print("%s PSNR %f" % (f, psnr))
psnr_test /= len(files_source)
print("\nPSNR on test data %f" % psnr_test)
if __name__ == "__main__":
main()