#include "op_plugin/utils/OpUtils.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
constexpr int64_t DIM_0 = 0;
constexpr int64_t DIM_1 = 1;
constexpr int64_t BSND_DIMS = 4;
constexpr int64_t TND_DIMS = 3;
constexpr int64_t NUM_THREE = 3;
constexpr int64_t NUM_TWO = 2;
constexpr int64_t NUM_ONE = 1;
inline void check_mhc_pre_backward_supported()
{
static const bool is_cann_ready = op_plugin::utils::is_gte_cann_version_900();
static const bool is_aclnn_kernel_available = check_aclnn_kernel_available("aclnnMhcPre") && check_aclnn_kernel_available("aclnnMhcPreBackward");
TORCH_CHECK(
is_cann_ready && is_aclnn_kernel_available,
"torch_npu.npu_mhc_pre_backward requires CANN >= 9.0.0, aclnnMhcPre and aclnnMhcPreBackward support. "
"Please upgrade CANN.",
OPS_ERROR(ErrCode::NOT_SUPPORT));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_mhc_pre_backward(
const at::Tensor &x, const at::Tensor &phi, const at::Tensor &alpha,
const at::Tensor &grad_h_in, const at::Tensor &grad_h_post, const at::Tensor &grad_h_res,
const at::Tensor &inv_rms, const at::Tensor &h_mix, const at::Tensor &h_pre, const at::Tensor &h_post,
const c10::optional<at::Tensor> &gamma, double hc_eps, const c10::optional<at::Tensor> &grad_x_post)
{
TORCH_CHECK(x.numel() > 0, "Tensor x is empty.");
TORCH_CHECK(phi.numel() > 0, "Tensor phi is empty.");
TORCH_CHECK(alpha.numel() > 0, "Tensor alpha is empty.");
TORCH_CHECK(grad_h_in.numel() > 0, "Tensor grad_h_in is empty.");
TORCH_CHECK(grad_h_post.numel() > 0, "Tensor grad_h_post is empty.");
TORCH_CHECK(grad_h_res.numel() > 0, "Tensor grad_h_res is empty.");
TORCH_CHECK(inv_rms.numel() > 0, "Tensor inv_rms is empty.");
TORCH_CHECK(h_mix.numel() > 0, "Tensor h_mix is empty.");
TORCH_CHECK(h_pre.numel() > 0, "Tensor h_pre is empty.");
TORCH_CHECK(h_post.numel() > 0, "Tensor h_post is empty.");
TORCH_CHECK(x.dim() == TND_DIMS || x.dim() == BSND_DIMS, "Input x must be 3D or 4D, but got ", x.dim(), "D.");
TORCH_CHECK(grad_h_in.dim() == 2 || grad_h_in.dim() == 3,
"Input grad_h_in must be 2D or 3D, but got ", grad_h_in.dim(), "D.", OPS_ERROR(ErrCode::PARAM));
TORCH_CHECK(grad_h_post.dim() == 2 || grad_h_post.dim() == 3,
"Input grad_h_post must be 2D or 3D, but got ", grad_h_post.dim(), "D.", OPS_ERROR(ErrCode::PARAM));
check_mhc_pre_backward_supported();
c10::TensorOptions xGradOptions = x.options().dtype(x.dtype());
c10::TensorOptions fp32Options = grad_h_in.options().dtype(at::kFloat);
at::Tensor xGrad;
at::Tensor phiGrad;
at::Tensor alphaGrad;
at::Tensor biasPostGrad;
at::Tensor gammaGrad;
auto dimen = grad_h_in.size(grad_h_in.dim() - 1);
auto numsResidual = grad_h_post.size(grad_h_post.dim() - 1);
if (grad_h_in.dim() == NUM_THREE) {
auto batch = grad_h_in.size(DIM_0);
auto sequence = grad_h_in.size(DIM_1);
c10::SmallVector<int64_t, BSND_DIMS> xGradSize = {batch, sequence, numsResidual, dimen};
xGrad = at::empty(xGradSize, xGradOptions);
} else {
auto t = grad_h_in.size(DIM_0);
c10::SmallVector<int64_t, TND_DIMS> xGradSize = {t, numsResidual, dimen};
xGrad = at::empty(xGradSize, xGradOptions);
}
c10::SmallVector<int64_t, NUM_TWO> phiGradSize = {(NUM_TWO * numsResidual) + (numsResidual * numsResidual), numsResidual * dimen};
c10::SmallVector<int64_t, NUM_ONE> alphaGradSize = {NUM_THREE};
c10::SmallVector<int64_t, NUM_ONE> biasPostGradSize = {(NUM_TWO * numsResidual) + (numsResidual * numsResidual)};
c10::SmallVector<int64_t, NUM_TWO> gammaGradSize = {numsResidual, dimen};
phiGrad = at::empty(phiGradSize, fp32Options);
alphaGrad = at::empty(alphaGradSize, fp32Options);
biasPostGrad = at::empty(biasPostGradSize, fp32Options);
gammaGrad = at::empty(gammaGradSize, fp32Options);
float eps_f = static_cast<float>(hc_eps);
EXEC_NPU_CMD(aclnnMhcPreBackward, x, phi, alpha, grad_h_in, grad_h_post,
grad_h_res, inv_rms, h_mix, h_pre, h_post, gamma, grad_x_post, eps_f, xGrad, phiGrad, alphaGrad, biasPostGrad, gammaGrad);
return std::make_tuple(xGrad, phiGrad, alphaGrad, biasPostGrad, gammaGrad);
}
}