* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef GE_GRAPH_PARTITION_BASE_CLUSTER_H_
#define GE_GRAPH_PARTITION_BASE_CLUSTER_H_
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <mutex>
#include "framework/common/ge_inner_error_codes.h"
#include "graph/compute_graph.h"
#include "common/checker.h"
namespace ge {
class BasePartitioner;
using SetAttrFn = bool (*)(AttrUtils::AttrHolderAdapter &&obj, const std::string &attr_name, const AnyValue &value);
using GetAttrFn = bool (*)(AttrUtils::ConstAttrHolderAdapter &&obj, const std::string &attr_name, AnyValue &val);
class PartitionNodeAttrName {
public:
PartitionNodeAttrName(const std::string &attr_name, const bool is_support_tensor, const bool is_need_copy,
const SetAttrFn set_fn, const GetAttrFn get_fn)
: attr_name_(attr_name),
is_support_tensor_(is_support_tensor),
is_need_copy_(is_need_copy),
set_fn_(set_fn),
get_fn_(get_fn) {}
~PartitionNodeAttrName() = default;
bool IsSupportTensor() const { return is_support_tensor_; }
bool IsNeedCopy() const { return is_need_copy_; }
const std::string &GetAttrName() const { return attr_name_; }
SetAttrFn SetAttrFunction() const { return set_fn_; }
GetAttrFn GetAttrFunction() const { return get_fn_; }
private:
std::string attr_name_;
bool is_support_tensor_;
bool is_need_copy_;
SetAttrFn set_fn_;
GetAttrFn get_fn_;
};
class PartitionNodeAttrNameManager {
public:
PartitionNodeAttrNameManager(const PartitionNodeAttrNameManager &attr_name) = delete;
PartitionNodeAttrNameManager &operator=(const PartitionNodeAttrNameManager &attr_name) = delete;
static PartitionNodeAttrNameManager &GetInstance();
void RegisterNodeAttrName(const PartitionNodeAttrName &attr_name);
const std::vector<PartitionNodeAttrName> &GetAllPartitionNodeAttrNames() const { return attr_array_; }
private:
PartitionNodeAttrNameManager() = default;
~PartitionNodeAttrNameManager() = default;
std::vector<PartitionNodeAttrName> attr_array_;
mutable std::mutex mutex_;
};
class PartitionNodeAttrRegister {
public:
PartitionNodeAttrRegister(const std::string &attr_name, bool is_support_tensor,
const bool is_need_copy, const SetAttrFn set_fn,
const GetAttrFn get_fn) noexcept;
~PartitionNodeAttrRegister() = default;
PartitionNodeAttrRegister(const PartitionNodeAttrRegister &other) = delete;
PartitionNodeAttrRegister &operator=(const PartitionNodeAttrRegister &other) = delete;
};
class BaseCluster : public std::enable_shared_from_this<BaseCluster> {
public:
BaseCluster(size_t rank, int32_t type_index, NodePtr node, BasePartitioner *partitioner)
: type_index_(type_index),
id_(rank),
min_(rank),
max_(rank),
partitioner_(partitioner) {
nodes_.push_back(node);
}
virtual ~BaseCluster() = default;
std::string DebugString() const;
size_t Id() const;
size_t MinId() const;
size_t UniqueId() const;
void UpdateRank(size_t rank);
bool IsData() const;
bool IsIndependent() const;
bool IsNetOutput() const;
std::vector<BaseCluster *> Inputs() const;
std::vector<BaseCluster *> Outputs() const;
bool IsInputNode() const;
const std::vector<NodePtr> &Nodes() const;
bool IsRefVariable() const;
void AddInput(std::shared_ptr<BaseCluster> in);
void RemoveInput(std::shared_ptr<BaseCluster> in);
void AddOutput(std::shared_ptr<BaseCluster> out);
void RemoveOutput(std::shared_ptr<BaseCluster> out);
virtual void Merge(std::shared_ptr<BaseCluster> other);
bool TryMerge(std::shared_ptr<BaseCluster> other);
std::vector<std::shared_ptr<BaseCluster>> MergeAllPathFrom(const std::shared_ptr<BaseCluster> &other);
bool AddFrameInput(InDataAnchorPtr anchor);
void AddFrameOutput(OutDataAnchorPtr anchor);
InDataAnchorPtr GetFrameInDataAnchor(InDataAnchorPtr anchor);
OutDataAnchorPtr GetFrameOutDataAnchor(OutDataAnchorPtr anchor);
InControlAnchorPtr GetFrameInControlAnchor();
OutControlAnchorPtr GetFrameOutControlAnchor();
Status BuildFrame();
Status AddPartitionedCallNodeOutput(const NodePtr &node, const OpDescPtr &partition_op);
Status BuildPartitionNodes(const OpDescPtr &partition_op);
Status RemoveNodeFromRoot(const ge::ComputeGraphPtr &graph);
Status CombinePartitionFrame();
virtual Status BuildPartitionFrame();
virtual Status BuildPartitionSubgraph();
virtual Status SetAttrToNetoutput(const OpDescPtr &op) {
(void)op;
return SUCCESS;
}
static Status CopyOpAttr(const OpDescPtr &from, OpDescPtr &to);
static Status UpdateFrameTensor(const OpDescPtr &op_desc, GeTensorDescPtr &tensor);
static Status CopyAttr(AttrUtils::ConstAttrHolderAdapter &&from, AttrUtils::AttrHolderAdapter &&to,
const PartitionNodeAttrName &item);
void MergeAllPath(const std::vector<std::shared_ptr<BaseCluster>> &path_clusters);
void Clear();
int32_t GetTypeIndex() const {
return type_index_;
}
void SetTypeIndex(int32_t type_index) {
type_index_ = type_index;
}
void SetMergeInputs(bool merge_inputs);
std::string GetPartitionedCallName() {
if (partition_node_ != nullptr) {
return partition_node_->GetName();
}
return "";
}
Status SetGraphId(const uint32_t graph_id) const;
protected:
std::vector<NodePtr> &GetMutableDirectNode();
size_t GetSubgraphNodesSize() const { return nodes_.size(); }
const std::unordered_map<InDataAnchorPtr, size_t> &GetInputIndexs() const { return inputs_index_; }
const std::unordered_map<OutDataAnchorPtr, size_t> &GetOutputIndexs() const { return outputs_index_; }
ComputeGraphPtr &GetMutableSubgraph() { return subgraph_; }
const ComputeGraphPtr &GetSubgraph() const { return subgraph_; }
const NodePtr &GetPartitionNode() const { return partition_node_; }
const BasePartitioner *GetPartitioner() const { return partitioner_; };
void UpdateInOutClusters(const std::unordered_set<BaseCluster *> &path_cluster_set);
void MergeInOutClusters(const std::unordered_set<BaseCluster *> &path_cluster_set,
const std::vector<BaseCluster *> &ordered_path_clusters);
private:
static thread_local size_t unique_id_;
int32_t type_index_;
size_t id_;
size_t min_;
size_t max_;
std::vector<BaseCluster *> in_clusters_;
std::vector<BaseCluster *> out_clusters_;
std::vector<InDataAnchorPtr> inputs_;
std::map<size_t, std::vector<InDataAnchorPtr>> data_to_node_inputs_;
std::map<std::string, size_t> src_key_to_frame_input_;
std::vector<OutDataAnchorPtr> outputs_;
std::unordered_set<std::shared_ptr<BaseCluster>> control_inputs_;
std::unordered_set<OutControlAnchorPtr> control_outputs_;
std::vector<NodePtr> nodes_;
std::unordered_map<InDataAnchorPtr, size_t> inputs_index_;
std::unordered_map<OutDataAnchorPtr, size_t> outputs_index_;
BasePartitioner *partitioner_{nullptr};
ComputeGraphPtr subgraph_{nullptr};
NodePtr partition_node_{nullptr};
bool merge_inputs_{false};
};
}
#define REGISTER_PARTITION_ATTR_NAME(attr_name, is_support_tensor, is_need_copy, set_attr_fn, get_attr_fn) \
static ge::PartitionNodeAttrRegister g_##attr_name##_register(attr_name, is_support_tensor, is_need_copy, \
set_attr_fn, get_attr_fn)
#define SET_ATTR_UTIL_IMP(ArgType, AttrUtilsFun) \
static bool Set##AttrUtilsFun##Attr(ge::AttrUtils::AttrHolderAdapter &&obj, const std::string &attr_name, \
const ge::AnyValue &value) { \
ArgType val; \
(void)value.GetValue(val); \
return ge::AttrUtils::Set##AttrUtilsFun(obj.get(), attr_name, val); \
}
#define GET_ATTR_UTIL_IMP(ArgType, AttrUtilsFun) \
static bool Get##AttrUtilsFun##Attr(AttrUtils::ConstAttrHolderAdapter &&obj, const std::string &attr, \
AnyValue &value) { \
ArgType val; \
const auto res = AttrUtils::Get##AttrUtilsFun(obj.get(), attr, val); \
value.SetValue(val); \
return res; \
}
#endif