* Copyright 2023-2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "akg/Analysis/TilingStrategy.h"
#include <algorithm>
#include <climits>
#include <map>
#include <numeric>
#include <optional>
#include <utility>
#include "akg/Dialect/Affine/Analysis/GpuTemplateTilingSolver.h"
#include "akg/Utils/AKGGlobalVars.hpp"
#include "akg/Utils/AnalysisCommon.hpp"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
namespace mlir {
namespace autotiling {
static constexpr const char *kNotInnerDimensionBroadcastLoopAttr = "not_inner_dimension_broadcast";
using akg::utils::GpuInfo;
using akg::utils::StrategyHelper;
using llvm::SmallVector;
using mlir::akg::autotiling::GpuTemplateSolver;
bool TilingStrategy::IsRelevant(const AxisPtr &a, const InitGraphPtr graph) {
if (a->axisType.find(workForAxisLabel) != a->axisType.end()) {
return true;
}
if (IsRelevant(a, graph, this->workForOps)) {
return true;
}
return false;
}
bool TilingStrategy::IsRelevant(const AxisPtr &a, const InitGraphPtr graph, std::unordered_set<std::string> ops) {
if (ops.empty()) {
return false;
}
return std::any_of(graph->nodes().begin(), graph->nodes().end(), [&](const NodePtr &node) {
if (ops.find(node->opType) == ops.end()) {
return false;
}
const auto &loopNest = node->loopNest();
return std::find(loopNest.begin(), loopNest.end(), a) != loopNest.end();
});
}
void RepositoryStrategy::AddConstraint(ModelGraphPtr initGraph) {
auto &tool = akgglobal::GpuScheduleTool::getInstance();
initGraph->rootAxis->forEachAxisTopDown([this, &tool](const AxisPtr a) {
for (auto axisInfo : tool.get(a->name)) {
if (axisInfo.tileLevel <= 0 || axisInfo.constSize <= 0) {
continue;
}
if (axisInfo.tileLevel == extraTileLevel) {
a->doExtraTile();
}
auto tileCfg = a->tryGetConfig(axisInfo.tileLevel - 1);
if (tileCfg == nullptr) {
continue;
}
tileCfg->value = axisInfo.constSize;
}
});
}
Sketch DynamicShapeStrategy::SketchAnalysis(std::vector<int64_t> shapes) {
Sketch gpuSketch = Sketch::kAllStatic;
auto it = std::find(shapes.begin(), shapes.end(), -1);
if (it == shapes.end()) {
return gpuSketch;
}
if (shapes.back() == -1) {
int dynCnt = 0;
for (int i = static_cast<int>(shapes.size()) - 1; i >= 0; --i) {
auto shape = shapes[static_cast<unsigned>(i)];
if (shape != -1) {
break;
}
++dynCnt;
}
gpuSketch = Sketch(std::min<int>(dynCnt, static_cast<int>(Sketch::kMoreDynamicInner)));
} else {
int64_t mul = 1;
for (int i = static_cast<int>(shapes.size()) - 1; i >= 0; --i) {
auto shape = shapes[static_cast<unsigned>(i)];
if (shape == -1) {
break;
}
mul *= shape;
}
if (mul >= largeShapeLimit) {
gpuSketch = Sketch::kLargeStaticInner;
} else {
gpuSketch = Sketch::kSmallStaticInner;
}
}
return gpuSketch;
}
void DynamicShapeStrategy::AddCpuConstraint(CpuModelGraphPtr initGraph) {
initGraph->rootAxis->forEachAxisTopDown([this](const AxisPtr a) {
if (a->axisType.find(mlir::autotiling::Axis::AxisLabel::kDynamic) != a->axisType.end()) {
auto tile1_cons = Constraint({microTileSize});
a->tryAddConstraint(0, tile1_cons);
}
});
}
void DynamicShapeStrategy::DoConstTile(const GpuModelGraphPtr initGraph, Sketch sketch) {
std::vector<int> dynamicTiles;
auto totalBlocks = GpuInfo::getInstance(initGraph->hardware).getTotalAvailableBlocks();
if (sketch == Sketch::kOneDynamicInner) {
dynamicTiles = {1024};
} else if (sketch == Sketch::kTwoDynamicInner) {
dynamicTiles = {8, 32};
} else if (sketch == Sketch::kMoreDynamicInner) {
dynamicTiles = {4, 8, 32};
}
int staticTiles = 1;
initGraph->rootAxis->forEachAxisBottomUp([&](const AxisPtr a) {
if (a->axisType.find(mlir::autotiling::Axis::AxisLabel::kDynamic) != a->axisType.end()) {
if (dynamicTiles.empty()) {
if (sketch == Sketch::kSmallStaticInner && staticTiles < totalBlocks) {
a->tryAddConstraint(0, Constraint({totalBlocks / staticTiles}));
} else {
a->tryAddConstraint(0, Constraint({1}));
}
} else {
a->tryAddConstraint(0, Constraint({dynamicTiles.back()}));
staticTiles *= dynamicTiles.back();
dynamicTiles.pop_back();
}
} else {
if (staticTiles < totalBlocks) {
int tile = std::min<int64_t>(totalBlocks, a->range.second);
a->tryAddConstraint(0, Constraint({tile}));
staticTiles *= tile;
} else {
a->tryAddConstraint(0, Constraint({1}));
}
}
});
}
void DynamicShapeStrategy::DoVariableTile(const GpuModelGraphPtr initGraph, Sketch sketch) {
auto &gpuTool = akgglobal::GpuScheduleTool::getInstance();
auto &tool = akgglobal::PrimeNumTool::getInstance();
if (sketch != Sketch::kTwoDynamicInner) {
DoConstTile(initGraph, sketch);
return;
}
auto prime1 = tool.getOnePrimeWithIdxUpdate();
auto prime2 = tool.getOnePrimeWithIdxUpdate();
auto arg1 = gpuTool.addRuntimeArgument(static_cast<int64_t>(prime1));
auto arg2 = gpuTool.addRuntimeArgument(static_cast<int64_t>(prime2));
std::vector<akgglobal::RuntimeVar> dynamicTiles{arg1, arg2};
size_t totalAxis = 0;
initGraph->rootAxis->forEachAxisBottomUp([&, currMapDim = size_t{0}](const AxisPtr a) mutable {
++totalAxis;
if (a->axisType.find(mlir::autotiling::Axis::AxisLabel::kDynamic) != a->axisType.end()) {
return;
}
if (initGraph->gpuGrid.canApply(a->range.second)) {
auto gridCfg = initGraph->gpuGrid.alloc(a, a->range.second);
gridCfg->mapDim = static_cast<int>(currMapDim);
currMapDim++;
}
a->tryAddConstraint(0, Constraint({1}));
});
size_t currDynAxis = 0;
initGraph->rootAxis->forEachAxisBottomUp([&](const AxisPtr a) {
if (a->axisType.find(mlir::autotiling::Axis::AxisLabel::kDynamic) == a->axisType.end()) {
return;
}
if (dynamicTiles.empty()) {
a->tryAddConstraint(0, Constraint({1}));
} else {
auto arg = dynamicTiles.back();
auto prime = static_cast<int>(arg.prime);
a->tryAddConstraint(0, Constraint({prime}));
if (initGraph->gpuBlock.canApply(-1)) {
(void)initGraph->gpuBlock.alloc(a, -1);
auto blockCfg = std::make_shared<GpuBlock>("DynBlock");
blockCfg->index = ConfigPos::kInner;
blockCfg->value = prime;
a->configs[blockCfg->type].push_back(blockCfg);
}
if ((totalAxis <= initGraph->gpuGrid.availbleSize.size() || currDynAxis == 1) &&
initGraph->gpuGrid.canApply(-1)) {
(void)initGraph->gpuGrid.alloc(a, -1);
auto gridCfg = std::make_shared<GpuGrid>("DynGrid");
gridCfg->index = ConfigPos::kOuter;
gridCfg->value = 1;
a->configs[gridCfg->type].push_back(gridCfg);
}
if (currDynAxis == 0) {
arg.expr = "mod32 OR -1";
gpuTool.updateRuntimeArgument(arg);
} else if (currDynAxis == 1 && totalAxis >= initGraph->gpuGrid.availbleSize.size()) {
arg.expr = "mod8 OR -1";
gpuTool.updateRuntimeArgument(arg);
}
dynamicTiles.pop_back();
}
currDynAxis++;
});
}
void DynamicShapeStrategy::AddGpuConstraint(GpuModelGraphPtr initGraph) {
std::vector<int64_t> shapes;
initGraph->rootAxis->forEachAxisTopDown([&](const AxisPtr a) {
if (a->axisType.find(mlir::autotiling::Axis::AxisLabel::kDynamic) != a->axisType.end()) {
shapes.push_back(-1);
} else if (a->range.second > 1) {
shapes.push_back(a->range.second);
}
});
auto sketch = SketchAnalysis(shapes);
if (sketch == Sketch::kAllStatic) {
return;
}
if (initGraph->tilingMode == "static") {
DoConstTile(initGraph, sketch);
} else {
if (sketch == Sketch::kTwoDynamicInner) {
DoVariableTile(initGraph, sketch);
} else {
DoConstTile(initGraph, sketch);
}
}
}
std::vector<int> rankArray(const std::vector<AxisPtr> &arr) {
std::vector<std::pair<int, int>> pairs;
for (size_t i = 0; i < arr.size(); i++) {
auto axisLength = arr[i]->range.second;
(void)pairs.push_back({axisLength, static_cast<int>(i)});
}
std::sort(pairs.begin(), pairs.end(),
[](const std::pair<int, int> &a, const std::pair<int, int> &b) { return a.first > b.first; });
std::vector<int> ranks;
for (const auto &p : pairs) {
(void)ranks.emplace_back(p.second);
}
return ranks;
}
void TransposeStrategy::AddGpuConstraint(GpuModelGraphPtr gpuGraph) {
std::deque<AxisPtr> sortByLoadAxes;
NodePtr transposeRead;
NodePtr transposeWrite;
for (auto node : gpuGraph->nodes_) {
auto isHigherRank = [&](const NodePtr target) -> bool {
return (target == nullptr || target->loopNest_.size() < node->loopNest_.size());
};
if (node->opType == "Load" && isHigherRank(transposeRead)) {
transposeRead = node;
}
if (node->opType == "Store" && isHigherRank(transposeWrite)) {
transposeWrite = node;
}
}
if (!transposeRead || !transposeWrite || transposeWrite->loopNest_.size() < minRankForTranspose) {
return;
}
sortByLoadAxes.assign(transposeWrite->loopNest_.rbegin(), transposeWrite->loopNest_.rend());
std::sort(sortByLoadAxes.begin(), sortByLoadAxes.end(),
[](const AxisPtr a1, const AxisPtr a2) { return a1->priority > a2->priority; });
auto warpSize = GpuInfo::getInstance(gpuGraph->hardware).getWarpSizes();
auto innerMostReadAxis = transposeRead->loopNest_.back();
auto innerMostWriteAxis = transposeWrite->loopNest_.back();
auto allocResource = [&, this](const AxisPtr axis) {
auto blockSize = std::min<int64_t>(warpSize, axis->range.second);
if (!gpuGraph->gpuBlock.canApply(blockSize)) {
return;
}
auto blockcfg = gpuGraph->gpuBlock.alloc(axis, blockSize);
blockcfg->index = ConfigPos::kInner;
blockcfg->mapDim = static_cast<int>(gpuGraph->gpuBlock.currAllocDim()) - 1;
auto outerSize = (axis->range.second - 1) / blockSize + 1;
if (outerSize > maxExpectSeqPerAxis && outerSize % maxExpectSeqPerAxis == 0 &&
gpuGraph->gpuGrid.canApply(outerSize / maxExpectSeqPerAxis)) {
auto innerTile = axis->tryGetConfig(1);
if (!innerTile) {
axis->doExtraTile();
innerTile = axis->tryGetConfig(1);
}
innerTile->value = static_cast<int>(blockSize);
auto outerTile = axis->tryGetConfig(0);
outerTile->value = innerTile->value * maxExpectSeqPerAxis;
outerSize = outerSize / maxExpectSeqPerAxis;
auto seqCfg = std::make_shared<GpuSeq>(maxExpectSeqPerAxis);
seqCfg->index = ConfigPos::kMiddle;
axis->configs[seqCfg->type].push_back(seqCfg);
auto gridcfg = gpuGraph->gpuGrid.alloc(axis, outerSize);
gridcfg->index = ConfigPos::kOuter;
} else {
auto tile = axis->tryGetConfig(0);
tile->value = static_cast<int>(blockSize);
if (outerSize <= maxExpectSeq) {
auto seqCfg = std::make_shared<GpuSeq>(outerSize);
seqCfg->index = ConfigPos::kOuter;
axis->configs[seqCfg->type].push_back(seqCfg);
} else {
auto gridcfg = gpuGraph->gpuGrid.alloc(axis, outerSize);
gridcfg->index = ConfigPos::kOuter;
}
}
};
allocResource(innerMostWriteAxis);
allocResource(innerMostReadAxis);
gpuGraph->sortedAxes = sortByLoadAxes;
gpuGraph->updateMaxRankTensor(transposeWrite);
}
bool BroadcastStrategy::searchForSmallShape(const GpuModelGraphPtr gpuGraph, const AxisPtr a) {
bool succ = false;
auto warpSize = GpuInfo::getInstance(gpuGraph->hardware).getWarpSizes();
for (auto expectSeq = minExpectSeq; expectSeq <= maxExpectSeq; ++expectSeq) {
if (a->range.second % expectSeq != 0 || a->range.second / expectSeq < warpSize) {
continue;
}
auto innerSize = a->range.second / expectSeq;
if (!gpuGraph->gpuBlock.canApply(innerSize)) {
continue;
}
auto blockcfg = gpuGraph->gpuBlock.alloc(a, innerSize);
blockcfg->index = ConfigPos::kInner;
auto tile = a->tryGetConfig(0);
tile->value = static_cast<int>(innerSize);
auto seqCfg = std::make_shared<GpuSeq>(expectSeq);
seqCfg->index = kOuter;
a->configs[seqCfg->type].push_back(seqCfg);
succ = true;
break;
}
return succ;
}
bool BroadcastStrategy::searchForLargeShape(const GpuModelGraphPtr gpuGraph, const AxisPtr a) {
bool succ = false;
auto warpSize = GpuInfo::getInstance(gpuGraph->hardware).getWarpSizes();
auto innerSize = warpSize;
for (auto expectSeq = minExpectSeq; expectSeq <= maxExpectSeq; ++expectSeq) {
auto middleSize = warpSize * expectSeq;
if (a->range.second % middleSize != 0) {
continue;
}
auto outerSize = a->range.second / middleSize;
if (!gpuGraph->gpuGrid.canApply(outerSize)) {
continue;
}
auto blockcfg = gpuGraph->gpuBlock.alloc(a, innerSize);
blockcfg->index = ConfigPos::kInner;
a->doExtraTile();
auto tile1 = a->tryGetConfig(1);
tile1->value = innerSize;
auto tile = a->tryGetConfig(0);
tile->value = middleSize;
auto seqCfg = std::make_shared<GpuSeq>(expectSeq);
seqCfg->index = kMiddle;
a->configs[seqCfg->type].push_back(seqCfg);
auto gridcfg = gpuGraph->gpuGrid.alloc(a, outerSize);
gridcfg->index = ConfigPos::kOuter;
succ = true;
break;
}
return succ;
}
NodePtr BroadcastStrategy::findMinRankNode(const GpuModelGraphPtr gpuGraph) {
NodePtr minRankNode = nullptr;
for (auto node : gpuGraph->nodes()) {
if (node->loopNest_.empty() || (node->opType != "Load" && node->opType != "Store")) {
continue;
}
if (minRankNode == nullptr || minRankNode->loopNest_.size() > node->loopNest_.size()) {
minRankNode = node;
}
}
return minRankNode;
}
int BroadcastStrategy::computeExpectedSeq(const GpuModelGraphPtr gpuGraph, const AxisPtr innerMostReadAxis) {
auto maxAvailBlocks = gpuGraph->gpuBlock.totalAvailableSize;
assert(maxExpectSeq > 0 && innerMostReadAxis->range.second > 0);
auto maxBlocks = StrategyHelper::getLargestDivisor(maxAvailBlocks, innerMostReadAxis->range.second / maxExpectSeq);
return std::min<int>(maxExpectSeq,
(maxBlocks * maxExpectSeq * blockWasteCoef - 1) / innerMostReadAxis->range.second + 1);
}
void BroadcastStrategy::searchSeqAxisFromInnerToOuter(const GpuModelGraphPtr gpuGraph, const NodePtr maxRankNode,
int expectSeq) {
int currBlocks = 1;
auto warpSize = GpuInfo::getInstance(gpuGraph->hardware).getWarpSizes();
for (int i = static_cast<int>(maxRankNode->loopNest_.size()) - 1; i >= 0; --i) {
auto a = maxRankNode->loopNest_[static_cast<unsigned>(i)];
if (expectSeq <= 1 || currBlocks >= proposedBlock || currBlocks >= warpSize) {
break;
}
auto currGrids = gpuGraph->problemSize / currBlocks / a->range.second;
if ((currGrids * gridWasteCoef >= proposedGrid) && gpuGraph->gpuBlock.canApply(a->range.second / maxExpectSeq)) {
if (searchForSmallShape(gpuGraph, a)) {
expectSeq = 1;
}
} else if (!gpuGraph->isDynamicShape && a->range.second % warpSize == 0 && gpuGraph->gpuBlock.canApply(warpSize)) {
if (searchForLargeShape(gpuGraph, a)) {
expectSeq = 1;
}
}
currBlocks *= static_cast<int>(a->range.second);
}
}
void BroadcastStrategy::AddGpuConstraint(GpuModelGraphPtr gpuGraph) {
NodePtr maxRankNode = gpuGraph->getMaxRankTensor();
if (!maxRankNode || maxRankNode->loopNest_.size() <= 1) {
return;
}
NodePtr minRankNode = findMinRankNode(gpuGraph);
if (!minRankNode || minRankNode->loopNest_.size() == maxRankNode->loopNest_.size()) {
return;
}
auto innerMostReadAxis = maxRankNode->loopNest_.back();
int expectSeq = computeExpectedSeq(gpuGraph, innerMostReadAxis);
std::tie(proposedGrid, proposedBlock) =
StrategyHelper::getProposalParallelSize(gpuGraph->problemSize, gpuGraph->hardware);
searchSeqAxisFromInnerToOuter(gpuGraph, maxRankNode, expectSeq);
}
void ParallelStrategy::InitProposalResource(const GpuModelGraphPtr gpuGraph) {
std::tie(proposedGrid, proposedBlock) =
StrategyHelper::getProposalParallelSize(gpuGraph->problemSize, gpuGraph->hardware);
}
bool ParallelStrategy::tryMapBlock(const GpuModelGraphPtr gpuGraph, const AxisPtr axis) {
auto loopExtent = axis->getRestExtent();
int64_t blockLimit = std::min<int64_t>(proposedBlock / gpuGraph->gpuBlock.currSize, gpuGraph->gpuBlock.rest());
int64_t largestForBlock = StrategyHelper::getLargestDivisor(blockLimit, loopExtent);
auto warpSize = GpuInfo::getInstance(gpuGraph->hardware).getWarpSizes();
if (largestForBlock == 1 && axis->axisType.find(Axis::AxisLabel::kReduction) != axis->axisType.end()) {
largestForBlock = gpuGraph->gpuBlock.totalAvailableSize / gpuGraph->gpuBlock.currSize;
} else if (loopExtent < axis->range.second) {
largestForBlock = loopExtent;
} else if (largestForBlock * blockWasteCoef <= blockLimit &&
gpuGraph->gpuBlock.canApply(std::min<int64_t>(blockLimit, loopExtent))) {
largestForBlock = std::min<int64_t>(blockLimit, loopExtent);
currHasMinMax = true;
} else if (axis->isInnerMost && largestForBlock % warpSize != 0 && largestForBlock * blockLimitCoef <= blockLimit &&
gpuGraph->gpuBlock.canApply(std::min<int64_t>(blockLimit, loopExtent))) {
largestForBlock = std::min<int64_t>(blockLimit, loopExtent);
currHasMinMax = true;
}
if (!gpuGraph->gpuBlock.seen(axis) && gpuGraph->gpuBlock.canApply(largestForBlock)) {
(void)gpuGraph->gpuBlock.alloc(axis, largestForBlock);
if (currHasMinMax) {
auto tile = axis->tryGetConfig(0);
tile->value = static_cast<int>(largestForBlock);
} else {
auto consTile = Constraint(1, static_cast<int>(largestForBlock), 1);
axis->tryAddConstraint(0, consTile);
}
return true;
}
return false;
}
bool ParallelStrategy::tryMapGrid(const GpuModelGraphPtr gpuGraph, const AxisPtr axis) {
auto DisableGridMapping = [&gpuGraph](const AxisPtr &axis) -> bool {
if (gpuGraph->globalConfigs.find(akg::utils::kEnableAtomicAdd) == gpuGraph->globalConfigs.end()) {
return false;
}
if (axis->axisType.find(mlir::autotiling::Axis::AxisLabel::kReduction) == axis->axisType.end()) {
return false;
}
return !dyn_cast<BoolAttr>(gpuGraph->globalConfigs[akg::utils::kEnableAtomicAdd]).getValue();
};
auto loopExtent = axis->getRestExtent();
int64_t gridLimit = std::min<int64_t>(proposedGrid / gpuGraph->gpuGrid.currSize, gpuGraph->gpuGrid.rest());
int64_t largestForGrid = StrategyHelper::getLargestDivisor(gridLimit, loopExtent);
if (loopExtent < axis->range.second) {
largestForGrid = loopExtent;
}
if (!currHasMinMax && largestForGrid * gridWasteCoef < gpuGraph->gpuGrid.rest() &&
gpuGraph->gpuGrid.canApply(std::min<int64_t>(gpuGraph->gpuGrid.rest(), loopExtent))) {
largestForGrid = std::min<int64_t>(gpuGraph->gpuGrid.rest(), loopExtent);
currHasMinMax = true;
}
if (!gpuGraph->gpuGrid.seen(axis) && gpuGraph->gpuGrid.canApply(largestForGrid) && !DisableGridMapping(axis)) {
(void)gpuGraph->gpuGrid.alloc(axis, largestForGrid);
if (largestForGrid < loopExtent && loopExtent < axis->range.second) {
axis->doExtraTile();
auto consTile = Constraint({static_cast<int>(largestForGrid)});
axis->tryAddConstraint(1, consTile);
} else {
auto consTile = Constraint(axis->range.second / largestForGrid, axis->range.second, 1);
axis->tryAddConstraint(0, consTile);
}
return true;
}
return false;
}
void ParallelStrategy::AddGpuConstraint(GpuModelGraphPtr gpuGraph) {
InitProposalResource(gpuGraph);
for (auto axis : gpuGraph->sortedAxes) {
currHasMinMax = false;
auto loopExtent = axis->getRestExtent();
if (loopExtent == 1) {
continue;
}
auto succMapBlock = tryMapBlock(gpuGraph, axis);
if (succMapBlock && gpuGraph->isDynamicShape) {
continue;
}
(void)tryMapGrid(gpuGraph, axis);
gpuGraph->hasMinMax |= currHasMinMax;
}
}
void ReduceStrategy::AddGpuConstraint(GpuModelGraphPtr gpuGraph) {
OpBuilder builder(gpuGraph->funcOp);
std::vector<AxisPtr> axes;
gpuGraph->rootAxis->forEachAxisTopDown([this, &axes, &gpuGraph](const AxisPtr a) {
if (!this->IsRelevant(a, gpuGraph)) {
return;
}
axes.push_back(a);
});
gpuGraph->hasMinMax = true;
bool enableParallelReduction = true;
bool enableAtomicReduction = gpuGraph->funcOp->hasAttr(akg::utils::kEnableAtomicAdd) &&
dyn_cast<BoolAttr>(gpuGraph->funcOp->getAttr(akg::utils::kEnableAtomicAdd)).getValue();
bool applyReorderPass = true;
GpuTemplateSolver::SolveScheduleForReductionOps(axes, enableParallelReduction, enableAtomicReduction,
applyReorderPass);
gpuGraph->globalConfigs[akg::utils::kEnableParallelReduce] = builder.getBoolAttr(enableParallelReduction);
gpuGraph->globalConfigs[akg::utils::kEnableAtomicAdd] = builder.getBoolAttr(enableAtomicReduction);
gpuGraph->globalConfigs[akg::utils::kApplyReorderPass] = builder.getBoolAttr(applyReorderPass);
}
int64_t BroadcastStrategy::getMaxByte(const CpuModelGraphPtr cpuGraph, NodePtr &minRankNode) {
int64_t maxByte = 0;
for (auto node : cpuGraph->nodes()) {
if (node->loopNest_.empty() || (node->opType != "Load" && node->opType != "Store")) {
continue;
}
if (minRankNode == nullptr || minRankNode->loopNest_.size() > node->loopNest_.size()) {
minRankNode = node;
}
if (!node->dataType) {
continue;
}
maxByte = std::max(maxByte, static_cast<int64_t>(node->dataType.getIntOrFloatBitWidth()));
}
maxByte /= cpuGraph->bitUnit;
return maxByte;
}
int BroadcastStrategy::getMaxBroadcastAxes(const CpuModelGraphPtr cpuGraph, std::vector<AxisPtr> maxloopNest) {
NodePtr minRankNode = nullptr;
int64_t maxByte = getMaxByte(cpuGraph, minRankNode);
if (!minRankNode || minRankNode->loopNest_.size() == maxloopNest.size()) {
return -1;
}
if (cpuGraph->funcOp->hasAttr(kOperatorTypeStr)) {
OpBuilder builder(cpuGraph->funcOp);
(void)cpuGraph->funcOp->removeAttr(kOperatorTypeStr);
Attribute opType = builder.getStringAttr("Broadcast");
cpuGraph->funcOp->setAttr(kOperatorTypeStr, opType);
}
int broadcastOpNum = 0;
int nonBroadcastOpNum = 0;
int estimatedDifferenceNum = 0;
std::vector<AxisPtr> minloopNest = minRankNode->loopNest_;
for (auto node : cpuGraph->nodes()) {
if (node->loopNest_.empty() || (node->opType != "Load" && node->opType != "Store")) {
continue;
}
if (maxloopNest.size() == node->loopNest_.size()) {
++broadcastOpNum;
}
if (minloopNest.size() == node->loopNest_.size()) {
if (node->opType == "Store") {
++estimatedDifferenceNum;
}
++nonBroadcastOpNum;
}
}
nonBroadcastOpNum -= estimatedDifferenceNum;
broadcastOpNum += estimatedDifferenceNum;
int64_t broadcastSize = 1;
for (int i = static_cast<int>(maxloopNest.size()) - 1; i >= 0; --i) {
auto axis = maxloopNest[static_cast<unsigned>(i)];
int64_t axisSize = axis->range.second;
auto iter = std::find(minloopNest.begin(), minloopNest.end(), axis);
if (iter == minloopNest.end()) {
broadcastSize *= axisSize;
}
}
int64_t nonBroadcastBytesSize = nonBroadcastOpNum * maxByte;
int64_t broadcastBytesSize = broadcastSize * broadcastOpNum * maxByte;
return static_cast<int>(cpuGraph->l1Cache / (broadcastBytesSize + nonBroadcastBytesSize + 1));
}
void BroadcastStrategy::AddCpuConstraint(CpuModelGraphPtr cpuGraph) {
NodePtr maxRankNode = cpuGraph->getMaxRankTensor();
if (!maxRankNode || maxRankNode->loopNest_.size() <= 1) {
return;
}
std::vector<AxisPtr> maxloopNest = maxRankNode->loopNest_;
int maxSize = static_cast<int>(maxloopNest.size());
if (maxloopNest[maxSize - 1]->isInnerMost) {
return;
}
int maxBroadcastAxes = getMaxBroadcastAxes(cpuGraph, maxloopNest);
bool isNewBand = true;
int pos = cpuGraph->tileNum;
for (int i = static_cast<int>(maxloopNest.size()) - 1; i >= 0; --i) {
auto axis = maxloopNest[i];
if (pos) {
axis->doExtraTile();
}
int64_t tileSize = 1;
int64_t axisSize = axis->range.second;
if (axis->isInnerMost && isNewBand) {
tileSize = StrategyHelper::getLargestDivisor(maxBroadcastAxes, axisSize);
isNewBand = false;
}
tileSize = std::min(axisSize, tileSize);
axis->tryAddConstraint(0, Constraint({static_cast<int>(tileSize)}), kTileCfg);
}
++cpuGraph->tileNum;
}
void UnrollStrategy::AddCpuConstraint(CpuModelGraphPtr cpuGraph) {
auto iter = cpuInstructionSetMap.find(cpuGraph->feature);
if (iter == cpuInstructionSetMap.end()) {
llvm::errs() << "The instruction set supported by the cpu only includes "
"sse, avx, avx2, avx512 and neon."
<< "\n";
return;
}
int64_t instructionSetBit = iter->second;
int64_t vectorSize = kVectorize128Bit;
for (auto node : cpuGraph->nodes_) {
if (!node->dataType) {
continue;
}
int64_t elementBit = static_cast<int64_t>(node->dataType.getIntOrFloatBitWidth());
vectorSize = std::min(vectorSize, instructionSetBit / elementBit);
}
int pos = cpuGraph->tileNum;
for (auto bandRoot : cpuGraph->rootAxis->children) {
std::deque<AxisPtr> q = {bandRoot};
bandRoot->forEachAxisTopDown([&q](const AxisPtr axis) { q.push_front(axis); });
bool isNewBand = true;
for (auto axis : q) {
if (pos) {
axis->doExtraTile();
}
if (!axis->isInnerMost || !isNewBand) {
axis->tryAddConstraint(pos, Constraint({1}), kTileCfg);
continue;
}
if (axis->axisType.find(Axis::AxisLabel::kDynamic) != axis->axisType.end()) {
axis->tryAddConstraint(pos, Constraint({(int32_t)vectorSize}), kTileCfg);
} else {
int64_t unrollSize = BEST_UNROLL_NUM;
int64_t axisSize = axis->range.second;
while (axisSize % unrollSize != 0 && unrollSize > MIN_UNROLL_NUM) {
unrollSize /= 2;
}
if (axisSize % vectorSize == 0) {
vectorSize = 1;
} else if (axisSize > vectorSize) {
vectorSize = axisSize - axisSize % vectorSize;
}
int64_t curVectorSize = std::min(axisSize, vectorSize);
axis->tryAddConstraint(pos, Constraint({static_cast<int>(curVectorSize)}), kTileCfg);
}
(void)axis->axisType.insert(mlir::autotiling::Axis::AxisLabel::kVectorization);
isNewBand = false;
}
}
++cpuGraph->tileNum;
}
void ParallelStrategy::AddCpuConstraint(CpuModelGraphPtr cpuGraph) {
for (auto bandRoot : cpuGraph->rootAxis->children) {
int64_t dataSize = bandRoot->range.second;
std::deque<AxisPtr> q = {bandRoot};
bandRoot->forEachAxisTopDown([&q, &dataSize](const AxisPtr axis) {
q.push_front(axis);
dataSize *= axis->range.second;
});
for (auto axis : q) {
if (axis->axisIdx != 0) {
continue;
}
(void)axis->axisType.insert(mlir::autotiling::Axis::AxisLabel::kMultiCore);
int64_t axisSize = axis->range.second;
int parallelNum = BEST_PARALLEL_NUM;
int unrollTileValue = 1;
bool isUnroll = axis->axisType.count(mlir::autotiling::Axis::AxisLabel::kVectorization);
if (isUnroll) {
auto unrollConfig = axis->tryGetConfig(1, kTileCfg);
unrollConfig->mergeConstraints();
unrollTileValue = unrollConfig->getValidCandidates()[0];
axisSize = unrollTileValue;
}
int evaluateNum = static_cast<int>(dataSize) / MIN_EXEC_NUM_PER_THREAD;
if (evaluateNum >= BEST_PARALLEL_NUM) {
parallelNum = std::min(axis->range.second, static_cast<int64_t>(BEST_PARALLEL_NUM));
} else if (evaluateNum > 1) {
while (parallelNum > 0 && axisSize % parallelNum != 0) {
parallelNum -= PARALLEL_DECREASE_VALUE;
}
} else {
parallelNum = 1;
}
if (parallelNum <= 0) {
parallelNum = evaluateNum;
}
int paralleltileValue = static_cast<int>(axis->range.second) / parallelNum;
if (axis->range.second % parallelNum != 0) {
paralleltileValue = static_cast<int>(axis->range.second);
}
if (paralleltileValue < MIN_UNROLL_NUM && isUnroll) {
paralleltileValue = std::min(static_cast<int>(axis->range.second), MIN_UNROLL_NUM);
unrollTileValue = paralleltileValue;
}
paralleltileValue = std::max(paralleltileValue, unrollTileValue);
axis->tryAddConstraint(0, Constraint({paralleltileValue}));
axis->tryAddConstraint(1, Constraint({unrollTileValue}));
}
}
}
unsigned getOuterTileSize(const AxisPtr axis, unsigned blockNumber) {
if (!axis || !axis->hasConstantBounds()) {
const unsigned MIN_TILE_SIZE = 512;
return MIN_TILE_SIZE;
}
int64_t upperBound = axis->getConstantUpperBound();
int64_t lowerBound = axis->getConstantLowerBound();
unsigned extent = static_cast<unsigned>(upperBound - lowerBound);
unsigned tileSizePerBlock = (extent + blockNumber - 1) / blockNumber;
const unsigned MIN_TILE_SIZE = 512;
if (tileSizePerBlock < MIN_TILE_SIZE && extent >= MIN_TILE_SIZE) {
tileSizePerBlock = MIN_TILE_SIZE;
}
return tileSizePerBlock;
}
namespace {
constexpr unsigned kNpuTargetBlocks = 48;
constexpr int64_t kNpuFallbackUbSizeInBytes = 192 * 1024;
constexpr unsigned kNpuMinInnerTileSize = 512;
constexpr int64_t kDefaultTypeBits = 32;
constexpr int64_t kBitsPerByte = 8;
constexpr int64_t kUbAlignBytes = 32;
constexpr int64_t kUbAlignBits = kUbAlignBytes * kBitsPerByte;
constexpr int64_t kUbGuardReserveBytes = 64;
constexpr int64_t kHivmAutoMultiBufferFactor = 2;
constexpr int64_t kGenericHivmWorkspaceBuffers = 1;
constexpr int64_t kCompareHivmWorkspaceBuffers = 3;
constexpr int64_t kHeavyHivmWorkspaceBuffers = 3;
constexpr int64_t kTransposeOrBroadcastFootprintBuffers = 4;
constexpr int64_t kSuffixPreservePointRows = 2;
struct TransposeCoLiveReserve {
SmallVector<size_t, 6> axisOrder;
SmallVector<int32_t, 6> alignDims;
int64_t elementBits{kDefaultTypeBits};
};
struct VectorLiveReserve {
SmallVector<size_t, 6> axisOrder;
SmallVector<int32_t, 6> alignDims;
int64_t elementBits{kDefaultTypeBits};
};
struct VectorPeakReserveSet {
SmallVector<VectorLiveReserve, 8> liveBuffers;
int64_t fixedBytes{0};
};
struct TransposeVectorInfo {
SmallVector<size_t, 6> sourceAxisOrder;
SmallVector<size_t, 6> targetAxisOrder;
int64_t elementBits{kDefaultTypeBits};
bool isF32{true};
};
struct NpuBandContext {
size_t bandIdx{0};
SmallVector<AxisPtr, 4> axes;
SmallVector<int64_t, 4> extents;
GraphTemplate graphTemplate{GraphTemplate::DEFAULT};
int64_t rawUbElems{1};
int64_t ubCapacityElems{1};
int64_t vectorUbCapacityElems{1};
int64_t vectorUbBitsPerElem{kDefaultTypeBits};
int64_t vectorElementBits{kDefaultTypeBits};
int64_t transposeElementBits{kDefaultTypeBits};
bool transposeElementIsF32{true};
SmallVector<int64_t, 6> axesAlignUnits;
SmallVector<size_t, 6> transposeSourceAxisOrder;
SmallVector<size_t, 6> transposeTargetAxisOrder;
SmallVector<TransposeVectorInfo, 6> transposeInfos;
SmallVector<SmallVector<TransposeCoLiveReserve, 4>, 4> transposeCoLiveReserveSets;
SmallVector<VectorPeakReserveSet, 32> vectorPeakReserveSets;
int64_t fixedReserveBytes{0};
int64_t targetBlocks{kNpuTargetBlocks};
int64_t mathComplexityScore{0};
int64_t smallestTypeBits{kDefaultTypeBits};
bool hasDynamicAxis{false};
bool hasReduction{false};
bool lastAxisIsReduction{false};
};
struct BandTilePlan {
SmallVector<unsigned, 4> outerTiles;
SmallVector<unsigned, 4> innerTiles;
};
int64_t computeAxisAlignSize(size_t axisIdx, const NpuBandContext &ctx);
void applyF32LastDimDoubleAlign(const NpuBandContext &ctx, SmallVectorImpl<int64_t> &axisAlignUnits);
static void applyFallbackAxisTiling(const AxisPtr axis, const SmallVector<unsigned, 4> &tileSizes,
unsigned innerTileSize, unsigned blockNumber, size_t &maxLevelToTile,
bool isFullTileAxis = false);
inline bool isReductionAxis(const AxisPtr &axis) {
if (!axis) return false;
if (axis->axisType.count(mlir::autotiling::Axis::AxisLabel::kReduction) > 0) return true;
Operation *loopOp = axis->getLoopOperation();
return loopOp && loopOp->hasAttr(kReductionLoopAttr);
}
inline bool isDynamicAxis(const AxisPtr &axis) {
return axis && axis->axisType.count(mlir::autotiling::Axis::AxisLabel::kDynamic) > 0;
}
inline bool isTransposeAxis(const AxisPtr &axis) {
if (!axis || !axis->loop) return false;
Operation *loopOp = axis->getLoopOperation();
return loopOp && loopOp->hasAttr(kTransposeLoopAttr);
}
inline int64_t ceilDivInt64(int64_t lhs, int64_t rhs) { return (rhs <= 0) ? lhs : (lhs + rhs - 1) / rhs; }
inline int64_t multiplyAndCap(int64_t lhs, int64_t rhs) {
if (lhs <= 0 || rhs <= 0) return 0;
return (lhs > LLONG_MAX / rhs) ? LLONG_MAX : lhs * rhs;
}
inline int64_t alignUpInt64(int64_t value, int64_t alignment) {
if (alignment <= 1) return value;
return ceilDivInt64(value, alignment) * alignment;
}
inline unsigned saturateToTileValue(int64_t value) {
return static_cast<unsigned>(std::clamp<int64_t>(value, 1, UINT_MAX));
}
int64_t getNodeMathComplexity(const NodePtr &node) {
if (!node) return 0;
if (node->opType == "HeavyElem") return 3;
if (node->opType == "Reduce") return 1;
Operation *op = node->op();
if (!op) return 0;
llvm::StringRef name = op->getName().getStringRef();
if (name.contains("pow") || name.contains("rsqrt") || name.contains("exp") || name.contains("log") ||
name.contains("tanh"))
return 3;
if (name.contains("sqrt") || name.contains("div")) return 2;
return 0;
}
struct UbCapacity {
int64_t rawUbElems{1};
int64_t ubCapacityElems{1};
int64_t smallestTypeBits{kDefaultTypeBits};
};
int64_t alignUbElems(int64_t elems, int64_t typeBits) {
int64_t safeTypeBits = std::max<int64_t>(typeBits, 1);
int64_t bits = std::max<int64_t>(elems, 1) * safeTypeBits;
return std::max<int64_t>((bits / kUbAlignBits) * kUbAlignBits / safeTypeBits, 1);
}
UbCapacity computeUbCapacity(const NpuModelGraphPtr &npuGraph) {
UbCapacity c;
if (!npuGraph) return c;
int64_t ubBytes = (npuGraph->ubSize > 0) ? npuGraph->ubSize : kNpuFallbackUbSizeInBytes;
ubBytes = std::max<int64_t>(ubBytes - kUbGuardReserveBytes, kUbAlignBytes);
c.smallestTypeBits = std::max<int64_t>(static_cast<int64_t>(npuGraph->smallestTypeBits), 1);
int64_t maxBufferCnt = std::max<int64_t>(npuGraph->maxBufferCnt, 1);
int64_t rawElems = ubBytes * kBitsPerByte / c.smallestTypeBits;
c.rawUbElems = alignUbElems(rawElems, c.smallestTypeBits);
c.ubCapacityElems = alignUbElems(rawElems / maxBufferCnt, c.smallestTypeBits);
return c;
}
inline int64_t ubCapacityForPointTile(const NpuBandContext &ctx) { return std::max<int64_t>(ctx.ubCapacityElems, 1); }
inline int64_t ubCapacityForVectorTile(const NpuBandContext &ctx) {
return std::max<int64_t>(ctx.vectorUbCapacityElems, 1);
}
std::map<size_t, SmallVector<AxisPtr, 4>> groupAxesByBand(const SmallVector<AxisPtr> &axes) {
std::map<size_t, SmallVector<AxisPtr, 4>> grouped;
for (const auto &axis : axes) {
if (axis) grouped[axis->bandIdx].push_back(axis);
}
for (auto &[_, bandAxes] : grouped) {
std::sort(bandAxes.begin(), bandAxes.end(),
[](const AxisPtr &lhs, const AxisPtr &rhs) { return lhs->axisIdx < rhs->axisIdx; });
}
return grouped;
}
struct BandNodeFacts {
int64_t mathComplexityScore{0}, maxVectorLoadBits{0};
bool hasF32VectorLoad{false};
};
struct VectorUbFootprint {
int64_t peakBitsPerElem{kDefaultTypeBits}, vectorElementBits{kDefaultTypeBits};
int64_t fixedReserveBytes{0};
SmallVector<SmallVector<TransposeCoLiveReserve, 4>, 4> transposeCoLiveReserveSets;
SmallVector<VectorPeakReserveSet, 32> vectorPeakReserveSets;
};
struct VectorUbLiveState {
llvm::DenseMap<Value, int64_t> remainingUses;
llvm::DenseMap<Value, bool> multiBufferedValues;
llvm::DenseMap<Value, int64_t> liveBits;
llvm::DenseMap<Value, SmallVector<size_t, 6>> liveAxisOrders;
llvm::DenseMap<Value, SmallVector<int32_t, 6>> liveAlignDims;
int64_t liveBitsSum{0};
int64_t liveMultiBufferExtraBitsSum{0};
};
struct BandNodeInfo {
BandNodeFacts facts;
SmallVector<NodePtr, 32> nodes;
SmallVector<TransposeVectorInfo, 6> transposeInfos;
};
enum class VectorOpKind { None, Generic, Heavy, Compare, Select };
int64_t getValueElementBits(Value value);
inline bool isNodeInBand(const NodePtr &node, size_t bandIdx) {
return node && !node->loopNest_.empty() && node->loopNest_.front() && node->loopNest_.front()->bandIdx == bandIdx;
}
std::optional<size_t> findAxisIndex(ArrayRef<AxisPtr> axes, const AxisPtr &target) {
if (!target || !target->loop) return std::nullopt;
Operation *targetLoop = target->getLoopOperation();
for (size_t i = 0; i < axes.size(); ++i) {
if (axes[i] && axes[i]->getLoopOperation() == targetLoop) return i;
}
return std::nullopt;
}
SmallVector<size_t, 6> getNodeAxisOrder(const NodePtr &node, ArrayRef<AxisPtr> axes) {
SmallVector<size_t, 6> order;
if (!node) return order;
for (const AxisPtr &axis : node->loopNest_) {
if (std::optional<size_t> idx = findAxisIndex(axes, axis)) order.push_back(*idx);
}
return order;
}
SmallVector<int32_t, 6> getDefaultStrideAlignDims(ArrayRef<size_t> axisOrder) {
size_t rank = axisOrder.size();
if (rank <= 1) return {};
return {static_cast<int32_t>(rank - 2)};
}
SmallVector<size_t, 6> intersectAxisOrder(ArrayRef<size_t> lhs, ArrayRef<size_t> rhs) {
SmallVector<size_t, 6> order;
for (size_t idx : lhs)
if (llvm::is_contained(rhs, idx)) order.push_back(idx);
return order;
}
bool isSameTransposeInfo(const TransposeVectorInfo &lhs, const TransposeVectorInfo &rhs) {
return lhs.sourceAxisOrder == rhs.sourceAxisOrder && lhs.targetAxisOrder == rhs.targetAxisOrder &&
lhs.elementBits == rhs.elementBits && lhs.isF32 == rhs.isF32;
}
SmallVector<size_t, 6> getMemrefAxisOrder(ValueRange indices, const NpuBandContext &ctx) {
SmallVector<size_t, 6> order;
Operation *leafParent =
!ctx.axes.empty() && ctx.axes.back() && ctx.axes.back()->loop ? ctx.axes.back()->getLoopOperation()->getParentOp()
: nullptr;
for (Value idx : indices) {
Value strippedIdx = idx;
while (true) {
if (auto castOp = strippedIdx.getDefiningOp<arith::IndexCastOp>()) {
strippedIdx = castOp.getIn();
continue;
}
if (auto castUIOp = strippedIdx.getDefiningOp<arith::IndexCastUIOp>()) {
strippedIdx = castUIOp.getIn();
continue;
}
break;
}
bool matched = false;
for (size_t i = 0; i < ctx.axes.size(); ++i) {
if (ctx.axes[i] && strippedIdx == ctx.axes[i]->getInductionVar()) {
order.push_back(i);
matched = true;
break;
}
}
if (matched) {
continue;
}
if (auto arg = dyn_cast<BlockArgument>(strippedIdx)) {
if (auto loop = dyn_cast<scf::ForOp>(arg.getOwner()->getParentOp());
leafParent && loop && loop->getParentOp() == leafParent) {
order.push_back(ctx.axes.size() - 1);
}
}
}
return order;
}
bool isValueDerivedFromImpl(Value value, Value source, llvm::SmallPtrSetImpl<Value> &visited) {
if (value == source) return true;
if (!visited.insert(value).second) return false;
Operation *defOp = value.getDefiningOp();
if (!defOp) return false;
if (llvm::any_of(defOp->getOperands(),
[&](Value operand) { return isValueDerivedFromImpl(operand, source, visited); }))
return true;
for (Region ®ion : defOp->getRegions())
for (Block &block : region) {
Operation *terminator = block.getTerminator();
if (terminator && llvm::any_of(terminator->getOperands(),
[&](Value v) { return isValueDerivedFromImpl(v, source, visited); }))
return true;
}
return false;
}
bool isValueDerivedFrom(Value value, Value source) {
llvm::SmallPtrSet<Value, 16> visited;
return isValueDerivedFromImpl(value, source, visited);
}
void pushUniqueTransposeInfo(SmallVectorImpl<TransposeVectorInfo> &infos, TransposeVectorInfo info) {
if (llvm::any_of(infos, [&](const TransposeVectorInfo &existing) { return isSameTransposeInfo(existing, info); }))
return;
infos.push_back(std::move(info));
}
SmallVector<TransposeVectorInfo, 6> findTransposeVectorInfos(const NpuBandContext &ctx) {
SmallVector<TransposeVectorInfo, 6> infos;
auto tryAddInfo = [&](Operation *loadOp, ArrayRef<size_t> loadOrder, Operation *storeOp,
ArrayRef<size_t> storeOrder) {
if (!isValueDerivedFrom(storeOp->getOperand(0), loadOp->getResult(0))) return;
SmallVector<size_t, 6> loadCommon = intersectAxisOrder(loadOrder, storeOrder);
SmallVector<size_t, 6> storeCommon = intersectAxisOrder(storeOrder, loadOrder);
if (loadCommon.size() < 2 || loadCommon.size() != storeCommon.size() || loadCommon == storeCommon ||
!std::is_permutation(loadCommon.begin(), loadCommon.end(), storeCommon.begin()) ||
!llvm::any_of(loadCommon, [&](size_t idx) { return isTransposeAxis(ctx.axes[idx]); }))
return;
pushUniqueTransposeInfo(
infos, TransposeVectorInfo{loadCommon, storeCommon, getValueElementBits(loadOp->getResult(0)),
loadOp->getResult(0).getType().isF32()});
};
Operation *scanRoot = nullptr;
auto transposeAxisIt = llvm::find_if(ctx.axes, [](const AxisPtr &axis) { return isTransposeAxis(axis); });
if (transposeAxisIt != ctx.axes.end()) {
scanRoot = (*transposeAxisIt)->getLoopOperation();
}
SmallVector<memref::LoadOp, 8> loads;
SmallVector<memref::StoreOp, 8> stores;
if (scanRoot) {
scanRoot->walk([&](Operation *op) {
if (auto load = dyn_cast<memref::LoadOp>(op)) {
loads.push_back(load);
} else if (auto store = dyn_cast<memref::StoreOp>(op)) {
stores.push_back(store);
}
});
}
for (memref::LoadOp load : loads) {
SmallVector<size_t, 6> loadOrder = getMemrefAxisOrder(load.getIndices(), ctx);
for (memref::StoreOp store : stores) {
tryAddInfo(load, loadOrder, store, getMemrefAxisOrder(store.getIndices(), ctx));
}
}
std::stable_sort(infos.begin(), infos.end(), [](const TransposeVectorInfo &lhs, const TransposeVectorInfo &rhs) {
if (lhs.sourceAxisOrder.size() != rhs.sourceAxisOrder.size())
return lhs.sourceAxisOrder.size() > rhs.sourceAxisOrder.size();
if (lhs.elementBits != rhs.elementBits) return lhs.elementBits > rhs.elementBits;
return std::lexicographical_compare(lhs.sourceAxisOrder.begin(), lhs.sourceAxisOrder.end(),
rhs.sourceAxisOrder.begin(), rhs.sourceAxisOrder.end());
});
return infos;
}
int64_t getValueElementBits(Value value) {
if (!value) return 0;
Type type = value.getType();
if (auto shapedType = dyn_cast<ShapedType>(type)) type = shapedType.getElementType();
if (!type.isIntOrFloat()) return 0;
return static_cast<int64_t>(type.getIntOrFloatBitWidth());
}
int64_t getValueUbBitsPerElem(Value value) {
int64_t elementBits = getValueElementBits(value);
return elementBits >= kBitsPerByte ? elementBits : 0;
}
bool hasTrackedVectorOperand(Operation *op, const llvm::DenseMap<Value, int64_t> &liveBits) {
if (!op) return false;
for (Value operand : op->getOperands())
if (liveBits.find(operand) != liveBits.end()) return true;
return false;
}
bool inheritsMultiBufferFromOperand(Operation *op, VectorOpKind kind, const VectorUbLiveState &state) {
if (!op || (kind != VectorOpKind::Generic && kind != VectorOpKind::Heavy)) return false;
Value vectorOperand;
for (Value operand : op->getOperands()) {
auto liveIt = state.liveBits.find(operand);
if (liveIt == state.liveBits.end() || liveIt->second < kBitsPerByte) continue;
if (vectorOperand) return false;
vectorOperand = operand;
}
return vectorOperand && state.multiBufferedValues.lookup(vectorOperand);
}
VectorOpKind classifyVectorOp(Operation *op) {
if (!op) return VectorOpKind::None;
llvm::StringRef name = op->getName().getStringRef();
const llvm::StringRef compareOps[] = {"arith.cmpf", "arith.cmpi"};
const llvm::StringRef heavyPatterns[] = {"arith.div", "math.exp", "math.log", "math.erf"};
const llvm::StringRef genericPatterns[] = {"arith.add", "arith.sub", "arith.mul", "arith.max",
"arith.min", "arith.and", "arith.or", "arith.xor"};
if (op->getNumResults() == 1 && llvm::is_contained(compareOps, name)) return VectorOpKind::Compare;
if (name == "arith.select") return VectorOpKind::Select;
if (op->getNumResults() == 0) return VectorOpKind::None;
if (llvm::any_of(heavyPatterns, [&](llvm::StringRef pattern) { return name.contains(pattern); }))
return VectorOpKind::Heavy;
if (llvm::any_of(genericPatterns, [&](llvm::StringRef pattern) { return name.contains(pattern); }) ||
name.starts_with("math."))
return VectorOpKind::Generic;
return VectorOpKind::None;
}
int64_t getScalarVbrcBitsPerElem(Operation *op, VectorOpKind kind, const llvm::DenseMap<Value, int64_t> &liveBits) {
if (!hasTrackedVectorOperand(op, liveBits)) return 0;
if (kind != VectorOpKind::Compare && kind != VectorOpKind::Select) return 0;
int64_t firstScalarOperand = kind == VectorOpKind::Select ? 1 : 0;
int64_t bits = 0;
for (int64_t i = firstScalarOperand; i < static_cast<int64_t>(op->getNumOperands()); ++i) {
Value operand = op->getOperand(static_cast<unsigned>(i));
if (liveBits.find(operand) == liveBits.end()) bits += getValueUbBitsPerElem(operand);
}
return bits;
}
int64_t getOpDataBitsPerElem(Operation *op, const llvm::DenseMap<Value, int64_t> &liveBits, int64_t resultBits,
int64_t scalarVbrcBits) {
int64_t dataBits = std::max(resultBits, scalarVbrcBits);
for (Value operand : op->getOperands()) {
auto liveIt = liveBits.find(operand);
int64_t operandBits = liveIt == liveBits.end() ? getValueUbBitsPerElem(operand) : liveIt->second;
dataBits = std::max(dataBits, operandBits);
}
return dataBits;
}
int64_t getHivmWorkspaceBitsPerElem(VectorOpKind kind, int64_t dataBits) {
if (dataBits <= 0) return 0;
if (kind == VectorOpKind::Heavy) return kHeavyHivmWorkspaceBuffers * dataBits;
if (kind == VectorOpKind::Compare) return kCompareHivmWorkspaceBuffers * dataBits;
return kind == VectorOpKind::None ? 0 : kGenericHivmWorkspaceBuffers * dataBits;
}
int64_t getHivmWorkspaceBufferCount(VectorOpKind kind) {
if (kind == VectorOpKind::Heavy) return kHeavyHivmWorkspaceBuffers;
if (kind == VectorOpKind::Compare) return kCompareHivmWorkspaceBuffers;
return kind == VectorOpKind::None ? 0 : kGenericHivmWorkspaceBuffers;
}
void appendVectorReserve(VectorPeakReserveSet &set, ArrayRef<size_t> axisOrder, ArrayRef<int32_t> alignDims,
int64_t elementBits) {
if (axisOrder.empty() || elementBits < kBitsPerByte) return;
VectorLiveReserve reserve;
reserve.axisOrder.assign(axisOrder.begin(), axisOrder.end());
reserve.alignDims.assign(alignDims.begin(), alignDims.end());
reserve.elementBits = elementBits;
set.liveBuffers.push_back(std::move(reserve));
}
int64_t getSelectTempBytes(VectorOpKind kind, bool hasVectorInput, int64_t resultBits) {
return (kind == VectorOpKind::Select && hasVectorInput && resultBits >= kBitsPerByte) ? kUbAlignBytes : 0;
}
int64_t getNodeResultUbBitsPerElem(const NodePtr &node, Operation *op, VectorOpKind kind, bool hasVectorInput) {
if (op->getNumResults() == 0) return 0;
if (kind == VectorOpKind::Compare) return 0;
if (node->opType != "Load" && !hasVectorInput) return 0;
int64_t bits = 0;
for (Value result : op->getResults()) bits = std::max(bits, getValueUbBitsPerElem(result));
return bits;
}
bool hasBroadcastVectorAxis(const NpuBandContext &ctx) {
for (const auto &axis : ctx.axes) {
Operation *loopOp = axis ? axis->getLoopOperation() : nullptr;
if (loopOp && (loopOp->hasAttr(kBroadcastLoopAttr) || loopOp->hasAttr(kNotInnerDimensionBroadcastLoopAttr)))
return true;
}
return false;
}
void recordVectorUbPeakState(VectorUbFootprint &footprint, const NpuBandContext &ctx, const VectorUbLiveState &state,
const NodePtr &node, ArrayRef<size_t> nodeAxisOrder, int64_t scalarVbrcBits,
int64_t resultBits, int64_t resultBufferCopies, int64_t opDataBits, int64_t workspaceBits,
VectorOpKind kind, bool hasVectorInput) {
footprint.peakBitsPerElem =
std::max(footprint.peakBitsPerElem, state.liveBitsSum + state.liveMultiBufferExtraBitsSum + scalarVbrcBits +
resultBits * resultBufferCopies + workspaceBits);
VectorPeakReserveSet set;
set.fixedBytes = getSelectTempBytes(kind, hasVectorInput, resultBits);
SmallVector<int32_t, 6> nodeAlignDims = getDefaultStrideAlignDims(nodeAxisOrder);
for (const auto &entry : state.liveBits) {
auto axisIt = state.liveAxisOrders.find(entry.first);
if (axisIt == state.liveAxisOrders.end()) continue;
SmallVector<int32_t, 6> alignDims = state.liveAlignDims.lookup(entry.first);
appendVectorReserve(set, axisIt->second, alignDims, entry.second);
if (state.multiBufferedValues.lookup(entry.first))
appendVectorReserve(set, axisIt->second, alignDims, entry.second);
}
appendVectorReserve(set, nodeAxisOrder, nodeAlignDims, scalarVbrcBits);
for (int64_t i = 0; i < resultBufferCopies; ++i) appendVectorReserve(set, nodeAxisOrder, nodeAlignDims, resultBits);
for (int64_t i = 0, e = getHivmWorkspaceBufferCount(kind); i < e; ++i) {
appendVectorReserve(set, nodeAxisOrder, nodeAlignDims, opDataBits);
}
if (!set.liveBuffers.empty() || set.fixedBytes > 0) {
footprint.fixedReserveBytes = std::max(footprint.fixedReserveBytes, set.fixedBytes);
footprint.vectorPeakReserveSets.push_back(std::move(set));
}
if (node->opType != "Load") return;
SmallVector<size_t, 6> currentOrder = getNodeAxisOrder(node, ctx.axes);
bool isKnownTransposeLoad = llvm::any_of(
ctx.transposeInfos, [&](const TransposeVectorInfo &info) { return currentOrder == info.sourceAxisOrder; });
if (!isKnownTransposeLoad && (ctx.transposeSourceAxisOrder.empty() || currentOrder != ctx.transposeSourceAxisOrder))
return;
SmallVector<TransposeCoLiveReserve, 4> liveSet;
for (const auto &entry : state.liveBits) {
if (entry.second < kBitsPerByte) continue;
auto axisIt = state.liveAxisOrders.find(entry.first);
if (axisIt == state.liveAxisOrders.end() || axisIt->second.empty()) continue;
liveSet.push_back(TransposeCoLiveReserve{axisIt->second, state.liveAlignDims.lookup(entry.first), entry.second});
}
if (!liveSet.empty()) footprint.transposeCoLiveReserveSets.push_back(std::move(liveSet));
}
void updateVectorUbLiveState(VectorUbFootprint &footprint, VectorUbLiveState &state, const NodePtr &node, Operation *op,
ArrayRef<size_t> nodeAxisOrder, int64_t resultBits, VectorOpKind kind, bool hasVectorInput,
bool inheritsMultiBuffer) {
bool addLoadResult = node->opType == "Load";
bool addVectorResult = node->opType != "Store" && hasVectorInput;
if (addLoadResult || addVectorResult) {
unsigned resultCount = addLoadResult ? std::min<unsigned>(op->getNumResults(), 1) : op->getNumResults();
for (unsigned i = 0; i < resultCount; ++i) {
Value result = op->getResult(i);
int64_t bits = addLoadResult ? resultBits : (kind == VectorOpKind::Compare ? 0 : getValueUbBitsPerElem(result));
auto liveIt = state.liveBits.find(result);
if (liveIt != state.liveBits.end()) {
state.liveBitsSum -= liveIt->second;
state.liveMultiBufferExtraBitsSum -= state.multiBufferedValues.lookup(result) * liveIt->second;
}
if (inheritsMultiBuffer && bits >= kBitsPerByte) state.multiBufferedValues[result] = true;
state.liveBits[result] = bits;
state.liveAxisOrders[result] = SmallVector<size_t, 6>(nodeAxisOrder.begin(), nodeAxisOrder.end());
state.liveAlignDims[result] = getDefaultStrideAlignDims(nodeAxisOrder);
state.liveBitsSum += bits;
state.liveMultiBufferExtraBitsSum += state.multiBufferedValues.lookup(result) * bits;
footprint.peakBitsPerElem =
std::max(footprint.peakBitsPerElem, state.liveBitsSum + state.liveMultiBufferExtraBitsSum);
}
}
SmallVector<Value, 8> valuesToErase;
for (Value operand : op->getOperands()) {
auto useIt = state.remainingUses.find(operand);
if (useIt == state.remainingUses.end()) continue;
if (--useIt->second != 0) continue;
state.remainingUses.erase(useIt);
valuesToErase.push_back(operand);
}
for (Value result : op->getResults())
if (state.remainingUses.lookup(result) == 0) valuesToErase.push_back(result);
for (Value value : valuesToErase) {
auto liveIt = state.liveBits.find(value);
if (liveIt == state.liveBits.end()) continue;
state.liveBitsSum -= liveIt->second;
state.liveMultiBufferExtraBitsSum -= state.multiBufferedValues.lookup(value) * liveIt->second;
state.liveBits.erase(liveIt);
state.liveAxisOrders.erase(value);
state.liveAlignDims.erase(value);
}
}
VectorUbFootprint computeVectorUbFootprint(ArrayRef<NodePtr> bandNodes, const NpuBandContext &ctx) {
VectorUbFootprint footprint;
footprint.vectorElementBits = std::max<int64_t>(ctx.smallestTypeBits, 1);
if (bandNodes.empty()) return footprint;
VectorUbLiveState state;
for (const auto &node : bandNodes) {
Operation *op = node->op();
for (Value result : op->getResults()) {
footprint.vectorElementBits = std::max(footprint.vectorElementBits, getValueElementBits(result));
}
for (Value operand : op->getOperands()) {
footprint.vectorElementBits = std::max(footprint.vectorElementBits, getValueElementBits(operand));
++state.remainingUses[operand];
}
if (node->opType == "Load" && op->getNumResults() > 0) state.multiBufferedValues[op->getResult(0)] = true;
if (node->opType == "Store" && op->getNumOperands() > 0) state.multiBufferedValues[op->getOperand(0)] = true;
}
for (const auto &node : bandNodes) {
Operation *op = node->op();
VectorOpKind kind = classifyVectorOp(op);
bool hasVectorInput = hasTrackedVectorOperand(op, state.liveBits);
int64_t resultBits = getNodeResultUbBitsPerElem(node, op, kind, hasVectorInput);
int64_t scalarVbrcBits = getScalarVbrcBitsPerElem(op, kind, state.liveBits);
int64_t opDataBits = getOpDataBitsPerElem(op, state.liveBits, resultBits, scalarVbrcBits);
int64_t workspaceBits = getHivmWorkspaceBitsPerElem(kind, opDataBits);
SmallVector<size_t, 6> nodeAxisOrder = getNodeAxisOrder(node, ctx.axes);
bool inheritsMultiBuffer = resultBits >= kBitsPerByte && inheritsMultiBufferFromOperand(op, kind, state);
int64_t resultBufferCopies =
(inheritsMultiBuffer ||
llvm::any_of(op->getResults(), [&](Value result) { return state.multiBufferedValues.lookup(result); }))
? kHivmAutoMultiBufferFactor
: 1;
recordVectorUbPeakState(footprint, ctx, state, node, nodeAxisOrder, scalarVbrcBits, resultBits, resultBufferCopies,
opDataBits, workspaceBits, kind, hasVectorInput);
updateVectorUbLiveState(footprint, state, node, op, nodeAxisOrder, resultBits, kind, hasVectorInput,
inheritsMultiBuffer);
}
if (ctx.graphTemplate == GraphTemplate::TRANSPOSE_OP ||
llvm::any_of(ctx.axes, [](const AxisPtr &axis) { return isTransposeAxis(axis); })) {
footprint.peakBitsPerElem =
std::max(footprint.peakBitsPerElem, kTransposeOrBroadcastFootprintBuffers * footprint.vectorElementBits);
} else if (ctx.graphTemplate == GraphTemplate::BROADCAST_OP || hasBroadcastVectorAxis(ctx)) {
footprint.peakBitsPerElem =
std::max(footprint.peakBitsPerElem, kTransposeOrBroadcastFootprintBuffers * footprint.vectorElementBits);
}
return footprint;
}
BandNodeInfo collectBandNodeInfo(const NpuModelGraphPtr &npuGraph, const NpuBandContext &ctx) {
BandNodeInfo info;
if (!npuGraph) return info;
for (const auto &node : npuGraph->nodes()) {
if (!isNodeInBand(node, ctx.bandIdx)) continue;
info.facts.mathComplexityScore =
std::min<int64_t>(info.facts.mathComplexityScore + getNodeMathComplexity(node), 64);
Operation *op = node->op();
if (op) info.nodes.push_back(node);
if (auto loadOp = dyn_cast_or_null<memref::LoadOp>(op)) {
int64_t loadBits = getValueElementBits(loadOp.getResult());
if (loadBits >= kBitsPerByte && loadBits <= kDefaultTypeBits) {
info.facts.maxVectorLoadBits = std::max(info.facts.maxVectorLoadBits, loadBits);
info.facts.hasF32VectorLoad |= loadOp.getResult().getType().isF32();
}
}
}
info.transposeInfos = findTransposeVectorInfos(ctx);
return info;
}
NpuBandContext buildNpuBandContext(const NpuModelGraphPtr &npuGraph, size_t bandIdx,
const SmallVector<AxisPtr, 4> &bandAxes) {
NpuBandContext ctx;
ctx.bandIdx = bandIdx;
ctx.axes = bandAxes;
if (!npuGraph) return ctx;
ctx.graphTemplate = npuGraph->graphTemplate;
ctx.targetBlocks =
(npuGraph->coreNum > 0) ? static_cast<int64_t>(npuGraph->coreNum) : static_cast<int64_t>(kNpuTargetBlocks);
UbCapacity ub = computeUbCapacity(npuGraph);
ctx.rawUbElems = ub.rawUbElems;
ctx.ubCapacityElems = ub.ubCapacityElems;
ctx.vectorUbCapacityElems = ub.ubCapacityElems;
ctx.smallestTypeBits = ub.smallestTypeBits;
for (const auto &axis : bandAxes) {
ctx.extents.push_back((axis && axis->range.second > 0) ? axis->range.second : 1);
ctx.hasDynamicAxis |= isDynamicAxis(axis);
ctx.hasReduction |= isReductionAxis(axis);
}
ctx.lastAxisIsReduction = !bandAxes.empty() && isReductionAxis(bandAxes.back());
BandNodeInfo nodeInfo = collectBandNodeInfo(npuGraph, ctx);
const BandNodeFacts &nodeFacts = nodeInfo.facts;
ctx.mathComplexityScore = nodeFacts.mathComplexityScore;
ctx.transposeElementBits =
nodeFacts.maxVectorLoadBits > 0 ? nodeFacts.maxVectorLoadBits : std::max<int64_t>(ctx.smallestTypeBits, 1);
ctx.transposeElementIsF32 = nodeFacts.hasF32VectorLoad;
if (!nodeInfo.transposeInfos.empty()) {
ctx.transposeInfos = nodeInfo.transposeInfos;
const TransposeVectorInfo &primary = ctx.transposeInfos.front();
ctx.transposeSourceAxisOrder = primary.sourceAxisOrder;
ctx.transposeTargetAxisOrder = primary.targetAxisOrder;
ctx.transposeElementBits = std::max<int64_t>(primary.elementBits, 1);
ctx.transposeElementIsF32 = primary.isF32;
}
ctx.axesAlignUnits.assign(ctx.axes.size(), 1);
for (size_t i = 0; i < ctx.axes.size(); ++i) {
ctx.axesAlignUnits[i] = std::max<int64_t>(computeAxisAlignSize(i, ctx), 1);
}
applyF32LastDimDoubleAlign(ctx, ctx.axesAlignUnits);
VectorUbFootprint footprint = computeVectorUbFootprint(nodeInfo.nodes, ctx);
ctx.vectorElementBits = std::max<int64_t>(footprint.vectorElementBits, ctx.smallestTypeBits);
ctx.vectorUbBitsPerElem = std::max<int64_t>(footprint.peakBitsPerElem, ctx.vectorElementBits);
ctx.transposeCoLiveReserveSets = footprint.transposeCoLiveReserveSets;
ctx.vectorPeakReserveSets = std::move(footprint.vectorPeakReserveSets);
ctx.fixedReserveBytes = footprint.fixedReserveBytes;
int64_t rawUbBits = multiplyAndCap(ctx.rawUbElems, ctx.smallestTypeBits);
int64_t fixedReserveBits = multiplyAndCap(footprint.fixedReserveBytes, kBitsPerByte);
int64_t vectorUbBits = rawUbBits > fixedReserveBits ? rawUbBits - fixedReserveBits : kUbAlignBits;
ctx.vectorUbCapacityElems = alignUbElems(vectorUbBits / ctx.vectorUbBitsPerElem, ctx.vectorElementBits);
return ctx;
}
int64_t getSliceProduct(ArrayRef<int64_t> extents, size_t from, size_t to) {
int64_t p = 1;
for (size_t i = from; i < to && i < extents.size(); ++i) p = multiplyAndCap(p, std::max<int64_t>(extents[i], 1));
return p;
}
inline bool isSmallSemanticAxis(const NpuBandContext &ctx, int64_t extent) {
constexpr int64_t kMinSmallSemanticAxisLimit = 8;
constexpr int64_t kSmallSemanticTargetBlockDivisor = 3;
return extent > 1 &&
extent <= std::max<int64_t>(kMinSmallSemanticAxisLimit, ctx.targetBlocks / kSmallSemanticTargetBlockDivisor);
}
size_t computeReductionSuffixStart(const NpuBandContext &ctx) {
size_t suffixStart = ctx.axes.size() - 1;
int64_t suffixElems = ctx.extents.back();
while (suffixStart > 0) {
int64_t candidate = ctx.extents[suffixStart - 1];
if (!isSmallSemanticAxis(ctx, candidate)) break;
if (suffixElems > ubCapacityForPointTile(ctx) / candidate) break;
--suffixStart;
suffixElems = multiplyAndCap(suffixElems, candidate);
}
return suffixStart;
}
size_t computeBroadcastSuffixStart(const NpuBandContext &ctx, int64_t tileBudget) {
constexpr size_t kMinStructuredSuffixRank = 2;
constexpr size_t kMaxStructuredSuffixRank = 3;
constexpr int64_t kMinWideInnermostExtent = 32;
constexpr int64_t kWideInnermostTargetBlockDivisor = 2;
if (ctx.axes.size() <= kMinStructuredSuffixRank) return ctx.axes.size() - 1;
int64_t wideThresh = std::max<int64_t>(kMinWideInnermostExtent, ctx.targetBlocks / kWideInnermostTargetBlockDivisor);
int64_t innerExtent = ctx.extents.back();
int64_t ubPerPoint = tileBudget / kSuffixPreservePointRows;
size_t bestStart = ctx.axes.size() - 1;
for (size_t start = ctx.axes.size() - 1; start > 0; --start) {
size_t candidateStart = start - 1;
size_t suffixRank = ctx.axes.size() - candidateStart;
if (suffixRank < kMinStructuredSuffixRank || suffixRank > kMaxStructuredSuffixRank) continue;
int64_t suffixElems = getSliceProduct(ctx.extents, candidateStart, ctx.extents.size());
if (suffixElems > ubPerPoint) break;
if (innerExtent < wideThresh) continue;
bool hasSmall = false;
for (size_t i = candidateStart; i + 1 < ctx.extents.size(); ++i) {
if (isSmallSemanticAxis(ctx, ctx.extents[i])) {
hasSmall = true;
break;
}
}
if (!hasSmall) continue;
bestStart = candidateStart;
}
return bestStart;
}
int64_t chooseBiasedTileSizeForTileCount(int64_t extent, int64_t tileCount) {
constexpr int64_t kTileSizeBiasNumerator = 2;
constexpr int64_t kTileSizeBiasDenominator = 3;
constexpr int64_t kTwoTileCount = 2;
if (extent <= 1 || tileCount <= 1) return extent;
int64_t low = ceilDivInt64(extent, tileCount);
int64_t high = (tileCount == kTwoTileCount) ? (extent - 1) : (ceilDivInt64(extent, tileCount - 1) - 1);
high = std::max<int64_t>(high, low);
return std::clamp(low + (high - low) * kTileSizeBiasNumerator / kTileSizeBiasDenominator, int64_t{1}, extent);
}
SmallVector<unsigned, 4> assignPrefixOuterTiles(ArrayRef<int64_t> prefixExtents, int64_t targetBlocks) {
constexpr int64_t kMaxSmallPrefixTileCount = 2;
SmallVector<unsigned, 4> outerTiles;
outerTiles.reserve(prefixExtents.size());
int64_t producedTiles = 1;
for (size_t i = 0; i < prefixExtents.size(); ++i) {
int64_t extent = prefixExtents[i];
if (extent == 1) {
outerTiles.push_back(1);
continue;
}
int64_t remaining = ceilDivInt64(targetBlocks, producedTiles);
int64_t desired = std::min<int64_t>(extent, remaining);
if (i > 0 && extent <= targetBlocks && desired > kMaxSmallPrefixTileCount) {
desired = kMaxSmallPrefixTileCount;
}
int64_t tile = chooseBiasedTileSizeForTileCount(extent, desired);
outerTiles.push_back(saturateToTileValue(tile));
producedTiles = multiplyAndCap(producedTiles, ceilDivInt64(extent, tile));
}
return outerTiles;
}
void initWholeBandPlan(const NpuBandContext &ctx, BandTilePlan &plan) {
plan.outerTiles.clear();
plan.innerTiles.clear();
plan.outerTiles.reserve(ctx.extents.size());
plan.innerTiles.reserve(ctx.extents.size());
for (int64_t extent : ctx.extents) {
unsigned tile = saturateToTileValue(extent);
plan.outerTiles.push_back(tile);
plan.innerTiles.push_back(tile);
}
}
int64_t getVectorUbBytes(const NpuBandContext &ctx) {
int64_t rawBytes = multiplyAndCap(ctx.rawUbElems, ctx.smallestTypeBits) / kBitsPerByte;
return std::max<int64_t>(rawBytes - ctx.fixedReserveBytes, kUbAlignBytes);
}
SmallVector<int64_t, 6> collectBishengStrideAlignUnits(ArrayRef<int64_t> shape, ArrayRef<char> staticDims,
ArrayRef<int32_t> alignDims, int64_t elementBits) {
SmallVector<int64_t, 6> alignUnits(shape.size() + 1, 1);
if (shape.empty() || alignDims.empty()) return alignUnits;
int64_t bitWidth = std::max<int64_t>(elementBits, 1);
SmallVector<int64_t, 6> alignTargets(shape.size(), 1);
for (int32_t dim : alignDims) {
if (dim < 0 || static_cast<size_t>(dim) >= shape.size()) continue;
if (bitWidth % kUbAlignBits == 0) continue;
if (kUbAlignBits % bitWidth != 0) continue;
size_t idx = static_cast<size_t>(dim);
alignTargets[idx] = std::lcm(alignTargets[idx], kUbAlignBits / bitWidth);
}
int64_t innerAlignedUnits = 1;
int64_t shapeAccumulation = 1;
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
size_t idx = static_cast<size_t>(dim);
int64_t newAlignedUnits = std::lcm(innerAlignedUnits, alignTargets[idx]);
alignUnits[idx + 1] = (shapeAccumulation % newAlignedUnits == 0) ? 1 : newAlignedUnits / innerAlignedUnits;
innerAlignedUnits = newAlignedUnits;
if (staticDims[idx]) {
shapeAccumulation = multiplyAndCap(shapeAccumulation, std::lcm(shape[idx], alignUnits[idx + 1]));
}
}
return alignUnits;
}
int64_t computeReserveBytes(const NpuBandContext &ctx, const VectorLiveReserve &reserve, ArrayRef<int64_t> axisTiles) {
SmallVector<int64_t, 6> shape;
SmallVector<char, 6> staticDims;
shape.reserve(reserve.axisOrder.size());
staticDims.reserve(reserve.axisOrder.size());
for (size_t axisIdx : reserve.axisOrder) {
int64_t tile = axisIdx < axisTiles.size() ? axisTiles[axisIdx] : 1;
int64_t extent = axisIdx < ctx.extents.size() ? ctx.extents[axisIdx] : tile;
shape.push_back(std::max<int64_t>(tile, 1));
staticDims.push_back(tile >= extent);
}
SmallVector<int64_t, 6> alignUnits =
collectBishengStrideAlignUnits(shape, staticDims, reserve.alignDims, reserve.elementBits);
int64_t elems = 1;
for (size_t i = 0; i < alignUnits.size(); ++i) {
int64_t dim = i < shape.size() ? shape[i] : 1;
elems = multiplyAndCap(elems, alignUpInt64(std::max<int64_t>(dim, 1), alignUnits[i]));
}
return alignUpInt64(ceilDivInt64(multiplyAndCap(elems, reserve.elementBits), kBitsPerByte), kUbAlignBytes);
}
int64_t computeLogicalShapeBytes(ArrayRef<int64_t> shape, int64_t elementBits) {
int64_t elems = 1;
for (int64_t dim : shape) elems = multiplyAndCap(elems, std::max<int64_t>(dim, 1));
return alignUpInt64(ceilDivInt64(multiplyAndCap(elems, elementBits), kBitsPerByte), kUbAlignBytes);
}
int64_t computeLogicalReserveBytes(const VectorLiveReserve &reserve, ArrayRef<int64_t> axisTiles) {
SmallVector<int64_t, 6> shape;
shape.reserve(reserve.axisOrder.size());
for (size_t axisIdx : reserve.axisOrder) {
int64_t tile = axisIdx < axisTiles.size() ? axisTiles[axisIdx] : 1;
shape.push_back(std::max<int64_t>(tile, 1));
}
return computeLogicalShapeBytes(shape, reserve.elementBits);
}
bool satisfiesStrideAlignWithoutExpansion(const NpuBandContext &ctx, const VectorLiveReserve &reserve,
ArrayRef<int64_t> axisTiles) {
return computeReserveBytes(ctx, reserve, axisTiles) <= computeLogicalReserveBytes(reserve, axisTiles);
}
int64_t computeVectorPeakReserveBytes(const NpuBandContext &ctx, ArrayRef<int64_t> axisTiles) {
int64_t peakBytes = 0;
for (const VectorPeakReserveSet &set : ctx.vectorPeakReserveSets) {
int64_t bytes = 0;
for (const VectorLiveReserve &reserve : set.liveBuffers) {
int64_t reserveBytes = computeReserveBytes(ctx, reserve, axisTiles);
bytes = (bytes > LLONG_MAX - reserveBytes) ? LLONG_MAX : bytes + reserveBytes;
}
peakBytes = std::max(peakBytes, bytes);
}
return peakBytes;
}
int64_t computeAlignedShapeBytes(ArrayRef<int64_t> shape, ArrayRef<int64_t> alignBytes, int64_t elemBytes) {
int64_t elems = 1;
for (size_t i = 0; i < shape.size(); ++i) {
int64_t alignUnit = alignBytes[i] > 0 ? std::max<int64_t>(alignBytes[i] / elemBytes, 1) : 1;
elems = multiplyAndCap(elems, alignUpInt64(std::max<int64_t>(shape[i], 1), alignUnit));
}
return alignUpInt64(multiplyAndCap(elems, elemBytes), kUbAlignBytes);
}
bool satisfiesAlignedShapeWithoutExpansion(ArrayRef<int64_t> shape, ArrayRef<int64_t> alignBytes, int64_t elemBytes) {
return computeAlignedShapeBytes(shape, alignBytes, elemBytes) <=
computeLogicalShapeBytes(shape, multiplyAndCap(elemBytes, kBitsPerByte));
}
void computeLastDimTransposeAlignBytes(ArrayRef<int64_t> srcShape, size_t dim0, size_t dim1, int64_t elemBytes,
bool isF32Transpose, SmallVectorImpl<int64_t> &dstShape,
SmallVectorImpl<int64_t> &srcAlign,
SmallVectorImpl<int64_t> &dstAlign) {
dstShape.assign(srcShape.begin(), srcShape.end());
std::swap(dstShape[dim0], dstShape[dim1]);
srcAlign.assign(srcShape.size(), 0);
dstAlign.assign(dstShape.size(), 0);
if (multiplyAndCap(srcShape[dim0], elemBytes) % kUbAlignBytes != 0) srcAlign[dim0] = kUbAlignBytes;
if (multiplyAndCap(srcShape[dim1], elemBytes) % kUbAlignBytes != 0) srcAlign[dim1] = kUbAlignBytes;
if (multiplyAndCap(dstShape[dim0], elemBytes) % kUbAlignBytes != 0) dstAlign[dim0] = kUbAlignBytes;
if (multiplyAndCap(dstShape[dim1], elemBytes) % kUbAlignBytes != 0) dstAlign[dim1] = kUbAlignBytes;
bool hasDoubleAlignedDim =
alignUpInt64(multiplyAndCap(srcShape[dim0], elemBytes), kUbAlignBytes) % (kUbAlignBytes * 2) == 0 ||
alignUpInt64(multiplyAndCap(srcShape[dim1], elemBytes), kUbAlignBytes) % (kUbAlignBytes * 2) == 0;
if (isF32Transpose && !hasDoubleAlignedDim) {
srcAlign[dim0] = kUbAlignBytes;
srcAlign[dim1] = kUbAlignBytes * 2;
dstAlign[dim0] = kUbAlignBytes * 2;
dstAlign[dim1] = kUbAlignBytes;
}
}
int64_t computeLastDimTransposeAllocBytes(ArrayRef<int64_t> srcShape, size_t dim0, size_t dim1, int64_t elemBytes,
bool isF32Transpose) {
SmallVector<int64_t, 6> dstShape, srcAlign, dstAlign;
computeLastDimTransposeAlignBytes(srcShape, dim0, dim1, elemBytes, isF32Transpose, dstShape, srcAlign, dstAlign);
int64_t srcReserve =
multiplyAndCap(kHivmAutoMultiBufferFactor, computeAlignedShapeBytes(srcShape, srcAlign, elemBytes));
int64_t dstReserve =
multiplyAndCap(kHivmAutoMultiBufferFactor, computeAlignedShapeBytes(dstShape, dstAlign, elemBytes));
return (srcReserve > LLONG_MAX - dstReserve) ? LLONG_MAX : srcReserve + dstReserve;
}
int64_t computeTransposeAllocReserveBytes(ArrayRef<int64_t> srcShape, ArrayRef<size_t> sourceOrder,
ArrayRef<size_t> targetOrder, int64_t elemBytes, bool isF32Transpose) {
if (srcShape.size() != sourceOrder.size() || sourceOrder.size() != targetOrder.size()) return 0;
int64_t peakBytes = 0;
SmallVector<size_t, 6> transposeDims;
for (size_t i = 0; i < sourceOrder.size(); ++i)
if (sourceOrder[i] != targetOrder[i]) transposeDims.push_back(i);
if (transposeDims.size() == 2 && llvm::is_contained(transposeDims, sourceOrder.size() - 1))
return computeLastDimTransposeAllocBytes(srcShape, transposeDims[0], transposeDims[1], elemBytes, isF32Transpose);
if (transposeDims.size() <= 2) return peakBytes;
SmallVector<size_t, 6> currentOrder(sourceOrder.begin(), sourceOrder.end());
SmallVector<int64_t, 6> currentShape(srcShape.begin(), srcShape.end());
for (size_t targetPos = 0; targetPos < targetOrder.size(); ++targetPos) {
auto it = llvm::find(currentOrder, targetOrder[targetPos]);
if (it == currentOrder.end()) return peakBytes;
for (size_t i = static_cast<size_t>(it - currentOrder.begin()); i > targetPos; --i) {
if (i == currentOrder.size() - 1) {
peakBytes =
std::max(peakBytes, computeLastDimTransposeAllocBytes(currentShape, i - 1, i, elemBytes, isF32Transpose));
}
std::swap(currentOrder[i - 1], currentOrder[i]);
std::swap(currentShape[i - 1], currentShape[i]);
}
}
return peakBytes;
}
bool satisfiesLastDimTransposeAllocAlign(ArrayRef<int64_t> srcShape, size_t dim0, size_t dim1, int64_t elemBytes,
bool isF32Transpose) {
SmallVector<int64_t, 6> dstShape, srcAlign, dstAlign;
computeLastDimTransposeAlignBytes(srcShape, dim0, dim1, elemBytes, isF32Transpose, dstShape, srcAlign, dstAlign);
return satisfiesAlignedShapeWithoutExpansion(srcShape, srcAlign, elemBytes) &&
satisfiesAlignedShapeWithoutExpansion(dstShape, dstAlign, elemBytes);
}
bool satisfiesTransposeAllocAlign(ArrayRef<int64_t> srcShape, ArrayRef<size_t> sourceOrder,
ArrayRef<size_t> targetOrder, int64_t elemBytes, bool isF32Transpose) {
if (srcShape.size() != sourceOrder.size() || sourceOrder.size() != targetOrder.size()) return false;
SmallVector<size_t, 6> transposeDims;
for (size_t i = 0; i < sourceOrder.size(); ++i)
if (sourceOrder[i] != targetOrder[i]) transposeDims.push_back(i);
if (transposeDims.size() == 2 && llvm::is_contained(transposeDims, sourceOrder.size() - 1)) {
return satisfiesLastDimTransposeAllocAlign(srcShape, transposeDims[0], transposeDims[1], elemBytes, isF32Transpose);
}
if (transposeDims.size() <= 2) return true;
SmallVector<size_t, 6> currentOrder(sourceOrder.begin(), sourceOrder.end());
SmallVector<int64_t, 6> currentShape(srcShape.begin(), srcShape.end());
for (size_t targetPos = 0; targetPos < targetOrder.size(); ++targetPos) {
auto it = llvm::find(currentOrder, targetOrder[targetPos]);
if (it == currentOrder.end()) return false;
for (size_t i = static_cast<size_t>(it - currentOrder.begin()); i > targetPos; --i) {
if (i == currentOrder.size() - 1 &&
!satisfiesLastDimTransposeAllocAlign(currentShape, i - 1, i, elemBytes, isF32Transpose))
return false;
std::swap(currentOrder[i - 1], currentOrder[i]);
std::swap(currentShape[i - 1], currentShape[i]);
}
}
return true;
}
int64_t computeStrideAlignedTransposeBytes(ArrayRef<int64_t> shape, ArrayRef<size_t> sourceOrder,
ArrayRef<size_t> targetOrder, int64_t elemBytes) {
if (shape.size() <= 2 || shape.size() != sourceOrder.size() || sourceOrder.size() != targetOrder.size()) return 0;
auto findPos = [](ArrayRef<size_t> order, size_t axis) -> std::optional<size_t> {
auto it = llvm::find(order, axis);
if (it == order.end()) return std::nullopt;
return static_cast<size_t>(it - order.begin());
};
auto computePairBytes = [&](size_t lhsAxis, size_t rhsAxis) {
std::optional<size_t> lhs = findPos(sourceOrder, lhsAxis);
std::optional<size_t> rhs = findPos(sourceOrder, rhsAxis);
if (!lhs || !rhs) return int64_t{0};
SmallVector<int64_t, 6> alignBytes(shape.size(), 0);
alignBytes[*lhs] = kUbAlignBytes;
alignBytes[*rhs] = kUbAlignBytes;
return computeAlignedShapeBytes(shape, alignBytes, elemBytes);
};
SmallVector<size_t, 6> current(sourceOrder.begin(), sourceOrder.end());
int64_t peakBytes = 0;
for (size_t targetPos = 0; targetPos < targetOrder.size(); ++targetPos) {
std::optional<size_t> pos = findPos(current, targetOrder[targetPos]);
if (!pos) return peakBytes;
for (size_t i = *pos; i > targetPos; --i) {
if (i == current.size() - 1) peakBytes = std::max(peakBytes, computePairBytes(current[i - 1], current[i]));
std::swap(current[i - 1], current[i]);
}
}
return peakBytes;
}
bool satisfiesStrideAlignedTranspose(ArrayRef<int64_t> shape, ArrayRef<size_t> sourceOrder,
ArrayRef<size_t> targetOrder, int64_t elemBytes) {
if (shape.size() <= 2 || shape.size() != sourceOrder.size() || sourceOrder.size() != targetOrder.size()) return true;
auto findPos = [](ArrayRef<size_t> order, size_t axis) -> std::optional<size_t> {
auto it = llvm::find(order, axis);
if (it == order.end()) return std::nullopt;
return static_cast<size_t>(it - order.begin());
};
auto satisfiesPair = [&](size_t lhsAxis, size_t rhsAxis) {
std::optional<size_t> lhs = findPos(sourceOrder, lhsAxis);
std::optional<size_t> rhs = findPos(sourceOrder, rhsAxis);
if (!lhs || !rhs) return false;
SmallVector<int64_t, 6> alignBytes(shape.size(), 0);
alignBytes[*lhs] = kUbAlignBytes;
alignBytes[*rhs] = kUbAlignBytes;
return satisfiesAlignedShapeWithoutExpansion(shape, alignBytes, elemBytes);
};
SmallVector<size_t, 6> current(sourceOrder.begin(), sourceOrder.end());
for (size_t targetPos = 0; targetPos < targetOrder.size(); ++targetPos) {
std::optional<size_t> pos = findPos(current, targetOrder[targetPos]);
if (!pos) return false;
for (size_t i = *pos; i > targetPos; --i) {
if (i == current.size() - 1 && !satisfiesPair(current[i - 1], current[i])) return false;
std::swap(current[i - 1], current[i]);
}
}
return true;
}
bool fitsVectorUb(const NpuBandContext &ctx, ArrayRef<int64_t> axisTiles) {
return computeVectorPeakReserveBytes(ctx, axisTiles) <= getVectorUbBytes(ctx);
}
int64_t findMaxFittingTile(const NpuBandContext &ctx, SmallVectorImpl<int64_t> &axisTiles, size_t axisIdx,
int64_t upper) {
int64_t best = 1;
int64_t low = 1;
int64_t high = std::max<int64_t>(upper, 1);
while (low <= high) {
int64_t mid = (low + high) / 2;
axisTiles[axisIdx] = mid;
if (fitsVectorUb(ctx, axisTiles)) {
best = mid;
low = mid + 1;
} else {
high = mid - 1;
}
}
axisTiles[axisIdx] = best;
return best;
}
void fillPrefixSuffixTiles(const NpuBandContext &ctx, ArrayRef<unsigned> prefixOuter, int64_t requestedPointRows,
BandTilePlan &plan) {
for (size_t i = 0; i < prefixOuter.size(); ++i) {
plan.outerTiles[i] = prefixOuter[i];
plan.innerTiles[i] = prefixOuter[i];
}
size_t innermost = ctx.extents.size() - 1;
SmallVector<int64_t, 4> axisTiles(plan.innerTiles.begin(), plan.innerTiles.end());
int64_t maxPointRows = std::min<int64_t>(requestedPointRows, prefixOuter.front());
for (int64_t pointRows = std::max<int64_t>(maxPointRows, 1); pointRows >= 1; --pointRows) {
axisTiles.front() = pointRows;
axisTiles[innermost] = ctx.extents[innermost];
int64_t innerTile = findMaxFittingTile(ctx, axisTiles, innermost, ctx.extents[innermost]);
if (fitsVectorUb(ctx, axisTiles)) {
plan.innerTiles.front() = saturateToTileValue(pointRows);
plan.innerTiles[innermost] = saturateToTileValue(innerTile);
return;
}
}
plan.innerTiles.front() = 1;
plan.innerTiles[innermost] = 1;
}
int64_t computeTransposeCoLiveReserveBytes(const NpuBandContext &ctx, ArrayRef<int64_t> axisTiles) {
if (ctx.transposeCoLiveReserveSets.empty()) return 0;
int64_t peakBytes = 0;
for (const auto &liveSet : ctx.transposeCoLiveReserveSets) {
int64_t liveBytes = 0;
for (const TransposeCoLiveReserve &reserve : liveSet) {
VectorLiveReserve vectorReserve{reserve.axisOrder, reserve.alignDims, reserve.elementBits};
int64_t bytes = computeReserveBytes(ctx, vectorReserve, axisTiles);
liveBytes = (liveBytes > LLONG_MAX - bytes) ? LLONG_MAX : liveBytes + bytes;
}
peakBytes = std::max(peakBytes, liveBytes);
}
return peakBytes;
}
SmallVector<size_t, 6> normalizeTransposeTargetAxisOrder(ArrayRef<size_t> sourceAxisOrder,
ArrayRef<size_t> targetAxisOrder) {
SmallVector<size_t, 6> normalized(targetAxisOrder.begin(), targetAxisOrder.end());
if (normalized.size() == sourceAxisOrder.size()) return normalized;
normalized.assign(sourceAxisOrder.begin(), sourceAxisOrder.end());
if (normalized.size() >= 2) std::swap(normalized[normalized.size() - 2], normalized.back());
return normalized;
}
int64_t computeTransposeTempReserveBytes(const NpuBandContext &ctx, ArrayRef<int64_t> axisTiles,
ArrayRef<size_t> sourceAxisOrder, ArrayRef<size_t> targetAxisOrder,
int64_t elementBits, bool elementIsF32) {
if (sourceAxisOrder.empty()) return 0;
SmallVector<size_t, 6> normalizedTargetAxisOrder =
normalizeTransposeTargetAxisOrder(sourceAxisOrder, targetAxisOrder);
SmallVector<size_t, 6> sourceOrder(sourceAxisOrder.begin(), sourceAxisOrder.end());
VectorLiveReserve source{sourceOrder, getDefaultStrideAlignDims(sourceAxisOrder), elementBits};
VectorLiveReserve target{normalizedTargetAxisOrder, getDefaultStrideAlignDims(normalizedTargetAxisOrder),
elementBits};
int64_t sourceBytes = computeReserveBytes(ctx, source, axisTiles);
int64_t targetBytes = computeReserveBytes(ctx, target, axisTiles);
int64_t strideReserveBytes = multiplyAndCap(
kHivmAutoMultiBufferFactor, sourceBytes > LLONG_MAX - targetBytes ? LLONG_MAX : sourceBytes + targetBytes);
int64_t elemBytes = std::max<int64_t>(ceilDivInt64(std::max<int64_t>(elementBits, 1), kBitsPerByte), 1);
SmallVector<int64_t, 6> sourceShape;
sourceShape.resize(sourceOrder.size());
std::transform(sourceOrder.begin(), sourceOrder.end(), sourceShape.begin(),
[&](size_t axisIdx) { return std::max<int64_t>(axisTiles[axisIdx], 1); });
int64_t allocReserveBytes =
computeTransposeAllocReserveBytes(sourceShape, sourceOrder, normalizedTargetAxisOrder, elemBytes, elementIsF32);
int64_t adjacentStrideBytes =
multiplyAndCap(kHivmAutoMultiBufferFactor + 1,
computeStrideAlignedTransposeBytes(sourceShape, sourceOrder, normalizedTargetAxisOrder, elemBytes));
return std::max(strideReserveBytes, std::max(allocReserveBytes, adjacentStrideBytes));
}
bool satisfiesVectorReserveAlignConstraints(const NpuBandContext &ctx, ArrayRef<int64_t> axisTiles) {
if (std::any_of(
ctx.vectorPeakReserveSets.begin(), ctx.vectorPeakReserveSets.end(), [&](const VectorPeakReserveSet &set) {
return std::any_of(set.liveBuffers.begin(), set.liveBuffers.end(), [&](const VectorLiveReserve &reserve) {
return !satisfiesStrideAlignWithoutExpansion(ctx, reserve, axisTiles);
});
}))
return false;
if (std::any_of(ctx.transposeCoLiveReserveSets.begin(), ctx.transposeCoLiveReserveSets.end(),
[&](const SmallVector<TransposeCoLiveReserve, 4> &liveSet) {
return std::any_of(liveSet.begin(), liveSet.end(), [&](const TransposeCoLiveReserve &reserve) {
VectorLiveReserve vectorReserve{reserve.axisOrder, reserve.alignDims, reserve.elementBits};
return !satisfiesStrideAlignWithoutExpansion(ctx, vectorReserve, axisTiles);
});
}))
return false;
return true;
}
bool satisfiesSingleTransposeAlignConstraints(ArrayRef<int64_t> axisTiles, ArrayRef<size_t> sourceAxisOrder,
ArrayRef<size_t> targetAxisOrder, int64_t elementBits,
bool elementIsF32) {
if (sourceAxisOrder.empty()) return false;
SmallVector<size_t, 6> normalizedTargetAxisOrder =
normalizeTransposeTargetAxisOrder(sourceAxisOrder, targetAxisOrder);
int64_t elemBytes = std::max<int64_t>(ceilDivInt64(std::max<int64_t>(elementBits, 1), kBitsPerByte), 1);
SmallVector<int64_t, 6> sourceShape;
sourceShape.resize(sourceAxisOrder.size());
std::transform(sourceAxisOrder.begin(), sourceAxisOrder.end(), sourceShape.begin(),
[&](size_t axisIdx) { return std::max<int64_t>(axisTiles[axisIdx], 1); });
return satisfiesTransposeAllocAlign(sourceShape, sourceAxisOrder, normalizedTargetAxisOrder, elemBytes,
elementIsF32) &&
satisfiesStrideAlignedTranspose(sourceShape, sourceAxisOrder, normalizedTargetAxisOrder, elemBytes);
}
bool satisfiesTransposeAlignConstraints(const NpuBandContext &ctx, ArrayRef<int64_t> axisTiles,
ArrayRef<size_t> fallbackSourceAxisOrder) {
if (!satisfiesVectorReserveAlignConstraints(ctx, axisTiles)) return false;
if (ctx.transposeInfos.empty())
return satisfiesSingleTransposeAlignConstraints(axisTiles, fallbackSourceAxisOrder, ctx.transposeTargetAxisOrder,
ctx.transposeElementBits, ctx.transposeElementIsF32);
return llvm::all_of(ctx.transposeInfos, [&](const TransposeVectorInfo &info) {
return satisfiesSingleTransposeAlignConstraints(axisTiles, info.sourceAxisOrder, info.targetAxisOrder,
info.elementBits, info.isF32);
});
}
int64_t getAnotherLastDimTransposeSizeAlignDim(const TransposeVectorInfo &info, size_t dim, bool &isInnerDim) {
isInnerDim = false;
if (info.sourceAxisOrder.size() != info.targetAxisOrder.size() || dim >= info.sourceAxisOrder.size()) return -1;
SmallVector<size_t, 6> transposeDims;
for (size_t i = 0; i < info.sourceAxisOrder.size(); ++i)
if (info.sourceAxisOrder[i] != info.targetAxisOrder[i]) transposeDims.push_back(i);
if (transposeDims.size() == 2) {
if (!llvm::is_contained(transposeDims, info.sourceAxisOrder.size() - 1) || !llvm::is_contained(transposeDims, dim))
return -1;
isInnerDim = dim == info.sourceAxisOrder.size() - 1;
return static_cast<int64_t>(transposeDims[dim == transposeDims[0] ? 1 : 0]);
}
size_t sourceAxis = info.sourceAxisOrder[dim];
SmallVector<size_t, 6> currentOrder(info.sourceAxisOrder.begin(), info.sourceAxisOrder.end());
for (size_t targetPos = 0; targetPos < info.targetAxisOrder.size(); ++targetPos) {
auto it = llvm::find(currentOrder, info.targetAxisOrder[targetPos]);
if (it == currentOrder.end()) return -1;
for (size_t i = static_cast<size_t>(it - currentOrder.begin()); i > targetPos; --i) {
if (i == currentOrder.size() - 1 && (currentOrder[i - 1] == sourceAxis || currentOrder[i] == sourceAxis)) {
isInnerDim = currentOrder[i] == sourceAxis;
size_t anotherAxis = isInnerDim ? currentOrder[i - 1] : currentOrder[i];
auto anotherIt = llvm::find(info.sourceAxisOrder, anotherAxis);
return anotherIt == info.sourceAxisOrder.end()
? -1
: static_cast<int64_t>(anotherIt - info.sourceAxisOrder.begin());
}
std::swap(currentOrder[i - 1], currentOrder[i]);
}
}
return -1;
}
int64_t getLastNonUnitDim(ArrayRef<size_t> sourceOrder, ArrayRef<size_t> targetOrder, const NpuBandContext &ctx,
int64_t startDim) {
for (int64_t dim = startDim; dim >= 0; --dim) {
size_t sourceAxis = sourceOrder[static_cast<size_t>(dim)];
size_t targetAxis = targetOrder[static_cast<size_t>(dim)];
int64_t sourceExtent = sourceAxis < ctx.extents.size() ? ctx.extents[sourceAxis] : 1;
int64_t targetExtent = targetAxis < ctx.extents.size() ? ctx.extents[targetAxis] : 1;
if (sourceExtent != 1 || targetExtent != 1) return dim;
}
return -1;
}
int64_t getNonLastTransposeStrideAlignDim(ArrayRef<size_t> sourceOrder, ArrayRef<size_t> targetOrder,
const NpuBandContext &ctx) {
if (sourceOrder.size() < 2 || sourceOrder.size() != targetOrder.size()) return -1;
SmallVector<size_t, 6> transposeDims;
for (size_t i = 0; i < sourceOrder.size(); ++i)
if (sourceOrder[i] != targetOrder[i]) transposeDims.push_back(i);
if (transposeDims.empty() || llvm::is_contained(transposeDims, sourceOrder.size() - 1)) return -1;
return getLastNonUnitDim(sourceOrder, targetOrder, ctx, static_cast<int64_t>(transposeDims.back()));
}
bool isNonLastTransposeStrideAlignAxis(const TransposeVectorInfo &info, size_t axisIdx, const NpuBandContext &ctx) {
if (info.sourceAxisOrder.size() < 2 || info.sourceAxisOrder.size() != info.targetAxisOrder.size()) return false;
SmallVector<size_t, 6> transposeDims;
for (size_t i = 0; i < info.sourceAxisOrder.size(); ++i)
if (info.sourceAxisOrder[i] != info.targetAxisOrder[i]) transposeDims.push_back(i);
auto matchesAxis = [&](ArrayRef<size_t> sourceOrder, ArrayRef<size_t> targetOrder, int64_t dim) {
return dim >= 0 && (sourceOrder[static_cast<size_t>(dim)] == axisIdx ||
targetOrder[static_cast<size_t>(dim)] == axisIdx);
};
if (transposeDims.size() == 2)
return matchesAxis(info.sourceAxisOrder, info.targetAxisOrder,
getNonLastTransposeStrideAlignDim(info.sourceAxisOrder, info.targetAxisOrder, ctx));
if (transposeDims.size() <= 2) return false;
SmallVector<size_t, 6> currentOrder(info.sourceAxisOrder.begin(), info.sourceAxisOrder.end());
for (size_t targetPos = 0; targetPos < info.targetAxisOrder.size(); ++targetPos) {
auto it = llvm::find(currentOrder, info.targetAxisOrder[targetPos]);
if (it == currentOrder.end()) return false;
for (size_t i = static_cast<size_t>(it - currentOrder.begin()); i > targetPos; --i) {
SmallVector<size_t, 6> nextOrder(currentOrder.begin(), currentOrder.end());
std::swap(nextOrder[i - 1], nextOrder[i]);
if (matchesAxis(currentOrder, nextOrder, getNonLastTransposeStrideAlignDim(currentOrder, nextOrder, ctx)))
return true;
currentOrder = std::move(nextOrder);
}
}
return false;
}
int64_t computeAxisAlignSize(size_t axisIdx, const NpuBandContext &ctx) {
int64_t lcmUnit = 1;
for (const TransposeVectorInfo &info : ctx.transposeInfos) {
if (info.sourceAxisOrder.size() < 2) continue;
auto it = llvm::find(info.sourceAxisOrder, axisIdx);
if (it == info.sourceAxisOrder.end()) continue;
size_t dim = static_cast<size_t>(it - info.sourceAxisOrder.begin());
int64_t elemBytes = std::max<int64_t>(ceilDivInt64(std::max<int64_t>(info.elementBits, 1), kBitsPerByte), 1);
bool isInnerDim = false;
if (getAnotherLastDimTransposeSizeAlignDim(info, dim, isInnerDim) >= 0)
lcmUnit = std::lcm(lcmUnit, std::max<int64_t>(kUbAlignBytes / elemBytes, 1));
if (isNonLastTransposeStrideAlignAxis(info, axisIdx, ctx))
lcmUnit = std::lcm(lcmUnit, std::max<int64_t>(kUbAlignBytes / elemBytes, 1));
}
return std::max<int64_t>(lcmUnit, 1);
}
void applyF32LastDimDoubleAlign(const NpuBandContext &ctx, SmallVectorImpl<int64_t> &axisAlignUnits) {
for (const TransposeVectorInfo &info : ctx.transposeInfos) {
if (!info.isF32) continue;
int64_t elemBytes =
std::max<int64_t>(ceilDivInt64(std::max<int64_t>(info.elementBits, 1), kBitsPerByte), 1);
int64_t doubleUnit = std::max<int64_t>((kUbAlignBytes * 2) / elemBytes, 1);
for (size_t dim = 0; dim < info.sourceAxisOrder.size(); ++dim) {
bool isInnerDim = false;
int64_t anotherDim = getAnotherLastDimTransposeSizeAlignDim(info, dim, isInnerDim);
if (anotherDim < 0 || !isInnerDim) continue;
size_t innerAxis = info.sourceAxisOrder[dim];
size_t anotherAxis = info.sourceAxisOrder[static_cast<size_t>(anotherDim)];
if (innerAxis >= axisAlignUnits.size() || anotherAxis >= axisAlignUnits.size()) continue;
if (axisAlignUnits[innerAxis] % doubleUnit == 0 || axisAlignUnits[anotherAxis] % doubleUnit == 0) continue;
axisAlignUnits[innerAxis] = std::lcm(axisAlignUnits[innerAxis], doubleUnit);
}
}
}
static SmallVector<int64_t, 6> buildTransposeSearchAxisTiles(const NpuBandContext &ctx, const BandTilePlan &plan,
ArrayRef<size_t> searchAxisOrder,
ArrayRef<int64_t> searchShape) {
SmallVector<int64_t, 6> axisTiles;
axisTiles.reserve(ctx.extents.size());
for (size_t i = 0; i < ctx.extents.size(); ++i) {
const int64_t fullExtent = i < ctx.extents.size() ? std::max<int64_t>(ctx.extents[i], 1) : int64_t{1};
int64_t planned = i < plan.innerTiles.size() ? std::max<int64_t>(static_cast<int64_t>(plan.innerTiles[i]), 1)
: saturateToTileValue(fullExtent);
int64_t tile = planned;
if (!llvm::is_contained(searchAxisOrder, i)) {
if (planned >= fullExtent) tile = 1;
}
axisTiles.push_back(tile);
}
for (size_t i = 0; i < searchAxisOrder.size() && i < searchShape.size(); ++i) {
axisTiles[searchAxisOrder[i]] = std::max<int64_t>(searchShape[i], 1);
}
return axisTiles;
}
static int64_t computeTransposeSearchCandidatePeak(const NpuBandContext &ctx, const BandTilePlan &plan,
ArrayRef<size_t> searchAxisOrder, ArrayRef<int64_t> searchShape) {
SmallVector<int64_t, 6> axisTiles = buildTransposeSearchAxisTiles(ctx, plan, searchAxisOrder, searchShape);
if (!satisfiesTransposeAlignConstraints(ctx, axisTiles, searchAxisOrder)) return static_cast<int64_t>(LLONG_MAX);
int64_t vectorBytes = computeVectorPeakReserveBytes(ctx, axisTiles);
int64_t transposeBytes = 0;
if (ctx.transposeInfos.empty()) {
transposeBytes = computeTransposeTempReserveBytes(ctx, axisTiles, searchAxisOrder, ctx.transposeTargetAxisOrder,
ctx.transposeElementBits, ctx.transposeElementIsF32);
} else {
for (const TransposeVectorInfo &info : ctx.transposeInfos) {
transposeBytes =
std::max(transposeBytes, computeTransposeTempReserveBytes(ctx, axisTiles, info.sourceAxisOrder,
info.targetAxisOrder, info.elementBits, info.isF32));
}
}
int64_t coLiveBytes = computeTransposeCoLiveReserveBytes(ctx, axisTiles);
transposeBytes = (transposeBytes > static_cast<int64_t>(LLONG_MAX) - coLiveBytes) ? static_cast<int64_t>(LLONG_MAX)
: transposeBytes + coLiveBytes;
return std::max<int64_t>(vectorBytes, transposeBytes);
}
bool chooseAlignedTransposeTiles(const NpuBandContext &ctx, BandTilePlan &plan, ArrayRef<size_t> searchAxisOrder,
int64_t ubLimitBytes) {
if (searchAxisOrder.empty()) return true;
const size_t n = searchAxisOrder.size();
SmallVector<int64_t, 6> axisAlignSize(n, 1);
SmallVector<int64_t, 6> axisMaxTileSize(n, 1);
for (size_t i = 0; i < n; ++i) {
size_t axisIdx = searchAxisOrder[i];
int64_t alignSize = axisIdx < ctx.axesAlignUnits.size() ? ctx.axesAlignUnits[axisIdx] : 1;
axisAlignSize[i] = std::max<int64_t>(alignSize, 1);
}
for (size_t i = 0; i < n; ++i) {
size_t axisIdx = searchAxisOrder[i];
int64_t extent = std::max<int64_t>(ctx.extents[axisIdx], 1);
axisMaxTileSize[i] = alignUpInt64(extent, axisAlignSize[i]);
}
for (size_t i = 0; i < n; ++i) {
plan.outerTiles[searchAxisOrder[i]] = 1;
plan.innerTiles[searchAxisOrder[i]] = 1;
}
SmallVector<int64_t, 6> tiles(n, 1);
for (size_t i = 0; i + 1 < n; ++i) tiles[i] = axisAlignSize[i];
tiles[n - 1] = axisMaxTileSize[n - 1];
auto peakOf = [&](ArrayRef<int64_t> shape) {
return computeTransposeSearchCandidatePeak(ctx, plan, searchAxisOrder, shape);
};
int64_t peak = peakOf(tiles);
if (peak > ubLimitBytes) {
while (peak > ubLimitBytes && tiles[n - 1] > axisAlignSize[n - 1]) {
tiles[n - 1] -= axisAlignSize[n - 1];
peak = peakOf(tiles);
}
if (peak > ubLimitBytes) {
llvm::errs() << "[TilingStrategy] chooseAlignedTransposeTiles: boundary-1 "
"(min tile exceeds UB)\n";
llvm::errs() << " ubLimitBytes = " << ubLimitBytes << "\n";
llvm::errs() << " peak(minTile) = " << peak << "\n";
for (size_t i = 0; i < n; ++i) {
llvm::errs() << " axis[" << searchAxisOrder[i]
<< "] extent=" << ctx.extents[searchAxisOrder[i]]
<< " axisAlignSize=" << axisAlignSize[i]
<< " axisMaxTileSize=" << axisMaxTileSize[i]
<< " finalTile=" << tiles[i] << "\n";
}
}
} else {
for (size_t i = n - 1; i-- > 0;) {
int64_t saved = tiles[i];
tiles[i] = axisMaxTileSize[i];
int64_t newPeak = peakOf(tiles);
if (newPeak > ubLimitBytes) {
tiles[i] = saved;
} else {
peak = newPeak;
}
}
}
for (size_t i = 0; i < n; ++i) {
unsigned tile = saturateToTileValue(tiles[i]);
plan.outerTiles[searchAxisOrder[i]] = tile;
plan.innerTiles[searchAxisOrder[i]] = tile;
}
return true;
}
int64_t getBroadcastSuffixPointRows(const NpuBandContext &ctx, size_t suffixStart, int64_t tileBudget) {
constexpr int64_t kBroadcastSuffixPointRows = 3;
constexpr int64_t kBroadcastCloneInnerBlocksFactor = 5;
constexpr int64_t kBroadcastCloneUbDivisor = 4;
int64_t suffixElems = getSliceProduct(ctx.extents, suffixStart, ctx.extents.size());
bool broadcastCloneLike = ctx.extents.back() >= ctx.targetBlocks * kBroadcastCloneInnerBlocksFactor &&
suffixElems >= std::max<int64_t>(tileBudget / kBroadcastCloneUbDivisor, 1);
return broadcastCloneLike ? kBroadcastSuffixPointRows : kSuffixPreservePointRows;
}
size_t computePureElemSuffixStart(const NpuBandContext &ctx, int64_t tileBudget) {
if (ctx.axes.size() <= 1) return 0;
size_t suffixStart = ctx.axes.size() - 1;
int64_t suffixElems = ctx.extents.back();
while (suffixStart > 1) {
int64_t candidate = ctx.extents[suffixStart - 1];
if (suffixElems > tileBudget / candidate) break;
--suffixStart;
suffixElems = multiplyAndCap(suffixElems, candidate);
}
return suffixStart;
}
void buildPureElemPlan(const NpuBandContext &ctx, size_t suffixStart, int64_t tileBudget, BandTilePlan &plan) {
initWholeBandPlan(ctx, plan);
if (ctx.axes.size() == 1) {
int64_t extent = ctx.extents.front();
SmallVector<int64_t, 4> axisTiles{extent};
int64_t innerCap = findMaxFittingTile(ctx, axisTiles, 0, extent);
int64_t parallelBlocks = std::min<int64_t>(ctx.targetBlocks, ceilDivInt64(extent, innerCap));
int64_t outerTile = ceilDivInt64(extent, parallelBlocks);
plan.outerTiles.front() = saturateToTileValue(outerTile);
plan.innerTiles.front() = saturateToTileValue(std::min<int64_t>(outerTile, innerCap));
return;
}
int64_t suffixElems = getSliceProduct(ctx.extents, suffixStart, ctx.extents.size());
int64_t prefixRows = getSliceProduct(ctx.extents, 0, suffixStart);
int64_t pointRowsCap = std::max<int64_t>(tileBudget / suffixElems, 1);
int64_t parallelBlocks = std::min<int64_t>(ctx.targetBlocks, ceilDivInt64(prefixRows, pointRowsCap));
SmallVector<unsigned, 4> prefixOuterTiles =
assignPrefixOuterTiles(ArrayRef<int64_t>(ctx.extents).take_front(suffixStart), parallelBlocks);
int64_t preservedPrefixInner = 1;
for (size_t i = 0; i < prefixOuterTiles.size(); ++i) {
plan.outerTiles[i] = prefixOuterTiles[i];
plan.innerTiles[i] = prefixOuterTiles[i];
if (i > 0) preservedPrefixInner = multiplyAndCap(preservedPrefixInner, prefixOuterTiles[i]);
}
int64_t firstInner = std::max<int64_t>(tileBudget / suffixElems, 1);
firstInner = std::max<int64_t>(firstInner / preservedPrefixInner, 1);
SmallVector<int64_t, 4> axisTiles(plan.innerTiles.begin(), plan.innerTiles.end());
firstInner = std::min<int64_t>(firstInner, prefixOuterTiles.front());
plan.innerTiles.front() = saturateToTileValue(findMaxFittingTile(ctx, axisTiles, 0, firstInner));
if (!fitsVectorUb(ctx, axisTiles)) {
axisTiles.front() = 1;
plan.innerTiles.front() = 1;
size_t innermost = ctx.extents.size() - 1;
plan.innerTiles[innermost] = saturateToTileValue(findMaxFittingTile(ctx, axisTiles, innermost, ctx.extents.back()));
}
}
bool tryBuildReductionSuffixPlan(const NpuBandContext &ctx, BandTilePlan &plan) {
if (ctx.hasDynamicAxis || !ctx.hasReduction || !ctx.lastAxisIsReduction || ctx.axes.size() < 2 ||
ctx.graphTemplate == GraphTemplate::TRANSPOSE_OP) {
return false;
}
size_t suffixStart = computeReductionSuffixStart(ctx);
if (suffixStart == 0) return false;
initWholeBandPlan(ctx, plan);
SmallVector<unsigned, 4> prefixOuterTiles =
assignPrefixOuterTiles(ArrayRef<int64_t>(ctx.extents).take_front(suffixStart), ctx.targetBlocks);
for (size_t i = 0; i < prefixOuterTiles.size(); ++i) {
plan.outerTiles[i] = prefixOuterTiles[i];
plan.innerTiles[i] = prefixOuterTiles[i];
}
size_t innermost = ctx.extents.size() - 1;
SmallVector<int64_t, 4> axisTiles(plan.innerTiles.begin(), plan.innerTiles.end());
for (size_t i = 0; i < suffixStart; ++i) axisTiles[i] = 1;
unsigned innerTile =
saturateToTileValue(findMaxFittingTile(ctx, axisTiles, innermost, ctx.extents[innermost]));
plan.outerTiles[innermost] = innerTile;
plan.innerTiles[innermost] = innerTile;
return true;
}
bool tryBuildTransposePlan(const NpuBandContext &ctx, BandTilePlan &plan) {
if (ctx.hasDynamicAxis || ctx.axes.size() < 2) {
return false;
}
int64_t outermostTransposeIdx = -1;
for (size_t i = 0; i < ctx.axes.size(); ++i) {
if (isTransposeAxis(ctx.axes[i])) {
outermostTransposeIdx = static_cast<int64_t>(i);
break;
}
}
if (outermostTransposeIdx < 0) {
return false;
}
NpuBandContext transposeCtx = ctx;
int64_t elemBytes = std::max<int64_t>(ceilDivInt64(std::max<int64_t>(transposeCtx.transposeElementBits, 1),
kBitsPerByte),
1);
transposeCtx.axesAlignUnits.back() =
std::lcm(transposeCtx.axesAlignUnits.back(), std::max<int64_t>(kUbAlignBytes / elemBytes, 1));
initWholeBandPlan(ctx, plan);
unsigned firstAxisTile = saturateToTileValue(ceilDivInt64(ctx.extents.front(), ctx.targetBlocks));
plan.outerTiles.front() = firstAxisTile;
plan.innerTiles.front() = firstAxisTile;
int64_t ubLimitBytes = getVectorUbBytes(ctx);
SmallVector<size_t, 6> searchAxisOrder;
for (size_t i = static_cast<size_t>(outermostTransposeIdx); i < ctx.axes.size(); ++i)
searchAxisOrder.push_back(i);
chooseAlignedTransposeTiles(transposeCtx, plan, searchAxisOrder, ubLimitBytes);
return true;
}
bool tryBuildBroadcastSuffixPlan(const NpuBandContext &ctx, BandTilePlan &plan) {
if (ctx.hasDynamicAxis || ctx.hasReduction || ctx.graphTemplate != GraphTemplate::BROADCAST_OP ||
ctx.axes.size() < 2) {
return false;
}
int64_t broadcastTileBudget = ubCapacityForVectorTile(ctx);
size_t suffixStart = computeBroadcastSuffixStart(ctx, broadcastTileBudget);
if (suffixStart == 0 || suffixStart >= ctx.axes.size() - 1) return false;
initWholeBandPlan(ctx, plan);
ArrayRef<int64_t> prefixExtents(ctx.extents);
SmallVector<unsigned, 4> prefixOuterTiles =
assignPrefixOuterTiles(prefixExtents.take_front(suffixStart), ctx.targetBlocks);
fillPrefixSuffixTiles(ctx, prefixOuterTiles, getBroadcastSuffixPointRows(ctx, suffixStart, broadcastTileBudget),
plan);
SmallVector<int64_t, 4> axisTiles(plan.innerTiles.begin(), plan.innerTiles.end());
for (int64_t i = static_cast<int64_t>(ctx.axes.size()) - 1; i >= 0; --i) {
size_t idx = static_cast<size_t>(i);
Operation *loopOp = ctx.axes[idx] ? ctx.axes[idx]->getLoopOperation() : nullptr;
if (!loopOp || (!loopOp->hasAttr(kBroadcastLoopAttr) && !loopOp->hasAttr(kNotInnerDimensionBroadcastLoopAttr))) {
break;
}
int64_t extent = ctx.extents[idx];
unsigned tile = saturateToTileValue(findMaxFittingTile(ctx, axisTiles, idx, extent));
plan.outerTiles[idx] = tile;
plan.innerTiles[idx] = tile;
}
return true;
}
bool tryBuildElementwisePlan(const NpuBandContext &ctx, BandTilePlan &plan) {
if (ctx.hasDynamicAxis || ctx.hasReduction || ctx.axes.empty() || ctx.graphTemplate != GraphTemplate::PURE_ELEM) {
return false;
}
int64_t tileBudget = ubCapacityForVectorTile(ctx);
int64_t alignElems = std::max<int64_t>(kUbAlignBits / ctx.smallestTypeBits, 1);
if (tileBudget >= alignElems) tileBudget = (tileBudget / alignElems) * alignElems;
size_t suffixStart = computePureElemSuffixStart(ctx, tileBudget);
buildPureElemPlan(ctx, suffixStart, tileBudget, plan);
return true;
}
bool isParallelCandidateAxis(const NpuBandContext &ctx, size_t axisIdx) {
if (axisIdx >= ctx.axes.size()) return false;
if (ctx.axes.size() > 1 && axisIdx + 1 == ctx.axes.size()) return false;
const AxisPtr &axis = ctx.axes[axisIdx];
if (!axis || isReductionAxis(axis) || isDynamicAxis(axis) || !axis->hasConstantBounds() ||
ctx.extents[axisIdx] <= 1)
return false;
Operation *loopOp = axis->getLoopOperation();
return !loopOp || !loopOp->hasAttr(kNotInnerDimensionBroadcastLoopAttr);
}
int64_t getParallelOuterTile(const NpuBandContext &ctx, size_t axisIdx) {
int64_t extent = std::max<int64_t>(ctx.extents[axisIdx], 1);
int64_t alignUnit = axisIdx < ctx.axesAlignUnits.size() ? std::max<int64_t>(ctx.axesAlignUnits[axisIdx], 1) : 1;
return alignUpInt64(std::max<int64_t>(ceilDivInt64(extent, ctx.targetBlocks), 1), alignUnit);
}
void adjustParallelAxisOuterTile(const NpuBandContext &ctx, BandTilePlan &plan) {
if (ctx.hasDynamicAxis || plan.outerTiles.size() < ctx.axes.size() || plan.innerTiles.size() < ctx.axes.size())
return;
int bestAxis = -1;
int64_t bestWork = 0;
int64_t bestDistance = LLONG_MAX;
SmallVector<int64_t, 6> axisTiles(ctx.axes.size(), 1);
SmallVector<int64_t, 6> axisWorks(ctx.axes.size(), 0);
for (size_t i = 0; i < ctx.axes.size(); ++i) {
int64_t alignUnit = i < ctx.axesAlignUnits.size() ? std::max<int64_t>(ctx.axesAlignUnits[i], 1) : 1;
if (alignUnit > 1) {
plan.outerTiles[i] =
saturateToTileValue(alignUpInt64(std::max<int64_t>(plan.outerTiles[i], 1), alignUnit));
plan.innerTiles[i] =
saturateToTileValue(alignUpInt64(std::max<int64_t>(plan.innerTiles[i], 1), alignUnit));
}
if (!isParallelCandidateAxis(ctx, i)) continue;
axisTiles[i] = getParallelOuterTile(ctx, i);
axisWorks[i] = ceilDivInt64(std::max<int64_t>(ctx.extents[i], 1), axisTiles[i]);
}
for (size_t i = 0; i < ctx.axes.size(); ++i) {
if (axisWorks[i] <= 0) continue;
int64_t distance =
axisWorks[i] > ctx.targetBlocks ? axisWorks[i] - ctx.targetBlocks : ctx.targetBlocks - axisWorks[i];
if (distance < bestDistance || (distance == bestDistance && axisWorks[i] > bestWork)) {
bestAxis = static_cast<int>(i);
bestWork = axisWorks[i];
bestDistance = distance;
}
}
if (bestAxis < 0) return;
size_t axisIdx = static_cast<size_t>(bestAxis);
int64_t outerTile = axisTiles[axisIdx];
int64_t alignUnit =
axisIdx < ctx.axesAlignUnits.size() ? std::max<int64_t>(ctx.axesAlignUnits[axisIdx], 1) : 1;
int64_t innerTile = std::min<int64_t>(std::max<int64_t>(plan.innerTiles[axisIdx], 1), outerTile);
innerTile = std::min<int64_t>(alignUpInt64(innerTile, alignUnit), outerTile);
plan.outerTiles[axisIdx] = saturateToTileValue(outerTile);
plan.innerTiles[axisIdx] = saturateToTileValue(innerTile);
}
void applyBandTilePlan(const SmallVector<AxisPtr, 4> &bandAxes, const BandTilePlan &plan, size_t &maxLevelToTile) {
constexpr size_t kTileLevels = 2;
for (size_t i = 0; i < bandAxes.size(); ++i) {
const auto &axis = bandAxes[i];
while (axis->configs[kTileCfg].size() < kTileLevels) {
axis->doExtraTile();
}
if (axis->configs[kTileCfg].size() > kTileLevels) {
axis->configs[kTileCfg].resize(kTileLevels);
}
for (size_t level = 0; level < kTileLevels; ++level) {
auto tileConfig = axis->tryGetConfig(static_cast<int>(level), kTileCfg);
if (tileConfig == nullptr) {
continue;
}
tileConfig->constraints.clear();
unsigned tileValue = (level == 0) ? plan.outerTiles[i] : plan.innerTiles[i];
if (isReductionAxis(axis) && !isTransposeAxis(axis) && axis->hasConstantBounds()) {
int64_t fullExtent = axis->getConstantUpperBound() - axis->getConstantLowerBound();
if (static_cast<int64_t>(tileValue) >= fullExtent) {
tileValue = saturateToTileValue(std::max<int64_t>(fullExtent, 1));
}
}
tileConfig->value = static_cast<int>(tileValue);
axis->tryAddConstraint(static_cast<int>(level), Constraint({static_cast<int>(tileValue)}));
}
}
maxLevelToTile = std::max(maxLevelToTile, kTileLevels);
}
void applyDynamicFallbackAxisTiling(const AxisPtr axis, bool isFullTileAxis) {
for (size_t i = axis->configs[kTileCfg].size(); i < 2; ++i) {
axis->doExtraTile();
}
auto tileConfig0 = axis->tryGetConfig(0, kTileCfg);
auto tileConfig1 = axis->tryGetConfig(1, kTileCfg);
if (isFullTileAxis) {
if (tileConfig0 != nullptr) {
tileConfig0->value = 1;
axis->tryAddConstraint(0, Constraint({1}));
}
if (tileConfig1 != nullptr) {
tileConfig1->value = 1;
axis->tryAddConstraint(1, Constraint({1}));
}
return;
}
if (tileConfig0 != nullptr) {
int value = static_cast<int>(UINT_MAX);
tileConfig0->value = value;
axis->tryAddConstraint(0, Constraint({value}));
}
if (tileConfig1 != nullptr) {
tileConfig1->value = static_cast<int>(kNpuMinInnerTileSize);
axis->tryAddConstraint(1, Constraint({static_cast<int>(kNpuMinInnerTileSize)}));
}
}
static void applyFallbackAxisTiling(const AxisPtr axis, const SmallVector<unsigned, 4> &tileSizes,
unsigned innerTileSize, unsigned blockNumber, size_t &maxLevelToTile,
bool isFullTileAxis) {
bool hasStaticBounds = axis->hasConstantBounds();
bool hasDynamicUpperBound = !axis->hasConstantUpperBound();
if (hasDynamicUpperBound || !hasStaticBounds) {
applyDynamicFallbackAxisTiling(axis, isFullTileAxis);
maxLevelToTile = std::max(maxLevelToTile, static_cast<size_t>(2));
return;
}
int64_t lowerBound = axis->getConstantLowerBound();
int64_t upperBound = axis->getConstantUpperBound();
int64_t extent = upperBound - lowerBound;
SmallVector<unsigned, 4> currentTileSizes = tileSizes;
if (isFullTileAxis) {
unsigned fullTile = 1;
if (extent > 0) {
fullTile = extent > static_cast<int64_t>(UINT_MAX) ? UINT_MAX : static_cast<unsigned>(extent);
}
currentTileSizes = {fullTile, fullTile};
}
if (currentTileSizes.empty()) {
currentTileSizes = {std::max(getOuterTileSize(axis, blockNumber), innerTileSize), innerTileSize};
}
SmallVector<unsigned, 4> usedTileSizes;
int64_t currentSize = extent;
for (size_t i = 0; i < currentTileSizes.size(); ++i) {
int64_t tileSize = static_cast<int64_t>(currentTileSizes[i]);
int64_t minRequired = (i == currentTileSizes.size() - 1) ? tileSize * 2 : tileSize;
if (currentSize >= minRequired) {
usedTileSizes.push_back(static_cast<unsigned>(tileSize));
currentSize = tileSize;
} else {
usedTileSizes.push_back(static_cast<unsigned>(currentSize));
}
}
size_t numLevels = usedTileSizes.size();
size_t currentTileLevel = axis->configs[kTileCfg].size();
if (currentTileLevel > numLevels) {
auto &cfgs = axis->configs[kTileCfg];
cfgs.resize(numLevels);
currentTileLevel = numLevels;
}
size_t levelsToAdd = (numLevels > currentTileLevel) ? (numLevels - currentTileLevel) : 0;
for (size_t level = 0; level < levelsToAdd; ++level) {
axis->doExtraTile();
}
maxLevelToTile = std::max(maxLevelToTile, numLevels);
for (size_t level = 0; level < numLevels; ++level) {
auto tileConfig = axis->tryGetConfig(static_cast<int>(level), kTileCfg);
if (tileConfig != nullptr) {
tileConfig->value = static_cast<int>(usedTileSizes[level]);
axis->tryAddConstraint(static_cast<int>(level), Constraint({static_cast<int>(usedTileSizes[level])}));
}
}
}
}
SmallVector<AxisPtr> NpuDefaultTileStrategy::collectAxes(const NpuModelGraphPtr npuGraph) {
SmallVector<AxisPtr> axes;
npuGraph->rootAxis->forEachAxisTopDown([&axes](const AxisPtr axis) {
if (axis) {
axes.push_back(axis);
}
});
return axes;
}
SmallVector<unsigned, 4> NpuDefaultTileStrategy::parseTileSizesConfig(const NpuModelGraphPtr npuGraph) {
SmallVector<unsigned, 4> tileSizes;
auto appendTileSizes = [&tileSizes](Attribute attr) {
auto arrayAttr = dyn_cast_or_null<ArrayAttr>(attr);
if (arrayAttr) {
for (auto attr : arrayAttr) {
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
tileSizes.push_back(static_cast<unsigned>(intAttr.getInt()));
}
}
}
};
if (npuGraph->funcOp && npuGraph->funcOp->hasAttr("npu.multiTileSizes")) {
appendTileSizes(npuGraph->funcOp->getAttr("npu.multiTileSizes"));
}
if (tileSizes.empty()) {
auto tileSizesIt = npuGraph->globalConfigs.find("npu.multiTileSizes");
if (tileSizesIt != npuGraph->globalConfigs.end()) {
appendTileSizes(tileSizesIt->second);
}
}
for (size_t i = 1; i < tileSizes.size(); ++i) {
if (tileSizes[i] > tileSizes[i - 1]) {
tileSizes[i] = tileSizes[i - 1];
}
}
return tileSizes;
}
void NpuDefaultTileStrategy::applyTilingToAxes(const NpuModelGraphPtr npuGraph, const SmallVector<AxisPtr> &axes,
const SmallVector<unsigned, 4> &tileSizes) {
size_t maxLevelToTile = 1;
constexpr unsigned innerTileSize = kNpuMinInnerTileSize;
const unsigned blockNumber =
(npuGraph && npuGraph->coreNum > 0) ? static_cast<unsigned>(npuGraph->coreNum) : kNpuTargetBlocks;
if (!tileSizes.empty()) {
for (const auto &axis : axes) {
applyFallbackAxisTiling(axis, tileSizes, innerTileSize, blockNumber, maxLevelToTile, isReductionAxis(axis));
}
npuGraph->levelToTile = std::max(npuGraph->levelToTile, maxLevelToTile);
return;
}
for (const auto &[bandIdx, bandAxes] : groupAxesByBand(axes)) {
NpuBandContext bandCtx = buildNpuBandContext(npuGraph, bandIdx, bandAxes);
BandTilePlan bandPlan;
bool matched = tryBuildTransposePlan(bandCtx, bandPlan) || tryBuildReductionSuffixPlan(bandCtx, bandPlan) ||
tryBuildBroadcastSuffixPlan(bandCtx, bandPlan) || tryBuildElementwisePlan(bandCtx, bandPlan);
if (matched) {
adjustParallelAxisOuterTile(bandCtx, bandPlan);
applyBandTilePlan(bandAxes, bandPlan, maxLevelToTile);
continue;
}
for (const auto &axis : bandAxes) {
applyFallbackAxisTiling(axis, tileSizes, innerTileSize, blockNumber, maxLevelToTile, isReductionAxis(axis));
}
}
npuGraph->levelToTile = std::max(npuGraph->levelToTile, maxLevelToTile);
}
void NpuDefaultTileStrategy::AddNpuConstraint(NpuModelGraphPtr npuGraph) {
if (npuGraph == nullptr || npuGraph->rootAxis == nullptr) {
return;
}
SmallVector<AxisPtr> axes = collectAxes(npuGraph);
SmallVector<unsigned, 4> tileSizes = parseTileSizesConfig(npuGraph);
applyTilingToAxes(npuGraph, axes, tileSizes);
if (npuGraph->funcOp && npuGraph->funcOp->hasAttr("npu.multiTileSizes")) {
(void)npuGraph->funcOp->removeAttr("npu.multiTileSizes");
}
}
void TilingStrategyManager::processOn(const ModelGraphPtr modelGraph) {
modelGraph->name = modelGraph->name + "_AfterStrategy";
for (auto strategy : this->strategies_) {
strategy->AddConstraint(modelGraph);
}
}
void TilingStrategyManager::processOn(const GpuModelGraphPtr gpuGraph) {
gpuGraph->name = gpuGraph->name + "_AfterStrategy";
for (auto strategy : this->strategies_) {
strategy->AddGpuConstraint(gpuGraph);
}
}
void TilingStrategyManager::processOn(const CpuModelGraphPtr cpuGraph) {
cpuGraph->name = cpuGraph->name + "_AfterStrategy";
for (auto strategy : this->strategies_) {
strategy->AddCpuConstraint(cpuGraph);
}
}
void TilingStrategyManager::processOn(const NpuModelGraphPtr npuGraph) {
npuGraph->name = npuGraph->name + "_AfterStrategy";
for (auto strategy : this->strategies_) {
strategy->AddNpuConstraint(npuGraph);
}
}
int64_t VectorizationStrategy::findConsecutiveStaticEnd(const SmallVector<AxisPtr> &axes) {
int64_t numAxes = static_cast<int64_t>(axes.size());
for (int64_t dimIdx = numAxes - 1; dimIdx >= 0; --dimIdx) {
auto axis = axes[static_cast<size_t>(dimIdx)];
if (!axis) {
return dimIdx;
}
if (axis->axisType.find(mlir::autotiling::Axis::AxisLabel::kDynamic) != axis->axisType.end()) {
return dimIdx;
}
}
return -1;
}
void VectorizationStrategy::handleDynamicAxisForVectorization(const AxisPtr &axis, int pos, int64_t &ubRemainingNum) {
if (ubRemainingNum > 1) {
axis->tryAddConstraint(pos, Constraint(1, static_cast<int>(ubRemainingNum), 1));
auto tileConfig = axis->tryGetConfig(pos, kTileCfg);
if (tileConfig != nullptr) {
tileConfig->value = -1;
}
} else {
auto tileConfig = axis->tryGetConfig(pos, kTileCfg);
if (tileConfig != nullptr) {
tileConfig->value = 1;
}
}
}
void VectorizationStrategy::handleConsecutiveStaticAxis(const AxisPtr &axis, int pos, int64_t dimSize,
int64_t &ubRemainingNum) {
int64_t tileSize = 1;
if (ubRemainingNum > 1) {
if (dimSize <= ubRemainingNum) {
tileSize = dimSize;
} else {
tileSize = ubRemainingNum;
}
}
tileSize = std::max<int64_t>(tileSize, 1);
auto vectorizationConfig = axis->tryGetConfig(pos, kTileCfg);
if (vectorizationConfig != nullptr) {
vectorizationConfig->value = static_cast<int>(tileSize);
}
if (axis->isInnerMost) {
(void)axis->axisType.insert(mlir::autotiling::Axis::AxisLabel::kVectorization);
}
if (tileSize > 0) {
ubRemainingNum = ubRemainingNum / tileSize;
ubRemainingNum = std::max<int64_t>(ubRemainingNum, 1);
}
}
void VectorizationStrategy::handleNonConsecutiveStaticAxis(const AxisPtr &axis, int pos, int64_t dimSize,
int64_t &ubRemainingNum) {
if (ubRemainingNum > 1 && dimSize > 1) {
int64_t maxTileSize = std::min(ubRemainingNum, dimSize);
axis->tryAddConstraint(pos, Constraint(1, static_cast<int>(maxTileSize), 1));
auto tileConfig = axis->tryGetConfig(pos, kTileCfg);
if (tileConfig != nullptr) {
tileConfig->value = -1;
}
ubRemainingNum = ubRemainingNum / maxTileSize;
ubRemainingNum = std::max<int64_t>(ubRemainingNum, 1);
} else {
auto tileConfig = axis->tryGetConfig(pos, kTileCfg);
if (tileConfig != nullptr) {
tileConfig->value = 1;
}
}
}
void VectorizationStrategy::applyVectorizationTiling(const SmallVector<AxisPtr> &axes, int64_t ubAvailableNum,
int pos) {
int64_t numAxes = static_cast<int64_t>(axes.size());
int64_t ubRemainingNum = std::max<int64_t>(ubAvailableNum, 1);
int64_t consecutiveStaticEnd = findConsecutiveStaticEnd(axes);
for (int64_t dimIdx = numAxes - 1; dimIdx >= 0; --dimIdx) {
auto axis = axes[static_cast<size_t>(dimIdx)];
if (!axis) {
continue;
}
if (pos) {
axis->doExtraTile();
}
bool isDynamic = axis->axisType.find(mlir::autotiling::Axis::AxisLabel::kDynamic) != axis->axisType.end();
if (isDynamic) {
handleDynamicAxisForVectorization(axis, pos, ubRemainingNum);
continue;
}
int64_t dimSize = axis->range.second > 0 ? axis->range.second : 1;
bool isInConsecutiveStatic = (consecutiveStaticEnd == -1) || (dimIdx >= consecutiveStaticEnd);
if (isInConsecutiveStatic) {
handleConsecutiveStaticAxis(axis, pos, dimSize, ubRemainingNum);
} else {
handleNonConsecutiveStaticAxis(axis, pos, dimSize, ubRemainingNum);
}
}
}
void VectorizationStrategy::AddNpuConstraint(NpuModelGraphPtr npuGraph) {
SmallVector<AxisPtr> axes;
npuGraph->rootAxis->forEachAxisTopDown([this, &axes](const AxisPtr a) { axes.push_back(a); });
if (axes.empty()) {
return;
}
int64_t maxBufferCnt = npuGraph->maxBufferCnt;
int64_t smallestTypeBits = static_cast<int64_t>(npuGraph->smallestTypeBits);
int64_t ubSizeInBytes = npuGraph->ubSize;
int64_t ubMaxSizeInBits = ubSizeInBytes * kNumBitsInByte;
int64_t ubAvailableNumInSmallestType = ubMaxSizeInBits / smallestTypeBits / maxBufferCnt;
int64_t alignedBufferSizeInBits = (ubAvailableNumInSmallestType * smallestTypeBits) /
(kUBAlignSizeInBytes * kNumBitsInByte) * (kUBAlignSizeInBytes * kNumBitsInByte);
ubAvailableNumInSmallestType = alignedBufferSizeInBits / smallestTypeBits;
applyVectorizationTiling(axes, ubAvailableNumInSmallestType, 1);
++npuGraph->tileNum;
}
void ParallelStrategy::collectAxesInfo(const SmallVector<AxisPtr> &axes, int pos) {
totalParallelSize = 1;
totalReduceSize = 1;
isParallelAxis.clear();
for (const auto &axis : axes) {
if (!axis) {
isParallelAxis.push_back(true);
continue;
}
int64_t axisSize = axis->range.second > 0 ? axis->range.second : 1;
bool isReduction = axis->axisType.find(mlir::autotiling::Axis::AxisLabel::kReduction) != axis->axisType.end();
isParallelAxis.push_back(!isReduction);
auto tileConfig = axis->tryGetConfig(pos);
int64_t effectiveSize = axisSize;
if (tileConfig && tileConfig->value > 0) {
effectiveSize = (axisSize + tileConfig->value - 1) / tileConfig->value;
}
if (isReduction) {
totalReduceSize *= effectiveSize;
} else {
totalParallelSize *= effectiveSize;
}
}
}
std::pair<int64_t, int64_t> ParallelStrategy::allocateCoresForAxes(int64_t totalCores) {
int64_t coresForParallel = 1;
int64_t coresForReduce = 1;
if (totalParallelSize >= totalCores) {
coresForParallel = totalCores;
coresForReduce = 1;
} else {
coresForParallel = totalParallelSize;
int64_t remainingCores = totalCores / coresForParallel;
coresForReduce = std::min(remainingCores, totalReduceSize);
coresForReduce = std::max<int64_t>(coresForReduce, 1);
}
return {coresForParallel, coresForReduce};
}
bool ParallelStrategy::hasDynamicAxis(const SmallVector<AxisPtr> &axes) {
for (const auto &axis : axes) {
if (axis && axis->axisType.find(mlir::autotiling::Axis::AxisLabel::kDynamic) != axis->axisType.end()) {
return true;
}
}
return false;
}
void ParallelStrategy::applyDynamicParallelTiling(const SmallVector<AxisPtr> &axes, int64_t coreNum, int pos) {
for (const auto &axis : axes) {
if (!axis) {
continue;
}
if (pos) {
axis->doExtraTile();
}
bool isReduction = axis->axisType.find(mlir::autotiling::Axis::AxisLabel::kReduction) != axis->axisType.end();
if (isReduction) {
auto parallelTileConfig = axis->tryGetConfig(pos, kTileCfg);
if (parallelTileConfig != nullptr) {
parallelTileConfig->value = 1;
}
continue;
}
auto vectorizationTileConfig = axis->tryGetConfig(pos + 1);
int64_t vectorizationUpperBound = 1;
if (vectorizationTileConfig) {
if (vectorizationTileConfig->value > 0) {
vectorizationUpperBound = vectorizationTileConfig->value;
} else {
if (!vectorizationTileConfig->constraints.empty()) {
int64_t minMax = INT_MAX;
for (const auto &cons : vectorizationTileConfig->constraints) {
if (cons.max > 0 && cons.max < minMax) {
minMax = cons.max;
}
}
if (minMax != INT_MAX) {
vectorizationUpperBound = minMax;
}
}
}
}
int64_t dimSize = axis->range.second > 0 ? axis->range.second : 1;
bool shouldSetParallelValue = (vectorizationUpperBound == dimSize);
int64_t parallelUpperBound = vectorizationUpperBound * coreNum;
parallelUpperBound = std::max<int64_t>(parallelUpperBound, 1);
auto tileConfig = axis->tryGetConfig(pos, kTileCfg);
if (tileConfig != nullptr) {
if (shouldSetParallelValue) {
tileConfig->value = vectorizationUpperBound;
} else {
axis->tryAddConstraint(pos, Constraint(1, static_cast<int>(parallelUpperBound), 1));
tileConfig->value = -1;
}
}
}
}
int64_t ParallelStrategy::calculateParallelTileSize(int64_t outerSize, int64_t &remainingParallelCores) {
if (remainingParallelCores <= 1 || outerSize <= 1) {
return outerSize;
}
int64_t parallelTileSize = (outerSize + remainingParallelCores - 1) / remainingParallelCores;
parallelTileSize = std::max<int64_t>(parallelTileSize, 1);
if (parallelTileSize == 1 && outerSize > 1) {
int64_t adjustedCores = std::max<int64_t>((outerSize + 1) / 2, 1);
if (adjustedCores < remainingParallelCores) {
remainingParallelCores = adjustedCores;
parallelTileSize = (outerSize + remainingParallelCores - 1) / remainingParallelCores;
parallelTileSize = std::max<int64_t>(parallelTileSize, 1);
}
}
if (parallelTileSize < outerSize) {
remainingParallelCores = (remainingParallelCores + parallelTileSize - 1) / parallelTileSize;
} else {
remainingParallelCores = 1;
}
return parallelTileSize;
}
int64_t ParallelStrategy::calculateReduceTileSize(int64_t outerSize, int64_t &remainingReduceCores) {
if (remainingReduceCores <= 1 || outerSize <= 1) {
return outerSize;
}
int64_t reduceTileSize = (outerSize + remainingReduceCores - 1) / remainingReduceCores;
reduceTileSize = std::max<int64_t>(reduceTileSize, 1);
if (reduceTileSize < outerSize) {
remainingReduceCores = (remainingReduceCores + reduceTileSize - 1) / reduceTileSize;
} else {
remainingReduceCores = 1;
}
return reduceTileSize;
}
int64_t ParallelStrategy::adjustTileSizeForCoreLimit(int64_t axisSize, int64_t tileSize, int64_t coreNum) {
int64_t finalTileSize = tileSize;
int64_t numBlocks = (axisSize + finalTileSize - 1) / finalTileSize;
if (numBlocks > coreNum) {
finalTileSize = (axisSize + coreNum - 1) / coreNum;
finalTileSize = std::max<int64_t>(finalTileSize, 1);
}
return finalTileSize;
}
void ParallelStrategy::applyStaticParallelTiling(const SmallVector<AxisPtr> &axes, int64_t coresForParallel,
int64_t coresForReduce, int64_t coreNum, int pos) {
int64_t remainingParallelCores = coresForParallel;
int64_t remainingReduceCores = coresForReduce;
for (size_t i = 0; i < axes.size(); ++i) {
auto axis = axes[i];
if (!axis) {
continue;
}
if (pos) {
axis->doExtraTile();
}
auto vectorizationTileConfig = axis->tryGetConfig(pos + 1);
int64_t innerTileSize = 1;
if (vectorizationTileConfig && vectorizationTileConfig->value > 0) {
innerTileSize = vectorizationTileConfig->value;
}
int64_t axisSize = axis->range.second > 0 ? axis->range.second : 1;
int64_t outerSize = (axisSize + innerTileSize - 1) / innerTileSize;
int64_t tileSize = outerSize;
if (isParallelAxis[i] && remainingParallelCores > 1 && outerSize > 1) {
tileSize = std::min(tileSize, calculateParallelTileSize(outerSize, remainingParallelCores));
} else if (!isParallelAxis[i] && remainingReduceCores > 1 && outerSize > 1) {
tileSize = std::min(tileSize, calculateReduceTileSize(outerSize, remainingReduceCores));
}
int64_t finalTileSize = tileSize * innerTileSize;
finalTileSize = std::min(finalTileSize, axisSize);
finalTileSize = std::max<int64_t>(finalTileSize, 1);
finalTileSize = adjustTileSizeForCoreLimit(axisSize, finalTileSize, coreNum);
auto parallelTileConfig = axis->tryGetConfig(pos, kTileCfg);
if (parallelTileConfig != nullptr) {
parallelTileConfig->value = static_cast<int>(finalTileSize);
}
}
}
void ParallelStrategy::AddNpuConstraint(NpuModelGraphPtr npuGraph) {
int64_t totalCores = npuGraph->coreNum;
SmallVector<AxisPtr> axes;
npuGraph->rootAxis->forEachAxisTopDown([&axes](const AxisPtr axis) {
if (axis) {
axes.push_back(axis);
}
});
if (axes.empty()) {
return;
}
bool isDynamicShape = hasDynamicAxis(axes);
if (isDynamicShape) {
applyDynamicParallelTiling(axes, totalCores, 0);
} else {
collectAxesInfo(axes, 1);
auto [coresForParallel, coresForReduce] = allocateCoresForAxes(totalCores);
applyStaticParallelTiling(axes, coresForParallel, coresForReduce, totalCores, 0);
}
for (const auto &axis : axes) {
if (axis && axis->axisType.find(mlir::autotiling::Axis::AxisLabel::kReduction) == axis->axisType.end()) {
(void)axis->axisType.insert(mlir::autotiling::Axis::AxisLabel::kMultiCore);
break;
}
}
++npuGraph->tileNum;
}
}
}