* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef GE_GRAPH_BUILD_STREAM_LOGICAL_STREAM_ALLOCATOR_H_
#define GE_GRAPH_BUILD_STREAM_LOGICAL_STREAM_ALLOCATOR_H_
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "engines/manager/engine_manager/dnnengine_manager.h"
#include "graph/manager/graph_manager_utils.h"
#include "stream_utils.h"
namespace ge {
#define STREAM_PASS_DEFAULT_FUNC(CLASS) \
CLASS() : LogicalStreamPass(#CLASS) {} \
~CLASS() override = default; \
CLASS(const CLASS &) = delete; \
CLASS &operator=(const CLASS &) = delete
#define OPTIMIZE_BY_STRUCTURE_PASS_DEFAULT_FUNC(CLASS) \
CLASS() : OptimizeByTopoPass(#CLASS) {} \
~CLASS() override = default; \
CLASS(const CLASS &) = delete; \
CLASS &operator=(const CLASS &) = delete
class LogicalStreamPass {
public:
struct Context {
int64_t default_stream = kInvalidStream;
int64_t next_stream = 0;
bool enable_single_stream = false;
bool enable_hcom_parallel = false;
};
explicit LogicalStreamPass(const std::string &name);
LogicalStreamPass(const LogicalStreamPass &) = delete;
LogicalStreamPass &operator=(const LogicalStreamPass &) = delete;
virtual ~LogicalStreamPass() = default;
const std::string &GetName() const;
virtual Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) = 0;
private:
std::string name_;
friend class LogicalStreamAllocator;
};
using LogicalStreamPassPtr = std::shared_ptr<LogicalStreamPass>;
class AssignByLabelPass : public LogicalStreamPass {
public:
STREAM_PASS_DEFAULT_FUNC(AssignByLabelPass);
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
};
class IndependentStreamPass : public LogicalStreamPass {
public:
STREAM_PASS_DEFAULT_FUNC(IndependentStreamPass);
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
};
class AssignByDependencyPass : public LogicalStreamPass {
public:
STREAM_PASS_DEFAULT_FUNC(AssignByDependencyPass);
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
private:
SubgraphPtr GetReusableSubgraph(const SubgraphPtr &subgraph,
const std::unordered_map<NodePtr, SubgraphPtr> &end_subgraph_map,
const std::unordered_map<NodePtr, SubgraphPtr> &pld_subgraph_map) const;
int64_t AssignNewStream(SubgraphPtr subgraph);
void UpdateAssignedSubgraphs(Context &context);
void UpdateReusedSubgraphs();
bool IsForceAttach(const SubgraphPtr &subgraph) const;
bool CouldReuse(const SubgraphPtr &subgraph, const SubgraphPtr &pred_subgraph,
const std::unordered_map<NodePtr, SubgraphPtr> &pld_subgraph_map) const;
bool SubGraphCouldReuse(const SubgraphPtr &subgraph, const SubgraphPtr &pred_subgraph,
const std::unordered_map<NodePtr, SubgraphPtr> &pld_subgraph_map) const;
bool IsMemoryPriority() const;
std::map<std::string, int64_t> engine_next_streams_;
std::map<std::string, int64_t> engine_stream_num_;
std::vector<SubgraphPtr> assigned_subgraphs_;
std::vector<std::pair<SubgraphPtr, SubgraphPtr>> reused_subgraphs_;
mutable std::unordered_set<SubgraphPtr> visited_subgraphs_;
bool is_memory_priority_ {false};
};
class SingleStreamPass : public LogicalStreamPass {
public:
STREAM_PASS_DEFAULT_FUNC(SingleStreamPass);
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
};
class NodeStreamUpdatePass : public LogicalStreamPass {
public:
STREAM_PASS_DEFAULT_FUNC(NodeStreamUpdatePass);
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
};
class UpdateForParallelGroupPass : public LogicalStreamPass {
public:
STREAM_PASS_DEFAULT_FUNC(UpdateForParallelGroupPass);
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
private:
Status UpdateStreamIdFromPreNode(const NodePtr &cur_node,
const std::unordered_map<ge::NodePtr, ge::NodePtr> &total_pld_to_end) const;
};
class UpdateForMdeGroupPass : public LogicalStreamPass {
public:
STREAM_PASS_DEFAULT_FUNC(UpdateForMdeGroupPass);
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
};
class UpdateForSkippedEnginePass : public LogicalStreamPass {
public:
STREAM_PASS_DEFAULT_FUNC(UpdateForSkippedEnginePass);
NodeA(stream1) -> Const(stream2) -> NodeB(stream1)
To case:
NodeA(stream1) -> Const(stream1) -> NodeB(stream1)
Which could reduce event number (Const could be other type which belong to skipped engine subgraph)
*/
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
private:
int64_t GetSingleInoutStream(const NodePtr &node) const;
bool AreAllPredStreamsInvalid(const NodePtr &node) const;
};
class AllReduceParallelPass : public LogicalStreamPass {
public:
STREAM_PASS_DEFAULT_FUNC(AllReduceParallelPass);
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override;
private:
bool IsHcomNode(const std::string& node_type) const;
int64_t GetFusion(const NodePtr &node) const;
};
class OptimizeByTopoPass {
public:
explicit OptimizeByTopoPass(const std::string &name);
OptimizeByTopoPass(const OptimizeByTopoPass &) = delete;
OptimizeByTopoPass &operator=(const OptimizeByTopoPass &) = delete;
virtual ~OptimizeByTopoPass() = default;
virtual Status Run(const ComputeGraphPtr &graph) = 0;
const std::string &GetName() const;
private:
std::string name_;
};
using OptimizeByTopoPassPtr = std::shared_ptr<OptimizeByTopoPass>;
class OptimizeIneffectiveMultiStreamPass : public OptimizeByTopoPass {
public:
OPTIMIZE_BY_STRUCTURE_PASS_DEFAULT_FUNC(OptimizeIneffectiveMultiStreamPass);
Status Run(const ComputeGraphPtr &graph) override;
};
class LogicalStreamAllocator {
using Context = LogicalStreamPass::Context;
public:
explicit LogicalStreamAllocator(const std::map<std::string, int32_t> &max_parallel_num);
LogicalStreamAllocator(const LogicalStreamAllocator &) = delete;
LogicalStreamAllocator &operator=(const LogicalStreamAllocator &) = delete;
~LogicalStreamAllocator() = default;
void EnableSingleStream(bool enable);
void EnableHcomParallel(bool enable);
Status Assign(const ComputeGraphPtr &root_graph, const Graph2SubGraphInfoList &subgraph_map,
int64_t &total_stream_num, int64_t &main_stream_num);
private:
Status DoAssign(const ComputeGraphPtr &graph, const Graph2SubGraphInfoList &subgraph_map,
const std::map<std::string, EngineConfPtr> &engine_confs);
Status RunPasses(const ComputeGraphPtr &graph, const std::vector<SubgraphPtr> &subgraphs);
Status RunOptimizeByTopoPasses(const ComputeGraphPtr &graph);
void RefreshContinuousStreams(const ComputeGraphPtr &graph);
const std::map<std::string, int32_t> &max_parallel_num_;
Context context_;
};
}
#endif