* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CANN_GRAPH_ENGINE_PASS_REGISTRY_H
#define CANN_GRAPH_ENGINE_PASS_REGISTRY_H
#include <cstdint>
#include <map>
#include <string>
#include <vector>
#include "ge/fusion/pass/fusion_pass_reg.h"
namespace ge {
namespace fusion {
enum class PythonPassKind : uint32_t {
kFusionBase = 0U,
kPatternFusion = 1U,
kDecompose = 2U
};
struct PythonPassCreateContext {
std::string descriptor_key;
std::string pass_name;
PythonPassKind kind{PythonPassKind::kFusionBase};
bool IsValid() const {
return (!descriptor_key.empty()) && (!pass_name.empty());
}
};
struct PythonPassDescriptor {
std::string descriptor_key;
std::string pass_name;
std::string module_name;
std::string class_name;
CustomPassStage stage{CustomPassStage::kBeforeInferShape};
PythonPassKind kind{PythonPassKind::kFusionBase};
std::vector<std::string> op_types;
bool IsValid() const {
return (!descriptor_key.empty()) && (!pass_name.empty());
}
PythonPassCreateContext ToCreateContext() const {
return PythonPassCreateContext{descriptor_key, pass_name, kind};
}
};
class PythonPassCreateScope {
public:
explicit PythonPassCreateScope(const PythonPassCreateContext &create_context);
~PythonPassCreateScope();
PythonPassCreateScope(const PythonPassCreateScope &) = delete;
PythonPassCreateScope &operator=(const PythonPassCreateScope &) = delete;
private:
bool has_previous_{false};
PythonPassCreateContext previous_context_;
};
void SetCurrentPythonPassCreateContext(const PythonPassCreateContext &create_context);
bool GetCurrentPythonPassCreateContext(PythonPassCreateContext &create_context);
void ClearCurrentPythonPassCreateContext();
inline std::string CustomPassStageToString(CustomPassStage stage) {
static const std::map<CustomPassStage, std::string> kCustomPassStageToStringMap = {
{CustomPassStage::kBeforeInferShape, "BeforeInferShape"},
{CustomPassStage::kAfterInferShape, "AfterInferShape"},
{CustomPassStage::kAfterAssignLogicStream, "AfterAssignLogicStream"},
{CustomPassStage::kAfterBuiltinFusionPass, "AfterBuiltinFusionPass"},
{CustomPassStage::kAfterOriginGraphOptimize, "AfterOriginGraphOptimize"},
{CustomPassStage::kCompatibleInherited, "CompatibleInherited"},
{CustomPassStage::kInvalid, "InvalidStage"}
};
if (stage > CustomPassStage::kInvalid) {
return "";
}
return kCustomPassStageToStringMap.find(stage)->second;
}
class PassRegistry {
public:
~PassRegistry();
static PassRegistry &GetInstance();
void RegisterFusionPass(FusionPassRegistrationData ®_data);
bool RegisterPythonPass(const PythonPassDescriptor &pass_desc, const CreateFusionPassFn &create_fusion_pass_fn);
std::vector<FusionPassRegistrationData> GetFusionPassRegDataByStage(CustomPassStage stage) const;
bool GetPythonPassDescriptor(const std::string &descriptor_key, PythonPassDescriptor &pass_desc) const;
bool ResolveCurrentPythonPassDescriptor(PythonPassDescriptor &pass_desc) const;
FusionBasePass *CreatePass(const FusionPassRegistrationData ®_data) const;
void ClearPythonPasses();
private:
PassRegistry();
bool GetPythonPassCreateContext(const std::string &pass_name, PythonPassCreateContext &create_context) const;
std::map<std::string, FusionPassRegistrationData> name_2_fusion_pass_regs_;
std::map<std::string, PythonPassDescriptor> descriptor_key_2_python_pass_descs_;
std::map<std::string, PythonPassCreateContext> pass_name_2_python_pass_create_contexts_;
};
}
}
#endif