import sys
sys.path.append("./centroids-reid")
import argparse
from pathlib import Path
import glob
import os.path as osp
import re
import numpy as np
import torch
from datasets import init_dataset
from train_ctl_model import CTLModel
def get_imagedata_info(data):
pids, cams = [], []
for _, pid, camid, *_ in data:
pids += [pid]
cams += [camid]
pids = set(pids)
cams = set(cams)
num_pids = len(pids)
num_cams = len(cams)
num_imgs = len(data)
return num_pids, num_imgs, num_cams
def process_dir(dir_path, relabel=False):
img_paths = glob.glob(osp.join(dir_path, '*.txt'))
pattern = re.compile(r'([-\d]+)_c(\d)')
pid_container = set()
for img_path in img_paths:
pid, _ = map(int, pattern.search(img_path).groups())
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(pid_container)}
dataset = []
for idx, img_path in enumerate(img_paths):
pid, camid = map(int, pattern.search(img_path).groups())
assert 1 <= camid <= 8
camid -= 1
if relabel: pid = pid2label[pid]
dataset.append((img_path, pid, camid, idx))
return dataset
def prepare_embedding(path):
data = np.loadtxt(path, dtype=np.float32)
matrix = data.reshape([1,2048])
output = torch.from_numpy(matrix)
return output
def prepare_tensor(item):
b=np.array(item)
output = np.expand_dims(b, 0)
output=torch.from_numpy(output)
return output
def run_postprocess(self,args):
query_dir = osp.join(args.dataset_dir, args.query_path)
gallery_dir = osp.join(args.dataset_dir, args.gallery_path)
query = process_dir(query_dir, relabel=False)
gallery = process_dir(gallery_dir, relabel=False)
outputs = []
for x in enumerate(query+gallery):
emb, class_labels, camid, idx=x[1]
emb=prepare_embedding(emb)
class_labels=prepare_tensor(class_labels)
camid=prepare_tensor(camid)
idx=prepare_tensor(idx)
outputs.append(
{"emb": emb, "labels": class_labels, "camid": camid, "idx": idx}
)
embeddings = torch.cat([x["emb"] for x in outputs]).detach().cpu()
labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
camids = torch.cat([x["camid"] for x in outputs]).cpu().detach().numpy()
idx = torch.cat([x["idx"] for x in outputs]).cpu().detach().numpy()
print(embeddings.shape,labels.shape,camids.shape)
embeddings, labels, camids = self.validation_create_centroids(
embeddings, labels, camids,respect_camids=self.hparams.MODEL.KEEP_CAMID_CENTROIDS,
)
self.get_val_metrics(embeddings, labels, camids)
del embeddings, labels, camids
def main(args):
model = CTLModel
model = model.load_from_checkpoint(
'./models/dukemtmcreid_resnet50_256_128_epoch_120.ckpt',
)
run_postprocess(model,args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="CLT ACL")
parser.add_argument("--dataset_dir", default="/home/wxq/DukeMTMC-reID/result/", help="embedding root path", type=str)
parser.add_argument("--query_path", default="", help="query/2022528_15_44_24_128520", type=str)
parser.add_argument("--gallery_path", default="", help="gallery/2022528_15_43_9_215073", type=str)
args = parser.parse_args()
main(args)