#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h"
#include "triton/Dialect/TritonGPU/IR/Types.h"
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
#define DEBUG_TYPE "triton-loop-pipelining"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
using namespace mlir::scf;
using namespace mlir::triton;
namespace {
struct LoopPipelinerInternal {
struct LiverangeInfo {
unsigned lastUseStage = 0;
unsigned defStage = 0;
};
protected:
ForOp forOp;
unsigned maxStage = 0;
DenseMap<Operation *, unsigned> stages;
std::vector<Operation *> opOrder;
Value ub;
Value lb;
Value step;
bool dynamicLoop;
triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr;
bool peelEpilogue;
triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr;
triton::PipeliningOption::EmitPredicateStageFnType emitPredicateStageFn =
nullptr;
DenseMap<Value, llvm::SmallVector<Value>> valueMapping;
void setValueMapping(Value key, Value el, int64_t idx);
std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
bool verifySchedule();
public:
bool initializeLoopInfo(ForOp op, const triton::PipeliningOption &options);
LogicalResult emitPrologue(RewriterBase &rewriter);
llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
scf::ForOp createKernelLoop(
const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
RewriterBase &rewriter,
llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
LogicalResult createKernel(
scf::ForOp newForOp,
const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
RewriterBase &rewriter);
LogicalResult emitEpilogue(RewriterBase &rewriter,
llvm::SmallVector<Value> &returnValues);
};
static SetVector<Value> getNestedOperands(Operation *op) {
SetVector<Value> operands;
op->walk([&](Operation *nestedOp) {
for (Value operand : nestedOp->getOperands()) {
operands.insert(operand);
}
});
return operands;
}
bool LoopPipelinerInternal::initializeLoopInfo(
ForOp op, const triton::PipeliningOption &options) {
LDBG("Start initializeLoopInfo");
forOp = op;
ub = forOp.getUpperBound();
lb = forOp.getLowerBound();
step = forOp.getStep();
std::vector<std::pair<Operation *, unsigned>> schedule;
options.getScheduleFn(forOp, schedule);
if (schedule.empty()) {
LDBG("--empty schedule -> BAIL");
return false;
}
opOrder.reserve(schedule.size());
for (auto &opSchedule : schedule) {
maxStage = std::max(maxStage, opSchedule.second);
stages[opSchedule.first] = opSchedule.second;
opOrder.push_back(opSchedule.first);
}
dynamicLoop = true;
auto upperBoundCst = ub.getDefiningOp<arith::ConstantIndexOp>();
auto lowerBoundCst = lb.getDefiningOp<arith::ConstantIndexOp>();
auto stepCst = step.getDefiningOp<arith::ConstantIndexOp>();
if (!upperBoundCst || !lowerBoundCst || !stepCst) {
if (!options.supportDynamicLoops) {
LDBG("--dynamic loop not supported -> BAIL");
return false;
}
} else {
int64_t ubImm = upperBoundCst.value();
int64_t lbImm = lowerBoundCst.value();
int64_t stepImm = stepCst.value();
int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
if (numIteration >= maxStage) {
dynamicLoop = false;
} else if (!options.supportDynamicLoops) {
LDBG("--fewer loop iterations than pipeline stages -> BAIL");
return false;
}
}
peelEpilogue = options.peelEpilogue;
predicateFn = options.predicateFn;
if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
LDBG("--no epilogue or predicate set -> BAIL");
return false;
}
emitPredicateStageFn = options.emitPredicateStageFn;
if (emitPredicateStageFn == nullptr) {
emitPredicateStageFn = mlir::triton::emitPredicateForStage;
}
for (Operation &op : forOp.getBody()->without_terminator()) {
if (!stages.contains(&op)) {
op.emitOpError("not assigned a pipeline stage");
LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL");
return false;
}
}
if (!verifySchedule()) {
LDBG("--invalid schedule: " << op << " -> BAIL");
return false;
}
for (const auto &[op, stageNum] : stages) {
(void)stageNum;
if (op == forOp.getBody()->getTerminator()) {
op->emitError("terminator should not be assigned a stage");
LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL");
return false;
}
if (op->getBlock() != forOp.getBody()) {
op->emitOpError("the owning Block of all operations assigned a stage "
"should be the loop body block");
LDBG("--the owning Block of all operations assigned a stage "
"should be the loop body block: "
<< *op << " -> BAIL");
return false;
}
}
for (auto &op : forOp.getBody()->without_terminator()) {
for (auto operand : getNestedOperands(&op)) {
auto [def, distance] = getDefiningOpAndDistance(operand);
if (!def)
continue;
if (distance > 1) {
LDBG("--only support loop carried dependency with a distance of 1 or "
"defined outside of the loop -> BAIL");
return false;
}
}
}
annotateFn = options.annotateFn;
return true;
}
bool LoopPipelinerInternal::verifySchedule() {
int64_t numCylesPerIter = opOrder.size();
DenseMap<Operation *, int64_t> unrolledCyles;
for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
Operation *def = opOrder[cycle];
auto it = stages.find(def);
assert(it != stages.end());
int64_t stage = it->second;
unrolledCyles[def] = cycle + stage * numCylesPerIter;
}
for (Operation *consumer : opOrder) {
int64_t consumerCycle = unrolledCyles[consumer];
for (Value operand : getNestedOperands(consumer)) {
auto [producer, distance] = getDefiningOpAndDistance(operand);
if (!producer)
continue;
auto it = unrolledCyles.find(producer);
if (it == unrolledCyles.end())
continue;
int64_t producerCycle = it->second;
if (consumerCycle < producerCycle - numCylesPerIter * distance) {
InFlightDiagnostic diag =
consumer->emitWarning("operation scheduled before its operands. "
"Pipelining will be disabled.");
diag.attachNote(producer->getLoc())
.append("operand defined here: ")
.appendOp(*producer, OpPrintingFlags().printGenericOpForm());
return false;
}
}
}
return true;
}
static Operation *
cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
function_ref<void(OpOperand *newOperand)> callback) {
Operation *clone = rewriter.clone(*op);
clone->walk<WalkOrder::PreOrder>([&](Operation *nested) {
for (OpOperand &operand : nested->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
callback(&operand);
}
});
return clone;
}
LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
for (auto [arg, operand] :
llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
setValueMapping(arg, operand.get(), 0);
}
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (auto [arg, operand] :
llvm::zip(forOp.getRegionIterArgs(), yield->getOpOperands())) {
if (forOp.getBodyRegion().isAncestor(operand.get().getParentRegion()))
continue;
for (int64_t i = 1; i < maxStage; ++i)
setValueMapping(arg, operand.get(), i);
}
Location loc = forOp.getLoc();
SmallVector<Value> predicates(maxStage);
for (int64_t i = 0; i < maxStage; i++) {
Type t = lb.getType();
Value iv = rewriter.create<arith::AddIOp>(
loc, lb,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::ConstantOp>(loc,
rewriter.getIntegerAttr(t, i))));
setValueMapping(forOp.getInductionVar(), iv, i);
if (dynamicLoop) {
predicates[i] = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, iv, ub);
}
for (Operation *op : opOrder) {
if (stages[op] > i)
continue;
Operation *newOp =
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
auto it = valueMapping.find(newOperand->get());
if (it != valueMapping.end()) {
Value replacement = it->second[i - stages[op]];
newOperand->set(replacement);
}
});
int predicateIdx = i - stages[op];
if (predicates[predicateIdx]) {
OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
if (newOp == nullptr)
return failure();
}
if (annotateFn)
annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Prologue, i);
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
Value source = newOp->getResult(destId);
for (OpOperand &operand : yield->getOpOperands()) {
if (operand.get() != op->getResult(destId))
continue;
if (predicates[predicateIdx] &&
!forOp.getResult(operand.getOperandNumber()).use_empty()) {
Value prevValue = valueMapping
[forOp.getRegionIterArgs()[operand.getOperandNumber()]]
[i - stages[op]];
source = rewriter.create<arith::SelectOp>(
loc, predicates[predicateIdx], source, prevValue);
}
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
source, i - stages[op] + 1);
}
setValueMapping(op->getResult(destId), newOp->getResult(destId),
i - stages[op]);
}
}
}
return success();
}
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
LoopPipelinerInternal::analyzeCrossStageValues() {
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
for (Operation *op : opOrder) {
unsigned stage = stages[op];
auto analyzeOperand = [&](OpOperand &operand) {
auto [def, distance] = getDefiningOpAndDistance(operand.get());
if (!def)
return;
auto defStage = stages.find(def);
if (defStage == stages.end() || defStage->second == stage ||
defStage->second == stage + distance)
return;
assert(stage > defStage->second);
LiverangeInfo &info = crossStageValues[operand.get()];
info.defStage = defStage->second;
info.lastUseStage = std::max(info.lastUseStage, stage);
};
for (OpOperand &operand : op->getOpOperands())
analyzeOperand(operand);
visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
analyzeOperand(*operand);
});
}
return crossStageValues;
}
std::pair<Operation *, int64_t>
LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
return triton::getDefiningOpAndDistance(forOp, value);
}
scf::ForOp LoopPipelinerInternal::createKernelLoop(
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
&crossStageValues,
RewriterBase &rewriter,
llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
llvm::SmallVector<Value> newLoopArg;
for (const auto &retVal :
llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
Operation *def = retVal.value().getDefiningOp();
auto defStage = stages.find(def);
if (defStage != stages.end()) {
Value valueVersion =
valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
[maxStage - defStage->second];
assert(valueVersion);
newLoopArg.push_back(valueVersion);
} else
newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
}
for (auto escape : crossStageValues) {
LiverangeInfo &info = escape.second;
Value value = escape.first;
for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
stageIdx++) {
Value valueVersion =
valueMapping[value][maxStage - info.lastUseStage + stageIdx];
assert(valueVersion);
newLoopArg.push_back(valueVersion);
loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
stageIdx)] = newLoopArg.size() - 1;
}
}
Value newUb = forOp.getUpperBound();
if (peelEpilogue) {
Type t = ub.getType();
Location loc = forOp.getLoc();
Value maxStageValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(t, maxStage));
Value maxStageByStep =
rewriter.create<arith::MulIOp>(loc, step, maxStageValue);
newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep);
}
auto newForOp =
rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
forOp.getStep(), newLoopArg);
newForOp->setAttrs(forOp->getAttrs());
if (!newForOp.getBody()->empty())
rewriter.eraseOp(newForOp.getBody()->getTerminator());
return newForOp;
}
LogicalResult LoopPipelinerInternal::createKernel(
scf::ForOp newForOp,
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
&crossStageValues,
const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
RewriterBase &rewriter) {
valueMapping.clear();
rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
IRMapping mapping;
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) {
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
}
SmallVector<Value> predicates(maxStage + 1, nullptr);
if (!peelEpilogue) {
Location loc = newForOp.getLoc();
for (unsigned i = 0; i < maxStage; i++) {
predicates[i] = emitPredicateStageFn(rewriter, newForOp.getInductionVar(),
ub, step, maxStage, i);
}
}
for (Operation *op : opOrder) {
int64_t useStage = stages[op];
auto *newOp = rewriter.clone(*op, mapping);
SmallVector<OpOperand *> operands;
op->walk([&operands](Operation *nestedOp) {
for (OpOperand &operand : nestedOp->getOpOperands()) {
operands.push_back(&operand);
}
});
for (OpOperand *operand : operands) {
Operation *nestedNewOp = mapping.lookup(operand->getOwner());
if (operand->get() == forOp.getInductionVar()) {
rewriter.setInsertionPoint(newOp);
Type t = step.getType();
Value offset = rewriter.create<arith::MulIOp>(
forOp.getLoc(), step,
rewriter.create<arith::ConstantOp>(
forOp.getLoc(),
rewriter.getIntegerAttr(t, maxStage - stages[op])));
Value iv = rewriter.create<arith::AddIOp>(
forOp.getLoc(), newForOp.getInductionVar(), offset);
nestedNewOp->setOperand(operand->getOperandNumber(), iv);
rewriter.setInsertionPointAfter(newOp);
continue;
}
Value source = operand->get();
auto arg = dyn_cast<BlockArgument>(source);
if (arg && arg.getOwner() == forOp.getBody()) {
Value ret = forOp.getBody()->getTerminator()->getOperand(
arg.getArgNumber() - 1);
if (forOp.isDefinedOutsideOfLoop(ret)) {
if (useStage != maxStage) {
nestedNewOp->setOperand(operand->getOperandNumber(), ret);
}
continue;
}
Operation *dep = ret.getDefiningOp();
if (!dep)
continue;
auto stageDep = stages.find(dep);
if (stageDep == stages.end() || stageDep->second == useStage)
continue;
if (stageDep->second == useStage + 1) {
nestedNewOp->setOperand(operand->getOperandNumber(),
mapping.lookupOrDefault(ret));
continue;
}
source = ret;
}
Operation *def = source.getDefiningOp();
if (!def)
continue;
auto stageDef = stages.find(def);
if (stageDef == stages.end() || stageDef->second == useStage)
continue;
auto remap = loopArgMap.find(
std::make_pair(operand->get(), useStage - stageDef->second));
assert(remap != loopArgMap.end());
nestedNewOp->setOperand(operand->getOperandNumber(),
newForOp.getRegionIterArgs()[remap->second]);
}
if (predicates[useStage]) {
OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[useStage]);
if (!newOp)
return failure();
for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
mapping.map(std::get<0>(values), std::get<1>(values));
}
if (annotateFn)
annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Kernel, 0);
}
llvm::SmallVector<Value> yieldOperands;
for (OpOperand &yieldOperand :
forOp.getBody()->getTerminator()->getOpOperands()) {
Value source = mapping.lookupOrDefault(yieldOperand.get());
if (!peelEpilogue &&
!forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
if (def) {
auto defStage = stages.find(def);
if (defStage != stages.end() && defStage->second < maxStage) {
Value pred = predicates[defStage->second];
source = rewriter.create<arith::SelectOp>(
pred.getLoc(), pred, source,
newForOp.getBody()
->getArguments()[yieldOperand.getOperandNumber() + 1]);
}
}
}
yieldOperands.push_back(source);
}
for (auto &it : crossStageValues) {
int64_t version = maxStage - it.second.lastUseStage + 1;
unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
for (unsigned i = 1; i < numVersionReturned; i++) {
setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
version++);
yieldOperands.push_back(
newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
newForOp.getNumInductionVars()]);
}
setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
version++);
yieldOperands.push_back(mapping.lookupOrDefault(it.first));
}
for (const auto &retVal :
llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
Operation *def = retVal.value().getDefiningOp();
auto defStage = stages.find(def);
if (defStage == stages.end()) {
for (unsigned int stage = 1; stage <= maxStage; stage++)
setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
retVal.value(), stage);
} else if (defStage->second > 0) {
setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
newForOp->getResult(retVal.index()),
maxStage - defStage->second + 1);
}
}
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
return success();
}
LogicalResult
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
llvm::SmallVector<Value> &returnValues) {
Location loc = forOp.getLoc();
Type t = lb.getType();
auto createConst = [&](int v) {
return rewriter.create<arith::ConstantOp>(loc,
rewriter.getIntegerAttr(t, v));
};
Value zero = createConst(0);
Value one = createConst(1);
Value stepLessZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, step, zero);
Value stepDecr =
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1));
Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
Value rangeDecr =
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations,
createConst(maxStage));
iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 1; i <= maxStage; i++) {
Value newlastIter = rewriter.create<arith::AddIOp>(
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
setValueMapping(forOp.getInductionVar(), newlastIter, i);
iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
if (dynamicLoop) {
predicates[i] = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, totalIterations, createConst(i));
}
}
for (int64_t i = 1; i <= maxStage; i++) {
SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
for (Operation *op : opOrder) {
if (stages[op] < i)
continue;
unsigned currentVersion = maxStage - stages[op] + i;
unsigned nextVersion = currentVersion + 1;
Operation *newOp =
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
auto it = valueMapping.find(newOperand->get());
if (it != valueMapping.end()) {
Value replacement = it->second[currentVersion];
newOperand->set(replacement);
}
});
if (dynamicLoop) {
OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
if (!newOp)
return failure();
}
if (annotateFn)
annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Epilogue,
i - 1);
for (auto [opRes, newRes] :
llvm::zip(op->getResults(), newOp->getResults())) {
setValueMapping(opRes, newRes, currentVersion);
for (OpOperand &operand :
forOp.getBody()->getTerminator()->getOpOperands()) {
if (operand.get() != opRes)
continue;
unsigned ri = operand.getOperandNumber();
returnValues[ri] = newRes;
Value mapVal = forOp.getRegionIterArgs()[ri];
returnMap[ri] = std::make_pair(mapVal, currentVersion);
if (nextVersion <= maxStage)
setValueMapping(mapVal, newRes, nextVersion);
}
}
}
if (dynamicLoop) {
for (auto pair : llvm::enumerate(returnValues)) {
unsigned ri = pair.index();
auto [mapVal, currentVersion] = returnMap[ri];
if (mapVal) {
unsigned nextVersion = currentVersion + 1;
Value pred = predicates[currentVersion];
Value prevValue = valueMapping[mapVal][currentVersion];
auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
prevValue);
returnValues[ri] = selOp;
if (nextVersion <= maxStage)
setValueMapping(mapVal, selOp, nextVersion);
}
}
}
}
return success();
}
void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
auto it = valueMapping.find(key);
if (it == valueMapping.end())
it =
valueMapping
.insert(std::make_pair(key, llvm::SmallVector<Value>(maxStage + 1)))
.first;
it->second[idx] = el;
}
}
FailureOr<ForOp>
mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
const triton::PipeliningOption &options,
bool *modifiedIR) {
if (modifiedIR)
*modifiedIR = false;
LoopPipelinerInternal pipeliner;
if (!pipeliner.initializeLoopInfo(forOp, options))
return failure();
if (modifiedIR)
*modifiedIR = true;
if (failed(pipeliner.emitPrologue(rewriter)))
return failure();
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
crossStageValues = pipeliner.analyzeCrossStageValues();
llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
ForOp newForOp =
pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
rewriter)))
return failure();
llvm::SmallVector<Value> returnValues =
newForOp.getResults().take_front(forOp->getNumResults());
if (options.peelEpilogue) {
rewriter.setInsertionPointAfter(newForOp);
if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
return failure();
}
if (forOp->getNumResults() > 0)
rewriter.replaceOp(forOp, returnValues);
else
rewriter.eraseOp(forOp);
return newForOp;
}
Value mlir::triton::emitPredicateForStage(RewriterBase &rewriter,
Value inductionVar, Value upperBound,
Value step, uint64_t maxStage,
uint64_t stage) {
auto loc = inductionVar.getLoc();
auto type = inductionVar.getType();
Value c = rewriter.create<arith::SubIOp>(
loc, upperBound,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(type, maxStage - stage))));
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
inductionVar, c);
}