#include <cstring>
#include "torch_npu/csrc/framework/utils/RandomOpAdapter.h"
#include "torch_npu/csrc/aten/CustomFunctions.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using namespace at_npu::native;
const int64_t MAX_HEAD_DIM = 128;
using npu_preparation = at_npu::native::OpPreparation;
static void check_params(const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value)
{
TORCH_CHECK(query.scalar_type() == key.scalar_type() && key.scalar_type() == value.scalar_type(),
"query, key, value must have the same dtype, got query=", query.scalar_type(),
", key=", key.scalar_type(), ", value=", value.scalar_type(), OPS_ERROR(ErrCode::PARAM));
int64_t head_dim = query.size(-1);
TORCH_CHECK(head_dim <= MAX_HEAD_DIM,
"head_dim must be <= ", MAX_HEAD_DIM, ", but got ", head_dim, OPS_ERROR(ErrCode::PARAM));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> npu_block_sparse_attention_backward(
const at::Tensor &d_out,
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &value,
const at::Tensor &attention_out,
const at::Tensor &softmax_lse,
const at::Tensor &block_sparse_mask,
const c10::OptionalIntArrayRef block_shape,
const c10::OptionalIntArrayRef actual_seq_lengths,
const c10::OptionalIntArrayRef actual_seq_lengths_kv,
c10::string_view q_input_layout,
c10::string_view kv_input_layout,
int64_t num_key_value_heads,
double scale_value)
{
check_params(query, key, value);
at::Tensor d_query = npu_preparation::apply_tensor_without_format(query);
at::Tensor d_key = npu_preparation::apply_tensor_without_format(key);
at::Tensor d_value = npu_preparation::apply_tensor_without_format(value);
static const int64_t kDefaultBlockShape[2] = {128, 128};
const at::IntArrayRef block_shape_value = (block_shape.has_value() && block_shape->size() >= 2)
? *block_shape
: at::IntArrayRef(kDefaultBlockShape, 2);
const at::Tensor atten_mask{nullptr};
const int64_t mask_type = 0;
const int64_t pre_tokens = 2147483647;
const int64_t next_tokens = 2147483647;
char *q_input_layout_ptr = const_cast<char *>(q_input_layout.data());
char *kv_input_layout_ptr = const_cast<char *>(kv_input_layout.data());
EXEC_NPU_NO_FORMAT_CHECK_CMD(
aclnnBlockSparseAttentionGrad,
d_out, query, key, value,
attention_out, softmax_lse,
block_sparse_mask, atten_mask, block_shape_value,
actual_seq_lengths, actual_seq_lengths_kv,
q_input_layout_ptr, kv_input_layout_ptr,
num_key_value_heads, mask_type, scale_value,
pre_tokens, next_tokens,
d_query, d_key, d_value);
return std::make_tuple(d_query, d_key, d_value);
}
}