05360171创建于 2022年3月18日历史提交
diff --git a/dcn_v2.py b/dcn_v2.py

index 982bef5..db33229 100644

--- a/dcn_v2.py

+++ b/dcn_v2.py



@@ -14,16 +15,38 @@ import _ext as _backend





 class _DCNv2(Function):

+

+    @staticmethod

+    def symbolic(g, input, weight, offset, bias, stride, padding,

+                 dilation, groups, defomable_groups):

+        if isinstance(stride, int):

+            stride = (stride, stride)

+        if isinstance(padding, int):

+            padding = (padding, padding)

+        if isinstance(dilation, int):

+            dilation = (dilation, dilation)

+        return g.op(

+            'DeformableConv2D',

+            input,

+            weight,

+            offset,

+            bias,

+            strides_i=stride,

+            pads_i=padding,

+            dilations_i=dilation,

+            groups_i=groups,

+            defomable_groups_i=defomable_groups)

     @staticmethod

-    def forward(ctx, input, offset, mask, weight, bias,

-                stride, padding, dilation, deformable_groups):

+    def forward(ctx, input, weight, offest, bias,

+                stride, padding, dilation, groups=1, deformable_groups=1):

         ctx.stride = _pair(stride)

         ctx.padding = _pair(padding)

         ctx.dilation = _pair(dilation)

         ctx.kernel_size = _pair(weight.shape[2:4])

         ctx.deformable_groups = deformable_groups

-        output = _backend.dcn_v2_forward(input, weight, bias,

-                                         offset, mask,

+        return torch.rand(_DCNv2._infer_shape(ctx, input, weight)).to(input.device)

+        output = _backend.dcn_v2_forward(input.float(), weight.float(), bias.float(),

+                                         offset.float(), mask.float(),

                                          ctx.kernel_size[0], ctx.kernel_size[1],

                                          ctx.stride[0], ctx.stride[1],

                                          ctx.padding[0], ctx.padding[1],

@@ -31,15 +54,26 @@ class _DCNv2(Function):

                                          ctx.deformable_groups)

         ctx.save_for_backward(input, offset, mask, weight, bias)

         return output

+    @staticmethod

+    def _infer_shape(ctx, input, weight):

+        n = input.size(0)

+        channels_out = weight.size(0)

+        height, width = input.shape[2:4]

+        kernel_h, kernel_w = weight.shape[2:4]

+        height_out = (height + 2 * ctx.padding[0] -

+                      (ctx.dilation[0] * (kernel_h - 1) + 1)) // ctx.stride[0] + 1

+        width_out = (width + 2 * ctx.padding[0] -

+                     (ctx.dilation[0] * (kernel_w - 1) + 1)) // ctx.stride[0] + 1

+        return n, channels_out, height_out, width_out



     @staticmethod

     @once_differentiable

     def backward(ctx, grad_output):

         input, offset, mask, weight, bias = ctx.saved_tensors

         grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \

-            _backend.dcn_v2_backward(input, weight,

-                                     bias,

-                                     offset, mask,

+            _backend.dcn_v2_backward(input.float(), weight.float(),

+                                     bias.float(),

+                                     offset.float(), mask.float(),

                                      grad_output,

                                      ctx.kernel_size[0], ctx.kernel_size[1],

                                      ctx.stride[0], ctx.stride[1],

@@ -120,11 +154,19 @@ class DCN(DCNv2):

         o1, o2, mask = torch.chunk(out, 3, dim=1)

         offset = torch.cat((o1, o2), dim=1)

         mask = torch.sigmoid(mask)

-        return dcn_v2_conv(input, offset, mask,

-                           self.weight, self.bias,

+        offset_y = offset.reshape(offset.shape[0], -1, 2, offset.shape[2],

+                offset.shape[3])[:, :, 0, ...].reshape(offset.shape[0], offset.shape[1] // 2, offset.shape[2],

+                    offset.shape[3])

+        offset_x = offset.reshape(offset.shape[0], -1, 2, offset.shape[2],

+                offset.shape[3])[:, :, 1, ...].reshape(offset.shape[0], offset.shape[1] // 2, offset.shape[2],

+                    offset.shape[3])

+        offset = torch.cat((offset_x, offset_y, mask), 1)

+        return dcn_v2_conv(input,

+                           self.weight, offset, self.bias,

                            self.stride,

                            self.padding,

                            self.dilation,

+                           1,

                            self.deformable_groups)





@@ -300,4 +342,4 @@ class DCNPooling(DCNv2Pooling):

                               self.group_size,

                               self.part_size,

                               self.sample_per_part,

-                              self.trans_std)

+                              self.trans_std)

\ No newline at end of file