#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpUtils.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
at::Tensor &npu_add_quant_gmm_(at::Tensor &self, const at::Tensor &x1, const at::Tensor &x2,
const at::Tensor &x2_scale, const at::Tensor &group_list,
const c10::optional<at::Tensor> &x1_scale, c10::optional<int64_t> group_list_type,
c10::OptionalIntArrayRef group_sizes, c10::optional<int64_t> x1_dtype,
c10::optional<int64_t> x2_dtype, c10::optional<int64_t> x1_scale_dtype,
c10::optional<int64_t> x2_scale_dtype)
{
static const bool is_quant_grouped_matmul_inplace_add_available =
check_aclnn_kernel_available("aclnnQuantGroupedMatmulInplaceAdd");
TORCH_CHECK(is_quant_grouped_matmul_inplace_add_available,
"Get aclnnQuantGroupedMatmulInplaceAdd or aclnnQuantGroupedMatmulInplaceAddGetWorkspaceSize failed, "
"please upgrade CANN.",
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(!group_sizes.has_value(), "group_sizes is not supported currently. ", OPS_ERROR(ErrCode::VALUE));
const at::Tensor &x1_scale_real = x1_scale.value_or(at::Tensor());
int64_t group_list_type_value = group_list_type.value_or(0);
int64_t group_size = 0;
TensorWrapper x1_wrapper = {x1, x1_dtype.has_value() ? c10_npu::GetAclDataType(x1_dtype.value())
: npu_preparation::convert_to_acl_data_type(x1.scalar_type())};
TensorWrapper x2_wrapper = {x2, x2_dtype.has_value() ? c10_npu::GetAclDataType(x2_dtype.value())
: npu_preparation::convert_to_acl_data_type(x2.scalar_type())};
TensorWrapper x2_scale_wrapper = {
x2_scale, x2_scale_dtype.has_value() ? c10_npu::GetAclDataType(x2_scale_dtype.value())
: npu_preparation::convert_to_acl_data_type(x2_scale.scalar_type())};
TensorWrapper x1_scale_wrapper = {
x1_scale_real,
x1_scale_dtype.has_value()
? c10_npu::GetAclDataType(x1_scale_dtype.value())
: (x1_scale.has_value() ? npu_preparation::convert_to_acl_data_type(x1_scale_real.scalar_type())
: aclDataType::ACL_FLOAT)};
EXEC_NPU_CMD(aclnnQuantGroupedMatmulInplaceAdd, x1_wrapper, x2_wrapper, x1_scale_wrapper, x2_scale_wrapper,
group_list, self, group_list_type_value, group_size);
return self;
}
}