#include "csrc/OpApiCommon.h"
#include "csrc/functions.h"
#include "csrc/utils.h"
constexpr size_t EDGE_NUM_DIM = 0;
constexpr size_t SRC_FEATURE_DIM = 1;
constexpr size_t FEATURE_NUM = 8;
at::Tensor graph_softmax(const at::Tensor &src, const at::Tensor &index, int N) {
TORCH_CHECK_NPU(src);
TORCH_CHECK_NPU(index);
TORCH_CHECK(src.dim() == 2, "src must be a 2D Tensor, but got: ", src.dim());
TORCH_CHECK(index.dim() == 1, "index must be a 1D Tensor, but got: ", index.dim());
TORCH_CHECK(index.sizes()[0] == src.sizes()[0], "The first dimension of index and src must be of equal size.");
TORCH_CHECK(src[0].sizes() == 8, "The second dimension of src must be 8, but got: ", src[0].sizes());
TORCH_CHECK(N > 0, "N must be positive, but got: ", N);
at::Tensor softmax_result = at::zeros_like(src);
EXEC_NPU_CMD(aclnnGraphSoftmax, src, index, N, softmax_result);
return softmax_result;
}
at::Tensor graph_softmax_grad(
const at::Tensor &index, const at::Tensor &softmax_out, const at::Tensor &grad_output, int32_t node_num) {
TORCH_CHECK(index.scalar_type() == at::kInt,
"index: int32 tensor expected but got a tensor with dtype: ", index.scalar_type());
TORCH_CHECK(softmax_out.scalar_type() == at::kFloat,
"softmax_out: float32 tensor expected but got a tensor with dtype: ", softmax_out.scalar_type());
TORCH_CHECK(grad_output.scalar_type() == at::kFloat,
"grad_output: float32 tensor expected but got a tensor with dtype: ", grad_output.scalar_type());
auto softmax_out_size = softmax_out.sizes();
auto index_size = index.sizes();
auto edge_num = softmax_out_size[EDGE_NUM_DIM];
auto feature_num = softmax_out_size[SRC_FEATURE_DIM];
auto index_edge = index_size[EDGE_NUM_DIM];
TORCH_CHECK(feature_num == FEATURE_NUM, "dim 2 of softmax_out tensor is invalid!");
TORCH_CHECK(edge_num == index_edge, "softmax_out tensor and index tensor should have same Edge num.");
at::Tensor grad_src = at::zeros({edge_num, feature_num}, softmax_out.options().dtype(at::kFloat));
at::Tensor reduce_sum = at::zeros({node_num, feature_num}, softmax_out.options().dtype(at::kFloat));
EXEC_NPU_CMD(aclnnGraphSoftmaxGrad, index, softmax_out, grad_output, reduce_sum, node_num, grad_src);
return grad_src;
}