* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "flow_attr_util.h"
#include "common/checker.h"
#include "graph_metadef/graph/debug/ge_util.h"
#include "data_flow_attr_define.h"
namespace ge {
namespace dflow {
const std::map<DataFlowAttrType, FlowAttrUtil::SetAttrFunc> FlowAttrUtil::set_attr_funcs_ = {
{DataFlowAttrType::COUNT_BATCH, &FlowAttrUtil::SetCountBatchAttr},
{DataFlowAttrType::TIME_BATCH, &FlowAttrUtil::SetTimeBatchAttr},
};
bool FlowAttrUtil::CheckAttrsIsSupport(const std::vector<DataFlowInputAttr> &attrs) {
bool count_batch = false;
bool time_batch = false;
if (attrs.empty()) {
return true;
}
for (size_t i = 0U; i < attrs.size(); ++i) {
if (attrs[i].attr_type == DataFlowAttrType::COUNT_BATCH) {
count_batch = true;
if (time_batch) {
GELOGE(ge::FAILED, "[Check]COUNT_BATCH attr and TIME_BATCH attr cannot be config at the same time.");
return false;
}
} else if (attrs[i].attr_type == DataFlowAttrType::TIME_BATCH) {
time_batch = true;
if (count_batch) {
GELOGE(ge::FAILED, "[Check]COUNT_BATCH attr and TIME_BATCH attr cannot be config at the same time.");
return false;
}
} else {
if (set_attr_funcs_.find(attrs[i].attr_type) == set_attr_funcs_.cend()) {
GELOGE(ge::FAILED, "[Check]Attr type(%u) is not supported.", static_cast<uint32_t>(attrs[i].attr_type));
return false;
}
}
}
return true;
}
graphStatus FlowAttrUtil::SetCountBatchAttr(const void *const attr_value, GeTensorDescPtr &tensor_desc) {
GE_ASSERT_NOTNULL(attr_value);
const CountBatch *const count_batch = static_cast<const CountBatch *>(attr_value);
if (count_batch->batch_size <= 0) {
GELOGE(FAILED, "CountBatch.batch_size should be larger than zero, but got %lld", count_batch->batch_size);
return ge::GRAPH_FAILED;
}
if ((count_batch->slide_stride < 0) || (count_batch->slide_stride > count_batch->batch_size)) {
GELOGE(FAILED, "CountBatch.slide_stride should in [0, %lld], but got %lld", count_batch->batch_size,
count_batch->slide_stride);
return ge::GRAPH_FAILED;
}
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(tensor_desc, ATTR_NAME_COUNT_BATCH_BATCH_SIZE, count_batch->batch_size));
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(tensor_desc, ATTR_NAME_COUNT_BATCH_SLIDE_STRIDE, count_batch->slide_stride));
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(tensor_desc, ATTR_NAME_COUNT_BATCH_TIMEOUT, count_batch->timeout));
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(tensor_desc, ATTR_NAME_COUNT_BATCH_BATCH_DIM, count_batch->batch_dim));
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(tensor_desc, ATTR_NAME_COUNT_BATCH_FLAG, count_batch->flag));
GE_ASSERT_TRUE(ge::AttrUtils::SetBool(tensor_desc, ATTR_NAME_COUNT_BATCH_PADDING, count_batch->padding));
GE_ASSERT_TRUE(ge::AttrUtils::SetBool(tensor_desc, ATTR_NAME_COUNT_BATCH_DROP_REMAINDER,
count_batch->drop_remainder));
GELOGI("set count batch attr: batch_size(%ld), slide_stride(%ld), timeout(%ld), "
"batch_dim(%ld), flag(%d), padding(%d), drop_remainder(%d)", count_batch->batch_size,
count_batch->slide_stride, count_batch->timeout, count_batch->batch_dim, count_batch->flag,
count_batch->padding, count_batch->drop_remainder);
return ge::GRAPH_SUCCESS;
}
graphStatus FlowAttrUtil::SetTimeBatchAttr(const void *const attr_value, GeTensorDescPtr &tensor_desc) {
GE_ASSERT_NOTNULL(attr_value);
const TimeBatch *const time_batch = static_cast<const TimeBatch *>(attr_value);
if (time_batch->batch_dim < -1) {
GELOGE(FAILED, "TimeBatch.batch_dim should be larger than -1, but got %lld", time_batch->batch_dim);
return ge::GRAPH_FAILED;
}
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(tensor_desc, ATTR_NAME_TIME_BATCH_TIME_WINDOW, time_batch->time_window));
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(tensor_desc, ATTR_NAME_TIME_BATCH_TIME_INTERVAL, time_batch->time_interval));
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(tensor_desc, ATTR_NAME_TIME_BATCH_TIMEOUT, time_batch->timeout));
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(tensor_desc, ATTR_NAME_TIME_BATCH_BATCH_DIM, time_batch->batch_dim));
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(tensor_desc, ATTR_NAME_TIME_BATCH_FLAG, time_batch->flag));
GE_ASSERT_TRUE(ge::AttrUtils::SetBool(tensor_desc, ATTR_NAME_TIME_BATCH_PADDING, time_batch->padding));
GE_ASSERT_TRUE(ge::AttrUtils::SetBool(tensor_desc, ATTR_NAME_TIME_BATCH_DROP_REMAINDER, time_batch->drop_remainder));
GELOGI("set time batch attr: time_window(%ld), time_interval(%ld), timeout(%ld), "
"batch_dim(%ld), flag(%d), padding(%d), drop_remainder(%d)", time_batch->time_window,
time_batch->time_interval, time_batch->timeout, time_batch->batch_dim, time_batch->flag,
time_batch->padding, time_batch->drop_remainder);
return ge::GRAPH_SUCCESS;
}
graphStatus FlowAttrUtil::SetAttrsToTensorDesc(const std::vector<DataFlowInputAttr> &attrs,
GeTensorDescPtr &tensor_desc) {
GE_ASSERT_TRUE(CheckAttrsIsSupport(attrs));
for (auto &attr : attrs) {
auto attr_type = attr.attr_type;
const auto iter = set_attr_funcs_.find(attr_type);
GE_ASSERT_TRUE(iter != set_attr_funcs_.cend(), "Data flow input attr type(%u) does not has process function..",
static_cast<uint32_t>(attr_type));
GE_ASSERT_SUCCESS(iter->second(attr.attr_value, tensor_desc));
}
return ge::GRAPH_SUCCESS;
}
}
}