#ifndef OP_PLUGIN_UTILS_SEARCHSORTED_WARN_UTIL_H_
#define OP_PLUGIN_UTILS_SEARCHSORTED_WARN_UTIL_H_
#include <ATen/core/Tensor.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
namespace op_plugin {
inline bool searchsorted_tensor_is_row_major_contiguous(const at::Tensor &t) {
if (!t.defined() || t.numel() == 0) {
return true;
}
if (t.layout() != c10::Layout::Strided) {
return t.is_contiguous();
}
const int64_t dim = t.dim();
if (dim == 0) {
return true;
}
int64_t z = 1;
for (int64_t d = dim - 1; d >= 0; --d) {
const int64_t size_d = t.size(d);
if (size_d != 1) {
if (t.stride(d) != z) {
return false;
}
if (size_d == 0) {
return true;
}
}
z *= size_d;
}
return true;
}
inline int warn_if_searchsorted_inputs_noncontiguous(
const at::Tensor &sorted_sequence, const at::Tensor &self, const c10::optional<at::Tensor> &sorter_opt) {
if (!searchsorted_tensor_is_row_major_contiguous(self)) {
TORCH_WARN_ONCE(
"torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due "
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value "
"tensor if possible. This message will only appear once per program.");
}
if (!searchsorted_tensor_is_row_major_contiguous(sorted_sequence)) {
TORCH_WARN_ONCE(
"torch.searchsorted(): boundary tensor is non-contiguous, this will lower the performance due "
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous boundary "
"tensor if possible. This message will only appear once per program.");
}
if (sorter_opt.has_value()) {
const at::Tensor &st = *sorter_opt;
if (st.defined() && !searchsorted_tensor_is_row_major_contiguous(st)) {
TORCH_WARN_ONCE(
"torch.searchsorted(): sorter tensor is non-contiguous, this will lower the performance due "
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous sorter "
"tensor if possible. This message will only appear once per program.");
}
}
return 0;
}
inline int warn_if_searchsorted_scalar_inputs_noncontiguous(
const at::Tensor &sorted_sequence, const c10::optional<at::Tensor> &sorter_opt) {
if (!searchsorted_tensor_is_row_major_contiguous(sorted_sequence)) {
TORCH_WARN_ONCE(
"torch.searchsorted(): boundary tensor is non-contiguous, this will lower the performance due "
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous boundary "
"tensor if possible. This message will only appear once per program.");
}
if (sorter_opt.has_value()) {
const at::Tensor &st = *sorter_opt;
if (st.defined() && !searchsorted_tensor_is_row_major_contiguous(st)) {
TORCH_WARN_ONCE(
"torch.searchsorted(): sorter tensor is non-contiguous, this will lower the performance due "
"to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous sorter "
"tensor if possible. This message will only appear once per program.");
}
}
return 0;
}
}
#endif