#include <cstring>
#include <string>
#include "torch_npu/csrc/framework/utils/RandomOpAdapter.h"
#include "torch_npu/csrc/aten/CustomFunctions.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using namespace at_npu::native;
const int DIMENSION_4D = 4;
using npu_preparation = at_npu::native::OpPreparation;
static void check_params(const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value)
{
TORCH_CHECK(query.scalar_type() == key.scalar_type() && key.scalar_type() == value.scalar_type(),
"query, key, value must have the same dtype, got query=", query.scalar_type(),
", key=", key.scalar_type(), ", value=", value.scalar_type(), OPS_ERROR(ErrCode::PARAM));
}
std::tuple<at::Tensor, at::Tensor> npu_block_sparse_attention(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value,
const at::Tensor &block_sparse_mask,
const c10::IntArrayRef block_shape,
c10::string_view q_input_layout,
c10::string_view kv_input_layout,
int64_t num_key_value_heads,
double scale_value,
int64_t inner_precise,
const c10::OptionalIntArrayRef actual_seq_lengths,
const c10::OptionalIntArrayRef actual_seq_lengths_kv,
c10::optional<int64_t> softmax_lse_flag)
{
check_params(query, key, value);
at::Tensor attention_out = npu_preparation::apply_tensor_without_format(query.sizes(), query.options());
at::Tensor softmax_lse_out;
if (query.dim() == DIMENSION_4D) {
softmax_lse_out = npu_preparation::apply_tensor_without_format(
{query.size(0), query.size(1), query.size(2), 1}, query.options().dtype(at::kFloat));
} else {
softmax_lse_out = npu_preparation::apply_tensor_without_format(
{query.size(0), query.size(1), 1}, query.options().dtype(at::kFloat));
}
const at::IntArrayRef block_shape_value = block_shape;
const int64_t softmax_lse_flag_value = softmax_lse_flag.value_or(0);
const at::Tensor atten_mask{nullptr};
const at::Tensor block_table{nullptr};
const int64_t mask_type = 0;
const int64_t block_size = 0;
const int64_t pre_tokens = 2147483647;
const int64_t next_tokens = 2147483647;
char *q_input_layout_ptr = const_cast<char *>(q_input_layout.data());
char *kv_input_layout_ptr = const_cast<char *>(kv_input_layout.data());
EXEC_NPU_NO_FORMAT_CHECK_CMD(
aclnnBlockSparseAttention, query, key, value, block_sparse_mask, atten_mask,
block_shape_value, actual_seq_lengths, actual_seq_lengths_kv, block_table,
q_input_layout_ptr, kv_input_layout_ptr, num_key_value_heads, mask_type, scale_value,
inner_precise, block_size, pre_tokens, next_tokens, softmax_lse_flag_value,
attention_out, softmax_lse_out);
return std::make_tuple(attention_out, softmax_lse_out);
}
}