05360171创建于 2022年3月18日历史提交
diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py
index 9470b6e..8b23583 100644
--- a/mmdet/apis/inference.py
+++ b/mmdet/apis/inference.py
@@ -35,7 +35,7 @@ def init_detector(config, checkpoint=None, device='cuda:0'):
     config.model.pretrained = None
     model = build_detector(config.model, test_cfg=config.test_cfg)
     if checkpoint is not None:
-        checkpoint = load_checkpoint(model, checkpoint)
+        checkpoint = load_checkpoint(model, checkpoint, map_location=torch.device(device))
         if 'CLASSES' in checkpoint['meta']:
             model.CLASSES = checkpoint['meta']['CLASSES']
         else:
diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
index 97c0dc6..da46af0 100644
--- a/mmdet/apis/train.py
+++ b/mmdet/apis/train.py
@@ -12,7 +12,6 @@ from mmdet import datasets
 from mmdet.core import (CocoDistEvalmAPHook, CocoDistEvalRecallHook,
                         DistEvalmAPHook, DistOptimizerHook, Fp16OptimizerHook)
 from mmdet.datasets import DATASETS, build_dataloader
-from mmdet.models import RPN
 from mmdet.utils import get_root_logger
 
 
diff --git a/mmdet/core/post_processing/__init__.py b/mmdet/core/post_processing/__init__.py
index 73fb199..67849fe 100644
--- a/mmdet/core/post_processing/__init__.py
+++ b/mmdet/core/post_processing/__init__.py
@@ -1,9 +1,5 @@
-from .bbox_nms import multiclass_nms
 from .matrix_nms import matrix_nms
-from .merge_augs import (merge_aug_bboxes, merge_aug_masks,
-                         merge_aug_proposals, merge_aug_scores)
 
 __all__ = [
-    'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes',
-    'merge_aug_scores', 'merge_aug_masks', 'matrix_nms'
+    'matrix_nms'
 ]
diff --git a/mmdet/models/__init__.py b/mmdet/models/__init__.py
index 35f0a09..56f7896 100644
--- a/mmdet/models/__init__.py
+++ b/mmdet/models/__init__.py
@@ -1,6 +1,5 @@
 from .anchor_heads import *  # noqa: F401,F403
 from .backbones import *  # noqa: F401,F403
-from .bbox_heads import *  # noqa: F401,F403
 from .builder import (build_backbone, build_detector, build_head, build_loss,
                       build_neck, build_roi_extractor, build_shared_head)
 from .detectors import *  # noqa: F401,F403
@@ -9,8 +8,6 @@ from .mask_heads import *  # noqa: F401,F403
 from .necks import *  # noqa: F401,F403
 from .registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
                        ROI_EXTRACTORS, SHARED_HEADS)
-from .roi_extractors import *  # noqa: F401,F403
-from .shared_heads import *  # noqa: F401,F403
 
 __all__ = [
     'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
diff --git a/mmdet/models/anchor_heads/__init__.py b/mmdet/models/anchor_heads/__init__.py
index de1d7ef..955440b 100644
--- a/mmdet/models/anchor_heads/__init__.py
+++ b/mmdet/models/anchor_heads/__init__.py
@@ -1,25 +1,7 @@
-from .anchor_head import AnchorHead
-from .atss_head import ATSSHead
-from .fcos_head import FCOSHead
-from .fovea_head import FoveaHead
-from .free_anchor_retina_head import FreeAnchorRetinaHead
-from .ga_retina_head import GARetinaHead
-from .ga_rpn_head import GARPNHead
-from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
-from .reppoints_head import RepPointsHead
-from .retina_head import RetinaHead
-from .retina_sepbn_head import RetinaSepBNHead
-from .rpn_head import RPNHead
-from .ssd_head import SSDHead
 from .solo_head import SOLOHead
 from .solov2_head import SOLOv2Head
-from .solov2_light_head import SOLOv2LightHead
-from .decoupled_solo_head import DecoupledSOLOHead
-from .decoupled_solo_light_head import DecoupledSOLOLightHead
+
 
 __all__ = [
-    'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', 'RPNHead',
-    'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead',
-    'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead',
-    'ATSSHead', 'SOLOHead', 'SOLOv2Head', 'SOLOv2LightHead', 'DecoupledSOLOHead', 'DecoupledSOLOLightHead'
+    'SOLOHead', 'SOLOv2Head',
 ]
diff --git a/mmdet/models/anchor_heads/solo_head.py b/mmdet/models/anchor_heads/solo_head.py
index e6c0607..2ecd75f 100644
--- a/mmdet/models/anchor_heads/solo_head.py
+++ b/mmdet/models/anchor_heads/solo_head.py
@@ -3,7 +3,6 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from mmcv.cnn import normal_init
-from mmdet.ops import DeformConv, roi_align
 from mmdet.core import multi_apply, bbox2roi, matrix_nms
 from ..builder import build_loss
 from ..registry import HEADS
diff --git a/mmdet/models/anchor_heads/solov2_head.py b/mmdet/models/anchor_heads/solov2_head.py
index 9616b99..2765eb2 100644
--- a/mmdet/models/anchor_heads/solov2_head.py
+++ b/mmdet/models/anchor_heads/solov2_head.py
@@ -3,7 +3,6 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from mmcv.cnn import normal_init
-from mmdet.ops import DeformConv, roi_align
 from mmdet.core import multi_apply, matrix_nms
 from ..builder import build_loss
 from ..registry import HEADS
diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
index ab6913e..2c66224 100644
--- a/mmdet/models/backbones/resnet.py
+++ b/mmdet/models/backbones/resnet.py
@@ -5,7 +5,6 @@ from mmcv.runner import load_checkpoint
 from torch.nn.modules.batchnorm import _BatchNorm
 
 from mmdet.models.plugins import GeneralizedAttention
-from mmdet.ops import ContextBlock
 from mmdet.utils import get_root_logger
 from ..registry import BACKBONES
 from ..utils import build_conv_layer, build_norm_layer
diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py
index e7aad35..afb78de 100644
--- a/mmdet/models/detectors/__init__.py
+++ b/mmdet/models/detectors/__init__.py
@@ -1,27 +1,6 @@
-from .atss import ATSS
-from .base import BaseDetector
-from .cascade_rcnn import CascadeRCNN
-from .double_head_rcnn import DoubleHeadRCNN
-from .fast_rcnn import FastRCNN
-from .faster_rcnn import FasterRCNN
-from .fcos import FCOS
-from .fovea import FOVEA
-from .grid_rcnn import GridRCNN
-from .htc import HybridTaskCascade
-from .mask_rcnn import MaskRCNN
-from .mask_scoring_rcnn import MaskScoringRCNN
-from .reppoints_detector import RepPointsDetector
-from .retinanet import RetinaNet
-from .rpn import RPN
-from .single_stage import SingleStageDetector
-from .single_stage_ins import SingleStageInsDetector
-from .two_stage import TwoStageDetector
 from .solo import SOLO
 from .solov2 import SOLOv2
 
 __all__ = [
-    'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
-    'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
-    'DoubleHeadRCNN', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN',
-    'RepPointsDetector', 'FOVEA', 'SingleStageInsDetector', 'SOLO', 'SOLOv2'
+    'SOLO', 'SOLOv2'
 ]
diff --git a/mmdet/models/detectors/test_mixins.py b/mmdet/models/detectors/test_mixins.py
index 84a96d1..31a4507 100644
--- a/mmdet/models/detectors/test_mixins.py
+++ b/mmdet/models/detectors/test_mixins.py
@@ -3,8 +3,7 @@ import sys
 
 import torch
 
-from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_bboxes,
-                        merge_aug_masks, merge_aug_proposals, multiclass_nms)
+from mmdet.core import (bbox2roi, bbox_mapping)
 
 logger = logging.getLogger(__name__)
 
diff --git a/mmdet/models/losses/focal_loss.py b/mmdet/models/losses/focal_loss.py
index 6b28e12..f410479 100644
--- a/mmdet/models/losses/focal_loss.py
+++ b/mmdet/models/losses/focal_loss.py
@@ -1,8 +1,7 @@
 import torch.nn as nn
 import torch.nn.functional as F
 
-from mmdet.ops import sigmoid_focal_loss as _sigmoid_focal_loss
-from ..registry import LOSSES
+from ..builder import LOSSES
 from .utils import weight_reduce_loss
 
 
diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py
index fa57404..a2d7771 100644
--- a/mmdet/models/necks/__init__.py
+++ b/mmdet/models/necks/__init__.py
@@ -1,6 +1,5 @@
 from .bfp import BFP
 from .fpn import FPN
-from .hrfpn import HRFPN
 from .nas_fpn import NASFPN

-__all__ = ['FPN', 'BFP', 'HRFPN', 'NASFPN']
+__all__ = ['FPN', 'BFP', 'NASFPN']
diff --git a/mmdet/models/roi_extractors/single_level.py b/mmdet/models/roi_extractors/single_level.py
index 6620d1d..6901ecc 100644
--- a/mmdet/models/roi_extractors/single_level.py
+++ b/mmdet/models/roi_extractors/single_level.py
@@ -3,7 +3,6 @@ from __future__ import division
 import torch
 import torch.nn as nn
 
-from mmdet import ops
 from mmdet.core import force_fp32
 from ..registry import ROI_EXTRACTORS
 
diff --git a/mmdet/models/utils/conv_module.py b/mmdet/models/utils/conv_module.py
index 3be32c3..0675aae 100644
--- a/mmdet/models/utils/conv_module.py
+++ b/mmdet/models/utils/conv_module.py
@@ -3,15 +3,12 @@ import warnings
 import torch.nn as nn
 from mmcv.cnn import constant_init, kaiming_init
 
-from mmdet.ops import DeformConvPack, ModulatedDeformConvPack
 from .conv_ws import ConvWS2d
 from .norm import build_norm_layer
 
 conv_cfg = {
     'Conv': nn.Conv2d,
     'ConvWS': ConvWS2d,
-    'DCN': DeformConvPack,
-    'DCNv2': ModulatedDeformConvPack,
     # TODO: octave conv
 }
 
diff --git a/mmdet/ops/dcn/deform_conv.py b/mmdet/ops/dcn/deform_conv.py
index 5ba5a5e..c20d97d 100644
--- a/mmdet/ops/dcn/deform_conv.py
+++ b/mmdet/ops/dcn/deform_conv.py
@@ -1,122 +1,51 @@
-import math
+# Copyright (c) 2020, Huawei Technologies.All rights reserved.
+#
+# Licensed under the BSD 3-Clause License  (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 
 import torch
 import torch.nn as nn
 from torch.autograd import Function
-from torch.autograd.function import once_differentiable
 from torch.nn.modules.utils import _pair, _single
+import math
 
-from mmdet.utils import print_log
-from . import deform_conv_cuda
-
-
-class DeformConvFunction(Function):
-
-    @staticmethod
-    def forward(ctx,
-                input,
-                offset,
-                weight,
-                stride=1,
-                padding=0,
-                dilation=1,
-                groups=1,
-                deformable_groups=1,
-                im2col_step=64):
-        if input is not None and input.dim() != 4:
-            raise ValueError(
-                'Expected 4D tensor as input, got {}D tensor instead.'.format(
-                    input.dim()))
-        ctx.stride = _pair(stride)
-        ctx.padding = _pair(padding)
-        ctx.dilation = _pair(dilation)
-        ctx.groups = groups
-        ctx.deformable_groups = deformable_groups
-        ctx.im2col_step = im2col_step
-
-        ctx.save_for_backward(input, offset, weight)
-
-        output = input.new_empty(
-            DeformConvFunction._output_size(input, weight, ctx.padding,
-                                            ctx.dilation, ctx.stride))
-
-        ctx.bufs_ = [input.new_empty(0), input.new_empty(0)]  # columns, ones
-
-        if not input.is_cuda:
-            raise NotImplementedError
-        else:
-            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
-            assert (input.shape[0] %
-                    cur_im2col_step) == 0, 'im2col step must divide batchsize'
-            deform_conv_cuda.deform_conv_forward_cuda(
-                input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
-                weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
-                ctx.padding[1], ctx.padding[0], ctx.dilation[1],
-                ctx.dilation[0], ctx.groups, ctx.deformable_groups,
-                cur_im2col_step)
-        return output
-
-    @staticmethod
-    @once_differentiable
-    def backward(ctx, grad_output):
-        input, offset, weight = ctx.saved_tensors
-
-        grad_input = grad_offset = grad_weight = None
-
-        if not grad_output.is_cuda:
-            raise NotImplementedError
-        else:
-            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
-            assert (input.shape[0] %
-                    cur_im2col_step) == 0, 'im2col step must divide batchsize'
-
-            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
-                grad_input = torch.zeros_like(input)
-                grad_offset = torch.zeros_like(offset)
-                deform_conv_cuda.deform_conv_backward_input_cuda(
-                    input, offset, grad_output, grad_input,
-                    grad_offset, weight, ctx.bufs_[0], weight.size(3),
-                    weight.size(2), ctx.stride[1], ctx.stride[0],
-                    ctx.padding[1], ctx.padding[0], ctx.dilation[1],
-                    ctx.dilation[0], ctx.groups, ctx.deformable_groups,
-                    cur_im2col_step)
-
-            if ctx.needs_input_grad[2]:
-                grad_weight = torch.zeros_like(weight)
-                deform_conv_cuda.deform_conv_backward_parameters_cuda(
-                    input, offset, grad_output,
-                    grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
-                    weight.size(2), ctx.stride[1], ctx.stride[0],
-                    ctx.padding[1], ctx.padding[0], ctx.dilation[1],
-                    ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
-                    cur_im2col_step)
 
-        return (grad_input, grad_offset, grad_weight, None, None, None, None,
-                None)
+class ModulatedDeformConv2dFunction(Function):
 
     @staticmethod
-    def _output_size(input, weight, padding, dilation, stride):
-        channels = weight.size(0)
-        output_size = (input.size(0), channels)
-        for d in range(input.dim() - 2):
-            in_size = input.size(d + 2)
-            pad = padding[d]
-            kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
-            stride_ = stride[d]
-            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
-        if not all(map(lambda s: s > 0, output_size)):
-            raise ValueError(
-                'convolution input is too small (output would be {})'.format(
-                    'x'.join(map(str, output_size))))
-        return output_size
-
-
-class ModulatedDeformConvFunction(Function):
+    def symbolic(g, input, weight, offset, bias, stride, padding,
+                 dilation, groups, defomable_groups):
+        if isinstance(stride, int):
+            stride = (stride, stride)
+        if isinstance(padding, int):
+            padding = (padding, padding)
+        if isinstance(dilation, int):
+            dilation = (dilation, dilation)
+        return g.op(
+            'DeformableConv2D',
+            input,
+            weight,
+            offset,
+            bias_i=bias,
+            strides_i=stride,
+            pads_i=padding,
+            dilations_i=dilation,
+            groups_i=groups,
+            defomable_groups_i=defomable_groups)
 
     @staticmethod
     def forward(ctx,
                 input,
-                offset,
+                offset_ori,
                 mask,
                 weight,
                 bias=None,
@@ -124,7 +53,13 @@ class ModulatedDeformConvFunction(Function):
                 padding=0,
                 dilation=1,
                 groups=1,
-                deformable_groups=1):
+                deformable_groups=1,
+                ):
+
+        input = input.float()
+        offset_ori = offset_ori.float()
+        mask = mask.float()
+
         ctx.stride = stride
         ctx.padding = padding
         ctx.dilation = dilation
@@ -132,39 +67,43 @@ class ModulatedDeformConvFunction(Function):
         ctx.deformable_groups = deformable_groups
         ctx.with_bias = bias is not None
         if not ctx.with_bias:
-            bias = input.new_empty(1)  # fake tensor
-        if not input.is_cuda:
-            raise NotImplementedError
+            device = input.device
+            bias = torch.zeros(weight.shape[0], device=device)
+        ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+
+        offset_x = offset_ori[:, ::2, :, :]
+        offset_y = offset_ori[:, 1::2, :, :]
+        offset = torch.cat([offset_y, offset_x], dim=1)
+        offset_all = torch.cat([offset, mask], dim=1)
+        output, offset_out = torch.npu_deformable_conv2d(
+            input, weight, offset_all, bias,
+            kernel_size=[weight.shape[3], weight.shape[2]],
+            stride=[1, 1, ctx.stride, ctx.stride],
+            padding=[ctx.padding, ctx.padding, ctx.padding, ctx.padding],
+            dilation=[1, 1, ctx.dilation, ctx.dilation],
+            groups=ctx.groups, deformable_groups=ctx.deformable_groups,
+            modulated=True)
         if weight.requires_grad or mask.requires_grad or offset.requires_grad \
                 or input.requires_grad:
-            ctx.save_for_backward(input, offset, mask, weight, bias)
-        output = input.new_empty(
-            ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
-        ctx._bufs = [input.new_empty(0), input.new_empty(0)]
-        deform_conv_cuda.modulated_deform_conv_cuda_forward(
-            input, weight, bias, ctx._bufs[0], offset, mask, output,
-            ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
-            ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
-            ctx.groups, ctx.deformable_groups, ctx.with_bias)
+            ctx.save_for_backward(input, offset, mask, weight, bias, offset_out)
         return output
 
     @staticmethod
-    @once_differentiable
     def backward(ctx, grad_output):
-        if not grad_output.is_cuda:
-            raise NotImplementedError
-        input, offset, mask, weight, bias = ctx.saved_tensors
-        grad_input = torch.zeros_like(input)
+        input, offset, mask, weight, bias, offset_out = ctx.saved_tensors
         grad_offset = torch.zeros_like(offset)
-        grad_mask = torch.zeros_like(mask)
-        grad_weight = torch.zeros_like(weight)
-        grad_bias = torch.zeros_like(bias)
-        deform_conv_cuda.modulated_deform_conv_cuda_backward(
-            input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
-            grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
-            grad_output, weight.shape[2], weight.shape[3], ctx.stride,
-            ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
-            ctx.groups, ctx.deformable_groups, ctx.with_bias)
+        offset_all = torch.cat([offset, mask], dim=1)
+        grad_input, grad_weight, grad_offset_all, grad_bias = torch.npu_deformable_conv2dbk(
+            input, grad_output, offset_out, weight, offset_all,
+            kernel_size=[weight.shape[3], weight.shape[2]],
+            stride=[1, 1, ctx.stride, ctx.stride],
+            padding=[ctx.padding, ctx.padding, ctx.padding, ctx.padding],
+            dilation=[1, 1, ctx.dilation, ctx.dilation],
+            groups=ctx.groups, deformable_groups=ctx.deformable_groups, modulated=True)
+        kernel_hxw = weight.shape[2] * weight.shape[3]
+        grad_offset[:, 1::2, :, :] = grad_offset_all[:, :kernel_hxw, :, :]
+        grad_offset[:, ::2, :, :] = grad_offset_all[:, kernel_hxw:kernel_hxw * 2, :, :]
+        grad_mask = grad_offset_all[:, -kernel_hxw:, :, :]
         if not ctx.with_bias:
             grad_bias = None
 
@@ -184,131 +123,7 @@ class ModulatedDeformConvFunction(Function):
         return n, channels_out, height_out, width_out
 
 
-deform_conv = DeformConvFunction.apply
-modulated_deform_conv = ModulatedDeformConvFunction.apply
-
-
-class DeformConv(nn.Module):
-
-    def __init__(self,
-                 in_channels,
-                 out_channels,
-                 kernel_size,
-                 stride=1,
-                 padding=0,
-                 dilation=1,
-                 groups=1,
-                 deformable_groups=1,
-                 bias=False):
-        super(DeformConv, self).__init__()
-
-        assert not bias
-        assert in_channels % groups == 0, \
-            'in_channels {} cannot be divisible by groups {}'.format(
-                in_channels, groups)
-        assert out_channels % groups == 0, \
-            'out_channels {} cannot be divisible by groups {}'.format(
-                out_channels, groups)
-
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.kernel_size = _pair(kernel_size)
-        self.stride = _pair(stride)
-        self.padding = _pair(padding)
-        self.dilation = _pair(dilation)
-        self.groups = groups
-        self.deformable_groups = deformable_groups
-        # enable compatibility with nn.Conv2d
-        self.transposed = False
-        self.output_padding = _single(0)
-
-        self.weight = nn.Parameter(
-            torch.Tensor(out_channels, in_channels // self.groups,
-                         *self.kernel_size))
-
-        self.reset_parameters()
-
-    def reset_parameters(self):
-        n = self.in_channels
-        for k in self.kernel_size:
-            n *= k
-        stdv = 1. / math.sqrt(n)
-        self.weight.data.uniform_(-stdv, stdv)
-
-    def forward(self, x, offset):
-        return deform_conv(x, offset, self.weight, self.stride, self.padding,
-                           self.dilation, self.groups, self.deformable_groups)
-
-
-class DeformConvPack(DeformConv):
-    """A Deformable Conv Encapsulation that acts as normal Conv layers.
-
-    Args:
-        in_channels (int): Same as nn.Conv2d.
-        out_channels (int): Same as nn.Conv2d.
-        kernel_size (int or tuple[int]): Same as nn.Conv2d.
-        stride (int or tuple[int]): Same as nn.Conv2d.
-        padding (int or tuple[int]): Same as nn.Conv2d.
-        dilation (int or tuple[int]): Same as nn.Conv2d.
-        groups (int): Same as nn.Conv2d.
-        bias (bool or str): If specified as `auto`, it will be decided by the
-            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
-            False.
-    """
-
-    _version = 2
-
-    def __init__(self, *args, **kwargs):
-        super(DeformConvPack, self).__init__(*args, **kwargs)
-
-        self.conv_offset = nn.Conv2d(
-            self.in_channels,
-            self.deformable_groups * 2 * self.kernel_size[0] *
-            self.kernel_size[1],
-            kernel_size=self.kernel_size,
-            stride=_pair(self.stride),
-            padding=_pair(self.padding),
-            bias=True)
-        self.init_offset()
-
-    def init_offset(self):
-        self.conv_offset.weight.data.zero_()
-        self.conv_offset.bias.data.zero_()
-
-    def forward(self, x):
-        offset = self.conv_offset(x)
-        return deform_conv(x, offset, self.weight, self.stride, self.padding,
-                           self.dilation, self.groups, self.deformable_groups)
-
-    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
-                              missing_keys, unexpected_keys, error_msgs):
-        version = local_metadata.get('version', None)
-
-        if version is None or version < 2:
-            # the key is different in early versions
-            # In version < 2, DeformConvPack loads previous benchmark models.
-            if (prefix + 'conv_offset.weight' not in state_dict
-                    and prefix[:-1] + '_offset.weight' in state_dict):
-                state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
-                    prefix[:-1] + '_offset.weight')
-            if (prefix + 'conv_offset.bias' not in state_dict
-                    and prefix[:-1] + '_offset.bias' in state_dict):
-                state_dict[prefix +
-                           'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
-                                                                '_offset.bias')
-
-        if version is not None and version > 1:
-            print_log(
-                'DeformConvPack {} is upgraded to version 2.'.format(
-                    prefix.rstrip('.')),
-                logger='root')
-
-        super()._load_from_state_dict(state_dict, prefix, local_metadata,
-                                      strict, missing_keys, unexpected_keys,
-                                      error_msgs)
-
-
-class ModulatedDeformConv(nn.Module):
+class ModulatedDeformConv2d(nn.Module):
 
     def __init__(self,
                  in_channels,
@@ -320,7 +135,7 @@ class ModulatedDeformConv(nn.Module):
                  groups=1,
                  deformable_groups=1,
                  bias=True):
-        super(ModulatedDeformConv, self).__init__()
+        super(ModulatedDeformConv2d, self).__init__()
         self.in_channels = in_channels
         self.out_channels = out_channels
         self.kernel_size = _pair(kernel_size)
@@ -330,9 +145,6 @@ class ModulatedDeformConv(nn.Module):
         self.groups = groups
         self.deformable_groups = deformable_groups
         self.with_bias = bias
-        # enable compatibility with nn.Conv2d
-        self.transposed = False
-        self.output_padding = _single(0)
 
         self.weight = nn.Parameter(
             torch.Tensor(out_channels, in_channels // groups,
@@ -341,9 +153,9 @@ class ModulatedDeformConv(nn.Module):
             self.bias = nn.Parameter(torch.Tensor(out_channels))
         else:
             self.register_parameter('bias', None)
-        self.reset_parameters()
+        self.init_weights()
 
-    def reset_parameters(self):
+    def init_weights(self):
         n = self.in_channels
         for k in self.kernel_size:
             n *= k
@@ -353,28 +165,12 @@ class ModulatedDeformConv(nn.Module):
             self.bias.data.zero_()
 
     def forward(self, x, offset, mask):
-        return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
-                                     self.stride, self.padding, self.dilation,
-                                     self.groups, self.deformable_groups)
+        return ModulatedDeformConv2dFunction.apply(x, offset, mask, self.weight, self.bias,
+                                                   self.stride, self.padding, self.dilation,
+                                                   self.groups, self.deformable_groups)
 
 
-class ModulatedDeformConvPack(ModulatedDeformConv):
-    """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
-
-    Args:
-        in_channels (int): Same as nn.Conv2d.
-        out_channels (int): Same as nn.Conv2d.
-        kernel_size (int or tuple[int]): Same as nn.Conv2d.
-        stride (int or tuple[int]): Same as nn.Conv2d.
-        padding (int or tuple[int]): Same as nn.Conv2d.
-        dilation (int or tuple[int]): Same as nn.Conv2d.
-        groups (int): Same as nn.Conv2d.
-        bias (bool or str): If specified as `auto`, it will be decided by the
-            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
-            False.
-    """
-
-    _version = 2
+class ModulatedDeformConvPack(ModulatedDeformConv2d):
 
     def __init__(self, *args, **kwargs):
         super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
@@ -390,6 +186,7 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
         self.init_offset()
 
     def init_offset(self):
+        super(ModulatedDeformConvPack, self).init_weights()
         self.conv_offset.weight.data.zero_()
         self.conv_offset.bias.data.zero_()
 
@@ -398,34 +195,21 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
         o1, o2, mask = torch.chunk(out, 3, dim=1)
         offset = torch.cat((o1, o2), dim=1)
         mask = torch.sigmoid(mask)
-        return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
-                                     self.stride, self.padding, self.dilation,
-                                     self.groups, self.deformable_groups)
+        return ModulatedDeformConv2dFunction.apply(x, offset, mask, self.weight, self.bias,
+                                                   self.stride, self.padding, self.dilation,
+                                                   self.groups, self.deformable_groups)
+
 
-    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
-                              missing_keys, unexpected_keys, error_msgs):
-        version = local_metadata.get('version', None)
+DCNv2 = ModulatedDeformConvPack
 
-        if version is None or version < 2:
-            # the key is different in early versions
-            # In version < 2, ModulatedDeformConvPack
-            # loads previous benchmark models.
-            if (prefix + 'conv_offset.weight' not in state_dict
-                    and prefix[:-1] + '_offset.weight' in state_dict):
-                state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
-                    prefix[:-1] + '_offset.weight')
-            if (prefix + 'conv_offset.bias' not in state_dict
-                    and prefix[:-1] + '_offset.bias' in state_dict):
-                state_dict[prefix +
-                           'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
-                                                                '_offset.bias')
+if __name__ == "__main__":
+    x = torch.randn(2, 32, 4, 4)
+    model = DCNv2(32, 32, 1)
 
-        if version is not None and version > 1:
-            print_log(
-                'ModulatedDeformConvPack {} is upgraded to version 2.'.format(
-                    prefix.rstrip('.')),
-                logger='root')
+    torch.npu.set_device(0)
+    x = x.npu()
+    model = model.npu()
 
-        super()._load_from_state_dict(state_dict, prefix, local_metadata,
-                                      strict, missing_keys, unexpected_keys,
-                                      error_msgs)
+    o = model(x)
+    l = o.sum()
+    l.backward()
diff --git a/setup.py b/setup.py
index aee4ddb..9a05ef5 100755
--- a/setup.py
+++ b/setup.py
@@ -7,7 +7,8 @@ import time
 from setuptools import Extension, dist, find_packages, setup
 
 import torch
-from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+from torch.utils.cpp_extension import (BuildExtension, CppExtension,
+                                       CUDAExtension)
 
 dist.Distribution().fetch_build_eggs(['Cython', 'numpy>=1.11.1'])
 import numpy as np  # noqa: E402, isort:skip
@@ -96,24 +97,26 @@ def get_version():
 def make_cuda_ext(name, module, sources):
 
     define_macros = []
+    extra_compile_args = {'cxx': []}
 
     if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
-        define_macros += [("WITH_CUDA", None)]
+        define_macros += [('WITH_CUDA', None)]
+        extension = CUDAExtension
+        extra_compile_args['nvcc'] = [
+            '-D__CUDA_NO_HALF_OPERATORS__',
+            '-D__CUDA_NO_HALF_CONVERSIONS__',
+            '-D__CUDA_NO_HALF2_OPERATORS__',
+        ]
+        sources += sources_cuda
     else:
-        raise EnvironmentError('CUDA is required to compile MMDetection!')
+        print(f'Compiling {name} without CUDA')
+        extension = CppExtension
 
-    return CUDAExtension(
-        name='{}.{}'.format(module, name),
+    return extension(
+        name=f'{module}.{name}',
         sources=[os.path.join(*module.split('.'), p) for p in sources],
         define_macros=define_macros,
-        extra_compile_args={
-            'cxx': [],
-            'nvcc': [
-                '-D__CUDA_NO_HALF_OPERATORS__',
-                '-D__CUDA_NO_HALF_CONVERSIONS__',
-                '-D__CUDA_NO_HALF2_OPERATORS__',
-            ]
-        })
+        extra_compile_args=extra_compile_args)
 
 
 def make_cython_ext(name, module, sources):
@@ -245,57 +248,6 @@ if __name__ == '__main__':
             'optional': parse_requirements('requirements/optional.txt'),
         },
         ext_modules=[
-            make_cuda_ext(
-                name='compiling_info',
-                module='mmdet.ops.utils',
-                sources=['src/compiling_info.cpp']),
-            make_cython_ext(
-                name='soft_nms_cpu',
-                module='mmdet.ops.nms',
-                sources=['src/soft_nms_cpu.pyx']),
-            make_cuda_ext(
-                name='nms_cpu',
-                module='mmdet.ops.nms',
-                sources=['src/nms_cpu.cpp']),
-            make_cuda_ext(
-                name='nms_cuda',
-                module='mmdet.ops.nms',
-                sources=['src/nms_cuda.cpp', 'src/nms_kernel.cu']),
-            make_cuda_ext(
-                name='roi_align_cuda',
-                module='mmdet.ops.roi_align',
-                sources=['src/roi_align_cuda.cpp', 'src/roi_align_kernel.cu']),
-            make_cuda_ext(
-                name='roi_pool_cuda',
-                module='mmdet.ops.roi_pool',
-                sources=['src/roi_pool_cuda.cpp', 'src/roi_pool_kernel.cu']),
-            make_cuda_ext(
-                name='deform_conv_cuda',
-                module='mmdet.ops.dcn',
-                sources=[
-                    'src/deform_conv_cuda.cpp',
-                    'src/deform_conv_cuda_kernel.cu'
-                ]),
-            make_cuda_ext(
-                name='deform_pool_cuda',
-                module='mmdet.ops.dcn',
-                sources=[
-                    'src/deform_pool_cuda.cpp',
-                    'src/deform_pool_cuda_kernel.cu'
-                ]),
-            make_cuda_ext(
-                name='sigmoid_focal_loss_cuda',
-                module='mmdet.ops.sigmoid_focal_loss',
-                sources=[
-                    'src/sigmoid_focal_loss.cpp',
-                    'src/sigmoid_focal_loss_cuda.cu'
-                ]),
-            make_cuda_ext(
-                name='masked_conv2d_cuda',
-                module='mmdet.ops.masked_conv',
-                sources=[
-                    'src/masked_conv2d_cuda.cpp', 'src/masked_conv2d_kernel.cu'
-                ]),
         ],
         cmdclass={'build_ext': BuildExtension},
         zip_safe=False)