* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file sk_graph.h
* \brief
*/
#ifndef __SK_GRAPH_H__
#define __SK_GRAPH_H__
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <memory>
#include <algorithm>
#include <cstdint>
#include <nlohmann/json.hpp>
#include "sk_log.h"
#include "sk_types.h"
#include "sk_node.h"
#include "sk_scope_info.h"
#include "acl/acl.h"
class SuperKernelOptionsManager;
using Json = nlohmann::ordered_json;
class SuperKernelNodeFactory {
public:
static std::unique_ptr<SuperKernelBaseNode> CreateNode(std::unique_ptr<aclmdlRITask> task, aclmdlRITaskType taskType,
uint64_t nodeIdx, uint64_t streamIdxInGraph, int32_t streamId, uint64_t preNodeId);
};
struct EventInfos {
uint64_t notifyNodeId = INVALID_TASK_ID;
std::unordered_set<uint64_t> resetNodeIdList;
std::unordered_set<uint64_t> waitNodeIdList;
};
struct MemoryInfos {
std::unordered_set<uint64_t> writeNodeIdList;
std::unordered_set<uint64_t> waitNodeIdList;
};
class SuperKernelGraph {
public:
aclError Update();
SuperKernelGraph() = default;
~SuperKernelGraph() = default;
SuperKernelGraph(const SuperKernelGraph&) = delete;
SuperKernelGraph& operator=(const SuperKernelGraph&) = delete;
SuperKernelGraph(SuperKernelGraph&&) = default;
SuperKernelGraph& operator=(SuperKernelGraph&&) = default;
SuperKernelGraph(aclmdlRI modelRI) : modelRI(modelRI) {}
SuperKernelGraph(aclmdlRI modelRI, const SuperKernelOptionsManager& opts) : modelRI(modelRI), opts_(&opts) {}
bool InitSKGraph();
void CaptureCurrentModelContext();
SuperKernelBaseNode* GetNodeById(uint64_t nodeId) const;
aclmdlRI GetModelRI() const { return modelRI; }
const std::string& GetModelIdCallCount() const { return modelId; }
const std::string& GetModelLabel() const { return modelLabel; }
const std::vector<uint64_t>& GetHeadNodes() const
{
return headNodes;
}
const std::vector<uint64_t>& GetNodeSizeInStream() const
{
return nodeSizeInStream;
}
const std::vector<aclrtStream>& GetStreams() const
{
return streams;
}
aclrtStream GetStreamByIndex(uint32_t streamIdx) const;
std::vector<uint64_t> GetSortedNodeIds() const;
const EventInfos* GetEventInfo(uint64_t eventId) const
{
auto it = eventToNodes.find(eventId);
return it != eventToNodes.end() ? &it->second : nullptr;
}
bool GetScopeNameByIdx(uint32_t scopeIdx, std::string& scopeName) const
{
auto it = scopeIdxToName.find(scopeIdx);
if (it != scopeIdxToName.end()) {
scopeName = it->second;
return true;
}
return false;
}
void AddShapeInfoPtr(std::unique_ptr<uint8_t[]> ptr) {
shapeInfoPtrList.emplace_back(std::move(ptr));
}
void ClearShapeInfoPtrList() {
shapeInfoPtrList.clear();
}
std::string BitsetToString(const std::bitset<MAX_SCOPE_NUM>& bitset) const;
bool ExpandUpdateNodes(std::vector<SuperKernelBaseNode*>& needUpdateNodes);
void SetUpdateFlag(bool flag)
{
needUpdate = flag;
}
void ParseOriginalScopes();
const std::vector<OriginalScopeInfo>& GetOriginalScopeInfos() const { return originalScopeInfos_; }
void DumpFusionFailReasons(const std::vector<SuperKernelScopeInfo>& processedScopeInfos);
Json ToJson() const;
bool InitFromModelRI();
struct FusionFailStats {
std::unordered_map<std::string, size_t> reasonStats;
size_t fusibleCount = 0;
size_t unfusibleCount = 0;
std::vector<std::string> nodeLogEntries;
std::vector<std::string> unfusibleNodeLogEntries;
};
private:
FusionFailStats CollectFusionFailStats();
bool AddNode(std::unique_ptr<SuperKernelBaseNode> node);
bool AddEventAssociateNotify(uint64_t eventId, SuperKernelBaseNode* node);
bool AddEventAssociateWait(uint64_t eventId, SuperKernelBaseNode* node);
bool AddEventAssociateReset(uint64_t eventId, SuperKernelBaseNode* node);
bool AddMemoryAssociateWrite(uint64_t eventId, SuperKernelBaseNode* node);
bool AddMemoryAssociateWait(uint64_t eventId, SuperKernelBaseNode* node);
bool AddEventAssociate();
void BuildEventNodeAssociations();
uint32_t GetValueBreakerBypass() const;
bool PostProcessMemoryNode();
bool ProcessMemoryWriteNodes(const uint64_t eventId, const MemoryInfos& memoryInfo,
const uint64_t memoryWaitValue, const uint32_t waitFlag);
void UpdateNodeScopeBitFlags();
bool InitStreamsFromModelRI();
bool ProcessAllStreamsAndTasks();
bool ProcessSingleTask(aclmdlRITask& task, uint32_t streamIdx, uint32_t taskIdx, uint64_t& preNodeId);
void RegisterFusibleScope(const std::unique_ptr<SuperKernelBaseNode>& node);
void UpdateNodeRelations(uint64_t nodeId, uint32_t streamIdx, uint32_t taskIdx, uint64_t& preNodeId);
std::unordered_map<uint64_t, std::unique_ptr<SuperKernelBaseNode>> graphMap;
std::unordered_map<uint64_t, EventInfos> eventToNodes;
std::unordered_map<uint64_t, MemoryInfos> memoryToNodes;
std::vector<uint64_t> headNodes;
std::vector<uint64_t> nodeSizeInStream;
std::vector<aclrtStream> streams;
aclmdlRI modelRI;
std::string modelId;
std::string modelLabel;
friend class SuperKernelOptimizer;
std::unordered_map<std::string, uint32_t> scopeNameToIdx;
std::unordered_map<uint32_t, std::string> scopeIdxToName;
std::vector<OriginalScopeInfo> originalScopeInfos_;
std::vector<std::unique_ptr<uint8_t[]>> shapeInfoPtrList;
std::unordered_set<SuperKernelBaseNode*> needUpdateNodes;
bool needUpdate = false;
const SuperKernelOptionsManager* opts_ = nullptr;
};
#endif