#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
std::tuple<at::Tensor, at::Tensor> native_dropout(
const at::Tensor& input,
double p,
c10::optional<bool> train)
{
if (input.numel() == 0) {
return std::make_tuple(input, at::empty_like(input, input.options()));
}
bool dropout_train = !train.has_value() ? true : train.value();
at::TensorOptions options = input.options();
if (p == static_cast<double>(0) || !dropout_train) {
at::Tensor mask = acl_op::ones(
input.sizes(),
at::kBool,
options.layout(),
options.device(),
options.pinned_memory());
return std::make_tuple(input.clone(), mask);
}
if (p == static_cast<double>(1)) {
at::Tensor output = at::zeros(input.sizes(), options);
at::Tensor mask = at::zeros(input.sizes(), options.dtype(at::kBool));
return std::make_tuple(output, mask);
}
return acl_op::_npu_dropout(input, p);
}
at::Tensor native_dropout_backward(
const at::Tensor& grad_output,
const at::Tensor& mask,
double scale)
{
double p = (scale == static_cast<double>(0.0)) ? 1 : (1 - 1 / scale);
at::TensorOptions options = grad_output.options();
if (mask.numel() == 0) {
return at_npu::native::OpPreparation::apply_tensor_without_format(mask.sizes(), options);
}
if (p == static_cast<double>(0)) {
return grad_output;
}
if (p == static_cast<double>(1)) {
return at::zeros(grad_output.sizes(), options);
}
return acl_op::npu_dropout_backward(grad_output, mask, p);
}
}