* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "user_graphs_manager.h"
#include "common/checker.h"
#include "graph/utils/graph_utils_ex.h"
#include "common/memory/tensor_trans_utils.h"
#include "api/aclgrph/option_utils.h"
namespace ge {
bool UserGraphsManager::ShouldUseSliceSchedule(uint32_t user_graph_id) const {
if (!EnableSliceSchedule()) {
return false;
}
std::lock_guard<std::mutex> locker(user_graph_ctrl_mutex_);
return slice_schedule_unsupported_set_.find(user_graph_id) == slice_schedule_unsupported_set_.end();
}
Status UserGraphsManager::AddGraph(uint32_t user_graph_id, const Graph &graph,
const std::map<std::string, std::string> &options) {
if (!EnableSliceSchedule()) {
return graph_manager_.AddGraph(user_graph_id, graph, options, domi::GetContext());
}
auto compute_graph = GraphUtilsEx::GetComputeGraph(graph);
GE_ASSERT_NOTNULL(compute_graph);
const bool slice_supported = IsGraphSupportSliceSchedule(compute_graph, options);
if (!slice_supported) {
std::lock_guard<std::mutex> locker(user_graph_ctrl_mutex_);
slice_schedule_unsupported_set_.insert(user_graph_id);
GELOGI("Graph[%u] does not support slice schedule, fallback to traditional mode.", user_graph_id);
return graph_manager_.AddGraph(user_graph_id, graph, options, domi::GetContext());
}
SetLocalOmgContext(domi::GetContext());
GetThreadLocalContext().SetGraphOption(options);
std::lock_guard<std::mutex> locker(user_graph_ctrl_mutex_);
auto iter = ids_to_user_graph_ctrl_.find(user_graph_id);
if (iter == ids_to_user_graph_ctrl_.end()) {
auto user_graph_ctrl =
MakeUnique<UserGraphControl>(user_graph_id, compute_graph, compile_context_, graph_manager_, options);
GE_ASSERT_NOTNULL(user_graph_ctrl);
GE_ASSERT_SUCCESS(user_graph_ctrl->AddGraphInstance());
ids_to_user_graph_ctrl_[user_graph_id] = std::move(user_graph_ctrl);
} else {
GE_ASSERT_SUCCESS(iter->second->AddGraphInstance());
}
return SUCCESS;
}
Status UserGraphsManager::BuildGraph(uint32_t user_graph_id, const std::vector<GeTensor> &inputs,
uint64_t session_id) const {
if (!ShouldUseSliceSchedule(user_graph_id)) {
GeRootModelPtr ge_root_model;
return graph_manager_.BuildGraph(user_graph_id, inputs, ge_root_model, session_id, true);
}
(void)user_graph_id;
(void)inputs;
return SUCCESS;
}
Status UserGraphsManager::RunGraphAsync(uint32_t user_graph_id, std::vector<gert::Tensor> &&inputs,
uint64_t session_id, const RunAsyncCallbackV2 &callback) {
if (!ShouldUseSliceSchedule(user_graph_id)) {
return graph_manager_.RunGraphAsync(user_graph_id, std::move(inputs), session_id, callback);
}
UserGraphControl *user_graph_control = nullptr;
{
std::lock_guard<std::mutex> locker(user_graph_ctrl_mutex_);
auto iter = ids_to_user_graph_ctrl_.find(user_graph_id);
GE_ASSERT_TRUE(iter != ids_to_user_graph_ctrl_.end());
user_graph_control = iter->second.get();
}
GE_ASSERT_NOTNULL(user_graph_control, "Failed to find user graph ctrl of graph[%u], session[]", user_graph_id);
auto exe_task = MakeUnique<UserGraphExecution>(user_graph_id, std::move(inputs), callback, session_id);
GE_ASSERT_NOTNULL(exe_task);
user_graph_control->RunGraphAsync(exe_task);
return SUCCESS;
}
UserGraphControl* UserGraphsManager::GetUserGraphControl(uint32_t user_graph_id) {
std::lock_guard<std::mutex> locker(user_graph_ctrl_mutex_);
UserGraphControl *user_graph_control = nullptr;
auto iter = ids_to_user_graph_ctrl_.find(user_graph_id);
GE_ASSERT_TRUE(iter != ids_to_user_graph_ctrl_.end(), "Failed to find user graph ctrl of graph[%u]", user_graph_id);
user_graph_control = iter->second.get();
return user_graph_control;
}
Status UserGraphsManager::CompileGraph(uint32_t user_graph_id, uint64_t session_id, const vector<ge::Tensor> &inputs) {
if (!ShouldUseSliceSchedule(user_graph_id)) {
return graph_manager_.CompileGraph(user_graph_id, session_id, inputs);
}
UserGraphControl *user_graph_control = GetUserGraphControl(user_graph_id);
GE_ASSERT_NOTNULL(user_graph_control, "Failed to find user graph ctrl of graph[%u], session[]", user_graph_id);
GE_ASSERT_SUCCESS(user_graph_control->CompileGraph(session_id));
return SUCCESS;
}
Status UserGraphsManager::GetCompiledGraphSummary(uint32_t user_graph_id, CompiledGraphSummaryPtr &summary) {
if (!ShouldUseSliceSchedule(user_graph_id)) {
return graph_manager_.GetCompiledGraphSummary(user_graph_id, summary);
}
UserGraphControl *user_graph_control = GetUserGraphControl(user_graph_id);
GE_ASSERT_NOTNULL(user_graph_control, "Failed to find user graph ctrl of graph[%u], session[]", user_graph_id);
summary = user_graph_control->GetCompiledGraphSummary();
return SUCCESS;
}
Status UserGraphsManager::LoadGraph(const uint32_t user_graph_id, const std::map<AscendString, AscendString> &options,
void *stream) {
if (!ShouldUseSliceSchedule(user_graph_id)) {
return graph_manager_.LoadGraph(user_graph_id, options, stream);
}
UserGraphControl *user_graph_control = GetUserGraphControl(user_graph_id);
GE_ASSERT_NOTNULL(user_graph_control, "Failed to find user graph ctrl of graph[%u], session[]", user_graph_id);
GE_ASSERT_SUCCESS(user_graph_control->LoadGraph(options, stream));
return SUCCESS;
}
Status UserGraphsManager::ExecuteGraphWithStreamAsync(uint32_t user_graph_id, void *stream,
const std::vector<gert::Tensor> &inputs,
std::vector<gert::Tensor> &outputs, uint64_t session_id) {
if (!ShouldUseSliceSchedule(user_graph_id)) {
return graph_manager_.ExecuteGraphWithStreamAsync(user_graph_id, stream, inputs, outputs);
}
UserGraphControl *user_graph_control = GetUserGraphControl(user_graph_id);
GE_ASSERT_NOTNULL(user_graph_control, "Failed to find user graph ctrl of graph[%u], session[]", user_graph_id);
auto exe_task = MakeUnique<UserGraphExecution>(user_graph_id, inputs, nullptr, session_id);
GE_ASSERT_NOTNULL(exe_task);
exe_task->stream = stream;
exe_task->session_id = session_id;
exe_task->rt_outputs = &outputs;
exe_task->load_options = user_graph_control->GetLoadOptions();
GE_ASSERT_SUCCESS(user_graph_control->ExecuteGraphWithStreamAsync(std::move(exe_task)));
return SUCCESS;
}
Status UserGraphsManager::Finalize() {
std::lock_guard<std::mutex> locker(user_graph_ctrl_mutex_);
ids_to_user_graph_ctrl_.clear();
slice_schedule_unsupported_set_.clear();
return SUCCESS;
}
Status UserGraphsManager::RemoveGraph(uint32_t user_graph_id) {
if (!ShouldUseSliceSchedule(user_graph_id)) {
std::lock_guard<std::mutex> locker(user_graph_ctrl_mutex_);
slice_schedule_unsupported_set_.erase(user_graph_id);
return graph_manager_.RemoveGraph(user_graph_id);
}
std::lock_guard<std::mutex> locker(user_graph_ctrl_mutex_);
auto iter = ids_to_user_graph_ctrl_.find(user_graph_id);
if (iter == ids_to_user_graph_ctrl_.end()) {
GELOGE(PARAM_INVALID, "Failed to remove graph %u which does not exist.", user_graph_id);
return FAILED;
}
GE_ASSERT_SUCCESS(iter->second->Finalize());
(void)ids_to_user_graph_ctrl_.erase(user_graph_id);
GELOGI("Remove graph %u success.", user_graph_id);
return SUCCESS;
}
bool UserGraphsManager::IsGraphNeedRebuild(uint32_t user_graph_id) {
if (!ShouldUseSliceSchedule(user_graph_id)) {
return graph_manager_.IsGraphNeedRebuild(user_graph_id);
}
std::lock_guard<std::mutex> locker(user_graph_ctrl_mutex_);
auto iter = ids_to_user_graph_ctrl_.find(user_graph_id);
if (iter == ids_to_user_graph_ctrl_.end()) {
REPORT_INNER_ERR_MSG("E19999", "Graph:%u does not exist, check rebuild invalid", user_graph_id);
GELOGE(PARAM_INVALID, "Graph %u need rebuild when does not exist.", user_graph_id);
return true;
}
return iter->second->IsUserGraphNeedRebuild();
}
Status UserGraphsManager::GetCompiledFlag(uint32_t user_graph_id, bool &flag) {
if (!ShouldUseSliceSchedule(user_graph_id)) {
return graph_manager_.GetCompiledFlag(user_graph_id, flag);
}
const UserGraphControl *user_graph_control = GetUserGraphControl(user_graph_id);
GE_ASSERT_NOTNULL(user_graph_control);
flag = user_graph_control->GetCompiledFlag();
return SUCCESS;
}
Status UserGraphsManager::DumpDebugJSONPrint(uint32_t user_graph_id, uint32_t flags, AscendString &json_result) {
return graph_manager_.DumpDebugJSONPrint(user_graph_id, flags, json_result);
}
Status UserGraphsManager::SetCompiledFlag(uint32_t user_graph_id, bool flag) {
if (!ShouldUseSliceSchedule(user_graph_id)) {
return graph_manager_.SetCompiledFlag(user_graph_id, flag);
}
UserGraphControl *user_graph_control = GetUserGraphControl(user_graph_id);
GE_ASSERT_NOTNULL(user_graph_control);
user_graph_control->SetCompiledFlag(flag);
return SUCCESS;
}
Status UserGraphsManager::GetOmeContextByGraphId(const GraphId &graph_id, OmeContext &ome_context) const {
GE_ASSERT_SUCCESS(graph_manager_.GetOmeContextByGraphId(graph_id, ome_context));
return SUCCESS;
}
}