#include "op_plugin/OpApiInterface.h"
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
namespace {
const int64_t INT4_NUMS_IN_INT32_SPACE = 8;
static std::map<int64_t, at::ScalarType> QUANTIZE_SUPPORT_MAP = {
{static_cast<int64_t>(at::kQUInt8), at::ScalarType::Byte},
{static_cast<int64_t>(at::kQInt8), at::ScalarType::Char},
{static_cast<int64_t>(at::kQInt32), at::ScalarType::Int},
{static_cast<int64_t>(at::kByte), at::ScalarType::Byte},
{static_cast<int64_t>(at::kChar), at::ScalarType::Char},
{static_cast<int64_t>(at::kInt), at::ScalarType::Int},
{static_cast<int64_t>(at::kFloat8_e4m3fn), at::ScalarType::Float8_e4m3fn},
{static_cast<int64_t>(at::kFloat8_e5m2), at::ScalarType::Float8_e5m2},
{static_cast<int64_t>(c10_npu::DType::HIFLOAT8), at::ScalarType::Byte}};
};
at::Tensor npu_quantize_by_kernel(
const at::Tensor& self,
const at::Tensor& scales,
const c10::optional<at::Tensor>& zero_points_opt,
int64_t dtype,
int64_t axis)
{
DO_COMPATIBILITY(aclnnQuantize,
acl_op::npu_quantize(self, scales, zero_points_opt, dtype, axis));
TORCH_CHECK(QUANTIZE_SUPPORT_MAP.find(dtype) != QUANTIZE_SUPPORT_MAP.end(),
"Param (dtype) must be Int8, UInt8, Int32, HiFloat8, Float8_e4m3fn, Float8_e5m2" + OPS_ERROR(ErrCode::TYPE));
auto output_shape = op_infer::array_to_small_vector(self.sizes());
at::ScalarType scalarDtype = QUANTIZE_SUPPORT_MAP[dtype];
aclDataType yAclType = npu_preparation::convert_to_acl_data_type(scalarDtype);
if (dtype == static_cast<int64_t>(c10_npu::DType::HIFLOAT8)) {
yAclType = ACL_HIFLOAT8;
}
at::Tensor y = npu_preparation::apply_tensor_without_format(output_shape, self.options().dtype(scalarDtype));
TensorWrapper y_wrapper = {y, yAclType};
EXEC_NPU_CMD(aclnnQuantize, self, scales, zero_points_opt, yAclType, axis, y_wrapper);
return y;
};
at::Tensor npu_quantize_by_ascend_quant(
const at::Tensor& self,
const at::Tensor& scales,
const c10::optional<at::Tensor>& zero_points_opt,
int64_t dtype,
int64_t axis)
{
at::ScalarType scalarDtype = at::ScalarType::Undefined;
aclDataType yAclType = ACL_INT8;
at::Tensor result;
if (dtype == static_cast<int64_t>(at::kQInt8)) {
ASCEND_LOGI("[npu_quantize]: Parameter(dtype) is torch.qint8, setting aclTensor out dtype to: %s",
at_npu::native::AclDataTypeToString(aclDataType::ACL_INT8).c_str());
yAclType = ACL_INT8;
scalarDtype = at::ScalarType::Char;
} else if (dtype == static_cast<int64_t>(at::ScalarType::QUInt4x2)) {
ASCEND_LOGI("[npu_quantize]: Parameter(dtype) is torch.quint4x2, setting aclTensor out dtype to: %s",
at_npu::native::AclDataTypeToString(aclDataType::ACL_INT32).c_str());
yAclType = ACL_INT32;
scalarDtype = at::ScalarType::Int;
} else {
ASCEND_LOGI("[npu_quantize]: Getting aclTensor out dtype by Parameter(dtype): %ld", dtype);
yAclType = c10_npu::GetAclDataType(dtype);
ASCEND_LOGI("[npu_quantize]: Setting aclTensor out to: %s", at_npu::native::AclDataTypeToString(yAclType).c_str());
scalarDtype = npu_preparation::convert_to_scalar_type(yAclType);
}
if (scalarDtype == at::ScalarType::Int) {
auto output_shape = op_infer::array_to_small_vector(self.sizes());
auto x_dim_num = self.dim();
TORCH_CHECK(output_shape[x_dim_num - 1] % INT4_NUMS_IN_INT32_SPACE == 0,
"Input shape last dim must be divded by 8" + OPS_ERROR(ErrCode::PARAM));
output_shape[x_dim_num - 1] /= INT4_NUMS_IN_INT32_SPACE;
int64_t npu_format = at_npu::native::custom_ops::get_npu_format(self);
if (npu_format == ACL_FORMAT_FRACTAL_NZ) {
result = npu_preparation::apply_tensor_with_format(
output_shape, self.options().dtype(scalarDtype), ACL_FORMAT_FRACTAL_NZ, true);
} else {
result = npu_preparation::apply_tensor_without_format(output_shape, self.options().dtype(scalarDtype));
}
} else {
result = npu_preparation::apply_tensor(self, self.options().dtype(scalarDtype));
}
TensorWrapper y_wrapper = {result, yAclType};
const bool sqrt_mode = false;
static const bool is_ascend_quant_V3_available = check_aclnn_kernel_available("aclnnAscendQuantV3");
if (!is_ascend_quant_V3_available) {
EXEC_NPU_CMD(aclnnAscendQuant, self, scales, zero_points_opt, sqrt_mode, "round", yAclType, y_wrapper);
} else {
axis = axis < -1 ? axis : -1;
EXEC_NPU_CMD(aclnnAscendQuantV3, self, scales, zero_points_opt, sqrt_mode, "round", yAclType, axis, y_wrapper);
}
return result;
};
at::Tensor npu_quantize(
const at::Tensor& self,
const at::Tensor& scales,
const c10::optional<at::Tensor>& zero_points_opt,
int64_t dtype,
int64_t axis,
bool div_mode)
{
if (div_mode) {
return npu_quantize_by_kernel(self, scales, zero_points_opt, dtype, axis);
}
return npu_quantize_by_ascend_quant(self, scales, zero_points_opt, dtype, axis);
}
}