import math
import numpy as np
import torch
from torch.nn import init
from torch.nn.init import calculate_gain
from torch.nn.parameter import Parameter
from ..ops import sparse_functional as Fsp
from .sparse_modules import SparseModule
from .sparse_structure import SparseConvTensor, IndiceData
def get_conv_output_size(input_size, kernel_size, stride, padding, dilation):
ndim = len(input_size)
output_size = []
for i in range(ndim):
size = (input_size[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[i] + 1
if kernel_size[i] == -1:
output_size.append(1)
else:
output_size.append(size)
return output_size
def get_inverse_conv_output_size(input_size, kernel_size, stride, padding, dilation, output_padding):
ndim = len(input_size)
output_size = []
for i in range(ndim):
size = (
(input_size[i] - 1) * stride[i]
- 2 * padding[i]
+ dilation[i] * (kernel_size[i] - 1)
+ output_padding[i]
+ 1
)
if kernel_size[i] == -1:
output_size.append(1)
else:
output_size.append(size)
return output_size
def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation, output_padding):
ndim = len(input_size)
output_size = []
for i in range(ndim):
if kernel_size[i] == -1:
raise ValueError("deconv don't support kernel_size < 0")
size = (input_size[i] - 1) * stride[i] - 2 * padding[i] + kernel_size[i] + output_padding[i]
output_size.append(size)
return output_size
def _calculate_fan_in_and_fan_out_hwio(tensor):
dimensions = tensor.ndimension()
if dimensions < 2:
raise ValueError("fan in and fan out can not be computed for tensor" "with fewer than 2 dimensions")
if dimensions == 2:
fan_in = tensor.size(-2)
fan_out = tensor.size(-1)
else:
num_input_fmaps = tensor.size(-2)
num_output_fmaps = tensor.size(-1)
receptive_field_size = 1
if tensor.dim() > 2:
receptive_field_size = tensor[..., 0, 0].numel()
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
class SparseConvolution(SparseModule):
def __init__(
self,
ndim,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
subm=False,
output_padding=0,
transposed=False,
inverse=False,
indice_key=None,
fused_bn=False,
mode="mmcv",
):
super().__init__()
if groups != 1:
raise RuntimeError("do not support group == 1")
if not isinstance(kernel_size, (list, tuple)):
kernel_size = [kernel_size] * ndim
if not isinstance(stride, (list, tuple)):
stride = [stride] * ndim
if not isinstance(padding, (list, tuple)):
padding = [padding] * ndim
if not isinstance(dilation, (list, tuple)):
dilation = [dilation] * ndim
if not isinstance(output_padding, (list, tuple)):
output_padding = [output_padding] * ndim
for d, s in zip(dilation, stride):
if not any([s == 1, d == 1]):
raise RuntimeError("do not support s == 1, d == 1")
self.ndim = ndim
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.conv1x1 = np.prod(kernel_size) == 1
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = transposed
self.inverse = inverse
self.output_padding = output_padding
self.groups = groups
self.subm = subm
self.indice_key = indice_key
self.fused_bn = fused_bn
self.mode = mode
self.weight = Parameter(torch.Tensor(*kernel_size, in_channels, out_channels))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
fan_in, fan_out = _calculate_fan_in_and_fan_out_hwio(self.weight)
if self.mode == "mmcv":
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
else:
self._custom_kaiming_uniform_(self.weight, a=math.sqrt(5), fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
if fan_in == 0:
bound = 0
else:
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def _custom_kaiming_uniform_(self, tensor, a=0, fan_in=0, fan_out=0, mode="fan_in", nonlinearity="leaky_relu"):
fan = 0.0
if mode == "fan_in":
fan = float(fan_in)
elif mode == "fan_out":
fan = float(fan_out)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std
with torch.no_grad():
tensor.uniform_(-bound, bound)
tensor.data = (
tensor.data.reshape(self.out_channels, np.prod(self.kernel_size) * self.in_channels)
.transpose(-1, -2)
.contiguous()
)
tensor.data = tensor.data.reshape(*self.kernel_size, self.in_channels, self.out_channels)
def forward(self, input_):
if not isinstance(input_, SparseConvTensor):
raise RuntimeError("input_ is not SparseConvTensor")
if self.inverse:
indice_data = input_.find_indice_pair(self.indice_key)
out_features = Fsp.indice_inverse_conv(
input_.features,
self.weight,
self.in_channels,
self.out_channels,
self.kernel_size,
indice_data,
)
outidx = indice_data.origin_indices
out_spatial_shape = indice_data.origin_spatial_shape
elif not self.subm:
out_spatial_shape = get_conv_output_size(
input_.spatial_shape, self.kernel_size, self.stride, self.padding, self.dilation
)
out_spatial_shape = [int(i) for i in out_spatial_shape]
if not isinstance(out_spatial_shape, list):
out_spatial_shape = out_spatial_shape.tolist()
out_features, outidx, unique_indices_offset, sorted_idx_to_former_indices, outidx_pair = Fsp.indice_conv(
input_.features,
input_.indices,
self.weight.to(input_.features.dtype),
out_spatial_shape,
self.out_channels,
input_.batch_size,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
self.groups,
self.bias,
)
if self.indice_key is not None:
indice_data = IndiceData(
input_.spatial_shape,
input_.indices,
unique_indices_offset,
sorted_idx_to_former_indices,
outidx_pair,
False,
)
if hasattr(input_, 'indice_dict'):
input_.indice_dict[self.indice_key] = indice_data
else:
out_spatial_shape = input_.spatial_shape
out_spatial_shape = [int(i) for i in out_spatial_shape]
if not isinstance(out_spatial_shape, list):
out_spatial_shape = out_spatial_shape.tolist()
indices_offset = input_.find_indice_pair(self.indice_key)
out_features, outidx, outidx_offset = Fsp.indice_subm_conv(
input_.features,
input_.indices,
self.weight.to(input_.features.dtype),
indices_offset,
out_spatial_shape,
self.out_channels,
input_.batch_size,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
self.groups,
self.bias,
)
if indices_offset is None:
input_.indice_dict[self.indice_key] = outidx_offset
if self.bias is not None:
out_features += self.bias
out_tensor = SparseConvTensor(out_features, outidx, out_spatial_shape, input_.batch_size)
if input_ is not None and hasattr(input_, 'indice_dict'):
out_tensor.indice_dict = input_.indice_dict.copy()
return out_tensor
class SparseConv3d(SparseConvolution):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
mode="mmcv",
):
super().__init__(
3,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
indice_key=indice_key,
mode=mode,
)
class SubMConv3d(SparseConvolution):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
mode="mmcv",
):
super().__init__(
3,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
True,
indice_key=indice_key,
mode=mode,
)
class SparseInverseConv3d(SparseConvolution):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
mode="mmcv",
):
super().__init__(
3,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
inverse=True,
indice_key=indice_key,
mode=mode,
)