@@ -2,7 +2,6 @@ import torch
import torch.nn as nn
from torch.nn import init
import functools
-from torch.optim import lr_scheduler
###############################################################################
@@ -35,35 +34,6 @@ def get_norm_layer(norm_type='instance'):
return norm_layer
-def get_scheduler(optimizer, opt):
- """Return a learning rate scheduler
-
- Parameters:
- optimizer -- the optimizer of the network
- opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
- opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
-
- For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
- and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
- For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
- See https://pytorch.org/docs/stable/optim.html for more details.
- """
- if opt.lr_policy == 'linear':
- def lambda_rule(epoch):
- lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
- return lr_l
- scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
- elif opt.lr_policy == 'step':
- scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
- elif opt.lr_policy == 'plateau':
- scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
- elif opt.lr_policy == 'cosine':
- scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
- else:
- return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
- return scheduler
-
-
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
@@ -108,10 +78,7 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
Return an initialized network.
"""
- if len(gpu_ids) > 0:
- assert(torch.cuda.is_available())
- net.to(gpu_ids[0])
- net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
+ # net.to(gpu_ids)
init_weights(net, init_type, init_gain=init_gain)
return net
@@ -203,122 +170,14 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
return init_net(net, init_type, init_gain, gpu_ids)
-##############################################################################
-# Classes
-##############################################################################
-class GANLoss(nn.Module):
- """Define different GAN objectives.
-
- The GANLoss class abstracts away the need to create the target label tensor
- that has the same size as the input.
- """
-
- def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
- """ Initialize the GANLoss class.
-
- Parameters:
- gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
- target_real_label (bool) - - label for a real image
- target_fake_label (bool) - - label of a fake image
-
- Note: Do not use sigmoid as the last layer of Discriminator.
- LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
- """
- super(GANLoss, self).__init__()
- self.register_buffer('real_label', torch.tensor(target_real_label))
- self.register_buffer('fake_label', torch.tensor(target_fake_label))
- self.gan_mode = gan_mode
- if gan_mode == 'lsgan':
- self.loss = nn.MSELoss()
- elif gan_mode == 'vanilla':
- self.loss = nn.BCEWithLogitsLoss()
- elif gan_mode in ['wgangp']:
- self.loss = None
- else:
- raise NotImplementedError('gan mode %s not implemented' % gan_mode)
-
- def get_target_tensor(self, prediction, target_is_real):
- """Create label tensors with the same size as the input.
-
- Parameters:
- prediction (tensor) - - tpyically the prediction from a discriminator
- target_is_real (bool) - - if the ground truth label is for real images or fake images
-
- Returns:
- A label tensor filled with ground truth label, and with the size of the input
- """
-
- if target_is_real:
- target_tensor = self.real_label
- else:
- target_tensor = self.fake_label
- return target_tensor.expand_as(prediction)
-
- def __call__(self, prediction, target_is_real):
- """Calculate loss given Discriminator's output and grount truth labels.
-
- Parameters:
- prediction (tensor) - - tpyically the prediction output from a discriminator
- target_is_real (bool) - - if the ground truth label is for real images or fake images
-
- Returns:
- the calculated loss.
- """
- if self.gan_mode in ['lsgan', 'vanilla']:
- target_tensor = self.get_target_tensor(prediction, target_is_real)
- loss = self.loss(prediction, target_tensor)
- elif self.gan_mode == 'wgangp':
- if target_is_real:
- loss = -prediction.mean()
- else:
- loss = prediction.mean()
- return loss
-
-
-def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
- """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
-
- Arguments:
- netD (network) -- discriminator network
- real_data (tensor array) -- real images
- fake_data (tensor array) -- generated images from the generator
- device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
- type (str) -- if we mix real and fake data or not [real | fake | mixed].
- constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
- lambda_gp (float) -- weight for this loss
-
- Returns the gradient penalty loss
- """
- if lambda_gp > 0.0:
- if type == 'real': # either use real images, fake images, or a linear interpolation of two.
- interpolatesv = real_data
- elif type == 'fake':
- interpolatesv = fake_data
- elif type == 'mixed':
- alpha = torch.rand(real_data.shape[0], 1, device=device)
- alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
- interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
- else:
- raise NotImplementedError('{} not implemented'.format(type))
- interpolatesv.requires_grad_(True)
- disc_interpolates = netD(interpolatesv)
- gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
- grad_outputs=torch.ones(disc_interpolates.size()).to(device),
- create_graph=True, retain_graph=True, only_inputs=True)
- gradients = gradients[0].view(real_data.size(0), -1) # flat the data
- gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
- return gradient_penalty, gradients
- else:
- return 0.0, None
-
-
class ResnetGenerator(nn.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
- def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6,
+ padding_type='zero'):
"""Construct a Resnet-based generator
Parameters:
@@ -336,8 +195,8 @@ class ResnetGenerator(nn.Module):
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
-
- model = [nn.ReflectionPad2d(3),
+ # 这里将ReflectionPad2d(3)--->替换为nn.ZeroPad2d(3)
+ model = [nn.ZeroPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
@@ -362,7 +221,7 @@ class ResnetGenerator(nn.Module):
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
- model += [nn.ReflectionPad2d(3)]
+ model += [nn.ZeroPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
@@ -402,7 +261,9 @@ class ResnetBlock(nn.Module):
conv_block = []
p = 0
if padding_type == 'reflect':
- conv_block += [nn.ReflectionPad2d(1)]
+ # ReflectionPad2d
+ # nn.ZeroPad2d(3)
+ conv_block += [nn.ZeroPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
@@ -416,7 +277,7 @@ class ResnetBlock(nn.Module):
p = 0
if padding_type == 'reflect':
- conv_block += [nn.ReflectionPad2d(1)]
+ conv_block += [nn.ZeroPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':