modulated_deform_conv2d(ModulatedDeformConv2dFunction.apply)
接口原型
mx_driving.modulated_deform_conv2d(Tensor x, Tensor offset, Tensor mask, Tensor weight, Tensor bias, Union[int, Tuple[int, ...]] stride, Union[int, Tuple[int, ...]] padding, Union[int, Tuple[int, ...]] dilation, int groups, int deformable_groups) -> Tensor
功能描述
在可变形卷积的基础之上加上了 modulation 机制,通过调控输出特征的幅度,提升可变形卷积的聚焦相关区域的能力。
参数说明
x(Tensor):输入特征,数据类型为float32,shape为(n, c_in, h_in, w_in),其中n为 batch size,c_in为输入特征的通道数量,h_in为输入特征图的高,w_in为输入特征图的宽。offset(Tensor):偏移量,数据类型为float32,shape 为(n, 2 * k * k, h_out, w_out),其中n为 batch size,k为卷积核大小,h_out为输出特征图高,w_out为输出特征图的宽。mask(Tensor):掩码,用于调控输出特征的幅度,数据类型为float32,shape 为(n, k * k, h_out, w_out),其中n为 batch size,k 为卷积核大小,h_out为输出特征图高,w_out为输出特征图的宽。weight(Tensor):卷积核权重,数据类型为float32,shape 为(c_out, c_in, k, k),其中c_out为输出的通道数,c_in为输入的通道数,k为卷积核大小。bias(Tensor):偏置,暂不支持bias,传入None即可。stride(Union):卷积步长。padding(Union):卷积的填充大小。dilation(Union):空洞卷积大小。groups(int):分组卷积大小,需要可以整除c_in,c_out。deformable_groups(int):将通道分成几组计算offsets,当前只支持1。
返回值
output(Tensor):输出张量,数据类型为float32,shape 为(n, c_out, h_out, w_out),其中n为 batch size,c_out为输出通道,h_out为输出特征图高,w_out为输出特征图的宽。
支持的型号
- Atlas A2 训练系列产品
约束说明
deformable_groups当前只支持1。h_in,w_in,h_out,w_out需满足
wout=(win+2∗padding−(dilation∗(k−1)+1))/stride+1hout=(hin+2∗padding−(dilation∗(k−1)+1))/stride+1w_{out}=(w_{in}+ 2 * padding - (dilation * (k - 1) + 1)) / stride + 1 \\ h_{out}=(h_{in}+ 2 * padding - (dilation * (k - 1) + 1)) / stride + 1
c_in需要为64的倍数。
调用示例
import torch
import torch_npu
from mx_driving import modulated_deform_conv2d, ModulatedDeformConv2dFunction
n, c_in, h_in, w_in = 16, 64, 100, 200
c_out, k, h_out, w_out = 64, 3, 50, 100
x = torch.randn((n, c_in, h_in, w_in)).npu()
offset = torch.randn((n, 2 * k * k, h_out, w_out)).npu()
mask = torch.randn((n, k * k, h_out, w_out)).npu()
weight = torch.randn((c_out, c_in, k, k)).npu()
bias = None
stride = 1
padding = 1
dilation = 1
groups = 1
deformable_groups = 1
output = modulated_deform_conv2d(x, offset, mask, weight, bias,
stride, padding, dilation, groups, deformable_groups)
output = ModulatedDeformConv2dFunction.apply(x, offset, mask, weight, bias,
stride, padding, dilation, groups, deformable_groups)