#include <torch/extension.h>
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
#include "inc/aclnn_common.h"
const static double EPSILON = 0.00000000001;
inline void npu_dropout_add_layer_norm_check(
const at::Tensor &x0,
const at::Tensor &weight,
const c10::optional<at::Tensor> &residual_opt,
const c10::optional<at::Tensor> &bias_opt,
const c10::optional<at::Tensor> &rowscale_opt,
const c10::optional<at::Tensor> &layerscale_opt,
double p,
double eps)
{
TORCH_CHECK(
torch_npu::utils::is_npu(x0),
"npu_dropout_add_layer_norm only supports device for NPU!");
auto itype = x0.scalar_type();
auto wtype = weight.scalar_type();
TORCH_CHECK(
!(itype == at::kBFloat16 && wtype == at::kHalf),
"weight_dtype == torch.float16 and input_dtype == torch.bfloat16 was not supported");
if (bias_opt.has_value()) {
auto bias = bias_opt.value();
TORCH_CHECK(bias.dtype() == wtype);
TORCH_CHECK(bias.sizes() == weight.sizes());
}
if (residual_opt.has_value()) {
auto residual = residual_opt.value();
TORCH_CHECK(residual.sizes() == x0.sizes());
}
if (rowscale_opt.has_value()) {
auto rowscale = rowscale_opt.value();
TORCH_CHECK(rowscale.dim() == x0.dim() - 1);
TORCH_CHECK(rowscale.dtype() == itype);
}
if (layerscale_opt.has_value()) {
auto layerscale = layerscale_opt.value();
TORCH_CHECK(layerscale.sizes()[0] == x0.sizes().back());
TORCH_CHECK(layerscale.dtype() == wtype);
}
TORCH_CHECK(
p >= 0 && p <= 1,
"dropout probability has to be between 0 and 1, but got ", p);
TORCH_CHECK(eps >= 0.f);
auto hidden_size = weight.numel();
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> npu_dropout_add_layer_norm(
const at::Tensor &x0,
const at::Tensor &weight,
const c10::optional<at::Tensor> &residual_opt,
const c10::optional<at::Tensor> &bias_opt,
const c10::optional<at::Tensor> &rowscale_opt,
const c10::optional<at::Tensor> &layerscale_opt,
double p,
double eps,
bool prenorm,
bool residual_in_fp32,
bool is_rms_norm,
bool return_dropout_mask)
{
npu_dropout_add_layer_norm_check(
x0, weight, residual_opt, bias_opt, rowscale_opt, layerscale_opt, p, eps);
const at::Tensor &residual_ = c10::value_or_else(residual_opt, [] { return at::Tensor(); });
const at::Tensor &bias_ = c10::value_or_else(bias_opt, [] { return at::Tensor(); });
const at::Tensor &rowscale_ = c10::value_or_else(rowscale_opt, [] { return at::Tensor(); });
const at::Tensor &layerscale_ = c10::value_or_else(layerscale_opt, [] { return at::Tensor(); });
at::Tensor residual = residual_;
at::Tensor bias = bias_;
at::Tensor rowscale = rowscale_;
at::Tensor layerscale = layerscale_;
at::IntArrayRef x0_sizes = x0.sizes();
at::ScalarType x0_dtype = x0.scalar_type();
at::ScalarType residual_dtype = residual.defined() ?
residual.scalar_type() :
(residual_in_fp32 ? at::kFloat : x0_dtype);
const at::Tensor x0_fp32 = (x0_dtype == at::kFloat) ? x0 : x0.to(at::kFloat);
const at::Tensor weight_fp32 = (weight.scalar_type() == at::kFloat) ? weight : weight.to(at::kFloat);
if (residual.defined()) {
residual = (residual.scalar_type() == at::kFloat) ? residual : residual.to(at::kFloat);
}
if (bias.defined()) {
bias = (bias.scalar_type() == at::kFloat) ? bias : bias.to(at::kFloat);
}
at::Tensor scaled_x0 = x0_fp32;
int64_t batch = scaled_x0.size(0);
int64_t seq = scaled_x0.size(1);
int64_t head = scaled_x0.size(2);
if (rowscale.defined()) {
rowscale = (rowscale.scalar_type() == at::kFloat) ? rowscale : rowscale.to(at::kFloat);
rowscale = rowscale.view({batch, seq, 1});
scaled_x0 = scaled_x0.mul(rowscale);
}
if (layerscale.defined()) {
layerscale = (layerscale.scalar_type() == at::kFloat) ? layerscale : layerscale.to(at::kFloat);
layerscale = layerscale.view({1, 1, head});
scaled_x0 = scaled_x0.mul(layerscale);
}
at::Tensor dropout_result;
at::Tensor mask;
bool train = p == 0.0 ? false : true;
if (train) {
double p1m = 1. - p;
double scale = std::abs(p1m) < (0 + EPSILON) ? 0. : 1. / p1m;
mask = at::empty_like(scaled_x0, scaled_x0.options().dtype(c10::CppTypeToScalarType<bool>::value));
mask.bernoulli_(p1m);
dropout_result = scaled_x0.mul(mask).mul_(scale);
} else {
mask = at::ones_like(scaled_x0, scaled_x0.options().dtype(c10::CppTypeToScalarType<bool>::value));
dropout_result = scaled_x0;
}
at::Tensor norm_result;
at::Tensor pre_norm = residual.defined() ? dropout_result.add(residual) : dropout_result;
int hidden_size = weight.numel();
float inverse_cols = 1.f / float(hidden_size);
if (!is_rms_norm) {
auto native_layer_norm_output = at::native_layer_norm(pre_norm, hidden_size, weight_fp32, bias, eps);
norm_result = std::get<0>(native_layer_norm_output);
} else {
at::Tensor norm_x = (pre_norm.mul(pre_norm)).sum(2, true).mul(inverse_cols).add(eps);
norm_result = pre_norm.mul(norm_x.pow(-0.5)).mul(weight_fp32.view({1,1,head}));
}
norm_result = (norm_result.scalar_type() == x0_dtype) ? norm_result : norm_result.to(x0_dtype);
at::Tensor pre_norm_result;
if (prenorm) {
pre_norm_result = (pre_norm.scalar_type() == residual_dtype) ? pre_norm : pre_norm.to(residual_dtype);
}
at::Tensor mask_result;
if (return_dropout_mask) {
mask_result = mask;
}
return std::tie(norm_result, pre_norm_result, mask_result);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("npu_dropout_add_layer_norm", &npu_dropout_add_layer_norm, "npu_dropout_add_layer_norm forward");
}