#include "op_plugin/OpInterface.h"
#include "op_plugin/utils/OpAdapter.h"
#include "op_plugin/utils/op_api_common.h"
#include "op_plugin/utils/AdvancedIndex.h"
namespace op_plugin {
namespace {
std::vector<at::Tensor> npu_expand_outplace(at::TensorList to_expand)
{
bool first = true;
std::vector<int64_t> sizes;
for (size_t i = 0; i < to_expand.size(); ++i) {
if (!to_expand[i].defined()) {
continue;
} else if (first) {
sizes = to_expand[i].sizes().vec();
first = false;
} else {
sizes = at::infer_size(sizes, to_expand[i].sizes());
}
}
std::vector<at::Tensor> result(to_expand.size());
for (size_t i = 0; i < to_expand.size(); ++i) {
if (!to_expand[i].defined()) {
continue;
} else if (to_expand[i].sizes().equals(sizes)) {
result[i] = to_expand[i];
} else {
if (to_expand[i].dtype() == at::kLong) {
result[i] = to_expand[i].to(at::kInt).expand(sizes, true);
} else {
result[i] = to_expand[i].expand(sizes, true);
}
}
}
return result;
}
at::Tensor npu_nonzero_aclop(const at::Tensor &self)
{
c10::SmallVector<int64_t, SIZE> output_size = {self.dim(), self.numel()};
at::Tensor result = at_npu::native::OpPreparation::apply_tensor(output_size, self.options().dtype(at::kLong), self);
c10::SmallVector<int64_t, N> output_sync_idx = {0};
at_npu::native::OpCommand cmd;
cmd.Sync(output_sync_idx).Name("NonZero").Input(self).Output(result).Attr("transpose", true).Run();
return result;
}
at::Tensor npu_nonzero_aclnn(const at::Tensor &self)
{
DO_COMPATIBILITY(aclnnNonzeroV2, npu_nonzero_aclop(self));
c10::SmallVector<int64_t, SIZE> out_size = {self.dim(), self.numel()};
at::Tensor out =
at_npu::native::OpPreparation::apply_tensor_without_format(out_size, self.options().dtype(at::kLong));
static auto aclGetViewShapeAddr = []() {
auto ret = GetOpApiFuncAddr("aclGetViewShape");
TORCH_CHECK(ret != nullptr);
return ret;
}();
using aclGetViewShapeFuncLocal = int (*)(const aclTensor* tensor, int64_t** view_dims, uint64_t* view_dims_num);
auto aclGetViewShape = reinterpret_cast<aclGetViewShapeFuncLocal>(aclGetViewShapeAddr);
OP_EXEC_LOG(aclnnNonzeroV2, "EXEC_NPU_CMD_SYNC", self, out);
auto npuAclParams = EXEC_NPU_CMD_SYNC(aclnnNonzeroV2, self, out);
int64_t* view_dims = nullptr;
uint64_t view_dim_num = 0;
auto ret = aclGetViewShape(npuAclParams.Get<1>(), &view_dims, &view_dim_num);
TORCH_CHECK(ret == 0, "aclGetViewShape failed.");
c10::SmallVector<int64_t, op_infer::SIZE> output_size(view_dims, view_dims + view_dim_num);
out = out.resize_(output_size);
delete[] view_dims;
view_dims = nullptr;
return out;
}
}
AdvancedIndex::AdvancedIndex(const at::Tensor &src, at::TensorList list_indices)
{
int64_t before_dims = 0;
int64_t after_dims = 0;
int64_t indexed_dims = 0;
at::IntArrayRef replacement_shape;
for (size_t dim = 0; dim < list_indices.size(); dim++) {
if (!list_indices[dim].defined()) {
if (indexed_dims == 0) {
before_dims++;
} else {
after_dims++;
}
} else {
indexed_dims++;
replacement_shape = list_indices[dim].sizes();
indexed_sizes.push_back(src.size(dim));
indexed_strides.push_back(src.stride(dim));
}
}
if (std::find(replacement_shape.begin(), replacement_shape.end(), 0) == replacement_shape.end() &&
std::find(indexed_sizes.begin(), indexed_sizes.end(), 0) != indexed_sizes.end()) {
TORCH_CHECK_INDEX(false, "index is out of bounds for dimension with size 0.", OPS_ERROR(ErrCode::PARAM));
}
this->dims_before = before_dims;
this->dims_after = after_dims;
this->src = AdvanceIndex::restride_src(src, before_dims, indexed_dims, replacement_shape);
for (auto &index : list_indices) {
if (index.defined()) {
indices.push_back(AdvanceIndex::reshape_indexer(index, before_dims, after_dims));
}
}
}
bool AdvanceIndex::all_strides_match(at::TensorList tensor_list)
{
TORCH_CHECK(tensor_list.size() >= 1, OPS_ERROR(ErrCode::PARAM));
auto strides = tensor_list[0].strides();
for (auto &tensor : tensor_list.slice(1)) {
if (!strides.equals(tensor.strides())) {
return false;
}
}
return true;
}
at::Tensor AdvanceIndex::reshape_indexer(const at::Tensor &index, int64_t dims_before, int64_t dims_after)
{
auto orig_shape = index.sizes();
auto shape = at::DimVector();
shape.append(dims_before, 1);
shape.append(orig_shape.begin(), orig_shape.end());
shape.append(dims_after, 1);
if (index.dtype() == at::kLong) {
return index.reshape(shape);
} else {
return index.reshape(shape).to(at::kLong);
}
}
at::Tensor AdvanceIndex::restride_src(const at::Tensor &src, int64_t before_dims, int64_t dims_indexed,
at::IntArrayRef replacement_shape)
{
auto shape = at::DimVector(src.sizes());
auto strides = at::DimVector(src.strides());
int64_t end = before_dims + dims_indexed;
TORCH_CHECK(shape.size() >= end, "end", end, "is overrange shape.size() ", shape.size(), OPS_ERROR(ErrCode::VALUE));
shape.erase(shape.begin() + before_dims, shape.begin() + end);
TORCH_CHECK(strides.size() >= end, "end", end, "is overrange strides.size() ", strides.size(),
OPS_ERROR(ErrCode::VALUE));
strides.erase(strides.begin() + before_dims, strides.begin() + end);
shape.insert(shape.begin() + before_dims, replacement_shape.begin(), replacement_shape.end());
strides.insert(strides.begin() + before_dims, replacement_shape.size(), 0);
return src.as_strided(shape, strides);
}
std::string AdvanceIndex::shapes_as_str(at::TensorList tensors)
{
std::ostringstream os;
bool first = true;
for (auto &t : tensors) {
if (t.defined()) {
if (!first) {
os << ", ";
}
os << t.sizes();
first = false;
}
}
return os.str();
}
bool AdvanceIndex::checkIndexTensorTypes(const torch::List<c10::optional<at::Tensor>> &indices)
{
bool needCast = false;
c10::optional<at::ScalarType> indicesDtype;
for (c10::optional<at::Tensor> tensor : indices) {
if (tensor.has_value() && tensor->defined()) {
auto scalarType = tensor->scalar_type();
if (scalarType != at::kLong && scalarType != at::kByte &&
scalarType != at::kBool && scalarType != at::kInt) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, int, byte, or bool tensors",
OPS_ERROR(ErrCode::TYPE));
}
if (!indicesDtype.has_value()) {
indicesDtype = scalarType;
} else if (indicesDtype.value() != scalarType) {
needCast = true;
}
}
}
return needCast;
}
AdvancedIndex AdvanceIndex::make_info(at::Tensor self, const torch::List<c10::optional<at::Tensor>> &orig)
{
AdvanceIndex::checkIndexTensorTypes(orig);
auto indices = at::native::expandTensors(self, orig);
try {
indices = npu_expand_outplace(indices);
} catch (std::exception &e) {
TORCH_CHECK_INDEX(false,
"shape mismatch: indexing tensors could not be broadcast"
" together with shapes ",
shapes_as_str(indices),
OPS_ERROR(ErrCode::VALUE));
}
while (indices.size() < static_cast<size_t>(self.dim())) {
indices.emplace_back();
}
if (!at::native::hasContiguousSubspace(indices)) {
std::tie(self, indices) = at::native::transposeToFront(self, indices);
}
for (size_t i = 0; i < indices.size(); i++) {
if (indices[i].defined() && indices[i].device() != self.device()) {
indices[i] = indices[i].to(self.device());
}
}
return AdvancedIndex(self, indices);
}
std::vector<at::Tensor> AdvanceIndex::npu_expand_tensors(const at::Tensor &self,
const torch::List<c10::optional<at::Tensor>> &indices,
bool needCast,
bool flag_aclnn)
{
std::vector<at::Tensor> result;
for (c10::optional<at::Tensor> index_opt : indices) {
if (!index_opt.has_value()) {
result.emplace_back();
} else {
at::Tensor index = std::move(*index_opt);
if (index.defined() && index.device() != self.device()) {
index = index.to(self.device());
}
if (index.scalar_type() == at::kByte || index.scalar_type() == at::kBool) {
if (index.scalar_type() == at::kByte) {
TORCH_WARN("indexing with dtype torch.uint8 is now deprecated,"
" please use a dtype torch.bool instead.");
}
for (uint64_t j = 0; j < static_cast<uint64_t>(index.dim()); j++) {
uint64_t srcIdx = result.size() + j;
TORCH_CHECK_INDEX(index.size(j) == self.size(srcIdx), "The shape of the mask ", index.sizes(),
" at index ", j, " does not match the shape of the indexed tensor ", self.sizes(),
" at index ", srcIdx, OPS_ERROR(ErrCode::VALUE));
}
at::Tensor nonzero;
nonzero = flag_aclnn ? npu_nonzero_aclnn(index) : npu_nonzero_aclop(index);
for (int64_t j = 0; j < index.dim(); j++) {
result.emplace_back(nonzero.select(0, j));
}
} else {
result.emplace_back(std::move(index));
}
}
}
if (needCast) {
for (size_t i = 0; i < result.size(); i++) {
if (result[i].defined() && result[i].dtype() == at::kInt) {
result[i] = result[i].to(at::kLong);
}
}
}
return result;
}
std::vector<at::Tensor> AdvanceIndex::npu_broadcast_tensors(std::vector<at::Tensor> to_broadcast)
{
bool first = true;
std::vector<int64_t> sizes;
for (uint64_t i = 0; i < to_broadcast.size(); ++i) {
if (!to_broadcast[i].defined()) {
continue;
} else if (first) {
sizes = to_broadcast[i].sizes().vec();
first = false;
} else {
sizes = at::infer_size(sizes, to_broadcast[i].sizes());
}
}
std::vector<at::Tensor> result(to_broadcast.size());
for (uint64_t i = 0; i < to_broadcast.size(); ++i) {
if (!to_broadcast[i].defined()) {
continue;
} else if (to_broadcast[i].sizes().equals(sizes)) {
result[i] = to_broadcast[i];
} else {
result[i] = op_plugin::npu_broadcast(to_broadcast[i], sizes);
}
}
return result;
}
bool AdvanceIndex::is_expandable_to(c10::IntArrayRef shape, c10::IntArrayRef desired)
{
size_t ndim = shape.size();
size_t target_dim = desired.size();
if (ndim > target_dim) {
return false;
}
for (size_t i = 0; i < ndim; i++) {
int64_t size = shape[ndim - i - 1];
int64_t target = desired[target_dim - i - 1];
if (size != target && size != 1) {
return false;
}
}
return true;
}
}