* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "range_kernel.h"
#include <cfloat>
#include <memory>
#include <set>
#include "framework/common/debug/log.h"
#include "framework/common/framework_types_internal.h"
#include "framework/common/util.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/ge_inner_error_codes.h"
#include "graph/utils/type_utils.h"
#include "host_kernels/kernel_factory.h"
namespace ge {
namespace {
constexpr size_t kRangeInputNum = 3;
constexpr uint32_t kRangeDimNum = 0;
constexpr size_t kStartIndex = 0;
constexpr size_t kLimitIndex = 1;
constexpr size_t kDeltaIndex = 2;
const std::set<DataType> kRangeSupportedType = {DT_INT32, DT_FLOAT, DT_DOUBLE, DT_INT64};
}
Status RangeKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstGeTensorPtr> &input,
std::vector<GeTensorPtr> &v_output) {
GELOGD("RangeKernel in");
if (op_desc_ptr == nullptr) {
GELOGE(PARAM_INVALID, "Parameter's invalid, input opDescPtr is nullptr.");
return PARAM_INVALID;
}
Status ret = RangeCheck(input);
if (ret != SUCCESS) {
return ret;
}
GeTensorPtr output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0));
if (output_ptr == nullptr) {
GELOGE(MEMALLOC_FAILED, "Make shared failed");
return MEMALLOC_FAILED;
}
ConstGeTensorPtr start = input.at(kStartIndex);
ConstGeTensorPtr limit = input.at(kLimitIndex);
ConstGeTensorPtr delta = input.at(kDeltaIndex);
DataType data_type = delta->GetTensorDesc().GetDataType();
if (data_type == DT_FLOAT) {
if (GetRange(*reinterpret_cast<const float *>(start->GetData().data()),
*reinterpret_cast<const float *>(limit->GetData().data()),
*reinterpret_cast<const float *>(delta->GetData().data()), output_ptr) != SUCCESS) {
return PARAM_INVALID;
}
} else if (data_type == DT_INT32) {
if (GetRange(*reinterpret_cast<const int32_t *>(start->GetData().data()),
*reinterpret_cast<const int32_t *>(limit->GetData().data()),
*reinterpret_cast<const int32_t *>(delta->GetData().data()), output_ptr) != SUCCESS) {
return PARAM_INVALID;
}
} else if (data_type == DT_INT64) {
if (GetRange(*reinterpret_cast<const int64_t *>(start->GetData().data()),
*reinterpret_cast<const int64_t *>(limit->GetData().data()),
*reinterpret_cast<const int64_t *>(delta->GetData().data()), output_ptr) != SUCCESS) {
return PARAM_INVALID;
}
} else if (data_type == DT_DOUBLE) {
if (GetRange(*reinterpret_cast<const double *>(start->GetData().data()),
*reinterpret_cast<const double *>(limit->GetData().data()),
*reinterpret_cast<const double *>(delta->GetData().data()), output_ptr) != SUCCESS) {
return PARAM_INVALID;
}
}
output_ptr->MutableTensorDesc().SetDataType(data_type);
v_output.push_back(output_ptr);
return SUCCESS;
}
Status RangeKernel::RangeCheck(const std::vector<ConstGeTensorPtr> &input) const {
if (input.size() != kRangeInputNum) {
GELOGI("The number of input for Range must be %zu.", kRangeInputNum);
return NOT_CHANGED;
}
ConstGeTensorPtr start = input.at(0);
ConstGeTensorPtr limit = input.at(1);
ConstGeTensorPtr delta = input.at(2);
GE_CHECK_NOTNULL(start);
GE_CHECK_NOTNULL(limit);
GE_CHECK_NOTNULL(delta);
if (start->GetData().size() == 0 || limit->GetData().size() == 0 || delta->GetData().size() == 0) {
GELOGI("Data size check skipped: empty input. start: %zu, limit: %zu, delta: %zu", start->GetData().size(), limit->GetData().size(),
delta->GetData().size());
return NOT_CHANGED;
}
DataType type = start->GetTensorDesc().GetDataType();
if ((type != limit->GetTensorDesc().GetDataType()) || (type != delta->GetTensorDesc().GetDataType())) {
GELOGI("Data type of inputs for Range not matched.");
return NOT_CHANGED;
}
size_t range_dim = static_cast<size_t>(kRangeDimNum);
bool all_scalar = (start->GetTensorDesc().GetShape().GetDimNum() == range_dim) &&
(limit->GetTensorDesc().GetShape().GetDimNum() == range_dim) &&
(delta->GetTensorDesc().GetShape().GetDimNum() == range_dim);
if (!all_scalar) {
GELOGI("Inputs for Range are not all scalars.");
return NOT_CHANGED;
}
if (kRangeSupportedType.find(type) == kRangeSupportedType.end()) {
GELOGI("Range does not support this Data type: %s", TypeUtils::DataTypeToSerialString(type).c_str());
return NOT_CHANGED;
}
return SUCCESS;
}
template <typename T>
bool RangeKernel::IsZero(const T &element) const {
return (element == 0);
}
template <>
bool RangeKernel::IsZero<float32_t>(const float32_t &element) const {
return (fabs(element) < FLT_EPSILON);
}
template <>
bool RangeKernel::IsZero<float64_t>(const float64_t &element) const {
return (fabs(element) < DBL_EPSILON);
}
template <typename T>
Status RangeKernel::GetRange(const T start, const T limit, const T delta, GeTensorPtr &output) const {
if (IsZero(delta)) {
GELOGE(PARAM_INVALID, "Requires delta != 0");
return PARAM_INVALID;
}
if (start > limit && delta > 0) {
GELOGE(PARAM_INVALID, "Requires start <= limit when delta > 0");
return PARAM_INVALID;
}
if (start < limit && delta < 0) {
GELOGE(PARAM_INVALID, "Requires start >= limit when delta < 0");
return PARAM_INVALID;
}
int64_t size = static_cast<int64_t>(std::is_integral<T>::value ? ((std::abs(limit - start)
+ std::abs(delta) - 1) / std::abs(delta))
: std::ceil(std::abs((limit - start) / delta)));
output->MutableTensorDesc().SetShape(GeShape());
if (size > 0) {
unique_ptr<T[]> buf(new (std::nothrow) T[size]);
if (buf == nullptr) {
GELOGE(MEMALLOC_FAILED, "New buf failed.");
return MEMALLOC_FAILED;
}
T val = start;
for (int64_t i = 0; i < size; ++i) {
buf[i] = val;
val += delta;
}
if (output->SetData(reinterpret_cast<uint8_t *>(buf.get()), size * sizeof(T)) != GRAPH_SUCCESS) {
GELOGW("GetRange: SetData failed");
}
output->MutableTensorDesc().SetShape(GeShape({size}));
}
return SUCCESS;
}
REGISTER_COMPUTE_NODE_KERNEL(RANGE, RangeKernel);
}