#include "torch_npu/csrc/framework/contiguous/ContiguousOpt.h"
#include "torch_npu/csrc/aten/CustomFunctions.h"
namespace at_npu {
namespace native {
class SliceContiguousOpt : public ContiguousOpt {
public:
bool Optimizer(at::Tensor &self, const at::Tensor &src,
const ContiguousTensorDesc &src_desc) override {
c10::SmallVector<int64_t, MAX_DIM> offsets;
c10::SmallVector<int64_t, MAX_DIM> size;
if (can_use_slice(src_desc, offsets, size)) {
RECORD_FUNCTION("contiguous_d_Slice", std::vector<c10::IValue>({src}));
slice_to_contiguous(self, src, offsets, size, src_desc);
return true;
}
return false;
}
bool CanOptimizer(const ContiguousTensorDesc &src_desc) override {
c10::SmallVector<int64_t, MAX_DIM> offsets;
c10::SmallVector<int64_t, MAX_DIM> size;
return can_use_slice(src_desc, offsets, size);
}
bool CachedOptimizer(at::Tensor &self, const at::Tensor &src,
const ContiguousTensorDesc &src_desc) override
{
if (src_desc.cached_contiguous) {
RECORD_FUNCTION("cached_contiguous_d_Slice", std::vector<c10::IValue>({src}));
CachedContiguousOpt cachedContiguousOpt = TransContiguous::cached_contiguous_opt[src_desc.hash_src_desc];
c10::SmallVector<int64_t, MAX_DIM> size = cachedContiguousOpt.cached_opt_parameters.pop_back_val();
c10::SmallVector<int64_t, MAX_DIM> offsets = cachedContiguousOpt.cached_opt_parameters.pop_back_val();
slice_to_contiguous(self, src, offsets, size, src_desc);
return true;
}
c10::SmallVector<int64_t, MAX_DIM> offsets;
c10::SmallVector<int64_t, MAX_DIM> size;
if (can_use_slice(src_desc, offsets, size)) {
RECORD_FUNCTION("contiguous_d_Slice", std::vector<c10::IValue>({src}));
CachedContiguousOpt cached_opt = CachedContiguousOpt{
"slice"
};
cached_opt.cached_opt_parameters.emplace_back(offsets);
cached_opt.cached_opt_parameters.emplace_back(size);
cached_opt.contiguous_tensor_desc = src_desc;
TransContiguous::cached_contiguous_opt[src_desc.hash_src_desc] = cached_opt;
slice_to_contiguous(self, src, offsets, size, src_desc);
return true;
}
return false;
}
private:
bool can_use_slice(const ContiguousTensorDesc &src_desc,
c10::SmallVector<int64_t, MAX_DIM> &offsets,
c10::SmallVector<int64_t, MAX_DIM> &size) {
const auto &base_sizes = src_desc.base_sizes_;
const auto &base_strides = src_desc.base_strides_;
auto view_sizes = src_desc.sizes_;
auto view_strides = src_desc.strides_;
if (view_strides[view_strides.size() - 1] != 1 &&
FormatHelper::IsBaseFormatType(src_desc.npu_format_) &&
view_strides.size() < base_strides.size() &&
c10::multiply_integers(view_sizes) <
c10::multiply_integers(base_sizes) / base_sizes[base_sizes.size() - 1]) {
view_sizes.emplace_back(1);
view_strides.emplace_back(1);
}
if (view_strides != base_strides) {
return false;
}
c10::SmallVector<int64_t, MAX_DIM> narrow_dims;
if (view_sizes.size() != base_sizes.size()) {
return false;
}
for (const auto i : c10::irange(view_sizes.size())) {
if (view_sizes[i] == base_sizes[i]) {
narrow_dims.emplace_back(0);
} else if (view_sizes[i] < base_sizes[i]) {
narrow_dims.emplace_back(1);
} else {
return false;
}
}
size = view_sizes;
offsets.clear();
int64_t storage_offsets = src_desc.offset_;
for (const auto i : c10::irange(view_strides.size())) {
offsets.emplace_back(storage_offsets / view_strides[i]);
storage_offsets = storage_offsets % view_strides[i];
}
if (storage_offsets != 0) {
return false;
}
for (const auto i : c10::irange(offsets.size())) {
if ((offsets[i] + view_sizes[i]) > base_sizes[i]) {
return false;
}
if (offsets[i] != 0 && narrow_dims[i] == 0) {
return false;
}
}
return true;
}
void slice_to_contiguous(at::Tensor &self, const at::Tensor &src,
const c10::SmallVector<int64_t, MAX_DIM> &offsets,
const c10::SmallVector<int64_t, MAX_DIM> &size,
const ContiguousTensorDesc &src_desc) {
const auto &temp_tensor_size = src_desc.base_sizes_;
at::Tensor temp_src = TransContiguous::view_tensor(src, src_desc.base_offset_, temp_tensor_size, src_desc.base_strides_);
custom_ops::npu_slice_out(temp_src, offsets, size, self);
return;
}
};
REGISTER_COPY_OPT(slice, SliceContiguousOpt)
}
}