from model import Generator
from model import Discriminator
from torch.autograd import Variable
from torchvision.utils import save_image
import torch
import torch.nn.functional as F
import numpy as np
import os
import time
import datetime
class Solver(object):
"""Solver for training and testing StarGAN."""
def __init__(self, celeba_loader, rafd_loader, config):
"""Initialize configurations."""
self.celeba_loader = celeba_loader
self.rafd_loader = rafd_loader
self.c_dim = config.c_dim
self.c2_dim = config.c2_dim
self.image_size = config.image_size
self.g_conv_dim = config.g_conv_dim
self.d_conv_dim = config.d_conv_dim
self.g_repeat_num = config.g_repeat_num
self.d_repeat_num = config.d_repeat_num
self.lambda_cls = config.lambda_cls
self.lambda_rec = config.lambda_rec
self.lambda_gp = config.lambda_gp
self.dataset = config.dataset
self.batch_size = config.batch_size
self.num_iters = config.num_iters
self.num_iters_decay = config.num_iters_decay
self.g_lr = config.g_lr
self.d_lr = config.d_lr
self.n_critic = config.n_critic
self.beta1 = config.beta1
self.beta2 = config.beta2
self.resume_iters = config.resume_iters
self.selected_attrs = config.selected_attrs
self.test_iters = config.test_iters
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.log_dir = config.log_dir
self.sample_dir = config.sample_dir
self.model_save_dir = config.model_save_dir
self.result_dir = config.result_dir
self.log_step = config.log_step
self.sample_step = config.sample_step
self.model_save_step = config.model_save_step
self.lr_update_step = config.lr_update_step
self.build_model()
def build_model(self):
"""Create a generator and a discriminator."""
self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
self.G.to(self.device)
def restore_model(self, resume_iters):
"""Restore the trained generator and discriminator."""
print('Loading the trained models from step {}...'.format(resume_iters))
G_path = os.path.join(self.model_save_dir, '{}-G.pth'.format(resume_iters))
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
def denorm(self, x):
"""Convert the range from [-1, 1] to [0, 1]."""
out = (x + 1) / 2
return out.clamp_(0, 1)
def label2onehot(self, labels, dim):
"""Convert label indices to one-hot vectors."""
batch_size = labels.size(0)
out = torch.zeros(batch_size, dim)
out[np.arange(batch_size), labels.long()] = 1
return out
def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None):
"""Generate target domain labels for testing."""
if dataset == 'CelebA':
hair_color_indices = []
for i, attr_name in enumerate(selected_attrs):
if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
hair_color_indices.append(i)
c_trg_list = []
for i in range(c_dim):
if dataset == 'CelebA':
c_trg = c_org.clone()
if i in hair_color_indices:
c_trg[:, i] = 1
for j in hair_color_indices:
if j != i:
c_trg[:, j] = 0
else:
c_trg[:, i] = (c_trg[:, i] == 0)
elif dataset == 'RaFD':
c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
c_trg_list.append(c_trg.to(self.device))
return c_trg_list
def test(self):
"""Translate images using StarGAN trained on a single dataset."""
if not os.path.exists('./bin/attr'):
os.makedirs('./bin/attr')
if not os.path.exists('./bin/img'):
os.makedirs('./bin/img')
if not os.path.exists(self.result_dir):
os.makedirs(self.result_dir)
self.restore_model(self.test_iters)
data_loader = self.celeba_loader
with torch.no_grad():
cnt = 0
for i, (x_real, c_org) in enumerate(data_loader):
x_real = x_real.to(self.device)
c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
x_fake_list = []
for c_trg in c_trg_list:
x_new = self.G(x_real, c_trg)
x_fake_list.append(x_new)
x_real.numpy().tofile("./bin/img" + "/%d.bin" % cnt)
c_trg.numpy().tofile("./bin/attr" + "/%d.bin" % cnt)
print('Saved bin into ./bin/img//%d.bin...' % cnt)
print('Saved bin into ./bin/attr//%d.bin...' % cnt)
cnt = cnt + 1
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
print('Saved real and fake images into {}...'.format(result_path))
if cnt >= 319: break