#include "mlir/Rewrite/PatternApplicator.h"
#include "ByteCode.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "pattern-application"
using namespace mlir;
using namespace mlir::detail;
PatternApplicator::PatternApplicator(
const FrozenRewritePatternSet &frozenPatternList)
: frozenPatternList(frozenPatternList) {
if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
bytecode->initializeMutableState(*mutableByteCodeState);
}
}
PatternApplicator::~PatternApplicator() = default;
#ifndef NDEBUG
static void logImpossibleToMatch(const Pattern &pattern) {
llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
<< "' because it is impossible to match or cannot lead "
"to legal IR (by cost model)\n";
}
static Operation *getDumpRootOp(Operation *op) {
Operation *isolatedParent =
op->getParentWithTrait<mlir::OpTrait::IsIsolatedFromAbove>();
if (isolatedParent)
return isolatedParent;
return op;
}
static void logSucessfulPatternApplication(Operation *op) {
llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n";
op->dump();
llvm::dbgs() << "\n\n";
}
#endif
void PatternApplicator::applyCostModel(CostModel model) {
if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
for (const auto &it : llvm::enumerate(bytecode->getPatterns()))
mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
}
patterns.clear();
for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
for (const RewritePattern *pattern : it.second) {
if (pattern->getBenefit().isImpossibleToMatch())
LLVM_DEBUG(logImpossibleToMatch(*pattern));
else
patterns[it.first].push_back(pattern);
}
}
anyOpPatterns.clear();
for (const RewritePattern &pattern :
frozenPatternList.getMatchAnyOpNativePatterns()) {
if (pattern.getBenefit().isImpossibleToMatch())
LLVM_DEBUG(logImpossibleToMatch(pattern));
else
anyOpPatterns.push_back(&pattern);
}
llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
return benefits[lhs] > benefits[rhs];
};
auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
if (list.size() == 1) {
if (model(*list.front()).isImpossibleToMatch()) {
LLVM_DEBUG(logImpossibleToMatch(*list.front()));
list.clear();
}
return;
}
benefits.clear();
for (const Pattern *pat : list)
benefits.try_emplace(pat, model(*pat));
std::stable_sort(list.begin(), list.end(), cmp);
while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
LLVM_DEBUG(logImpossibleToMatch(*list.back()));
list.pop_back();
}
};
for (auto &it : patterns)
processPatternList(it.second);
processPatternList(anyOpPatterns);
}
void PatternApplicator::walkAllPatterns(
function_ref<void(const Pattern &)> walk) {
for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
for (const auto &pattern : it.second)
walk(*pattern);
for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
walk(it);
if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
for (const Pattern &it : bytecode->getPatterns())
walk(it);
}
}
LogicalResult PatternApplicator::matchAndRewrite(
Operation *op, PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess) {
SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
if (bytecode)
bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
MutableArrayRef<const RewritePattern *> opPatterns;
auto patternIt = patterns.find(op->getName());
if (patternIt != patterns.end())
opPatterns = patternIt->second;
unsigned opIt = 0, opE = opPatterns.size();
unsigned anyIt = 0, anyE = anyOpPatterns.size();
unsigned pdlIt = 0, pdlE = pdlMatches.size();
LogicalResult result = failure();
do {
const Pattern *bestPattern = nullptr;
unsigned *bestPatternIt = &opIt;
if (opIt < opE)
bestPattern = opPatterns[opIt];
if (anyIt < anyE &&
(!bestPattern ||
bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
bestPatternIt = &anyIt;
bestPattern = anyOpPatterns[anyIt];
}
const PDLByteCode::MatchResult *pdlMatch = nullptr;
if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
pdlMatches[pdlIt].benefit)) {
bestPatternIt = &pdlIt;
pdlMatch = &pdlMatches[pdlIt];
bestPattern = pdlMatch->pattern;
}
if (!bestPattern)
break;
++(*bestPatternIt);
if (canApply && !canApply(*bestPattern))
continue;
bool matched = false;
op->getContext()->executeAction<ApplyPatternAction>(
[&]() {
rewriter.setInsertionPoint(op);
#ifndef NDEBUG
Operation *dumpRootOp = getDumpRootOp(op);
#endif
if (pdlMatch) {
result =
bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
} else {
LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
<< bestPattern->getDebugName() << "\"\n");
const auto *pattern =
static_cast<const RewritePattern *>(bestPattern);
result = pattern->matchAndRewrite(op, rewriter);
LLVM_DEBUG(llvm::dbgs()
<< "\"" << bestPattern->getDebugName() << "\" result "
<< succeeded(result) << "\n");
}
if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
result = failure();
if (succeeded(result)) {
LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
matched = true;
return;
}
if (onFailure)
onFailure(*bestPattern);
},
{op}, *bestPattern);
if (matched)
break;
} while (true);
if (mutableByteCodeState)
mutableByteCodeState->cleanupAfterMatchAndRewrite();
return result;
}