import numpy as np
from time import time
import os
import random
import time
import io
import json
import torch
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from argparse import ArgumentParser
from utils import DefaultBoxes, Encoder, COCODetection
from utils import SSDTransformer
from ssd300 import SSD300
from train import load_checkpoint, dboxes300_coco
def parse_args():
parser = ArgumentParser(description="Visualize models predictions on image")
parser.add_argument('--images', '-i', nargs='*', type=str,
help='path to jpg image')
parser.add_argument('--model', '-m', type=str, default='iter_240000.pt',
help='path to trained model')
parser.add_argument('--threshold', '-t', type=float, default=0.10,
help='threshold for predictions probabilities')
parser.add_argument('--annotations', '-a', type=str,
default='/share/home/litaotao/yzc/coco/annotations/instances_val2017.json',
help='path to json with annotations')
return parser.parse_args()
def print_image(image, model, encoder, inv_map, name_map, category_id_to_color, threshold):
im = Image.open(image)
W, H = im.size
tmp = im.copy()
tmp = tmp.resize((300, 300))
img = transforms.ToTensor()(tmp)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
img = normalize(img).unsqueeze(dim = 0)
with torch.no_grad():
ploc, plabel = model(img)
ploc, plabel = ploc.float(), plabel.float()
ret = []
for idx in range(ploc.shape[0]):
ploc_i = ploc[idx, :, :].unsqueeze(0)
plabel_i = plabel[idx, :, :].unsqueeze(0)
try:
result = encoder.decode_batch(ploc_i, plabel_i, 0.50, 200)[0]
except:
print("No object detected in image {}".format(image))
continue
htot, wtot = (H, W)
loc, label, prob = [r.cpu().numpy() for r in result]
for loc_, label_, prob_ in zip(loc, label, prob):
ret.append([0, loc_[0]*wtot, \
loc_[1]*htot,
(loc_[2] - loc_[0])*wtot,
(loc_[3] - loc_[1])*htot,
prob_,
inv_map[label_]])
ret = np.array(ret).astype(np.float32)
bboxes = []
for re in ret:
if re[5] > threshold:
bboxes.append(re)
print("Bounding boxes detected in image {}:".format(image))
print(bboxes)
img = transforms.ToTensor()(im)
img = img.permute(1, 2, 0)
H = img.shape[0]
W = img.shape[1]
fig,ax = plt.subplots(1)
ax.imshow(img)
used = set()
for bbox in bboxes:
if (bbox[6] in used):
rect = patches.Rectangle((bbox[1], bbox[2]), bbox[3], bbox[4],
edgecolor=category_id_to_color[bbox[6]],
linewidth=2, facecolor='none')
else:
rect = patches.Rectangle((bbox[1], bbox[2]), bbox[3], bbox[4],
label = name_map[bbox[6]],
edgecolor=category_id_to_color[bbox[6]],
linewidth=2, facecolor='none')
used.add(bbox[6])
ax.add_patch(rect)
plt.legend(ncol=1, bbox_to_anchor=(1.04,1), loc="upper left")
plt.show()
def main():
args = parse_args()
with open(args.annotations,'r') as anno:
js = json.loads(anno.read())
coco_names = js['categories']
name_map = {}
for name in coco_names:
name_map[name['id']] = name['name']
deleted = [12, 26, 29, 30, 45, 66, 68, 69, 71, 83]
inv_map = {}
cnt = 0
for i in range(1, 81):
while i + cnt in deleted:
cnt += 1
inv_map[i] = i + cnt
category_id_to_color = dict([(cat_id, [random.uniform(0, 1) ,random.uniform(0, 1), random.uniform(0, 1)]) for cat_id in range(1, 91)])
plt.rcParams["figure.figsize"] = (12, 8)
ssd300 = SSD300(81, backbone="resnet34", model_path=None, dilation=None)
load_checkpoint(ssd300, args.model)
ssd300.eval()
dboxes = dboxes300_coco()
encoder = Encoder(dboxes)
for image in args.images:
print_image(image, ssd300, encoder, inv_map, name_map, category_id_to_color, args.threshold)
if __name__ == "__main__":
main()