#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/Support/MathExtras.h"
namespace mlir {
bool isZeroIndex(OpFoldResult v) {
if (!v)
return false;
std::optional<int64_t> constint = getConstantIntValue(v);
if (!constint)
return false;
return *constint == 0;
}
std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
SmallVector<OpFoldResult>>
getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
SmallVector<OpFoldResult> offsets, sizes, strides;
offsets.reserve(ranges.size());
sizes.reserve(ranges.size());
strides.reserve(ranges.size());
for (const auto &[offset, size, stride] : ranges) {
offsets.push_back(offset);
sizes.push_back(size);
strides.push_back(stride);
}
return std::make_tuple(offsets, sizes, strides);
}
void dispatchIndexOpFoldResult(OpFoldResult ofr,
SmallVectorImpl<Value> &dynamicVec,
SmallVectorImpl<int64_t> &staticVec) {
auto v = llvm::dyn_cast_if_present<Value>(ofr);
if (!v) {
APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
staticVec.push_back(apInt.getSExtValue());
return;
}
dynamicVec.push_back(v);
staticVec.push_back(ShapedType::kDynamic);
}
void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
SmallVectorImpl<Value> &dynamicVec,
SmallVectorImpl<int64_t> &staticVec) {
for (OpFoldResult ofr : ofrs)
dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec);
}
OpFoldResult getAsOpFoldResult(Value val) {
if (!val)
return OpFoldResult();
Attribute attr;
if (matchPattern(val, m_Constant(&attr)))
return attr;
return val;
}
SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values) {
return llvm::to_vector(
llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
}
SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr) {
SmallVector<OpFoldResult> res;
res.reserve(arrayAttr.size());
for (Attribute a : arrayAttr)
res.push_back(a);
return res;
}
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val) {
return IntegerAttr::get(IndexType::get(ctx), val);
}
SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
ArrayRef<int64_t> values) {
return llvm::to_vector(llvm::map_range(
values, [ctx](int64_t v) { return getAsIndexOpFoldResult(ctx, v); }));
}
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
APSInt intVal;
if (matchPattern(val, m_ConstantInt(&intVal)))
return intVal.getSExtValue();
return std::nullopt;
}
Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
return intAttr.getValue().getSExtValue();
return std::nullopt;
}
std::optional<SmallVector<int64_t>>
getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
bool failed = false;
SmallVector<int64_t> res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) {
auto cv = getConstantIntValue(ofr);
if (!cv.has_value())
failed = true;
return cv.has_value() ? cv.value() : 0;
});
if (failed)
return std::nullopt;
return res;
}
bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
auto val = getConstantIntValue(ofr);
return val && *val == value;
}
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
if (cst1 && cst2 && *cst1 == *cst2)
return true;
auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
v2 = llvm::dyn_cast_if_present<Value>(ofr2);
return v1 && v1 == v2;
}
bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
ArrayRef<OpFoldResult> ofrs2) {
if (ofrs1.size() != ofrs2.size())
return false;
for (auto [ofr1, ofr2] : llvm::zip_equal(ofrs1, ofrs2))
if (!isEqualConstantIntOrValue(ofr1, ofr2))
return false;
return true;
}
SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
ValueRange dynamicValues, Builder &b) {
SmallVector<OpFoldResult> res;
res.reserve(staticValues.size());
unsigned numDynamic = 0;
unsigned count = static_cast<unsigned>(staticValues.size());
for (unsigned idx = 0; idx < count; ++idx) {
int64_t value = staticValues[idx];
res.push_back(ShapedType::isDynamic(value)
? OpFoldResult{dynamicValues[numDynamic++]}
: OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
}
return res;
}
std::pair<SmallVector<int64_t>, SmallVector<Value>>
decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) {
SmallVector<int64_t> staticValues;
SmallVector<Value> dynamicValues;
for (const auto &it : mixedValues) {
if (it.is<Attribute>()) {
staticValues.push_back(cast<IntegerAttr>(it.get<Attribute>()).getInt());
} else {
staticValues.push_back(ShapedType::kDynamic);
dynamicValues.push_back(it.get<Value>());
}
}
return {staticValues, dynamicValues};
}
template <typename K, typename V>
static SmallVector<V>
getValuesSortedByKeyImpl(ArrayRef<K> keys, ArrayRef<V> values,
llvm::function_ref<bool(K, K)> compare) {
if (keys.empty())
return SmallVector<V>{values};
assert(keys.size() == values.size() && "unexpected mismatching sizes");
auto indices = llvm::to_vector(llvm::seq<int64_t>(0, values.size()));
std::sort(indices.begin(), indices.end(),
[&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); });
SmallVector<V> res;
res.reserve(values.size());
for (int64_t i = 0, e = indices.size(); i < e; ++i)
res.push_back(values[indices[i]]);
return res;
}
SmallVector<Value>
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
llvm::function_ref<bool(Attribute, Attribute)> compare) {
return getValuesSortedByKeyImpl(keys, values, compare);
}
SmallVector<OpFoldResult>
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
llvm::function_ref<bool(Attribute, Attribute)> compare) {
return getValuesSortedByKeyImpl(keys, values, compare);
}
SmallVector<int64_t>
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
llvm::function_ref<bool(Attribute, Attribute)> compare) {
return getValuesSortedByKeyImpl(keys, values, compare);
}
std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) {
if (lb == ub)
return 0;
std::optional<int64_t> lbConstant = getConstantIntValue(lb);
if (!lbConstant)
return std::nullopt;
std::optional<int64_t> ubConstant = getConstantIntValue(ub);
if (!ubConstant)
return std::nullopt;
std::optional<int64_t> stepConstant = getConstantIntValue(step);
if (!stepConstant)
return std::nullopt;
return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);
}
bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
return llvm::none_of(sizesOrOffsets, [](int64_t value) {
return !ShapedType::isDynamic(value) && value < 0;
});
}
bool hasValidStrides(SmallVector<int64_t> strides) {
return llvm::none_of(strides, [](int64_t value) {
return !ShapedType::isDynamic(value) && value == 0;
});
}
LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
bool onlyNonNegative, bool onlyNonZero) {
bool valuesChanged = false;
for (OpFoldResult &ofr : ofrs) {
if (ofr.is<Attribute>())
continue;
Attribute attr;
if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
if (onlyNonNegative && *getConstantIntValue(attr) < 0)
continue;
if (onlyNonZero && *getConstantIntValue(attr) == 0)
continue;
ofr = attr;
valuesChanged = true;
}
}
return success(valuesChanged);
}
LogicalResult
foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) {
return foldDynamicIndexList(offsetsOrSizes, true,
false);
}
LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) {
return foldDynamicIndexList(strides, false,
true);
}
}