* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef BASE_EXEC_RUNTIME_DEPLOY_DEPLOY_PLANNER_H_
#define BASE_EXEC_RUNTIME_DEPLOY_DEPLOY_PLANNER_H_
#include <atomic>
#include <map>
#include <set>
#include <vector>
#include "dflow/base/model/model_relation.h"
#include "common/plugin/ge_make_unique_util.h"
#include "dflow/inc/data_flow/model/pne_model.h"
namespace ge {
using RootModelId2SubmodelName = std::map<uint32_t, std::map<const std::string, bool>>;
struct ModelQueueIndex {
std::string model_name;
std::string invoke_key;
int32_t id;
bool operator < (const ModelQueueIndex &other) const {
if (model_name != other.model_name) {
return model_name < other.model_name;
} else if (invoke_key != other.invoke_key) {
return invoke_key < other.invoke_key;
} else {
return id < other.id;
}
}
};
* Deploy plan for GeRootModel
*/
class DeployPlan {
public:
class DeviceInfo {
public:
DeviceInfo() = default;
DeviceInfo(const int32_t type, const int32_t node_id, const int32_t device_id) noexcept;
DeviceInfo(const int32_t type, const int32_t node_id, const int32_t device_id,
const int32_t proxy_device_id) noexcept;
bool WithProxy() const;
DeviceInfo ProxyDevice() const;
int32_t GetType() const;
int32_t GetNodeId() const;
int32_t GetDeviceId() const;
int32_t GetProxyDeviceId() const;
const std::string &GetKey() const;
const std::string &GetDesc() const;
int32_t GetHcomDeviceId() const;
void SetHcomDeviceId(int32_t hcom_device_id);
int32_t GetOsId() const;
void SetOsId(int32_t os_id);
void SetSupportFlowgw(bool support_flowgw) {
support_flowgw_ = support_flowgw;
}
bool SupportFlowgw() const {
return support_flowgw_;
}
void SetProxyDeviceId(int32_t proxy_device_id) {
proxy_device_id_ = proxy_device_id;
}
bool operator<(const DeviceInfo& other) const {
if (node_id_ != other.node_id_) {
return node_id_ < other.node_id_;
}
if (device_id_ != other.device_id_) {
return device_id_ < other.device_id_;
}
return type_ < other.type_;
}
private:
std::string key_ = "1_0_0";
std::string desc_ = "1_0_0(-1)";
int32_t type_ = static_cast<int32_t>(CPU);
int32_t node_id_ = 0;
int32_t device_id_ = 0;
int32_t proxy_device_id_ = -1;
int32_t hcom_device_id_ = 0;
int32_t os_id_ = 0;
bool support_flowgw_ = true;
};
struct ExtendedIndexInfo {
DeviceInfo device_info;
std::string submodel_instance_name;
bool is_normal;
};
struct DynamicGroupRouteInfo {
int32_t entry_index;
int32_t endpoint_index;
ExtendedIndexInfo extended_info;
bool is_redundant;
};
struct DstGroupInfo {
uint32_t model_id;
std::vector<DynamicGroupRouteInfo> routes;
};
using DynamicSchedIndex = std::map<int32_t, std::map<int32_t, std::pair<ExtendedIndexInfo,
std::map<int32_t, DstGroupInfo>>>>;
using DeviceStateList = std::map<DeviceInfo, bool>;
using ModelDeployInfo = std::map<std::string, std::map<std::string, std::set<DeviceInfo>>>;
using AbnormalStatusCallback = std::function<Status(uint32_t, RootModelId2SubmodelName &)>;
struct AbnormalStatusCallbackInfo {
std::mutex mu;
std::map<uint32_t, AbnormalStatusCallback> callback_list;
};
enum class QueueAction {
kDefault,
kControl,
kStatus,
kSched
};
struct QueueInfo {
DeviceInfo device_info;
uint32_t depth = 2U;
int32_t ref_index = -1;
std::string name;
std::string model_instance_name;
std::string enqueue_policy;
bool owned = true;
QueueAction queue_action;
int32_t fusion_offset = 0;
uint32_t instance_num;
uint32_t instance_idx;
uint32_t model_id = 0U;
bool is_dummy = false;
int32_t process_id = 0;
};
struct InvokedModelQueueInfo {
std::vector<int32_t> feed_queue_indices;
std::vector<int32_t> fetch_queue_indices;
};
struct SubmodelInfo {
int32_t process_id = 0;
DeviceInfo device_info;
DeviceInfo queue_device_info;
PneModelPtr model;
std::vector<int32_t> input_queue_indices;
std::vector<int32_t> control_input_queue_indices;
std::vector<int32_t> output_queue_indices;
std::vector<int32_t> control_output_queue_indices;
std::map<std::string, std::string> attrs;
std::map<std::string, InvokedModelQueueInfo> invoked_model_queue_infos;
std::vector<int32_t> status_input_queue_indices;
std::vector<int32_t> status_output_queue_indices;
std::vector<int32_t> sched_input_queue_indices;
std::vector<int32_t> sched_output_queue_indices;
bool is_head = false;
bool is_redundant = false;
};
class DynamicSchedPlan {
public:
const std::vector<int32_t> &GetStatusOutputQueueIndices() const;
const std::vector<int32_t> &GetSchedOutputQueueIndices() const;
const std::vector<int32_t> &GetSchedInputQueueIndices() const;
const std::map<int32_t, int32_t> &GetDatagwRequestBindings() const;
const std::map<int32_t, int32_t> &GetEntryBindings() const;
const DynamicSchedIndex &GetModelIndexInfo() const;
const std::map<std::string, uint32_t> &GetModelInstanceNum() const;
private:
friend class DeployPlannerBase;
std::map<int32_t, int32_t> datagw_request_bindings_;
std::map<int32_t, int32_t> entry_to_dst_index_;
DynamicSchedIndex model_index_info_;
SubmodelInfo root_model_info_;
std::map<std::string, uint32_t> submodels_id_;
std::map<std::string, std::vector<int32_t>> src_endpoint_indices_;
std::map<int32_t, std::map<ModelQueueIndex, std::vector<DeployPlan::QueueInfo>>> endpoint_pairs_;
std::map<std::string, uint32_t> model_instances_num_;
};
Status GetQueueInfo(const int32_t queue_index, const DeployPlan::QueueInfo *&queue_info) const;
const std::vector<QueueInfo> &GetQueueInfoList() const;
const std::vector<QueueInfo> &GetGroupEntryInfoList() const;
const std::vector<std::pair<int32_t, int32_t>> &GetQueueBindings() const;
const std::vector<int32_t> &GetInputQueueIndices() const;
const std::vector<int32_t> &GetControlInputQueueIndices() const;
const std::vector<int32_t> &GetControlOutputQueueIndices() const;
const DeployPlan::DeviceInfo &GetRootModelQueueDeviceInfo() const;
std::vector<int32_t> GetAllInputQueueIndices() const;
const std::vector<int32_t> &GetOutputQueueIndices() const;
const std::map<std::string, SubmodelInfo> &GetSubmodels() const;
std::map<std::string, SubmodelInfo> &MutableSubmodels();
const std::map<int32_t, std::vector<int32_t>> &GetGroups() const;
bool IsGroupEndpoint(const int32_t queue_index) const;
const DynamicSchedPlan &GetDynamicSchedPlan() const;
void SetIsDynamicSched(const bool is_dynamic_sched);
const bool &GetIsDynamicSched() const;
DeployPlan::ModelDeployInfo &GetModelDeployInfos();
const std::map<std::string, std::vector<std::string>> &GetTrimmingEdgesModelInstances() const;
const std::vector<int32_t> GetBroadcastIndices(int32_t src_endpoint_index) const;
void SetEnableExceptionCatch(bool enable_exception_catch);
bool IsEnableExceptionCatch() const;
private:
friend class DeployPlannerBase;
std::string model_name_;
std::vector<QueueInfo> queues_;
std::vector<std::pair<int32_t, int32_t>> queue_bindings_;
std::map<int32_t, int32_t> dst_to_src_bindings_;
std::map<int32_t, std::map<ModelQueueIndex, std::vector<int32_t>>> src_to_dst_endpoints_;
SubmodelInfo root_model_info_;
std::map<std::string, SubmodelInfo> submodels_;
std::map<int32_t, std::vector<int32_t>> groups_;
std::map<std::string, int32_t> groups_key_to_idx_;
std::vector<QueueInfo> group_entries_;
DynamicSchedPlan dynamic_sched_plan_;
bool is_dynamic_sched_ = false;
bool enable_exception_catch_ = false;
ModelDeployInfo model_deploy_infos_;
std::map<std::string, std::vector<std::string>> trimming_edges_model_instance_names_;
};
class DeployPlannerBase {
public:
DeployPlannerBase() = default;
GE_DELETE_ASSIGN_AND_COPY(DeployPlannerBase);
virtual ~DeployPlannerBase() = default;
Status BuildPlan(DeployPlan &deploy_plan);
struct InputGroupAttr {
uint32_t instance_num;
uint32_t instance_idx;
};
protected:
virtual Status PrepareModelsAndRelation(ModelRelation &model_relation) = 0;
virtual void SelectHeadAndTailDevice(DeployPlan::DeviceInfo &device_info) {
device_info = DeployPlan::DeviceInfo();
}
DeployPlan::SubmodelInfo &MutableSubmodelInfo(const std::string &name);
static Status ValidateModelAndRelation(const std::map<std::string, PneModelPtr> &models,
const ModelRelation &model_relation);
bool IsHeadOrTail(const std::string &name) const;
Status CreateEndpointInfo(const DeployPlan::QueueInfo &queue_info);
Status CreateEndpointInfo(const DeployPlan::QueueInfo &queue_info, int32_t &queue_idx);
Status CreateGroupEntry(const DeployPlan::QueueInfo &queue_info, int32_t &entry_index);
Status CreateGroupRefEntry(const DeployPlan::QueueInfo &queue_info,
int32_t endpoint_index,
int32_t &entry_index);
Status CreateGroupQueueEntry(const DeployPlan::QueueInfo &queue_info,
int32_t &queue_index,
int32_t &entry_index);
Status CreateGroupInfo(const DeployPlan::QueueInfo &queue_info,
const std::vector<int32_t> &grouped_indices,
int32_t &group_index);
void AddEndpointBindings(int32_t src_index, int32_t dst_index, bool skip_if_dst_exists = true);
const bool &GetIsDynamicSched() const;
static std::atomic<int64_t> plan_id_gen_;
DeployPlan deploy_plan_;
private:
Status Initialize();
Status ParseModelRelation();
void UpdateForInputControlIo();
void UpdateForOutputControlIo();
void UpdateRelationForControlIo();
Status AssignEnqueueQueues();
Status ResolveEnqueueFusion();
Status ResolveDequeueFusion(int32_t src_endpoint_idx, int32_t dst_endpoint_idx);
Status ResolveInputsPlacement(const std::string &model_instance_name,
const ModelRelation::ModelEndpointInfo &model_endpoint_info);
Status ResolveModelFusion(const std::string &model_instance_name,
const ModelRelation::ModelEndpointInfo &model_endpoint_info);
bool CanBeFused(const std::string &fusion_name, const std::string &endpoint_name);
void UpdateFusionOffset(int32_t src_index, int32_t dst_index);
Status ResolveInvokedFusion();
static Status GetInvokedModelFusionInputs(const PneModelPtr model,
std::map<std::string, std::string> &fusion_inputs);
static Status ParseInputIndexWithRange(const std::string &fusion_input_str,
std::vector<size_t> &fusion_input_index_list);
static Status ParseInvokedModelFusionInputs(const std::string &fusion_inputs_str,
std::vector<std::vector<size_t>> &fusion_inputs_list);
Status ResolveModelInvokedFusion(const std::string &model_instance_name,
const std::vector<std::string> &queue_names,
const std::string &invoke_key,
const std::string &fusion_inputs);
void MarkMultiDeployedModels();
Status AdjustEnqueueDevices();
Status AdjustEnqueueDevice(DeployPlan::QueueInfo &src_endpoint,
const std::map<ModelQueueIndex, std::vector<DeployPlan::QueueInfo>> &dst_endpoints) const;
Status AdjustDequeueDevice(DeployPlan::QueueInfo &dst_endpoint, const std::vector<int32_t> &src_endpoint_indices);
Status ResolveDataFlows();
Status ResolveModelInputs(const std::string &model_instance_name,
const ModelRelation::ModelEndpointInfo &model_endpoint_info);
Status ResolveModelDynamicInputs(const std::string &model_instance_name,
const ModelRelation::ModelEndpointInfo &model_endpoint_info);
void LogDataFlow() const;
Status ResolveReusableQueues();
Status AssignDequeueQueues();
Status BindRemoteOutputGroupToInput();
Status BindOutputToRemoteInputs();
void UpdateDeployPlan();
Status CreateOutputQueueDefs(const std::string &model_instance_name,
const std::vector<std::string> &queue_names,
const bool is_owned = true);
Status CreateFeedEndpoints(const std::string &model_instance_name,
const std::vector<std::string> &queue_names,
const std::string &invoke_key);
Status GetOrCreateInputEndpoint(const ModelQueueIndex &model_queue_index,
const DeployPlan::QueueInfo &queue_info,
int32_t &endpoint_index);
void AddInputGroups(const int32_t dst_endpoint_idx,
const int32_t src_tag_idx,
const InputGroupAttr &input_group_attr);
std::vector<std::string> ToEndpointDescs(const std::vector<int32_t> &endpoint_indices,
const bool is_group_entry = false) const;
std::string ToEndpointDesc(const int32_t endpoint_indices, const bool is_group_entry = false) const;
DeployPlan::QueueInfo BuildQueueInfo(const Endpoint &queue_def,
const std::string &model_instance_name);
std::string GenShortName(const std::string &name);
std::string GetEndpointFullName(const DeployPlan::QueueInfo &endpoint_info, const ModelQueueIndex &model_queue_index);
const std::string &GetSubmodelType(const std::string &name);
bool CheckAndAddRelation(const int32_t src_endpoint_idx,
const int32_t dst_endpoint_idx,
const std::string &suffix = "");
bool IsOutputMultiConnected(const int32_t src_endpoint_idx);
bool IsInputMultiConnected(const int32_t dst_endpoint_idx);
bool IsMultiDeployed(const std::string &model_instance_name) const;
bool CheckSkipBinding(const std::string &src_model_instance_name,
const std::string &dst_model_instance_name);
static bool CanConnectWithQ(const DeployPlan::DeviceInfo &src_device_info,
const DeployPlan::DeviceInfo &dst_device_info);
static bool CanConnectWithLocalQ(const DeployPlan::DeviceInfo &src_device_info,
const DeployPlan::DeviceInfo &dst_device_info);
Status GetOrCreateMappingTagPairEntry(const int32_t endpoint_idx,
const DeployPlan::QueueInfo &mapping_queue_info,
std::pair<int32_t, int32_t> &tag_pair,
bool use_balanced = true);
void GenTagEntityPair(int32_t endpoint_idx,
const DeployPlan::QueueInfo &mapping_queue_info,
std::pair<DeployPlan::QueueInfo, DeployPlan::QueueInfo> &entity_pair);
Status GetOrCreateMappingEntry(const int32_t endpoint_idx,
const DeployPlan::QueueInfo &mapping_queue_info,
int32_t &mapping_idx);
Status PrepareDiffNodeRelation(const int32_t src_endpoint_idx,
const int32_t dst_endpoint_idx,
const ModelQueueIndex &model_queue_loc,
const DeployPlan::QueueInfo &queue_info,
const InputGroupAttr &input_group_attr);
Status PrepareSameNodeRelation(const int32_t src_endpoint_idx,
const int32_t dst_endpoint_idx,
const ModelQueueIndex &model_queue_loc,
const DeployPlan::QueueInfo &queue_info,
const InputGroupAttr &input_group_attr);
Status PrepareQueuesRelation(const int32_t src_endpoint_idx,
const int32_t dst_endpoint_idx,
const ModelQueueIndex &model_queue_loc,
const DeployPlan::QueueInfo &queue_info,
const InputGroupAttr &input_group_attr);
Status PrepareRelations(const int32_t src_endpoint_idx,
const int32_t dst_endpoint_idx,
const ModelQueueIndex &model_queue_loc,
const DeployPlan::QueueInfo &queue_info,
const InputGroupAttr &input_group_attr);
Status CreateAndBindGroup(const DeployPlan::QueueInfo &group_info,
const std::vector<int32_t> &group_entry_index,
const int32_t dst_endpoint_index,
const bool skip_if_dst_exists = true);
Status AssignDynamicSchedDequeueQueues();
Status AssignDynamicSchedDequeueQueue(const DeployPlan::QueueInfo &queue_info,
const ModelQueueIndex &model_queue_loc,
const int32_t &src_endpoint_idx);
Status CreateDynamicSchedOutputQueueDefs(const std::string &model_instance_name,
const std::vector<std::string> &queue_names,
const bool is_owned = true);
Status CreateDynamicSchedTags(const int32_t src_endpoint_idx,
const int32_t dst_endpoint_idx,
const DeployPlan::QueueInfo &queue_info);
void GenerateDynamicSchedModelId();
Status AssignDynamicSchedEnqueueQueues();
void UpdateRelationForDynamicSched();
Status DynamicSchedBindGroup2Queue(const int32_t src_idx,
const int32_t dst_idx,
int32_t &group_index);
Status DynamicSchedBindQueue2Group(const int32_t src_idx,
const int32_t dst_idx,
int32_t &group_index);
void UpdateDynamicSchedDeployPlan();
void DynamicSchedGroupFormat(const int32_t &real_entry_index,
const int32_t &entry_index,
const DeployPlan::QueueInfo *src_queue_info,
const int32_t &src_q_idx,
const int32_t &dst_q_idx);
void AddDependentDevice(std::set<DeployPlan::DeviceInfo> &device_infos,
const std::vector<int32_t> &queue_indexs);
void BuildModelDeployInfos();
Status BuildDynamicSchedInfo();
Status SetHeadNodeInfo();
void AddTrimmingEdgesModelInstance(const std::string &src_model_instance_name,
const std::string &dst_model_instance_name);
bool IsContainInvokedModel(const std::string &src_model_instance_name,
const std::string &dst_model_instance_name);
void BindDynamicSchedDevQueue(const int32_t src_endpoint_idx,
const int32_t dst_endpoint_idx);
Status BindDynamicSchedHostQueue(const DeployPlan::DeviceInfo &src_device_info,
const DeployPlan::DeviceInfo &dst_device_info,
DeployPlan::QueueInfo &entry_info,
int32_t &src_endpoint_idx,
int32_t &dst_endpoint_idx);
ModelRelation model_relation_;
std::unique_ptr<ModelRelationReader> relation_reader_;
std::map<std::string, std::vector<int32_t>> src_endpoint_indices_;
std::map<int32_t, std::map<ModelQueueIndex, std::vector<DeployPlan::QueueInfo>>> endpoint_pairs_;
std::map<std::string, std::set<int32_t>> relation_dst_to_src_;
std::set<int32_t> reusable_queue_indices_;
std::map<std::tuple<ModelQueueIndex, std::string, int32_t>, int32_t> input_endpoint_indices_;
std::map<int32_t, std::map<std::string, int32_t>> dequeue_ref_indices_;
std::map<int32_t, std::map<ModelQueueIndex, std::map<std::string, int32_t>>> output_groups_;
std::map<std::string, std::set<std::string>> dequeue_placements_;
std::set<std::string> disable_fusion_queues_;
std::map<int32_t, std::vector<int32_t>> input_groups_;
std::map<int32_t, InputGroupAttr> input_groups_attr_;
ModelRelation::ModelEndpointInfo head_model_queue_info_;
ModelRelation::ModelEndpointInfo tail_model_queue_info_;
DeployPlan::SubmodelInfo head_model_info_;
DeployPlan::SubmodelInfo tail_model_info_;
static std::atomic<int64_t> endpoint_name_id_gen_;
std::map<std::string, std::string> short_names_;
std::map<std::string, std::string> instance_to_model_name_;
std::map<std::string, std::vector<std::pair<std::string, DeployPlan::DeviceInfo>>> model_deploy_locations_;
std::map<std::pair<int32_t, std::string>, int32_t> endpoint_device_mapping_;
std::map<std::pair<int32_t, std::string>, std::pair<int32_t, int32_t>> endpoint_device_tags_mapping_;
std::set<std::string> relations_;
std::set<std::string> no_group_endpoint_names_;
std::map<std::string, int32_t> model_name_to_id_;
};
class ModelRelationFlattener {
public:
explicit ModelRelationFlattener(PneModelPtr root_model);
Status Flatten(ModelRelation &flattened_model_relation, std::map<std::string, PneModelPtr> &name_to_models);
static Status Flatten(const PneModelPtr &root_model);
private:
Status FlattenSubmodel(const ModelRelation::ModelEndpointInfo &parent_model_queue_info,
const PneModelPtr &pne_model,
const int32_t depth);
void MergeEndpoints(const std::map<std::string, std::string> &name_refs,
const std::vector<Endpoint> &endpoints);
static void ReplaceQueueNames(const std::map<std::string, std::string> &name_refs, std::vector<std::string> &names);
static std::map<std::string, std::string> BuildNameRefs(
const ModelRelation::ModelEndpointInfo &parent_model_queue_info,
const ModelRelation::ModelEndpointInfo &root_model_queue_info);
static Status CheckConsistency(const ModelRelation::ModelEndpointInfo &parent_model_queue_info,
const ModelRelation::ModelEndpointInfo &root_model_queue_info);
static bool NeedFlatten(const PneModelPtr &root_model);
PneModelPtr root_model_;
ModelRelation flattened_model_relation_;
std::map<std::string, PneModelPtr> leaf_models_;
int32_t max_depth_ = 16;
};
class DeployPlanner : public DeployPlannerBase {
public:
explicit DeployPlanner(const PneModelPtr &root_model);
~DeployPlanner() override = default;
protected:
Status PrepareModelsAndRelation(ModelRelation &model_relation) override;
private:
const PneModelPtr root_model_;
};
}
#endif