@@ -1,6 +1,17 @@
import torch
+def triu_(x, diagonal=0):
+ t = x.shape[0]
+ base = torch.arange(t, device=x.device)
+ mask = base.expand(t, t)
+ base = base.unsqueeze(-1)
+ if diagonal:
+ base = base + diagonal
+ mask = mask >= base
+ return mask * x
+
+
def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None):
"""Matrix NMS for multi-class masks.
@@ -26,10 +37,12 @@ def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0
# union.
sum_masks_x = sum_masks.expand(n_samples, n_samples)
# iou.
- iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix)).triu(diagonal=1)
+ iou_matrix = inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix)
+ iou_matrix = triu_(iou_matrix, diagonal=1)
# label_specific matrix.
cate_labels_x = cate_labels.expand(n_samples, n_samples)
- label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float().triu(diagonal=1)
+ label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float()
+ label_matrix = triu_(label_matrix, diagonal=1)
# IoU compensation
compensate_iou, _ = (iou_matrix * label_matrix).max(0)
@@ -22,12 +22,14 @@ def center_of_mass(bitmasks):
center_y = m01 / m00
return center_x, center_y
+
def points_nms(heat, kernel=2):
# kernel must be 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=1)
- keep = (hmax[:, :, :-1, :-1] == heat).float()
- return heat * keep
+ keep = torch.abs(hmax[:, :, :-1, :-1]-heat) < 1e-3
+ return keep.int()
+
def dice_loss(input, target):
input = input.contiguous().view(input.size()[0], -1)
@@ -145,8 +147,13 @@ class SOLOHead(nn.Module):
cate_feat = x
# ins branch
# concat coord
- x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
- y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
+ feat_h, feat_w = ins_feat.shape[-2], ins_feat.shape[-1]
+ feat_h, feat_w = int(feat_h.cpu().numpy() if isinstance(feat_h, torch.Tensor) else feat_h), \
+ int(feat_w.cpu().numpy() if isinstance(feat_w, torch.Tensor) else feat_w)
+ step_x = 2. / (feat_w - 1)
+ step_y = 2. / (feat_h - 1)
+ x_range = torch.arange(-1, 1.00147, step_x, device=ins_feat.device)
+ y_range = torch.arange(-1, 1.00147, step_y, device=ins_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([ins_feat.shape[0], 1, -1, -1])
x = x.expand([ins_feat.shape[0], 1, -1, -1])
@@ -169,7 +176,9 @@ class SOLOHead(nn.Module):
cate_pred = self.solo_cate(cate_feat)
if eval:
ins_pred = F.interpolate(ins_pred.sigmoid(), size=upsampled_size, mode='bilinear')
- cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
+ cate_mark = points_nms(cate_pred, kernel=2)
+ cate_pred = cate_pred.sigmoid()
+ cate_pred = (cate_pred * cate_mark).permute(0, 2, 3, 1)
return ins_pred, cate_pred
def loss(self,
@@ -313,28 +322,25 @@ class SOLOHead(nn.Module):
ins_ind_label_list.append(ins_ind_label)
return ins_label_list, cate_label_list, ins_ind_label_list
- def get_seg(self, seg_preds, cate_preds, img_metas, cfg, rescale=None):
+ def get_seg(self, seg_preds, cate_preds, cfg, rescale=None):
assert len(seg_preds) == len(cate_preds)
num_levels = len(cate_preds)
featmap_size = seg_preds[0].size()[-2:]
+ img_num = 1
result_list = []
- for img_id in range(len(img_metas)):
+ for img_id in range(img_num):
cate_pred_list = [
cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
]
seg_pred_list = [
seg_preds[i][img_id].detach() for i in range(num_levels)
]
- img_shape = img_metas[img_id]['img_shape']
- scale_factor = img_metas[img_id]['scale_factor']
- ori_shape = img_metas[img_id]['ori_shape']
cate_pred_list = torch.cat(cate_pred_list, dim=0)
seg_pred_list = torch.cat(seg_pred_list, dim=0)
- result = self.get_seg_single(cate_pred_list, seg_pred_list,
- featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
+ result = self.get_seg_single(cate_pred_list, seg_pred_list, featmap_size, cfg, rescale)
result_list.append(result)
return result_list
@@ -342,26 +348,16 @@ class SOLOHead(nn.Module):
cate_preds,
seg_preds,
featmap_size,
- img_shape,
- ori_shape,
- scale_factor,
cfg,
rescale=False, debug=False):
assert len(cate_preds) == len(seg_preds)
- # overall info.
- h, w, _ = img_shape
- upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)
-
# process.
- inds = (cate_preds > cfg.score_thr)
- # category scores.
- cate_scores = cate_preds[inds]
- if len(cate_scores) == 0:
- return None
- # category labels.
- inds = inds.nonzero()
- cate_labels = inds[:, 1]
+ cate_scores, cate_preds = torch.max(cate_preds, dim=-1)
+ cate_scores, inds = torch.topk(cate_scores, k=200)
+
+ cate_labels = cate_preds[inds]
+ cate_labels = cate_labels.int()
# strides.
size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum(0)
@@ -370,30 +366,30 @@ class SOLOHead(nn.Module):
strides[:size_trans[0]] *= self.strides[0]
for ind_ in range(1, n_stage):
strides[size_trans[ind_ - 1]:size_trans[ind_]] *= self.strides[ind_]
- strides = strides[inds[:, 0]]
+ strides = strides[inds]
# masks.
- seg_preds = seg_preds[inds[:, 0]]
+ seg_preds = seg_preds[inds]
seg_masks = seg_preds > cfg.mask_thr
- sum_masks = seg_masks.sum((1, 2)).float()
+ sum_masks = seg_masks.int().sum((1, 2)).float()
# filter.
keep = sum_masks > strides
- if keep.sum() == 0:
- return None
-
- seg_masks = seg_masks[keep, ...]
- seg_preds = seg_preds[keep, ...]
- sum_masks = sum_masks[keep]
- cate_scores = cate_scores[keep]
- cate_labels = cate_labels[keep]
+ keep_int = keep.int()
+ keep_mask = keep_int.reshape(-1, 1, 1)
+ keep_mask = keep_mask.expand(-1, seg_masks.shape[1], seg_masks.shape[2]).int()
+ seg_masks = torch.mul(seg_masks, keep_mask)
+ seg_preds = torch.mul(seg_preds, keep_mask)
+ cate_scores = torch.mul(cate_scores, keep_int)
+ sum_masks = torch.mul(sum_masks, keep_int)
+ sum_masks += 0.1
# maskness.
seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_scores
# sort and keep top nms_pre
- sort_inds = torch.argsort(cate_scores, descending=True)
+ _, sort_inds = torch.sort(cate_scores, descending=True)
if len(sort_inds) > cfg.nms_pre:
sort_inds = sort_inds[:cfg.nms_pre]
seg_masks = seg_masks[sort_inds, :, :]
@@ -406,27 +402,12 @@ class SOLOHead(nn.Module):
cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)
- # filter.
- keep = cate_scores >= cfg.update_thr
- if keep.sum() == 0:
- return None
- seg_preds = seg_preds[keep, :, :]
- cate_scores = cate_scores[keep]
- cate_labels = cate_labels[keep]
-
# sort and keep top_k
- sort_inds = torch.argsort(cate_scores, descending=True)
+ _, sort_inds = torch.sort(cate_scores, descending=True)
if len(sort_inds) > cfg.max_per_img:
sort_inds = sort_inds[:cfg.max_per_img]
seg_preds = seg_preds[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
- seg_preds = F.interpolate(seg_preds.unsqueeze(0),
- size=upsampled_size_out,
- mode='bilinear')[:, :, :h, :w]
- seg_masks = F.interpolate(seg_preds,
- size=ori_shape[:2],
- mode='bilinear').squeeze(0)
- seg_masks = seg_masks > cfg.mask_thr
- return seg_masks, cate_labels, cate_scores
+ return seg_preds, cate_labels, cate_scores
@@ -124,7 +124,7 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
assert imgs_per_gpu == 1
if num_augs == 1:
- return self.simple_test(imgs[0], img_metas[0], **kwargs)
+ return self.simple_test(imgs[0], **kwargs)
else:
return self.aug_test(imgs, img_metas, **kwargs)
@@ -78,7 +78,7 @@ class SingleStageInsDetector(BaseDetector):
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
- def simple_test(self, img, img_meta, rescale=False):
+ def simple_test(self, img, img_meta=None, rescale=False):
x = self.extract_feat(img)
outs = self.bbox_head(x, eval=True)
@@ -88,7 +88,7 @@ class SingleStageInsDetector(BaseDetector):
start_level:self.mask_feat_head.end_level + 1])
seg_inputs = outs + (mask_feat_pred, img_meta, self.test_cfg, rescale)
else:
- seg_inputs = outs + (img_meta, self.test_cfg, rescale)
+ seg_inputs = outs + (self.test_cfg, rescale)
seg_result = self.bbox_head.get_seg(*seg_inputs)
return seg_result