# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# BSD 3-Clause License
#
# Copyright (c) 2017 xxxx
# All rights reserved.
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# ============================================================================
import os
import sys
import time
sys.path.append('./model')
import numpy as np
import torch
import pickle
import argparse
import PIL.Image
import functools
from tqdm import tqdm
def save_image_grid(img, fname, drange, grid_size):
lo, hi = drange
img = np.asarray(img, dtype=np.float32)
img = (img - lo) * (255 / (hi - lo))
img = np.rint(img).clip(0, 255).astype(np.uint8)
gw, gh = grid_size
_N, C, H, W = img.shape
img = img.reshape(gh, gw, C, H, W)
img = img.transpose(0, 3, 1, 4, 2)
img = img.reshape(gh * H, gw * W, C)
assert C in [1, 3]
if C == 1:
PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
if C == 3:
PIL.Image.fromarray(img, 'RGB').save(fname)
def main(args):
pkl_file = args.pkl_file
bs = args.batch_size
input_path = args.input_path
image_path = args.image_path
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
grid_size = (1, 1)
input_files = os.listdir(input_path)
input_files.sort()
image_path = os.path.join(image_path, 'pkl_img')
os.makedirs(image_path, exist_ok=True)
# load model
start = time.time()
with open(pkl_file, 'rb') as f:
G = pickle.load(f)['G_ema'].to(device)
G.forward = functools.partial(G.forward, force_fp32=True)
for i in tqdm(range(len(input_files))):
input_file = input_files[i]
input_file = os.path.join(input_path, input_file)
input_file = np.fromfile(input_file, dtype=np.float32)
z = torch.tensor(input_file).reshape(-1, G.z_dim).to(device)
c = torch.empty(bs, 0).to(device)
image = G(z, c)
image = image.reshape(-1, 3, 512, 512)
image = image.cpu()
save_image_grid(image, os.path.join(image_path, f'gen_image_{i:04d}') + '.png', drange=[-1, 1],
grid_size=grid_size)
end = time.time()
print(f'Inference average time : {((end - start) * 1000 / len(input_files)):.2f} ms')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pkl_file', type=str, default='./G_ema_bs8_8p_kimg1000.pkl')
parser.add_argument('--input_path', type=str, default='./pre_data')
parser.add_argument('--image_path', type=str, default='./results')
parser.add_argument('--batch_size', type=int, default=1)
args = parser.parse_args()
main(args)