* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file device_execute_context.cpp
* \brief
*/
#include "machine/device/dynamic/context/device_execute_context.h"
#include "machine/device/distributed/shmem_wait_until.h"
#include "tileop/distributed/comm_context.h"
#include <cinttypes>
namespace npu::tile_fwk::dynamic {
bool DeviceExecuteContext::DuppedRootCached()
{
if (!controlFlowCacheActivated) {
return false;
}
return duppedRootCount < devProg->ctrlFlowCacheAnchor->rootTaskCount;
}
bool DeviceExecuteContext::DuppedRootUpdateAndCachedAllSubmitted()
{
if (!controlFlowCacheActivated) {
return false;
}
duppedRootCount++;
return duppedRootCount == devProg->ctrlFlowCacheAnchor->rootTaskCount;
}
int DeviceExecuteContext::RunInit(DevStartArgs* startArgs, PushTaskEntry tPushTask)
{
PerfBegin(PERF_EVT_CONTROL_FLOW_INIT);
this->pushTask = tPushTask;
this->args = startArgs;
this->devProg = startArgs->devProg;
parallelCtx.InitParallel(devProg->GetParallelism());
workspace.Init(startArgs);
stitchTaskLoopNumThreshold = devProg->stitchMaxFunctionNum;
slotContext.InitAllocator(workspace, devProg->slotSize);
slotContext.FillInputOutputSlot(devProg, startArgs);
stitchContext.Init(devProg, workspace);
taskContext.InitAllocator(devProg, workspace, startArgs);
workspace.SetupVector(symbolTable);
symbolTable.resize(devProg->symbolTable.size());
for (int index = 0; index < startArgs->GetInputSymbolSize(); ++index) {
DevInputSymbol& param = startArgs->GetInputSymbol(index);
int inputSymbolIndex = this->devProg->startArgsInputSymbolIndexList[index];
symbolTable[inputSymbolIndex] = param.value;
DEV_INFO("Param %d Symbol Table %d = %ld.", index, inputSymbolIndex, param.value);
}
workspace.AllocateStitchCache();
The remaining portion of AICPU workspace meta memory must support reclamation. */
workspace.InitMetadataSlabAllocator();
PerfEnd(PERF_EVT_CONTROL_FLOW_INIT);
DEV_INFO("Image size is %lu.", devProg->GetSize());
return DEVICE_MACHINE_OK;
}
DeviceExecuteContext::DeviceExecuteContext(DevStartArgs* startArgs)
{
PerfBegin(PERF_EVT_INIT);
this->devProg = startArgs->devProg;
DEV_IF_VERBOSE_DEBUG
{
std::string dump = devProg->Dump(0, true);
DEV_VERBOSE_DEBUG_SPLIT("[DEVICE] %s.", dump.c_str());
}
PerfBegin(PERF_EVT_CONTROL_FLOW_MAPEXE);
if (startArgs->controlFlowEntry) {
execProg = DeviceExecuteProgram(
devProg,
reinterpret_cast<AOTBinaryControlFlow::controlFlowEntry>(const_cast<void*>(startArgs->controlFlowEntry)));
AOTCodePool::GetCodePool().MapExec();
}
PerfEnd(PERF_EVT_CONTROL_FLOW_MAPEXE);
PerfEnd(PERF_EVT_INIT);
}
void DeviceExecuteContext::PushTask(DynDeviceTask* dynTask)
{
pushTask(dynTask, this);
taskId++;
}
void DeviceExecuteContext::ShowStats()
{
taskContext.ShowStats();
workspace.DumpMemoryUsage("End ExecDyn");
}
void DeviceExecuteContext::GELaunchRunCached(DevStartArgs* startArgs, PushTaskEntry tPushTask)
{
PerfBegin(PERF_EVT_CONTROL_FLOW_INIT);
this->pushTask = tPushTask;
this->args = startArgs;
this->devProg = startArgs->devProg;
PerfEnd(PERF_EVT_CONTROL_FLOW_INIT);
PerfMtTrace(PERF_TRACE_INIT, CTRL_CPU_THREAD_IDX);
PerfBegin(PERF_EVT_CONTROL_FLOW);
for (size_t index = 0; index < devProg->ctrlFlowCacheAnchor->deviceTaskCount; index++) {
DynDeviceTask* dynTask =
reinterpret_cast<DynDeviceTask*>(devProg->ctrlFlowCacheAnchor->deviceTaskCacheList[index].dynTaskBase);
devProg->ctrlFlowCacheAnchor->PredCountDataRestore(dynTask);
devProg->ctrlFlowCacheAnchor->ReadyQueueDataRestore(dynTask);
devProg->ctrlFlowCacheAnchor->DieReadyQueueDataRestore(dynTask);
devProg->ctrlFlowCacheAnchor->MixTaskDataRestore(dynTask);
taskContext.UpdateReadyTaskNum(dynTask->readyQueueBackup->readyTaskNum);
parallelCtx.info = dynTask->parallelInfo;
if (parallelCtx.info.forId != 0) {
parallelCtx.isInParallelForScope = true;
}
PROF_STAGE_BEGIN(PERF_EVT_STAGE_PUSH_TASK, "push.before\n");
DumpDeviceTask(taskId, dynTask);
PerfMtTrace(PERF_TRACE_DEV_TASK_BUILD, CTRL_CPU_THREAD_IDX);
PushTask(dynTask);
PROF_STAGE_END(PERF_EVT_STAGE_PUSH_TASK, "push.after\n");
}
PerfEnd(PERF_EVT_CONTROL_FLOW);
}
int DeviceExecuteContext::RunControlFlow(DevStartArgs* startArgs)
{
PerfBegin(PERF_EVT_CONTROL_FLOW);
RuntimeCallEntryType runtimeCallList[static_cast<uint32_t>(RuntimeCallStage::T_RUNTIME_CALL_MAX)] = {
DeviceExecuteRuntimeCallRootAlloc,
DeviceExecuteRuntimeCallRootStitch,
DeviceExecuteRuntimeCallLog,
DeviceExecuteRuntimeCallShmemAllocator,
DeviceExecuteRuntimeCallSlotMarkNeedAlloc,
DeviceExecuteRuntimeCallGetLoopDieId,
DeviceExecuteRuntimeCallSetLoopDieId,
};
int originalErrorState = this->GetErrorState();
execProg.controlFlowBinary.CallControlFlow(this, symbolTable.data(), runtimeCallList, startArgs);
int finalErrorState = this->GetErrorState();
if (finalErrorState != originalErrorState && finalErrorState != DEVICE_MACHINE_OK) {
DEV_ERROR(
CtrlErr::CTRL_FLOW_EXEC_FAILED, "#ctrl.ctrlflow.leave: Control flow execution failed with error code: %d",
finalErrorState);
return finalErrorState;
}
PerfEnd(PERF_EVT_CONTROL_FLOW);
return DEVICE_MACHINE_OK;
}
int DeviceExecuteContext::GELaunchFullCacheRunControlFlow(DevStartArgs* startArgs, PushTaskEntry tPushTask)
{
int ret = DEVICE_MACHINE_OK;
ret = RunInit(startArgs, tPushTask);
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return DEVICE_MACHINE_ERROR;
}
ret = RunControlFlow(startArgs);
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return DEVICE_MACHINE_ERROR;
}
return ret;
}
void DeviceExecuteContext::GELaunchFullCache(DevStartArgs* startArgs, PushTaskEntry tPushTask)
{
if (devProg->ctrlFlowCacheAnchor->IsActivatedFullCache(startArgs)) {
DEV_TRACE_DEBUG(CtrlEvent(none(), ControlFlowCacheFullRunCache()));
GELaunchRunCached(startArgs, tPushTask);
} else {
DEV_TRACE_DEBUG(CtrlEvent(none(), ControlFlowCacheFullRunControl()));
GELaunchFullCacheRunControlFlow(startArgs, tPushTask);
}
}
int DeviceExecuteContext::GELaunch(DevStartArgs* startArgs, PushTaskEntry tPushTask)
{
int ret = DEVICE_MACHINE_OK;
if (devProg->ctrlFlowCacheAnchor->IsRecording()) {
devProg->ctrlFlowCacheAnchor->InitInputOutput(startArgs);
}
ret = GELaunchPartialCache(startArgs, tPushTask);
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return DEVICE_MACHINE_ERROR;
}
return DEVICE_MACHINE_OK;
}
int DeviceExecuteContext::GELaunchPartialCache(DevStartArgs* startArgs, PushTaskEntry tPushTask)
{
int ret = DEVICE_MACHINE_OK;
DEV_TRACE_DEBUG(CtrlEvent(
none(),
Workspace(Range(
startArgs->contextWorkspaceAddr, startArgs->contextWorkspaceAddr + startArgs->contextWorkspaceSize))));
if (devProg->ctrlFlowCacheAnchor->IsActivatedPartialCache(startArgs)) {
controlFlowCacheActivated = true;
DEV_TRACE_DEBUG(CtrlEvent(
none(), ControlFlowCachePartRunCache(
devProg->ctrlFlowCacheAnchor->deviceTaskCount, devProg->ctrlFlowCacheAnchor->rootTaskCount)));
GELaunchRunCached(startArgs, tPushTask);
}
DEV_TRACE_DEBUG(CtrlEvent(none(), ControlFlowCacheFullRunControl()));
ret = RunInit(startArgs, tPushTask);
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return DEVICE_MACHINE_ERROR;
}
ret = RunControlFlow(startArgs);
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return DEVICE_MACHINE_ERROR;
}
return ret;
}
bool DeviceExecuteContext::AiCoreFree()
{
return false;
}
void DeviceExecuteContext::DumpDeviceTask(uint64_t taskId, DynDeviceTask* deviceTask)
{
DEV_IF_VERBOSE_DEBUG {}
else
{
return;
}
for (uint64_t dupIdx = 0; dupIdx < deviceTask->dynFuncDataCacheListSize; dupIdx++) {
DevAscendFunctionDuppedData* dupped = deviceTask->dynFuncDataCacheList[dupIdx].duppedData;
DEV_TRACE_DEBUG(
REvent(RUid(taskId, dupIdx, dupped->GetSource()->GetRootIndex()), dupped->SchemaGetWorkspace()));
size_t incastSize = dupped->GetSource()->GetIncastSize();
DEV_TRACE_DEBUG(REvent(RUid(taskId, dupIdx, dupped->GetSource()->GetRootIndex()), RActIncastCount(incastSize)));
for (size_t i = 0; i < incastSize; ++i) {
DEV_TRACE_DEBUG(REvent(
RUid(taskId, dupIdx, dupped->GetSource()->GetRootIndex()),
RActIncast(i, dupped->SchemaGetIncastRange(i))));
}
size_t outcastSize = dupped->GetSource()->GetOutcastSize();
DEV_TRACE_DEBUG(
REvent(RUid(taskId, dupIdx, dupped->GetSource()->GetRootIndex()), RActOutcastCount(outcastSize)));
for (size_t i = 0; i < outcastSize; ++i) {
DEV_TRACE_DEBUG(REvent(
RUid(taskId, dupIdx, dupped->GetSource()->GetRootIndex()),
RActOutcast(i, dupped->SchemaGetOutcastRange(i))));
}
DEV_TRACE_DEBUG(REvent(
RUid(taskId, dupIdx, dupped->GetSource()->GetRootIndex()),
RActExpressionCount(dupped->GetExpressionSize())));
DEV_TRACE_DEBUG_SPLIT(
REvent(RUid(taskId, dupIdx, dupped->GetSource()->GetRootIndex()), expr(dupped->SchemaGetExpressionList())));
}
}
void DeviceExecuteContext::ProcessControlFlowCacheRecord(DynDeviceTask* dynTask)
{
if (devProg->ctrlFlowCacheAnchor->IsRecording()) {
if (!devProg->ctrlFlowCacheAnchor->IsRecordingStopped()) {
devProg->ctrlFlowCacheAnchor->PredCountDataBackup(dynTask);
devProg->ctrlFlowCacheAnchor->ReadyQueueDataBackup(dynTask);
devProg->ctrlFlowCacheAnchor->DieReadyQueueDataBackup(dynTask);
devProg->ctrlFlowCacheAnchor->MixTaskDataBackup(dynTask);
devProg->ctrlFlowCacheAnchor->IncastOutcastAddrBackup(dynTask);
devProg->ctrlFlowCacheAnchor->TaskAddrBackupWorkspace(dynTask);
devProg->ctrlFlowCacheAnchor->RuntimeAddrBackup(
slotContext.GetSlotList(), workspace.GetRuntimeOutcastTensorPool(), devProg->slotSize,
devProg->runtimeOutcastPoolSize, workspace.GetTensorAllocator(), devProg->GetParallelism());
}
devProg->ctrlFlowCacheAnchor->AppendDeviceTask(dynTask);
}
}
void DeviceExecuteContext::CalcControlMaxAicore()
{
currentMaxC_ = 0;
currentMaxV_ = 0;
const auto& stitchedList = stitchContext.GetStitchedList();
for (size_t i = 0; i < stitchedList.size(); i++) {
const DevAscendFunction* sourceFunc = stitchedList[i].GetSource();
if (sourceFunc != nullptr) {
currentMaxC_ += sourceFunc->GetMaxC();
currentMaxV_ += sourceFunc->GetMaxV();
}
}
if (currentMaxC_ == 0 && currentMaxV_ == 0) {
currentMaxC_ = devProg->devArgs.nrValidAic;
currentMaxV_ = currentMaxC_ * AIV_NUM_PER_AI_CORE;
return;
}
uint32_t oriAivNum = devProg->devArgs.nrValidAic * AIV_NUM_PER_AI_CORE;
currentMaxC_ = currentMaxC_ >= devProg->devArgs.nrValidAic ? devProg->devArgs.nrValidAic : currentMaxC_;
currentMaxV_ = currentMaxV_ >= oriAivNum ? oriAivNum : currentMaxV_;
if (devProg->devArgs.archInfo == ArchInfo::DAV_3510) {
if (currentMaxC_ * AIV_NUM_PER_AI_CORE >= currentMaxV_) {
currentMaxV_ = currentMaxC_ * AIV_NUM_PER_AI_CORE;
} else {
currentMaxV_ = (currentMaxV_ & 1) ? currentMaxV_ + 1 : currentMaxV_;
currentMaxC_ = currentMaxV_ / AIV_NUM_PER_AI_CORE;
}
}
DEV_INFO("[CalcControlMaxAicore] stitchedSize=%u, maxC=%u, maxV=%u",
static_cast<uint32_t>(stitchedList.size()), currentMaxC_, currentMaxV_);
}
int DeviceExecuteContext::PrepareShmemWaitUntilTasks(DynDeviceTask* dynTask)
{
auto funcDataList = reinterpret_cast<npu::tile_fwk::DynFuncData*>(&dynTask->GetDynFuncDataList()->At(0));
auto hcclContextAddr = funcDataList->startArgs->commContexts;
size_t cacheSize = sizeof(npu::tile_fwk::Distributed::ShmemWaitUntilCache);
WsAllocation alloc = ControlFlowAllocateSlab(
devProg, cacheSize, workspace.SlabAlloc(cacheSize, WsAicpuSlabMemType::SHMEM_WAIT_UNTIL_CACHE));
auto cache = alloc.As<npu::tile_fwk::Distributed::ShmemWaitUntilCache>();
if (cache == nullptr) {
DEV_ERROR(CtrlErr::CTRL_FLOW_EXEC_FAILED, "AllocateCache failed for ShmemWaitUntilCache");
return DEVICE_MACHINE_ERROR;
}
uint32_t taskCount = 0;
for (uint64_t funcId = 0; funcId < dynTask->dynFuncDataCacheListSize; ++funcId) {
auto callList = dynTask->dynFuncDataCacheList[funcId].calleeList;
for (size_t opIndex = 0; opIndex < dynTask->dynFuncDataCacheList[funcId].devFunc->GetOperationSize();
++opIndex) {
auto coreType = dynTask->cceBinary[callList[opIndex]].coreType;
if (coreType == static_cast<int>(MachineType::AICPU)) {
if (taskCount >= npu::tile_fwk::Distributed::AICPU_TASK_ARRAY_SIZE) {
DEV_ERROR(
CtrlErr::CTRL_FLOW_EXEC_FAILED,
"PrepareShmemWaitUntilTasks: taskCount=%u exceeds AICPU_TASK_ARRAY_SIZE=%lu", taskCount,
npu::tile_fwk::Distributed::AICPU_TASK_ARRAY_SIZE);
return DEVICE_MACHINE_ERROR;
}
uint32_t aicpuTaskId =
(static_cast<uint32_t>(funcId) << TASKID_TASK_BITS) | static_cast<uint32_t>(opIndex);
auto& code = dynTask->aicpuLeafBinary[callList[opIndex]].aicpuLeafCode;
auto prepareRet = npu::tile_fwk::Distributed::ShmemWaitUntilImpl::PrepareTask(
aicpuTaskId, code, cache->taskArray, taskCount, funcDataList, hcclContextAddr);
if (prepareRet != DEVICE_MACHINE_OK) {
DEV_ERROR(
CtrlErr::CTRL_FLOW_EXEC_FAILED, "PrepareTask failed: aicpuTaskId=%u ret=%d", aicpuTaskId,
prepareRet);
return DEVICE_MACHINE_ERROR;
}
++taskCount;
}
}
}
cache->taskCount = taskCount;
npu::tile_fwk::Distributed::ShmemWaitUntilImpl::BuildHashTable(cache, taskCount);
dynTask->shmemWaitUntilCacheBackup = cache;
DEV_INFO("[ControlFlow] PrepareShmemWaitUntilTasks END: prepared %u tasks, cache=%p", taskCount, cache);
return DEVICE_MACHINE_OK;
}
int DeviceExecuteContext::SubmitToAicoreAndRecycleMemory(bool withoutTail, bool isLastTask, bool isParallelIterLastTask)
{
int ret = DEVICE_MACHINE_OK;
DEV_VERBOSE_DEBUG("Submit stitch task");
DEV_TRACE_DEBUG(DEvent(taskId, DActSubmit(stitchContext.Size())));
AutoScopedPerf asp(PERF_EVT_SUBMIT_AICORE);
if (stitchContext.Empty()) {
DEV_INFO("Stitch context is empty.");
return ret;
}
PROF_STAGE_BEGIN(PERF_EVT_DECIDE_SLOT_ADDRESS, "slotaddr.before\n");
stitchContext.DecideSlotAddress(slotContext.GetSlotList(), slotContext.GetSlotSize());
PROF_STAGE_END(PERF_EVT_DECIDE_SLOT_ADDRESS, "slotaddr.after\n");
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return DEVICE_MACHINE_ERROR;
}
PROF_STAGE_BEGIN(PERF_EVT_DECIDE_INCAST_ADDRESS, "incastaddr.before\n");
ret = stitchContext.DecideIncastOutcast(taskId);
PROF_STAGE_END(PERF_EVT_DECIDE_INCAST_ADDRESS, "incastaddr.after\n");
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return DEVICE_MACHINE_ERROR;
}
DEV_IF_VERBOSE_DEBUG
{
stitchContext.DumpStitchInfo();
#if !DEBUG_INFINITE_LIFETIME
stitchContext.VerifyStitchedListMemory(*args);
#endif
}
#if DEBUG_MEM_DUMP_LEVEL >= DEBUG_MEM_DUMP_FULL
workspace.MarkAsNewStitchWindow();
#endif
CalcControlMaxAicore();
PROF_STAGE_BEGIN(PERF_EVT_STAGE_BUILD_TASK, "BuildDeviceTaskData.before\n");
DynDeviceTask* dynTask = taskContext.BuildDeviceTaskData(stitchContext, taskId, devProg, withoutTail);
if (dynTask == nullptr) {
DEV_ERROR(DevCommonErr::NULLPTR, "#ctrl.buildtask.leave: Build device task data failed.");
return DEVICE_MACHINE_ERROR;
}
dynTask->SetMaxCV(currentMaxC_, currentMaxV_);
if (parallelCtx.isInParallelForScope) {
dynTask->SetParallelInfo(parallelCtx.info);
}
if (!devProg->ctrlFlowCacheAnchor->IsRecording() ||
(devProg->ctrlFlowCacheAnchor->IsRecording() && devProg->ctrlFlowCacheAnchor->IsCacheOriginShape())) {
dynTask->SetLastTask(isLastTask);
dynTask->SetParallelSameIterLastDevTask(isParallelIterLastTask);
}
DEV_INFO("devProg->ctrlFlowCacheAnchor->isActivated : %d", devProg->ctrlFlowCacheAnchor->isActivated);
PROF_STAGE_END(PERF_EVT_STAGE_BUILD_TASK, "BuildDeviceTaskData.after\n");
if (!devProg->ctrlFlowCacheAnchor->isActivated && devProg->devArgs.hasAicpuTask) {
PrepareShmemWaitUntilTasks(dynTask);
}
PROF_STAGE_BEGIN(PERF_EVT_DEALLOCATE_WORKSPACE, "RecycleTensorWorkspace.before\n");
stitchContext.RecycleTensorWorkspace();
stitchContext.Reset();
slotContext.ClearDirty();
PROF_STAGE_END(PERF_EVT_DEALLOCATE_WORKSPACE, "RecycleTensorWorkspace.after\n");
ProcessControlFlowCacheRecord(dynTask);
PROF_STAGE_BEGIN(PERF_EVT_STAGE_PUSH_TASK, "push.before\n");
DumpDeviceTask(taskId, dynTask);
PerfMtTrace(PERF_TRACE_DEV_TASK_BUILD, CTRL_CPU_THREAD_IDX);
PushTask(dynTask);
PROF_STAGE_END(PERF_EVT_STAGE_PUSH_TASK, "push.after\n");
currentMaxC_ = 0;
currentMaxV_ = 0;
return ret;
}
schema::RUid DeviceExecuteContext::GetRuid(uint64_t rootKey, bool afterAppend)
{
int64_t dupIndex = stitchContext.Size();
if (afterAppend) {
dupIndex -= 1;
}
schema::RUid ruid(taskId, dupIndex, rootKey);
return ruid;
}
int DeviceExecuteContext::ControlFlowCacheStopCache(uint64_t rootKey)
{
int ret = DEVICE_MACHINE_OK;
ret = SubmitToAicoreAndRecycleMemory(false);
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return DEVICE_MACHINE_ERROR;
}
devProg->ctrlFlowCacheAnchor->StopRecording();
DEV_INFO("[Stitch Finish] Stop recording ctrl flow cache. rootKey=%" PRIu64 ".", rootKey);
return ret;
}
void* DeviceExecuteContext::CallRootFunctionAlloc(uint64_t rootKey)
{
int ret = DEVICE_MACHINE_OK;
DevAscendFunction* devRoot = devProg->GetFunction(rootKey);
DEV_DEBUG("Slloc one func %lu %p %s.", rootKey, devRoot, devRoot->GetRawName());
uint16_t realStitchNumThreshold =
parallelCtx.isInParallelForScope ? MAX_STITCH_FUNC_NUM : stitchTaskLoopNumThreshold;
if ((stitchContext.Size() == realStitchNumThreshold) ||
stitchContext.stitchedCallOpSize() + devRoot->GetOperationSize() > devProg->stitchFunctionsize) {
DEV_INFO(
"[Stitch Finish] Stitch Limit Exceeded. numThreshold=%u rootKey=%lu, func=%s, "
"#task=%zu+1 (limit=%u), #callop=%u+%zu (limit=%u).",
realStitchNumThreshold, rootKey, devRoot->GetRawName(), stitchContext.Size(), stitchTaskLoopNumThreshold,
stitchContext.stitchedCallOpSize(), devRoot->GetOperationSize(), devProg->stitchFunctionsize);
ret = SubmitToAicoreAndRecycleMemory(false);
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return RUNTIME_FUNCKEY_ERROR;
}
}
DEV_TRACE_DEBUG(REvent(GetRuid(rootKey), RActDup(devRoot->GetRawName())));
PROF_STAGE_BEGIN(PERF_EVT_STAGE_DUP_ROOT, "dup.before\n");
currDevRootDup = workspace.DuplicateRoot(devRoot);
PROF_STAGE_END(PERF_EVT_STAGE_DUP_ROOT, "dup.after\n");
return reinterpret_cast<void*>(&currDevRootDup.GetExpression(0));
}
bool DeviceExecuteContext::NeedSubmmitDevTask(uint64_t rootkey)
{
return (
rootkey == RUNTIME_FUNCKEY_FINISH || rootkey == RUNTIME_FUNCKEY_LOOP_BARRIER ||
rootkey == RUNTIME_FUNCKEY_PARALLEL_FOR_END || rootkey == RUNTIME_FUNCKEY_PARALLEL_FOR_BEGIN);
}
void DeviceExecuteContext::ParallelForBegin()
{
parallelCtx.Begin();
workspace.SwitchWParallelWorkSpace(parallelCtx.info.wsId);
}
void* DeviceExecuteContext::CallRootFunctionStitch(uint64_t rootKey)
{
int ret = DEVICE_MACHINE_OK;
DEV_DEBUG("Root stitch %lu.", rootKey);
if (rootKey == RUNTIME_FUNCKEY_CACHESTOP) {
if (devProg->ctrlFlowCacheAnchor->IsRecording()) {
ret = ControlFlowCacheStopCache(rootKey);
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return RUNTIME_FUNCKEY_ERROR;
}
return RUNTIME_FUNCRET_CACHESTOP_RETURN;
} else {
return RUNTIME_FUNCRET_CACHESTOP_CONTINUE;
}
}
if (NeedSubmmitDevTask(rootKey)) {
ret = SubmitToAicoreAndRecycleMemory(
false, rootKey == RUNTIME_FUNCKEY_FINISH ? true : false,
(rootKey == RUNTIME_FUNCKEY_PARALLEL_FOR_END) ? true : false);
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return RUNTIME_FUNCKEY_ERROR;
}
switch (rootKey) {
case RUNTIME_FUNCKEY_LOOP_BARRIER: {
parallelCtx.ChangeForId();
break;
}
case RUNTIME_FUNCKEY_PARALLEL_FOR_BEGIN: {
ParallelForBegin();
break;
}
case RUNTIME_FUNCKEY_PARALLEL_FOR_END: {
parallelCtx.End();
break;
}
default: {
break;
}
}
DEV_INFO("[Stitch Finish] Finish Signal or Barrier. rootKey=%" PRIu64 ".", rootKey);
return nullptr;
}
DEV_TRACE_DEBUG(REvent(GetRuid(rootKey), currDevRootDup.SchemaGetExpressionTable()));
while (!workspace.TryAllocateFunctionMemory(currDevRootDup, slotContext.GetSlotList())) {
ret = SubmitToAicoreAndRecycleMemory(true);
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return RUNTIME_FUNCKEY_ERROR;
}
DEV_INFO("[Stitch Finish] Memory Limit Exceeded.");
}
if (AiCoreFree()) {
ret = SubmitToAicoreAndRecycleMemory(false);
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return RUNTIME_FUNCKEY_ERROR;
}
DEV_INFO("[Stitch Finish] AICore Free.");
}
DEV_TRACE_DEBUG(DEvent(taskId, DActStitchStart(GetRuid(rootKey))));
PROF_STAGE_BEGIN(PERF_EVT_STAGE_STITCH, "stitch.before\n");
size_t devNextIdx = stitchContext.Size();
stitchContext.Stitch(slotContext, currDevRootDup, taskId, devNextIdx);
uint32_t updateErrCode = slotContext.UpdateSlots(currDevRootDup, taskId, devNextIdx);
if (updateErrCode == static_cast<uint32_t>(CtrlErr::CELL_MATCH_FILL_OP_NOT_ENOUGH)) {
DEV_INFO("UpdateSlots stitch cell failed with error code %u, force submit devtask", updateErrCode);
ret = SubmitToAicoreAndRecycleMemory(false);
if (unlikely(ret != DEVICE_MACHINE_OK)) {
return RUNTIME_FUNCKEY_ERROR;
}
}
PROF_STAGE_END(PERF_EVT_STAGE_STITCH, "stitch.after\n");
DEV_TRACE_DEBUG(DEvent(taskId, DActStitchFinish(GetRuid(rootKey, true))));
return nullptr;
}
void DeviceExecuteContext::MarkSlotNeedAlloc(int slotIndex)
{
DEV_ASSERT_MSG(
DevCommonErr::PARAM_INVALID, slotIndex >= 0 && slotIndex < static_cast<int>(slotContext.GetSlotSize()),
"MarkSlotNeedAlloc: Invalid slot index %d.", slotIndex);
slotContext.GetSlotList()[slotIndex].isAssembleSlotNeedAlloc = true;
return;
}
void DeviceExecuteContext::SetLoopDieId(int8_t dieId)
{
if (DuppedRootCached()) {
return;
}
currDevRootDup.DupDataForDynFuncData()->loopDieId_ = dieId;
}
void* DeviceExecuteContext::DeviceExecuteRuntimeCallRootAlloc(void* ctx_, uint64_t rootKey)
{
DeviceExecuteContext* ctx = (DeviceExecuteContext*)ctx_;
if (ctx == nullptr) {
DEV_ERROR(CtrlErr::ROOT_ALLOC_CTX_NULL, "#ctrl.ctrlflow.call.root_alloc: invalid ctx.");
return nullptr;
}
PerfBegin(PERF_EVT_ROOT_FUNC);
void* result = nullptr;
if (ctx->DuppedRootCached()) {
result = nullptr;
} else if (
ctx->devProg->ctrlFlowCacheAnchor->IsRecording() && ctx->devProg->ctrlFlowCacheAnchor->IsRecordingStopped()) {
result = nullptr;
} else {
result = ctx->CallRootFunctionAlloc(rootKey);
if (result == RUNTIME_FUNCKEY_ERROR) {
ctx->SetErrorState(DEVICE_MACHINE_ERROR);
}
}
PerfEnd(PERF_EVT_ROOT_FUNC);
return result;
}
bool IsSpecialRootKey(uint64_t rootKey)
{
if (rootKey == RUNTIME_FUNCKEY_FINISH || rootKey == RUNTIME_FUNCKEY_CACHESTOP ||
rootKey == RUNTIME_FUNCKEY_LOOP_BARRIER || rootKey == RUNTIME_FUNCKEY_PARALLEL_FOR_BEGIN ||
rootKey == RUNTIME_FUNCKEY_PARALLEL_FOR_END) {
return true;
}
return false;
}
void* DeviceExecuteContext::DeviceExecuteRuntimeCallRootStitch(void* ctx_, uint64_t rootKey)
{
DeviceExecuteContext* ctx = (DeviceExecuteContext*)ctx_;
if (ctx == nullptr) {
DEV_ERROR(CtrlErr::ROOT_STITCH_CTX_NULL, "#ctrl.ctrlflow.call.root_stitch: invalid ctx.");
return nullptr;
}
PerfBegin(PERF_EVT_ROOT_FUNC);
void* result = nullptr;
if (ctx->DuppedRootCached()) {
result = nullptr;
} else if (
ctx->devProg->ctrlFlowCacheAnchor->IsRecording() && ctx->devProg->ctrlFlowCacheAnchor->IsRecordingStopped()) {
result = nullptr;
} else {
result = ctx->CallRootFunctionStitch(rootKey);
if (result == RUNTIME_FUNCKEY_ERROR) {
ctx->SetErrorState(DEVICE_MACHINE_ERROR);
}
}
if (result == nullptr && IsSpecialRootKey(rootKey)) {
return result;
}
PerfEnd(PERF_EVT_ROOT_FUNC);
if (ctx->DuppedRootUpdateAndCachedAllSubmitted()) {
DEV_TRACE_DEBUG(CtrlEvent(none(), ControlFlowCachePartRunControlContinue()));
auto ctrlFlowCacheAnchor = ctx->devProg->ctrlFlowCacheAnchor;
ctrlFlowCacheAnchor->RuntimeAddrRestore(
ctx->slotContext.GetSlotList(), ctx->workspace.GetRuntimeOutcastTensorPool(), ctx->devProg->slotSize,
ctx->devProg->runtimeOutcastPoolSize, ctx->workspace.GetTensorAllocator(), ctx->devProg->GetParallelism());
ctrlFlowCacheAnchor->RuntimeAddrRelocWorkspace(
0, ctx->args->contextWorkspaceAddr, ctx->args, ctx->slotContext.GetSlotList(),
ctx->workspace.GetRuntimeOutcastTensorPoolBase(), ctx->devProg->GetParallelism());
}
return result;
}
void* DeviceExecuteContext::DeviceExecuteRuntimeCallLog(void* ctx_, uint64_t value)
{
(void)ctx_;
DEV_DEBUG("DeviceExecuteRuntimeCallLog -> Value: %lu", value);
return nullptr;
}
void* DeviceExecuteContext::DeviceExecuteRuntimeCallShmemAllocator(void* ctx_, uint64_t value)
{
uint64_t groupIndex = (reinterpret_cast<uint64_t*>(value))[0];
uint64_t memType = (reinterpret_cast<uint64_t*>(value))[1];
uint64_t size = (reinterpret_cast<uint64_t*>(value))[2];
uint64_t maxTileNum = (reinterpret_cast<uint64_t*>(value))[3];
constexpr uint64_t memTypeCount = 2;
DEV_ASSERT(DevCommonErr::PARAM_CHECK_FAILED, memType < memTypeCount);
DeviceExecuteContext* ctx = (DeviceExecuteContext*)ctx_;
DEV_ASSERT(DevCommonErr::PARAM_CHECK_FAILED, groupIndex < ctx->args->commGroupNum);
auto hcclOpParam = reinterpret_cast<TileOp::CommContext*>(ctx->args->commContexts[groupIndex]);
uint64_t winSize = memType == 0 ? hcclOpParam->winDataSize : hcclOpParam->winStatusSize;
uint64_t shmemAddrEndOffset = ctx->shmemAddrOffset[groupIndex][memType] + size;
if (shmemAddrEndOffset > winSize) {
ctx->shmemAddrOffset[groupIndex][memType] = 0UL;
DEV_ERROR(
DevCommonErr::PARAM_CHECK_FAILED, "#ctrl.unknown: Exceeds winSize limit. Maximum allowed: %lu, got: %lu",
winSize, shmemAddrEndOffset);
}
uint64_t vaddr =
TileOp::Distributed::EncodeShmemAddr(ctx->shmemAddrOffset[groupIndex][memType], maxTileNum, groupIndex, memType);
ctx->shmemAddrOffset[groupIndex][memType] += size;
return reinterpret_cast<void*>(vaddr);
}
void* DeviceExecuteContext::DeviceExecuteRuntimeCallSlotMarkNeedAlloc(void* ctx_, uint64_t slotIndex)
{
DeviceExecuteContext* ctx = (DeviceExecuteContext*)ctx_;
ctx->MarkSlotNeedAlloc(slotIndex);
return nullptr;
}
void* DeviceExecuteContext::DeviceExecuteRuntimeCallGetLoopDieId(void* ctx_, uint64_t rootKey)
{
(void)rootKey;
DeviceExecuteContext* ctx = (DeviceExecuteContext*)ctx_;
return static_cast<void*>(&ctx->loopDieId_);
}
void* DeviceExecuteContext::DeviceExecuteRuntimeCallSetLoopDieId(void* ctx_, uint64_t rootKey)
{
(void)rootKey;
DeviceExecuteContext* ctx = (DeviceExecuteContext*)ctx_;
ctx->SetLoopDieId(ctx->loopDieId_);
DEV_VERBOSE_DEBUG("Set loop die id:%d:", ctx->loopDieId_);
return nullptr;
}
}