import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.common_types import _size_2_t
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
import torch_npu
from torch_npu.utils._error_code import ErrCode, ops_error
__all__ = ['QuantConv2d']
class QuantConv2d(nn.Module):
r"""Applies Conv2d + Dequant: :math:`output = (fmap * weight + bias) * scale`
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolution kernel
output_dtype (int): Dtype of the output
support torch.float16, torch.bfloat16, torch.float32, torch_npu.hifloat8
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Padding added to all four sides of
the input. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input
channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the
output. Default: ``True``
offset (bool, optional): If set to ``True``, the layer will learn an additive offset.
Default: ``False``
offset_x (int, optional): Actual padding value. Default: 0
round_mode (str, optional): Requant calculation parameter. Default: "rint"
input_dtype (int, optional): Dtype of the fmap.
Only need to be set torch_npu.hifloat8 when the the dtype of input is hifloat8. Default: None
weight_dtype (int, optional): Dtype of the weight.
Only need to be set torch_npu.hifloat8 when the the dtype of weight is hifloat8. Default: None
Shape:
- Input: :math:`(N, C_{in}, H_{in}, W_{in})`
- Output: :math:`(N, C_{out}, H_{out}, W_{out})`, where
.. math::
H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
\times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
.. math::
W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
\times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
Attributes:
weight (Tensor): the learnable weights of the module of shape
:math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
:math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
The values of these weights are sampled from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
bias (Tensor): the learnable bias of the module of shape
(out_channels). If :attr:`bias` is ``True``,
then the values of these weights are
sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
scale (Tensor): Dequant calculation parameter of shape (out_channels)
offset (Tensor): Requant calculation parameter of shape (out_channels)
Examples::
>>> quant_conv2d_input = torch.randint(-1, 1, (1, 1, 4, 4), dtype=torch.int8)
>>> weight = torch.randint(-1, 1, (1, 1, 3, 3), dtype=torch.int8)
>>> scale = torch.randint(-1, 1, (1,), dtype=torch.int64)
>>> bias = torch.randint(-1, 1, (1,), dtype=torch.int32)
>>> model = QuantConv2d(in_channels, out_channels, k_size, output_dtype)
>>> model = model.npu()
>>> model.weight.data = weight
>>> model.scale.data = scale
>>> model.bias.data = bias
>>> config = CompilerConfig()
>>> npu_backend = tng.get_npu_backend(compiler_config=config)
>>> static_graph_model = torch.compile(model, backend=npu_backend, dynamic=False)
>>> output = static_graph_model(quant_conv2d_input)
>>> print(output.size())
torch.Size(1, 1, 2, 2)
"""
in_channels: int
out_channels: int
k_size: int
output_dtype: int
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
output_dtype: int,
stride: _size_2_t = 1,
padding: _size_2_t = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
offset: bool = False,
offset_x: int = 0,
round_mode: str = "rint",
input_dtype: int = None,
weight_dtype: int = None,
device=None,
dtype=None) -> None:
super(QuantConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.offset_x = offset_x
self.round_mode = round_mode
self.input_dtype = input_dtype
self.weight_dtype = weight_dtype
self.output_dtype = output_dtype
self.weight = \
Parameter(torch.empty((self.out_channels, self.in_channels // self.groups, *self.kernel_size)), False)
self.scale = Parameter(torch.empty(self.out_channels, dtype=torch.int64), False)
if bias:
self.bias = Parameter(torch.empty(self.out_channels), False)
else:
self.register_parameter('bias', None)
if offset:
raise ValueError("offset must be False" + ops_error(ErrCode.VALUE))
else:
self.register_parameter('offset', None)
def forward(self, quant_conv2d_input: Tensor) -> Tensor:
scale_ = self.scale
if self.scale.dtype == torch.float32:
scale_ = torch_npu.npu_trans_quant_param(self.scale)
return torch_npu.npu_quant_conv2d(quant_conv2d_input, self.weight, scale_, self.stride, self.padding,
self.dilation, self.groups, self.offset_x, self.round_mode, self.output_dtype,
self.bias, self.offset, self.input_dtype, self.weight_dtype)