#ifndef OP_PLUGIN_UTILS_SEARCHSORTED_VALIDATE_UTIL_H_
#define OP_PLUGIN_UTILS_SEARCHSORTED_VALIDATE_UTIL_H_
#include <climits>
#include <string>
#include <tuple>
#include <ATen/core/Tensor.h>
#include <c10/core/Scalar.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/string_view.h>
namespace op_plugin {
inline bool searchsorted_dims_matched_before_last_dim(const at::Tensor &boundaries, const at::Tensor &input) {
if (boundaries.dim() != input.dim()) {
return false;
}
const auto &dims_bd = boundaries.sizes();
const auto &dims_in = input.sizes();
for (int64_t dim = 0; dim + 1 < boundaries.dim(); ++dim) {
if (dims_bd[dim] != dims_in[dim]) {
return false;
}
}
return true;
}
inline void searchsorted_pre_check_npu(const at::Tensor &boundaries, const at::Tensor &input, const at::Tensor *output,
bool out_int32, bool right, const c10::optional<c10::string_view> &side_opt,
const c10::optional<at::Tensor> &sorter_opt) {
if (side_opt.has_value()) {
c10::string_view side = *side_opt;
TORCH_CHECK(side == "left" || side == "right",
"torch.searchsorted(): side can only be 'left' or 'right' but got ", std::string(side.data(), side.size()));
TORCH_CHECK(!right || side == "right",
"torch.searchsorted(): side and right can't be set to opposites, got side of ",
std::string(side.data(), side.size()), " while right was True");
}
TORCH_CHECK(boundaries.device() == input.device(),
"torch.searchsorted(): boundaries and input value tensors should have same device type, but got "
"boundaries tensor device type ",
boundaries.device(), " and input value tensor device type ", input.device());
if (sorter_opt.has_value()) {
const at::Tensor &sorter = *sorter_opt;
TORCH_CHECK(
sorter.defined(), "torch.searchsorted(): optional sorter was set but the sorter tensor is undefined");
TORCH_CHECK(sorter.device() == boundaries.device(),
"torch.searchsorted(): sorter and boundary tensors should have same device type, but got sorter tensor "
"device type ",
sorter.device(), " and input value tensor device type ", boundaries.device());
TORCH_CHECK(sorter.sizes() == boundaries.sizes(),
"torch.searchsorted(): boundary and sorter must have the same size, but got boundary tensor ",
boundaries.sizes(), "and got sorter tensor ", sorter.sizes());
TORCH_CHECK(sorter.scalar_type() == at::ScalarType::Long,
"torch.searchsorted(): sorter must be a tensor of long dtype but got dtype ", sorter.scalar_type());
if (sorter.numel() > 0) {
auto minmax = sorter.aminmax();
int64_t vmin = std::get<0>(minmax).item().toLong();
int64_t vmax = std::get<1>(minmax).item().toLong();
TORCH_CHECK(vmin >= 0 && vmax < sorter.sizes().back(), "torch.searchsorted(): sorter index out of range");
}
}
TORCH_CHECK(input.dim() > 0 || (input.dim() == 0 && input.numel() == 1 && boundaries.dim() == 1),
"torch.searchsorted(): input value can be a scalar only when boundaries tensor dimension is 1, but we "
"got boundaries tensor dim(",
boundaries.dim(), ") and input value's dim(", input.dim(), ") numel(", input.numel(), ")");
if (boundaries.dim() != 1 && input.numel() == 1 && input.dim() != boundaries.dim()) {
TORCH_CHECK(false,
"torch.searchsorted(): input value can be a scalar only when boundaries tensor dimension is 1, but we "
"got boundaries tensor dim(",
boundaries.dim(), ") and input value's dim(", input.dim(), ") numel(", input.numel(), ")");
}
TORCH_CHECK(boundaries.dim() != 0,
"torch.searchsorted(): boundaries tensor should have positive dimension, but got 0 dimension");
TORCH_CHECK(boundaries.dim() == 1 || searchsorted_dims_matched_before_last_dim(boundaries, input),
"torch.searchsorted(): boundaries tensor should be 1 dimension or the first N-1 dimensions of boundaries "
"tensor and input value tensor must match, but we got boundaries tensor ",
boundaries.sizes(), " and input value tensor ", input.sizes());
if (output != nullptr) {
at::ScalarType output_dtype = output->scalar_type();
TORCH_CHECK(
(output_dtype == at::ScalarType::Long && !out_int32) || (output_dtype == at::ScalarType::Int && out_int32),
"torch.searchsorted(): output tensor's dtype is wrong, it can only be Int(int32) or Long(int64) "
"depending on whether out_int32 flag is True, but we got output tensor's dtype ",
output_dtype, " and out_int32 flag is ", (out_int32 ? "True" : "False"));
}
if (out_int32) {
TORCH_CHECK(boundaries.sizes().back() < INT_MAX,
"torch.searchsorted(): the size of boundaries' last dimension should be less than ", INT_MAX,
", but we got ", boundaries.sizes().back());
}
}
inline int searchsorted_validate_core_no_output(const at::Tensor &sorted_sequence, const at::Tensor &self,
bool out_int32, bool right, const c10::optional<c10::string_view> &side_opt,
const c10::optional<at::Tensor> &sorter_opt) {
searchsorted_pre_check_npu(sorted_sequence, self, nullptr, out_int32, right, side_opt, sorter_opt);
return 0;
}
inline int searchsorted_validate_tensor_out_op(const at::Tensor &sorted_sequence, const at::Tensor &self,
const at::Tensor &out, bool out_int32, bool right, const c10::optional<c10::string_view> &side_opt,
const c10::optional<at::Tensor> &sorter_opt) {
searchsorted_pre_check_npu(sorted_sequence, self, &out, out_int32, right, side_opt, sorter_opt);
return 0;
}
inline int searchsorted_validate_scalar_op(const at::Tensor &sorted_sequence, const c10::Scalar &self, bool out_int32,
bool right, const c10::optional<c10::string_view> &side_opt, const c10::optional<at::Tensor> &sorter_opt) {
at::Tensor input_t = at::empty({}, sorted_sequence.options());
input_t.fill_(self);
input_t.unsafeGetTensorImpl()->set_wrapped_number(true);
searchsorted_pre_check_npu(sorted_sequence, input_t, nullptr, out_int32, right, side_opt, sorter_opt);
return 0;
}
}
#endif