#include <set>
#include <op_plugin/OpApiInterface.h>
#include <torch_npu/csrc/framework/utils/InternalFormatOpAdapter.h>
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpUtils.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
static const int TWO_DIMS = 2;
const std::set<int> SUPPORT_WORLD_SIZE_LIST{2, 4, 8, 16};
std::tuple<at::Tensor, at::Tensor> npu_all_to_all_matmul(const at::Tensor &x1, const at::Tensor &x2, c10::string_view hcom, int64_t world_size,
const c10::optional<at::Tensor>& bias, c10::OptionalIntArrayRef all2all_axes, bool all2all_out_flag,
c10::optional<c10::string_view> comm_mode)
{
TORCH_CHECK(x1.dim() == TWO_DIMS, "The x1 input of alltoallmatmul is required to be 2D, but the actual x1 input is ", x1.dim(), "D." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(x2.dim() == TWO_DIMS, "The x2 input of alltoallmatmul is required to be 2D, but the actual x2 input is ", x2.dim(), "D." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(SUPPORT_WORLD_SIZE_LIST.find(world_size) != SUPPORT_WORLD_SIZE_LIST.end(),
"The world_size should be in [2, 4, 8, 16], but the actual value is ", world_size, "." + OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(x1.size(0) % world_size == 0, "The x1 first-axis should be divisible by world_size.", OPS_ERROR(ErrCode::PARAM));
aclDataType output_acl_type = npu_preparation::convert_to_acl_data_type(x1.scalar_type());
at::ScalarType output_scalar_type = npu_preparation::convert_to_scalar_type(output_acl_type);
int64_t out_m = x1.size(0) / world_size;
int64_t out_n = x2.size(1);
auto output_size = {out_m, out_n};
at::Tensor output_tensor = npu_preparation::apply_tensor_without_format(output_size, c10::dtype(output_scalar_type));
char *group_ptr = const_cast<char *>(hcom.data());
bool transpose_x1 = false;
bool transpose_x2 = false;
if (all2all_out_flag) {
aclDataType alltoall_out_acl_type = npu_preparation::convert_to_acl_data_type(x1.scalar_type());
at::ScalarType alltoall_out_scalar_type = npu_preparation::convert_to_scalar_type(alltoall_out_acl_type);
auto alltoall_out_size = {out_m, x1.size(1) * world_size};
at::Tensor alltoall_out_tensor = npu_preparation::apply_tensor_without_format(alltoall_out_size, c10::dtype(alltoall_out_scalar_type));
if (comm_mode.has_value()) {
TORCH_CHECK(check_aclnn_kernel_available("aclnnAlltoAllMatmulV2"),
"Current CANN version do not support this api. Please try to update the version of CANN." + OPS_ERROR(ErrCode::PARAM));
char *comm_mode_ptr = const_cast<char *>(comm_mode.value().data());
EXEC_NPU_CMD(aclnnAlltoAllMatmulV2, x1, x2, bias, all2all_axes, group_ptr, comm_mode_ptr, transpose_x1, transpose_x2, output_tensor, alltoall_out_tensor);
} else {
EXEC_NPU_CMD(aclnnAlltoAllMatmul, x1, x2, bias, all2all_axes, group_ptr, transpose_x1, transpose_x2, output_tensor, alltoall_out_tensor);
}
return std::tuple<at::Tensor, at::Tensor>(output_tensor, alltoall_out_tensor);
} else {
const std::nullptr_t& alltoalloutNullptr = nullptr;
if (comm_mode.has_value()) {
TORCH_CHECK(check_aclnn_kernel_available("aclnnAlltoAllMatmulV2"),
"Current CANN version do not support this api. Please try to update the version of CANN." + OPS_ERROR(ErrCode::PARAM));
char *comm_mode_ptr = const_cast<char *>(comm_mode.value().data());
EXEC_NPU_CMD(aclnnAlltoAllMatmulV2, x1, x2, bias, all2all_axes, group_ptr, comm_mode_ptr, transpose_x1, transpose_x2, output_tensor, alltoalloutNullptr);
} else {
EXEC_NPU_CMD(aclnnAlltoAllMatmul, x1, x2, bias, all2all_axes, group_ptr, transpose_x1, transpose_x2, output_tensor, alltoalloutNullptr);
}
return std::tuple<at::Tensor, at::Tensor>(output_tensor, alltoalloutNullptr);
}
}
}