* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef __SK_TASK_BUILDER_H__
#define __SK_TASK_BUILDER_H__
#include "sk_node.h"
#include "sk_types.h"
#include "sk_options_manager.h"
#include <vector>
#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
class SuperKernelGraph;
enum class SkQueueType : uint8_t {
AIC,
AIV,
MIX_1_1,
MIX_1_2,
UNKNOWN,
};
inline const char* to_string(SkQueueType type)
{
switch (type) {
case SkQueueType::AIC:
return "AIC";
case SkQueueType::AIV:
return "AIV";
case SkQueueType::MIX_1_1:
return "MIX_1_1";
case SkQueueType::MIX_1_2:
return "MIX_1_2";
case SkQueueType::UNKNOWN:
return "UNKNOWN";
default:
return "UNKNOWN";
}
}
enum class SyncDirection : uint8_t {
NONE = 0,
CUB_TO_CUB,
VEC_TO_VEC,
CUB_TO_VEC,
VEC_TO_CUB,
MIX_TO_MIX,
ALL_SYNC,
};
inline const char* to_string(SyncDirection dir)
{
switch (dir) {
case SyncDirection::NONE:
return "NONE";
case SyncDirection::CUB_TO_CUB:
return "CUB_TO_CUB";
case SyncDirection::VEC_TO_VEC:
return "VEC_TO_VEC";
case SyncDirection::CUB_TO_VEC:
return "CUB_TO_VEC";
case SyncDirection::VEC_TO_CUB:
return "VEC_TO_CUB";
case SyncDirection::MIX_TO_MIX:
return "MIX_TO_MIX";
case SyncDirection::ALL_SYNC:
return "ALL_SYNC";
default:
return "UNKNOWN";
}
}
struct EarlyStartInfo {
SuperKernelBaseNode* relatedNode = nullptr;
uint32_t funcEarlyStartConfig = 0U;
SuperKernelBaseNode* nextAicRelatedNode = nullptr;
SuperKernelBaseNode* nextAivRelatedNode = nullptr;
uint32_t syncEarlyStartConfig = 0U;
void ApplyFuncMask(SkEarlyStartMask mask)
{
funcEarlyStartConfig |= static_cast<uint32_t>(mask);
}
bool CheckFuncMask(SkEarlyStartMask mask) const
{
return (funcEarlyStartConfig & static_cast<uint32_t>(mask)) != 0;
}
void ApplySyncMask(SkEarlyStartMask mask)
{
syncEarlyStartConfig |= static_cast<uint32_t>(mask);
}
bool CheckSyncMask(SkEarlyStartMask mask) const
{
return (syncEarlyStartConfig & static_cast<uint32_t>(mask)) != 0;
}
};
struct TaskSyncInfo {
SkQueueType queueType;
std::map<size_t, SyncDirection> cubSendInfo;
std::map<size_t, SyncDirection> cubRecvInfo;
std::map<size_t, SyncDirection> vecSendInfo;
std::map<size_t, SyncDirection> vecRecvInfo;
std::map<size_t, SyncDirection> crossSyncInfo;
EarlyStartInfo earlyStartInfo{};
TaskSyncInfo() : queueType(SkQueueType::UNKNOWN) {}
};
struct SkBuildResult {
SkLaunchInfo launchInfo;
Json taskQueueJson;
};
class SkTaskBuilder {
public:
SkTaskBuilder(SuperKernelOptionsManager& opts, const SuperKernelGraph& graph) :
opts(opts), graph_(graph)
{}
SkBuildResult Build(std::string skFuncName, const std::vector<SuperKernelBaseNode*>& tasks,
const std::vector<SuperKernelBaseNode*>& customTasks, uint16_t scopeId);
private:
SuperKernelOptionsManager& opts;
const SuperKernelGraph& graph_;
std::vector<TaskSyncInfo> taskSyncInfos_;
std::unordered_map<uint64_t, size_t> nodeIdToIndex_;
std::unordered_map<size_t, uint64_t> indexToNodeId_;
bool aicAvailable_ = false;
bool aivAvailable_ = false;
std::pair<int, int> GetPreFetchCnt(const ResolvedFunctionInfo& resolved);
bool AddSyncTask(SkTask& skTask, size_t nodeIndex, SkCoreSyncType syncType,
uint8_t earlyStartConfig = 0U, uint32_t skipCoreCount = 0U,
SkKernelType relatedType = SkKernelType::DEFAULT);
bool AddEventTask(SkTask& skTask, SuperKernelBaseNode* node, size_t nodeIndex, SkTaskType taskType);
bool AddFuncTask(SkTask& skTask, SuperKernelBaseNode* node, SkDfxInfo* dfxInfo, size_t nodeIndex, int addrIndex,
int binCount, SkTaskType taskType, uint32_t numBlocks);
bool DispatchFuncTask(SkTask& skTaskCube, SkTask& skTaskVec, SuperKernelBaseNode* node, SkDfxInfo* dfxInfo,
size_t nodeIndex, int binCount, SkTaskType taskType, SkQueueType queueType);
bool DispatchEventTask(SkTask& skTaskCube, SkTask& skTaskVec, SuperKernelBaseNode* node, size_t nodeIndex,
SkTaskType taskType, SkQueueType queueType);
bool DispatchSyncTasks(SkTask& skTaskCube, SkTask& skTaskVec, size_t nodeIndex,
const std::map<size_t, SyncDirection>& syncInfo, bool isSend, SkQueueType queueType);
bool DispatchSyncTasks(SkTask& skTaskCube, SkTask& skTaskVec, size_t nodeIndex, const EarlyStartInfo& earlyStartInfo,
bool isSend, SkQueueType queueType);
bool InitTaskSyncInfos(const std::vector<SuperKernelBaseNode*>& tasks);
bool PrecomputeSyncRelationsFromGraph(const std::vector<SuperKernelBaseNode*>& tasks);
bool PrecomputeSyncRelationsByMixGroups(const std::vector<SuperKernelBaseNode*>& tasks);
bool SplitTasksByMixGroups(const std::vector<SuperKernelBaseNode*>& tasks,
std::vector<std::vector<SuperKernelBaseNode*>>& splitTasks,
bool& hasMixKernel) const;
bool InitSyncInfoSnapshotForMixGroups(const std::vector<SuperKernelBaseNode*>& tasks,
std::vector<TaskSyncInfo>& taskSyncInfosOrigin);
bool ProcessSyncRelationSplitGroup(const std::vector<SuperKernelBaseNode*>& curSplitTasks,
size_t groupIndex,
size_t groupOffset,
bool hasNextGroup,
const std::vector<TaskSyncInfo>& taskSyncInfosOrigin,
std::vector<TaskSyncInfo>& mergedTaskSyncInfos);
bool RebaseTaskSyncInfo(TaskSyncInfo& syncInfo, size_t offset) const;
void AddBoundaryAllSync(const std::vector<SuperKernelBaseNode*>& curSplitTasks,
size_t groupIndex,
size_t groupOffset);
bool IsMixKernelTask(const SuperKernelBaseNode* task) const;
void ExtractIntraStreamSync(const std::vector<SuperKernelBaseNode*>& tasks);
bool ExtractInterStreamSync(const std::vector<SuperKernelBaseNode*>& tasks);
void InsertSyncEvent(size_t preIdx, size_t currIdx);
void OptimizeSyncRelations(const std::vector<SuperKernelBaseNode*>& tasks);
void RemoveCrossedLineSync();
void RemoveMultiSendSync();
void RemoveMultiRecvSync();
void RemoveRedundantCrossSync(const std::vector<SuperKernelBaseNode*>& tasks);
bool ApplyEarlyStartSyncPass(const std::vector<SuperKernelBaseNode*>& tasks);
bool ApplyPerOpMaxCoreNum(const std::vector<SuperKernelBaseNode*>& tasks, SkTask& aicTask, SkTask& aivTask);
bool JudgeRemoveCrossSync(size_t sendIdx, size_t recvIdx, bool isCubToVec);
void RemoveSyncInfo(size_t sendIdx, size_t recvIdx, bool isRemoveRecv, SyncDirection dirToRemove);
void PrintSyncInfo(const char* stage);
SkHostEntryInfo GenEntryInfo(SkTask& skTaskCube, SkTask& skTaskVec);
DeviceArgsPtr GenEntryArgs(const SkTask& skTaskCube, const SkTask& skTaskVec, const SkDfxInfo* dfxInfos,
uint32_t dfxCount, const SkEventConfig *eventConfig = nullptr);
bool UpdateDfxInfo(SkDfxInfo* dfxInfo, const KernelInfos& kernelInfo, const ResolvedFunctionInfo& resolved,
int binIndex, int addrIndex);
bool ProcessCoreFuncSize(SkDfxInfo* dfxInfo, aclrtBinHandle binHdl, const void* binHostAddr, uint32_t binHostSize,
const ResolvedFunctionInfo& resolved, int coreIndex, int binIndex,
const char* coreName);
};
#endif