#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
using namespace mlir::affine;
FailureOr<OpFoldResult> mlir::affine::reifyValueBound(
OpBuilder &b, Location loc, presburger::BoundType type,
const ValueBoundsConstraintSet::Variable &var,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
boundMap, mapOperands, type, var, stopCondition, closedUB)))
return failure();
return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
}
OpFoldResult affine::materializeComputedBound(
OpBuilder &b, Location loc, AffineMap boundMap,
ArrayRef<std::pair<Value, std::optional<int64_t>>> mapOperands) {
SmallVector<Value> operands;
for (auto valueDim : mapOperands) {
Value value = valueDim.first;
std::optional<int64_t> dim = valueDim.second;
if (!dim.has_value()) {
assert(value.getType().isIndex() && "expected index type");
operands.push_back(value);
continue;
}
assert(cast<ShapedType>(value.getType()).isDynamicDim(*dim) &&
"expected dynamic dim");
if (isa<RankedTensorType>(value.getType())) {
operands.push_back(b.create<tensor::DimOp>(loc, value, *dim));
} else if (isa<MemRefType>(value.getType())) {
operands.push_back(b.create<memref::DimOp>(loc, value, *dim));
} else {
llvm_unreachable("cannot generate DimOp for unsupported shaped type");
}
}
affine::canonicalizeMapAndOperands(&boundMap, &operands);
if (boundMap.isSingleConstant()) {
return static_cast<OpFoldResult>(
b.getIndexAttr(boundMap.getSingleConstantResult()));
}
if (auto expr = dyn_cast<AffineDimExpr>(boundMap.getResult(0)))
return static_cast<OpFoldResult>(operands[expr.getPosition()]);
if (auto expr = dyn_cast<AffineSymbolExpr>(boundMap.getResult(0)))
return static_cast<OpFoldResult>(
operands[expr.getPosition() + boundMap.getNumDims()]);
return static_cast<OpFoldResult>(
b.create<affine::AffineApplyOp>(loc, boundMap, operands).getResult());
}
FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
ValueBoundsConstraintSet &cstr) {
return v != value;
};
return reifyValueBound(b, loc, type, {value, dim},
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
ValueBoundsConstraintSet &cstr) {
return v != value;
};
return reifyValueBound(b, loc, type, value,
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}