#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
namespace mlir {
namespace arith {
namespace {
struct AddIOpInterface
: public ValueBoundsOpInterface::ExternalModel<AddIOpInterface, AddIOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto addIOp = cast<AddIOp>(op);
assert(value == addIOp.getResult() && "invalid value");
AffineExpr lhs = cstr.getExpr(addIOp.getLhs());
AffineExpr rhs = cstr.getExpr(addIOp.getRhs());
cstr.bound(value) == lhs + rhs;
}
};
struct ConstantOpInterface
: public ValueBoundsOpInterface::ExternalModel<ConstantOpInterface,
ConstantOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto constantOp = cast<ConstantOp>(op);
assert(value == constantOp.getResult() && "invalid value");
if (auto attr = llvm::dyn_cast<IntegerAttr>(constantOp.getValue()))
cstr.bound(value) == attr.getInt();
}
};
struct SubIOpInterface
: public ValueBoundsOpInterface::ExternalModel<SubIOpInterface, SubIOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto subIOp = cast<SubIOp>(op);
assert(value == subIOp.getResult() && "invalid value");
AffineExpr lhs = cstr.getExpr(subIOp.getLhs());
AffineExpr rhs = cstr.getExpr(subIOp.getRhs());
cstr.bound(value) == lhs - rhs;
}
};
struct MulIOpInterface
: public ValueBoundsOpInterface::ExternalModel<MulIOpInterface, MulIOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto mulIOp = cast<MulIOp>(op);
assert(value == mulIOp.getResult() && "invalid value");
AffineExpr lhs = cstr.getExpr(mulIOp.getLhs());
AffineExpr rhs = cstr.getExpr(mulIOp.getRhs());
cstr.bound(value) == lhs *rhs;
}
};
struct SelectOpInterface
: public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
SelectOp> {
static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
Value value = selectOp.getResult();
Value condition = selectOp.getCondition();
Value trueValue = selectOp.getTrueValue();
Value falseValue = selectOp.getFalseValue();
if (isa<ShapedType>(condition.getType())) {
cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
return;
}
cstr.populateConstraints(trueValue, dim);
cstr.populateConstraints(falseValue, dim);
auto boundsBuilder = cstr.bound(value);
if (dim)
boundsBuilder[*dim];
if (cstr.compare({trueValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
{falseValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
} else {
cstr.bound(value) >= trueValue;
cstr.bound(value) <= falseValue;
}
}
if (cstr.compare({falseValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
{trueValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
} else {
cstr.bound(value) >= falseValue;
cstr.bound(value) <= trueValue;
}
}
}
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<SelectOp>(op), std::nullopt, cstr);
}
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<SelectOp>(op), dim, cstr);
}
};
}
}
}
void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
arith::AddIOp::attachInterface<arith::AddIOpInterface>(*ctx);
arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
});
}