59e460bd创建于 2023年2月10日历史提交
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the BSD 3-Clause License  (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
if torch.__version__ >= "1.8":
    import torch_npu
import os
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
import numpy as np
from models import networks
"""!!!!!!!!!!!!!!!修改的地方!!!!!!!!!!!!!!!!!!1"""
LOCAL_RANK = int(os.getenv('LOCAL_RANK', 0))  # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', 0))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
"""!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
# os.environ['MASTER_ADDR'] = 'localhost'
# os.environ['MASTER_PORT'] = '5678'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29128'


# python pix2pix_pth2onnx.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_label2photo_pretrained --save_onnx True
if __name__ == '__main__':
    opt = TestOptions().parse()  # get test options
    """!!!!!!!!!!!!!!!修改的地方!!!!!!!!!!!!!!!!!!1"""
    """!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
    # torch.distributed.init_process_group(backend="nccl", rank=RANK, world_size=WORLD_SIZE)
    torch.distributed.init_process_group(backend="hccl", rank=RANK, world_size=WORLD_SIZE)
    print(f"[init] == local rank: {LOCAL_RANK}, global rank: {RANK} , world size: {WORLD_SIZE}")

    # hard-code some parameters for test
    opt.num_threads = 0   # test code only supports num_threads = 0
    opt.batch_size = 1    # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.
    opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.
    # dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    model = create_model(opt)      # create a model given opt.model and other options
    model.setup(opt)               # regular setup: load and print networks; create schedulers
    # create a website
    # web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch))  # define the website directory
    # if opt.load_iter > 0:  # load_iter is 0 by default
    #     web_dir = '{:s}_iter{:d}'.format(web_dir, opt.load_iter)
    # print('creating web directory', web_dir)
    # webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
    # test with eval mode. This only affects layers like batchnorm and dropout.
    # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
    # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
    model.eval()
    b, c, h, w = 1 ,3, 256, 256
    dummy_input = torch.randn(b, c, h, w, requires_grad=True)
    model.save_onnx(dummy_input)