#include <set>
#include <cstring>
#include <op_plugin/OpApiInterface.h>
#include <torch_npu/csrc/framework/utils/InternalFormatOpAdapter.h>
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpUtils.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
static const int TWO_DIMS = 2;
static const int64_t PERTOKEN_QUANT_MODE = 3;
static const int64_t PERCHANNEL_QUANT_MODE = 2;
static const int64_t NON_QUANT = 0;
static const int64_t ACL_UNDEFINED = -1;
const std::set<int> SUPPORT_WORLD_SIZE_LIST{2, 4, 8, 16};
at::Tensor npu_quant_matmul_all_to_all(const at::Tensor &x1, const at::Tensor &x2, c10::string_view hcom,
int64_t world_size, const c10::optional<at::Tensor>& bias, const c10::optional<at::Tensor>& x1_scale,
const c10::optional<at::Tensor>& x2_scale, const c10::optional<at::Tensor>& common_scale,
const c10::optional<at::Tensor>& x1_offset, const c10::optional<at::Tensor>& x2_offset,
c10::optional<int64_t> x1_quant_mode, c10::optional<int64_t> x2_quant_mode, c10::optional<int64_t> common_quant_mode,
c10::OptionalIntArrayRef group_sizes, c10::OptionalIntArrayRef all2all_axes,
c10::optional<int64_t> comm_quant_dtype, c10::optional<int64_t> x1_dtype, c10::optional<int64_t> x2_dtype,
c10::optional<int64_t> x1_scale_dtype, c10::optional<int64_t> x2_scale_dtype,
c10::optional<int64_t> output_scale_dtype, c10::optional<int64_t> comm_scale_dtype, c10::optional<int64_t> y_dtype,
c10::optional<c10::string_view> comm_mode
)
{
TORCH_CHECK(x1.dim() == TWO_DIMS, "The x1 input of quantmatmulalltoall is required to be 2D, but the actual x1 input is ", x1.dim(), "D." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(x2.dim() == TWO_DIMS, "The x2 input of quantmatmulalltoall is required to be 2D, but the actual x2 input is ", x2.dim(), "D." + OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(SUPPORT_WORLD_SIZE_LIST.find(world_size) != SUPPORT_WORLD_SIZE_LIST.end(),
"The world_size should be in [2, 4, 8, 16], but the actual value is ", world_size, "." + OPS_ERROR(ErrCode::VALUE));
at::IntArrayRef group_size_list = group_sizes.value_or(at::IntArrayRef{});
int64_t group_size = op_plugin::utils::check_and_get_group_size(group_size_list);
int64_t output_default_dtype = static_cast<int64_t>(at::ScalarType::Float);
if (y_dtype.has_value() && y_dtype.value() != ACL_UNDEFINED) {
output_default_dtype = y_dtype.value();
}
aclDataType output_acl_type = c10_npu::GetAclDataType(output_default_dtype);
at::ScalarType output_scalar_type = npu_preparation::convert_to_scalar_type(output_acl_type);
int64_t out_m = x1.size(0) * world_size;
int64_t out_n = x2.size(1) / world_size;
auto output_size = {out_m, out_n};
at::Tensor output_tensor = npu_preparation::apply_tensor_without_format(output_size, c10::dtype(output_scalar_type));
char *group_ptr = const_cast<char *>(hcom.data());
int64_t x1QuantMode = x1_quant_mode.has_value() ? x1_quant_mode.value() : PERTOKEN_QUANT_MODE;
int64_t x2QuantMode = x2_quant_mode.has_value() ? x2_quant_mode.value() : PERCHANNEL_QUANT_MODE;
int64_t commonQuantMode = common_quant_mode.has_value() ? common_quant_mode.value() : NON_QUANT;
int64_t commQuantDtype = comm_quant_dtype.has_value() ? comm_quant_dtype.value() : ACL_UNDEFINED;
bool transpose_x1 = false;
bool transpose_x2 = false;
TensorWrapper x1_wrapper = make_wrapper(x1, x1_dtype);
TensorWrapper x2_wrapper = make_wrapper(x2, x2_dtype);
const at::Tensor &x1_scale_real = x1_scale.value_or(at::Tensor());
const at::Tensor &x2_scale_real = x2_scale.value_or(at::Tensor());
TensorWrapper x1_scale_wrapper = make_wrapper(x1_scale_real, x1_scale_dtype);
TensorWrapper x2_scale_wrapper = make_wrapper(x2_scale_real, x2_scale_dtype);
if (comm_mode.has_value()) {
TORCH_CHECK(check_aclnn_kernel_available("aclnnQuantMatmulAlltoAllV2"),
"Too old ops-transformer package, please update. Or use comm_mode = None." + OPS_ERROR(ErrCode::PARAM));
char *comm_mode_ptr = const_cast<char *>(comm_mode.value().data());
EXEC_NPU_CMD(aclnnQuantMatmulAlltoAllV2, x1_wrapper, x2_wrapper, bias, x1_scale_wrapper, x2_scale_wrapper,
common_scale, x1_offset, x2_offset, all2all_axes, group_ptr, comm_mode_ptr, x1QuantMode, x2QuantMode,
commonQuantMode, commQuantDtype, group_size, transpose_x1, transpose_x2, output_tensor);
} else {
EXEC_NPU_CMD(aclnnQuantMatmulAlltoAll, x1_wrapper, x2_wrapper, bias, x1_scale_wrapper, x2_scale_wrapper,
common_scale, x1_offset, x2_offset, all2all_axes, group_ptr, x1QuantMode, x2QuantMode, commonQuantMode,
commQuantDtype, group_size, transpose_x1, transpose_x2, output_tensor);
}
return output_tensor;
}
}