#include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
#include <optional>
using namespace mlir;
using namespace mlir::pdll::ast;
namespace {
class NodePrinter {
public:
NodePrinter(raw_ostream &os) : os(os) {}
void print(Type type);
void print(const Node *node);
private:
template <typename RangeT,
std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
* = nullptr>
void printChildren(RangeT &&range) {
if (range.empty())
return;
auto it = std::begin(range);
for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
print(*it);
elementIndentStack.back() = true;
print(*it);
}
template <typename RangeT, typename... OthersT,
std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
* = nullptr>
void printChildren(RangeT &&range, OthersT &&...others) {
printChildren(ArrayRef<const Node *>({range, others...}));
}
template <typename RangeT>
void printChildren(StringRef label, RangeT &&range) {
if (range.empty())
return;
elementIndentStack.reserve(elementIndentStack.size() + 1);
llvm::SaveAndRestore lastElement(elementIndentStack.back(), true);
printIndent();
os << label << "`\n";
elementIndentStack.push_back( false);
printChildren(std::forward<RangeT>(range));
elementIndentStack.pop_back();
}
void printImpl(const CompoundStmt *stmt);
void printImpl(const EraseStmt *stmt);
void printImpl(const LetStmt *stmt);
void printImpl(const ReplaceStmt *stmt);
void printImpl(const ReturnStmt *stmt);
void printImpl(const RewriteStmt *stmt);
void printImpl(const AttributeExpr *expr);
void printImpl(const CallExpr *expr);
void printImpl(const DeclRefExpr *expr);
void printImpl(const MemberAccessExpr *expr);
void printImpl(const OperationExpr *expr);
void printImpl(const RangeExpr *expr);
void printImpl(const TupleExpr *expr);
void printImpl(const TypeExpr *expr);
void printImpl(const AttrConstraintDecl *decl);
void printImpl(const OpConstraintDecl *decl);
void printImpl(const TypeConstraintDecl *decl);
void printImpl(const TypeRangeConstraintDecl *decl);
void printImpl(const UserConstraintDecl *decl);
void printImpl(const ValueConstraintDecl *decl);
void printImpl(const ValueRangeConstraintDecl *decl);
void printImpl(const NamedAttributeDecl *decl);
void printImpl(const OpNameDecl *decl);
void printImpl(const PatternDecl *decl);
void printImpl(const UserRewriteDecl *decl);
void printImpl(const VariableDecl *decl);
void printImpl(const Module *module);
void printIndent() {
if (elementIndentStack.empty())
return;
for (bool isLastElt : llvm::ArrayRef(elementIndentStack).drop_back())
os << (isLastElt ? " " : " |");
os << (elementIndentStack.back() ? " `" : " |");
}
raw_ostream &os;
SmallVector<bool> elementIndentStack;
};
}
void NodePrinter::print(Type type) {
if (!type) {
os << "Type<NULL>";
return;
}
TypeSwitch<Type>(type)
.Case([&](AttributeType) { os << "Attr"; })
.Case([&](ConstraintType) { os << "Constraint"; })
.Case([&](OperationType type) {
os << "Op";
if (std::optional<StringRef> name = type.getName())
os << "<" << *name << ">";
})
.Case([&](RangeType type) {
print(type.getElementType());
os << "Range";
})
.Case([&](RewriteType) { os << "Rewrite"; })
.Case([&](TupleType type) {
os << "Tuple<";
llvm::interleaveComma(
llvm::zip(type.getElementNames(), type.getElementTypes()), os,
[&](auto it) {
if (!std::get<0>(it).empty())
os << std::get<0>(it) << ": ";
this->print(std::get<1>(it));
});
os << ">";
})
.Case([&](TypeType) { os << "Type"; })
.Case([&](ValueType) { os << "Value"; })
.Default([](Type) { llvm_unreachable("unknown AST type"); });
}
void NodePrinter::print(const Node *node) {
printIndent();
os << "-";
elementIndentStack.push_back( false);
TypeSwitch<const Node *>(node)
.Case<
const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
const ReturnStmt, const RewriteStmt,
const AttributeExpr, const CallExpr, const DeclRefExpr,
const MemberAccessExpr, const OperationExpr, const RangeExpr,
const TupleExpr, const TypeExpr,
const AttrConstraintDecl, const OpConstraintDecl,
const TypeConstraintDecl, const TypeRangeConstraintDecl,
const UserConstraintDecl, const ValueConstraintDecl,
const ValueRangeConstraintDecl, const NamedAttributeDecl,
const OpNameDecl, const PatternDecl, const UserRewriteDecl,
const VariableDecl,
const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
.Default([](const Node *) { llvm_unreachable("unknown AST node"); });
elementIndentStack.pop_back();
}
void NodePrinter::printImpl(const CompoundStmt *stmt) {
os << "CompoundStmt " << stmt << "\n";
printChildren(stmt->getChildren());
}
void NodePrinter::printImpl(const EraseStmt *stmt) {
os << "EraseStmt " << stmt << "\n";
printChildren(stmt->getRootOpExpr());
}
void NodePrinter::printImpl(const LetStmt *stmt) {
os << "LetStmt " << stmt << "\n";
printChildren(stmt->getVarDecl());
}
void NodePrinter::printImpl(const ReplaceStmt *stmt) {
os << "ReplaceStmt " << stmt << "\n";
printChildren(stmt->getRootOpExpr());
printChildren("ReplValues", stmt->getReplExprs());
}
void NodePrinter::printImpl(const ReturnStmt *stmt) {
os << "ReturnStmt " << stmt << "\n";
printChildren(stmt->getResultExpr());
}
void NodePrinter::printImpl(const RewriteStmt *stmt) {
os << "RewriteStmt " << stmt << "\n";
printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
}
void NodePrinter::printImpl(const AttributeExpr *expr) {
os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
}
void NodePrinter::printImpl(const CallExpr *expr) {
os << "CallExpr " << expr << " Type<";
print(expr->getType());
os << ">";
if (expr->getIsNegated())
os << " Negated";
os << "\n";
printChildren(expr->getCallableExpr());
printChildren("Arguments", expr->getArguments());
}
void NodePrinter::printImpl(const DeclRefExpr *expr) {
os << "DeclRefExpr " << expr << " Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getDecl());
}
void NodePrinter::printImpl(const MemberAccessExpr *expr) {
os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
<< "> Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getParentExpr());
}
void NodePrinter::printImpl(const OperationExpr *expr) {
os << "OperationExpr " << expr << " Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getNameDecl());
printChildren("Operands", expr->getOperands());
printChildren("Result Types", expr->getResultTypes());
printChildren("Attributes", expr->getAttributes());
}
void NodePrinter::printImpl(const RangeExpr *expr) {
os << "RangeExpr " << expr << " Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getElements());
}
void NodePrinter::printImpl(const TupleExpr *expr) {
os << "TupleExpr " << expr << " Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getElements());
}
void NodePrinter::printImpl(const TypeExpr *expr) {
os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
}
void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
os << "AttrConstraintDecl " << decl << "\n";
if (const auto *typeExpr = decl->getTypeExpr())
printChildren(typeExpr);
}
void NodePrinter::printImpl(const OpConstraintDecl *decl) {
os << "OpConstraintDecl " << decl << "\n";
printChildren(decl->getNameDecl());
}
void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
os << "TypeConstraintDecl " << decl << "\n";
}
void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
os << "TypeRangeConstraintDecl " << decl << "\n";
}
void NodePrinter::printImpl(const UserConstraintDecl *decl) {
os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName()
<< "> ResultType<" << decl->getResultType() << ">";
if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
os << " Code<";
llvm::printEscapedString(*codeBlock, os);
os << ">";
}
os << "\n";
printChildren("Inputs", decl->getInputs());
printChildren("Results", decl->getResults());
if (const CompoundStmt *body = decl->getBody())
printChildren(body);
}
void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
os << "ValueConstraintDecl " << decl << "\n";
if (const auto *typeExpr = decl->getTypeExpr())
printChildren(typeExpr);
}
void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
os << "ValueRangeConstraintDecl " << decl << "\n";
if (const auto *typeExpr = decl->getTypeExpr())
printChildren(typeExpr);
}
void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
<< ">\n";
printChildren(decl->getValue());
}
void NodePrinter::printImpl(const OpNameDecl *decl) {
os << "OpNameDecl " << decl;
if (std::optional<StringRef> name = decl->getName())
os << " Name<" << *name << ">";
os << "\n";
}
void NodePrinter::printImpl(const PatternDecl *decl) {
os << "PatternDecl " << decl;
if (const Name *name = decl->getName())
os << " Name<" << name->getName() << ">";
if (std::optional<uint16_t> benefit = decl->getBenefit())
os << " Benefit<" << *benefit << ">";
if (decl->hasBoundedRewriteRecursion())
os << " Recursion";
os << "\n";
printChildren(decl->getBody());
}
void NodePrinter::printImpl(const UserRewriteDecl *decl) {
os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName()
<< "> ResultType<" << decl->getResultType() << ">";
if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
os << " Code<";
llvm::printEscapedString(*codeBlock, os);
os << ">";
}
os << "\n";
printChildren("Inputs", decl->getInputs());
printChildren("Results", decl->getResults());
if (const CompoundStmt *body = decl->getBody())
printChildren(body);
}
void NodePrinter::printImpl(const VariableDecl *decl) {
os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
<< "> Type<";
print(decl->getType());
os << ">\n";
if (Expr *initExpr = decl->getInitExpr())
printChildren(initExpr);
auto constraints =
llvm::map_range(decl->getConstraints(),
[](const ConstraintRef &ref) { return ref.constraint; });
printChildren("Constraints", constraints);
}
void NodePrinter::printImpl(const Module *module) {
os << "Module " << module << "\n";
printChildren(module->getChildren());
}
void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }