* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CANN_GRAPH_ENGINE_PYTHON_PASS_ADAPTER_H
#define CANN_GRAPH_ENGINE_PYTHON_PASS_ADAPTER_H
#include <memory>
#include <utility>
#include <vector>
#include "ge/ge_api_types.h"
#include "ge/fusion/pass/decompose_pass.h"
#include "ge/fusion/pass/fusion_base_pass.h"
#include "ge/fusion/pass/pattern_fusion_pass.h"
#include "pass_registry.h"
namespace ge {
namespace fusion {
using PythonFusionBasePassHolderCreateFn = void *(*)(const PythonPassDescriptor *pass_desc);
using PythonFusionBasePassHolderDestroyFn = void (*)(void *holder);
using PythonFusionBasePassRunFn = Status (*)(const void *holder, GraphPtr &graph, CustomPassContext &pass_context);
using PythonFusionPassGetMatcherConfigFn =
Status (*)(const void *holder, std::unique_ptr<PatternMatcherConfig> &matcher_config);
using PythonFusionPassPatternsFn = Status (*)(const void *holder, std::vector<PatternUniqPtr> &patterns);
using PythonFusionPassMeetRequirementsFn = bool (*)(const void *holder,
const std::unique_ptr<MatchResult> &match_result);
using PythonFusionPassReplacementFn =
Status (*)(const void *holder, const std::unique_ptr<MatchResult> &match_result, GraphUniqPtr &replacement_graph);
using PythonDecomposePassMeetRequirementsFn = bool (*)(const void *holder, const GNode &matched_node);
using PythonDecomposePassReplacementFn =
Status (*)(const void *holder, const GNode &matched_node, GraphUniqPtr &replacement_graph);
struct PythonFusionPassCallbacks {
PythonFusionBasePassHolderCreateFn create{nullptr};
PythonFusionBasePassHolderDestroyFn destroy{nullptr};
PythonFusionBasePassRunFn run{nullptr};
PythonFusionPassGetMatcherConfigFn get_matcher_config{nullptr};
PythonFusionPassPatternsFn patterns{nullptr};
PythonFusionPassMeetRequirementsFn meet_requirements{nullptr};
PythonFusionPassReplacementFn replacement{nullptr};
PythonDecomposePassMeetRequirementsFn decompose_meet_requirements{nullptr};
PythonDecomposePassReplacementFn decompose_replacement{nullptr};
bool IsValid(PythonPassKind kind) const {
if ((create == nullptr) || (destroy == nullptr)) {
return false;
}
switch (kind) {
case PythonPassKind::kFusionBase:
return run != nullptr;
case PythonPassKind::kPatternFusion:
return (patterns != nullptr) && (replacement != nullptr);
case PythonPassKind::kDecompose:
return decompose_replacement != nullptr;
default:
return false;
}
}
};
class PythonFusionPassRuntimeRegistry {
public:
static PythonFusionPassRuntimeRegistry &GetInstance();
static bool Register(const PythonPassDescriptor &pass_desc, const PythonFusionPassCallbacks &callbacks);
static bool Unregister(const std::string &descriptor_key);
static bool Acquire(const PythonPassDescriptor &pass_desc, PythonFusionPassCallbacks &callbacks);
static void Release(const PythonPassDescriptor &pass_desc);
static void Clear();
private:
PythonFusionPassRuntimeRegistry() = default;
~PythonFusionPassRuntimeRegistry() = default;
};
class PythonPassHolder {
public:
explicit PythonPassHolder(const PythonPassDescriptor &pass_desc);
~PythonPassHolder();
PythonPassHolder(const PythonPassHolder &) = delete;
PythonPassHolder &operator=(const PythonPassHolder &) = delete;
bool IsValid() const;
void *GetHolder() const;
const PythonFusionPassCallbacks &GetCallbacks() const;
const PythonPassDescriptor &GetPassDescriptor() const;
private:
PythonPassDescriptor pass_desc_;
PythonFusionPassCallbacks callbacks_;
void *holder_{nullptr};
bool valid_{false};
};
class PythonFusionBasePassAdapter : public FusionBasePass {
public:
explicit PythonFusionBasePassAdapter(const PythonPassDescriptor &pass_desc);
~PythonFusionBasePassAdapter() override;
Status Run(GraphPtr &graph, CustomPassContext &pass_context) override;
bool IsValid() const;
private:
std::unique_ptr<PythonPassHolder> holder_;
};
class PythonPatternFusionPassAdapter : public PatternFusionPass {
public:
explicit PythonPatternFusionPassAdapter(const PythonPassDescriptor &pass_desc);
~PythonPatternFusionPassAdapter() override;
bool IsValid() const;
protected:
std::vector<PatternUniqPtr> Patterns() override;
bool MeetRequirements(const std::unique_ptr<MatchResult> &match_result) override;
GraphUniqPtr Replacement(const std::unique_ptr<MatchResult> &match_result) override;
private:
explicit PythonPatternFusionPassAdapter(
std::pair<std::unique_ptr<PythonPassHolder>, std::unique_ptr<PatternMatcherConfig>> init);
std::unique_ptr<PythonPassHolder> holder_;
};
class PythonDecomposePassAdapter : public DecomposePass {
public:
explicit PythonDecomposePassAdapter(const PythonPassDescriptor &pass_desc);
~PythonDecomposePassAdapter() override;
bool IsValid() const;
protected:
bool MeetRequirements(const GNode &matched_node) override;
GraphUniqPtr Replacement(const GNode &matched_node) override;
private:
std::unique_ptr<PythonPassHolder> holder_;
};
FusionBasePass *CreatePythonPassAdapter();
bool RegisterPythonPass(const PythonPassDescriptor &pass_desc,
const PythonFusionPassCallbacks &callbacks);
void ClearPythonPassRuntimeRegistry();
}
}
#endif