import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import init
import layers
class GBlock(nn.Module):
def __init__(self, in_channels, out_channels,
which_conv=nn.Conv2d, which_bn=layers.bn, activation=None,
upsample=None, channel_ratio=4):
super(GBlock, self).__init__()
self.in_channels, self.out_channels = in_channels, out_channels
self.hidden_channels = self.in_channels // channel_ratio
self.which_conv, self.which_bn = which_conv, which_bn
self.activation = activation
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels,
kernel_size=1, padding=0)
self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
self.conv4 = self.which_conv(self.hidden_channels, self.out_channels,
kernel_size=1, padding=0)
self.bn1 = self.which_bn(self.in_channels)
self.bn2 = self.which_bn(self.hidden_channels)
self.bn3 = self.which_bn(self.hidden_channels)
self.bn4 = self.which_bn(self.hidden_channels)
self.upsample = upsample
def forward(self, x, y):
h = self.conv1(self.activation(self.bn1(x, y)))
h = self.activation(self.bn2(h, y))
if self.in_channels != self.out_channels:
x = x[:, :self.out_channels]
if self.upsample:
h = self.upsample(h)
x = self.upsample(x)
h = self.conv2(h)
h = self.conv3(self.activation(self.bn3(h, y)))
h = self.conv4(self.activation(self.bn4(h, y)))
return h + x
def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
arch = {}
arch[256] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]],
'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]],
'upsample': [True] * 6,
'resolution': [8, 16, 32, 64, 128, 256],
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
for i in range(3, 9)}}
arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
'out_channels': [ch * item for item in [16, 8, 4, 2, 1]],
'upsample': [True] * 5,
'resolution': [8, 16, 32, 64, 128],
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
for i in range(3, 8)}}
arch[64] = {'in_channels': [ch * item for item in [16, 16, 8, 4]],
'out_channels': [ch * item for item in [16, 8, 4, 2]],
'upsample': [True] * 4,
'resolution': [8, 16, 32, 64],
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
for i in range(3, 7)}}
arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]],
'out_channels': [ch * item for item in [4, 4, 4]],
'upsample': [True] * 3,
'resolution': [8, 16, 32],
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
for i in range(3, 6)}}
return arch
class Generator(nn.Module):
def __init__(self, G_ch=64, G_depth=2, dim_z=128, bottom_width=4, resolution=128,
G_kernel_size=3, G_attn='64', n_classes=1000,
num_G_SVs=1, num_G_SV_itrs=1,
G_shared=True, shared_dim=0, hier=False,
cross_replica=False, mybn=False,
G_activation=nn.ReLU(inplace=False),
G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
G_init='ortho', skip_init=False, no_optim=False,
G_param='SN', norm_style='bn',
**kwargs):
super(Generator, self).__init__()
self.ch = G_ch
self.G_depth = G_depth
self.dim_z = dim_z
self.bottom_width = bottom_width
self.resolution = resolution
self.kernel_size = G_kernel_size
self.attention = G_attn
self.n_classes = n_classes
self.G_shared = G_shared
self.shared_dim = shared_dim if shared_dim > 0 else dim_z
self.hier = hier
self.cross_replica = cross_replica
self.mybn = mybn
self.activation = G_activation
self.init = G_init
self.G_param = G_param
self.norm_style = norm_style
self.BN_eps = BN_eps
self.SN_eps = SN_eps
self.fp16 = G_fp16
self.arch = G_arch(self.ch, self.attention)[resolution]
if self.G_param == 'SN':
self.which_conv = functools.partial(layers.SNConv2d,
kernel_size=3, padding=1,
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
eps=self.SN_eps)
self.which_linear = functools.partial(layers.SNLinear,
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
eps=self.SN_eps)
else:
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
self.which_linear = nn.Linear
self.which_embedding = nn.Embedding
bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
else self.which_embedding)
self.which_bn = functools.partial(layers.ccbn,
which_linear=bn_linear,
cross_replica=self.cross_replica,
mybn=self.mybn,
input_size=(self.shared_dim + self.dim_z if self.G_shared
else self.n_classes),
norm_style=self.norm_style,
eps=self.BN_eps)
self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
else layers.identity())
self.linear = self.which_linear(self.dim_z + self.shared_dim,
self.arch['in_channels'][0] * (self.bottom_width ** 2))
self.blocks = []
for index in range(len(self.arch['out_channels'])):
self.blocks += [[GBlock(in_channels=self.arch['in_channels'][index],
out_channels=self.arch['in_channels'][index] if g_index == 0 else
self.arch['out_channels'][index],
which_conv=self.which_conv,
which_bn=self.which_bn,
activation=self.activation,
upsample=(functools.partial(F.interpolate, scale_factor=2)
if self.arch['upsample'][index] and g_index == (
self.G_depth - 1) else None))]
for g_index in range(self.G_depth)]
if self.arch['attention'][self.arch['resolution'][index]]:
print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
cross_replica=self.cross_replica,
mybn=self.mybn),
self.activation,
self.which_conv(self.arch['out_channels'][-1], 3))
if not skip_init:
self.init_weights()
if no_optim:
return
self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
if G_mixed_precision:
print('Using fp16 adam in G...')
import utils
self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
betas=(self.B1, self.B2), weight_decay=0,
eps=self.adam_eps)
else:
self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
betas=(self.B1, self.B2), weight_decay=0,
eps=self.adam_eps)
def init_weights(self):
self.param_count = 0
for module in self.modules():
if (isinstance(module, nn.Conv2d)
or isinstance(module, nn.Linear)
or isinstance(module, nn.Embedding)):
if self.init == 'ortho':
init.orthogonal_(module.weight)
elif self.init == 'N02':
init.normal_(module.weight, 0, 0.02)
elif self.init in ['glorot', 'xavier']:
init.xavier_uniform_(module.weight)
else:
print('Init style not recognized...')
self.param_count += sum([p.data.nelement() for p in module.parameters()])
print('Param count for G''s initialized parameters: %d' % self.param_count)
def forward(self, z, y):
if self.hier:
z = torch.cat([y, z], 1)
y = z
h = self.linear(z)
h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
for index, blocklist in enumerate(self.blocks):
for block in blocklist:
h = block(h, y)
return torch.tanh(self.output_layer(h))
class DBlock(nn.Module):
def __init__(self, in_channels, out_channels, which_conv=layers.SNConv2d, wide=True,
preactivation=True, activation=None, downsample=None,
channel_ratio=4):
super(DBlock, self).__init__()
self.in_channels, self.out_channels = in_channels, out_channels
self.hidden_channels = self.out_channels // channel_ratio
self.which_conv = which_conv
self.preactivation = preactivation
self.activation = activation
self.downsample = downsample
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels,
kernel_size=1, padding=0)
self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
self.conv4 = self.which_conv(self.hidden_channels, self.out_channels,
kernel_size=1, padding=0)
self.learnable_sc = True if (in_channels != out_channels) else False
if self.learnable_sc:
self.conv_sc = self.which_conv(in_channels, out_channels - in_channels,
kernel_size=1, padding=0)
def shortcut(self, x):
if self.downsample:
x = self.downsample(x)
if self.learnable_sc:
x = torch.cat([x, self.conv_sc(x)], 1)
return x
def forward(self, x):
h = self.conv1(F.relu(x))
h = self.conv2(self.activation(h))
h = self.conv3(self.activation(h))
h = self.activation(h)
if self.downsample:
h = self.downsample(h)
h = self.conv4(h)
return h + self.shortcut(x)
def D_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
arch = {}
arch[256] = {'in_channels': [item * ch for item in [1, 2, 4, 8, 8, 16]],
'out_channels': [item * ch for item in [2, 4, 8, 8, 16, 16]],
'downsample': [True] * 6 + [False],
'resolution': [128, 64, 32, 16, 8, 4, 4],
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
for i in range(2, 8)}}
arch[128] = {'in_channels': [item * ch for item in [1, 2, 4, 8, 16]],
'out_channels': [item * ch for item in [2, 4, 8, 16, 16]],
'downsample': [True] * 5 + [False],
'resolution': [64, 32, 16, 8, 4, 4],
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
for i in range(2, 8)}}
arch[64] = {'in_channels': [item * ch for item in [1, 2, 4, 8]],
'out_channels': [item * ch for item in [2, 4, 8, 16]],
'downsample': [True] * 4 + [False],
'resolution': [32, 16, 8, 4, 4],
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
for i in range(2, 7)}}
arch[32] = {'in_channels': [item * ch for item in [4, 4, 4]],
'out_channels': [item * ch for item in [4, 4, 4]],
'downsample': [True, True, False, False],
'resolution': [16, 16, 16, 16],
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
for i in range(2, 6)}}
return arch
class Discriminator(nn.Module):
def __init__(self, D_ch=64, D_wide=True, D_depth=2, resolution=128,
D_kernel_size=3, D_attn='64', n_classes=1000,
num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8,
SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False,
D_init='ortho', skip_init=False, D_param='SN', **kwargs):
super(Discriminator, self).__init__()
self.ch = D_ch
self.D_wide = D_wide
self.D_depth = D_depth
self.resolution = resolution
self.kernel_size = D_kernel_size
self.attention = D_attn
self.n_classes = n_classes
self.activation = D_activation
self.init = D_init
self.D_param = D_param
self.SN_eps = SN_eps
self.fp16 = D_fp16
self.arch = D_arch(self.ch, self.attention)[resolution]
if self.D_param == 'SN':
self.which_conv = functools.partial(layers.SNConv2d,
kernel_size=3, padding=1,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
eps=self.SN_eps)
self.which_linear = functools.partial(layers.SNLinear,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
eps=self.SN_eps)
self.which_embedding = functools.partial(layers.SNEmbedding,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
eps=self.SN_eps)
self.input_conv = self.which_conv(3, self.arch['in_channels'][0])
self.blocks = []
for index in range(len(self.arch['out_channels'])):
self.blocks += [[DBlock(
in_channels=self.arch['in_channels'][index] if d_index == 0 else self.arch['out_channels'][index],
out_channels=self.arch['out_channels'][index],
which_conv=self.which_conv,
wide=self.D_wide,
activation=self.activation,
preactivation=True,
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] and d_index == 0 else None))
for d_index in range(self.D_depth)]]
if self.arch['attention'][self.arch['resolution'][index]]:
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
self.which_conv)]
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
if not skip_init:
self.init_weights()
self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
if D_mixed_precision:
print('Using fp16 adam in D...')
import utils
self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
else:
self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
def init_weights(self):
self.param_count = 0
for module in self.modules():
if (isinstance(module, nn.Conv2d)
or isinstance(module, nn.Linear)
or isinstance(module, nn.Embedding)):
if self.init == 'ortho':
init.orthogonal_(module.weight)
elif self.init == 'N02':
init.normal_(module.weight, 0, 0.02)
elif self.init in ['glorot', 'xavier']:
init.xavier_uniform_(module.weight)
else:
print('Init style not recognized...')
self.param_count += sum([p.data.nelement() for p in module.parameters()])
print('Param count for D''s initialized parameters: %d' % self.param_count)
def forward(self, x, y=None):
h = self.input_conv(x)
for index, blocklist in enumerate(self.blocks):
for block in blocklist:
h = block(h)
h = torch.sum(self.activation(h), [2, 3])
out = self.linear(h)
out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
return out
class G_D(nn.Module):
def __init__(self, G, D):
super(G_D, self).__init__()
self.G = G
self.D = D
def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False,
split_D=False):
with torch.set_grad_enabled(train_G):
G_z = self.G(z, self.G.shared(gy))
if self.G.fp16 and not self.D.fp16:
G_z = G_z.float()
if self.D.fp16 and not self.G.fp16:
G_z = G_z.half()
if split_D:
D_fake = self.D(G_z, gy)
if x is not None:
D_real = self.D(x, dy)
return D_fake, D_real
else:
if return_G_z:
return D_fake, G_z
else:
return D_fake
else:
D_input = torch.cat([G_z, x], 0) if x is not None else G_z
D_class = torch.cat([gy, dy], 0) if dy is not None else gy
D_out = self.D(D_input, D_class)
if x is not None:
return torch.split(D_out, [G_z.shape[0], x.shape[0]])
else:
if return_G_z:
return D_out, G_z
else:
return D_out