#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "torch_npu/csrc/custom_dtype/Init.h"
#include "torch_npu/csrc/framework/utils/InternalFormatOpAdapter.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
at::ScalarType get_output_dtype(const at::Tensor& x1, const c10::optional<at::Tensor>& dequant_scale)
{
auto output_dtype = x1.scalar_type() == at::kChar ? at::ScalarType::Half : x1.scalar_type();
if (dequant_scale.has_value()) {
const at::Tensor& dequant = dequant_scale.value();
if (dequant.scalar_type() == at::kBFloat16) {
output_dtype = at::kBFloat16;
}
}
return output_dtype;
}
void check_params(const at::Tensor& x1, const at::Tensor& x2,
const c10::optional<at::Tensor>& antiquant_scale,
const c10::optional<at::Tensor>& antiquant_offset,
const c10::optional<at::Tensor>& x3,
const c10::optional<at::Tensor>& dequant_scale,
const c10::optional<at::Tensor>& pertoken_scale,
const c10::optional<at::Tensor>& comm_quant_scale_1,
const c10::optional<at::Tensor>& comm_quant_scale_2)
{
TORCH_CHECK(x2.dim() == 2, "x2 needs to be 2D, but got: ", x2.dim(), "D", OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(x1.size(x1.dim() - 1) == x2.size(0), "K of x1 and x2 should be same, but they are x1_k: ",
x1.size(x1.dim() - 1), ", x2_k: ", x2.size(0), OPS_ERROR(ErrCode::VALUE));
if (isIntegralType(x1.scalar_type()) && isIntegralType(x2.scalar_type())) {
TORCH_CHECK(x1.scalar_type() == at::kChar, "x1 must be an int8 tensor for quant.", OPS_ERROR(ErrCode::TYPE));
TORCH_CHECK(x2.scalar_type() == at::kChar, "x2 must be an int8 tensor for quant.", OPS_ERROR(ErrCode::TYPE));
TORCH_CHECK((!antiquant_scale.has_value() && !antiquant_offset.has_value()),
"when both dtype of x1 and dtype of x2 are equal to int8, "
"antiquantScale, antiquantOffset should both be null", OPS_ERROR(ErrCode::TYPE));
} else if (!isIntegralType(x1.scalar_type()) && isIntegralType(x2.scalar_type())) {
TORCH_CHECK(x2.scalar_type() == at::kChar, "x2 must be an int8 tensor for weight quant.", OPS_ERROR(ErrCode::TYPE));
TORCH_CHECK((!dequant_scale.has_value()),
"when only dtype of x2 is equal to int8, dequantScale should be null", OPS_ERROR(ErrCode::TYPE));
} else if (!isIntegralType(x1.scalar_type()) && !isIntegralType(x2.scalar_type())) {
TORCH_CHECK((!antiquant_scale.has_value() && !antiquant_offset.has_value() && !dequant_scale.has_value()),
"when neither dtype of x1 or dtype of x2 is equal to int8, "
"antiquantScale, antiquantOffset and dequantScale should all be null", OPS_ERROR(ErrCode::TYPE));
} else {
TORCH_CHECK(false, "when neither dtype of x1 or dtype of x2 is valid. ", OPS_ERROR(ErrCode::TYPE));
}
if (x3.has_value()) {
auto output_size = op_infer::array_to_small_vector(x1.sizes());
output_size[x1.dim() - 1] = x2.size(1);
auto output_dtype = get_output_dtype(x1, dequant_scale);
const at::Tensor& x3_real = x3.value();
TORCH_CHECK(x3_real.sizes().equals(output_size), "x3 with shape ", x3_real.sizes(),
" doesn't match the output shape ", output_size, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(x3_real.scalar_type() == output_dtype, "x3 with dtype ", x3_real.scalar_type(),
" doesn't match the output dtype ", output_dtype, OPS_ERROR(ErrCode::PARAM));
}
if (pertoken_scale.has_value()) {
TORCH_CHECK((dequant_scale.has_value()),
"when has pertoken_scale, dequantScale shoulden't be null", OPS_ERROR(ErrCode::TYPE));
const at::Tensor& pertoken_scale_real = pertoken_scale.value();
TORCH_CHECK(pertoken_scale_real.dim() == 1, "pertoken_scale with shape ", pertoken_scale_real.sizes(),
" pertoken_scale dim should be 1.", OPS_ERROR(ErrCode::PARAM));
auto x1_size = op_infer::array_to_small_vector(x1.sizes());
int64_t x1_m = 1;
for (int dim = 0; dim < x1.dim() - 1; dim++) {
x1_m *= x1_size[dim];
}
TORCH_CHECK(x1_m == pertoken_scale_real.size(0), "pertoken_scale with shape ", pertoken_scale_real.sizes(),
" doesn't match the input shape ", x1_size, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(pertoken_scale_real.scalar_type() == at::ScalarType::Float,
"pertoken_scale with dtype ", pertoken_scale_real.scalar_type(),
" doesn't match the output dtype ", at::ScalarType::Float, OPS_ERROR(ErrCode::PARAM));
}
TORCH_CHECK((comm_quant_scale_1.has_value() && comm_quant_scale_2.has_value()) ||
(!comm_quant_scale_1.has_value() && !comm_quant_scale_2.has_value()),
"comm_quant_scale_1 and comm_quant_scale_2 should both be null or not null", OPS_ERROR(ErrCode::TYPE));
if (comm_quant_scale_1.has_value() && comm_quant_scale_2.has_value()) {
const at::Tensor& comm_quant_scale_1_real = comm_quant_scale_1.value();
const at::Tensor& comm_quant_scale_2_real = comm_quant_scale_2.value();
TORCH_CHECK((comm_quant_scale_1_real.dim() == 2 && comm_quant_scale_2_real.dim() == 2) || (comm_quant_scale_1_real.dim() == 1 &&
comm_quant_scale_2_real.dim() == 1), "comm_quant_scale_1 and comm_quant_scale_2 both need to be 1D or 2D, but got: comm_quant_scale_1",
comm_quant_scale_1_real.dim(), "D, comm_quant_scale_2", comm_quant_scale_2_real.dim(), "D", OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK((comm_quant_scale_1_real.dim() == 2 && comm_quant_scale_1_real.size(0) == 1 && comm_quant_scale_2_real.size(0) == 1 &&
comm_quant_scale_1_real.size(1) == x2.size(1) && comm_quant_scale_2_real.size(1) == x2.size(1)) ||
(comm_quant_scale_1_real.dim() == 1 && comm_quant_scale_1_real.size(0) == x2.size(1) && comm_quant_scale_2_real.size(0) == x2.size(1)),
"comm_quant_scale_1 and comm_quant_scale_2 shape do not match [1,n] or [n], n=", x2.size(1), ", comm_quant_scale_1 shape: ",
comm_quant_scale_1_real.sizes(), ", comm_quant_scale_2 shape: ", comm_quant_scale_2_real.sizes(), OPS_ERROR(ErrCode::PARAM));
auto output_dtype = get_output_dtype(x1, dequant_scale);
TORCH_CHECK(comm_quant_scale_1_real.scalar_type() == output_dtype && comm_quant_scale_2_real.scalar_type() == output_dtype,
"comm_quant_scale_1 with dtype ", comm_quant_scale_1_real.scalar_type(), "comm_quant_scale_2 with dtype ",
comm_quant_scale_2_real.scalar_type(), " doesn't match the output dtype ", output_dtype, OPS_ERROR(ErrCode::PARAM));
}
}
void check_a5_required_inputs(const at::Tensor& x1, const at::Tensor& x2, bool is_a8, bool is_w8, bool is_a4, bool is_w4)
{
if ((is_a8 && !is_w8) || (is_a4 && !is_w4)) {
TORCH_CHECK(false, "when dtype of x1 is fp8 or fp4, dtype of x2 also should be fp8 or fp4",
OPS_ERROR(ErrCode::TYPE));
} else if (isIntegralType(x1.scalar_type()) && isIntegralType(x2.scalar_type())) {
TORCH_CHECK((x1.scalar_type() == at::kChar) || (x1.scalar_type() == at::kInt) || (x1.scalar_type() == at::kByte),
"x1 must be an int8 or uint8 tensor for quant, but x1 actual type is ",
c10::toString(x1.scalar_type()),
OPS_ERROR(ErrCode::TYPE));
TORCH_CHECK((x2.scalar_type() == at::kChar) || (x2.scalar_type() == at::kInt) || (x2.scalar_type() == at::kByte),
"x2 must be an int8 or uint8 tensor for quant, but x2 actual type is ",
c10::toString(x2.scalar_type()),
OPS_ERROR(ErrCode::TYPE));
} else if ((!isIntegralType(x1.scalar_type()) && isIntegralType(x2.scalar_type())) || (!is_a8 && is_w8)) {
TORCH_CHECK((x2.scalar_type() == at::kChar) || (x2.scalar_type() == at::kInt) || is_w8,
"x2 must be an int8 or uint8 tensor for weight quant, but x2 actual type is ",
c10::toString(x2.scalar_type()),
OPS_ERROR(ErrCode::TYPE));
} else if (isIntegralType(x1.scalar_type()) && !isIntegralType(x2.scalar_type())) {
TORCH_CHECK(false, "when neither dtype of x1 or dtype of x2 is vaild, current x1 dtype is ",
c10::toString(x1.scalar_type()), " and x2 dtype is ", c10::toString(x2.scalar_type()),
OPS_ERROR(ErrCode::TYPE));
}
}
at::Tensor npu_mm_all_reduce_base(const at::Tensor& x1, const at::Tensor& x2, c10::string_view hcom,
c10::string_view reduce_op, const c10::optional<at::Tensor>& bias,
const c10::optional<at::Tensor>& antiquant_scale,
const c10::optional<at::Tensor>& antiquant_offset,
const c10::optional<at::Tensor>& x3, const c10::optional<at::Tensor>& dequant_scale,
const c10::optional<at::Tensor>& pertoken_scale,
const c10::optional<at::Tensor>& comm_quant_scale_1,
const c10::optional<at::Tensor>& comm_quant_scale_2,
int64_t antiquant_group_size, int64_t comm_turn,
c10::OptionalIntArrayRef group_sizes, c10::optional<int64_t> y_dtype,
c10::optional<int64_t> x1_dtype, c10::optional<int64_t> x2_dtype,
c10::optional<int64_t> dequant_scale_dtype,
c10::optional<int64_t> pertoken_scale_dtype,
int64_t comm_quant_mode, c10::optional<c10::string_view> comm_mode)
{
bool is_comm_mode = comm_mode.has_value();
c10::string_view comm_mode_value = comm_mode.value_or("");
char *comm_mode_ptr = const_cast<char *>(comm_mode_value.data());
bool is_a8 = (x1_dtype.has_value() ?
(x1_dtype.value() == static_cast<int>(c10_npu::DType::HIFLOAT8)) : false) ||
(x1.scalar_type() == at::kFloat8_e4m3fn) || (x1.scalar_type() == at::kFloat8_e5m2);
bool is_w8 = (x2_dtype.has_value() ?
(x2_dtype.value() == static_cast<int>(c10_npu::DType::HIFLOAT8)) : false) ||
(x2.scalar_type() == at::kFloat8_e4m3fn) || (x2.scalar_type() == at::kFloat8_e5m2);
bool is_a4 = x1_dtype.has_value() ?
((x1_dtype.value() == static_cast<int>(c10_npu::DType::FLOAT4_E1M2)) ||
(x1_dtype.value() == static_cast<int>(c10_npu::DType::FLOAT4_E2M1))) : false;
bool is_w4 = x2_dtype.has_value() ?
((x2_dtype.value() == static_cast<int>(c10_npu::DType::FLOAT4_E1M2)) ||
(x2_dtype.value() == static_cast<int>(c10_npu::DType::FLOAT4_E2M1))) : false;
if (!c10_npu::IsAclnnOnly()) {
check_params(x1, x2, antiquant_scale, antiquant_offset, x3, dequant_scale, pertoken_scale, comm_quant_scale_1,
comm_quant_scale_2);
} else {
check_a5_required_inputs(x1, x2, is_a8, is_w8, is_a4, is_w4);
}
at::IntArrayRef group_size_list = group_sizes.value_or(at::IntArrayRef{});
int64_t group_size = op_plugin::utils::check_and_get_group_size(group_size_list);
TORCH_CHECK(group_size != -1, "Invalid group_sizes.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(x1.dim() > 1, "x1 must have at least 2 dimensions, but got: ", x1.dim(), " dimensions.", OPS_ERROR(ErrCode::VALUE));
TORCH_CHECK(x2.dim() > 1, "x2 must have at least 2 dimensions, but got: ", x2.dim(), " dimensions.", OPS_ERROR(ErrCode::VALUE));
auto output_size = op_infer::array_to_small_vector(x1.sizes());
output_size[x1.dim() - 1] = x2.size(x2.dim() - 1);
auto output_dtype = get_output_dtype(x1, dequant_scale);
if (is_a8 || is_a4) {
TORCH_CHECK(y_dtype.has_value(), "input dtype is fp8 or fp4, but no input y_dtype",
OPS_ERROR(ErrCode::PARAM));
auto output_acltype = c10_npu::GetAclDataType(y_dtype.value());
output_dtype= npu_preparation::convert_to_scalar_type(output_acltype);
}
bool is_fp8 = is_a8 || is_w8;
auto result = at_npu::native::OpPreparation::apply_tensor_without_format(output_size,
x1.options().dtype(output_dtype));
char* reduce_op_ptr = const_cast<char*>(reduce_op.data());
char* hcom_ptr = const_cast<char*>(hcom.data());
const at::Tensor& bias_real = bias.value_or(at::Tensor());
const at::Tensor& antiquant_scale_real = antiquant_scale.value_or(at::Tensor());
const at::Tensor& antiquant_offset_real = antiquant_offset.value_or(at::Tensor());
const at::Tensor& x3_real = x3.value_or(at::Tensor());
const at::Tensor& dequant_scale_real = dequant_scale.value_or(at::Tensor());
const at::Tensor& pertoken_scale_real = pertoken_scale.value_or(at::Tensor());
const at::Tensor& comm_quant_scale_1_real = comm_quant_scale_1.value_or(at::Tensor());
const at::Tensor& comm_quant_scale_2_real = comm_quant_scale_2.value_or(at::Tensor());
int64_t stream_mode = ACL_STOP_ON_FAILURE;
TensorWrapper x1_wrapper = {x1, (x1_dtype.has_value()) ?
c10_npu::GetAclDataType(x1_dtype.value()) :
npu_preparation::convert_to_acl_data_type(x1.scalar_type())};
TensorWrapper x2_wrapper = {x2, (x2_dtype.has_value()) ?
c10_npu::GetAclDataType(x2_dtype.value()) :
npu_preparation::convert_to_acl_data_type(x2.scalar_type())};
auto pertoken_scale_tensor_dtype = pertoken_scale.has_value() ?
npu_preparation::convert_to_acl_data_type(pertoken_scale_real.scalar_type()) :
aclDataType::ACL_FLOAT;
TensorWrapper pertoken_scale_wrapper = {pertoken_scale_real, (pertoken_scale_dtype.has_value()) ?
c10_npu::GetAclDataType(pertoken_scale_dtype.value()) :
pertoken_scale_tensor_dtype};
auto dequant_scale_tensor_dtype = dequant_scale.has_value() ?
npu_preparation::convert_to_acl_data_type(dequant_scale_real.scalar_type()) :
aclDataType::ACL_FLOAT;
TensorWrapper dequant_scale_wrapper = {dequant_scale_real, (dequant_scale_dtype.has_value()) ?
c10_npu::GetAclDataType(dequant_scale_dtype.value()) :
dequant_scale_tensor_dtype};
if (is_a8 || is_a4) {
TORCH_CHECK(dequant_scale.has_value(), "a8w8 or a4w4 scense should input dequant_scale",
OPS_ERROR(ErrCode::PARAM));
if (is_comm_mode) {
TORCH_CHECK(check_aclnn_kernel_available("aclnnQuantMatmulAllReduceV5"),
"In this CANN Version, comm_mode is not supported, please upgrade CANN Version.", OPS_ERROR(ErrCode::PARAM));
EXEC_NPU_CMD(aclnnQuantMatmulAllReduceV5, x1_wrapper, x2_wrapper, bias_real, x3_real, pertoken_scale_wrapper,
dequant_scale_wrapper, comm_quant_scale_1_real, comm_quant_scale_2_real, hcom_ptr, reduce_op_ptr,
comm_mode_ptr, comm_turn, stream_mode, group_size, comm_quant_mode, result);
} else {
EXEC_NPU_CMD(aclnnQuantMatmulAllReduceV4, x1_wrapper, x2_wrapper, bias_real, x3_real, pertoken_scale_wrapper,
dequant_scale_wrapper, comm_quant_scale_1_real, comm_quant_scale_2_real, hcom_ptr, reduce_op_ptr,
comm_turn, stream_mode, group_size, comm_quant_mode, result);
}
}
if (!isIntegralType(x1.scalar_type()) && !isIntegralType(x2.scalar_type()) && !is_fp8 && !is_a4) {
if (is_comm_mode) {
TORCH_CHECK(check_aclnn_kernel_available("aclnnMatmulAllReduceV3"),
"In this CANN Version, comm_mode is not supported, please upgrade CANN Version.", OPS_ERROR(ErrCode::PARAM));
EXEC_NPU_CMD(aclnnMatmulAllReduceV3, x1, x2, bias_real, x3_real, hcom_ptr, reduce_op_ptr,
comm_mode_ptr, comm_turn, stream_mode, result);
} else {
if (x3.has_value()) {
EXEC_NPU_CMD(aclnnMatmulAllReduceV2, x1, x2, bias_real, x3_real, hcom_ptr, reduce_op_ptr, comm_turn, stream_mode, result);
} else {
EXEC_NPU_CMD(aclnnMatmulAllReduce, x1, x2, bias_real, hcom_ptr, reduce_op_ptr, comm_turn, stream_mode, result);
}
}
}
if (isIntegralType(x1.scalar_type()) && isIntegralType(x2.scalar_type()) && !is_fp8 && !is_a4) {
if (is_comm_mode) {
TORCH_CHECK(check_aclnn_kernel_available("aclnnQuantMatmulAllReduceV5"),
"In this CANN Version, comm_mode is not supported, please upgrade CANN Version.", OPS_ERROR(ErrCode::PARAM));
EXEC_NPU_CMD(aclnnQuantMatmulAllReduceV5, x1, x2, bias_real, x3_real, pertoken_scale_real,
dequant_scale_real, comm_quant_scale_1_real, comm_quant_scale_2_real, hcom_ptr,
reduce_op_ptr, comm_mode_ptr, comm_turn, stream_mode, group_size, comm_quant_mode, result);
} else {
if (comm_quant_scale_1.has_value() && comm_quant_scale_2.has_value()) {
EXEC_NPU_CMD(aclnnQuantMatmulAllReduceV3, x1, x2, bias_real, x3_real, dequant_scale_real,
pertoken_scale_real, comm_quant_scale_1_real, comm_quant_scale_2_real, hcom_ptr,
reduce_op_ptr, comm_turn, stream_mode, result);
} else if (comm_quant_scale_1.has_value() || comm_quant_scale_2.has_value()) {
TORCH_CHECK(false,
"comm_quant_scale_1 and comm_quant_scale_2 should both be null or not null", OPS_ERROR(ErrCode::TYPE));
} else if (pertoken_scale.has_value()) {
EXEC_NPU_CMD(aclnnQuantMatmulAllReduceV2, x1, x2, bias_real, x3_real, dequant_scale_real, pertoken_scale_real, hcom_ptr, reduce_op_ptr,
comm_turn, stream_mode, result);
} else {
EXEC_NPU_CMD(aclnnQuantMatmulAllReduce, x1, x2, bias_real, x3_real, dequant_scale_real, hcom_ptr, reduce_op_ptr, comm_turn, stream_mode, result);
}
}
}
if (!isIntegralType(x1.scalar_type()) && isIntegralType(x2.scalar_type()) || (!is_a8 && is_w8)) {
if (is_comm_mode) {
TORCH_CHECK(check_aclnn_kernel_available("aclnnWeightQuantMatmulAllReduceV2"),
"In this CANN Version, comm_mode is not supported, please upgrade CANN Version.", OPS_ERROR(ErrCode::PARAM));
EXEC_NPU_CMD(aclnnWeightQuantMatmulAllReduceV2, x1_wrapper, x2_wrapper, bias_real, antiquant_scale_real, antiquant_offset_real,
x3_real, hcom_ptr, reduce_op_ptr, comm_mode_ptr, comm_turn, stream_mode, antiquant_group_size, result);
} else {
EXEC_NPU_CMD(aclnnWeightQuantMatmulAllReduce, x1_wrapper, x2_wrapper, bias_real, antiquant_scale_real, antiquant_offset_real,
x3_real, hcom_ptr, reduce_op_ptr, comm_turn, stream_mode, antiquant_group_size, result);
}
}
FLOP_COUNT(FlopCounter::mm_flop, x1, x2);
return result;
}
}