# -*- coding: utf-8 -*-
# BSD 3-Clause License
#
# Copyright (c) 2017
# All rights reserved.
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# ==========================================================================
import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
import apex
class Pix2PixHDModel(BaseModel):
def name(self):
return 'Pix2PixHDModel'
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)
def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
return loss_filter
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
# torch.backends.cudnn.benchmark = True
pass
self.isTrain = opt.isTrain
self.use_features = opt.instance_feat or opt.label_feat
self.gen_features = self.use_features and not self.opt.load_features
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
##### define networks
# Generator network
netG_input_nc = input_nc
if not opt.no_instance:
netG_input_nc += 1
if self.use_features:
netG_input_nc += opt.feat_num
# 在模型的初始化函数中有这行代码实现了模型的构建,出入的参数中有这个gpu_ids=self.gpu_ids
# 在模型的构建过程中 那个也是通过这个参数实现了模型的迁移过程。===>networks.py
# self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
# opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
# opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
opt.n_blocks_local, opt.norm, device=self.device)
# Discriminator network
if self.isTrain:
use_sigmoid = opt.no_lsgan
netD_input_nc = input_nc + opt.output_nc
if not opt.no_instance:
netD_input_nc += 1
# self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
# opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
opt.num_D, not opt.no_ganFeat_loss, device=self.device)
### Encoder network
if self.gen_features:
# self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
# opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)
self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
opt.n_downsample_E, norm=opt.norm, device=self.device)
if self.opt.verbose:
print('---------- Networks initialized -------------')
# load networks
if not self.isTrain or opt.continue_train or opt.load_pretrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
if self.isTrain:
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
if self.gen_features:
self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)
# set loss functions and optimizers
if self.isTrain:
# if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
# raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
if opt.pool_size > 0:
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
self.fake_pool = ImagePool(opt.pool_size)
self.old_lr = opt.lr
# define loss functions
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
# ???这里的这个损失函数也需要改动
# self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor, device = self.device)
# self.criterionFeat = torch.nn.L1Loss()
self.criterionFeat = torch.nn.L1Loss().to(self.device)
# ???这个损失也要改
if not opt.no_vgg_loss:
# self.criterionVGG = networks.VGGLoss(self.gpu_ids)
self.criterionVGG = networks.VGGLoss(self.device)
# self.criterionVGG = networks.VGGLoss("cpu")
# print("self.criterionVGG:", self.criterionVGG.device)
# Names so we can breakout loss
self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')
# initialize optimizers
# optimizer G
if opt.niter_fix_global > 0:
import sys
if sys.version_info >= (3,0):
finetune_list = set()
else:
from sets import Set
finetune_list = Set()
params_dict = dict(self.netG.named_parameters())
params = []
for key, value in params_dict.items():
if key.startswith('model' + str(opt.n_local_enhancers)):
params += [value]
finetune_list.add(key.split('.')[0])
print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
print('The layers that are finetuned are ', sorted(finetune_list))
else:
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
# 这里定义了生成器的优化器
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# self.optimizer_G = apex.optimizers.NpuFusedAdam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# optimizer D
params = list(self.netD.parameters())
# 这里定义了判别器的优化器
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# self.optimizer_D = apex.optimizers.NpuFusedAdam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, device = None, infer=False):
if self.opt.label_nc == 0:
# input_label = label_map.data.cuda()
input_label = label_map.data.to(device)
else:
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
# input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = torch.Tensor(torch.Size(oneHot_size)).zero_().to(device)
# input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
input_label = input_label.scatter_(1, label_map.data.long().to(device), 1.0)
if self.opt.data_type == 16:
input_label = input_label.half()
# get edges from instance map
if not self.opt.no_instance:
# inst_map = inst_map.data.cuda()
inst_map = inst_map.data.to(device)
# edge_map = self.get_edges(inst_map)
edge_map = self.get_edges(inst_map, device)
input_label = torch.cat((input_label, edge_map), dim=1)
input_label = Variable(input_label, volatile=infer)
# real images for training
if real_image is not None:
# real_image = Variable(real_image.data.cuda())
real_image = Variable(real_image.data.to(device))
# instance map for feature encoding
if self.use_features:
# get precomputed feature maps
if self.opt.load_features:
# feat_map = Variable(feat_map.data.cuda())
feat_map = Variable(feat_map.data.to(device))
if self.opt.label_feat:
# inst_map = label_map.cuda()
inst_map = label_map.to(device)
return input_label, inst_map, real_image, feat_map
def discriminate(self, input_label, test_image, use_pool=False):
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
if use_pool:
fake_query = self.fake_pool.query(input_concat)
return self.netD.forward(fake_query)
else:
return self.netD.forward(input_concat)
def forward(self, label, inst, image, feat, device = None, infer=False):
# Encode Inputs
input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat, device)
# Fake Generation
if self.use_features:
if not self.opt.load_features:
feat_map = self.netE.forward(real_image, inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
fake_image = self.netG.forward(input_concat)
# Fake Detection and Loss
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(input_label, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
# 这里197
loss_G_GAN = self.criterionGAN(pred_fake, True)
# GAN feature matching loss
loss_G_GAN_Feat = 0
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i])-1):
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
# VGG feature matching loss
loss_G_VGG = 0
if not self.opt.no_vgg_loss:
fake_image = fake_image.float()
# print("fack_image:", fake_image.shape, fake_image.device, fake_image.dtype)
# print("real_image:", real_image.shape, real_image.device, real_image.dtype)
# print("self.opt.lambda_feat:", self.opt.lambda_feat, type(self.opt.lambda_feat))
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
# loss_G_VGG = self.criterionVGG(fake_image.cpu(), real_image.cpu()).to(device) * self.opt.lambda_feat
# print("loss_G_VGG:", loss_G_VGG.shape, loss_G_VGG.device, loss_G_VGG.dtype)
# Only return the fake_B image if necessary to save BW
return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]
# 这个在推断的时候还需要修改
def inference(self, label, inst, image=None, device=None):
# Encode Inputs
image = Variable(image) if image is not None else None
input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, device = device, infer=True)
# Fake Generation
if self.use_features:
if self.opt.use_encoded_image:
# encode the real image to get feature map
feat_map = self.netE.forward(real_image, inst_map)
else:
# sample clusters from precomputed features
feat_map = self.sample_features(inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
if torch.__version__.startswith('0.4'):
with torch.no_grad():
fake_image = self.netG.forward(input_concat)
else:
fake_image = self.netG.forward(input_concat)
return fake_image
def sample_features(self, inst):
# read precomputed feature clusters
cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
features_clustered = np.load(cluster_path, encoding='latin1').item()
# randomly sample from the feature clusters
inst_np = inst.cpu().numpy().astype(int)
feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
for i in np.unique(inst_np):
label = i if i < 1000 else i//1000
if label in features_clustered:
feat = features_clustered[label]
cluster_idx = np.random.randint(0, feat.shape[0])
idx = (inst == int(i)).nonzero()
for k in range(self.opt.feat_num):
feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
if self.opt.data_type==16:
feat_map = feat_map.half()
return feat_map
def encode_features(self, image, inst):
image = Variable(image.cuda(), volatile=True)
feat_num = self.opt.feat_num
h, w = inst.size()[2], inst.size()[3]
block_num = 32
feat_map = self.netE.forward(image, inst.cuda())
inst_np = inst.cpu().numpy().astype(int)
feature = {}
for i in range(self.opt.label_nc):
feature[i] = np.zeros((0, feat_num+1))
for i in np.unique(inst_np):
label = i if i < 1000 else i//1000
idx = (inst == int(i)).nonzero()
num = idx.size()[0]
idx = idx[num//2,:]
val = np.zeros((1, feat_num+1))
for k in range(feat_num):
val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
val[0, feat_num] = float(num) / (h * w // block_num)
feature[label] = np.append(feature[label], val, axis=0)
return feature
# def get_edges(self, t, device):
# # 首先将t转到cpu上
# t = t.cpu()
# # print("t.device:", t.device)
# # edge = torch.cuda.ByteTensor(t.size()).zero_()
# edge = torch.ByteTensor(t.size()).zero_()
# # print("edge数据的dtype类型:", edge.type()) # torch.npu.ByteTensor # torch.ByteTensor
# edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
# edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
# edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
# edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
# if self.opt.data_type==16:
# return edge.half()
# else:
# # print("before edge.device:", edge.device)
# edge = edge.float()
# edge = edge.to(device)
# # print("最终返回的edge的数据类型和设备:", edge.type(), edge.device, edge.dtype)
# # print("after edge.device:", edge.device)
# # return edge.float()
# return edge
def get_edges(self, t, device):
# 首先将t转到cpu上
t = t.cpu()
# print("t.device:", t.device)
# edge = torch.cuda.ByteTensor(t.size()).zero_()
edge = torch.ByteTensor(t.size()).zero_()
# print("edge数据的dtype类型:", edge.type()) # torch.npu.ByteTensor # torch.ByteTensor
edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
if self.opt.data_type==16:
return edge.half()
else:
# print("before edge.device:", edge.device)
edge = edge.float()
edge = edge.to(device)
# print("最终返回的edge的数据类型和设备:", edge.type(), edge.device, edge.dtype)
# print("after edge.device:", edge.device)
# return edge.float()
return edge
def save(self, which_epoch):
# self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
self.save_network(self.netG, 'G', which_epoch, self.device)
# self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
self.save_network(self.netD, 'D', which_epoch, self.device)
if self.gen_features:
# self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
self.save_network(self.netE, 'E', which_epoch, self.device)
def update_fixed_params(self):
# after fixing the global generator for a number of iterations, also start finetuning it
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
class InferenceModel(Pix2PixHDModel):
def forward(self, inp):
label, inst = inp
return self.inference(label, inst)