* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/manager/graph_manager_utils.h"
#include <set>
#include "common/checker.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/ge_context.h"
#include "ge/ge_api_types.h"
#include "base/err_msg.h"
#include "common/memory/tensor_trans_utils.h"
namespace ge {
namespace {
const char_t *kFrozenInputIndexes = "ge.exec.frozenInputIndexes";
constexpr size_t kIndexOfFrozenDataIndex = 0UL;
constexpr size_t kIndexOfFrozenDataAddr = 1UL;
constexpr size_t kIndexOfFrozenDataLen = 2UL;
Status IsDigitStrings(const std::vector<std::string> &input_vec) {
for (const auto &input : input_vec) {
for (const char ch : input) {
if (ch != ' ' && (static_cast<bool>(isdigit(static_cast<unsigned char>(ch))) == false)) {
(void)REPORT_PREDEFINED_ERR_MSG("E10001", std::vector<const char_t *>({"parameter", "value", "reason"}),
std::vector<const char_t *>({kFrozenInputIndexes, input.c_str(),
"The frozen input index is not a digit."}));
return GE_GRAPH_OPTIONS_INVALID;
}
}
}
return SUCCESS;
}
}
const char_t *GetRunGraphModeStr(RunGraphMode mode) {
static constexpr const char_t *run_graph_mode_str[] = {"RunGraph", "RunGraphAsync", "RunGraphWithStreamAsync",
"InitRunGraphMode"};
static_assert(sizeof(run_graph_mode_str) / sizeof(run_graph_mode_str[0]) ==
(static_cast<size_t>(RunGraphMode::kRunGraphModeEnd) + 1U), "RunGraphMode enum and string array size mismatch");
const auto index = static_cast<size_t>(mode);
if (index >= static_cast<size_t>(RunGraphMode::kRunGraphModeEnd)) {
return "UnknownRunGraphMode";
}
return run_graph_mode_str[index];
}
GraphNode::GraphNode(const GraphId graph_id) : graph_id_(graph_id), sem_(1U), const_mem_(), feature_mem_() {
std::string opt;
(void)GetContext().GetOption(GRAPH_MAX_PARALLEL_MODEL_NUM, opt);
int32_t max_num = 0;
try {
max_num = std::stoi(opt);
if (max_num > 0) {
max_load_record_ = max_num;
} else {
GELOGW("Option %s param failed:%s", GRAPH_MAX_PARALLEL_MODEL_NUM.c_str(), opt.c_str());
}
} catch (...) {
GELOGW("Option %s param failed:%s", GRAPH_MAX_PARALLEL_MODEL_NUM.c_str(), opt.c_str());
}
GELOGD("[GraphManager] graphMaxParallelModelNum is %u", max_load_record_);
}
GraphNode::~GraphNode() = default;
void GraphNode::Lock() {
(void)sem_.Push(0U);
}
void GraphNode::Unlock() {
uint8_t unused;
(void)sem_.Pop(unused);
(void)unused;
}
void GraphNode::IncreaseLoadCount() {
const std::unique_lock<std::mutex> lock(load_count_mu_);
if (load_record_ == max_load_record_) {
GELOGW("Reach the maximum of load_count:%u", kMaxLoadNum);
return;
}
++load_count_;
}
Status GraphNode::ParseFrozenInputIndex() {
std::string frozen_input;
(void)ge::GetContext().GetOption(kFrozenInputIndexes, frozen_input);
if (frozen_input.empty()) {
return SUCCESS;
}
std::vector<std::string> frozen_input_vec = StringUtils::Split(frozen_input, ';');
for (auto &frozen_info : frozen_input_vec) {
std::vector<std::string> frozen_info_vec = StringUtils::Split(frozen_info, ',');
GE_ASSERT_TRUE((frozen_info_vec.size() == 1UL) || (frozen_info_vec.size() == 3UL),
"frozen info vector size should be 1 or 3 parsed by [%s]", frozen_info.c_str());
GE_ASSERT_SUCCESS(IsDigitStrings(frozen_info_vec),
"There are some invalid characters in frozen option value[%s].", frozen_input.c_str());
int32_t frozen_input_index = -1;
GE_ASSERT_SUCCESS(ConvertToInt32(frozen_info_vec[kIndexOfFrozenDataIndex], frozen_input_index));
GE_ASSERT_TRUE((frozen_input_index >= 0), "Frozen_input_index must be greater than zero: %u", frozen_input_index);
(void)frozen_input_indexes_.insert(static_cast<uint32_t>(frozen_input_index));
if (frozen_info_vec.size() == 1UL) {
GELOGD("Parse frozen input index[%d] success.", frozen_input_index);
continue;
}
uint64_t addr = 0UL;
GE_ASSERT_SUCCESS(ConvertToUint64(frozen_info_vec[kIndexOfFrozenDataAddr], addr));
GE_ASSERT_TRUE(addr != 0UL, "Frozen input addr cannot be nullptr.");
uint64_t len = 0UL;
GE_ASSERT_SUCCESS(ConvertToUint64(frozen_info_vec[kIndexOfFrozenDataLen], len));
GE_ASSERT_TRUE(len != 0UL, "Frozen input length cannot be zero.");
frozen_index_to_node_info_[static_cast<uint32_t>(frozen_input_index)] = std::make_pair(addr, len);
GELOGI("Parse and set frozen addr[%lx] length[%lu] for input index[%d] success.", addr, len, frozen_input_index);
}
return SUCCESS;
}
void GraphNode::SetLoaded() {
--load_count_;
++load_record_;
load_flag_ = true;
}
std::shared_ptr<GraphNode> GraphNode::Fork(uint32_t forked_graph_id) {
GraphNodePtr graph_node = MakeShared<GraphNode>(forked_graph_id);
GE_ASSERT_NOTNULL(graph_node, "[Fork][GraphNode] fail, graph_id:%u", forked_graph_id);
graph_node->origin_graph_id_ = this->graph_id_;
graph_node->options_ = this->options_;
graph_node->context_ = this->context_;
graph_node->graph_ = this->graph_;
graph_node->compute_graph_ = this->compute_graph_;
graph_node->compiled_flag_ = this->compiled_flag_;
graph_node->build_flag_ = this->GetBuildFlag();
if (this->GetBuildFlag()) {
graph_node->compiled_flag_ = true;
}
graph_node->async_ = this->async_;
graph_node->is_specific_stream_ = this->is_specific_stream_;
GE_ASSERT_NOTNULL(ge_root_model_);
auto forked_root_model = ge_root_model_->Fork();
GE_ASSERT_NOTNULL(forked_root_model);
graph_node->ge_root_model_ = forked_root_model;
graph_node->is_feature_base_refreshable_ = this->is_feature_base_refreshable_;
graph_node->const_mem_ = std::make_pair(nullptr, this->const_mem_.second);
graph_node->feature_mem_ = std::make_pair(nullptr, this->feature_mem_.second);
graph_node->refreshable_feature_mem_ = std::make_pair(nullptr, this->refreshable_feature_mem_.second);
graph_node->compiled_summary_ = this->compiled_summary_;
graph_node->net_output_node_ = this->net_output_node_;
graph_node->tensor_sizes_ = this->tensor_sizes_;
graph_node->is_saved_net_output_tensor_info_flag_ = this->is_saved_net_output_tensor_info_flag_;
graph_node->ge_tensor_descs_ = this->ge_tensor_descs_;
graph_node->group_2_communication_nodes_ = this->group_2_communication_nodes_;
return graph_node;
}
SubGraphInfo::SubGraphInfo() : subgraph_ptr_(nullptr), ge_model_ptr_(nullptr) {}
SubGraphInfo::~SubGraphInfo() = default;
GraphModelListener::GraphModelListener() : ModelListener() {}
Status GraphModelListener::OnComputeDone(const uint32_t model_id, const uint32_t data_index, const uint32_t result_code,
std::vector<gert::Tensor> &outputs) {
(void)outputs;
GELOGI(
"[GraphManager] graph compute call back, model_id:%u, task_id:%u, "
"resultCode:%u.",
model_id, data_index, result_code);
const std::lock_guard<std::mutex> lock(mutex_);
result_code_ = result_code;
is_finished_ = true;
condition_.notify_all();
return SUCCESS;
}
uint32_t GraphModelListener::GetResultCode() {
std::unique_lock<std::mutex> lock(mutex_);
if (!is_finished_) {
GELOGI("[GetResultCode] wait model execute finished.");
condition_.wait(lock);
}
if (!is_finished_) {
REPORT_INNER_ERR_MSG("E19999", "Model not run finish");
GELOGE(INTERNAL_ERROR, "[Check][Param] model not run finish.");
return INTERNAL_ERROR;
}
return result_code_;
}
Status GraphModelListener::ResetResult() {
const std::lock_guard<std::mutex> lock(mutex_);
result_code_ = 0U;
is_finished_ = false;
return SUCCESS;
}
void RunAsyncListener::SetCallback(const RunAsyncCallbackV2 &callback) {
(void)sem_v2_.Push(callback);
}
Status RunAsyncListener::OnComputeDone(const uint32_t model_id, const uint32_t data_index, const uint32_t result_code,
std::vector<gert::Tensor> &outputs) {
GELOGI("[GraphManager] run graph async call back, modelId:%u, taskId:%u, resultCode:%u.",
model_id, data_index, result_code);
RunAsyncCallbackV2 callback;
(void)sem_v2_.Pop(callback, 0U);
GE_CHECK_NOTNULL(callback);
callback(result_code, outputs);
return SUCCESS;
}
}