#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Index/IR/IndexAttrs.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
namespace mlir {
namespace linalg {
namespace {
template <typename T>
struct StructuredOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
StructuredOpInterface<T>, T> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto linalgOp = llvm::cast<LinalgOp>(op);
SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges);
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
transform(ends, ends.begin(), [&](OpFoldResult end) {
auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
return builder.createOrFold<index::SubOp>(loc, endValue, one);
});
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
auto startIndices = affine::makeComposedFoldedMultiResultAffineApply(
builder, loc, indexingMap, starts);
auto endIndices = affine::makeComposedFoldedMultiResultAffineApply(
builder, loc, indexingMap, ends);
for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) {
auto startIndex =
getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]);
auto endIndex =
getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]);
auto min =
builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex);
auto cmpOp = builder.createOrFold<index::CmpOp>(
loc, index::IndexCmpPredicate::SGE, min, zero);
auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
linalgOp, "unexpected negative result on dimension #" +
std::to_string(dim) + " of input/output operand #" +
std::to_string(opOperand.getOperandNumber()));
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
auto max =
builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex);
auto inferredDimSize =
builder.createOrFold<index::AddOp>(loc, max, one);
auto actualDimSize =
createOrFoldDimOp(builder, loc, opOperand.get(), dim);
auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim))
? index::IndexCmpPredicate::EQ
: index::IndexCmpPredicate::SLE;
cmpOp = builder.createOrFold<index::CmpOp>(
loc, predicate, inferredDimSize, actualDimSize);
msg = RuntimeVerifiableOpInterface::generateErrorMessage(
linalgOp, "dimension #" + std::to_string(dim) +
" of input/output operand #" +
std::to_string(opOperand.getOperandNumber()) +
" is incompatible with inferred dimension size");
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
}
};
template <typename... OpTs>
void attachInterface(MLIRContext *ctx) {
(OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
}
}
}
}
void mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) {
attachInterface<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(ctx);
ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
cf::ControlFlowDialect, index::IndexDialect,
tensor::TensorDialect>();
});
}