#include "torch_npu/csrc/aten/CustomFunctions.h"
#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
at::Tensor dropout(const at::Tensor& self, double p, bool train)
{
if (p == 0 || !train || self.numel() == 0) {
return self;
}
if (p == 1) {
return self.mul(at::zeros(self.sizes(), self.options()));
}
auto results = at_npu::native::custom_ops::_npu_dropout(self, p);
return std::get<0>(results);
}
at::Tensor& dropout_(at::Tensor& self, double p, bool train)
{
if (p == 0 || !train || self.numel() == 0) {
return self;
}
if (p == 1) {
return self.mul_(at::zeros(self.sizes(), self.options()));
}
auto results = at_npu::native::custom_ops::_npu_dropout(self, p);
self.copy_(std::get<0>(results));
return self;
}
}