import os
import argparse
import numpy as np
import torch
if torch.__version__ >= '1.8':
import torch_npu
import torch.nn as nn
import torch.npu
import torch.optim as optim
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from models import DnCNN
from dataset import Traindata, Evldata
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from utils import batch_PSNR, weights_init_kaiming, AverageMeter
from apex import amp
import apex
import torch.multiprocessing as mp
import time
parser = argparse.ArgumentParser(description="DnCNN")
parser.add_argument("--preprocess", type=str, default="F", help="write T or Ture means create h5py dataset")
parser.add_argument('--data_path', type=str, help='path of dataset')
parser.add_argument("--batchSize", type=int, default=512, help="Training batch size")
parser.add_argument("--num_of_layers", type=int, default=17, help="Number of total layers")
parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")
parser.add_argument("--milestone", type=int, default=30, help="When to decay learning rate; should be less than epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
parser.add_argument("--outf", type=str, default=".", help='path of log files')
parser.add_argument("--mode", type=str, default="S", help='with known noise level (S) or blind training (B)')
parser.add_argument("--noiseL", type=float, default=25, help='noise level; ignored when mode=B')
parser.add_argument("--val_noiseL", type=float, default=25, help='noise level used on validation set')
parser.add_argument('--gpu_use_num', default=8, type=int, help='how many gpus you want use')
parser.add_argument('--device-list', default='0,1,2,3,4,5,6,7', type=str, help='device id list')
parser.add_argument('--worldsize', default=8, type=int, help='number of nodes for distributed training')
opt = parser.parse_args()
def getDiv(epo):
""" lr div value get """
if epo < 15:
return 1.
if epo < 25:
return 2.
if epo < 30:
return 3.
if epo < 60:
return 4.
return 1.
def device_id_to_process_device_map(device_list):
""" find device """
devices = device_list.split(",")
devices = [int(x) for x in devices]
devices.sort()
process_device_map = dict()
for process_id, device_id in enumerate(devices):
process_device_map[process_id] = device_id
return process_device_map
def main():
""" 8p in"""
opt.process_device_map = device_id_to_process_device_map(opt.device_list)
ngpus_per_node = torch.cuda.device_count()
mp.spawn(main_worker, nprocs=opt.gpu_use_num, args=(ngpus_per_node, opt))
def main_worker(gpu, gpu_nums, opt):
""" spawn thread """
opt.gpu = opt.process_device_map[gpu]
print ("gpu -> ", gpu)
print ("all :", gpu_nums)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29688'
os.environ['WORLD_SIZE'] = '8'
dist.init_process_group(backend="hccl", init_method='env://', world_size=opt.worldsize, rank=gpu)
dataset_train = Traindata(data_path=opt.data_path, getDataSet='train')
dataset_val = Evldata(data_path=opt.data_path, getDataSet='Set68')
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
loader_train = DataLoader(dataset=dataset_train, batch_size=opt.batchSize, \
pin_memory=True, shuffle=(train_sampler is None), sampler=train_sampler, drop_last=True)
print("# of training samples: %d" % int(len(dataset_train)))
print("use gpu num", opt.gpu)
loc = 'npu:{}'.format(opt.gpu)
torch.npu.set_device(loc)
net = DnCNN(channels = 1, num_of_layers=opt.num_of_layers)
net.apply(weights_init_kaiming)
net=net.to(loc)
optimizer = apex.optimizers.NpuFusedAdam(net.parameters(), lr=opt.lr)
criterion = nn.MSELoss(reduction='sum')
net, optimizer = amp.initialize(net, optimizer, opt_level = "O1", loss_scale = "dynamic", combine_grad=True)
model = DDP(net, device_ids=[gpu])
criterion.to(loc)
step_time = AverageMeter('Time', ':6.3f')
step_start=time.time()
data_time = AverageMeter('Time', ':6.3f')
data_start=time.time()
maxPsnr=0
for epoch in range(opt.epochs):
train_sampler.set_epoch(epoch)
current_lr = opt.lr / getDiv(epoch)
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
print('learning rate %f' % current_lr)
model.train()
for i, dataLoad in enumerate(loader_train, 0):
data_difTime=time.time() - data_start
data_time.update(data_difTime)
step_start=time.time()
imgn_train, img_train, noise = dataLoad
img_train, imgn_train = img_train.to(loc), imgn_train.to(loc)
noise = noise.to(loc)
out_train = model(imgn_train)
loss = criterion(out_train, noise) / (imgn_train.size()[0] * 2)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
optimizer.zero_grad()
torch.npu.synchronize()
step_difTime=time.time() - step_start
step_time.update(step_difTime)
data_start = time.time()
if i % 40 == 2 and int(time.time()) % 8 == opt.gpu:
print('epoch[{}] load[{}/{}] loss: {:.3f} '.format(epoch, i + 1, len(loader_train), loss.item()))
print('dataTime[{:.3f}]/[{:.3f}] stepTime[{:.3f}/{:.3f}] FPS: {:.3f}'.format(data_difTime, \
data_time.avg, step_difTime, step_time.avg, opt.batchSize * opt.worldsize / (data_difTime + step_difTime)))
model.eval()
psnr_val = 0
for k in range(len(dataset_val)):
noiseData, clearData, noise = dataset_val[k]
img_val = torch.unsqueeze(clearData, 0)
imgn_val = torch.unsqueeze(noiseData, 0)
img_val, imgn_val = img_val.to(loc), imgn_val.to(loc)
out_val = torch.clamp(imgn_val - model(imgn_val), 0., 1.)
psnr_val += batch_PSNR(out_val, img_val, 1.)
psnr_val /= len(dataset_val)
print("\n[epoch %d] [gpu %d] PSNR_val: %.4f" % (epoch, opt.gpu, psnr_val))
if opt.gpu == 0 and psnr_val > maxPsnr:
maxPsnr = psnr_val
print("save and maxPsnr = ", maxPsnr)
torch.save(model.state_dict(), os.path.join(opt.outf, 'net8p.pth'))
if opt.gpu == 0:
print("gpuNum: ", opt.gpu, "gpu num finnal Psnr:", maxPsnr)
if __name__ == "__main__":
main()