* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "ascend/include/DynamicCVPipeline/AddControlFlowCondition/UpdateLoopIterTimes.h"
#include "ascend/include/DynamicCVPipeline/AddControlFlowCondition.h"
#include "bishengir/Dialect/HIVM/IR/HIVM.h"
#include "bishengir/Dialect/Scope/IR/Scope.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
static constexpr const char *DEBUG_TYPE = "UpdateLoopIterTimes";
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(...) LLVM_DEBUG(DBGS() << __VA_ARGS__ << "\n")
using namespace llvm;
using namespace mlir;
using namespace triton;
using namespace hivm;
static int findIfOpIndexInList(Operation *op, SmallVector<scf::IfOp> &ifOps,
llvm::DenseMap<Operation *, int> &ifOpIndex)
{
for (scf::IfOp ifOp : ifOps) {
if (ifOp->isAncestor(op)) {
return ifOpIndex[ifOp.getOperation()];
}
}
return -1;
}
std::pair<int, int> UpdateLoopIterTimesPass::calculateFactor(scf::ForOp forOp)
{
int maxRequiredBuffers = 1;
int maxX = 1;
SmallVector<scf::IfOp> ifOps;
DenseMap<Operation *, int> ifOpIndex;
int index = 1;
int ret = 0;
forOp.walk([&](Operation* op) {
if (op->hasAttr("ssbuffer.if")) {
auto ifOp = dyn_cast<scf::IfOp>(op);
if (!ifOp) {
ret = 1;
LDBG("ssbuffer.if attribute is not allocated on ifOp!");
return WalkResult::interrupt();
}
ifOps.push_back(ifOp);
ifOpIndex[ifOp.getOperation()] = index++;
}
return WalkResult::advance();
});
if (ret == -1) {
return {-1, -1};
}
if (ifOps.empty()) {
LDBG("mainloop do not contains ifblocks!");
return {-1, -1};
}
if (!info->intraCoreDependentMap.count(forOp)) {
return {1, 1};
}
auto &deps = info->intraCoreDependentMap[forOp];
if (deps.empty()) {
return {1, 1};
}
for (auto &entry : deps) {
Value consumerResult = entry.first;
SmallVector<Value> producerBuffers = entry.second;
int x = producerBuffers.size();
Operation *consumerDefOp = consumerResult.getDefiningOp();
if (!consumerDefOp) {
LDBG("consumerResult do not have the defOp!");
return {-1, -1};
}
int m = findIfOpIndexInList(consumerDefOp, ifOps, ifOpIndex);
if (m == -1) {
LDBG("Can not find the consumerDefOp in any ifOps!\nconsumerDefOp: " << *consumerDefOp);
return {-1, -1};
}
if (producerBuffers.empty()) {
LDBG("consumer do not have the producerBuffers!");
return {-1, -1};
}
int producerIfOpIndex = -1;
for (Value buffer : producerBuffers) {
for (Operation *user : buffer.getUsers()) {
if (isa<mlir::bufferization::MaterializeInDestinationOp>(user) || isa<hivm::CopyOp>(user)) {
producerIfOpIndex = findIfOpIndexInList(user, ifOps, ifOpIndex);
if (producerIfOpIndex == -1) {
LDBG("Can not find the producerBuffers in any ifOps\nproducerBuffers: " << *user);
return {-1, -1};
}
break;
}
}
if (producerIfOpIndex != -1) {
break;
}
}
if (producerIfOpIndex == -1) {
LDBG("All producerBuffers are not found in any ifOps");
return {-1, -1};
}
int n = producerIfOpIndex;
if (m <= n) {
LDBG("producer is after the comsumer!");
return {-1, -1};
}
int requiredBuffers = m - n + 1;
if (requiredBuffers * maxX > maxRequiredBuffers * x) {
maxRequiredBuffers = requiredBuffers;
maxX = x;
}
}
return {maxRequiredBuffers, maxX};
}
Value UpdateLoopIterTimesPass::computeNewLoopUpperBound(
OpBuilder &builder, Location loc, scf::ForOp forOp,
int ifCount, int requiredBuffers, int x)
{
Value originalLowerBound = forOp.getLowerBound();
Value originalUpperBound = forOp.getUpperBound();
Value originalStep = forOp.getStep();
Type ubType = originalStep.getType();
auto createConstant = [&](int val) -> Value {
if (ubType.isIndex()) {
return builder.create<arith::ConstantIndexOp>(loc, val);
} else if (auto intType = dyn_cast<IntegerType>(ubType)) {
return builder.create<arith::ConstantIntOp>(loc, intType, val);
} else {
auto indexVal = builder.create<arith::ConstantIndexOp>(loc, val);
return builder.create<arith::IndexCastOp>(loc, ubType, indexVal);
}
};
auto createCeilDiv = [&](Value lhs, Value rhs) -> Value {
if (ubType.isIndex()) {
return builder.create<arith::CeilDivUIOp>(loc, lhs, rhs);
} else if (auto intType = dyn_cast<IntegerType>(ubType)) {
if (intType.isSigned()) {
return builder.create<arith::CeilDivSIOp>(loc, lhs, rhs);
} else {
return builder.create<arith::CeilDivUIOp>(loc, lhs, rhs);
}
} else {
return builder.create<arith::CeilDivUIOp>(loc, lhs, rhs);
}
};
Value ifCountValue = createConstant(ifCount);
Value requiredBuffersValue = createConstant(requiredBuffers);
Value xValue = createConstant(x);
Value rangeDiff = builder.create<arith::SubIOp>(loc, originalUpperBound, originalLowerBound);
Value iterCount = createCeilDiv(rangeDiff, originalStep);
Value scaledIterCount = builder.create<arith::MulIOp>(loc, iterCount, requiredBuffersValue);
Value ceiledScaledIterCount = createCeilDiv(scaledIterCount, xValue);
Value newIterCount = builder.create<arith::AddIOp>(loc, ceiledScaledIterCount, ifCountValue);
Value totalSteps = builder.create<arith::MulIOp>(loc, originalStep, newIterCount);
Value newUpperBound = builder.create<arith::AddIOp>(loc, originalLowerBound, totalSteps);
return newUpperBound;
}
scf::ForOp UpdateLoopIterTimesPass::cloneForOpWithNewUpperBound(
OpBuilder &builder, Location loc, scf::ForOp oldForOp, Value newUpperBound,
IRMapping &mapper)
{
Value originalLowerBound = oldForOp.getLowerBound();
Value originalStep = oldForOp.getStep();
SmallVector<Value> newInitArgs(oldForOp.getInitArgs().begin(), oldForOp.getInitArgs().end());
auto newForOp = builder.create<scf::ForOp>(loc, originalLowerBound, newUpperBound, originalStep, newInitArgs);
for (auto &attr : oldForOp->getAttrs()) {
newForOp->setAttr(attr.getName(), attr.getValue());
}
mapper.map(oldForOp.getInductionVar(), newForOp.getInductionVar());
for (auto [oldArg, newArg] : llvm::zip(oldForOp.getRegionIterArgs(), newForOp.getRegionIterArgs())) {
mapper.map(oldArg, newArg);
}
Block *oldBlock = oldForOp.getBody();
Block *newBlock = newForOp.getBody();
unsigned totalArgs = oldBlock->getNumArguments();
for (unsigned i = 0; i < totalArgs; ++i) {
BlockArgument oldArg = oldBlock->getArgument(i);
BlockArgument newArg = newBlock->getArgument(i);
oldArg.replaceAllUsesWith(newArg);
}
Operation *oldTerminator = oldBlock->getTerminator();
builder.setInsertionPointToStart(newBlock);
for (Operation &op : llvm::make_early_inc_range(oldBlock->without_terminator())) {
builder.clone(op, mapper);
}
auto oldYield = cast<scf::YieldOp>(oldTerminator);
SmallVector<Value> newYieldOperands;
for (unsigned i = 0; i < oldYield.getNumOperands(); ++i) {
newYieldOperands.push_back(mapper.lookupOrDefault(oldYield.getOperand(i)));
}
builder.setInsertionPointToEnd(newBlock);
builder.create<scf::YieldOp>(loc, newYieldOperands);
unsigned numOriginalResults = oldForOp.getNumResults();
if (numOriginalResults > 0) {
SmallVector<Value> originalResults;
for (unsigned i = 0; i < numOriginalResults; ++i) {
originalResults.push_back(newForOp.getResult(i));
}
oldForOp.replaceAllUsesWith(originalResults);
}
return newForOp;
}
int UpdateLoopIterTimesPass::updateCntArgsAfterClone(
scf::ForOp oldForOp, IRMapping &mapper,
SmallVector<scf::IfOp> &ifOpsInThisFor)
{
for (scf::IfOp oldIfOp : ifOpsInThisFor) {
scf::IfOp newIfOp = dyn_cast<scf::IfOp>(mapper.lookupOrDefault(oldIfOp));
if (!newIfOp)
continue;
Value oldCntVal = info->cntArgs[oldIfOp];
Value newCntVal = mapper.lookupOrDefault(oldCntVal);
LDBG("updating CntArgs After Clone...");
LDBG("oldCntVal: " << oldCntVal);
LDBG("newCntVal: " << newCntVal);
info->cntArgs.erase(oldIfOp);
info->cntArgs[newIfOp] = newCntVal;
}
return 0;
}
scf::ForOp UpdateLoopIterTimesPass::extendForOpIterationCount(
scf::ForOp oldForOp, int ifCount, int requiredBuffers, int x,
IRMapping &mapper, SmallVector<scf::IfOp> &ifOpsInThisFor)
{
OpBuilder builder(oldForOp);
Location loc = oldForOp.getLoc();
Value newUpperBound = computeNewLoopUpperBound(builder, loc, oldForOp, ifCount, requiredBuffers, x);
if (!newUpperBound) {
LDBG("computeNewLoopUpperBound failed!");
return nullptr;
}
scf::ForOp newForOp = cloneForOpWithNewUpperBound(builder, loc, oldForOp, newUpperBound, mapper);
if (!newForOp) {
LDBG("cloneForOpWithNewUpperBound failed!");
return nullptr;
} else {
LDBG("cloneForOpWithNewUpperBound Success!");
LDBG("new ForOp: " << newForOp);
LDBG("old ForOp: " << oldForOp);
}
if (updateCntArgsAfterClone(oldForOp, mapper, ifOpsInThisFor) != 0) {
LDBG("updateCntArgsAfterClone failed!");
return nullptr;
}
return newForOp;
}
int UpdateLoopIterTimesPass::replaceForOpCounterInIfOps()
{
int ret = 0;
getOperation().walk([&](Operation* op) {
if (op->hasAttr("ssbuffer.main_loop")) {
auto forOp = dyn_cast<scf::ForOp>(op);
if (!forOp) {
ret = -1;
LDBG("Do not support other mainloop except forOp!");
return WalkResult::interrupt();
}
Value indVar = forOp.getInductionVar();
forOp.walk([&](scf::IfOp ifOp) {
if (ifOp->hasAttr("ssbuffer.if")) {
if (!info->cntArgs.count(ifOp)) {
LDBG("ifblock has no counter in cntArgs");
ret = -1;
return WalkResult::interrupt();
}
Value cntVal = info->cntArgs[ifOp];
ifOp.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
if (operand.get() == indVar) {
operand.set(cntVal);
}
}
});
}
return mlir::WalkResult::advance();
});
}
return mlir::WalkResult::advance();
});
return ret;
}
int UpdateLoopIterTimesPass::GetMainLoopIdToLoopOpMap(
ModuleOp module, DenseMap<int, SmallVector<Operation *>> &cmap,
DenseMap<int, SmallVector<Operation *>> &vmap)
{
cmap.clear();
vmap.clear();
int ret = 0;
module.walk([&](scope::ScopeOp scopeOp) {
bool isCube = false;
bool isVector = false;
if (scopeOp->hasAttr("hivm.tcore_type")) {
auto attr = scopeOp->getAttr("hivm.tcore_type");
auto aiCAttr = hivm::TCoreTypeAttr::get(scopeOp->getContext(), hivm::TCoreType::CUBE);
auto aiVAttr = hivm::TCoreTypeAttr::get(scopeOp->getContext(), hivm::TCoreType::VECTOR);
if (attr == aiCAttr) {
isCube = true;
} else if (attr == aiVAttr) {
isVector = true;
}
}
if (!(isCube || isVector)) {
ret = -1;
LDBG("mlir do not processed by split mix kernel!");
return mlir::WalkResult::interrupt();
}
scopeOp.walk([&](Operation* op) {
if (op->hasAttr("ssbuffer.main_loop")) {
auto forOp = dyn_cast<scf::ForOp>(op);
if (!forOp) {
ret = -1;
LDBG("do not surpport other loop op temprarily!");
return mlir::WalkResult::interrupt();
}
auto mainLoopId = forOp->getAttrOfType<IntegerAttr>("ssbuffer.main_loop");
if (mainLoopId) {
int id = mainLoopId.getInt();
if (isCube) {
cmap[id].push_back(forOp.getOperation());
} else if (isVector) {
vmap[id].push_back(forOp.getOperation());
}
}
}
return mlir::WalkResult::advance();
});
return mlir::WalkResult::advance();
});
return ret;
}
int UpdateLoopIterTimesPass::ComputeMainLoopTimes(
DenseMap<int, SmallVector<Operation *>> &loopMap,
DenseMap<Operation *, IterationTimesInfo> &infoMap)
{
for (auto &entry : loopMap) {
for (Operation *loopOp : entry.second) {
scf::ForOp forOp = dyn_cast<scf::ForOp>(loopOp);
if (!forOp) {
LDBG("currently only support forOp!");
return -1;
}
IterationTimesInfo iterInfo;
if (info->cntArgs.empty()) {
LDBG("cntArgs is empty, no ifblock is contained!");
return -1;
}
for (auto &[ifOp, cntVal] : info->cntArgs) {
if (ifOp->hasAttr("ssbuffer.if")) {
auto parentOp = ifOp->getParentOp();
if (parentOp->hasAttr("ssbuffer.main_loop") && isa<scf::ForOp>(parentOp)) {
if (parentOp == forOp.getOperation()) {
iterInfo.ifCount++;
iterInfo.ifOpsInThisFor.push_back(ifOp);
}
} else {
LDBG("Get wrong ifblock structure: ifblock's parentOp is not mainloop!");
LDBG("ifblock Op: " << ifOp);
LDBG("Parent Op: " << *parentOp);
return -1;
}
} else {
LDBG("ifOp in cntArgs does not contain ssbuffer.if Attribute\n ifOp: " << ifOp);
return -1;
}
}
auto [requiredBuffers, x] = calculateFactor(forOp);
if (requiredBuffers == -1 || x == -1) {
LDBG("calculateFactor failed!");
return -1;
}
iterInfo.requiredBuffers = requiredBuffers;
iterInfo.x = x;
infoMap[loopOp] = iterInfo;
}
}
return 0;
}
int UpdateLoopIterTimesPass::collectForOpsAndUpdateMax(
DenseMap<int, SmallVector<Operation *>> &map,
int id,
SmallVector<Operation *> &allForOps,
int &maxIfCount,
int &maxRequiredBuffers,
int &maxX,
DenseMap<Operation *, IterationTimesInfo> &infoMap)
{
if (map.count(id)) {
for (Operation *loopOp : map[id]) {
allForOps.push_back(loopOp);
if (infoMap.count(loopOp)) {
IterationTimesInfo &iterInfo = infoMap[loopOp];
if (iterInfo.ifCount > maxIfCount) {
maxIfCount = iterInfo.ifCount;
}
if (iterInfo.requiredBuffers * maxX > maxRequiredBuffers * iterInfo.x) {
maxRequiredBuffers = iterInfo.requiredBuffers;
maxX = iterInfo.x;
}
} else {
LDBG("mainloop Op: " << *loopOp << " do not include in infoMap!");
return -1;
}
}
}
return 0;
}
int UpdateLoopIterTimesPass::UpdateForLoopIteration(
DenseMap<int, SmallVector<Operation *>> &cmap,
DenseMap<int, SmallVector<Operation *>> &vmap,
DenseMap<Operation *, IterationTimesInfo> &infoMap)
{
int ret = 0;
DenseSet<int> allIds;
for (auto &entry : cmap) {
allIds.insert(entry.first);
}
for (auto &entry : vmap) {
allIds.insert(entry.first);
}
SmallVector<Operation *> allForOps;
for (int id : allIds) {
int maxIfCount = 0;
int maxRequiredBuffers = 1;
int maxX = 1;
SmallVector<Operation *> sameIdForOps;
ret = collectForOpsAndUpdateMax(cmap, id, sameIdForOps, maxIfCount, maxRequiredBuffers, maxX, infoMap);
if (ret != 0)
return -1;
ret = collectForOpsAndUpdateMax(vmap, id, sameIdForOps, maxIfCount, maxRequiredBuffers, maxX, infoMap);
if (ret != 0)
return -1;
if (maxIfCount == 0) {
LDBG("no ifblock in mainloop!");
return -1;
}
for (Operation *loopOp : sameIdForOps) {
scf::ForOp oldForOp = dyn_cast<scf::ForOp>(loopOp);
if (!oldForOp) {
LDBG("do not surpport other loop op except forOp!");
return -1;
}
IterationTimesInfo &iterInfo = infoMap[loopOp];
IRMapping mapper;
scf::ForOp newForOp = extendForOpIterationCount(oldForOp, maxIfCount, maxRequiredBuffers, maxX, mapper, iterInfo.ifOpsInThisFor);
if (!newForOp) {
LDBG("extendForOpIterationCount failed!");
return -1;
}
allForOps.push_back(loopOp);
}
}
for (Operation *loopOp : allForOps) {
if (!loopOp) {
LDBG("erasing error: loopOp is nullptr, there are nested mainloop!");
return -1;
}
loopOp->erase();
}
return 0;
}
void UpdateLoopIterTimesPass::runOnOperation()
{
ModuleOp module = getOperation();
LDBG("before updateloopitertimes:\n" << module);
LDBG("\nEnter UpdateLoopIterTimesPass!");
int ret = 0;
DenseMap<int, SmallVector<Operation *>> cmap;
DenseMap<int, SmallVector<Operation *>> vmap;
ret = GetMainLoopIdToLoopOpMap(module, cmap, vmap);
if (ret != 0) {
LDBG("\nGetMainLoopIdToLoopOpMap Failed!");
signalPassFailure();
return;
}
DenseMap<Operation *, IterationTimesInfo> infoMap;
ret = ComputeMainLoopTimes(cmap, infoMap);
if (ret != 0) {
LDBG("\nComputeMainLoopTimes from cube Failed!");
signalPassFailure();
return;
}
ret = ComputeMainLoopTimes(vmap, infoMap);
if (ret != 0) {
LDBG("\nComputeMainLoopTimes from vector Failed!");
signalPassFailure();
return;
}
ret = UpdateForLoopIteration(cmap, vmap, infoMap);
if (ret != 0) {
LDBG("Update ForLoop Iteration Failed!\n");
signalPassFailure();
}
LDBG("after UpdateForLoopIteration:\n" << module);
if (replaceForOpCounterInIfOps() != 0) {
LDBG("replaceForOpCounterInIfOps Failed!\n");
signalPassFailure();
}
LDBG("after updateloopitertimes:\n" << module);
LDBG("\nExit UpdateLoopIterTimes pass.");
}
namespace mlir {
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createUpdateLoopIterTimesPass()
{
return std::make_unique<UpdateLoopIterTimesPass>();
}
}
}