#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpAdapter.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
using tensor_list = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
using namespace op_plugin::utils;
using namespace op_infer;
tensor_list npu_add_rms_norm_quant(const at::Tensor &x1, const at::Tensor &x2, const at::Tensor &gamma,
const at::Tensor &scales1, const c10::optional<at::Tensor> &zero_points1,
const c10::optional<at::Tensor> &beta, const c10::optional<at::Tensor> &scales2,
const c10::optional<at::Tensor> &zero_points2, int64_t axis, double epsilon, bool div_mode)
{
TORCH_CHECK(!scales2.has_value(), "scales only supprt None.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(!zero_points2.has_value(), "zero_points2 only supprt None.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(axis == -1, "axis only support -1.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(div_mode == true, "div_mode only support True.", OPS_ERROR(ErrCode::PARAM));
auto output_size_0 = x1.sizes();
auto output_dtype_0 = at::kChar;
auto output_dtype_1 = x1.scalar_type();
at::Tensor y1 = npu_preparation::apply_tensor_without_format(output_size_0, x1.options().dtype(output_dtype_0));
at::Tensor y2 = npu_preparation::apply_tensor_without_format(output_size_0, x1.options().dtype(output_dtype_0));
at::Tensor x_out = npu_preparation::apply_tensor_without_format(output_size_0, x1.options().dtype(output_dtype_1));
at::Tensor rmsnorm_out{nullptr};
if (check_aclnn_kernel_available("aclnnAddRmsNormQuantV2")) {
EXEC_NPU_CMD(aclnnAddRmsNormQuantV2, x1, x2, gamma, scales1, scales2, zero_points1, zero_points2, beta, axis, epsilon, div_mode, y1, y2, x_out, rmsnorm_out);
} else {
TORCH_CHECK(!beta.has_value(), "In the current CANN version, aclnnAddRmsNormQuant does not support the parameter beta input. It is recommended to upgrade the CANN package. Or please remove the beta input parameter.", OPS_ERROR(ErrCode::PARAM));
EXEC_NPU_CMD(aclnnAddRmsNormQuant, x1, x2, gamma, scales1, scales2, zero_points1, zero_points2, axis, epsilon, div_mode, y1, y2, x_out);
}
return std::make_tuple(std::move(y1), std::move(y2), std::move(x_out));
}
}