* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "control_kernel.h"
#include <cmath>
#include "common/ge_inner_attrs.h"
#include "common/checker.h"
#include "graph/ge_error_codes.h"
#include "graph/types.h"
#include "graph/utils/type_utils.h"
#include "exe_graph/runtime/storage_shape.h"
#include "exe_graph/runtime/tensor.h"
#include "register/kernel_registry.h"
#include "runtime/mem.h"
#include "exe_graph/runtime/gert_tensor_data.h"
#include "graph_metadef/common/ge_common/util.h"
namespace gert {
namespace kernel {
namespace {
ge::graphStatus EmptyFunc(KernelContext *context) {
(void)context;
return ge::GRAPH_SUCCESS;
}
template <typename T>
bool IntegerToBool(const void *data) {
return (*static_cast<const T *>(data)) != T(0);
}
template <>
bool IntegerToBool<bool>(const void *data) {
return *static_cast<const bool *>(data);
}
template <typename T>
bool FloatToBool(const void *data) {
return fabs(*static_cast<const T *>(data)) > std::numeric_limits<T>::epsilon();
}
const std::vector<std::function<bool(const void *)>> kBoolConverter = []() noexcept {
std::vector<std::function<bool(const void *)>> converts;
converts.resize(static_cast<size_t>(ge::DT_MAX), nullptr);
converts[ge::DT_FLOAT] = &FloatToBool<ge::float32_t>;
converts[ge::DT_DOUBLE] = &FloatToBool<ge::float64_t>;
converts[ge::DT_INT8] = &IntegerToBool<int8_t>;
converts[ge::DT_UINT8] = &IntegerToBool<uint8_t>;
converts[ge::DT_INT16] = &IntegerToBool<int16_t>;
converts[ge::DT_UINT16] = &IntegerToBool<uint16_t>;
converts[ge::DT_INT32] = &IntegerToBool<int32_t>;
converts[ge::DT_UINT32] = &IntegerToBool<uint32_t>;
converts[ge::DT_INT64] = &IntegerToBool<int64_t>;
converts[ge::DT_UINT64] = &IntegerToBool<uint64_t>;
converts[ge::DT_BOOL] = &IntegerToBool<bool>;
return converts;
}();
ge::graphStatus GetCondState(const void *tensor_addr, const StorageShape &shape, const ge::DataType &data_type,
bool &cond_state) {
if (shape.GetOriginShape().IsScalar()) {
const auto numeral_dtype = static_cast<size_t>(data_type);
GE_ASSERT_TRUE(numeral_dtype < kBoolConverter.size(), "Unsupported cond data type %s(%zu)",
ge::TypeUtils::DataTypeToSerialString(data_type).c_str(), numeral_dtype);
GE_ASSERT_NOTNULL(tensor_addr, "Failed to generate index as tensor data is nullptr");
cond_state = kBoolConverter[numeral_dtype](tensor_addr);
} else {
cond_state = true;
for (size_t i = 0U; i < shape.GetOriginShape().GetDimNum(); ++i) {
if (shape.GetOriginShape().GetDim(i) == 0) {
cond_state = false;
break;
}
}
}
return ge::GRAPH_SUCCESS;
}
}
ge::graphStatus GenIndexForIf(KernelContext *context) {
auto tensor = context->GetInputPointer<Tensor>(0UL);
auto index = context->GetOutputPointer<uint32_t>(0UL);
if (tensor == nullptr || index == nullptr) {
GELOGE(ge::GE_GRAPH_PARAM_NULLPTR, "Failed to generate index for if, no tensor or index input");
return ge::GE_GRAPH_PARAM_NULLPTR;
}
bool cond_state = false;
GE_ASSERT_SUCCESS(GetCondState(tensor->GetAddr(), tensor->GetShape(), tensor->GetDataType(), cond_state),
"Failed to get cond state for If");
*index = cond_state ? 0U : 1U;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus GenIndexForCase(KernelContext *context) {
auto tensor = context->GetInputPointer<Tensor>(0UL);
auto branch_num = context->GetInputValue<uint32_t>(1UL);
auto index = context->GetOutputPointer<uint32_t>(0UL);
if (tensor == nullptr || tensor->GetData<int32_t>() == nullptr || index == nullptr) {
GELOGE(ge::GE_GRAPH_PARAM_NULLPTR, "Failed to generate index for case, tensor or index is nullptr");
return ge::GE_GRAPH_PARAM_NULLPTR;
}
if (branch_num == 0U) {
GELOGE(ge::PARAM_INVALID, "Failed to generate index for case, at least one default branch");
return ge::PARAM_INVALID;
}
auto branch_index = *tensor->GetData<int32_t>();
if (branch_index < 0 || static_cast<uint32_t>(branch_index) >= branch_num) {
*index = branch_num - 1U;
GELOGD("The branch index %d is less than 0 or reach the max num of branch num(%u), select the default branch %u",
branch_index, branch_num, *index);
} else {
*index = static_cast<uint32_t>(branch_index);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus GenCondForWhile(KernelContext *context) {
auto tensor = context->GetInputPointer<GertTensorData>(0U);
GE_ASSERT_NOTNULL(tensor);
auto shape = context->GetInputPointer<StorageShape>(1U);
GE_ASSERT_NOTNULL(shape);
auto data_type = context->GetInputValue<ge::DataType>(2U);
auto out = context->GetOutputPointer<bool>(0U);
GE_ASSERT_NOTNULL(out);
GE_ASSERT_SUCCESS(GetCondState(tensor->GetAddr(), *shape, data_type, *out), "Failed to get cond state for While");
return ge::GRAPH_SUCCESS;
}
ge::graphStatus SwitchNotifyKernel(KernelContext *context) {
auto branch_index = context->GetInputValue<uint32_t>(0UL);
auto signal = context->GetOutputPointer<NotifySignal>(0UL);
if (signal == nullptr) {
GELOGE(ge::GE_GRAPH_PARAM_NULLPTR, "Failed to execute SwitchNotifyKernel, signal is nullptr");
return ge::GE_GRAPH_PARAM_NULLPTR;
}
if (branch_index >= signal->receiver_num) {
GELOGE(ge::PARAM_INVALID, "Invalid branch index %u, receiver num %zu", branch_index, signal->receiver_num);
return ge::PARAM_INVALID;
}
signal->target_receiver = signal->receivers[branch_index];
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CreateSwitchNotifyOutputs(const ge::FastNode *node, KernelContext *context) {
Watcher *watcher = nullptr;
int64_t address_buffer = 0;
if (!ge::AttrUtils::GetInt(node->GetOpDescBarePtr(), ge::kWatcherAddress, address_buffer)) {
GELOGE(ge::FAILED, "Failed get attr '%s' from node %s", ge::kWatcherAddress, node->GetName().c_str());
return ge::GRAPH_FAILED;
}
if (memcpy_s(&watcher, sizeof(Watcher *), &address_buffer, sizeof(Watcher *)) != EOK) {
GELOGE(ge::FAILED, "Failed copy watcher address from src buffer");
return ge::GRAPH_FAILED;
}
GE_ASSERT_EQ(watcher->watch_num, node->GetAllOutEdgesSize());
auto watchers_num = watcher->watch_num;
auto outputs_num = node->GetAllOutDataEdgesSize();
std::vector<NodeIdentity> optimized_watcher;
optimized_watcher.reserve(watchers_num);
optimized_watcher.push_back(watcher->node_ids[0]);
for (size_t i = outputs_num; i < watchers_num; i++) {
optimized_watcher.push_back(watcher->node_ids[i]);
}
watcher->watch_num = optimized_watcher.size();
auto data_watchers = &watcher->node_ids[optimized_watcher.size()];
for (size_t i = 1U; i < outputs_num; i++) {
optimized_watcher.push_back(watcher->node_ids[i]);
}
GE_ASSERT_EOK(memcpy_s(watcher->node_ids, watchers_num * sizeof(NodeIdentity), optimized_watcher.data(),
optimized_watcher.size() * sizeof(NodeIdentity)));
auto chain = context->GetOutput(0U);
GE_ASSERT_NOTNULL(chain);
auto signal = new (std::nothrow) NotifySignal(*watcher->node_ids, data_watchers, outputs_num - 1U);
if (signal == nullptr) {
GELOGE(ge::FAILED, "Failed create outputs for node %s", node->GetName().c_str());
return ge::GRAPH_FAILED;
}
chain->SetWithDefaultDeleter(signal);
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CondSwitchNotifyKernel(KernelContext *context) {
auto signal = context->GetOutputPointer<NotifySignal>(0);
GE_ASSERT_NOTNULL(signal);
size_t branch_index = context->GetInputValue<bool>(0) ? 0U : 1U;
GE_ASSERT_TRUE(signal->receiver_num > branch_index);
signal->target_receiver = signal->receivers[branch_index];
return ge::GRAPH_SUCCESS;
}
REGISTER_KERNEL(BranchPivot).RunFunc(EmptyFunc);
REGISTER_KERNEL(BranchDone).RunFunc(EmptyFunc);
REGISTER_KERNEL(WatcherPlaceholder).RunFunc(EmptyFunc);
REGISTER_KERNEL(WaitAnyone).RunFunc(EmptyFunc);
REGISTER_KERNEL(GenIndexForCase).RunFunc(GenIndexForCase);
REGISTER_KERNEL(GenIndexForIf).RunFunc(GenIndexForIf);
REGISTER_KERNEL(GenCondForWhile).RunFunc(GenCondForWhile);
REGISTER_KERNEL(Enter).RunFunc(EmptyFunc);
REGISTER_KERNEL(Exit).RunFunc(EmptyFunc);
REGISTER_KERNEL(NoOp).RunFunc(EmptyFunc);
REGISTER_KERNEL(UnfedData).RunFunc(EmptyFunc);
REGISTER_KERNEL(SwitchNotify).RunFunc(SwitchNotifyKernel).OutputsCreator(CreateSwitchNotifyOutputs);
REGISTER_KERNEL(CondSwitchNotify).RunFunc(CondSwitchNotifyKernel).OutputsCreator(CreateSwitchNotifyOutputs);
}
}