#include "TestDialect.h"
#include "TestOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Arith/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Pass/Pass.h"
#define PASS_NAME "test-affine-reify-value-bounds"
using namespace mlir;
using namespace mlir::affine;
using mlir::presburger::BoundType;
namespace {
struct TestReifyValueBounds
: public PassWrapper<TestReifyValueBounds, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds)
StringRef getArgument() const final { return PASS_NAME; }
StringRef getDescription() const final {
return "Tests ValueBoundsOpInterface with affine dialect reification";
}
TestReifyValueBounds() = default;
TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){};
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<affine::AffineDialect, tensor::TensorDialect,
memref::MemRefDialect>();
}
void runOnOperation() override;
private:
Option<bool> reifyToFuncArgs{
*this, "reify-to-func-args",
llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)};
Option<bool> useArithOps{*this, "use-arith-ops",
llvm::cl::desc("Reify with arith dialect ops"),
llvm::cl::init(false)};
};
}
static ValueBoundsConstraintSet::ComparisonOperator
invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) {
if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LT)
return ValueBoundsConstraintSet::ComparisonOperator::GE;
if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LE)
return ValueBoundsConstraintSet::ComparisonOperator::GT;
if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GT)
return ValueBoundsConstraintSet::ComparisonOperator::LE;
if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GE)
return ValueBoundsConstraintSet::ComparisonOperator::LT;
llvm_unreachable("unsupported comparison operator");
}
static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
bool reifyToFuncArgs,
bool useArithOps) {
IRRewriter rewriter(funcOp.getContext());
WalkResult result = funcOp.walk([&](test::ReifyBoundOp op) {
auto boundType = op.getBoundType();
Value value = op.getVar();
std::optional<int64_t> dim = op.getDim();
bool constant = op.getConstant();
bool scalable = op.getScalable();
std::function<bool(Value, std::optional<int64_t>,
ValueBoundsConstraintSet & cstr)>
stopCondition = [&](Value v, std::optional<int64_t> d,
ValueBoundsConstraintSet &cstr) {
return v != value;
};
if (reifyToFuncArgs) {
stopCondition = [](Value v, std::optional<int64_t> d,
ValueBoundsConstraintSet &cstr) {
auto bbArg = dyn_cast<BlockArgument>(v);
if (!bbArg)
return false;
return isa<FunctionOpInterface>(bbArg.getParentBlock()->getParentOp());
};
}
rewriter.setInsertionPointAfter(op);
FailureOr<OpFoldResult> reified = failure();
if (constant) {
auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
boundType, {value, dim}, nullptr);
if (succeeded(reifiedConst))
reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
} else if (scalable) {
auto loc = op->getLoc();
auto reifiedScalable =
vector::ScalableValueBoundsConstraintSet::computeScalableBound(
value, dim, *op.getVscaleMin(), *op.getVscaleMax(), boundType);
if (succeeded(reifiedScalable)) {
SmallVector<std::pair<Value, std::optional<int64_t>>, 1> vscaleOperand;
if (reifiedScalable->map.getNumInputs() == 1) {
vscaleOperand.push_back(std::make_pair(
rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
}
reified = affine::materializeComputedBound(
rewriter, loc, reifiedScalable->map, vscaleOperand);
}
} else {
if (useArithOps) {
reified = arith::reifyValueBound(rewriter, op->getLoc(), boundType,
op.getVariable(), stopCondition);
} else {
reified = reifyValueBound(rewriter, op->getLoc(), boundType,
op.getVariable(), stopCondition);
}
}
if (failed(reified)) {
op->emitOpError("could not reify bound");
return WalkResult::interrupt();
}
if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) {
rewriter.replaceOp(op, val);
return WalkResult::skip();
}
Value constOp = rewriter.create<arith::ConstantIndexOp>(
op->getLoc(), cast<IntegerAttr>(reified->get<Attribute>()).getInt());
rewriter.replaceOp(op, constOp);
return WalkResult::skip();
});
return failure(result.wasInterrupted());
}
static LogicalResult testEquality(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
WalkResult result = funcOp.walk([&](test::CompareOp op) {
auto cmpType = op.getComparisonOperator();
if (op.getCompose()) {
if (cmpType != ValueBoundsConstraintSet::EQ) {
op->emitOpError(
"comparison operator must be EQ when 'composed' is specified");
return WalkResult::interrupt();
}
FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
op->getOperand(0), op->getOperand(1));
if (failed(delta)) {
op->emitError("could not determine equality");
} else if (*delta == 0) {
op->emitRemark("equal");
} else {
op->emitRemark("different");
}
return WalkResult::advance();
}
auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs());
};
if (compare(cmpType)) {
op->emitRemark("true");
} else if (cmpType != ValueBoundsConstraintSet::EQ &&
compare(invertComparisonOperator(cmpType))) {
op->emitRemark("false");
} else if (cmpType == ValueBoundsConstraintSet::EQ &&
(compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
op->emitRemark("false");
} else {
op->emitError("unknown");
}
return WalkResult::advance();
});
return failure(result.wasInterrupted());
}
void TestReifyValueBounds::runOnOperation() {
if (failed(
testReifyValueBounds(getOperation(), reifyToFuncArgs, useArithOps)))
signalPassFailure();
if (failed(testEquality(getOperation())))
signalPassFailure();
}
namespace mlir {
void registerTestAffineReifyValueBoundsPass() {
PassRegistration<TestReifyValueBounds>();
}
}