* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "common/tbe_handle_store/tbe_handle_store.h"
#include <limits>
#include "framework/common/ge_inner_error_codes.h"
#include "framework/common/debug/log.h"
#include "runtime/kernel.h"
#include "common/plugin/ge_make_unique_util.h"
#include "common/checker.h"
#include "base/err_msg.h"
namespace ge {
TbeHandleInfo::~TbeHandleInfo() {
if (handle_ != nullptr) {
GE_CHK_RT(rtDevBinaryUnRegister(handle_));
}
handle_ = nullptr;
}
void TbeHandleInfo::used_inc(const uint32_t num) {
if (used_ > (std::numeric_limits<uint32_t>::max() - num)) {
REPORT_INNER_ERR_MSG("E19999", "Used:%u reach numeric max", used_);
GELOGE(INTERNAL_ERROR, "[Check][Param] Used[%u] reach numeric max.", used_);
return;
}
used_ += num;
}
void TbeHandleInfo::used_dec(const uint32_t num) {
if (used_ < (std::numeric_limits<uint32_t>::min() + num)) {
REPORT_INNER_ERR_MSG("E19999", "Used:%u reach numeric min", used_);
GELOGE(INTERNAL_ERROR, "[Check][Param] Used[%u] reach numeric min.", used_);
return;
}
used_ -= num;
}
uint32_t TbeHandleInfo::used_num() const {
return used_;
}
void *TbeHandleInfo::handle() const {
return handle_;
}
std::recursive_mutex TBEHandleStore::mutex_;
TBEHandleStore &TBEHandleStore::GetInstance() {
static TBEHandleStore instance;
return instance;
}
bool TBEHandleStore::FindTBEHandle(const std::string &name, void *&handle) {
const std::lock_guard<std::recursive_mutex> lock(mutex_);
const auto it = bin_key_to_handle_.find(name);
if (it == bin_key_to_handle_.end()) {
return false;
} else {
handle = it->second->handle();
return true;
}
}
Status TBEHandleStore::StoreTBEHandle(const std::string &name, void *handle, const OpKernelBinPtr &kernel) {
const std::lock_guard<std::recursive_mutex> lock(mutex_);
const auto it = bin_key_to_handle_.find(name);
if (it == bin_key_to_handle_.end()) {
std::unique_ptr<TbeHandleInfo> handle_info = ge::MakeUnique<TbeHandleInfo>(handle, kernel);
GE_ASSERT_NOTNULL(handle_info);
handle_info->used_inc();
(void)bin_key_to_handle_.emplace(name, std::move(handle_info));
} else {
auto &handle_info = it->second;
handle_info->used_inc();
}
return SUCCESS;
}
void *TBEHandleStore::GetUniqueIdPtr(void *const handle, const std::string &kernel, bool &inserted) {
const std::lock_guard<std::recursive_mutex> lock(mutex_);
std::unordered_map<std::string, std::list<uint8_t>>& inner = handle_to_kernel_to_unique_id_[handle];
std::pair<std::unordered_map<std::string, std::list<uint8_t>>::iterator, bool> ret =
inner.insert(std::make_pair(kernel, std::list<uint8_t>{0}));
inserted = ret.second;
return &(ret.first->second.back());
}
void TBEHandleStore::ReferTBEHandle(const std::string &name) {
const std::lock_guard<std::recursive_mutex> lock(mutex_);
const auto it = bin_key_to_handle_.find(name);
if (it == bin_key_to_handle_.end()) {
REPORT_INNER_ERR_MSG("E19999", "Kernel:%s not found in stored check invalid", name.c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] Kernel[%s] not found in stored.", name.c_str());
return;
}
it->second->used_inc();
}
void TBEHandleStore::EraseTBEHandle(const std::map<std::string, uint32_t> &names) {
const std::lock_guard<std::recursive_mutex> lock(mutex_);
for (auto &item : names) {
const auto it = bin_key_to_handle_.find(item.first);
if (it == bin_key_to_handle_.end()) {
REPORT_INNER_ERR_MSG("E19999", "Kernel:%s not found in stored check invalid", item.first.c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] Kernel[%s] not found in stored.", item.first.c_str());
continue;
}
auto &info = it->second;
if (info->used_num() > item.second) {
info->used_dec(item.second);
} else {
(void)handle_to_kernel_to_unique_id_.erase(info->handle());
(void)bin_key_to_handle_.erase(it);
}
}
}
KernelHolder::KernelHolder(const char_t *const stub_func,
const std::shared_ptr<ge::OpKernelBin> &kernel_bin)
: stub_func_(stub_func), kernel_bin_(kernel_bin) {}
KernelHolder::~KernelHolder() {
if (bin_handle_ != nullptr) {
GE_CHK_RT(rtDevBinaryUnRegister(bin_handle_));
}
}
HandleHolder::HandleHolder(void *const bin_handle)
: bin_handle_(bin_handle) {}
HandleHolder::~HandleHolder() {
if (bin_handle_ != nullptr) {
GE_CHK_RT(rtDevBinaryUnRegister(bin_handle_));
}
}
const char_t *KernelBinRegistry::GetUnique(const std::string &stub_func) {
const std::lock_guard<std::mutex> lock(mutex_);
auto it = unique_stubs_.find(stub_func);
if (it != unique_stubs_.end()) {
return it->c_str();
} else {
it = unique_stubs_.insert(unique_stubs_.cend(), stub_func);
return it->c_str();
}
}
const char_t *KernelBinRegistry::GetStubFunc(const std::string &stub_name) {
const std::lock_guard<std::mutex> lock(mutex_);
const auto iter = registered_bins_.find(stub_name);
if (iter != registered_bins_.end()) {
return iter->second->stub_func_;
}
return nullptr;
}
bool KernelBinRegistry::AddKernel(const std::string &stub_name, std::unique_ptr<KernelHolder> &&holder) {
const std::lock_guard<std::mutex> lock(mutex_);
const auto ret = registered_bins_.emplace(stub_name, std::move(holder));
return ret.second;
}
bool HandleRegistry::AddHandle(std::unique_ptr<HandleHolder> &&holder) {
const auto ret = registered_handles_.emplace(std::move(holder));
return ret.second;
}
KernelBinRegistry &KernelBinRegistry::GetInstance() {
static KernelBinRegistry instance;
return instance;
}
HandleRegistry &HandleRegistry::GetInstance() {
static HandleRegistry instance;
return instance;
}
}