#include "PassDetail.h"
#include "mlir/Pass/PassManager.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Format.h"
using namespace mlir;
using namespace mlir::detail;
constexpr StringLiteral kPassStatsDescription =
"... Pass statistics report ...";
namespace {
struct Statistic {
const char *name, *desc;
uint64_t value;
};
}
static void printPassEntry(raw_ostream &os, unsigned indent, StringRef pass,
MutableArrayRef<Statistic> stats = std::nullopt) {
os.indent(indent) << pass << "\n";
if (stats.empty())
return;
llvm::array_pod_sort(
stats.begin(), stats.end(), [](const auto *lhs, const auto *rhs) {
return StringRef{lhs->name}.compare(StringRef{rhs->name});
});
size_t largestName = 0, largestValue = 0;
for (auto &stat : stats) {
largestName = std::max(largestName, (size_t)strlen(stat.name));
largestValue =
std::max(largestValue, (size_t)llvm::utostr(stat.value).size());
}
for (auto &stat : stats) {
os.indent(indent + 2) << llvm::format("(S) %*u %-*s - %s\n", largestValue,
stat.value, largestName, stat.name,
stat.desc);
}
}
static void printResultsAsList(raw_ostream &os, OpPassManager &pm) {
llvm::StringMap<std::vector<Statistic>> mergedStats;
std::function<void(Pass *)> addStats = [&](Pass *pass) {
auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass);
if (!adaptor) {
#if LLVM_ENABLE_STATS
auto statistics = pass->getStatistics();
if (statistics.empty())
return;
auto &passEntry = mergedStats[pass->getName()];
if (passEntry.empty()) {
for (Pass::Statistic *it : pass->getStatistics())
passEntry.push_back({it->getName(), it->getDesc(), it->getValue()});
} else {
for (auto [idx, statistic] : llvm::enumerate(pass->getStatistics()))
passEntry[idx].value += statistic->getValue();
}
#endif
return;
}
for (auto &mgr : adaptor->getPassManagers())
for (Pass &pass : mgr.getPasses())
addStats(&pass);
};
for (Pass &pass : pm.getPasses())
addStats(&pass);
auto passAndStatistics =
llvm::to_vector<16>(llvm::make_pointer_range(mergedStats));
llvm::array_pod_sort(passAndStatistics.begin(), passAndStatistics.end(),
[](const decltype(passAndStatistics)::value_type *lhs,
const decltype(passAndStatistics)::value_type *rhs) {
return (*lhs)->getKey().compare((*rhs)->getKey());
});
for (auto &statData : passAndStatistics)
printPassEntry(os, 2, statData->first(), statData->second);
}
static void printResultsAsPipeline(raw_ostream &os, OpPassManager &pm) {
#if LLVM_ENABLE_STATS
std::function<void(unsigned, Pass *)> printPass = [&](unsigned indent,
Pass *pass) {
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass)) {
auto mgrs = adaptor->getPassManagers();
if (mgrs.size() > 1) {
printPassEntry(os, indent, adaptor->getAdaptorName());
indent += 2;
}
for (OpPassManager &mgr : mgrs) {
auto name = ("'" + mgr.getOpAnchorName() + "' Pipeline").str();
printPassEntry(os, indent, name);
for (Pass &pass : mgr.getPasses())
printPass(indent + 2, &pass);
}
return;
}
std::vector<Statistic> stats;
for (Pass::Statistic *stat : pass->getStatistics())
stats.push_back({stat->getName(), stat->getDesc(), stat->getValue()});
printPassEntry(os, indent, pass->getName(), stats);
};
for (Pass &pass : pm.getPasses())
printPass(0, &pass);
#endif
}
static void printStatistics(OpPassManager &pm, PassDisplayMode displayMode) {
auto os = llvm::CreateInfoOutputFile();
*os << "===" << std::string(73, '-') << "===\n";
unsigned padding = (80 - kPassStatsDescription.size()) / 2;
os->indent(padding) << kPassStatsDescription << '\n';
*os << "===" << std::string(73, '-') << "===\n";
switch (displayMode) {
case PassDisplayMode::List:
printResultsAsList(*os, pm);
break;
case PassDisplayMode::Pipeline:
printResultsAsPipeline(*os, pm);
break;
}
*os << "\n";
os->flush();
}
Pass::Statistic::Statistic(Pass *owner, const char *name,
const char *description)
: llvm::Statistic{"", name, description} {
#if LLVM_ENABLE_STATS
Initialized = true;
#endif
owner->statistics.push_back(this);
}
auto Pass::Statistic::operator=(unsigned value) -> Statistic & {
llvm::Statistic::operator=(value);
return *this;
}
void OpPassManager::mergeStatisticsInto(OpPassManager &other) {
auto passes = getPasses(), otherPasses = other.getPasses();
for (auto passPair : llvm::zip(passes, otherPasses)) {
Pass &pass = std::get<0>(passPair), &otherPass = std::get<1>(passPair);
if (auto *adaptorPass = dyn_cast<OpToOpPassAdaptor>(&pass)) {
auto *otherAdaptorPass = cast<OpToOpPassAdaptor>(&otherPass);
for (auto mgrs : llvm::zip(adaptorPass->getPassManagers(),
otherAdaptorPass->getPassManagers()))
std::get<0>(mgrs).mergeStatisticsInto(std::get<1>(mgrs));
continue;
}
assert(pass.statistics.size() == otherPass.statistics.size());
for (unsigned i = 0, e = pass.statistics.size(); i != e; ++i) {
assert(pass.statistics[i]->getName() ==
StringRef(otherPass.statistics[i]->getName()));
*otherPass.statistics[i] += *pass.statistics[i];
*pass.statistics[i] = 0;
}
}
}
static void prepareStatistics(OpPassManager &pm) {
for (Pass &pass : pm.getPasses()) {
OpToOpPassAdaptor *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
if (!adaptor)
continue;
MutableArrayRef<OpPassManager> nestedPms = adaptor->getPassManagers();
for (auto &asyncPM : adaptor->getParallelPassManagers()) {
for (unsigned i = 0, e = asyncPM.size(); i != e; ++i) {
prepareStatistics(asyncPM[i]);
asyncPM[i].mergeStatisticsInto(nestedPms[i]);
}
}
for (OpPassManager &nestedPM : nestedPms)
prepareStatistics(nestedPM);
}
}
void PassManager::dumpStatistics() {
prepareStatistics(*this);
printStatistics(*this, *passStatisticsMode);
}
void PassManager::enableStatistics(PassDisplayMode displayMode) {
passStatisticsMode = displayMode;
}