#include "mlir/Dialect/Async/Passes.h"
#include "PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Async/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
#include <utility>
namespace mlir {
#define GEN_PASS_DEF_ASYNCPARALLELFOR
#include "mlir/Dialect/Async/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::async;
#define DEBUG_TYPE "async-parallel-for"
namespace {
struct AsyncParallelForPass
: public impl::AsyncParallelForBase<AsyncParallelForPass> {
AsyncParallelForPass() = default;
AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads,
int32_t minTaskSize) {
this->asyncDispatch = asyncDispatch;
this->numWorkerThreads = numWorkerThreads;
this->minTaskSize = minTaskSize;
}
void runOnOperation() override;
};
struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
public:
AsyncParallelForRewrite(
MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads,
AsyncMinTaskSizeComputationFunction computeMinTaskSize)
: OpRewritePattern(ctx), asyncDispatch(asyncDispatch),
numWorkerThreads(numWorkerThreads),
computeMinTaskSize(std::move(computeMinTaskSize)) {}
LogicalResult matchAndRewrite(scf::ParallelOp op,
PatternRewriter &rewriter) const override;
private:
bool asyncDispatch;
int32_t numWorkerThreads;
AsyncMinTaskSizeComputationFunction computeMinTaskSize;
};
struct ParallelComputeFunctionType {
FunctionType type;
SmallVector<Value> captures;
};
struct ParallelComputeFunctionArgs {
BlockArgument blockIndex();
BlockArgument blockSize();
ArrayRef<BlockArgument> tripCounts();
ArrayRef<BlockArgument> lowerBounds();
ArrayRef<BlockArgument> upperBounds();
ArrayRef<BlockArgument> steps();
ArrayRef<BlockArgument> captures();
unsigned numLoops;
ArrayRef<BlockArgument> args;
};
struct ParallelComputeFunctionBounds {
SmallVector<IntegerAttr> tripCounts;
SmallVector<IntegerAttr> lowerBounds;
SmallVector<IntegerAttr> upperBounds;
SmallVector<IntegerAttr> steps;
};
struct ParallelComputeFunction {
unsigned numLoops;
func::FuncOp func;
llvm::SmallVector<Value> captures;
};
}
BlockArgument ParallelComputeFunctionArgs::blockIndex() { return args[0]; }
BlockArgument ParallelComputeFunctionArgs::blockSize() { return args[1]; }
ArrayRef<BlockArgument> ParallelComputeFunctionArgs::tripCounts() {
return args.drop_front(2).take_front(numLoops);
}
ArrayRef<BlockArgument> ParallelComputeFunctionArgs::lowerBounds() {
return args.drop_front(2 + 1 * numLoops).take_front(numLoops);
}
ArrayRef<BlockArgument> ParallelComputeFunctionArgs::upperBounds() {
return args.drop_front(2 + 2 * numLoops).take_front(numLoops);
}
ArrayRef<BlockArgument> ParallelComputeFunctionArgs::steps() {
return args.drop_front(2 + 3 * numLoops).take_front(numLoops);
}
ArrayRef<BlockArgument> ParallelComputeFunctionArgs::captures() {
return args.drop_front(2 + 4 * numLoops);
}
template <typename ValueRange>
static SmallVector<IntegerAttr> integerConstants(ValueRange values) {
SmallVector<IntegerAttr> attrs(values.size());
for (unsigned i = 0; i < values.size(); ++i)
matchPattern(values[i], m_Constant(&attrs[i]));
return attrs;
}
static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index,
ArrayRef<Value> tripCounts) {
SmallVector<Value> coords(tripCounts.size());
assert(!tripCounts.empty() && "tripCounts must be not empty");
for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) {
coords[i] = b.create<arith::RemSIOp>(index, tripCounts[i]);
index = b.create<arith::DivSIOp>(index, tripCounts[i]);
}
return coords;
}
static ParallelComputeFunctionType
getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) {
llvm::SetVector<Value> captures;
getUsedValuesDefinedAbove(op.getRegion(), op.getRegion(), captures);
SmallVector<Type> inputs;
inputs.reserve(2 + 4 * op.getNumLoops() + captures.size());
Type indexTy = rewriter.getIndexType();
inputs.push_back(indexTy);
inputs.push_back(indexTy);
for (unsigned i = 0; i < op.getNumLoops(); ++i)
inputs.push_back(indexTy);
for (unsigned i = 0; i < op.getNumLoops(); ++i) {
inputs.push_back(indexTy);
inputs.push_back(indexTy);
inputs.push_back(indexTy);
}
for (Value capture : captures)
inputs.push_back(capture.getType());
SmallVector<Value> capturesVector(captures.begin(), captures.end());
return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector};
}
static ParallelComputeFunction createParallelComputeFunction(
scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds,
unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter) {
OpBuilder::InsertionGuard guard(rewriter);
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
ModuleOp module = op->getParentOfType<ModuleOp>();
ParallelComputeFunctionType computeFuncType =
getParallelComputeFunctionType(op, rewriter);
FunctionType type = computeFuncType.type;
func::FuncOp func = func::FuncOp::create(
op.getLoc(),
numBlockAlignedInnerLoops > 0 ? "parallel_compute_fn_with_aligned_loops"
: "parallel_compute_fn",
type);
func.setPrivate();
SymbolTable symbolTable(module);
symbolTable.insert(func);
rewriter.getListener()->notifyOperationInserted(func, {});
Block *block =
b.createBlock(&func.getBody(), func.begin(), type.getInputs(),
SmallVector<Location>(type.getNumInputs(), op.getLoc()));
b.setInsertionPointToEnd(block);
ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()};
BlockArgument blockIndex = args.blockIndex();
BlockArgument blockSize = args.blockSize();
Value c0 = b.create<arith::ConstantIndexOp>(0);
Value c1 = b.create<arith::ConstantIndexOp>(1);
auto values = [&](ArrayRef<BlockArgument> args, ArrayRef<IntegerAttr> attrs) {
return llvm::to_vector(
llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value {
if (IntegerAttr attr = std::get<1>(tuple))
return b.create<arith::ConstantOp>(attr);
return std::get<0>(tuple);
}));
};
auto tripCounts = values(args.tripCounts(), bounds.tripCounts);
auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds);
auto steps = values(args.steps(), bounds.steps);
ArrayRef<BlockArgument> captures = args.captures();
Value tripCount = tripCounts[0];
for (unsigned i = 1; i < tripCounts.size(); ++i)
tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]);
Value blockFirstIndex = b.create<arith::MulIOp>(blockIndex, blockSize);
Value blockEnd0 = b.create<arith::AddIOp>(blockFirstIndex, blockSize);
Value blockEnd1 = b.create<arith::MinSIOp>(blockEnd0, tripCount);
Value blockLastIndex = b.create<arith::SubIOp>(blockEnd1, c1);
auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts);
auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts);
SmallVector<Value> blockEndCoord(op.getNumLoops());
for (size_t i = 0; i < blockLastCoord.size(); ++i)
blockEndCoord[i] = b.create<arith::AddIOp>(blockLastCoord[i], c1);
using LoopBodyBuilder =
std::function<void(OpBuilder &, Location, Value, ValueRange)>;
using LoopNestBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>;
SmallVector<Value> computeBlockInductionVars(op.getNumLoops());
SmallVector<Value> isBlockFirstCoord(op.getNumLoops());
SmallVector<Value> isBlockLastCoord(op.getNumLoops());
LoopNestBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
return [&, loopIdx](OpBuilder &nestedBuilder, Location loc, Value iv,
ValueRange args) {
ImplicitLocOpBuilder b(loc, nestedBuilder);
computeBlockInductionVars[loopIdx] = b.create<arith::AddIOp>(
lowerBounds[loopIdx], b.create<arith::MulIOp>(iv, steps[loopIdx]));
isBlockFirstCoord[loopIdx] = b.create<arith::CmpIOp>(
arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
isBlockLastCoord[loopIdx] = b.create<arith::CmpIOp>(
arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
if (loopIdx > 0) {
isBlockFirstCoord[loopIdx] = b.create<arith::AndIOp>(
isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
isBlockLastCoord[loopIdx] = b.create<arith::AndIOp>(
isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
}
if (loopIdx < op.getNumLoops() - 1) {
if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
b.create<scf::ForOp>(c0, tripCounts[loopIdx + 1], c1, ValueRange(),
workLoopBuilder(loopIdx + 1));
} else {
auto lb = b.create<arith::SelectOp>(isBlockFirstCoord[loopIdx],
blockFirstCoord[loopIdx + 1], c0);
auto ub = b.create<arith::SelectOp>(isBlockLastCoord[loopIdx],
blockEndCoord[loopIdx + 1],
tripCounts[loopIdx + 1]);
b.create<scf::ForOp>(lb, ub, c1, ValueRange(),
workLoopBuilder(loopIdx + 1));
}
b.create<scf::YieldOp>(loc);
return;
}
IRMapping mapping;
mapping.map(op.getInductionVars(), computeBlockInductionVars);
mapping.map(computeFuncType.captures, captures);
for (auto &bodyOp : op.getRegion().front().without_terminator())
b.clone(bodyOp, mapping);
b.create<scf::YieldOp>(loc);
};
};
b.create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(),
workLoopBuilder(0));
b.create<func::ReturnOp>(ValueRange());
return {op.getNumLoops(), func, std::move(computeFuncType.captures)};
}
static func::FuncOp
createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
PatternRewriter &rewriter) {
OpBuilder::InsertionGuard guard(rewriter);
Location loc = computeFunc.func.getLoc();
ImplicitLocOpBuilder b(loc, rewriter);
ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
ArrayRef<Type> computeFuncInputTypes =
computeFunc.func.getFunctionType().getInputs();
SmallVector<Type> inputTypes;
inputTypes.push_back(async::GroupType::get(rewriter.getContext()));
inputTypes.push_back(rewriter.getIndexType());
inputTypes.append(computeFuncInputTypes.begin(), computeFuncInputTypes.end());
FunctionType type = rewriter.getFunctionType(inputTypes, TypeRange());
func::FuncOp func = func::FuncOp::create(loc, "async_dispatch_fn", type);
func.setPrivate();
SymbolTable symbolTable(module);
symbolTable.insert(func);
rewriter.getListener()->notifyOperationInserted(func, {});
Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(),
SmallVector<Location>(type.getNumInputs(), loc));
b.setInsertionPointToEnd(block);
Type indexTy = b.getIndexType();
Value c1 = b.create<arith::ConstantIndexOp>(1);
Value c2 = b.create<arith::ConstantIndexOp>(2);
Value group = block->getArgument(0);
Value blockStart = block->getArgument(1);
Value blockEnd = block->getArgument(2);
SmallVector<Type> types = {indexTy, indexTy};
SmallVector<Value> operands = {blockStart, blockEnd};
SmallVector<Location> locations = {loc, loc};
scf::WhileOp whileOp = b.create<scf::WhileOp>(types, operands);
Block *before = b.createBlock(&whileOp.getBefore(), {}, types, locations);
Block *after = b.createBlock(&whileOp.getAfter(), {}, types, locations);
{
b.setInsertionPointToEnd(before);
Value start = before->getArgument(0);
Value end = before->getArgument(1);
Value distance = b.create<arith::SubIOp>(end, start);
Value dispatch =
b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, distance, c1);
b.create<scf::ConditionOp>(dispatch, before->getArguments());
}
{
b.setInsertionPointToEnd(after);
Value start = after->getArgument(0);
Value end = after->getArgument(1);
Value distance = b.create<arith::SubIOp>(end, start);
Value halfDistance = b.create<arith::DivSIOp>(distance, c2);
Value midIndex = b.create<arith::AddIOp>(start, halfDistance);
auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
Location executeLoc, ValueRange executeArgs) {
SmallVector<Value> operands{block->getArguments().begin(),
block->getArguments().end()};
operands[1] = midIndex;
operands[2] = end;
executeBuilder.create<func::CallOp>(executeLoc, func.getSymName(),
func.getResultTypes(), operands);
executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
};
auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
executeBodyBuilder);
b.create<AddToGroupOp>(indexTy, execute.getToken(), group);
b.create<scf::YieldOp>(ValueRange({start, midIndex}));
}
b.setInsertionPointAfter(whileOp);
auto forwardedInputs = block->getArguments().drop_front(3);
SmallVector<Value> computeFuncOperands = {blockStart};
computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
b.create<func::CallOp>(computeFunc.func.getSymName(),
computeFunc.func.getResultTypes(),
computeFuncOperands);
b.create<func::ReturnOp>(ValueRange());
return func;
}
static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
ParallelComputeFunction ¶llelComputeFunction,
scf::ParallelOp op, Value blockSize,
Value blockCount,
const SmallVector<Value> &tripCounts) {
MLIRContext *ctx = op->getContext();
func::FuncOp asyncDispatchFunction =
createAsyncDispatchFunction(parallelComputeFunction, rewriter);
Value c0 = b.create<arith::ConstantIndexOp>(0);
Value c1 = b.create<arith::ConstantIndexOp>(1);
auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) {
operands.append(tripCounts);
operands.append(op.getLowerBound().begin(), op.getLowerBound().end());
operands.append(op.getUpperBound().begin(), op.getUpperBound().end());
operands.append(op.getStep().begin(), op.getStep().end());
operands.append(parallelComputeFunction.captures);
};
Value isSingleBlock =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, blockCount, c1);
auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
ImplicitLocOpBuilder b(loc, nestedBuilder);
SmallVector<Value> operands = {c0, blockSize};
appendBlockComputeOperands(operands);
b.create<func::CallOp>(parallelComputeFunction.func.getSymName(),
parallelComputeFunction.func.getResultTypes(),
operands);
b.create<scf::YieldOp>();
};
auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
ImplicitLocOpBuilder b(loc, nestedBuilder);
Value groupSize = b.create<arith::SubIOp>(blockCount, c1);
Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
SmallVector<Value> operands = {group, c0, blockCount, blockSize};
appendBlockComputeOperands(operands);
b.create<func::CallOp>(asyncDispatchFunction.getSymName(),
asyncDispatchFunction.getResultTypes(), operands);
b.create<AwaitAllOp>(group);
b.create<scf::YieldOp>();
};
b.create<scf::IfOp>(isSingleBlock, syncDispatch, asyncDispatch);
}
static void
doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
ParallelComputeFunction ¶llelComputeFunction,
scf::ParallelOp op, Value blockSize, Value blockCount,
const SmallVector<Value> &tripCounts) {
MLIRContext *ctx = op->getContext();
func::FuncOp compute = parallelComputeFunction.func;
Value c0 = b.create<arith::ConstantIndexOp>(0);
Value c1 = b.create<arith::ConstantIndexOp>(1);
Value groupSize = b.create<arith::SubIOp>(blockCount, c1);
Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
using LoopBodyBuilder =
std::function<void(OpBuilder &, Location, Value, ValueRange)>;
auto computeFuncOperands = [&](Value blockIndex) -> SmallVector<Value> {
SmallVector<Value> computeFuncOperands = {blockIndex, blockSize};
computeFuncOperands.append(tripCounts);
computeFuncOperands.append(op.getLowerBound().begin(),
op.getLowerBound().end());
computeFuncOperands.append(op.getUpperBound().begin(),
op.getUpperBound().end());
computeFuncOperands.append(op.getStep().begin(), op.getStep().end());
computeFuncOperands.append(parallelComputeFunction.captures);
return computeFuncOperands;
};
LoopBodyBuilder loopBuilder = [&](OpBuilder &loopBuilder, Location loc,
Value iv, ValueRange args) {
ImplicitLocOpBuilder b(loc, loopBuilder);
auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
Location executeLoc, ValueRange executeArgs) {
executeBuilder.create<func::CallOp>(executeLoc, compute.getSymName(),
compute.getResultTypes(),
computeFuncOperands(iv));
executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
};
auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
executeBodyBuilder);
b.create<AddToGroupOp>(rewriter.getIndexType(), execute.getToken(), group);
b.create<scf::YieldOp>();
};
b.create<scf::ForOp>(c1, blockCount, c1, ValueRange(), loopBuilder);
b.create<func::CallOp>(compute.getSymName(), compute.getResultTypes(),
computeFuncOperands(c0));
b.create<AwaitAllOp>(group);
}
LogicalResult
AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
PatternRewriter &rewriter) const {
if (op.getNumReductions() != 0)
return failure();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value minTaskSize = computeMinTaskSize(b, op);
cloneConstantsIntoTheRegion(op.getRegion(), rewriter);
SmallVector<Value> tripCounts(op.getNumLoops());
for (size_t i = 0; i < op.getNumLoops(); ++i) {
auto lb = op.getLowerBound()[i];
auto ub = op.getUpperBound()[i];
auto step = op.getStep()[i];
auto range = b.createOrFold<arith::SubIOp>(ub, lb);
tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step);
}
Value tripCount = tripCounts[0];
for (size_t i = 1; i < tripCounts.size(); ++i)
tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]);
Value c0 = b.create<arith::ConstantIndexOp>(0);
Value isZeroIterations =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, tripCount, c0);
auto noOp = [&](OpBuilder &nestedBuilder, Location loc) {
nestedBuilder.create<scf::YieldOp>(loc);
};
auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) {
ImplicitLocOpBuilder b(loc, nestedBuilder);
ParallelComputeFunctionBounds staticBounds = {
integerConstants(tripCounts),
integerConstants(op.getLowerBound()),
integerConstants(op.getUpperBound()),
integerConstants(op.getStep()),
};
static constexpr int64_t maxUnrollableIterations = 512;
int numUnrollableLoops = 0;
auto getInt = [](IntegerAttr attr) { return attr ? attr.getInt() : 0; };
SmallVector<int64_t> numIterations(op.getNumLoops());
numIterations.back() = getInt(staticBounds.tripCounts.back());
for (int i = op.getNumLoops() - 2; i >= 0; --i) {
int64_t tripCount = getInt(staticBounds.tripCounts[i]);
int64_t innerIterations = numIterations[i + 1];
numIterations[i] = tripCount * innerIterations;
if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
numUnrollableLoops++;
}
Value numWorkerThreadsVal;
if (numWorkerThreads >= 0)
numWorkerThreadsVal = b.create<arith::ConstantIndexOp>(numWorkerThreads);
else
numWorkerThreadsVal = b.create<async::RuntimeNumWorkerThreadsOp>();
const SmallVector<std::pair<int, float>> overshardingBrackets = {
{4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
const float initialOvershardingFactor = 8.0f;
Value scalingFactor = b.create<arith::ConstantFloatOp>(
llvm::APFloat(initialOvershardingFactor), b.getF32Type());
for (const std::pair<int, float> &p : overshardingBrackets) {
Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first);
Value inBracket = b.create<arith::CmpIOp>(
arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
Value bracketScalingFactor = b.create<arith::ConstantFloatOp>(
llvm::APFloat(p.second), b.getF32Type());
scalingFactor = b.create<arith::SelectOp>(inBracket, bracketScalingFactor,
scalingFactor);
}
Value numWorkersIndex =
b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal);
Value numWorkersFloat =
b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex);
Value scaledNumWorkers =
b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
Value scaledNumInt =
b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers);
Value scaledWorkers =
b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt);
Value maxComputeBlocks = b.create<arith::MaxSIOp>(
b.create<arith::ConstantIndexOp>(1), scaledWorkers);
Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks);
Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize);
Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch;
Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) {
ParallelComputeFunction compute =
createParallelComputeFunction(op, staticBounds, 0, rewriter);
ImplicitLocOpBuilder b(loc, nestedBuilder);
doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
b.create<scf::YieldOp>();
};
auto dispatchBlockAligned = [&](OpBuilder &nestedBuilder, Location loc) {
ParallelComputeFunction compute = createParallelComputeFunction(
op, staticBounds, numUnrollableLoops, rewriter);
ImplicitLocOpBuilder b(loc, nestedBuilder);
Value numIters = b.create<arith::ConstantIndexOp>(
numIterations[op.getNumLoops() - numUnrollableLoops]);
Value alignedBlockSize = b.create<arith::MulIOp>(
b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
tripCounts);
b.create<scf::YieldOp>();
};
if (numUnrollableLoops > 0) {
Value numIters = b.create<arith::ConstantIndexOp>(
numIterations[op.getNumLoops() - numUnrollableLoops]);
Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>(
arith::CmpIPredicate::sge, blockSize, numIters);
b.create<scf::IfOp>(useBlockAlignedComputeFn, dispatchBlockAligned,
dispatchDefault);
b.create<scf::YieldOp>();
} else {
dispatchDefault(b, loc);
}
};
b.create<scf::IfOp>(isZeroIterations, noOp, dispatch);
rewriter.eraseOp(op);
return success();
}
void AsyncParallelForPass::runOnOperation() {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateAsyncParallelForPatterns(
patterns, asyncDispatch, numWorkerThreads,
[&](ImplicitLocOpBuilder builder, scf::ParallelOp op) {
return builder.create<arith::ConstantIndexOp>(minTaskSize);
});
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<Pass> mlir::createAsyncParallelForPass() {
return std::make_unique<AsyncParallelForPass>();
}
std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch,
int32_t numWorkerThreads,
int32_t minTaskSize) {
return std::make_unique<AsyncParallelForPass>(asyncDispatch, numWorkerThreads,
minTaskSize);
}
void mlir::async::populateAsyncParallelForPatterns(
RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads,
const AsyncMinTaskSizeComputationFunction &computeMinTaskSize) {
MLIRContext *ctx = patterns.getContext();
patterns.add<AsyncParallelForRewrite>(ctx, asyncDispatch, numWorkerThreads,
computeMinTaskSize);
}