/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
 * MindIE is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

#include <torch/library.h>
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
#include "torch_npu/csrc/core/npu/NPUFormat.h"
#include "pytorch_npu_helper.h"
#include "sparse_block_estimate.h"

using namespace at;
namespace {
constexpr int DIM_INDEX_0 = 0;
constexpr int DIM_INDEX_1 = 1;
constexpr int DIM_INDEX_2 = 2;
constexpr int DIM_INDEX_3 = 3;
constexpr std::string_view SPARSEBLOCKESTIMATE_NAME = "aclnnSparseBlockEstimate";
}
std::tuple<at::Tensor, at::Tensor> sparse_block_estimate_mindie_sd_impl_npu(
    const at::Tensor &query, const at::Tensor &key,
    c10::OptionalIntArrayRef actual_seq_lengths,
    c10::OptionalIntArrayRef actual_seq_lengths_kv,
    std::string input_layout, int64_t stride, int64_t sparse_size,
    int64_t num_heads, int64_t num_key_value_heads, double scale_value,
    double threshold, bool causal, bool keep_sink, bool keep_recent, double row_sparse)
{
    TORCH_CHECK(num_heads != 0, "num_heads must be nonzero.");
    TORCH_CHECK(sparse_size != 0, "sparse_size must be nonzero.");

    auto actSeqLen = actual_seq_lengths.value_or(at::IntArrayRef{});
    auto actSeqLenKv = actual_seq_lengths_kv.value_or(at::IntArrayRef{});
    const char* inputLayoutPtr = input_layout.c_str();

    int64_t b;
    int64_t nq;
    int64_t s;
    int64_t d;

    if (input_layout == "BNSD") {
        b = query.size(DIM_INDEX_0);
        nq = query.size(DIM_INDEX_1);
        s = query.size(DIM_INDEX_2);
        d = query.size(DIM_INDEX_3);
    } else if (input_layout == "BSND") {
        b = query.size(DIM_INDEX_0);
        nq = query.size(DIM_INDEX_2);
        s = query.size(DIM_INDEX_1);
        d = query.size(DIM_INDEX_3);
    } else if (input_layout == "BSH") {
        b = query.size(DIM_INDEX_0);
        s = query.size(DIM_INDEX_1);
        d = query.size(DIM_INDEX_2) / num_heads;
        nq = num_heads;
    } else {
        std::cerr << "Error: input_layout only support BNSD, BSND, BSH!!!" << std::endl;
    }
    int64_t seqlenSparse = (s + sparse_size - 1) / sparse_size;
    int64_t seqlenSparseAlign32 = (seqlenSparse + 31) / 32 * 32;

    at::Tensor sparse_mask =
        at_npu::native::empty_with_format({b, nq, seqlenSparse, seqlenSparseAlign32},
        query.options().dtype(c10::ScalarType::Char), at_npu::native::get_npu_format(query));

    at::Tensor sparse_count_table =
        at_npu::native::empty_with_format({b, nq, seqlenSparse}, query.options().dtype(c10::ScalarType::Int),
        at_npu::native::get_npu_format(query));

    EXEC_NPU_CMD<SPARSEBLOCKESTIMATE_NAME>(query, key, actSeqLen,
        actSeqLenKv, inputLayoutPtr, stride, sparse_size,
        num_heads, num_key_value_heads, scale_value,
        threshold, causal, keep_sink, keep_recent, row_sparse, sparse_mask, sparse_count_table);

    return std::tuple<at::Tensor, at::Tensor>(sparse_mask, sparse_count_table);
}