#include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/ODS/Operation.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <optional>
using namespace mlir;
using namespace mlir::pdll;
namespace {
class CodeGen {
public:
CodeGen(raw_ostream &os) : os(os) {}
void generate(const ast::Module &astModule, ModuleOp module);
private:
void generate(pdl::PatternOp pattern, StringRef patternName,
StringSet<> &nativeFunctions);
void generateConstraintAndRewrites(const ast::Module &astModule,
ModuleOp module,
StringSet<> &nativeFunctions);
void generate(const ast::UserConstraintDecl *decl,
StringSet<> &nativeFunctions);
void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
void generateConstraintOrRewrite(const ast::CallableDecl *decl,
bool isConstraint,
StringSet<> &nativeFunctions);
StringRef getNativeTypeName(ast::Type type);
StringRef getNativeTypeName(ast::VariableDecl *decl);
raw_ostream &os;
};
}
void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
StringSet<> nativeFunctions;
generateConstraintAndRewrites(astModule, module, nativeFunctions);
os << "namespace {\n";
std::string basePatternName = "GeneratedPDLLPattern";
int patternIndex = 0;
for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
if (std::optional<StringRef> patternName = pattern.getSymName()) {
patternNames.insert(patternName->str());
} else {
std::string name;
do {
name = (basePatternName + Twine(patternIndex++)).str();
} while (!patternNames.insert(name));
}
generate(pattern, patternNames.back(), nativeFunctions);
}
os << "} // end namespace\n\n";
os << "template <typename... ConfigsT>\n"
"static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
"::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
for (const auto &name : patternNames)
os << " patterns.add<" << name
<< ">(patterns.getContext(), configs...);\n";
os << "}\n";
}
void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
StringSet<> &nativeFunctions) {
const char *patternClassStartStr = R"(
struct {0} : ::mlir::PDLPatternModule {{
template <typename... ConfigsT>
{0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
: ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
)";
os << llvm::formatv(patternClassStartStr, patternName);
os << "R\"mlir(";
pattern->print(os, OpPrintingFlags().enableDebugInfo());
os << "\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
StringSet<> registeredNativeFunctions;
auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
if (!nativeFunctions.count(fnName) ||
!registeredNativeFunctions.insert(fnName).second)
return;
os << " register" << fnType << "Function(\"" << fnName << "\", "
<< fnName << "PDLFn);\n";
};
pattern.walk([&](Operation *op) {
if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
checkRegisterNativeFn(constraintOp.getName(), "Constraint");
else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
checkRegisterNativeFn(rewriteOp.getName(), "Rewrite");
});
os << " }\n};\n\n";
}
void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
ModuleOp module,
StringSet<> &nativeFunctions) {
StringSet<> usedFns;
module.walk([&](Operation *op) {
TypeSwitch<Operation *>(op)
.Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
[&](auto op) { usedFns.insert(op.getName()); });
});
for (const ast::Decl *decl : astModule.getChildren()) {
TypeSwitch<const ast::Decl *>(decl)
.Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
[&](const auto *decl) {
if (decl->getCodeBlock() &&
usedFns.contains(decl->getName().getName()))
this->generate(decl, nativeFunctions);
});
}
}
void CodeGen::generate(const ast::UserConstraintDecl *decl,
StringSet<> &nativeFunctions) {
return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
true, nativeFunctions);
}
void CodeGen::generate(const ast::UserRewriteDecl *decl,
StringSet<> &nativeFunctions) {
return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
false, nativeFunctions);
}
StringRef CodeGen::getNativeTypeName(ast::Type type) {
return llvm::TypeSwitch<ast::Type, StringRef>(type)
.Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
.Case([&](ast::OperationType opType) -> StringRef {
if (const auto *odsOp = opType.getODSOperation())
return odsOp->getNativeClassName();
return "::mlir::Operation *";
})
.Case([&](ast::TypeType) { return "::mlir::Type"; })
.Case([&](ast::ValueType) { return "::mlir::Value"; })
.Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
.Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
}
StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
for (ast::ConstraintRef &cst : decl->getConstraints()) {
if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
if (std::optional<StringRef> name = userCst->getNativeInputType(0))
return *name;
return getNativeTypeName(userCst->getInputs()[0]);
}
}
return getNativeTypeName(decl->getType());
}
void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
bool isConstraint,
StringSet<> &nativeFunctions) {
StringRef name = decl->getName()->getName();
nativeFunctions.insert(name);
os << "static ";
if (isConstraint) {
os << "::llvm::LogicalResult";
} else {
ArrayRef<ast::VariableDecl *> results = decl->getResults();
if (results.empty()) {
os << "void";
} else if (results.size() == 1) {
os << getNativeTypeName(results[0]);
} else {
os << "std::tuple<";
llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) {
os << getNativeTypeName(result);
});
os << ">";
}
}
os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
if (!decl->getInputs().empty()) {
os << ", ";
llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) {
os << getNativeTypeName(input) << " " << input->getName().getName();
});
}
os << ") {\n";
os << " " << decl->getCodeBlock()->trim() << "\n}\n\n";
}
void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
raw_ostream &os) {
CodeGen codegen(os);
codegen.generate(astModule, module);
}