#include <set>
#include <op_plugin/OpApiInterface.h>
#include <torch_npu/csrc/framework/utils/InternalFormatOpAdapter.h>
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpUtils.h"
#include "op_plugin/utils/KernelNpuOutputSize.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
static const int DIM_TWO = 2;
static const int DIM_THREE = 3;
static const int DIM_FOUR = 4;
static const int NUM_64 = 64;
static const int NUM_128 = 128;
static const int H_LOWER_LIMIT = 1024;
static const int H_UPPER_LIMIT = 8192;
const std::set<int> SUPPORT_WORLD_SIZE_LIST{2, 4, 8};
const std::set<int64_t> SUPPORT_X_DTYPE_LIST{
static_cast<int64_t>(c10_npu::DType::INT8),
static_cast<int64_t>(c10_npu::DType::HIFLOAT8),
static_cast<int64_t>(c10_npu::DType::FLOAT8_E5M2),
static_cast<int64_t>(c10_npu::DType::FLOAT8_E4M3FN),
static_cast<int64_t>(c10_npu::DType::FLOAT4_E1M2),
static_cast<int64_t>(c10_npu::DType::FLOAT4_E2M1)
};
const std::set<int64_t> SUPPORT_SCALES_DTYPE_LIST{
static_cast<int64_t>(c10_npu::DType::FLOAT),
static_cast<int64_t>(c10_npu::DType::FLOAT8_E8M0)
};
at::Tensor npu_quant_all_reduce(const at::Tensor &x, const at::Tensor &scales, c10::string_view hcom,
int64_t world_size, c10::optional<c10::string_view> reduce_op,
c10::optional<int64_t> output_dtype, c10::optional<int64_t> x_dtype,
c10::optional<int64_t> scales_dtype)
{
TORCH_CHECK(x.defined(), "The input tensor x can not be None.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(scales.defined(), "The input tensor scales can not be None.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(x.dim() == DIM_TWO || x.dim() == DIM_THREE,
"The input x tensor shape is required to be 2 or 3 dim, but the actual input shape is ",
x.dim(), OPS_ERROR(ErrCode::PARAM));
if (x.dim() == DIM_TWO) {
TORCH_CHECK(x.size(0) != 0 && x.size(1) != 0, "The input 2 dim tensor x can not be empty tensor", OPS_ERROR(ErrCode::PARAM));
} else if (x.dim() == DIM_THREE) {
TORCH_CHECK(x.size(0) != 0 && x.size(1) != 0 && x.size(DIM_TWO) != 0, "The input 3 dim tensor x can not be empty tensor",
OPS_ERROR(ErrCode::PARAM));
}
if (x_dtype.has_value()) {
TORCH_CHECK(SUPPORT_X_DTYPE_LIST.find(x_dtype.value()) != SUPPORT_X_DTYPE_LIST.end(),
"The optional parameter x_dtype only supports int8/hifloat8/float8_e4m3fn/float8_e5m2, but now is ",
op_plugin::utils::DTypeToString(x_dtype.value()), "." + OPS_ERROR(ErrCode::VALUE));
}
TORCH_CHECK(SUPPORT_WORLD_SIZE_LIST.find(world_size) != SUPPORT_WORLD_SIZE_LIST.end(),
"The world_size should be in ", c10::Join(", ", SUPPORT_WORLD_SIZE_LIST),
", but the actual value is ", world_size, OPS_ERROR(ErrCode::VALUE));
int64_t axis_bs = x.size(0);
if (x.dim() == DIM_THREE) {
axis_bs = axis_bs * x.size(1);
}
TORCH_CHECK(axis_bs % world_size == 0, "The x BS-axis should be divisible by world_size",
OPS_ERROR(ErrCode::PARAM));
uint32_t axis_h = (x.dim() == DIM_THREE ? 2 : 1);
TORCH_CHECK(x.size(axis_h) >= H_LOWER_LIMIT && x.size(axis_h) <= H_UPPER_LIMIT && x.size(axis_h) % NUM_128 == 0,
"The x H-axis should be in [1024, 8192] and divisible by 128", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(scales.dim() == DIM_TWO || scales.dim() == DIM_THREE || scales.dim() == DIM_FOUR,
"The input scales tensor shape is required to be equal to x in TG QuantMode, "
"or be equal to x plus 1 in MX QuantMode, but the actual input scales shape is ",
scales.dim(), OPS_ERROR(ErrCode::PARAM));
if (scales.dim() == DIM_TWO) {
TORCH_CHECK(scales.size(0) != 0 && scales.size(1) != 0, "The input 2 dim tensor scales can not be empty tensor",
OPS_ERROR(ErrCode::PARAM));
} else if (scales.dim() == DIM_THREE) {
TORCH_CHECK(scales.size(0) != 0 && scales.size(1) != 0 && scales.size(DIM_TWO) != 0,
"The input 3 dim tensor scales can not be empty tensor", OPS_ERROR(ErrCode::PARAM));
} else if (scales.dim() == DIM_FOUR) {
TORCH_CHECK(scales.size(0) != 0 && scales.size(1) != 0 && scales.size(DIM_TWO) != 0 && scales.size(DIM_THREE) != 0,
"The input 4 dim tensor scales can not be empty tensor", OPS_ERROR(ErrCode::PARAM));
}
if (scales_dtype.has_value()) {
TORCH_CHECK(SUPPORT_SCALES_DTYPE_LIST.find(scales_dtype.value()) != SUPPORT_SCALES_DTYPE_LIST.end(),
"The optional parameter scales_dtype only supports float/float_e8m0, but now is ",
op_plugin::utils::DTypeToString(scales_dtype.value()), "." + OPS_ERROR(ErrCode::VALUE));
}
int64_t output_default_dtype = static_cast<int64_t>(at::ScalarType::BFloat16);
if (output_dtype.has_value()) {
output_default_dtype = output_dtype.value();
}
aclDataType output_acl_type = c10_npu::GetAclDataType(output_default_dtype);
at::ScalarType output_scalar_type = npu_preparation::convert_to_scalar_type(output_acl_type);
at::Tensor output_tensor = npu_preparation::apply_tensor_without_format(op_infer::array_to_small_vector(x.sizes()),
c10::dtype(output_scalar_type));
char *group_ptr = const_cast<char *>(hcom.data());
c10::string_view reduce_op_value = reduce_op.value_or("sum");
char *reduce_op_ptr = const_cast<char *>(reduce_op_value.data());
TensorWrapper x_wrapper = make_wrapper(x, x_dtype);
TensorWrapper scales_wrapper = make_wrapper(scales, scales_dtype);
EXEC_NPU_CMD(aclnnQuantAllReduce, x_wrapper, scales_wrapper, group_ptr, reduce_op_ptr, output_tensor);
return output_tensor;
}
}