05360171创建于 2022年3月18日历史提交
diff --git a/dcn_v2.py b/dcn_v2.py
index 982bef5..db33229 100644
--- a/dcn_v2.py
+++ b/dcn_v2.py

@@ -14,16 +15,38 @@ import _ext as _backend
 
 
 class _DCNv2(Function):
+
+    @staticmethod
+    def symbolic(g, input, weight, offset, bias, stride, padding,
+                 dilation, groups, defomable_groups):
+        if isinstance(stride, int):
+            stride = (stride, stride)
+        if isinstance(padding, int):
+            padding = (padding, padding)
+        if isinstance(dilation, int):
+            dilation = (dilation, dilation)
+        return g.op(
+            'DeformableConv2D',
+            input,
+            weight,
+            offset,
+            bias,
+            strides_i=stride,
+            pads_i=padding,
+            dilations_i=dilation,
+            groups_i=groups,
+            defomable_groups_i=defomable_groups)
     @staticmethod
-    def forward(ctx, input, offset, mask, weight, bias,
-                stride, padding, dilation, deformable_groups):
+    def forward(ctx, input, weight, offest, bias,
+                stride, padding, dilation, groups=1, deformable_groups=1):
         ctx.stride = _pair(stride)
         ctx.padding = _pair(padding)
         ctx.dilation = _pair(dilation)
         ctx.kernel_size = _pair(weight.shape[2:4])
         ctx.deformable_groups = deformable_groups
-        output = _backend.dcn_v2_forward(input, weight, bias,
-                                         offset, mask,
+        return torch.rand(_DCNv2._infer_shape(ctx, input, weight)).to(input.device)
+        output = _backend.dcn_v2_forward(input.float(), weight.float(), bias.float(),
+                                         offset.float(), mask.float(),
                                          ctx.kernel_size[0], ctx.kernel_size[1],
                                          ctx.stride[0], ctx.stride[1],
                                          ctx.padding[0], ctx.padding[1],
@@ -31,15 +54,26 @@ class _DCNv2(Function):
                                          ctx.deformable_groups)
         ctx.save_for_backward(input, offset, mask, weight, bias)
         return output
+    @staticmethod
+    def _infer_shape(ctx, input, weight):
+        n = input.size(0)
+        channels_out = weight.size(0)
+        height, width = input.shape[2:4]
+        kernel_h, kernel_w = weight.shape[2:4]
+        height_out = (height + 2 * ctx.padding[0] -
+                      (ctx.dilation[0] * (kernel_h - 1) + 1)) // ctx.stride[0] + 1
+        width_out = (width + 2 * ctx.padding[0] -
+                     (ctx.dilation[0] * (kernel_w - 1) + 1)) // ctx.stride[0] + 1
+        return n, channels_out, height_out, width_out
 
     @staticmethod
     @once_differentiable
     def backward(ctx, grad_output):
         input, offset, mask, weight, bias = ctx.saved_tensors
         grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \
-            _backend.dcn_v2_backward(input, weight,
-                                     bias,
-                                     offset, mask,
+            _backend.dcn_v2_backward(input.float(), weight.float(),
+                                     bias.float(),
+                                     offset.float(), mask.float(),
                                      grad_output,
                                      ctx.kernel_size[0], ctx.kernel_size[1],
                                      ctx.stride[0], ctx.stride[1],
@@ -120,11 +154,19 @@ class DCN(DCNv2):
         o1, o2, mask = torch.chunk(out, 3, dim=1)
         offset = torch.cat((o1, o2), dim=1)
         mask = torch.sigmoid(mask)
-        return dcn_v2_conv(input, offset, mask,
-                           self.weight, self.bias,
+        offset_y = offset.reshape(offset.shape[0], -1, 2, offset.shape[2],
+                offset.shape[3])[:, :, 0, ...].reshape(offset.shape[0], offset.shape[1] // 2, offset.shape[2],
+                    offset.shape[3])
+        offset_x = offset.reshape(offset.shape[0], -1, 2, offset.shape[2],
+                offset.shape[3])[:, :, 1, ...].reshape(offset.shape[0], offset.shape[1] // 2, offset.shape[2],
+                    offset.shape[3])
+        offset = torch.cat((offset_x, offset_y, mask), 1)
+        return dcn_v2_conv(input,
+                           self.weight, offset, self.bias,
                            self.stride,
                            self.padding,
                            self.dilation,
+                           1,
                            self.deformable_groups)
 
 
@@ -300,4 +342,4 @@ class DCNPooling(DCNv2Pooling):
                               self.group_size,
                               self.part_size,
                               self.sample_per_part,
-                              self.trans_std)
+                              self.trans_std)
\ No newline at end of file