* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file slice_infershape.cpp
* \brief
*/
#include "register/op_impl_registry.h"
#include "log/log.h"
#include "util/const_util.h"
#include "util/shape_util.h"
#include "op_util.h"
using namespace ge;
namespace ops {
static constexpr size_t SLICE_IN_X = 0;
static constexpr size_t SLICE_IN_OFFSET = 1;
static constexpr size_t SLICE_IN_SIZE = 2;
static constexpr size_t SLICE_OUT_Y = 0;
struct SliceConstParams {
gert::Shape offset;
gert::Shape size;
int64_t offset_num;
int64_t size_num;
bool is_offset_const;
bool is_size_const;
std::string to_string() const
{
std::string result = "offset_const:" + Ops::Base::ToString(offset);
result += " size_const:" + Ops::Base::ToString(size);
result += " offset_num:" + std::to_string(offset_num);
result += " size_num:" + std::to_string(size_num);
return result;
}
void init()
{
offset.SetDimNum(0);
size.SetDimNum(0);
offset_num = 0;
size_num = 0;
is_size_const = true;
is_offset_const = true;
}
};
static int64_t SliceInferRankFn(
gert::InferShapeContext* context, const gert::Shape* x_shape, const SliceConstParams& slice_infer_info)
{
int64_t offset_num =
slice_infer_info.is_offset_const ? slice_infer_info.offset.GetDimNum() : slice_infer_info.offset_num;
int64_t size_num = slice_infer_info.is_size_const ? slice_infer_info.size.GetDimNum() : slice_infer_info.size_num;
int64_t x_shape_rank = Ops::Base::IsUnknownRank(*x_shape) ? -1 : x_shape->GetDimNum();
int64_t output_rank = std::max(offset_num, size_num);
output_rank = std::max(output_rank, x_shape_rank);
(void)context;
return output_rank;
}
static bool CheckSliceInfo(const gert::Shape* x_shape, const SliceConstParams& slice_infer_info)
{
const bool is_unknown_rank_x = Ops::Base::IsUnknownRank(*x_shape);
if (is_unknown_rank_x) {
OP_LOGD("CheckSliceInfo", "input is unknown rank, no nedd check.");
return true;
}
const size_t input_dim = x_shape->GetDimNum();
if (slice_infer_info.is_offset_const) {
OP_LOGD("CheckSliceInfo", "will check offset const value.");
OP_CHECK_IF(
slice_infer_info.offset.GetDimNum() != input_dim,
OP_LOGE(
"CheckSliceInfo", "%s",
ops::ConcatString(
"offset num and input rank must be the same, but offset_value is ",
Ops::Base::ToString(slice_infer_info.offset).c_str(), ", input shape is ",
Ops::Base::ToString(*x_shape).c_str()).c_str()),
return false);
}
if (slice_infer_info.is_size_const) {
OP_LOGD("CheckSliceInfo", "will check size const value.");
OP_CHECK_IF(
slice_infer_info.size.GetDimNum() != input_dim,
OP_LOGE(
"CheckSliceInfo", "%s",
ops::ConcatString(
"size num and input rank must be the same, but offset_value is ",
Ops::Base::ToString(slice_infer_info.size).c_str(), ", input shape is ",
Ops::Base::ToString(*x_shape).c_str()).c_str()),
return false);
}
return true;
}
static graphStatus SliceInferShapeFnWithConstSize(
gert::InferShapeContext* context, const gert::Shape* x_shape, const SliceConstParams& slice_infer_info,
gert::Shape* out_shape)
{
OP_LOGD(context->GetNodeName(), "start to do SliceInferShapeFnWithConstSize");
OP_LOGD(context->GetNodeName(), "slice input shape is %s", Ops::Base::ToString(*x_shape).c_str());
OP_LOGD(context->GetNodeName(), "slice const info is %s", slice_infer_info.to_string().c_str());
*out_shape = slice_infer_info.size;
const bool is_unknown_rank_x = Ops::Base::IsUnknownRank(*x_shape);
const size_t output_rank = out_shape->GetDimNum();
for (size_t i = 0; i < output_rank; ++i) {
int64_t size_value = out_shape->GetDim(i);
OP_CHECK_IF(
size_value < -1,
OP_LOGE(
context->GetNodeName(), "%s",
ops::ConcatString(
"the value of size can not < -1, but is ", Ops::Base::ToString(slice_infer_info.size).c_str()).c_str()),
return ge::GRAPH_FAILED);
if (!is_unknown_rank_x && slice_infer_info.is_offset_const) {
int64_t x_dim_value = x_shape->GetDim(i);
if (size_value == -1 && x_dim_value != -1) {
out_shape->SetDim(i, x_dim_value - slice_infer_info.offset.GetDim(i));
}
}
}
OP_LOGD(context->GetNodeName(), "output y = %s", Ops::Base::ToString(*out_shape).c_str());
OP_LOGD(context->GetNodeName(), "End to do SliceInferShapeFn with const size");
return ge::GRAPH_SUCCESS;
}
static graphStatus SliceInferShapeFn(
gert::InferShapeContext* context, const gert::Shape* x_shape, const SliceConstParams& slice_infer_info,
gert::Shape* out_shape)
{
OP_LOGD(context->GetNodeName(), "start to do SliceInferShapeFn");
OP_LOGD(context->GetNodeName(), "slice input shape is %s", Ops::Base::ToString(*x_shape).c_str());
OP_LOGD(context->GetNodeName(), "slice const info is %s", slice_infer_info.to_string().c_str());
OP_CHECK_IF(
!CheckSliceInfo(x_shape, slice_infer_info), OP_LOGE(context->GetNodeName(), "check input info failed."),
return ge::GRAPH_FAILED);
if (slice_infer_info.is_size_const) {
return SliceInferShapeFnWithConstSize(context, x_shape, slice_infer_info, out_shape);
}
OP_LOGD(context->GetNodeName(), "do slice info shape with tensor size.");
const int64_t output_rank = SliceInferRankFn(context, x_shape, slice_infer_info);
OP_LOGD(context->GetNodeName(), "try to get the rank of output = %ld.", output_rank);
if (output_rank <= 0) {
OP_LOGD(context->GetNodeName(), "set the output shape = -2.");
Ops::Base::SetUnknownRank(*out_shape);
return ge::GRAPH_SUCCESS;
}
OP_LOGD(context->GetNodeName(), "set the output shape all -1.");
Ops::Base::SetUnknownShape(output_rank, *out_shape);
return ge::GRAPH_SUCCESS;
}
static graphStatus InferShape4Slice(gert::InferShapeContext* context)
{
const gert::Shape* x_shape = context->GetInputShape(SLICE_IN_X);
OP_CHECK_NULL_WITH_CONTEXT(context, x_shape);
gert::Shape* out_shape = context->GetOutputShape(SLICE_OUT_Y);
OP_CHECK_NULL_WITH_CONTEXT(context, out_shape);
SliceConstParams slice_infer_info;
slice_infer_info.init();
if (!Ops::Base::GetConstIntToShape(context, SLICE_IN_OFFSET, slice_infer_info.offset)) {
slice_infer_info.is_offset_const = false;
OP_LOGD(context->GetNodeName(), "Get const of size unsuccessful, will do dynamic infershape.");
slice_infer_info.offset.SetDimNum(0);
const gert::Shape* offset_shape = context->GetInputShape(SLICE_IN_OFFSET);
OP_CHECK_NULL_WITH_CONTEXT(context, offset_shape);
slice_infer_info.offset_num = offset_shape->GetDimNum() == 0 ? 1 : offset_shape->GetDim(0);
}
if (!Ops::Base::GetConstIntToShape(context, SLICE_IN_SIZE, slice_infer_info.size)) {
slice_infer_info.is_size_const = false;
OP_LOGD(context->GetNodeName(), "Get const of offset unsuccessful, will do dynamic infershape.");
slice_infer_info.size.SetDimNum(0);
const gert::Shape* size_shape = context->GetInputShape(SLICE_IN_SIZE);
OP_CHECK_NULL_WITH_CONTEXT(context, size_shape);
slice_infer_info.size_num = size_shape->GetDimNum() == 0 ? 1 : size_shape->GetDim(0);
}
return SliceInferShapeFn(context, x_shape, slice_infer_info, out_shape);
}
IMPL_OP_INFERSHAPE(Slice).InferShape(InferShape4Slice).InputsDataDependency({SLICE_IN_OFFSET, SLICE_IN_SIZE});
}