* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file sk_node.h
* \brief
*/
#ifndef __SK_NODE_H__
#define __SK_NODE_H__
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
#include <unordered_set>
#include <bitset>
#include <set>
#include <unordered_map>
#include <array>
#include <nlohmann/json.hpp>
#include "sk_log.h"
#include "sk_types.h"
#include "acl/acl.h"
class SuperKernelGraph;
class SuperKernelOptionsManager;
struct SkLaunchInfo;
enum class ScopeFailReason : uint8_t;
enum class DeadlockFailReason : uint8_t;
struct SkBindInfo {
uint64_t cap = 0;
std::array<uint64_t, 4> sknlFuncs = {0, 0, 0, 0};
};
using SkBindMap = std::unordered_map<uint64_t, SkBindInfo>;
enum class FusionFailReason {
CAN_FUSE,
BINDMAP_IS_EMPTY,
TASK_GROUP_NOT_EMPTY,
NOT_IN_SCOPE,
IN_UNFUSIBLE_SCOPE,
EXCEED_DEVICE_MAX,
RESET_TYPE_NODE,
ISOLATED_EVENT,
EXIST_DEADLOCK,
SCOPE_FUSE_PART,
EXTERNAL_DEPEND,
UNSUPPORT_EVENT_TYPE,
NOTIFY_NO_WAIT_NODE,
MEMORY_WAIT_NODE_ONLY,
MEMORY_WRITE_NODE_ONLY,
DEFAULT_NODE,
SIMT_OP_NOT_SUPPORTED,
KERNEL_ATTR_GET_FAILED,
EXCEED_SCOPE_MAX,
};
enum class BindmapFailReason : uint8_t {
NONE,
BINDMAP_INIT_EMPTY,
BINHDL_NULL,
FUNCHDL_NULL,
FUNC_NOT_FOUND,
BIN_DEV_ADDR_GET_FAILED,
FUNC_ADDR_GET_FAILED,
BINDMAP_ENTRY_CONFLICT,
BINDMAP_CAP_INCONSISTENT,
BIN_HOST_ADDR_GET_FAILED,
};
struct FusionFailReasonInfo {
FusionFailReason primary = FusionFailReason::CAN_FUSE;
uint8_t scopeDetailValue = 0;
uint8_t deadlockDetailValue = 0;
uint8_t bindmapDetailValue = 0;
FusionFailReasonInfo() = default;
explicit FusionFailReasonInfo(FusionFailReason reason) : primary(reason) {}
FusionFailReasonInfo(FusionFailReason reason, ScopeFailReason scopeReason);
FusionFailReasonInfo(FusionFailReason reason, DeadlockFailReason deadlockReason);
FusionFailReasonInfo(FusionFailReason reason, BindmapFailReason bindmapReason);
ScopeFailReason GetScopeDetail() const;
void SetScopeDetail(ScopeFailReason scopeReason);
DeadlockFailReason GetDeadlockDetail() const;
void SetDeadlockDetail(DeadlockFailReason deadlockReason);
BindmapFailReason GetBindmapDetail() const;
void SetBindmapDetail(BindmapFailReason bindmapReason);
bool operator==(FusionFailReason reason) const { return primary == reason; }
bool operator!=(FusionFailReason reason) const { return primary != reason; }
};
const char* BindmapFailReasonToStr(BindmapFailReason reason);
size_t AlignUpAndClamp(size_t value, size_t coreIdx);
inline const char* FusionFailReasonToStr(FusionFailReason reason) {
switch (reason) {
case FusionFailReason::CAN_FUSE:
return "node can fuse";
case FusionFailReason::BINDMAP_IS_EMPTY:
return "The operator does not support the operation of fusing SuperKernel";
case FusionFailReason::TASK_GROUP_NOT_EMPTY:
return "The operator will refresh task information at runtime, but SK does not support fusing dynamically changing tasks";
case FusionFailReason::NOT_IN_SCOPE:
return "The user actively marked that this operator is not fused";
case FusionFailReason::IN_UNFUSIBLE_SCOPE:
return "This operator is not within the fusion range marked by the user";
case FusionFailReason::EXCEED_DEVICE_MAX:
return "The number of kernels required by the operator exceeds the maximum number of kernels that the device can provide";
case FusionFailReason::RESET_TYPE_NODE:
return "reset type node in end";
case FusionFailReason::ISOLATED_EVENT:
return "There is no kernel node on the stream where the current node is located, and this stream is within the scope";
case FusionFailReason::EXIST_DEADLOCK:
return "exist deadlock";
case FusionFailReason::SCOPE_FUSE_PART:
return "scope fuse failed";
case FusionFailReason::EXTERNAL_DEPEND:
return "event node has external dependency";
case FusionFailReason::UNSUPPORT_EVENT_TYPE:
return "unsupport event type";
case FusionFailReason::NOTIFY_NO_WAIT_NODE:
return "notify node has no wait node in modelRI, mark as unfusible";
case FusionFailReason::MEMORY_WAIT_NODE_ONLY:
return "No memory write exists, meaning the memory write is outside modelRI. Therefore change all waits to event semantics, but they cannot be fused.";
case FusionFailReason::MEMORY_WRITE_NODE_ONLY:
return "only exists memory write nodes, mask it as unfusible";
case FusionFailReason::DEFAULT_NODE:
return "default node uses aicpu resources, mask it as unfusible";
case FusionFailReason::SIMT_OP_NOT_SUPPORTED:
return "SIMT operator is not supported for SuperKernel fusion";
case FusionFailReason::KERNEL_ATTR_GET_FAILED:
return "Failed to get kernel attribute for SuperKernel fusion";
case FusionFailReason::EXCEED_SCOPE_MAX:
return "Exceeded maximum scope number limit for SuperKernel fusion";
default:
return "UNKNOWN_REASON";
}
}
const char* GetKernelTypeString(uint32_t kernelType, const uint32_t taskRatio[2]);
std::string FusionFailReasonToStr(const FusionFailReasonInfo& info);
struct UpdateContext {
SkLaunchInfo* launchInfo = nullptr;
aclmdlRITaskParams* customParams = nullptr;
};
struct SknlMapInfo {
uint64_t cap;
void* globalFunc;
void* sknlFunc[4];
};
struct ResolvedFunctionInfo {
uint64_t funcAddr[2] = {0, 0};
uint64_t prefetchCnt[2] = {0, 0};
uint64_t funcOffset[2] = {0, 0};
std::string symbolBind[2] = {"", ""};
};
constexpr size_t K_MAX_SPLIT_BIN_COUNT = 4;
enum class KernelCapBitOffset : uint8_t {
EARLY_START_WAIT_FLAG = 0,
EARLY_START_SET_FLAG = 1,
DCCI = 2,
DISABLE_SCHEMODE = 3,
};
struct KernelCapBits {
bool earlyStartWaitFlag = false;
bool earlyStartSetFlag = false;
bool disableDcci = false;
bool disableScheMode = false;
};
KernelCapBits ParseKernelCapBits(uint64_t cap);
struct KernelInfos {
SkKernelType kernelType = SkKernelType::DEFAULT;
uint32_t kernelTypeInt = 0;
uint32_t taskRatio[2] = {0, 0};
uint32_t resolvedNum = 0;
uint64_t cap = 0;
KernelCapBits capBits = {};
uint32_t numBlocks = 0;
uint32_t vecNum = 0;
uint32_t cubeNum = 0;
const void *devArgs = nullptr;
void* opInfoPtr = nullptr;
size_t opInfoSize = 0;
std::string funcName = "Unknown";
aclrtBinHandle binHdl = nullptr;
aclrtFuncHandle funcHdl = nullptr;
aclrtLaunchKernelCfg* launchKernelCfg = nullptr;
bool isScheModeOn = false;
bool needMixKernelSplit = false;
bool isSimtOp = false;
ResolvedFunctionInfo resolvedFuncs[4];
BindmapFailReason bindmapFailReason = BindmapFailReason::NONE;
std::string Format() const;
};
struct SyncInfos {
uint64_t eventId = INVALID_TASK_ID;
void* addrValue = nullptr;
uint64_t correspondingNotifyNodeId = INVALID_TASK_ID;
std::vector<uint64_t> correspondingWaitNodeIds;
std::vector<uint64_t> correspondingResetNodeIds;
uint64_t memoryValue = std::numeric_limits<uint64_t>::max();
uint32_t memoryWaitFlag = std::numeric_limits<uint32_t>::max();
uint64_t eventFlag = std::numeric_limits<uint64_t>::max();
};
struct NodeInfos {
KernelInfos kernelInfos;
SyncInfos syncInfos;
};
class SuperKernelBaseNode {
public:
SuperKernelBaseNode(std::unique_ptr<aclmdlRITask> inputOriginTask, aclmdlRITaskType inputRtNodeType,
uint64_t inputNodeIdxInStream, uint64_t inputStreamIdxInGraph, int32_t inputStreamId, uint64_t inputPreNodeId)
: originTask(std::move(inputOriginTask)),
taskParams({}),
rtNodeType(inputRtNodeType),
notifyExpandVecNum(0),
notifyExpandCubeNum(0),
streamIdxInGraph(inputStreamIdxInGraph),
streamId(inputStreamId),
nodeIdxInStream(inputNodeIdxInStream),
nodeId(INVALID_TASK_ID),
preNodeId(inputPreNodeId),
nextNodeId(INVALID_TASK_ID),
nodeType(SkNodeType::NODE_DEFAULT),
isVisited(false),
isFusible(false),
isScopeNode(false),
isUpdate(false) { }
virtual ~SuperKernelBaseNode() = default;
virtual bool InitNode(const SuperKernelOptionsManager* opts = nullptr);
* @brief Format complete node information for logging
* @return Formatted string with nodeId, streamIdxInGraph, nodeIdxInStream, and node-specific info
*
* Format: [nodeId:lu, streamIdxInGraph:u, nodeIdxInStream:lu] - {node-specific-info}
* Examples:
* Kernel: [nodeId:123, streamIdxInGraph:0, nodeIdxInStream:5] - Kernel:func_name
* Notify: [nodeId:124, streamIdxInGraph:0, nodeIdxInStream:6] - Event:Notify(eventId:0x7ff8a0)
* Wait: [nodeId:125, streamIdxInGraph:0, nodeIdxInStream:7] - Event:Wait(eventId:0x7ff8a0)
* Default: [nodeId:126, streamIdxInGraph:0, nodeIdxInStream:8] - Default
*/
virtual std::string Format() const = 0;
uint32_t GetStreamIdxInGraph() const
{
return streamIdxInGraph;
}
int32_t GetStreamId() const
{
return streamId;
}
uint64_t GetNodeIdxInStream() const
{
return nodeIdxInStream;
}
uint64_t GetNodeId() const
{
return nodeId;
}
bool IsFusible() const
{
return isFusible;
}
void SetIsFusible(bool fusible)
{
isFusible = fusible;
}
void SetNodeId(uint64_t inputNodeId)
{
nodeId = inputNodeId;
}
void SetPreNodeId(uint64_t inputPreNodeId)
{
preNodeId = inputPreNodeId;
}
void SetNextNodeId(uint64_t inputNextNodeId)
{
nextNodeId = inputNextNodeId;
}
uint64_t GetPreNodeId() const
{
return preNodeId;
}
uint64_t GetNextNodeId() const
{
return nextNodeId;
}
virtual uint32_t GetNumBlocks() const { return 0; }
virtual SkKernelType GetKernelType() const { return SkKernelType::DEFAULT; }
virtual uint32_t GetVecNum() const { return 0; }
virtual uint32_t GetCubeNum() const { return 0; }
virtual bool IsScheModeOn() const { return false; }
virtual bool GetScheMode() const { return false; }
virtual uint64_t GetEventId() const
{
return INVALID_TASK_ID;
}
virtual std::vector<uint64_t> GetCorrespondingWaitNodeIds() const
{
return std::vector<uint64_t>();
}
virtual uint64_t GetCorrespondingNotifyNodeId() const
{
return INVALID_TASK_ID;
}
virtual void SetCorrespondingWaitNodeIds(const std::vector<uint64_t>& waitIds) {}
virtual void SetCorrespondingNotifyNodeId(uint64_t notifyId) {}
virtual const NodeInfos& GetNodeInfos() const
{
return nodeInfos;
}
virtual bool Update(const UpdateContext& ctx = {});
virtual aclError InValidateNode();
SkNodeType GetNodeType() const
{
return nodeType;
}
void SetNodeType(SkNodeType inputNodeType)
{
nodeType = inputNodeType;
}
bool IsVisited() const
{
return isVisited;
}
void SetVisited(bool inputIsVisited)
{
isVisited = inputIsVisited;
}
virtual const std::string GetScopeName() const { return ""; }
virtual bool IsScopeBegin() const { return false; }
virtual bool IsScopeEnd() const { return false; }
virtual bool IsScopePlaceholder() const { return false; }
const std::bitset<MAX_SCOPE_NUM>& GetScopeBitFlags() const
{
return scopeBitFlags;
}
void SetScopeBitFlags(const std::bitset<MAX_SCOPE_NUM>& flags)
{
scopeBitFlags = flags;
}
void SetIsScopeNode(bool isScope) { isScopeNode = isScope; }
bool IsScopeNode() const { return isScopeNode; }
void ClearScopeBitFlags() { scopeBitFlags.reset(); }
void MarkEventNodeToScope(SuperKernelBaseNode* node);
void SetNotifyExpandVecNum(uint32_t vecNum) { notifyExpandVecNum = vecNum; }
void SetNotifyExpandCubeNum(uint32_t cubeNum) { notifyExpandCubeNum = cubeNum; }
void SetScopeStreamIds(const std::unordered_set<uint32_t>& streamIds) { scopeStreamIds = streamIds; }
const std::unordered_set<uint32_t>& GetScopeStreamIds() const { return scopeStreamIds; }
bool IsUpdated() const { return isUpdate; }
void SetUpdate(bool update) { isUpdate = update; }
void SetFusionFailReason(FusionFailReason reason, ScopeFailReason scopeReason = static_cast<ScopeFailReason>(0)) {
fusionFailReason_.primary = reason;
fusionFailReason_.SetScopeDetail(scopeReason);
}
void SetFusionFailReason(FusionFailReason reason, DeadlockFailReason deadlockReason) {
fusionFailReason_.primary = reason;
fusionFailReason_.SetDeadlockDetail(deadlockReason);
}
void SetFusionFailReason(FusionFailReason reason, BindmapFailReason bindmapReason) {
fusionFailReason_.primary = reason;
fusionFailReason_.SetBindmapDetail(bindmapReason);
}
void SetFusionFailReason(const FusionFailReasonInfo& info) { fusionFailReason_ = info; }
FusionFailReason GetFusionFailReason() const { return fusionFailReason_.primary; }
const FusionFailReasonInfo& GetFusionFailReasonInfo() const { return fusionFailReason_; }
const aclmdlRITaskParams& GetTaskParams() const { return taskParams; }
bool IsInvalidated() const { return isInvalidated; }
void SetInvalidated(bool invalidated) { isInvalidated = invalidated; }
SkBindMap InitSuperKernelBindMap(aclrtBinHandle binHdl);
public:
NodeInfos nodeInfos;
std::unique_ptr<aclmdlRITask> originTask;
std::unordered_set<uint64_t> sendToNodeId;
std::unordered_set<uint64_t> receiveNodeId;
FusionFailReasonInfo fusionFailReason_;
protected:
aclmdlRITaskParams taskParams;
void LogNodeUpdateResult(const aclmdlRITaskParams* taskParams) const;
const char* GetUpdateTargetTypeName(aclmdlRITaskType type) const;
uint32_t notifyExpandVecNum;
uint32_t notifyExpandCubeNum;
uint32_t streamIdxInGraph;
int32_t streamId;
uint64_t nodeIdxInStream;
uint64_t nodeId;
uint64_t preNodeId;
uint64_t nextNodeId;
SkNodeType nodeType;
aclmdlRITaskType rtNodeType;
bool isVisited;
bool isFusible;
bool isScopeNode;
bool isUpdate;
bool isInvalidated = false;
std::unordered_set<uint32_t> scopeStreamIds;
std::bitset<MAX_SCOPE_NUM> scopeBitFlags;
};
class SuperKernelKernelNode : public SuperKernelBaseNode {
public:
using SuperKernelBaseNode::SuperKernelBaseNode;
bool InitNode(const SuperKernelOptionsManager* opts = nullptr) override;
uint32_t GetNumBlocks() const override { return nodeInfos.kernelInfos.numBlocks; }
SkKernelType GetKernelType() const override { return nodeInfos.kernelInfos.kernelType; }
uint32_t GetVecNum() const override { return nodeInfos.kernelInfos.vecNum; }
uint32_t GetCubeNum() const override { return nodeInfos.kernelInfos.cubeNum; }
bool GetScheMode() const override;
std::string Format() const override;
bool Update(const UpdateContext& ctx) override;
const std::string GetScopeName() const override
{
return scopeName;
}
bool IsScopeBegin() const override { return isScopeBegin; }
bool IsScopeEnd() const override { return isScopeEnd; }
bool IsScopePlaceholder() const override { return isPlaceholder; }
bool IsScheModeOn() const override { return nodeInfos.kernelInfos.isScheModeOn; }
private:
void IdentifyAndHandleSimtKernel(const SuperKernelOptionsManager* opts);
bool isScopeBegin = false;
bool isScopeEnd = false;
bool isPlaceholder = false;
std::string scopeName;
};
class SuperKernelMemoryNode : public SuperKernelBaseNode {
public:
using SuperKernelBaseNode::SuperKernelBaseNode;
uint64_t GetEventId() const override
{
return nodeInfos.syncInfos.eventId;
}
std::vector<uint64_t> GetCorrespondingWaitNodeIds() const override
{
return nodeInfos.syncInfos.correspondingWaitNodeIds;
}
void SetCorrespondingWaitNodeIds(const std::vector<uint64_t>& waitIds) override
{
nodeInfos.syncInfos.correspondingWaitNodeIds.assign(waitIds.begin(), waitIds.end());
}
uint64_t GetCorrespondingNotifyNodeId() const override
{
return nodeInfos.syncInfos.correspondingNotifyNodeId;
}
void SetCorrespondingNotifyNodeId(uint64_t notifyId) override
{
nodeInfos.syncInfos.correspondingNotifyNodeId = notifyId;
}
std::string Format() const override;
bool InitNode(const SuperKernelOptionsManager* opts = nullptr) override;
bool Update(const UpdateContext& ctx) override;
uint32_t GetVecNum() const override { return notifyExpandVecNum; }
uint32_t GetCubeNum() const override { return notifyExpandCubeNum; }
};
class SuperKernelDefaultNode : public SuperKernelBaseNode {
public:
using SuperKernelBaseNode::SuperKernelBaseNode;
bool InitNode(const SuperKernelOptionsManager* opts = nullptr) override;
aclError InValidateNode() override;
std::string Format() const override;
};
class SuperKernelGraph;
bool DumpKernelBinaries(const SuperKernelGraph& graph, const std::string& binPath);
const SkBindMap& GetSkBindMap(aclrtBinHandle binHdl);
using Json = nlohmann::ordered_json;
Json KernelInfosToJson(const KernelInfos& kernelInfos);
Json SyncInfosToJson(const SyncInfos& syncInfos, SkNodeType nodeType);
Json NodeInfosToJson(const NodeInfos& nodeInfos, SkNodeType nodeType);
Json SuperKernelBaseNodeToJson(const SuperKernelBaseNode* node);
Json SuperKernelKernelNodeToJson(const SuperKernelKernelNode* node);
Json SuperKernelMemoryNodeToJson(const SuperKernelMemoryNode* node);
Json SuperKernelDefaultNodeToJson(const SuperKernelDefaultNode* node);
#endif