@@ -1,5 +1,4 @@
import numpy as np
-import math
import functools
import torch
@@ -7,16 +6,14 @@ import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
-from torch.nn import Parameter as P
import layers
-from sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
# Architectures for G
# Attention is passed in in the format '32_64' to mean applying an attention
# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
-def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
+def G_arch(ch=64, attention='64'):
arch = {}
arch[512] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
@@ -128,7 +125,7 @@ class Generator(nn.Module):
else:
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
self.which_linear = nn.Linear
-
+
# We use a non-spectral-normed embedding here regardless;
# For some reason applying SN to G's embedding seems to randomly cripple G
self.which_embedding = nn.Embedding
@@ -146,7 +143,7 @@ class Generator(nn.Module):
# Prepare model
# If not using shared embeddings, self.shared is just a passthrough
- self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
+ self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
else layers.identity())
# First linear layer
self.linear = self.which_linear(self.dim_z // self.num_slots,
@@ -210,7 +207,7 @@ class Generator(nn.Module):
self.param_count = 0
for module in self.modules():
if (isinstance(module, nn.Conv2d)
- or isinstance(module, nn.Linear)
+ or isinstance(module, nn.Linear)
or isinstance(module, nn.Embedding)):
if self.init == 'ortho':
init.orthogonal_(module.weight)
@@ -227,224 +224,24 @@ class Generator(nn.Module):
# already been passed through G.shared to enable easy class-wise
# interpolation later. If we passed in the one-hot and then ran it through
# G.shared in this forward function, it would be harder to handle.
- def forward(self, z, y):
+ def forward(self, z, ys):
# If hierarchical, concatenate zs and ys
- if self.hier:
- zs = torch.split(z, self.z_chunk_size, 1)
- z = zs[0]
- ys = [torch.cat([y, item], 1) for item in zs[1:]]
- else:
- ys = [y] * len(self.blocks)
-
+ if not self.hier:
+ print("仅支持BigGAN分层模型的转换,请设置模型入参 hier=True")
+
+ z = z.view(-1, 20)
+ ys = ys.view(-1, 5, 148)
# First linear layer
h = self.linear(z)
# Reshape
h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
-
+
# Loop over blocks
for index, blocklist in enumerate(self.blocks):
# Second inner loop in case block has multiple layers
for block in blocklist:
- h = block(h, ys[index])
-
+ y = ys[:,index,:]
+ h = block(h, y)
+
# Apply batchnorm-relu-conv-tanh at output
return torch.tanh(self.output_layer(h))
-
-
-# Discriminator architecture, same paradigm as G's above
-def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'):
- arch = {}
- arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]],
- 'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
- 'downsample' : [True] * 6 + [False],
- 'resolution' : [128, 64, 32, 16, 8, 4, 4 ],
- 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
- for i in range(2,8)}}
- arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]],
- 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]],
- 'downsample' : [True] * 5 + [False],
- 'resolution' : [64, 32, 16, 8, 4, 4],
- 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
- for i in range(2,8)}}
- arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]],
- 'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]],
- 'downsample' : [True] * 4 + [False],
- 'resolution' : [32, 16, 8, 4, 4],
- 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
- for i in range(2,7)}}
- arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]],
- 'out_channels' : [item * ch for item in [4, 4, 4, 4]],
- 'downsample' : [True, True, False, False],
- 'resolution' : [16, 16, 16, 16],
- 'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
- for i in range(2,6)}}
- return arch
-
-class Discriminator(nn.Module):
-
- def __init__(self, D_ch=64, D_wide=True, resolution=128,
- D_kernel_size=3, D_attn='64', n_classes=1000,
- num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
- D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8,
- SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False,
- D_init='ortho', skip_init=False, D_param='SN', **kwargs):
- super(Discriminator, self).__init__()
- # Width multiplier
- self.ch = D_ch
- # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
- self.D_wide = D_wide
- # Resolution
- self.resolution = resolution
- # Kernel size
- self.kernel_size = D_kernel_size
- # Attention?
- self.attention = D_attn
- # Number of classes
- self.n_classes = n_classes
- # Activation
- self.activation = D_activation
- # Initialization style
- self.init = D_init
- # Parameterization style
- self.D_param = D_param
- # Epsilon for Spectral Norm?
- self.SN_eps = SN_eps
- # Fp16?
- self.fp16 = D_fp16
- # Architecture
- self.arch = D_arch(self.ch, self.attention)[resolution]
-
- # Which convs, batchnorms, and linear layers to use
- # No option to turn off SN in D right now
- if self.D_param == 'SN':
- self.which_conv = functools.partial(layers.SNConv2d,
- kernel_size=3, padding=1,
- num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
- eps=self.SN_eps)
- self.which_linear = functools.partial(layers.SNLinear,
- num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
- eps=self.SN_eps)
- self.which_embedding = functools.partial(layers.SNEmbedding,
- num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
- eps=self.SN_eps)
- # Prepare model
- # self.blocks is a doubly-nested list of modules, the outer loop intended
- # to be over blocks at a given resolution (resblocks and/or self-attention)
- self.blocks = []
- for index in range(len(self.arch['out_channels'])):
- self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
- out_channels=self.arch['out_channels'][index],
- which_conv=self.which_conv,
- wide=self.D_wide,
- activation=self.activation,
- preactivation=(index > 0),
- downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
- # If attention on this block, attach it to the end
- if self.arch['attention'][self.arch['resolution'][index]]:
- print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
- self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
- self.which_conv)]
- # Turn self.blocks into a ModuleList so that it's all properly registered.
- self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
- # Linear output layer. The output dimension is typically 1, but may be
- # larger if we're e.g. turning this into a VAE with an inference output
- self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
- # Embedding for projection discrimination
- self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
-
- # Initialize weights
- if not skip_init:
- self.init_weights()
-
- # Set up optimizer
- self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
- if D_mixed_precision:
- print('Using fp16 adam in D...')
- import utils
- self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
- betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
- else:
- self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
- betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
- # LR scheduling, left here for forward compatibility
- # self.lr_sched = {'itr' : 0}# if self.progressive else {}
- # self.j = 0
-
- # Initialize
- def init_weights(self):
- self.param_count = 0
- for module in self.modules():
- if (isinstance(module, nn.Conv2d)
- or isinstance(module, nn.Linear)
- or isinstance(module, nn.Embedding)):
- if self.init == 'ortho':
- init.orthogonal_(module.weight)
- elif self.init == 'N02':
- init.normal_(module.weight, 0, 0.02)
- elif self.init in ['glorot', 'xavier']:
- init.xavier_uniform_(module.weight)
- else:
- print('Init style not recognized...')
- self.param_count += sum([p.data.nelement() for p in module.parameters()])
- print('Param count for D''s initialized parameters: %d' % self.param_count)
-
- def forward(self, x, y=None):
- # Stick x into h for cleaner for loops without flow control
- h = x
- # Loop over blocks
- for index, blocklist in enumerate(self.blocks):
- for block in blocklist:
- h = block(h)
- # Apply global sum pooling as in SN-GAN
- h = torch.sum(self.activation(h), [2, 3])
- # Get initial class-unconditional output
- out = self.linear(h)
- # Get projection of final featureset onto class vectors and add to evidence
- out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
- return out
-
-# Parallelized G_D to minimize cross-gpu communication
-# Without this, Generator outputs would get all-gathered and then rebroadcast.
-class G_D(nn.Module):
- def __init__(self, G, D):
- super(G_D, self).__init__()
- self.G = G
- self.D = D
-
- def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False,
- split_D=False):
- # If training G, enable grad tape
- with torch.set_grad_enabled(train_G):
- # Get Generator output given noise
- G_z = self.G(z, self.G.shared(gy))
- # Cast as necessary
- if self.G.fp16 and not self.D.fp16:
- G_z = G_z.float()
- if self.D.fp16 and not self.G.fp16:
- G_z = G_z.half()
- # Split_D means to run D once with real data and once with fake,
- # rather than concatenating along the batch dimension.
- if split_D:
- D_fake = self.D(G_z, gy)
- if x is not None:
- D_real = self.D(x, dy)
- return D_fake, D_real
- else:
- if return_G_z:
- return D_fake, G_z
- else:
- return D_fake
- # If real data is provided, concatenate it with the Generator's output
- # along the batch dimension for improved efficiency.
- else:
- D_input = torch.cat([G_z, x], 0) if x is not None else G_z
- D_class = torch.cat([gy, dy], 0) if dy is not None else gy
- # Get Discriminator output
- D_out = self.D(D_input, D_class)
- if x is not None:
- return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real
- else:
- if return_G_z:
- return D_out, G_z
- else:
- return D_out
@@ -1,26 +1,28 @@
''' Inception utilities
This file contains methods for calculating IS and FID, using either
- the original numpy code or an accelerated fully-pytorch version that
+ the original numpy code or an accelerated fully-pytorch version that
uses a fast newton-schulz approximation for the matrix sqrt. There are also
methods for acquiring a desired number of samples from the Generator,
and parallelizing the inbuilt PyTorch inception network.
-
- NOTE that Inception Scores and FIDs calculated using these methods will
+
+ NOTE that Inception Scores and FIDs calculated using these methods will
*not* be directly comparable to values calculated using the original TF
IS/FID code. You *must* use the TF model if you wish to report and compare
numbers. This code tends to produce IS values that are 5-10% lower than
- those obtained through TF.
-'''
+ those obtained through TF.
+'''
import numpy as np
from scipy import linalg # For numpy FID
-import time
+import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter as P
+from biggan_preprocess import proc_nodes_module
from torchvision.models.inception import inception_v3
+warnings.filterwarnings("ignore")
# Module that wraps the inception network to enable use with dataparallel and
# returning pool features and logits.
@@ -119,26 +121,6 @@ def torch_cov(m, rowvar=False):
return fact * m.matmul(mt).squeeze()
-# Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji
-# https://github.com/msubhransu/matrix-sqrt
-def sqrt_newton_schulz(A, numIters, dtype=None):
- with torch.no_grad():
- if dtype is None:
- dtype = A.type()
- batchSize = A.shape[0]
- dim = A.shape[1]
- normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt()
- Y = A.div(normA.view(batchSize, 1, 1).expand_as(A));
- I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
- Z = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
- for i in range(numIters):
- T = 0.5*(3.0*I - Z.bmm(Y))
- Y = Y.bmm(T)
- Z = T.bmm(Z)
- sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A)
- return sA
-
-
# FID calculator from TTUR--consider replacing this with GPU-accelerated cov
# calculations using torch?
def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
@@ -152,10 +134,10 @@ def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
-- mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
- -- mu2 : The sample mean over activations, precalculated on an
+ -- mu2 : The sample mean over activations, precalculated on an
representive data set.
-- sigma1: The covariance matrix over activations for generated samples.
- -- sigma2: The covariance matrix over activations, precalculated on an
+ -- sigma2: The covariance matrix over activations, precalculated on an
representive data set.
Returns:
-- : The Frechet Distance.
@@ -189,48 +171,14 @@ def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
- covmean = covmean.real
+ covmean = covmean.real
- tr_covmean = np.trace(covmean)
+ tr_covmean = np.trace(covmean)
out = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
return out
-def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
- """Pytorch implementation of the Frechet Distance.
- Taken from https://github.com/bioinf-jku/TTUR
- The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
- and X_2 ~ N(mu_2, C_2) is
- d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
- Stable version by Dougal J. Sutherland.
- Params:
- -- mu1 : Numpy array containing the activations of a layer of the
- inception net (like returned by the function 'get_predictions')
- for generated samples.
- -- mu2 : The sample mean over activations, precalculated on an
- representive data set.
- -- sigma1: The covariance matrix over activations for generated samples.
- -- sigma2: The covariance matrix over activations, precalculated on an
- representive data set.
- Returns:
- -- : The Frechet Distance.
- """
-
-
- assert mu1.shape == mu2.shape, \
- 'Training and test mean vectors have different lengths'
- assert sigma1.shape == sigma2.shape, \
- 'Training and test covariances have different dimensions'
-
- diff = mu1 - mu2
- # Run 50 itrs of newton-schulz to get the matrix sqrt of sigma1 dot sigma2
- covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50).squeeze()
- out = (diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2)
- - 2 * torch.trace(covmean))
- return out
-
-
# Calculate Inception Score mean + std given softmax'd logits and number of splits
def calculate_inception_score(pred, num_splits=10):
scores = []
@@ -243,24 +191,45 @@ def calculate_inception_score(pred, num_splits=10):
# Loop and run the sampler and the net until it accumulates num_inception_images
-# activations. Return the pool, the logits, and the labels (if one wants
+# activations. Return the pool, the logits, and the labels (if one wants
# Inception Accuracy the labels of the generated class will be needed)
-def accumulate_inception_activations(sample, net, num_inception_images=50000):
+def accumulate_inception_activations(img, label, net, num_inception_images=50000):
pool, logits, labels = [], [], []
+ count = 1
+ img = torch.from_numpy(img)
+ if num_inception_images < 1000:
+ batch = 10
+ else:
+ batch = 100
while (torch.cat(logits, 0).shape[0] if len(logits) else 0) < num_inception_images:
with torch.no_grad():
- images, labels_val = sample()
+ if count * batch > num_inception_images:
+ images = img[(count-1)*batch : num_inception_images]
+ labels_val = np.array([label[(count-1)*batch : num_inception_images]])
+ else:
+ images = img[(count-1)*batch : count*batch]
+ labels_val = np.array([label[(count-1)*batch : count*batch]])
+ labels_val = torch.from_numpy(labels_val[: np.newaxis])
pool_val, logits_val = net(images.float())
pool += [pool_val]
logits += [F.softmax(logits_val, 1)]
labels += [labels_val]
- return torch.cat(pool, 0), torch.cat(logits, 0), torch.cat(labels, 0)
+ count += 1
+ if (count-1) % 10 == 0:
+ print("Has counted {} samples".format((count-1)*batch))
+ pool_concatenate = torch.cat(pool, 0)
+ logits_concatenate = torch.cat(logits, 0)
+ labels_concatenate = torch.cat(labels, 0)
+ return pool_concatenate, logits_concatenate, labels_concatenate
# Load and wrap the Inception model
def load_inception_net(parallel=False):
- inception_model = inception_v3(pretrained=True, transform_input=False)
- inception_model = WrapInception(inception_model.eval()).cuda()
+ inception_model = inception_v3(pretrained=False, transform_input=False)
+ model_checkpoint = torch.load('./inception_v3_google.pth', map_location=torch.device('cpu'))
+ model_checkpoint = proc_nodes_module(model_checkpoint)
+ inception_model.load_state_dict(model_checkpoint)
+ inception_model = WrapInception(inception_model.eval()).to()
if parallel:
print('Parallelizing Inception module...')
inception_model = nn.DataParallel(inception_model)
@@ -271,21 +240,21 @@ def load_inception_net(parallel=False):
# and iterates until it accumulates config['num_inception_images'] images.
# The iterator can return samples with a different batch size than used in
# training, using the setting confg['inception_batchsize']
-def prepare_inception_metrics(dataset, parallel, no_fid=False):
+def prepare_inception_metrics(dataset, parallel=False, no_fid=False):
# Load metrics; this is intentionally not in a try-except loop so that
# the script will crash here if it cannot find the Inception moments.
# By default, remove the "hdf5" from dataset
- dataset = dataset.strip('_hdf5')
data_mu = np.load(dataset+'_inception_moments.npz')['mu']
data_sigma = np.load(dataset+'_inception_moments.npz')['sigma']
# Load network
net = load_inception_net(parallel)
- def get_inception_metrics(sample, num_inception_images, num_splits=10,
- prints=True, use_torch=True):
+ #################################################################
+ def get_inception_metrics(img, label, num_inception_images,
+ num_splits=10, prints=True):
if prints:
print('Gathering activations...')
- pool, logits, labels = accumulate_inception_activations(sample, net, num_inception_images)
- if prints:
+ pool, logits, labels = accumulate_inception_activations(img, label, net, num_inception_images)
+ if prints:
print('Calculating Inception Score...')
IS_mean, IS_std = calculate_inception_score(logits.cpu().numpy(), num_splits)
if no_fid:
@@ -293,18 +262,12 @@ def prepare_inception_metrics(dataset, parallel, no_fid=False):
else:
if prints:
print('Calculating means and covariances...')
- if use_torch:
- mu, sigma = torch.mean(pool, 0), torch_cov(pool, rowvar=False)
- else:
- mu, sigma = np.mean(pool.cpu().numpy(), axis=0), np.cov(pool.cpu().numpy(), rowvar=False)
+ mu, sigma = np.mean(pool.cpu().numpy(), axis=0), np.cov(pool.cpu().numpy(), rowvar=False)
if prints:
print('Covariances calculated, getting FID...')
- if use_torch:
- FID = torch_calculate_frechet_distance(mu, sigma, torch.tensor(data_mu).float().cuda(), torch.tensor(data_sigma).float().cuda())
- FID = float(FID.cpu().numpy())
- else:
- FID = numpy_calculate_frechet_distance(mu.cpu().numpy(), sigma.cpu().numpy(), data_mu, data_sigma)
+ FID = numpy_calculate_frechet_distance(mu, sigma, data_mu, data_sigma)
# Delete mu, sigma, pool, logits, and labels, just in case
del mu, sigma, pool, logits, labels
return IS_mean, IS_std, FID
- return get_inception_metrics
\ No newline at end of file
+ #################################################################
+ return get_inception_metrics
@@ -1,16 +1,12 @@
''' Layers
This file contains various layers for the BigGAN models.
'''
-import numpy as np
+
import torch
import torch.nn as nn
-from torch.nn import init
-import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter as P
-from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
-
# Projection of x onto y
def proj(x, y):
@@ -54,9 +50,9 @@ def power_iteration(W, u_, update=True, eps=1e-12):
class identity(nn.Module):
def forward(self, input):
return input
-
-# Spectral normalization base class
+
+# Spectral normalization base class
class SN(object):
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
# Number of power iterations per step
@@ -71,18 +67,18 @@ class SN(object):
for i in range(self.num_svs):
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
self.register_buffer('sv%d' % i, torch.ones(1))
-
+
# Singular vectors (u side)
@property
def u(self):
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
- # Singular values;
- # note that these buffers are just for logging and are not used in training.
+ # Singular values;
+ # note that these buffers are just for logging and are not used in training.
@property
def sv(self):
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
-
+
# Compute the spectrally-normalized weight
def W_(self):
W_mat = self.weight.view(self.weight.size(0), -1)
@@ -90,25 +86,25 @@ class SN(object):
W_mat = W_mat.t()
# Apply num_itrs power iterations
for _ in range(self.num_itrs):
- svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
+ svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
# Update the svs
if self.training:
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
for i, sv in enumerate(svs):
- self.sv[i][:] = sv
+ self.sv[i][:] = sv
return self.weight / svs[0]
# 2D Conv layer with spectral norm
class SNConv2d(nn.Conv2d, SN):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, dilation=1, groups=1, bias=True,
+ padding=0, dilation=1, groups=1, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
- nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
- SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
+ SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
def forward(self, x):
- return F.conv2d(x, self.W_(), self.bias, self.stride,
+ return F.conv2d(x, self.W_(), self.bias, self.stride,
self.padding, self.dilation, self.groups)
@@ -126,12 +122,12 @@ class SNLinear(nn.Linear, SN):
# We use num_embeddings as the dim instead of embedding_dim here
# for convenience sake
class SNEmbedding(nn.Embedding, SN):
- def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
+ def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
sparse=False, _weight=None,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
- max_norm, norm_type, scale_grad_by_freq,
+ max_norm, norm_type, scale_grad_by_freq,
sparse, _weight)
SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
def forward(self, x):
@@ -157,7 +153,7 @@ class Attention(nn.Module):
# Apply convs
theta = self.theta(x)
phi = F.max_pool2d(self.phi(x), [2,2])
- g = F.max_pool2d(self.g(x), [2,2])
+ g = F.max_pool2d(self.g(x), [2,2])
# Perform reshapes
theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3])
phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4)
@@ -191,7 +187,7 @@ def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
# Cast x to float32 if necessary
float_x = x.float()
- # Calculate expected value of x (m) and expected value of x**2 (m2)
+ # Calculate expected value of x (m) and expected value of x**2 (m2)
# Mean of x
m = torch.mean(float_x, [0, 2, 3], keepdim=True)
# Mean of x squared
@@ -201,14 +197,14 @@ def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
# Cast back to float 16 if necessary
var = var.type(x.type())
m = m.type(x.type())
- # Return mean and variance for updating stored mean/var if requested
+ # Return mean and variance for updating stored mean/var if requested
if return_mean_var:
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
else:
return fused_bn(x, m, var, gain, bias, eps)
-# My batchnorm, supports standing stats
+# My batchnorm, supports standing stats
class myBN(nn.Module):
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
super(myBN, self).__init__()
@@ -224,13 +220,13 @@ class myBN(nn.Module):
self.register_buffer('accumulation_counter', torch.zeros(1))
# Accumulate running means and vars
self.accumulate_standing = False
-
+
# reset standing stats
def reset_stats(self):
self.stored_mean[:] = 0
self.stored_var[:] = 0
self.accumulation_counter[:] = 0
-
+
def forward(self, x, gain, bias):
if self.training:
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
@@ -245,17 +241,17 @@ class myBN(nn.Module):
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
return out
# If not in training mode, use the stored statistics
- else:
+ else:
mean = self.stored_mean.view(1, -1, 1, 1)
var = self.stored_var.view(1, -1, 1, 1)
- # If using standing stats, divide them by the accumulation counter
+ # If using standing stats, divide them by the accumulation counter
if self.accumulate_standing:
mean = mean / self.accumulation_counter
var = var / self.accumulation_counter
return fused_bn(x, mean, var, gain, bias, self.eps)
-# Simple function to handle groupnorm norm stylization
+# Simple function to handle groupnorm norm stylization
def groupnorm(x, norm_style):
# If number of channels specified in norm_style:
if 'ch' in norm_style:
@@ -274,7 +270,7 @@ def groupnorm(x, norm_style):
# output size is the number of channels, input size is for the linear layers
# Andy's Note: this class feels messy but I'm not really sure how to clean it up
# Suggestions welcome! (By which I mean, refactor this and make a pull request
-# if you want to make this more readable/usable).
+# if you want to make this more readable/usable).
class ccbn(nn.Module):
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
cross_replica=False, mybn=False, norm_style='bn',):
@@ -287,28 +283,24 @@ class ccbn(nn.Module):
self.eps = eps
# Momentum
self.momentum = momentum
- # Use cross-replica batchnorm?
- self.cross_replica = cross_replica
# Use my batchnorm?
self.mybn = mybn
# Norm style?
self.norm_style = norm_style
-
- if self.cross_replica:
- self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
- elif self.mybn:
+
+ if self.mybn:
self.bn = myBN(output_size, self.eps, self.momentum)
elif self.norm_style in ['bn', 'in']:
self.register_buffer('stored_mean', torch.zeros(output_size))
- self.register_buffer('stored_var', torch.ones(output_size))
-
-
+ self.register_buffer('stored_var', torch.ones(output_size))
+
+
def forward(self, x, y):
# Calculate class-conditional gains and biases
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
bias = self.bias(y).view(y.size(0), -1, 1, 1)
# If using my batchnorm
- if self.mybn or self.cross_replica:
+ if self.mybn:
return self.bn(x, gain=gain, bias=bias)
# else:
else:
@@ -342,22 +334,18 @@ class bn(nn.Module):
self.eps = eps
# Momentum
self.momentum = momentum
- # Use cross-replica batchnorm?
- self.cross_replica = cross_replica
# Use my batchnorm?
self.mybn = mybn
-
- if self.cross_replica:
- self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
- elif mybn:
+
+ if mybn:
self.bn = myBN(output_size, self.eps, self.momentum)
# Register buffers if neither of the above
- else:
+ else:
self.register_buffer('stored_mean', torch.zeros(output_size))
self.register_buffer('stored_var', torch.ones(output_size))
-
+
def forward(self, x, y=None):
- if self.cross_replica or self.mybn:
+ if self.mybn:
gain = self.gain.view(1,-1,1,1)
bias = self.bias.view(1,-1,1,1)
return self.bn(x, gain=gain, bias=bias)
@@ -365,19 +353,19 @@ class bn(nn.Module):
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
self.bias, self.training, self.momentum, self.eps)
-
+
# Generator blocks
# Note that this class assumes the kernel size and padding (and any other
# settings) have been selected in the main generator module and passed in
# through the which_conv arg. Similar rules apply with which_bn (the input
-# size [which is actually the number of channels of the conditional info] must
+# size [which is actually the number of channels of the conditional info] must
# be preselected)
class GBlock(nn.Module):
def __init__(self, in_channels, out_channels,
- which_conv=nn.Conv2d, which_bn=bn, activation=None,
+ which_conv=nn.Conv2d, which_bn=bn, activation=None,
upsample=None):
super(GBlock, self).__init__()
-
+
self.in_channels, self.out_channels = in_channels, out_channels
self.which_conv, self.which_bn = which_conv, which_bn
self.activation = activation
@@ -387,7 +375,7 @@ class GBlock(nn.Module):
self.conv2 = self.which_conv(self.out_channels, self.out_channels)
self.learnable_sc = in_channels != out_channels or upsample
if self.learnable_sc:
- self.conv_sc = self.which_conv(in_channels, out_channels,
+ self.conv_sc = self.which_conv(in_channels, out_channels,
kernel_size=1, padding=0)
# Batchnorm layers
self.bn1 = self.which_bn(in_channels)
@@ -403,57 +391,6 @@ class GBlock(nn.Module):
h = self.conv1(h)
h = self.activation(self.bn2(h, y))
h = self.conv2(h)
- if self.learnable_sc:
+ if self.learnable_sc:
x = self.conv_sc(x)
return h + x
-
-
-# Residual block for the discriminator
-class DBlock(nn.Module):
- def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
- preactivation=False, activation=None, downsample=None,):
- super(DBlock, self).__init__()
- self.in_channels, self.out_channels = in_channels, out_channels
- # If using wide D (as in SA-GAN and BigGAN), change the channel pattern
- self.hidden_channels = self.out_channels if wide else self.in_channels
- self.which_conv = which_conv
- self.preactivation = preactivation
- self.activation = activation
- self.downsample = downsample
-
- # Conv layers
- self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
- self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
- self.learnable_sc = True if (in_channels != out_channels) or downsample else False
- if self.learnable_sc:
- self.conv_sc = self.which_conv(in_channels, out_channels,
- kernel_size=1, padding=0)
- def shortcut(self, x):
- if self.preactivation:
- if self.learnable_sc:
- x = self.conv_sc(x)
- if self.downsample:
- x = self.downsample(x)
- else:
- if self.downsample:
- x = self.downsample(x)
- if self.learnable_sc:
- x = self.conv_sc(x)
- return x
-
- def forward(self, x):
- if self.preactivation:
- # h = self.activation(x) # NOT TODAY SATAN
- # Andy's note: This line *must* be an out-of-place ReLU or it
- # will negatively affect the shortcut connection.
- h = F.relu(x)
- else:
- h = x
- h = self.conv1(h)
- h = self.conv2(self.activation(h))
- if self.downsample:
- h = self.downsample(h)
-
- return h + self.shortcut(x)
-
-# dogball
\ No newline at end of file