#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SubElementInterfaces.h"
#include "mlir/IR/Verifier.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/Threading.h"
#include <tuple>
using namespace mlir;
using namespace mlir::detail;
#define DEBUG_TYPE "mlir-asm-printer"
void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
void OperationName::dump() const { print(llvm::errs()); }
AsmParser::~AsmParser() = default;
DialectAsmParser::~DialectAsmParser() = default;
OpAsmParser::~OpAsmParser() = default;
MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
DialectAsmPrinter::~DialectAsmPrinter() = default;
OpAsmPrinter::~OpAsmPrinter() = default;
void OpAsmPrinter::printFunctionalType(Operation *op) {
auto &os = getStream();
os << '(';
llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
*this << (operand ? operand.getType() : Type());
});
os << ") -> ";
bool wrapped = op->getNumResults() != 1;
if (!wrapped && op->getResult(0).getType() &&
op->getResult(0).getType().isa<FunctionType>())
wrapped = true;
if (wrapped)
os << '(';
llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
*this << (result ? result.getType() : Type());
});
if (wrapped)
os << ')';
}
#include "mlir/IR/OpAsmInterface.cpp.inc"
LogicalResult
OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
return entry.emitError() << "unknown 'resource' key '" << entry.getKey()
<< "' for dialect '" << getDialect()->getNamespace()
<< "'";
}
namespace {
struct AsmPrinterOptions {
llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
"mlir-print-elementsattrs-with-hex-if-larger",
llvm::cl::desc(
"Print DenseElementsAttrs with a hex string that have "
"more elements than the given upper limit (use -1 to disable)")};
llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
"mlir-elide-elementsattrs-if-larger",
llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
"more elements than the given upper limit")};
llvm::cl::opt<bool> printDebugInfoOpt{
"mlir-print-debuginfo", llvm::cl::init(false),
llvm::cl::desc("Print debug info in MLIR output")};
llvm::cl::opt<bool> printPrettyDebugInfoOpt{
"mlir-pretty-debuginfo", llvm::cl::init(false),
llvm::cl::desc("Print pretty debug info in MLIR output")};
llvm::cl::opt<bool> printGenericOpFormOpt{
"mlir-print-op-generic", llvm::cl::init(false),
llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
llvm::cl::opt<bool> assumeVerifiedOpt{
"mlir-print-assume-verified", llvm::cl::init(false),
llvm::cl::desc("Skip op verification when using custom printers"),
llvm::cl::Hidden};
llvm::cl::opt<bool> printLocalScopeOpt{
"mlir-print-local-scope", llvm::cl::init(false),
llvm::cl::desc("Print with local scope and inline information (eliding "
"aliases for attributes, types, and locations")};
llvm::cl::opt<bool> printValueUsers{
"mlir-print-value-users", llvm::cl::init(false),
llvm::cl::desc(
"Print users of operation results and block arguments as a comment")};
};
}
static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
void mlir::registerAsmPrinterCLOptions() {
*clOptions;
}
OpPrintingFlags::OpPrintingFlags()
: printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
printGenericOpFormFlag(false), assumeVerifiedFlag(false),
printLocalScope(false), printValueUsersFlag(false) {
if (!clOptions.isConstructed())
return;
if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
printDebugInfoFlag = clOptions->printDebugInfoOpt;
printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
printLocalScope = clOptions->printLocalScopeOpt;
printValueUsersFlag = clOptions->printValueUsers;
}
OpPrintingFlags &
OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
elementsAttrElementLimit = largeElementLimit;
return *this;
}
OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
printDebugInfoFlag = true;
printDebugInfoPrettyFormFlag = prettyForm;
return *this;
}
OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
printGenericOpFormFlag = true;
return *this;
}
OpPrintingFlags &OpPrintingFlags::assumeVerified() {
assumeVerifiedFlag = true;
return *this;
}
OpPrintingFlags &OpPrintingFlags::useLocalScope() {
printLocalScope = true;
return *this;
}
OpPrintingFlags &OpPrintingFlags::printValueUsers() {
printValueUsersFlag = true;
return *this;
}
bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
return elementsAttrElementLimit &&
*elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
!attr.isa<SplatElementsAttr>();
}
Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
return elementsAttrElementLimit;
}
bool OpPrintingFlags::shouldPrintDebugInfo() const {
return printDebugInfoFlag;
}
bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
return printDebugInfoPrettyFormFlag;
}
bool OpPrintingFlags::shouldPrintGenericOpForm() const {
return printGenericOpFormFlag;
}
bool OpPrintingFlags::shouldAssumeVerified() const {
return assumeVerifiedFlag;
}
bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
bool OpPrintingFlags::shouldPrintValueUsers() const {
return printValueUsersFlag;
}
static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
if (clOptions.isConstructed()) {
if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
if (clOptions->printElementsAttrWithHexIfLarger == -1)
return false;
return numElements > clOptions->printElementsAttrWithHexIfLarger;
}
}
return numElements > 100;
}
namespace {
struct NewLineCounter {
unsigned curLine = 1;
};
static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
++newLine.curLine;
return os << '\n';
}
}
namespace {
class SymbolAlias {
public:
SymbolAlias(StringRef name, bool isDeferrable)
: name(name), suffixIndex(0), hasSuffixIndex(false),
isDeferrable(isDeferrable) {}
SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
: name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
isDeferrable(isDeferrable) {}
void print(raw_ostream &os) const {
os << name;
if (hasSuffixIndex)
os << suffixIndex;
}
bool canBeDeferred() const { return isDeferrable; }
private:
StringRef name;
uint32_t suffixIndex : 30;
bool hasSuffixIndex : 1;
bool isDeferrable : 1;
};
class AliasInitializer {
public:
AliasInitializer(
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
llvm::BumpPtrAllocator &aliasAllocator)
: interfaces(interfaces), aliasAllocator(aliasAllocator),
aliasOS(aliasBuffer) {}
void initialize(Operation *op, const OpPrintingFlags &printerFlags,
llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
llvm::MapVector<Type, SymbolAlias> &typeToAlias);
void visit(Attribute attr, bool canBeDeferred = false);
void visit(Type type);
private:
template <typename T>
LogicalResult
generateAlias(T symbol,
llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr;
llvm::MapVector<StringRef, std::vector<Type>> aliasToType;
llvm::BumpPtrAllocator &aliasAllocator;
DenseSet<Attribute> visitedAttributes;
DenseSet<Attribute> deferrableAttributes;
DenseSet<Type> visitedTypes;
SmallString<32> aliasBuffer;
llvm::raw_svector_ostream aliasOS;
};
class DummyAliasOperationPrinter : private OpAsmPrinter {
public:
explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
AliasInitializer &initializer)
: printerFlags(printerFlags), initializer(initializer) {}
void print(Operation *op) {
if (printerFlags.shouldPrintDebugInfo())
initializer.visit(op->getLoc(), true);
if (!printerFlags.shouldPrintGenericOpForm()) {
if (auto opInfo = op->getRegisteredInfo()) {
opInfo->printAssembly(op, *this, "");
return;
}
}
printGenericOp(op);
}
private:
void printGenericOp(Operation *op, bool printOpName = true) override {
if (op->getNumRegions() != 0) {
for (Region ®ion : op->getRegions())
printRegion(region, true,
true);
}
for (Type type : op->getOperandTypes())
printType(type);
for (Type type : op->getResultTypes())
printType(type);
for (const NamedAttribute &attr : op->getAttrs())
printAttribute(attr.getValue());
}
void print(Block *block, bool printBlockArgs = true,
bool printBlockTerminator = true) {
if (printBlockArgs) {
for (BlockArgument arg : block->getArguments()) {
printType(arg.getType());
if (printerFlags.shouldPrintDebugInfo())
initializer.visit(arg.getLoc(), false);
}
}
bool hasTerminator =
!block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
auto range = llvm::make_range(
block->begin(),
std::prev(block->end(),
(!hasTerminator || printBlockTerminator) ? 0 : 1));
for (Operation &op : range)
print(&op);
}
void printRegion(Region ®ion, bool printEntryBlockArgs,
bool printBlockTerminators,
bool printEmptyBlock = false) override {
if (region.empty())
return;
auto *entryBlock = ®ion.front();
print(entryBlock, printEntryBlockArgs, printBlockTerminators);
for (Block &b : llvm::drop_begin(region, 1))
print(&b);
}
void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
bool omitType) override {
printType(arg.getType());
if (printerFlags.shouldPrintDebugInfo())
initializer.visit(arg.getLoc(), false);
}
void printType(Type type) override { initializer.visit(type); }
void printAttribute(Attribute attr) override { initializer.visit(attr); }
void printAttributeWithoutType(Attribute attr) override {
printAttribute(attr);
}
LogicalResult printAlias(Attribute attr) override {
initializer.visit(attr);
return success();
}
LogicalResult printAlias(Type type) override {
initializer.visit(type);
return success();
}
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
if (attrs.empty())
return;
if (elidedAttrs.empty()) {
for (const NamedAttribute &attr : attrs)
printAttribute(attr.getValue());
return;
}
llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
elidedAttrs.end());
for (const NamedAttribute &attr : attrs)
if (!elidedAttrsSet.contains(attr.getName().strref()))
printAttribute(attr.getValue());
}
void printOptionalAttrDictWithKeyword(
ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
printOptionalAttrDict(attrs, elidedAttrs);
}
raw_ostream &getStream() const override { return os; }
void printFloat(const APFloat &value) override {}
void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
void printNewline() override {}
void printOperand(Value) override {}
void printOperand(Value, raw_ostream &os) override {
os << "%";
}
void printKeywordOrString(StringRef) override {}
void printSymbolName(StringRef) override {}
void printSuccessor(Block *) override {}
void printSuccessorAndUseList(Block *, ValueRange) override {}
void shadowRegionArgs(Region &, ValueRange) override {}
const OpPrintingFlags &printerFlags;
AliasInitializer &initializer;
mutable llvm::raw_null_ostream os;
};
}
static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
StringRef allowedPunctChars = "$._-",
bool allowTrailingDigit = true) {
assert(!name.empty() && "Shouldn't have an empty name here");
auto copyNameToBuffer = [&] {
for (char ch : name) {
if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))
buffer.push_back(ch);
else if (ch == ' ')
buffer.push_back('_');
else
buffer.append(llvm::utohexstr((unsigned char)ch));
}
};
if (isdigit(name[0])) {
buffer.push_back('_');
copyNameToBuffer();
return buffer;
}
if (!allowTrailingDigit && isdigit(name.back())) {
copyNameToBuffer();
buffer.push_back('_');
return buffer;
}
for (char ch : name) {
if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {
copyNameToBuffer();
return buffer;
}
}
return name;
}
template <typename T>
static void
initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
llvm::MapVector<T, SymbolAlias> &symbolToAlias,
DenseSet<T> *deferrableAliases = nullptr) {
std::vector<std::pair<StringRef, std::vector<T>>> aliases =
aliasToSymbol.takeVector();
llvm::array_pod_sort(aliases.begin(), aliases.end(),
[](const auto *lhs, const auto *rhs) {
return lhs->first.compare(rhs->first);
});
for (auto &it : aliases) {
if (it.second.size() == 1) {
T symbol = it.second.front();
bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
continue;
}
for (int i = 0, e = it.second.size(); i < e; ++i) {
T symbol = it.second[i];
bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
}
}
}
void AliasInitializer::initialize(
Operation *op, const OpPrintingFlags &printerFlags,
llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
aliasPrinter.print(op);
initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
initializeAliases(aliasToType, typeToAlias);
}
void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
if (!visitedAttributes.insert(attr).second) {
if (!canBeDeferred)
deferrableAttributes.erase(attr);
return;
}
if (succeeded(generateAlias(attr, aliasToAttr))) {
if (canBeDeferred)
deferrableAttributes.insert(attr);
return;
}
if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) {
subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
[&](Type type) { visit(type); });
}
}
void AliasInitializer::visit(Type type) {
if (!visitedTypes.insert(type).second)
return;
if (succeeded(generateAlias(type, aliasToType)))
return;
if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) {
subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
[&](Type type) { visit(type); });
}
}
template <typename T>
LogicalResult AliasInitializer::generateAlias(
T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
SmallString<32> nameBuffer;
for (const auto &interface : interfaces) {
OpAsmDialectInterface::AliasResult result =
interface.getAlias(symbol, aliasOS);
if (result == OpAsmDialectInterface::AliasResult::NoAlias)
continue;
nameBuffer = std::move(aliasBuffer);
assert(!nameBuffer.empty() && "expected valid alias name");
if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
break;
}
if (nameBuffer.empty())
return failure();
SmallString<16> tempBuffer;
StringRef name =
sanitizeIdentifier(nameBuffer, tempBuffer, "$_-",
false);
name = name.copy(aliasAllocator);
aliasToSymbol[name].push_back(symbol);
return success();
}
namespace {
class AliasState {
public:
void
initialize(Operation *op, const OpPrintingFlags &printerFlags,
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
LogicalResult getAlias(Type ty, raw_ostream &os) const;
void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
printAliases(os, newLine, false);
}
void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
printAliases(os, newLine, true);
}
private:
void printAliases(raw_ostream &os, NewLineCounter &newLine,
bool isDeferred) const;
llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
llvm::MapVector<Type, SymbolAlias> typeToAlias;
llvm::BumpPtrAllocator aliasAllocator;
};
}
void AliasState::initialize(
Operation *op, const OpPrintingFlags &printerFlags,
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
AliasInitializer initializer(interfaces, aliasAllocator);
initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
}
LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
auto it = attrToAlias.find(attr);
if (it == attrToAlias.end())
return failure();
it->second.print(os << '#');
return success();
}
LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
auto it = typeToAlias.find(ty);
if (it == typeToAlias.end())
return failure();
it->second.print(os << '!');
return success();
}
void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
bool isDeferred) const {
auto filterFn = [=](const auto &aliasIt) {
return aliasIt.second.canBeDeferred() == isDeferred;
};
for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
it.second.print(os << '#');
os << " = " << it.first << newLine;
}
for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
it.second.print(os << '!');
os << " = " << it.first << newLine;
}
}
namespace {
struct BlockInfo {
int ordering;
StringRef name;
};
class SSANameState {
public:
enum : unsigned { NameSentinel = ~0U };
SSANameState(Operation *op, const OpPrintingFlags &printerFlags);
void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
void printOperationID(Operation *op, raw_ostream &stream) const;
ArrayRef<int> getOpResultGroups(Operation *op);
BlockInfo getBlockInfo(Block *block);
void shadowRegionArgs(Region ®ion, ValueRange namesToUse);
private:
void numberValuesInRegion(Region ®ion);
void numberValuesInBlock(Block &block);
void numberValuesInOp(Operation &op);
void getResultIDAndNumber(OpResult result, Value &lookupValue,
Optional<int> &lookupResultNo) const;
void setValueName(Value value, StringRef name);
StringRef uniqueValueName(StringRef name);
DenseMap<Value, unsigned> valueIDs;
DenseMap<Value, StringRef> valueNames;
DenseMap<Operation *, unsigned> operationIDs;
DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
DenseMap<Block *, BlockInfo> blockNames;
llvm::ScopedHashTable<StringRef, char> usedNames;
llvm::BumpPtrAllocator usedNameAllocator;
unsigned nextValueID = 0;
unsigned nextArgumentID = 0;
unsigned nextConflictID = 0;
OpPrintingFlags printerFlags;
};
}
SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags)
: printerFlags(printerFlags) {
llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
using NamingContext =
std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
llvm::BumpPtrAllocator allocator;
auto *topLevelNamesScope =
new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
SmallVector<NamingContext, 8> nameContext;
for (Region ®ion : op->getRegions())
nameContext.push_back(std::make_tuple(®ion, nextValueID, nextArgumentID,
nextConflictID, topLevelNamesScope));
numberValuesInOp(*op);
while (!nameContext.empty()) {
Region *region;
UsedNamesScopeTy *parentScope;
std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) =
nameContext.pop_back_val();
while (usedNames.getCurScope() != parentScope) {
usedNames.getCurScope()->~UsedNamesScopeTy();
assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
"top level parentScope must be a nullptr");
}
auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
UsedNamesScopeTy(usedNames);
numberValuesInRegion(*region);
for (Operation &op : region->getOps())
for (Region ®ion : op.getRegions())
nameContext.push_back(std::make_tuple(®ion, nextValueID,
nextArgumentID, nextConflictID,
curNamesScope));
}
while (usedNames.getCurScope() != nullptr)
usedNames.getCurScope()->~UsedNamesScopeTy();
}
void SSANameState::printValueID(Value value, bool printResultNo,
raw_ostream &stream) const {
if (!value) {
stream << "<<NULL VALUE>>";
return;
}
Optional<int> resultNo;
auto lookupValue = value;
if (OpResult result = value.dyn_cast<OpResult>())
getResultIDAndNumber(result, lookupValue, resultNo);
auto it = valueIDs.find(lookupValue);
if (it == valueIDs.end()) {
stream << "<<UNKNOWN SSA VALUE>>";
return;
}
stream << '%';
if (it->second != NameSentinel) {
stream << it->second;
} else {
auto nameIt = valueNames.find(lookupValue);
assert(nameIt != valueNames.end() && "Didn't have a name entry?");
stream << nameIt->second;
}
if (resultNo && printResultNo)
stream << '#' << resultNo;
}
void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const {
auto it = operationIDs.find(op);
if (it == operationIDs.end()) {
stream << "<<UNKOWN OPERATION>>";
} else {
stream << '%' << it->second;
}
}
ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
auto it = opResultGroups.find(op);
return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
}
BlockInfo SSANameState::getBlockInfo(Block *block) {
auto it = blockNames.find(block);
BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
return it != blockNames.end() ? it->second : invalidBlock;
}
void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) {
assert(!region.empty() && "cannot shadow arguments of an empty region");
assert(region.getNumArguments() == namesToUse.size() &&
"incorrect number of names passed in");
assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
"only KnownIsolatedFromAbove ops can shadow names");
SmallVector<char, 16> nameStr;
for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
auto nameToUse = namesToUse[i];
if (nameToUse == nullptr)
continue;
auto nameToReplace = region.getArgument(i);
nameStr.clear();
llvm::raw_svector_ostream nameStream(nameStr);
printValueID(nameToUse, true, nameStream);
assert(valueIDs[nameToReplace] == NameSentinel);
auto name = StringRef(nameStream.str()).drop_front();
valueNames[nameToReplace] = name.copy(usedNameAllocator);
}
}
void SSANameState::numberValuesInRegion(Region ®ion) {
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(arg.cast<BlockArgument>().getOwner()->getParent() == ®ion &&
"arg not defined in current region");
setValueName(arg, name);
};
if (!printerFlags.shouldPrintGenericOpForm()) {
if (Operation *op = region.getParentOp()) {
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
}
}
unsigned nextBlockID = 0;
for (auto &block : region) {
auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
if (blockInfoIt.second) {
std::string name;
llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
}
blockInfoIt.first->second.ordering = nextBlockID++;
numberValuesInBlock(block);
}
}
void SSANameState::numberValuesInBlock(Block &block) {
bool isEntryBlock = block.isEntryBlock();
SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
llvm::raw_svector_ostream specialName(specialNameBuffer);
for (auto arg : block.getArguments()) {
if (valueIDs.count(arg))
continue;
if (isEntryBlock) {
specialNameBuffer.resize(strlen("arg"));
specialName << nextArgumentID++;
}
setValueName(arg, specialName.str());
}
for (auto &op : block)
numberValuesInOp(op);
}
void SSANameState::numberValuesInOp(Operation &op) {
SmallVector<int, 2> resultGroups(1, 0);
auto setResultNameFn = [&](Value result, StringRef name) {
assert(!valueIDs.count(result) && "result numbered multiple times");
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
setValueName(result, name);
if (int resultNo = result.cast<OpResult>().getResultNumber())
resultGroups.push_back(resultNo);
};
auto setBlockNameFn = [&](Block *block, StringRef name) {
assert(block->getParentOp() == &op &&
"getAsmBlockArgumentNames callback invoked on a block not directly "
"nested under the current operation");
assert(!blockNames.count(block) && "block numbered multiple times");
SmallString<16> tmpBuffer{"^"};
name = sanitizeIdentifier(name, tmpBuffer);
if (name.data() != tmpBuffer.data()) {
tmpBuffer.append(name);
name = tmpBuffer.str();
}
name = name.copy(usedNameAllocator);
blockNames[block] = {-1, name};
};
if (!printerFlags.shouldPrintGenericOpForm()) {
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
asmInterface.getAsmBlockNames(setBlockNameFn);
asmInterface.getAsmResultNames(setResultNameFn);
}
}
unsigned numResults = op.getNumResults();
if (numResults == 0) {
if (printerFlags.shouldPrintValueUsers()) {
if (operationIDs.try_emplace(&op, nextValueID).second)
++nextValueID;
}
return;
}
Value resultBegin = op.getResult(0);
if (valueIDs.try_emplace(resultBegin, nextValueID).second)
++nextValueID;
if (resultGroups.size() != 1) {
llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
opResultGroups.try_emplace(&op, std::move(resultGroups));
}
}
void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
Optional<int> &lookupResultNo) const {
Operation *owner = result.getOwner();
if (owner->getNumResults() == 1)
return;
int resultNo = result.getResultNumber();
auto resultGroupIt = opResultGroups.find(owner);
if (resultGroupIt == opResultGroups.end()) {
lookupResultNo = resultNo;
lookupValue = owner->getResult(0);
return;
}
ArrayRef<int> resultGroups = resultGroupIt->second;
const auto *it = llvm::upper_bound(resultGroups, resultNo);
int groupResultNo = 0, groupSize = 0;
if (it == resultGroups.end()) {
groupResultNo = resultGroups.back();
groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
} else {
groupResultNo = *std::prev(it);
groupSize = *it - groupResultNo;
}
if (groupSize != 1)
lookupResultNo = resultNo - groupResultNo;
lookupValue = owner->getResult(groupResultNo);
}
void SSANameState::setValueName(Value value, StringRef name) {
if (name.empty()) {
valueIDs[value] = nextValueID++;
return;
}
valueIDs[value] = NameSentinel;
valueNames[value] = uniqueValueName(name);
}
StringRef SSANameState::uniqueValueName(StringRef name) {
SmallString<16> tmpBuffer;
name = sanitizeIdentifier(name, tmpBuffer);
if (!usedNames.count(name)) {
name = name.copy(usedNameAllocator);
} else {
SmallString<64> probeName(name);
probeName.push_back('_');
while (true) {
probeName += llvm::utostr(nextConflictID++);
if (!usedNames.count(probeName)) {
name = probeName.str().copy(usedNameAllocator);
break;
}
probeName.resize(name.size() + 1);
}
}
usedNames.insert(name, char());
return name;
}
AsmParsedResourceEntry::~AsmParsedResourceEntry() = default;
AsmResourceBuilder::~AsmResourceBuilder() = default;
AsmResourceParser::~AsmResourceParser() = default;
AsmResourcePrinter::~AsmResourcePrinter() = default;
namespace mlir {
namespace detail {
class AsmStateImpl {
public:
explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
AsmState::LocationMap *locationMap)
: interfaces(op->getContext()), nameState(op, printerFlags),
printerFlags(printerFlags), locationMap(locationMap) {}
void initializeAliases(Operation *op) {
aliasState.initialize(op, printerFlags, interfaces);
}
AliasState &getAliasState() { return aliasState; }
SSANameState &getSSANameState() { return nameState; }
DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() {
return interfaces;
}
auto getResourcePrinters() {
return llvm::make_pointee_range(externalResourcePrinters);
}
const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
if (locationMap)
(*locationMap)[op] = std::make_pair(line, col);
}
private:
DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
AliasState aliasState;
SSANameState nameState;
OpPrintingFlags printerFlags;
AsmState::LocationMap *locationMap;
friend AsmState;
};
}
}
static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
OpPrintingFlags printerFlags) {
if (printerFlags.shouldPrintGenericOpForm() ||
printerFlags.shouldAssumeVerified())
return printerFlags;
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Verifying operation: "
<< op->getName() << "\n");
auto parentThreadId = llvm::get_threadid();
ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) {
if (parentThreadId == llvm::get_threadid()) {
LLVM_DEBUG({
diag.print(llvm::dbgs());
llvm::dbgs() << "\n";
});
return success();
}
return failure();
});
if (failed(verify(op))) {
LLVM_DEBUG(llvm::dbgs()
<< DEBUG_TYPE << ": '" << op->getName()
<< "' failed to verify and will be printed in generic form\n");
printerFlags.printGenericOpForm();
}
return printerFlags;
}
AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
LocationMap *locationMap)
: impl(std::make_unique<AsmStateImpl>(
op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {}
AsmState::~AsmState() = default;
const OpPrintingFlags &AsmState::getPrinterFlags() const {
return impl->getPrinterFlags();
}
void AsmState::attachResourcePrinter(
std::unique_ptr<AsmResourcePrinter> printer) {
impl->externalResourcePrinters.emplace_back(std::move(printer));
}
namespace mlir {
class AsmPrinter::Impl {
public:
Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None,
AsmStateImpl *state = nullptr)
: os(os), printerFlags(flags), state(state) {}
explicit Impl(Impl &other)
: Impl(other.os, other.printerFlags, other.state) {}
raw_ostream &getStream() { return os; }
template <typename Container, typename UnaryFunctor>
inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
llvm::interleaveComma(c, os, eachFn);
}
enum class AttrTypeElision {
Never,
May,
Must
};
void printAttribute(Attribute attr,
AttrTypeElision typeElision = AttrTypeElision::Never);
LogicalResult printAlias(Attribute attr);
void printType(Type type);
LogicalResult printAlias(Type type);
void printLocation(LocationAttr loc, bool allowAlias = false);
void printResourceHandle(const AsmDialectResourceHandle &resource) {
auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
os << interface->getResourceKey(resource);
dialectResources[resource.getDialect()].insert(resource);
}
void printAffineMap(AffineMap map);
void
printAffineExpr(AffineExpr expr,
function_ref<void(unsigned, bool)> printValueName = nullptr);
void printAffineConstraint(AffineExpr expr, bool isEq);
void printIntegerSet(IntegerSet set);
protected:
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {},
bool withKeyword = false);
void printNamedAttribute(NamedAttribute attr);
void printTrailingLocation(Location loc, bool allowAlias = true);
void printLocationInternal(LocationAttr loc, bool pretty = false);
void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
void printDenseStringElementsAttr(DenseStringElementsAttr attr);
void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
bool allowHex);
void printDialectAttribute(Attribute attr);
void printDialectType(Type type);
void printEscapedString(StringRef str);
void printHexString(StringRef str);
void printHexString(ArrayRef<char> data);
enum class BindingStrength {
Weak,
Strong,
};
void printAffineExprInternal(
AffineExpr expr, BindingStrength enclosingTightness,
function_ref<void(unsigned, bool)> printValueName = nullptr);
raw_ostream &os;
OpPrintingFlags printerFlags;
AsmStateImpl *state;
NewLineCounter newLine;
DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources;
};
}
void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
if (!printerFlags.shouldPrintDebugInfo())
return;
os << " ";
printLocation(loc, allowAlias);
}
void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) {
TypeSwitch<LocationAttr>(loc)
.Case<OpaqueLoc>([&](OpaqueLoc loc) {
printLocationInternal(loc.getFallbackLocation(), pretty);
})
.Case<UnknownLoc>([&](UnknownLoc loc) {
if (pretty)
os << "[unknown]";
else
os << "unknown";
})
.Case<FileLineColLoc>([&](FileLineColLoc loc) {
if (pretty)
os << loc.getFilename().getValue();
else
printEscapedString(loc.getFilename());
os << ':' << loc.getLine() << ':' << loc.getColumn();
})
.Case<NameLoc>([&](NameLoc loc) {
printEscapedString(loc.getName());
auto childLoc = loc.getChildLoc();
if (!childLoc.isa<UnknownLoc>()) {
os << '(';
printLocationInternal(childLoc, pretty);
os << ')';
}
})
.Case<CallSiteLoc>([&](CallSiteLoc loc) {
Location caller = loc.getCaller();
Location callee = loc.getCallee();
if (!pretty)
os << "callsite(";
printLocationInternal(callee, pretty);
if (pretty) {
if (callee.isa<NameLoc>()) {
if (caller.isa<FileLineColLoc>()) {
os << " at ";
} else {
os << newLine << " at ";
}
} else {
os << newLine << " at ";
}
} else {
os << " at ";
}
printLocationInternal(caller, pretty);
if (!pretty)
os << ")";
})
.Case<FusedLoc>([&](FusedLoc loc) {
if (!pretty)
os << "fused";
if (Attribute metadata = loc.getMetadata())
os << '<' << metadata << '>';
os << '[';
interleave(
loc.getLocations(),
[&](Location loc) { printLocationInternal(loc, pretty); },
[&]() { os << ", "; });
os << ']';
});
}
static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
bool isInf = apValue.isInfinity();
bool isNaN = apValue.isNaN();
if (!isInf && !isNaN) {
SmallString<128> strValue;
apValue.toString(strValue, 6, 0,
false);
assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
((strValue[0] == '-' || strValue[0] == '+') &&
(strValue[1] >= '0' && strValue[1] <= '9'))) &&
"[-+]?[0-9] regex does not match!");
if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
os << strValue;
return;
}
strValue.clear();
apValue.toString(strValue);
if (strValue.str().contains('.')) {
os << strValue;
return;
}
}
SmallVector<char, 16> str;
APInt apInt = apValue.bitcastToAPInt();
apInt.toString(str, 16, false,
true);
os << str;
}
void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
if (printerFlags.shouldPrintDebugInfoPrettyForm())
return printLocationInternal(loc, true);
os << "loc(";
if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
printLocationInternal(loc);
os << ')';
}
static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
if (symName.empty() || !isalpha(symName.front()))
return false;
symName = symName.drop_while(
[](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
if (symName.empty())
return true;
return symName.front() == '<' && symName.back() == '>';
}
static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
StringRef dialectName, StringRef symString) {
os << symPrefix << dialectName;
if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
os << '.' << symString;
return;
}
os << '<' << symString << '>';
}
static bool isBareIdentifier(StringRef name) {
if (name.empty() || (!isalpha(name[0]) && name[0] != '_'))
return false;
return llvm::all_of(name.drop_front(), [](unsigned char c) {
return isalnum(c) || c == '_' || c == '$' || c == '.';
});
}
static void printKeywordOrString(StringRef keyword, raw_ostream &os) {
if (isBareIdentifier(keyword)) {
os << keyword;
return;
}
os << "\"";
printEscapedString(keyword, os);
os << '"';
}
static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
assert(!symbolRef.empty() && "expected valid symbol reference");
os << '@';
printKeywordOrString(symbolRef, os);
}
static void printElidedElementsAttr(raw_ostream &os) {
os << R"(opaque<"elided_large_const", "0xDEADBEEF">)";
}
LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
return success(state && succeeded(state->getAliasState().getAlias(attr, os)));
}
LogicalResult AsmPrinter::Impl::printAlias(Type type) {
return success(state && succeeded(state->getAliasState().getAlias(type, os)));
}
void AsmPrinter::Impl::printAttribute(Attribute attr,
AttrTypeElision typeElision) {
if (!attr) {
os << "<<NULL ATTRIBUTE>>";
return;
}
if (succeeded(printAlias(attr)))
return;
auto attrType = attr.getType();
if (!isa<BuiltinDialect>(attr.getDialect())) {
printDialectAttribute(attr);
} else if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
opaqueAttr.getAttrData());
} else if (attr.isa<UnitAttr>()) {
os << "unit";
return;
} else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
os << '{';
interleaveComma(dictAttr.getValue(),
[&](NamedAttribute attr) { printNamedAttribute(attr); });
os << '}';
} else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
if (attrType.isSignlessInteger(1)) {
os << (intAttr.getValue().getBoolValue() ? "true" : "false");
return;
}
bool isUnsigned =
attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
intAttr.getValue().print(os, !isUnsigned);
if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
return;
} else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
printFloatValue(floatAttr.getValue(), os);
if (typeElision == AttrTypeElision::May && attrType.isF64())
return;
} else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
printEscapedString(strAttr.getValue());
} else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
os << '[';
interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
printAttribute(attr, AttrTypeElision::May);
});
os << ']';
} else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
os << "affine_map<";
affineMapAttr.getValue().print(os);
os << '>';
return;
} else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
os << "affine_set<";
integerSetAttr.getValue().print(os);
os << '>';
return;
} else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
printType(typeAttr.getValue());
} else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
printSymbolReference(refAttr.getRootReference().getValue(), os);
for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
os << "::";
printSymbolReference(nestedRef.getValue(), os);
}
} else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
printElidedElementsAttr(os);
} else {
os << "opaque<" << opaqueAttr.getDialect() << ", ";
printHexString(opaqueAttr.getValue());
os << ">";
}
} else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
printElidedElementsAttr(os);
} else {
os << "dense<";
printDenseIntOrFPElementsAttr(intOrFpEltAttr, true);
os << '>';
}
} else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
printElidedElementsAttr(os);
} else {
os << "dense<";
printDenseStringElementsAttr(strEltAttr);
os << '>';
}
} else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
printElidedElementsAttr(os);
} else {
os << "sparse<";
DenseIntElementsAttr indices = sparseEltAttr.getIndices();
if (indices.getNumElements() != 0) {
printDenseIntOrFPElementsAttr(indices, false);
os << ", ";
printDenseElementsAttr(sparseEltAttr.getValues(), true);
}
os << '>';
}
} else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
typeElision = AttrTypeElision::Must;
switch (denseArrayAttr.getElementType()) {
case DenseArrayBaseAttr::EltType::I8:
os << "[:i8";
break;
case DenseArrayBaseAttr::EltType::I16:
os << "[:i16";
break;
case DenseArrayBaseAttr::EltType::I32:
os << "[:i32";
break;
case DenseArrayBaseAttr::EltType::I64:
os << "[:i64";
break;
case DenseArrayBaseAttr::EltType::F32:
os << "[:f32";
break;
case DenseArrayBaseAttr::EltType::F64:
os << "[:f64";
break;
}
if (denseArrayAttr.getType().cast<ShapedType>().getRank())
os << " ";
denseArrayAttr.printWithoutBraces(os);
os << "]";
} else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
printLocation(locAttr);
} else {
llvm::report_fatal_error("Unknown builtin attribute");
}
if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
os << " : ";
printType(attrType);
}
}
static void printDenseIntElement(const APInt &value, raw_ostream &os,
bool isSigned) {
if (value.getBitWidth() == 1)
os << (value.getBoolValue() ? "true" : "false");
else
value.print(os, isSigned);
}
static void
printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
function_ref<void(unsigned)> printEltFn) {
if (isSplat)
return printEltFn(0);
auto numElements = type.getNumElements();
if (numElements == 0)
return;
int64_t rank = type.getRank();
SmallVector<unsigned, 4> counter(rank, 0);
unsigned openBrackets = 0;
auto shape = type.getShape();
auto bumpCounter = [&] {
++counter[rank - 1];
for (unsigned i = rank - 1; i > 0; --i)
if (counter[i] >= shape[i]) {
counter[i] = 0;
++counter[i - 1];
--openBrackets;
os << ']';
}
};
for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
if (idx != 0)
os << ", ";
while (openBrackets++ < rank)
os << '[';
openBrackets = rank;
printEltFn(idx);
bumpCounter();
}
while (openBrackets-- > 0)
os << ']';
}
void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
bool allowHex) {
if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
return printDenseStringElementsAttr(stringAttr);
printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
allowHex);
}
void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
DenseIntOrFPElementsAttr attr, bool allowHex) {
auto type = attr.getType();
auto elementType = type.getElementType();
auto numElements = type.getNumElements();
if (!attr.isSplat() && allowHex &&
shouldPrintElementsAttrWithHex(numElements)) {
ArrayRef<char> rawData = attr.getRawData();
if (llvm::support::endian::system_endianness() ==
llvm::support::endianness::big) {
SmallVector<char, 64> outDataVec(rawData.size());
MutableArrayRef<char> convRawData(outDataVec);
DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
rawData, convRawData, type);
printHexString(convRawData);
} else {
printHexString(rawData);
}
return;
}
if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
Type complexElementType = complexTy.getElementType();
if (complexElementType.isa<IntegerType>()) {
bool isSigned = !complexElementType.isUnsignedInteger();
auto valueIt = attr.value_begin<std::complex<APInt>>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
auto complexValue = *(valueIt + index);
os << "(";
printDenseIntElement(complexValue.real(), os, isSigned);
os << ",";
printDenseIntElement(complexValue.imag(), os, isSigned);
os << ")";
});
} else {
auto valueIt = attr.value_begin<std::complex<APFloat>>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
auto complexValue = *(valueIt + index);
os << "(";
printFloatValue(complexValue.real(), os);
os << ",";
printFloatValue(complexValue.imag(), os);
os << ")";
});
}
} else if (elementType.isIntOrIndex()) {
bool isSigned = !elementType.isUnsignedInteger();
auto valueIt = attr.value_begin<APInt>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
printDenseIntElement(*(valueIt + index), os, isSigned);
});
} else {
assert(elementType.isa<FloatType>() && "unexpected element type");
auto valueIt = attr.value_begin<APFloat>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
printFloatValue(*(valueIt + index), os);
});
}
}
void AsmPrinter::Impl::printDenseStringElementsAttr(
DenseStringElementsAttr attr) {
ArrayRef<StringRef> data = attr.getRawStringData();
auto printFn = [&](unsigned index) { printEscapedString(data[index]); };
printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
}
void AsmPrinter::Impl::printType(Type type) {
if (!type) {
os << "<<NULL TYPE>>";
return;
}
if (state && succeeded(state->getAliasState().getAlias(type, os)))
return;
TypeSwitch<Type>(type)
.Case<OpaqueType>([&](OpaqueType opaqueTy) {
printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
opaqueTy.getTypeData());
})
.Case<IndexType>([&](Type) { os << "index"; })
.Case<BFloat16Type>([&](Type) { os << "bf16"; })
.Case<Float16Type>([&](Type) { os << "f16"; })
.Case<Float32Type>([&](Type) { os << "f32"; })
.Case<Float64Type>([&](Type) { os << "f64"; })
.Case<Float80Type>([&](Type) { os << "f80"; })
.Case<Float128Type>([&](Type) { os << "f128"; })
.Case<IntegerType>([&](IntegerType integerTy) {
if (integerTy.isSigned())
os << 's';
else if (integerTy.isUnsigned())
os << 'u';
os << 'i' << integerTy.getWidth();
})
.Case<FunctionType>([&](FunctionType funcTy) {
os << '(';
interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
os << ") -> ";
ArrayRef<Type> results = funcTy.getResults();
if (results.size() == 1 && !results[0].isa<FunctionType>()) {
printType(results[0]);
} else {
os << '(';
interleaveComma(results, [&](Type ty) { printType(ty); });
os << ')';
}
})
.Case<VectorType>([&](VectorType vectorTy) {
os << "vector<";
auto vShape = vectorTy.getShape();
unsigned lastDim = vShape.size();
unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims();
unsigned dimIdx = 0;
for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++)
os << vShape[dimIdx] << 'x';
if (vectorTy.isScalable()) {
os << '[';
unsigned secondToLastDim = lastDim - 1;
for (; dimIdx < secondToLastDim; dimIdx++)
os << vShape[dimIdx] << 'x';
os << vShape[dimIdx] << "]x";
}
printType(vectorTy.getElementType());
os << '>';
})
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
os << "tensor<";
for (int64_t dim : tensorTy.getShape()) {
if (ShapedType::isDynamic(dim))
os << '?';
else
os << dim;
os << 'x';
}
printType(tensorTy.getElementType());
if (tensorTy.getEncoding()) {
os << ", ";
printAttribute(tensorTy.getEncoding());
}
os << '>';
})
.Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
os << "tensor<*x";
printType(tensorTy.getElementType());
os << '>';
})
.Case<MemRefType>([&](MemRefType memrefTy) {
os << "memref<";
for (int64_t dim : memrefTy.getShape()) {
if (ShapedType::isDynamic(dim))
os << '?';
else
os << dim;
os << 'x';
}
printType(memrefTy.getElementType());
if (!memrefTy.getLayout().isIdentity()) {
os << ", ";
printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
}
if (memrefTy.getMemorySpace()) {
os << ", ";
printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
}
os << '>';
})
.Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
os << "memref<*x";
printType(memrefTy.getElementType());
if (memrefTy.getMemorySpace()) {
os << ", ";
printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
}
os << '>';
})
.Case<ComplexType>([&](ComplexType complexTy) {
os << "complex<";
printType(complexTy.getElementType());
os << '>';
})
.Case<TupleType>([&](TupleType tupleTy) {
os << "tuple<";
interleaveComma(tupleTy.getTypes(),
[&](Type type) { printType(type); });
os << '>';
})
.Case<NoneType>([&](Type) { os << "none"; })
.Default([&](Type type) { return printDialectType(type); });
}
void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs,
bool withKeyword) {
if (attrs.empty())
return;
auto printFilteredAttributesFn = [&](auto filteredAttrs) {
if (withKeyword)
os << " attributes";
os << " {";
interleaveComma(filteredAttrs,
[&](NamedAttribute attr) { printNamedAttribute(attr); });
os << '}';
};
if (elidedAttrs.empty())
return printFilteredAttributesFn(attrs);
llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
elidedAttrs.end());
auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
return !elidedAttrsSet.contains(attr.getName().strref());
});
if (!filteredAttrs.empty())
printFilteredAttributesFn(filteredAttrs);
}
void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
::printKeywordOrString(attr.getName().strref(), os);
if (attr.getValue().isa<UnitAttr>())
return;
os << " = ";
printAttribute(attr.getValue());
}
void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
auto &dialect = attr.getDialect();
std::string attrName;
{
llvm::raw_string_ostream attrNameStr(attrName);
Impl subPrinter(attrNameStr, printerFlags, state);
DialectAsmPrinter printer(subPrinter);
dialect.printAttribute(attr, printer);
for (auto &it : subPrinter.dialectResources)
for (const auto &resource : it.second)
dialectResources[it.first].insert(resource);
}
printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
}
void AsmPrinter::Impl::printDialectType(Type type) {
auto &dialect = type.getDialect();
std::string typeName;
{
llvm::raw_string_ostream typeNameStr(typeName);
Impl subPrinter(typeNameStr, printerFlags, state);
DialectAsmPrinter printer(subPrinter);
dialect.printType(type, printer);
for (auto &it : subPrinter.dialectResources)
for (const auto &resource : it.second)
dialectResources[it.first].insert(resource);
}
printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
}
void AsmPrinter::Impl::printEscapedString(StringRef str) {
os << "\"";
llvm::printEscapedString(str, os);
os << "\"";
}
void AsmPrinter::Impl::printHexString(StringRef str) {
os << "\"0x" << llvm::toHex(str) << "\"";
}
void AsmPrinter::Impl::printHexString(ArrayRef<char> data) {
printHexString(StringRef(data.data(), data.size()));
}
AsmPrinter::~AsmPrinter() = default;
raw_ostream &AsmPrinter::getStream() const {
assert(impl && "expected AsmPrinter::getStream to be overriden");
return impl->getStream();
}
void AsmPrinter::printFloat(const APFloat &value) {
assert(impl && "expected AsmPrinter::printFloat to be overriden");
printFloatValue(value, impl->getStream());
}
void AsmPrinter::printType(Type type) {
assert(impl && "expected AsmPrinter::printType to be overriden");
impl->printType(type);
}
void AsmPrinter::printAttribute(Attribute attr) {
assert(impl && "expected AsmPrinter::printAttribute to be overriden");
impl->printAttribute(attr);
}
LogicalResult AsmPrinter::printAlias(Attribute attr) {
assert(impl && "expected AsmPrinter::printAlias to be overriden");
return impl->printAlias(attr);
}
LogicalResult AsmPrinter::printAlias(Type type) {
assert(impl && "expected AsmPrinter::printAlias to be overriden");
return impl->printAlias(type);
}
void AsmPrinter::printAttributeWithoutType(Attribute attr) {
assert(impl &&
"expected AsmPrinter::printAttributeWithoutType to be overriden");
impl->printAttribute(attr, Impl::AttrTypeElision::Must);
}
void AsmPrinter::printKeywordOrString(StringRef keyword) {
assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
::printKeywordOrString(keyword, impl->getStream());
}
void AsmPrinter::printSymbolName(StringRef symbolRef) {
assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
::printSymbolReference(symbolRef, impl->getStream());
}
void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) {
assert(impl && "expected AsmPrinter::printResourceHandle to be overriden");
impl->printResourceHandle(resource);
}
void AsmPrinter::Impl::printAffineExpr(
AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
}
void AsmPrinter::Impl::printAffineExprInternal(
AffineExpr expr, BindingStrength enclosingTightness,
function_ref<void(unsigned, bool)> printValueName) {
const char *binopSpelling = nullptr;
switch (expr.getKind()) {
case AffineExprKind::SymbolId: {
unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
if (printValueName)
printValueName(pos, true);
else
os << 's' << pos;
return;
}
case AffineExprKind::DimId: {
unsigned pos = expr.cast<AffineDimExpr>().getPosition();
if (printValueName)
printValueName(pos, false);
else
os << 'd' << pos;
return;
}
case AffineExprKind::Constant:
os << expr.cast<AffineConstantExpr>().getValue();
return;
case AffineExprKind::Add:
binopSpelling = " + ";
break;
case AffineExprKind::Mul:
binopSpelling = " * ";
break;
case AffineExprKind::FloorDiv:
binopSpelling = " floordiv ";
break;
case AffineExprKind::CeilDiv:
binopSpelling = " ceildiv ";
break;
case AffineExprKind::Mod:
binopSpelling = " mod ";
break;
}
auto binOp = expr.cast<AffineBinaryOpExpr>();
AffineExpr lhsExpr = binOp.getLHS();
AffineExpr rhsExpr = binOp.getRHS();
if (binOp.getKind() != AffineExprKind::Add) {
if (enclosingTightness == BindingStrength::Strong)
os << '(';
auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
rhsConst.getValue() == -1) {
os << "-";
printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
if (enclosingTightness == BindingStrength::Strong)
os << ')';
return;
}
printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
os << binopSpelling;
printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
if (enclosingTightness == BindingStrength::Strong)
os << ')';
return;
}
if (enclosingTightness == BindingStrength::Strong)
os << '(';
if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
if (rhs.getKind() == AffineExprKind::Mul) {
AffineExpr rrhsExpr = rhs.getRHS();
if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
if (rrhs.getValue() == -1) {
printAffineExprInternal(lhsExpr, BindingStrength::Weak,
printValueName);
os << " - ";
if (rhs.getLHS().getKind() == AffineExprKind::Add) {
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
printValueName);
} else {
printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
printValueName);
}
if (enclosingTightness == BindingStrength::Strong)
os << ')';
return;
}
if (rrhs.getValue() < -1) {
printAffineExprInternal(lhsExpr, BindingStrength::Weak,
printValueName);
os << " - ";
printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
printValueName);
os << " * " << -rrhs.getValue();
if (enclosingTightness == BindingStrength::Strong)
os << ')';
return;
}
}
}
}
if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
if (rhsConst.getValue() < 0) {
printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
os << " - " << -rhsConst.getValue();
if (enclosingTightness == BindingStrength::Strong)
os << ')';
return;
}
}
printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
os << " + ";
printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
if (enclosingTightness == BindingStrength::Strong)
os << ')';
}
void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
printAffineExprInternal(expr, BindingStrength::Weak);
isEq ? os << " == 0" : os << " >= 0";
}
void AsmPrinter::Impl::printAffineMap(AffineMap map) {
os << '(';
for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
os << 'd' << i << ", ";
if (map.getNumDims() >= 1)
os << 'd' << map.getNumDims() - 1;
os << ')';
if (map.getNumSymbols() != 0) {
os << '[';
for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
os << 's' << i << ", ";
if (map.getNumSymbols() >= 1)
os << 's' << map.getNumSymbols() - 1;
os << ']';
}
os << " -> (";
interleaveComma(map.getResults(),
[&](AffineExpr expr) { printAffineExpr(expr); });
os << ')';
}
void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
os << '(';
for (unsigned i = 1; i < set.getNumDims(); ++i)
os << 'd' << i - 1 << ", ";
if (set.getNumDims() >= 1)
os << 'd' << set.getNumDims() - 1;
os << ')';
if (set.getNumSymbols() != 0) {
os << '[';
for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
os << 's' << i << ", ";
if (set.getNumSymbols() >= 1)
os << 's' << set.getNumSymbols() - 1;
os << ']';
}
os << " : (";
int numConstraints = set.getNumConstraints();
for (int i = 1; i < numConstraints; ++i) {
printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
os << ", ";
}
if (numConstraints >= 1)
printAffineConstraint(set.getConstraint(numConstraints - 1),
set.isEq(numConstraints - 1));
os << ')';
}
namespace {
class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
public:
using Impl = AsmPrinter::Impl;
using Impl::printType;
explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
: Impl(os, state.getPrinterFlags(), &state),
OpAsmPrinter(static_cast<Impl &>(*this)) {}
void printTopLevelOperation(Operation *op);
void print(Operation *op);
void printOperation(Operation *op);
void printGenericOp(Operation *op, bool printOpName) override;
void printBlockName(Block *block);
void print(Block *block, bool printBlockArgs = true,
bool printBlockTerminator = true);
void printValueID(Value value, bool printResultNo = true,
raw_ostream *streamOverride = nullptr) const;
void printOperationID(Operation *op,
raw_ostream *streamOverride = nullptr) const;
void printNewline() override {
os << newLine;
os.indent(currentIndent);
}
void printRegionArgument(BlockArgument arg,
ArrayRef<NamedAttribute> argAttrs = {},
bool omitType = false) override;
void printOperand(Value value) override { printValueID(value); }
void printOperand(Value value, raw_ostream &os) override {
printValueID(value, true, &os);
}
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
Impl::printOptionalAttrDict(attrs, elidedAttrs);
}
void printOptionalAttrDictWithKeyword(
ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
Impl::printOptionalAttrDict(attrs, elidedAttrs,
true);
}
void printSuccessor(Block *successor) override;
void printSuccessorAndUseList(Block *successor,
ValueRange succOperands) override;
void printRegion(Region ®ion, bool printEntryBlockArgs,
bool printBlockTerminators, bool printEmptyBlock) override;
void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override {
state->getSSANameState().shadowRegionArgs(region, namesToUse);
}
void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
ValueRange operands) override;
void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
ValueRange symOperands) override;
void printUsersComment(Operation *op);
void printUsersComment(BlockArgument arg);
void printValueUsers(Value value);
void printUserIDs(Operation *user, bool prefixComma = false);
private:
class ResourceBuilder : public AsmResourceBuilder {
public:
using ValueFn = function_ref<void(raw_ostream &)>;
using PrintFn = function_ref<void(StringRef, ValueFn)>;
ResourceBuilder(OperationPrinter &p, PrintFn printFn)
: p(p), printFn(printFn) {}
~ResourceBuilder() override = default;
void buildBool(StringRef key, bool data) final {
printFn(key, [&](raw_ostream &os) { p.os << (data ? "true" : "false"); });
}
void buildString(StringRef key, StringRef data) final {
printFn(key, [&](raw_ostream &os) { p.printEscapedString(data); });
}
void buildBlob(StringRef key, ArrayRef<char> data,
uint32_t dataAlignment) final {
printFn(key, [&](raw_ostream &os) {
llvm::support::ulittle32_t dataAlignmentLE(dataAlignment);
os << "\"0x"
<< llvm::toHex(StringRef(reinterpret_cast<char *>(&dataAlignmentLE),
sizeof(dataAlignment)))
<< llvm::toHex(StringRef(data.data(), data.size())) << "\"";
});
}
private:
OperationPrinter &p;
PrintFn printFn;
};
void printFileMetadataDictionary(Operation *op);
void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict,
Operation *op);
SmallVector<StringRef> defaultDialectStack{"builtin"};
const static unsigned indentWidth = 2;
unsigned currentIndent = 0;
};
}
void OperationPrinter::printTopLevelOperation(Operation *op) {
state->getAliasState().printNonDeferredAliases(os, newLine);
print(op);
os << newLine;
state->getAliasState().printDeferredAliases(os, newLine);
printFileMetadataDictionary(op);
}
void OperationPrinter::printFileMetadataDictionary(Operation *op) {
bool sawMetadataEntry = false;
auto checkAddMetadataDict = [&] {
if (!std::exchange(sawMetadataEntry, true))
os << newLine << "{-#" << newLine;
};
printResourceFileMetadata(checkAddMetadataDict, op);
if (sawMetadataEntry)
os << newLine << "#-}" << newLine;
}
void OperationPrinter::printResourceFileMetadata(
function_ref<void()> checkAddMetadataDict, Operation *op) {
bool hadResource = false;
auto processProvider = [&](StringRef dictName, StringRef name, auto &provider,
auto &&...providerArgs) {
bool hadEntry = false;
auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
checkAddMetadataDict();
if (!std::exchange(hadResource, true))
os << " " << dictName << "_resources: {" << newLine;
if (!std::exchange(hadEntry, true))
os << " " << name << ": {" << newLine;
else
os << "," << newLine;
os << " " << key << ": ";
valueFn(os);
};
ResourceBuilder entryBuilder(*this, printFn);
provider.buildResources(op, providerArgs..., entryBuilder);
if (hadEntry)
os << newLine << " }";
};
for (const OpAsmDialectInterface &interface : state->getDialectInterfaces()) {
StringRef name = interface.getDialect()->getNamespace();
auto it = dialectResources.find(interface.getDialect());
if (it != dialectResources.end())
processProvider("dialect", name, interface, it->second);
else
processProvider("dialect", name, interface,
SetVector<AsmDialectResourceHandle>());
}
if (hadResource)
os << newLine << " }";
hadResource = false;
for (const auto &printer : state->getResourcePrinters())
processProvider("external", printer.getName(), printer);
if (hadResource)
os << newLine << " }";
}
void OperationPrinter::printRegionArgument(BlockArgument arg,
ArrayRef<NamedAttribute> argAttrs,
bool omitType) {
printOperand(arg);
if (!omitType) {
os << ": ";
printType(arg.getType());
}
printOptionalAttrDict(argAttrs);
printTrailingLocation(arg.getLoc(), false);
}
void OperationPrinter::print(Operation *op) {
state->registerOperationLocation(op, newLine.curLine, currentIndent);
os.indent(currentIndent);
printOperation(op);
printTrailingLocation(op->getLoc());
if (printerFlags.shouldPrintValueUsers())
printUsersComment(op);
}
void OperationPrinter::printOperation(Operation *op) {
if (size_t numResults = op->getNumResults()) {
auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
printValueID(op->getResult(resultNo), false);
if (resultCount > 1)
os << ':' << resultCount;
};
ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
if (!resultGroups.empty()) {
interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
printResultGroup(resultGroups[i],
resultGroups[i + 1] - resultGroups[i]);
});
os << ", ";
printResultGroup(resultGroups.back(), numResults - resultGroups.back());
} else {
printResultGroup(0, numResults);
}
os << " = ";
}
if (!printerFlags.shouldPrintGenericOpForm()) {
if (auto opInfo = op->getRegisteredInfo()) {
opInfo->printAssembly(op, *this, defaultDialectStack.back());
return;
}
if (Dialect *dialect = op->getDialect()) {
if (auto opPrinter = dialect->getOperationPrinter(op)) {
StringRef name = op->getName().getStringRef();
if (name.count('.') == 1)
name.consume_front((defaultDialectStack.back() + ".").str());
os << name;
opPrinter(op, *this);
return;
}
}
}
printGenericOp(op, true);
}
void OperationPrinter::printUsersComment(Operation *op) {
unsigned numResults = op->getNumResults();
if (!numResults && op->getNumOperands()) {
os << " // id: ";
printOperationID(op);
} else if (numResults && op->use_empty()) {
os << " // unused";
} else if (numResults && !op->use_empty()) {
unsigned usedInNResults = 0;
unsigned usedInNOperations = 0;
SmallPtrSet<Operation *, 1> userSet;
for (Operation *user : op->getUsers()) {
if (userSet.insert(user).second) {
++usedInNOperations;
usedInNResults += user->getNumResults();
}
}
bool exactlyOneUniqueUse =
usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1;
os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": ";
bool shouldPrintBrackets = numResults > 1;
auto printOpResult = [&](OpResult opResult) {
if (shouldPrintBrackets)
os << "(";
printValueUsers(opResult);
if (shouldPrintBrackets)
os << ")";
};
interleaveComma(op->getResults(), printOpResult);
}
}
void OperationPrinter::printUsersComment(BlockArgument arg) {
os << "// ";
printValueID(arg);
if (arg.use_empty()) {
os << " is unused";
} else {
os << " is used by ";
printValueUsers(arg);
}
os << newLine;
}
void OperationPrinter::printValueUsers(Value value) {
if (value.use_empty())
os << "unused";
SmallPtrSet<Operation *, 1> userSet;
for (auto &indexedUser : enumerate(value.getUsers())) {
if (userSet.insert(indexedUser.value()).second)
printUserIDs(indexedUser.value(), indexedUser.index());
}
}
void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
if (prefixComma)
os << ", ";
if (!user->getNumResults()) {
printOperationID(user);
} else {
interleaveComma(user->getResults(),
[this](Value result) { printValueID(result); });
}
}
void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
if (printOpName)
printEscapedString(op->getName().getStringRef());
os << '(';
interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
os << ')';
if (op->getNumSuccessors() != 0) {
os << '[';
interleaveComma(op->getSuccessors(),
[&](Block *successor) { printBlockName(successor); });
os << ']';
}
if (op->getNumRegions() != 0) {
os << " (";
interleaveComma(op->getRegions(), [&](Region ®ion) {
printRegion(region, true,
true, true);
});
os << ')';
}
auto attrs = op->getAttrs();
printOptionalAttrDict(attrs);
os << " : ";
printFunctionalType(op);
}
void OperationPrinter::printBlockName(Block *block) {
os << state->getSSANameState().getBlockInfo(block).name;
}
void OperationPrinter::print(Block *block, bool printBlockArgs,
bool printBlockTerminator) {
if (printBlockArgs) {
os.indent(currentIndent);
printBlockName(block);
if (!block->args_empty()) {
os << '(';
interleaveComma(block->getArguments(), [&](BlockArgument arg) {
printValueID(arg);
os << ": ";
printType(arg.getType());
printTrailingLocation(arg.getLoc(), false);
});
os << ')';
}
os << ':';
if (!block->getParent()) {
os << " // block is not in a region!";
} else if (block->hasNoPredecessors()) {
if (!block->isEntryBlock())
os << " // no predecessors";
} else if (auto *pred = block->getSinglePredecessor()) {
os << " // pred: ";
printBlockName(pred);
} else {
SmallVector<BlockInfo, 4> predIDs;
for (auto *pred : block->getPredecessors())
predIDs.push_back(state->getSSANameState().getBlockInfo(pred));
llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
return lhs.ordering < rhs.ordering;
});
os << " // " << predIDs.size() << " preds: ";
interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
}
os << newLine;
}
currentIndent += indentWidth;
if (printerFlags.shouldPrintValueUsers()) {
for (BlockArgument arg : block->getArguments()) {
os.indent(currentIndent);
printUsersComment(arg);
}
}
bool hasTerminator =
!block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
auto range = llvm::make_range(
block->begin(),
std::prev(block->end(),
(!hasTerminator || printBlockTerminator) ? 0 : 1));
for (auto &op : range) {
print(&op);
os << newLine;
}
currentIndent -= indentWidth;
}
void OperationPrinter::printValueID(Value value, bool printResultNo,
raw_ostream *streamOverride) const {
state->getSSANameState().printValueID(value, printResultNo,
streamOverride ? *streamOverride : os);
}
void OperationPrinter::printOperationID(Operation *op,
raw_ostream *streamOverride) const {
state->getSSANameState().printOperationID(op, streamOverride ? *streamOverride
: os);
}
void OperationPrinter::printSuccessor(Block *successor) {
printBlockName(successor);
}
void OperationPrinter::printSuccessorAndUseList(Block *successor,
ValueRange succOperands) {
printBlockName(successor);
if (succOperands.empty())
return;
os << '(';
interleaveComma(succOperands,
[this](Value operand) { printValueID(operand); });
os << " : ";
interleaveComma(succOperands,
[this](Value operand) { printType(operand.getType()); });
os << ')';
}
void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs,
bool printBlockTerminators,
bool printEmptyBlock) {
os << "{" << newLine;
if (!region.empty()) {
auto restoreDefaultDialect =
llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); });
if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp()))
defaultDialectStack.push_back(iface.getDefaultDialect());
else
defaultDialectStack.push_back("");
auto *entryBlock = ®ion.front();
bool shouldAlwaysPrintBlockHeader =
(printEmptyBlock && entryBlock->empty()) ||
(printEntryBlockArgs && entryBlock->getNumArguments() != 0);
print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
print(&b);
}
os.indent(currentIndent) << "}";
}
void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
ValueRange operands) {
AffineMap map = mapAttr.getValue();
unsigned numDims = map.getNumDims();
auto printValueName = [&](unsigned pos, bool isSymbol) {
unsigned index = isSymbol ? numDims + pos : pos;
assert(index < operands.size());
if (isSymbol)
os << "symbol(";
printValueID(operands[index]);
if (isSymbol)
os << ')';
};
interleaveComma(map.getResults(), [&](AffineExpr expr) {
printAffineExpr(expr, printValueName);
});
}
void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
ValueRange dimOperands,
ValueRange symOperands) {
auto printValueName = [&](unsigned pos, bool isSymbol) {
if (!isSymbol)
return printValueID(dimOperands[pos]);
os << "symbol(";
printValueID(symOperands[pos]);
os << ')';
};
printAffineExpr(expr, printValueName);
}
void Attribute::print(raw_ostream &os) const {
AsmPrinter::Impl(os).printAttribute(*this);
}
void Attribute::dump() const {
print(llvm::errs());
llvm::errs() << "\n";
}
void Type::print(raw_ostream &os) const {
AsmPrinter::Impl(os).printType(*this);
}
void Type::dump() const { print(llvm::errs()); }
void AffineMap::dump() const {
print(llvm::errs());
llvm::errs() << "\n";
}
void IntegerSet::dump() const {
print(llvm::errs());
llvm::errs() << "\n";
}
void AffineExpr::print(raw_ostream &os) const {
if (!expr) {
os << "<<NULL AFFINE EXPR>>";
return;
}
AsmPrinter::Impl(os).printAffineExpr(*this);
}
void AffineExpr::dump() const {
print(llvm::errs());
llvm::errs() << "\n";
}
void AffineMap::print(raw_ostream &os) const {
if (!map) {
os << "<<NULL AFFINE MAP>>";
return;
}
AsmPrinter::Impl(os).printAffineMap(*this);
}
void IntegerSet::print(raw_ostream &os) const {
AsmPrinter::Impl(os).printIntegerSet(*this);
}
void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); }
void Value::print(raw_ostream &os, const OpPrintingFlags &flags) {
if (!impl) {
os << "<<NULL VALUE>>";
return;
}
if (auto *op = getDefiningOp())
return op->print(os, flags);
BlockArgument arg = this->cast<BlockArgument>();
os << "<block argument> of type '" << arg.getType()
<< "' at index: " << arg.getArgNumber();
}
void Value::print(raw_ostream &os, AsmState &state) {
if (!impl) {
os << "<<NULL VALUE>>";
return;
}
if (auto *op = getDefiningOp())
return op->print(os, state);
BlockArgument arg = this->cast<BlockArgument>();
os << "<block argument> of type '" << arg.getType()
<< "' at index: " << arg.getArgNumber();
}
void Value::dump() {
print(llvm::errs());
llvm::errs() << "\n";
}
void Value::printAsOperand(raw_ostream &os, AsmState &state) {
state.getImpl().getSSANameState().printValueID(*this, true,
os);
}
void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
if (!getParent() && !printerFlags.shouldUseLocalScope()) {
AsmState state(this, printerFlags);
state.getImpl().initializeAliases(this);
print(os, state);
return;
}
Operation *op = this;
bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
do {
if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
break;
Operation *parentOp = op->getParentOp();
if (!parentOp)
break;
op = parentOp;
} while (true);
AsmState state(op, printerFlags);
print(os, state);
}
void Operation::print(raw_ostream &os, AsmState &state) {
OperationPrinter printer(os, state.getImpl());
if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope())
printer.printTopLevelOperation(this);
else
printer.print(this);
}
void Operation::dump() {
print(llvm::errs(), OpPrintingFlags().useLocalScope());
llvm::errs() << "\n";
}
void Block::print(raw_ostream &os) {
Operation *parentOp = getParentOp();
if (!parentOp) {
os << "<<UNLINKED BLOCK>>\n";
return;
}
while (auto *nextOp = parentOp->getParentOp())
parentOp = nextOp;
AsmState state(parentOp);
print(os, state);
}
void Block::print(raw_ostream &os, AsmState &state) {
OperationPrinter(os, state.getImpl()).print(this);
}
void Block::dump() { print(llvm::errs()); }
void Block::printAsOperand(raw_ostream &os, bool printType) {
Operation *parentOp = getParentOp();
if (!parentOp) {
os << "<<UNLINKED BLOCK>>\n";
return;
}
AsmState state(parentOp);
printAsOperand(os, state);
}
void Block::printAsOperand(raw_ostream &os, AsmState &state) {
OperationPrinter printer(os, state.getImpl());
printer.printBlockName(this);
}