import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import apex.amp as amp
import apex
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import time
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.nn.modules.loss import _Loss
from models import ADNet
from dataset import prepare_data, Dataset
from collections import OrderedDict
from utils import *
if torch.__version__ >= '1.8':
import torch_npu
parser = argparse.ArgumentParser(description="DnCNN")
parser.add_argument("--preprocess", type=bool, default=False, help='run prepare_data or not')
parser.add_argument("--batchSize", type=int, default=128, help="Training batch size")
parser.add_argument("--resume", type=bool, default=False, help="resume training from .pth")
parser.add_argument("--num_of_layers", type=int, default=17, help="Number of total layers")
parser.add_argument("--epochs", type=int, default=70, help="Number of training epochs")
parser.add_argument("--logdir", type=str, default="", help='path of log files')
parser.add_argument("--milestone", type=int, default=30, help="When to decay learning rate; should be less than epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
parser.add_argument("--outf", type=str, default="logs", help='path of log files?path to .pth')
parser.add_argument("--mode", type=str, default="S", help='with known noise level (S) or blind training (B)')
parser.add_argument("--noiseL", type=float, default=15, help='noise level; ignored when mode=B')
parser.add_argument("--val_noiseL", type=float, default=15, help='noise level used on validation set')
parser.add_argument("--is_distributed", type=int, default=0, help='choose ddp or not')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--DeviceID', type=str, default="0")
parser.add_argument("--num_gpus", default=1, type=int)
parser.add_argument("--world_size", default=-1, type=int)
parser.add_argument("--loss_scale", default=128, type=int)
'''
parser.add_argument("--clip",type=float,default=0.005,help='Clipping Gradients. Default=0.4') #tcw201809131446tcw
parser.add_argument("--momentum",default=0.9,type='float',help = 'Momentum, Default:0.9') #tcw201809131447tcw
parser.add_argument("--weight-decay","-wd",default=1e-3,type=float,help='Weight decay, Default:1e-4') #tcw20180913347tcw
'''
opt = parser.parse_args()
def proc_nodes_module(checkpoint):
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
if "module." in k:
name = k.replace("module.", "")
else:
name = k
new_state_dict[name] = v
return new_state_dict
class sum_squared_error(_Loss):
"""
Definition: sum_squared_error = 1/2 * nn.MSELoss(reduction = 'sum')
The backward is defined as: input-target
"""
def __init__(self, size_average=None, reduce=None, reduction='sum'):
super(sum_squared_error, self).__init__(size_average, reduce, reduction)
def forward(self, input, target):
return torch.nn.functional.mse_loss(input, target, size_average=None, reduce=None, reduction='sum').div_(2)
def main():
t1 = time.time()
if opt.is_distributed == 0:
local_device = torch.device(f'npu:{opt.DeviceID}')
torch.npu.set_device(local_device)
print("using npu :{}".format(opt.DeviceID))
else:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29688'
opt.local_rank = int(os.environ['RANK_ID'])
local_device = torch.device(f'npu:{opt.local_rank}')
torch.npu.set_device(local_device)
if opt.local_rank == 0:
print("using npu :{}".format(opt.DeviceID))
dist.init_process_group(backend='hccl',world_size=opt.world_size, rank=opt.local_rank)
save_dir = opt.outf + 'sigma' + str(opt.noiseL) + '/'
if not os.path.exists(save_dir):
os.mkdir(save_dir)
if opt.local_rank == 0:
print('Loading dataset ...\n')
dataset_train = Dataset(train=True)
dataset_val = Dataset(train=False)
if opt.is_distributed == 0:
loader_train = DataLoader(dataset=dataset_train, num_workers=16, batch_size=opt.batchSize, shuffle=True, drop_last=True)
else:
train_sampler = DistributedSampler(dataset_train)
loader_train = DataLoader(dataset=dataset_train, sampler=train_sampler, num_workers=16,
batch_size=opt.batchSize, pin_memory=False, drop_last=True)
if opt.local_rank == 0:
print("# of training samples: %d\n" % int(len(dataset_train)))
net = ADNet(channels=1, num_of_layers=opt.num_of_layers)
model = net.to(local_device)
optimizer = apex.optimizers.NpuFusedSGD(model.parameters(), lr=opt.lr, momentum=0.9, nesterov=True)
if opt.resume == True:
path_checkpoint = os.path.join(save_dir, 'best_model.pth')
checkpoint = torch.load(path_checkpoint)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dic'])
start_epoch = checkpoint['epoch']+1
model = model.npu()
else:
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=opt.loss_scale)
start_epoch = 0
criterion = nn.MSELoss(reduction='sum')
criterion.cpu()
if opt.is_distributed:
model = nn.parallel.DistributedDataParallel(model, device_ids=[opt.local_rank], broadcast_buffers=False)
noiseL_B=[0,55]
psnr_list = []
for epoch in range(start_epoch, opt.epochs):
if opt.is_distributed == 1:
train_sampler.set_epoch(epoch)
if epoch <= opt.milestone:
current_lr = opt.lr
if epoch > 30 and epoch <=60:
current_lr = opt.lr/1.
if epoch > 60 and epoch <=90:
current_lr = opt.lr /100.
if epoch >90 and epoch <=120:
current_lr = opt.lr/1000.
for param_group in optimizer.param_groups:
param_group["lr"] = current_lr
if opt.local_rank == 0:
print('learning rate %f' % current_lr)
fps_s = 0
for i, data in enumerate(loader_train, 0):
start_time = time.time()
model.train()
img_train = data
if opt.mode == 'S':
noise = torch.FloatTensor(img_train.size()).normal_(mean=0, std=opt.noiseL/255.)
if opt.mode == 'B':
noise = torch.zeros(img_train.size())
stdN = np.random.uniform(noiseL_B[0], noiseL_B[1], size=noise.size()[0])
for n in range(noise.size()[0]):
sizeN = noise[0,:,:,:].size()
noise[n,:,:,:] = torch.FloatTensor(sizeN).normal_(mean=0, std=stdN[n]/255.)
imgn_train = img_train + noise
img_train, imgn_train = Variable(img_train.npu()), Variable(imgn_train.npu())
noise = Variable(noise.npu())
out_train = model(imgn_train)
loss = (criterion(out_train.cpu(), img_train.cpu()) / (imgn_train.size()[0] * 2)).npu()
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
steptime = time.time() - start_time
model.eval()
out_train = torch.clamp(model(imgn_train), 0., 1.)
psnr_train = batch_PSNR(out_train, img_train, 1.)
if opt.local_rank == 0:
if i >= 10:
fps_s = opt.num_gpus * opt.batchSize / steptime + fps_s
if (i + 1) == len(loader_train):
time_all = time.time() - start_time
print("[epoch %d][%d/%d] fps: %.4f time: %.4f" %
(epoch + 1, i + 1, len(loader_train), fps_s / (len(loader_train)-10), time_all))
print("[epoch %d][%d/%d] loss: %.4f PSNR_train: %.4f scaled_loss: %.4f " %
(epoch+1, i+1, len(loader_train), loss.item(), psnr_train, scaled_loss.item()))
model.eval()
psnr_val = 0
best_psnr = 0
for k in range(len(dataset_val)):
img_val = torch.unsqueeze(dataset_val[k], 0)
torch.manual_seed(0)
noisy = torch.FloatTensor(img_val.size()).normal_(mean=0, std=opt.val_noiseL/255.)
imgn_val = img_val + noisy
img_val, imgn_val = Variable(img_val.npu()), Variable(imgn_val.npu(),requires_grad=False)
out_val = torch.clamp(model(imgn_val), 0., 1.)
stream = torch.npu.current_stream()
stream.synchronize()
psnr_val += batch_PSNR(out_val, img_val, 1.)
stream = torch.npu.current_stream()
stream.synchronize()
psnr_val /= len(dataset_val)
psnr_val1 = str(psnr_val)
psnr_list.append(psnr_val1)
if opt.local_rank == 0:
print("\n[epoch %d] PSNR_val: %.4f" % (epoch+1, psnr_val))
model_name = 'model'+ '_' + str(opt.resume)+ '_' + str(epoch+1) + '.pth'
if opt.local_rank == 0:
checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dic": optimizer.state_dict(),
"loss": loss,
"epoch": epoch}
torch.save(checkpoint, os.path.join(save_dir, model_name))
if best_psnr < psnr_val:
checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dic": optimizer.state_dict(),
"loss": loss,
"epoch": epoch}
torch.save(checkpoint, os.path.join(save_dir, 'best_model.pth'))
filename = save_dir + 'psnr.txt'
f = open(filename,'w')
for line in psnr_list:
f.write(line+'\n')
f.close()
t2 = time.time()
t = t2-t1
if opt.local_rank == 0:
print ("total training used time:",t)
if __name__ == "__main__":
if opt.preprocess:
if opt.mode == 'S':
prepare_data(data_path='data', patch_size=50, stride=40, aug_times=1)
if opt.mode == 'B':
prepare_data(data_path='data', patch_size=50, stride=10, aug_times=2)
main()