""" Layers
This file contains various layers for the BigGAN models.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter as P
from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
def proj(x, y):
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
def gram_schmidt(x, ys):
for y in ys:
x = x - proj(x, y)
return x
def power_iteration(W, u_, update=True, eps=1e-12):
us, vs, svs = [], [], []
for i, u in enumerate(u_):
with torch.no_grad():
v = torch.matmul(u, W)
v = F.normalize(gram_schmidt(v, vs), eps=eps)
vs += [v]
u = torch.matmul(v, W.t())
u = F.normalize(gram_schmidt(u, us), eps=eps)
us += [u]
if update:
u_[i][:] = u
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
return svs, us, vs
class identity(nn.Module):
def forward(self, input):
return input
class SN(object):
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
self.num_itrs = num_itrs
self.num_svs = num_svs
self.transpose = transpose
self.eps = eps
for i in range(self.num_svs):
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
self.register_buffer('sv%d' % i, torch.ones(1))
@property
def u(self):
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
@property
def sv(self):
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
def W_(self):
W_mat = self.weight.view(self.weight.size(0), -1)
if self.transpose:
W_mat = W_mat.t()
for _ in range(self.num_itrs):
svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
if self.training:
with torch.no_grad():
for i, sv in enumerate(svs):
self.sv[i][:] = sv
return self.weight / svs[0]
class SNConv2d(nn.Conv2d, SN):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
def forward(self, x):
return F.conv2d(x, self.W_(), self.bias, self.stride,
self.padding, self.dilation, self.groups)
class SNLinear(nn.Linear, SN):
def __init__(self, in_features, out_features, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Linear.__init__(self, in_features, out_features, bias)
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
def forward(self, x):
return F.linear(x, self.W_(), self.bias)
class SNEmbedding(nn.Embedding, SN):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
sparse=False, _weight=None,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
max_norm, norm_type, scale_grad_by_freq,
sparse, _weight)
SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
def forward(self, x):
return F.embedding(x, self.W_())
class Attention(nn.Module):
def __init__(self, ch, which_conv=SNConv2d, name='attention'):
super(Attention, self).__init__()
self.ch = ch
self.which_conv = which_conv
self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
self.gamma = P(torch.tensor(0.), requires_grad=True)
def forward(self, x, y=None):
theta = self.theta(x)
phi = F.max_pool2d(self.phi(x), [2, 2])
g = F.max_pool2d(self.g(x), [2, 2])
theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
return self.gamma * o + x
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
scale = torch.rsqrt(var + eps)
if gain is not None:
scale = scale * gain
shift = mean * scale
if bias is not None:
shift = shift - bias
return x * scale - shift
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
float_x = x.float()
m = torch.mean(float_x, [0, 2, 3], keepdim=True)
m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
var = (m2 - m ** 2)
var = var.type(x.type())
m = m.type(x.type())
if return_mean_var:
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
else:
return fused_bn(x, m, var, gain, bias, eps)
class myBN(nn.Module):
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
super(myBN, self).__init__()
self.momentum = momentum
self.eps = eps
self.momentum = momentum
self.register_buffer('stored_mean', torch.zeros(num_channels))
self.register_buffer('stored_var', torch.ones(num_channels))
self.register_buffer('accumulation_counter', torch.zeros(1))
self.accumulate_standing = False
def reset_stats(self):
self.stored_mean[:] = 0
self.stored_var[:] = 0
self.accumulation_counter[:] = 0
def forward(self, x, gain, bias):
if self.training:
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
if self.accumulate_standing:
self.stored_mean[:] = self.stored_mean + mean.data
self.stored_var[:] = self.stored_var + var.data
self.accumulation_counter += 1.0
else:
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
return out
else:
mean = self.stored_mean.view(1, -1, 1, 1)
var = self.stored_var.view(1, -1, 1, 1)
if self.accumulate_standing:
mean = mean / self.accumulation_counter
var = var / self.accumulation_counter
return fused_bn(x, mean, var, gain, bias, self.eps)
def groupnorm(x, norm_style):
if 'ch' in norm_style:
ch = int(norm_style.split('_')[-1])
groups = max(int(x.shape[1]) // ch, 1)
elif 'grp' in norm_style:
groups = int(norm_style.split('_')[-1])
else:
groups = 16
return F.group_norm(x, groups)
class ccbn(nn.Module):
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
cross_replica=False, mybn=False, norm_style='bn', ):
super(ccbn, self).__init__()
self.output_size, self.input_size = output_size, input_size
self.gain = which_linear(input_size, output_size)
self.bias = which_linear(input_size, output_size)
self.eps = eps
self.momentum = momentum
self.cross_replica = cross_replica
self.mybn = mybn
self.norm_style = norm_style
if self.cross_replica:
self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
elif self.mybn:
self.bn = myBN(output_size, self.eps, self.momentum)
elif self.norm_style in ['bn', 'in']:
self.register_buffer('stored_mean', torch.zeros(output_size))
self.register_buffer('stored_var', torch.ones(output_size))
def forward(self, x, y):
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
bias = self.bias(y).view(y.size(0), -1, 1, 1)
if self.mybn or self.cross_replica:
return self.bn(x, gain=gain, bias=bias)
else:
if self.norm_style == 'bn':
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
self.training, 0.1, self.eps)
elif self.norm_style == 'in':
out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
self.training, 0.1, self.eps)
elif self.norm_style == 'gn':
out = groupnorm(x, self.normstyle)
elif self.norm_style == 'nonorm':
out = x
return out * gain + bias
def extra_repr(self):
s = 'out: {output_size}, in: {input_size},'
s += ' cross_replica={cross_replica}'
return s.format(**self.__dict__)
class bn(nn.Module):
def __init__(self, output_size, eps=1e-5, momentum=0.1,
cross_replica=False, mybn=False):
super(bn, self).__init__()
self.output_size = output_size
self.gain = P(torch.ones(output_size), requires_grad=True)
self.bias = P(torch.zeros(output_size), requires_grad=True)
self.eps = eps
self.momentum = momentum
self.cross_replica = cross_replica
self.mybn = mybn
if self.cross_replica:
self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
elif mybn:
self.bn = myBN(output_size, self.eps, self.momentum)
else:
self.register_buffer('stored_mean', torch.zeros(output_size))
self.register_buffer('stored_var', torch.ones(output_size))
def forward(self, x, y=None):
if self.cross_replica or self.mybn:
gain = self.gain.view(1, -1, 1, 1)
bias = self.bias.view(1, -1, 1, 1)
return self.bn(x, gain=gain, bias=bias)
else:
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
self.bias, self.training, self.momentum, self.eps)
class GBlock(nn.Module):
def __init__(self, in_channels, out_channels,
which_conv=nn.Conv2d, which_bn=bn, activation=None,
upsample=None):
super(GBlock, self).__init__()
self.in_channels, self.out_channels = in_channels, out_channels
self.which_conv, self.which_bn = which_conv, which_bn
self.activation = activation
self.upsample = upsample
self.conv1 = self.which_conv(self.in_channels, self.out_channels)
self.conv2 = self.which_conv(self.out_channels, self.out_channels)
self.learnable_sc = in_channels != out_channels or upsample
if self.learnable_sc:
self.conv_sc = self.which_conv(in_channels, out_channels,
kernel_size=1, padding=0)
self.bn1 = self.which_bn(in_channels)
self.bn2 = self.which_bn(out_channels)
self.upsample = upsample
def forward(self, x, y):
h = self.activation(self.bn1(x, y))
if self.upsample:
h = self.upsample(h)
x = self.upsample(x)
h = self.conv1(h)
h = self.activation(self.bn2(h, y))
h = self.conv2(h)
if self.learnable_sc:
x = self.conv_sc(x)
return h + x
class DBlock(nn.Module):
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
preactivation=False, activation=None, downsample=None, ):
super(DBlock, self).__init__()
self.in_channels, self.out_channels = in_channels, out_channels
self.hidden_channels = self.out_channels if wide else self.in_channels
self.which_conv = which_conv
self.preactivation = preactivation
self.activation = activation
self.downsample = downsample
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
self.learnable_sc = True if (in_channels != out_channels) or downsample else False
if self.learnable_sc:
self.conv_sc = self.which_conv(in_channels, out_channels,
kernel_size=1, padding=0)
def shortcut(self, x):
if self.preactivation:
if self.learnable_sc:
x = self.conv_sc(x)
if self.downsample:
x = self.downsample(x)
else:
if self.downsample:
x = self.downsample(x)
if self.learnable_sc:
x = self.conv_sc(x)
return x
def forward(self, x):
if self.preactivation:
h = F.relu(x)
else:
h = x
h = self.conv1(h)
h = self.conv2(self.activation(h))
if self.downsample:
h = self.downsample(h)
return h + self.shortcut(x)