#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
#include "mlir/Dialect/PDL/IR/PDLOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/ADT/ScopeExit.h"
using namespace mlir;
MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks)
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
namespace {
class PatternApplicatorExtension : public transform::TransformState::Extension {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
explicit PatternApplicatorExtension(transform::TransformState &state,
Operation *patternContainer)
: Extension(state), patterns(patternContainer) {}
LogicalResult findAllMatches(StringRef patternName, Operation *root,
SmallVectorImpl<Operation *> &results);
private:
llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
SymbolTable patterns;
};
LogicalResult PatternApplicatorExtension::findAllMatches(
StringRef patternName, Operation *root,
SmallVectorImpl<Operation *> &results) {
auto it = compiledPatterns.find(patternName);
if (it == compiledPatterns.end()) {
auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
if (!patternOp)
return failure();
OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody());
builder.clone(*patternOp);
PDLPatternModule patternModule(std::move(pdlModuleOp));
auto *dialect =
root->getContext()->getLoadedDialect<transform::TransformDialect>();
for (const auto &[name, constraintFn] :
dialect->getExtraData<transform::PDLMatchHooks>()
.getPDLConstraintHooks()) {
patternModule.registerConstraintFunction(name, constraintFn);
}
patternModule.registerRewriteFunction(
"transform.dialect", [](PatternRewriter &, Operation *) {});
it = compiledPatterns
.try_emplace(patternOp.getName(), std::move(patternModule))
.first;
}
PatternApplicator applicator(it->second);
struct TrivialPatternRewriter : public PatternRewriter {
public:
explicit TrivialPatternRewriter(MLIRContext *context)
: PatternRewriter(context) {}
};
TrivialPatternRewriter rewriter(root->getContext());
applicator.applyDefaultCostModel();
root->walk([&](Operation *op) {
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
results.push_back(op);
});
return success();
}
}
void transform::PDLMatchHooks::mergeInPDLMatchHooks(
llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
for (auto &it : constraintFns)
pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
}
const llvm::StringMap<PDLConstraintFunction> &
transform::PDLMatchHooks::getPDLConstraintHooks() const {
return pdlMatchHooks.getConstraintFunctions();
}
DiagnosedSilenceableFailure
transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto *extension = state.getExtension<PatternApplicatorExtension>();
assert(extension &&
"expected PatternApplicatorExtension to be attached by the parent op");
SmallVector<Operation *> targets;
for (Operation *root : state.getPayloadOps(getRoot())) {
if (failed(extension->findAllMatches(
getPatternName().getLeafReference().getValue(), root, targets))) {
emitDefiniteFailure()
<< "could not find pattern '" << getPatternName() << "'";
}
}
results.set(llvm::cast<OpResult>(getResult()), targets);
return DiagnosedSilenceableFailure::success();
}
void transform::PDLMatchOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getRootMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure
transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
TransformOpInterface transformOp = nullptr;
for (Operation &nested : getBody().front()) {
if (!isa<pdl::PatternOp>(nested)) {
transformOp = cast<TransformOpInterface>(nested);
break;
}
}
state.addExtension<PatternApplicatorExtension>(getOperation());
auto guard = llvm::make_scope_exit(
[&]() { state.removeExtension<PatternApplicatorExtension>(); });
auto scope = state.make_region_scope(getBody());
if (failed(mapBlockArguments(state)))
return DiagnosedSilenceableFailure::definiteFailure();
return state.applyTransform(transformOp);
}
void transform::WithPDLPatternsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
getPotentialTopLevelEffects(effects);
}
LogicalResult transform::WithPDLPatternsOp::verify() {
Block *body = getBodyBlock();
Operation *topLevelOp = nullptr;
for (Operation &op : body->getOperations()) {
if (isa<pdl::PatternOp>(op))
continue;
if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
if (topLevelOp) {
InFlightDiagnostic diag =
emitOpError() << "expects only one non-pattern op in its body";
diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
diag.attachNote(op.getLoc()) << "second non-pattern op";
return diag;
}
topLevelOp = &op;
continue;
}
InFlightDiagnostic diag =
emitOpError()
<< "expects only pattern and top-level transform ops in its body";
diag.attachNote(op.getLoc()) << "offending op";
return diag;
}
if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
InFlightDiagnostic diag = emitOpError() << "cannot be nested";
diag.attachNote(parent.getLoc()) << "parent operation";
return diag;
}
if (!topLevelOp) {
InFlightDiagnostic diag = emitOpError()
<< "expects at least one non-pattern op";
return diag;
}
return success();
}