* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ------------------------------------------------------------------------- */
#include "KernelContext.h"
#include <elf.h>
#include <fstream>
#include <mutex>
#include <iostream>
#include <functional>
#include <iterator>
#include <sstream>
#include <tuple>
#include <utility>
#include <thread>
#include "acl.h"
#include "acl_rt_impl/AscendclImplOrigin.h"
#include "runtime/inject_helpers/ArgsHandleContext.h"
#include "runtime/inject_helpers/KernelReplacement.h"
#include "runtime/inject_helpers/LaunchContext.h"
#include "utils/InjectLogger.h"
#include "utils/ElfLoader.h"
#include "utils/FileSystem.h"
#include "utils/PipeCall.h"
#include "utils/TypeTraits.h"
#include "utils/Ustring.h"
#include "runtime/RuntimeOrigin.h"
#include "MemoryDataCollect.h"
#include "DBITask.h"
#include "MsTx.h"
#include "core/FuncSelector.h"
#include "RegisterContext.h"
#include "LaunchManager.h"
#include "ArgsManager.h"
#include "DeviceContext.h"
using namespace std;
namespace {
constexpr uint32_t U32_MASK = ~0U;
constexpr uint8_t U64_BYTE_SIZE = 8U;
constexpr char const *MIX_AIC_TAIL = "_mix_aic";
constexpr char const *MIX_AIV_TAIL = "_mix_aiv";
using Defer = std::shared_ptr<void>;
enum class PtrType : uint64_t {
PRIMARY = 0U,
SECONDARY,
SECONDARY_SHAPE
};
bool GetTilingKeyFromName(std::string kernelName, uint64_t &tilingKey)
{
if (EndsWith(kernelName, MIX_AIC_TAIL) ||
EndsWith(kernelName, MIX_AIV_TAIL)) {
kernelName = kernelName.substr(0UL, kernelName.length() - 8UL);
}
std::vector<std::string> items;
SplitString(kernelName, '_', items);
if (items.size() < 2UL) {
return false;
}
std::stringstream ss(items[items.size() - 1UL]);
ss >> tilingKey;
return true;
}
bool GetKernelOffsetByTilingKey(std::vector<std::string> const &kernelNames,
std::vector<uint64_t> const &kernelOffsets,
uint64_t tilingKey,
uint64_t &kernelOffset)
{
uint64_t tilingKeyFromName {};
for (std::size_t idx = 0; idx < kernelNames.size(); ++idx) {
if (!GetTilingKeyFromName(kernelNames[idx], tilingKeyFromName)) {
continue;
}
if (tilingKeyFromName == tilingKey) {
DEBUG_LOG("match kernel name with tiling key success. tilingKey: %lu, "
"kernelName: %s, kernelOffset: 0x%lx",
tilingKey, kernelNames[idx].c_str(), kernelOffsets[idx]);
kernelOffset = kernelOffsets[idx];
return true;
}
}
return false;
}
void CompleteInputData(const rtArgsEx_t &argsInfo, OpMemInfo &memInfo)
{
auto *buff = static_cast<uint64_t *>(argsInfo.args);
uint8_t *tilingDataOffsetBuff = (argsInfo.tilingDataOffset == 0) ?
nullptr : reinterpret_cast<uint8_t *>(buff + ((argsInfo.tilingDataOffset) / U64_BYTE_SIZE));
uint8_t *tilingAddrOffsetBuff = (argsInfo.tilingAddrOffset == 0) ?
nullptr : reinterpret_cast<uint8_t *>(*(buff + (argsInfo.tilingAddrOffset / U64_BYTE_SIZE)));
uint8_t *tilingBuffData = (tilingDataOffsetBuff != nullptr) ? tilingDataOffsetBuff : tilingAddrOffsetBuff;
vector<uint8_t> tilingData;
if (tilingBuffData && memInfo.tilingDataSize > 0) {
DEBUG_LOG("update input size with tiling data. tilingDataSize=%lu", memInfo.tilingDataSize);
tilingData.assign(tilingBuffData, tilingBuffData + memInfo.tilingDataSize);
UpdateInputSizeWithTilingData(tilingData, memInfo);
}
}
}
std::string KernelContext::GetNameByTilingKey(const KernelHandle *handle, uint64_t tilingKey)
{
auto it = hdlToRegId_.find(handle);
if (it == hdlToRegId_.end()) {
WARN_LOG("can not find names for handle.");
return "";
}
uint64_t regId = it->second;
RegisterEvent event = registerEvents_[regId];
const vector<string> &names = event.kernelNames;
if (names.size() == 1U) {
return names.back();
}
std::string nameSuffix = "_" + std::to_string(tilingKey);
std::vector<std::string> suffixNames = {nameSuffix, nameSuffix + MIX_AIC_TAIL, nameSuffix + MIX_AIV_TAIL};
for (const std::string &name : names) {
for (const std::string &suffix: suffixNames) {
if (EndsWith(name, suffix)) {
return name;
}
}
}
return "";
}
bool KernelContext::GetDevBinary(KernelHandlePtr hdl, rtDevBinary_t &bin, bool getStubBin) const
{
auto it = hdlToRegId_.find(hdl.value);
if (it == hdlToRegId_.cend()) {
ERROR_LOG("find register id by binary handle failed");
return false;
}
RegisterEvent const ®isterEvent = registerEvents_[it->second];
if (getStubBin && registerEvent.stubBin == nullptr) {
return false;
}
bin = getStubBin ? *registerEvent.stubBin : registerEvent.bin;
return true;
}
bool KernelContext::GetDevBinary(StubFuncPtr stubFunc, rtDevBinary_t &bin, bool getStubBin) const
{
auto it = stubInfo_.find(stubFunc.value);
if (it == stubInfo_.cend()) {
ERROR_LOG("find binary handle by stub func failed");
return false;
}
return GetDevBinary(KernelHandlePtr{it->second.hdl}, bin, getStubBin);
}
bool KernelContext::GetLastLaunchEvent(LaunchEvent &event) const
{
return this->GetDeviceContext().GetLastLaunchEvent(event);
}
KernelContext::DeviceContext::DeviceContext(KernelContext &kernelContext) : kernelContext_{kernelContext} {}
bool KernelContext::DeviceContext::GetLastLaunchEvent(LaunchEvent &event) const
{
if (launchEvents_.empty()) {
return false;
}
event = launchEvents_.back();
return true;
}
bool KernelContext::GetStubFuncInfo(StubFuncPtr stubFunc, StubFuncInfo &stubFuncInfo) const
{
auto it = stubInfo_.find(stubFunc.value);
if (it == stubInfo_.cend()) {
ERROR_LOG("find stub func info by stub func failed");
return false;
}
stubFuncInfo = it->second;
return true;
}
void KernelContext::SetArgsSize(const rtArgsSizeInfo_t *const sizeInfo)
{
this->GetDeviceContext().SetArgsSize(sizeInfo);
}
void KernelContext::ClearArgsInfo()
{
this->GetDeviceContext().ClearArgsInfo();
}
void KernelContext::DeviceContext::ParseMetaDataFromBinary(rtDevBinary_t const &binary, const rtArgsEx_t *argsInfo)
{
if (memInfo_.isForSetException) {
DEBUG_LOG("rtSetExceptionExtInfo is called, not parse binary file");
SetFftsInfo(memInfo_, binary.magic);
memInfo_.isForSetException = false;
return;
}
memInfo_.Clear();
std::string kernelName = KernelContext::Instance().GetLaunchName();
std::vector<uint8_t> metaData;
auto shSize = GetMetaSection(binary, kernelName, metaData);
if (shSize == 0) {
DEBUG_LOG("Get meta data failed");
return;
}
auto parser = KernelMetaDataParser(metaData);
if (parser.Parse()) {
OpMemInfo memInfo(memInfo_);
memInfo_ = parser.GetOpMemInfo();
memInfo_.hasOverflowAddr = memInfo.hasOverflowAddr;
memInfo_.overflowAddr = memInfo.overflowAddr;
SetFftsInfo(memInfo_, binary.magic);
if (argsInfo) {
CompleteInputData(*argsInfo, memInfo_);
}
}
}
void KernelContext::DeviceContext::SetArgsSize(const rtArgsSizeInfo_t * const sizeInfo)
{
if (sizeInfo == nullptr || sizeInfo->infoAddr == nullptr) {
return;
}
memInfo_.Clear();
memInfo_.isForSetException = true;
auto *buff = static_cast<uint64_t *>(sizeInfo->infoAddr);
memInfo_.inputNum = buff[1] & U32_MASK;
memInfo_.skipNum = (buff[1] >> 32U) & U32_MASK;
uint64_t rightOperand = 56U;
uint64_t inputOffset = 0U;
for (uint64_t i = 0U; i < memInfo_.inputNum; ++i) {
uint64_t ptrType = (buff[2U + i + inputOffset] >> rightOperand) & 0xff;
uint64_t primaryPtrNum = buff[2U + i + inputOffset] & 0xffffffffffffff;
if (PtrType(ptrType) == PtrType::PRIMARY) {
DEBUG_LOG("Ptr value is %lu", primaryPtrNum);
memInfo_.inputParamsAddrInfos.emplace_back(AddrInfo{0U, primaryPtrNum, MemInfoSrc::EXTRA,
MemInfoDesc::INPUT, i + 1});
}
}
memInfo_.isFFTS = memInfo_.skipNum != 0;
}
bool KernelContext::DumpKernelObject(uint64_t regId, const string &outputPath)
{
if (regId >= registerEvents_.size()) {
return false;
}
const auto ®Event = registerEvents_[regId];
vector<char> buffer(regEvent.binData.begin(), regEvent.binData.end());
Elf64_Ehdr header{};
if (!ElfLoader::LoadHeader(buffer, header)) {
ERROR_LOG("Read elf header failed when load elf from buffer");
return false;
}
if (header.e_flags == FLAG_DEFAULT_VALUE) {
auto correctFlag = GetFlagByMagic(regEvent.bin.magic);
if (correctFlag > 0) {
header.e_flags = correctFlag;
const char *base = static_cast<const char*>(static_cast<const void *>(&header));
std::copy(base, base + sizeof(Elf64_Ehdr), buffer.data());
DEBUG_LOG("default flag of op is forced set to %u", correctFlag);
}
}
size_t written = WriteBinary(outputPath, reinterpret_cast<const char*>(buffer.data()), buffer.size());
return written == buffer.size();
}
static bool UpdateSinkTaskArgsAddr(const uint64_t *args, OpMemInfo &memInfo)
{
aclError ret = ACL_ERROR_NONE;
for (size_t i = 0; i < memInfo.inputParamsAddrInfos.size(); i++) {
auto devAddr = args + memInfo.skipNum + i;
uint64_t argAddr;
ret = aclrtMemcpyImplOrigin(static_cast<void *>(&argAddr), sizeof(uint64_t),
static_cast<const void *>(devAddr), sizeof(uint64_t),
ACL_MEMCPY_DEVICE_TO_HOST);
if (ret != ACL_ERROR_NONE) {
DEBUG_LOG("copy arg from device addr failed, ret=%u", ret);
return false;
}
if (argAddr == 0 && memInfo.inputParamsAddrInfos[i].length > 0) {
DEBUG_LOG("get invalid arg value from device addr");
return false;
}
DEBUG_LOG("dev args[%lu] addr=%#lx, length=%lu", i, argAddr, memInfo.inputParamsAddrInfos[i].length);
memInfo.inputParamsAddrInfos[i].addr = argAddr;
}
return true;
}
static bool UpdateNormalTaskArgsAddr(const uint64_t *args, OpMemInfo &memInfo)
{
for (size_t i = 0; i < memInfo.inputParamsAddrInfos.size(); i++) {
auto &addrInfo = memInfo.inputParamsAddrInfos[i];
addrInfo.addr = *(args + memInfo.skipNum + i);
}
return true;
}
static bool DumpTilingData(const rtArgsEx_t &argsInfo, const string &outputDir, KernelContext::ContextConfig &config)
{
auto *buff = static_cast<uint64_t *>(argsInfo.args);
void *tilingDataOffsetBuff = (argsInfo.tilingDataOffset == 0) ?
nullptr : static_cast<void *>(buff +
((argsInfo.tilingDataOffset) / sizeof(uint64_t)));
void *tilingAddrOffsetBuff = (argsInfo.tilingAddrOffset == 0) ?
nullptr : reinterpret_cast<void *>(*(buff +
(argsInfo.tilingAddrOffset / sizeof(uint64_t))));
void *tilingBuffData = (tilingDataOffsetBuff != nullptr) ? tilingDataOffsetBuff : tilingAddrOffsetBuff;
if (tilingBuffData == nullptr) {
DEBUG_LOG("no tiling data");
return true;
}
if (argsInfo.argsSize <= argsInfo.tilingDataOffset) {
return false;
}
uint32_t tilingDataSize = argsInfo.argsSize - argsInfo.tilingDataOffset;
void *hostData;
if (rtMallocHostOrigin(&hostData, tilingDataSize, 0) != RT_ERROR_NONE) {
return false;
}
Defer defer0(nullptr, [&hostData](std::nullptr_t&) {
if (rtFreeHostOrigin(hostData) != RT_ERROR_NONE) {
ERROR_LOG("rtFreeHost failed.");
}
});
if (rtMemcpyOrigin(hostData, tilingDataSize, tilingBuffData, tilingDataSize,
RT_MEMCPY_DEVICE_TO_HOST) != RT_ERROR_NONE) {
return false;
}
config.tilingDataPath = outputDir + "/input_tiling.bin";
config.tilingDataSize = tilingDataSize;
size_t written = WriteBinary(config.tilingDataPath, (const char *) hostData, tilingDataSize);
if (written != tilingDataSize) {
return false;
}
return true;
}
bool KernelContext::DeviceContext::DumpKernelArgs(const string &outputDir, uint64_t launchId,
KernelContext::ContextConfig &config)
{
if (launchId >= launchEvents_.size() || launchId >= memInfoHistory_.size()) {
return false;
}
const auto &launchEvent = launchEvents_[launchId];
auto argsInfo = launchEvent.argsInfoCopy;
if (argsInfo.args == nullptr) {
DEBUG_LOG("empty kernel args in argsInfo.");
return false;
}
const auto *args = static_cast<const uint64_t *>(argsInfo.args);
auto &memInfo = memInfoHistory_[launchId];
if (launchEvent.isSink && argsInfo.tilingDataOffset == 0) {
return UpdateSinkTaskArgsAddr(args, memInfo) &&
DumpInputData(outputDir, memInfo.inputParamsAddrInfos, config);
} else {
return UpdateNormalTaskArgsAddr(args, memInfo) &&
DumpInputData(outputDir, memInfo.inputParamsAddrInfos, config) &&
DumpTilingData(argsInfo, outputDir, config);
}
}
bool KernelContext::Save(const string &outputDir, uint64_t launchId)
{
return this->GetDeviceContext().Save(outputDir, launchId);
}
bool KernelContext::DeviceContext::Save(const string &outputDir, uint64_t launchId)
{
if (launchId == UINT64_MAX) {
launchId = GetLaunchId();
}
DEBUG_LOG("Start save launch event %lu", launchId);
if (!SaveAicore(outputDir, launchId)) {
return false;
}
ContextConfig config{};
if (!DumpKernelArgs(outputDir, launchId, config)) {
ERROR_LOG("Kernel context save failed, dump kernel args failed");
return false;
}
if (launchId >= launchEvents_.size()) {
DEBUG_LOG("Kernel context save fail because launch Id out of range");
return false;
}
auto &launchEvent = launchEvents_[launchId];
uint64_t regId = kernelContext_.GetRegisterId(launchEvent.hdl, true);
if (regId >= this->kernelContext_.registerEvents_.size()) {
DEBUG_LOG("Kernel context save fail because register id out of range");
return false;
}
auto regEvent = this->kernelContext_.registerEvents_[regId];
auto &devCtx = ::DeviceContext::Local();
config.isFFTS = memInfoHistory_[launchId].isFFTS;
config.magic = regEvent.bin.magic;
config.blockDim = launchEvent.blockDim;
config.kernelName = launchEvent.kernelName;
config.tilingKey = launchEvent.tilingKey;
config.binPath = JoinPath({outputDir, KERNEL_BIN_NAME});
config.hasTilingKey = launchEvent.hasTilingKey;
config.devId = devCtx.GetDeviceId();
config.visibleDevId = devCtx.GetVisibleDeviceId();
return DumpConfig(outputDir, config);
}
bool KernelContext::SaveAicore(const string &outputDir, uint64_t launchId)
{
return this->GetDeviceContext().SaveAicore(outputDir, launchId);
}
bool KernelContext::DeviceContext::SaveAicore(const string &outputDir, uint64_t launchId)
{
if (launchId == UINT64_MAX) {
launchId = GetLaunchId();
}
if (launchId >= launchEvents_.size()) {
ERROR_LOG("Kernel context save failed, invalid launch id");
return false;
}
auto &launchEvent = launchEvents_[launchId];
uint64_t regId = this->kernelContext_.GetRegisterId(launchEvent.hdl, true);
if (regId >= this->kernelContext_.registerEvents_.size()) {
ERROR_LOG("Kernel context save failed, regId is error");
return false;
}
string kernelObjPath = JoinPath({outputDir, KERNEL_BIN_NAME});
if (!this->kernelContext_.DumpKernelObject(regId, kernelObjPath)) {
ERROR_LOG("Kernel context save failed, dump kernel obj failed");
return false;
}
return true;
}
bool KernelContext::GetPcStartAddr(StubFuncPtr stubFunc, uint64_t &pcStartAddr)
{
return this->GetDeviceContext().GetPcStartAddr(StubFuncArgs{stubFunc.value, nullptr}, pcStartAddr);
}
bool KernelContext::GetPcStartAddr(KernelHandlePtr hdl, uint64_t &pcStartAddr)
{
auto it = as_const(hdlToRegId_).find(hdl.value);
if (it == hdlToRegId_.cend()) {
ERROR_LOG("find register id by handle failed");
return false;
}
auto const &kernelNames = registerEvents_[it->second].kernelNames;
if (kernelNames.empty()) {
ERROR_LOG("get kernel name by register id failed");
return false;
}
uint64_t tilingKey = 0UL;
if (!GetTilingKeyFromName(kernelNames.front(), tilingKey)) {
ERROR_LOG("get tiling key by kernel name failed");
return false;
}
return this->GetDeviceContext().GetPcStartAddr(KernelHandleArgs{hdl.value, nullptr, tilingKey}, pcStartAddr);
}
bool KernelContext::DeviceContext::GetKernelAddr(StubFuncArgs const &args, uint64_t &kernelAddr) const
{
void *pcStart = nullptr;
uint32_t prefetchCnt{};
StubFunc const *stubFunc = args.dbiHandle ? args.dbiHandle : args.originHandle;
rtError_t ret = rtKernelGetAddrAndPrefCntOrigin(nullptr, 0UL, stubFunc, 0U, &pcStart, &prefetchCnt);
if (ret != RT_ERROR_NONE) {
ERROR_LOG("call rtKernelGetAddrAndPrefCntOrigin with stub func failed");
kernelAddr = 0x00;
return false;
}
kernelAddr = reinterpret_cast<uint64_t>(pcStart);
DEBUG_LOG("get kernel addr by stub func success. kernelAddr: 0x%lx", kernelAddr);
return true;
}
bool KernelContext::DeviceContext::GetKernelAddr(KernelHandleArgs const &args, uint64_t &kernelAddr) const
{
void *pcStart = nullptr;
uint32_t prefetchCnt{};
KernelHandle const *hdl = args.dbiHandle ? args.dbiHandle : args.originHandle;
rtError_t ret = rtKernelGetAddrAndPrefCntOrigin(const_cast<void *>(hdl), args.tilingKey,
nullptr, 1U, &pcStart, &prefetchCnt);
if (ret != RT_ERROR_NONE) {
ERROR_LOG("call rtKernelGetAddrAndPrefCntOrigin with tiling key failed");
kernelAddr = 0x00;
return false;
}
kernelAddr = reinterpret_cast<uint64_t>(pcStart);
DEBUG_LOG("get kernel addr by tiling key success. tilingKey: %lu, kernelAddr: 0x%lx",
args.tilingKey, kernelAddr);
return true;
}
bool KernelContext::DeviceContext::GetKernelOffset(StubFuncArgs const &args, uint64_t &kernelOffset) const
{
StubFuncInfo stubFuncInfo {};
if (!kernelContext_.GetStubFuncInfo(StubFuncPtr{args.originHandle}, stubFuncInfo)) {
return false;
}
uint64_t registerId = kernelContext_.GetRegisterId(stubFuncInfo.hdl, true);
RegisterEvent registerEvent {};
if (!kernelContext_.GetRegisterEvent(registerId, registerEvent)) {
return false;
}
auto const &kernelNames = registerEvent.kernelNames;
auto it = std::find(kernelNames.cbegin(), kernelNames.cend(), stubFuncInfo.kernelName);
if (it == kernelNames.cend()) {
return false;
}
kernelOffset = registerEvent.kernelOffsets[it - kernelNames.cbegin()];
DEBUG_LOG("kernel offset from pc start addr: 0x%lx", kernelOffset);
return true;
}
bool KernelContext::DeviceContext::GetKernelOffset(KernelHandleArgs const &args, uint64_t &kernelOffset) const
{
uint64_t registerId = kernelContext_.GetRegisterId(args.originHandle, true);
RegisterEvent registerEvent {};
if (!kernelContext_.GetRegisterEvent(registerId, registerEvent)) {
ERROR_LOG("get register event failed. registerId: %lu", registerId);
return false;
}
auto const &kernelNames = registerEvent.kernelNames;
auto const &kernelOffsets = registerEvent.kernelOffsets;
if (!GetKernelOffsetByTilingKey(kernelNames, kernelOffsets, args.tilingKey, kernelOffset)) {
ERROR_LOG("get kernel offset by tiling key failed. tilingKey: %lu", args.tilingKey);
return false;
}
DEBUG_LOG("kernel offset from pc start addr: 0x%lx", kernelOffset);
return true;
}
template <typename ArgsT>
bool KernelContext::DeviceContext::GetPcStartAddr(ArgsT const &args, uint64_t &pcStartAddr) const
{
uint64_t kernelAddr {};
if (!GetKernelAddr(args, kernelAddr)) {
ERROR_LOG("get kernel addr failed");
return false;
}
uint64_t kernelOffset {};
if (!GetKernelOffset(args, kernelOffset)) {
ERROR_LOG("get kernel offset from pc start failed");
return false;
}
if (kernelAddr < kernelOffset) {
ERROR_LOG("kernelAddr should never smaller than kernelOffset. 0x%lx vs. 0x%lx.",
kernelAddr, kernelOffset);
return false;
}
pcStartAddr = kernelAddr - kernelOffset;
DEBUG_LOG("get pc start addr 0x%lx", pcStartAddr);
return true;
}
bool KernelContext::DumpKernelObject(const KernelHandle *hdl, const string &outputDir, const string &savedFileName)
{
auto iter = hdlToRegId_.find(hdl);
if (iter == hdlToRegId_.end()) {
ERROR_LOG("Cannot find object reg id");
return false;
}
if (!IsExist(outputDir) && !MkdirRecusively(outputDir)) {
return false;
}
uint64_t regId = iter->second;
return DumpKernelObject(regId, JoinPath({outputDir, savedFileName}));
}
std::string KernelContext::GetLaunchName(uint64_t launchId)
{
return this->GetDeviceContext().GetLaunchName(launchId);
}
std::string KernelContext::DeviceContext::GetLaunchName(uint64_t launchId)
{
if (launchId == UINT64_MAX) {
launchId = GetLaunchId();
}
if (launchId >= launchEvents_.size()) {
return "invalid_launch_id";
}
return launchEvents_[launchId].kernelName;
}
uint32_t KernelContext::GetBlockId()
{
return this->GetDeviceContext().GetBlockId();
}
uint32_t KernelContext::DeviceContext::GetBlockId()
{
uint64_t launchId = GetLaunchId();
if (launchId == UINT64_MAX) {
return UINT32_MAX;
}
return launchEvents_[launchId].blockDim;
}
bool KernelContext::GetLaunchEvent(uint64_t launchId, LaunchEvent &event) const
{
return this->GetDeviceContext().GetLaunchEvent(launchId, event);
}
bool KernelContext::DeviceContext::GetLaunchEvent(uint64_t launchId, LaunchEvent &event) const
{
if (launchId < launchEvents_.size()) {
event = launchEvents_[launchId];
return true;
}
return false;
}
uint64_t KernelContext::GetRegisterId(uint64_t launchId)
{
return this->GetDeviceContext().GetRegisterId(launchId);
}
uint64_t KernelContext::DeviceContext::GetRegisterId(uint64_t launchId)
{
std::lock_guard<std::mutex> guard(this->deviceMutex_);
if (launchId >= launchEvents_.size()) {
return UINT64_MAX;
}
auto event = launchEvents_[launchId];
return this->kernelContext_.GetRegisterId(event.hdl, true);
}
uint64_t KernelContext::GetRegisterId(const void *key, bool withHandle)
{
std::lock_guard<std::mutex> guard(this->mtx_);
if (!withHandle) {
auto it = stubInfo_.find(key);
if (it == stubInfo_.end()) {
return UINT64_MAX;
}
key = it->second.hdl;
}
auto it = hdlToRegId_.find(key);
if (it == hdlToRegId_.end()) {
return UINT64_MAX;
}
return it->second;
}
bool KernelContext::GetRegisterEvent(uint64_t regId, RegisterEvent &event)
{
if (regId >= registerEvents_.size()) {
return false;
}
event = registerEvents_[regId];
return true;
}
uint64_t KernelContext::GetLaunchId() const
{
return this->GetDeviceContext().GetLaunchId();
}
uint64_t KernelContext::DeviceContext::GetLaunchId() const
{
if (launchEvents_.empty()) {
return UINT64_MAX;
}
return launchEvents_.size() - 1;
}
static string GetFullKernelName(const char *name, const vector<string> &fullNames)
{
if (!name) {
return "";
}
uint64_t length = GetValidLength(name, MstxAPI::MAX_KERNEL_NAME_LENGTH);
string subName(name, length);
for (const auto &fullName: fullNames) {
if (fullName == subName) {
return fullName;
}
if (fullName.find(subName + MIX_AIC_TAIL) != fullName.npos) {
return fullName;
}
if (fullName.find(subName + MIX_AIV_TAIL) != fullName.npos) {
return fullName;
}
}
return "";
}
void KernelContext::AddFuncRegisterEvent(const KernelHandle *binHdl, const StubFunc *stubFunc,
const char *stubName, const void *kernelInfoExt, uint32_t funcMode)
{
std::string stubNameStr("anonymous");
if (stubName) {
stubNameStr = stubName;
}
do {
auto it = hdlToRegId_.find(binHdl);
if (it == hdlToRegId_.end()) {
break;
}
RegisterEvent event;
if (!GetRegisterEvent(it->second, event)) {
break;
}
string tmpName = GetFullKernelName(stubName, event.kernelNames);
if (!tmpName.empty()) {
stubNameStr = tmpName;
break;
}
tmpName = GetFullKernelName(static_cast<const char *>(kernelInfoExt), event.kernelNames);
if (!tmpName.empty()) {
stubNameStr = tmpName;
break;
}
} while (false);
std::string kernelInfoExtStr{};
if (kernelInfoExt) {
kernelInfoExtStr = static_cast<const char *>(kernelInfoExt);
}
stubInfo_.insert({stubFunc, StubFuncInfo{binHdl, stubNameStr, kernelInfoExtStr, funcMode}});
}
void KernelContext::AddLaunchEvent(const StubFunc *stubFunc, uint32_t blockDim,
const rtArgsEx_t *argsInfo, rtStream_t stm)
{
auto it = stubInfo_.find(stubFunc);
if (it == stubInfo_.end()) {
return;
}
this->GetDeviceContext().AddLaunchEvent(stubFunc, it->second, blockDim, argsInfo, stm);
}
void KernelContext::DeviceContext::AddLaunchEvent(const StubFunc *stubFunc, StubFuncInfo const &info,
uint32_t blockDim, const rtArgsEx_t *argsInfo, rtStream_t stm)
{
uint64_t pcStartAddr {};
if (!GetPcStartAddr(StubFuncArgs{stubFunc, nullptr}, pcStartAddr)) {
WARN_LOG("get pc start addr by stub func failed. pc start addr will fallback to 0");
}
auto stmInfo = LaunchManager::GetOrCreateStreamInfo(stm);
rtArgsEx_t argsInfoCopy{};
if (argsInfo) {
argsInfoCopy = *argsInfo;
}
launchEvents_.push_back({false, stubFunc, info.hdl, false, 0ULL,
pcStartAddr, argsInfo, blockDim, info.kernelName,
argsInfoCopy, stmInfo.binded});
}
void KernelContext::AddLaunchEvent(const KernelHandle *hdl, uint64_t tilingKey, uint32_t blockDim,
const rtArgsEx_t *argsInfo, rtStream_t stm)
{
std::string kernelName = GetNameByTilingKey(hdl, tilingKey);
this->GetDeviceContext().AddLaunchEvent(hdl, tilingKey, blockDim, argsInfo, kernelName, stm);
}
void KernelContext::DeviceContext::AddLaunchEvent(const KernelHandle *hdl, uint64_t tilingKey, uint32_t blockDim,
const rtArgsEx_t *argsInfo, std::string const &kernelName, rtStream_t stm)
{
uint64_t pcStartAddr {};
if (!GetPcStartAddr(KernelHandleArgs{hdl, nullptr, tilingKey}, pcStartAddr)) {
WARN_LOG("get pc start addr by handle failed. pc start addr will fallback to 0");
}
auto stmInfo = LaunchManager::GetOrCreateStreamInfo(stm);
rtArgsEx_t argsInfoCopy{};
if (argsInfo) {
argsInfoCopy = *argsInfo;
}
launchEvents_.push_back({true, nullptr, hdl, true, tilingKey,
pcStartAddr, argsInfo, blockDim, kernelName,
argsInfoCopy, stmInfo.binded});
}
bool KernelContext::GetMc2CtxFlag() const
{
return this->GetDeviceContext().GetMc2CtxFlag();
}
void KernelContext::ParseSecondPtrAddrs(
const rtArgsEx_t &argsInfo, OpMemInfo &opMemInfo, uint32_t index) const
{
auto infosIt = opMemInfo.secondPtrAddrInfos.find(index);
if (infosIt == opMemInfo.secondPtrAddrInfos.cend()) {
return;
}
if (argsInfo.hostInputInfoPtr == nullptr) {
WARN_LOG("Host input ptr is null");
return;
}
auto *buff = static_cast<uint64_t *>(argsInfo.args);
for (auto i = 0; i < argsInfo.hostInputInfoNum; ++i) {
auto curHostInfoPtr = argsInfo.hostInputInfoPtr[i];
uint64_t hostIndex = curHostInfoPtr.addrOffset / sizeof(uint64_t);
if (hostIndex != uint64_t(index)) {continue;}
auto &secondPtrInfo = infosIt->second;
uint64_t *hostInputsPtr = buff + curHostInfoPtr.dataOffset / sizeof(uint64_t);
uint64_t ptrOffset = *hostInputsPtr / sizeof(uint64_t);
DEBUG_LOG("hostInput index:%lu addrOffset:%u dataOffset:%u ptrOffset:%lu", hostIndex,
curHostInfoPtr.addrOffset, curHostInfoPtr.dataOffset, ptrOffset);
uint64_t *dimCntPtr = hostInputsPtr + 1;
size_t dimIdx = 0;
size_t ptrIdx = 0;
while (ptrOffset > 0 && dimIdx < ptrOffset - 1) {
auto dimCntVal = dimCntPtr[dimIdx];
std::ostringstream oss;
uint32_t dim = static_cast<uint32_t>(dimCntVal & 0xFFFFFFFFULL);
uint32_t cnt = static_cast<uint32_t>((dimCntVal >> 32) & 0xFFFFFFFFULL);
oss << "dim:" << dim << " cnt:" << cnt;
dimIdx++;
uint64_t totalDim = 1;
oss << " shape:[";
for (size_t j = 0; j < dim; ++j) {
auto dimVal = dimCntPtr[dimIdx + j];
totalDim *= dimVal;
oss << dimVal << ",";
}
oss << "] dtypeBytes:" << secondPtrInfo.dtypeBytes << std::hex << " addr:0x" <<
hostInputsPtr[ptrOffset + ptrIdx] << std::dec;
DEBUG_LOG("hostDfxInfo %.2048s", oss.str().c_str());
secondPtrInfo.addrInfoVec.push_back(AddrInfo{hostInputsPtr[ptrOffset + ptrIdx],
totalDim * secondPtrInfo.dtypeBytes, MemInfoSrc::EXTRA, MemInfoDesc::DOUBLE_PTR, ptrIdx + 1});
dimIdx += dim;
ptrIdx++;
}
}
}
std::vector<AddrInfo> KernelContext::ParseMc2CtxAddrs(uint64_t addr) const
{
void *hostData;
if (aclrtMallocHostImplOrigin(&hostData, sizeof(HcclCombinOpParam)) != ACL_ERROR_NONE) {
return {};
}
Defer defer0(nullptr, [&hostData](std::nullptr_t&) {
if (aclrtFreeHostImplOrigin(hostData) != ACL_ERROR_NONE) {
ERROR_LOG("rtFreeHost failed");
}
});
aclError error = aclrtMemcpyImplOrigin(hostData, sizeof(HcclCombinOpParam),
reinterpret_cast<void *>(addr), sizeof(HcclCombinOpParam), ACL_MEMCPY_DEVICE_TO_HOST);
if (error != ACL_ERROR_NONE) {
ERROR_LOG("ParseMc2CtxAddrs rtMemcpy error: %d", error);
return {};
}
auto opParamPtr = reinterpret_cast<HcclCombinOpParam *>(hostData);
if (opParamPtr->multiFlag != 0U) {
ERROR_LOG("multiFlag is :%u, not 0, currently unable to parse mc2 address",
static_cast<uint32_t>(opParamPtr->multiFlag));
return {};
}
std::vector<AddrInfo> mc2CtxAddrs;
bool isThroughAdump = ArgsManager::Instance().GetThroughAdumpFlag();
MemInfoSrc infoSrc = isThroughAdump ? MemInfoSrc::EXTRA : MemInfoSrc::BYPASS;
int32_t deviceId = 0;
aclrtGetDeviceImplOrigin(&deviceId);
for (size_t i = 0; i < opParamPtr->rankNum; ++i) {
if (!isThroughAdump && static_cast<size_t>(deviceId) == i) { continue; }
mc2CtxAddrs.push_back({opParamPtr->windowsIn[i], opParamPtr->winSize, infoSrc, MemInfoDesc::HCCL_MC2_CONTEXT});
mc2CtxAddrs.push_back({opParamPtr->windowsOut[i], opParamPtr->winSize, infoSrc, MemInfoDesc::HCCL_MC2_CONTEXT});
}
return mc2CtxAddrs;
}
void KernelContext::AddHdlRegisterEvent(const KernelHandle *hdl, const rtDevBinary_t *bin,
const rtDevBinary_t *stubBin)
{
if (!hdl || !bin) {
return;
}
if (bin->length >= MAX_MEM_BYTE_SIZE) {
return;
}
auto binData = static_cast<const uint8_t *>(bin->data);
RegisterEvent event;
event.hdl = hdl;
event.bin = *bin;
event.stubBin = stubBin;
if (binData) {
event.binData = vector<uint8_t>(binData, binData + bin->length);
event.bin.data = event.binData.data();
}
event.isKernelWithDBI = !HasStaticStub(*bin);
event.hasDebugLine = HasDebugLine(*bin);
GetSymInfoFromBinary(static_cast<const char*>(bin->data), bin->length, event.kernelNames, event.kernelOffsets);
if (!event.kernelNames.empty()) {
DEBUG_LOG("Register first kernelName is %s", event.kernelNames.front().c_str());
}
hasSimt_ = GetSimtSymbolFromBinary(static_cast<const char*>(bin->data), bin->length);
hdlToRegId_[hdl] = registerEvents_.size();
registerEvents_.push_back(std::move(event));
}
KernelContext::DeviceContext const &KernelContext::GetDeviceContext() const
{
int32_t deviceId {};
aclrtGetDeviceImplOrigin(&deviceId);
std::lock_guard<std::mutex> guard(this->mtx_);
auto it = this->deviceContextMap_.find(deviceId);
if (it != this->deviceContextMap_.cend()) {
return it->second;
}
auto p = this->deviceContextMap_.emplace(std::piecewise_construct,
std::forward_as_tuple(deviceId),
std::forward_as_tuple(*const_cast<KernelContext*>(this)));
return p.first->second;
}
KernelContext::DeviceContext &KernelContext::GetDeviceContext()
{
int32_t deviceId {};
aclrtGetDeviceImplOrigin(&deviceId);
std::lock_guard<std::mutex> guard(this->mtx_);
auto it = this->deviceContextMap_.find(deviceId);
if (it != this->deviceContextMap_.cend()) {
return it->second;
}
auto p = this->deviceContextMap_.emplace(std::piecewise_construct,
std::forward_as_tuple(deviceId),
std::forward_as_tuple(*this));
return p.first->second;
}
OpMemInfo &KernelContext::GetOpMemInfo()
{
return this->GetDeviceContext().GetOpMemInfo();
}
OpMemInfo &KernelContext::DeviceContext::GetOpMemInfo()
{
return this->memInfo_;
}
void KernelContext::SetMC2Flag()
{
AicpuLaunchArgs& aicpuLaunchArgs = KernelContext::GetAicpuLaunchArgs();
string aiCoreKernelName = "invalid_kernel_name";
auto lastContext = LaunchManager::Local().GetLastContext();
if (lastContext != nullptr && lastContext->GetFuncContext() != nullptr) {
aiCoreKernelName = lastContext->GetFuncContext()->GetKernelName();
}
string aiCpuKernelName = aicpuLaunchArgs.opName;
bool isMC2 = false;
if (EndsWith(aiCpuKernelName, MC2_AICPU_SUFFIX)) {
string kernelName = aiCpuKernelName.substr(0, aiCpuKernelName.rfind(MC2_AICPU_SUFFIX));
isMC2 = aicpuLaunchArgs.isValid && (aiCoreKernelName == kernelName ||
StartsWith(aiCoreKernelName, kernelName + "_"));
}
this->GetDeviceContext().SetMC2Flag(isMC2);
}
bool KernelContext::GetMC2Flag()
{
return this->GetDeviceContext().GetMC2Flag();
}
bool KernelContext::GetLcclFlag()
{
return this->GetDeviceContext().GetLcclFlag();
}
void KernelContext::SetLcclFlag(const bool &isLccl)
{
this->GetDeviceContext().SetLcclFlag(isLccl);
}
void KernelContext::ParseMetaDataFromBinary(const rtDevBinary_t &binary, const rtArgsEx_t *argsInfo)
{
return this->GetDeviceContext().ParseMetaDataFromBinary(binary, argsInfo);
}
void KernelContext::ArchiveMemInfo()
{
return this->GetDeviceContext().ArchiveMemInfo();
}
void KernelContext::DeviceContext::ListenCallbackTask(int32_t devId)
{
auto ret = aclrtSetDeviceImplOrigin(devId);
if (ret != ACL_ERROR_NONE) {
ERROR_LOG("Set device to %d failed, ret=%u, stop listening",
devId, ret);
return;
}
DEBUG_LOG("set device id to %d for listen callback", devId);
constexpr int32_t timeout = 1000;
while (!stopListen_) {
if (aclrtProcessReportImplOrigin(timeout) != ACL_ERROR_NONE &&
rtProcessReportOrigin(timeout) != RT_ERROR_NONE) {
continue;
}
DEBUG_LOG("Finish one callback function");
}
}
bool KernelContext::DeviceContext::SubscribeReport(rtStream_t stm)
{
auto &streamInfo = LaunchManager::GetOrCreateStreamInfo(stm);
if (streamInfo.subscribed) {
return true;
}
stopListen_ = false;
if (!listenThread_.joinable()) {
listenThread_ = thread(&KernelContext::DeviceContext::ListenCallbackTask, this, ::DeviceContext::Local().GetDeviceId());
}
std::ostringstream oss;
oss << listenThread_.get_id();
uint64_t tid;
if (!StringToNum(oss.str(), tid)) {
ERROR_LOG("get thread id number failed: %s", oss.str().c_str());
return false;
}
if (aclrtSubscribeReportImplOrigin(tid, stm) != ACL_ERROR_NONE &&
rtSubscribeReportOrigin(tid, stm) != RT_ERROR_NONE) {
DEBUG_LOG("subscribe report for stream failed");
return false;
}
DEBUG_LOG("Subscribe report success.");
streamInfo.subscribed = true;
return true;
}
KernelContext::DeviceContext::~DeviceContext()
{
stopListen_ = true;
if (listenThread_.joinable()) {
listenThread_.join();
}
}
bool KernelContext::SetDbiBinary(uint64_t registerId, rtDevBinary_t const &devBinary)
{
std::lock_guard<std::mutex> guard(this->mtx_);
if (registerId >= registerEvents_.size()) {
ERROR_LOG("registerId exceeds size of registerEvents. %lu vs. %zu.", registerId, registerEvents_.size());
return false;
}
RegisterEvent ®isterEvent = registerEvents_[registerId];
registerEvent.stubBin = &devBinary;
registerEvent.kernelNames.clear();
registerEvent.kernelOffsets.clear();
GetSymInfoFromBinary(static_cast<const char*>(devBinary.data), devBinary.length,
registerEvent.kernelNames, registerEvent.kernelOffsets);
return true;
}
template bool KernelContext::DeviceContext::UpdatePcStartAddrByDbi(uint64_t launchId, StubFuncArgs const &args);
template bool KernelContext::DeviceContext::UpdatePcStartAddrByDbi(uint64_t launchId, KernelHandleArgs const &args);
template <typename ArgsT>
bool KernelContext::DeviceContext::UpdatePcStartAddrByDbi(uint64_t launchId, ArgsT const &args)
{
if (launchId >= launchEvents_.size()) {
ERROR_LOG("launchId exceeds size of launchEvents. %lu vs. %zu.", launchId, launchEvents_.size());
return false;
}
LaunchEvent &launchEvent = launchEvents_[launchId];
uint64_t pcStartAddr {};
if (!GetPcStartAddr(args, pcStartAddr)) {
ERROR_LOG("update pc start addr by dbi failed. launchId: %lu", launchId);
return false;
}
launchEvent.pcStartAddr = pcStartAddr;
DEBUG_LOG("pc start addr for kernel with dbi: 0x%lx", pcStartAddr);
return true;
}