#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Mutex.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/Signals.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
using namespace mlir;
using namespace mlir::detail;
DiagnosticArgument::DiagnosticArgument(Attribute attr)
: kind(DiagnosticArgumentKind::Attribute),
opaqueVal(reinterpret_cast<intptr_t>(attr.getAsOpaquePointer())) {}
DiagnosticArgument::DiagnosticArgument(Type val)
: kind(DiagnosticArgumentKind::Type),
opaqueVal(reinterpret_cast<intptr_t>(val.getAsOpaquePointer())) {}
Attribute DiagnosticArgument::getAsAttribute() const {
assert(getKind() == DiagnosticArgumentKind::Attribute);
return Attribute::getFromOpaquePointer(
reinterpret_cast<const void *>(opaqueVal));
}
Type DiagnosticArgument::getAsType() const {
assert(getKind() == DiagnosticArgumentKind::Type);
return Type::getFromOpaquePointer(reinterpret_cast<const void *>(opaqueVal));
}
void DiagnosticArgument::print(raw_ostream &os) const {
switch (kind) {
case DiagnosticArgumentKind::Attribute:
os << getAsAttribute();
break;
case DiagnosticArgumentKind::Double:
os << getAsDouble();
break;
case DiagnosticArgumentKind::Integer:
os << getAsInteger();
break;
case DiagnosticArgumentKind::String:
os << getAsString();
break;
case DiagnosticArgumentKind::Type:
os << '\'' << getAsType() << '\'';
break;
case DiagnosticArgumentKind::Unsigned:
os << getAsUnsigned();
break;
}
}
static StringRef twineToStrRef(const Twine &val,
std::vector<std::unique_ptr<char[]>> &strings) {
SmallString<64> data;
auto strRef = val.toStringRef(data);
if (strRef.empty())
return strRef;
strings.push_back(std::unique_ptr<char[]>(new char[strRef.size()]));
memcpy(&strings.back()[0], strRef.data(), strRef.size());
return StringRef(&strings.back()[0], strRef.size());
}
Diagnostic &Diagnostic::operator<<(char val) { return *this << Twine(val); }
Diagnostic &Diagnostic::operator<<(const Twine &val) {
arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
return *this;
}
Diagnostic &Diagnostic::operator<<(Twine &&val) {
arguments.push_back(DiagnosticArgument(twineToStrRef(val, strings)));
return *this;
}
Diagnostic &Diagnostic::operator<<(StringAttr val) {
arguments.push_back(DiagnosticArgument(val));
return *this;
}
Diagnostic &Diagnostic::operator<<(OperationName val) {
arguments.push_back(DiagnosticArgument(val.getStringRef()));
return *this;
}
static OpPrintingFlags adjustPrintingFlags(OpPrintingFlags flags,
DiagnosticSeverity severity) {
flags.useLocalScope();
flags.elideLargeElementsAttrs();
if (severity == DiagnosticSeverity::Error)
flags.printGenericOpForm();
return flags;
}
Diagnostic &Diagnostic::operator<<(Operation &op) {
return appendOp(op, OpPrintingFlags());
}
Diagnostic &Diagnostic::appendOp(Operation &op, const OpPrintingFlags &flags) {
std::string str;
llvm::raw_string_ostream os(str);
op.print(os, adjustPrintingFlags(flags, severity));
if (str.find('\n') != std::string::npos)
*this << '\n';
return *this << os.str();
}
Diagnostic &Diagnostic::operator<<(Value val) {
std::string str;
llvm::raw_string_ostream os(str);
val.print(os, adjustPrintingFlags(OpPrintingFlags(), severity));
return *this << os.str();
}
void Diagnostic::print(raw_ostream &os) const {
for (auto &arg : getArguments())
arg.print(os);
}
std::string Diagnostic::str() const {
std::string str;
llvm::raw_string_ostream os(str);
print(os);
return os.str();
}
Diagnostic &Diagnostic::attachNote(std::optional<Location> noteLoc) {
assert(severity != DiagnosticSeverity::Note &&
"cannot attach a note to a note");
if (!noteLoc)
noteLoc = loc;
notes.push_back(
std::make_unique<Diagnostic>(*noteLoc, DiagnosticSeverity::Note));
return *notes.back();
}
Diagnostic::operator LogicalResult() const { return failure(); }
InFlightDiagnostic::operator LogicalResult() const {
return failure(isActive());
}
void InFlightDiagnostic::report() {
if (isInFlight()) {
owner->emit(std::move(*impl));
owner = nullptr;
}
impl.reset();
}
void InFlightDiagnostic::abandon() { owner = nullptr; }
namespace mlir {
namespace detail {
struct DiagnosticEngineImpl {
void emit(Diagnostic &&diag);
llvm::sys::SmartMutex<true> mutex;
llvm::SmallMapVector<DiagnosticEngine::HandlerID, DiagnosticEngine::HandlerTy,
2>
handlers;
DiagnosticEngine::HandlerID uniqueHandlerId = 1;
};
}
}
void DiagnosticEngineImpl::emit(Diagnostic &&diag) {
llvm::sys::SmartScopedLock<true> lock(mutex);
for (auto &handlerIt : llvm::reverse(handlers))
if (succeeded(handlerIt.second(diag)))
return;
if (diag.getSeverity() != DiagnosticSeverity::Error)
return;
auto &os = llvm::errs();
if (!llvm::isa<UnknownLoc>(diag.getLocation()))
os << diag.getLocation() << ": ";
os << "error: ";
os << diag << '\n';
os.flush();
}
DiagnosticEngine::DiagnosticEngine() : impl(new DiagnosticEngineImpl()) {}
DiagnosticEngine::~DiagnosticEngine() = default;
auto DiagnosticEngine::registerHandler(HandlerTy handler) -> HandlerID {
llvm::sys::SmartScopedLock<true> lock(impl->mutex);
auto uniqueID = impl->uniqueHandlerId++;
impl->handlers.insert({uniqueID, std::move(handler)});
return uniqueID;
}
void DiagnosticEngine::eraseHandler(HandlerID handlerID) {
llvm::sys::SmartScopedLock<true> lock(impl->mutex);
impl->handlers.erase(handlerID);
}
void DiagnosticEngine::emit(Diagnostic &&diag) {
assert(diag.getSeverity() != DiagnosticSeverity::Note &&
"notes should not be emitted directly");
impl->emit(std::move(diag));
}
static InFlightDiagnostic
emitDiag(Location location, DiagnosticSeverity severity, const Twine &message) {
MLIRContext *ctx = location->getContext();
auto &diagEngine = ctx->getDiagEngine();
auto diag = diagEngine.emit(location, severity);
if (!message.isTriviallyEmpty())
diag << message;
if (ctx->shouldPrintStackTraceOnDiagnostic()) {
std::string bt;
{
llvm::raw_string_ostream stream(bt);
llvm::sys::PrintStackTrace(stream);
}
if (!bt.empty())
diag.attachNote() << "diagnostic emitted with trace:\n" << bt;
}
return diag;
}
InFlightDiagnostic mlir::emitError(Location loc) { return emitError(loc, {}); }
InFlightDiagnostic mlir::emitError(Location loc, const Twine &message) {
return emitDiag(loc, DiagnosticSeverity::Error, message);
}
InFlightDiagnostic mlir::emitWarning(Location loc) {
return emitWarning(loc, {});
}
InFlightDiagnostic mlir::emitWarning(Location loc, const Twine &message) {
return emitDiag(loc, DiagnosticSeverity::Warning, message);
}
InFlightDiagnostic mlir::emitRemark(Location loc) {
return emitRemark(loc, {});
}
InFlightDiagnostic mlir::emitRemark(Location loc, const Twine &message) {
return emitDiag(loc, DiagnosticSeverity::Remark, message);
}
ScopedDiagnosticHandler::~ScopedDiagnosticHandler() {
if (handlerID)
ctx->getDiagEngine().eraseHandler(handlerID);
}
namespace mlir {
namespace detail {
struct SourceMgrDiagnosticHandlerImpl {
unsigned getSourceMgrBufferIDForFile(llvm::SourceMgr &mgr,
StringRef filename) {
auto bufferIt = filenameToBufId.find(filename);
if (bufferIt != filenameToBufId.end())
return bufferIt->second;
for (unsigned i = 1, e = mgr.getNumBuffers() + 1; i != e; ++i) {
auto *buf = mgr.getMemoryBuffer(i);
if (buf->getBufferIdentifier() == filename)
return filenameToBufId[filename] = i;
}
std::string ignored;
unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
filenameToBufId[filename] = id;
return id;
}
llvm::StringMap<unsigned> filenameToBufId;
};
}
}
static std::optional<CallSiteLoc> getCallSiteLoc(Location loc) {
if (dyn_cast<NameLoc>(loc))
return getCallSiteLoc(cast<NameLoc>(loc).getChildLoc());
if (auto callLoc = dyn_cast<CallSiteLoc>(loc))
return callLoc;
if (dyn_cast<FusedLoc>(loc)) {
for (auto subLoc : cast<FusedLoc>(loc).getLocations()) {
if (auto callLoc = getCallSiteLoc(subLoc)) {
return callLoc;
}
}
return std::nullopt;
}
return std::nullopt;
}
static llvm::SourceMgr::DiagKind getDiagKind(DiagnosticSeverity kind) {
switch (kind) {
case DiagnosticSeverity::Note:
return llvm::SourceMgr::DK_Note;
case DiagnosticSeverity::Warning:
return llvm::SourceMgr::DK_Warning;
case DiagnosticSeverity::Error:
return llvm::SourceMgr::DK_Error;
case DiagnosticSeverity::Remark:
return llvm::SourceMgr::DK_Remark;
}
llvm_unreachable("Unknown DiagnosticSeverity");
}
SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
llvm::SourceMgr &mgr, MLIRContext *ctx, raw_ostream &os,
ShouldShowLocFn &&shouldShowLocFn)
: ScopedDiagnosticHandler(ctx), mgr(mgr), os(os),
shouldShowLocFn(std::move(shouldShowLocFn)),
impl(new SourceMgrDiagnosticHandlerImpl()) {
setHandler([this](Diagnostic &diag) { emitDiagnostic(diag); });
}
SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
llvm::SourceMgr &mgr, MLIRContext *ctx, ShouldShowLocFn &&shouldShowLocFn)
: SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs(),
std::move(shouldShowLocFn)) {}
SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() = default;
void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
DiagnosticSeverity kind,
bool displaySourceLine) {
auto fileLoc = loc->findInstanceOf<FileLineColLoc>();
if (!fileLoc) {
std::string str;
llvm::raw_string_ostream strOS(str);
if (!llvm::isa<UnknownLoc>(loc))
strOS << loc << ": ";
strOS << message;
return mgr.PrintMessage(os, SMLoc(), getDiagKind(kind), strOS.str());
}
if (displaySourceLine) {
auto smloc = convertLocToSMLoc(fileLoc);
if (smloc.isValid())
return mgr.PrintMessage(os, smloc, getDiagKind(kind), message);
}
std::string locStr;
llvm::raw_string_ostream locOS(locStr);
locOS << fileLoc.getFilename().getValue() << ":" << fileLoc.getLine() << ":"
<< fileLoc.getColumn();
llvm::SMDiagnostic diag(locOS.str(), getDiagKind(kind), message.str());
diag.print(nullptr, os);
}
void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic &diag) {
SmallVector<std::pair<Location, StringRef>> locationStack;
auto addLocToStack = [&](Location loc, StringRef locContext) {
if (std::optional<Location> showableLoc = findLocToShow(loc))
locationStack.emplace_back(*showableLoc, locContext);
};
Location loc = diag.getLocation();
addLocToStack(loc, {});
if (auto callLoc = getCallSiteLoc(loc)) {
loc = callLoc->getCaller();
for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) {
addLocToStack(loc, "called from");
if ((callLoc = getCallSiteLoc(loc)))
loc = callLoc->getCaller();
else
break;
}
}
if (locationStack.empty()) {
emitDiagnostic(diag.getLocation(), diag.str(), diag.getSeverity());
} else {
emitDiagnostic(locationStack.front().first, diag.str(), diag.getSeverity());
for (auto &it : llvm::drop_begin(locationStack))
emitDiagnostic(it.first, it.second, DiagnosticSeverity::Note);
}
for (auto ¬e : diag.getNotes()) {
emitDiagnostic(note.getLocation(), note.str(), note.getSeverity(),
loc != note.getLocation());
loc = note.getLocation();
}
}
const llvm::MemoryBuffer *
SourceMgrDiagnosticHandler::getBufferForFile(StringRef filename) {
if (unsigned id = impl->getSourceMgrBufferIDForFile(mgr, filename))
return mgr.getMemoryBuffer(id);
return nullptr;
}
std::optional<Location>
SourceMgrDiagnosticHandler::findLocToShow(Location loc) {
if (!shouldShowLocFn)
return loc;
if (!shouldShowLocFn(loc))
return std::nullopt;
return TypeSwitch<LocationAttr, std::optional<Location>>(loc)
.Case([&](CallSiteLoc callLoc) -> std::optional<Location> {
return findLocToShow(callLoc.getCallee());
})
.Case([&](FileLineColLoc) -> std::optional<Location> { return loc; })
.Case([&](FusedLoc fusedLoc) -> std::optional<Location> {
for (Location childLoc : fusedLoc.getLocations())
if (std::optional<Location> showableLoc = findLocToShow(childLoc))
return showableLoc;
return std::nullopt;
})
.Case([&](NameLoc nameLoc) -> std::optional<Location> {
return findLocToShow(nameLoc.getChildLoc());
})
.Case([&](OpaqueLoc opaqueLoc) -> std::optional<Location> {
return findLocToShow(opaqueLoc.getFallbackLocation());
})
.Case([](UnknownLoc) -> std::optional<Location> {
return std::nullopt;
});
}
SMLoc SourceMgrDiagnosticHandler::convertLocToSMLoc(FileLineColLoc loc) {
if (loc.getLine() == 0 || loc.getColumn() == 0)
return SMLoc();
unsigned bufferId = impl->getSourceMgrBufferIDForFile(mgr, loc.getFilename());
if (!bufferId)
return SMLoc();
return mgr.FindLocForLineAndColumn(bufferId, loc.getLine(), loc.getColumn());
}
namespace mlir {
namespace detail {
struct ExpectedDiag {
ExpectedDiag(DiagnosticSeverity kind, unsigned lineNo, SMLoc fileLoc,
StringRef substring)
: kind(kind), lineNo(lineNo), fileLoc(fileLoc), substring(substring) {}
LogicalResult emitError(raw_ostream &os, llvm::SourceMgr &mgr,
const Twine &msg) {
SMRange range(fileLoc, SMLoc::getFromPointer(fileLoc.getPointer() +
substring.size()));
mgr.PrintMessage(os, fileLoc, llvm::SourceMgr::DK_Error, msg, range);
return failure();
}
bool match(StringRef str) const {
if (substringRegex)
return substringRegex->match(str);
return str.contains(substring);
}
LogicalResult computeRegex(raw_ostream &os, llvm::SourceMgr &mgr) {
std::string regexStr;
llvm::raw_string_ostream regexOS(regexStr);
StringRef strToProcess = substring;
while (!strToProcess.empty()) {
size_t regexIt = strToProcess.find("{{");
if (regexIt == StringRef::npos) {
regexOS << llvm::Regex::escape(strToProcess);
break;
}
regexOS << llvm::Regex::escape(strToProcess.take_front(regexIt));
strToProcess = strToProcess.drop_front(regexIt + 2);
size_t regexEndIt = strToProcess.find("}}");
if (regexEndIt == StringRef::npos)
return emitError(os, mgr, "found start of regex with no end '}}'");
StringRef regexStr = strToProcess.take_front(regexEndIt);
std::string regexError;
if (!llvm::Regex(regexStr).isValid(regexError))
return emitError(os, mgr, "invalid regex: " + regexError);
regexOS << '(' << regexStr << ')';
strToProcess = strToProcess.drop_front(regexEndIt + 2);
}
substringRegex = llvm::Regex(regexOS.str());
return success();
}
DiagnosticSeverity kind;
unsigned lineNo;
SMLoc fileLoc;
bool matched = false;
StringRef substring;
std::optional<llvm::Regex> substringRegex;
};
struct SourceMgrDiagnosticVerifierHandlerImpl {
SourceMgrDiagnosticVerifierHandlerImpl() : status(success()) {}
std::optional<MutableArrayRef<ExpectedDiag>>
getExpectedDiags(StringRef bufName);
MutableArrayRef<ExpectedDiag>
computeExpectedDiags(raw_ostream &os, llvm::SourceMgr &mgr,
const llvm::MemoryBuffer *buf);
LogicalResult status;
llvm::StringMap<SmallVector<ExpectedDiag, 2>> expectedDiagsPerFile;
llvm::Regex expected =
llvm::Regex("expected-(error|note|remark|warning)(-re)? "
"*(@([+-][0-9]+|above|below))? *{{(.*)}}$");
};
}
}
static StringRef getDiagKindStr(DiagnosticSeverity kind) {
switch (kind) {
case DiagnosticSeverity::Note:
return "note";
case DiagnosticSeverity::Warning:
return "warning";
case DiagnosticSeverity::Error:
return "error";
case DiagnosticSeverity::Remark:
return "remark";
}
llvm_unreachable("Unknown DiagnosticSeverity");
}
std::optional<MutableArrayRef<ExpectedDiag>>
SourceMgrDiagnosticVerifierHandlerImpl::getExpectedDiags(StringRef bufName) {
auto expectedDiags = expectedDiagsPerFile.find(bufName);
if (expectedDiags != expectedDiagsPerFile.end())
return MutableArrayRef<ExpectedDiag>(expectedDiags->second);
return std::nullopt;
}
MutableArrayRef<ExpectedDiag>
SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags(
raw_ostream &os, llvm::SourceMgr &mgr, const llvm::MemoryBuffer *buf) {
if (!buf)
return std::nullopt;
auto &expectedDiags = expectedDiagsPerFile[buf->getBufferIdentifier()];
unsigned lastNonDesignatorLine = 0;
SmallVector<unsigned, 1> designatorsForNextLine;
SmallVector<StringRef, 100> lines;
buf->getBuffer().split(lines, '\n');
for (unsigned lineNo = 0, e = lines.size(); lineNo < e; ++lineNo) {
SmallVector<StringRef, 4> matches;
if (!expected.match(lines[lineNo].rtrim(), &matches)) {
if (!designatorsForNextLine.empty()) {
for (unsigned diagIndex : designatorsForNextLine)
expectedDiags[diagIndex].lineNo = lineNo + 1;
designatorsForNextLine.clear();
}
lastNonDesignatorLine = lineNo;
continue;
}
SMLoc expectedStart = SMLoc::getFromPointer(matches[0].data());
DiagnosticSeverity kind;
if (matches[1] == "error")
kind = DiagnosticSeverity::Error;
else if (matches[1] == "warning")
kind = DiagnosticSeverity::Warning;
else if (matches[1] == "remark")
kind = DiagnosticSeverity::Remark;
else {
assert(matches[1] == "note");
kind = DiagnosticSeverity::Note;
}
ExpectedDiag record(kind, lineNo + 1, expectedStart, matches[5]);
if (!matches[2].empty() && failed(record.computeRegex(os, mgr))) {
status = failure();
continue;
}
StringRef offsetMatch = matches[3];
if (!offsetMatch.empty()) {
offsetMatch = offsetMatch.drop_front(1);
if (offsetMatch[0] == '+' || offsetMatch[0] == '-') {
int offset;
offsetMatch.drop_front().getAsInteger(0, offset);
if (offsetMatch.front() == '+')
record.lineNo += offset;
else
record.lineNo -= offset;
} else if (offsetMatch.consume_front("above")) {
record.lineNo = lastNonDesignatorLine + 1;
} else {
assert(offsetMatch.consume_front("below"));
designatorsForNextLine.push_back(expectedDiags.size());
record.lineNo = e;
}
}
expectedDiags.emplace_back(std::move(record));
}
return expectedDiags;
}
SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
llvm::SourceMgr &srcMgr, MLIRContext *ctx, raw_ostream &out)
: SourceMgrDiagnosticHandler(srcMgr, ctx, out),
impl(new SourceMgrDiagnosticVerifierHandlerImpl()) {
for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
(void)impl->computeExpectedDiags(out, mgr, mgr.getMemoryBuffer(i + 1));
setHandler([&](Diagnostic &diag) {
process(diag);
for (auto ¬e : diag.getNotes())
process(note);
});
}
SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
llvm::SourceMgr &srcMgr, MLIRContext *ctx)
: SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs()) {}
SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() {
(void)verify();
}
LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
for (auto &expectedDiagsPair : impl->expectedDiagsPerFile) {
for (auto &err : expectedDiagsPair.second) {
if (err.matched)
continue;
impl->status =
err.emitError(os, mgr,
"expected " + getDiagKindStr(err.kind) + " \"" +
err.substring + "\" was not produced");
}
}
impl->expectedDiagsPerFile.clear();
return impl->status;
}
void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) {
auto kind = diag.getSeverity();
if (auto fileLoc = diag.getLocation()->findInstanceOf<FileLineColLoc>())
return process(fileLoc, diag.str(), kind);
emitDiagnostic(diag.getLocation(),
"unexpected " + getDiagKindStr(kind) + ": " + diag.str(),
DiagnosticSeverity::Error);
impl->status = failure();
}
void SourceMgrDiagnosticVerifierHandler::process(FileLineColLoc loc,
StringRef msg,
DiagnosticSeverity kind) {
auto diags = impl->getExpectedDiags(loc.getFilename());
if (!diags) {
diags = impl->computeExpectedDiags(os, mgr,
getBufferForFile(loc.getFilename()));
}
ExpectedDiag *nearMiss = nullptr;
unsigned line = loc.getLine();
for (auto &e : *diags) {
if (line == e.lineNo && e.match(msg)) {
if (e.kind == kind) {
e.matched = true;
return;
}
nearMiss = &e;
}
}
if (nearMiss)
mgr.PrintMessage(os, nearMiss->fileLoc, llvm::SourceMgr::DK_Error,
"'" + getDiagKindStr(kind) +
"' diagnostic emitted when expecting a '" +
getDiagKindStr(nearMiss->kind) + "'");
else
emitDiagnostic(loc, "unexpected " + getDiagKindStr(kind) + ": " + msg,
DiagnosticSeverity::Error);
impl->status = failure();
}
namespace mlir {
namespace detail {
struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
struct ThreadDiagnostic {
ThreadDiagnostic(size_t id, Diagnostic diag)
: id(id), diag(std::move(diag)) {}
bool operator<(const ThreadDiagnostic &rhs) const { return id < rhs.id; }
size_t id;
Diagnostic diag;
};
ParallelDiagnosticHandlerImpl(MLIRContext *ctx) : context(ctx) {
handlerID = ctx->getDiagEngine().registerHandler([this](Diagnostic &diag) {
uint64_t tid = llvm::get_threadid();
llvm::sys::SmartScopedLock<true> lock(mutex);
if (!threadToOrderID.count(tid))
return failure();
diagnostics.emplace_back(threadToOrderID[tid], std::move(diag));
return success();
});
}
~ParallelDiagnosticHandlerImpl() override {
context->getDiagEngine().eraseHandler(handlerID);
if (diagnostics.empty())
return;
emitDiagnostics([&](Diagnostic &diag) {
return context->getDiagEngine().emit(std::move(diag));
});
}
void emitDiagnostics(llvm::function_ref<void(Diagnostic &)> emitFn) const {
std::stable_sort(diagnostics.begin(), diagnostics.end());
for (ThreadDiagnostic &diag : diagnostics)
emitFn(diag.diag);
}
void setOrderIDForThread(size_t orderID) {
uint64_t tid = llvm::get_threadid();
llvm::sys::SmartScopedLock<true> lock(mutex);
threadToOrderID[tid] = orderID;
}
void eraseOrderIDForThread() {
uint64_t tid = llvm::get_threadid();
llvm::sys::SmartScopedLock<true> lock(mutex);
threadToOrderID.erase(tid);
}
void print(raw_ostream &os) const override {
if (diagnostics.empty())
return;
os << "In-Flight Diagnostics:\n";
emitDiagnostics([&](const Diagnostic &diag) {
os.indent(4);
if (!llvm::isa<UnknownLoc>(diag.getLocation()))
os << diag.getLocation() << ": ";
switch (diag.getSeverity()) {
case DiagnosticSeverity::Error:
os << "error: ";
break;
case DiagnosticSeverity::Warning:
os << "warning: ";
break;
case DiagnosticSeverity::Note:
os << "note: ";
break;
case DiagnosticSeverity::Remark:
os << "remark: ";
break;
}
os << diag << '\n';
});
}
llvm::sys::SmartMutex<true> mutex;
DenseMap<uint64_t, size_t> threadToOrderID;
mutable std::vector<ThreadDiagnostic> diagnostics;
DiagnosticEngine::HandlerID handlerID = 0;
MLIRContext *context;
};
}
}
ParallelDiagnosticHandler::ParallelDiagnosticHandler(MLIRContext *ctx)
: impl(new ParallelDiagnosticHandlerImpl(ctx)) {}
ParallelDiagnosticHandler::~ParallelDiagnosticHandler() = default;
void ParallelDiagnosticHandler::setOrderIDForThread(size_t orderID) {
impl->setOrderIDForThread(orderID);
}
void ParallelDiagnosticHandler::eraseOrderIDForThread() {
impl->eraseOrderIDForThread();
}