#include "op_plugin/OpApiInterface.h"
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
std::tuple<at::Tensor, at::Tensor, at::Tensor> npu_deep_norm(const at::Tensor& x,
const at::Tensor& gx,
const at::Tensor& beta,
const at::Tensor& gamma,
double alpha,
double epsilon)
{
DO_COMPATIBILITY(aclnnDeepNorm, acl_op::npu_deep_norm(x, gx, beta, gamma, alpha, epsilon));
at::SmallVector<int64_t, SIZE> shape;
auto param_dim = x.dim() - gamma.dim();
for (int64_t index = 0; index < x.dim(); index++) {
if (index < param_dim) {
shape.emplace_back(x.size(index));
} else {
shape.emplace_back(1);
}
}
at::Tensor y = npu_preparation::apply_tensor(x);
at::Tensor mean = npu_preparation::apply_tensor(shape, x.options().dtype(at::kFloat), x);
at::Tensor rstd = npu_preparation::apply_tensor(shape, x.options().dtype(at::kFloat), x);
EXEC_NPU_CMD(aclnnDeepNorm, x, gx, beta, gamma, alpha, epsilon, mean, rstd, y);
return std::tuple<at::Tensor, at::Tensor, at::Tensor>(mean, rstd, y);
}
}