#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
constexpr int MINIMUM_SHAPE_SIZE = 2;
const int64_t INT4_NUMS_IN_INT32 = 8;
at::Tensor npu_weight_quant_batchmatmul(const at::Tensor &x, const at::Tensor &weight,
const at::Tensor &antiquant_scale,
const c10::optional<at::Tensor> &antiquant_offset,
const c10::optional<at::Tensor> &quant_scale,
const c10::optional<at::Tensor> &quant_offset,
const c10::optional<at::Tensor> &bias,
int64_t antiquant_group_size,
int64_t inner_precise, c10::optional<int64_t> weight_dtype)
{
bool trans_weight = op_plugin::utils::is_transpose_last_two_dims(weight);
auto x_dim_num = x.dim();
auto weight_dim_num = weight.dim();
TORCH_CHECK(x_dim_num >= MINIMUM_SHAPE_SIZE, "x shape do not support dim num less than 2, but it is ", x_dim_num,
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(weight_dim_num >= MINIMUM_SHAPE_SIZE, "weight shape do not support dim num less than 2, but it is ",
weight_dim_num, OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(!(std::min(x_dim_num, weight_dim_num) > MINIMUM_SHAPE_SIZE && x_dim_num != weight_dim_num),
"x dim is not the same as weight dim", OPS_ERROR(ErrCode::PARAM));
auto x_k_dim = x.size(x_dim_num - 1);
auto weight_k_dim = ((weight.dtype() == at::kInt || weight.dtype() == at::kFloat) && trans_weight) ?
weight.size(weight_dim_num - MINIMUM_SHAPE_SIZE) * INT4_NUMS_IN_INT32 :
weight.size(weight_dim_num - MINIMUM_SHAPE_SIZE);
TORCH_CHECK(x_k_dim == weight_k_dim, "The k of x and weight should be equal. but x_k_dim is ", x_k_dim,
", weight_k_dim is ", weight_k_dim, OPS_ERROR(ErrCode::PARAM));
auto out_dim_num = std::max(x_dim_num, weight_dim_num);
auto output_size = op_infer::array_to_small_vector(x.sizes());
output_size[out_dim_num - MINIMUM_SHAPE_SIZE] = x.size(x_dim_num - MINIMUM_SHAPE_SIZE);
auto weight_size_base = weight.size(weight_dim_num - MINIMUM_SHAPE_SIZE + 1);
output_size[out_dim_num - MINIMUM_SHAPE_SIZE + 1] =
((weight.dtype() == at::kInt || weight.dtype() == at::kFloat) && !trans_weight)
? weight_size_base * INT4_NUMS_IN_INT32
: weight_size_base;
if (x_dim_num == weight_dim_num) {
for (auto i = 0; i < out_dim_num - MINIMUM_SHAPE_SIZE; i++) {
TORCH_CHECK(x.size(i) == weight.size(i), "batch of x is diff from batch of weight",
OPS_ERROR(ErrCode::PARAM));
output_size[i] = x.size(i);
}
} else {
auto longer_tensor = x_dim_num > weight_dim_num ? x : weight;
for (auto i = 0; i < out_dim_num - MINIMUM_SHAPE_SIZE; i++) {
output_size[i] = longer_tensor.size(i);
}
}
const at::Tensor &antiquant_offset_real = antiquant_offset.value_or(at::Tensor());
const at::Tensor &quant_scale_real = quant_scale.value_or(at::Tensor());
const at::Tensor &quant_offset_real = quant_offset.value_or(at::Tensor());
const at::Tensor &bias_real = bias.value_or(at::Tensor());
int antiquant_group_size_real = static_cast<int>(antiquant_group_size);
bool is_group_size_vaild = antiquant_group_size_real == 0 || (antiquant_group_size_real >= 32 &&
antiquant_group_size_real <= weight_k_dim - 1 && antiquant_group_size_real % 32 == 0);
TORCH_CHECK(is_group_size_vaild,
"antiquant_group_size can be either 0 or a multiple of 32 within the range 32 to weight_k_dim - 1.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK((quant_scale.has_value() || !quant_offset.has_value()),
"Quantization parameters are incorrectly set, quant_offset cannot exist in isolation from quant_scale",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK((inner_precise == 0 || inner_precise == 1),
"inner_precise only support 0 or 1. but is:", inner_precise,
OPS_ERROR(ErrCode::PARAM));
c10::TensorOptions options =
quant_scale.has_value() ? x.options().dtype(at::kChar) : x.options().dtype(x.scalar_type());
at::Tensor result = npu_preparation::apply_tensor_without_format(output_size, options);
TensorWrapper weight_wrapper = make_wrapper(weight, weight_dtype);
int64_t weight_format = at_npu::native::custom_ops::get_npu_format(weight);
const bool is_weight_nz = (weight_format == ACL_FORMAT_FRACTAL_NZ) ||
(weight_format == ACL_FORMAT_FRACTAL_NZ_C0_2);
if (is_weight_nz && c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend950) {
static const bool is_weight_quant_matmul_nz_available =
check_aclnn_kernel_available("aclnnWeightQuantBatchMatmulNz");
TORCH_CHECK(is_weight_quant_matmul_nz_available,
"Get aclnnWeightQuantBatchMatmulNz or aclnnWeightQuantBatchMatmulNzGetWorkspaceSize failed, only "
"aclnnWeightQuantBatchMatmulNz support weight's format is nz, please upgrade CANN.",
OPS_ERROR(ErrCode::PARAM));
EXEC_NPU_CMD(aclnnWeightQuantBatchMatmulNz, x, weight_wrapper, antiquant_scale, antiquant_offset_real,
quant_scale_real, quant_offset_real, bias_real, antiquant_group_size_real, result);
} else if (quant_scale.has_value() && quant_scale_real.dtype() == at::kFloat) {
auto quant_scale_output_size = op_infer::array_to_small_vector(quant_scale_real.sizes());
c10::TensorOptions quant_scale_options = quant_scale_real.options().dtype(at::kLong);
at::Tensor quant_scale_result = npu_preparation::apply_tensor_without_format(quant_scale_output_size,
quant_scale_options);
EXEC_NPU_CMD(aclnnTransQuantParamV2, quant_scale_real, quant_offset_real, quant_scale_result);
EXEC_NPU_CMD(aclnnWeightQuantBatchMatmulV2, x, weight_wrapper, antiquant_scale, antiquant_offset_real,
quant_scale_result, quant_offset_real, bias_real, antiquant_group_size_real, result);
} else if (inner_precise == 1) {
EXEC_NPU_CMD(aclnnWeightQuantBatchMatmulV3, x, weight_wrapper, antiquant_scale, antiquant_offset_real, quant_scale_real,
quant_offset_real, bias_real, antiquant_group_size_real, inner_precise, result);
} else {
EXEC_NPU_CMD(aclnnWeightQuantBatchMatmulV2, x, weight_wrapper, antiquant_scale, antiquant_offset_real, quant_scale_real,
quant_offset_real, bias_real, antiquant_group_size_real, result);
}
return result;
}
}