#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
std::tuple<at::Tensor, at::Tensor> _npu_rms_norm_backward_quant(
const at::Tensor &dy,
const at::Tensor &x,
const at::Tensor &rstd,
const at::Tensor &gamma,
const at::Tensor &scale_x,
const c10::optional<at::Tensor> &offset_x,
c10::optional<bool> div_mode,
c10::string_view quant_mode,
c10::optional<int64_t> dst_type)
{
TORCH_CHECK(dy.dim() >= 1, "dy must have at least 1 dimension, but got ", dy.dim(),
OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(x.dim() >= 1, "x must have at least 1 dimension, but got ", x.dim(),
OPS_ERROR(ErrCode::PARAM));
std::string quant_mode_str(quant_mode);
char *quant_mode_ptr = const_cast<char *>(quant_mode_str.c_str());
at::Tensor dgamma_out = npu_preparation::apply_tensor_with_format(
gamma.sizes(), gamma.options().dtype(at::ScalarType::Float), ACL_FORMAT_ND);
aclDataType dx_acltype = aclDataType::ACL_INT8;
at::ScalarType dx_scalar = npu_preparation::convert_to_scalar_type(dx_acltype);
at::Tensor dx_out;
if (dst_type.has_value()) {
dx_acltype = c10_npu::GetAclDataType(dst_type.value());
dx_scalar = npu_preparation::convert_to_scalar_type(dx_acltype);
}
dx_out = npu_preparation::apply_tensor_with_format(
dy.sizes(), dy.options().dtype(dx_scalar), ACL_FORMAT_ND);
TensorWrapper dx_wrapper = {dx_out, dx_acltype};
EXEC_NPU_CMD(aclnnRmsNormGradQuant, dy, x, rstd, gamma,
scale_x, offset_x, quant_mode_ptr, div_mode,
dx_wrapper, dgamma_out);
return std::make_tuple(std::move(dx_out), std::move(dgamma_out));
}
}