* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef GE_GRAPH_EXECUTE_MODEL_EXECUTOR_H
#define GE_GRAPH_EXECUTE_MODEL_EXECUTOR_H
#include <thread>
#include "common/model/executor.h"
#include "graph/execute/graph_executor.h"
namespace ge {
class ModelExecutor : public Executor {
public:
ModelExecutor() = default;
Status Initialize(const std::map<std::string, std::string> &options, const uint64_t session_id);
Status Finalize();
Status LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node,
const aclrtStream stream = nullptr) override;
Status UnloadGraph(const GeRootModelPtr &ge_root_model, const uint32_t graph_id) override;
Status PushRunArgs(const std::shared_ptr<RunArgs> &args) override;
Status RunGraph(const GraphNodePtr &graph_node, const GraphId graph_id,
const std::vector<gert::Tensor> &inputs, std::vector<gert::Tensor> &outputs) override;
Status RunGraphWithStream(const GraphNodePtr &graph_node, const GraphId graph_id, aclrtStream const stream,
const std::vector<GeTensor> &inputs, std::vector<GeTensor> &outputs) override;
Status ExecuteGraphWithStream(const GraphNodePtr &graph_node, const GraphId graph_id,
aclrtStream const stream, const std::vector<gert::Tensor> &inputs,
std::vector<gert::Tensor> &outputs) override;
Status DumpDebugJSONPrint(uint32_t model_id, uint32_t graph_id, uint32_t flags,
AscendString &json_result) override;
* @ingroup ge
* @brief Update feature memory base after load graph
* @param [in] graph_node: node of graph.
* @param [in] mem_base: graph feature memory base(without input and output).
* @param [in] size: memory size
* @return Status result of function
*/
Status UpdateFeatureMemoryBase(const GraphNodePtr &graph_node, const uintptr_t mem_base, const size_t size) override;
void StartRunThread();
Status PaRemapped(const GraphNodePtr &graph_node, const uint64_t va, const uint64_t new_pa,
const uint64_t len, std::vector<std::pair<uint64_t, uint64_t>> &cross_ranges) override;
private:
static Status GetDeviceMemorySize(size_t &free_mem, size_t &total_mem_size);
void AddGraphNode(const GraphId graph_id, const GraphNodePtr &graph_node);
void RemoveGraphNode(const GraphId graph_id);
Status ModelLoad(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node, const aclrtStream stream = nullptr);
Status MallocFixedFeatureMemoryIfNeed(const GraphNodePtr &graph_node, const GeRootModelPtr &ge_root_model,
const aclrtStream stream) const;
static Status MallocByDiffAllocator(const uint64_t session_id,
const aclrtStream stream,
const FeatureMemoryPtr &fixed_feature_mem,
const rtMemType_t rt_mem_type,
const GeRootModelPtr &ge_root_model);
static Status FreeFixedFeatureMemoryIfNeed(const GeRootModelPtr &ge_root_model);
Status GetStreamNum(const GeRootModelPtr &ge_root_model, uint32_t &stream_num, uint64_t &hccl_follow_stream) const;
Status GetEventNum(const GeRootModelPtr &ge_root_model, uint32_t &event_num) const;
static Status UnloadModel(const GeRootModelPtr &ge_root_model, const uint32_t graph_id);
static Status UnloadPneModel(const uint32_t model_id, const uint64_t session_id, const uint32_t graph_id);
bool ReleaseMemory(const GeRootModelPtr &ge_root_model, const GraphNodePtr &loaded_graph_node) const;
bool ReleaseModel(const GeRootModelPtr &ge_root_model, const GraphNodePtr &loaded_graph_node) const;
Status CheckAndReleaseMemory(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node);
Status CheckAndReleaseStream(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node);
Status CheckAndReleaseEvent(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node);
static Status GetMemoryInfo(size_t &free);
Status CheckFreeMemory(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node,
bool &is_enough, bool &release_all) const;
Status GetMemorySizeAfterReuse(const std::vector<GeModelPtr> &ge_models, const GraphNodePtr &graph_node,
int64_t &sum_size, bool &reuse) const;
void RunThread();
void StopQueue();
void ReturnError(const RunAsyncCallbackV2 &callback,
const Status ret, const std::string &log_info) const;
bool DoReleaseModel(const GeRootModelPtr &ge_root_model, const GraphNodePtr &loaded_graph_node) const;
bool init_flag_{false};
uint64_t session_id_{0U};
GraphExecutor graph_executor_;
std::mutex mutex_;
std::map<GraphId, GraphNodePtr> graph_nodes_;
std::thread run_thread_;
std::atomic_bool thread_run_flag_{false};
BlockingQueue<std::shared_ptr<RunArgs>> run_args_q_;
};
}
#endif