#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <optional>
using namespace mlir;
using namespace mlir::linalg;
namespace {
template <typename ConcreteType>
class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
public:
struct APIntOrFloat {
std::optional<APInt> apInt;
std::optional<APFloat> apFloat;
};
struct APIntOrFloatArray {
SmallVector<APInt> apInts;
SmallVector<APFloat> apFloats;
};
using RegionComputationFn =
std::function<APIntOrFloat(const APIntOrFloatArray &)>;
FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlFn(controlFn) {}
LogicalResult matchAndRewrite(LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
if (!linalgOp.hasPureTensorSemantics())
return failure();
if (linalgOp.getNumDpsInits() != 1)
return failure();
auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
if (!outputType || !outputType.hasStaticShape())
return failure();
if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
return isa<ShapedType>(input.getType());
}))
return failure();
auto getOperandElementType = [](Value value) {
return cast<ShapedType>(value.getType()).getElementType();
};
if (!llvm::all_equal(
llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
return failure();
auto elementType = outputType.getElementType();
if (!elementType.isIntOrFloat())
return failure();
if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
[](AffineMap map) { return map.isPermutation(); }))
return failure();
for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
if (linalgOp.payloadUsesValueFromOperand(&operand))
return failure();
}
if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp))
return failure();
RegionComputationFn computeFn =
static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp);
if (!computeFn)
return failure();
int numInputs = linalgOp.getNumDpsInputs();
SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
if (!matchPattern(en.value()->get(),
m_Constant(&inputValues[en.index()])))
return failure();
}
for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
if (!controlFn(operand))
return failure();
}
SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
int64_t numElements = outputType.getNumElements();
SmallVector<APInt> intOutputValues;
SmallVector<APFloat> fpOutputValues;
if (isa<FloatType>(elementType))
fpOutputValues.resize(numElements, APFloat(0.f));
else
intOutputValues.resize(numElements);
auto getDimPositions = [](AffineMap map) {
SmallVector<unsigned> dims;
dims.reserve(map.getNumResults());
for (AffineExpr result : map.getResults()) {
dims.push_back(cast<AffineDimExpr>(result).getPosition());
}
return dims;
};
SmallVector<SmallVector<unsigned>> inputDims;
for (int i = 0; i < numInputs; ++i)
inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
auto outputShape = outputType.getShape();
SmallVector<uint64_t> indices(loopBounds.size(), 0);
SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
SmallVector<SmallVector<uint64_t>> srcIndices(
numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
uint64_t dstLinearIndex = 0;
APIntOrFloatArray computeFnInputs;
auto inputShapes = llvm::to_vector<4>(
llvm::map_range(linalgOp.getDpsInputs(), [](Value value) {
return cast<ShapedType>(value.getType()).getShape();
}));
auto computeRemappedLinearIndex = [&](int linearIndex) {
int totalCount = linearIndex;
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
indices[dim] = totalCount % loopBounds[dim];
totalCount /= loopBounds[dim];
}
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
for (int i = 0; i < numInputs; ++i)
srcIndices[i][dim] = indices[inputDims[i][dim]];
dstIndices[dim] = indices[outputDims[dim]];
}
dstLinearIndex = dstIndices.front();
for (int i = 0; i < numInputs; ++i)
srcLinearIndices[i] = srcIndices[i].front();
for (int dim = 1; dim < outputType.getRank(); ++dim) {
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
for (int i = 0; i < numInputs; ++i)
srcLinearIndices[i] =
srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
}
};
bool isFloat = isa<FloatType>(elementType);
if (isFloat) {
SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
for (int i = 0; i < numInputs; ++i)
inFpRanges.push_back(inputValues[i].getValues<APFloat>());
computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
computeRemappedLinearIndex(linearIndex);
for (int i = 0; i < numInputs; ++i)
computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
}
} else {
SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
for (int i = 0; i < numInputs; ++i)
inIntRanges.push_back(inputValues[i].getValues<APInt>());
computeFnInputs.apInts.resize(numInputs);
for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
computeRemappedLinearIndex(linearIndex);
for (int i = 0; i < numInputs; ++i)
computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
}
}
DenseElementsAttr outputAttr =
isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
: DenseElementsAttr::get(outputType, intOutputValues);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr);
return success();
}
private:
ControlFusionFn controlFn;
};
struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
using FoldConstantBase::FoldConstantBase;
bool matchIndexingMaps(LinalgOp linalgOp) const {
return linalgOp.getIndexingMapsArray().size() == 2;
}
RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
Block &body = linalgOp->getRegion(0).front();
if (!llvm::hasSingleElement(body))
return nullptr;
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
if (!yieldOp)
return nullptr;
for (Value yieldVal : yieldOp.getValues()) {
auto yieldArg = dyn_cast<BlockArgument>(yieldVal);
if (!yieldArg || yieldArg.getOwner() != &body)
return nullptr;
if (yieldArg.getArgNumber() != 0)
return nullptr;
}
return [](const APIntOrFloatArray &inputs) {
if (inputs.apFloats.empty())
return APIntOrFloat{inputs.apInts.front(), std::nullopt};
return APIntOrFloat{std::nullopt, inputs.apFloats.front()};
};
}
ControlFusionFn controlFn;
};
}
void mlir::linalg::populateConstantFoldLinalgOperations(
RewritePatternSet &patterns, const ControlFusionFn &controlFn) {
MLIRContext *context = patterns.getContext();
patterns.insert<FoldConstantTranspose>(context, controlFn);
}