#include "AttrOrTypeFormatGen.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/TableGenBackend.h"
#define DEBUG_TYPE "mlir-tblgen-attrortypedefgen"
using namespace mlir;
using namespace mlir::tblgen;
static void collectAllDefs(StringRef selectedDialect,
std::vector<llvm::Record *> records,
SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
if (records.empty())
return;
auto defs = llvm::map_range(
records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); });
if (selectedDialect.empty()) {
if (!llvm::all_equal(llvm::map_range(
defs, [](const auto &def) { return def.getDialect(); }))) {
llvm::PrintFatalError("defs belonging to more than one dialect. Must "
"select one via '--(attr|type)defs-dialect'");
}
resultDefs.assign(defs.begin(), defs.end());
} else {
auto dialectDefs = llvm::make_filter_range(defs, [&](const auto &def) {
return def.getDialect().getName() == selectedDialect;
});
resultDefs.assign(dialectDefs.begin(), dialectDefs.end());
}
}
namespace {
class DefGen {
public:
DefGen(const AttrOrTypeDef &def);
void emitDecl(raw_ostream &os) const {
if (storageCls && def.genStorageClass()) {
NamespaceEmitter ns(os, def.getStorageNamespace());
os << "struct " << def.getStorageClassName() << ";\n";
}
defCls.writeDeclTo(os);
}
void emitDef(raw_ostream &os) const {
if (storageCls && def.genStorageClass()) {
NamespaceEmitter ns(os, def.getStorageNamespace());
storageCls->writeDeclTo(os);
}
defCls.writeDefTo(os);
}
private:
void createParentWithTraits();
void emitTopLevelDeclarations();
void emitName();
void emitDialectName();
void emitBuilders();
void emitVerifier();
void emitParserPrinter();
void emitAccessors();
void emitInterfaceMethods();
void emitDefaultBuilder();
void emitCheckedBuilder();
void emitCustomBuilder(const AttrOrTypeBuilder &builder);
void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);
void emitTraitMethods(const InterfaceTrait &trait);
void emitTraitMethod(const InterfaceMethod &method);
void emitStorageClass();
void emitStorageConstructor();
void emitKeyType();
void emitEquals();
void emitHashKey();
void emitConstruct();
SmallVector<MethodParameter>
getBuilderParams(std::initializer_list<MethodParameter> prefix) const;
const AttrOrTypeDef &def;
ArrayRef<AttrOrTypeParameter> params;
Class defCls;
std::optional<Class> storageCls;
StringRef valueType;
StringRef defType;
};
}
DefGen::DefGen(const AttrOrTypeDef &def)
: def(def), params(def.getParameters()), defCls(def.getCppClassName()),
valueType(isa<AttrDef>(def) ? "Attribute" : "Type"),
defType(isa<AttrDef>(def) ? "Attr" : "Type") {
for (const AttrOrTypeParameter ¶m : def.getParameters())
if (param.isAnonymous())
llvm::PrintFatalError("all parameters must have a name");
if (def.getNumParameters() > 0)
storageCls.emplace(def.getStorageClassName(), true);
createParentWithTraits();
emitTopLevelDeclarations();
if (storageCls)
emitBuilders();
emitName();
emitDialectName();
if (storageCls && def.genVerifyDecl())
emitVerifier();
if (def.getMnemonic())
emitParserPrinter();
if (def.genAccessors())
emitAccessors();
emitInterfaceMethods();
defCls.finalize();
if (storageCls && def.genStorageClass())
emitStorageClass();
}
void DefGen::createParentWithTraits() {
ParentClass defParent(strfmt("::mlir::{0}::{1}Base", valueType, defType));
defParent.addTemplateParam(def.getCppClassName());
defParent.addTemplateParam(def.getCppBaseClassName());
defParent.addTemplateParam(storageCls
? strfmt("{0}::{1}", def.getStorageNamespace(),
def.getStorageClassName())
: strfmt("::mlir::{0}Storage", valueType));
for (auto &trait : def.getTraits()) {
defParent.addTemplateParam(
isa<NativeTrait>(&trait)
? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
: cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
}
defCls.addParent(std::move(defParent));
}
static std::string formatExtraDeclarations(const AttrOrTypeDef &def) {
SmallVector<StringRef> extraDeclarations;
for (const auto &trait : def.getTraits()) {
if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration();
if (value.empty())
continue;
extraDeclarations.push_back(value);
}
}
if (std::optional<StringRef> extraDecl = def.getExtraDecls()) {
extraDeclarations.push_back(*extraDecl);
}
return llvm::join(extraDeclarations, "\n");
}
static std::string formatExtraDefinitions(const AttrOrTypeDef &def) {
SmallVector<StringRef> extraDefinitions;
for (const auto &trait : def.getTraits()) {
if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition();
if (value.empty())
continue;
extraDefinitions.push_back(value);
}
}
if (std::optional<StringRef> extraDef = def.getExtraDefs()) {
extraDefinitions.push_back(*extraDef);
}
FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName());
return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str();
}
void DefGen::emitTopLevelDeclarations() {
defCls.declare<VisibilityDeclaration>(Visibility::Public);
defCls.declare<UsingDeclaration>("Base::Base");
std::string extraDecl = formatExtraDeclarations(def);
std::string extraDef = formatExtraDefinitions(def);
defCls.declare<ExtraClassDeclaration>(std::move(extraDecl),
std::move(extraDef));
}
void DefGen::emitName() {
StringRef name;
if (auto *attrDef = dyn_cast<AttrDef>(&def)) {
name = attrDef->getAttrName();
} else {
auto *typeDef = cast<TypeDef>(&def);
name = typeDef->getTypeName();
}
std::string nameDecl =
strfmt("static constexpr ::llvm::StringLiteral name = \"{0}\";\n", name);
defCls.declare<ExtraClassDeclaration>(std::move(nameDecl));
}
void DefGen::emitDialectName() {
std::string decl =
strfmt("static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n",
def.getDialect().getName());
defCls.declare<ExtraClassDeclaration>(std::move(decl));
}
void DefGen::emitBuilders() {
if (!def.skipDefaultBuilders()) {
emitDefaultBuilder();
if (def.genVerifyDecl())
emitCheckedBuilder();
}
for (auto &builder : def.getBuilders()) {
emitCustomBuilder(builder);
if (def.genVerifyDecl())
emitCheckedCustomBuilder(builder);
}
}
void DefGen::emitVerifier() {
defCls.declare<UsingDeclaration>("Base::getChecked");
defCls.declareStaticMethod(
"::llvm::LogicalResult", "verify",
getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
"emitError"}}));
}
void DefGen::emitParserPrinter() {
auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
"::llvm::StringLiteral", "getMnemonic");
mnemonic->body().indent() << strfmt("return {\"{0}\"};", *def.getMnemonic());
bool hasAssemblyFormat = def.getAssemblyFormat().has_value();
if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat)
return;
SmallVector<MethodParameter> parserParams;
parserParams.emplace_back("::mlir::AsmParser &", "odsParser");
if (isa<AttrDef>(&def))
parserParams.emplace_back("::mlir::Type", "odsType");
auto *parser = defCls.addMethod(strfmt("::mlir::{0}", valueType), "parse",
hasAssemblyFormat ? Method::Static
: Method::StaticDeclaration,
std::move(parserParams));
auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration;
Method *printer =
defCls.addMethod("void", "print", props,
MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
if (hasAssemblyFormat)
return generateAttrOrTypeFormat(def, parser->body(), printer->body());
}
void DefGen::emitAccessors() {
for (auto ¶m : params) {
Method *m = defCls.addMethod(
param.getCppAccessorType(), param.getAccessorName(),
def.genStorageClass() ? Method::Const : Method::ConstDeclaration);
if (!def.genStorageClass())
continue;
m->body().indent() << "return getImpl()->" << param.getName() << ";";
}
}
void DefGen::emitInterfaceMethods() {
for (auto &traitDef : def.getTraits())
if (auto *trait = dyn_cast<InterfaceTrait>(&traitDef))
if (trait->shouldDeclareMethods())
emitTraitMethods(*trait);
}
SmallVector<MethodParameter>
DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const {
SmallVector<MethodParameter> builderParams;
builderParams.append(prefix.begin(), prefix.end());
for (auto ¶m : params)
builderParams.emplace_back(param.getCppType(), param.getName());
return builderParams;
}
void DefGen::emitDefaultBuilder() {
Method *m = defCls.addStaticMethod(
def.getCppClassName(), "get",
getBuilderParams({{"::mlir::MLIRContext *", "context"}}));
MethodBody &body = m->body().indent();
auto scope = body.scope("return Base::get(context", ");");
for (const auto ¶m : params)
body << ", std::move(" << param.getName() << ")";
}
void DefGen::emitCheckedBuilder() {
Method *m = defCls.addStaticMethod(
def.getCppClassName(), "getChecked",
getBuilderParams(
{{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"},
{"::mlir::MLIRContext *", "context"}}));
MethodBody &body = m->body().indent();
auto scope = body.scope("return Base::getChecked(emitError, context", ");");
for (const auto ¶m : params)
body << ", " << param.getName();
}
static SmallVector<MethodParameter>
getCustomBuilderParams(std::initializer_list<MethodParameter> prefix,
const AttrOrTypeBuilder &builder) {
auto params = builder.getParameters();
SmallVector<MethodParameter> builderParams;
builderParams.append(prefix.begin(), prefix.end());
if (!builder.hasInferredContextParameter())
builderParams.emplace_back("::mlir::MLIRContext *", "context");
for (auto ¶m : params) {
builderParams.emplace_back(param.getCppType(), *param.getName(),
param.getDefaultValue());
}
return builderParams;
}
void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
StringRef returnType = def.getCppClassName();
if (std::optional<StringRef> builderReturnType = builder.getReturnType())
returnType = *builderReturnType;
Method *m = defCls.addMethod(returnType, "get", props,
getCustomBuilderParams({}, builder));
if (!builder.getBody())
return;
FmtContext ctx;
ctx.addSubst("_get", "Base::get");
if (!builder.hasInferredContextParameter())
ctx.addSubst("_ctxt", "context");
std::string bodyStr = tgfmt(*builder.getBody(), &ctx);
m->body().indent().getStream().printReindented(bodyStr);
}
static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
size_t pos = 0;
while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos)
str.replace(pos, from.size(), to.data(), to.size());
return str;
}
void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
StringRef returnType = def.getCppClassName();
if (std::optional<StringRef> builderReturnType = builder.getReturnType())
returnType = *builderReturnType;
Method *m = defCls.addMethod(
returnType, "getChecked", props,
getCustomBuilderParams(
{{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
builder));
if (!builder.getBody())
return;
FmtContext ctx;
if (!builder.hasInferredContextParameter())
ctx.addSubst("_ctxt", "context");
std::string bodyStr = replaceInStr(builder.getBody()->str(), "$_get(",
"Base::getChecked(emitError, ");
bodyStr = tgfmt(bodyStr, &ctx);
m->body().indent().getStream().printReindented(bodyStr);
}
void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods();
StringSet<> alwaysDeclared;
alwaysDeclared.insert(alwaysDeclaredMethods.begin(),
alwaysDeclaredMethods.end());
Interface iface = trait.getInterface();
for (auto &method : iface.getMethods()) {
if (method.getBody() || (method.getDefaultImplementation() &&
!alwaysDeclared.count(method.getName())))
continue;
emitTraitMethod(method);
}
}
void DefGen::emitTraitMethod(const InterfaceMethod &method) {
auto props =
method.isStatic() ? Method::StaticDeclaration : Method::ConstDeclaration;
SmallVector<MethodParameter> params;
for (auto ¶m : method.getArguments())
params.emplace_back(param.type, param.name);
defCls.addMethod(method.getReturnType(), method.getName(), props,
std::move(params));
}
void DefGen::emitStorageConstructor() {
Constructor *ctor =
storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
for (auto ¶m : params) {
std::string movedValue = ("std::move(" + param.getName() + ")").str();
ctor->addMemberInitializer(param.getName(), movedValue);
}
}
void DefGen::emitKeyType() {
std::string keyType("std::tuple<");
llvm::raw_string_ostream os(keyType);
llvm::interleaveComma(params, os,
[&](auto ¶m) { os << param.getCppType(); });
os << '>';
storageCls->declare<UsingDeclaration>("KeyTy", std::move(os.str()));
Method *m = storageCls->addConstMethod<Method::Inline>("KeyTy", "getAsKey");
m->body().indent() << "return KeyTy(";
llvm::interleaveComma(params, m->body().indent(),
[&](auto ¶m) { m->body() << param.getName(); });
m->body() << ");";
}
void DefGen::emitEquals() {
Method *eq = storageCls->addConstMethod<Method::Inline>(
"bool", "operator==", MethodParameter("const KeyTy &", "tblgenKey"));
auto &body = eq->body().indent();
auto scope = body.scope("return (", ");");
const auto eachFn = [&](auto it) {
FmtContext ctx({{"_lhs", it.value().getName()},
{"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
body << tgfmt(it.value().getComparator(), &ctx);
};
llvm::interleave(llvm::enumerate(params), body, eachFn, ") && (");
}
void DefGen::emitHashKey() {
Method *hash = storageCls->addStaticInlineMethod(
"::llvm::hash_code", "hashKey",
MethodParameter("const KeyTy &", "tblgenKey"));
auto &body = hash->body().indent();
auto scope = body.scope("return ::llvm::hash_combine(", ");");
llvm::interleaveComma(llvm::enumerate(params), body, [&](auto it) {
body << llvm::formatv("std::get<{0}>(tblgenKey)", it.index());
});
}
void DefGen::emitConstruct() {
Method *construct = storageCls->addMethod<Method::Inline>(
strfmt("{0} *", def.getStorageClassName()), "construct",
def.hasStorageCustomConstructor() ? Method::StaticDeclaration
: Method::Static,
MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType),
"allocator"),
MethodParameter("KeyTy &&", "tblgenKey"));
if (!def.hasStorageCustomConstructor()) {
auto &body = construct->body().indent();
for (const auto &it : llvm::enumerate(params)) {
body << formatv("auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
it.value().getName(), it.index());
}
FmtContext ctx = FmtContext().addSubst("_allocator", "allocator");
for (auto ¶m : params) {
if (std::optional<StringRef> allocCode = param.getAllocator()) {
ctx.withSelf(param.getName()).addSubst("_dst", param.getName());
body << tgfmt(*allocCode, &ctx) << '\n';
}
}
auto scope =
body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(",
def.getStorageClassName()),
");");
llvm::interleaveComma(params, body, [&](auto ¶m) {
body << "std::move(" << param.getName() << ")";
});
}
}
void DefGen::emitStorageClass() {
storageCls->addParent(strfmt("::mlir::{0}Storage", valueType));
emitStorageConstructor();
emitKeyType();
emitEquals();
emitHashKey();
emitConstruct();
storageCls->finalize();
for (auto ¶m : params)
storageCls->declare<Field>(param.getCppType(), param.getName());
}
namespace {
class DefGenerator {
public:
bool emitDecls(StringRef selectedDialect);
bool emitDefs(StringRef selectedDialect);
protected:
DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os,
StringRef defType, StringRef valueType, bool isAttrGenerator)
: defRecords(std::move(defs)), os(os), defType(defType),
valueType(valueType), isAttrGenerator(isAttrGenerator) {
llvm::sort(defRecords, [](llvm::Record *lhs, llvm::Record *rhs) {
return lhs->getID() < rhs->getID();
});
}
void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
std::vector<llvm::Record *> defRecords;
raw_ostream &os;
StringRef defType;
StringRef valueType;
bool isAttrGenerator;
};
struct AttrDefGenerator : public DefGenerator {
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
"Attr", "Attribute", true) {}
};
struct TypeDefGenerator : public DefGenerator {
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
"Type", "Type", false) {}
};
}
static const char *const typeDefDeclHeader = R"(
namespace mlir {
class AsmParser;
class AsmPrinter;
} // namespace mlir
)";
bool DefGenerator::emitDecls(StringRef selectedDialect) {
emitSourceFileHeader((defType + "Def Declarations").str(), os);
IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
os << typeDefDeclHeader;
SmallVector<AttrOrTypeDef, 16> defs;
collectAllDefs(selectedDialect, defRecords, defs);
if (defs.empty())
return false;
{
NamespaceEmitter nsEmitter(os, defs.front().getDialect());
for (const AttrOrTypeDef &def : defs)
os << "class " << def.getCppClassName() << ";\n";
for (const AttrOrTypeDef &def : defs)
DefGen(def).emitDecl(os);
}
for (const AttrOrTypeDef &def : defs)
if (!def.getDialect().getCppNamespace().empty())
os << "MLIR_DECLARE_EXPLICIT_TYPE_ID("
<< def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
<< ")\n";
return false;
}
void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
IfDefScope scope("GET_" + defType.upper() + "DEF_LIST", os);
auto interleaveFn = [&](const AttrOrTypeDef &def) {
os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName();
};
llvm::interleave(defs, os, interleaveFn, ",\n");
os << "\n";
}
static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
/// Parse an attribute registered to this dialect.
::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
::mlir::Type type) const {{
::llvm::SMLoc typeLoc = parser.getCurrentLocation();
::llvm::StringRef attrTag;
{{
::mlir::Attribute attr;
auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
if (parseResult.has_value())
return attr;
}
{1}
parser.emitError(typeLoc) << "unknown attribute `"
<< attrTag << "` in dialect `" << getNamespace() << "`";
return {{};
}
/// Print an attribute registered to this dialect.
void {0}::printAttribute(::mlir::Attribute attr,
::mlir::DialectAsmPrinter &printer) const {{
if (::mlir::succeeded(generatedAttributePrinter(attr, printer)))
return;
{2}
}
)";
static const char *const dialectDynamicAttrParserDispatch = R"(
{
::mlir::Attribute genAttr;
auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr);
if (parseResult.has_value()) {
if (::mlir::succeeded(parseResult.value()))
return genAttr;
return Attribute();
}
}
)";
static const char *const dialectDynamicAttrPrinterDispatch = R"(
if (::mlir::succeeded(printIfDynamicAttr(attr, printer)))
return;
)";
static const char *const dialectDefaultTypePrinterParserDispatch = R"(
/// Parse a type registered to this dialect.
::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
::llvm::SMLoc typeLoc = parser.getCurrentLocation();
::llvm::StringRef mnemonic;
::mlir::Type genType;
auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
if (parseResult.has_value())
return genType;
{1}
parser.emitError(typeLoc) << "unknown type `"
<< mnemonic << "` in dialect `" << getNamespace() << "`";
return {{};
}
/// Print a type registered to this dialect.
void {0}::printType(::mlir::Type type,
::mlir::DialectAsmPrinter &printer) const {{
if (::mlir::succeeded(generatedTypePrinter(type, printer)))
return;
{2}
}
)";
static const char *const dialectDynamicTypeParserDispatch = R"(
{
auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
if (parseResult.has_value()) {
if (::mlir::succeeded(parseResult.value()))
return genType;
return ::mlir::Type();
}
}
)";
static const char *const dialectDynamicTypePrinterDispatch = R"(
if (::mlir::succeeded(printIfDynamicType(type, printer)))
return;
)";
void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
if (llvm::none_of(defs, [](const AttrOrTypeDef &def) {
return def.getMnemonic().has_value();
})) {
return;
}
SmallVector<MethodParameter> params = {{"::mlir::AsmParser &", "parser"},
{"::llvm::StringRef *", "mnemonic"}};
if (isAttrGenerator)
params.emplace_back("::mlir::Type", "type");
params.emplace_back(strfmt("::mlir::{0} &", valueType), "value");
Method parse("::mlir::OptionalParseResult",
strfmt("generated{0}Parser", valueType), Method::StaticInline,
std::move(params));
Method printer("::llvm::LogicalResult",
strfmt("generated{0}Printer", valueType), Method::StaticInline,
{{strfmt("::mlir::{0}", valueType), "def"},
{"::mlir::AsmPrinter &", "printer"}});
parse.body() << " return "
"::mlir::AsmParser::KeywordSwitch<::mlir::"
"OptionalParseResult>(parser)\n";
const char *const getValueForMnemonic =
R"( .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{
value = {0}::{1};
return ::mlir::success(!!value);
})
)";
printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType
<< ", ::llvm::LogicalResult>(def)";
const char *const printValue = R"( .Case<{0}>([&](auto t) {{
printer << {0}::getMnemonic();{1}
return ::mlir::success();
})
)";
for (auto &def : defs) {
if (!def.getMnemonic())
continue;
bool hasParserPrinterDecl =
def.hasCustomAssemblyFormat() || def.getAssemblyFormat();
std::string defClass = strfmt(
"{0}::{1}", def.getDialect().getCppNamespace(), def.getCppClassName());
std::string parseOrGet =
hasParserPrinterDecl
? strfmt("parse(parser{0})", isAttrGenerator ? ", type" : "")
: "get(parser.getContext())";
parse.body() << llvm::formatv(getValueForMnemonic, defClass, parseOrGet);
StringRef printDef = "";
if (hasParserPrinterDecl)
printDef = "\nt.print(printer);";
printer.body() << llvm::formatv(printValue, defClass, printDef);
}
parse.body() << " .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n"
" *mnemonic = keyword;\n"
" return std::nullopt;\n"
" });";
printer.body() << " .Default([](auto) { return ::mlir::failure(); });";
raw_indented_ostream indentedOs(os);
parse.writeDeclTo(indentedOs);
printer.writeDeclTo(indentedOs);
}
bool DefGenerator::emitDefs(StringRef selectedDialect) {
emitSourceFileHeader((defType + "Def Definitions").str(), os);
SmallVector<AttrOrTypeDef, 16> defs;
collectAllDefs(selectedDialect, defRecords, defs);
if (defs.empty())
return false;
emitTypeDefList(defs);
IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
emitParsePrintDispatch(defs);
for (const AttrOrTypeDef &def : defs) {
{
NamespaceEmitter ns(os, def.getDialect());
DefGen gen(def);
gen.emitDef(os);
}
if (!def.getDialect().getCppNamespace().empty())
os << "MLIR_DEFINE_EXPLICIT_TYPE_ID("
<< def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
<< ")\n";
}
Dialect firstDialect = defs.front().getDialect();
if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) {
NamespaceEmitter nsEmitter(os, firstDialect);
if (firstDialect.isExtensible()) {
os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
firstDialect.getCppClassName(),
dialectDynamicAttrParserDispatch,
dialectDynamicAttrPrinterDispatch);
} else {
os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
firstDialect.getCppClassName(), "", "");
}
}
if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) {
NamespaceEmitter nsEmitter(os, firstDialect);
if (firstDialect.isExtensible()) {
os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
firstDialect.getCppClassName(),
dialectDynamicTypeParserDispatch,
dialectDynamicTypePrinterDispatch);
} else {
os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
firstDialect.getCppClassName(), "", "");
}
}
return false;
}
static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*");
static llvm::cl::opt<std::string>
attrDialect("attrdefs-dialect",
llvm::cl::desc("Generate attributes for this dialect"),
llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated);
static mlir::GenRegistration
genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDefs(attrDialect);
});
static mlir::GenRegistration
genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDecls(attrDialect);
});
static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
static llvm::cl::opt<std::string>
typeDialect("typedefs-dialect",
llvm::cl::desc("Generate types for this dialect"),
llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
static mlir::GenRegistration
genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDefs(typeDialect);
});
static mlir::GenRegistration
genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDecls(typeDialect);
});