#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
#include "op_plugin/utils/custom_functions/aclops/inner_compute.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
using npu_utils = at_npu::native::NpuUtils;
namespace {
at::SmallVector<int64_t, SIZE> where_npu_output_size(const at::Tensor& condition) {
int64_t dim = condition.dim();
at::Tensor boolSelf = at_npu::native::custom_ops::_npu_dtype_cast(condition, at::ScalarType::Bool);
at::Tensor intSelf = at_npu::native::custom_ops::_npu_dtype_cast(boolSelf, at::ScalarType::Int);
at::Tensor cout_nonzero_self = at::sum(intSelf, at::ScalarType::Int);
int64_t nonzero_num = cout_nonzero_self.item().toInt();
at::SmallVector<int64_t, SIZE> output_size = {nonzero_num, dim};
return output_size;
}
}
std::vector<at::Tensor> where(const at::Tensor& condition) {
at::Tensor format_cast_of_condition = condition;
if (npu_preparation::get_tensor_npu_format(condition) != ACL_FORMAT_ND) {
format_cast_of_condition =
at_npu::native::custom_ops::npu_format_cast(format_cast_of_condition, ACL_FORMAT_ND);
}
if (condition.scalar_type() == at::ScalarType::Half) {
format_cast_of_condition = at_npu::native::custom_ops::_npu_dtype_cast(format_cast_of_condition, at::ScalarType::Float);
}
auto output_size = where_npu_output_size(format_cast_of_condition);
at::Tensor result = npu_preparation::apply_tensor_with_format(
output_size, format_cast_of_condition.options().dtype(at::kLong), ACL_FORMAT_ND);
at_npu::native::OpCommand cmd;
cmd.Name("NonZero")
.Input(format_cast_of_condition)
.Output(result)
.Run();
result = result.transpose(1, 0);
std::vector<at::Tensor> chunk_result = result.chunk(result.size(0), 0);
std::vector<at::Tensor> squeeze_result;
for (uint64_t i = 0; i < chunk_result.size(); i++) {
squeeze_result.push_back(chunk_result[i].squeeze(0));
}
return squeeze_result;
}
at::Tensor& where_out(
const at::Tensor& condition,
const at::Tensor& self,
const at::Tensor& other,
at::Tensor& out) {
at::Tensor b_condition;
at::Tensor b_self;
at::Tensor b_other;
std::tie(b_condition, b_self, b_other) = npu_expand_outplace(condition, self, other, "where_npu");
npu_preparation::CheckOut(
{condition, self, other},
out,
b_self);
if (!npu_utils::check_match(&out)) {
at::Tensor contiguous_out = npu_utils::format_contiguous(out);
where_out_nocheck(contiguous_out, condition, self, other);
npu_utils::format_fresh_view(out, contiguous_out);
} else {
where_out_nocheck(out, condition, self, other);
}
return out;
}
at::Tensor where(
const at::Tensor& condition,
const at::Tensor& self,
const at::Tensor& other) {
at::Tensor b_condition;
at::Tensor b_self;
at::Tensor b_other;
std::tie(b_condition, b_self, b_other) = npu_expand_outplace(condition, self, other, "where_npu");
at::Tensor ret = npu_preparation::apply_tensor(b_self);
where_out_nocheck(ret, b_condition, b_self, b_other);
return ret;
}
}