#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/CodeGenTypes/MachineValueType.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
static llvm::cl::OptionCategory intrinsicGenCat("Intrinsics Generator Options");
static llvm::cl::opt<std::string>
nameFilter("llvmir-intrinsics-filter",
llvm::cl::desc("Only keep the intrinsics with the specified "
"substring in their record name"),
llvm::cl::cat(intrinsicGenCat));
static llvm::cl::opt<std::string>
opBaseClass("dialect-opclass-base",
llvm::cl::desc("The base class for the ops in the dialect we "
"are planning to emit"),
llvm::cl::init("LLVM_IntrOp"), llvm::cl::cat(intrinsicGenCat));
static llvm::cl::opt<std::string> accessGroupRegexp(
"llvmir-intrinsics-access-group-regexp",
llvm::cl::desc("Mark intrinsics that match the specified "
"regexp as taking an access group metadata"),
llvm::cl::cat(intrinsicGenCat));
static llvm::cl::opt<std::string> aliasAnalysisRegexp(
"llvmir-intrinsics-alias-analysis-regexp",
llvm::cl::desc("Mark intrinsics that match the specified "
"regexp as taking alias.scopes, noalias, and tbaa metadata"),
llvm::cl::cat(intrinsicGenCat));
using IndicesTy = llvm::SmallBitVector;
static llvm::MVT::SimpleValueType getValueType(const llvm::Record *rec) {
return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt(
"Value");
}
static IndicesTy getOverloadableTypeIdxs(const llvm::Record &record,
const char *listName) {
auto results = record.getValueAsListOfDefs(listName);
IndicesTy overloadedOps(results.size());
for (const auto &r : llvm::enumerate(results)) {
llvm::MVT::SimpleValueType vt = getValueType(r.value());
switch (vt) {
case llvm::MVT::iAny:
case llvm::MVT::fAny:
case llvm::MVT::Any:
case llvm::MVT::iPTRAny:
case llvm::MVT::vAny:
overloadedOps.set(r.index());
break;
default:
continue;
}
}
return overloadedOps;
}
namespace {
class LLVMIntrinsic {
public:
LLVMIntrinsic(const llvm::Record &record) : record(record) {}
std::string getOperationName() const {
llvm::StringRef name = record.getValueAsString(fieldName);
if (!name.empty())
return name.str();
name = record.getName();
assert(name.starts_with("int_") &&
"LLVM intrinsic names are expected to start with 'int_'");
name = name.drop_front(4);
llvm::SmallVector<llvm::StringRef, 8> chunks;
llvm::StringRef targetPrefix = record.getValueAsString("TargetPrefix");
name.split(chunks, '_');
auto *chunksBegin = chunks.begin();
if (!targetPrefix.empty()) {
assert(targetPrefix == *chunksBegin &&
"Intrinsic has TargetPrefix, but "
"record name doesn't begin with it");
assert(chunks.size() >= 2 &&
"Intrinsic has TargetPrefix, but "
"chunks has only one element meaning the intrinsic name is empty");
++chunksBegin;
}
return llvm::join(chunksBegin, chunks.end(), ".");
}
llvm::StringRef getProperRecordName() const {
llvm::StringRef name = record.getName();
assert(name.starts_with("int_") &&
"LLVM intrinsic names are expected to start with 'int_'");
return name.drop_front(4);
}
unsigned getNumOperands() const {
auto operands = record.getValueAsListOfDefs(fieldOperands);
assert(llvm::all_of(operands,
[](const llvm::Record *r) {
return r->isSubClassOf("LLVMType");
}) &&
"expected operands to be of LLVM type");
return operands.size();
}
unsigned getNumResults() const {
auto results = record.getValueAsListOfDefs(fieldResults);
for (const llvm::Record *r : results) {
(void)r;
assert(r->isSubClassOf("LLVMType") &&
"expected operands to be of LLVM type");
}
return results.size();
}
bool hasSideEffects() const {
return llvm::none_of(
record.getValueAsListOfDefs(fieldTraits),
[](const llvm::Record *r) { return r->getName() == "IntrNoMem"; });
}
bool isCommutative() const {
return llvm::any_of(
record.getValueAsListOfDefs(fieldTraits),
[](const llvm::Record *r) { return r->getName() == "Commutative"; });
}
IndicesTy getOverloadableOperandsIdxs() const {
return getOverloadableTypeIdxs(record, fieldOperands);
}
IndicesTy getOverloadableResultsIdxs() const {
return getOverloadableTypeIdxs(record, fieldResults);
}
private:
const char *fieldName = "LLVMName";
const char *fieldOperands = "ParamTypes";
const char *fieldResults = "RetTypes";
const char *fieldTraits = "IntrProperties";
const llvm::Record &record;
};
}
template <typename Range>
void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
os << '[';
llvm::interleaveComma(range, os);
os << ']';
}
static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
LLVMIntrinsic intr(record);
llvm::Regex accessGroupMatcher(accessGroupRegexp);
bool requiresAccessGroup =
!accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName());
llvm::Regex aliasAnalysisMatcher(aliasAnalysisRegexp);
bool requiresAliasAnalysis = !aliasAnalysisRegexp.empty() &&
aliasAnalysisMatcher.match(record.getName());
llvm::SmallVector<llvm::StringRef, 2> traits;
if (intr.isCommutative())
traits.push_back("Commutative");
if (!intr.hasSideEffects())
traits.push_back("NoMemoryEffect");
llvm::SmallVector<llvm::StringRef, 8> operands(intr.getNumOperands(),
"LLVM_Type");
if (requiresAccessGroup)
operands.push_back(
"OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups");
if (requiresAliasAnalysis) {
operands.push_back("OptionalAttr<LLVM_AliasScopeArrayAttr>:$alias_scopes");
operands.push_back(
"OptionalAttr<LLVM_AliasScopeArrayAttr>:$noalias_scopes");
operands.push_back("OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa");
}
os << "def LLVM_" << intr.getProperRecordName() << " : " << opBaseClass
<< "<\"" << intr.getOperationName() << "\", ";
printBracketedRange(intr.getOverloadableResultsIdxs().set_bits(), os);
os << ", ";
printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
os << ", ";
printBracketedRange(traits, os);
os << ", " << intr.getNumResults() << ", "
<< (requiresAccessGroup ? "1" : "0") << ", "
<< (requiresAliasAnalysis ? "1" : "0") << ">, Arguments<(ins"
<< (operands.empty() ? "" : " ");
llvm::interleaveComma(operands, os);
os << ")>;\n\n";
return false;
}
static bool emitIntrinsics(const llvm::RecordKeeper &records,
llvm::raw_ostream &os) {
llvm::emitSourceFileHeader("Operations for LLVM intrinsics", os, records);
os << "include \"mlir/Dialect/LLVMIR/LLVMOpBase.td\"\n";
os << "include \"mlir/Interfaces/SideEffectInterfaces.td\"\n\n";
auto defs = records.getAllDerivedDefinitions("Intrinsic");
for (const llvm::Record *r : defs) {
if (!nameFilter.empty() && !r->getName().contains(nameFilter))
continue;
if (emitIntrinsic(*r, os))
return true;
}
return false;
}
static mlir::GenRegistration genLLVMIRIntrinsics("gen-llvmir-intrinsics",
"Generate LLVM IR intrinsics",
emitIntrinsics);