#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
at::Tensor &npu_attn_softmax_backward_(at::Tensor &self, const at::Tensor &grad_output, const at::Tensor &values)
{
int8_t cube_math_type = npu_preparation::get_cube_math_type(at_npu::native::env::IsAllowMatmulHF32());
int8_t cube_math_type_passthrough = npu_preparation::get_cube_math_type();
if (cube_math_type_passthrough >= 0) {
cube_math_type = cube_math_type_passthrough;
}
at::Tensor values_tmp = values;
values_tmp = values_tmp.transpose(-2, -1);
auto output_size = op_infer::matmul_output_size(grad_output, values_tmp);
auto matmul_result = at_npu::native::OpPreparation::apply_tensor_without_format(output_size, grad_output.options());
EXEC_NPU_CMD(aclnnMatmul, grad_output, values_tmp, matmul_result, cube_math_type);
int64_t dim = -1;
EXEC_NPU_CMD(aclnnSoftmaxBackward, matmul_result, self, dim, self);
return self;
}
}