* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <numeric>
#include "common/util/mem_utils.h"
#include "common/util.h"
#include "common/checker.h"
#include "framework/common/framework_types_internal.h"
#include "graph/optimize/symbolic/symbol_compute_context.h"
#include "graph/optimize/symbolic/infer_symbolic_shape/symbolic_infer_util.h"
#include "graph/optimize/symbolic/symbolic_kernel_factory.h"
namespace ge {
namespace {
constexpr size_t kXInputIndex = 0UL;
constexpr size_t kStartInputIndex = 1UL;
constexpr size_t kEndInputIndex = 2UL;
constexpr size_t kStridesInputIndex = 3UL;
constexpr size_t kOutputIndex = 0UL;
enum class StridedSliceAttrIndex {
kAttrBeginMaskIndex,
kAttrEndMaskIndex,
kAttrEllipsisMaskIndex,
kAttrNewAxisMaskIndex,
kAttrShrinkAxisMaskIndex,
kEnd
};
enum class StridedSliceDAttrIndex {
kAttrStartInputIndex,
kAttrEndInputIndex,
kAttrStridesInputIndex,
kAttrBeginMaskIndex,
kAttrEndMaskIndex,
kAttrEllipsisMaskIndex,
kAttrNewAxisMaskIndex,
kAttrShrinkAxisMaskIndex,
kEnd
};
struct StridedSliceAttr {
int64_t begin_mask{0};
int64_t end_mask{0};
int64_t ellipsis_mask{0};
int64_t new_axis_mask{0};
int64_t shrink_axis_mask{0};
};
struct StrdedSliceIndexInputs {
std::vector<int64_t> start_indexes;
std::vector<int64_t> end_indexes;
std::vector<int64_t> strides_indexes;
std::vector<bool> is_new_axis;
};
graphStatus GetValueFromInputConstData(const gert::InferSymbolComputeContext *context, const size_t index,
std::vector<int64_t> &dims) {
auto symbol_tensor = context->GetInputSymbolTensor(index);
GE_UNSUPPORTED_IF_NULL(symbol_tensor);
auto symbols = symbol_tensor->GetSymbolicValue();
if (symbols == nullptr) {
GELOGW("SymbolicKernel compute unsupported, reason: get %zu input symbolic value failed, node %s[%s].", index,
context->GetNodeName(), context->GetNodeType());
return UNSUPPORTED;
}
for (size_t i = 0UL; i < symbols->size(); i++) {
int64_t dim = 0L;
if (!(*symbols)[i].GetConstValue(dim)) {
GELOGW("SymbolicKernel compute unsupported, reason: get %zu input const value failed, node %s[%s].", i,
context->GetNodeName(), context->GetNodeType());
return UNSUPPORTED;
}
dims.emplace_back(dim);
}
return SUCCESS;
}
Status NormalizeInput(std::vector<int64_t> &input_indexes, const std::vector<int64_t> &input_dims) {
GE_ASSERT_TRUE(input_indexes.size() <= input_dims.size(),
"input indexes size: %zu should not more than input shape size: %zu", input_indexes.size(),
input_dims.size());
for (size_t i = 0UL; i < input_indexes.size(); i++) {
input_indexes[i] = input_indexes[i] < 0 ? input_indexes[i] + input_dims[i] : input_indexes[i];
GE_ASSERT_TRUE(input_indexes[i] >= 0, "input_indexes[%zu]=%lld", i, input_indexes[i]);
}
return SUCCESS;
}
void HandleMaskConflict(StridedSliceAttr &strided_slice_attr) {
strided_slice_attr.new_axis_mask = ((static_cast<uint64_t>(strided_slice_attr.new_axis_mask) &
static_cast<uint64_t>(strided_slice_attr.ellipsis_mask)) ^
static_cast<uint64_t>(strided_slice_attr.new_axis_mask));
strided_slice_attr.shrink_axis_mask = ((static_cast<uint64_t>(strided_slice_attr.shrink_axis_mask) &
static_cast<uint64_t>(strided_slice_attr.ellipsis_mask)) ^
static_cast<uint64_t>(strided_slice_attr.shrink_axis_mask));
strided_slice_attr.shrink_axis_mask = ((static_cast<uint64_t>(strided_slice_attr.shrink_axis_mask) &
static_cast<uint64_t>(strided_slice_attr.new_axis_mask)) ^
static_cast<uint64_t>(strided_slice_attr.shrink_axis_mask));
GELOGI("handle mask conflict, new_axis_mask: %lld, shrink_axis_mask: %lld", strided_slice_attr.new_axis_mask,
strided_slice_attr.shrink_axis_mask);
}
int64_t CountBitNum(const int64_t num) {
int64_t count = 0L;
if (num <= 0) {
return count;
}
for (uint64_t n = num; n > 0; n >>= 1) {
count += (n & 1L);
}
return count;
}
bool IsInEllipsisMaskRange(const std::pair<int64_t, int64_t> &ellipsis_mask_range, const int64_t pos) {
return ((pos >= ellipsis_mask_range.first) && (pos < ellipsis_mask_range.second));
}
Status AppendNewAxis(const std::pair<int64_t, int64_t> &ellipsis_mask_range, const int64_t new_axis_mask,
const std::vector<int64_t> &input_dims, std::vector<int64_t> &input_append_axis_shape,
StrdedSliceIndexInputs &index_input) {
const size_t begin_len = index_input.start_indexes.size();
int64_t new_axis_num = 0;
for (size_t i = 0L; i < begin_len; ++i) {
if ((static_cast<uint64_t>(new_axis_mask) & (1 << i)) > 0) {
new_axis_num++;
}
}
int64_t mask_pos = 0L;
for (size_t i = 0L; i < input_dims.size();) {
if ((static_cast<uint64_t>(new_axis_mask) & (1 << mask_pos)) > 0) {
if ((IsInEllipsisMaskRange(ellipsis_mask_range, static_cast<int64_t>(input_append_axis_shape.size())))) {
input_append_axis_shape.emplace_back(input_dims[i++]);
index_input.is_new_axis.emplace_back(false);
} else {
new_axis_num--;
input_append_axis_shape.emplace_back(1L);
index_input.is_new_axis.emplace_back(true);
mask_pos++;
}
} else {
input_append_axis_shape.emplace_back(input_dims[i++]);
index_input.is_new_axis.emplace_back(false);
mask_pos++;
}
}
while (new_axis_num-- > 0) {
input_append_axis_shape.emplace_back(1L);
index_input.is_new_axis.emplace_back(true);
}
GELOGI("Input shape after insert new axis: %s", SymbolicInferUtil::VectorToStr(input_append_axis_shape).c_str());
return SUCCESS;
}
std::pair<int64_t, int64_t> GetEllipsisMaskRange(const StridedSliceAttr &strided_slice_attr,
const int64_t slice_dim_num, const int64_t input_size) {
int64_t bit_count = CountBitNum(strided_slice_attr.new_axis_mask);
int64_t ellipsis_mask_num = input_size + bit_count - slice_dim_num + 1;
int64_t pos = 0L;
for (; pos < slice_dim_num; pos++) {
if ((static_cast<uint64_t>(strided_slice_attr.ellipsis_mask) & (1 << pos)) > 0) {
break;
}
}
if (pos == slice_dim_num) {
return std::make_pair(-1, -1);
}
GELOGI("ellipsis_mask_range: [%lld, %lld)", pos, pos + ellipsis_mask_num);
return std::make_pair(pos, pos + ellipsis_mask_num);
}
Status HandleEllipsisMask(const int64_t ellipsis_mask_index, const std::vector<int64_t> &input_dims,
StrdedSliceIndexInputs &index_input) {
GE_ASSERT_TRUE(index_input.start_indexes.size() == index_input.end_indexes.size(),
"start_index size: %zu should equal to end_index size:%zu", index_input.start_indexes.size(),
index_input.end_indexes.size());
GE_ASSERT_TRUE(index_input.start_indexes.size() == index_input.strides_indexes.size(),
"start_index size: %zu should equal to strides_index size:%zu", index_input.start_indexes.size(),
index_input.strides_indexes.size());
for (int64_t i = 0UL; i < static_cast<int64_t>(index_input.start_indexes.size()); i++) {
if (i == ellipsis_mask_index) {
index_input.start_indexes[i] = 0;
index_input.end_indexes[i] = input_dims[i];
index_input.strides_indexes[i] = 1;
break;
}
}
GELOGD("start index after insert handle ellipsis_mask: %s",
SymbolicInferUtil::VectorToStr(index_input.start_indexes).c_str());
GELOGD("end index after insert handle ellipsis_mask: %s",
SymbolicInferUtil::VectorToStr(index_input.end_indexes).c_str());
GELOGD("strides index after insert handle ellipsis_mask: %s",
SymbolicInferUtil::VectorToStr(index_input.strides_indexes).c_str());
return SUCCESS;
}
void GetShrinkAxisIndex(const int64_t shrink_axis_mask, const std::pair<int64_t, int64_t> &ellipsis_mask_range,
const int64_t index_size, std::set<int64_t> &shrink_axis_indexes) {
int64_t bit_pos = 0L;
for (int64_t i = 0UL; i < index_size; i++) {
if ((static_cast<uint64_t>(shrink_axis_mask) & (1 << bit_pos)) > 0) {
if (IsInEllipsisMaskRange(ellipsis_mask_range, static_cast<int64_t>(i))) {
continue;
}
shrink_axis_indexes.insert(i);
}
bit_pos++;
}
}
Status HandleShrinkAxisShape(const std::set<int64_t> &shrink_axis_indexes, StrdedSliceIndexInputs &index_input) {
for (const auto &shrink_axis_id : shrink_axis_indexes) {
GE_ASSERT_TRUE((shrink_axis_id < static_cast<int64_t>(index_input.start_indexes.size())) && (shrink_axis_id >= 0));
GELOGI("Change strideslice index to [%lld, %lld, 1] of dim[%lld]", index_input.start_indexes[shrink_axis_id],
index_input.end_indexes[shrink_axis_id], shrink_axis_id);
index_input.end_indexes[shrink_axis_id] = index_input.start_indexes[shrink_axis_id] + 1;
index_input.strides_indexes[shrink_axis_id] = 1;
}
return SUCCESS;
}
Status FillMissionIndex(const std::pair<int64_t, int64_t> &ellipsis_mask_range, const std::vector<int64_t> &input_dims,
StrdedSliceIndexInputs &index_input) {
std::vector<int64_t> origin_start_indexes = index_input.start_indexes;
std::vector<int64_t> origin_end_indexes = index_input.end_indexes;
std::vector<int64_t> origin_strides_indexes = index_input.strides_indexes;
GELOGD("origin_start_indexes before insert fill missing: %s",
SymbolicInferUtil::VectorToStr(origin_start_indexes).c_str());
GELOGD("origin_end_indexes before insert fill missing: %s",
SymbolicInferUtil::VectorToStr(origin_end_indexes).c_str());
GELOGD("origin_strides_indexes before insert fill missing: %s",
SymbolicInferUtil::VectorToStr(origin_strides_indexes).c_str());
auto ori_start_size = origin_start_indexes.size();
for (size_t i = ori_start_size; i < input_dims.size(); i++) {
origin_start_indexes.emplace_back(0L);
origin_end_indexes.emplace_back(input_dims[i]);
origin_strides_indexes.emplace_back(1L);
}
GE_ASSERT_SUCCESS(NormalizeInput(origin_start_indexes, input_dims));
GE_ASSERT_SUCCESS(NormalizeInput(origin_end_indexes, input_dims));
index_input.start_indexes.clear();
index_input.end_indexes.clear();
index_input.strides_indexes.clear();
int64_t start_index_pos = 0L;
for (size_t i = 0UL; i < input_dims.size(); i++) {
if (IsInEllipsisMaskRange(ellipsis_mask_range, static_cast<int64_t>(i))) {
index_input.start_indexes.emplace_back(0L);
index_input.end_indexes.emplace_back(input_dims[i]);
index_input.strides_indexes.emplace_back(1L);
if (static_cast<int64_t>(i) == ellipsis_mask_range.first) {
start_index_pos++;
}
continue;
}
index_input.start_indexes.emplace_back(std::min(origin_start_indexes[start_index_pos], input_dims[i] - 1));
index_input.end_indexes.emplace_back(std::min(origin_end_indexes[start_index_pos], input_dims[i]));
index_input.strides_indexes.emplace_back(origin_strides_indexes[start_index_pos]);
start_index_pos++;
}
GELOGD("start index after insert fill missing: %s",
SymbolicInferUtil::VectorToStr(index_input.start_indexes).c_str());
GELOGD("end index after insert handle fill missing: %s",
SymbolicInferUtil::VectorToStr(index_input.end_indexes).c_str());
GELOGD("strides index after insert handle fill missing: %s",
SymbolicInferUtil::VectorToStr(index_input.strides_indexes).c_str());
return SUCCESS;
}
void HandleBeginEndMask(const StridedSliceAttr &strided_slice_attr, const std::vector<int64_t> &input_dims,
const std::pair<int64_t, int64_t> &ellipsis_mask_range, StrdedSliceIndexInputs &index_input) {
int64_t mask_pos = 0L;
for (size_t i = 0UL; i < index_input.start_indexes.size(); i++) {
if (IsInEllipsisMaskRange(ellipsis_mask_range, static_cast<int64_t>(i))) {
if (static_cast<int64_t>(i) == ellipsis_mask_range.first) {
mask_pos++;
}
continue;
}
if ((static_cast<uint64_t>(strided_slice_attr.begin_mask) & (1 << mask_pos)) > 0) {
index_input.start_indexes[i] = (index_input.strides_indexes[i] > 0) ? 0 : input_dims[i] - 1;
}
if ((static_cast<uint64_t>(strided_slice_attr.end_mask) & (1 << mask_pos)) > 0) {
index_input.end_indexes[i] = (index_input.strides_indexes[i] > 0) ? input_dims[i] : 0;
}
mask_pos++;
}
GELOGI("start index after insert handle begin end mask: %s",
SymbolicInferUtil::VectorToStr(index_input.start_indexes).c_str());
GELOGI("end index after insert handle begin end mask: %s",
SymbolicInferUtil::VectorToStr(index_input.end_indexes).c_str());
GELOGI("strides index after insert handle begin end mask: %s",
SymbolicInferUtil::VectorToStr(index_input.strides_indexes).c_str());
}
Status CalcOutputShape(const int64_t shrink_axis_mask, const std::pair<int64_t, int64_t> &ellipsis_mask_range,
StrdedSliceIndexInputs &index_input, std::vector<Expression> &output_symbols_shape) {
std::set<int64_t> shrink_axis_indexes;
GetShrinkAxisIndex(shrink_axis_mask, ellipsis_mask_range, static_cast<int64_t>(index_input.start_indexes.size()),
shrink_axis_indexes);
GE_ASSERT_SUCCESS(HandleShrinkAxisShape(shrink_axis_indexes, index_input));
for (size_t i = 0UL; i < index_input.start_indexes.size(); i++) {
if (shrink_axis_indexes.count(i) > 0) {
continue;
}
GE_ASSERT_TRUE(index_input.strides_indexes[i] != 0L, "index_input.strides_indexes[%zu]=%lld", i,
index_input.strides_indexes[i]);
int64_t result_dim = std::max(
0L,
static_cast<int64_t>(std::ceil(static_cast<float>(index_input.end_indexes[i] - index_input.start_indexes[i]) /
static_cast<float>(index_input.strides_indexes[i]))));
auto output_dim = (index_input.is_new_axis[i] == true) ? 1L : result_dim;
output_symbols_shape.emplace_back(Symbol(output_dim));
}
return SUCCESS;
}
bool IsNeedEndStridedSlice(const int64_t strides_index, const int64_t end_index, const int64_t current_index) {
return (((strides_index > 0) && (current_index >= end_index)) ||
((strides_index < 0) && (current_index <= end_index)));
}
Status StridedSliceOutputSymbolsValue(const std::vector<Expression> &input_x_symbols,
const std::vector<int64_t> &input_dims, const StrdedSliceIndexInputs &index_input,
std::vector<Expression> &output_symbols_value) {
std::vector<Expression> last_output_symbols = input_x_symbols;
for (size_t i = 0UL; i < index_input.start_indexes.size(); i++) {
output_symbols_value.clear();
if (((index_input.end_indexes[i] - index_input.start_indexes[i]) * index_input.strides_indexes[i]) <= 0) {
GELOGW("value will become empty if end indexes: %lld less than start_indexes: %lld", index_input.end_indexes[i],
index_input.start_indexes[i]);
return SUCCESS;
}
int64_t block_size = std::accumulate(input_dims.begin() + i + 1, input_dims.end(), 1, std::multiplies<int64_t>());
int64_t block_num = static_cast<int64_t>(last_output_symbols.size()) / block_size / input_dims[i];
GELOGI("block num: %lld, input_dims: %lld block size[%zu] : %lld", block_num, input_dims[i], i, block_size);
for (int64_t j = 0L; j < block_num; j++) {
int64_t index = index_input.start_indexes[i];
while (!IsNeedEndStridedSlice(index_input.strides_indexes[i], index_input.end_indexes[i], index)) {
GE_ASSERT_TRUE(index_input.strides_indexes[i] != 0L, "index_input.strides_indexes[%zu]=%lld", i,
index_input.strides_indexes[i]);
GE_ASSERT_TRUE((static_cast<int64_t>(index) < input_dims[i]) && (static_cast<int64_t>(index) >= 0),
"index=%lld, input_dims[%zu]=%lld", index, i, input_dims[i]);
auto begin_iter = last_output_symbols.begin() + (j * input_dims[i] + index) * block_size;
auto end_iter = last_output_symbols.begin() + (j * input_dims[i] + index + 1) * block_size;
output_symbols_value.insert(output_symbols_value.end(), begin_iter, end_iter);
index += index_input.strides_indexes[i];
}
}
last_output_symbols = output_symbols_value;
}
return SUCCESS;
}
Status GetStridedSliceIndexInput(gert::InferSymbolComputeContext *context, StrdedSliceIndexInputs &index_input) {
auto ret = GetValueFromInputConstData(context, kStartInputIndex, index_input.start_indexes);
if (ret != SUCCESS) {
return ret;
}
ret = GetValueFromInputConstData(context, kEndInputIndex, index_input.end_indexes);
if (ret != SUCCESS) {
return ret;
}
ret = GetValueFromInputConstData(context, kStridesInputIndex, index_input.strides_indexes);
if (ret != SUCCESS) {
return ret;
}
return SUCCESS;
}
Status GetStridedSliceDIndexInput(const gert::InferSymbolComputeContext *context, StrdedSliceIndexInputs &index_input) {
auto attrs = context->GetAttrs();
GE_ASSERT_NOTNULL(attrs);
auto attr_start_input = attrs->GetListInt(static_cast<size_t>(StridedSliceDAttrIndex::kAttrStartInputIndex));
GE_ASSERT_NOTNULL(attr_start_input);
auto attr_end_input = attrs->GetListInt(static_cast<size_t>(StridedSliceDAttrIndex::kAttrEndInputIndex));
GE_ASSERT_NOTNULL(attr_end_input);
auto attr_strides_input = attrs->GetListInt(static_cast<size_t>(StridedSliceDAttrIndex::kAttrStridesInputIndex));
GE_ASSERT_NOTNULL(attr_strides_input);
for (size_t i = 0; i < attr_start_input->GetSize(); i++) {
index_input.start_indexes.push_back(attr_start_input->GetData()[i]);
}
for (size_t i = 0; i < attr_end_input->GetSize(); i++) {
index_input.end_indexes.push_back(attr_end_input->GetData()[i]);
}
for (size_t i = 0; i < attr_strides_input->GetSize(); i++) {
index_input.strides_indexes.push_back(attr_strides_input->GetData()[i]);
}
return SUCCESS;
}
Status GetStridedSliceMaskAttr(const gert::InferSymbolComputeContext *context, StridedSliceAttr &strided_slice_attr) {
auto attrs = context->GetAttrs();
GE_ASSERT_NOTNULL(attrs);
auto begin_mask_ptr = (strcmp(context->GetNodeType(), STRIDEDSLICE) == 0)
? attrs->GetInt(static_cast<size_t>(StridedSliceAttrIndex::kAttrBeginMaskIndex))
: attrs->GetInt(static_cast<size_t>(StridedSliceDAttrIndex::kAttrBeginMaskIndex));
auto end_mask_ptr = (strcmp(context->GetNodeType(), STRIDEDSLICE) == 0)
? attrs->GetInt(static_cast<size_t>(StridedSliceAttrIndex::kAttrEndMaskIndex))
: attrs->GetInt(static_cast<size_t>(StridedSliceDAttrIndex::kAttrEndMaskIndex));
auto ellipsis_mask_ptr = (strcmp(context->GetNodeType(), STRIDEDSLICE) == 0)
? attrs->GetInt(static_cast<size_t>(StridedSliceAttrIndex::kAttrEllipsisMaskIndex))
: attrs->GetInt(static_cast<size_t>(StridedSliceDAttrIndex::kAttrEllipsisMaskIndex));
auto new_axis_mask_ptr = (strcmp(context->GetNodeType(), STRIDEDSLICE) == 0)
? attrs->GetInt(static_cast<size_t>(StridedSliceAttrIndex::kAttrNewAxisMaskIndex))
: attrs->GetInt(static_cast<size_t>(StridedSliceDAttrIndex::kAttrNewAxisMaskIndex));
auto shrink_axis_mask_ptr =
(strcmp(context->GetNodeType(), STRIDEDSLICE) == 0)
? attrs->GetInt(static_cast<size_t>(StridedSliceAttrIndex::kAttrShrinkAxisMaskIndex))
: attrs->GetInt(static_cast<size_t>(StridedSliceDAttrIndex::kAttrShrinkAxisMaskIndex));
GE_ASSERT_NOTNULL(begin_mask_ptr);
GE_ASSERT_NOTNULL(end_mask_ptr);
GE_ASSERT_NOTNULL(ellipsis_mask_ptr);
GE_ASSERT_NOTNULL(new_axis_mask_ptr);
GE_ASSERT_NOTNULL(shrink_axis_mask_ptr);
strided_slice_attr.begin_mask = *begin_mask_ptr;
strided_slice_attr.end_mask = *end_mask_ptr;
strided_slice_attr.ellipsis_mask = *ellipsis_mask_ptr;
strided_slice_attr.new_axis_mask = *new_axis_mask_ptr;
strided_slice_attr.shrink_axis_mask = *shrink_axis_mask_ptr;
return SUCCESS;
}
Status HandleMaskAttr(const std::pair<int64_t, int64_t> &ellipsis_mask_range,
const std::vector<int64_t> &input_append_axis_shape, const StridedSliceAttr &strided_slice_attr,
StrdedSliceIndexInputs &index_input) {
GE_ASSERT_SUCCESS(HandleEllipsisMask(ellipsis_mask_range.first, input_append_axis_shape, index_input));
GE_ASSERT_SUCCESS(FillMissionIndex(ellipsis_mask_range, input_append_axis_shape, index_input));
HandleBeginEndMask(strided_slice_attr, input_append_axis_shape, ellipsis_mask_range, index_input);
return SUCCESS;
}
}
static graphStatus StridedSliceSymbolicKernelCompute(gert::InferSymbolComputeContext *context) {
GE_ASSERT_NOTNULL(context);
GELOGD("StridedSlice Symbolic Kernel in, node %s[%s].", context->GetNodeName(), context->GetNodeType());
StrdedSliceIndexInputs index_input;
Status ret = PARAM_INVALID;
if (strcmp(context->GetNodeType(), STRIDEDSLICE) == 0) {
ret = GetStridedSliceIndexInput(context, index_input);
} else if (strcmp(context->GetNodeType(), STRIDEDSLICED) == 0) {
ret = GetStridedSliceDIndexInput(context, index_input);
} else {
GELOGW("Node type: %s is not StridedSlice or StridedSliceD.", context->GetNodeType());
}
if (ret != SUCCESS) {
return ret;
}
std::vector<int64_t> input_x_dims;
if (!context->GetConstInputDims(kXInputIndex, input_x_dims)) {
return UNSUPPORTED;
}
auto input_x_symbols = context->GetInputSymbolTensor(kXInputIndex)->GetSymbolicValue();
if (input_x_symbols == nullptr) {
GELOGW("SymbolicKernel compute unsupported, reason: get input symbolic value failed, node %s[%s].",
context->GetNodeName(), context->GetNodeType());
return UNSUPPORTED;
}
StridedSliceAttr strided_slice_attr;
GetStridedSliceMaskAttr(context, strided_slice_attr);
HandleMaskConflict(strided_slice_attr);
std::pair<int64_t, int64_t> ellipsis_mask_range =
GetEllipsisMaskRange(strided_slice_attr, static_cast<int64_t>(index_input.start_indexes.size()),
static_cast<int64_t>(input_x_dims.size()));
std::vector<int64_t> input_append_axis_shape;
GE_ASSERT_SUCCESS(AppendNewAxis(ellipsis_mask_range, strided_slice_attr.new_axis_mask, input_x_dims,
input_append_axis_shape, index_input));
GE_ASSERT_SUCCESS(HandleMaskAttr(ellipsis_mask_range, input_append_axis_shape, strided_slice_attr, index_input));
auto out_symbols_tensor = context->GetOutputSymbolTensor(kOutputIndex);
GE_ASSERT_NOTNULL(out_symbols_tensor);
std::vector<Expression> output_symbols_shape;
GE_ASSERT_SUCCESS(
CalcOutputShape(strided_slice_attr.shrink_axis_mask, ellipsis_mask_range, index_input, output_symbols_shape));
GE_ASSERT_NOTNULL(out_symbols_tensor->MutableSymbolicValue());
out_symbols_tensor->MutableOriginSymbolShape().MutableDims() = output_symbols_shape;
auto output_symbols_value = out_symbols_tensor->MutableSymbolicValue();
GE_ASSERT_SUCCESS(
StridedSliceOutputSymbolsValue(*input_x_symbols, input_append_axis_shape, index_input, *output_symbols_value));
GELOGD("%s[%s] kernel success, %s", context->GetNodeName(), context->GetNodeType(),
SymbolicInferUtil::DumpSymbolTensor(*out_symbols_tensor).c_str());
return SUCCESS;
}
REGISTER_SYMBOLIC_KERNEL(StridedSlice, StridedSliceSymbolicKernelCompute);
REGISTER_SYMBOLIC_KERNEL(StridedSliceD, StridedSliceSymbolicKernelCompute);
}