05360171创建于 2022年3月18日历史提交
diff --git a/aligned_reid/dataset/Prefetcher.py b/aligned_reid/dataset/Prefetcher.py
index 4324f1b..89d7d51 100644
--- a/aligned_reid/dataset/Prefetcher.py
+++ b/aligned_reid/dataset/Prefetcher.py
@@ -1,5 +1,5 @@
 import threading
-import Queue
+import queue
 import time
 
 
@@ -46,7 +46,7 @@ class Enqueuer(object):
     assert num_threads > 0
     self.num_threads = num_threads
     self.queue_size = queue_size
-    self.queue = Queue.Queue(maxsize=queue_size)
+    self.queue = queue.Queue(maxsize=queue_size)
     # The pointer shared by threads.
     self.ptr = Counter(max_val=num_elements)
     # The event to wake up threads, it's set at the beginning of an epoch.
diff --git a/aligned_reid/dataset/__init__.py b/aligned_reid/dataset/__init__.py
index 3b89cd3..a056843 100644
--- a/aligned_reid/dataset/__init__.py
+++ b/aligned_reid/dataset/__init__.py
@@ -1,4 +1,5 @@
 import numpy as np
+import os
 import os.path as osp
 ospj = osp.join
 ospeu = osp.expanduser
@@ -24,8 +25,8 @@ def create_dataset(
   ########################################
 
   if name == 'market1501':
-    im_dir = ospeu('~/Dataset/market1501/images')
-    partition_file = ospeu('~/Dataset/market1501/partitions.pkl')
+    im_dir = ospeu('./market1501/images')
+    partition_file = ospeu('./market1501/partitions.pkl')
 
   elif name == 'cuhk03':
     im_type = ['detected', 'labeled'][0]
@@ -50,7 +51,6 @@ def create_dataset(
   cmc_kwargs = dict(separate_camera_set=False,
                     single_gallery_shot=False,
                     first_match_break=True)
-
   partitions = load_pickle(partition_file)
   im_names = partitions['{}_im_names'.format(part)]
 
diff --git a/aligned_reid/model/Model.py b/aligned_reid/model/Model.py
index 20e50eb..dd2e979 100644
--- a/aligned_reid/model/Model.py
+++ b/aligned_reid/model/Model.py
@@ -9,7 +9,7 @@ from .resnet import resnet50
 class Model(nn.Module):
   def __init__(self, local_conv_out_channels=128, num_classes=None):
     super(Model, self).__init__()
-    self.base = resnet50(pretrained=True)
+    self.base = resnet50(pretrained=False)
     planes = 2048
     self.local_conv = nn.Conv2d(planes, local_conv_out_channels, 1)
     self.local_bn = nn.BatchNorm2d(local_conv_out_channels)
@@ -28,7 +28,8 @@ class Model(nn.Module):
     """
     # shape [N, C, H, W]
     feat = self.base(x)
-    global_feat = F.avg_pool2d(feat, feat.size()[2:])
+    feat_size = [int(s) for s in feat.size()[2:]]
+    global_feat = F.avg_pool2d(feat, feat_size)
     # shape [N, C]
     global_feat = global_feat.view(global_feat.size(0), -1)
     # shape [N, C, H, 1]
diff --git a/aligned_reid/model/resnet.py b/aligned_reid/model/resnet.py
index 0af1192..5ea8c58 100644
--- a/aligned_reid/model/resnet.py
+++ b/aligned_reid/model/resnet.py
@@ -147,7 +147,7 @@ class ResNet(nn.Module):
 
 def remove_fc(state_dict):
   """Remove the fc layer parameters from state_dict."""
-  for key, value in state_dict.items():
+  for key in list(state_dict.keys()):
     if key.startswith('fc.'):
       del state_dict[key]
   return state_dict
diff --git a/aligned_reid/utils/utils.py b/aligned_reid/utils/utils.py
index d25fb26..0169cb6 100644
--- a/aligned_reid/utils/utils.py
+++ b/aligned_reid/utils/utils.py
@@ -1,7 +1,7 @@
 from __future__ import print_function
 import os
 import os.path as osp
-import cPickle as pickle
+import pickle as pickle
 from scipy import io
 import datetime
 import time
@@ -288,7 +288,7 @@ def load_state_dict(model, src_state_dict):
       param = param.data
     try:
       dest_state_dict[name].copy_(param)
-    except Exception, msg:
+    except Exception as msg:
       print("Warning: Error occurs when copying '{}': {}"
             .format(name, str(msg)))