* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*
* The code snippet comes from Tensorflow project.
*
* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* \file strided_slice.h
* \brief
*/
#ifndef OPS_BUILT_IN_OP_PROTO_RUNTIME_STRIDED_SLICE_H_
#define OPS_BUILT_IN_OP_PROTO_RUNTIME_STRIDED_SLICE_H_
#include <string>
#include "log/log.h"
namespace ops {
using QuickVector = gert::Shape;
static const int32_t kStridedSliceNewAxis = -2;
static const std::string OP_NAME = "StridedSlice";
struct StridedSliceParams {
gert::Shape input_shape;
QuickVector begin;
QuickVector end;
QuickVector strides;
uint64_t begin_mask;
uint64_t end_mask;
uint64_t ellipsis_mask;
uint64_t new_axis_mask;
uint64_t shrink_axis_mask;
bool begin_valid;
bool end_valid;
bool stride_valid;
bool real_begin_valid = true;
bool real_end_valid = true;
QuickVector dy_shape;
std::string to_string() const
{
std::string result = "input_shape:" + Ops::Base::ToString(input_shape);
result += " begin:" + Ops::Base::ToString(begin);
result += " end:" + Ops::Base::ToString(end);
result += " strides:" + Ops::Base::ToString(strides);
result += " begin_mask:" + std::to_string(begin_mask);
result += " end_mask:" + std::to_string(end_mask);
result += " ellipsis_mask:" + std::to_string(ellipsis_mask);
result += " new_axis_mask:" + std::to_string(new_axis_mask);
result += " shrink_axis_mask:" + std::to_string(shrink_axis_mask);
result += " begin_valid:" + std::to_string(begin_valid);
result += " end_valid:" + std::to_string(end_valid);
result += " stride_valid:" + std::to_string(stride_valid);
result += " real_begin_valid:" + std::to_string(static_cast<int32_t>(real_begin_valid));
result += " real_end_valid:" + std::to_string(static_cast<int32_t>(real_end_valid));
return result;
}
};
struct ProcessingData {
gert::Shape processing_shape;
QuickVector processing_begin;
QuickVector processing_end;
QuickVector processing_strides;
};
struct InputParamUnit {
int64_t begin;
int64_t end;
int64_t stride;
int64_t dim;
bool shrink;
};
struct StridedSliceSparseSpec {
int64_t dims;
int32_t num_add_axis_after_ellipsis;
const QuickVector begin;
const QuickVector end;
const QuickVector strides;
const uint64_t begin_mask;
const uint64_t end_mask;
uint64_t ellipsis_mask;
const uint64_t new_axis_mask;
const uint64_t shrink_axis_mask;
};
struct StridedSliceDenseSpec {
const int64_t dims;
uint64_t begin_mask;
uint64_t end_mask;
bool begin_valid;
bool end_valid;
QuickVector begin;
QuickVector end;
QuickVector strides;
gert::Shape final_shape_gather_indices;
uint64_t shrink_axis_mask;
};
static inline uint64_t bit1value(int i)
{
const uint64_t bit_i = static_cast<uint64_t>(1) << static_cast<uint64_t>(i);
return bit_i;
}
static bool FwdOutOfBound(int64_t fwd, int64_t lower, int64_t upper)
{
return (fwd < lower) || (fwd >= upper);
}
static void BuildSparseSpec(StridedSliceParams& params, StridedSliceSparseSpec& sparse_spec)
{
sparse_spec.dims = static_cast<int64_t>(params.strides.GetDimNum());
bool ellipsis_seen = false;
for (int32_t i = 0; i < sparse_spec.dims; i++) {
const uint64_t bit_i = bit1value(i);
if (ellipsis_seen && (bit_i & params.new_axis_mask) != 0) {
sparse_spec.num_add_axis_after_ellipsis++;
}
if ((bit_i & params.ellipsis_mask) != 0) {
ellipsis_seen = true;
}
}
if (!ellipsis_seen) {
sparse_spec.ellipsis_mask |= bit1value(sparse_spec.dims);
sparse_spec.dims++;
}
}
static bool BuildDenseSpec(const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense)
{
constexpr int32_t kShrinkAxis = -1;
dense->begin.SetDimNum(dense->dims);
dense->end.SetDimNum(dense->dims);
dense->strides.SetDimNum(dense->dims);
dense->begin_mask = 0;
dense->end_mask = 0;
dense->shrink_axis_mask = 0;
int full_index = 0;
for (int i = 0; i < sparse.dims; i++) {
const uint64_t bit_i = bit1value(i);
if ((bit_i & sparse.ellipsis_mask) != 0) {
int32_t next_index =
std::min(dense->dims - (sparse.dims - i) + 1 + sparse.num_add_axis_after_ellipsis, dense->dims);
for (; full_index < next_index; full_index++) {
dense->begin[full_index] = dense->end[full_index] = 0;
dense->strides[full_index] = 1;
dense->begin_mask |= bit1value(full_index);
dense->end_mask |= bit1value(full_index);
dense->final_shape_gather_indices.AppendDim(full_index);
}
} else if ((bit_i & sparse.new_axis_mask) != 0) {
dense->final_shape_gather_indices.AppendDim(kStridedSliceNewAxis);
} else {
if (static_cast<size_t>(full_index) == dense->begin.GetDimNum()) {
OP_LOGE(
OP_NAME, "Index out of range using input dim %d; input has only %ld dims.", full_index,
dense->dims);
return false;
}
dense->begin[full_index] = sparse.begin[i];
dense->end[full_index] = sparse.end[i];
dense->strides[full_index] = sparse.strides[i];
if ((sparse.begin_mask & bit_i) != 0) {
dense->begin_mask |= bit1value(full_index);
}
if ((sparse.end_mask & bit_i) != 0) {
dense->end_mask |= bit1value(full_index);
}
if ((sparse.shrink_axis_mask & bit_i) != 0) {
dense->final_shape_gather_indices.AppendDim(kShrinkAxis);
dense->shrink_axis_mask |= bit1value(full_index);
} else {
dense->final_shape_gather_indices.AppendDim(full_index);
}
full_index++;
}
}
return true;
}
static void BuildProcessingShape(
StridedSliceDenseSpec& dense_spec, InputParamUnit& input_param_unit, const bool begin_and_end_masked,
gert::Shape& processing_shape)
{
int64_t interval_length;
bool known_interval = false;
if (dense_spec.begin_valid && dense_spec.end_valid) {
interval_length = input_param_unit.end - input_param_unit.begin;
known_interval = true;
} else if (input_param_unit.shrink) {
interval_length = 1;
known_interval = true;
} else if (begin_and_end_masked) {
if (input_param_unit.dim >= 0) {
if (input_param_unit.stride < 0) {
interval_length = -input_param_unit.dim;
} else {
interval_length = input_param_unit.dim;
}
known_interval = true;
}
}
if (known_interval) {
int64_t size_i;
if (interval_length == 0 || ((interval_length < 0) != (input_param_unit.stride < 0))) {
size_i = 0;
} else {
size_i =
interval_length / input_param_unit.stride + (interval_length % input_param_unit.stride != 0 ? 1 : 0);
}
processing_shape.AppendDim(size_i);
} else {
processing_shape.AppendDim(-1);
}
}
static bool BuildProcessingData(
StridedSliceDenseSpec& dense_spec, StridedSliceParams& params, ProcessingData& processing_data)
{
bool is_identity = true;
bool slice_dim0 = true;
bool is_simple_slice = true;
for (int i = 0; i < static_cast<int>(params.input_shape.GetDimNum()); ++i) {
auto& begin_i = params.begin[i];
auto& end_i = params.end[i];
auto& stride_i = params.strides[i];
auto dim_i = params.input_shape.GetDim(i);
if (stride_i == 0) {
OP_LOGE(OP_NAME, "strides[%d] must be non-zero", i);
return false;
}
const uint64_t bit_i = bit1value(i);
bool shrink_i = (dense_spec.shrink_axis_mask & bit_i);
const std::array<uint64_t, 2> masks = {{dense_spec.begin_mask & bit_i, dense_spec.end_mask & bit_i}};
if (dim_i == -1) {
processing_data.processing_shape.AppendDim(shrink_i ? 1 : -1);
processing_data.processing_begin.AppendDim(begin_i);
processing_data.processing_end.AppendDim(shrink_i ? (begin_i + 1) : end_i);
processing_data.processing_strides.AppendDim(shrink_i ? 1 : stride_i);
continue;
}
const std::array<bool, 2> real_valid = {params.real_begin_valid, params.real_end_valid};
const std::array<int64_t, 2> valid_range = {{stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}};
auto canonical = [stride_i, dim_i, masks, valid_range, real_valid](int64_t x, int c) {
if (masks[c]) {
return stride_i > 0 ? valid_range[c] :
valid_range[static_cast<uint64_t>(c + 1) & static_cast<uint64_t>(1)];
} else {
if (!real_valid[c]) {
return -3L;
}
int64_t x_fwd = x < 0 ? dim_i + x : x;
return x_fwd < valid_range[0] ? valid_range[0] : std::min(x_fwd, valid_range[1]);
}
};
if (shrink_i && stride_i <= 0) {
OP_LOGE(OP_NAME, "only stride 1 allowed on non-range indexing.");
return false;
}
is_simple_slice = is_simple_slice && (stride_i == 1);
const bool begin_and_end_masked =
((dense_spec.begin_mask & bit_i) != 0) && ((dense_spec.end_mask & bit_i) != 0);
if (dense_spec.begin_valid && dense_spec.end_valid) {
if (shrink_i) {
if (real_valid[0]) {
int64_t x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
begin_i = x_fwd;
end_i = begin_i + 1;
if (FwdOutOfBound(x_fwd, 0, dim_i)) {
OP_LOGE(OP_NAME, "slice index %ld of dimension %d out of bounds.", begin_i, i);
return false;
}
} else {
begin_i = -2;
end_i = begin_i + 1;
}
} else {
begin_i = canonical(begin_i, 0);
end_i = canonical(end_i, 1);
}
if ((!real_valid[0] || !real_valid[1]) && (begin_i == -3 || end_i == -3)) {
OP_LOGE(
OP_NAME, "begin_i:%ld end_i:%ld is invalid while unconst begin or end, shrink_i:%d masks:%lu %lu",
begin_i, end_i, shrink_i, masks[0], masks[1]);
return false;
}
processing_data.processing_begin.AppendDim(begin_i);
processing_data.processing_end.AppendDim(end_i);
processing_data.processing_strides.AppendDim(stride_i);
bool take_all_in_dimension = stride_i == 1 && begin_i == 0 && end_i == dim_i;
is_identity = is_identity && take_all_in_dimension;
slice_dim0 = slice_dim0 && ((i == 0 && stride_i == 1) || take_all_in_dimension);
} else {
is_identity = is_identity && (stride_i == 1 && begin_and_end_masked);
slice_dim0 = slice_dim0 && ((i == 0 && stride_i == 1) || begin_and_end_masked);
processing_data.processing_begin.AppendDim(begin_i);
processing_data.processing_end.AppendDim(end_i);
processing_data.processing_strides.AppendDim(1);
}
InputParamUnit input_param_unit = {begin_i, end_i, stride_i, dim_i, shrink_i};
BuildProcessingShape(dense_spec, input_param_unit, begin_and_end_masked, processing_data.processing_shape);
}
return true;
}
static void BuildFinalShape(
ProcessingData& processing_data, StridedSliceDenseSpec& dense_spec, StridedSliceParams& params,
gert::Shape* out_shape)
{
params.begin.SetDimNum(0);
params.end.SetDimNum(0);
params.strides.SetDimNum(0);
out_shape->SetDimNum(0);
gert::Shape final_shape_input;
int shrink_gather_index = 0;
for (size_t i = 0; i < dense_spec.final_shape_gather_indices.GetDimNum(); i++) {
auto gather_index = dense_spec.final_shape_gather_indices.GetDim(i);
if (gather_index >= 0) {
const auto dim_gather_i = processing_data.processing_shape[gather_index];
out_shape->AppendDim(dim_gather_i);
final_shape_input.AppendDim(params.input_shape.GetDim(gather_index));
params.begin.AppendDim(processing_data.processing_begin[gather_index]);
params.end.AppendDim(processing_data.processing_end[gather_index]);
params.strides.AppendDim(processing_data.processing_strides[gather_index]);
shrink_gather_index = gather_index + 1;
} else if (gather_index == kStridedSliceNewAxis) {
out_shape->AppendDim(1);
if (params.input_shape.IsScalar()) {
final_shape_input.AppendDim(1);
params.begin.AppendDim(0);
params.end.AppendDim(1);
params.strides.AppendDim(1);
}
} else {
final_shape_input.AppendDim(params.input_shape.GetDim(shrink_gather_index));
params.begin.AppendDim(processing_data.processing_begin[shrink_gather_index]);
params.end.AppendDim(processing_data.processing_begin[shrink_gather_index] + 1);
params.strides.AppendDim(1);
shrink_gather_index += 1;
}
}
params.input_shape = final_shape_input;
}
inline bool InferShape(StridedSliceParams& params, gert::Shape* out_shape)
{
OP_LOGD(OP_NAME, "input params:%s.", params.to_string().c_str());
auto& ellipsis_mask = params.ellipsis_mask;
if ((ellipsis_mask != 0) && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) {
OP_LOGE(OP_NAME, "Multiple ellipses in slice spec not allowed.");
return false;
}
StridedSliceSparseSpec sparse_spec = {
0,
0,
params.begin,
params.end,
params.strides,
params.begin_mask,
params.end_mask,
params.ellipsis_mask,
params.new_axis_mask,
params.shrink_axis_mask};
BuildSparseSpec(params, sparse_spec);
gert::Shape final_shape_gather_indices;
StridedSliceDenseSpec dense_spec = {
static_cast<int64_t>(params.input_shape.GetDimNum()),
0,
0,
params.begin_valid,
params.end_valid,
params.begin,
params.end,
params.strides,
final_shape_gather_indices,
0};
if (!BuildDenseSpec(sparse_spec, &dense_spec)) {
return false;
}
OP_LOGD(
OP_NAME, "DenseSpec: begin_mask:%lu end_mask:%lu begin:%s end:%s strides:%s indices:%s shrink_axis_mask:%lu",
dense_spec.begin_mask, dense_spec.end_mask, Ops::Base::ToString(dense_spec.begin).c_str(),
Ops::Base::ToString(dense_spec.end).c_str(), Ops::Base::ToString(dense_spec.strides).c_str(),
Ops::Base::ToString(dense_spec.final_shape_gather_indices).c_str(), dense_spec.shrink_axis_mask);
ProcessingData processing_data;
params.begin = dense_spec.begin;
params.end = dense_spec.end;
params.strides = dense_spec.strides;
if (!BuildProcessingData(dense_spec, params, processing_data)) {
return false;
}
OP_LOGD(
OP_NAME, "ProcessingData: shape:%s begin:%s end:%s strides:%s.",
Ops::Base::ToString(processing_data.processing_shape).c_str(),
Ops::Base::ToString(processing_data.processing_begin).c_str(),
Ops::Base::ToString(processing_data.processing_end).c_str(),
Ops::Base::ToString(processing_data.processing_strides).c_str());
BuildFinalShape(processing_data, dense_spec, params, out_shape);
OP_LOGI(
OP_NAME, "after infershape params:%s, output_shape:%s.", params.to_string().c_str(),
Ops::Base::ToString(*out_shape).c_str());
return true;
}
}
#endif