/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"

namespace {
constexpr int64_t DIM_X = 2;
constexpr int64_t DIM_EXPERT_IDX = 2;
constexpr int64_t LENGTH_ACTIVE_EXPERT_RANGE = 2;
constexpr int64_t EXPERT_TOKENS_COUNT = 1;
constexpr int64_t EXPERT_TOKENS_KEY_VALUE = 2;
constexpr int64_t QUANT_MODE_UNQUANT = -1;
constexpr int64_t QUANT_MODE_STATIC = 0;
constexpr int64_t QUANT_MODE_DYNAMIC = 1;
constexpr int64_t QUANT_MODE_MXFP8_E5M2 = 2;
constexpr int64_t QUANT_MODE_MXFP8_E4M3FN = 3;
constexpr int64_t QUANT_MODE_HIF8_CAST = 6;
constexpr int64_t QUANT_MODE_HIF8_PERTENSOR = 7;
constexpr int64_t QUANT_MODE_HIF8_PER_TOKEN_DIM = 8;
constexpr int64_t QUANT_MODE_MXFP4_E2M1 = 9;
constexpr int64_t QUANT_MODE_FP8_PERBLOCK_E5M2 =11;
constexpr int64_t QUANT_MODE_FP8_PERBLOCK_E4M3FN = 12;
constexpr int64_t QUANT_MODE_INT4_DYNAMIC = 13;
constexpr int64_t MXQUANT_BLOCK_SIZE = 32;
constexpr int64_t FP8_QUANT_BLOCK_SIZE = 128;
constexpr int64_t PAD_TO_EVEN_FACTOR = 2;
constexpr int64_t INT4_NUMS_IN_INT8 = 2;

constexpr int64_t EXPERT_NUM_V2 = 128;
constexpr int64_t EXPERT_NUM_MIN_V2 = 0;
constexpr int64_t EXPERT_NUM_MAX_V2 = 128;
constexpr int64_t HIDDEN_DIM_VAL_V2 = 2048;
};  // namespace

inline bool IsQuantModeMXFP4(int64_t quantMode) {
    return quantMode == QUANT_MODE_MXFP4_E2M1;
}

inline bool IsQuantModeMXFP8(int64_t quantMode) {
    return quantMode == QUANT_MODE_MXFP8_E5M2 || quantMode == QUANT_MODE_MXFP8_E4M3FN;
}

inline bool IsQuantModeFP8(int64_t quantMode) {
    return quantMode == QUANT_MODE_FP8_PERBLOCK_E5M2 || quantMode == QUANT_MODE_FP8_PERBLOCK_E4M3FN;
}

inline bool IsQuantModeHIF8(int64_t quantMode) {
    return quantMode == QUANT_MODE_HIF8_CAST || quantMode == QUANT_MODE_HIF8_PERTENSOR || quantMode == QUANT_MODE_HIF8_PER_TOKEN_DIM;
}

inline bool IsInt4OutputDType(c10::optional<int64_t> xDtype) {
    return xDtype.has_value() && xDtype.value() == static_cast<int64_t>(c10_npu::DType::INT4);
}

inline bool IsDynamicQuantInt4Output(int64_t quantMode, c10::optional<int64_t> xDtype) {
    return quantMode == QUANT_MODE_INT4_DYNAMIC && (!xDtype.has_value() || IsInt4OutputDType(xDtype));
}

namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
using npu_utils = at_npu::native::NpuUtils;
using tensor_list = std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>;

at::IntArrayRef init_new_active_expert_range(at::IntArrayRef &active_expert_range, int64_t expert_num) {
    if (active_expert_range.empty()) {
        static std::vector<int64_t> default_active_expert_range = {0, expert_num};
        return at::IntArrayRef(default_active_expert_range);
    } else {
        return active_expert_range;
    }
}

static bool CheckV2Case(int hidden_dim, int64_t expert_num, at::IntArrayRef active_expert_range,
    int64_t expert_tokens_num_type, int64_t quant_mode) {
    if (expert_num == EXPERT_NUM_V2 && active_expert_range[0] == EXPERT_NUM_MIN_V2 &&
        active_expert_range[1] == EXPERT_NUM_MAX_V2 && hidden_dim == HIDDEN_DIM_VAL_V2) {
        if (quant_mode == -1 && expert_tokens_num_type == 1) {
            return true;
        }
    }
    return false;
}

tensor_list npu_moe_init_routing_v2(const at::Tensor &x, const at::Tensor &expert_idx,
    const c10::optional<at::Tensor> &scale, const c10::optional<at::Tensor> &offset, int64_t active_num,
    int64_t expert_capacity, int64_t expert_num, int64_t drop_pad_mode, int64_t expert_tokens_num_type,
    bool expert_tokens_num_flag, int64_t quant_mode, at::IntArrayRef active_expert_range, int64_t row_idx_type,
    c10::optional<int64_t> x_dtype) {
#if !VERSION_BETWEEN(V2R7, VERSION_NEWEST)
    // 小于2.7的版本不支持MXFP8量化需要的float8_e8m0fnu类型
    TORCH_CHECK(!IsQuantModeMXFP8(quant_mode),
        "Unsupported quant_mode:",
        quant_mode,
        " on this version of torch with torch_npu. Please upgrade to at least v2.7.",
        OPS_ERROR(ErrCode::PARAM));
#endif

    int64_t x_dim = x.dim();
    TORCH_CHECK(x_dim == DIM_X,
        "The x should be ",
        DIM_X,
        "-Dimension, current is ",
        x_dim,
        "-Dimension.",
        OPS_ERROR(ErrCode::PARAM));

    int64_t expert_idx_dim = expert_idx.dim();
    TORCH_CHECK(expert_idx_dim == DIM_EXPERT_IDX,
        "The expert_idx should be ",
        DIM_EXPERT_IDX,
        "-Dimension, current is ",
        expert_idx_dim,
        "-Dimension.",
        OPS_ERROR(ErrCode::PARAM));

    at::IntArrayRef current_active_expert_range = init_new_active_expert_range(active_expert_range, expert_num);
    int64_t active_expert_range_length = current_active_expert_range.size();
    TORCH_CHECK(active_expert_range_length == LENGTH_ACTIVE_EXPERT_RANGE,
        "The length of list active_expert_range should be ",
        LENGTH_ACTIVE_EXPERT_RANGE,
        ", current is ",
        active_expert_range_length,
        ".",
        OPS_ERROR(ErrCode::PARAM));

    int expert_length = current_active_expert_range[1] - current_active_expert_range[0];
    auto x_size = x.sizes();
    auto expert_idx_size = expert_idx.sizes();
    const at::Tensor &p_scale = c10::value_or_else(scale, [] { return at::Tensor(); });
    const at::Tensor &p_offset = c10::value_or_else(offset, [] { return at::Tensor(); });

    int bs = x_size[0];
    int h = x_size[1];
    aclDataType x_acl_type = c10_npu::GetAclDataType(x_dtype.value_or(static_cast<int64_t>(x.scalar_type())));
    if (x_acl_type == aclDataType::ACL_FLOAT4_E2M1) {
        h = h * 2;
    }
    int k = expert_idx_size[1];
    // more suitable cases for v2
    bool using_v2 = CheckV2Case(h, expert_num, active_expert_range, expert_tokens_num_type, quant_mode);

    TORCH_CHECK(!(quant_mode == QUANT_MODE_DYNAMIC && IsInt4OutputDType(x_dtype)),
        "INT4 dynamic quantization uses quant_mode=13. quant_mode=1 only supports INT8 dynamic quantization.",
        OPS_ERROR(ErrCode::PARAM));
    if (quant_mode == QUANT_MODE_INT4_DYNAMIC) {
        TORCH_CHECK(drop_pad_mode == 0,
            "INT4 dynamic quantization only supports drop_pad_mode=0.",
            OPS_ERROR(ErrCode::PARAM));
        TORCH_CHECK(x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16,
            "INT4 dynamic quantization only supports float32 or bfloat16 x.",
            OPS_ERROR(ErrCode::TYPE));
        TORCH_CHECK(!p_offset.defined(),
            "INT4 dynamic quantization does not support offset.",
            OPS_ERROR(ErrCode::PARAM));
        TORCH_CHECK(h % INT4_NUMS_IN_INT8 == 0,
            "INT4 dynamic quantization requires the hidden size of x to be even.",
            OPS_ERROR(ErrCode::PARAM));
        TORCH_CHECK(!x_dtype.has_value() || IsInt4OutputDType(x_dtype),
            "The optional parameter x_dtype must be torch_npu.int4 or None when quant_mode=13.",
            OPS_ERROR(ErrCode::PARAM));
        if (p_scale.defined()) {
            TORCH_CHECK(p_scale.scalar_type() == at::kFloat,
                "The scale dtype must be float32 in INT4 dynamic quantization.",
                OPS_ERROR(ErrCode::TYPE));
            TORCH_CHECK(p_scale.dim() == DIM_X,
                "The scale shape supports only 2D in INT4 dynamic quantization.",
                OPS_ERROR(ErrCode::PARAM));
            TORCH_CHECK(p_scale.size(0) == 1,
                "The first dim of scale must be 1 in INT4 dynamic quantization.",
                OPS_ERROR(ErrCode::PARAM));
            TORCH_CHECK(p_scale.size(1) == x_size[1],
                "The second dim of scale should be the same as the second dim of x in INT4 dynamic quantization.",
                OPS_ERROR(ErrCode::PARAM));
        }
    }

    int64_t expanded_scale_len = 0;
    at::Tensor expanded_x;
    if (drop_pad_mode == 1) {  // Drop/Pad
        if (quant_mode == QUANT_MODE_UNQUANT) {
            expanded_x = npu_preparation::apply_tensor_without_format(x, {expert_num, expert_capacity, h});
        } else {
            expanded_x = npu_preparation::apply_tensor_without_format(
                {expert_num, expert_capacity, h}, x.options().dtype(at::kChar));
        }
        expanded_scale_len = expert_num * expert_capacity;
    } else {  // Dropless
        expanded_scale_len = (active_num <= 0) ? bs * k : std::min<int64_t>(active_num, bs * k);
        switch (quant_mode) {
#if VERSION_BETWEEN(V2R7, VERSION_NEWEST)
            case QUANT_MODE_MXFP8_E5M2:
                expanded_x = npu_preparation::apply_tensor_without_format(
                    {expanded_scale_len, h}, x.options().dtype(at::kFloat8_e5m2));
                break;
            case QUANT_MODE_MXFP8_E4M3FN:
                expanded_x = npu_preparation::apply_tensor_without_format(
                    {expanded_scale_len, h}, x.options().dtype(at::kFloat8_e4m3fn));
                break;
            case QUANT_MODE_MXFP4_E2M1:
                expanded_x = npu_preparation::apply_tensor_without_format(
                    {expanded_scale_len, h / 2}, x.options().dtype(at::kByte));
                break;
#endif
            case QUANT_MODE_STATIC:
            case QUANT_MODE_DYNAMIC:
                expanded_x = npu_preparation::apply_tensor_without_format(
                    {expanded_scale_len, h}, x.options().dtype(at::kChar));
                break;
            case QUANT_MODE_INT4_DYNAMIC:
                expanded_x = npu_preparation::apply_tensor_without_format(
                    {expanded_scale_len, h / INT4_NUMS_IN_INT8}, x.options().dtype(at::kByte));
                break;
            case QUANT_MODE_HIF8_CAST:
            case QUANT_MODE_HIF8_PERTENSOR:
            case QUANT_MODE_HIF8_PER_TOKEN_DIM: {
                expanded_x =
                    npu_preparation::apply_tensor_without_format({expanded_scale_len, h}, x.options().dtype(at::kByte));
                break;
            }
            case QUANT_MODE_FP8_PERBLOCK_E5M2: {
                expanded_x =
                    npu_preparation::apply_tensor_without_format({expanded_scale_len, h}, x.options().dtype(at::kFloat8_e5m2));
                break;
            }
            case QUANT_MODE_FP8_PERBLOCK_E4M3FN: {
                expanded_x =
                    npu_preparation::apply_tensor_without_format({expanded_scale_len, h}, x.options().dtype(at::kFloat8_e4m3fn));
                break;
            }
            default:  // quant_mode == QUANT_MODE_UNQUANT
                expanded_x = npu_preparation::apply_tensor_without_format(x, {expanded_scale_len, x_size[1]});
        }
    }

    at::Tensor expanded_row_idx = npu_preparation::apply_tensor_without_format(expert_idx, {bs * k});
    at::Tensor expert_tokens_count_or_cumsum;
    if (Is310PBoolCheck()) {
 	  	expert_tokens_count_or_cumsum =
 	  	    npu_preparation::apply_tensor_without_format({expert_length}, x.options().dtype(at::kInt));
 	} else if (expert_tokens_num_type < EXPERT_TOKENS_KEY_VALUE) {
        // expert_tokens_count_or_cumsum in [end-start, ]
        expert_tokens_count_or_cumsum =
            npu_preparation::apply_tensor_without_format({expert_length}, x.options().dtype(at::kLong));
    } else if (expert_tokens_num_type == EXPERT_TOKENS_KEY_VALUE) {
        // key_value in [2, end-start]
        expert_tokens_count_or_cumsum =
            npu_preparation::apply_tensor_without_format({expert_num, 2}, x.options().dtype(at::kLong));
    }

    if ((using_v2 && !op_plugin::utils::is_gte_cann_version_850alpha003()) || Is310PBoolCheck()) {
        at::Tensor expert_tokens_before_capacity =
            npu_preparation::apply_tensor_without_format({expert_num}, x.options().dtype(at::kInt));
        expert_capacity = 0;
        drop_pad_mode = 0;
        int64_t expert_tokens_count_or_cumsum_flag = Is310PBoolCheck() ? 1 : 2;
        bool expert_tokens_before_capacity_flag = false;
        if (bs == 0) {
            // return when using empty tensor
            expert_tokens_count_or_cumsum.zero_();
            return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expert_tokens_before_capacity);
        }
        EXEC_NPU_CMD(aclnnMoeInitRoutingV2,
            x,
            expert_idx,
            active_num,
            expert_capacity,
            expert_num,
            drop_pad_mode,
            expert_tokens_count_or_cumsum_flag,
            expert_tokens_before_capacity_flag,
            expanded_x,
            expanded_row_idx,
            expert_tokens_count_or_cumsum,
            expert_tokens_before_capacity);
        return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expert_tokens_before_capacity);
    }

#if VERSION_BETWEEN(V2R7, VERSION_NEWEST)
    at::Tensor expanded_scale;
    if (IsQuantModeMXFP8(quant_mode)) {
        // scale_cols为h向上整除32后向上对齐到偶数倍
        int64_t scale_cols = (h + MXQUANT_BLOCK_SIZE - 1) / MXQUANT_BLOCK_SIZE;
        scale_cols = (scale_cols + PAD_TO_EVEN_FACTOR - 1) / PAD_TO_EVEN_FACTOR * PAD_TO_EVEN_FACTOR;
        expanded_scale = npu_preparation::apply_tensor_without_format(
            {expanded_scale_len, scale_cols}, x.options().dtype(at::kFloat8_e8m0fnu));
    } else if (IsQuantModeFP8(quant_mode)) {
        // scale_cols为h向上整除128后向上对齐到偶数倍
        int64_t block_size = FP8_QUANT_BLOCK_SIZE * 2;
        expanded_scale = npu_preparation::apply_tensor_without_format(
            {expanded_scale_len, op_infer::CeilDiv(h, block_size), 2}, x.options().dtype(at::kFloat));
    } else if (quant_mode == -1 && (x.scalar_type() == at::kFloat8_e5m2 || x.scalar_type() == at::kFloat8_e4m3fn) && scale.has_value()) {
        expanded_scale = npu_preparation::apply_tensor_without_format(
            {expanded_scale_len, op_infer::CeilDiv(h, 64), 2}, x.options().dtype(at::kByte));
    } else if (quant_mode == -1 && (x_acl_type == aclDataType::ACL_FLOAT4_E2M1) && scale.has_value()) {
        expanded_scale = npu_preparation::apply_tensor_without_format(
            {expanded_scale_len, op_infer::CeilDiv(h, 64), 2}, x.options().dtype(at::kByte));
    } else if (IsQuantModeMXFP4(quant_mode)) {
        expanded_scale = npu_preparation::apply_tensor_without_format(
            {expanded_scale_len, op_infer::CeilDiv(h, 64), 2}, x.options().dtype(at::kByte));
    } else {
        expanded_scale =
            npu_preparation::apply_tensor_without_format({expanded_scale_len}, x.options().dtype(at::kFloat));
    }
#else
    at::Tensor expanded_scale =
        npu_preparation::apply_tensor_without_format({expanded_scale_len}, x.options().dtype(at::kFloat));
#endif
    auto scale_scalar_dtype = p_scale.defined() ? p_scale.scalar_type() : at::kFloat;
    auto expanded_scale_scalar_dtype = expanded_scale.defined() ? expanded_scale.scalar_type() : at::kFloat;
    TensorWrapper scale_wrapper = {
        p_scale,
        (quant_mode == -1 && (x.scalar_type() == at::kFloat8_e5m2 || x.scalar_type() == at::kFloat8_e4m3fn ||
        x_acl_type == aclDataType::ACL_FLOAT4_E2M1)) ?
            aclDataType::ACL_FLOAT8_E8M0:
            npu_preparation::convert_to_acl_data_type(scale_scalar_dtype)
    };
    TensorWrapper expanded_scale_wrapper = {
        expanded_scale,
        (quant_mode == -1 && (x.scalar_type() == at::kFloat8_e5m2 || x.scalar_type() == at::kFloat8_e4m3fn ||
        x_acl_type == aclDataType::ACL_FLOAT4_E2M1)) ?
            aclDataType::ACL_FLOAT8_E8M0:
            npu_preparation::convert_to_acl_data_type(expanded_scale_scalar_dtype)
    };

    TensorWrapper x_wrapper = {x, (quant_mode == -1 && x_dtype.has_value()) ?
        c10_npu::GetAclDataType(x_dtype.value()):
        npu_preparation::convert_to_acl_data_type(x.scalar_type())};
    TensorWrapper expanded_x_wrapper = {expanded_x, npu_preparation::convert_to_acl_data_type(expanded_x.scalar_type())};
    if (quant_mode == -1 && x_dtype.has_value()) {
        expanded_x_wrapper.dtype = c10_npu::GetAclDataType(x_dtype.value());
    } else if (IsDynamicQuantInt4Output(quant_mode, x_dtype)) {
        expanded_x_wrapper.dtype = aclDataType::ACL_INT4;
    } else if (IsQuantModeHIF8(quant_mode)) {
        expanded_x_wrapper.dtype = aclDataType::ACL_HIFLOAT8;
    } else if (IsQuantModeMXFP4(quant_mode)) {
        expanded_x_wrapper.dtype = aclDataType::ACL_FLOAT4_E2M1;
        expanded_scale_wrapper.dtype = aclDataType::ACL_FLOAT8_E8M0;
    }
    if (bs == 0) {
        // return when using empty tensor
        expert_tokens_count_or_cumsum.zero_();
        return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale);
    }
    EXEC_NPU_CMD(aclnnMoeInitRoutingV3,
        x_wrapper,
        expert_idx,
        scale_wrapper,
        p_offset,
        active_num,
        expert_capacity,
        expert_num,
        drop_pad_mode,
        expert_tokens_num_type,
        expert_tokens_num_flag,
        quant_mode,
        active_expert_range,
        row_idx_type,
        expanded_x_wrapper,
        expanded_row_idx,
        expert_tokens_count_or_cumsum,
        expanded_scale_wrapper);
    return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale);
}
} // namespace op_api