import torch
import numpy as np
import os.path as osp
import os
import argparse
import torchreid
from torchreid.data.datasets.image.market1501 import Market1501
from torchreid.utils import load_pretrained_weights
from torchreid import metrics
if not os.path.exists("inference"):
os.makedirs("inference")
os.system('rm -f inference/*')
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument('-d', '--data_path', type=str, default='./reid-data/')
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--checkpoint', type=str, default='log/osnet_x1_0_market1501_softmax/model/model.pth.tar-350')
parser.add_argument('--config_file', type=str, default='configs/osnet_x1_0_trained_from_scratch.yaml')
args = parser.parse_args()
os.environ['device'] = args.device
def build_model():
model = torchreid.models.build_model(
name="osnet_x1_0",
num_classes=751,
loss="softmax",
pretrained=False,
use_gpu=False
)
load_pretrained_weights(model, args.checkpoint)
if os.environ['device'] == 'npu':
model = model.to("npu:0")
elif os.environ['device'] == 'gpu':
model = model.to("cuda:0")
model.eval()
return model
def get_raw_data():
name, root = "market1501", args.data_path
param = {
"root": root,
'sources': ['market1501'],
'targets': ['market1501'],
'height': 256,
'width': 128,
'transforms': ['random_flip', 'random_crop', 'random_patch'],
'k_tfm': 1,
'norm_mean': [0.485, 0.456, 0.406],
'norm_std': [0.229, 0.224, 0.225],
'use_gpu': False,
'split_id': 0,
'combineall': False,
'load_train_targets': False,
'batch_size_train': 32,
'batch_size_test': 32,
'workers': 4,
'num_instances': 4,
'num_cams': 1,
'num_datasets': 1,
'train_sampler': "RandomSampler",
'train_sampler_t': "RandomSampler",
'cuhk03_labeled': False,
'cuhk03_classic_split': False,
'market1501_500k': False,
}
datamanager = torchreid.data.ImageDataManager(**param)
query_loader = datamanager.test_loader['market1501']['query']
gallery_loader = datamanager.test_loader['market1501']['gallery']
data = next(iter(query_loader))
fnames = data['impath']
pids = data['pid']
imgs = data['img']
camids = data['camid']
img = imgs[24]
pid = pids[24]
camid = camids[24]
fname = fnames[24]
img = torch.unsqueeze(img, dim=0)
return datamanager, gallery_loader, img, fname, pid
def parse_data_for_eval(data):
fnames = data['impath']
imgs = data['img']
pids = data['pid']
camids = data['camid']
return fnames, imgs, pids, camids
def feature_extraction(model, imgs):
if os.environ['device'] == 'gpu':
imgs = imgs.cuda()
elif os.environ['device'] == 'npu':
imgs = imgs.npu()
features = model(imgs)
features = features.cpu().clone()
return features
def find_imgs_with_id(id, data_loader):
f_, pids_, camids_, f_names_ = [], [], [], []
for batch_idx, data in enumerate(data_loader):
fnames, imgs, pids, camids = parse_data_for_eval(data)
for fname, img, pid, camid in zip(fnames, imgs, pids, camids):
if pid == id:
img = torch.unsqueeze(img, dim=0)
f_.append(img)
pids_.append(pid)
camids_.append(camid)
f_names_.append(fname)
f_ = torch.cat(f_, dim=0)
return f_, pids_, camids_, f_names_
def feature_extraction_single(model, imgs):
if os.environ['device'] == 'gpu':
imgs = imgs.cuda()
elif os.environ['device'] == 'npu':
imgs = imgs.npu()
features = model(imgs)
features = features.cpu().clone()
return features
def save_image(fname):
img_name = osp.basename(fname)
command = "cp %s ./inference/%s" % (fname, img_name)
os.system(command)
if __name__ == '__main__':
data_path = args.data_path
print("load dataset")
datamanager, gallery_loader, img, fname, pid = get_raw_data()
print("find a img in gallery with id %d" % pid)
imgs_gallery, pids, camids, f_names = find_imgs_with_id(pid.item(), gallery_loader)
print("build model")
model = build_model()
print("extract img feature...")
img_feature = feature_extraction_single(model, img)
print("extract gallery feature...")
gallery_feature = feature_extraction(model, imgs_gallery)
dist_metric = "euclidean"
print(
'Computing distance matrix with metric={} ...'.format(dist_metric)
)
distmat = metrics.compute_distance_matrix(img_feature, gallery_feature, dist_metric)
distmat = distmat.cpu().detach().numpy()
m, n = distmat.shape
indices = np.argsort(distmat, axis=1)
index = indices[0][0]
fname_gallery = f_names[index]
save_image(fname)
save_image(fname_gallery)
print("query img saved to ./inference/%s" % osp.basename(fname))
print("gallery img saved to ./inference/%s" % osp.basename(fname_gallery))