* Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file range_infershape.cpp
* \brief
*/
#include <cmath>
#include <type_traits>
#include "log/log.h"
#include "op_host/util/fp16.h"
#include "register/op_impl_registry.h"
#include "util/math_util.h"
using namespace ge;
namespace ops {
constexpr size_t INT16_BITS_NUM = 16;
static bool IsTensorNull(const gert::InferShapeContext* context, const gert::Tensor* tensor)
{
auto tensorDataType = tensor->GetDataType();
auto ret = false;
switch (tensorDataType) {
case ge::DT_INT32: {
const int32_t* constDataPtr = tensor->GetData<int32_t>();
ret = (constDataPtr == nullptr) ? true : ret;
break;
}
case ge::DT_FLOAT: {
const float* constDataPtr = tensor->GetData<float>();
ret = (constDataPtr == nullptr) ? true : ret;
break;
}
case ge::DT_FLOAT16: {
const Ops::Base::fp16_t* constDataPtr = tensor->GetData<Ops::Base::fp16_t>();
ret = (constDataPtr == nullptr) ? true : ret;
break;
}
case ge::DT_BF16: {
const int16_t* constDataPtr = tensor->GetData<int16_t>();
ret = (constDataPtr == nullptr) ? true : ret;
break;
}
case ge::DT_INT64: {
const int64_t* constDataPtr = tensor->GetData<int64_t>();
ret = (constDataPtr == nullptr) ? true : ret;
break;
}
case ge::DT_DOUBLE: {
const double* constDataPtr = tensor->GetData<double>();
ret = (constDataPtr == nullptr) ? true : ret;
break;
}
default:
OP_LOGW(context->GetNodeName(), "aicore datatype not support.");
break;
}
return ret;
}
template <typename T>
static ge::graphStatus RangeGetConstValue(
gert::InferShapeContext* context, const gert::Tensor* tensor, std::vector<T>& value)
{
if (tensor->GetDataType() == ge::DT_INT32) {
const int32_t* constDataPtr = tensor->GetData<int32_t>();
OP_CHECK_NULL_WITH_CONTEXT(context, constDataPtr);
value.push_back(static_cast<T>(*constDataPtr));
OP_LOGD(context->GetNodeName(), "range get const value:%d", *constDataPtr);
} else if (tensor->GetDataType() == ge::DT_FLOAT) {
const float* constDataPtr = tensor->GetData<float>();
OP_CHECK_NULL_WITH_CONTEXT(context, constDataPtr);
value.push_back(static_cast<T>(*constDataPtr));
OP_LOGD(context->GetNodeName(), "range get const value:%f", *constDataPtr);
} else if (tensor->GetDataType() == ge::DT_FLOAT16) {
const Ops::Base::fp16_t* constDataPtr = tensor->GetData<Ops::Base::fp16_t>();
OP_CHECK_NULL_WITH_CONTEXT(context, constDataPtr);
float f32 = static_cast<float>(*constDataPtr);
value.push_back(static_cast<T>(f32));
OP_LOGD(context->GetNodeName(), "range get const value:%f", static_cast<float>(*constDataPtr));
} else if (tensor->GetDataType() == ge::DT_BF16) {
const int16_t* constDataPtr = tensor->GetData<int16_t>();
OP_CHECK_NULL_WITH_CONTEXT(context, constDataPtr);
int32_t f32hex = (static_cast<int32_t>(*constDataPtr)) << INT16_BITS_NUM;
float* f32ptr = reinterpret_cast<float*>(&f32hex);
value.push_back(static_cast<T>(*f32ptr));
OP_LOGD(context->GetNodeName(), "range get const value:%d", f32hex);
} else if (tensor->GetDataType() == ge::DT_INT64) {
const int64_t* constDataPtr = tensor->GetData<int64_t>();
OP_CHECK_NULL_WITH_CONTEXT(context, constDataPtr);
value.push_back(static_cast<T>(*constDataPtr));
OP_LOGD(context->GetNodeName(), "range get const value:%ld", *constDataPtr);
} else if (tensor->GetDataType() == ge::DT_DOUBLE) {
const double* constDataPtr = tensor->GetData<double>();
OP_CHECK_NULL_WITH_CONTEXT(context, constDataPtr);
value.push_back(static_cast<T>(*constDataPtr));
OP_LOGD(context->GetNodeName(), "range get const value:%lf", *constDataPtr);
return ge::GRAPH_SUCCESS;
} else {
return ge::GRAPH_SUCCESS;
}
return ge::GRAPH_SUCCESS;
}
template <typename T>
static ge::graphStatus CheckParam(gert::InferShapeContext* context, T start, T limit, T delta)
{
OP_CHECK_IF(
!(delta > (static_cast<T>(0)) || delta < (static_cast<T>(0))),
OP_LOGE(context->GetNodeName(), "delta is zero."), return ge::GRAPH_FAILED);
OP_CHECK_IF(
((limit > start) && (delta < 0)),
OP_LOGE(context->GetNodeName(), "increate is positive, but delta is negative"), return ge::GRAPH_FAILED);
OP_CHECK_IF(
((limit < start) && (delta > 0)),
OP_LOGE(context->GetNodeName(), "increate is negative, but delta is positive"), return ge::GRAPH_FAILED);
return ge::GRAPH_SUCCESS;
}
template <typename T>
static ge::graphStatus CalculateOutputNum(
gert::InferShapeContext* context, const gert::Tensor* tensorStart, const gert::Tensor* tensorLimit,
const gert::Tensor* tensorDelta, uint64_t& totalNum)
{
std::vector<T> startMultiples;
std::vector<T> limitMultiples;
std::vector<T> deltaMultiples;
OP_CHECK_IF(
RangeGetConstValue<T>(context, tensorStart, startMultiples) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "get start const value fail."), return ge::GRAPH_FAILED);
OP_CHECK_IF(
RangeGetConstValue<T>(context, tensorLimit, limitMultiples) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "get limit const value fail."), return ge::GRAPH_FAILED);
OP_CHECK_IF(
RangeGetConstValue<T>(context, tensorDelta, deltaMultiples) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "get delta const value fail."), return ge::GRAPH_FAILED);
if (startMultiples.empty() || limitMultiples.empty() || deltaMultiples.empty()) {
totalNum = -1;
return ge::GRAPH_SUCCESS;
}
OP_CHECK_IF(
CheckParam(context, startMultiples[0], limitMultiples[0], deltaMultiples[0]) != ge::GRAPH_SUCCESS,
OP_LOGE(
context->GetNodeName(), "CheckParam fail, start: %lf, limit: %lf, delta: %lf",
static_cast<double>(startMultiples[0]), static_cast<double>(limitMultiples[0]),
static_cast<double>(deltaMultiples[0])),
return ge::GRAPH_FAILED);
const gert::RuntimeAttrs* attrs = context->GetAttrs();
OP_CHECK_NULL_WITH_CONTEXT(context, attrs);
const bool* isClosed = attrs->GetAttrPointer<bool>(0);
OP_CHECK_NULL_WITH_CONTEXT(context, isClosed);
if (*isClosed) {
totalNum = static_cast<uint64_t>((limitMultiples[0] - startMultiples[0]) / deltaMultiples[0] + 1);
} else {
if (std::is_same<T, int64_t>::value) {
totalNum = static_cast<uint64_t>(Ops::Base::CeilDiv(
static_cast<int64_t>(limitMultiples[0]) - static_cast<int64_t>(startMultiples[0]),
static_cast<int64_t>(deltaMultiples[0])));
} else {
std::vector<double> startDouble;
std::vector<double> limitDouble;
std::vector<double> deltaDouble;
OP_CHECK_IF(
RangeGetConstValue<double>(context, tensorStart, startDouble) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "get start const value fail."), return ge::GRAPH_FAILED);
OP_CHECK_IF(
RangeGetConstValue<double>(context, tensorLimit, limitDouble) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "get limit const value fail."), return ge::GRAPH_FAILED);
OP_CHECK_IF(
RangeGetConstValue<double>(context, tensorDelta, deltaDouble) != ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "get delta const value fail."), return ge::GRAPH_FAILED);
totalNum = static_cast<uint64_t>(std::ceil((limitDouble[0] - startDouble[0]) / deltaDouble[0]));
}
}
OP_LOGD(
context->GetNodeName(), "CalculateOutputNum: start: %lf, limit: %lf, delta: %lf, total_num: %lu",
static_cast<double>(startMultiples[0]), static_cast<double>(limitMultiples[0]),
static_cast<double>(deltaMultiples[0]), totalNum);
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus RangeInferShapeFunc(gert::InferShapeContext* context)
{
OP_LOGD(context->GetNodeName(), "start running RangeInferShapeFunc...");
auto startTensor = context->GetInputTensor(0);
OP_CHECK_NULL_WITH_CONTEXT(context, startTensor);
auto limitTensor = context->GetInputTensor(1);
OP_CHECK_NULL_WITH_CONTEXT(context, limitTensor);
auto deltaTensor = context->GetInputTensor(2);
OP_CHECK_NULL_WITH_CONTEXT(context, deltaTensor);
auto out_shape = context->GetOutputShape(0);
OP_CHECK_NULL_WITH_CONTEXT(context, out_shape);
if (IsTensorNull(context, startTensor) || IsTensorNull(context, limitTensor) ||
IsTensorNull(context, deltaTensor)) {
out_shape->SetDimNum(1);
out_shape->SetDim(0, -1);
OP_LOGD(context->GetNodeName(), "tensor is null, return.");
return GRAPH_SUCCESS;
}
auto outDesc = context->GetOutputDesc(0);
OP_CHECK_NULL_WITH_CONTEXT(context, outDesc);
DataType outDtype = outDesc->GetDataType();
uint64_t totalNum = 0;
switch (outDtype) {
case ge::DT_FLOAT:
OP_CHECK_IF(
CalculateOutputNum<double>(context, startTensor, limitTensor, deltaTensor, totalNum) !=
ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "calculate output_total_num value fail."), return ge::GRAPH_FAILED);
break;
case ge::DT_INT32:
case ge::DT_INT64:;
OP_CHECK_IF(
CalculateOutputNum<int64_t>(context, startTensor, limitTensor, deltaTensor, totalNum) !=
ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "calculate output_total_num value fail."), return ge::GRAPH_FAILED);
break;
default:
OP_CHECK_IF(
CalculateOutputNum<float>(context, startTensor, limitTensor, deltaTensor, totalNum) !=
ge::GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "calculate output_total_num value fail."), return ge::GRAPH_FAILED);
break;
}
out_shape->SetDimNum(1);
out_shape->SetDim(0, totalNum);
return GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(Range).InputsDataDependency({0, 1, 2}).InferShape(RangeInferShapeFunc);
}