#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
using npu_utils = at_npu::native::NpuUtils;
using tensor_list = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
const int DIM_TWO = 2;
tensor_list npu_moe_init_routing(const at::Tensor &x, const at::Tensor &row_idx, const at::Tensor &expert_idx, int64_t active_num)
{
TORCH_NPU_WARN_ONCE("The oprator of MoeInitRouting will be removed from Pytorch and switch to AscendSpeed after 630.");
TORCH_CHECK(x.dim() == DIM_TWO, "The x should be 2D");
TORCH_CHECK(
x.scalar_type() == at::kHalf || x.scalar_type() == at::kFloat || x.scalar_type() == at::kBFloat16,
"float16、float32 or bfloat16 tensor expected but got a tensor with dtype: ",
x.scalar_type());
TORCH_CHECK(expert_idx.dim() == DIM_TWO, "The expert_idx should be 2D");
TORCH_CHECK(
expert_idx.scalar_type() == at::kInt,
"int32 tensor expected but got a tensor with dtype: ",
expert_idx.scalar_type());
TORCH_CHECK(row_idx.dim() == DIM_TWO, "The row_idx should be 2D");
TORCH_CHECK(
row_idx.scalar_type() == at::kInt,
"int32 tensor expected but got a tensor with dtype: ",
row_idx.scalar_type());
TORCH_CHECK(active_num >= 0, "The active_num must be a non-negative number");
auto x_size = x.sizes();
auto expert_idx_size = expert_idx.sizes();
auto row_idx_size = row_idx.sizes();
TORCH_CHECK(x_size[0] == expert_idx_size[0], "Input rows shoud be same.");
TORCH_CHECK(expert_idx_size == row_idx_size, "The shape of expert_idx and row_idx should be same.");
int n = x_size[0];
int cols = x_size[1];
int k = expert_idx_size[1];
active_num = n > active_num ? active_num : n;
at::Tensor expanded_x = npu_preparation::apply_tensor_without_format(x, {active_num * k, cols});
at::Tensor expanded_row_idx = npu_preparation::apply_tensor_without_format(row_idx, {n * k});
at::Tensor expanded_expert_idx = npu_preparation::apply_tensor_without_format(expert_idx, {n * k});
EXEC_NPU_CMD(aclnnMoeInitRouting, x, row_idx, expert_idx, active_num, expanded_x, expanded_row_idx, expanded_expert_idx);
return std::tie(expanded_x, expanded_row_idx, expanded_expert_idx);
}
}