* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file strided_slice_v3.cc
* \brief
*/
#include <numeric>
#include <cmath>
#include "register/op_impl_registry.h"
#include "../../strided_slice/op_host/strided_slice_util.h"
#include "log/log.h"
#include "util/const_util.h"
using namespace std;
namespace {
const std::string OP_NAME = "StridedSliceV3";
const int INDEX_X = 0;
const int INDEX_BEGIN = 1;
const int INDEX_END = 2;
const int INDEX_AXES = 3;
const int INDEX_STRIDES = 4;
const int INDEX_Y = 0;
}
namespace ops {
static int64_t GetConstIndexValue(const gert::Tensor* tensor, size_t idx) {
int64_t value = 0;
if (tensor->GetDataType() == ge::DT_INT32) {
const int32_t* data = tensor->GetData<int32_t>();
value = static_cast<int64_t>(data[idx]);
} else {
const int64_t* data = tensor->GetData<int64_t>();
value = data[idx];
}
OP_LOGD(OP_NAME.c_str(), "const tensor[%ld] is %ld.", idx, value);
return value;
}
static int64_t GetConstIndexValue(const gert::Tensor* tensor, size_t idx, int64_t input_size, int64_t clip_lower,
int64_t clip_upper) {
int64_t value = 0;
if (tensor->GetDataType() == ge::DT_INT32) {
const int32_t* data = tensor->GetData<int32_t>();
value = static_cast<int64_t>(data[idx]);
} else {
const int64_t* data = tensor->GetData<int64_t>();
value = data[idx];
}
if (value < 0) {
value += input_size;
}
if (value < clip_lower) {
value = clip_lower;
} else if (value > clip_upper) {
value = clip_upper;
}
OP_LOGD(OP_NAME.c_str(), "const tensor[%ld] is %ld.", idx, value);
return value;
}
template <typename T>
static void PositiveAxisImpl(int32_t input_dims, const gert::Tensor* axis_tensor, vector<int32_t>& new_axis) {
const int64_t axis_size = axis_tensor->GetShapeSize();
const T* data = axis_tensor->GetData<T>();
for (int i = 0; i < axis_size; i++) {
int64_t value = static_cast<int64_t>(data[i]);
if (value >= 0 && value < input_dims) {
new_axis.push_back(value);
} else if (value < 0 && value >= -input_dims) {
new_axis.push_back(value + input_dims);
}
}
return;
}
static std::vector<int32_t> ConstructValidAxis(int32_t input_dims, const gert::Tensor* axis_tensor) {
std::vector<int32_t> new_axis;
if (!axis_tensor || axis_tensor->GetShapeSize() == 0) {
new_axis.resize(input_dims);
std::iota(new_axis.begin(), new_axis.end(), 0);
return new_axis;
}
if (axis_tensor->GetDataType() == ge::DT_INT32) {
PositiveAxisImpl<int32_t>(input_dims, axis_tensor, new_axis);
} else if (axis_tensor->GetDataType() == ge::DT_INT64) {
PositiveAxisImpl<int64_t>(input_dims, axis_tensor, new_axis);
}
return new_axis;
}
static ge::graphStatus StridedSliceV3InferShape(gert::InferShapeContext* context) {
const gert::Shape* x_shape = context->GetInputShape(INDEX_X);
gert::Shape* out_shape = context->GetOutputShape(INDEX_Y);
const gert::Tensor* begin_tensor = context->GetInputTensor(INDEX_BEGIN);
const gert::Tensor* end_tensor = context->GetInputTensor(INDEX_END);
if (x_shape == nullptr || out_shape == nullptr || begin_tensor == nullptr || end_tensor == nullptr) {
OP_LOGE(OP_NAME.c_str(), "input tensor or output tensor is null. Please check.");
return ge::GRAPH_FAILED;
}
*out_shape = *x_shape;
int32_t input_dim_num = static_cast<int32_t>(x_shape->GetDimNum());
std::vector<int32_t> new_axis = ConstructValidAxis(input_dim_num, context->GetOptionalInputTensor(INDEX_AXES));
const gert::Tensor* strides_tensor = context->GetOptionalInputTensor(INDEX_STRIDES);
const int32_t strides_size = (strides_tensor) ? static_cast<int32_t>(strides_tensor->GetShapeSize()) : 0;
const int32_t begins_size = static_cast<int32_t>(begin_tensor->GetShapeSize());
const int32_t ends_size = static_cast<int32_t>(end_tensor->GetShapeSize());
const int32_t axis_size = static_cast<int32_t>(new_axis.size());
if (axis_size == 0) {
OP_LOGE(OP_NAME.c_str(), "axis_size is 0. Please check.");
return ge::GRAPH_FAILED;
}
for (int32_t i = 0; i < axis_size; i++) {
const int32_t axis_value = new_axis[i];
int64_t step_value = 1;
if (i < strides_size) {
step_value = GetConstIndexValue(strides_tensor, i);
}
int64_t cur_axis_input_size = x_shape->GetDim(axis_value);
int64_t begin_value = 0;
if (i < begins_size) {
int64_t clip_upper = cur_axis_input_size;
if (step_value < 0) {
clip_upper -= 1;
}
begin_value = GetConstIndexValue(begin_tensor, i, cur_axis_input_size, 0, clip_upper);
}
int64_t end_value = cur_axis_input_size;
if (i < ends_size) {
int64_t clip_lower = 0;
if (step_value < 0) {
clip_lower = -1;
}
end_value = GetConstIndexValue(end_tensor, i, cur_axis_input_size, clip_lower, cur_axis_input_size);
}
int64_t cur_out_size = static_cast<int64_t>(std::ceil((end_value - begin_value) / static_cast<float>(step_value)));
if (cur_out_size < 0) {
cur_out_size = 0;
}
out_shape->SetDim(axis_value, cur_out_size);
}
OP_LOGD(OP_NAME.c_str(), "out_shape: %s", Ops::Base::ToString(*out_shape).c_str());
return ge::GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(StridedSliceV3).InferShape(StridedSliceV3InferShape).InputsDataDependency({1, 2, 3, 4});
}