#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
namespace mlir {
namespace scf {
namespace {
struct ForOpInterface
: public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
static void populateIterArgBounds(scf::ForOp forOp, Value value,
std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
int64_t iterArgIdx;
if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
} else {
iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
}
Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
.getOperand(iterArgIdx);
Value iterArg = forOp.getRegionIterArg(iterArgIdx);
Value initArg = forOp.getInitArgs()[iterArgIdx];
if (cstr.populateAndCompare(
{yieldedValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::EQ,
{iterArg, dim})) {
if (dim.has_value()) {
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
} else {
cstr.bound(value) == cstr.getExpr(initArg);
}
}
}
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto forOp = cast<ForOp>(op);
if (value == forOp.getInductionVar()) {
cstr.bound(value) >= forOp.getLowerBound();
cstr.bound(value) < forOp.getUpperBound();
return;
}
populateIterArgBounds(forOp, value, std::nullopt, cstr);
}
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
auto forOp = cast<ForOp>(op);
populateIterArgBounds(forOp, value, dim, cstr);
}
};
struct IfOpInterface
: public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
static void populateBounds(scf::IfOp ifOp, Value value,
std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
unsigned int resultNum = cast<OpResult>(value).getResultNumber();
Value thenValue = ifOp.thenYield().getResults()[resultNum];
Value elseValue = ifOp.elseYield().getResults()[resultNum];
auto boundsBuilder = cstr.bound(value);
if (dim)
boundsBuilder[*dim];
if (cstr.populateAndCompare(
{thenValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
{elseValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
} else {
cstr.bound(value) >= thenValue;
cstr.bound(value) <= elseValue;
}
}
if (cstr.populateAndCompare(
{elseValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
{thenValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
} else {
cstr.bound(value) >= elseValue;
cstr.bound(value) <= thenValue;
}
}
}
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<IfOp>(op), value, std::nullopt, cstr);
}
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<IfOp>(op), value, dim, cstr);
}
};
}
}
}
void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
});
}