#include "PredicateTree.h"
#include "RootOrdering.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <queue>
#define DEBUG_TYPE "pdl-predicate-tree"
using namespace mlir;
using namespace mlir::pdl_to_pdl_interp;
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
Position *pos);
static bool comparePosDepth(Position *lhs, Position *rhs) {
return lhs->getOperationDepth() < rhs->getOperationDepth();
}
static unsigned getNumNonRangeValues(ValueRange values) {
return llvm::count_if(values.getTypes(),
[](Type type) { return !isa<pdl::RangeType>(type); });
}
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
AttributePosition *pos) {
assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
predList.emplace_back(pos, builder.getIsNotNull());
if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
if (Value type = attr.getValueType())
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
else if (Attribute value = attr.getValueAttr())
predList.emplace_back(pos, builder.getAttributeConstraint(value));
}
}
static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
Position *pos) {
Type valueType = val.getType();
bool isVariadic = isa<pdl::RangeType>(valueType);
TypeSwitch<Operation *>(val.getDefiningOp())
.Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
predList.emplace_back(pos, builder.getIsNotNull());
if (Value type = op.getValueType())
getTreePredicates(predList, type, builder, inputs,
builder.getType(pos));
})
.Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
std::optional<unsigned> index = op.getIndex();
if (index)
predList.emplace_back(pos, builder.getIsNotNull());
OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
predList.emplace_back(parentPos, builder.getIsNotNull());
Position *resultPos = nullptr;
if (std::is_same<pdl::ResultOp, decltype(op)>::value)
resultPos = builder.getResult(parentPos, *index);
else
resultPos = builder.getResultGroup(parentPos, index, isVariadic);
predList.emplace_back(resultPos, builder.getEqualTo(pos));
getTreePredicates(predList, op.getParent(), builder, inputs,
(Position *)parentPos);
});
}
static void
getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs, OperationPosition *pos,
std::optional<unsigned> ignoreOperand = std::nullopt) {
assert(isa<pdl::OperationType>(val.getType()) && "expected operation");
pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
OperationPosition *opPos = cast<OperationPosition>(pos);
if (!opPos->isRoot())
predList.emplace_back(pos, builder.getIsNotNull());
if (std::optional<StringRef> opName = op.getOpName())
predList.emplace_back(pos, builder.getOperationName(*opName));
OperandRange operands = op.getOperandValues();
unsigned minOperands = getNumNonRangeValues(operands);
if (minOperands != operands.size()) {
if (minOperands)
predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
} else {
predList.emplace_back(pos, builder.getOperandCount(minOperands));
}
OperandRange types = op.getTypeValues();
unsigned minResults = getNumNonRangeValues(types);
if (minResults == types.size())
predList.emplace_back(pos, builder.getResultCount(types.size()));
else if (minResults)
predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
for (auto [attrName, attr] :
llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
getTreePredicates(
predList, attr, builder, inputs,
builder.getAttribute(opPos, cast<StringAttr>(attrName).getValue()));
}
if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].getType())) {
if (opPos->isRoot() || opPos->isOperandDefiningOp())
getTreePredicates(predList, operands.front(), builder, inputs,
builder.getAllOperands(opPos));
} else {
bool foundVariableLength = false;
for (const auto &operandIt : llvm::enumerate(operands)) {
bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType());
foundVariableLength |= isVariadic;
if (ignoreOperand && *ignoreOperand == operandIt.index())
continue;
Position *pos =
foundVariableLength
? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
: builder.getOperand(opPos, operandIt.index());
getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
}
}
if (types.size() == 1 && isa<pdl::RangeType>(types[0].getType())) {
getTreePredicates(predList, types.front(), builder, inputs,
builder.getType(builder.getAllResults(opPos)));
return;
}
bool foundVariableLength = false;
for (auto [idx, typeValue] : llvm::enumerate(types)) {
bool isVariadic = isa<pdl::RangeType>(typeValue.getType());
foundVariableLength |= isVariadic;
auto *resultPos = foundVariableLength
? builder.getResultGroup(pos, idx, isVariadic)
: builder.getResult(pos, idx);
predList.emplace_back(resultPos, builder.getIsNotNull());
getTreePredicates(predList, typeValue, builder, inputs,
builder.getType(resultPos));
}
}
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
TypePosition *pos) {
if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
if (Attribute type = typeOp.getConstantTypeAttr())
predList.emplace_back(pos, builder.getTypeConstraint(type));
} else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
if (Attribute typeAttr = typeOp.getConstantTypesAttr())
predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
}
}
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
Position *pos) {
auto it = inputs.try_emplace(val, pos);
if (!it.second) {
if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
pdl::TypeOp>(val.getDefiningOp())) {
auto minMaxPositions =
std::minmax(pos, it.first->second, comparePosDepth);
predList.emplace_back(minMaxPositions.second,
builder.getEqualTo(minMaxPositions.first));
}
return;
}
TypeSwitch<Position *>(pos)
.Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
getTreePredicates(predList, val, builder, inputs, pos);
})
.Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
getOperandTreePredicates(predList, val, builder, inputs, pos);
})
.Default([](auto *) { llvm_unreachable("unexpected position kind"); });
}
static void getAttributePredicates(pdl::AttributeOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
Position *&attrPos = inputs[op];
if (attrPos)
return;
Attribute value = op.getValueAttr();
assert(value && "expected non-tree `pdl.attribute` to contain a value");
attrPos = builder.getAttributeLiteral(value);
}
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
OperandRange arguments = op.getArgs();
std::vector<Position *> allPositions;
allPositions.reserve(arguments.size());
for (Value arg : arguments)
allPositions.push_back(inputs.lookup(arg));
Position *pos = *llvm::max_element(allPositions, comparePosDepth);
ResultRange results = op.getResults();
PredicateBuilder::Predicate pred = builder.getConstraint(
op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
op.getIsNegated());
for (auto [i, result] : llvm::enumerate(results)) {
ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
ConstraintPosition *pos = builder.getConstraintPosition(q, i);
auto [it, inserted] = inputs.try_emplace(result, pos);
if (!inserted) {
Position *first = pos;
Position *second = it->second;
if (comparePosDepth(second, first))
std::tie(second, first) = std::make_pair(first, second);
predList.emplace_back(second, builder.getEqualTo(first));
}
}
predList.emplace_back(pos, pred);
}
static void getResultPredicates(pdl::ResultOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
Position *&resultPos = inputs[op];
if (resultPos)
return;
auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
resultPos = builder.getResult(parentPos, op.getIndex());
predList.emplace_back(resultPos, builder.getIsNotNull());
}
static void getResultPredicates(pdl::ResultsOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
Position *&resultPos = inputs[op];
if (resultPos)
return;
auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
bool isVariadic = isa<pdl::RangeType>(op.getType());
std::optional<unsigned> index = op.getIndex();
resultPos = builder.getResultGroup(parentPos, index, isVariadic);
if (index)
predList.emplace_back(resultPos, builder.getIsNotNull());
}
static void getTypePredicates(Value typeValue,
function_ref<Attribute()> typeAttrFn,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
Position *&typePos = inputs[typeValue];
if (typePos)
return;
Attribute typeAttr = typeAttrFn();
assert(typeAttr &&
"expected non-tree `pdl.type`/`pdl.types` to contain a value");
typePos = builder.getTypeLiteral(typeAttr);
}
static void getNonTreePredicates(pdl::PatternOp pattern,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
for (Operation &op : pattern.getBodyRegion().getOps()) {
TypeSwitch<Operation *>(&op)
.Case([&](pdl::AttributeOp attrOp) {
getAttributePredicates(attrOp, predList, builder, inputs);
})
.Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
getConstraintPredicates(constraintOp, predList, builder, inputs);
})
.Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
getResultPredicates(resultOp, predList, builder, inputs);
})
.Case([&](pdl::TypeOp typeOp) {
getTypePredicates(
typeOp, [&] { return typeOp.getConstantTypeAttr(); }, builder,
inputs);
})
.Case([&](pdl::TypesOp typeOp) {
getTypePredicates(
typeOp, [&] { return typeOp.getConstantTypesAttr(); }, builder,
inputs);
});
}
}
namespace {
struct OpIndex {
Value parent;
std::optional<unsigned> index;
};
using ParentMaps = DenseMap<Value, DenseMap<Value, OpIndex>>;
}
static SmallVector<Value> detectRoots(pdl::PatternOp pattern) {
DenseSet<Value> used;
for (auto operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>()) {
for (Value operand : operationOp.getOperandValues())
TypeSwitch<Operation *>(operand.getDefiningOp())
.Case<pdl::ResultOp, pdl::ResultsOp>(
[&used](auto resultOp) { used.insert(resultOp.getParent()); });
}
if (Value root = pattern.getRewriter().getRoot())
used.erase(root);
SmallVector<Value> roots;
for (Value operationOp : pattern.getBodyRegion().getOps<pdl::OperationOp>())
if (!used.contains(operationOp))
roots.push_back(operationOp);
return roots;
}
static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
ParentMaps &parentMaps) {
struct Entry {
Entry(Value value, Value parent, std::optional<unsigned> index,
unsigned depth)
: value(value), parent(parent), index(index), depth(depth) {}
Value value;
Value parent;
std::optional<unsigned> index;
unsigned depth;
};
struct RootDepth {
Value root;
unsigned depth = 0;
};
llvm::MapVector<Value, SmallVector<RootDepth, 1>> connectorsRootsDepths;
for (Value root : roots) {
std::queue<Entry> toVisit;
toVisit.emplace(root, Value(), 0, 0);
DenseMap<Value, OpIndex> &parentMap = parentMaps[root];
while (!toVisit.empty()) {
Entry entry = toVisit.front();
toVisit.pop();
if (!parentMap.insert({entry.value, {entry.parent, entry.index}}).second)
continue;
connectorsRootsDepths[entry.value].push_back({root, entry.depth});
TypeSwitch<Operation *>(entry.value.getDefiningOp())
.Case<pdl::OperationOp>([&](auto operationOp) {
OperandRange operands = operationOp.getOperandValues();
if (operands.size() == 1 &&
isa<pdl::RangeType>(operands[0].getType())) {
toVisit.emplace(operands[0], entry.value, std::nullopt,
entry.depth + 1);
return;
}
for (const auto &p :
llvm::enumerate(operationOp.getOperandValues()))
toVisit.emplace(p.value(), entry.value, p.index(),
entry.depth + 1);
})
.Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
toVisit.emplace(resultOp.getParent(), entry.value,
resultOp.getIndex(), entry.depth);
});
}
}
unsigned nextID = 0;
for (const auto &connectorRootsDepths : connectorsRootsDepths) {
Value value = connectorRootsDepths.first;
ArrayRef<RootDepth> rootsDepths = connectorRootsDepths.second;
if (rootsDepths.size() == 1)
continue;
for (const RootDepth &p : rootsDepths) {
for (const RootDepth &q : rootsDepths) {
if (&p == &q)
continue;
RootOrderingEntry &entry = graph[q.root][p.root];
if (!entry.connector || entry.cost.first > q.depth) {
if (!entry.connector)
entry.cost.second = nextID++;
entry.cost.first = q.depth;
entry.connector = value;
}
}
}
}
assert((llvm::hasSingleElement(roots) || graph.size() == roots.size()) &&
"the pattern contains a candidate root disconnected from the others");
}
static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
OperandRange operands = op.getOperandValues();
assert(index < operands.size() && "operand index out of range");
for (unsigned i = 0; i <= index; ++i)
if (isa<pdl::RangeType>(operands[i].getType()))
return true;
return false;
}
static void visitUpward(std::vector<PositionalPredicate> &predList,
OpIndex opIndex, PredicateBuilder &builder,
DenseMap<Value, Position *> &valueToPosition,
Position *&pos, unsigned rootID) {
Value value = opIndex.parent;
TypeSwitch<Operation *>(value.getDefiningOp())
.Case<pdl::OperationOp>([&](auto operationOp) {
LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
Position *usersPos = builder.getUsers(pos, true);
Position *foreachPos = builder.getForEach(usersPos, rootID);
OperationPosition *opPos = builder.getPassthroughOp(foreachPos);
Position *operandPos;
if (!opIndex.index) {
operandPos = builder.getAllOperands(opPos);
} else if (useOperandGroup(operationOp, *opIndex.index)) {
Type type = operationOp.getOperandValues()[*opIndex.index].getType();
bool variadic = isa<pdl::RangeType>(type);
operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
} else {
operandPos = builder.getOperand(opPos, *opIndex.index);
}
predList.emplace_back(operandPos, builder.getEqualTo(pos));
bool inserted = valueToPosition.try_emplace(value, opPos).second;
(void)inserted;
assert(inserted && "duplicate upward visit");
getTreePredicates(predList, value, builder, valueToPosition, opPos,
opIndex.index);
pos = opPos;
})
.Case<pdl::ResultOp>([&](auto resultOp) {
auto *opPos = dyn_cast<OperationPosition>(pos);
assert(opPos && "operations and results must be interleaved");
pos = builder.getResult(opPos, *opIndex.index);
valueToPosition.try_emplace(value, pos);
})
.Case<pdl::ResultsOp>([&](auto resultOp) {
auto *opPos = dyn_cast<OperationPosition>(pos);
assert(opPos && "operations and results must be interleaved");
bool isVariadic = isa<pdl::RangeType>(value.getType());
if (opIndex.index)
pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
else
pos = builder.getAllResults(opPos);
valueToPosition.try_emplace(value, pos);
});
}
static Value buildPredicateList(pdl::PatternOp pattern,
PredicateBuilder &builder,
std::vector<PositionalPredicate> &predList,
DenseMap<Value, Position *> &valueToPosition) {
SmallVector<Value> roots = detectRoots(pattern);
RootOrderingGraph graph;
ParentMaps parentMaps;
buildCostGraph(roots, graph, parentMaps);
LLVM_DEBUG({
llvm::dbgs() << "Graph:\n";
for (auto &target : graph) {
llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first
<< "\n";
for (auto &source : target.second) {
RootOrderingEntry &entry = source.second;
llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first
<< ":" << entry.cost.second << " via "
<< entry.connector.getLoc() << "\n";
}
}
});
Value bestRoot = pattern.getRewriter().getRoot();
OptimalBranching::EdgeList bestEdges;
if (!bestRoot) {
unsigned bestCost = 0;
LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
for (Value root : roots) {
OptimalBranching solver(graph, root);
unsigned cost = solver.solve();
LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n");
if (!bestRoot || bestCost > cost) {
bestCost = cost;
bestRoot = root;
bestEdges = solver.preOrderTraversal(roots);
}
}
} else {
OptimalBranching solver(graph, bestRoot);
solver.solve();
bestEdges = solver.preOrderTraversal(roots);
}
LLVM_DEBUG({
llvm::dbgs() << "Best tree:\n";
for (const std::pair<Value, Value> &edge : bestEdges) {
llvm::dbgs() << " * " << edge.first;
if (edge.second)
llvm::dbgs() << " <- " << edge.second;
llvm::dbgs() << "\n";
}
});
LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n");
getTreePredicates(predList, bestRoot, builder, valueToPosition,
builder.getRoot());
for (const auto &it : llvm::enumerate(bestEdges)) {
Value target = it.value().first;
Value source = it.value().second;
if (valueToPosition.count(target))
continue;
Value connector = graph[target][source].connector;
assert(connector && "invalid edge");
LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n");
DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target);
Position *pos = valueToPosition.lookup(connector);
assert(pos && "connector has not been traversed yet");
for (Value value = connector; value != target;) {
OpIndex opIndex = parentMap.lookup(value);
assert(opIndex.parent && "missing parent");
visitUpward(predList, opIndex, builder, valueToPosition, pos, it.index());
value = opIndex.parent;
}
}
getNonTreePredicates(pattern, predList, builder, valueToPosition);
return bestRoot;
}
namespace {
struct OrderedPredicate {
OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
: position(ip.first), question(ip.second) {}
OrderedPredicate(const PositionalPredicate &ip)
: position(ip.position), question(ip.question) {}
Position *position;
Qualifier *question;
unsigned primary = 0;
unsigned secondary = 0;
unsigned id = 0;
DenseMap<Operation *, Qualifier *> patternToAnswer;
bool operator<(const OrderedPredicate &rhs) const {
auto *rhsPos = rhs.position;
return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
rhsPos->getKind(), rhs.question->getKind(), rhs.id) >
std::make_tuple(rhs.primary, rhs.secondary,
position->getOperationDepth(), position->getKind(),
question->getKind(), id);
}
};
struct OrderedPredicateDenseInfo {
using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
static bool isEqual(const OrderedPredicate &lhs,
const OrderedPredicate &rhs) {
return lhs.position == rhs.position && lhs.question == rhs.question;
}
static unsigned getHashValue(const OrderedPredicate &p) {
return llvm::hash_combine(p.position, p.question);
}
};
struct OrderedPredicateList {
OrderedPredicateList(pdl::PatternOp pattern, Value root)
: pattern(pattern), root(root) {}
pdl::PatternOp pattern;
Value root;
DenseSet<OrderedPredicate *> predicates;
};
}
static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
return node->getPosition() == predicate->position &&
node->getQuestion() == predicate->question;
}
std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
OrderedPredicate *predicate,
pdl::PatternOp pattern) {
assert(isSamePredicate(node, predicate) &&
"expected matcher to equal the given predicate");
auto it = predicate->patternToAnswer.find(pattern);
assert(it != predicate->patternToAnswer.end() &&
"expected pattern to exist in predicate");
return node->getChildren().insert({it->second, nullptr}).first->second;
}
static void propagatePattern(std::unique_ptr<MatcherNode> &node,
OrderedPredicateList &list,
std::vector<OrderedPredicate *>::iterator current,
std::vector<OrderedPredicate *>::iterator end) {
if (current == end) {
node =
std::make_unique<SuccessNode>(list.pattern, list.root, std::move(node));
} else if (!list.predicates.contains(*current)) {
propagatePattern(node, list, std::next(current), end);
} else if (!node) {
node = std::make_unique<SwitchNode>((*current)->position,
(*current)->question);
propagatePattern(
getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
list, std::next(current), end);
} else if (isSamePredicate(node.get(), *current)) {
propagatePattern(
getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
list, std::next(current), end);
} else {
propagatePattern(node->getFailureNode(), list, current, end);
}
}
static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
if (!node)
return;
if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
SwitchNode::ChildMapT &children = switchNode->getChildren();
for (auto &it : children)
foldSwitchToBool(it.second);
if (children.size() == 1) {
auto *childIt = children.begin();
node = std::make_unique<BoolNode>(
node->getPosition(), node->getQuestion(), childIt->first,
std::move(childIt->second), std::move(node->getFailureNode()));
}
} else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
foldSwitchToBool(boolNode->getSuccessNode());
}
foldSwitchToBool(node->getFailureNode());
}
static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
while (*root)
root = &(*root)->getFailureNode();
*root = std::make_unique<ExitNode>();
}
template <typename Iterator, typename Compare>
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
while (begin != end) {
llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
for (auto i = begin; i != end; ++i) {
if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
sortBeforeOthers.insert(*i);
}
auto const next = std::stable_partition(begin, end, [&](auto const &a) {
return sortBeforeOthers.contains(a);
});
assert(next != begin && "not a partial ordering");
begin = next;
}
}
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
if (!cqa)
return false;
auto positionDependsOnA = [&](Position *p) {
auto *cp = dyn_cast<ConstraintPosition>(p);
return cp && cp->getQuestion() == cqa;
};
if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
return llvm::any_of(cqb->getArgs(), positionDependsOnA);
}
if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
return positionDependsOnA(b->position) ||
positionDependsOnA(equalTo->getValue());
}
return positionDependsOnA(b->position);
}
std::unique_ptr<MatcherNode>
MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
DenseMap<Value, Position *> &valueToPosition) {
struct PatternPredicates {
PatternPredicates(pdl::PatternOp pattern, Value root,
std::vector<PositionalPredicate> predicates)
: pattern(pattern), root(root), predicates(std::move(predicates)) {}
pdl::PatternOp pattern;
Value root;
std::vector<PositionalPredicate> predicates;
};
SmallVector<PatternPredicates, 16> patternsAndPredicates;
for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
std::vector<PositionalPredicate> predicateList;
Value root =
buildPredicateList(pattern, builder, predicateList, valueToPosition);
patternsAndPredicates.emplace_back(pattern, root, std::move(predicateList));
}
DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
for (auto &patternAndPredList : patternsAndPredicates) {
for (auto &predicate : patternAndPredList.predicates) {
auto it = uniqued.insert(predicate);
it.first->patternToAnswer.try_emplace(patternAndPredList.pattern,
predicate.answer);
if (it.second)
it.first->id = uniqued.size() - 1;
}
}
std::vector<OrderedPredicateList> lists;
lists.reserve(patternsAndPredicates.size());
for (auto &patternAndPredList : patternsAndPredicates) {
OrderedPredicateList list(patternAndPredList.pattern,
patternAndPredList.root);
for (auto &predicate : patternAndPredList.predicates) {
OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
list.predicates.insert(orderedPredicate);
++orderedPredicate->primary;
}
lists.push_back(std::move(list));
}
for (auto &list : lists) {
unsigned total = 0;
for (auto *predicate : list.predicates)
total += predicate->primary * predicate->primary;
for (auto *predicate : list.predicates)
predicate->secondary += total;
}
std::vector<OrderedPredicate *> ordered;
ordered.reserve(uniqued.size());
for (auto &ip : uniqued)
ordered.push_back(&ip);
llvm::sort(ordered, [](OrderedPredicate *lhs, OrderedPredicate *rhs) {
return *lhs < *rhs;
});
stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
std::unique_ptr<MatcherNode> root;
for (OrderedPredicateList &list : lists)
propagatePattern(root, list, ordered.begin(), ordered.end());
foldSwitchToBool(root);
insertExitNode(&root);
return root;
}
MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
std::unique_ptr<MatcherNode> failureNode)
: position(p), question(q), failureNode(std::move(failureNode)),
matcherTypeID(matcherTypeID) {}
BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
std::unique_ptr<MatcherNode> successNode,
std::unique_ptr<MatcherNode> failureNode)
: MatcherNode(TypeID::get<BoolNode>(), position, question,
std::move(failureNode)),
answer(answer), successNode(std::move(successNode)) {}
SuccessNode::SuccessNode(pdl::PatternOp pattern, Value root,
std::unique_ptr<MatcherNode> failureNode)
: MatcherNode(TypeID::get<SuccessNode>(), nullptr,
nullptr, std::move(failureNode)),
pattern(pattern), root(root) {}
SwitchNode::SwitchNode(Position *position, Qualifier *question)
: MatcherNode(TypeID::get<SwitchNode>(), position, question) {}