GGiteeupdate
cd100dc0创建于 2022年7月28日历史提交
--- loader.py	2022-07-28 16:59:02.740926771 +0800
+++ loader.py	2022-07-28 16:52:32.728921993 +0800
@@ -63,8 +63,8 @@
                  re_count=1,
                  re_num_splits=0):
         self.loader = loader
-        self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
-        self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
+        self.mean = torch.tensor([x * 255 for x in mean]).view(1, 3, 1, 1)
+        self.std = torch.tensor([x * 255 for x in std]).view(1, 3, 1, 1)
         self.fp16 = fp16
         if fp16:
             self.mean = self.mean.half()
@@ -76,26 +76,19 @@
             self.random_erasing = None
 
     def __iter__(self):
-        stream = torch.cuda.Stream()
         first = True
-
         for next_input, next_target in self.loader:
-            with torch.cuda.stream(stream):
-                next_input = next_input.cuda(non_blocking=True)
-                next_target = next_target.cuda(non_blocking=True)
-                if self.fp16:
-                    next_input = next_input.half().sub_(self.mean).div_(self.std)
-                else:
-                    next_input = next_input.float().sub_(self.mean).div_(self.std)
-                if self.random_erasing is not None:
-                    next_input = self.random_erasing(next_input)
+            if self.fp16:
+                next_input = next_input.half().sub_(self.mean).div_(self.std)
+            else:
+                next_input = next_input.float().sub_(self.mean).div_(self.std)
+            if self.random_erasing is not None:
+                next_input = self.random_erasing(next_input)
 
             if not first:
                 yield input, target
             else:
                 first = False
-
-            torch.cuda.current_stream().wait_stream(stream)
             input = next_input
             target = next_target