import argparse
import os
import torch
if torch.__version__ >= "1.8":
import torch_npu
from torch import nn
import torch.multiprocessing as mp
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
from model import RCAN
from dataset import Dataset,Dataset_test_label
from utils import AverageMeter,device_id_to_process_device_map,information_print,timer
import shutil
import matplotlib.pyplot as plt
import numpy as np
import time
import PIL.Image as pil_image
import skimage.io as io
from skimage.metrics import peak_signal_noise_ratio,structural_similarity
from apex import amp
import apex
import sys
parser = argparse.ArgumentParser()
parser.add_argument('--arch', type=str, default='RCAN')
parser.add_argument('--train_dataset_dir', type=str, required=True)
parser.add_argument('--test_dataset_dir', type=str, required=True)
parser.add_argument('--outputs_dir', type=str, required=True)
parser.add_argument('--patch_size', type=int, default=48)
parser.add_argument('--batch_size', type=int, default=160)
parser.add_argument('--num_epochs', type=int, default=300)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--workers', type=int, default=8)
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--scale', type=int, default=2)
parser.add_argument('--num_features', type=int, default=64)
parser.add_argument('--num_rg', type=int, default=10)
parser.add_argument('--num_rcab', type=int, default=20)
parser.add_argument('--reduction', type=int, default=16)
parser.add_argument('--ifcontinue',action='store_true', default=False, help='if continue to train the model')
parser.add_argument('--checkpoint_path', type=str,default="model_best_amp.pth", help='the path of checkpoint to load')
parser.add_argument('--iffinetuning',action='store_true', default=False, help='if iffinetuning to train the model')
parser.add_argument('--finetuning_checkpoint_path', type=str,default="", help='the path of checkpoint to load')
parser.add_argument('--amp', default=False, action='store_true', help='if use amp to train the model')
parser.add_argument('--loss_scale', default=128.0, type=float, help='amp setting: loss scale, default 128.0')
parser.add_argument('--opt_level', default='O2', type=str, help='amp setting: opt level, default O2')
parser.add_argument('--device', type=str,default= "gpu", help='npu or gpu')
parser.add_argument('--device_list', type=str,default= '0,1,2,3,4,5,6,7')
parser.add_argument('--device_id', type=int, default=None,help='index of gpu/npu to use')
parser.add_argument('--world_size', type=int, default=1,help='number of nodes for distributed training')
parser.add_argument('--multiprocessing_distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
def flush_print(func):
def new_print(*args, **kwargs):
func(*args, **kwargs)
sys.stdout.flush()
return new_print
print = flush_print(print)
def save_checkpoint(opt,epoch,checkpoint_performance,checkpoint_time,model,optimizer):
checkpoint_path = os.path.join(opt.outputs_dir,'{}_epoch_last{}.pth'.format(opt.arch, "_amp" if opt.amp else ""))
checkpoint = {
'model':model.state_dict(),
'optimizer':optimizer.state_dict(),
'epoch':epoch,
'amp':amp.state_dict() if opt.amp else None,
'best':np.array(checkpoint_performance)[:,0].max(),
'ck_performance':checkpoint_performance,
'ck_time':checkpoint_time
}
torch.save(checkpoint, checkpoint_path )
plt.figure()
plt.title('PSNR')
plt.plot(range(len(np.array(checkpoint_performance)[:,0])),np.array(checkpoint_performance)[:,0])
plt.grid(True)
plt.savefig(os.path.join(opt.outputs_dir, "result.png"))
plt.close()
if checkpoint_performance[-1][0]> opt.best:
print("\t", "This Epoch {} PSNR: {} (BEST)".format(epoch,checkpoint_performance[-1][0]))
opt.best = checkpoint_performance[-1][0]
shutil.copyfile(checkpoint_path, os.path.join(opt.outputs_dir,'model_best{}.pth'.format("_amp" if opt.amp else "")))
else:
print("\t","This Epoch {} PSNR: {},".format(epoch,checkpoint_performance[-1][0]),
"The Best PSNR is {} at Epoch {}".format(opt.best, np.where(np.array(checkpoint_performance)[:,0] == opt.best)))
def test_eval(model,dataloader_test,opt):
psnr_list = []
ssim_list = []
filename_list = []
for batch, (hr, lr, bicubic, filename) in enumerate(dataloader_test):
with torch.no_grad():
if opt.device == "npu":
lr = lr.npu()
elif opt.device == "gpu":
lr = lr.cuda()
pred = model(lr)
hr = hr[0].numpy().astype('uint8')
bicubic = bicubic[0].numpy().astype('uint8')
pred = np.transpose( (pred[0]*255.0).clamp_(0.0, 255.0).byte().cpu().numpy(),axes = (1,2,0))
if not opt.multiprocessing_distributed or (opt.multiprocessing_distributed and opt.process_id % opt.ndevices_per_node == 0):
image_pred = pil_image.fromarray(pred, mode='RGB')
image_pred.save(os.path.join(opt.outputs_dir, '{}_x{}_output{}.png'.format(filename[0], opt.scale,opt.process_id)))
image_src = hr/255.0
image_src = 65.481 * image_src[:,:,0] + 128.553 * image_src[:,:,1] + 24.966 * image_src[:,:,2] + 16
image_src = image_src/255.0
image_src = np.expand_dims(image_src,axis=2)
bicubic = bicubic/255.0
bicubic = 65.481 * bicubic[:,:,0] + 128.553 * bicubic[:,:,1] + 24.966 * bicubic[:,:,2] + 16
bicubic = bicubic/255.0
bicubic = np.expand_dims(bicubic,axis=2)
image_rcan = pred/255.0
image_rcan = 65.481 * image_rcan[:,:,0] + 128.553 * image_rcan[:,:,1] + 24.966 * image_rcan[:,:,2] + 16
image_rcan = image_rcan/255.0
image_rcan = np.expand_dims(image_rcan,axis=2)
image_src = image_src[opt.scale+6:-(opt.scale+6),opt.scale+6:-(opt.scale+6),:]
bicubic = bicubic[opt.scale+6:-(opt.scale+6),opt.scale+6:-(opt.scale+6),:]
image_rcan = image_rcan[opt.scale+6:-(opt.scale+6),opt.scale+6:-(opt.scale+6),:]
filename_list.append(filename)
psnr_list.append(peak_signal_noise_ratio(image_src, image_rcan))
ssim_list.append(structural_similarity(image_src, image_rcan,win_size=11,gaussian_weights=True,multichannel=True,data_range=1.0,K1=0.01,K2=0.03,sigma=1.5))
if np.mean(psnr_list)>opt.best:
if not opt.multiprocessing_distributed or (opt.multiprocessing_distributed and opt.process_id % opt.ndevices_per_node == 0):
print("\t", "Saving the pictures")
for filename in filename_list:
image_path = os.path.join(opt.outputs_dir, '{}_x{}_output{}.png'.format(filename[0], opt.scale,opt.process_id))
new_image_path = os.path.splitext(image_path)[0]+ "_best" +os.path.splitext(image_path)[1]
shutil.copyfile(image_path, new_image_path)
return np.mean(psnr_list),np.mean(ssim_list)
def train_prepare(opt):
opt.process_device_map = device_id_to_process_device_map(opt.device_list)
information_print(opt.process_device_map,mode = 0)
if opt.multiprocessing_distributed:
if opt.device_id != None:
raise ValueError("when you choose multi processing, you don't need to select one npu or gpu")
opt.ndevices_per_node = len(opt.process_device_map)
opt.world_size = opt.ndevices_per_node * opt.world_size
information_print("...multi processing...\nChoose to use {} {}s from device list...".format(opt.ndevices_per_node,opt.device),mode = 0)
elif opt.device_id != None:
opt.ndevices_per_node = 1
information_print("...single processing...\nChoose to use {}ID:{} from device list...".format(opt.device,opt.device_id),mode = 0)
else:
raise ValueError("The process_para set Wrong")
opt.outputs_dir = opt.outputs_dir + "/X" + str(opt.scale)+"/"
if opt.ifcontinue:
if not os.path.exists(opt.outputs_dir):
raise ValueError("Do not find this dir for continuing train, please check")
if not os.path.exists(os.path.join(opt.outputs_dir,opt.checkpoint_path)):
raise ValueError("Do not find this file for continuing train, please check")
print("Continue in {}, train from {}".format(opt.outputs_dir,opt.checkpoint_path))
opt.checkpoint_path = os.path.join(opt.outputs_dir,opt.checkpoint_path)
else:
if not os.path.exists(opt.outputs_dir):
os.makedirs(opt.outputs_dir)
else:
information_print("The dir is existing, if continue, Retraining from and Replacing...",mode = 0)
with open(opt.outputs_dir+'/Para_{}.txt'.format(time.strftime("%Y_%m_%d_%H_%M", time.localtime()) ),"w") as f:
for i in vars(opt):
f.write(i+":"+str(vars(opt)[i])+'\n')
f.close()
def main():
print("===============main() start=================")
opt = parser.parse_args()
information_print(opt,mode = 0)
if opt.device == "npu":
import torch
import torch.npu
elif opt.device == "gpu":
import torch
torch.manual_seed(opt.seed)
os.environ['MASTER_ADDR'] = "127.0.0.1"
os.environ['MASTER_PORT'] = '29688'
train_prepare(opt)
print("===============main() end=================")
if opt.multiprocessing_distributed:
mp.spawn(main_worker, nprocs=opt.ndevices_per_node, args = [opt])
else:
main_worker( 0, opt)
def main_worker(process_id,opt):
if opt.multiprocessing_distributed:
os.environ['MASTER_ADDR'] = "127.0.0.1"
os.environ['MASTER_PORT'] = '29688'
opt.sub_device_id = opt.process_device_map[process_id]
if opt.device =='npu':
torch.distributed.init_process_group(backend='hccl', world_size=opt.world_size, rank=process_id)
if opt.device == 'gpu':
torch.distributed.init_process_group(backend='nccl', init_method="env://", world_size=opt.world_size, rank=process_id)
else:
opt.sub_device_id = opt.device_id
opt.process_id = process_id
if opt.device =='npu':
loc = 'npu:{}'.format(opt.sub_device_id)
device = torch.device("npu:{}".format(opt.sub_device_id))
torch.npu.set_device(device)
if opt.device == 'gpu':
loc = 'cuda:{}'.format(opt.sub_device_id)
device = torch.device("cuda:{}".format(opt.sub_device_id))
torch.cuda.set_device(opt.sub_device_id)
print("===="*5)
print("SUB PROCESSING INFORMATION")
print("Number of Mutil-process {}".format(opt.ndevices_per_node))
print("rank ID {}".format(process_id))
print("Wanted device {}ID:{}".format(opt.device,opt.sub_device_id))
print("Chosen device {}".format(opt.device,torch.cuda.current_device() if opt.device == "gpu" else torch.npu.current_device()))
print("===="*5)
criterion = nn.L1Loss()
if opt.device == "npu":
model = RCAN(opt).to(device)
elif opt.device == "gpu":
model = RCAN(opt).cuda()
if opt.device == "npu":
optimizer = apex.optimizers.NpuFusedAdam(model.parameters(), lr=opt.lr)
if opt.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level='O2', loss_scale=32.0, combine_grad=True)
if opt.multiprocessing_distributed:
print("start ddp")
t0 = time.time()
model = torch.nn.parallel.DistributedDataParallel(model,device_ids=[device], broadcast_buffers=False)
print("end ddp")
print("time",time.time()-t0)
elif opt.device == "gpu":
optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
if opt.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level = opt.opt_level,loss_scale=opt.loss_scale)
if opt.multiprocessing_distributed:
model = torch.nn.parallel.DistributedDataParallel(model,device_ids=[device], broadcast_buffers=False)
if opt.iffinetuning:
if os.path.isfile(opt.finetuning_checkpoint_path):
print("loading checkpoint: {} and finetuning from it".format(opt.finetuning_checkpoint_path))
checkpoint = torch.load(opt.finetuning_checkpoint_path, map_location=loc)
opt.start_epoch = 0
opt.best = -1
model.load_state_dict(checkpoint['model'], False)
checkpoint_performance = []
checkpoint_time = []
print("start finetuning")
else:
raise ValueError("=> no checkpoint found at '{}'".format(opt.finetuning_checkpoint_path))
elif opt.ifcontinue:
if os.path.isfile(opt.checkpoint_path):
print("loading checkpoint: {}".format(opt.checkpoint_path))
checkpoint = torch.load(opt.checkpoint_path, map_location=loc)
opt.start_epoch = len(checkpoint['ck_performance'])
opt.best = checkpoint['best']
model.load_state_dict(checkpoint['model'])
checkpoint_performance =checkpoint['ck_performance']
checkpoint_time =checkpoint['ck_time']
if opt.amp:
amp.load_state_dict(checkpoint['amp'])
print("loaded checkpoint: {} (epoch {})".format(opt.checkpoint_path, checkpoint['epoch']))
else:
raise ValueError("=> no checkpoint found at '{}'".format(opt.checkpoint_path))
else:
opt.start_epoch = 0
opt.best = -1
checkpoint_performance = []
checkpoint_time = []
torch.backends.cudnn.benchmark=True
transform_my=transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip()])
dataset_train = Dataset(opt.train_dataset_dir, opt.patch_size, opt.scale,transform = transform_my)
if opt.multiprocessing_distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
opt.workers = int((opt.workers + opt.ndevices_per_node - 1) / opt.ndevices_per_node)
opt.batch_size = int(opt.batch_size / opt.ndevices_per_node)
else:
train_sampler = None
dataloader = DataLoader(dataset=dataset_train,batch_size=opt.batch_size,num_workers=opt.workers,pin_memory=True,drop_last=True,sampler=train_sampler)
dataset_test = Dataset_test_label(opt.test_dataset_dir, opt.scale)
dataloader_test = DataLoader(dataset=dataset_test,batch_size=1,num_workers=1 ,pin_memory=True,drop_last=True)
for epoch in range(opt.start_epoch, opt.start_epoch + opt.num_epochs):
if opt.multiprocessing_distributed:
train_sampler.set_epoch(epoch)
model.train()
if not opt.multiprocessing_distributed or (opt.multiprocessing_distributed and opt.process_id % opt.ndevices_per_node == 0):
epoch_losses = AverageMeter()
epoch_timer = AverageMeter()
timer_iter = timer()
timer_epoch = timer()
print('E:{}/{}(RankID:{})'.format(epoch , opt.start_epoch + opt.num_epochs,opt.process_id))
model.train()
timer_epoch.tic()
for index,data in enumerate(dataloader):
inputs, labels, filename = data
if opt.device == "npu":
inputs,labels = inputs.to(device),labels.to(device)
elif opt.device == "gpu":
inputs,labels = inputs.cuda(),labels.cuda()
if index > 5:
timer_iter.tic()
preds = model(inputs)
loss = criterion(preds, labels)
epoch_losses.update(loss.item(), len(inputs))
optimizer.zero_grad()
if opt.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
if index > 5:
timer_iter.hold()
epoch_timer.update(timer_iter.release()/(opt.batch_size * opt.ndevices_per_node))
if index % 10 == 0 and index > 5:
data_length = len(dataset_train) - len(dataset_train) % (opt.batch_size * opt.ndevices_per_node)
iteration_all = data_length//(opt.batch_size * opt.ndevices_per_node)
information1 = 'Iteration:{}/{}->{:.2f}%'.format(index ,iteration_all,index/iteration_all*100)
information2 = 'Time_avg_per_img:{}->FPS:{}'.format(epoch_timer.avg, 1/epoch_timer.avg)
information3 = 'Loss:{:.6f}'.format(epoch_losses.avg)
print("\t", information1, information2, information3)
timer_epoch.hold()
print("\tThis Epoch's whole time: {}".format(timer_epoch.acc))
model.eval()
psnr_avg,ssim_avg = test_eval(model,dataloader_test,opt)
checkpoint_time.append(epoch_timer.avg)
checkpoint_performance.append([psnr_avg,ssim_avg])
save_checkpoint(opt,epoch,checkpoint_performance,checkpoint_time,model,optimizer)
else:
for index,data in enumerate( dataloader):
inputs, labels, filename = data
if opt.device == "npu":
inputs,labels = inputs.to(device),labels.to(device)
elif opt.device == "gpu":
inputs,labels = inputs.cuda(),labels.cuda()
preds = model(inputs)
loss = criterion(preds, labels)
optimizer.zero_grad()
if opt.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
if __name__ == '__main__':
main()