* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file cell_match_dynamic.cpp
* \brief Per-launch evaluation and refresh of dynamic CellMatchTableDesc.stride on host.
*/
#include "machine/runtime/launcher/cell_match_dynamic.h"
#include <vector>
#include "interface/machine/device/tilefwk/aikernel_data.h"
#include "machine/runtime/launcher/device_launcher_binding.h"
#include "machine/utils/dynamic/dev_encode_function_stitch.h"
#include "tilefwk/error_code.h"
namespace npu::tile_fwk::dynamic {
namespace {
bool TryBuildDynamicCellMatchDesc(
const DyndevFunctionAttribute::DynamicCellMatchLaunchMeta& launchMeta, Evaluator& eval,
DevCellMatchTableDesc& patchedDesc)
{
patchedDesc.SetCellShape(launchMeta.cellShape);
const int dim = patchedDesc.GetDimensionSize();
if (launchMeta.candidateRawDims.empty() || dim > DEV_SHAPE_DIM_MAX) {
return false;
}
bool consistent = true;
int64_t refStride[DEV_SHAPE_DIM_MAX]{0};
for (size_t c = 0; c < launchMeta.candidateRawDims.size(); ++c) {
int64_t currentStride[DEV_SHAPE_DIM_MAX]{0};
for (int d = dim - 1; d >= 0; --d) {
auto expr = launchMeta.candidateRawDims[c][d];
int64_t tensorDim = eval.Evaluate(expr);
int64_t cellDim = std::max<int64_t>(patchedDesc.GetCellShape(d), 1);
int64_t tile = (tensorDim + cellDim - 1) / cellDim;
ASSERT(tile > 0) << "Invalid tile for dynamic cell match slot=" << launchMeta.slotIndex << ", dim=" << d;
currentStride[d] = tile;
}
if (c == 0) {
for (int d = 0; d < dim; ++d) {
refStride[d] = currentStride[d];
}
continue;
}
for (int d = 0; d < dim; ++d) {
if (refStride[d] != currentStride[d]) {
consistent = false;
break;
}
}
if (!consistent) {
break;
}
}
if (!consistent) {
return false;
}
std::vector<int> strideShape(dim);
for (int d = 0; d < dim; ++d) {
strideShape[d] = static_cast<int>(refStride[d]);
}
patchedDesc.SetStrideShape(strideShape);
return true;
}
}
std::vector<DevDynamicCellMatchStridePatch> PrepareDynamicCellMatchDescPatches(
const DyndevFunctionAttribute& dynAttr, Evaluator& eval)
{
std::vector<DevDynamicCellMatchStridePatch> patches;
if (dynAttr.dynamicCellMatchLaunchMetaList.empty()) {
return patches;
}
for (const auto& launchMeta : dynAttr.dynamicCellMatchLaunchMetaList) {
DevCellMatchTableDesc patchedDesc;
bool ready = TryBuildDynamicCellMatchDesc(launchMeta, eval, patchedDesc);
if (!ready) {
ASSERT(false) << "dynamic cell match launch prepare failed, slot=" << launchMeta.slotIndex;
}
DevDynamicCellMatchStridePatch patch;
patch.descOffset = launchMeta.descOffset;
patch.stride = patchedDesc.stride;
patches.push_back(patch);
}
return patches;
}
void PatchHostDynamicCellMatchTableDesc(
DevAscendProgram* hostDevProg, const std::vector<DevDynamicCellMatchStridePatch>& patches)
{
if (hostDevProg == nullptr || patches.empty()) {
return;
}
auto* cfgBytes = reinterpret_cast<uint8_t*>(hostDevProg);
for (const auto& patch : patches) {
auto* dstDesc = reinterpret_cast<DevCellMatchTableDesc*>(cfgBytes + patch.descOffset);
dstDesc->stride = patch.stride;
}
}
void WriteDynamicCellMatchStridePatchesToLaunchArgs(
int64_t* launchInputs, const std::vector<DevDynamicCellMatchStridePatch>& patches)
{
if (launchInputs == nullptr) {
return;
}
const uint64_t inputCount = static_cast<uint64_t>(launchInputs[0]);
const uint64_t outputCount = static_cast<uint64_t>(launchInputs[1]);
auto* patchCountPtr = reinterpret_cast<uint64_t*>(
reinterpret_cast<DevTensorData*>(launchInputs + TENSOR_INFO_OFFSET) + inputCount + outputCount);
*patchCountPtr = patches.size();
auto* patchArr = reinterpret_cast<DevDynamicCellMatchStridePatch*>(patchCountPtr + 1);
for (size_t i = 0; i < patches.size(); ++i) {
patchArr[i] = patches[i];
}
}
void ValidateDynamicCellMatchTableMemBudget(
const DyndevFunctionAttribute& dynAttr, DevAscendProgram* hostDevProg)
{
if (hostDevProg == nullptr || dynAttr.constructAssembleNeedAllocRuntimeSlots.empty()) {
return;
}
const auto* cfgBytes = reinterpret_cast<const uint8_t*>(hostDevProg);
for (const auto& launchMeta : dynAttr.dynamicCellMatchLaunchMetaList) {
if (dynAttr.constructAssembleNeedAllocRuntimeSlots.count(launchMeta.slotIndex) == 0) {
continue;
}
const auto* desc = reinterpret_cast<const DevCellMatchTableDesc*>(cfgBytes + launchMeta.descOffset);
const uint64_t cellMatchStride0 = desc->stride.dimStride[0];
ASSERT(DevCommonErr::PARAM_CHECK_FAILED, cellMatchStride0 < static_cast<uint64_t>(MAX_CELLMATCHSSTRIDE))
<< " Dynamic cell match slot=" << launchMeta.slotIndex
<< " stitch results in excessive memory consumption,"
<< "Please appropriately configure the view shape and tile shape, and ensure aligned with the input shape.";
}
}
void RefillDynamicMemBudgets(
DevAscendProgram* hostDevProg, DyndevFunctionAttribute& dynAttr, Evaluator& eval)
{
if (hostDevProg == nullptr) {
return;
}
if (dynAttr.maxDynamicAssembleOutcastMem.IsValid()) {
hostDevProg->memBudget.tensor.maxDynamicAssembleOutcastMem =
eval.Evaluate(dynAttr.maxDynamicAssembleOutcastMem);
}
if (!dynAttr.maxDynamicCellMatchTableMem.IsValid()) {
return;
}
hostDevProg->memBudget.metadata.maxDynamicCellMatchTableMem =
eval.Evaluate(dynAttr.maxDynamicCellMatchTableMem);
uint64_t totalDynamicCellMatchSlotNum = hostDevProg->memBudget.metadata.dynamicCellMatchSlotNum;
hostDevProg->memBudget.metadata.dynamicCellMatch =
totalDynamicCellMatchSlotNum * hostDevProg->memBudget.metadata.maxDynamicCellMatchTableMem;
}
std::vector<DevDynamicCellMatchStridePatch> PrepareHostDynamicCellMatchForLaunch(
DyndevFunctionAttribute& dynAttr, Evaluator& eval, DevAscendProgram* hostDevProg)
{
auto patches = PrepareDynamicCellMatchDescPatches(dynAttr, eval);
PatchHostDynamicCellMatchTableDesc(hostDevProg, patches);
RefillDynamicMemBudgets(hostDevProg, dynAttr, eval);
if (dynAttr.maxDynamicCellMatchTableMem.IsValid()) {
ValidateDynamicCellMatchTableMemBudget(dynAttr, hostDevProg);
}
return patches;
}
}