@@ -1,5 +1,5 @@
import threading
-import Queue
+import queue
import time
@@ -46,7 +46,7 @@ class Enqueuer(object):
assert num_threads > 0
self.num_threads = num_threads
self.queue_size = queue_size
- self.queue = Queue.Queue(maxsize=queue_size)
+ self.queue = queue.Queue(maxsize=queue_size)
# The pointer shared by threads.
self.ptr = Counter(max_val=num_elements)
# The event to wake up threads, it's set at the beginning of an epoch.
@@ -1,4 +1,5 @@
import numpy as np
+import os
import os.path as osp
ospj = osp.join
ospeu = osp.expanduser
@@ -24,8 +25,8 @@ def create_dataset(
########################################
if name == 'market1501':
- im_dir = ospeu('~/Dataset/market1501/images')
- partition_file = ospeu('~/Dataset/market1501/partitions.pkl')
+ im_dir = ospeu('./market1501/images')
+ partition_file = ospeu('./market1501/partitions.pkl')
elif name == 'cuhk03':
im_type = ['detected', 'labeled'][0]
@@ -50,7 +51,6 @@ def create_dataset(
cmc_kwargs = dict(separate_camera_set=False,
single_gallery_shot=False,
first_match_break=True)
-
partitions = load_pickle(partition_file)
im_names = partitions['{}_im_names'.format(part)]
@@ -9,7 +9,7 @@ from .resnet import resnet50
class Model(nn.Module):
def __init__(self, local_conv_out_channels=128, num_classes=None):
super(Model, self).__init__()
- self.base = resnet50(pretrained=True)
+ self.base = resnet50(pretrained=False)
planes = 2048
self.local_conv = nn.Conv2d(planes, local_conv_out_channels, 1)
self.local_bn = nn.BatchNorm2d(local_conv_out_channels)
@@ -28,7 +28,8 @@ class Model(nn.Module):
"""
# shape [N, C, H, W]
feat = self.base(x)
- global_feat = F.avg_pool2d(feat, feat.size()[2:])
+ feat_size = [int(s) for s in feat.size()[2:]]
+ global_feat = F.avg_pool2d(feat, feat_size)
# shape [N, C]
global_feat = global_feat.view(global_feat.size(0), -1)
# shape [N, C, H, 1]
@@ -147,7 +147,7 @@ class ResNet(nn.Module):
def remove_fc(state_dict):
"""Remove the fc layer parameters from state_dict."""
- for key, value in state_dict.items():
+ for key in list(state_dict.keys()):
if key.startswith('fc.'):
del state_dict[key]
return state_dict
@@ -1,7 +1,7 @@
from __future__ import print_function
import os
import os.path as osp
-import cPickle as pickle
+import pickle as pickle
from scipy import io
import datetime
import time
@@ -288,7 +288,7 @@ def load_state_dict(model, src_state_dict):
param = param.data
try:
dest_state_dict[name].copy_(param)
- except Exception, msg:
+ except Exception as msg:
print("Warning: Error occurs when copying '{}': {}"
.format(name, str(msg)))