Only in ./origin: .git
Only in ./origin: .gitignore
Only in ./origin: .replit
Only in ./origin: CycleGAN.ipynb
Only in ./origin: LICENSE
diff -u -r '--exclude=*.md' ./origin/data/__init__.py ./fix/data/__init__.py
--- ./origin/data/__init__.py	2023-06-28 08:07:10.643967613 +0000
+++ ./fix/data/__init__.py	2023-06-28 08:25:08.411980818 +0000
@@ -1,3 +1,4 @@
+
 """This package includes all the modules related to data loading and preprocessing
 
  To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
@@ -11,9 +12,11 @@
 See our template dataset class 'template_dataset.py' for more details.
 """
 import importlib
+""""""
 import torch.utils.data
 from data.base_dataset import BaseDataset
-
+import os
+WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
 
 def find_dataset_using_name(dataset_name):
     """Import the module "data/[dataset_name]_dataset.py".
@@ -72,11 +75,30 @@
         dataset_class = find_dataset_using_name(opt.dataset_mode)
         self.dataset = dataset_class(opt)
         print("dataset [%s] was created" % type(self.dataset).__name__)
+
+        """!!!!!!!!!!!!!!!修改的地方!!!!!!!!!!!!!!!!!!1"""
+        """!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+        # 这里需要设置shuffle=False,然后在每个epoch前,通过调用train_sampler.set_epoch(epoch)来达到shuffle的效果.
+        self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
         self.dataloader = torch.utils.data.DataLoader(
             self.dataset,
-            batch_size=opt.batch_size,
-            shuffle=not opt.serial_batches,
-            num_workers=int(opt.num_threads))
+            batch_size=opt.batch_size,  # 因为使用了多GPU所以一定要除WORLD_SIZE
+            # 修改的内容 如果设置了数据采集器,则shuffle参数为False
+            shuffle=(self.train_sampler is None),
+            # shuffle=not opt.serial_batches,
+
+            num_workers=int(opt.num_threads),
+            # 添加的地方 ,下面这三行
+            pin_memory=False,
+            sampler=self.train_sampler if opt.isTrain else None,
+            drop_last=True
+        )
+
+        # self.dataloader = torch.utils.data.DataLoader(
+        #     self.dataset,
+        #     batch_size=opt.batch_size,
+        #     shuffle=not opt.serial_batches,
+        #     num_workers=int(opt.num_threads))
 
     def load_data(self):
         return self
diff -u -r '--exclude=*.md' ./origin/data/aligned_dataset.py ./fix/data/aligned_dataset.py
--- ./origin/data/aligned_dataset.py	2023-06-28 08:07:10.643967613 +0000
+++ ./fix/data/aligned_dataset.py	2023-06-28 08:25:08.439980818 +0000
@@ -1,3 +1,4 @@
+
 import os
 from data.base_dataset import BaseDataset, get_params, get_transform
 from data.image_folder import make_dataset
diff -u -r '--exclude=*.md' ./origin/data/base_dataset.py ./fix/data/base_dataset.py
--- ./origin/data/base_dataset.py	2023-06-28 08:07:10.643967613 +0000
+++ ./fix/data/base_dataset.py	2023-06-28 08:25:08.423980818 +0000
@@ -1,3 +1,4 @@
+
 """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
 
 It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
diff -u -r '--exclude=*.md' ./origin/data/colorization_dataset.py ./fix/data/colorization_dataset.py
--- ./origin/data/colorization_dataset.py	2023-06-28 08:07:10.643967613 +0000
+++ ./fix/data/colorization_dataset.py	2023-06-28 08:25:08.427980818 +0000
@@ -1,3 +1,4 @@
+
 import os
 from data.base_dataset import BaseDataset, get_transform
 from data.image_folder import make_dataset
diff -u -r '--exclude=*.md' ./origin/data/image_folder.py ./fix/data/image_folder.py
--- ./origin/data/image_folder.py	2023-06-28 08:07:10.643967613 +0000
+++ ./fix/data/image_folder.py	2023-06-28 08:25:08.419980818 +0000
@@ -1,3 +1,4 @@
+
 """A modified image folder class
 
 We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
diff -u -r '--exclude=*.md' ./origin/data/single_dataset.py ./fix/data/single_dataset.py
--- ./origin/data/single_dataset.py	2023-06-28 08:07:10.643967613 +0000
+++ ./fix/data/single_dataset.py	2023-06-28 08:25:08.443980818 +0000
@@ -1,3 +1,4 @@
+
 from data.base_dataset import BaseDataset, get_transform
 from data.image_folder import make_dataset
 from PIL import Image
diff -u -r '--exclude=*.md' ./origin/data/template_dataset.py ./fix/data/template_dataset.py
--- ./origin/data/template_dataset.py	2023-06-28 08:07:10.647967613 +0000
+++ ./fix/data/template_dataset.py	2023-06-28 08:25:08.435980818 +0000
@@ -1,3 +1,4 @@
+
 """Dataset class template
 
 This module provides a template for users to implement custom datasets.
diff -u -r '--exclude=*.md' ./origin/data/unaligned_dataset.py ./fix/data/unaligned_dataset.py
--- ./origin/data/unaligned_dataset.py	2023-06-28 08:07:10.647967613 +0000
+++ ./fix/data/unaligned_dataset.py	2023-06-28 08:25:08.407980818 +0000
@@ -1,3 +1,4 @@
+
 import os
 from data.base_dataset import BaseDataset, get_transform
 from data.image_folder import make_dataset
Only in ./origin: docs
Only in ./origin: imgs
diff -u -r '--exclude=*.md' ./origin/models/base_model.py ./fix/models/base_model.py
--- ./origin/models/base_model.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/models/base_model.py	2023-06-28 08:25:08.471980819 +0000
@@ -1,8 +1,10 @@
 import os
-import torch
+"""!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+import torch.npu
 from collections import OrderedDict
 from abc import ABC, abstractmethod
 from . import networks
+LOCAL_RANK = int(os.getenv('LOCAL_RANK', 0))  # https://pytorch.org/docs/stable/elastic/run.html
 
 
 class BaseModel(ABC):
@@ -32,7 +34,16 @@
         self.opt = opt
         self.gpu_ids = opt.gpu_ids
         self.isTrain = opt.isTrain
-        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')  # get device name: CPU or GPU
+
+        """!!!!!!!!!!!!!!!修改的地方!!!!!!!!!!!!!!!!!!1"""
+        # self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')  # get device name: CPU or GPU
+        # self.device = torch.device('cuda:{}'.format(LOCAL_RANK)) if self.gpu_ids else torch.device('cpu')
+        """!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+        self.device = torch.device('npu:{}'.format(LOCAL_RANK)) 
+
+
+
+
         self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)  # save all the checkpoints to save_dir
         if opt.preprocess != 'scale_width':  # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
             torch.backends.cudnn.benchmark = True
@@ -95,13 +106,6 @@
                 net = getattr(self, 'net' + name)
                 net.eval()
 
-    def train(self):
-        for name in self.model_names:
-            if isinstance(name, str):
-                net = getattr(self, 'net' + name)
-                net.train()
-
-
     def test(self):
         """Forward function used in test time.
 
@@ -162,10 +166,16 @@
 
                 if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                     torch.save(net.module.cpu().state_dict(), save_path)
-                    net.cuda(self.gpu_ids[0])
+                    """!!!!!!!!!!!!!!!修改的地方!!!!!!!!!!!!!!!!!!1"""
+                    net.cuda(self.device)
+                    # net.cuda(self.gpu_ids[0])
+                elif len(self.gpu_ids) > 0 and torch.npu.is_available():
+                    """!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+                    torch.save(net.module.cpu().state_dict(), save_path)
+                    net.to(self.device)     
                 else:
                     torch.save(net.cpu().state_dict(), save_path)
-
+ 
     def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
         """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
         key = keys[i]
@@ -191,8 +201,12 @@
                 load_filename = '%s_net_%s.pth' % (epoch, name)
                 load_path = os.path.join(self.save_dir, load_filename)
                 net = getattr(self, 'net' + name)
-                if isinstance(net, torch.nn.DataParallel):
+                """!!!!!!!!!!!!!!!修改的地方!!!!!!!!!!!!!!!!!!1"""
+                if isinstance(net, torch.nn.parallel.DistributedDataParallel):
                     net = net.module
+
+                # if isinstance(net, torch.nn.DataParallel):
+                #     net = net.module
                 print('loading the model from %s' % load_path)
                 # if you are using PyTorch newer than 0.4 (e.g., built from
                 # GitHub source), you can remove str() on self.device
@@ -223,6 +237,28 @@
                 print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
         print('-----------------------------------------------')
 
+    """!!!!!!!!!!!!!!!修改的地方!!!!!!!!!!!!!!!!!!1"""
+    def save_onnx(self, dummy_input):
+          # 导出onnx模型
+        input_names = ['inputs']  # 输入名字
+        output_names = ['outputs']  # 输出名字
+        dummy_input = torch.randn_like(dummy_input)
+        """!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+        dynamic_axes = {'inputs': {0: '-1'}, 'outputs': {0: '-1'}}
+
+        for name in self.model_names:  # 只有netG
+            if isinstance(name, str):
+                load_filename = 'net' + name + '_onnx.onnx'
+                save_path = os.path.join(self.save_dir, load_filename)
+                net = getattr(self, 'net' + name)
+                if isinstance(net, torch.nn.parallel.DistributedDataParallel):
+                    net = net.module
+                """!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+                torch.onnx.export(net.cpu(), dummy_input, save_path,
+                                  input_names=input_names,dynamic_axes = dynamic_axes, output_names=output_names,
+                                  opset_version=11, verbose=True)
+        print('save onnx model on ', save_path)
+
     def set_requires_grad(self, nets, requires_grad=False):
         """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
         Parameters:
diff -u -r '--exclude=*.md' ./origin/models/colorization_model.py ./fix/models/colorization_model.py
--- ./origin/models/colorization_model.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/models/colorization_model.py	2023-06-28 08:25:08.451980818 +0000
@@ -1,5 +1,6 @@
 from .pix2pix_model import Pix2PixModel
-import torch
+"""!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+import torch.npu
 from skimage import color  # used for lab2rgb
 import numpy as np
 
diff -u -r '--exclude=*.md' ./origin/models/cycle_gan_model.py ./fix/models/cycle_gan_model.py
--- ./origin/models/cycle_gan_model.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/models/cycle_gan_model.py	2023-06-28 08:25:08.487980819 +0000
@@ -1,4 +1,5 @@
-import torch
+"""!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+import torch.npu
 import itertools
 from util.image_pool import ImagePool
 from .base_model import BaseModel
diff -u -r '--exclude=*.md' ./origin/models/networks.py ./fix/models/networks.py
--- ./origin/models/networks.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/models/networks.py	2023-06-28 08:25:08.479980819 +0000
@@ -1,9 +1,11 @@
-import torch
+"""!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+import torch.npu
 import torch.nn as nn
 from torch.nn import init
 import functools
 from torch.optim import lr_scheduler
-
+import os
+LOCAL_RANK = int(os.getenv('LOCAL_RANK', 0))  # https://pytorch.org/docs/stable/elastic/run.html
 
 ###############################################################################
 # Helper Functions
@@ -59,6 +61,13 @@
         scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
     elif opt.lr_policy == 'cosine':
         scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
+    elif opt.lr_policy == 'warm_up':
+        """!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+        import math
+        warm_up_epochs = 5
+        warm_up_with_cosine_lr = lambda epoch: (epoch + 1) / warm_up_epochs if epoch < warm_up_epochs \
+            else 0.5 * (math.cos((epoch - warm_up_epochs) / (opt.n_epochs+opt.n_epochs_decay- warm_up_epochs) * math.pi) + 1)
+        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_cosine_lr)
     else:
         return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
     return scheduler
@@ -109,9 +118,15 @@
     Return an initialized network.
     """
     if len(gpu_ids) > 0:
-        assert(torch.cuda.is_available())
-        net.to(gpu_ids[0])
-        net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
+        """!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+        # assert(torch.cuda.is_available())
+        """!!!!!!!!!!!!!!!修改的地方!!!!!!!!!!!!!!!!!!1"""
+        # device = torch.device("cuda", LOCAL_RANK)
+        """!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+        device = torch.device("npu", LOCAL_RANK)
+        net.to(device)
+        # net.to(gpu_ids[0])
+        # net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
     init_weights(net, init_type, init_gain=init_gain)
     return net
 
diff -u -r '--exclude=*.md' ./origin/models/pix2pix_model.py ./fix/models/pix2pix_model.py
--- ./origin/models/pix2pix_model.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/models/pix2pix_model.py	2023-06-28 08:25:08.491980819 +0000
@@ -1,7 +1,12 @@
-import torch
+"""!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+import torch.npu
 from .base_model import BaseModel
 from . import networks
-
+#from apex import amp
+#import apex
+import os
+from torch.nn.parallel import DistributedDataParallel as DDP
+LOCAL_RANK = int(os.getenv('LOCAL_RANK', 0))  # https://pytorch.org/docs/stable/elastic/run.html
 
 class Pix2PixModel(BaseModel):
     """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
@@ -53,9 +58,12 @@
         else:  # during test time, only load G
             self.model_names = ['G']
         # define networks (both generator and discriminator)
+  
         self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                       not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
+      
 
+        # print("!!!!!!!!!!!!!!!!!!!!!!!!!!")
         if self.isTrain:  # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
             self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                                           opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
@@ -65,11 +73,28 @@
             self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
             self.criterionL1 = torch.nn.L1Loss()
             # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
-            self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
-            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
+            """!!!!!!!!!性能!!!!!!!!"""
+            # self.optimizer_G = apex.optimizers.NpuFusedAdam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))   
+            # self.optimizer_D = apex.optimizers.NpuFusedAdam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))   
+            # self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
+            # self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
+
             self.optimizers.append(self.optimizer_G)
             self.optimizers.append(self.optimizer_D)
-
+            """!!!!!!!!!!!!!!!修改的地方!!!!!!!!!!!!!!!!!!1"""
+            """!!!!!!!!!性能!!!!!!!!"""
+            # self.netG, self.optimizer_G = amp.initialize(self.netG, self.optimizer_G, opt_level='O2', loss_scale=32.0, combine_grad=True)
+            # self.netD, self.optimizer_D = amp.initialize(self.netD, self.optimizer_D, opt_level='O2', loss_scale=32.0, combine_grad=True)
+            # self.netG, self.optimizer_G = amp.initialize(self.netG, self.optimizer_G, opt_level="O2",loss_scale=128.0)
+            # self.netD, self.optimizer_D = amp.initialize(self.netD, self.optimizer_D, opt_level="O2",loss_scale=128.0)
+
+            """!!!!!!!!!!!!!!!修改的地方!!!!!!!!!!!!!!!!!!1"""
+            """!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+        #     self.netD = DDP(self.netD, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK,broadcast_buffers=False)
+        # self.netG = DDP(self.netG, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK,broadcast_buffers=False)
+        #     # self.netD = torch.nn.DataParallel(self.netD,self.gpu_ids)  # multi-GPUs 改成distribute那种,再model zoo 上面参考https://gitcode.com/ascend/modelzoo/tree/master/contrib/PyTorch/Official/cv/image_classification
+        # 解决Warning->GPU多卡(prof文件性能文件 pytorch里面有性能测试工具,log文件)->NPU 1P(跑通)->NPU 8P(也按200epoch对齐GPU loss)
+        # self.netG = torch.nn.DataParallel(self.netG, self.gpu_ids)  # multi-GPUs
     def set_input(self, input):
         """Unpack input data from the dataloader and perform necessary pre-processing steps.
 
@@ -99,6 +124,8 @@
         self.loss_D_real = self.criterionGAN(pred_real, True)
         # combine loss and calculate gradients
         self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
+        #with amp.scale_loss(self.loss_D, self.optimizer_D) as scaled_loss:
+        #    scaled_loss.backward()
         self.loss_D.backward()
 
     def backward_G(self):
@@ -111,6 +138,8 @@
         self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
         # combine loss and calculate gradients
         self.loss_G = self.loss_G_GAN + self.loss_G_L1
+        #with amp.scale_loss(self.loss_G, self.optimizer_G) as scaled_loss:
+        #    scaled_loss.backward()
         self.loss_G.backward()
 
     def optimize_parameters(self):
diff -u -r '--exclude=*.md' ./origin/models/template_model.py ./fix/models/template_model.py
--- ./origin/models/template_model.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/models/template_model.py	2023-06-28 08:25:08.455980818 +0000
@@ -1,3 +1,4 @@
+
 """Model class template
 
 This module provides a template for users to implement custom models.
@@ -15,7 +16,8 @@
     <forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
     <optimize_parameters>: Update network weights; it will be called in every training iteration.
 """
-import torch
+"""!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+import torch.npu
 from .base_model import BaseModel
 from . import networks
 
diff -u -r '--exclude=*.md' ./origin/options/base_options.py ./fix/options/base_options.py
--- ./origin/options/base_options.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/options/base_options.py	2023-06-28 08:25:08.247980816 +0000
@@ -1,9 +1,12 @@
 import argparse
 import os
 from util import util
-import torch
+"""!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+import torch.npu
 import models
 import data
+"""!!!!!!!!!!!!!!!npu8p修改的地方!!!!!!!!!!!!!!!!!!1"""
+LOCAL_RANK = int(os.getenv('LOCAL_RANK', 0)) 
 
 
 class BaseOptions():
@@ -20,7 +23,8 @@
     def initialize(self, parser):
         """Define the common options that are used in both training and test."""
         # basic parameters
-        parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
+        """推理提交修改"""
+        parser.add_argument('--dataroot',  help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
         parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
         parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
         parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
@@ -52,7 +56,7 @@
         # additional parameters
         parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
         parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
-        parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
+        parser.add_argument('--verbose', action='store_true', help='if specified, print more information')
         parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
         self.initialized = True
         return parser
@@ -114,7 +118,8 @@
         """Parse our options, create checkpoints directory suffix, and set up gpu device."""
         opt = self.gather_options()
         opt.isTrain = self.isTrain   # train or test
-
+        """!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+        CALCULATE_DEVICE = "npu:0"
         # process opt.suffix
         if opt.suffix:
             suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
@@ -130,7 +135,11 @@
             if id >= 0:
                 opt.gpu_ids.append(id)
         if len(opt.gpu_ids) > 0:
-            torch.cuda.set_device(opt.gpu_ids[0])
-
+            """!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+            # torch.cuda.set_device(opt.gpu_ids[0])
+            """!!!!!!!!!!!!!!!npu8p修改的地方!!!!!!!!!!!!!!!!!!1"""
+            loc = 'npu:{}'.format(LOCAL_RANK)
+            torch.npu.set_device(loc)
+            # torch.npu.set_device(CALCULATE_DEVICE)
         self.opt = opt
         return self.opt
diff -u -r '--exclude=*.md' ./origin/options/test_options.py ./fix/options/test_options.py
--- ./origin/options/test_options.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/options/test_options.py	2023-06-28 08:25:08.239980816 +0000
@@ -19,5 +19,8 @@
         parser.set_defaults(model='test')
         # To avoid cropping, the load_size should be the same as crop_size
         parser.set_defaults(load_size=parser.get_default('crop_size'))
+        """"!!!!!!!!添加代码 加入保存onnx保存选项!!!!!"""
+        parser.add_argument('--save_onnx', type=bool, default=False, help='save onnx model')
+
         self.isTrain = False
         return parser
Only in ./origin/options: train_options.py
Only in ./origin: pix2pix.ipynb
Only in ./origin: requirements.txt
Only in ./origin/scripts: conda_deps.sh
Only in ./origin/scripts/edges: PostprocessHED.m
diff -u -r '--exclude=*.md' ./origin/scripts/edges/batch_hed.py ./fix/scripts/edges/batch_hed.py
--- ./origin/scripts/edges/batch_hed.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/scripts/edges/batch_hed.py	2023-06-28 08:25:08.367980817 +0000
@@ -1,3 +1,4 @@
+
 # HED batch processing script; modified from https://github.com/s9xie/hed/blob/master/examples/hed/HED-tutorial.ipynb
 # Step 1: download the hed repo: https://github.com/s9xie/hed
 # Step 2: download the models and protoxt, and put them under {caffe_root}/examples/hed/
diff -u -r '--exclude=*.md' ./origin/scripts/eval_cityscapes/cityscapes.py ./fix/scripts/eval_cityscapes/cityscapes.py
--- ./origin/scripts/eval_cityscapes/cityscapes.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/scripts/eval_cityscapes/cityscapes.py	2023-06-28 08:25:08.343980817 +0000
@@ -1,3 +1,4 @@
+
 # The following code is modified from https://github.com/shelhamer/clockwork-fcn
 import sys
 import os
diff -u -r '--exclude=*.md' ./origin/scripts/eval_cityscapes/util.py ./fix/scripts/eval_cityscapes/util.py
--- ./origin/scripts/eval_cityscapes/util.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/scripts/eval_cityscapes/util.py	2023-06-28 08:25:08.339980817 +0000
@@ -1,3 +1,4 @@
+
 # The following code is modified from https://github.com/shelhamer/clockwork-fcn
 import numpy as np
 
Only in ./origin/scripts: install_deps.sh
Only in ./origin/scripts: test_before_push.py
Only in ./origin/scripts: train_colorization.sh
Only in ./origin/scripts: train_cyclegan.sh
Only in ./origin/scripts: train_pix2pix.sh
diff -u -r '--exclude=*.md' ./origin/test.py ./fix/test.py
--- ./origin/test.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/test.py	2023-06-28 08:25:08.303980816 +0000
@@ -1,3 +1,4 @@
+
 """General-purpose test script for image-to-image translation.
 
 Once you have trained your model with train.py, you can use this script to test the model.
@@ -32,10 +33,38 @@
 from models import create_model
 from util.visualizer import save_images
 from util import html
-
+"""!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+import torch.npu
+# python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_label2photo_pretrained --norm instance 
+# python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_label2photo_pretrained --save_onnx True  
+# python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_pix2pix_1p_bs1_lr0002_ep200 --norm instance 
+# python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_pix2pix_2p_bs1_lr0002_ep200 --norm instance 
+# python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_pix2pix_8p_bs1_lr0002_ep200 --norm instance 
+# python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_pix2pix_8p_bs1_lr0002_ep1200 --norm instance 
+# python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_pix2pix_8p_bs1_lr0016_ep200 --norm instance 
+# python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_pix2pix_8p_bs1_lr0002_ep200 --norm instance 
+# python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_pix2pix_8p_bs1_lr0002_ep200_NpuFusedAdam --norm instance 
+# python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_pix2pix_8p_bs1_lr0002_ep200_Onlycombine_grad --norm instance 
+# python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_pix2pix_8p_bs1_lr0002_ep200_OnlyNpuFusedAdam --norm instance 
+
+"""!!!!!!!!!!!!!!!修改的地方!!!!!!!!!!!!!!!!!!1"""
+LOCAL_RANK = int(os.getenv('LOCAL_RANK', 0))  # https://pytorch.org/docs/stable/elastic/run.html
+RANK = int(os.getenv('RANK', 0))
+WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
+"""!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+# os.environ['MASTER_ADDR'] = 'localhost'
+# os.environ['MASTER_PORT'] = '5678'
+os.environ['MASTER_ADDR'] = '127.0.0.1'
+os.environ['MASTER_PORT'] = '29128'
 
 if __name__ == '__main__':
     opt = TestOptions().parse()  # get test options
+    """!!!!!!!!!!!!!!!修改的地方!!!!!!!!!!!!!!!!!!1"""
+    """!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+    # torch.distributed.init_process_group(backend="nccl", rank=RANK, world_size=WORLD_SIZE)
+    torch.distributed.init_process_group(backend="hccl", rank=RANK, world_size=WORLD_SIZE)
+    print(f"[init] == local rank: {LOCAL_RANK}, global rank: {RANK} , world size: {WORLD_SIZE}")
+
     # hard-code some parameters for test
     opt.num_threads = 0   # test code only supports num_threads = 0
     opt.batch_size = 1    # test code only supports batch_size = 1
@@ -67,3 +96,7 @@
             print('processing (%04d)-th image... %s' % (i, img_path))
         save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
     webpage.save()  # save the HTML
+    if opt.save_onnx:
+        for i, data in enumerate(dataset):
+            model.save_onnx(data['A'])
+            break
Only in ./origin: train.py
diff -u -r '--exclude=*.md' ./origin/util/__init__.py ./fix/util/__init__.py
--- ./origin/util/__init__.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/util/__init__.py	2023-06-28 08:25:08.379980817 +0000
@@ -1 +1,2 @@
+
 """This package includes a miscellaneous collection of useful helper functions."""
diff -u -r '--exclude=*.md' ./origin/util/image_pool.py ./fix/util/image_pool.py
--- ./origin/util/image_pool.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/util/image_pool.py	2023-06-28 08:25:08.391980818 +0000
@@ -1,5 +1,6 @@
 import random
-import torch
+"""!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+import torch.npu
 
 
 class ImagePool():
diff -u -r '--exclude=*.md' ./origin/util/util.py ./fix/util/util.py
--- ./origin/util/util.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/util/util.py	2023-06-28 08:25:08.403980818 +0000
@@ -1,6 +1,7 @@
 """This module contains simple helper functions """
 from __future__ import print_function
-import torch
+"""!!!!!!!!!!!!!!!npu修改的地方!!!!!!!!!!!!!!!!!!1"""
+import torch.npu
 import numpy as np
 from PIL import Image
 import os
diff -u -r '--exclude=*.md' ./origin/util/visualizer.py ./fix/util/visualizer.py
--- ./origin/util/visualizer.py	2023-06-28 08:07:10.719967614 +0000
+++ ./fix/util/visualizer.py	2023-06-28 08:25:08.383980817 +0000
@@ -80,14 +80,9 @@
             util.mkdirs([self.web_dir, self.img_dir])
         # create a logging file to store training losses
         self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
-        self.fid_log_name = os.path.join(opt.checkpoints_dir, opt.name, 'fid_log.txt')
-
         with open(self.log_name, "a") as log_file:
             now = time.strftime("%c")
             log_file.write('================ Training Loss (%s) ================\n' % now)
-        with open(self.fid_log_name, "a") as log_file:
-            now = time.strftime("%c")
-            log_file.write('================ Validation FID (%s) ================\n' % now)
 
     def reset(self):
         """Reset the self.saved status"""
@@ -181,29 +176,6 @@
                 webpage.add_images(ims, txts, links, width=self.win_size)
             webpage.save()
 
-    def plot_current_fid(self, epoch, fid):
-        """display the current fid on visdom display
-
-        Parameters:
-            epoch (int)  -- current epoch
-            fid (float)  -- validation fid
-        """
-        if not hasattr(self, 'fid_plot_data'):
-            self.fid_plot_data = {'X': [], 'Y': []}
-        self.fid_plot_data['X'].append(epoch)
-        self.fid_plot_data['Y'].append(fid)
-        try:
-            self.vis.line(
-                X=np.array(self.fid_plot_data['X']),
-                Y=np.array(self.fid_plot_data['Y']),
-                opts={
-                    'title': self.name + ' fid over time',
-                    'xlabel': 'epoch',
-                    'ylabel': 'fid'},
-                win=self.display_id + 4)
-        except VisdomExceptionBase:
-            self.create_visdom_connections()
-
     def plot_current_losses(self, epoch, counter_ratio, losses):
         """display the current losses on visdom display: dictionary of error labels and values
 
@@ -230,7 +202,7 @@
             self.create_visdom_connections()
 
     # losses: same format as |losses| of plot_current_losses
-    def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
+    def print_current_losses(self, epoch, iters, losses, t_comp, t_data, FPS):
         """print current losses on console; also save the losses to the disk
 
         Parameters:
@@ -240,23 +212,10 @@
             t_comp (float) -- computational time per data point (normalized by batch_size)
             t_data (float) -- data loading time per data point (normalized by batch_size)
         """
-        message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
+        message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f, FPS %d) ' % (epoch, iters, t_comp, t_data,FPS)
         for k, v in losses.items():
             message += '%s: %.3f ' % (k, v)
 
         print(message)  # print the message
         with open(self.log_name, "a") as log_file:
             log_file.write('%s\n' % message)  # save the message
-
-    def print_current_fid(self, epoch, fid):
-        """print current fid on console; also save the fid to the disk
-
-        Parameters:
-            epoch (int) -- current epoch
-            fid (float) - fid metric
-        """
-        message = '(epoch: %d, fid: %.3f) ' % (epoch, fid)
-
-        print(message)  # print the message
-        with open(self.fid_log_name, "a") as log_file:
-            log_file.write('%s\n' % message)  # save the message