#include "mlir/TableGen/Predicate.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace tblgen;
Pred::Pred(const llvm::Record *record) : def(record) {
assert(def->isSubClassOf("Pred") &&
"must be a subclass of TableGen 'Pred' class");
}
Pred::Pred(const llvm::Init *init) {
if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
def = defInit->getDef();
}
std::string Pred::getCondition() const {
if (def->isSubClassOf("CombinedPred"))
return static_cast<const CombinedPred *>(this)->getConditionImpl();
if (def->isSubClassOf("CPred"))
return static_cast<const CPred *>(this)->getConditionImpl();
llvm_unreachable("Pred::getCondition must be overridden in subclasses");
}
bool Pred::isCombined() const {
return def && def->isSubClassOf("CombinedPred");
}
ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); }
CPred::CPred(const llvm::Record *record) : Pred(record) {
assert(def->isSubClassOf("CPred") &&
"must be a subclass of Tablegen 'CPred' class");
}
CPred::CPred(const llvm::Init *init) : Pred(init) {
assert((!def || def->isSubClassOf("CPred")) &&
"must be a subclass of Tablegen 'CPred' class");
}
std::string CPred::getConditionImpl() const {
assert(!isNull() && "null predicate does not have a condition");
return std::string(def->getValueAsString("predExpr"));
}
CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
assert(def->isSubClassOf("CombinedPred") &&
"must be a subclass of Tablegen 'CombinedPred' class");
}
CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
assert((!def || def->isSubClassOf("CombinedPred")) &&
"must be a subclass of Tablegen 'CombinedPred' class");
}
const llvm::Record *CombinedPred::getCombinerDef() const {
assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
return def->getValueAsDef("kind");
}
std::vector<llvm::Record *> CombinedPred::getChildren() const {
assert(def->getValue("children") &&
"CombinedPred must have a value 'children'");
return def->getValueAsListOfDefs("children");
}
namespace {
enum class PredCombinerKind {
Leaf,
And,
Or,
Not,
SubstLeaves,
Concat,
False,
True
};
struct PredNode {
PredCombinerKind kind;
const Pred *predicate;
SmallVector<PredNode *, 4> children;
std::string expr;
std::string prefix;
std::string suffix;
};
}
static PredCombinerKind getPredCombinerKind(const Pred &pred) {
if (!pred.isCombined())
return PredCombinerKind::Leaf;
const auto &combinedPred = static_cast<const CombinedPred &>(pred);
return StringSwitch<PredCombinerKind>(
combinedPred.getCombinerDef()->getName())
.Case("PredCombinerAnd", PredCombinerKind::And)
.Case("PredCombinerOr", PredCombinerKind::Or)
.Case("PredCombinerNot", PredCombinerKind::Not)
.Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
.Case("PredCombinerConcat", PredCombinerKind::Concat);
}
namespace {
using Subst = std::pair<StringRef, StringRef>;
}
static void performSubstitutions(std::string &str,
ArrayRef<Subst> substitutions) {
for (const auto &subst : llvm::reverse(substitutions)) {
auto pos = str.find(std::string(subst.first));
while (pos != std::string::npos) {
str.replace(pos, subst.first.size(), std::string(subst.second));
pos += subst.second.size();
pos = str.find(std::string(subst.first), pos);
}
}
}
static PredNode *
buildPredicateTree(const Pred &root,
llvm::SpecificBumpPtrAllocator<PredNode> &allocator,
ArrayRef<Subst> substitutions) {
auto *rootNode = allocator.Allocate();
new (rootNode) PredNode;
rootNode->kind = getPredCombinerKind(root);
rootNode->predicate = &root;
if (!root.isCombined()) {
rootNode->expr = root.getCondition();
performSubstitutions(rootNode->expr, substitutions);
return rootNode;
}
auto allSubstitutions = llvm::to_vector<4>(substitutions);
if (rootNode->kind == PredCombinerKind::SubstLeaves) {
const auto &substPred = static_cast<const SubstLeavesPred &>(root);
allSubstitutions.push_back(
{substPred.getPattern(), substPred.getReplacement()});
} else if (rootNode->kind == PredCombinerKind::Concat) {
const auto &concatPred = static_cast<const ConcatPred &>(root);
rootNode->prefix = std::string(concatPred.getPrefix());
performSubstitutions(rootNode->prefix, substitutions);
rootNode->suffix = std::string(concatPred.getSuffix());
performSubstitutions(rootNode->suffix, substitutions);
}
auto combined = static_cast<const CombinedPred &>(root);
for (const auto *record : combined.getChildren()) {
auto *childTree =
buildPredicateTree(Pred(record), allocator, allSubstitutions);
rootNode->children.push_back(childTree);
}
return rootNode;
}
static PredNode *
propagateGroundTruth(PredNode *node,
const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
if (knownTruePreds.count(node->predicate) != 0) {
node->kind = PredCombinerKind::True;
node->children.clear();
return node;
}
if (knownFalsePreds.count(node->predicate) != 0) {
node->kind = PredCombinerKind::False;
node->children.clear();
return node;
}
if (node->kind == PredCombinerKind::SubstLeaves) {
return node;
}
llvm::SmallVector<PredNode *, 4> children;
std::swap(node->children, children);
for (auto &child : children) {
auto *simplifiedChild =
propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
if (node->kind != PredCombinerKind::And &&
node->kind != PredCombinerKind::Or) {
node->children.push_back(simplifiedChild);
continue;
}
auto collapseKind = node->kind == PredCombinerKind::And
? PredCombinerKind::False
: PredCombinerKind::True;
auto eraseKind = node->kind == PredCombinerKind::And
? PredCombinerKind::True
: PredCombinerKind::False;
const auto &collapseList =
node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
const auto &eraseList =
node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
if (simplifiedChild->kind == collapseKind ||
collapseList.count(simplifiedChild->predicate) != 0) {
node->kind = collapseKind;
node->children.clear();
return node;
}
if (simplifiedChild->kind == eraseKind ||
eraseList.count(simplifiedChild->predicate) != 0) {
continue;
}
node->children.push_back(simplifiedChild);
}
return node;
}
static std::string combineBinary(ArrayRef<std::string> children,
const std::string &combiner,
std::string init) {
if (children.empty())
return init;
auto size = children.size();
if (size == 1)
return children.front();
std::string str;
llvm::raw_string_ostream os(str);
os << '(' << children.front() << ')';
for (unsigned i = 1; i < size; ++i) {
os << ' ' << combiner << " (" << children[i] << ')';
}
return os.str();
}
static std::string combineNot(ArrayRef<std::string> children) {
assert(children.size() == 1 && "expected exactly one child predicate of Neg");
return (Twine("!(") + children.front() + Twine(')')).str();
}
static std::string getCombinedCondition(const PredNode &root) {
if (root.kind == PredCombinerKind::Leaf)
return root.expr;
if (root.kind == PredCombinerKind::True)
return "true";
if (root.kind == PredCombinerKind::False)
return "false";
llvm::SmallVector<std::string, 4> childExpressions;
childExpressions.reserve(root.children.size());
for (const auto &child : root.children)
childExpressions.push_back(getCombinedCondition(*child));
if (root.kind == PredCombinerKind::And)
return combineBinary(childExpressions, "&&", "true");
if (root.kind == PredCombinerKind::Or)
return combineBinary(childExpressions, "||", "false");
if (root.kind == PredCombinerKind::Not)
return combineNot(childExpressions);
if (root.kind == PredCombinerKind::Concat) {
assert(childExpressions.size() == 1 &&
"ConcatPred should only have one child");
return root.prefix + childExpressions.front() + root.suffix;
}
if (root.kind == PredCombinerKind::SubstLeaves) {
assert(childExpressions.size() == 1 &&
"substitution predicate must have one child");
return childExpressions[0];
}
llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
}
std::string CombinedPred::getConditionImpl() const {
llvm::SpecificBumpPtrAllocator<PredNode> allocator;
auto *predicateTree = buildPredicateTree(*this, allocator, {});
predicateTree =
propagateGroundTruth(predicateTree,
llvm::SmallPtrSet<Pred *, 2>(),
llvm::SmallPtrSet<Pred *, 2>());
return getCombinedCondition(*predicateTree);
}
StringRef SubstLeavesPred::getPattern() const {
return def->getValueAsString("pattern");
}
StringRef SubstLeavesPred::getReplacement() const {
return def->getValueAsString("replacement");
}
StringRef ConcatPred::getPrefix() const {
return def->getValueAsString("prefix");
}
StringRef ConcatPred::getSuffix() const {
return def->getValueAsString("suffix");
}