#include "torch_npu/csrc/framework/contiguous/ContiguousOpt.h"
#include "torch_npu/csrc/aten/CustomFunctions.h"
namespace at_npu {
namespace native {
class IndexingContiguousOpt : public ContiguousOpt {
public:
bool Optimizer(at::Tensor &self, const at::Tensor &src,
const ContiguousTensorDesc &src_desc) override {
c10::SmallVector<int64_t, MAX_DIM> start;
c10::SmallVector<int64_t, MAX_DIM> end;
c10::SmallVector<int64_t, MAX_DIM> step;
if (can_use_indexing(src_desc, start, end, step)) {
RECORD_FUNCTION("contiguous_d_StridedSlice", std::vector<c10::IValue>({src}));
indexing_to_contiguous(self, src, start, end, step, src_desc);
return true;
}
return false;
}
private:
bool can_use_indexing(const ContiguousTensorDesc &src_desc,
c10::SmallVector<int64_t, MAX_DIM> &start,
c10::SmallVector<int64_t, MAX_DIM> &end,
c10::SmallVector<int64_t, MAX_DIM> &step) {
if (c10::multiply_integers(src_desc.sizes_) >=
c10::multiply_integers(src_desc.base_sizes_)) {
return false;
}
if (src_desc.sizes_.size() != src_desc.base_sizes_.size()) {
return false;
}
if (src_desc.strides_.size() != src_desc.base_strides_.size()) {
return false;
}
const auto &base_size = src_desc.base_sizes_;
const auto &base_stride = src_desc.base_strides_;
const auto &indexing_size = src_desc.sizes_;
const auto &indexing_stride = src_desc.strides_;
for (const auto i : c10::irange(indexing_size.size())) {
if ((base_stride[i] == 0) ||
(indexing_stride[i] < base_stride[i]) ||
((indexing_stride[i] % base_stride[i]) != 0)) {
return false;
}
}
for (const auto i : c10::irange(indexing_size.size())) {
step.emplace_back(indexing_stride[i] / base_stride[i]);
}
int64_t src_offset = src_desc.offset_;
for (const auto i : c10::irange(indexing_size.size())) {
start.emplace_back(src_offset / base_stride[i]);
src_offset = src_offset % base_stride[i];
}
for (const auto i : c10::irange(indexing_size.size())) {
int64_t calculate_end = start[i] + indexing_size[i] * step[i];
if (calculate_end - step[i] > src_desc.base_sizes_[i]) {
return false;
}
end.emplace_back(calculate_end);
}
if (c10::multiply_integers(step) == 1 || step[step.size() - 1] != 1) {
return false;
}
for (const auto i : c10::irange(step.size())) {
if (step[i] == 1 && indexing_size[i] != base_size[i]) {
return false;
}
}
for (const auto i : c10::irange(step.size() - 1)) {
if (step[i] != 1) {
if (indexing_size[i] == 1) {
return false;
}
if (step[i + 1] == 1 &&
(indexing_stride[i] !=
indexing_size[i + 1] * indexing_stride[i + 1] * step[i])) {
return false;
}
}
}
return true;
}
void indexing_to_contiguous(at::Tensor &self, const at::Tensor &src,
c10::SmallVector<int64_t, MAX_DIM> &start,
c10::SmallVector<int64_t, MAX_DIM> &end,
c10::SmallVector<int64_t, MAX_DIM> &step,
const ContiguousTensorDesc &src_desc) {
const auto &base_size = src_desc.base_sizes_;
at::Tensor temp_src = TransContiguous::view_tensor(src, src_desc.base_offset_, base_size, src_desc.base_strides_);
custom_ops::npu_indexing_out(temp_src, start, end, step, 0, 0, 0, 0, 0, self);
return;
}
};
REGISTER_COPY_OPT(indexing, IndexingContiguousOpt)
}
}