// Copyright (c) 2024 Huawei Technologies Co., Ltd
// All rights reserved.
//
// Licensed under the BSD 3-Clause License  (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <ATen/NamedTensorUtils.h>

#include "torch_npu/csrc/aten/CustomFunctions.h"
#include "torch_npu/csrc/framework/utils/OpPreparation.h"
#include "torch_npu/csrc/aten/mirror/NPUTypeProperties.h"
#include "torch_npu/csrc/core/npu/GetCANNInfo.h"
#include "torch_npu/csrc/custom_dtype/Init.h"
#include "torch_npu/csrc/core/npu/NpuVariables.h"
#include "op_plugin/utils/OpUtils.h"

namespace op_plugin {
namespace utils {
static const uint64_t GROUP_MAX = 65535UL;
static const size_t GROUP_DIM = 3;
static const size_t OFFSET_32_BITS = 32;
static const size_t OFFSET_16_BITS = 16;

std::string get_reduction_str(int64_t reduction)
{
    std::string reductionStr;
    if (reduction == at::Reduction::None) {
        reductionStr = "none";
    } else if (reduction == at::Reduction::Mean) {
        reductionStr = "mean";
    } else {
        reductionStr = "sum";
    }
    return reductionStr;
}

std::string get_vector_str(const std::vector<int64_t> &vec)
{
    std::string shape_str = "[";
    size_t vec_num = vec.size();
    for (size_t i = 0; i < vec_num; i++) {
        shape_str += std::to_string(vec[i]);
        if (i != vec_num - 1) {
            shape_str += ", ";
        }
    }
    shape_str += "]";
    return shape_str;
}

int64_t get_rotary_mode(c10::string_view mode)
{
    if (mode == "half") {
        // ROTATE_HALF模式对应输入为0
        return 0;
    } else if (mode == "interleave") {
        // ROTATE_INTERLEAVED模式对应输入为1
        return 1;
    } else if (mode == "quarter") {
        return 2;
    } else if (mode == "interleave-half") {
        return 3;
    }
}

int64_t make_warp_dim(int64_t dim, int64_t dim_post_expr)
{
    if (dim_post_expr <= 0) {
        dim_post_expr = 1; // this will make range [-1, 0]
    }
    if (dim < 0) {
        dim += dim_post_expr;
    }
    return dim;
}

bool is_transpose_last_two_dims(const at::Tensor &tensor)
{
    if (tensor.dim() < 2 || tensor.dim() > 3) {
        return false;
    }
    int64_t numel = at_npu::native::NPUNativeFunctions::get_storage_size(tensor);
    int64_t dim1 = tensor.dim() - 1;
    int64_t dim2 = tensor.dim() - 2;

    c10::SmallVector<int64_t, 5> tensor_base_size = at_npu::native::OpPreparation::get_tensor_desc_base_sizes(tensor);
    if (tensor.stride(dim2) == 1 && tensor.stride(dim1) == tensor.size(dim2) &&
        tensor.size(dim1) == tensor_base_size[dim2] &&
        tensor.size(dim2) == tensor_base_size[dim1] &&
        tensor.numel() == numel &&
        tensor_base_size.size() == static_cast<uint64_t>(tensor.dim())) {
        return true;
    } else {
        return false;
    }
}

bool is_nz_format(const at::Tensor &mat2)
{
    const torch_npu::NPUStorageDesc &tensor_desc = torch_npu::NPUBridge::GetNpuStorageImpl(mat2)->npu_desc_;
    return tensor_desc.npu_format_ == ACL_FORMAT_FRACTAL_NZ;
}

bool is_two_tensor_base_format(const at::Tensor &self, const at::Tensor &mat2)
{
    return at_npu::native::FormatHelper::IsOpInputBaseFormat(self) &&
           at_npu::native::FormatHelper::IsOpInputBaseFormat(mat2);
}

bool is_nd_nz_format(const at::Tensor &self, const at::Tensor &mat2)
{
    constexpr int k2D = 2;
    constexpr int k3D = 3;

    const auto d1 = self.dim();
    const auto d2 = mat2.dim();

    if (!is_nz_format(mat2) || is_nz_format(self)) {
        return false;
    }
    // Case 1: 2D ND * 2D NZ
    if (d1 == k2D && d2 == k2D) {
        return true;
    }
    // Case 2: 3D ND * 3D NZ, only support Ascend950
    if (d1 == k3D && d2 == k3D && c10_npu::GetSocVersion() > c10_npu::SocVersion::Ascend910_9362) {
        return true;
    }
    return false;
}

bool is_nd_to_nz_on_fly(const at::Tensor &self, const at::Tensor &mat2)
{
    const static int64_t kInnerAxisMinLimit = 128;
    const static int64_t kInnerAxisMaxLimit = 65535;
    if (self.dim() < 2 || mat2.dim() < 2) {
        return false;
    }
    // get inner axis of input after transpose.
    int64_t self_inner_axis = self.size(self.dim() - 1);
    int64_t self_outer_axis = self.size(self.dim() - 2);
    int64_t mat2_inner_axis = mat2.size(mat2.dim() - 1);
    int64_t mat2_outer_axis = mat2.size(mat2.dim() - 2);
    if (is_transpose_last_two_dims(self)) {
        self_inner_axis = self.size(self.dim() - 2);
        self_outer_axis = self.size(self.dim() - 1);
    }
    if (is_transpose_last_two_dims(mat2)) {
        mat2_inner_axis = mat2.size(mat2.dim() - 2);
        mat2_outer_axis = mat2.size(mat2.dim() - 1);
    }
    if (self_inner_axis * self_outer_axis <= kInnerAxisMaxLimit &&
        mat2_inner_axis * mat2_outer_axis <= kInnerAxisMaxLimit) {
        // too small tensor size
        return true;
    }
    // self inner_axis and mat2_inner_axis both in [128, 65535] or in (0, 128) and is multi of 16
    return ((self_inner_axis >= kInnerAxisMinLimit && self_inner_axis <= kInnerAxisMaxLimit) ||
            (self_inner_axis < kInnerAxisMinLimit && !(static_cast<uint64_t>(self_inner_axis) & 0xF))) &&
           ((mat2_inner_axis >= kInnerAxisMinLimit && mat2_inner_axis <= kInnerAxisMaxLimit) ||
            (mat2_inner_axis < kInnerAxisMinLimit && !(static_cast<uint64_t>(mat2_inner_axis) & 0xF)));
}

bool is_scalar_one(const c10::Scalar &scalar)
{
    if (scalar.isIntegral(false)) {
        return scalar.toInt() == 1;
    } else if (scalar.isFloatingPoint()) {
        return fabs(scalar.toFloat() - 1.0) < 1e-6;
    } else {
        return false;
    }
}

float get_scalar_float_value(const c10::Scalar &scalar)
{
    float value;
    if (scalar.isFloatingPoint()) {
        value = scalar.toFloat();
    } else {
        value = static_cast<float>(scalar.toInt());
    }
    return value;
}

c10::SmallVector<int64_t, N> convert_array_to_vector(c10::IntArrayRef intArray)
{
    c10::SmallVector<int64_t, N> intVec;
    for (uint64_t i = 0; i < intArray.size(); i++) {
        intVec.emplace_back(intArray[i]);
    }
    return intVec;
}

c10::SmallVector<int64_t, N> get_dimlist_for_tensor(const at::Tensor &self)
{
    c10::SmallVector<int64_t, N> dimList = {};
    for (int64_t i = 0; i < self.dim(); i++) {
        dimList.emplace_back(i);
    }
    return dimList;
}

int64_t complete_pad(int64_t s_size, int64_t p_size, int64_t k_size, int64_t stride)
{
    int64_t needpads = 0;
    int64_t sizeP = s_size + p_size * 2;
    int64_t leftLen = sizeP - k_size;
    TORCH_CHECK(stride != 0, "CompletePad stride is zero!", OPS_ERROR(ErrCode::VALUE));
    auto reminder = leftLen % stride;
    if (reminder != 0) {
        needpads = stride - reminder;
    }
    return needpads;
}

c10::optional<double> get_scale_value(c10::optional<c10::ArrayRef<double>> scales, int idx)
{
    if (!scales) {
        return c10::nullopt;
    }
    TORCH_CHECK(scales->size() > idx, "idx", idx, "is overrange scales->at(idx) ", scales->size(),
        OPS_ERROR(ErrCode::VALUE));
    return scales->at(idx);
}

at::ScalarType get_divide_result_type(const at::Tensor& self, const at::Tensor& other)
{
    at::ScalarType high_type = at::native::result_type(self, other);
    if (isIntegralType(high_type, true)) {
        high_type = at::kFloat;
    }
    return high_type;
}

at::ScalarType get_divide_calculate_type(const at::Tensor &self, const at::Tensor &other)
{
    at::ScalarType calculate_type = at_npu::native::result_type(self.scalar_type(), other.scalar_type());
    if (isIntegralType(calculate_type, true) || calculate_type == at::kDouble) {
        calculate_type = at::kFloat;
    }
    return calculate_type;
}

at::Tensor get_cast_input(const at::Tensor& self, at::ScalarType calculate_type)
{
    at::Tensor self_cast = (self.dtype() == calculate_type) ? self : at_npu::native::custom_ops::_npu_dtype_cast(self, calculate_type);
    self_cast = at_npu::native::OpPreparation::CastBackToOriFormat(static_cast<const at::Tensor&>(self_cast));
    return self_cast;
}

NameVector compute_names_npu(std::vector<at::Tensor> tensor_list)
{
    NameVector names;
    bool has_names = false;

    for (auto tensor : tensor_list) {
        if (tensor.has_names()) {
            has_names = true;
            break;
        }
    }

    if (!has_names) {
        return names;
    }

    for (auto tensor : tensor_list) {
        if (names.empty()) {
            names = tensor.names();
        } else {
            names = NameVector(at::unify_from_right(names, tensor.names()));
        }
    }
    return names;
}

double compute_scale(int64_t input_size, int64_t output_size, double scale)
{
    if (scale > 0.0) {
        return 1.0 / scale ;
    } else {
        return output_size != 0 ? static_cast<double>(input_size) / output_size : 0;
    }
}

bool check_dtype_foreach(at::ScalarType tensorDtype, ForeachTensorDtypeSupport tensorDtypeCategory, ForeachInputType inputType,
                         c10::optional<at::ScalarType> scalarDtype, c10::optional<ForeachMappingType> mapping)
{
    bool result = false;

    // check tensor dtype
    result = check_foreach_tensor_dtype_spport(tensorDtype, tensorDtypeCategory);
    if (!result) {
        return false;
    }

    // check scalr (scalarlist) parm
    at::ScalarType dtype;
    ForeachMappingType mappingType;
    if (scalarDtype == c10::nullopt && mapping == c10::nullopt) {
        return result;
    } else if (scalarDtype != c10::nullopt && mapping != c10::nullopt) {
        dtype = scalarDtype.value();
        mappingType = mapping.value();
    } else {
        TORCH_CHECK(false, "Invalid  scalarType Parm or ForeachMappingType Parm!", OPS_ERROR(ErrCode::PARAM));
    }

    // checke mapping
    switch (inputType) {
        case ForeachInputType::TYPE_SCALAR:
            return check_mapping_between_tensor_and_scalar(tensorDtype, dtype, mappingType);
        case ForeachInputType::TYPE_SCALARLIST:
            return check_mapping_between_tensor_and_scalar_list(tensorDtype, dtype, mappingType);
        case ForeachInputType::TYPE_TENSOR:
            return true;
        default:
            TORCH_CHECK(false, "Invalid inputType Parm!", OPS_ERROR(ErrCode::PARAM));
    }
}

bool check_foreach_tensor_dtype_spport(at::ScalarType tensorDtype, ForeachTensorDtypeSupport tensorDtypeCategory)
{
    // check tensor dtype
    switch (tensorDtypeCategory) {
        case ForeachTensorDtypeSupport::BASE_DTYPE:
            return check_foreach_tensor_dtype_spport_base(tensorDtype);
        case ForeachTensorDtypeSupport::TO_INT32:
            return check_foreach_tensor_dtype_spport_base(tensorDtype) || (tensorDtype == at::ScalarType::Int);
        case ForeachTensorDtypeSupport::TO_INT:
            return check_foreach_tensor_dtype_spport_base_and_int(tensorDtype);
        default:
            TORCH_CHECK(false, "Invalid  ForeachTensorDtypeSupport Parm", OPS_ERROR(ErrCode::PARAM));
    }
}

bool check_foreach_tensor_dtype_spport_base(at::ScalarType tensorDtype)
{
    return (tensorDtype == at::ScalarType::Half || tensorDtype == at::ScalarType::Float ||
            tensorDtype == at::ScalarType::BFloat16);
}

bool check_foreach_tensor_dtype_spport_base_and_int(at::ScalarType tensorDtype)
{
    return (tensorDtype == at::ScalarType::Half || tensorDtype == at::ScalarType::Float ||
            tensorDtype == at::ScalarType::BFloat16 || tensorDtype == at::ScalarType::Int ||
            tensorDtype == at::ScalarType::Char || tensorDtype == at::ScalarType::Long);
}

bool check_foreach_scalar_dtype_spport(at::ScalarType scalarDtype)
{
    return at::isIntegralType(scalarDtype) || at::isFloatingType(scalarDtype);
}

bool check_mapping_between_tensor_and_scalar_list(at::ScalarType tensorDtype, at::ScalarType scalarDtype, ForeachMappingType mapping)
{
    if (!check_foreach_scalar_dtype_spport(scalarDtype)) {
        return false;
    }

    switch (mapping) {
        case ForeachMappingType::MAP_SCALARLIST_DEFAULT:
            return (at::isIntegralType(scalarDtype) && at::isIntegralType(tensorDtype)) ||
                   (at::isFloatingType(scalarDtype) && at::isFloatingType(tensorDtype));
        default:
            TORCH_CHECK(false, "Invalid  ForeachMappingType Parm Between Tensor And ScalarList", OPS_ERROR(ErrCode::PARAM));
    }
}

bool check_mapping_between_tensor_and_scalar(at::ScalarType tensorDtype, at::ScalarType scalarDtype, ForeachMappingType mapping)
{
    if (!check_foreach_scalar_dtype_spport(scalarDtype)) {
        return false;
    }

    switch (mapping) {
        case ForeachMappingType::MAP_SCALAR_DEFAULT:
            return !at::isIntegralType(tensorDtype) && at::isFloatingType(scalarDtype);
        case ForeachMappingType::MAP_POW_SCALAR_AND_TENSOR:
            return true;
        default:
            TORCH_CHECK(false, "Invalid ForeachMappingType Parm Between Tensor And Scalar!", OPS_ERROR(ErrCode::PARAM));
    }
}

void check_input_same_type_as_parameters(
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& bias)
{
    TORCH_CHECK(input.options().type_equal(weight.options()),
        "Input type (", input.toString(), ") and weight type (", weight.toString(),
        ") should be the same" + OPS_ERROR(ErrCode::TYPE));
    TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
        "Input type (", input.toString(), ") and bias type (", bias.toString(),
        ") should be the same" + OPS_ERROR(ErrCode::TYPE));
}

void check_input_same_type_as_parameters(
    const at::Tensor& input,
    const at::Tensor& weight)
{
    check_input_same_type_as_parameters(input, weight, at::Tensor());
}

bool is_gte_cann_version_810rc1()
{
    const static bool is_support_inf_norm = []() -> bool {
        return IsGteCANNVersion("8.1.RC1", "CANN");
    }();
    return is_support_inf_norm;
}

bool is_gte_cann_version_820rc1()
{
    const static bool result = IsGteCANNVersion("8.2.RC1", "CANN");
    return result;
}

bool is_gte_cann_version_830rc1()
{
    const static bool result = IsGteCANNVersion("8.3.RC1", "CANN");
    return result;
}

bool is_gte_cann_version_850()
{
    const static bool result = IsGteCANNVersion("8.5.0", "CANN");
    return result;
}

bool is_gte_cann_version_851()
{
    const static bool result = IsGteCANNVersion("8.5.1", "CANN");
    return result;
}

bool is_gte_cann_version_850alpha003()
{
    const static bool result = IsGteCANNVersion("8.5.0.alpha003", "CANN");
    return result;
}

bool is_gte_cann_version_900()
{
    const static bool result = IsGteCANNVersion("9.0.0", "CANN");
    return result;
}

const std::string DTypeToString(int64_t input_type)
{
    return c10_npu::IsCustomDType(input_type) ?
           c10_npu::CustomDataTypeToString(input_type) : c10::toString(static_cast<at::ScalarType>(input_type));
}

aclDataType get_dynamic_scales_dtype(const at::Tensor &x, const c10::optional<at::Tensor> &scales, int64_t quant_mode)
{
    aclDataType dynamic_scale_dtype = ACL_FLOAT;
    if (quant_mode == QuantMode::QUANT_MODE_NO_QUANT) {
        dynamic_scale_dtype = (x.scalar_type() == at::kBFloat16 || x.scalar_type() == at::kHalf) ? ACL_FLOAT :
           (scales.has_value() ? at_npu::native::OpPreparation::convert_to_acl_data_type(scales.value().scalar_type()) : ACL_FLOAT);
    } else if (quant_mode == QuantMode::QUANT_MODE_MX) {
        dynamic_scale_dtype = c10_npu::GetAclDataType(c10_npu::DType::FLOAT8_E8M0);
    }
    return dynamic_scale_dtype;
}

std::vector<int64_t> get_dynamic_shape(const c10::optional<at::Tensor> &scales, int64_t quant_mode, int64_t a, int64_t h)
{
    const static int PER_GROUP_SIZE = 128;
    const static int MX_QUANT_SIZE = 32;
    const static int DIM_TWO = 2;
    std::vector<int64_t> shape{a};
    if (quant_mode == QuantMode::QUANT_MODE_NO_QUANT && scales.has_value()) {
        TORCH_CHECK(scales.value().dim() >= DIM_TWO, "Expected scales to be at least 2D.", OPS_ERROR(ErrCode::PARAM));
        shape = {a * scales.value().sizes()[1]};
    } else if (quant_mode == QuantMode::QUANT_MODE_PERTOKEN) {
        shape = {a};
    } else if (quant_mode == QuantMode::QUANT_MODE_PERGROUP) {
        shape = {a, (h + PER_GROUP_SIZE - 1) / PER_GROUP_SIZE};
    } else if (quant_mode == QuantMode::QUANT_MODE_MX) {
        // ensure the ceiling of h divided by MX_QUANT_SIZE is even
        shape = {a, ((h + MX_QUANT_SIZE - 1) / MX_QUANT_SIZE + 1) / 2 * 2};
    }
    return shape;
}

int64_t check_and_get_group_size(at::IntArrayRef group_size_list)
{
    int64_t groups = 0;
    if (group_size_list.empty()) {
        return groups;
    }
    size_t group_dim = group_size_list.size();
    TORCH_CHECK(group_dim == GROUP_DIM, "group_sizes only support input with three elements, but got ", group_dim,
                OPS_ERROR(ErrCode::PARAM));
    int64_t group_m = static_cast<int64_t>(group_size_list[0]);
    int64_t group_n = static_cast<int64_t>(group_size_list[1]);
    int64_t group_k = static_cast<int64_t>(group_size_list[2]);
    bool invalid_group_param = ((group_m <= GROUP_MAX && group_m >= 0) && (group_n <= GROUP_MAX && group_n >= 0) &&
                                (group_k <= GROUP_MAX && group_k >= 0));
    TORCH_CHECK(invalid_group_param, "group param value must conform to range [0, 65535]", OPS_ERROR(ErrCode::VALUE));
    groups = static_cast<int64_t>((static_cast<uint64_t>(group_m) << OFFSET_32_BITS) +
                                  (static_cast<uint64_t>(group_n) << OFFSET_16_BITS) + static_cast<uint64_t>(group_k));
    return groups;
}

int8_t get_cube_math_type_with_passthrough()
{
    int8_t cube_math_type = at_npu::native::OpPreparation::get_cube_math_type(at_npu::native::env::IsAllowMatmulHF32());
    int8_t cube_math_type_passthrough = at_npu::native::OpPreparation::get_cube_math_type();
    if (cube_math_type_passthrough >= 0) {
        cube_math_type = cube_math_type_passthrough;
    }
    return cube_math_type;
}

}  // namespace utils
}  // namespace op_plugin