import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import init
import layers
def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
arch = {}
arch[512] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
'upsample': [True] * 7,
'resolution': [8, 16, 32, 64, 128, 256, 512],
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
for i in range(3, 10)}}
arch[256] = {'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]],
'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]],
'upsample': [True] * 6,
'resolution': [8, 16, 32, 64, 128, 256],
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
for i in range(3, 9)}}
arch[128] = {'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
'out_channels': [ch * item for item in [16, 8, 4, 2, 1]],
'upsample': [True] * 5,
'resolution': [8, 16, 32, 64, 128],
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
for i in range(3, 8)}}
arch[64] = {'in_channels': [ch * item for item in [16, 16, 8, 4]],
'out_channels': [ch * item for item in [16, 8, 4, 2]],
'upsample': [True] * 4,
'resolution': [8, 16, 32, 64],
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
for i in range(3, 7)}}
arch[32] = {'in_channels': [ch * item for item in [4, 4, 4]],
'out_channels': [ch * item for item in [4, 4, 4]],
'upsample': [True] * 3,
'resolution': [8, 16, 32],
'attention': {2 ** i: (2 ** i in [int(item) for item in attention.split('_')])
for i in range(3, 6)}}
return arch
class Generator(nn.Module):
def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128,
G_kernel_size=3, G_attn='64', n_classes=1000,
num_G_SVs=1, num_G_SV_itrs=1,
G_shared=True, shared_dim=0, hier=False,
cross_replica=False, mybn=False,
G_activation=nn.ReLU(inplace=False),
G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
G_init='ortho', skip_init=False, no_optim=False,
G_param='SN', norm_style='bn',
**kwargs):
super(Generator, self).__init__()
self.ch = G_ch
self.dim_z = dim_z
self.bottom_width = bottom_width
self.resolution = resolution
self.kernel_size = G_kernel_size
self.attention = G_attn
self.n_classes = n_classes
self.G_shared = G_shared
self.shared_dim = shared_dim if shared_dim > 0 else dim_z
self.hier = hier
self.cross_replica = cross_replica
self.mybn = mybn
self.activation = G_activation
self.init = G_init
self.G_param = G_param
self.norm_style = norm_style
self.BN_eps = BN_eps
self.SN_eps = SN_eps
self.fp16 = G_fp16
self.arch = G_arch(self.ch, self.attention)[resolution]
if self.hier:
self.num_slots = len(self.arch['in_channels']) + 1
self.z_chunk_size = (self.dim_z // self.num_slots)
self.dim_z = self.z_chunk_size * self.num_slots
else:
self.num_slots = 1
self.z_chunk_size = 0
if self.G_param == 'SN':
self.which_conv = functools.partial(layers.SNConv2d,
kernel_size=3, padding=1,
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
eps=self.SN_eps)
self.which_linear = functools.partial(layers.SNLinear,
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
eps=self.SN_eps)
else:
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
self.which_linear = nn.Linear
self.which_embedding = nn.Embedding
bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
else self.which_embedding)
self.which_bn = functools.partial(layers.ccbn,
which_linear=bn_linear,
cross_replica=self.cross_replica,
mybn=self.mybn,
input_size=(self.shared_dim + self.z_chunk_size if self.G_shared
else self.n_classes),
norm_style=self.norm_style,
eps=self.BN_eps)
self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
else layers.identity())
self.linear = self.which_linear(self.dim_z // self.num_slots,
self.arch['in_channels'][0] * (self.bottom_width ** 2))
self.blocks = []
for index in range(len(self.arch['out_channels'])):
self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
out_channels=self.arch['out_channels'][index],
which_conv=self.which_conv,
which_bn=self.which_bn,
activation=self.activation,
upsample=(functools.partial(F.interpolate, scale_factor=2)
if self.arch['upsample'][index] else None))]]
if self.arch['attention'][self.arch['resolution'][index]]:
print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
self.output_layer = nn.Sequential(self.activation,
self.which_conv(self.arch['out_channels'][-1], 3))
if not skip_init:
self.init_weights()
if no_optim:
return
self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
if G_mixed_precision:
print('Using fp16 adam in G...')
import utils
self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
betas=(self.B1, self.B2), weight_decay=0,
eps=self.adam_eps)
else:
self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
betas=(self.B1, self.B2), weight_decay=0,
eps=self.adam_eps)
def init_weights(self):
self.param_count = 0
for module in self.modules():
if (isinstance(module, nn.Conv2d)
or isinstance(module, nn.Linear)
or isinstance(module, nn.Embedding)):
if self.init == 'ortho':
init.orthogonal_(module.weight)
elif self.init == 'N02':
init.normal_(module.weight, 0, 0.02)
elif self.init in ['glorot', 'xavier']:
init.xavier_uniform_(module.weight)
else:
print('Init style not recognized...')
self.param_count += sum([p.data.nelement() for p in module.parameters()])
print('Param count for G''s initialized parameters: %d' % self.param_count)
def forward(self, z, y):
if self.hier:
zs = torch.split(z, self.z_chunk_size, 1)
z = zs[0]
ys = [torch.cat([y, item], 1) for item in zs[1:]]
else:
ys = [y] * len(self.blocks)
h = self.linear(z)
h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
for index, blocklist in enumerate(self.blocks):
for block in blocklist:
h = block(h, ys[index])
return torch.tanh(self.output_layer(h))
def D_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
arch = {}
arch[256] = {'in_channels': [3] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
'downsample': [True] * 6 + [False],
'resolution': [128, 64, 32, 16, 8, 4, 4],
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
for i in range(2, 8)}}
arch[128] = {'in_channels': [3] + [ch * item for item in [1, 2, 4, 8, 16]],
'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
'downsample': [True] * 5 + [False],
'resolution': [64, 32, 16, 8, 4, 4],
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
for i in range(2, 8)}}
arch[64] = {'in_channels': [3] + [ch * item for item in [1, 2, 4, 8]],
'out_channels': [item * ch for item in [1, 2, 4, 8, 16]],
'downsample': [True] * 4 + [False],
'resolution': [32, 16, 8, 4, 4],
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
for i in range(2, 7)}}
arch[32] = {'in_channels': [3] + [item * ch for item in [4, 4, 4]],
'out_channels': [item * ch for item in [4, 4, 4, 4]],
'downsample': [True, True, False, False],
'resolution': [16, 16, 16, 16],
'attention': {2 ** i: 2 ** i in [int(item) for item in attention.split('_')]
for i in range(2, 6)}}
return arch
class Discriminator(nn.Module):
def __init__(self, D_ch=64, D_wide=True, resolution=128,
D_kernel_size=3, D_attn='64', n_classes=1000,
num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8,
SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False,
D_init='ortho', skip_init=False, D_param='SN', **kwargs):
super(Discriminator, self).__init__()
self.ch = D_ch
self.D_wide = D_wide
self.resolution = resolution
self.kernel_size = D_kernel_size
self.attention = D_attn
self.n_classes = n_classes
self.activation = D_activation
self.init = D_init
self.D_param = D_param
self.SN_eps = SN_eps
self.fp16 = D_fp16
self.arch = D_arch(self.ch, self.attention)[resolution]
if self.D_param == 'SN':
self.which_conv = functools.partial(layers.SNConv2d,
kernel_size=3, padding=1,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
eps=self.SN_eps)
self.which_linear = functools.partial(layers.SNLinear,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
eps=self.SN_eps)
self.which_embedding = functools.partial(layers.SNEmbedding,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
eps=self.SN_eps)
self.blocks = []
for index in range(len(self.arch['out_channels'])):
self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
out_channels=self.arch['out_channels'][index],
which_conv=self.which_conv,
wide=self.D_wide,
activation=self.activation,
preactivation=(index > 0),
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
if self.arch['attention'][self.arch['resolution'][index]]:
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
self.which_conv)]
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
if not skip_init:
self.init_weights()
self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
if D_mixed_precision:
print('Using fp16 adam in D...')
import utils
self.optim = utils.Adam16(params=self.parameters(), lr=self.lr,
betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
else:
self.optim = optim.Adam(params=self.parameters(), lr=self.lr,
betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
def init_weights(self):
self.param_count = 0
for module in self.modules():
if (isinstance(module, nn.Conv2d)
or isinstance(module, nn.Linear)
or isinstance(module, nn.Embedding)):
if self.init == 'ortho':
init.orthogonal_(module.weight)
elif self.init == 'N02':
init.normal_(module.weight, 0, 0.02)
elif self.init in ['glorot', 'xavier']:
init.xavier_uniform_(module.weight)
else:
print('Init style not recognized...')
self.param_count += sum([p.data.nelement() for p in module.parameters()])
print('Param count for D''s initialized parameters: %d' % self.param_count)
def forward(self, x, y=None):
h = x
for index, blocklist in enumerate(self.blocks):
for block in blocklist:
h = block(h)
h = torch.sum(self.activation(h), [2, 3])
out = self.linear(h)
out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
return out
class G_D(nn.Module):
def __init__(self, G, D):
super(G_D, self).__init__()
self.G = G
self.D = D
def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False,
split_D=False):
with torch.set_grad_enabled(train_G):
G_z = self.G(z, self.G.shared(gy))
if self.G.fp16 and not self.D.fp16:
G_z = G_z.float()
if self.D.fp16 and not self.G.fp16:
G_z = G_z.half()
if split_D:
D_fake = self.D(G_z, gy)
if x is not None:
D_real = self.D(x, dy)
return D_fake, D_real
else:
if return_G_z:
return D_fake, G_z
else:
return D_fake
else:
D_input = torch.cat([G_z, x], 0) if x is not None else G_z
D_class = torch.cat([gy, dy], 0) if dy is not None else gy
D_out = self.D(D_input, D_class)
if x is not None:
return torch.split(D_out, [G_z.shape[0], x.shape[0]])
else:
if return_G_z:
return D_out, G_z
else:
return D_out