#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/Transforms/LoopPeeling.h"
using namespace mlir;
namespace {
bool getPeelEpilogue(scf::ForOp forOp) {
return forOp->hasAttr("__test_peel_epilogue");
}
struct TestLoopPeelingPass
: public PassWrapper<TestLoopPeelingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopPeelingPass);
StringRef getArgument() const final { return "triton-test-loop-peeling"; }
StringRef getDescription() const final {
return "test the loop peeling pass";
}
void runOnOperation() override {
IRRewriter rewriter(getOperation());
getOperation().walk([&](scf::ForOp forOp) {
if (getPeelEpilogue(forOp)) {
mlir::triton::peelLoopEpilogue(forOp);
}
});
}
};
}
namespace mlir {
namespace test {
void registerTestLoopPeelingPass() { PassRegistration<TestLoopPeelingPass>(); }
}
}