import warnings
import matplotlib.pyplot as plt
import mmcv
import numpy as np
import pycocotools.mask as maskUtils
import torch
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmdet.core import get_classes
from mmdet.datasets.pipelines import Compose
from mmdet.models import build_detector
import cv2
from scipy import ndimage
def init_detector(config, checkpoint=None, device='cuda:0'):
"""Initialize a detector from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
Returns:
nn.Module: The constructed detector.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
'but got {}'.format(type(config)))
config.model.pretrained = None
model = build_detector(config.model, test_cfg=config.test_cfg)
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint)
if 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['CLASSES']
else:
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use COCO classes by default.')
model.CLASSES = get_classes('coco')
model.cfg = config
model.to(device)
model.eval()
return model
class LoadImage(object):
def __call__(self, results):
if isinstance(results['img'], str):
results['filename'] = results['img']
else:
results['filename'] = None
img = mmcv.imread(results['img'])
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
return results
def inference_detector(model, img):
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
Returns:
If imgs is a str, a generator will be returned, otherwise return the
detection results directly.
"""
cfg = model.cfg
device = next(model.parameters()).device
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
data = dict(img=img)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result
async def async_inference_detector(model, img):
"""Async inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
Returns:
Awaitable detection results.
"""
cfg = model.cfg
device = next(model.parameters()).device
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
data = dict(img=img)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
torch.set_grad_enabled(False)
result = await model.aforward_test(rescale=True, **data)
return result
def show_result(img,
result,
class_names,
score_thr=0.3,
wait_time=0,
show=True,
out_file=None):
"""Visualize the detection results on the image.
Args:
img (str or np.ndarray): Image filename or loaded image.
result (tuple[list] or list): The detection result, can be either
(bbox, segm) or just bbox.
class_names (list[str] or tuple[str]): A list of class names.
score_thr (float): The threshold to visualize the bboxes and masks.
wait_time (int): Value of waitKey param.
show (bool, optional): Whether to show the image with opencv or not.
out_file (str, optional): If specified, the visualization result will
be written to the out file instead of shown in a window.
Returns:
np.ndarray or None: If neither `show` nor `out_file` is specified, the
visualized image is returned, otherwise None is returned.
"""
assert isinstance(class_names, (tuple, list))
img = mmcv.imread(img)
img = img.copy()
if isinstance(result, tuple):
bbox_result, segm_result = result
else:
bbox_result, segm_result = result, None
bboxes = np.vstack(bbox_result)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
if segm_result is not None:
segms = mmcv.concat_list(segm_result)
inds = np.where(bboxes[:, -1] > score_thr)[0]
np.random.seed(42)
color_masks = [
np.random.randint(0, 256, (1, 3), dtype=np.uint8)
for _ in range(max(labels) + 1)
]
for i in inds:
i = int(i)
color_mask = color_masks[labels[i]]
mask = maskUtils.decode(segms[i]).astype(np.bool)
img[mask] = img[mask] * 0.5 + color_mask * 0.5
mmcv.imshow_det_bboxes(
img,
bboxes,
labels,
class_names=class_names,
score_thr=score_thr,
show=show,
wait_time=wait_time,
out_file=out_file)
if not (show or out_file):
return img
def show_result_pyplot(img,
result,
class_names,
score_thr=0.3,
fig_size=(15, 10)):
"""Visualize the detection results on the image.
Args:
img (str or np.ndarray): Image filename or loaded image.
result (tuple[list] or list): The detection result, can be either
(bbox, segm) or just bbox.
class_names (list[str] or tuple[str]): A list of class names.
score_thr (float): The threshold to visualize the bboxes and masks.
fig_size (tuple): Figure size of the pyplot figure.
out_file (str, optional): If specified, the visualization result will
be written to the out file instead of shown in a window.
"""
img = show_result(
img, result, class_names, score_thr=score_thr, show=False)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
def show_result_ins(img,
result,
class_names,
score_thr=0.3,
sort_by_density=False,
out_file=None):
"""Visualize the instance segmentation results on the image.
Args:
img (str or np.ndarray): Image filename or loaded image.
result (tuple[list] or list): The instance segmentation result.
class_names (list[str] or tuple[str]): A list of class names.
score_thr (float): The threshold to visualize the masks.
sort_by_density (bool): sort the masks by their density.
out_file (str, optional): If specified, the visualization result will
be written to the out file instead of shown in a window.
Returns:
np.ndarray or None: If neither `show` nor `out_file` is specified, the
visualized image is returned, otherwise None is returned.
"""
assert isinstance(class_names, (tuple, list))
img = mmcv.imread(img)
img_show = img.copy()
h, w, _ = img.shape
if not result or result == [None]:
return img_show
cur_result = result[0]
seg_label = cur_result[0]
seg_label = seg_label.cpu().numpy().astype(np.uint8)
cate_label = cur_result[1]
cate_label = cate_label.cpu().numpy()
score = cur_result[2].cpu().numpy()
vis_inds = score > score_thr
seg_label = seg_label[vis_inds]
num_mask = seg_label.shape[0]
cate_label = cate_label[vis_inds]
cate_score = score[vis_inds]
if sort_by_density:
mask_density = []
for idx in range(num_mask):
cur_mask = seg_label[idx, :, :]
cur_mask = mmcv.imresize(cur_mask, (w, h))
cur_mask = (cur_mask > 0.5).astype(np.int32)
mask_density.append(cur_mask.sum())
orders = np.argsort(mask_density)
seg_label = seg_label[orders]
cate_label = cate_label[orders]
cate_score = cate_score[orders]
np.random.seed(42)
color_masks = [
np.random.randint(0, 256, (1, 3), dtype=np.uint8)
for _ in range(num_mask)
]
for idx in range(num_mask):
idx = -(idx+1)
cur_mask = seg_label[idx, :, :]
cur_mask = mmcv.imresize(cur_mask, (w, h))
cur_mask = (cur_mask > 0.5).astype(np.uint8)
if cur_mask.sum() == 0:
continue
color_mask = color_masks[idx]
cur_mask_bool = cur_mask.astype(np.bool)
img_show[cur_mask_bool] = img[cur_mask_bool] * 0.5 + color_mask * 0.5
cur_cate = cate_label[idx]
cur_score = cate_score[idx]
label_text = class_names[cur_cate]
center_y, center_x = ndimage.measurements.center_of_mass(cur_mask)
vis_pos = (max(int(center_x) - 10, 0), int(center_y))
cv2.putText(img_show, label_text, vis_pos,
cv2.FONT_HERSHEY_COMPLEX, 0.3, (255, 255, 255))
if out_file is None:
return img_show
else:
mmcv.imwrite(img_show, out_file)