#include "mlir/Support/Timing.h"
#include "mlir/Support/ThreadLocalCache.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/RWMutex.h"
#include "llvm/Support/Threading.h"
#include "llvm/Support/raw_ostream.h"
#include <atomic>
#include <chrono>
#include <optional>
using namespace mlir;
using namespace detail;
using DisplayMode = DefaultTimingManager::DisplayMode;
using OutputFormat = DefaultTimingManager::OutputFormat;
constexpr llvm::StringLiteral kTimingDescription =
"... Execution time report ...";
namespace mlir {
namespace detail {
class TimingManagerImpl {
public:
llvm::BumpPtrAllocator identifierAllocator;
llvm::StringSet<llvm::BumpPtrAllocator &> identifiers;
llvm::sys::SmartRWMutex<true> identifierMutex;
ThreadLocalCache<llvm::StringMap<llvm::StringMapEntry<std::nullopt_t> *>>
localIdentifierCache;
TimingManagerImpl() : identifiers(identifierAllocator) {}
};
}
}
TimingManager::TimingManager() : impl(std::make_unique<TimingManagerImpl>()) {}
TimingManager::~TimingManager() = default;
Timer TimingManager::getRootTimer() {
auto rt = rootTimer();
return rt ? Timer(*this, *rt) : Timer();
}
TimingScope TimingManager::getRootScope() {
return TimingScope(getRootTimer());
}
TimingIdentifier TimingIdentifier::get(StringRef str, TimingManager &tm) {
auto &impl = *tm.impl;
auto *&localEntry = (*impl.localIdentifierCache)[str];
if (localEntry)
return TimingIdentifier(localEntry);
{
llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
auto it = impl.identifiers.find(str);
if (it != impl.identifiers.end()) {
localEntry = &*it;
return TimingIdentifier(localEntry);
}
}
llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
auto it = impl.identifiers.insert(str).first;
localEntry = &*it;
return TimingIdentifier(localEntry);
}
namespace {
class OutputTextStrategy : public OutputStrategy {
public:
OutputTextStrategy(raw_ostream &os) : OutputStrategy(os) {}
void printHeader(const TimeRecord &total) override {
unsigned padding = (80 - kTimingDescription.size()) / 2;
os << "===" << std::string(73, '-') << "===\n";
os.indent(padding) << kTimingDescription << '\n';
os << "===" << std::string(73, '-') << "===\n";
os << llvm::format(" Total Execution Time: %.4f seconds\n\n", total.wall);
if (total.user != total.wall)
os << " ----User Time----";
os << " ----Wall Time---- ----Name----\n";
}
void printFooter() override { os.flush(); }
void printTime(const TimeRecord &time, const TimeRecord &total) override {
if (total.user != total.wall) {
os << llvm::format(" %8.4f (%5.1f%%)", time.user,
100.0 * time.user / total.user);
}
os << llvm::format(" %8.4f (%5.1f%%) ", time.wall,
100.0 * time.wall / total.wall);
}
void printListEntry(StringRef name, const TimeRecord &time,
const TimeRecord &total, bool lastEntry) override {
printTime(time, total);
os << name << "\n";
}
void printTreeEntry(unsigned indent, StringRef name, const TimeRecord &time,
const TimeRecord &total) override {
printTime(time, total);
os.indent(indent) << name << "\n";
}
void printTreeEntryEnd(unsigned indent, bool lastEntry) override {}
};
class OutputJsonStrategy : public OutputStrategy {
public:
OutputJsonStrategy(raw_ostream &os) : OutputStrategy(os) {}
void printHeader(const TimeRecord &total) override { os << "[" << "\n"; }
void printFooter() override {
os << "]" << "\n";
os.flush();
}
void printTime(const TimeRecord &time, const TimeRecord &total) override {
if (total.user != total.wall) {
os << "\"user\": {";
os << "\"duration\": " << llvm::format("%8.4f", time.user) << ", ";
os << "\"percentage\": "
<< llvm::format("%5.1f", 100.0 * time.user / total.user);
os << "}, ";
}
os << "\"wall\": {";
os << "\"duration\": " << llvm::format("%8.4f", time.wall) << ", ";
os << "\"percentage\": "
<< llvm::format("%5.1f", 100.0 * time.wall / total.wall);
os << "}";
}
void printListEntry(StringRef name, const TimeRecord &time,
const TimeRecord &total, bool lastEntry) override {
os << "{";
printTime(time, total);
os << ", \"name\": " << "\"" << name << "\"";
os << "}";
if (!lastEntry)
os << ",";
os << "\n";
}
void printTreeEntry(unsigned indent, StringRef name, const TimeRecord &time,
const TimeRecord &total) override {
os.indent(indent) << "{";
printTime(time, total);
os << ", \"name\": " << "\"" << name << "\"";
os << ", \"passes\": [" << "\n";
}
void printTreeEntryEnd(unsigned indent, bool lastEntry) override {
os.indent(indent) << "{}]";
os << "}";
if (!lastEntry)
os << ",";
os << "\n";
}
};
}
namespace {
class TimerImpl {
public:
using ChildrenMap = llvm::MapVector<const void *, std::unique_ptr<TimerImpl>>;
using AsyncChildrenMap = llvm::DenseMap<uint64_t, ChildrenMap>;
TimerImpl(std::string &&name, std::unique_ptr<OutputStrategy> &output)
: threadId(llvm::get_threadid()), name(name), output(output) {}
void start() { startTime = std::chrono::steady_clock::now(); }
void stop() {
auto newTime = std::chrono::steady_clock::now() - startTime;
wallTime += newTime;
userTime += newTime;
}
TimerImpl *nest(const void *id, function_ref<std::string()> nameBuilder) {
auto tid = llvm::get_threadid();
if (tid == threadId)
return nestTail(children[id], nameBuilder);
std::unique_lock<std::mutex> lock(asyncMutex);
return nestTail(asyncChildren[tid][id], nameBuilder);
}
TimerImpl *nestTail(std::unique_ptr<TimerImpl> &child,
function_ref<std::string()> nameBuilder) {
if (!child)
child = std::make_unique<TimerImpl>(nameBuilder(), output);
return child.get();
}
void finalize() {
addAsyncUserTime();
mergeAsyncChildren();
}
std::chrono::nanoseconds addAsyncUserTime() {
auto added = std::chrono::nanoseconds(0);
for (auto &child : children)
added += child.second->addAsyncUserTime();
for (auto &thread : asyncChildren) {
for (auto &child : thread.second) {
child.second->addAsyncUserTime();
added += child.second->userTime;
}
}
userTime += added;
return added;
}
void mergeAsyncChildren() {
for (auto &child : children)
child.second->mergeAsyncChildren();
mergeChildren(std::move(asyncChildren));
assert(asyncChildren.empty());
}
void mergeChildren(ChildrenMap &&other) {
if (children.empty()) {
children = std::move(other);
for (auto &child : children)
child.second->mergeAsyncChildren();
} else {
for (auto &child : other)
mergeChild(child.first, std::move(child.second));
other.clear();
}
}
void mergeChildren(AsyncChildrenMap &&other) {
for (auto &thread : other) {
mergeChildren(std::move(thread.second));
assert(thread.second.empty());
}
other.clear();
}
void mergeChild(const void *id, std::unique_ptr<TimerImpl> &&other) {
auto &into = children[id];
if (!into) {
into = std::move(other);
into->mergeAsyncChildren();
} else {
into->wallTime = std::max(into->wallTime, other->wallTime);
into->userTime += other->userTime;
into->mergeChildren(std::move(other->children));
into->mergeChildren(std::move(other->asyncChildren));
other.reset();
}
}
void dump(raw_ostream &os, unsigned indent = 0, unsigned markThreadId = 0) {
auto time = getTimeRecord();
os << std::string(indent * 2, ' ') << name << " [" << threadId << "]"
<< llvm::format(" %7.4f / %7.4f", time.user, time.wall);
if (threadId != markThreadId && markThreadId != 0)
os << " (*)";
os << "\n";
for (auto &child : children)
child.second->dump(os, indent + 1, threadId);
for (auto &thread : asyncChildren)
for (auto &child : thread.second)
child.second->dump(os, indent + 1, threadId);
}
TimeRecord getTimeRecord() {
return TimeRecord(
std::chrono::duration_cast<std::chrono::duration<double>>(wallTime)
.count(),
std::chrono::duration_cast<std::chrono::duration<double>>(userTime)
.count());
}
void printAsList(TimeRecord total) {
llvm::StringMap<TimeRecord> mergedTimers;
std::function<void(TimerImpl *)> addTimer = [&](TimerImpl *timer) {
mergedTimers[timer->name] += timer->getTimeRecord();
for (auto &children : timer->children)
addTimer(children.second.get());
};
addTimer(this);
std::vector<std::pair<StringRef, TimeRecord>> timerNameAndTime;
for (auto &it : mergedTimers)
timerNameAndTime.emplace_back(it.first(), it.second);
llvm::array_pod_sort(timerNameAndTime.begin(), timerNameAndTime.end(),
[](const std::pair<StringRef, TimeRecord> *lhs,
const std::pair<StringRef, TimeRecord> *rhs) {
return llvm::array_pod_sort_comparator<double>(
&rhs->second.wall, &lhs->second.wall);
});
for (auto &timeData : timerNameAndTime)
output->printListEntry(timeData.first, timeData.second, total);
}
void printAsTree(TimeRecord total, unsigned indent = 0) {
unsigned childIndent = indent;
if (!hidden) {
output->printTreeEntry(indent, name, getTimeRecord(), total);
childIndent += 2;
}
for (auto &child : children) {
child.second->printAsTree(total, childIndent);
}
if (!hidden) {
output->printTreeEntryEnd(indent);
}
}
void print(DisplayMode displayMode) {
auto total = getTimeRecord();
output->printHeader(total);
switch (displayMode) {
case DisplayMode::List:
printAsList(total);
break;
case DisplayMode::Tree:
printAsTree(total);
break;
}
auto rest = total;
for (auto &child : children)
rest -= child.second->getTimeRecord();
output->printListEntry("Rest", rest, total);
output->printListEntry("Total", total, total, true);
output->printFooter();
}
std::chrono::time_point<std::chrono::steady_clock> startTime;
std::chrono::nanoseconds wallTime = std::chrono::nanoseconds(0);
std::chrono::nanoseconds userTime = std::chrono::nanoseconds(0);
uint64_t threadId;
std::string name;
bool hidden = false;
ChildrenMap children;
AsyncChildrenMap asyncChildren;
std::mutex asyncMutex;
std::unique_ptr<OutputStrategy> &output;
};
}
namespace mlir {
namespace detail {
class DefaultTimingManagerImpl {
public:
bool enabled = false;
DisplayMode displayMode = DisplayMode::Tree;
std::unique_ptr<TimerImpl> rootTimer;
};
}
}
DefaultTimingManager::DefaultTimingManager()
: impl(std::make_unique<DefaultTimingManagerImpl>()),
out(std::make_unique<OutputTextStrategy>(llvm::errs())) {
clear();
}
DefaultTimingManager::~DefaultTimingManager() { print(); }
void DefaultTimingManager::setEnabled(bool enabled) { impl->enabled = enabled; }
bool DefaultTimingManager::isEnabled() const { return impl->enabled; }
void DefaultTimingManager::setDisplayMode(DisplayMode displayMode) {
impl->displayMode = displayMode;
}
DefaultTimingManager::DisplayMode DefaultTimingManager::getDisplayMode() const {
return impl->displayMode;
}
void DefaultTimingManager::setOutput(std::unique_ptr<OutputStrategy> output) {
out = std::move(output);
}
void DefaultTimingManager::print() {
if (impl->enabled) {
impl->rootTimer->finalize();
impl->rootTimer->print(impl->displayMode);
}
clear();
}
void DefaultTimingManager::clear() {
impl->rootTimer = std::make_unique<TimerImpl>("root", out);
impl->rootTimer->hidden = true;
}
void DefaultTimingManager::dumpTimers(raw_ostream &os) {
impl->rootTimer->dump(os);
}
void DefaultTimingManager::dumpAsList(raw_ostream &os) {
impl->rootTimer->finalize();
impl->rootTimer->print(DisplayMode::List);
}
void DefaultTimingManager::dumpAsTree(raw_ostream &os) {
impl->rootTimer->finalize();
impl->rootTimer->print(DisplayMode::Tree);
}
std::optional<void *> DefaultTimingManager::rootTimer() {
if (impl->enabled)
return impl->rootTimer.get();
return std::nullopt;
}
void DefaultTimingManager::startTimer(void *handle) {
static_cast<TimerImpl *>(handle)->start();
}
void DefaultTimingManager::stopTimer(void *handle) {
static_cast<TimerImpl *>(handle)->stop();
}
void *DefaultTimingManager::nestTimer(void *handle, const void *id,
function_ref<std::string()> nameBuilder) {
return static_cast<TimerImpl *>(handle)->nest(id, nameBuilder);
}
void DefaultTimingManager::hideTimer(void *handle) {
static_cast<TimerImpl *>(handle)->hidden = true;
}
namespace {
struct DefaultTimingManagerOptions {
llvm::cl::opt<bool> timing{"mlir-timing",
llvm::cl::desc("Display execution times"),
llvm::cl::init(false)};
llvm::cl::opt<DisplayMode> displayMode{
"mlir-timing-display", llvm::cl::desc("Display method for timing data"),
llvm::cl::init(DisplayMode::Tree),
llvm::cl::values(
clEnumValN(DisplayMode::List, "list",
"display the results in a list sorted by total time"),
clEnumValN(DisplayMode::Tree, "tree",
"display the results ina with a nested tree view"))};
llvm::cl::opt<OutputFormat> outputFormat{
"mlir-output-format", llvm::cl::desc("Output format for timing data"),
llvm::cl::init(OutputFormat::Text),
llvm::cl::values(clEnumValN(OutputFormat::Text, "text",
"display the results in text format"),
clEnumValN(OutputFormat::Json, "json",
"display the results in JSON format"))};
};
}
static llvm::ManagedStatic<DefaultTimingManagerOptions> options;
void mlir::registerDefaultTimingManagerCLOptions() {
*options;
}
void mlir::applyDefaultTimingManagerCLOptions(DefaultTimingManager &tm) {
if (!options.isConstructed())
return;
tm.setEnabled(options->timing);
tm.setDisplayMode(options->displayMode);
std::unique_ptr<OutputStrategy> printer;
if (options->outputFormat == OutputFormat::Text)
printer = std::make_unique<OutputTextStrategy>(llvm::errs());
else if (options->outputFormat == OutputFormat::Json)
printer = std::make_unique<OutputJsonStrategy>(llvm::errs());
tm.setOutput(std::move(printer));
}