import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.parameter import Parameter
import torch_npu
from torch_npu.utils._error_code import ErrCode, ops_error

__all__ = ["LinearQuant"]


class LinearQuant(nn.Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = xA + b`

    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``
        offset: If set to ``True``, the layer will learn an additive offset.
            Default: ``False``
    Shape:
        - Input: :math:`(*, H_{in})` where :math:`*` means any number of
          dimensions including none and :math:`H_{in} = \text{in\_features}`.
        - Output: :math:`(*, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.

    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        scale: quant matmul calculation parameter
        offset: quant matmul calculation parameter
        pertoken_scale: inverse quant matmul calculation parameter
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`

    A4W4 Examples::
        >>> x1 = torch.randint(-1, 1, (1, 2), dtype=torch.int32).npu()
        >>> x2 = torch.randint(-1, 1, (128, 2), dtype=torch.int32).npu()
        >>> scale = torch.randn(1, dtype=torch.float32).npu()
        >>> model = LinearQuant(in_features, out_features, False)
        >>> model = model.npu()
        >>> model.weight.data = x2
        >>> model.scale.data = scale
        >>> output = model(x1)
        >>> print(output.size())
        torch.Size(1, 128)

    A8W8 Examples::
        >>> x1 = torch.randint(-1, 1, (1, 5), dtype=torch.int8).npu()
        >>> x2 = torch.randint(-1, 1, (127, 5), dtype=torch.int8).npu()
        >>> scale = torch.randn(1, dtype=torch.float32).npu()
        >>> model = LinearQuant(in_features, out_features, False)
        >>> model = model.npu()
        >>> model.weight.data = x2
        >>> model.scale.data = scale
        >>> output = model(x1)
        >>> print(output.size())
        torch.Size(1, 127)
    """
    in_features: int
    out_features: int
    weight: Tensor
    scale: Tensor
    offset: Tensor
    pertoken_scale: Tensor
    bias: Tensor

    def __init__(self, in_features: int, out_features: int, *, bias: bool = True, offset: bool = False,
                 pertoken_scale: bool = False, device=None, dtype=None, output_dtype=None) -> None:

        super(LinearQuant, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.empty((out_features, in_features)), False)
        self.scale = Parameter(torch.empty(out_features), False)
        self.output_dtype = output_dtype
        if offset:
            self.offset = Parameter(torch.empty(out_features, dtype=torch.float32), False)
        else:
            self.register_parameter('offset', None)

        if pertoken_scale:
            self.pertoken_scale = Parameter(torch.empty(out_features, dtype=torch.float32), False)
        else:
            self.register_parameter('pertoken_scale', None)

        if bias:
            self.bias = Parameter(torch.empty(out_features, dtype=torch.int32), False)
        else:
            self.register_parameter('bias', None)

    def forward(self, linear_quant_input: Tensor) -> Tensor:
        scale_quant = self.scale
        first_last_dim = self.weight.dim() - 1
        second_last_dim = self.weight.dim() - 2
        if not ((linear_quant_input.dtype == torch.int32 and self.weight.dtype == torch.int32) or
                (linear_quant_input.dtype == torch.int8 and self.weight.dtype == torch.int8)):
            raise ValueError("input and weight should be both torch.int32 or both torch.int8 datatype, "
                             f"but now input is {linear_quant_input.dtype}, weight is {self.weight.dtype}." + ops_error(ErrCode.TYPE))

        is_check_dtype_ok = (self.scale.dtype == torch.float32 and 
                             self.output_dtype not in [torch.bfloat16, torch.int32])
        if self.pertoken_scale is None and is_check_dtype_ok:
            scale_quant = torch_npu.npu_trans_quant_param(self.scale, self.offset)

        return torch_npu.npu_quant_matmul(linear_quant_input, self.weight.transpose(second_last_dim, first_last_dim),
                                          scale_quant, offset=self.offset, pertoken_scale=self.pertoken_scale, bias=self.bias,
                                          output_dtype=self.output_dtype)