import torch

import torch_npu
from torch_npu.utils._error_code import ErrCode, pta_error

__all__ = []


class _NPULinearOP(object):

    @staticmethod
    def forward(input_, weight, bias=None):
        if torch.onnx.is_in_onnx_export():
            return torch._C._nn.linear(input_, weight, bias)
        return torch.ops.npu.npu_linear(input_, weight, bias)


class _NPUTransposeOP(object):

    @staticmethod
    def forward(self, perm, require_contiguous=True, out=None):
        if torch.onnx.is_in_onnx_export():
            if require_contiguous:
                out = torch.permute(self, perm).contiguous()
            else:
                out = torch.permute(self, perm)
            return out
        out = torch.ops.npu.npu_transpose(
            self, perm, require_contiguous)
        return out


class _NPUBroadcastOP(object):

    @staticmethod
    def forward(self, size, out=None):
        if torch.onnx.is_in_onnx_export():
            out = self.expand(size)
            return out
        out = torch.ops.npu.npu_broadcast(self, size)
        return out


class _NPUConvTranspose2dOP(object):

    @staticmethod
    def forward(input_, weight, bias, padding, output_padding, stride, dilation, groups):
        if torch.onnx.is_in_onnx_export():
            return torch.conv_transpose2d(input_, weight, bias, stride, padding,
                                          output_padding, groups, dilation)
        return torch.ops.npu.npu_conv_transpose2d(input_, weight, bias,
                                                                         padding, output_padding,
                                                                         stride, dilation, groups)


class _NPUConv2dOP(object):

    @staticmethod
    def forward(input_, weight, bias, stride, padding, dilation, groups):
        if torch.onnx.is_in_onnx_export():
            return torch.conv2d(input_, weight, bias, stride, padding, dilation, groups)
        return torch.ops.npu.npu_conv2d(input_, weight, bias, stride,
                                                               padding, dilation, groups)


class _NPUConv3dOP(object):

    @staticmethod
    def forward(input_, weight, bias, stride, padding, dilation, groups):
        if torch.onnx.is_in_onnx_export():
            return torch.conv3d(input_, weight, bias, stride, padding, dilation, groups)
        return torch.ops.npu.npu_conv3d(input_, weight, bias, stride,
                                                               padding, dilation, groups)


class _NPUStrideCopyOP(object):

    @staticmethod
    def forward(self, shape, stride, storage_offset, out=None):
        if torch.onnx.is_in_onnx_export():
            out = torch.as_strided(self, shape, stride, 0).clone()
            return out
        out = torch.ops.npu.npu_stride_copy(self, shape, stride, storage_offset)
        return out


class _NPUSortV2OP(object):

    @staticmethod
    def forward(self, dim=-1, descending=False, out=None):
        if torch.onnx.is_in_onnx_export():
            out, indices = torch.sort(self, dim, descending)
            return out
        out = torch.ops.npu.npu_sort_v2(self, dim, descending)
        return out


class _NPULayerNormEvalOP(object):

    @staticmethod
    def forward(input_, normalized_shape, weight=None, bias=None, eps=1e-05):
        if torch.onnx.is_in_onnx_export():
            return torch.layer_norm(input_, normalized_shape, weight, bias, eps, False)
        return torch.ops.npu.npu_layer_norm_eval(input_, normalized_shape, 
                                                                   weight, bias, eps)


class _NPUReshapeOP(object):

    @staticmethod
    def forward(self, shape, can_refresh=False, out=None):
        if torch.onnx.is_in_onnx_export():
            if can_refresh:
                out = torch.reshape(self, shape).clone()
            else:
                out = torch.reshape(self, shape)
            return out
        out = torch.ops.npu.npu_reshape(self, shape, can_refresh)
        return out


class _NPUPadOP(object):

    @staticmethod
    def forward(input_, paddings):
        if torch.onnx.is_in_onnx_export():
            return torch.nn.functional.pad(input_, paddings[2:] + paddings[:2], "constant", 0)
        return torch.ops.npu.npu_pad(input_, paddings)


class _NPUConvolutionOP(object):

    @staticmethod
    def forward(input_, weight, bias, stride, padding, dilation, groups):
        if torch.onnx.is_in_onnx_export():
            dim = input_.dim()
            if dim == 4:
                output = torch.nn.functional.conv2d(input_, weight, bias, stride,
                                                    padding, dilation, groups)
            elif dim == 5:
                is_dilated = False
                for d in dilation:
                    is_dilated |= (d != 1)
                if groups == 1 and not is_dilated:
                    output = torch._C._nn.slow_conv3d(input_, weight, weight.size()[2],
                                                      bias, stride, padding)
                else:
                    output = torch.nn.functional.conv3d(input_, weight, bias, stride,
                                                        padding, dilation, groups)
            else:
                raise ValueError("input dim must be 4 or 5, but got ", dim, pta_error(ErrCode.VALUE))
            return output
        else:
            return torch.ops.npu.npu_convolution(input_, weight, bias,
                                                                   stride, padding, dilation, groups)


class _NPUConvolutionTransposeOP(object):

    @staticmethod
    def forward(input_, weight, bias, padding, output_padding, stride, dilation, groups):
        if torch.onnx.is_in_onnx_export():
            dim = input_.dim()
            if dim == 4:
                output = torch.conv_transpose2d(input_, weight, bias, stride,
                                                padding, output_padding, groups, dilation)
            elif dim == 5:
                output = torch.conv_transpose3d(input_, weight, bias, stride,
                                                padding, output_padding, groups, dilation)
            else:
                raise ValueError("input dim must be 4 or 5, but got ", dim, pta_error(ErrCode.value))
            return output
        else:
            return torch.ops.npu.npu_convolution_transpose(
                input_, weight, bias, padding, output_padding, stride, dilation, groups)


class _NPUConfusionTransposeOP(object):

    @staticmethod
    def forward(self, perm, shape, transpose_first):
        if torch.onnx.is_in_onnx_export():
            if transpose_first:
                return self.permute(*perm).contiguous().view(shape)
            else:
                return self.view(shape).permute(*perm)
        return torch.ops.npu.npu_confusion_transpose(self, perm, shape, transpose_first)


class _NPUMaxOP(object):

    @staticmethod
    def forward(self, dim, keepdim=False):
        if torch.onnx.is_in_onnx_export():
            values, indices = torch.max(self, dim, keepdim)
            indices = indices.to(torch.int32)
            return values, indices
        return torch.ops.npu.npu_max(self, dim, keepdim)


class _NPUBmmV2OP(object):

    @staticmethod
    def forward(self, mat2, output_sizes):
        if torch.onnx.is_in_onnx_export():
            return torch.matmul(self, mat2)
        return torch.ops.npu.npu_bmmV2(self, mat2, output_sizes)


class _NPUDtypeCastOP(object):

    @staticmethod
    def forward(self, dtype):
        if torch.onnx.is_in_onnx_export():
            return self.to(dtype)
        return torch.ops.npu.npu_dtype_cast(self, dtype)


class _NPUSiluOP(object):

    @staticmethod
    def forward(self):
        if torch.onnx.is_in_onnx_export():
            return self * torch.sigmoid(self)
        return torch.ops.npu.npu_silu(self)


class _NPUMinOP(object):

    @staticmethod
    def forward(self, dim, keepdim=False):
        if torch.onnx.is_in_onnx_export():
            outputs, indices = torch.min(self, dim, keepdim)
            indices = indices.to(torch.int32)
            return outputs, indices
        return torch.ops.npu.npu_min(self, dim, keepdim)


class _NPUFusedAttentionLayernormQkvFwdOP(object):

    @staticmethod
    def confusion_transpose(x, new_shape):
        perm = (0, 2, 1, 3)
        return torch_npu.npu_confusion_transpose(x, perm, new_shape, False).contiguous()

    @staticmethod
    def forward(x, kernel_query, kernel_key, kernel_value, gamma, beta, 
                bias_query=None, bias_key=None, bias_value=None, seq_len=128, num_heads=12, eps=1e-05):
        if torch.onnx.is_in_onnx_export():
            kernel_query = kernel_query.t().contiguous()
            kernel_key = kernel_key.t().contiguous()
            kernel_value = kernel_value.t().contiguous()

            norm_shape = [x.shape[-1]]
            new_shape = (int(x.shape[0] / seq_len), seq_len, num_heads, int(x.shape[1] / num_heads))

            norm, mean, variance = torch.native_layer_norm(x, norm_shape, gamma, beta, eps=1e-05)
            q_layer = _NPUFusedAttentionLayernormQkvFwdOP.confusion_transpose(
                    torch.nn.functional.linear(norm, kernel_query, bias_query), new_shape)
            k_layer = _NPUFusedAttentionLayernormQkvFwdOP.confusion_transpose(
                    torch.nn.functional.linear(norm, kernel_key, bias_key), new_shape)
            v_layer = _NPUFusedAttentionLayernormQkvFwdOP.confusion_transpose(
                    torch.nn.functional.linear(norm, kernel_value, bias_value), new_shape)

            return [norm, q_layer, k_layer, v_layer, mean, variance]
        
        return torch.ops.npu.npu_fused_attention_layernorm_qkv_fwd(
                                x, kernel_query, kernel_key, kernel_value, gamma, beta, 
                                bias_query=bias_query, bias_key=bias_key, bias_value=bias_value, 
                                seq_len=seq_len, num_heads=num_heads, eps=eps)


def _add_ops_combined_for_onnx():
    torch_npu.npu_linear = _NPULinearOP.forward
    torch_npu.npu_transpose = _NPUTransposeOP.forward
    torch_npu.npu_broadcast = _NPUBroadcastOP.forward
    torch_npu.npu_conv_transpose2d = _NPUConvTranspose2dOP.forward
    torch_npu.npu_conv2d = _NPUConv2dOP.forward
    torch_npu.npu_conv3d = _NPUConv3dOP.forward
    torch_npu.npu_stride_copy = _NPUStrideCopyOP.forward
    torch_npu.npu_sort_v2 = _NPUSortV2OP.forward
    torch_npu.npu_layer_norm_eval = _NPULayerNormEvalOP.forward
    torch_npu.npu_reshape = _NPUReshapeOP.forward
    torch_npu.npu_pad = _NPUPadOP.forward
    torch_npu.npu_convolution = _NPUConvolutionOP.forward
    torch_npu.npu_convolution_transpose = _NPUConvolutionTransposeOP.forward
    torch_npu.npu_confusion_transpose = _NPUConfusionTransposeOP.forward
    torch_npu.npu_max = _NPUMaxOP.forward
    torch_npu.npu_bmmV2 = _NPUBmmV2OP.forward
    torch_npu.npu_dtype_cast = _NPUDtypeCastOP.forward
    torch_npu.npu_silu = _NPUSiluOP.forward
    torch_npu.npu_min = _NPUMinOP.forward
    torch_npu.npu_fused_attention_layernorm_qkv_fwd = _NPUFusedAttentionLayernormQkvFwdOP.forward