#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Bytecode/Encoding.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SourceMgr.h"
#include <cstddef>
#include <list>
#include <memory>
#include <numeric>
#include <optional>
#define DEBUG_TYPE "mlir-bytecode-reader"
using namespace mlir;
static std::string toString(bytecode::Section::ID sectionID) {
switch (sectionID) {
case bytecode::Section::kString:
return "String (0)";
case bytecode::Section::kDialect:
return "Dialect (1)";
case bytecode::Section::kAttrType:
return "AttrType (2)";
case bytecode::Section::kAttrTypeOffset:
return "AttrTypeOffset (3)";
case bytecode::Section::kIR:
return "IR (4)";
case bytecode::Section::kResource:
return "Resource (5)";
case bytecode::Section::kResourceOffset:
return "ResourceOffset (6)";
case bytecode::Section::kDialectVersions:
return "DialectVersions (7)";
case bytecode::Section::kProperties:
return "Properties (8)";
default:
return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str();
}
}
static bool isSectionOptional(bytecode::Section::ID sectionID, int version) {
switch (sectionID) {
case bytecode::Section::kString:
case bytecode::Section::kDialect:
case bytecode::Section::kAttrType:
case bytecode::Section::kAttrTypeOffset:
case bytecode::Section::kIR:
return false;
case bytecode::Section::kResource:
case bytecode::Section::kResourceOffset:
case bytecode::Section::kDialectVersions:
return true;
case bytecode::Section::kProperties:
return version < bytecode::kNativePropertiesEncoding;
default:
llvm_unreachable("unknown section ID");
}
}
namespace {
class EncodingReader {
public:
explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
: buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {}
explicit EncodingReader(StringRef contents, Location fileLoc)
: EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()),
contents.size()},
fileLoc) {}
bool empty() const { return dataIt == buffer.end(); }
size_t size() const { return buffer.end() - dataIt; }
LogicalResult alignTo(unsigned alignment) {
if (!llvm::isPowerOf2_32(alignment))
return emitError("expected alignment to be a power-of-two");
auto isUnaligned = [&](const uint8_t *ptr) {
return ((uintptr_t)ptr & (alignment - 1)) != 0;
};
while (isUnaligned(dataIt)) {
uint8_t padding;
if (failed(parseByte(padding)))
return failure();
if (padding != bytecode::kAlignmentByte) {
return emitError("expected alignment byte (0xCB), but got: '0x" +
llvm::utohexstr(padding) + "'");
}
}
if (LLVM_UNLIKELY(isUnaligned(dataIt))) {
return emitError("expected data iterator aligned to ", alignment,
", but got pointer: '0x" +
llvm::utohexstr((uintptr_t)dataIt) + "'");
}
return success();
}
template <typename... Args>
InFlightDiagnostic emitError(Args &&...args) const {
return ::emitError(fileLoc).append(std::forward<Args>(args)...);
}
InFlightDiagnostic emitError() const { return ::emitError(fileLoc); }
template <typename T>
LogicalResult parseByte(T &value) {
if (empty())
return emitError("attempting to parse a byte at the end of the bytecode");
value = static_cast<T>(*dataIt++);
return success();
}
LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) {
if (length > size()) {
return emitError("attempting to parse ", length, " bytes when only ",
size(), " remain");
}
result = {dataIt, length};
dataIt += length;
return success();
}
LogicalResult parseBytes(size_t length, uint8_t *result) {
if (length > size()) {
return emitError("attempting to parse ", length, " bytes when only ",
size(), " remain");
}
memcpy(result, dataIt, length);
dataIt += length;
return success();
}
LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data,
uint64_t &alignment) {
uint64_t dataSize;
if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) ||
failed(alignTo(alignment)))
return failure();
return parseBytes(dataSize, data);
}
LogicalResult parseVarInt(uint64_t &result) {
if (failed(parseByte(result)))
return failure();
if (LLVM_LIKELY(result & 1)) {
result >>= 1;
return success();
}
if (LLVM_UNLIKELY(result == 0)) {
llvm::support::ulittle64_t resultLE;
if (failed(parseBytes(sizeof(resultLE),
reinterpret_cast<uint8_t *>(&resultLE))))
return failure();
result = resultLE;
return success();
}
return parseMultiByteVarInt(result);
}
LogicalResult parseSignedVarInt(uint64_t &result) {
if (failed(parseVarInt(result)))
return failure();
result = (result >> 1) ^ (~(result & 1) + 1);
return success();
}
LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) {
if (failed(parseVarInt(result)))
return failure();
flag = result & 1;
result >>= 1;
return success();
}
LogicalResult skipBytes(size_t length) {
if (length > size()) {
return emitError("attempting to skip ", length, " bytes when only ",
size(), " remain");
}
dataIt += length;
return success();
}
LogicalResult parseNullTerminatedString(StringRef &result) {
const char *startIt = (const char *)dataIt;
const char *nulIt = (const char *)memchr(startIt, 0, size());
if (!nulIt)
return emitError(
"malformed null-terminated string, no null character found");
result = StringRef(startIt, nulIt - startIt);
dataIt = (const uint8_t *)nulIt + 1;
return success();
}
LogicalResult parseSection(bytecode::Section::ID §ionID,
ArrayRef<uint8_t> §ionData) {
uint8_t sectionIDAndHasAlignment;
uint64_t length;
if (failed(parseByte(sectionIDAndHasAlignment)) ||
failed(parseVarInt(length)))
return failure();
sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment &
0b01111111);
bool hasAlignment = sectionIDAndHasAlignment & 0b10000000;
if (sectionID >= bytecode::Section::kNumSections)
return emitError("invalid section ID: ", unsigned(sectionID));
if (hasAlignment) {
uint64_t alignment;
if (failed(parseVarInt(alignment)) || failed(alignTo(alignment)))
return failure();
}
return parseBytes(static_cast<size_t>(length), sectionData);
}
Location getLoc() const { return fileLoc; }
private:
LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) {
uint32_t numBytes = llvm::countr_zero<uint32_t>(result);
assert(numBytes > 0 && numBytes <= 7 &&
"unexpected number of trailing zeros in varint encoding");
llvm::support::ulittle64_t resultLE(result);
if (failed(
parseBytes(numBytes, reinterpret_cast<uint8_t *>(&resultLE) + 1)))
return failure();
result = resultLE >> (numBytes + 1);
return success();
}
ArrayRef<uint8_t> buffer;
const uint8_t *dataIt;
Location fileLoc;
};
}
template <typename RangeT, typename T>
static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries,
uint64_t index, T &entry,
StringRef entryStr) {
if (index >= entries.size())
return reader.emitError("invalid ", entryStr, " index: ", index);
if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>)
entry = entries[index];
else
entry = &entries[index];
return success();
}
template <typename RangeT, typename T>
static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries,
T &entry, StringRef entryStr) {
uint64_t entryIdx;
if (failed(reader.parseVarInt(entryIdx)))
return failure();
return resolveEntry(reader, entries, entryIdx, entry, entryStr);
}
namespace {
class StringSectionReader {
public:
LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData);
LogicalResult parseString(EncodingReader &reader, StringRef &result) const {
return parseEntry(reader, strings, result, "string");
}
LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result,
bool &flag) const {
uint64_t entryIdx;
if (failed(reader.parseVarIntWithFlag(entryIdx, flag)))
return failure();
return parseStringAtIndex(reader, entryIdx, result);
}
LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
StringRef &result) const {
return resolveEntry(reader, strings, index, result, "string");
}
private:
SmallVector<StringRef> strings;
};
}
LogicalResult StringSectionReader::initialize(Location fileLoc,
ArrayRef<uint8_t> sectionData) {
EncodingReader stringReader(sectionData, fileLoc);
uint64_t numStrings;
if (failed(stringReader.parseVarInt(numStrings)))
return failure();
strings.resize(numStrings);
size_t stringDataEndOffset = sectionData.size();
for (StringRef &string : llvm::reverse(strings)) {
uint64_t stringSize;
if (failed(stringReader.parseVarInt(stringSize)))
return failure();
if (stringDataEndOffset < stringSize) {
return stringReader.emitError(
"string size exceeds the available data size");
}
size_t stringOffset = stringDataEndOffset - stringSize;
string = StringRef(
reinterpret_cast<const char *>(sectionData.data() + stringOffset),
stringSize - 1);
stringDataEndOffset = stringOffset;
}
if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) {
return stringReader.emitError("unexpected trailing data between the "
"offsets for strings and their data");
}
return success();
}
namespace {
class DialectReader;
struct BytecodeDialect {
LogicalResult load(const DialectReader &reader, MLIRContext *ctx);
Dialect *getLoadedDialect() const {
assert(dialect &&
"expected `load` to be invoked before `getLoadedDialect`");
return *dialect;
}
std::optional<Dialect *> dialect;
const BytecodeDialectInterface *interface = nullptr;
StringRef name;
ArrayRef<uint8_t> versionBuffer;
std::unique_ptr<DialectVersion> loadedVersion;
};
struct BytecodeOperationName {
BytecodeOperationName(BytecodeDialect *dialect, StringRef name,
std::optional<bool> wasRegistered)
: dialect(dialect), name(name), wasRegistered(wasRegistered) {}
std::optional<OperationName> opName;
BytecodeDialect *dialect;
StringRef name;
std::optional<bool> wasRegistered;
};
}
static LogicalResult parseDialectGrouping(
EncodingReader &reader,
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
function_ref<LogicalResult(BytecodeDialect *)> entryCallback) {
std::unique_ptr<BytecodeDialect> *dialect;
if (failed(parseEntry(reader, dialects, dialect, "dialect")))
return failure();
uint64_t numEntries;
if (failed(reader.parseVarInt(numEntries)))
return failure();
for (uint64_t i = 0; i < numEntries; ++i)
if (failed(entryCallback(dialect->get())))
return failure();
return success();
}
namespace {
class ResourceSectionReader {
public:
LogicalResult
initialize(Location fileLoc, const ParserConfig &config,
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
LogicalResult parseResourceHandle(EncodingReader &reader,
AsmDialectResourceHandle &result) const {
return parseEntry(reader, dialectResources, result, "resource handle");
}
private:
SmallVector<AsmDialectResourceHandle> dialectResources;
llvm::StringMap<std::string> dialectResourceHandleRenamingMap;
};
class ParsedResourceEntry : public AsmParsedResourceEntry {
public:
ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind,
EncodingReader &reader, StringSectionReader &stringReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
: key(key), kind(kind), reader(reader), stringReader(stringReader),
bufferOwnerRef(bufferOwnerRef) {}
~ParsedResourceEntry() override = default;
StringRef getKey() const final { return key; }
InFlightDiagnostic emitError() const final { return reader.emitError(); }
AsmResourceEntryKind getKind() const final { return kind; }
FailureOr<bool> parseAsBool() const final {
if (kind != AsmResourceEntryKind::Bool)
return emitError() << "expected a bool resource entry, but found a "
<< toString(kind) << " entry instead";
bool value;
if (failed(reader.parseByte(value)))
return failure();
return value;
}
FailureOr<std::string> parseAsString() const final {
if (kind != AsmResourceEntryKind::String)
return emitError() << "expected a string resource entry, but found a "
<< toString(kind) << " entry instead";
StringRef string;
if (failed(stringReader.parseString(reader, string)))
return failure();
return string.str();
}
FailureOr<AsmResourceBlob>
parseAsBlob(BlobAllocatorFn allocator) const final {
if (kind != AsmResourceEntryKind::Blob)
return emitError() << "expected a blob resource entry, but found a "
<< toString(kind) << " entry instead";
ArrayRef<uint8_t> data;
uint64_t alignment;
if (failed(reader.parseBlobAndAlignment(data, alignment)))
return failure();
if (bufferOwnerRef) {
ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()),
data.size());
return UnmanagedAsmResourceBlob::allocateWithAlign(
charData, alignment,
[bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {});
}
AsmResourceBlob blob = allocator(data.size(), alignment);
assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&
blob.isMutable() &&
"blob allocator did not return a properly aligned address");
memcpy(blob.getMutableData().data(), data.data(), data.size());
return blob;
}
private:
StringRef key;
AsmResourceEntryKind kind;
EncodingReader &reader;
StringSectionReader &stringReader;
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
};
}
template <typename T>
static LogicalResult
parseResourceGroup(Location fileLoc, bool allowEmpty,
EncodingReader &offsetReader, EncodingReader &resourceReader,
StringSectionReader &stringReader, T *handler,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
function_ref<StringRef(StringRef)> remapKey = {},
function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
uint64_t numResources;
if (failed(offsetReader.parseVarInt(numResources)))
return failure();
for (uint64_t i = 0; i < numResources; ++i) {
StringRef key;
AsmResourceEntryKind kind;
uint64_t resourceOffset;
ArrayRef<uint8_t> data;
if (failed(stringReader.parseString(offsetReader, key)) ||
failed(offsetReader.parseVarInt(resourceOffset)) ||
failed(offsetReader.parseByte(kind)) ||
failed(resourceReader.parseBytes(resourceOffset, data)))
return failure();
if ((processKeyFn && failed(processKeyFn(key))))
return failure();
if (allowEmpty && data.empty())
continue;
if (!handler)
continue;
EncodingReader entryReader(data, fileLoc);
key = remapKey(key);
ParsedResourceEntry entry(key, kind, entryReader, stringReader,
bufferOwnerRef);
if (failed(handler->parseResource(entry)))
return failure();
if (!entryReader.empty()) {
return entryReader.emitError(
"unexpected trailing bytes in resource entry '", key, "'");
}
}
return success();
}
LogicalResult ResourceSectionReader::initialize(
Location fileLoc, const ParserConfig &config,
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
EncodingReader resourceReader(sectionData, fileLoc);
EncodingReader offsetReader(offsetSectionData, fileLoc);
uint64_t numExternalResourceGroups;
if (failed(offsetReader.parseVarInt(numExternalResourceGroups)))
return failure();
auto parseGroup = [&](auto *handler, bool allowEmpty = false,
function_ref<LogicalResult(StringRef)> keyFn = {}) {
auto resolveKey = [&](StringRef key) -> StringRef {
auto it = dialectResourceHandleRenamingMap.find(key);
if (it == dialectResourceHandleRenamingMap.end())
return key;
return it->second;
};
return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
stringReader, handler, bufferOwnerRef, resolveKey,
keyFn);
};
for (uint64_t i = 0; i < numExternalResourceGroups; ++i) {
StringRef key;
if (failed(stringReader.parseString(offsetReader, key)))
return failure();
AsmResourceParser *handler = config.getResourceParser(key);
if (!handler) {
emitWarning(fileLoc) << "ignoring unknown external resources for '" << key
<< "'";
}
if (failed(parseGroup(handler)))
return failure();
}
MLIRContext *ctx = fileLoc->getContext();
while (!offsetReader.empty()) {
std::unique_ptr<BytecodeDialect> *dialect;
if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
failed((*dialect)->load(dialectReader, ctx)))
return failure();
Dialect *loadedDialect = (*dialect)->getLoadedDialect();
if (!loadedDialect) {
return resourceReader.emitError()
<< "dialect '" << (*dialect)->name << "' is unknown";
}
const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect);
if (!handler) {
return resourceReader.emitError()
<< "unexpected resources for dialect '" << (*dialect)->name << "'";
}
auto processResourceKeyFn = [&](StringRef key) -> LogicalResult {
FailureOr<AsmDialectResourceHandle> handle =
handler->declareResource(key);
if (failed(handle)) {
return resourceReader.emitError()
<< "unknown 'resource' key '" << key << "' for dialect '"
<< (*dialect)->name << "'";
}
dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
dialectResources.push_back(*handle);
return success();
};
if (failed(parseGroup(handler, true, processResourceKeyFn)))
return failure();
}
return success();
}
namespace {
class AttrTypeReader {
template <typename T>
struct Entry {
T entry = {};
BytecodeDialect *dialect = nullptr;
bool hasCustomEncoding = false;
ArrayRef<uint8_t> data;
};
using AttrEntry = Entry<Attribute>;
using TypeEntry = Entry<Type>;
public:
AttrTypeReader(const StringSectionReader &stringReader,
const ResourceSectionReader &resourceReader,
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
uint64_t &bytecodeVersion, Location fileLoc,
const ParserConfig &config)
: stringReader(stringReader), resourceReader(resourceReader),
dialectsMap(dialectsMap), fileLoc(fileLoc),
bytecodeVersion(bytecodeVersion), parserConfig(config) {}
LogicalResult
initialize(MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData);
Attribute resolveAttribute(size_t index) {
return resolveEntry(attributes, index, "Attribute");
}
Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }
LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
uint64_t attrIdx;
if (failed(reader.parseVarInt(attrIdx)))
return failure();
result = resolveAttribute(attrIdx);
return success(!!result);
}
LogicalResult parseOptionalAttribute(EncodingReader &reader,
Attribute &result) {
uint64_t attrIdx;
bool flag;
if (failed(reader.parseVarIntWithFlag(attrIdx, flag)))
return failure();
if (!flag)
return success();
result = resolveAttribute(attrIdx);
return success(!!result);
}
LogicalResult parseType(EncodingReader &reader, Type &result) {
uint64_t typeIdx;
if (failed(reader.parseVarInt(typeIdx)))
return failure();
result = resolveType(typeIdx);
return success(!!result);
}
template <typename T>
LogicalResult parseAttribute(EncodingReader &reader, T &result) {
Attribute baseResult;
if (failed(parseAttribute(reader, baseResult)))
return failure();
if ((result = dyn_cast<T>(baseResult)))
return success();
return reader.emitError("expected attribute of type: ",
llvm::getTypeName<T>(), ", but got: ", baseResult);
}
private:
template <typename T>
T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
StringRef entryType);
template <typename T>
LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
StringRef entryType);
template <typename T>
LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
StringRef entryType);
const StringSectionReader &stringReader;
const ResourceSectionReader &resourceReader;
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
SmallVector<AttrEntry> attributes;
SmallVector<TypeEntry> types;
Location fileLoc;
uint64_t &bytecodeVersion;
const ParserConfig &parserConfig;
};
class DialectReader : public DialectBytecodeReader {
public:
DialectReader(AttrTypeReader &attrTypeReader,
const StringSectionReader &stringReader,
const ResourceSectionReader &resourceReader,
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
EncodingReader &reader, uint64_t &bytecodeVersion)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
resourceReader(resourceReader), dialectsMap(dialectsMap),
reader(reader), bytecodeVersion(bytecodeVersion) {}
InFlightDiagnostic emitError(const Twine &msg) const override {
return reader.emitError(msg);
}
FailureOr<const DialectVersion *>
getDialectVersion(StringRef dialectName) const override {
auto dialectEntry = dialectsMap.find(dialectName);
if (dialectEntry == dialectsMap.end())
return failure();
if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) ||
dialectEntry->getValue()->loadedVersion == nullptr)
return failure();
return dialectEntry->getValue()->loadedVersion.get();
}
MLIRContext *getContext() const override { return getLoc().getContext(); }
uint64_t getBytecodeVersion() const override { return bytecodeVersion; }
DialectReader withEncodingReader(EncodingReader &encReader) const {
return DialectReader(attrTypeReader, stringReader, resourceReader,
dialectsMap, encReader, bytecodeVersion);
}
Location getLoc() const { return reader.getLoc(); }
LogicalResult readAttribute(Attribute &result) override {
return attrTypeReader.parseAttribute(reader, result);
}
LogicalResult readOptionalAttribute(Attribute &result) override {
return attrTypeReader.parseOptionalAttribute(reader, result);
}
LogicalResult readType(Type &result) override {
return attrTypeReader.parseType(reader, result);
}
FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
AsmDialectResourceHandle handle;
if (failed(resourceReader.parseResourceHandle(reader, handle)))
return failure();
return handle;
}
LogicalResult readVarInt(uint64_t &result) override {
return reader.parseVarInt(result);
}
LogicalResult readSignedVarInt(int64_t &result) override {
uint64_t unsignedResult;
if (failed(reader.parseSignedVarInt(unsignedResult)))
return failure();
result = static_cast<int64_t>(unsignedResult);
return success();
}
FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override {
if (bitWidth <= 8) {
uint8_t value;
if (failed(reader.parseByte(value)))
return failure();
return APInt(bitWidth, value);
}
if (bitWidth <= 64) {
uint64_t value;
if (failed(reader.parseSignedVarInt(value)))
return failure();
return APInt(bitWidth, value);
}
uint64_t numActiveWords;
if (failed(reader.parseVarInt(numActiveWords)))
return failure();
SmallVector<uint64_t, 4> words(numActiveWords);
for (uint64_t i = 0; i < numActiveWords; ++i)
if (failed(reader.parseSignedVarInt(words[i])))
return failure();
return APInt(bitWidth, words);
}
FailureOr<APFloat>
readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override {
FailureOr<APInt> intVal =
readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics));
if (failed(intVal))
return failure();
return APFloat(semantics, *intVal);
}
LogicalResult readString(StringRef &result) override {
return stringReader.parseString(reader, result);
}
LogicalResult readBlob(ArrayRef<char> &result) override {
uint64_t dataSize;
ArrayRef<uint8_t> data;
if (failed(reader.parseVarInt(dataSize)) ||
failed(reader.parseBytes(dataSize, data)))
return failure();
result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()),
data.size());
return success();
}
LogicalResult readBool(bool &result) override {
return reader.parseByte(result);
}
private:
AttrTypeReader &attrTypeReader;
const StringSectionReader &stringReader;
const ResourceSectionReader &resourceReader;
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
EncodingReader &reader;
uint64_t &bytecodeVersion;
};
class PropertiesSectionReader {
public:
LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData) {
if (sectionData.empty())
return success();
EncodingReader propReader(sectionData, fileLoc);
uint64_t count;
if (failed(propReader.parseVarInt(count)))
return failure();
if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers)))
return failure();
EncodingReader offsetsReader(propertiesBuffers, fileLoc);
offsetTable.reserve(count);
for (auto idx : llvm::seq<int64_t>(0, count)) {
(void)idx;
offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size());
ArrayRef<uint8_t> rawProperties;
uint64_t dataSize;
if (failed(offsetsReader.parseVarInt(dataSize)) ||
failed(offsetsReader.parseBytes(dataSize, rawProperties)))
return failure();
}
if (!offsetsReader.empty())
return offsetsReader.emitError()
<< "Broken properties section: didn't exhaust the offsets table";
return success();
}
LogicalResult read(Location fileLoc, DialectReader &dialectReader,
OperationName *opName, OperationState &opState) const {
uint64_t propertiesIdx;
if (failed(dialectReader.readVarInt(propertiesIdx)))
return failure();
if (propertiesIdx >= offsetTable.size())
return dialectReader.emitError("Properties idx out-of-bound for ")
<< opName->getStringRef();
size_t propertiesOffset = offsetTable[propertiesIdx];
if (propertiesIdx >= propertiesBuffers.size())
return dialectReader.emitError("Properties offset out-of-bound for ")
<< opName->getStringRef();
ArrayRef<char> rawProperties;
{
EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset),
fileLoc);
if (failed(
dialectReader.withEncodingReader(reader).readBlob(rawProperties)))
return failure();
}
EncodingReader reader(
StringRef(rawProperties.begin(), rawProperties.size()), fileLoc);
DialectReader propReader = dialectReader.withEncodingReader(reader);
auto *iface = opName->getInterface<BytecodeOpInterface>();
if (iface)
return iface->readProperties(propReader, opState);
if (opName->isRegistered())
return propReader.emitError(
"has properties but missing BytecodeOpInterface for ")
<< opName->getStringRef();
return propReader.readAttribute(opState.propertiesAttr);
}
private:
ArrayRef<uint8_t> propertiesBuffers;
SmallVector<int64_t> offsetTable;
};
}
LogicalResult AttrTypeReader::initialize(
MutableArrayRef<std::unique_ptr<BytecodeDialect>> dialects,
ArrayRef<uint8_t> sectionData, ArrayRef<uint8_t> offsetSectionData) {
EncodingReader offsetReader(offsetSectionData, fileLoc);
uint64_t numAttributes, numTypes;
if (failed(offsetReader.parseVarInt(numAttributes)) ||
failed(offsetReader.parseVarInt(numTypes)))
return failure();
attributes.resize(numAttributes);
types.resize(numTypes);
uint64_t currentOffset = 0;
auto parseEntries = [&](auto &&range) {
size_t currentIndex = 0, endIndex = range.size();
auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
auto &entry = range[currentIndex++];
uint64_t entrySize;
if (failed(offsetReader.parseVarIntWithFlag(entrySize,
entry.hasCustomEncoding)))
return failure();
if (currentOffset + entrySize > sectionData.size()) {
return offsetReader.emitError(
"Attribute or Type entry offset points past the end of section");
}
entry.data = sectionData.slice(currentOffset, entrySize);
entry.dialect = dialect;
currentOffset += entrySize;
return success();
};
while (currentIndex != endIndex)
if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn)))
return failure();
return success();
};
if (failed(parseEntries(attributes)) || failed(parseEntries(types)))
return failure();
if (!offsetReader.empty()) {
return offsetReader.emitError(
"unexpected trailing data in the Attribute/Type offset section");
}
return success();
}
template <typename T>
T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
StringRef entryType) {
if (index >= entries.size()) {
emitError(fileLoc) << "invalid " << entryType << " index: " << index;
return {};
}
Entry<T> &entry = entries[index];
if (entry.entry)
return entry.entry;
EncodingReader reader(entry.data, fileLoc);
if (entry.hasCustomEncoding) {
if (failed(parseCustomEntry(entry, reader, entryType)))
return T();
} else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
return T();
}
if (!reader.empty()) {
reader.emitError("unexpected trailing bytes after " + entryType + " entry");
return T();
}
return entry.entry;
}
template <typename T>
LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
StringRef entryType) {
StringRef asmStr;
if (failed(reader.parseNullTerminatedString(asmStr)))
return failure();
size_t numRead = 0;
MLIRContext *context = fileLoc->getContext();
if constexpr (std::is_same_v<T, Type>)
result =
::parseType(asmStr, context, &numRead, true);
else
result = ::parseAttribute(asmStr, context, Type(), &numRead,
true);
if (!result)
return failure();
if (numRead != asmStr.size()) {
return reader.emitError("trailing characters found after ", entryType,
" assembly format: ", asmStr.drop_front(numRead));
}
return success();
}
template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
StringRef entryType) {
DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
reader, bytecodeVersion);
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
if constexpr (std::is_same_v<T, Type>) {
for (const auto &callback :
parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) {
if (failed(
callback->read(dialectReader, entry.dialect->name, entry.entry)))
return failure();
if (!!entry.entry)
return success();
reader = EncodingReader(entry.data, reader.getLoc());
}
} else {
for (const auto &callback :
parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) {
if (failed(
callback->read(dialectReader, entry.dialect->name, entry.entry)))
return failure();
if (!!entry.entry)
return success();
reader = EncodingReader(entry.data, reader.getLoc());
}
}
if (!entry.dialect->interface) {
return reader.emitError("dialect '", entry.dialect->name,
"' does not implement the bytecode interface");
}
if constexpr (std::is_same_v<T, Type>)
entry.entry = entry.dialect->interface->readType(dialectReader);
else
entry.entry = entry.dialect->interface->readAttribute(dialectReader);
return success(!!entry.entry);
}
class mlir::BytecodeReader::Impl {
struct RegionReadState;
using LazyLoadableOpsInfo =
std::list<std::pair<Operation *, RegionReadState>>;
using LazyLoadableOpsMap =
DenseMap<Operation *, LazyLoadableOpsInfo::iterator>;
public:
Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading,
llvm::MemoryBufferRef buffer,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
: config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
attrTypeReader(stringReader, resourceReader, dialectsMap, version,
fileLoc, config),
forwardRefOpState(UnknownLoc::get(config.getContext()),
"builtin.unrealized_conversion_cast", ValueRange(),
NoneType::get(config.getContext())),
buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
LogicalResult read(Block *block,
llvm::function_ref<bool(Operation *)> lazyOps);
int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); }
bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); }
LogicalResult
materialize(Operation *op,
llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
this->lazyOpsCallback = lazyOpsCallback;
auto resetlazyOpsCallback =
llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
auto it = lazyLoadableOpsMap.find(op);
assert(it != lazyLoadableOpsMap.end() &&
"materialize called on non-materializable op");
return materialize(it);
}
LogicalResult materializeAll() {
while (!lazyLoadableOpsMap.empty()) {
if (failed(materialize(lazyLoadableOpsMap.begin())))
return failure();
}
return success();
}
LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) {
while (!lazyLoadableOps.empty()) {
Operation *op = lazyLoadableOps.begin()->first;
if (shouldMaterialize(op)) {
if (failed(materialize(lazyLoadableOpsMap.find(op))))
return failure();
continue;
}
op->dropAllReferences();
op->erase();
lazyLoadableOps.pop_front();
lazyLoadableOpsMap.erase(op);
}
return success();
}
private:
LogicalResult materialize(LazyLoadableOpsMap::iterator it) {
assert(it != lazyLoadableOpsMap.end() &&
"materialize called on non-materializable op");
valueScopes.emplace_back();
std::vector<RegionReadState> regionStack;
regionStack.push_back(std::move(it->getSecond()->second));
lazyLoadableOps.erase(it->getSecond());
lazyLoadableOpsMap.erase(it);
while (!regionStack.empty())
if (failed(parseRegions(regionStack, regionStack.back())))
return failure();
return success();
}
MLIRContext *getContext() const { return config.getContext(); }
LogicalResult parseVersion(EncodingReader &reader);
LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData);
FailureOr<OperationName> parseOpName(EncodingReader &reader,
std::optional<bool> &wasRegistered);
template <typename T>
LogicalResult parseAttribute(EncodingReader &reader, T &result) {
return attrTypeReader.parseAttribute(reader, result);
}
LogicalResult parseType(EncodingReader &reader, Type &result) {
return attrTypeReader.parseType(reader, result);
}
LogicalResult
parseResourceSection(EncodingReader &reader,
std::optional<ArrayRef<uint8_t>> resourceData,
std::optional<ArrayRef<uint8_t>> resourceOffsetData);
struct RegionReadState {
RegionReadState(Operation *op, EncodingReader *reader,
bool isIsolatedFromAbove)
: RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
bool isIsolatedFromAbove)
: curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
isIsolatedFromAbove(isIsolatedFromAbove) {}
MutableArrayRef<Region>::iterator curRegion, endRegion;
EncodingReader *reader;
std::unique_ptr<EncodingReader> owningReader;
unsigned numValues = 0;
SmallVector<Block *> curBlocks;
Region::iterator curBlock = {};
uint64_t numOpsRemaining = 0;
bool isIsolatedFromAbove = false;
};
LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block);
LogicalResult parseRegions(std::vector<RegionReadState> ®ionStack,
RegionReadState &readState);
FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
bool &isIsolatedFromAbove);
LogicalResult parseRegion(RegionReadState &readState);
LogicalResult parseBlockHeader(EncodingReader &reader,
RegionReadState &readState);
LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
Value parseOperand(EncodingReader &reader);
LogicalResult defineValues(EncodingReader &reader, ValueRange values);
Value createForwardRef();
struct UseListOrderStorage {
UseListOrderStorage(bool isIndexPairEncoding,
SmallVector<unsigned, 4> &&indices)
: indices(std::move(indices)),
isIndexPairEncoding(isIndexPairEncoding){};
SmallVector<unsigned, 4> indices;
bool isIndexPairEncoding;
};
using UseListMapT = DenseMap<unsigned, UseListOrderStorage>;
FailureOr<UseListMapT> parseUseListOrderForRange(EncodingReader &reader,
uint64_t rangeSize);
LogicalResult sortUseListOrder(Value value);
LogicalResult processUseLists(Operation *topLevelOp);
struct ValueScope {
void push(RegionReadState &readState) {
nextValueIDs.push_back(values.size());
values.resize(values.size() + readState.numValues);
}
void pop(RegionReadState &readState) {
values.resize(values.size() - readState.numValues);
nextValueIDs.pop_back();
}
std::vector<Value> values;
SmallVector<unsigned, 4> nextValueIDs;
};
const ParserConfig &config;
Location fileLoc;
bool lazyLoading;
LazyLoadableOpsInfo lazyLoadableOps;
LazyLoadableOpsMap lazyLoadableOpsMap;
llvm::function_ref<bool(Operation *)> lazyOpsCallback;
AttrTypeReader attrTypeReader;
uint64_t version = 0;
StringRef producer;
SmallVector<std::unique_ptr<BytecodeDialect>> dialects;
llvm::StringMap<BytecodeDialect *> dialectsMap;
SmallVector<BytecodeOperationName> opNames;
ResourceSectionReader resourceReader;
DenseMap<void *, UseListOrderStorage> valueToUseListMap;
StringSectionReader stringReader;
PropertiesSectionReader propertiesReader;
std::vector<ValueScope> valueScopes;
DenseMap<Operation *, unsigned> operationIDs;
Block forwardRefOps;
Block openForwardRefOps;
OperationState forwardRefOpState;
llvm::MemoryBufferRef buffer;
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
};
LogicalResult BytecodeReader::Impl::read(
Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
EncodingReader reader(buffer.getBuffer(), fileLoc);
this->lazyOpsCallback = lazyOpsCallback;
auto resetlazyOpsCallback =
llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
if (failed(reader.skipBytes(StringRef("ML\xefR").size())))
return failure();
if (failed(parseVersion(reader)) ||
failed(reader.parseNullTerminatedString(producer)))
return failure();
ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) {
diag.attachNote() << "in bytecode version " << version
<< " produced by: " << producer;
return failure();
});
std::optional<ArrayRef<uint8_t>>
sectionDatas[bytecode::Section::kNumSections];
while (!reader.empty()) {
bytecode::Section::ID sectionID;
ArrayRef<uint8_t> sectionData;
if (failed(reader.parseSection(sectionID, sectionData)))
return failure();
if (sectionDatas[sectionID]) {
return reader.emitError("duplicate top-level section: ",
::toString(sectionID));
}
sectionDatas[sectionID] = sectionData;
}
for (int i = 0; i < bytecode::Section::kNumSections; ++i) {
bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) {
return reader.emitError("missing data for top-level section: ",
::toString(sectionID));
}
}
if (failed(stringReader.initialize(
fileLoc, *sectionDatas[bytecode::Section::kString])))
return failure();
if (sectionDatas[bytecode::Section::kProperties] &&
failed(propertiesReader.initialize(
fileLoc, *sectionDatas[bytecode::Section::kProperties])))
return failure();
if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect])))
return failure();
if (failed(parseResourceSection(
reader, sectionDatas[bytecode::Section::kResource],
sectionDatas[bytecode::Section::kResourceOffset])))
return failure();
if (failed(attrTypeReader.initialize(
dialects, *sectionDatas[bytecode::Section::kAttrType],
*sectionDatas[bytecode::Section::kAttrTypeOffset])))
return failure();
return parseIRSection(*sectionDatas[bytecode::Section::kIR], block);
}
LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
if (failed(reader.parseVarInt(version)))
return failure();
uint64_t currentVersion = bytecode::kVersion;
uint64_t minSupportedVersion = bytecode::kMinSupportedVersion;
if (version < minSupportedVersion) {
return reader.emitError("bytecode version ", version,
" is older than the current version of ",
currentVersion, ", and upgrade is not supported");
}
if (version > currentVersion) {
return reader.emitError("bytecode version ", version,
" is newer than the current version ",
currentVersion);
}
if (version < bytecode::kLazyLoading)
lazyLoading = false;
return success();
}
LogicalResult BytecodeDialect::load(const DialectReader &reader,
MLIRContext *ctx) {
if (dialect)
return success();
Dialect *loadedDialect = ctx->getOrLoadDialect(name);
if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
return reader.emitError("dialect '")
<< name
<< "' is unknown. If this is intended, please call "
"allowUnregisteredDialects() on the MLIRContext, or use "
"-allow-unregistered-dialect with the MLIR tool used.";
}
dialect = loadedDialect;
if (loadedDialect)
interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
if (!versionBuffer.empty()) {
if (!interface)
return reader.emitError("dialect '")
<< name
<< "' does not implement the bytecode interface, "
"but found a version entry";
EncodingReader encReader(versionBuffer, reader.getLoc());
DialectReader versionReader = reader.withEncodingReader(encReader);
loadedVersion = interface->readVersion(versionReader);
if (!loadedVersion)
return failure();
}
return success();
}
LogicalResult
BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
EncodingReader sectionReader(sectionData, fileLoc);
uint64_t numDialects;
if (failed(sectionReader.parseVarInt(numDialects)))
return failure();
dialects.resize(numDialects);
for (uint64_t i = 0; i < numDialects; ++i) {
dialects[i] = std::make_unique<BytecodeDialect>();
if (version < bytecode::kDialectVersioning) {
if (failed(stringReader.parseString(sectionReader, dialects[i]->name)))
return failure();
continue;
}
uint64_t dialectNameIdx;
bool versionAvailable;
if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx,
versionAvailable)))
return failure();
if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
dialects[i]->name)))
return failure();
if (versionAvailable) {
bytecode::Section::ID sectionID;
if (failed(sectionReader.parseSection(sectionID,
dialects[i]->versionBuffer)))
return failure();
if (sectionID != bytecode::Section::kDialectVersions) {
emitError(fileLoc, "expected dialect version section");
return failure();
}
}
dialectsMap[dialects[i]->name] = dialects[i].get();
}
auto parseOpName = [&](BytecodeDialect *dialect) {
StringRef opName;
std::optional<bool> wasRegistered;
if (version < bytecode::kNativePropertiesEncoding) {
if (failed(stringReader.parseString(sectionReader, opName)))
return failure();
} else {
bool wasRegisteredFlag;
if (failed(stringReader.parseStringWithFlag(sectionReader, opName,
wasRegisteredFlag)))
return failure();
wasRegistered = wasRegisteredFlag;
}
opNames.emplace_back(dialect, opName, wasRegistered);
return success();
};
if (version >= bytecode::kElideUnknownBlockArgLocation) {
uint64_t numOps;
if (failed(sectionReader.parseVarInt(numOps)))
return failure();
opNames.reserve(numOps);
}
while (!sectionReader.empty())
if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName)))
return failure();
return success();
}
FailureOr<OperationName>
BytecodeReader::Impl::parseOpName(EncodingReader &reader,
std::optional<bool> &wasRegistered) {
BytecodeOperationName *opName = nullptr;
if (failed(parseEntry(reader, opNames, opName, "operation name")))
return failure();
wasRegistered = opName->wasRegistered;
if (!opName->opName) {
if (opName->name.empty()) {
opName->opName.emplace(opName->dialect->name, getContext());
} else {
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
dialectsMap, reader, version);
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
getContext());
}
}
return *opName->opName;
}
LogicalResult BytecodeReader::Impl::parseResourceSection(
EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
if (resourceData.has_value() != resourceOffsetData.has_value()) {
if (resourceOffsetData)
return emitError(fileLoc, "unexpected resource offset section when "
"resource section is not present");
return emitError(
fileLoc,
"expected resource offset section when resource section is present");
}
if (!resourceData)
return success();
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
dialectsMap, reader, version);
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
*resourceData, *resourceOffsetData,
dialectReader, bufferOwnerRef);
}
FailureOr<BytecodeReader::Impl::UseListMapT>
BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader,
uint64_t numResults) {
BytecodeReader::Impl::UseListMapT map;
uint64_t numValuesToRead = 1;
if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead)))
return failure();
for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) {
uint64_t resultIdx = 0;
if (numResults > 1 && failed(reader.parseVarInt(resultIdx)))
return failure();
uint64_t numValues;
bool indexPairEncoding;
if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding)))
return failure();
SmallVector<unsigned, 4> useListOrders;
for (size_t idx = 0; idx < numValues; idx++) {
uint64_t index;
if (failed(reader.parseVarInt(index)))
return failure();
useListOrders.push_back(index);
}
map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding,
std::move(useListOrders)));
}
return map;
}
LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
if (value.use_empty() || value.hasOneUse())
return success();
bool hasIncomingOrder =
valueToUseListMap.contains(value.getAsOpaquePointer());
bool alreadySorted = true;
auto &firstUse = *value.use_begin();
uint64_t prevID =
bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner()));
llvm::SmallVector<std::pair<unsigned, uint64_t>> currentOrder = {{0, prevID}};
for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) {
uint64_t currentID = bytecode::getUseID(
item.value(), operationIDs.at(item.value().getOwner()));
alreadySorted &= prevID > currentID;
currentOrder.push_back({item.index(), currentID});
prevID = currentID;
}
if (alreadySorted && !hasIncomingOrder)
return success();
if (!alreadySorted)
std::sort(
currentOrder.begin(), currentOrder.end(),
[](auto elem1, auto elem2) { return elem1.second > elem2.second; });
if (!hasIncomingOrder) {
SmallVector<unsigned> shuffle = SmallVector<unsigned>(
llvm::map_range(currentOrder, [&](auto item) { return item.first; }));
value.shuffleUseList(shuffle);
return success();
}
UseListOrderStorage customOrder =
valueToUseListMap.at(value.getAsOpaquePointer());
SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
uint64_t numUses =
std::distance(value.getUses().begin(), value.getUses().end());
if (customOrder.isIndexPairEncoding) {
if (shuffle.size() & 1)
return failure();
SmallVector<unsigned, 4> newShuffle(numUses);
size_t idx = 0;
std::iota(newShuffle.begin(), newShuffle.end(), idx);
for (idx = 0; idx < shuffle.size(); idx += 2)
newShuffle[shuffle[idx]] = shuffle[idx + 1];
shuffle = std::move(newShuffle);
}
DenseSet<unsigned> set;
uint64_t accumulator = 0;
for (const auto &elem : shuffle) {
if (set.contains(elem))
return failure();
accumulator += elem;
set.insert(elem);
}
if (numUses != shuffle.size() ||
accumulator != (((numUses - 1) * numUses) >> 1))
return failure();
shuffle = SmallVector<unsigned, 4>(llvm::map_range(
currentOrder, [&](auto item) { return shuffle[item.first]; }));
value.shuffleUseList(shuffle);
return success();
}
LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) {
unsigned operationID = 0;
topLevelOp->walk<mlir::WalkOrder::PreOrder>(
[&](Operation *op) { operationIDs.try_emplace(op, operationID++); });
auto blockWalk = topLevelOp->walk([this](Block *block) {
for (auto arg : block->getArguments())
if (failed(sortUseListOrder(arg)))
return WalkResult::interrupt();
return WalkResult::advance();
});
auto resultWalk = topLevelOp->walk([this](Operation *op) {
for (auto result : op->getResults())
if (failed(sortUseListOrder(result)))
return WalkResult::interrupt();
return WalkResult::advance();
});
return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted());
}
LogicalResult
BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
Block *block) {
EncodingReader reader(sectionData, fileLoc);
std::vector<RegionReadState> regionStack;
OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
regionStack.emplace_back(*moduleOp, &reader, true);
regionStack.back().curBlocks.push_back(moduleOp->getBody());
regionStack.back().curBlock = regionStack.back().curRegion->begin();
if (failed(parseBlockHeader(reader, regionStack.back())))
return failure();
valueScopes.emplace_back();
valueScopes.back().push(regionStack.back());
while (!regionStack.empty())
if (failed(parseRegions(regionStack, regionStack.back())))
return failure();
if (!forwardRefOps.empty()) {
return reader.emitError(
"not all forward unresolved forward operand references");
}
if (failed(processUseLists(*moduleOp)))
return reader.emitError(
"parsed use-list orders were invalid and could not be applied");
for (const std::unique_ptr<BytecodeDialect> &byteCodeDialect : dialects) {
if (!byteCodeDialect->loadedVersion)
continue;
if (byteCodeDialect->interface &&
failed(byteCodeDialect->interface->upgradeFromVersion(
*moduleOp, *byteCodeDialect->loadedVersion)))
return failure();
}
if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp)))
return failure();
auto &parsedOps = moduleOp->getBody()->getOperations();
auto &destOps = block->getOperations();
destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end());
return success();
}
LogicalResult
BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
RegionReadState &readState) {
for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
if (readState.curBlock == Region::iterator()) {
if (failed(parseRegion(readState)))
return failure();
if (readState.curRegion->empty())
continue;
}
EncodingReader &reader = *readState.reader;
do {
while (readState.numOpsRemaining--) {
bool isIsolatedFromAbove = false;
FailureOr<Operation *> op =
parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
if (failed(op))
return failure();
if ((*op)->getNumRegions()) {
RegionReadState childState(*op, &reader, isIsolatedFromAbove);
if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) {
bytecode::Section::ID sectionID;
ArrayRef<uint8_t> sectionData;
if (failed(reader.parseSection(sectionID, sectionData)))
return failure();
if (sectionID != bytecode::Section::kIR)
return emitError(fileLoc, "expected IR section for region");
childState.owningReader =
std::make_unique<EncodingReader>(sectionData, fileLoc);
childState.reader = childState.owningReader.get();
if (lazyLoading && (!lazyOpsCallback || !lazyOpsCallback(*op))) {
lazyLoadableOps.emplace_back(*op, std::move(childState));
lazyLoadableOpsMap.try_emplace(*op,
std::prev(lazyLoadableOps.end()));
continue;
}
}
regionStack.push_back(std::move(childState));
if (isIsolatedFromAbove)
valueScopes.emplace_back();
return success();
}
}
if (++readState.curBlock == readState.curRegion->end())
break;
if (failed(parseBlockHeader(reader, readState)))
return failure();
} while (true);
readState.curBlock = {};
valueScopes.back().pop(readState);
}
if (readState.isIsolatedFromAbove) {
assert(!valueScopes.empty() && "Expect a valueScope after reading region");
valueScopes.pop_back();
}
assert(!regionStack.empty() && "Expect a regionStack after reading region");
regionStack.pop_back();
return success();
}
FailureOr<Operation *>
BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
bool &isIsolatedFromAbove) {
std::optional<bool> wasRegistered;
FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
if (failed(opName))
return failure();
uint8_t opMask;
if (failed(reader.parseByte(opMask)))
return failure();
LocationAttr opLoc;
if (failed(parseAttribute(reader, opLoc)))
return failure();
OperationState opState(opLoc, *opName);
if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
DictionaryAttr dictAttr;
if (failed(parseAttribute(reader, dictAttr)))
return failure();
opState.attributes = dictAttr;
}
if (opMask & bytecode::OpEncodingMask::kHasProperties) {
if (!wasRegistered)
return emitError(fileLoc,
"Unexpected missing `wasRegistered` opname flag at "
"bytecode version ")
<< version << " with properties.";
if (wasRegistered) {
DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
dialectsMap, reader, version);
if (failed(
propertiesReader.read(fileLoc, dialectReader, &*opName, opState)))
return failure();
} else {
if (failed(parseAttribute(reader, opState.propertiesAttr)))
return failure();
}
}
if (opMask & bytecode::OpEncodingMask::kHasResults) {
uint64_t numResults;
if (failed(reader.parseVarInt(numResults)))
return failure();
opState.types.resize(numResults);
for (int i = 0, e = numResults; i < e; ++i)
if (failed(parseType(reader, opState.types[i])))
return failure();
}
if (opMask & bytecode::OpEncodingMask::kHasOperands) {
uint64_t numOperands;
if (failed(reader.parseVarInt(numOperands)))
return failure();
opState.operands.resize(numOperands);
for (int i = 0, e = numOperands; i < e; ++i)
if (!(opState.operands[i] = parseOperand(reader)))
return failure();
}
if (opMask & bytecode::OpEncodingMask::kHasSuccessors) {
uint64_t numSuccs;
if (failed(reader.parseVarInt(numSuccs)))
return failure();
opState.successors.resize(numSuccs);
for (int i = 0, e = numSuccs; i < e; ++i) {
if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i],
"successor")))
return failure();
}
}
std::optional<UseListMapT> resultIdxToUseListMap = std::nullopt;
if (version >= bytecode::kUseListOrdering &&
(opMask & bytecode::OpEncodingMask::kHasUseListOrders)) {
size_t numResults = opState.types.size();
auto parseResult = parseUseListOrderForRange(reader, numResults);
if (failed(parseResult))
return failure();
resultIdxToUseListMap = std::move(*parseResult);
}
if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) {
uint64_t numRegions;
if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove)))
return failure();
opState.regions.reserve(numRegions);
for (int i = 0, e = numRegions; i < e; ++i)
opState.regions.push_back(std::make_unique<Region>());
}
Operation *op = Operation::create(opState);
readState.curBlock->push_back(op);
if (readState.numValues && op->getNumResults() &&
failed(defineValues(reader, op->getResults())))
return failure();
if (resultIdxToUseListMap.has_value()) {
for (size_t idx = 0; idx < op->getNumResults(); idx++) {
if (resultIdxToUseListMap->contains(idx)) {
valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(),
resultIdxToUseListMap->at(idx));
}
}
}
return op;
}
LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
EncodingReader &reader = *readState.reader;
uint64_t numBlocks;
if (failed(reader.parseVarInt(numBlocks)))
return failure();
if (numBlocks == 0)
return success();
uint64_t numValues;
if (failed(reader.parseVarInt(numValues)))
return failure();
readState.numValues = numValues;
readState.curBlocks.clear();
readState.curBlocks.reserve(numBlocks);
for (uint64_t i = 0; i < numBlocks; ++i) {
readState.curBlocks.push_back(new Block());
readState.curRegion->push_back(readState.curBlocks.back());
}
valueScopes.back().push(readState);
readState.curBlock = readState.curRegion->begin();
return parseBlockHeader(reader, readState);
}
LogicalResult
BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
RegionReadState &readState) {
bool hasArgs;
if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs)))
return failure();
if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock)))
return failure();
if (version < bytecode::kUseListOrdering)
return success();
uint8_t hasUseListOrders = 0;
if (hasArgs && failed(reader.parseByte(hasUseListOrders)))
return failure();
if (!hasUseListOrders)
return success();
Block &blk = *readState.curBlock;
auto argIdxToUseListMap =
parseUseListOrderForRange(reader, blk.getNumArguments());
if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty())
return failure();
for (size_t idx = 0; idx < blk.getNumArguments(); idx++)
if (argIdxToUseListMap->contains(idx))
valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(),
argIdxToUseListMap->at(idx));
return success();
}
LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
Block *block) {
uint64_t numArgs;
if (failed(reader.parseVarInt(numArgs)))
return failure();
SmallVector<Type> argTypes;
SmallVector<Location> argLocs;
argTypes.reserve(numArgs);
argLocs.reserve(numArgs);
Location unknownLoc = UnknownLoc::get(config.getContext());
while (numArgs--) {
Type argType;
LocationAttr argLoc = unknownLoc;
if (version >= bytecode::kElideUnknownBlockArgLocation) {
uint64_t typeIdx;
bool hasLoc;
if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) ||
!(argType = attrTypeReader.resolveType(typeIdx)))
return failure();
if (hasLoc && failed(parseAttribute(reader, argLoc)))
return failure();
} else {
if (failed(parseType(reader, argType)) ||
failed(parseAttribute(reader, argLoc)))
return failure();
}
argTypes.push_back(argType);
argLocs.push_back(argLoc);
}
block->addArguments(argTypes, argLocs);
return defineValues(reader, block->getArguments());
}
Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
std::vector<Value> &values = valueScopes.back().values;
Value *value = nullptr;
if (failed(parseEntry(reader, values, value, "value")))
return Value();
if (!*value)
*value = createForwardRef();
return *value;
}
LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
ValueRange newValues) {
ValueScope &valueScope = valueScopes.back();
std::vector<Value> &values = valueScope.values;
unsigned &valueID = valueScope.nextValueIDs.back();
unsigned valueIDEnd = valueID + newValues.size();
if (valueIDEnd > values.size()) {
return reader.emitError(
"value index range was outside of the expected range for "
"the parent region, got [",
valueID, ", ", valueIDEnd, "), but the maximum index was ",
values.size() - 1);
}
for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
Value newValue = newValues[i];
if (Value oldValue = std::exchange(values[valueID], newValue)) {
Operation *forwardRefOp = oldValue.getDefiningOp();
assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps &&
"value index was already defined?");
oldValue.replaceAllUsesWith(newValue);
forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end());
}
}
return success();
}
Value BytecodeReader::Impl::createForwardRef() {
if (!openForwardRefOps.empty()) {
Operation *op = &openForwardRefOps.back();
op->moveBefore(&forwardRefOps, forwardRefOps.end());
} else {
forwardRefOps.push_back(Operation::create(forwardRefOpState));
}
return forwardRefOps.back().getResult(0);
}
BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); }
BytecodeReader::BytecodeReader(
llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
Location sourceFileLoc =
FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
0, 0);
impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer,
bufferOwnerRef);
}
LogicalResult BytecodeReader::readTopLevel(
Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
return impl->read(block, lazyOpsCallback);
}
int64_t BytecodeReader::getNumOpsToMaterialize() const {
return impl->getNumOpsToMaterialize();
}
bool BytecodeReader::isMaterializable(Operation *op) {
return impl->isMaterializable(op);
}
LogicalResult BytecodeReader::materialize(
Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
return impl->materialize(op, lazyOpsCallback);
}
LogicalResult
BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) {
return impl->finalize(shouldMaterialize);
}
bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
return buffer.getBuffer().starts_with("ML\xefR");
}
static LogicalResult
readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block,
const ParserConfig &config,
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
Location sourceFileLoc =
FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
0, 0);
if (!isBytecode(buffer)) {
return emitError(sourceFileLoc,
"input buffer is not an MLIR bytecode file");
}
BytecodeReader::Impl reader(sourceFileLoc, config, false,
buffer, bufferOwnerRef);
return reader.read(block, nullptr);
}
LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
const ParserConfig &config) {
return readBytecodeFileImpl(buffer, block, config, {});
}
LogicalResult
mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
Block *block, const ParserConfig &config) {
return readBytecodeFileImpl(
*sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config,
sourceMgr);
}