#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
namespace mlir::arm_sme {
unsigned getSMETileSliceMinNumElts(Type type) {
assert(isValidSMETileElementType(type) && "invalid tile type!");
return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
}
bool isValidSMETileElementType(Type type) {
return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
type.isInteger(64) || type.isInteger(128) || type.isF16() ||
type.isBF16() || type.isF32() || type.isF64() || type.isF128();
}
bool isValidSMETileVectorType(VectorType vType) {
if ((vType.getRank() != 2) || !vType.allDimsScalable())
return false;
auto elemType = vType.getElementType();
if (!isValidSMETileElementType(elemType))
return false;
unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
return false;
return true;
}
std::optional<ArmSMETileType> getSMETileType(VectorType type) {
if (!isValidSMETileVectorType(type))
return {};
switch (type.getElementTypeBitWidth()) {
case 8:
return ArmSMETileType::ZAB;
case 16:
return ArmSMETileType::ZAH;
case 32:
return ArmSMETileType::ZAS;
case 64:
return ArmSMETileType::ZAD;
case 128:
return ArmSMETileType::ZAQ;
default:
llvm_unreachable("unknown SME tile type");
}
}
LogicalResult verifyOperationHasValidTileId(Operation *op) {
auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
if (!tileOp)
return success();
auto tileId = tileOp.getTileId();
if (!tileId)
return success();
if (!tileId.getType().isSignlessInteger(32))
return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
return success();
}
scf::ForOp createLoopOverTileSlices(
PatternRewriter &rewriter, Location loc, Value initTile,
std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
OpBuilder::InsertionGuard g(rewriter);
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
Value nextTile =
makeLoopBody(rewriter, loc, forOp.getInductionVar(),
forOp.getRegionIterArg(0));
rewriter.create<scf::YieldOp>(loc, nextTile);
return forOp;
}
bool isMultipleOfSMETileVectorType(VectorType vType) {
if (vType.getRank() != 2 || !vType.allDimsScalable())
return false;
auto elementType = vType.getElementType();
if (!isValidSMETileElementType(elementType))
return false;
unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
int64_t vectorRows = vType.getDimSize(0);
int64_t vectorCols = vType.getDimSize(1);
return (vectorRows > minNumElts || vectorCols > minNumElts) &&
vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
}
VectorType getSMETileTypeForElement(Type elementType) {
unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
}
void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
FunctionOpInterface function) {
SmallVector<Operation *> worklist;
function->walk([&](Operation *op) {
auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op);
if (armSMEOp && isOpTriviallyDead(armSMEOp))
worklist.push_back(armSMEOp);
});
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
if (!isOpTriviallyDead(op))
continue;
for (Value value : op->getOperands()) {
if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
worklist.push_back(armSMEOp);
}
rewriter.eraseOp(op);
}
}
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp) {
return tileOp && tileOp->getNumResults() == 1 &&
tileOp->getNumOperands() == 0 && isPure(tileOp);
}
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) {
for (Value result : tileOp->getResults()) {
if (arm_sme::isValidSMETileVectorType(result.getType()))
return true;
}
return false;
}
OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
if (!tileOp)
return nullptr;
auto isTileOperandType = [](OpOperand &operand) {
return arm_sme::isValidSMETileVectorType(operand.get().getType());
};
assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
"expected at most one tile operand");
OpOperand *tileOperand =
llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
if (tileOperand == tileOp->getOpOperands().end())
return nullptr;
return tileOperand;
}
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB) {
return static_cast<unsigned>(typeA) <= static_cast<unsigned>(typeB);
}
}