* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "python_pass_adapter.h"
#include <memory>
#include <mutex>
#include <utility>
#include <vector>
#include "graph_metadef/graph/debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
namespace ge {
namespace fusion {
namespace {
struct PythonFusionPassRuntimeEntry {
explicit PythonFusionPassRuntimeEntry(PythonPassDescriptor desc, PythonFusionPassCallbacks cb)
: pass_desc(std::move(desc)), callbacks(cb) {}
PythonPassDescriptor pass_desc;
PythonFusionPassCallbacks callbacks;
size_t active_adapter_count{0U};
std::mutex mutex;
};
class PythonFusionPassRuntimeRegistryImpl {
public:
bool Register(const PythonPassDescriptor &pass_desc, const PythonFusionPassCallbacks &callbacks) {
if (!callbacks.IsValid(pass_desc.kind)) {
GELOGW("Register python pass runtime failed, descriptor key[%s], kind[%u].",
pass_desc.descriptor_key.c_str(), static_cast<uint32_t>(pass_desc.kind));
return false;
}
std::lock_guard<std::mutex> lock(mutex_);
if (descriptor_key_2_runtime_entry_.find(pass_desc.descriptor_key) != descriptor_key_2_runtime_entry_.cend()) {
GELOGI("Python pass runtime descriptor key [%s] has already registered.", pass_desc.descriptor_key.c_str());
return false;
}
descriptor_key_2_runtime_entry_.emplace(
pass_desc.descriptor_key, ge::ComGraphMakeShared<PythonFusionPassRuntimeEntry>(pass_desc, callbacks));
return true;
}
bool Unregister(const std::string &descriptor_key) {
const std::lock_guard<std::mutex> map_lock(mutex_);
const auto iter = descriptor_key_2_runtime_entry_.find(descriptor_key);
if (iter == descriptor_key_2_runtime_entry_.cend()) {
return false;
}
const auto &runtime_entry = iter->second;
std::lock_guard<std::mutex> runtime_lock(runtime_entry->mutex);
if (runtime_entry->active_adapter_count != 0U) {
GELOGW("Python pass runtime descriptor key [%s] is still in use.", descriptor_key.c_str());
return false;
}
descriptor_key_2_runtime_entry_.erase(iter);
return true;
}
bool Acquire(const PythonPassDescriptor &pass_desc, PythonFusionPassCallbacks &callbacks) {
auto runtime_entry = Get(pass_desc.descriptor_key);
if (runtime_entry == nullptr) {
GELOGW("Acquire python pass runtime failed because descriptor key [%s] is not registered.",
pass_desc.descriptor_key.c_str());
return false;
}
std::lock_guard<std::mutex> lock(runtime_entry->mutex);
++runtime_entry->active_adapter_count;
callbacks = runtime_entry->callbacks;
return true;
}
void Release(const PythonPassDescriptor &pass_desc) {
auto runtime_entry = Get(pass_desc.descriptor_key);
if (runtime_entry == nullptr) {
return;
}
std::lock_guard<std::mutex> lock(runtime_entry->mutex);
if (runtime_entry->active_adapter_count == 0U) {
return;
}
--runtime_entry->active_adapter_count;
}
void Clear() {
std::lock_guard<std::mutex> lock(mutex_);
descriptor_key_2_runtime_entry_.clear();
}
private:
std::shared_ptr<PythonFusionPassRuntimeEntry> Get(const std::string &descriptor_key) {
std::lock_guard<std::mutex> lock(mutex_);
const auto iter = descriptor_key_2_runtime_entry_.find(descriptor_key);
if (iter == descriptor_key_2_runtime_entry_.cend()) {
return nullptr;
}
return iter->second;
}
std::mutex mutex_;
std::map<std::string, std::shared_ptr<PythonFusionPassRuntimeEntry>> descriptor_key_2_runtime_entry_;
};
PythonFusionPassRuntimeRegistryImpl &GetPythonFusionPassRuntimeRegistryImpl() {
static PythonFusionPassRuntimeRegistryImpl runtime_registry;
return runtime_registry;
}
}
PythonPassHolder::PythonPassHolder(const PythonPassDescriptor &pass_desc) : pass_desc_(pass_desc) {
if (!PythonFusionPassRuntimeRegistry::GetInstance().Acquire(pass_desc_, callbacks_)) {
return;
}
if (callbacks_.create == nullptr) {
PythonFusionPassRuntimeRegistry::GetInstance().Release(pass_desc_);
return;
}
holder_ = callbacks_.create(&pass_desc_);
if (holder_ == nullptr) {
PythonFusionPassRuntimeRegistry::GetInstance().Release(pass_desc_);
return;
}
valid_ = true;
}
PythonPassHolder::~PythonPassHolder() {
if (valid_) {
if ((holder_ != nullptr) && (callbacks_.destroy != nullptr)) {
callbacks_.destroy(holder_);
holder_ = nullptr;
}
PythonFusionPassRuntimeRegistry::GetInstance().Release(pass_desc_);
}
}
bool PythonPassHolder::IsValid() const { return valid_; }
void *PythonPassHolder::GetHolder() const { return holder_; }
const PythonFusionPassCallbacks &PythonPassHolder::GetCallbacks() const { return callbacks_; }
const PythonPassDescriptor &PythonPassHolder::GetPassDescriptor() const { return pass_desc_; }
PythonFusionPassRuntimeRegistry &PythonFusionPassRuntimeRegistry::GetInstance() {
static PythonFusionPassRuntimeRegistry runtime_registry;
return runtime_registry;
}
bool PythonFusionPassRuntimeRegistry::Register(const PythonPassDescriptor &pass_desc,
const PythonFusionPassCallbacks &callbacks) {
return GetPythonFusionPassRuntimeRegistryImpl().Register(pass_desc, callbacks);
}
bool PythonFusionPassRuntimeRegistry::Unregister(const std::string &descriptor_key) {
return GetPythonFusionPassRuntimeRegistryImpl().Unregister(descriptor_key);
}
bool PythonFusionPassRuntimeRegistry::Acquire(const PythonPassDescriptor &pass_desc,
PythonFusionPassCallbacks &callbacks) {
return GetPythonFusionPassRuntimeRegistryImpl().Acquire(pass_desc, callbacks);
}
void PythonFusionPassRuntimeRegistry::Release(const PythonPassDescriptor &pass_desc) {
GetPythonFusionPassRuntimeRegistryImpl().Release(pass_desc);
}
void PythonFusionPassRuntimeRegistry::Clear() {
GetPythonFusionPassRuntimeRegistryImpl().Clear();
}
PythonFusionBasePassAdapter::PythonFusionBasePassAdapter(const PythonPassDescriptor &pass_desc)
: holder_(new (std::nothrow) PythonPassHolder(pass_desc)) {}
PythonFusionBasePassAdapter::~PythonFusionBasePassAdapter() = default;
Status PythonFusionBasePassAdapter::Run(GraphPtr &graph, CustomPassContext &pass_context) {
if ((holder_ == nullptr) || (!holder_->IsValid())) {
pass_context.SetErrorMessage("python fusion base adapter is invalid");
return FAILED;
}
const auto &callbacks = holder_->GetCallbacks();
if ((holder_->GetHolder() == nullptr) || (callbacks.run == nullptr)) {
pass_context.SetErrorMessage("python fusion base adapter callback is invalid");
return FAILED;
}
return callbacks.run(holder_->GetHolder(), graph, pass_context);
}
bool PythonFusionBasePassAdapter::IsValid() const {
return (holder_ != nullptr) && holder_->IsValid();
}
FusionBasePass *CreatePythonPassAdapter() {
PythonPassDescriptor pass_desc;
if (!PassRegistry::GetInstance().ResolveCurrentPythonPassDescriptor(pass_desc)) {
GELOGW("Create python pass adapter failed because current python pass descriptor is missing.");
return nullptr;
}
switch (pass_desc.kind) {
case PythonPassKind::kFusionBase: {
auto *adapter = new (std::nothrow) PythonFusionBasePassAdapter(pass_desc);
if ((adapter == nullptr) || (!adapter->IsValid())) {
delete adapter;
return nullptr;
}
return adapter;
}
case PythonPassKind::kPatternFusion: {
auto *adapter = new (std::nothrow) PythonPatternFusionPassAdapter(pass_desc);
if ((adapter == nullptr) || (!adapter->IsValid())) {
delete adapter;
return nullptr;
}
return adapter;
}
case PythonPassKind::kDecompose: {
auto *adapter = new (std::nothrow) PythonDecomposePassAdapter(pass_desc);
if ((adapter == nullptr) || (!adapter->IsValid())) {
delete adapter;
return nullptr;
}
return adapter;
}
default:
GELOGW("Create python pass adapter failed because pass kind[%u] is unsupported.",
static_cast<uint32_t>(pass_desc.kind));
return nullptr;
}
}
bool RegisterPythonPass(const PythonPassDescriptor &pass_desc,
const PythonFusionPassCallbacks &callbacks) {
if (!PythonFusionPassRuntimeRegistry::GetInstance().Register(pass_desc, callbacks)) {
return false;
}
if (PassRegistry::GetInstance().RegisterPythonPass(pass_desc, CreatePythonPassAdapter)) {
return true;
}
(void)PythonFusionPassRuntimeRegistry::GetInstance().Unregister(pass_desc.descriptor_key);
return false;
}
void ClearPythonPassRuntimeRegistry() {
PythonFusionPassRuntimeRegistry::GetInstance().Clear();
}
}
}