#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Pass/Pass.h"
#define PASS_NAME "test-loop-permutation"
using namespace mlir;
using namespace mlir::affine;
namespace {
struct TestLoopPermutation
: public PassWrapper<TestLoopPermutation, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopPermutation)
StringRef getArgument() const final { return PASS_NAME; }
StringRef getDescription() const final {
return "Tests affine loop permutation utility";
}
TestLoopPermutation() = default;
TestLoopPermutation(const TestLoopPermutation &pass) : PassWrapper(pass){};
void runOnOperation() override;
private:
ListOption<unsigned> permList{*this, "permutation-map",
llvm::cl::desc("Specify the loop permutation"),
llvm::cl::OneOrMore};
};
}
void TestLoopPermutation::runOnOperation() {
SmallVector<unsigned, 4> permMap(permList.begin(), permList.end());
SmallVector<AffineForOp, 2> forOps;
getOperation()->walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
for (auto forOp : forOps) {
SmallVector<AffineForOp, 6> nest;
getPerfectlyNestedLoops(nest, forOp);
if (nest.size() >= 2 && nest.size() == permMap.size()) {
permuteLoops(nest, permMap);
}
}
}
namespace mlir {
void registerTestLoopPermutationPass() {
PassRegistration<TestLoopPermutation>();
}
}