* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ------------------------------------------------------------------------- */
#include "DBITask.h"
#include <chrono>
#include <elf.h>
#include "utils/FileSystem.h"
#include "utils/ElfLoader.h"
#include "utils/Ustring.h"
#include "utils/PipeCall.h"
#include "utils/Future.h"
#include "runtime/inject_helpers/InstrReport.h"
#include "runtime/inject_helpers/KernelContext.h"
#include "runtime/inject_helpers/KernelReplacement.h"
#include "runtime/inject_helpers/MemoryDataCollect.h"
#include "runtime/inject_helpers/RegisterContext.h"
#include "runtime/inject_helpers/DevMemManager.h"
#include "runtime/RuntimeOrigin.h"
#include "core/BinaryInstrumentation.h"
#include "core/LocalProcess.h"
using namespace std;
DBITaskFactory &DBITaskFactory::Instance()
{
thread_local static DBITaskFactory inst;
return inst;
}
DBITaskSP DBITaskFactory::GetOrCreate(uint64_t regId, const std::string &tilingKey,
BIType type, const std::string &pluginPath)
{
auto &task = taskPool_[regId][{type, nullptr, tilingKey, pluginPath}];
if (!task) {
task = MakeShared<DBITask>(type);
}
return task;
}
DBITaskSP DBITaskFactory::GetOrCreate(uint64_t regId, const void *stubFunc,
BIType type, const std::string &pluginPath)
{
auto &task = taskPool_[regId][{type, stubFunc, "", pluginPath}];
if (!task) {
task = MakeShared<DBITask>(type);
}
return task;
}
void DBITaskFactory::Reset()
{
taskPool_.clear();
}
void DBITaskFactory::Destroy(uint64_t regId)
{
auto dit = taskPool_.find(regId);
if (dit != taskPool_.cend()) {
taskPool_.erase(dit);
}
}
bool DBITask::NeedConvert() const
{
if (!replaceTask_ && !funcCtx_) {
return true;
}
if (!IsDir(DBITaskConfig::Instance().GetOutputDir(launchId_)) &&
DBITaskConfig::Instance().keepTaskOutputs_) {
return true;
}
return false;
}
bool DBITask::ReuseConverted(void **handle, uint64_t launchId, uint64_t registerId)
{
if (NeedConvert()) {
return false;
}
*handle = replaceTask_->GetHandle();
KernelContext::Instance().SetDbiBinary(registerId, replaceTask_->GetDevBinary());
if (DBITaskConfig::Instance().keepTaskOutputs_) {
CreateSymlink(DBITaskConfig::Instance().GetOutputDir(launchId_),
DBITaskConfig::Instance().GetOutputDir(launchId));
}
return true;
}
bool DBITask::ReuseConverted(uint64_t launchId) const
{
if (NeedConvert()) {
return false;
}
if (DBITaskConfig::Instance().keepTaskOutputs_) {
CreateSymlink(DBITaskConfig::Instance().GetOutputDir(launchId_),
DBITaskConfig::Instance().GetOutputDir(launchId));
}
return true;
}
FuncContextSP DBITask::Run(const LaunchContextSP &launchCtx)
{
uint64_t launchId = launchCtx->GetLaunchId();
if (ReuseConverted(launchId)) {
return funcCtx_;
}
auto funcCtx = launchCtx->GetFuncContext();
const auto &taskConfig = DBITaskConfig::Instance();
string tmpLaunchDir = taskConfig.GetOutputDir(launchId);
if (!MkdirRecusively(tmpLaunchDir)) {
ERROR_LOG("Mkdir tmp launch dir failed, stop running dbi task.");
return nullptr;
}
shared_ptr<void> defer0(nullptr, [&tmpLaunchDir](std::nullptr_t&) {
if (!DBITaskConfig::Instance().keepTaskOutputs_) {
RemoveAll(tmpLaunchDir);
}
});
BinaryInstrumentation::Config config{taskConfig.pluginPath_, GetTargetArchName(funcCtx), tmpLaunchDir,
taskConfig.tuneLogPath_, taskConfig.argsSize_, taskConfig.extraCompilerArgs_};
string oldKernelPath = JoinPath({tmpLaunchDir, taskConfig.oldKernelName_});
string newKernelPath = JoinPath({tmpLaunchDir, taskConfig.newKernelName_});
if (!funcCtx->GetRegisterContext()->Save(oldKernelPath)) {
DEBUG_LOG("Dump kernel object failed. stop running dbi task.");
return nullptr;
}
uint64_t tilingKey{};
string tilingKeyStr{};
if (funcCtx->GetTilingKey(tilingKey)) {
tilingKeyStr = to_string(tilingKey);
}
if (!Convert(config, oldKernelPath, newKernelPath, tilingKeyStr)) {
DEBUG_LOG("Run dbi task failed, will use original kernel to run.");
return nullptr;
}
auto newRegCtx = funcCtx->GetRegisterContext()->Clone(newKernelPath);
if (!newRegCtx) {
DEBUG_LOG("Replace handle failed: register binary failed.");
return nullptr;
}
funcCtx_ = funcCtx->Clone(newRegCtx);
if (!funcCtx_) {
DEBUG_LOG("Replace handle failed: register function failed.");
return nullptr;
}
launchId_ = launchId;
return funcCtx_;
}
bool DBITask::Run(void **handle, uint64_t launchId, bool withStubFunc)
{
uint64_t regId = KernelContext::Instance().GetRegisterId(launchId);
if (ReuseConverted(handle, launchId, regId)) {
return true;
}
KernelContext::LaunchEvent event;
if (!KernelContext::Instance().GetLaunchEvent(launchId, event)) {
ERROR_LOG("Get launch event failed.");
return false;
};
const auto &taskConfig = DBITaskConfig::Instance();
string tmpLaunchDir = taskConfig.GetOutputDir(launchId);
if (!MkdirRecusively(tmpLaunchDir)) {
ERROR_LOG("Mkdir tmp launch dir failed, stop running dbi task.");
return false;
}
shared_ptr<void> defer0(nullptr, [&tmpLaunchDir](std::nullptr_t&) {
if (!DBITaskConfig::Instance().keepTaskOutputs_) {
RemoveAll(tmpLaunchDir);
}
});
BinaryInstrumentation::Config config{taskConfig.pluginPath_, GetCurrentArchName(), tmpLaunchDir,
taskConfig.tuneLogPath_, taskConfig.argsSize_, taskConfig.extraCompilerArgs_};
string oldKernelPath = JoinPath({tmpLaunchDir, taskConfig.oldKernelName_});
string newKernelPath = JoinPath({tmpLaunchDir, taskConfig.newKernelName_});
if (!KernelContext::Instance().DumpKernelObject(regId, oldKernelPath)) {
DEBUG_LOG("Dump kernel object failed. stop running dbi task.");
return false;
}
string tilingKey{};
if (!event.stubFunc) { tilingKey = to_string(event.tilingKey);}
if (!Convert(config, oldKernelPath, newKernelPath, tilingKey)) {
DEBUG_LOG("Run dbi task failed, will use original kernel to run.");
return false;
}
replaceTask_.reset(new KernelReplaceTask(newKernelPath));
if (replaceTask_->Run(handle, regId, withStubFunc)) {
launchId_ = launchId;
KernelContext::Instance().SetDbiBinary(regId, replaceTask_->GetDevBinary());
return true;
}
DEBUG_LOG("Replace handle failed.");
replaceTask_.reset();
return false;
}
bool DBITask::Run(void **handle, const void **stubFunc, uint64_t launchId)
{
if (!DBITask::Run(handle, launchId, true)) {
return false;
}
if (registered_) {
*stubFunc = &this->stubFunc_;
return true;
}
KernelContext::LaunchEvent launchEvent;
if (!KernelContext::Instance().GetLaunchEvent(launchId, launchEvent)) {
ERROR_LOG("Get launch event failed.");
return false;
};
KernelContext::StubFuncInfo stubFuncInfo;
if (!KernelContext::Instance().GetStubFuncInfo(KernelContext::StubFuncPtr{launchEvent.stubFunc}, stubFuncInfo)) {
ERROR_LOG("Get stubFuncInfo failed");
return false;
}
rtError_t ret = rtFunctionRegisterOrigin(*handle, &this->stubFunc_,
stubFuncInfo.kernelName.c_str(),
stubFuncInfo.kernelInfoExt.c_str(), stubFuncInfo.funcMode);
if (ret != RT_ERROR_NONE) {
ERROR_LOG("register function for stubName %s failed. error code: %d", stubFuncInfo.kernelName.c_str(), ret);
return false;
}
registered_ = true;
*stubFunc = &this->stubFunc_;
return true;
}
bool DBITask::Convert(const BinaryInstrumentation::Config &config, const string &oldKernelPath,
const string &newKernelPath, const string &tilingKey)
{
if (!dbi_->SetConfig(config)) {
return false;
}
auto start = std::chrono::system_clock::now();
if (!dbi_->Convert(newKernelPath, oldKernelPath, tilingKey)) {
return false;
}
auto end = std::chrono::system_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
DEBUG_LOG("Run dbi task success. time duration: %fs", duration.count() / 1000.f);
return true;
}
DBITaskConfig &DBITaskConfig::Instance()
{
thread_local static DBITaskConfig inst;
return inst;
}
void DBITaskConfig::Init(BIType type, const std::string &pluginPath, const KernelMatcher::Config &config,
const std::string &tmpPath, const std::string &tuneLogPath, const std::vector<std::string> &extraCompilerArgs) {
enabled_ = true;
pluginPath_ = pluginPath;
type_ = type;
matcher_ = MakeShared<KernelMatcher>(config);
tuneLogPath_ = tuneLogPath;
extraCompilerArgs_ = extraCompilerArgs;
SetTmpRootDir(tmpPath);
}
void DBITaskConfig::Init(BIType type, const std::shared_ptr<KernelMatcher> &matcher, const std::string &tmpPath)
{
enabled_ = true;
pluginPath_ = "";
type_ = type;
matcher_ = matcher;
SetTmpRootDir(tmpPath);
}
DBITaskConfig::DBITaskConfig()
{
keepTaskOutputs_ = (InjectLogger::Instance().GetLogLv() <= LogLv::DEBUG);
keepRootTmpDirOutputs_ = keepTaskOutputs_;
}
DBITaskConfig::~DBITaskConfig()
{
if (!keepRootTmpDirOutputs_ && IsExist(tmpRootDir_)) {
RemoveAll(tmpRootDir_);
}
}
bool DBITaskConfig::IsEnabled(uint64_t launchId, const string &kernelName) const
{
if (!enabled_) {
DEBUG_LOG("not enable dbi task");
return false;
}
if (!matcher_ || !matcher_->Match(launchId, kernelName)) {
DEBUG_LOG("Can not match kernel name %s", kernelName.c_str());
return false;
}
return true;
}
void DBITaskConfig::SetTmpRootDir(const std::string &path)
{
if (!tmpRootDir_.empty()) {
return;
}
if (path.empty()) {
tmpRootDir_ = "./tmp_" + std::to_string(getpid()) + "_" + GenerateTimeStamp<std::chrono::nanoseconds>();
return;
}
tmpRootDir_ = path + "/tmp_" + std::to_string(getpid()) + "_" + GenerateTimeStamp<std::chrono::nanoseconds>();
}
uint8_t* InitMemory(uint64_t memSize)
{
if (memSize == 0) {
return nullptr;
}
DEBUG_LOG("Dbi task init with memSize: %lu", memSize);
void *memInfo;
aclError ret = DevMemManager::Instance().MallocMemory(&memInfo, memSize);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("malloc memInfo error: %d", ret);
return nullptr;
}
aclError error = aclrtMemsetImplOrigin(memInfo, memSize, 0x00, memSize);
if (error != ACL_ERROR_NONE) {
ERROR_LOG("init rtMemset memInfo error: %d", error);
return nullptr;
}
return static_cast<uint8_t *>(memInfo);
}
bool RunDBITask(const StubFunc **stubFunc)
{
if (!IsPlatformSupportDBI()) {
DEBUG_LOG("Unsupported platform, exit DBI");
return false;
}
void const *lastStubFunc = *stubFunc;
rtDevBinary_t binary;
if (!KernelContext::Instance().GetDevBinary(KernelContext::StubFuncPtr{*stubFunc}, binary)) {
ERROR_LOG("get binary failed");
return false;
}
uint64_t launchId = KernelContext::Instance().GetLaunchId();
string kernelName = KernelContext::Instance().GetLaunchName();
const auto &config = DBITaskConfig::Instance();
if ((!config.IsEnabled(launchId, kernelName)) || HasStaticStub(binary)) {
return false;
}
uint64_t regId = KernelContext::Instance().GetRegisterId(launchId);
KernelContext::LaunchEvent launchEvent;
if (!KernelContext::Instance().GetLaunchEvent(launchId, launchEvent)) {
ERROR_LOG("get launch event by launchId %lu failed", launchId);
return false;
}
DEBUG_LOG("DBI with launch id = %lu, kernelName=%s", launchId, kernelName.c_str());
KernelHandle *hdl{nullptr};
auto task = DBITaskFactory::Instance().GetOrCreate(regId, *stubFunc, config.type_, config.pluginPath_);
if (!task || !task->Run(&hdl, stubFunc, launchId)) {
return false;
}
return KernelContext::Instance().GetDeviceContext()
.UpdatePcStartAddrByDbi(launchId, KernelContext::StubFuncArgs{lastStubFunc, *stubFunc});
}
bool RunDBITask(void **hdl, const uint64_t tilingKey)
{
if (!IsPlatformSupportDBI()) {
DEBUG_LOG("Unsupported platform, exit DBI");
return false;
}
void const *lastHdl = *hdl;
rtDevBinary_t binary;
if (!KernelContext::Instance().GetDevBinary(KernelContext::KernelHandlePtr{*hdl}, binary)) {
ERROR_LOG("get binary failed");
return false;
}
uint64_t launchId = KernelContext::Instance().GetLaunchId();
string kernelName = KernelContext::Instance().GetLaunchName();
const auto &config = DBITaskConfig::Instance();
if ((!config.IsEnabled(launchId, kernelName)) || HasStaticStub(binary)) {
return false;
}
uint64_t regId = KernelContext::Instance().GetRegisterId(launchId);
DEBUG_LOG("DBI with launch id = %lu, kernelName=%s", launchId, kernelName.c_str());
auto task = DBITaskFactory::Instance().GetOrCreate(regId, to_string(tilingKey),
config.type_, config.pluginPath_);
if (!task || !task->Run(hdl, launchId)) {
return false;
}
return KernelContext::Instance().GetDeviceContext()
.UpdatePcStartAddrByDbi(launchId, KernelContext::KernelHandleArgs{lastHdl, *hdl, tilingKey});
}
FuncContextSP RunDBITask(const LaunchContextSP &launchCtx)
{
if (!IsPlatformSupportDBI()) {
DEBUG_LOG("Unsupported platform, exit DBI");
return nullptr;
}
uint64_t launchId = launchCtx->GetLaunchId();
string kernelName = launchCtx->GetFuncContext()->GetKernelName();
const auto &config = DBITaskConfig::Instance();
if ((!config.IsEnabled(launchId, kernelName)) || launchCtx->GetFuncContext()->GetRegisterContext()->HasStaticStub()) {
return nullptr;
}
DEBUG_LOG("DBI with launch id = %lu, kernelName=%s", launchId, kernelName.c_str());
uint64_t regId = launchCtx->GetFuncContext()->GetRegisterContext()->GetRegisterId();
auto task = DBITaskFactory::Instance().GetOrCreate(regId, kernelName, config.type_, config.pluginPath_);
if (task == nullptr) {
return nullptr;
}
auto newFuncCtx = task->Run(launchCtx);
return newFuncCtx;
}