#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/OpApiInterface.h"
#include "op_plugin/utils/op_api_common.h"
namespace {
#if VERSION_BETWEEN(V2R1, VERSION_NEWEST)
static void isin_sorting(const at::Tensor& elements,
const at::Tensor& test_elements,
bool assume_unique,
bool invert,
const at::Tensor& out)
{
at::Tensor elements_flat;
at::Tensor test_elements_flat;
at::Tensor unique_order;
if (assume_unique) {
elements_flat = elements.ravel();
test_elements_flat = test_elements.ravel();
} else {
std::tie (elements_flat, unique_order) = at::_unique(elements, false, true);
std::tie (test_elements_flat, std::ignore) = at::_unique(test_elements, false);
}
at::Tensor all_elements = at::cat({std::move(elements_flat), std::move(test_elements_flat)});
at::Tensor sorted_elements;
at::Tensor sorted_order;
std::tie (sorted_elements, sorted_order) = all_elements.sort(true, 0, false);
at::Tensor duplicate_mask = at::empty_like(sorted_elements, at::TensorOptions(at::ScalarType::Bool));
at::Tensor sorted_except_first = sorted_elements.slice(0, 1, at::indexing::None);
at::Tensor sorted_except_last = sorted_elements.slice(0, 0, -1);
duplicate_mask.slice(0, 0, -1).copy_(
invert ? sorted_except_first.ne(sorted_except_last) : sorted_except_first.eq(sorted_except_last));
duplicate_mask.index_put_({-1}, invert);
at::Tensor mask = at::empty_like(duplicate_mask);
mask.index_copy_(0, sorted_order, duplicate_mask);
if (assume_unique) {
out.copy_(mask.slice(0, 0, elements.numel()).view_as(out));
} else {
out.copy_(at::index(mask, {c10::optional<at::Tensor>(unique_order)}));
}
}
void isin_default_kernel_npu(const at::Tensor& elements,
const at::Tensor& test_elements,
bool invert,
const at::Tensor& out)
{
std::vector<int64_t> bc_shape(elements.dim(), 1);
bc_shape.push_back(-1);
out.copy_(invert ? elements.unsqueeze(-1).ne(test_elements.view(bc_shape)).all(-1)
: elements.unsqueeze(-1).eq(test_elements.view(bc_shape)).any(-1));
}
void isin_Tensor_Tensor_out_impl(const at::Tensor& elements,
const at::Tensor& test_elements,
bool assume_unique,
bool invert,
const at::Tensor& out)
{
if (elements.numel() == 0) {
return;
}
if (test_elements.numel() < static_cast<int64_t>(
10.0f * std::pow(static_cast<double>(elements.numel()), 0.145))) {
out.fill_(invert);
isin_default_kernel_npu(elements, test_elements, invert, out);
} else {
isin_sorting(elements, test_elements, assume_unique, invert, out);
}
}
#endif
}
namespace op_api {
using npu_preparation = at_npu::native::OpPreparation;
#if VERSION_BETWEEN(V2R1, VERSION_NEWEST)
at::Tensor& isin_out(const at::Tensor& elements, const at::Tensor &test_elements,
bool assume_unique, bool invert, at::Tensor& result)
{
npu_preparation::check_tensor({elements, test_elements}, result, at::kBool, elements.sizes());
isin_Tensor_Tensor_out_impl(elements, test_elements, assume_unique, invert, result);
return result;
}
at::Tensor isin(const at::Tensor& elements, const at::Tensor &test_elements,
bool assume_unique, bool invert)
{
at::Tensor result = npu_preparation::apply_tensor_without_format(
elements.sizes(),
elements.options().dtype(at::kBool));
isin_Tensor_Tensor_out_impl(elements, test_elements, assume_unique, invert, result);
return result;
}
#endif
}