@@ -5,6 +5,7 @@ from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
+import time
class Pix2PixHDModel(BaseModel):
def name(self):
@@ -213,7 +214,13 @@ class Pix2PixHDModel(BaseModel):
with torch.no_grad():
fake_image = self.netG.forward(input_concat)
else:
- fake_image = self.netG.forward(input_concat)
+ with torch.no_grad():
+ # torch.cuda.synchronize()
+ start = time.time()
+ fake_image = self.netG.forward(input_concat)
+ # torch.cuda.synchronize()
+ end = time.time()
+ print("本次生成图像耗时{}s".format(end - start))
return fake_image
def sample_features(self, inst):
@@ -64,4 +64,4 @@ for i, data in enumerate(dataset):
print('process image... %s' % img_path)
visualizer.save_images(webpage, visuals, img_path)
-webpage.save()
+webpage.save()
\ No newline at end of file