#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
at::Tensor& npu_grouped_matmul_add_(at::Tensor& self, const at::Tensor& x, const at::Tensor& weight,
const at::Tensor& group_list, bool transpose_x, bool transpose_weight,
int64_t group_type, c10::optional<int64_t> group_list_type)
{
static const bool has_grouped_matmul_add_v2 = check_aclnn_kernel_available("aclnnGroupedMatmulAddV2");
int64_t group_list_type_value = group_list_type.value_or(0);
static const bool checkSoc = c10_npu::IsAclnnOnly();
if (!has_grouped_matmul_add_v2 || !checkSoc) {
EXEC_NPU_CMD(aclnnGroupedMatmulAdd, x, weight, group_list, self, transpose_x, transpose_weight,
group_type);
return self;
}
EXEC_NPU_CMD(aclnnGroupedMatmulAddV2, x, weight, group_list, self, transpose_x, transpose_weight,
group_type, group_list_type_value);
return self;
}
at::Tensor npu_grouped_matmul_add(const at::Tensor& self, const at::Tensor& x, const at::Tensor& weight,
const at::Tensor& group_list, bool transpose_x, bool transpose_weight,
int64_t group_type, c10::optional<int64_t> group_list_type)
{
static const bool has_grouped_matmul_add_v2 = check_aclnn_kernel_available("aclnnGroupedMatmulAddV2");
int64_t group_list_type_value = group_list_type.value_or(0);
static const bool checkSoc = c10_npu::IsAclnnOnly();
if (!has_grouped_matmul_add_v2 || !checkSoc) {
EXEC_NPU_CMD(aclnnGroupedMatmulAdd, x, weight, group_list, self, transpose_x, transpose_weight, group_type);
return self.clone();
}
EXEC_NPU_CMD(aclnnGroupedMatmulAddV2, x, weight, group_list, self, transpose_x, transpose_weight, group_type,
group_list_type_value);
return self.clone();
}
}