* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef GE_GRAPH_PARTITION_BASE_PARTITION_H_
#define GE_GRAPH_PARTITION_BASE_PARTITION_H_
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "framework/common/ge_inner_error_codes.h"
#include "graph/compute_graph.h"
#include "graph/partition/base_cluster.h"
namespace ge {
const char_t *const kClusterData = "DATA";
const char_t *const kClusterInputNode = "INPUT_NODE";
const char_t *const kClusterNetoutput = "NETOUTPUT";
const char_t *const kClusterStage = "STAGE";
constexpr int32_t kDataTypeIndex = 0;
constexpr int32_t kNetOutputTypeIndex = 1;
constexpr int32_t kInputNodeTypeIndex = 2;
class BasePartitioner {
public:
friend class BaseCluster;
friend class HeterogeneousCluster;
friend class FlowCluster;
friend class DynamicShapeCluster;
using ClusterFilter = std::function<bool(const std::shared_ptr<BaseCluster> &cluster)>;
using ClusterCompareFunc =
std::function<bool(const std::shared_ptr<BaseCluster> &a, const std::shared_ptr<BaseCluster> &b)>;
explicit BasePartitioner(ge::ComputeGraphPtr graph) : root_graph_(std::move(graph)) {}
virtual ~BasePartitioner() = default;
virtual void LogDebugString(bool is_failed) const;
virtual Status Partition() = 0;
virtual std::string GetSubgraphName(const BaseCluster &cluster) const = 0;
virtual bool IsNeedBuildPartitionFrame(const BaseCluster &cluster) const = 0;
virtual void ClearResource();
void DumpGraph(const std::string &suffix) const;
Status TopologicalSortClusters(const ClusterFilter &cluster_filter);
std::vector<std::shared_ptr<BaseCluster>> GetUniqueClusters(const ClusterFilter &cluster_filter);
Status PartitionImpl();
virtual Status BuildPartitionFrame();
virtual Status CombinePartitionFrame() const;
virtual Status BuildPartitionSubgraph();
virtual Status InitClusters() = 0;
virtual Status MergeClusters() = 0;
virtual Status ProcessUniqueClusters() = 0;
virtual std::string GetPartitionName() const = 0;
static Status CheckWritableVarNode(const ComputeGraphPtr &root_graph);
protected:
const ComputeGraphPtr &GetRootGraph() const { return root_graph_; }
void SetRootGraph(const ComputeGraphPtr &root_graph) { root_graph_ = root_graph; }
const std::shared_ptr<BaseCluster> &GetCluster(const NodePtr &node) { return node_2_cluster_[node]; }
void SetCluster(const NodePtr &node, const shared_ptr<BaseCluster> &cluster) {
node_2_cluster_[node] = cluster;
}
void RecordClusters(const shared_ptr<BaseCluster> &cluster) {
clusters_.emplace(cluster);
}
void ClearClusters() {
clusters_.clear();
node_2_cluster_.clear();
ordered_cluster_.clear();
unique_clusters_.clear();
sorted_unique_clusters_.clear();
type_index_to_type_.clear();
}
Status SortUniqueClusters(const ClusterCompareFunc &compare_func);
const std::vector<std::shared_ptr<BaseCluster>> &GetOrderedClusters() const {
return ordered_cluster_;
}
const std::unordered_set<std::shared_ptr<BaseCluster>> &GetUniqueClusters() const {
return unique_clusters_;
}
const std::vector<std::shared_ptr<BaseCluster>> &GetSortedUniqueClusters() const {
return sorted_unique_clusters_;
}
void InsertIndexType(const int32_t index, const std::string &type);
std::string GetTypeByIndex(const int32_t index) const {
const auto &it = type_index_to_type_.find(index);
if (it == type_index_to_type_.end()) {
return "";
}
return it->second;
}
std::vector<std::shared_ptr<BaseCluster>> sorted_unique_clusters_;
private:
ge::ComputeGraphPtr root_graph_;
std::unordered_map<NodePtr, std::shared_ptr<BaseCluster>> node_2_cluster_;
std::unordered_set<std::shared_ptr<BaseCluster>> clusters_;
std::vector<std::shared_ptr<BaseCluster>> ordered_cluster_;
std::unordered_set<std::shared_ptr<BaseCluster>> unique_clusters_;
std::map<int32_t, std::string> type_index_to_type_;
};
}
#endif