#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/OpAdapter.h"
namespace op_api {
at::Tensor softmax(const at::Tensor& self, int64_t dim, c10::optional<at::ScalarType> dtype)
{
DO_COMPATIBILITY(aclnnSoftmax, acl_op::softmax(self, dim, dtype));
auto result = [&]() {
at::NoNamesGuard guard;
at::Tensor converted = self;
if (dtype.has_value() && self.scalar_type() != dtype.value()) {
converted = at_npu::native::custom_ops::_npu_dtype_cast(self, dtype.value());
}
return at::_softmax(converted, dim, false);
}();
at::namedinference::propagate_names(result, self);
return result;
}
at::Tensor softmax(const at::Tensor& self, at::Dimname dim, c10::optional<at::ScalarType> dtype)
{
return op_api::softmax(self, dimname_to_position(self, dim), dtype);
}
}