05360171创建于 2022年3月18日历史提交
diff --git a/models/pix2pixHD_model.py b/models/pix2pixHD_model.py
index fafdec0..9c96a9f 100755
--- a/models/pix2pixHD_model.py
+++ b/models/pix2pixHD_model.py
@@ -5,6 +5,7 @@ from torch.autograd import Variable
 from util.image_pool import ImagePool
 from .base_model import BaseModel
 from . import networks
+import time
 
 class Pix2PixHDModel(BaseModel):
     def name(self):
@@ -213,7 +214,13 @@ class Pix2PixHDModel(BaseModel):
             with torch.no_grad():
                 fake_image = self.netG.forward(input_concat)
         else:
-            fake_image = self.netG.forward(input_concat)
+            with torch.no_grad():
+                # torch.cuda.synchronize()
+                start = time.time()
+                fake_image = self.netG.forward(input_concat)
+                # torch.cuda.synchronize()
+                end = time.time()
+                print("本次生成图像耗时{}s".format(end - start))
         return fake_image
 
     def sample_features(self, inst): 
diff --git a/test.py b/test.py
index e0b1ec3..83d13fc 100755
--- a/test.py
+++ b/test.py
@@ -64,4 +64,4 @@ for i, data in enumerate(dataset):
     print('process image... %s' % img_path)
     visualizer.save_images(webpage, visuals, img_path)
 
-webpage.save()
+webpage.save()
\ No newline at end of file