#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include <numeric>
#include <optional>
using namespace mlir;
std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForReshape(ShapedType sourceType,
ShapedType targetType) {
if (sourceType.getRank() > targetType.getRank())
return getReassociationIndicesForCollapse(sourceType.getShape(),
targetType.getShape());
if (sourceType.getRank() < targetType.getRank())
return getReassociationIndicesForCollapse(targetType.getShape(),
sourceType.getShape());
return std::nullopt;
}
std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> targetShape) {
if (sourceShape.size() <= targetShape.size())
return std::nullopt;
unsigned sourceDim = 0;
SmallVector<ReassociationIndices> reassociationMap;
reassociationMap.reserve(targetShape.size());
ReassociationIndices currIndices;
int64_t prodOfCollapsedDims = 1;
while (sourceDim < sourceShape.size()) {
unsigned targetDim = reassociationMap.size();
if (targetDim == targetShape.size())
break;
int64_t currTargetShape = targetShape[targetDim];
while (sourceDim < sourceShape.size() &&
sourceShape[sourceDim] != ShapedType::kDynamic &&
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
prodOfCollapsedDims *= sourceShape[sourceDim];
currIndices.push_back(sourceDim++);
}
if (sourceShape[sourceDim] == ShapedType::kDynamic &&
(currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
return std::nullopt;
if (currTargetShape == ShapedType::kDynamic &&
sourceShape[sourceDim] != ShapedType::kDynamic)
return std::nullopt;
if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
return std::nullopt;
currIndices.push_back(sourceDim++);
reassociationMap.emplace_back(ReassociationIndices{});
std::swap(reassociationMap.back(), currIndices);
prodOfCollapsedDims = 1;
}
if (reassociationMap.size() != targetShape.size())
return std::nullopt;
for (; sourceDim < sourceShape.size(); sourceDim++) {
if (sourceShape[sourceDim] != ShapedType::kDynamic &&
sourceShape[sourceDim] != 1)
return std::nullopt;
if (!reassociationMap.empty())
reassociationMap.back().push_back(sourceDim);
}
return reassociationMap;
}
std::optional<SmallVector<ReassociationIndices>>
mlir::composeReassociationIndices(
ArrayRef<ReassociationIndices> producerReassociations,
ArrayRef<ReassociationIndices> consumerReassociations,
MLIRContext *context) {
SmallVector<ReassociationIndices> composedIndices;
if (producerReassociations.size() == consumerReassociations.size())
return std::nullopt;
if (producerReassociations.size() < consumerReassociations.size())
std::swap(producerReassociations, consumerReassociations);
if (consumerReassociations.empty())
return composedIndices;
size_t consumerDims = std::accumulate(
consumerReassociations.begin(), consumerReassociations.end(), 0,
[](size_t all, ReassociationIndicesRef indices) {
return all + indices.size();
});
if (producerReassociations.size() != consumerDims)
return std::nullopt;
for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
ReassociationIndices reassociations;
for (int64_t consumerIndex : consumerIndices) {
llvm::append_range(reassociations, producerReassociations[consumerIndex]);
}
composedIndices.push_back(std::move(reassociations));
}
return composedIndices;
}
SmallVector<SmallVector<AffineExpr, 2>, 2>
mlir::convertReassociationIndicesToExprs(
MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
for (const auto &indices : reassociationIndices) {
SmallVector<AffineExpr, 2> reassociationMap;
reassociationMap.reserve(indices.size());
for (int64_t index : indices)
reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
reassociationMaps.push_back(std::move(reassociationMap));
}
return reassociationMaps;
}
template <typename AffineExprTy>
unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
unsigned pos = 0;
for (const auto &exprs : exprArrays) {
for (auto expr : exprs) {
expr.walk([&pos](AffineExpr e) {
if (auto d = dyn_cast<AffineExprTy>(e))
pos = std::max(pos, d.getPosition());
});
}
}
return pos;
}
ArrayAttr mlir::getReassociationIndicesAttribute(
OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
SmallVector<Attribute, 4> reassociationAttr =
llvm::to_vector<4>(llvm::map_range(
reassociation, [&](const ReassociationIndices &indices) -> Attribute {
return cast<Attribute>(b.getI64ArrayAttr(indices));
}));
return b.getArrayAttr(reassociationAttr);
}
SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
ArrayRef<ReassociationExprs> reassociationExprs) {
SmallVector<ReassociationIndices, 2> reassociationIndices;
for (const auto &exprs : reassociationExprs) {
ReassociationIndices indices;
indices.reserve(exprs.size());
for (const auto &expr : exprs)
indices.push_back(cast<AffineDimExpr>(expr).getPosition());
reassociationIndices.push_back(indices);
}
return reassociationIndices;
}
SmallVector<AffineMap, 4>
mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
"Expected symbol-less expressions");
SmallVector<AffineMap, 4> maps;
maps.reserve(reassociation.size());
for (const auto &exprs : reassociation) {
assert(!exprs.empty());
maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
}
return maps;
}
bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
int *invalidIndex) {
if (reassociation.empty())
return true;
unsigned nDims = reassociation[0].getNumDims();
unsigned nextExpectedDim = 0;
for (const auto &it : llvm::enumerate(reassociation)) {
auto m = it.value();
if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
if (invalidIndex)
*invalidIndex = it.index();
return false;
}
for (auto e : m.getResults()) {
auto d = dyn_cast<AffineDimExpr>(e);
if (!d || d.getPosition() != nextExpectedDim++) {
if (invalidIndex)
*invalidIndex = it.index();
return false;
}
}
}
if (nextExpectedDim != nDims) {
if (invalidIndex)
*invalidIndex = reassociation.size() - 1;
return false;
}
return true;
}
LogicalResult mlir::reshapeLikeShapesAreCompatible(
function_ref<LogicalResult(const Twine &)> emitError,
ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
unsigned expandedDimStart = 0;
for (const auto &map : llvm::enumerate(reassociationMaps)) {
bool foundDynamicShape = false;
int64_t linearizedStaticShape = 1;
for (const auto &dim : llvm::enumerate(
expandedShape.slice(expandedDimStart, map.value().size()))) {
if (ShapedType::isDynamic(dim.value()))
foundDynamicShape = true;
else
linearizedStaticShape *= dim.value();
}
if (foundDynamicShape) {
if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
return emitError(
"expected dimension " + Twine(map.index()) +
" of collapsed type to be dynamic since one or more of the "
"corresponding dimensions in the expanded type is dynamic");
}
} else {
if (collapsedShape[map.index()] != linearizedStaticShape) {
return emitError("expected dimension " + Twine(map.index()) +
" of collapsed type to be static value of " +
Twine(linearizedStaticShape));
}
}
expandedDimStart += map.value().size();
}
return success();
}
bool mlir::hasNonIdentityLayout(Type type) {
if (auto memrefType = dyn_cast<MemRefType>(type))
return !memrefType.getLayout().isIdentity();
return false;
}
llvm::SmallBitVector
mlir::getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
ArrayRef<Range> sliceParams) {
assert(sliceParams.size() == sliceInputShape.size() &&
"only supports non rank-reducing case");
llvm::SmallBitVector mask(sliceInputShape.size());
unsigned idx = 0;
for (const auto &[offset, size, stride] : sliceParams) {
std::optional<int64_t> offsetConst = getConstantIntValue(offset);
std::optional<int64_t> strideConst = getConstantIntValue(stride);
mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) ||
(!strideConst || *strideConst != 1) ||
(!offsetConst || *offsetConst != 0);
idx++;
}
return mask;
}
llvm::SmallBitVector mlir::getLinearizedDimensions(
ArrayRef<ReassociationIndices> reassociationIndices) {
llvm::SmallBitVector result(reassociationIndices.size());
for (const auto &it : llvm::enumerate(reassociationIndices))
result[it.index()] = it.value().size() > 1;
return result;
}
SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
MLIRContext *ctx, ArrayRef<ValueRange> multiIndices) {
unsigned loopIdx = 0;
auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
SmallVector<Range> offsetsSizesAndStrides;
offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
for (const auto &it : llvm::enumerate(reassociationIndices)) {
if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
llvm::append_range(
offsetsSizesAndStrides,
llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range {
return Range{getAsOpFoldResult(v), oneAttr, oneAttr};
}));
continue;
}
if (linearizedDimensions[it.index()]) {
llvm::append_range(
offsetsSizesAndStrides,
llvm::map_range(it.value(), [&](int64_t idx) -> Range {
return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
}));
continue;
}
offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
}
return offsetsSizesAndStrides;
}
SmallVector<Range>
SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,
ValueRange tileIndices) {
auto one = IntegerAttr::get(IndexType::get(ctx), 1);
auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
SmallVector<Range> insertParams;
insertParams.reserve(linearizedDimensions.size());
unsigned loopIdx = 0;
for (unsigned i = 0; i < linearizedDimensions.size(); i++) {
if (linearizedDimensions[i] && slicedDimensions[i]) {
insertParams.push_back(Range{tileIndices[loopIdx++], one, one});
continue;
}
insertParams.push_back(Range{zero, sliceParams[i].size, one});
}
return insertParams;
}
static std::optional<int64_t> getUniqueNonUnitDim(ArrayRef<int64_t> indices,
ArrayRef<int64_t> shape) {
std::optional<int64_t> dimIndex;
if (indices.size() < 2)
return std::nullopt;
for (int64_t idx : indices) {
if (shape[idx] != 1) {
if (dimIndex != std::nullopt)
return std::nullopt;
dimIndex = idx;
}
}
return dimIndex;
}
static SmallVector<std::optional<int64_t>> getCollapseShapeTrivialSegments(
RankedTensorType sourceType,
ArrayRef<ReassociationIndices> reassociationIndices) {
SmallVector<std::optional<int64_t>> trivialSegments;
for (const auto &indices : reassociationIndices)
trivialSegments.push_back(
getUniqueNonUnitDim(indices, sourceType.getShape()));
return trivialSegments;
}
static FailureOr<SmallVector<std::optional<int64_t>>>
canCollapseShapeBeSimplifiedByRankReducingSlice(
RankedTensorType sourceType,
ArrayRef<ReassociationIndices> reassociationIndices) {
SmallVector<std::optional<int64_t>> trivialSegments =
getCollapseShapeTrivialSegments(sourceType, reassociationIndices);
if (!llvm::any_of(trivialSegments, [](const std::optional<int64_t> &idx) {
return idx.has_value();
}))
return failure();
return trivialSegments;
}
FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
RankedTensorType sourceType,
ArrayRef<ReassociationIndices> reassociationIndices) {
FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments =
canCollapseShapeBeSimplifiedByRankReducingSlice(sourceType,
reassociationIndices);
if (failed(trivialSegments))
return failure();
SmallVector<int64_t> sliceShape;
for (const auto &[nonUnitDim, indices] :
llvm::zip(*trivialSegments, reassociationIndices)) {
if (nonUnitDim) {
sliceShape.push_back(sourceType.getDimSize(*nonUnitDim));
continue;
}
llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {
return sourceType.getDimSize(idx);
}));
}
auto sliceType =
RankedTensorType::get(sliceShape, sourceType.getElementType());
if (sliceShape.size() == reassociationIndices.size())
return CollapseShapeRankReducingSliceSimplificationInfo{sliceType,
std::nullopt};
SmallVector<ReassociationIndices> newReassociationIndices;
SmallVector<int64_t, 2> reassociation;
int64_t groupIdx = 0;
for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
reassociation.push_back(dimIdx);
if ((*trivialSegments)[groupIdx] ||
reassociation.size() == reassociationIndices[groupIdx].size()) {
newReassociationIndices.push_back(reassociation);
reassociation.clear();
groupIdx++;
}
}
return CollapseShapeRankReducingSliceSimplificationInfo{
sliceType, newReassociationIndices};
}
PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
ArrayRef<int64_t> innerDimPos) {
PackingMetadata res;
res.insertPositions.reserve(innerDimPos.size());
int64_t offset = 1;
for (int64_t pos : innerDimPos) {
int64_t numInsertedBefore = llvm::count_if(
innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });
res.insertPositions.push_back(pos + numInsertedBefore + offset);
}
DenseSet<int64_t> posSet(res.insertPositions.begin(),
res.insertPositions.end());
res.reassociations.reserve(packedRank);
for (int64_t i = 1; i <= packedRank; ++i) {
res.outerPositions.push_back(i - 1);
if (!posSet.contains(i)) {
res.reassociations.push_back(ReassociationIndices{i - 1});
continue;
}
res.reassociations.push_back(ReassociationIndices{i - 1, i});
++i;
}
return res;
}