* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
*
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include <string_view>
#include <torch/library.h>
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
#include "torch_npu/csrc/core/npu/NPUFormat.h"
#include "pytorch_npu_helper.h"
#include "layernorm.h"
using namespace at;
constexpr std::string_view LAYER_NORM_OP_NAME = "aclnnLayerNormWithImplMode";
std::tuple<at::Tensor, at::Tensor, at::Tensor> layernorm_mindie_sd_impl_npu(const at::Tensor &input,
at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> &weight, const c10::optional<at::Tensor> &bias,
double eps, int64_t impl_mode) {
const at::Tensor &weight_op = c10::value_or_else(weight, [] { return at::Tensor(); });
const at::Tensor &bias_op = c10::value_or_else(bias, [] { return at::Tensor(); });
at::Tensor output =
at_npu::native::empty_with_format(input.sizes(), input.options(), at_npu::native::get_npu_format(input));
at::Tensor mean_out;
at::Tensor rstd_out;
const size_t normNdim = normalized_shape.size();
const auto inputNdim = input.dim();
TORCH_CHECK(normNdim <= static_cast<size_t>(inputNdim),
"normalized_shape must fit within input dimensions, but got normalized_shape.size()=", normNdim,
" and input.dim()=", inputNdim);
const size_t beginAxis = inputNdim - normNdim;
const auto input_shape = input.sizes();
const int64_t m =
std::accumulate(input_shape.cbegin(), input_shape.cbegin() + beginAxis, 1LL, std::multiplies<int64_t>());
if (m <= 0) {
mean_out = at_npu::native::empty_with_format({m}, input.options(), at_npu::native::get_npu_format(input));
rstd_out = at_npu::native::empty_with_format({m}, input.options(), at_npu::native::get_npu_format(input));
} else {
at::SmallVector<int64_t, 8> mean_shape;
for (size_t index = 0; index < beginAxis; index++) {
mean_shape.emplace_back(input.size(index));
}
for (size_t index = beginAxis; index < inputNdim; index++) {
mean_shape.emplace_back(1);
}
mean_out =
at_npu::native::empty_with_format(mean_shape, input.options(), at_npu::native::get_npu_format(input));
rstd_out =
at_npu::native::empty_with_format(mean_shape, input.options(), at_npu::native::get_npu_format(input));
}
EXEC_NPU_CMD<LAYER_NORM_OP_NAME>(
input, normalized_shape, weight_op, bias_op, eps, output, mean_out, rstd_out, impl_mode);
return std::tie(output, mean_out, rstd_out);
}