#include "systrace_manager.h"
#include <chrono>
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <string>
int global_stage_id = 0;
int global_stage_type = 0;
#ifdef HAS_BTF_SUPPORT
extern "C" {
int run_osprobe();
void cleanup_osprobe();
}
#endif
namespace systrace {
namespace {
constexpr uint64_t TRACE_INTERVAL = 100;
constexpr std::chrono::milliseconds POLL_INTERVAL(10);
}
PyTorchTrace &PyTorchTrace::getInstance() {
std::call_once(init_flag_, []() {
instance_ = new PyTorchTrace();
instance_->initialize();
});
return *instance_;
}
PyTorchTrace::PyTorchTrace() {
#ifdef USE_JSON
writer_thread_ = std::thread(&PyTorchTrace::writerLoop, this);
#endif
}
PyTorchTrace::~PyTorchTrace() {
#ifdef USE_JSON
{
std::lock_guard<std::mutex> lock(queue_mutex_);
stop_writer_ = true;
}
queue_cv_.notify_one();
if (writer_thread_.joinable())
writer_thread_.join();
#else
std::lock_guard<std::mutex> lock(trace_mutex_);
writeTraceToFile();
#endif
if (pytorch_tracing_library_) {
delete pytorch_tracing_library_;
pytorch_tracing_library_ = nullptr;
}
}
void PyTorchTrace::initialize() {
#ifndef USE_JSON
pytorch_trace_.set_rank(
systrace::util::config::GlobalConfig::Instance().rank);
LOG_MODULE(INFO, "PyTorchTrace")
<< "Rank set to: "
<< systrace::util::config::GlobalConfig::Instance().rank;
#endif
pytorch_tracing_library_ =
new pytorch_tracing::PyTorchTracingLibrary("libsysTrace.so");
LOG_MODULE(INFO, "PyTorchTrace") << "Tracing library loaded";
registerTracingFunctions();
}
void PyTorchTrace::registerTracingFunctions() {
std::ifstream funcListFile(PyFuncListPath_);
std::string line;
if (!funcListFile.is_open()) {
LOG_MODULE(ERROR, "PyTorchTrace") << "Failed to open PyFuncList file";
return;
}
while (std::getline(funcListFile, line)) {
if (!line.empty() && line[0] != '#') {
pytorch_tracing_functions_.push_back(line);
}
}
funcListFile.close();
if (pytorch_tracing_library_) {
auto errors =
pytorch_tracing_library_->Register(pytorch_tracing_functions_);
for (size_t i = 0; i < pytorch_tracing_functions_.size(); ++i) {
LOG_MODULE(INFO, "PyTorchTrace")
<< "Registered function: " << pytorch_tracing_functions_[i]
<< ", status: " << errors[i];
}
}
}
bool PyTorchTrace::triggerTrace() { return has_trigger_trace_.exchange(true); }
void PyTorchTrace::dumpPyTorchTracing() {
const std::string dump_path =
constant::TorchTraceConstant::DEFAULT_TRACE_DUMP_PATH();
if (util::fs_utils::CreateDirectoryIfNotExists(dump_path)) {
LOG_MODULE(ERROR, "PyTorchTrace")
<< "[PyTorchTrace] Failed to create dump directory";
return;
}
std::lock_guard<std::mutex> lock(trace_mutex_);
#ifndef USE_JSON
pytorch_trace_.set_rank(
systrace::util::config::GlobalConfig::Instance().rank);
pytorch_trace_.set_comm(
systrace::util::config::GlobalConfig::Instance().job_name);
#endif
for (size_t i = 0; i < pytorch_tracing_functions_.size(); ++i) {
processFunctionTracingData(i);
}
#ifndef USE_JSON
writeTraceToFile();
#endif
}
void PyTorchTrace::processFunctionTracingData(size_t function_index) {
if (!pytorch_tracing_library_)
return;
std::vector<PyTorchTracingDataArray *> data_holders;
if (auto data = pytorch_tracing_library_->RetrievePartialTracingData(
function_index)) {
data_holders.push_back(data);
}
while (auto data = pytorch_tracing_library_->RetrieveAllTracingData(
function_index)) {
data_holders.push_back(data);
}
for (auto data : data_holders) {
for (uint32_t i = 0; i < data->cur; ++i) {
if (data->data[i].start == 0)
continue;
#ifdef USE_JSON
json item;
item["start_us"] = data->data[i].start;
item["end_us"] = data->data[i].end;
item["stage_id"] = data->data[i].count;
item["rank"] =
systrace::util::config::GlobalConfig::Instance().rank;
if (data->data[i].type == PAYLOAD_GC) {
item["stage_type"] = "GC";
item["gc_debug"] = {
{"collected", data->data[i].payload.gc_debug[0]},
{"uncollectable", data->data[i].payload.gc_debug[1]}};
} else {
item["stage_type"] = pytorch_tracing_functions_[function_index];
}
if (data->data[i].stack_depth > 0) {
item["stack_frames"] = json::array();
for (int j = 0; j < data->data[i].stack_depth; ++j) {
if (data->data[i].stack_info[j][0] != '\0') {
item["stack_frames"].push_back(
data->data[i].stack_info[j]);
}
}
}
enqueueTraceEntry(std::move(item));
#else
auto trace = pytorch_trace_.add_pytorch_stages();
trace->set_start_us(data->data[i].start);
trace->set_end_us(data->data[i].end);
trace->set_stage_id(data->data[i].count);
if (data->data[i].type == PAYLOAD_GC) {
trace->set_stage_type("GC");
} else {
trace->set_stage_type(
pytorch_tracing_functions_[function_index]);
}
if (data->data[i].stack_depth > 0) {
trace->mutable_stack_frames()->Reserve(
data->data[i].stack_depth);
for (int j = 0; j < data->data[i].stack_depth; ++j) {
if (data->data[i].stack_info[j][0] != '\0') {
trace->add_stack_frames(data->data[i].stack_info[j]);
}
}
}
if (data->data[i].type == PAYLOAD_GC) {
auto gc_debug = trace->mutable_gc_debug();
gc_debug->set_collected(data->data[i].payload.gc_debug[0]);
gc_debug->set_uncollectable(data->data[i].payload.gc_debug[1]);
}
#endif
}
}
for (auto data : data_holders) {
pytorch_tracing_library_->ReleaseTracingData(
data, PY_TRACING_EMPTY_POOL, function_index);
}
}
#ifdef USE_JSON
void PyTorchTrace::enqueueTraceEntry(json &&entry) {
{
std::lock_guard<std::mutex> lock(queue_mutex_);
trace_queue_.push(std::move(entry));
}
queue_cv_.notify_one();
}
void PyTorchTrace::writerLoop() {
while (true) {
std::vector<json> batch;
{
std::unique_lock<std::mutex> lock(queue_mutex_);
queue_cv_.wait(
lock, [this] { return !trace_queue_.empty() || stop_writer_; });
if (stop_writer_ && trace_queue_.empty())
break;
while (!trace_queue_.empty()) {
batch.push_back(std::move(trace_queue_.front()));
trace_queue_.pop();
}
}
if (!batch.empty()) {
const std::string dump_path =
constant::TorchTraceConstant::DEFAULT_TRACE_DUMP_PATH();
std::string file_path =
dump_path + "/" +
util::fs_utils::GenerateClusterUniqueFilename(".json");
std::ofstream file(file_path, std::ios::app);
if (file.is_open()) {
for (const auto &entry : batch)
file << entry.dump() << "\n";
file.flush();
}
}
}
}
#else
void PyTorchTrace::writeTraceToFile() {
const std::string dump_path =
constant::TorchTraceConstant::DEFAULT_TRACE_DUMP_PATH();
std::string file_path =
dump_path + "/" +
util::fs_utils::GenerateClusterUniqueFilename(".timeline");
std::ofstream file(file_path, std::ios::binary | std::ios::out);
if (!file) {
LOG_MODULE(ERROR, "PyTorchTrace")
<< "Failed to open file: " << file_path;
return;
}
std::string binary_data;
if (!pytorch_trace_.SerializeToString(&binary_data)) {
LOG_MODULE(ERROR, "PyTorchTrace") << "Failed to serialize trace data";
return;
}
file << binary_data;
}
#endif
SysTrace &SysTrace::getInstance() {
std::call_once(init_flag_, []() {
instance_ = new SysTrace();
instance_->initializeSystem();
std::atexit(cleanup);
});
return *instance_;
}
SysTrace::~SysTrace() {
stopEventPoller();
#ifdef HAS_BTF_SUPPORT
stopOsProbePoller();
#endif
}
bool SysTrace::isMsptiLibraryLoaded() {
const char *ld_preload = std::getenv("LD_PRELOAD");
if (!ld_preload) {
return false;
}
std::string preload_str(ld_preload);
return preload_str.find("libmspti.so") != std::string::npos;
}
void SysTrace::registerPlugins() {
auto &cm = ControlManager::getInstance();
cm.register_plugin(std::make_shared<HbmPlugin>());
if (isMsptiLibraryLoaded()) {
cm.register_plugin(std::make_shared<MsptiPlugin>());
} else {
LOG_MODULE(INFO, "SysTrace")
<< "libmspti.so not found in LD_PRELOAD, skipping MsptiPlugin "
"registration";
}
cm.register_plugin(std::make_shared<IOPlugin>());
#ifdef HAS_BTF_SUPPORT
cm.register_plugin(std::make_shared<MemoryPlugin>());
cm.register_plugin(std::make_shared<CpuPlugin>());
#endif
cm.register_plugin(std::make_shared<CacheMissPlugin>());
cm.register_plugin(std::make_shared<FtracePlugin>());
cm.register_plugin(std::make_shared<TracePlugin>());
#ifdef HAS_BPF_SUPPORT
cm.register_plugin(std::make_shared<GILPlugin>());
cm.register_plugin(std::make_shared<PthreadPlugin>());
#endif
}
void SysTrace::initializeSystem() {
if (!systrace::util::config::GlobalConfig::Instance().enable)
return;
init_sys_trace_root_dir();
systrace::util::InitializeSystemUtilities();
registerPlugins();
ControlManager::getInstance().start();
#ifdef ENABLE_PYTHON_TRACING
PyTorchTrace::getInstance();
#endif
#ifdef HAS_BTF_SUPPORT
os_probe_ = std::thread(&run_osprobe);
#endif
#ifdef ENABLE_PYTHON_TRACING
startEventPoller();
#endif
}
void SysTrace::startEventPoller() {
should_run_ = true;
event_poller_ = std::thread(&SysTrace::eventPollerMain, this);
pthread_setname_np(event_poller_.native_handle(), "systrace_poller");
}
void SysTrace::stopEventPoller() {
should_run_ = false;
if (event_poller_.joinable())
event_poller_.join();
ControlManager::getInstance().stop();
}
void SysTrace::eventPollerMain() {
#ifdef ENABLE_PYTHON_TRACING
while (should_run_) {
if (loop_count_++ % TRACE_INTERVAL == 0) {
if (PyTorchTrace::getInstance().triggerTrace()) {
PyTorchTrace::getInstance().dumpPyTorchTracing();
}
}
std::this_thread::sleep_for(POLL_INTERVAL);
}
PyTorchTrace::getInstance().dumpPyTorchTracing();
#endif
}
#ifdef HAS_BTF_SUPPORT
void SysTrace::stopOsProbePoller() {
cleanup_osprobe();
if (os_probe_.joinable())
os_probe_.join();
}
#endif
void SysTrace::cleanup() {
if (instance_) {
#ifdef ENABLE_PYTHON_TRACING
PyTorchTrace::getInstance().dumpPyTorchTracing();
#endif
}
#ifdef HAS_BTF_SUPPORT
instance_->stopOsProbePoller();
#endif
}
}