#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
constexpr int64_t MINIMUM_SHAPE_SIZE = 2;
at::Tensor npu_matmul_compress_dequant(const at::Tensor &x1, const at::Tensor &x2,
const at::Tensor &compress_index, const at::Tensor &bias,
const at::Tensor &scale,
const c10::optional<at::Tensor> &offsetW,
c10::optional<int64_t> offsetX)
{
static const bool is_aclnn_available =
check_aclnn_kernel_available("aclnnMatmulCompressDequant");
TORCH_CHECK(is_aclnn_available,
"aclnnMatmulCompressDequant or aclnnMatmulCompressDequantGetWorkspaceSize not found, "
"please upgrade CANN.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(!offsetW.has_value(),
"offsetW currently only supports null/None, please do not pass offsetW.",
OPS_ERROR(ErrCode::PARAM));
int64_t offset_x_val = offsetX.value_or(0);
TORCH_CHECK(offset_x_val == 0,
"offsetX currently only supports 0, but got ", offset_x_val,
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(x1.dim() == MINIMUM_SHAPE_SIZE,
"x1 must have 2 dimensions, but got ", x1.dim(),
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(scale.dim() == MINIMUM_SHAPE_SIZE,
"scale must have 2 dimensions, but got ", scale.dim(),
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(bias.dim() == MINIMUM_SHAPE_SIZE,
"bias must have 2 dimensions, but got ", bias.dim(),
OPS_ERROR(ErrCode::PARAM));
int64_t M = x1.size(0);
int64_t N = bias.size(1);
c10::SmallVector<int64_t, op_infer::SIZE> output_size = {M, N};
at::Tensor result = npu_preparation::apply_tensor_with_format(
output_size, x1.options().dtype(at::kHalf), ACL_FORMAT_ND);
int64_t k = x1.size(1);
int64_t n = scale.size(1);
std::vector<int64_t> compress_info_vec = {8, 8, k, n, 1};
at::IntArrayRef compress_info_ref(compress_info_vec);
c10::optional<at::Tensor> offset_w_for_api = c10::nullopt;
int offset_x_for_api = 0;
EXEC_NPU_CMD(aclnnMatmulCompressDequant, x1, x2, compress_index, bias, scale,
offset_w_for_api, offset_x_for_api, compress_info_ref, result);
return result;
}
}