import argparse
import os
from math import log10
import time
import torch
if torch.__version__>= "1.8":
print("import torch_npu")
import torch_npu
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import pytorch_ssim
from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform
from loss import GeneratorLoss
from model import Generator, Discriminator
import config
from config import AverageMeter
parser = argparse.ArgumentParser(description='Train Super Resolution Models')
parser.add_argument('--crop_size', default=88, type=int, help='training images crop size')
parser.add_argument('--upscale_factor', default=2, type=int, choices=[2, 4, 8],
help='super resolution upscale factor')
parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')
parser.add_argument('--batch_size', default=64, type=int, help='set training batch size')
parser.add_argument('--amp', default=False, type=bool,
help='use amp to train the model')
parser.add_argument('--amp_level', default='O1',
help='set amp level.')
parser.add_argument('--train_data_path', default='./data/VOC2012/train', type=str,
help='source data folder for training')
parser.add_argument('--val_data_path', default='./data/VOC2012/val', type=str,
help='source data folder for training')
parser.add_argument('--device', default='npu',
help='device id (i.e. npu:1 or 1,2 or cpu)')
parser.add_argument('--use_npu', default=False, type=bool,
help='If use npu for training.')
parser.add_argument('--use_gpu', default=False, type=bool,
help='If use gpu for training.')
parser.add_argument('--only_keep_best', default=True, type=bool,
help='Only keep best training result.')
parser.add_argument('--save_prof', default=False, type=bool,
help='If save training prof.')
parser.add_argument('--loss_scale_g', default=128.0, help='netG amp loss_scale: dynamic, 128.0')
parser.add_argument('--loss_scale_d', default=128.0, help='netD amp loss_scale: dynamic, 128.0')
parser.add_argument('--nproc', default=8, type=int, help='workers for dataset_loader')
parser.add_argument('--save_val_img', default=False, type=bool,
help='save val_images')
parser.add_argument('--performance', default=False, type=bool,
help='If run val process.')
parser.add_argument('--output_dir', default=config.get_root_path(), type=str,
help='Path to save running results.')
def seed_everything(seed):
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.npu.manual_seed(seed)
torch.npu.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if __name__ == '__main__':
opt = parser.parse_args()
print(opt)
seed_everything(5)
avt = AverageMeter(opt.performance)
if opt.use_npu:
import torch.npu
if torch.npu.is_available():
device = torch.device(opt.device)
prof_kwargs = {'use_npu': True}
elif opt.use_gpu:
if torch.cuda.is_available():
device = torch.device('cuda')
prof_kwargs = {'use_cuda': True}
else:
device = torch.device('cpu')
prof_kwargs = {}
print(f'使用 {device} 进行训练.')
CROP_SIZE = opt.crop_size
UPSCALE_FACTOR = opt.upscale_factor
NUM_EPOCHS = opt.num_epochs
config.set_root_path(opt.output_dir)
out_path = config.get_root_path() + 'epochs'
if not os.path.exists(out_path):
os.makedirs(out_path)
if opt.save_val_img:
val_path = config.get_root_path() + 'val_results_img'
if not os.path.exists(val_path):
os.makedirs(val_path)
train_set = TrainDatasetFromFolder(opt.train_data_path, crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
val_set = ValDatasetFromFolder(opt.val_data_path, upscale_factor=UPSCALE_FACTOR)
train_loader = DataLoader(dataset=train_set, num_workers=opt.nproc,
batch_size=opt.batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(dataset=val_set, num_workers=opt.nproc, batch_size=1, shuffle=False)
netG = Generator(UPSCALE_FACTOR)
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
generator_criterion = GeneratorLoss()
if opt.use_npu or opt.use_gpu:
netG = netG.to(device)
netD = netD.to(device)
generator_criterion = generator_criterion.to(device)
else:
netG.cpu()
netD.cpu()
generator_criterion.cpu()
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())
if opt.amp:
from apex import amp
netG, optimizerG = amp.initialize(netG, optimizerG, opt_level=opt.amp_level, loss_scale=opt.loss_scale_g)
netD, optimizerD = amp.initialize(netD, optimizerD, opt_level=opt.amp_level, loss_scale=opt.loss_scale_d)
results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': [], 'train_fps': []}
best_results = 0
cudnn.benchmark = True
for epoch in range(1, NUM_EPOCHS + 1):
avt.t_start('epoch')
running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0, 'train_fps': 0}
netG.train()
netD.train()
step = 0
fps_number = 0
fps_count_start = False
fps_start_time = 0
avt.t_start('training')
for data, target in train_loader:
step += 1
fps_start_time = time.time()
if step == 11:
fps_count_start = True
g_update_first = True
batch_size = data.size(0)
running_results['batch_sizes'] += batch_size
if opt.save_prof and step == 10 and epoch == 1:
with torch.autograd.profiler.profile(**prof_kwargs) as prof:
real_img = Variable(target)
if opt.use_npu or opt.use_gpu:
real_img = real_img.to(device)
else:
real_img.cpu()
z = Variable(data)
if opt.use_npu or opt.use_gpu:
z = z.to(device)
else:
z.cpu()
fake_img = netG(z)
netD.zero_grad()
real_out = netD(real_img).mean()
fake_out = netD(fake_img).mean()
d_loss = 1 - real_out + fake_out
if opt.amp:
with amp.scale_loss(d_loss, optimizerD) as scaled_d_loss:
scaled_d_loss.backward(retain_graph=True)
else:
d_loss.backward(retain_graph=True)
optimizerD.step()
netG.zero_grad()
fake_img = netG(z)
fake_out = netD(fake_img).mean()
g_loss = generator_criterion(fake_out.float(), fake_img.float(), real_img.float())
if opt.amp:
with amp.scale_loss(g_loss, optimizerG) as scaled_g_loss:
scaled_g_loss.backward()
else:
g_loss.backward()
optimizerG.step()
prof.export_chrome_trace(f'{config.get_root_path()}srgan_prof_{device}_{step}.prof')
else:
real_img = Variable(target)
if opt.use_npu or opt.use_gpu:
real_img = real_img.to(device)
else:
real_img.cpu()
z = Variable(data)
if opt.use_npu or opt.use_gpu:
z = z.to(device)
else:
z.cpu()
fake_img = netG(z)
netD.zero_grad()
real_out = netD(real_img).mean()
fake_out = netD(fake_img).mean()
d_loss = 1 - real_out + fake_out
if opt.amp:
with amp.scale_loss(d_loss, optimizerD) as scaled_d_loss:
scaled_d_loss.backward(retain_graph=True)
else:
d_loss.backward(retain_graph=True)
optimizerD.step()
netG.zero_grad()
fake_img = netG(z)
fake_out = netD(fake_img).mean()
g_loss = generator_criterion(fake_out.float(), fake_img.float(), real_img.float())
if opt.amp:
with amp.scale_loss(g_loss, optimizerG) as scaled_g_loss:
scaled_g_loss.backward()
else:
g_loss.backward()
optimizerG.step()
if fps_count_start:
fps = batch_size / (time.time() - fps_start_time)
fps = round(fps, 2)
fps_number += 1
running_results['train_fps'] += fps
else:
fps = 0
running_results['g_loss'] += g_loss.item() * batch_size
running_results['d_loss'] += d_loss.item() * batch_size
running_results['d_score'] += real_out.item() * batch_size
running_results['g_score'] += fake_out.item() * batch_size
Loss_D = running_results['d_loss'] / running_results['batch_sizes']
Loss_G = running_results['g_loss'] / running_results['batch_sizes']
score_D = running_results['d_score'] / running_results['batch_sizes']
score_G = running_results['g_score'] / running_results['batch_sizes']
print(f'[{epoch}/{NUM_EPOCHS}] step:{step} Loss_D: {Loss_D:.4f} Loss_G: {Loss_G:.4f} '
f'D(x): {score_D:.4f} D(G(z)): {score_G:.4f} Fps: {fps:.4f}')
avt.step_update()
avt.print_time('training')
running_results['train_fps'] = running_results['train_fps'] / fps_number
opt.save_prof = False
if not opt.performance:
avt.t_start('val')
netG.eval()
with torch.no_grad():
valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
val_images = []
print_step = 0
for val_lr, val_hr_restore, val_hr in val_loader:
batch_size = val_lr.size(0)
valing_results['batch_sizes'] += batch_size
lr = val_lr
hr = val_hr
if opt.use_npu or opt.use_gpu:
lr = lr.to(device)
hr = hr.to(device)
sr = netG(lr)
batch_mse = ((sr - hr) ** 2).data.mean()
valing_results['mse'] += batch_mse * batch_size
batch_ssim = pytorch_ssim.ssim(sr, hr).item()
valing_results['ssims'] += batch_ssim * batch_size
valing_results['psnr'] = 10 * log10(
(hr.max() ** 2) / (valing_results['mse'] / valing_results['batch_sizes']))
valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
if print_step % 10 == 0 or print_step == 0:
psnr = valing_results['psnr']
ssim = valing_results['ssim']
print(f'[converting LR images to SR images] PSNR: {psnr:4f} dB SSIM: {ssim:4f}')
print_step += 1
if opt.save_val_img:
val_images.append(display_transform()(val_hr_restore.squeeze(0)))
val_images.append(display_transform()(hr.data.cpu().squeeze(0)))
val_images.append(display_transform()(sr.data.cpu().squeeze(0)))
if opt.save_val_img:
val_images = torch.stack(val_images)
val_images = torch.chunk(val_images, val_images.size(0) // 15)
print('[saving training results]:\n')
index = 1
for image in val_images:
image = utils.make_grid(image, nrow=3, padding=5)
utils.save_image(image, val_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
index += 1
avt.print_time('val')
else:
print('Skip the verification process!')
results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
results['train_fps'].append(running_results['train_fps'])
if not opt.performance:
results['psnr'].append(valing_results['psnr'])
results['ssim'].append(valing_results['ssim'])
else:
results['psnr'].append(0)
results['ssim'].append(0)
if opt.only_keep_best and not opt.performance:
current_results = valing_results['psnr']/10 + valing_results['ssim']
if current_results > best_results:
print(f'current best validation results -> psnr: {valing_results["psnr"]}, '
f'ssim: {valing_results["ssim"]}')
torch.save(netG.state_dict(), out_path + '/netG_best.pth')
torch.save(netD.state_dict(), out_path + '/netD_best.pth')
best_results = current_results
else:
if epoch == 1 or epoch % 5 == 0:
torch.save(netG.state_dict(), out_path + '/netG_epoch_%d.pth' % (epoch))
torch.save(netD.state_dict(), out_path + '/netD_epoch_%d.pth' % (epoch))
with open(config.get_root_path() + 'epoch_log_1p.txt', 'w', encoding='utf-8') as f:
title = 'epoch \t d_loss \t g_loss \t d_score \t g_score \t psnr \t ssim \t train_fps \n'
f.write(title)
print('saving to file')
for i in range(0, epoch):
str = f"{i + 1} \t {results['d_loss'][i]:.4f} \t {results['g_loss'][i]:.4f} \t " \
f"{results['d_score'][i]:.4f} \t {results['g_score'][i]:.4f} \t " \
f"{results['psnr'][i]:.4f} \t " \
f"{results['ssim'][i]:.4f} \t {results['train_fps'][i]:.4f} \n"
f.write(str)
print('write results successfully!')
avt.print_time('epoch')
avt.print_time('end')