#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpAdapter.h"
#include "op_plugin/utils/KernelNpuOutputSize.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
using namespace op_infer;
at::Tensor npu_add_rms_norm_v2(at::Tensor &x1, at::Tensor &x2, const at::Tensor &gamma, double epsilon)
{
auto output_size_0 = rms_norm_npu_output_size(x1, gamma)[1];
auto output_dtype_0 = at::kFloat;
at::Tensor y = npu_preparation::apply_tensor_without_format(output_size_0,
x1.options().dtype(output_dtype_0));
EXEC_NPU_CMD(aclnnInplaceAddRmsNorm, x1, x2, gamma, epsilon, y);
return y;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> npu_add_rms_norm_v2_functional(const at::Tensor &x1,
const at::Tensor &x2,
const at::Tensor &gamma,
double epsilon)
{
auto output_size_0 = rms_norm_npu_output_size(x1, gamma)[1];
auto output_dtype_0 = at::kFloat;
at::Tensor y = npu_preparation::apply_tensor_without_format(output_size_0,
x1.options().dtype(output_dtype_0));
at::Tensor x1_inplace = x1.clone();
at::Tensor x2_inplace = x2.clone();
EXEC_NPU_CMD(aclnnInplaceAddRmsNorm, x1_inplace, x2_inplace, gamma, epsilon, y);
return std::tie(y, x1_inplace, x2_inplace);
}
}