@@ -14,16 +15,38 @@ import _ext as _backend
class _DCNv2(Function):
+
+ @staticmethod
+ def symbolic(g, input, weight, offset, bias, stride, padding,
+ dilation, groups, defomable_groups):
+ if isinstance(stride, int):
+ stride = (stride, stride)
+ if isinstance(padding, int):
+ padding = (padding, padding)
+ if isinstance(dilation, int):
+ dilation = (dilation, dilation)
+ return g.op(
+ 'DeformableConv2D',
+ input,
+ weight,
+ offset,
+ bias,
+ strides_i=stride,
+ pads_i=padding,
+ dilations_i=dilation,
+ groups_i=groups,
+ defomable_groups_i=defomable_groups)
@staticmethod
- def forward(ctx, input, offset, mask, weight, bias,
- stride, padding, dilation, deformable_groups):
+ def forward(ctx, input, weight, offest, bias,
+ stride, padding, dilation, groups=1, deformable_groups=1):
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.kernel_size = _pair(weight.shape[2:4])
ctx.deformable_groups = deformable_groups
- output = _backend.dcn_v2_forward(input, weight, bias,
- offset, mask,
+ return torch.rand(_DCNv2._infer_shape(ctx, input, weight)).to(input.device)
+ output = _backend.dcn_v2_forward(input.float(), weight.float(), bias.float(),
+ offset.float(), mask.float(),
ctx.kernel_size[0], ctx.kernel_size[1],
ctx.stride[0], ctx.stride[1],
ctx.padding[0], ctx.padding[1],
@@ -31,15 +54,26 @@ class _DCNv2(Function):
ctx.deformable_groups)
ctx.save_for_backward(input, offset, mask, weight, bias)
return output
+ @staticmethod
+ def _infer_shape(ctx, input, weight):
+ n = input.size(0)
+ channels_out = weight.size(0)
+ height, width = input.shape[2:4]
+ kernel_h, kernel_w = weight.shape[2:4]
+ height_out = (height + 2 * ctx.padding[0] -
+ (ctx.dilation[0] * (kernel_h - 1) + 1)) // ctx.stride[0] + 1
+ width_out = (width + 2 * ctx.padding[0] -
+ (ctx.dilation[0] * (kernel_w - 1) + 1)) // ctx.stride[0] + 1
+ return n, channels_out, height_out, width_out
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \
- _backend.dcn_v2_backward(input, weight,
- bias,
- offset, mask,
+ _backend.dcn_v2_backward(input.float(), weight.float(),
+ bias.float(),
+ offset.float(), mask.float(),
grad_output,
ctx.kernel_size[0], ctx.kernel_size[1],
ctx.stride[0], ctx.stride[1],
@@ -120,11 +154,19 @@ class DCN(DCNv2):
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
- return dcn_v2_conv(input, offset, mask,
- self.weight, self.bias,
+ offset_y = offset.reshape(offset.shape[0], -1, 2, offset.shape[2],
+ offset.shape[3])[:, :, 0, ...].reshape(offset.shape[0], offset.shape[1] // 2, offset.shape[2],
+ offset.shape[3])
+ offset_x = offset.reshape(offset.shape[0], -1, 2, offset.shape[2],
+ offset.shape[3])[:, :, 1, ...].reshape(offset.shape[0], offset.shape[1] // 2, offset.shape[2],
+ offset.shape[3])
+ offset = torch.cat((offset_x, offset_y, mask), 1)
+ return dcn_v2_conv(input,
+ self.weight, offset, self.bias,
self.stride,
self.padding,
self.dilation,
+ 1,
self.deformable_groups)
@@ -300,4 +342,4 @@ class DCNPooling(DCNv2Pooling):
self.group_size,
self.part_size,
self.sample_per_part,
- self.trans_std)
+ self.trans_std)
\ No newline at end of file