#include "mlir/Analysis/Liveness.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/IntervalMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include <algorithm>
namespace mlir::arm_sme {
#define GEN_PASS_DEF_TESTTILEALLOCATION
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::arm_sme;
namespace {
enum class TileMask : unsigned {
kZA0B = 0xffff,
kZA0H = 0xaaaa,
kZA1H = 0x5555,
kZA0S = 0x8888,
kZA1S = 0x4444,
kZA2S = 0x2222,
kZA3S = 0x1111,
kZA0D = 0x8080,
kZA1D = 0x4040,
kZA2D = 0x2020,
kZA3D = 0x1010,
kZA4D = 0x808,
kZA5D = 0x404,
kZA6D = 0x202,
kZA7D = 0x101,
kZA0Q = 0x8000,
kZA1Q = 0x4000,
kZA2Q = 0x2000,
kZA3Q = 0x1000,
kZA4Q = 0x800,
kZA5Q = 0x400,
kZA6Q = 0x200,
kZA7Q = 0x100,
kZA8Q = 0x80,
kZA9Q = 0x40,
kZA10Q = 0x20,
kZA11Q = 0x10,
kZA12Q = 0x8,
kZA13Q = 0x4,
kZA14Q = 0x2,
kZA15Q = 0x1,
kNone = 0x0,
LLVM_MARK_AS_BITMASK_ENUM(kZA0B)
};
static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
TileMask::kZA2S, TileMask::kZA3S};
static constexpr std::array ZA_D_MASKS = {
TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
static constexpr std::array ZA_Q_MASKS = {
TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
switch (type) {
case ArmSMETileType::ZAB:
return ZA_B_MASKS;
case ArmSMETileType::ZAH:
return ZA_H_MASKS;
case ArmSMETileType::ZAS:
return ZA_S_MASKS;
case ArmSMETileType::ZAD:
return ZA_D_MASKS;
case ArmSMETileType::ZAQ:
return ZA_Q_MASKS;
}
}
class TileAllocator {
public:
FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
auto masks = getMasks(tileType);
for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
if ((tilesInUse & tileMask) == TileMask::kNone) {
tilesInUse |= tileMask;
return tileId;
}
}
return failure();
}
void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
TileMask tileMask = getMasks(tileType)[tileId];
assert((tilesInUse & tileMask) == TileMask::kNone &&
"cannot acquire allocated tile!");
tilesInUse |= tileMask;
}
void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
TileMask tileMask = getMasks(tileType)[tileId];
assert((tilesInUse & tileMask) == tileMask &&
"cannot release unallocated tile!");
tilesInUse ^= tileMask;
}
unsigned allocateInMemoryTileId() {
return nextInMemoryTileId++;
}
private:
TileMask tilesInUse = TileMask::kNone;
unsigned nextInMemoryTileId = kInMemoryTileIdBase;
};
void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
SmallVector<cf::CondBranchOp> worklist;
function.walk([&](cf::CondBranchOp condBranch) {
if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
return isValidSMETileVectorType(value.getType());
})) {
worklist.push_back(condBranch);
}
});
auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
rewriter.setInsertionPointToEnd(source);
rewriter.create<cf::BranchOp>(loc, dest, args);
};
for (auto condBranch : worklist) {
auto loc = condBranch.getLoc();
Block *block = condBranch->getBlock();
auto newTrueBranch = rewriter.splitBlock(block, block->end());
auto newFalseBranch = rewriter.splitBlock(block, block->end());
insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
condBranch.getTrueDestOperands());
insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
condBranch.getFalseDestOperands());
rewriter.modifyOpInPlace(condBranch, [&] {
condBranch.getFalseDestOperandsMutable().clear();
condBranch.getTrueDestOperandsMutable().clear();
condBranch.setSuccessor(newTrueBranch, 0);
condBranch.setSuccessor(newFalseBranch, 1);
});
}
}
void insertCopiesAtBranches(IRRewriter &rewriter,
FunctionOpInterface function) {
for (Block &block : function.getBlocks()) {
Operation *terminator = block.getTerminator();
if (!isa<cf::BranchOp>(terminator))
continue;
rewriter.setInsertionPoint(terminator);
for (OpOperand &operand : terminator->getOpOperands()) {
if (isValidSMETileVectorType(operand.get().getType())) {
auto copy =
rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
}
}
}
}
void preprocessForTileAllocation(IRRewriter &rewriter,
FunctionOpInterface function) {
splitCondBranches(rewriter, function);
insertCopiesAtBranches(rewriter, function);
}
struct LiveRange {
using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
llvm::IntervalMapHalfOpenInfo<unsigned>>;
using Allocator = RangeSet::Allocator;
static constexpr uint8_t kValidLiveRange = 0xff;
LiveRange(Allocator &allocator)
: ranges(std::make_unique<RangeSet>(allocator)) {}
bool overlaps(LiveRange const &otherRange) const {
return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
*otherRange.ranges)
.valid();
}
bool overlaps(uint64_t point) const {
return ranges->lookup(point) == kValidLiveRange;
}
void unionWith(LiveRange const &otherRange) {
for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
++it)
ranges->insert(it.start(), it.stop(), kValidLiveRange);
values.set_union(otherRange.values);
}
void insert(Value value, unsigned start, unsigned end) {
values.insert(value);
if (start != end)
ranges->insert(start, end, kValidLiveRange);
}
bool empty() const { return ranges->empty(); }
unsigned start() const { return ranges->start(); }
unsigned end() const { return ranges->stop(); }
bool operator<(LiveRange const &other) const {
return start() < other.start();
}
ArmSMETileType getTileType() const {
return *getSMETileType(cast<VectorType>(values[0].getType()));
}
SetVector<Value> values;
std::unique_ptr<RangeSet> ranges;
std::optional<unsigned> tileId;
};
DenseMap<Operation *, unsigned>
generateOperationNumbering(FunctionOpInterface function) {
unsigned index = 0;
SetVector<Block *> blocks =
getBlocksSortedByDominance(function.getFunctionBody());
DenseMap<Operation *, unsigned> operationToIndexMap;
for (Block *block : blocks) {
index++;
for (Operation &op : block->getOperations()) {
#ifndef NDEBUG
op.walk([&](ArmSMETileOpInterface nestedOp) {
assert(&op == nestedOp.getOperation() &&
"ArmSME tile allocation does not support nested regions");
});
#endif
operationToIndexMap.try_emplace(&op, index++);
}
}
return operationToIndexMap;
}
DenseMap<Value, LiveRange>
gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
LiveRange::Allocator &liveRangeAllocator,
Liveness &liveness, FunctionOpInterface function) {
assert(!operationToIndexMap.empty() && "expected operation numbering");
DenseMap<Value, LiveRange> liveRanges;
auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
LivenessBlockInfo const &livenessInfo,
bool liveAtBlockEntry = false) {
if (!isValidSMETileVectorType(value.getType()))
return;
auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
LiveRange &valueLiveRange = it->second;
auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
unsigned startOpIdx =
operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
valueLiveRange.insert(value, startOpIdx, endOpIdx);
};
for (Block &block : function.getBlocks()) {
LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
for (Value argument : block.getArguments())
defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
true);
for (Value liveIn : livenessInfo->in())
defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
true);
for (Operation &op : block) {
for (Value result : op.getResults())
defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
}
}
return liveRanges;
}
static void forEachPredecessorTileValue(BlockArgument blockArg,
function_ref<void(Value)> callback) {
Block *block = blockArg.getOwner();
unsigned argNumber = blockArg.getArgNumber();
for (Block *pred : block->getPredecessors()) {
TypeSwitch<Operation *>(pred->getTerminator())
.Case<cf::BranchOp>([&](auto branch) {
Value predecessorOperand = branch.getDestOperands()[argNumber];
callback(predecessorOperand);
})
.Case<cf::CondBranchOp>([&](auto condBranch) {
if (condBranch.getFalseDest() == block) {
Value predecessorOperand =
condBranch.getFalseDestOperands()[argNumber];
callback(predecessorOperand);
}
if (condBranch.getTrueDest() == block) {
Value predecessorOperand =
condBranch.getTrueDestOperands()[argNumber];
callback(predecessorOperand);
}
});
}
}
SmallVector<LiveRange *>
coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
DenseMap<Value, LiveRange *> liveRanges;
for (auto &[value, liveRange] : initialLiveRanges) {
liveRanges.insert({value, &liveRange});
}
auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
LiveRange *aLiveRange = liveRanges.at(a);
LiveRange *bLiveRange = liveRanges.at(b);
if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
aLiveRange->unionWith(*bLiveRange);
for (Value value : bLiveRange->values)
liveRanges[value] = aLiveRange;
}
};
auto unifyDefinitionsWithOperands = [&](Value value) {
auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
if (!armSMEOp)
return;
for (auto operand : armSMEOp->getOperands()) {
if (isValidSMETileVectorType(operand.getType()))
mergeValuesIfNonOverlapping(value, operand);
}
};
auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
auto blockArg = dyn_cast<BlockArgument>(value);
if (!blockArg)
return;
forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
mergeValuesIfNonOverlapping(blockArg, predecessorTile);
});
};
auto applyRule = [&](auto rule) {
llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
};
applyRule(unifyBlockArgumentsWithPredecessors);
applyRule(unifyDefinitionsWithOperands);
SetVector<LiveRange *> uniqueLiveRanges;
for (auto [_, liveRange] : liveRanges) {
if (!liveRange->empty())
uniqueLiveRanges.insert(liveRange);
}
auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
[](LiveRange *a, LiveRange *b) { return *a < *b; });
return std::move(coalescedLiveRanges);
}
template <typename OverlappingRangesIterator>
LiveRange *
chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
LiveRange *newRange) {
auto isTrivialSpill = [&](LiveRange &allocatedRange) {
return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
newRange->getTileType()) &&
allocatedRange.values.size() == 1 &&
isTriviallyCloneableTileOp(
allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
};
if (isTrivialSpill(*newRange))
return newRange;
auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
if (trivialSpill != overlappingRanges.end())
return &*trivialSpill;
auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
a.end() < b.end();
};
LiveRange &latestEndingLiveRange =
*std::max_element(overlappingRanges.begin(), overlappingRanges.end(),
isSmallerTileTypeOrEndsEarlier);
if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
return &latestEndingLiveRange;
return newRange;
}
void allocateTilesToLiveRanges(
ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
TileAllocator tileAllocator;
SetVector<LiveRange *> activeRanges;
SetVector<LiveRange *> inactiveRanges;
for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
auto currentPoint = nextRange->start();
activeRanges.remove_if([&](LiveRange *activeRange) {
if (activeRange->end() <= currentPoint) {
tileAllocator.releaseTileId(activeRange->getTileType(),
*activeRange->tileId);
return true;
}
if (!activeRange->overlaps(currentPoint)) {
tileAllocator.releaseTileId(activeRange->getTileType(),
*activeRange->tileId);
inactiveRanges.insert(activeRange);
return true;
}
return false;
});
inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
if (inactiveRange->end() <= currentPoint) {
return true;
}
if (inactiveRange->overlaps(currentPoint)) {
tileAllocator.acquireTileId(inactiveRange->getTileType(),
*inactiveRange->tileId);
activeRanges.insert(inactiveRange);
return true;
}
return false;
});
SmallVector<LiveRange *> overlappingInactiveRanges;
for (LiveRange *inactiveRange : inactiveRanges) {
if (inactiveRange->overlaps(*nextRange)) {
tileAllocator.acquireTileId(inactiveRange->getTileType(),
*inactiveRange->tileId);
overlappingInactiveRanges.push_back(inactiveRange);
}
}
auto rangeTileType = nextRange->getTileType();
auto tileId = tileAllocator.allocateTileId(rangeTileType);
if (succeeded(tileId)) {
nextRange->tileId = *tileId;
} else {
auto allOverlappingRanges = llvm::concat<LiveRange>(
llvm::make_pointee_range(activeRanges.getArrayRef()),
llvm::make_pointee_range(overlappingInactiveRanges));
LiveRange *rangeToSpill =
chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
if (rangeToSpill != nextRange) {
tileAllocator.releaseTileId(rangeToSpill->getTileType(),
*rangeToSpill->tileId);
nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
if (!activeRanges.remove(rangeToSpill)) {
bool removed = inactiveRanges.remove(rangeToSpill);
assert(removed && "expected a range to be removed!");
(void)removed;
}
}
rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
}
if (nextRange->tileId < kInMemoryTileIdBase)
activeRanges.insert(nextRange);
for (LiveRange *range : overlappingInactiveRanges) {
if (*range->tileId < kInMemoryTileIdBase)
tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
}
}
}
void assignTileIdToValue(IRRewriter &rewriter, Value value,
IntegerAttr tileIdAttr) {
if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
for (Operation *user : value.getUsers()) {
if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
if (!hasTileResult(tileOp))
rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
}
}
}
LogicalResult assignTileIdsAndResolveTrivialConflicts(
IRRewriter &rewriter, FunctionOpInterface function,
ArrayRef<LiveRange *> allocatedLiveRanges) {
for (LiveRange const *liveRange : allocatedLiveRanges) {
auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
auto isAllocatedToSameTile = [&](Value value) {
if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
tileOp && tileOp.getTileId() == tileIdAttr)
return true;
return liveRange->values.contains(value);
};
auto foldRedundantCopies = [&](Value value) -> LogicalResult {
auto copyOp = value.getDefiningOp<CopyTileOp>();
if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
return failure();
rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
return success();
};
auto validateBlockArguments = [&](Value value) {
auto blockArg = dyn_cast<BlockArgument>(value);
if (!blockArg) {
return success();
}
bool tileMismatch = false;
forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
if (tileMismatch)
return;
if (!isAllocatedToSameTile(predecessorTile)) {
blockArg.getOwner()->getParentOp()->emitOpError(
"block argument not allocated to the same SME virtial tile as "
"predecessors");
tileMismatch = true;
}
});
return success(!tileMismatch);
};
auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
OpOperand *tileOperand = getTileOpOperand(tileOp);
if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
return success();
}
auto operandTileOp =
tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
if (!isTriviallyCloneableTileOp(operandTileOp)) {
auto error =
tileOp.emitOpError("tile operand allocated to different SME "
"virtial tile (move required)");
error.attachNote(tileOperand->get().getLoc())
<< "tile operand is: " << tileOperand->get();
return error;
}
rewriter.setInsertionPoint(tileOp);
auto clonedOp = operandTileOp.clone();
rewriter.modifyOpInPlace(clonedOp,
[&] { clonedOp.setTileId(tileOp.getTileId()); });
rewriter.insert(clonedOp);
if (isa<CopyTileOp>(tileOp)) {
rewriter.replaceAllUsesWith(tileOp->getResult(0),
clonedOp->getResult(0));
} else {
rewriter.modifyOpInPlace(
tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
}
return success();
};
for (Value value : liveRange->values) {
assignTileIdToValue(rewriter, value, tileIdAttr);
if (succeeded(foldRedundantCopies(value)))
continue;
if (failed(validateBlockArguments(value)))
return failure();
if (failed(resolveTrivialTileConflicts(value)))
return failure();
}
}
return success();
}
void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
ArrayRef<LiveRange const *> liveRanges,
FunctionOpInterface function) {
llvm::errs() << "SME Tile Liveness: @" << function.getName()
<< "\nKey:\nS - Start\nE - End\n| - Live\n";
for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
llvm::errs() << "^bb" << blockIdx << ":\n";
for (Operation &op : block.getOperations()) {
unsigned operationIndex = operationToIndexMap.at(&op);
for (LiveRange const *range : liveRanges) {
char liveness = ' ';
for (auto it = range->ranges->begin(); it != range->ranges->end();
++it) {
if (it.start() == operationIndex)
liveness = (liveness == 'E' ? '|' : 'S');
else if (it.stop() == operationIndex)
liveness = (liveness == 'S' ? '|' : 'E');
else if (operationIndex >= it.start() && operationIndex < it.stop())
liveness = '|';
}
llvm::errs() << liveness;
}
llvm::errs() << ' ' << op.getName() << '\n';
}
}
llvm::errs() << "==========\n";
}
struct TestTileAllocationPass
: public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
using TestTileAllocationBase::TestTileAllocationBase;
void runOnOperation() override {
FunctionOpInterface function = getOperation();
if (preprocessOnly) {
IRRewriter rewriter(function);
return preprocessForTileAllocation(rewriter, function);
}
if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
signalPassFailure();
}
};
}
LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
bool dumpRanges) {
if (function.empty()) {
return success();
}
LiveRange::Allocator liveRangeAllocator;
IRRewriter rewriter(function.getContext());
preprocessForTileAllocation(rewriter, function);
Liveness liveness(function);
auto operationToIndexMap = generateOperationNumbering(function);
auto initialLiveRanges = gatherTileLiveRanges(
operationToIndexMap, liveRangeAllocator, liveness, function);
if (initialLiveRanges.empty())
return success();
if (dumpRanges) {
auto nonEmpty = llvm::make_filter_range(
llvm::make_second_range(initialLiveRanges),
[&](LiveRange const &liveRange) { return !liveRange.empty(); });
auto initialRanges = llvm::to_vector(llvm::map_range(
nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
std::sort(initialRanges.begin(), initialRanges.end(),
[](LiveRange const *a, LiveRange const *b) { return *a < *b; });
llvm::errs() << "\n========== Initial Live Ranges:\n";
dumpLiveRanges(operationToIndexMap, initialRanges, function);
}
auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
if (dumpRanges) {
llvm::errs() << "\n========== Coalesced Live Ranges:\n";
dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
}
allocateTilesToLiveRanges(coalescedLiveRanges);
if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
coalescedLiveRanges))) {
return failure();
}
eraseTriviallyDeadTileOps(rewriter, function);
return success();
}