import torch
import numpy as np
from reid import datasets
from reid import models
import os.path as osp
from reid.utils.data import transforms as T
from reid.feature_extraction import extract_cnn_feature
from reid.evaluators import extract_features
from reid.utils.serialization import load_checkpoint, save_checkpoint
import os
from torch.utils.data import DataLoader
from reid.utils.data.preprocessor import Preprocessor
from reid.utils import to_numpy
import argparse
if not os.path.exists("inference"):
os.makedirs("inference")
os.system('rm -f inference/*')
parser = argparse.ArgumentParser(description="Softmax loss classification")
parser.add_argument('-d', '--data_path', type=str, default='../data/Market-1501')
parser.add_argument('--device', type=str, default='npu')
parser.add_argument('--checkpoint', type=str, default='logs/market-1501/PCB/checkpoint.pth.tar')
args = parser.parse_args()
os.environ['device'] = args.device
def extract_features_single_img(model, img):
outputs = extract_cnn_feature(model, img)
return outputs
def build_model():
model = models.create("resnet50", num_features=256,
dropout=0.5, num_classes=751,cut_at_pooling=False, FCN=True)
checkpoint = load_checkpoint(args.checkpoint)
model.load_state_dict(checkpoint['state_dict'])
if os.environ['device'] == 'npu':
model = model.to("npu:0")
elif os.environ['device'] == 'gpu':
model = model.to("cuda:0")
model.eval()
return model
def get_raw_data():
name, root = "market", args.data_path
dataset = datasets.create(name, root)
fname, pid, camid = dataset.query[44]
from PIL import Image
fpath = osp.join(osp.join(args.data_path, "query"), fname)
img = Image.open(fpath).convert('RGB')
return img, dataset, fpath
def pairwise_img_gallery_distance(img_feature, gallery_features, gallery):
x = img_feature.unsqueeze(0)
y = torch.cat([gallery_features[f].unsqueeze(0) for f, _, _ in gallery], 0)
m, n = x.size(0), y.size(0)
x = x.view(m, -1)
y = y.view(n, -1)
dist = torch.pow(x, 2).sum(1).unsqueeze(1).expand(m, n) + \
torch.pow(y, 2).sum(1).unsqueeze(1).expand(n, m).t()
dist.addmm_(1, -2, x, y.t())
return dist
def get_feature(model, img, dataset):
normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
test_transformer = T.Compose([
T.RectScale(384, 128),
T.ToTensor(),
normalizer,
])
img = test_transformer(img)
img = np.expand_dims(img, axis=0)
print("extract single img feature...")
img_feature = extract_features_single_img(model, img)
print("extract gallery features...")
gallery_loader = DataLoader(
Preprocessor(dataset.gallery, root=osp.join(dataset.images_dir,dataset.gallery_path),
transform=test_transformer),
batch_size=64, num_workers=8,
shuffle=False)
gallery_features, _ = extract_features(model, gallery_loader)
return img_feature, gallery_features
if __name__ == '__main__':
data_path = args.data_path
img, dataset, img_fpath = get_raw_data()
model = build_model()
img_feature, gallery_features = get_feature(model, img, dataset)
print("pairwise distance")
distmat = pairwise_img_gallery_distance(img_feature, gallery_features, dataset.gallery)
distmat = to_numpy(distmat)
m, n = distmat.shape
indices = np.argsort(distmat, axis=1)
o_name, o_id, o_cid = dataset.gallery[indices[0][0]]
img_name = osp.basename(img_fpath)
img_fpath = osp.join(osp.join(data_path, "query"), img_name)
o_path = osp.join(osp.join(data_path, "bounding_box_test"), o_name)
command = "cp %s ./inference/%s" % (img_fpath, img_name)
os.system(command)
print("input img (query) saved to ./inference/%s" % img_name)
command = "cp %s ./inference/%s" % (o_path, o_name)
os.system(command)
print("predict img (gallery) saved to ./inference/%s" % o_name)