* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef MATH_COMMON_OP_HOST_INPUT_UTIL_H
#define MATH_COMMON_OP_HOST_INPUT_UTIL_H
#include <tuple>
#include <array>
#include "log/log.h"
#include "runtime/continuous_vector.h"
#include "graph/types.h"
namespace Ops::Math {
static constexpr size_t NCHW_N_DIM = 0U;
static constexpr size_t NCHW_C_DIM = 1U;
static constexpr size_t NCHW_H_DIM = 2U;
static constexpr size_t NCHW_W_DIM = 3U;
static constexpr size_t NHWC_N_DIM = 0U;
static constexpr size_t NHWC_H_DIM = 1U;
static constexpr size_t NHWC_W_DIM = 2U;
static constexpr size_t NHWC_C_DIM = 3U;
static constexpr size_t NHWC_DIM_NUM = 4U;
static constexpr size_t UNPACK_TWO_ATTRS = 2U;
static constexpr size_t UNPACK_FOUR_ATTRS = 4U;
namespace {
* @brief 检查固定长度的 ListInt 类型属性值
* @tparam unpackLen 解包长度
* @tparam T context类型
* @param [in] context infershape 或 tiling 上下文
* @param [in] attrName 属性名称
* @param [in] vec 属性vector
* @param [in] elementValidator 元素校验方法,对元素挨个进行校验
* @return 执行结果,false: 失败,true: 成功
*/
template <size_t unpackLen, typename T>
static inline bool CheckUnpackFixedDimListIntAttr(
T* context, const char* attrName, const gert::TypedContinuousVector<int64_t>*& vec,
std::function<bool(int64_t)> elementValidator)
{
static_assert(unpackLen > 0, "unpackLen should be positive");
OP_CHECK_IF(vec == nullptr, OP_LOGE(context, "attr %s is nullptr!", attrName), return false);
size_t packSize = vec->GetSize();
OP_CHECK_IF(
(packSize != unpackLen),
OP_LOGE(context, "The size of attr %s must be %lu, but got %lu", attrName, unpackLen, packSize), return false);
for (size_t i = 0; i < packSize; ++i) {
int64_t value = vec->GetData()[i];
if (!elementValidator(value)) {
OP_LOGE(context, "The %lu-th element of attr %s is invalid, value: %ld", i, attrName, value);
return false;
}
}
return true;
}
* @brief 检查自适应长度的 ListInt 类型属性值
* @tparam unpackLen 解包长度
* @tparam T context类型
* @param [in] context infershape 或 tiling 上下文
* @param [in] attrName 属性名称
* @param [in] vec 属性vector
* @param [in] elementValidator 元素校验方法,对元素挨个进行校验
* @return 执行结果,false: 失败,true: 成功
*/
template <size_t unpackLen, typename T>
static inline bool CheckUnpackAdaptDimListIntAttr(
T* context, const char* attrName, const gert::TypedContinuousVector<int64_t>*& vec,
std::function<bool(int64_t)> elementValidator)
{
static_assert(unpackLen > 0, "unpackLen should be positive");
OP_CHECK_IF(vec == nullptr, OP_LOGE(context, "attr %s is nullptr!", attrName), return false);
size_t packSize = vec->GetSize();
if (unlikely(packSize == 0 || packSize > unpackLen || unpackLen % packSize != 0)) {
if constexpr (unpackLen == UNPACK_TWO_ATTRS) {
OP_LOGE(context, "The size of attr %s must be 1 or 2, but got %lu", attrName, packSize);
} else if constexpr (unpackLen == UNPACK_FOUR_ATTRS) {
OP_LOGE(context, "The size of attr %s must be 1, 2 or 4, but got %lu", attrName, packSize);
} else {
OP_LOGE(context, "The size of attr %s must be divisor of %lu, but got %lu", attrName, unpackLen, packSize);
}
return false;
}
for (size_t i = 0; i < packSize; ++i) {
int64_t value = vec->GetData()[i];
if (!elementValidator(value)) {
OP_LOGE(context, "The %lu-th element of attr %s is invalid, value: %ld", i, attrName, value);
return false;
}
}
return true;
}
* @brief 不安全解包函数。解包 ListInt 类型属性值,并将输入的元素自动扩展到输出参数长度,依次赋值到输出参数中。
* 仅解包,不做参数检查,由调用者做检查。需满足约束: \n
* - 属性size 不为 0 \n
* - 属性size 必须为解包长度的约数
* @param [in] checkedVec 属性vector
* @param [in] Is... 索引序列,从0开始,长度等于args个数
* @param [out] args 解包结果
* @return 解包结果,元素个数为 unpackLen
*/
template <size_t... Is, typename... Args>
static inline void UnsafeUnpackListIntAttr(
const gert::TypedContinuousVector<int64_t>*& checkedVec, std::index_sequence<Is...>, Args&... args)
{
static_assert(sizeof...(Args) == sizeof...(Is), "Number of arguments must match length of sequence");
size_t partLen = sizeof...(Is) / checkedVec->GetSize();
((args = checkedVec->GetData()[Is / partLen]), ...);
}
}
* @brief 解包固定长度的 ListInt 类型属性
* @tparam unpackLen 解包长度
* @tparam T context类型
* @param [in] context infershape 或 tiling 上下文
* @param [in] attrName 属性名称
* @param [in] vec 属性vector
* @param [in] elementValidator 元素校验方法,对元素挨个进行校验
* @return
* - status: 执行结果,GRAPH_FAILED: 失败,GRAPH_SUCCESS: 成功
* - result: 解包结果,元素个数为 unpackLen
*/
template <size_t unpackLen, typename T>
static inline std::tuple<ge::graphStatus, std::array<int64_t, unpackLen>> UnpackFixedDimListIntAttr(
T* context, const char* attrName, const gert::TypedContinuousVector<int64_t>*& vec,
std::function<bool(int64_t)> elementValidator)
{
if (!CheckUnpackFixedDimListIntAttr<unpackLen>(context, attrName, vec, elementValidator)) {
return {ge::GRAPH_FAILED, {}};
}
std::array<int64_t, unpackLen> value{};
for (size_t i = 0; i < vec->GetSize(); ++i) {
value[i] = vec->GetData()[i];
}
return {ge::GRAPH_SUCCESS, value};
}
* @brief 解包自适应长度的 ListInt 类型属性,将输入的元素自动扩展到解包长度
* @tparam unpackLen 解包长度
* @tparam T context类型
* @param [in] context infershape 或 tiling 上下文
* @param [in] attrName 属性名称
* @param [in] vec 属性vector
* @param [in] elementValidator 元素校验方法,对元素挨个进行校验
* @return
- status: 执行结果,GRAPH_FAILED: 失败,GRAPH_SUCCESS: 成功
- result: 解包结果,元素个数为 unpackLen
*/
template <size_t unpackLen, typename T>
static inline std::tuple<ge::graphStatus, std::array<int64_t, unpackLen>> UnpackAdaptDimListIntAttr(
T* context, const char* attrName, const gert::TypedContinuousVector<int64_t>*& vec,
std::function<bool(int64_t)> elementValidator)
{
if (!CheckUnpackAdaptDimListIntAttr<unpackLen>(context, attrName, vec, elementValidator)) {
return {ge::GRAPH_FAILED, {}};
}
std::array<int64_t, unpackLen> value{};
size_t partLen = unpackLen / vec->GetSize();
for (size_t i = 0; i < unpackLen; ++i) {
value[i] = vec->GetData()[i / partLen];
}
return {ge::GRAPH_SUCCESS, value};
}
* @brief 解包固定长度的 ListInt 类型属性
* @tparam unpackLen 解包长度
* @tparam T context类型
* @param [in] context infershape 或 tiling 上下文
* @param [in] attrName 属性名称
* @param [in] vec 属性vector
* @param [in] elementValidator 元素校验方法,对元素挨个进行校验
* @param [out] args 解包结果,参数个数为 unpackLen
* @return 执行结果,GRAPH_FAILED: 失败,GRAPH_SUCCESS: 成功
*/
template <size_t unpackLen, typename T, typename... Args>
static inline ge::graphStatus UnpackFixedDimListIntAttr(
T* context, const char* attrName, const gert::TypedContinuousVector<int64_t>*& vec,
std::function<bool(int64_t)> elementValidator, Args&... args)
{
static_assert(unpackLen > 0, "unpackLen should be positive");
static_assert(sizeof...(Args) == unpackLen, "Number of arguments must match template paremeter unpackLen");
static_assert((std::is_same_v<Args, int64_t> && ...), "All arguments must be of type int64_t");
if (!CheckUnpackFixedDimListIntAttr<unpackLen>(context, attrName, vec, elementValidator)) {
return ge::GRAPH_FAILED;
}
UnsafeUnpackListIntAttr(vec, std::make_index_sequence<unpackLen>{}, args...);
return ge::GRAPH_SUCCESS;
}
* @brief 解包自适应长度的 ListInt 类型属性,将输入的元素自动扩展到解包长度
* @tparam unpackLen 解包长度
* @tparam T context类型
* @param [in] context infershape 或 tiling 上下文
* @param [in] attrName 属性名称
* @param [in] vec 属性vector
* @param [in] elementValidator 元素校验方法,对元素挨个进行校验
* @param [out] args 解包结果,参数个数为 unpackLen
* @return 执行结果,GRAPH_FAILED: 失败,GRAPH_SUCCESS: 成功
*/
template <size_t unpackLen, typename T, typename... Args>
static inline ge::graphStatus UnpackAdaptDimListIntAttr(
T* context, const char* attrName, const gert::TypedContinuousVector<int64_t>*& vec,
std::function<bool(int64_t)> elementValidator, Args&... args)
{
static_assert(unpackLen > 0, "unpackLen should be positive");
static_assert(sizeof...(Args) == unpackLen, "Number of arguments must match template paremeter unpackLen");
static_assert((std::is_same_v<Args, int64_t> && ...), "All arguments must be of type int64_t");
if (!CheckUnpackAdaptDimListIntAttr<unpackLen>(context, attrName, vec, elementValidator)) {
return ge::GRAPH_FAILED;
}
UnsafeUnpackListIntAttr(vec, std::make_index_sequence<unpackLen>{}, args...);
return ge::GRAPH_SUCCESS;
}
* @brief 按NCHW顺序获取图像数据shape的维度,仅支持 NCHW/NHWC 格式
* @tparam T context类型
* @param [in] context infershape 或 tiling 上下文
* @param [in] shape 输入形状
* @param [in] format 输入数据格式
* @return 结果
* - status: 执行结果,GRAPH_FAILED: 失败,GRAPH_SUCCESS: 成功
* - result: 按NCHW顺序的shape的维度
*/
template <typename T>
static inline std::tuple<ge::graphStatus, std::array<int64_t, NHWC_DIM_NUM>> GetImgDataDimsByNCHWOrder(
T* context, const gert::Shape& shape, const ge::Format& format)
{
std::array<int64_t, NHWC_DIM_NUM> dims{0};
int64_t dimNum = shape.GetDimNum();
if (unlikely(dimNum != NHWC_DIM_NUM)) {
OP_LOGE(context, "input shape dim num should be 4, but got %ld", dimNum);
return {ge::GRAPH_FAILED, dims};
}
if (format == ge::Format::FORMAT_NCHW) {
return {
ge::GRAPH_SUCCESS,
{
shape.GetDim(NCHW_N_DIM),
shape.GetDim(NCHW_C_DIM),
shape.GetDim(NCHW_H_DIM),
shape.GetDim(NCHW_W_DIM),
}};
}
if (format == ge::Format::FORMAT_NHWC) {
return {
ge::GRAPH_SUCCESS,
{
shape.GetDim(NHWC_N_DIM),
shape.GetDim(NHWC_C_DIM),
shape.GetDim(NHWC_H_DIM),
shape.GetDim(NHWC_W_DIM),
}};
}
OP_LOGE(context, "The input data format must be NCHW or NHWC, but got %s", ge::GetFormatName(format));
return {ge::GRAPH_FAILED, dims};
}
* @brief 按NCHW顺序获取图像数据shape的维度,仅支持 NCHW/NHWC 格式
* @tparam T context类型
* @param [in] context infershape 或 tiling 上下文
* @param [in] shape 输入形状
* @param [in] format 输入数据格式
* @param [out] n 输出N
* @param [out] c 输出C
* @param [out] h 输出H
* @param [out] w 输出W
* @return 执行结果,GRAPH_FAILED: 失败,GRAPH_SUCCESS: 成功
*/
template <typename T>
static inline ge::graphStatus GetImgDataDimsByNCHWOrder(
T* context, const gert::Shape& shape, const ge::Format& format, int64_t& n, int64_t& c, int64_t& h, int64_t& w)
{
int64_t dimNum = shape.GetDimNum();
if (unlikely(dimNum != NHWC_DIM_NUM)) {
OP_LOGE(context, "input shape dim num should be 4, but got %ld", dimNum);
return ge::GRAPH_FAILED;
}
if (format == ge::Format::FORMAT_NCHW) {
n = shape.GetDim(NCHW_N_DIM);
c = shape.GetDim(NCHW_C_DIM);
h = shape.GetDim(NCHW_H_DIM);
w = shape.GetDim(NCHW_W_DIM);
return ge::GRAPH_SUCCESS;
}
if (format == ge::Format::FORMAT_NHWC) {
n = shape.GetDim(NHWC_N_DIM);
c = shape.GetDim(NHWC_C_DIM);
h = shape.GetDim(NHWC_H_DIM);
w = shape.GetDim(NHWC_W_DIM);
return ge::GRAPH_SUCCESS;
}
OP_LOGE(context, "The input data format must be NCHW or NHWC, but got %s", ge::GetFormatName(format));
return ge::GRAPH_FAILED;
}
}
#endif