// Copyright (c) 2025 Huawei Technologies Co., Ltd
// All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "torch_npu/csrc/core/npu/GetCANNInfo.h"

namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;

const int DIM_2 = 2;
const int DIM_3 = 3;
const int MODE_1 = 1;
const int MODE_2 = 2;
const int MODE_3 = 3;
const int MODE_4 = 4;
const int MODE_5 = 5;
const int FP8_E4M3_BLOCK_SIZE = 32;
const char* const REQUIRED_CANN_VERSION = "8.5.0.alpha003";
const char* const CANN_PRODUCT = "CANN";
bool is_hifloat8_dtype = false;
bool is_cann_version_gte_required = IsGteCANNVersion(REQUIRED_CANN_VERSION, CANN_PRODUCT); // whether cann version >= 8.5.0.alpha003

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_mla_prolog_v3(
    const at::Tensor& token_x, const at::Tensor& weight_dq, const at::Tensor& weight_uq_qr,
    const at::Tensor& weight_uk, const at::Tensor& weight_dkv_kr, const at::Tensor& rmsnorm_gamma_cq,
    const at::Tensor& rmsnorm_gamma_ckv, const at::Tensor& rope_sin, const at::Tensor& rope_cos,
    at::Tensor& kv_cache, at::Tensor& kr_cache, const c10::optional<at::Tensor>& cache_index,
    const c10::optional<at::Tensor>& dequant_scale_x, const c10::optional<at::Tensor>& dequant_scale_w_dq,
    const c10::optional<at::Tensor>& dequant_scale_w_uq_qr, const c10::optional<at::Tensor>& dequant_scale_w_dkv_kr, const c10::optional<at::Tensor>& quant_scale_ckv,
    const c10::optional<at::Tensor>& quant_scale_ckr, const c10::optional<at::Tensor>& smooth_scales_cq,
    const c10::optional<at::Tensor>& actual_seq_len, const c10::optional<at::Tensor>& k_nope_clip_alpha,
    double rmsnorm_epsilon_cq, double rmsnorm_epsilon_ckv, c10::string_view cache_mode, bool query_norm_flag, int64_t weight_quant_mode, int64_t kv_cache_quant_mode,
    int64_t query_quant_mode, int64_t ckvkr_repo_mode, int64_t quant_scale_repo_mode, int64_t tile_size, double qc_qr_scale, double kc_scale,
    c10::optional<int64_t> token_x_dtype, c10::optional<int64_t> weight_dq_dtype, c10::optional<int64_t> weight_uq_qr_dtype,
    c10::optional<int64_t> weight_dkv_kr_dtype, c10::optional<int64_t> kv_cache_dtype)
{
    // construct the output tensor
    if (weight_quant_mode == MODE_3) {
        TORCH_CHECK(c10_npu::IsAclnnOnly(), "When weight_quant_mode is 3, not support on this soc version.", OPS_ERROR(ErrCode::NOT_SUPPORT));

        auto dequant_scale_x_dtype = dequant_scale_x.value().dtype();
        auto dequant_scale_w_dq_dtype = dequant_scale_w_dq.value().dtype();
        auto dequant_scale_w_uq_qr_dtype = dequant_scale_w_uq_qr.value().dtype();
        auto dequant_scale_w_dkv_kr_dtype = dequant_scale_w_dkv_kr.value().dtype();

#if VERSION_BETWEEN(V2R7, VERSION_NEWEST)
        TORCH_CHECK(dequant_scale_x_dtype == at::kFloat8_e8m0fnu && dequant_scale_w_dq_dtype == at::kFloat8_e8m0fnu &&
            dequant_scale_w_uq_qr_dtype == at::kFloat8_e8m0fnu && dequant_scale_w_dkv_kr_dtype == at::kFloat8_e8m0fnu,
            "torch_npu supports the float8_e8m0 only in version later than v2.7., dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr Dtype in weight_quant_mode=3 must be torch.float8_e8m0fnu",
            OPS_ERROR(ErrCode::PARAM));
#endif
#if VERSION_BETWEEN(V2R1, V2R6)
        TORCH_CHECK(false, "torch_npu supports the float8_e8m0 only in version later than v2.7.", OPS_ERROR(ErrCode::PARAM));
#endif
    }

    if (weight_quant_mode == MODE_5 && (token_x.dtype() == at::kByte || weight_dq.dtype() == at::kByte ||
        weight_uq_qr.dtype() == at::kByte || weight_dkv_kr.dtype() == at::kByte)) {
        TORCH_CHECK(token_x_dtype.has_value() && weight_dq_dtype.has_value() && weight_uq_qr_dtype.has_value() && weight_dkv_kr_dtype.has_value(),
            "when weight_quant_mode is 5 and input dtype is hifloat8, token_x_dtype,weight_dq_dtype,weight_uq_qr_dtype,weight_dkv_kr_dtype cannot be null.", OPS_ERROR(ErrCode::PARAM));

        TORCH_CHECK(c10_npu::GetAclDataType(token_x_dtype.value()) == aclDataType::ACL_HIFLOAT8 && c10_npu::GetAclDataType(weight_dq_dtype.value()) == aclDataType::ACL_HIFLOAT8 &&
            c10_npu::GetAclDataType(weight_uq_qr_dtype.value()) == aclDataType::ACL_HIFLOAT8 && c10_npu::GetAclDataType(weight_dkv_kr_dtype.value()) == aclDataType::ACL_HIFLOAT8,
            "when weight_quant_mode is 5 and input dtype is hifloat8, token_x_dtype, weight_dq_dtype, weight_uq_qr_dtype, weight_dkv_kr_dtype value must be torch_npu.hifloat8", OPS_ERROR(ErrCode::PARAM));

        if (kv_cache_quant_mode == 1 || kv_cache_quant_mode == 3) {
            TORCH_CHECK(kv_cache_dtype.has_value(),
                "when weight_quant_mode is 5 and kv_cache_quant_mode is 1 or 3 and input dtype is hifloat8, kv_cache_dtype cannot be null.", OPS_ERROR(ErrCode::PARAM));

            TORCH_CHECK(c10_npu::GetAclDataType(kv_cache_dtype.value()) == aclDataType::ACL_HIFLOAT8,
                "when weight_quant_mode is 5 and input dtype is hifloat8, kv_cache_dtype value must be torch_npu.hifloat8", OPS_ERROR(ErrCode::PARAM));
        }
        is_hifloat8_dtype = true;
    }

    auto token_x_dim = token_x.dim();
    TORCH_CHECK(token_x_dim == DIM_2 || token_x_dim == DIM_3, "token_x dim num should be 2 or 3, but the actual value is ", token_x_dim, OPS_ERROR(ErrCode::PARAM));

    auto weight_uk_dim = weight_uk.dim();
    TORCH_CHECK(weight_uk_dim == DIM_3, "weight_uk dim num should be 3, but the actual value is ", weight_uk_dim, OPS_ERROR(ErrCode::PARAM));

    auto rope_sin_dim = rope_sin.dim();

    at::Tensor query;
    at::Tensor query_rope;
    at::Tensor dequant_scale_q_nope {nullptr};
    at::Tensor query_norm {nullptr};
    at::Tensor dequant_scale_q_norm {nullptr};

    if (token_x_dim == DIM_3) {
        TORCH_CHECK(rope_sin_dim == DIM_3, "when token_x dim num is 3, rope_sin dim num should be 3, but the actual value is ", rope_sin_dim, OPS_ERROR(ErrCode::PARAM));
        if ((weight_quant_mode == MODE_2 || weight_quant_mode == MODE_3 || weight_quant_mode == MODE_4 || weight_quant_mode == MODE_5) && kv_cache_quant_mode == MODE_1) {
            // weight_quant_mode=2,4,5且kv_cache_quant_mode=1时为全量化kv量化场景(int8,fp8,hif8)
            // weight_quant_mode=3且kv_cache_quant_mode=1时为mxfp8全量化kv量化场景
            if (is_hifloat8_dtype) {
                query = npu_preparation::apply_tensor_without_format({token_x.size(0), token_x.size(1), weight_uk.size(0), weight_uk.size(2)}, token_x.options().dtype(at::kByte));
            } else {
                query = npu_preparation::apply_tensor_without_format({token_x.size(0), token_x.size(1), weight_uk.size(0), weight_uk.size(2)}, token_x.options().dtype(token_x.dtype()));
            }
            dequant_scale_q_nope = npu_preparation::apply_tensor_without_format({token_x.size(0) * token_x.size(1), weight_uk.size(0), 1}, at::kFloat);
        } else {
            query = npu_preparation::apply_tensor_without_format({token_x.size(0), token_x.size(1), weight_uk.size(0), weight_uk.size(2)}, token_x.options().dtype(at::kBFloat16));
        }
        query_rope = npu_preparation::apply_tensor_without_format({token_x.size(0), token_x.size(1), weight_uk.size(0), rope_sin.size(2)}, at::kBFloat16);
        if (query_norm_flag) {
            if (is_hifloat8_dtype) {
                query_norm = npu_preparation::apply_tensor_without_format({token_x.size(0), token_x.size(1), weight_dq.size(1)}, token_x.options().dtype(at::kByte));
            } else {
                query_norm = npu_preparation::apply_tensor_without_format({token_x.size(0), token_x.size(1), weight_dq.size(1)}, token_x.options().dtype(weight_uq_qr.dtype()));
            }
            if (weight_quant_mode == MODE_1 || weight_quant_mode == MODE_2 || weight_quant_mode == MODE_4 || weight_quant_mode == MODE_5) {
                // weight_quant_mode=1 半量化场景, weight_quant_mode=2,4,5 全量化场景(int8,fp8,hif8)
                dequant_scale_q_norm = npu_preparation::apply_tensor_without_format({token_x.size(0) * token_x.size(1), 1}, at::kFloat);
            } else if (weight_quant_mode == MODE_3) {
                dequant_scale_q_norm = npu_preparation::apply_tensor_without_format({token_x.size(0) * token_x.size(1), weight_dq.size(1) / FP8_E4M3_BLOCK_SIZE}, dequant_scale_x.value().options().dtype(dequant_scale_x.value().dtype()));
            }
        }
    } else {
        TORCH_CHECK(rope_sin_dim == DIM_2, "when token_x dim num is 2, rope_sin dim num should be 2, but the actual value is ", rope_sin_dim, OPS_ERROR(ErrCode::PARAM));
        if ((weight_quant_mode == MODE_2 || weight_quant_mode == MODE_3 || weight_quant_mode == MODE_4 || weight_quant_mode == MODE_5) && kv_cache_quant_mode == MODE_1) {
            if (is_hifloat8_dtype) {
                query = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_uk.size(0), weight_uk.size(2)}, token_x.options().dtype(at::kByte));
            } else {
                query = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_uk.size(0), weight_uk.size(2)}, token_x.options().dtype(token_x.dtype()));
            }
            dequant_scale_q_nope = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_uk.size(0), 1}, at::kFloat);
        } else {
            query = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_uk.size(0), weight_uk.size(2)}, token_x.options().dtype(at::kBFloat16));
        }
        query_rope = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_uk.size(0), rope_sin.size(1)}, at::kBFloat16);
        if (query_norm_flag) {
            if (is_hifloat8_dtype) {
                query_norm = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_dq.size(1)}, token_x.options().dtype(at::kByte));
            } else {
                query_norm = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_dq.size(1)}, token_x.options().dtype(weight_uq_qr.dtype()));
            }
            if (weight_quant_mode == MODE_1 || weight_quant_mode == MODE_2 || weight_quant_mode == MODE_4 || weight_quant_mode == MODE_5) {
                dequant_scale_q_norm = npu_preparation::apply_tensor_without_format({token_x.size(0), 1}, at::kFloat);
            } else if (weight_quant_mode == MODE_3) {
                dequant_scale_q_norm = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_dq.size(1) / FP8_E4M3_BLOCK_SIZE}, dequant_scale_x.value().options().dtype(dequant_scale_x.value().dtype()));
            }
        }
    }

    char *cache_mode_ptr = const_cast<char *>(cache_mode.data());

    if (is_cann_version_gte_required) {
        if (is_hifloat8_dtype) {
            TensorWrapper token_x_wrapper = make_wrapper(token_x, token_x_dtype);
            TensorWrapper weight_dq_wrapper = make_wrapper(weight_dq, weight_dq_dtype);
            TensorWrapper weight_uq_qr_wrapper = make_wrapper(weight_uq_qr, weight_uq_qr_dtype);
            TensorWrapper weight_dkv_kr_wrapper = make_wrapper(weight_dkv_kr, weight_dkv_kr_dtype);
            if (kv_cache_quant_mode == MODE_1) {
                TensorWrapper kv_cache_wrapper = make_wrapper(kv_cache, kv_cache_dtype);
                TensorWrapper query_wrapper = make_wrapper(query, token_x_dtype);
                TensorWrapper query_norm_wrapper = make_wrapper(query_norm, token_x_dtype);
                EXEC_NPU_CMD(aclnnMlaPrologV3WeightNz, token_x_wrapper, weight_dq_wrapper, weight_uq_qr_wrapper, weight_uk, weight_dkv_kr_wrapper, rmsnorm_gamma_cq,
                    rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache_wrapper, kr_cache, cache_index, dequant_scale_x, dequant_scale_w_dq,
                    dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr, quant_scale_ckv, quant_scale_ckr, smooth_scales_cq, actual_seq_len,
                    k_nope_clip_alpha, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv, cache_mode_ptr, weight_quant_mode, kv_cache_quant_mode, query_quant_mode,
                    ckvkr_repo_mode, quant_scale_repo_mode, tile_size, qc_qr_scale, kc_scale, query_wrapper, query_rope,
                    dequant_scale_q_nope, query_norm_wrapper, dequant_scale_q_norm);
            } else if (kv_cache_quant_mode == MODE_3) {
                TensorWrapper kv_cache_wrapper = make_wrapper(kv_cache, kv_cache_dtype);
                TensorWrapper query_norm_wrapper = make_wrapper(query_norm, token_x_dtype);
                EXEC_NPU_CMD(aclnnMlaPrologV3WeightNz, token_x_wrapper, weight_dq_wrapper, weight_uq_qr_wrapper, weight_uk, weight_dkv_kr_wrapper, rmsnorm_gamma_cq,
                    rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache_wrapper, kr_cache, cache_index, dequant_scale_x, dequant_scale_w_dq,
                    dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr, quant_scale_ckv, quant_scale_ckr, smooth_scales_cq, actual_seq_len,
                    k_nope_clip_alpha, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv, cache_mode_ptr, weight_quant_mode, kv_cache_quant_mode, query_quant_mode,
                    ckvkr_repo_mode, quant_scale_repo_mode, tile_size, qc_qr_scale, kc_scale, query, query_rope,
                    dequant_scale_q_nope, query_norm_wrapper, dequant_scale_q_norm);
            } else {
                TensorWrapper query_norm_wrapper = make_wrapper(query_norm, token_x_dtype);
                EXEC_NPU_CMD(aclnnMlaPrologV3WeightNz, token_x_wrapper, weight_dq_wrapper, weight_uq_qr_wrapper, weight_uk, weight_dkv_kr_wrapper, rmsnorm_gamma_cq,
                    rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, dequant_scale_x, dequant_scale_w_dq,
                    dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr, quant_scale_ckv, quant_scale_ckr, smooth_scales_cq, actual_seq_len,
                    k_nope_clip_alpha, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv, cache_mode_ptr, weight_quant_mode, kv_cache_quant_mode, query_quant_mode,
                    ckvkr_repo_mode, quant_scale_repo_mode, tile_size, qc_qr_scale, kc_scale, query, query_rope,
                    dequant_scale_q_nope, query_norm_wrapper, dequant_scale_q_norm);
            }
        } else {
            EXEC_NPU_CMD(aclnnMlaPrologV3WeightNz, token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq,
                rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, dequant_scale_x, dequant_scale_w_dq,
                dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr, quant_scale_ckv, quant_scale_ckr, smooth_scales_cq, actual_seq_len,
                k_nope_clip_alpha, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv, cache_mode_ptr, weight_quant_mode, kv_cache_quant_mode, query_quant_mode,
                ckvkr_repo_mode, quant_scale_repo_mode, tile_size, qc_qr_scale, kc_scale, query, query_rope,
                dequant_scale_q_nope, query_norm, dequant_scale_q_norm);
        }
    } else {
        EXEC_NPU_CMD(aclnnMlaPrologV3WeightNz, token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq,
            rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache, kr_cache, cache_index, dequant_scale_x, dequant_scale_w_dq,
            dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr, quant_scale_ckv, quant_scale_ckr, smooth_scales_cq, actual_seq_len,
            k_nope_clip_alpha, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv, cache_mode_ptr, query_norm_flag, weight_quant_mode, kv_cache_quant_mode, query_quant_mode,
            ckvkr_repo_mode, quant_scale_repo_mode, tile_size, qc_qr_scale, kc_scale, query, query_rope,
            dequant_scale_q_nope, query_norm, dequant_scale_q_norm);
    }

    if (!query_norm.defined()) {
        query_norm = npu_preparation::apply_tensor_without_format({0}, token_x.options().dtype(weight_uq_qr.dtype()));
    }
    if (!dequant_scale_q_nope.defined()) {
        dequant_scale_q_nope = npu_preparation::apply_tensor_without_format({0}, at::kFloat);
    }
    if (!dequant_scale_q_norm.defined()) {
        if (weight_quant_mode == MODE_3) {
            dequant_scale_q_norm = npu_preparation::apply_tensor_without_format({0}, dequant_scale_x.value().options().dtype(dequant_scale_x.value().dtype()));
        } else {
            dequant_scale_q_norm = npu_preparation::apply_tensor_without_format({0}, at::kFloat);
        }
    }

    return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>(query, query_rope, dequant_scale_q_nope, query_norm, dequant_scale_q_norm);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_mla_prolog_v3_functional(
    const at::Tensor& token_x, const at::Tensor& weight_dq, const at::Tensor& weight_uq_qr,
    const at::Tensor& weight_uk, const at::Tensor& weight_dkv_kr, const at::Tensor& rmsnorm_gamma_cq,
    const at::Tensor& rmsnorm_gamma_ckv, const at::Tensor& rope_sin, const at::Tensor& rope_cos,
    const at::Tensor& kv_cache, const at::Tensor& kr_cache, const c10::optional<at::Tensor>& cache_index,
    const c10::optional<at::Tensor>& dequant_scale_x, const c10::optional<at::Tensor>& dequant_scale_w_dq,
    const c10::optional<at::Tensor>& dequant_scale_w_uq_qr, const c10::optional<at::Tensor>& dequant_scale_w_dkv_kr, const c10::optional<at::Tensor>& quant_scale_ckv,
    const c10::optional<at::Tensor>& quant_scale_ckr, const c10::optional<at::Tensor>& smooth_scales_cq,
    const c10::optional<at::Tensor>& actual_seq_len, const c10::optional<at::Tensor>& k_nope_clip_alpha,
    double rmsnorm_epsilon_cq, double rmsnorm_epsilon_ckv, c10::string_view cache_mode, bool query_norm_flag, int64_t weight_quant_mode, int64_t kv_cache_quant_mode,
    int64_t query_quant_mode, int64_t ckvkr_repo_mode, int64_t quant_scale_repo_mode, int64_t tile_size, double qc_qr_scale, double kc_scale,
    c10::optional<int64_t> token_x_dtype, c10::optional<int64_t> weight_dq_dtype, c10::optional<int64_t> weight_uq_qr_dtype,
    c10::optional<int64_t> weight_dkv_kr_dtype, c10::optional<int64_t> kv_cache_dtype)
{
    // construct the output tensor
    if (weight_quant_mode == MODE_3) {
        TORCH_CHECK(c10_npu::IsAclnnOnly(), "When weight_quant_mode is 3, not support on this soc version.", OPS_ERROR(ErrCode::NOT_SUPPORT));

        auto dequant_scale_x_dtype = dequant_scale_x.value().dtype();
        auto dequant_scale_w_dq_dtype = dequant_scale_w_dq.value().dtype();
        auto dequant_scale_w_uq_qr_dtype = dequant_scale_w_uq_qr.value().dtype();
        auto dequant_scale_w_dkv_kr_dtype = dequant_scale_w_dkv_kr.value().dtype();

#if VERSION_BETWEEN(V2R7, VERSION_NEWEST)
        TORCH_CHECK(dequant_scale_x_dtype == at::kFloat8_e8m0fnu && dequant_scale_w_dq_dtype == at::kFloat8_e8m0fnu &&
            dequant_scale_w_uq_qr_dtype == at::kFloat8_e8m0fnu && dequant_scale_w_dkv_kr_dtype == at::kFloat8_e8m0fnu,
            "torch_npu supports the float8_e8m0 only in version later than v2.7., dequant_scale_x, dequant_scale_w_dq, dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr Dtype in weight_quant_mode=3 must be torch.float8_e8m0fnu",
            OPS_ERROR(ErrCode::PARAM));
#endif
#if VERSION_BETWEEN(V2R1, V2R6)
        TORCH_CHECK(false, "torch_npu supports the float8_e8m0 only in version later than v2.7.", OPS_ERROR(ErrCode::PARAM));
#endif
    }

    if (weight_quant_mode == MODE_5 && (token_x.dtype() == at::kByte || weight_dq.dtype() == at::kByte ||
        weight_uq_qr.dtype() == at::kByte || weight_dkv_kr.dtype() == at::kByte)) {
        TORCH_CHECK(token_x_dtype.has_value() && weight_dq_dtype.has_value() && weight_uq_qr_dtype.has_value() && weight_dkv_kr_dtype.has_value(),
            "when weight_quant_mode is 5 and input dtype is hifloat8, token_x_dtype,weight_dq_dtype,weight_uq_qr_dtype,weight_dkv_kr_dtype cannot be null.", OPS_ERROR(ErrCode::PARAM));

        TORCH_CHECK(c10_npu::GetAclDataType(token_x_dtype.value()) == aclDataType::ACL_HIFLOAT8 && c10_npu::GetAclDataType(weight_dq_dtype.value()) == aclDataType::ACL_HIFLOAT8 &&
            c10_npu::GetAclDataType(weight_uq_qr_dtype.value()) == aclDataType::ACL_HIFLOAT8 && c10_npu::GetAclDataType(weight_dkv_kr_dtype.value()) == aclDataType::ACL_HIFLOAT8,
            "when weight_quant_mode is 5 and input dtype is hifloat8, token_x_dtype, weight_dq_dtype, weight_uq_qr_dtype, weight_dkv_kr_dtype value must be torch_npu.hifloat8", OPS_ERROR(ErrCode::PARAM));

        if (kv_cache_quant_mode == 1 or kv_cache_quant_mode == 3) {
            TORCH_CHECK(kv_cache_dtype.has_value(),
                "when weight_quant_mode is 5 and kv_cache_quant_mode is 1 or 3 and input dtype is hifloat8, kv_cache_dtype cannot be null.", OPS_ERROR(ErrCode::PARAM));

            TORCH_CHECK(c10_npu::GetAclDataType(kv_cache_dtype.value()) == aclDataType::ACL_HIFLOAT8,
                "when weight_quant_mode is 5 and input dtype is hifloat8, kv_cache_dtype value must be torch_npu.hifloat8", OPS_ERROR(ErrCode::PARAM));
        }
        is_hifloat8_dtype = true;
    }

    auto token_x_dim = token_x.dim();
    TORCH_CHECK(token_x_dim == DIM_2 || token_x_dim == DIM_3, "token_x dim num should be 2 or 3, but the actual value is ", token_x_dim, OPS_ERROR(ErrCode::PARAM));

    auto weight_uk_dim = weight_uk.dim();
    TORCH_CHECK(weight_uk_dim == DIM_3, "weight_uk dim num should be 3, but the actual value is ", weight_uk_dim, OPS_ERROR(ErrCode::PARAM));

    auto rope_sin_dim = rope_sin.dim();
    
    at::Tensor query;
    at::Tensor query_rope;
    at::Tensor dequant_scale_q_nope {nullptr};
    at::Tensor query_norm {nullptr};
    at::Tensor dequant_scale_q_norm {nullptr};

    if (token_x_dim == DIM_3) {
        TORCH_CHECK(rope_sin_dim == DIM_3, "when token_x dim num is 3, rope_sin dim num should be 3, but the actual value is ", rope_sin_dim, OPS_ERROR(ErrCode::PARAM));
        if ((weight_quant_mode == MODE_2 || weight_quant_mode == MODE_3 || weight_quant_mode == MODE_4 || weight_quant_mode == MODE_5) && kv_cache_quant_mode == MODE_1) {
            // weight_quant_mode=2,4,5且kv_cache_quant_mode=1时为全量化kv量化场景(int8,fp8,hif8)
            // weight_quant_mode=3且kv_cache_quant_mode=1时为mxfp8全量化kv量化场景
            if (is_hifloat8_dtype) {
                query = npu_preparation::apply_tensor_without_format({token_x.size(0), token_x.size(1), weight_uk.size(0), weight_uk.size(2)}, token_x.options().dtype(at::kByte));
            } else {
                query = npu_preparation::apply_tensor_without_format({token_x.size(0), token_x.size(1), weight_uk.size(0), weight_uk.size(2)}, token_x.options().dtype(token_x.dtype()));
            }
            dequant_scale_q_nope = npu_preparation::apply_tensor_without_format({token_x.size(0) * token_x.size(1), weight_uk.size(0), 1}, at::kFloat);
        } else {
            query = npu_preparation::apply_tensor_without_format({token_x.size(0), token_x.size(1), weight_uk.size(0), weight_uk.size(2)}, token_x.options().dtype(at::kBFloat16));
        }
        query_rope = npu_preparation::apply_tensor_without_format({token_x.size(0), token_x.size(1), weight_uk.size(0), rope_sin.size(2)}, at::kBFloat16);
        if (query_norm_flag) {
            if (is_hifloat8_dtype) {
                query_norm = npu_preparation::apply_tensor_without_format({token_x.size(0), token_x.size(1), weight_dq.size(1)}, token_x.options().dtype(at::kByte));
            } else {
                query_norm = npu_preparation::apply_tensor_without_format({token_x.size(0), token_x.size(1), weight_dq.size(1)}, token_x.options().dtype(weight_uq_qr.dtype()));
            }
            if (weight_quant_mode == MODE_1 || weight_quant_mode == MODE_2 || weight_quant_mode == MODE_4 || weight_quant_mode == MODE_5) {
                // weight_quant_mode=1 半量化场景,weight_quant_mode=2,4,5 全量化场景(int8,fp8,hif8)
                dequant_scale_q_norm = npu_preparation::apply_tensor_without_format({token_x.size(0) * token_x.size(1), 1}, at::kFloat);
            } else if (weight_quant_mode == MODE_3) {
                dequant_scale_q_norm = npu_preparation::apply_tensor_without_format({token_x.size(0) * token_x.size(1), weight_dq.size(1) / FP8_E4M3_BLOCK_SIZE}, dequant_scale_x.value().options().dtype(dequant_scale_x.value().dtype()));
            }
        }
    } else {
        TORCH_CHECK(rope_sin_dim == DIM_2, "when token_x dim num is 2, rope_sin dim num should be 2, but the actual value is ", rope_sin_dim, OPS_ERROR(ErrCode::PARAM));
        if ((weight_quant_mode == MODE_2 || weight_quant_mode == MODE_3 || weight_quant_mode == MODE_4 || weight_quant_mode == MODE_5) && kv_cache_quant_mode == MODE_1) {
            if (is_hifloat8_dtype) {
                query = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_uk.size(0), weight_uk.size(2)}, token_x.options().dtype(at::kByte));
            } else {
                query = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_uk.size(0), weight_uk.size(2)}, token_x.options().dtype(token_x.dtype()));
            }
            dequant_scale_q_nope = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_uk.size(0), 1}, at::kFloat);
        } else {
            query = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_uk.size(0), weight_uk.size(2)}, token_x.options().dtype(at::kBFloat16));
        }
        query_rope = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_uk.size(0), rope_sin.size(1)}, at::kBFloat16);
        if (query_norm_flag) {
            if (is_hifloat8_dtype) {
                query_norm = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_dq.size(1)}, token_x.options().dtype(at::kByte));
            } else {
                query_norm = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_dq.size(1)}, token_x.options().dtype(weight_uq_qr.dtype()));
            }
            if (weight_quant_mode == MODE_1 || weight_quant_mode == MODE_2 || weight_quant_mode == MODE_4 || weight_quant_mode == MODE_5) {
                dequant_scale_q_norm = npu_preparation::apply_tensor_without_format({token_x.size(0), 1}, at::kFloat);
            } else if (weight_quant_mode == MODE_3) {
                dequant_scale_q_norm = npu_preparation::apply_tensor_without_format({token_x.size(0), weight_dq.size(1) / FP8_E4M3_BLOCK_SIZE}, dequant_scale_x.value().options().dtype(dequant_scale_x.value().dtype()));
            }
        }
    }

    char *cache_mode_ptr = const_cast<char *>(cache_mode.data());
    at::Tensor kv_cache_inplace = kv_cache.clone();
    at::Tensor kr_cache_inplace = kr_cache.clone();

    if (is_cann_version_gte_required) {
        if (is_hifloat8_dtype) {
            TensorWrapper token_x_wrapper = make_wrapper(token_x, token_x_dtype);
            TensorWrapper weight_dq_wrapper = make_wrapper(weight_dq, weight_dq_dtype);
            TensorWrapper weight_uq_qr_wrapper = make_wrapper(weight_uq_qr, weight_uq_qr_dtype);
            TensorWrapper weight_dkv_kr_wrapper = make_wrapper(weight_dkv_kr, weight_dkv_kr_dtype);
            if (kv_cache_quant_mode == MODE_1) {
                TensorWrapper kv_cache_wrapper = make_wrapper(kv_cache_inplace, kv_cache_dtype);
                TensorWrapper query_wrapper = make_wrapper(query, token_x_dtype);
                TensorWrapper query_norm_wrapper = make_wrapper(query_norm, token_x_dtype);
                EXEC_NPU_CMD(aclnnMlaPrologV3WeightNz, token_x_wrapper, weight_dq_wrapper, weight_uq_qr_wrapper, weight_uk, weight_dkv_kr_wrapper, rmsnorm_gamma_cq,
                    rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache_wrapper, kr_cache_inplace, cache_index, dequant_scale_x, dequant_scale_w_dq,
                    dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr, quant_scale_ckv, quant_scale_ckr, smooth_scales_cq, actual_seq_len,
                    k_nope_clip_alpha, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv, cache_mode_ptr, weight_quant_mode, kv_cache_quant_mode, query_quant_mode,
                    ckvkr_repo_mode, quant_scale_repo_mode, tile_size, qc_qr_scale, kc_scale, query_wrapper, query_rope,
                    dequant_scale_q_nope, query_norm_wrapper, dequant_scale_q_norm);
            } else if (kv_cache_quant_mode == MODE_3) {
                TensorWrapper kv_cache_wrapper = make_wrapper(kv_cache_inplace, kv_cache_dtype);
                TensorWrapper query_norm_wrapper = make_wrapper(query_norm, token_x_dtype);
                EXEC_NPU_CMD(aclnnMlaPrologV3WeightNz, token_x_wrapper, weight_dq_wrapper, weight_uq_qr_wrapper, weight_uk, weight_dkv_kr_wrapper, rmsnorm_gamma_cq,
                    rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache_wrapper, kr_cache_inplace, cache_index, dequant_scale_x, dequant_scale_w_dq,
                    dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr, quant_scale_ckv, quant_scale_ckr, smooth_scales_cq, actual_seq_len,
                    k_nope_clip_alpha, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv, cache_mode_ptr, weight_quant_mode, kv_cache_quant_mode, query_quant_mode,
                    ckvkr_repo_mode, quant_scale_repo_mode, tile_size, qc_qr_scale, kc_scale, query, query_rope,
                    dequant_scale_q_nope, query_norm_wrapper, dequant_scale_q_norm);
            } else {
                TensorWrapper query_norm_wrapper = make_wrapper(query_norm, token_x_dtype);
                EXEC_NPU_CMD(aclnnMlaPrologV3WeightNz, token_x_wrapper, weight_dq_wrapper, weight_uq_qr_wrapper, weight_uk, weight_dkv_kr_wrapper, rmsnorm_gamma_cq,
                    rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache_inplace, kr_cache_inplace, cache_index, dequant_scale_x, dequant_scale_w_dq,
                    dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr, quant_scale_ckv, quant_scale_ckr, smooth_scales_cq, actual_seq_len,
                    k_nope_clip_alpha, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv, cache_mode_ptr, weight_quant_mode, kv_cache_quant_mode, query_quant_mode,
                    ckvkr_repo_mode, quant_scale_repo_mode, tile_size, qc_qr_scale, kc_scale, query, query_rope,
                    dequant_scale_q_nope, query_norm_wrapper, dequant_scale_q_norm);
            }
        } else {
            EXEC_NPU_CMD(aclnnMlaPrologV3WeightNz, token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq,
                rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache_inplace, kr_cache_inplace, cache_index, dequant_scale_x, dequant_scale_w_dq,
                dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr, quant_scale_ckv, quant_scale_ckr, smooth_scales_cq, actual_seq_len,
                k_nope_clip_alpha, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv, cache_mode_ptr, weight_quant_mode, kv_cache_quant_mode, query_quant_mode,
                ckvkr_repo_mode, quant_scale_repo_mode, tile_size, qc_qr_scale, kc_scale, query, query_rope,
                dequant_scale_q_nope, query_norm, dequant_scale_q_norm);
        }
    } else {
        EXEC_NPU_CMD(aclnnMlaPrologV3WeightNz, token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq,
            rmsnorm_gamma_ckv, rope_sin, rope_cos, kv_cache_inplace, kr_cache_inplace, cache_index, dequant_scale_x, dequant_scale_w_dq,
            dequant_scale_w_uq_qr, dequant_scale_w_dkv_kr, quant_scale_ckv, quant_scale_ckr, smooth_scales_cq, actual_seq_len,
            k_nope_clip_alpha, rmsnorm_epsilon_cq, rmsnorm_epsilon_ckv, cache_mode_ptr, query_norm_flag, weight_quant_mode, kv_cache_quant_mode, query_quant_mode,
            ckvkr_repo_mode, quant_scale_repo_mode, tile_size, qc_qr_scale, kc_scale, query, query_rope,
            dequant_scale_q_nope, query_norm, dequant_scale_q_norm);
    }

    if (!query_norm.defined()) {
        query_norm = npu_preparation::apply_tensor_without_format({0}, token_x.options().dtype(weight_uq_qr.dtype()));
    }
    if (!dequant_scale_q_nope.defined()) {
        dequant_scale_q_nope = npu_preparation::apply_tensor_without_format({0}, at::kFloat);
    }
    if (!dequant_scale_q_norm.defined()) {
        if (weight_quant_mode == MODE_3) {
            dequant_scale_q_norm = npu_preparation::apply_tensor_without_format({0}, dequant_scale_x.value().options().dtype(dequant_scale_x.value().dtype()));
        } else {
            dequant_scale_q_norm = npu_preparation::apply_tensor_without_format({0}, at::kFloat);
        }
    }

    return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>(query, query_rope, dequant_scale_q_nope, query_norm, dequant_scale_q_norm, kv_cache_inplace, kr_cache_inplace);
}

} // namespace op_api