import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.parameter import Parameter
import torch_npu

__all__ = ["LinearWeightQuant"]


class LinearWeightQuant(nn.Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = xA + b`

    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``
        offset: If set to ``True``, the layer will learn an additive offset.
            Default: ``False``
    Shape:
        - Input: :math:`(*, H_{in})` where :math:`*` means any number of
          dimensions including none and :math:`H_{in} = \text{in\_features}`.
        - Output: :math:`(*, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.

    Attributes:
        weight: the learnable weights of the module of shape
        antiquant_scale: weight quant batchmatmul calculation parameter
        antiquant_offset: weight quant batchmatmul calculation parameter
        quant_scale: weight quant batchmatmul calculation parameter
        quant_offset: weight quant batchmatmul calculation parameter
        bias: the learnable bias of the module of shape
        antiquant_group_size: size of group in antiquant calculation
        weight_dtype: dtype of weight

    Examples::
        >>> x = torch.randn((16, 32), dtype=torch.float16).npu()
        >>> weight = torch.randint(-3, 3, (128, 32), dtype=torch.int8).npu()
        >>> antiquant_scale = torch.randn((128), dtype=torch.float16).npu()
        >>> model = LinearWeightQuant(32, 128, False)
        >>> model = model.npu()
        >>> model.weight.data = weight
        >>> model.antiquant_scale.data = antiquant_scale
        >>> output = model(x)
        >>> print(output.size())
        torch.Size(16, 128)
    """
    in_features: int
    out_features: int
    weight: Tensor
    antiquant_scale: Tensor
    antiquant_offset: Tensor
    quant_scale: Tensor
    quant_offset: Tensor

    def __init__(
        self,
        in_features,
        out_features,
        bias: bool = True,
        device=None,
        dtype=None,
        antiquant_offset: bool = False,
        quant_scale: bool = False,
        quant_offset: bool = False,
        antiquant_group_size: int = 0,
        inner_precise: int = 0,
        weight_dtype=None
    ) -> None:
        super(LinearWeightQuant, self).__init__()
        self.weight = Parameter(torch.empty((out_features, in_features), device=device), False)
        self.antiquant_scale = Parameter(torch.empty(out_features, device=device), False)

        if antiquant_offset:
            self.antiquant_offset = Parameter(torch.empty(out_features, device=device), False)
        else:
            self.register_parameter('antiquant_offset', None)

        if quant_scale:
            self.quant_scale = Parameter(torch.empty(out_features, device=device), False)
        else:
            self.register_parameter('quant_scale', None)

        if quant_offset:
            self.quant_offset = Parameter(torch.empty(out_features, device=device), False)
        else:
            self.register_parameter('quant_offset', None)

        if bias:
            self.bias = Parameter(torch.empty(out_features, device=device), False)
        else:
            self.register_parameter('bias', None)

        self.antiquant_group_size = antiquant_group_size
        self.inner_precise = inner_precise
        self.weight_dtype = weight_dtype

    def forward(self, x: Tensor) -> Tensor:
        antiquant_scale = self.antiquant_scale
        antiquant_offset = self.antiquant_offset
        if self.antiquant_scale.dim() == 2:
            antiquant_scale = self.antiquant_scale.transpose(-1, -2)
        if self.antiquant_offset is not None:
            if self.antiquant_offset.dim() == 2:
                antiquant_offset = self.antiquant_offset.transpose(-1, -2)
        return torch_npu.npu_weight_quant_batchmatmul(x, self.weight.transpose(-1, -2), antiquant_scale,
                                                      antiquant_offset, self.quant_scale, self.quant_offset,
                                                      self.bias, self.antiquant_group_size, self.inner_precise,
                                                      weight_dtype=self.weight_dtype)