05360171创建于 2022年3月18日历史提交
diff --git a/mmdet/core/post_processing/matrix_nms.py b/mmdet/core/post_processing/matrix_nms.py
index cbbe420..764d9cb 100644
--- a/mmdet/core/post_processing/matrix_nms.py
+++ b/mmdet/core/post_processing/matrix_nms.py
@@ -1,6 +1,17 @@
 import torch
 
 
+def triu_(x, diagonal=0):
+    t = x.shape[0]
+    base = torch.arange(t, device=x.device)
+    mask = base.expand(t, t)
+    base = base.unsqueeze(-1)
+    if diagonal:
+        base = base + diagonal
+    mask = mask >= base
+    return mask * x
+
+
 def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None):
     """Matrix NMS for multi-class masks.
 
@@ -26,10 +37,12 @@ def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0
     # union.
     sum_masks_x = sum_masks.expand(n_samples, n_samples)
     # iou.
-    iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix)).triu(diagonal=1)
+    iou_matrix = inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix)
+    iou_matrix = triu_(iou_matrix, diagonal=1)
     # label_specific matrix.
     cate_labels_x = cate_labels.expand(n_samples, n_samples)
-    label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float().triu(diagonal=1)
+    label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float()
+    label_matrix = triu_(label_matrix, diagonal=1)
 
     # IoU compensation
     compensate_iou, _ = (iou_matrix * label_matrix).max(0)
diff --git a/mmdet/models/anchor_heads/solo_head.py b/mmdet/models/anchor_heads/solo_head.py
index 2ecd75f..3ac6b06 100644
--- a/mmdet/models/anchor_heads/solo_head.py
+++ b/mmdet/models/anchor_heads/solo_head.py
@@ -22,12 +22,14 @@ def center_of_mass(bitmasks):
     center_y = m01 / m00
     return center_x, center_y
 
+
 def points_nms(heat, kernel=2):
     # kernel must be 2
     hmax = nn.functional.max_pool2d(
         heat, (kernel, kernel), stride=1, padding=1)
-    keep = (hmax[:, :, :-1, :-1] == heat).float()
-    return heat * keep
+    keep = torch.abs(hmax[:, :, :-1, :-1]-heat) < 1e-3
+    return keep.int()
+
 
 def dice_loss(input, target):
     input = input.contiguous().view(input.size()[0], -1)
@@ -145,8 +147,13 @@ class SOLOHead(nn.Module):
         cate_feat = x
         # ins branch
         # concat coord
-        x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
-        y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
+        feat_h, feat_w = ins_feat.shape[-2], ins_feat.shape[-1]
+        feat_h, feat_w = int(feat_h.cpu().numpy() if isinstance(feat_h, torch.Tensor) else feat_h), \
+                         int(feat_w.cpu().numpy() if isinstance(feat_w, torch.Tensor) else feat_w)
+        step_x = 2. / (feat_w - 1)
+        step_y = 2. / (feat_h - 1)
+        x_range = torch.arange(-1, 1.00147, step_x, device=ins_feat.device)
+        y_range = torch.arange(-1, 1.00147, step_y, device=ins_feat.device)
         y, x = torch.meshgrid(y_range, x_range)
         y = y.expand([ins_feat.shape[0], 1, -1, -1])
         x = x.expand([ins_feat.shape[0], 1, -1, -1])
@@ -169,7 +176,9 @@ class SOLOHead(nn.Module):
         cate_pred = self.solo_cate(cate_feat)
         if eval:
             ins_pred = F.interpolate(ins_pred.sigmoid(), size=upsampled_size, mode='bilinear')
-            cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
+            cate_mark = points_nms(cate_pred, kernel=2)
+            cate_pred = cate_pred.sigmoid()
+            cate_pred = (cate_pred * cate_mark).permute(0, 2, 3, 1)
         return ins_pred, cate_pred
 
     def loss(self,
@@ -313,28 +322,25 @@ class SOLOHead(nn.Module):
             ins_ind_label_list.append(ins_ind_label)
         return ins_label_list, cate_label_list, ins_ind_label_list
 
-    def get_seg(self, seg_preds, cate_preds, img_metas, cfg, rescale=None):
+    def get_seg(self, seg_preds, cate_preds, cfg, rescale=None):
         assert len(seg_preds) == len(cate_preds)
         num_levels = len(cate_preds)
         featmap_size = seg_preds[0].size()[-2:]
+        img_num = 1
 
         result_list = []
-        for img_id in range(len(img_metas)):
+        for img_id in range(img_num):
             cate_pred_list = [
                 cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
             ]
             seg_pred_list = [
                 seg_preds[i][img_id].detach() for i in range(num_levels)
             ]
-            img_shape = img_metas[img_id]['img_shape']
-            scale_factor = img_metas[img_id]['scale_factor']
-            ori_shape = img_metas[img_id]['ori_shape']
 
             cate_pred_list = torch.cat(cate_pred_list, dim=0)
             seg_pred_list = torch.cat(seg_pred_list, dim=0)
 
-            result = self.get_seg_single(cate_pred_list, seg_pred_list,
-                                         featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
+            result = self.get_seg_single(cate_pred_list, seg_pred_list, featmap_size, cfg, rescale)
             result_list.append(result)
         return result_list
 
@@ -342,26 +348,16 @@ class SOLOHead(nn.Module):
                        cate_preds,
                        seg_preds,
                        featmap_size,
-                       img_shape,
-                       ori_shape,
-                       scale_factor,
                        cfg,
                        rescale=False, debug=False):
         assert len(cate_preds) == len(seg_preds)
 
-        # overall info.
-        h, w, _ = img_shape
-        upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)
-
         # process.
-        inds = (cate_preds > cfg.score_thr)
-        # category scores.
-        cate_scores = cate_preds[inds]
-        if len(cate_scores) == 0:
-            return None
-        # category labels.
-        inds = inds.nonzero()
-        cate_labels = inds[:, 1]
+        cate_scores, cate_preds = torch.max(cate_preds, dim=-1)
+        cate_scores, inds = torch.topk(cate_scores, k=200)
+
+        cate_labels = cate_preds[inds]
+        cate_labels = cate_labels.int()
 
         # strides.
         size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum(0)
@@ -370,30 +366,30 @@ class SOLOHead(nn.Module):
         strides[:size_trans[0]] *= self.strides[0]
         for ind_ in range(1, n_stage):
             strides[size_trans[ind_ - 1]:size_trans[ind_]] *= self.strides[ind_]
-        strides = strides[inds[:, 0]]
+        strides = strides[inds]
 
         # masks.
-        seg_preds = seg_preds[inds[:, 0]]
+        seg_preds = seg_preds[inds]
         seg_masks = seg_preds > cfg.mask_thr
-        sum_masks = seg_masks.sum((1, 2)).float()
+        sum_masks = seg_masks.int().sum((1, 2)).float()
 
         # filter.
         keep = sum_masks > strides
-        if keep.sum() == 0:
-            return None
-
-        seg_masks = seg_masks[keep, ...]
-        seg_preds = seg_preds[keep, ...]
-        sum_masks = sum_masks[keep]
-        cate_scores = cate_scores[keep]
-        cate_labels = cate_labels[keep]
+        keep_int = keep.int()
+        keep_mask = keep_int.reshape(-1, 1, 1)
+        keep_mask = keep_mask.expand(-1, seg_masks.shape[1], seg_masks.shape[2]).int()
+        seg_masks = torch.mul(seg_masks, keep_mask)
+        seg_preds = torch.mul(seg_preds, keep_mask)
+        cate_scores = torch.mul(cate_scores, keep_int)
+        sum_masks = torch.mul(sum_masks, keep_int)
+        sum_masks += 0.1
 
         # maskness.
         seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
         cate_scores *= seg_scores
 
         # sort and keep top nms_pre
-        sort_inds = torch.argsort(cate_scores, descending=True)
+        _, sort_inds = torch.sort(cate_scores, descending=True)
         if len(sort_inds) > cfg.nms_pre:
             sort_inds = sort_inds[:cfg.nms_pre]
         seg_masks = seg_masks[sort_inds, :, :]
@@ -406,27 +402,12 @@ class SOLOHead(nn.Module):
         cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
                                  kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)
 
-        # filter.
-        keep = cate_scores >= cfg.update_thr
-        if keep.sum() == 0:
-            return None
-        seg_preds = seg_preds[keep, :, :]
-        cate_scores = cate_scores[keep]
-        cate_labels = cate_labels[keep]
-
         # sort and keep top_k
-        sort_inds = torch.argsort(cate_scores, descending=True)
+        _, sort_inds = torch.sort(cate_scores, descending=True)
         if len(sort_inds) > cfg.max_per_img:
             sort_inds = sort_inds[:cfg.max_per_img]
         seg_preds = seg_preds[sort_inds, :, :]
         cate_scores = cate_scores[sort_inds]
         cate_labels = cate_labels[sort_inds]
 
-        seg_preds = F.interpolate(seg_preds.unsqueeze(0),
-                                  size=upsampled_size_out,
-                                  mode='bilinear')[:, :, :h, :w]
-        seg_masks = F.interpolate(seg_preds,
-                                  size=ori_shape[:2],
-                                  mode='bilinear').squeeze(0)
-        seg_masks = seg_masks > cfg.mask_thr
-        return seg_masks, cate_labels, cate_scores
+        return seg_preds, cate_labels, cate_scores
diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py
index 82f91bd..4a93a27 100644
--- a/mmdet/models/detectors/base.py
+++ b/mmdet/models/detectors/base.py
@@ -124,7 +124,7 @@ class BaseDetector(nn.Module, metaclass=ABCMeta):
         assert imgs_per_gpu == 1
 
         if num_augs == 1:
-            return self.simple_test(imgs[0], img_metas[0], **kwargs)
+            return self.simple_test(imgs[0], **kwargs)
         else:
             return self.aug_test(imgs, img_metas, **kwargs)
 
diff --git a/mmdet/models/detectors/single_stage_ins.py b/mmdet/models/detectors/single_stage_ins.py
index 773d5d2..2d19f88 100644
--- a/mmdet/models/detectors/single_stage_ins.py
+++ b/mmdet/models/detectors/single_stage_ins.py
@@ -78,7 +78,7 @@ class SingleStageInsDetector(BaseDetector):
             *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
         return losses
 
-    def simple_test(self, img, img_meta, rescale=False):
+    def simple_test(self, img, img_meta=None, rescale=False):
         x = self.extract_feat(img)
         outs = self.bbox_head(x, eval=True)
 
@@ -88,7 +88,7 @@ class SingleStageInsDetector(BaseDetector):
                   start_level:self.mask_feat_head.end_level + 1])
             seg_inputs = outs + (mask_feat_pred, img_meta, self.test_cfg, rescale)
         else:
-            seg_inputs = outs + (img_meta, self.test_cfg, rescale)
+            seg_inputs = outs + (self.test_cfg, rescale)
         seg_result = self.bbox_head.get_seg(*seg_inputs)
         return seg_result