* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file dev_stitch_dependency_enhanced.h
* \brief Enhanced Stitch dependency processing with unified functions
*
* Design principles:
* - Multi-concurrent types: Atomic-Write, Read (same type non-mutex, support multiple ops)
* - Exclusive type: Normal-Write (all types mutex, only 1 op)
* - Unified functions with opType parameter to reduce code duplication
*/
#pragma once
#include "machine/utils/dynamic/dev_encode_types.h"
#include "machine/utils/dynamic/dev_cell_match_mem_layout.h"
#include "machine/utils/dynamic/dev_encode_function_stitch.h"
#include "tilefwk/aicpu_common.h"
#include "tilefwk/aikernel_data.h"
namespace npu::tile_fwk::dynamic {
inline bool CheckStitchCacheDuplicate(
uint64_t* stitchCache, uint32_t rootFuncMaxCallOpsize, uint32_t producerFuncIdx, uint32_t producerOpIdx,
uint32_t consumerFuncIdx, uint32_t consumerOpIdx, uint64_t devTaskId)
{
if (stitchCache == nullptr || rootFuncMaxCallOpsize == 0) {
return false;
}
uint64_t& entry = stitchCache[consumerOpIdx * rootFuncMaxCallOpsize + producerOpIdx];
uint64_t expected = (devTaskId << 32) | (static_cast<uint64_t>(producerFuncIdx) << 16) | consumerFuncIdx;
if (entry == expected) {
return true;
}
entry = expected;
return false;
}
inline void EstablishDependenciesWithType(
uint64_t cellMemBase, uint64_t* cellMatchTableData, uint32_t opType, uint32_t opCount,
DevAscendFunctionDupped* stitchingList, int stitchingSize, DevAscendFunctionDupped* currDup, uint64_t tagId,
size_t devCurrIdx, uint32_t operationIdx, DeviceWorkspaceAllocator* workspace, uint64_t* matchCount, int slotIdx,
const DevCellMatchTableDesc& desc)
{
(void)stitchingSize;
for (uint32_t i = 0; i < opCount; i++) {
uint64_t opId = CellMatchGetOpId(cellMatchTableData, cellMemBase, opType, i, desc);
if (opId != AICORE_TASK_INIT) {
auto funcIdx = FuncID(static_cast<uint32_t>(opId));
auto opIdx = TaskID(static_cast<uint32_t>(opId));
if (funcIdx >= static_cast<uint32_t>(stitchingSize)) {
DEV_ERROR(CtrlErr::CELL_MATCH_OP_ID_INVALID,
"Cell match cache opid is invalid maybe dirty data, opid=(%u!%u), stitchListSize=%d.",
funcIdx, opIdx, stitchingSize);
continue;
}
(*matchCount)++;
DeviceStitchContext::HandleOneStitch(
stitchingList[funcIdx], *currDup, funcIdx, opIdx, devCurrIdx, operationIdx, workspace,
DeviceStitchContext::StitchKind::StitchDefault, slotIdx, static_cast<uint64_t>(CellMatchGetDevTaskIdFromTagId(tagId)));
DEV_VERBOSE_DEBUG(
"Stitch dependency: (%u!%u) -> (%zu!%u), cellopType=%u, cellopCount=%u, "
"slotIdx=%d, opTagid=%x, curOpTagid=%lx",
funcIdx, opIdx, devCurrIdx, operationIdx, opType, opCount, slotIdx,
static_cast<uint32_t>(opId >> CELL_MATCH_META_TAGID_SHIFT32), tagId);
}
}
}
inline void CellMatchStitchHandle(
int cellIndex, uint64_t* cellMatchTableData, uint32_t myOpType, DevAscendFunctionDupped* stitchingList,
int stitchingSize, DevAscendFunctionDupped* currDup, uint64_t tagId, size_t devCurrIdx, uint32_t operationIdx,
DeviceWorkspaceAllocator* workspace, uint64_t* matchCount, int slotIdx, const DevCellMatchTableDesc& desc)
{
DEV_VERBOSE_DEBUG(
"CellMatchStitchHandle: cell[%d], cellMatchTableData=%p, myOpType=%u, mytagId=%lx, slotIdx=%d, myTaskid=%x", cellIndex,
cellMatchTableData, myOpType, tagId, slotIdx, MakeTaskID(devCurrIdx, operationIdx));
uint64_t cellMemBase = CellMatchCellIndexToMemBase(static_cast<uint64_t>(cellIndex), desc);
uint64_t meta = cellMatchTableData[cellMemBase];
uint32_t curActiveOpType = CellMatchGetCurrentOpType(meta);
if (curActiveOpType == CELL_MATCH_OP_TYPE_NONE) {
DEV_VERBOSE_DEBUG("CellMatchStitchHandle: cell[%d] early return, curActiveOpType=NONE", cellIndex);
return;
}
if (CellMatchGetTagId(meta) != tagId) {
DEV_VERBOSE_DEBUG(
"CellMatchStitchHandle: cell[%d] early return, storedTagId=%lx != tagId=%lx", cellIndex,
CellMatchGetTagId(meta), tagId);
return;
}
uint32_t curActiveOpCount = CellMatchGetCurrentOpCount(meta);
if (CellMatchIsMutexOp(myOpType, curActiveOpType)) {
DEV_VERBOSE_DEBUG(
"CellMatchStitchHandle: cell[%d] mutex op, establish deps for curActiveOpType=%u curActiveOpCnt=%u", cellIndex,
curActiveOpType, curActiveOpCount);
EstablishDependenciesWithType(
cellMemBase, cellMatchTableData, curActiveOpType, curActiveOpCount, stitchingList, stitchingSize, currDup,
tagId, devCurrIdx, operationIdx, workspace, matchCount, slotIdx, desc);
return;
}
uint32_t prevMutexType = CellMatchGetPrevMutexOpType(meta);
uint32_t prevMutexCount = CellMatchGetPrevMutexOpCount(meta);
if (prevMutexType != CELL_MATCH_OP_TYPE_NONE && prevMutexCount > 0 &&
prevMutexCount != CELL_MATCH_INVALID_OP_COUNT) {
DEV_VERBOSE_DEBUG(
"CellMatchStitchHandle: cell[%d] prev mutex, establish deps for prevMutexType=%u, prevMutexCount=%u",
cellIndex, prevMutexType, prevMutexCount);
EstablishDependenciesWithType(
cellMemBase, cellMatchTableData, prevMutexType, prevMutexCount, stitchingList, stitchingSize, currDup,
tagId, devCurrIdx, operationIdx, workspace, matchCount, slotIdx, desc);
}
}
template <typename... TyArgs>
static void CellMatchStitchEnhance(
const uint64_t offset[DEV_SHAPE_DIM_MAX], const uint64_t shape[DEV_SHAPE_DIM_MAX],
const DevCellMatchTableDesc& cellMatchTableDesc, uint32_t opType, TyArgs... args)
{
if constexpr (sizeof...(args) == 10) {
auto argsTuple = std::make_tuple(args...);
uint64_t* cellMatchTableData = std::get<0>(argsTuple);
DevAscendFunctionDupped* stitchingList = std::get<1>(argsTuple);
int stitchingSize = std::get<2>(argsTuple);
DevAscendFunctionDupped* currDup = std::get<3>(argsTuple);
uint64_t tagId = std::get<4>(argsTuple);
size_t devCurrIdx = std::get<5>(argsTuple);
DeviceWorkspaceAllocator* workspace = std::get<6>(argsTuple);
uint32_t operationIdx = std::get<7>(argsTuple);
int slotIdx = std::get<8>(argsTuple);
uint64_t* matchCount = std::get<9>(argsTuple);
struct HandleStitch {
static inline uint32_t Process(
int cellIndex, uint64_t* data, DevAscendFunctionDupped* list, int size, DevAscendFunctionDupped* dup,
uint64_t taskId, size_t tagId, uint32_t opIdx, DeviceWorkspaceAllocator* ws, uint64_t* cnt,
uint32_t type, int slot, const DevCellMatchTableDesc& desc)
{
CellMatchStitchHandle(
cellIndex, data, type, list, size, dup, taskId, tagId, opIdx, ws, cnt, slot, desc);
return 0;
}
};
CellMatchHandle<HandleStitch>(
offset, shape, cellMatchTableDesc, cellMatchTableData, stitchingList, stitchingSize, currDup, tagId,
devCurrIdx, operationIdx, workspace, matchCount, opType, slotIdx, cellMatchTableDesc);
}
}
}