#include <cstring>
#include "torch_npu/csrc/framework/utils/RandomOpAdapter.h"
#include "torch_npu/csrc/aten/CustomFunctions.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using namespace at_npu::native;
const int DIMENSION_3D = 3;
const int DIMENSION_4D = 4;
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_dense_lightning_indexer_grad_kl_loss_symint(
const at::Tensor &query,
const at::Tensor &key,
const at::Tensor &query_index,
const at::Tensor &key_index,
const at::Tensor &weights,
const at::Tensor &softmax_max,
const at::Tensor &softmax_sum,
const at::Tensor &softmax_max_index,
const at::Tensor &softmax_sum_index,
double scale_value,
const c10::optional<at::Tensor> &query_rope,
const c10::optional<at::Tensor> &key_rope,
c10::OptionalArrayRef<c10::SymInt> actual_seq_qlen,
c10::OptionalArrayRef<c10::SymInt> actual_seq_klen,
c10::optional<c10::string_view> layout,
c10::optional<int64_t> sparse_mode,
c10::optional<int64_t> pre_tokens,
c10::optional<int64_t> next_tokens)
{
const at::Tensor &query_rope_const = query_rope.value_or(at::Tensor());
const at::Tensor &key_rope_const = key_rope.value_or(at::Tensor());
c10::string_view layout_str = layout.value_or("BSND");
char *layout_ptr = const_cast<char *>(layout_str.data());
int64_t sparse_mode_const = sparse_mode.value_or(3);
int64_t pre_tokens_const = pre_tokens.value_or(9223372036854775807);
int64_t next_tokens_const = next_tokens.value_or(9223372036854775807);
TORCH_CHECK(query.dim() == DIMENSION_3D || query.dim() == DIMENSION_4D,
"The shapes of the input query should be 3 or 4 dimensional, but got ",
query.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
if (query_rope_const.defined()) {
TORCH_CHECK(query_rope_const.dim() == DIMENSION_3D || query_rope_const.dim() == DIMENSION_4D,
"The shapes of the input query_rope should be 3 or 4 dimensional, but got ",
query_rope_const.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
}
TORCH_CHECK(key.dim() == DIMENSION_3D || key.dim() == DIMENSION_4D,
"The shapes of the input key should be 3 or 4 dimensional, but got ", key.dim(),
"-dimensional", OPS_ERROR(ErrCode::PARAM));
if (key_rope_const.defined()) {
TORCH_CHECK(key_rope_const.dim() == DIMENSION_3D || key_rope_const.dim() == DIMENSION_4D,
"The shapes of the input key_rope should be 3 or 4 dimensional, but got ",
key_rope_const.dim(), "-dimensional", OPS_ERROR(ErrCode::PARAM));
}
at::Tensor d_query_index = OpPreparation::apply_tensor_without_format(query_index);
at::Tensor d_key_index = OpPreparation::apply_tensor_without_format(key_index);
at::Tensor d_weights = OpPreparation::apply_tensor_without_format(weights);
at::Tensor loss = OpPreparation::apply_tensor_without_format({1}, query.options().dtype(at::kFloat));
EXEC_NPU_NO_FORMAT_CHECK_CMD(
aclnnDenseLightningIndexerGradKLLoss, query, key, query_index, key_index, weights,
softmax_max, softmax_sum, softmax_max_index, softmax_sum_index, query_rope_const, key_rope_const,
actual_seq_qlen, actual_seq_klen, scale_value, layout_ptr, sparse_mode_const, pre_tokens_const,
next_tokens_const, d_query_index, d_key_index, d_weights, loss);
return std::make_tuple(d_query_index, d_key_index, d_weights, loss);
}
}