#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/FormatVariadic.h"
#include <numeric>
#include <utility>
using namespace mlir;
using namespace mlir::vector;
static AffineMap calculateImplicitMap(VectorType sequentialType,
VectorType distributedType) {
SmallVector<AffineExpr> perm;
perm.reserve(1);
for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
}
auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
distributedType.getContext());
return map;
}
namespace {
struct DistributedLoadStoreHelper {
DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
Value laneId, Value zero)
: sequentialVal(sequentialVal), distributedVal(distributedVal),
laneId(laneId), zero(zero) {
sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
if (sequentialVectorType && distributedVectorType)
distributionMap =
calculateImplicitMap(sequentialVectorType, distributedVectorType);
}
Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
int64_t distributedSize = distributedVectorType.getDimSize(index);
AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
ArrayRef<Value>{laneId});
}
Operation *buildStore(RewriterBase &b, Location loc, Value val,
Value buffer) {
assert((val == distributedVal || val == sequentialVal) &&
"Must store either the preregistered distributed or the "
"preregistered sequential value.");
if (!isa<VectorType>(val.getType()))
return b.create<memref::StoreOp>(loc, val, buffer, zero);
int64_t rank = sequentialVectorType.getRank();
SmallVector<Value> indices(rank, zero);
if (val == distributedVal) {
for (auto dimExpr : distributionMap.getResults()) {
int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
indices[index] = buildDistributedOffset(b, loc, index);
}
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferWriteOp>(
loc, val, buffer, indices,
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
if (!isa<VectorType>(type))
return b.create<memref::LoadOp>(loc, buffer, zero);
assert((type == distributedVectorType || type == sequentialVectorType) &&
"Must store either the preregistered distributed or the "
"preregistered sequential type.");
SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
if (type == distributedVectorType) {
for (auto dimExpr : distributionMap.getResults()) {
int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
indices[index] = buildDistributedOffset(b, loc, index);
}
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferReadOp>(
loc, cast<VectorType>(type), buffer, indices,
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
Value sequentialVal, distributedVal, laneId, zero;
VectorType sequentialVectorType, distributedVectorType;
AffineMap distributionMap;
};
}
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(warpOp);
auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
Region &opBody = warpOp.getBodyRegion();
Region &newOpBody = newWarpOp.getBodyRegion();
Block &newOpFirstBlock = newOpBody.front();
rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
rewriter.eraseBlock(&newOpFirstBlock);
assert(newWarpOp.getWarpRegion().hasOneBlock() &&
"expected WarpOp with single block");
auto yield =
cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
rewriter.modifyOpInPlace(
yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
return newWarpOp;
}
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes,
llvm::SmallVector<size_t> &indices) {
SmallVector<Type> types(warpOp.getResultTypes().begin(),
warpOp.getResultTypes().end());
auto yield = cast<vector::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
yield.getOperands().end());
for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
if (yieldValues.insert(std::get<0>(newRet))) {
types.push_back(std::get<1>(newRet));
indices.push_back(yieldValues.size() - 1);
} else {
for (auto [idx, yieldOperand] :
llvm::enumerate(yieldValues.getArrayRef())) {
if (yieldOperand == std::get<0>(newRet)) {
indices.push_back(idx);
break;
}
}
}
}
yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, yieldValues.getArrayRef(), types);
rewriter.replaceOp(warpOp,
newWarpOp.getResults().take_front(warpOp.getNumResults()));
return newWarpOp;
}
static bool canBeHoisted(Operation *op,
function_ref<bool(Value)> definedOutside) {
return llvm::all_of(op->getOperands(), definedOutside) &&
isMemoryEffectFree(op) && op->getNumRegions() == 0;
}
static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
const std::function<bool(Operation *)> &fn) {
auto yield = cast<vector::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
for (OpOperand &yieldOperand : yield->getOpOperands()) {
Value yieldValues = yieldOperand.get();
Operation *definedOp = yieldValues.getDefiningOp();
if (definedOp && fn(definedOp)) {
if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
return &yieldOperand;
}
}
return {};
}
static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
Location loc, Operation *op,
ArrayRef<Value> operands,
ArrayRef<Type> resultTypes) {
OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
op->getAttrs());
return rewriter.create(res);
}
namespace {
struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpToScfIfPattern(MLIRContext *context,
const WarpExecuteOnLane0LoweringOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
options(options) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
assert(warpOp.getBodyRegion().hasOneBlock() &&
"expected WarpOp with single block");
Block *warpOpBody = &warpOp.getBodyRegion().front();
Location loc = warpOp.getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(warpOp);
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value isLane0 = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
false);
rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
SmallVector<Value> bbArgReplacements;
for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
Value sequentialVal = warpOpBody->getArgument(it.index());
Value distributedVal = it.value();
DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
warpOp.getLaneid(), c0);
rewriter.setInsertionPoint(ifOp);
Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
sequentialVal.getType());
helper.buildStore(rewriter, loc, distributedVal, buffer);
rewriter.setInsertionPointToStart(ifOp.thenBlock());
bbArgReplacements.push_back(
helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
}
if (!warpOp.getArgs().empty()) {
rewriter.setInsertionPoint(ifOp);
options.warpSyncronizationFn(loc, rewriter, warpOp);
}
rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
SmallVector<Value> replacements;
auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
Location yieldLoc = yieldOp.getLoc();
for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
Value sequentialVal = it.value();
Value distributedVal = warpOp->getResult(it.index());
DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
warpOp.getLaneid(), c0);
rewriter.setInsertionPoint(ifOp);
Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
sequentialVal.getType());
rewriter.setInsertionPoint(yieldOp);
helper.buildStore(rewriter, loc, sequentialVal, buffer);
rewriter.setInsertionPointAfter(ifOp);
replacements.push_back(
helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
}
if (!yieldOp.getOperands().empty()) {
rewriter.setInsertionPointAfter(ifOp);
options.warpSyncronizationFn(loc, rewriter, warpOp);
}
rewriter.eraseOp(yieldOp);
rewriter.setInsertionPointToEnd(ifOp.thenBlock());
rewriter.create<scf::YieldOp>(yieldLoc);
rewriter.replaceOp(warpOp, replacements);
return success();
}
private:
const WarpExecuteOnLane0LoweringOptions &options;
};
static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
WarpExecuteOnLane0Op warpOp,
vector::TransferWriteOp writeOp,
VectorType targetType,
VectorType maybeMaskType) {
assert(writeOp->getParentOp() == warpOp &&
"write must be nested immediately under warp");
OpBuilder::InsertionGuard g(rewriter);
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp;
if (maybeMaskType) {
newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
TypeRange{targetType, maybeMaskType}, newRetIndices);
} else {
newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, ValueRange{{writeOp.getVector()}},
TypeRange{targetType}, newRetIndices);
}
rewriter.setInsertionPointAfter(newWarpOp);
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
rewriter.eraseOp(writeOp);
newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
if (maybeMaskType)
newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
return newWriteOp;
}
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
SmallVector<int64_t> targetShape(originalType.getShape().begin(),
originalType.getShape().end());
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);
if (targetShape[position] % warpSize != 0) {
if (warpSize % targetShape[position] != 0) {
return VectorType();
}
warpSize /= targetShape[position];
targetShape[position] = 1;
continue;
}
targetShape[position] = targetShape[position] / warpSize;
warpSize = 1;
break;
}
if (warpSize != 1) {
return VectorType();
}
VectorType targetType =
VectorType::get(targetShape, originalType.getElementType());
return targetType;
}
struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
unsigned maxNumElementsToExtract, PatternBenefit b = 1)
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
distributionMapFn(std::move(fn)),
maxNumElementsToExtract(maxNumElementsToExtract) {}
LogicalResult tryDistributeOp(RewriterBase &rewriter,
vector::TransferWriteOp writeOp,
WarpExecuteOnLane0Op warpOp) const {
VectorType writtenVectorType = writeOp.getVectorType();
if (writtenVectorType.getRank() == 0)
return failure();
AffineMap map = distributionMapFn(writeOp.getVector());
VectorType targetType =
getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
if (!targetType)
return failure();
VectorType maskType;
if (writeOp.getMask()) {
if (!writeOp.getPermutationMap().isMinorIdentity())
return failure();
maskType =
getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
}
vector::TransferWriteOp newWriteOp =
cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
auto newWarpOp =
newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
rewriter.setInsertionPoint(newWriteOp);
SmallVector<OpFoldResult> delinearizedIdSizes;
for (auto [seqSize, distSize] :
llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
}
SmallVector<Value> delinearized;
if (map.getNumResults() > 1) {
delinearized = rewriter
.create<mlir::affine::AffineDelinearizeIndexOp>(
newWarpOp.getLoc(), newWarpOp.getLaneid(),
delinearizedIdSizes)
.getResults();
} else {
delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
}
AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
Location loc = newWriteOp.getLoc();
SmallVector<Value> indices(newWriteOp.getIndices().begin(),
newWriteOp.getIndices().end());
for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
AffineExpr d0, d1;
bindDims(newWarpOp.getContext(), d0, d1);
auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
if (!indexExpr)
continue;
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
Value laneId = delinearized[vectorPos];
auto scale =
rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
indices[indexPos] = affine::makeComposedAffineApply(
rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
}
newWriteOp.getIndicesMutable().assign(indices);
return success();
}
LogicalResult tryExtractOp(RewriterBase &rewriter,
vector::TransferWriteOp writeOp,
WarpExecuteOnLane0Op warpOp) const {
Location loc = writeOp.getLoc();
VectorType vecType = writeOp.getVectorType();
if (vecType.getNumElements() > maxNumElementsToExtract) {
return rewriter.notifyMatchFailure(
warpOp,
llvm::formatv(
"writes more elements ({0}) than allowed to extract ({1})",
vecType.getNumElements(), maxNumElementsToExtract));
}
if (llvm::all_of(warpOp.getOps(),
llvm::IsaPred<vector::TransferWriteOp, vector::YieldOp>))
return failure();
SmallVector<Value> yieldValues = {writeOp.getVector()};
SmallVector<Type> retTypes = {vecType};
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, yieldValues, retTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
Block &body = secondWarpOp.getBodyRegion().front();
rewriter.setInsertionPointToStart(&body);
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
rewriter.eraseOp(writeOp);
rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
return success();
}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<vector::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
if (!writeOp)
return failure();
Value maybeMask = writeOp.getMask();
if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
return writeOp.getVector() == value ||
(maybeMask && maybeMask == value) ||
warpOp.isDefinedOutsideOfRegion(value);
}))
return failure();
if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
return success();
if (writeOp.getMask())
return failure();
if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
return success();
return failure();
}
private:
DistributionMapFn distributionMapFn;
unsigned maxNumElementsToExtract = 1;
};
struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
return OpTrait::hasElementwiseMappableTraits(op);
});
if (!yieldOperand)
return failure();
Operation *elementWise = yieldOperand->get().getDefiningOp();
unsigned operandIndex = yieldOperand->getOperandNumber();
Value distributedVal = warpOp.getResult(operandIndex);
SmallVector<Value> yieldValues;
SmallVector<Type> retTypes;
Location loc = warpOp.getLoc();
for (OpOperand &operand : elementWise->getOpOperands()) {
Type targetType;
if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
auto operandType = cast<VectorType>(operand.get().getType());
targetType =
VectorType::get(vecType.getShape(), operandType.getElementType());
} else {
auto operandType = operand.get().getType();
assert(!isa<VectorType>(operandType) &&
"unexpected yield of vector from op with scalar result type");
targetType = operandType;
}
retTypes.push_back(targetType);
yieldValues.push_back(operand.get());
}
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, yieldValues, retTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
SmallVector<Value> newOperands(elementWise->getOperands().begin(),
elementWise->getOperands().end());
for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
}
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, elementWise, newOperands,
{newWarpOp.getResult(operandIndex).getType()});
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
newOp->getResult(0));
return success();
}
};
struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *yieldOperand =
getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
if (!yieldOperand)
return failure();
auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
if (!dense)
return failure();
rewriter.startOpModification(warpOp);
unsigned operandIndex = yieldOperand->getOperandNumber();
Attribute scalarAttr = dense.getSplatValue<Attribute>();
auto newAttr = DenseElementsAttr::get(
cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
Location loc = warpOp.getLoc();
rewriter.setInsertionPointAfter(warpOp);
Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
rewriter.finalizeOpModification(warpOp);
return success();
}
};
bool delinearizeLaneId(OpBuilder &builder, Location loc,
ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> distributedShape, int64_t warpSize,
Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
if (originalShape == distributedShape) {
delinearizedIds.clear();
return true;
}
SmallVector<int64_t> sizes;
for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
if (large % small != 0)
return false;
sizes.push_back(large / small);
}
if (std::accumulate(sizes.begin(), sizes.end(), 1,
std::multiplies<int64_t>()) != warpSize)
return false;
AffineExpr s0, s1;
bindSymbols(builder.getContext(), s0, s1);
int64_t usedThreads = 1;
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
delinearizedIds.assign(sizes.size(), zero);
for (int i = sizes.size() - 1; i >= 0; --i) {
usedThreads *= sizes[i];
if (usedThreads == warpSize) {
delinearizedIds[i] = laneId;
break;
}
delinearizedIds[i] =
affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
laneId = affine::makeComposedAffineApply(
builder, loc, s0.floorDiv(usedThreads), {laneId});
}
return true;
}
struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
return isa<vector::TransferReadOp>(op) && op->hasOneUse();
});
if (!operand)
return rewriter.notifyMatchFailure(
warpOp, "warp result is not a vector.transfer_read op");
auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
return rewriter.notifyMatchFailure(
read, "source must be defined outside of the region");
unsigned operandIndex = operand->getOperandNumber();
Value distributedVal = warpOp.getResult(operandIndex);
SmallVector<Value, 4> indices(read.getIndices().begin(),
read.getIndices().end());
auto sequentialType = cast<VectorType>(read.getResult().getType());
auto distributedType = cast<VectorType>(distributedVal.getType());
AffineMap map = calculateImplicitMap(sequentialType, distributedType);
AffineMap indexMap = map.compose(read.getPermutationMap());
SmallVector<Value> delinearizedIds;
if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
distributedType.getShape(), warpOp.getWarpSize(),
warpOp.getLaneid(), delinearizedIds)) {
return rewriter.notifyMatchFailure(
read, "cannot delinearize lane ID for distribution");
}
assert(!delinearizedIds.empty() || map.getNumResults() == 0);
OpBuilder::InsertionGuard g(rewriter);
SmallVector<Value> additionalResults(indices.begin(), indices.end());
SmallVector<Type> additionalResultTypes(indices.size(),
rewriter.getIndexType());
additionalResults.push_back(read.getPadding());
additionalResultTypes.push_back(read.getPadding().getType());
bool hasMask = false;
if (read.getMask()) {
hasMask = true;
if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
return rewriter.notifyMatchFailure(
read, "non-trivial permutation maps not supported");
VectorType maskType =
getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
additionalResults.push_back(read.getMask());
additionalResultTypes.push_back(maskType);
}
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, additionalResults, additionalResultTypes,
newRetIndices);
distributedVal = newWarpOp.getResult(operandIndex);
SmallVector<Value> newIndices;
for (int64_t i = 0, e = indices.size(); i < e; ++i)
newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
rewriter.setInsertionPointAfter(newWarpOp);
for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
AffineExpr d0, d1;
bindDims(read.getContext(), d0, d1);
auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
if (!indexExpr)
continue;
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
int64_t scale = distributedType.getDimSize(vectorPos);
newIndices[indexPos] = affine::makeComposedAffineApply(
rewriter, read.getLoc(), d0 + scale * d1,
{newIndices[indexPos], delinearizedIds[vectorPos]});
}
Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
Value newMask =
hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
: Value();
auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
read.getPermutationMapAttr(), newPadding, newMask,
read.getInBoundsAttr());
rewriter.replaceAllUsesWith(distributedVal, newRead);
return success();
}
};
struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
SmallVector<Type> newResultTypes;
newResultTypes.reserve(warpOp->getNumResults());
SmallVector<Value> newYieldValues;
newYieldValues.reserve(warpOp->getNumResults());
DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
DenseMap<OpResult, int64_t> dedupResultPositionMap;
auto yield = cast<vector::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
for (OpResult result : warpOp.getResults()) {
Value yieldOperand = yield.getOperand(result.getResultNumber());
auto it = dedupYieldOperandPositionMap.insert(
std::make_pair(yieldOperand, newResultTypes.size()));
dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
if (result.use_empty() || !it.second)
continue;
newResultTypes.push_back(result.getType());
newYieldValues.push_back(yieldOperand);
}
if (yield.getNumOperands() == newYieldValues.size())
return failure();
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, newYieldValues, newResultTypes);
newWarpOp.getBody()->walk([&](Operation *op) {
if (isOpTriviallyDead(op))
rewriter.eraseOp(op);
});
SmallVector<Value> newValues;
newValues.reserve(warpOp->getNumResults());
for (OpResult result : warpOp.getResults()) {
if (result.use_empty())
newValues.push_back(Value());
else
newValues.push_back(
newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
}
rewriter.replaceOp(warpOp, newValues);
return success();
}
};
struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
SmallVector<Value> yieldValues;
auto yield = cast<vector::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Value valForwarded;
unsigned resultIndex;
for (OpOperand &operand : yield->getOpOperands()) {
Value result = warpOp.getResult(operand.getOperandNumber());
if (result.use_empty())
continue;
if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
if (result.getType() != operand.get().getType())
continue;
valForwarded = operand.get();
resultIndex = operand.getOperandNumber();
break;
}
auto arg = dyn_cast<BlockArgument>(operand.get());
if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
continue;
Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
if (result.getType() != warpOperand.getType())
continue;
valForwarded = warpOperand;
resultIndex = operand.getOperandNumber();
break;
}
if (!valForwarded)
return failure();
rewriter.startOpModification(warpOp);
rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
rewriter.finalizeOpModification(warpOp);
return success();
}
};
struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
Location loc = broadcastOp.getLoc();
auto destVecType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
Value broadcastSrc = broadcastOp.getSource();
Type broadcastSrcType = broadcastSrc.getType();
if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
vector::BroadcastableToResult::Success)
return failure();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value broadcasted = rewriter.create<vector::BroadcastOp>(
loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
broadcasted);
return success();
}
};
struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
if (!operand)
return failure();
auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
unsigned int operandNumber = operand->getOperandNumber();
auto castDistributedType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
VectorType castOriginalType = oldCastOp.getSourceVectorType();
VectorType castResultType = castDistributedType;
unsigned castDistributedRank = castDistributedType.getRank();
unsigned castOriginalRank = castOriginalType.getRank();
if (castDistributedRank < castOriginalRank) {
SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
llvm::append_range(shape, castDistributedType.getShape());
castDistributedType =
VectorType::get(shape, castDistributedType.getElementType());
}
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value newCast = rewriter.create<vector::ShapeCastOp>(
oldCastOp.getLoc(), castResultType,
newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
return success();
}
};
struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *yieldOperand =
getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
if (!yieldOperand)
return failure();
auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
if (!llvm::all_of(mask->getOperands(), [&](Value value) {
return warpOp.isDefinedOutsideOfRegion(value);
}))
return failure();
Location loc = mask.getLoc();
unsigned operandIndex = yieldOperand->getOperandNumber();
auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
VectorType seqType = mask.getVectorType();
ArrayRef<int64_t> seqShape = seqType.getShape();
ArrayRef<int64_t> distShape = distType.getShape();
rewriter.setInsertionPointAfter(warpOp);
SmallVector<Value> delinearizedIds;
if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
warpOp.getWarpSize(), warpOp.getLaneid(),
delinearizedIds))
return rewriter.notifyMatchFailure(
mask, "cannot delinearize lane ID for distribution");
assert(!delinearizedIds.empty());
rewriter.startOpModification(warpOp);
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
SmallVector<Value> newOperands;
for (int i = 0, e = distShape.size(); i < e; ++i) {
Value maskDimIdx = affine::makeComposedAffineApply(
rewriter, loc, s1 - s0 * distShape[i],
{delinearizedIds[i], mask.getOperand(i)});
newOperands.push_back(maskDimIdx);
}
auto newMask =
rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
rewriter.finalizeOpModification(warpOp);
return success();
}
};
struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
VectorType extractSrcType = extractOp.getSourceVectorType();
Location loc = extractOp.getLoc();
assert(extractSrcType.getRank() > 0 &&
"vector.extract does not support rank 0 sources");
if (extractOp.getNumIndices() == 0)
return failure();
if (extractSrcType.getRank() == 1) {
if (extractOp.hasDynamicPosition())
return failure();
assert(extractOp.getNumIndices() == 1 && "expected 1 index");
int64_t pos = extractOp.getStaticPosition()[0];
rewriter.setInsertionPoint(extractOp);
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
extractOp, extractOp.getVector(),
rewriter.create<arith::ConstantIndexOp>(loc, pos));
return success();
}
if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {extractOp.getVector()},
{extractOp.getSourceVectorType()}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
Value newExtract = rewriter.create<vector::ExtractOp>(
loc, distributedVec, extractOp.getMixedPosition());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
}
auto distributedType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
auto yieldedType = cast<VectorType>(operand->get().getType());
int64_t distributedDim = -1;
for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
assert(distributedDim == -1 && "found multiple distributed dims");
distributedDim = i;
}
}
assert(distributedDim != -1 && "could not find distributed dimension");
(void)distributedDim;
SmallVector<int64_t> newDistributedShape(extractSrcType.getShape().begin(),
extractSrcType.getShape().end());
for (int i = 0; i < distributedType.getRank(); ++i)
newDistributedShape[i + extractOp.getNumIndices()] =
distributedType.getDimSize(i);
auto newDistributedType =
VectorType::get(newDistributedShape, distributedType.getElementType());
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
Value newExtract = rewriter.create<vector::ExtractOp>(
loc, distributedVec, extractOp.getMixedPosition());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
}
};
struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
PatternBenefit b = 1)
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
warpShuffleFromIdxFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
VectorType extractSrcType = extractOp.getSourceVectorType();
if (!extractSrcType.getElementType().isF32() &&
!extractSrcType.getElementType().isInteger(32))
return rewriter.notifyMatchFailure(
extractOp, "only f32/i32 element types are supported");
bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
Type elType = extractSrcType.getElementType();
VectorType distributedVecType;
if (!is0dOrVec1Extract) {
assert(extractSrcType.getRank() == 1 &&
"expected that extractelement src rank is 0 or 1");
if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
return failure();
int64_t elementsPerLane =
extractSrcType.getShape()[0] / warpOp.getWarpSize();
distributedVecType = VectorType::get({elementsPerLane}, elType);
} else {
distributedVecType = extractSrcType;
}
SmallVector<Value> additionalResults{extractOp.getVector()};
SmallVector<Type> additionalResultTypes{distributedVecType};
if (static_cast<bool>(extractOp.getPosition())) {
additionalResults.push_back(extractOp.getPosition());
additionalResultTypes.push_back(extractOp.getPosition().getType());
}
Location loc = extractOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, additionalResults, additionalResultTypes,
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
if (is0dOrVec1Extract) {
Value newExtract;
if (extractSrcType.getRank() == 1) {
newExtract = rewriter.create<vector::ExtractElementOp>(
loc, distributedVec,
rewriter.create<arith::ConstantIndexOp>(loc, 0));
} else {
newExtract =
rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
}
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
}
int64_t elementsPerLane = distributedVecType.getShape()[0];
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
loc, sym0.ceilDiv(elementsPerLane),
newWarpOp->getResult(newRetIndices[1]));
Value pos =
elementsPerLane == 1
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
: rewriter
.create<affine::AffineApplyOp>(
loc, sym0 % elementsPerLane,
newWarpOp->getResult(newRetIndices[1]))
.getResult();
Value extracted =
rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
Value shuffled = warpShuffleFromIdxFn(
loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
return success();
}
private:
WarpShuffleFromIdxFn warpShuffleFromIdxFn;
};
struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
VectorType vecType = insertOp.getDestVectorType();
VectorType distrType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
bool hasPos = static_cast<bool>(insertOp.getPosition());
SmallVector<Value> additionalResults{insertOp.getDest(),
insertOp.getSource()};
SmallVector<Type> additionalResultTypes{distrType,
insertOp.getSource().getType()};
if (hasPos) {
additionalResults.push_back(insertOp.getPosition());
additionalResultTypes.push_back(insertOp.getPosition().getType());
}
Location loc = insertOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, additionalResults, additionalResultTypes,
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
Value newSource = newWarpOp->getResult(newRetIndices[1]);
Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
rewriter.setInsertionPointAfter(newWarpOp);
if (vecType == distrType) {
Value newInsert = rewriter.create<vector::InsertElementOp>(
loc, newSource, distributedVec, newPos);
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newInsert);
return success();
}
int64_t elementsPerLane = distrType.getShape()[0];
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
Value insertingLane = rewriter.create<affine::AffineApplyOp>(
loc, sym0.ceilDiv(elementsPerLane), newPos);
Value pos =
elementsPerLane == 1
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
: rewriter
.create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
newPos)
.getResult();
Value isInsertingLane = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
Value newResult =
rewriter
.create<scf::IfOp>(
loc, isInsertingLane,
[&](OpBuilder &builder, Location loc) {
Value newInsert = builder.create<vector::InsertElementOp>(
loc, newSource, distributedVec, pos);
builder.create<scf::YieldOp>(loc, newInsert);
},
[&](OpBuilder &builder, Location loc) {
builder.create<scf::YieldOp>(loc, distributedVec);
})
.getResult(0);
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
return success();
}
};
struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
Location loc = insertOp.getLoc();
if (insertOp.getNumIndices() == 0)
return failure();
if (insertOp.getDestVectorType().getRank() == 1) {
if (insertOp.hasDynamicPosition())
return failure();
assert(insertOp.getNumIndices() == 1 && "expected 1 index");
int64_t pos = insertOp.getStaticPosition()[0];
rewriter.setInsertionPoint(insertOp);
rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
insertOp, insertOp.getSource(), insertOp.getDest(),
rewriter.create<arith::ConstantIndexOp>(loc, pos));
return success();
}
if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
{insertOp.getSourceType(), insertOp.getDestVectorType()},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
Value newResult = rewriter.create<vector::InsertOp>(
loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newResult);
return success();
}
auto distrDestType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
auto yieldedType = cast<VectorType>(operand->get().getType());
int64_t distrDestDim = -1;
for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
assert(distrDestDim == -1 && "found multiple distributed dims");
distrDestDim = i;
}
}
assert(distrDestDim != -1 && "could not find distributed dimension");
VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
SmallVector<int64_t> distrSrcShape(srcVecType.getShape().begin(),
srcVecType.getShape().end());
int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
if (distrSrcDim >= 0)
distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
auto distrSrcType =
VectorType::get(distrSrcShape, distrDestType.getElementType());
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
{distrSrcType, distrDestType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
Value newResult;
if (distrSrcDim >= 0) {
newResult = rewriter.create<vector::InsertOp>(
loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
} else {
int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
SmallVector<int64_t> newPos = getAsIntegers(pos);
Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
loc, newPos[distrDestDim] / elementsPerLane);
Value isInsertingLane = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
newPos[distrDestDim] %= elementsPerLane;
auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
Value newInsert = builder.create<vector::InsertOp>(
loc, distributedSrc, distributedDest, newPos);
builder.create<scf::YieldOp>(loc, newInsert);
};
auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
builder.create<scf::YieldOp>(loc, distributedDest);
};
newResult = rewriter
.create<scf::IfOp>(loc, isInsertingLane,
insertingBuilder,
nonInsertingBuilder)
.getResult(0);
}
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
return success();
}
};
struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
distributionMapFn(std::move(fn)) {}
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<vector::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
if (!forOp)
return failure();
llvm::SmallSetVector<Value, 32> escapingValues;
SmallVector<Type> inputTypes;
SmallVector<Type> distTypes;
mlir::visitUsedValuesDefinedAbove(
forOp.getBodyRegion(), [&](OpOperand *operand) {
Operation *parent = operand->get().getParentRegion()->getParentOp();
if (warpOp->isAncestor(parent)) {
if (!escapingValues.insert(operand->get()))
return;
Type distType = operand->get().getType();
if (auto vecType = dyn_cast<VectorType>(distType)) {
AffineMap map = distributionMapFn(operand->get());
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
inputTypes.push_back(operand->get().getType());
distTypes.push_back(distType);
}
});
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
newRetIndices);
yield = cast<vector::YieldOp>(
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
SmallVector<Value> newOperands;
SmallVector<unsigned> resultIdx;
for (OpOperand &yieldOperand : yield->getOpOperands()) {
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
continue;
auto forResult = cast<OpResult>(yieldOperand.get());
newOperands.push_back(
newWarpOp.getResult(yieldOperand.getOperandNumber()));
yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
resultIdx.push_back(yieldOperand.getOperandNumber());
}
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newOperands);
rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
newForOp.getRegionIterArgs().end());
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
forOp.getResultTypes().end());
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
warpInput.push_back(newWarpOp.getResult(retIdx));
argIndexMapping[escapingValues[i]] = warpInputType.size();
warpInputType.push_back(inputTypes[i]);
}
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
newWarpOp.getWarpSize(), warpInput, warpInputType);
SmallVector<Value> argMapping;
argMapping.push_back(newForOp.getInductionVar());
for (Value args : innerWarp.getBody()->getArguments()) {
argMapping.push_back(args);
}
argMapping.resize(forOp.getBody()->getNumArguments());
SmallVector<Value> yieldOperands;
for (Value operand : forOp.getBody()->getTerminator()->getOperands())
yieldOperands.push_back(operand);
rewriter.eraseOp(forOp.getBody()->getTerminator());
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
rewriter.setInsertionPointAfter(innerWarp);
if (!innerWarp.getResults().empty())
rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
rewriter.eraseOp(forOp);
for (const auto &res : llvm::enumerate(resultIdx)) {
rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
newForOp.getResult(res.index()));
newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
}
newForOp.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
auto it = argIndexMapping.find(operand.get());
if (it == argIndexMapping.end())
continue;
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
}
});
mlir::vector::moveScalarUniformCode(innerWarp);
return success();
}
private:
DistributionMapFn distributionMapFn;
};
struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpReduction(MLIRContext *context,
DistributedReductionFn distributedReductionFn,
PatternBenefit benefit = 1)
: OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
distributedReductionFn(std::move(distributedReductionFn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *yieldOperand =
getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
if (!yieldOperand)
return failure();
auto reductionOp =
cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
if (vectorType.getRank() != 1)
return rewriter.notifyMatchFailure(
warpOp, "Only rank 1 reductions can be distributed.");
if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
return rewriter.notifyMatchFailure(
warpOp, "Reduction vector dimension must match was size.");
if (!reductionOp.getType().isIntOrFloat())
return rewriter.notifyMatchFailure(
warpOp, "Reduction distribution currently only supports floats and "
"integer types.");
int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
unsigned operandIndex = yieldOperand->getOperandNumber();
SmallVector<Value> yieldValues = {reductionOp.getVector()};
SmallVector<Type> retTypes = {
VectorType::get({numElements}, reductionOp.getType())};
if (reductionOp.getAcc()) {
yieldValues.push_back(reductionOp.getAcc());
retTypes.push_back(reductionOp.getAcc().getType());
}
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, yieldValues, retTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
Value fullReduce =
distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
reductionOp.getKind(), newWarpOp.getWarpSize());
if (reductionOp.getAcc()) {
fullReduce = vector::makeArithReduction(
rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
newWarpOp.getResult(newRetIndices[1]));
}
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
return success();
}
private:
DistributedReductionFn distributedReductionFn;
};
}
void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
RewritePatternSet &patterns,
const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) {
patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
}
void mlir::vector::populateDistributeTransferWriteOpPatterns(
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
unsigned maxNumElementsToExtract, PatternBenefit benefit) {
patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
maxNumElementsToExtract, benefit);
}
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
patterns
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractElement>(patterns.getContext(),
warpShuffleFromIdxFn, benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
benefit);
}
void mlir::vector::populateDistributeReduction(
RewritePatternSet &patterns,
const DistributedReductionFn &distributedReductionFn,
PatternBenefit benefit) {
patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
benefit);
}
void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
Block *body = warpOp.getBody();
llvm::SmallSetVector<Operation *, 8> opsToMove;
auto isDefinedOutsideOfBody = [&](Value value) {
auto *definingOp = value.getDefiningOp();
return (definingOp && opsToMove.count(definingOp)) ||
warpOp.isDefinedOutsideOfRegion(value);
};
for (auto &op : body->without_terminator()) {
bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
return isa<VectorType>(result.getType());
});
if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
opsToMove.insert(&op);
}
for (Operation *op : opsToMove)
op->moveBefore(warpOp);
}