#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
static void round_decimals_check(const at::Tensor& self, int64_t decimals)
{
TORCH_CHECK(isFloatingType(self.scalar_type()) ||
self.scalar_type() == at::ScalarType::Int ||
self.scalar_type() == at::ScalarType::Long,
"\"round_npu\" not implemented for '", toString(self.scalar_type()), "'",
OPS_ERROR(ErrCode::TYPE));
}
at::Tensor& round_out(const at::Tensor& self, int64_t decimals, at::Tensor& result) {
DO_COMPATIBILITY(aclnnRoundDecimals, acl_op::round_out(self, decimals, result));
round_decimals_check(self, decimals);
npu_preparation::check_tensor({self}, result, self);
EXEC_NPU_CMD(aclnnRoundDecimals, self, decimals, result);
return result;
}
at::Tensor round(const at::Tensor& self, int64_t decimals) {
DO_COMPATIBILITY(aclnnRoundDecimals, acl_op::round(self, decimals));
round_decimals_check(self, decimals);
at::Tensor result = npu_preparation::apply_tensor_without_format(self);
EXEC_NPU_CMD(aclnnRoundDecimals, self, decimals, result);
return result;
}
at::Tensor& round_(at::Tensor& self, int64_t decimals) {
DO_COMPATIBILITY(aclnnInplaceRoundDecimals, acl_op::round_(self, decimals));
round_decimals_check(self, decimals);
EXEC_NPU_CMD(aclnnInplaceRoundDecimals, self, decimals);
return self;
}
}