import math
from typing import Optional, Tuple
import torch
import torchvision
from torch import nn, Tensor
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops
import torch_npu
def deform_conv2d(
input: Tensor,
offset: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
dilation: Tuple[int, int] = (1, 1),
mask: Optional[Tensor] = None,
) -> Tensor:
_assert_has_ops()
out_channels = weight.shape[0]
use_mask = mask is not None
if mask is None:
mask = torch.zeros((input.shape[0], 0), device=input.device, dtype=input.dtype)
if bias is None:
bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)
stride_h, stride_w = _pair(stride)
pad_h, pad_w = _pair(padding)
dil_h, dil_w = _pair(dilation)
weights_h, weights_w = weight.shape[-2:]
_, n_in_channels, in_h, in_w = input.shape
n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w)
n_weight_grps = n_in_channels // weight.shape[1]
if n_offset_grps == 0:
raise RuntimeError(
"the shape of the offset tensor at dimension 1 is not valid. It should "
"be a multiple of 2 * weight.size[2] * weight.size[3].\n"
"Got offset.shape[1]={}, while 2 * weight.size[2] * weight.size[3]={}".format(
offset.shape[1], 2 * weights_h * weights_w))
if not input.is_npu:
return torch.ops.torchvision.deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h, stride_w,
pad_h, pad_w,
dil_h, dil_w,
n_weight_grps,
n_offset_grps,
use_mask,)
else:
return npu_deform_conv2d(
input,
offset,
weight,
bias,
(stride_h, stride_w),
(pad_h, pad_w),
(dil_h, dil_w),
mask if use_mask else None,
n_weight_grps,
n_offset_grps
)
def npu_deform_conv2d(
conv_input: Tensor,
offset: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
dilation: Tuple[int, int] = (1, 1),
mask: Optional[Tensor] = None,
groups: Optional[int] = 1,
deform_groups: Optional[int] = 1):
_, _, kernel_h, kernel_w = weight.shape
conv2d_bias = bias
sort_index_fp, sort_index_bp = _calculate_sort_index(
kernel_h, kernel_w, deform_groups)
select_offset = offset.index_select(1, sort_index_fp)
if mask is None:
mask_shape, _ = torch.chunk(offset, 2, dim=1)
mask = torch.ones_like(mask_shape).to(conv_input.device)
offset_all = torch.cat([select_offset, mask], dim=1)
output, offset_out = torch_npu.npu_deformable_conv2d(
conv_input,
weight,
offset_all,
conv2d_bias,
kernel_size=[kernel_h, kernel_w],
stride=[1, 1, stride[0], stride[1]],
padding=[padding[0], padding[0], padding[1], padding[1]],
dilation=[1, 1, dilation[0], dilation[1]],
groups=groups,
deformable_groups=deform_groups,
modulated=True)
return output
def _calculate_sort_index(kernel_h, kernel_w, deformable_group):
split_num = deformable_group * 2 * kernel_h * kernel_w
sort_index = list(range(split_num))
sort_index_fp = (sort_index[1::2] + sort_index[::2])
sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index_fp)}
sort_index_bp = [sort_index_bp_dict[i] for i in sort_index]
sort_index_fp = torch.IntTensor(sort_index_fp)
sort_index_bp = torch.IntTensor(sort_index_bp)
sort_index_fp = sort_index_fp.npu()
sort_index_bp = sort_index_bp.npu()
return sort_index_fp, sort_index_bp
def patch_deform_conv():
torchvision.ops.deform_conv2d = deform_conv2d
torchvision.ops.deform_conv.deform_conv2d = deform_conv2d