05360171创建于 2022年3月18日历史提交
diff --git a/BigGAN.py b/BigGAN.py
index 036d48e..493c1e5 100644
--- a/BigGAN.py
+++ b/BigGAN.py
@@ -1,5 +1,4 @@
 import numpy as np
-import math
 import functools
 
 import torch
@@ -7,16 +6,14 @@ import torch.nn as nn
 from torch.nn import init
 import torch.optim as optim
 import torch.nn.functional as F
-from torch.nn import Parameter as P
 
 import layers
-from sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
 
 
 # Architectures for G
 # Attention is passed in in the format '32_64' to mean applying an attention
 # block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
-def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
+def G_arch(ch=64, attention='64'):
   arch = {}
   arch[512] = {'in_channels' :  [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
                'out_channels' : [ch * item for item in [16,  8, 8, 4, 2, 1, 1]],
@@ -128,7 +125,7 @@ class Generator(nn.Module):
     else:
       self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
       self.which_linear = nn.Linear
-      
+
     # We use a non-spectral-normed embedding here regardless;
     # For some reason applying SN to G's embedding seems to randomly cripple G
     self.which_embedding = nn.Embedding
@@ -146,7 +143,7 @@ class Generator(nn.Module):
 
     # Prepare model
     # If not using shared embeddings, self.shared is just a passthrough
-    self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared 
+    self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
                     else layers.identity())
     # First linear layer
     self.linear = self.which_linear(self.dim_z // self.num_slots,
@@ -210,7 +207,7 @@ class Generator(nn.Module):
     self.param_count = 0
     for module in self.modules():
       if (isinstance(module, nn.Conv2d) 
-          or isinstance(module, nn.Linear) 
+          or isinstance(module, nn.Linear)
           or isinstance(module, nn.Embedding)):
         if self.init == 'ortho':
           init.orthogonal_(module.weight)
@@ -227,224 +224,24 @@ class Generator(nn.Module):
   # already been passed through G.shared to enable easy class-wise
   # interpolation later. If we passed in the one-hot and then ran it through
   # G.shared in this forward function, it would be harder to handle.
-  def forward(self, z, y):
+  def forward(self, z, ys):
     # If hierarchical, concatenate zs and ys
-    if self.hier:
-      zs = torch.split(z, self.z_chunk_size, 1)
-      z = zs[0]
-      ys = [torch.cat([y, item], 1) for item in zs[1:]]
-    else:
-      ys = [y] * len(self.blocks)
-      
+    if not self.hier:
+      print("仅支持BigGAN分层模型的转换,请设置模型入参 hier=True")
+
+    z = z.view(-1, 20)
+    ys = ys.view(-1, 5, 148)
     # First linear layer
     h = self.linear(z)
     # Reshape
     h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
-    
+
     # Loop over blocks
     for index, blocklist in enumerate(self.blocks):
       # Second inner loop in case block has multiple layers
       for block in blocklist:
-        h = block(h, ys[index])
-        
+        y = ys[:,index,:]
+        h = block(h, y)
+
     # Apply batchnorm-relu-conv-tanh at output
     return torch.tanh(self.output_layer(h))
-
-
-# Discriminator architecture, same paradigm as G's above
-def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'):
-  arch = {}
-  arch[256] = {'in_channels' :  [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]],
-               'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
-               'downsample' : [True] * 6 + [False],
-               'resolution' : [128, 64, 32, 16, 8, 4, 4 ],
-               'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
-                              for i in range(2,8)}}
-  arch[128] = {'in_channels' :  [3] + [ch*item for item in [1, 2, 4, 8, 16]],
-               'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]],
-               'downsample' : [True] * 5 + [False],
-               'resolution' : [64, 32, 16, 8, 4, 4],
-               'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
-                              for i in range(2,8)}}
-  arch[64]  = {'in_channels' :  [3] + [ch*item for item in [1, 2, 4, 8]],
-               'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]],
-               'downsample' : [True] * 4 + [False],
-               'resolution' : [32, 16, 8, 4, 4],
-               'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
-                              for i in range(2,7)}}
-  arch[32]  = {'in_channels' :  [3] + [item * ch for item in [4, 4, 4]],
-               'out_channels' : [item * ch for item in [4, 4, 4, 4]],
-               'downsample' : [True, True, False, False],
-               'resolution' : [16, 16, 16, 16],
-               'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
-                              for i in range(2,6)}}
-  return arch
-
-class Discriminator(nn.Module):
-
-  def __init__(self, D_ch=64, D_wide=True, resolution=128,
-               D_kernel_size=3, D_attn='64', n_classes=1000,
-               num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
-               D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8,
-               SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False,
-               D_init='ortho', skip_init=False, D_param='SN', **kwargs):
-    super(Discriminator, self).__init__()
-    # Width multiplier
-    self.ch = D_ch
-    # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
-    self.D_wide = D_wide
-    # Resolution
-    self.resolution = resolution
-    # Kernel size
-    self.kernel_size = D_kernel_size
-    # Attention?
-    self.attention = D_attn
-    # Number of classes
-    self.n_classes = n_classes
-    # Activation
-    self.activation = D_activation
-    # Initialization style
-    self.init = D_init
-    # Parameterization style
-    self.D_param = D_param
-    # Epsilon for Spectral Norm?
-    self.SN_eps = SN_eps
-    # Fp16?
-    self.fp16 = D_fp16
-    # Architecture
-    self.arch = D_arch(self.ch, self.attention)[resolution]
-
-    # Which convs, batchnorms, and linear layers to use
-    # No option to turn off SN in D right now
-    if self.D_param == 'SN':
-      self.which_conv = functools.partial(layers.SNConv2d,
-                          kernel_size=3, padding=1,
-                          num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
-                          eps=self.SN_eps)
-      self.which_linear = functools.partial(layers.SNLinear,
-                          num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
-                          eps=self.SN_eps)
-      self.which_embedding = functools.partial(layers.SNEmbedding,
-                              num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
-                              eps=self.SN_eps)
-    # Prepare model
-    # self.blocks is a doubly-nested list of modules, the outer loop intended
-    # to be over blocks at a given resolution (resblocks and/or self-attention)
-    self.blocks = []
-    for index in range(len(self.arch['out_channels'])):
-      self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
-                       out_channels=self.arch['out_channels'][index],
-                       which_conv=self.which_conv,
-                       wide=self.D_wide,
-                       activation=self.activation,
-                       preactivation=(index > 0),
-                       downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
-      # If attention on this block, attach it to the end
-      if self.arch['attention'][self.arch['resolution'][index]]:
-        print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
-        self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
-                                             self.which_conv)]
-    # Turn self.blocks into a ModuleList so that it's all properly registered.
-    self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
-    # Linear output layer. The output dimension is typically 1, but may be
-    # larger if we're e.g. turning this into a VAE with an inference output
-    self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
-    # Embedding for projection discrimination
-    self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
-
-    # Initialize weights
-    if not skip_init:
-      self.init_weights()
-
-    # Set up optimizer
-    self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
-    if D_mixed_precision:
-      print('Using fp16 adam in D...')
-      import utils
-      self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
-                             betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
-    else:
-      self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
-                             betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
-    # LR scheduling, left here for forward compatibility
-    # self.lr_sched = {'itr' : 0}# if self.progressive else {}
-    # self.j = 0
-
-  # Initialize
-  def init_weights(self):
-    self.param_count = 0
-    for module in self.modules():
-      if (isinstance(module, nn.Conv2d)
-          or isinstance(module, nn.Linear)
-          or isinstance(module, nn.Embedding)):
-        if self.init == 'ortho':
-          init.orthogonal_(module.weight)
-        elif self.init == 'N02':
-          init.normal_(module.weight, 0, 0.02)
-        elif self.init in ['glorot', 'xavier']:
-          init.xavier_uniform_(module.weight)
-        else:
-          print('Init style not recognized...')
-        self.param_count += sum([p.data.nelement() for p in module.parameters()])
-    print('Param count for D''s initialized parameters: %d' % self.param_count)
-
-  def forward(self, x, y=None):
-    # Stick x into h for cleaner for loops without flow control
-    h = x
-    # Loop over blocks
-    for index, blocklist in enumerate(self.blocks):
-      for block in blocklist:
-        h = block(h)
-    # Apply global sum pooling as in SN-GAN
-    h = torch.sum(self.activation(h), [2, 3])
-    # Get initial class-unconditional output
-    out = self.linear(h)
-    # Get projection of final featureset onto class vectors and add to evidence
-    out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
-    return out
-
-# Parallelized G_D to minimize cross-gpu communication
-# Without this, Generator outputs would get all-gathered and then rebroadcast.
-class G_D(nn.Module):
-  def __init__(self, G, D):
-    super(G_D, self).__init__()
-    self.G = G
-    self.D = D
-
-  def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False,
-              split_D=False):              
-    # If training G, enable grad tape
-    with torch.set_grad_enabled(train_G):
-      # Get Generator output given noise
-      G_z = self.G(z, self.G.shared(gy))
-      # Cast as necessary
-      if self.G.fp16 and not self.D.fp16:
-        G_z = G_z.float()
-      if self.D.fp16 and not self.G.fp16:
-        G_z = G_z.half()
-    # Split_D means to run D once with real data and once with fake,
-    # rather than concatenating along the batch dimension.
-    if split_D:
-      D_fake = self.D(G_z, gy)
-      if x is not None:
-        D_real = self.D(x, dy)
-        return D_fake, D_real
-      else:
-        if return_G_z:
-          return D_fake, G_z
-        else:
-          return D_fake
-    # If real data is provided, concatenate it with the Generator's output
-    # along the batch dimension for improved efficiency.
-    else:
-      D_input = torch.cat([G_z, x], 0) if x is not None else G_z
-      D_class = torch.cat([gy, dy], 0) if dy is not None else gy
-      # Get Discriminator output
-      D_out = self.D(D_input, D_class)
-      if x is not None:
-        return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real
-      else:
-        if return_G_z:
-          return D_out, G_z
-        else:
-          return D_out
diff --git a/inception_utils.py b/inception_utils.py
index 373d3cd..6dafc75 100644
--- a/inception_utils.py
+++ b/inception_utils.py
@@ -1,26 +1,28 @@
 ''' Inception utilities
     This file contains methods for calculating IS and FID, using either
-    the original numpy code or an accelerated fully-pytorch version that 
+    the original numpy code or an accelerated fully-pytorch version that
     uses a fast newton-schulz approximation for the matrix sqrt. There are also
     methods for acquiring a desired number of samples from the Generator,
     and parallelizing the inbuilt PyTorch inception network.
-    
-    NOTE that Inception Scores and FIDs calculated using these methods will 
+
+    NOTE that Inception Scores and FIDs calculated using these methods will
     *not* be directly comparable to values calculated using the original TF
     IS/FID code. You *must* use the TF model if you wish to report and compare
     numbers. This code tends to produce IS values that are 5-10% lower than
-    those obtained through TF. 
-'''    
+    those obtained through TF.
+'''
 import numpy as np
 from scipy import linalg # For numpy FID
-import time
+import warnings
 
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from torch.nn import Parameter as P
+from biggan_preprocess import proc_nodes_module
 from torchvision.models.inception import inception_v3
 
+warnings.filterwarnings("ignore")
 
 # Module that wraps the inception network to enable use with dataparallel and
 # returning pool features and logits.
@@ -119,26 +121,6 @@ def torch_cov(m, rowvar=False):
     return fact * m.matmul(mt).squeeze()
 
 
-# Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji
-# https://github.com/msubhransu/matrix-sqrt 
-def sqrt_newton_schulz(A, numIters, dtype=None):
-  with torch.no_grad():
-    if dtype is None:
-      dtype = A.type()
-    batchSize = A.shape[0]
-    dim = A.shape[1]
-    normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt()
-    Y = A.div(normA.view(batchSize, 1, 1).expand_as(A));
-    I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
-    Z = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
-    for i in range(numIters):
-      T = 0.5*(3.0*I - Z.bmm(Y))
-      Y = Y.bmm(T)
-      Z = T.bmm(Z)
-    sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A)
-  return sA
-
-
 # FID calculator from TTUR--consider replacing this with GPU-accelerated cov
 # calculations using torch?
 def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
@@ -152,10 +134,10 @@ def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
   -- mu1   : Numpy array containing the activations of a layer of the
              inception net (like returned by the function 'get_predictions')
              for generated samples.
-  -- mu2   : The sample mean over activations, precalculated on an 
+  -- mu2   : The sample mean over activations, precalculated on an
              representive data set.
   -- sigma1: The covariance matrix over activations for generated samples.
-  -- sigma2: The covariance matrix over activations, precalculated on an 
+  -- sigma2: The covariance matrix over activations, precalculated on an
              representive data set.
   Returns:
   --   : The Frechet Distance.
@@ -189,48 +171,14 @@ def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
     if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
       m = np.max(np.abs(covmean.imag))
       raise ValueError('Imaginary component {}'.format(m))
-    covmean = covmean.real  
+    covmean = covmean.real
 
-  tr_covmean = np.trace(covmean) 
+  tr_covmean = np.trace(covmean)
 
   out = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
   return out
 
 
-def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
-  """Pytorch implementation of the Frechet Distance.
-  Taken from https://github.com/bioinf-jku/TTUR
-  The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
-  and X_2 ~ N(mu_2, C_2) is
-          d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
-  Stable version by Dougal J. Sutherland.
-  Params:
-  -- mu1   : Numpy array containing the activations of a layer of the
-             inception net (like returned by the function 'get_predictions')
-             for generated samples.
-  -- mu2   : The sample mean over activations, precalculated on an 
-             representive data set.
-  -- sigma1: The covariance matrix over activations for generated samples.
-  -- sigma2: The covariance matrix over activations, precalculated on an 
-             representive data set.
-  Returns:
-  --   : The Frechet Distance.
-  """
-
-
-  assert mu1.shape == mu2.shape, \
-    'Training and test mean vectors have different lengths'
-  assert sigma1.shape == sigma2.shape, \
-    'Training and test covariances have different dimensions'
-
-  diff = mu1 - mu2
-  # Run 50 itrs of newton-schulz to get the matrix sqrt of sigma1 dot sigma2
-  covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50).squeeze()  
-  out = (diff.dot(diff) +  torch.trace(sigma1) + torch.trace(sigma2)
-         - 2 * torch.trace(covmean))
-  return out
-
-
 # Calculate Inception Score mean + std given softmax'd logits and number of splits
 def calculate_inception_score(pred, num_splits=10):
   scores = []
@@ -243,24 +191,45 @@ def calculate_inception_score(pred, num_splits=10):
 
 
 # Loop and run the sampler and the net until it accumulates num_inception_images
-# activations. Return the pool, the logits, and the labels (if one wants 
+# activations. Return the pool, the logits, and the labels (if one wants
 # Inception Accuracy the labels of the generated class will be needed)
-def accumulate_inception_activations(sample, net, num_inception_images=50000):
+def accumulate_inception_activations(img, label, net, num_inception_images=50000):
   pool, logits, labels = [], [], []
+  count = 1
+  img = torch.from_numpy(img)
+  if num_inception_images < 1000:
+    batch = 10
+  else:
+    batch = 100
   while (torch.cat(logits, 0).shape[0] if len(logits) else 0) < num_inception_images:
     with torch.no_grad():
-      images, labels_val = sample()
+      if count * batch > num_inception_images:
+        images = img[(count-1)*batch : num_inception_images]
+        labels_val = np.array([label[(count-1)*batch : num_inception_images]])
+      else:
+        images = img[(count-1)*batch : count*batch]
+        labels_val = np.array([label[(count-1)*batch : count*batch]])
+      labels_val = torch.from_numpy(labels_val[: np.newaxis])
       pool_val, logits_val = net(images.float())
       pool += [pool_val]
       logits += [F.softmax(logits_val, 1)]
       labels += [labels_val]
-  return torch.cat(pool, 0), torch.cat(logits, 0), torch.cat(labels, 0)
+      count += 1
+      if (count-1) % 10 == 0:
+        print("Has counted {} samples".format((count-1)*batch))
+  pool_concatenate = torch.cat(pool, 0)
+  logits_concatenate = torch.cat(logits, 0)
+  labels_concatenate = torch.cat(labels, 0)
+  return pool_concatenate, logits_concatenate, labels_concatenate
 
 
 # Load and wrap the Inception model
 def load_inception_net(parallel=False):
-  inception_model = inception_v3(pretrained=True, transform_input=False)
-  inception_model = WrapInception(inception_model.eval()).cuda()
+  inception_model = inception_v3(pretrained=False, transform_input=False)
+  model_checkpoint = torch.load('./inception_v3_google.pth', map_location=torch.device('cpu'))
+  model_checkpoint = proc_nodes_module(model_checkpoint)
+  inception_model.load_state_dict(model_checkpoint)
+  inception_model = WrapInception(inception_model.eval()).to()
   if parallel:
     print('Parallelizing Inception module...')
     inception_model = nn.DataParallel(inception_model)
@@ -271,21 +240,21 @@ def load_inception_net(parallel=False):
 # and iterates until it accumulates config['num_inception_images'] images.
 # The iterator can return samples with a different batch size than used in
 # training, using the setting confg['inception_batchsize']
-def prepare_inception_metrics(dataset, parallel, no_fid=False):
+def prepare_inception_metrics(dataset, parallel=False, no_fid=False):
   # Load metrics; this is intentionally not in a try-except loop so that
   # the script will crash here if it cannot find the Inception moments.
   # By default, remove the "hdf5" from dataset
-  dataset = dataset.strip('_hdf5')
   data_mu = np.load(dataset+'_inception_moments.npz')['mu']
   data_sigma = np.load(dataset+'_inception_moments.npz')['sigma']
   # Load network
   net = load_inception_net(parallel)
-  def get_inception_metrics(sample, num_inception_images, num_splits=10, 
-                            prints=True, use_torch=True):
+  #################################################################
+  def get_inception_metrics(img, label, num_inception_images,
+                            num_splits=10, prints=True):
     if prints:
       print('Gathering activations...')
-    pool, logits, labels = accumulate_inception_activations(sample, net, num_inception_images)
-    if prints:  
+    pool, logits, labels = accumulate_inception_activations(img, label, net, num_inception_images)
+    if prints:
       print('Calculating Inception Score...')
     IS_mean, IS_std = calculate_inception_score(logits.cpu().numpy(), num_splits)
     if no_fid:
@@ -293,18 +262,12 @@ def prepare_inception_metrics(dataset, parallel, no_fid=False):
     else:
       if prints:
         print('Calculating means and covariances...')
-      if use_torch:
-        mu, sigma = torch.mean(pool, 0), torch_cov(pool, rowvar=False)
-      else:
-        mu, sigma = np.mean(pool.cpu().numpy(), axis=0), np.cov(pool.cpu().numpy(), rowvar=False)
+      mu, sigma = np.mean(pool.cpu().numpy(), axis=0), np.cov(pool.cpu().numpy(), rowvar=False)
       if prints:
         print('Covariances calculated, getting FID...')
-      if use_torch:
-        FID = torch_calculate_frechet_distance(mu, sigma, torch.tensor(data_mu).float().cuda(), torch.tensor(data_sigma).float().cuda())
-        FID = float(FID.cpu().numpy())
-      else:
-        FID = numpy_calculate_frechet_distance(mu.cpu().numpy(), sigma.cpu().numpy(), data_mu, data_sigma)
+      FID = numpy_calculate_frechet_distance(mu, sigma, data_mu, data_sigma)
     # Delete mu, sigma, pool, logits, and labels, just in case
     del mu, sigma, pool, logits, labels
     return IS_mean, IS_std, FID
-  return get_inception_metrics
\ No newline at end of file
+  #################################################################
+  return get_inception_metrics
diff --git a/layers.py b/layers.py
index 55aaab1..956edc8 100644
--- a/layers.py
+++ b/layers.py
@@ -1,16 +1,12 @@
 ''' Layers
     This file contains various layers for the BigGAN models.
 '''
-import numpy as np
+
 import torch
 import torch.nn as nn
-from torch.nn import init
-import torch.optim as optim
 import torch.nn.functional as F
 from torch.nn import Parameter as P
 
-from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
-
 
 # Projection of x onto y
 def proj(x, y):
@@ -54,9 +50,9 @@ def power_iteration(W, u_, update=True, eps=1e-12):
 class identity(nn.Module):
   def forward(self, input):
     return input
- 
 
-# Spectral normalization base class 
+
+# Spectral normalization base class
 class SN(object):
   def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
     # Number of power iterations per step
@@ -71,18 +67,18 @@ class SN(object):
     for i in range(self.num_svs):
       self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
       self.register_buffer('sv%d' % i, torch.ones(1))
-  
+
   # Singular vectors (u side)
   @property
   def u(self):
     return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
 
-  # Singular values; 
-  # note that these buffers are just for logging and are not used in training. 
+  # Singular values;
+  # note that these buffers are just for logging and are not used in training.
   @property
   def sv(self):
    return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
-   
+
   # Compute the spectrally-normalized weight
   def W_(self):
     W_mat = self.weight.view(self.weight.size(0), -1)
@@ -90,25 +86,25 @@ class SN(object):
       W_mat = W_mat.t()
     # Apply num_itrs power iterations
     for _ in range(self.num_itrs):
-      svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) 
+      svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
     # Update the svs
     if self.training:
       with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
         for i, sv in enumerate(svs):
-          self.sv[i][:] = sv     
+          self.sv[i][:] = sv
     return self.weight / svs[0]
 
 
 # 2D Conv layer with spectral norm
 class SNConv2d(nn.Conv2d, SN):
   def __init__(self, in_channels, out_channels, kernel_size, stride=1,
-             padding=0, dilation=1, groups=1, bias=True, 
+             padding=0, dilation=1, groups=1, bias=True,
              num_svs=1, num_itrs=1, eps=1e-12):
-    nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, 
+    nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
                      padding, dilation, groups, bias)
-    SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)    
+    SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
   def forward(self, x):
-    return F.conv2d(x, self.W_(), self.bias, self.stride, 
+    return F.conv2d(x, self.W_(), self.bias, self.stride,
                     self.padding, self.dilation, self.groups)
 
 
@@ -126,12 +122,12 @@ class SNLinear(nn.Linear, SN):
 # We use num_embeddings as the dim instead of embedding_dim here
 # for convenience sake
 class SNEmbedding(nn.Embedding, SN):
-  def __init__(self, num_embeddings, embedding_dim, padding_idx=None, 
+  def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
                max_norm=None, norm_type=2, scale_grad_by_freq=False,
                sparse=False, _weight=None,
                num_svs=1, num_itrs=1, eps=1e-12):
     nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
-                          max_norm, norm_type, scale_grad_by_freq, 
+                          max_norm, norm_type, scale_grad_by_freq,
                           sparse, _weight)
     SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
   def forward(self, x):
@@ -157,7 +153,7 @@ class Attention(nn.Module):
     # Apply convs
     theta = self.theta(x)
     phi = F.max_pool2d(self.phi(x), [2,2])
-    g = F.max_pool2d(self.g(x), [2,2])    
+    g = F.max_pool2d(self.g(x), [2,2])
     # Perform reshapes
     theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3])
     phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4)
@@ -191,7 +187,7 @@ def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
 def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
   # Cast x to float32 if necessary
   float_x = x.float()
-  # Calculate expected value of x (m) and expected value of x**2 (m2)  
+  # Calculate expected value of x (m) and expected value of x**2 (m2)
   # Mean of x
   m = torch.mean(float_x, [0, 2, 3], keepdim=True)
   # Mean of x squared
@@ -201,14 +197,14 @@ def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
   # Cast back to float 16 if necessary
   var = var.type(x.type())
   m = m.type(x.type())
-  # Return mean and variance for updating stored mean/var if requested  
+  # Return mean and variance for updating stored mean/var if requested
   if return_mean_var:
     return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
   else:
     return fused_bn(x, m, var, gain, bias, eps)
 
 
-# My batchnorm, supports standing stats    
+# My batchnorm, supports standing stats
 class myBN(nn.Module):
   def __init__(self, num_channels, eps=1e-5, momentum=0.1):
     super(myBN, self).__init__()
@@ -224,13 +220,13 @@ class myBN(nn.Module):
     self.register_buffer('accumulation_counter', torch.zeros(1))
     # Accumulate running means and vars
     self.accumulate_standing = False
-    
+
   # reset standing stats
   def reset_stats(self):
     self.stored_mean[:] = 0
     self.stored_var[:] = 0
     self.accumulation_counter[:] = 0
-    
+
   def forward(self, x, gain, bias):
     if self.training:
       out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
@@ -245,17 +241,17 @@ class myBN(nn.Module):
         self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
       return out
     # If not in training mode, use the stored statistics
-    else:         
+    else:
       mean = self.stored_mean.view(1, -1, 1, 1)
       var = self.stored_var.view(1, -1, 1, 1)
-      # If using standing stats, divide them by the accumulation counter   
+      # If using standing stats, divide them by the accumulation counter
       if self.accumulate_standing:
         mean = mean / self.accumulation_counter
         var = var / self.accumulation_counter
       return fused_bn(x, mean, var, gain, bias, self.eps)
 
 
-# Simple function to handle groupnorm norm stylization                      
+# Simple function to handle groupnorm norm stylization
 def groupnorm(x, norm_style):
   # If number of channels specified in norm_style:
   if 'ch' in norm_style:
@@ -274,7 +270,7 @@ def groupnorm(x, norm_style):
 # output size is the number of channels, input size is for the linear layers
 # Andy's Note: this class feels messy but I'm not really sure how to clean it up
 # Suggestions welcome! (By which I mean, refactor this and make a pull request
-# if you want to make this more readable/usable). 
+# if you want to make this more readable/usable).
 class ccbn(nn.Module):
   def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
                cross_replica=False, mybn=False, norm_style='bn',):
@@ -287,28 +283,24 @@ class ccbn(nn.Module):
     self.eps = eps
     # Momentum
     self.momentum = momentum
-    # Use cross-replica batchnorm?
-    self.cross_replica = cross_replica
     # Use my batchnorm?
     self.mybn = mybn
     # Norm style?
     self.norm_style = norm_style
-    
-    if self.cross_replica:
-      self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
-    elif self.mybn:
+
+    if self.mybn:
       self.bn = myBN(output_size, self.eps, self.momentum)
     elif self.norm_style in ['bn', 'in']:
       self.register_buffer('stored_mean', torch.zeros(output_size))
-      self.register_buffer('stored_var',  torch.ones(output_size)) 
-    
-    
+      self.register_buffer('stored_var',  torch.ones(output_size))
+
+
   def forward(self, x, y):
     # Calculate class-conditional gains and biases
     gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
     bias = self.bias(y).view(y.size(0), -1, 1, 1)
     # If using my batchnorm
-    if self.mybn or self.cross_replica:
+    if self.mybn:
       return self.bn(x, gain=gain, bias=bias)
     # else:
     else:
@@ -342,22 +334,18 @@ class bn(nn.Module):
     self.eps = eps
     # Momentum
     self.momentum = momentum
-    # Use cross-replica batchnorm?
-    self.cross_replica = cross_replica
     # Use my batchnorm?
     self.mybn = mybn
-    
-    if self.cross_replica:
-      self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)    
-    elif mybn:
+
+    if mybn:
       self.bn = myBN(output_size, self.eps, self.momentum)
      # Register buffers if neither of the above
-    else:     
+    else:
       self.register_buffer('stored_mean', torch.zeros(output_size))
       self.register_buffer('stored_var',  torch.ones(output_size))
-    
+
   def forward(self, x, y=None):
-    if self.cross_replica or self.mybn:
+    if self.mybn:
       gain = self.gain.view(1,-1,1,1)
       bias = self.bias.view(1,-1,1,1)
       return self.bn(x, gain=gain, bias=bias)
@@ -365,19 +353,19 @@ class bn(nn.Module):
       return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
                           self.bias, self.training, self.momentum, self.eps)
 
-                          
+
 # Generator blocks
 # Note that this class assumes the kernel size and padding (and any other
 # settings) have been selected in the main generator module and passed in
 # through the which_conv arg. Similar rules apply with which_bn (the input
-# size [which is actually the number of channels of the conditional info] must 
+# size [which is actually the number of channels of the conditional info] must
 # be preselected)
 class GBlock(nn.Module):
   def __init__(self, in_channels, out_channels,
-               which_conv=nn.Conv2d, which_bn=bn, activation=None, 
+               which_conv=nn.Conv2d, which_bn=bn, activation=None,
                upsample=None):
     super(GBlock, self).__init__()
-    
+
     self.in_channels, self.out_channels = in_channels, out_channels
     self.which_conv, self.which_bn = which_conv, which_bn
     self.activation = activation
@@ -387,7 +375,7 @@ class GBlock(nn.Module):
     self.conv2 = self.which_conv(self.out_channels, self.out_channels)
     self.learnable_sc = in_channels != out_channels or upsample
     if self.learnable_sc:
-      self.conv_sc = self.which_conv(in_channels, out_channels, 
+      self.conv_sc = self.which_conv(in_channels, out_channels,
                                      kernel_size=1, padding=0)
     # Batchnorm layers
     self.bn1 = self.which_bn(in_channels)
@@ -403,57 +391,6 @@ class GBlock(nn.Module):
     h = self.conv1(h)
     h = self.activation(self.bn2(h, y))
     h = self.conv2(h)
-    if self.learnable_sc:       
+    if self.learnable_sc:
       x = self.conv_sc(x)
     return h + x
-    
-    
-# Residual block for the discriminator
-class DBlock(nn.Module):
-  def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
-               preactivation=False, activation=None, downsample=None,):
-    super(DBlock, self).__init__()
-    self.in_channels, self.out_channels = in_channels, out_channels
-    # If using wide D (as in SA-GAN and BigGAN), change the channel pattern
-    self.hidden_channels = self.out_channels if wide else self.in_channels
-    self.which_conv = which_conv
-    self.preactivation = preactivation
-    self.activation = activation
-    self.downsample = downsample
-        
-    # Conv layers
-    self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
-    self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
-    self.learnable_sc = True if (in_channels != out_channels) or downsample else False
-    if self.learnable_sc:
-      self.conv_sc = self.which_conv(in_channels, out_channels, 
-                                     kernel_size=1, padding=0)
-  def shortcut(self, x):
-    if self.preactivation:
-      if self.learnable_sc:
-        x = self.conv_sc(x)
-      if self.downsample:
-        x = self.downsample(x)
-    else:
-      if self.downsample:
-        x = self.downsample(x)
-      if self.learnable_sc:
-        x = self.conv_sc(x)
-    return x
-    
-  def forward(self, x):
-    if self.preactivation:
-      # h = self.activation(x) # NOT TODAY SATAN
-      # Andy's note: This line *must* be an out-of-place ReLU or it 
-      #              will negatively affect the shortcut connection.
-      h = F.relu(x)
-    else:
-      h = x    
-    h = self.conv1(h)
-    h = self.conv2(self.activation(h))
-    if self.downsample:
-      h = self.downsample(h)     
-        
-    return h + self.shortcut(x)
-    
-# dogball
\ No newline at end of file