* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is part of the MindStudio project.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "csrc/common/function_loader.h"
#include <dlfcn.h>
#include <iostream>
#include <set>
#include "csrc/common/utils.h"
namespace Mspti {
namespace Common {
FunctionLoader::FunctionLoader(const std::string& soName)
{
soName_ = soName + ".so";
}
FunctionLoader::~FunctionLoader()
{
if (handle_) {
dlclose(handle_);
}
}
void FunctionLoader::Set(const std::string& funcName)
{
registry_[funcName] = nullptr;
}
std::string FunctionLoader::CanonicalSoPath()
{
static const std::set<std::string> soNameList = {
"libascend_hal.so",
"libascendalog.so",
"libascendcl.so",
"libhccl.so",
"libprofapi.so",
};
if (soNameList.find(soName_) == soNameList.end()) {
std::cout << soName_ << " was invalid." << std::endl;
return "";
}
char *ascendHomePath = std::getenv("ASCEND_HOME_PATH");
if (ascendHomePath == nullptr || ascendHomePath[0] == '\0') {
return soName_;
}
auto soPath = std::string(ascendHomePath) + "/lib64/" + soName_;
auto canonicalPath = Utils::RealPath(Utils::RelativeToAbsPath(soPath));
return Utils::FileExist(canonicalPath) && Utils::FileReadable(canonicalPath) ? canonicalPath : soName_;
}
void *FunctionLoader::Get(const std::string& funcName)
{
if (!handle_) {
auto soPath = CanonicalSoPath();
if (soPath.empty()) {
return nullptr;
}
auto handle = dlopen(soPath.c_str(), RTLD_LAZY);
if (handle == nullptr) {
std::cout << dlerror() << std::endl;
return nullptr;
}
handle_ = handle;
}
auto itr = registry_.find(funcName);
if (itr == registry_.end()) {
std::cout << "function(\"" << funcName << "\") is not registered." << std::endl;
return nullptr;
}
if (itr->second) {
return itr->second;
}
auto func = dlsym(handle_, funcName.c_str());
if (func == nullptr) {
return nullptr;
}
registry_[funcName] = func;
return func;
}
FunctionRegister *FunctionRegister::GetInstance()
{
static FunctionRegister instance;
return &instance;
}
void FunctionRegister::RegisterFunction(const std::string& soName, const std::string& funcName)
{
std::lock_guard<std::mutex> lock(mu_);
auto itr = registry_.find(soName);
if (itr == registry_.end()) {
std::unique_ptr<FunctionLoader> func_loader = nullptr;
Mspti::Common::MsptiMakeUniquePtr(func_loader, soName);
if (!func_loader) {
std::cout << "Failed to init FunctionLoader." << std::endl;
return;
}
func_loader->Set(funcName);
registry_.emplace(soName, std::move(func_loader));
return;
}
itr->second->Set(funcName);
}
FunctionHandle FunctionRegister::Get(const std::string &soName, const std::string &funcName)
{
std::lock_guard<std::mutex> lock(mu_);
auto itr = registry_.find(soName);
if (itr != registry_.end()) {
return itr->second->Get(funcName);
}
return nullptr;
}
void* RegisterFunction(const std::string& soName, const std::string& funcName)
{
FunctionRegister::GetInstance()->RegisterFunction(soName, funcName);
return FunctionRegister::GetInstance()->Get(soName, funcName);
}
}
}