#include <ATen/Context.h>
#include <torch/csrc/profiler/combined_traceback.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include "torch_npu/csrc/utils/LazyInit.h"
#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"
#include "torch_npu/csrc/core/npu/NPUWorkspaceAllocator.h"
#include "torch_npu/csrc/profiler/combined_traceback.h"
#include "torch_npu/csrc/npu/memory_snapshot.h"
using torch::jit::Pickler;
using c10_npu::NPUCachingAllocator::BlockInfo;
using c10_npu::NPUCachingAllocator::SegmentInfo;
namespace torch_npu {
std::shared_ptr<c10::GatheredContext> gather()
{
#if defined(__x86_64__)
return torch::CapturedTraceback::gather(true, true, false);
#else
return torch_npu::CapturedTraceback::gather(true, true, false);
#endif
}
std::shared_ptr<c10::GatheredContext> gather_with_cpp()
{
#if defined(__x86_64__)
return torch::CapturedTraceback::gather(true, true, true);
#else
return torch_npu::CapturedTraceback::gather(true, true, true);
#endif
}
static void checkOptionIn(const std::string& option,
std::initializer_list<std::string> valid,
const char* error)
{
TORCH_CHECK(valid.end() != std::find(valid.begin(), valid.end(), option),
error, PTA_ERROR(ErrCode::NOT_FOUND));
}
void _record_memory_history(c10::optional<std::string> enabled,
c10::optional<std::string> context,
std::string stacks, size_t max_entries)
{
if (enabled) {
checkOptionIn(*enabled, {"state", "all"},
"expected state to be 'state', 'all', or None");
}
if (context) {
checkOptionIn(
*context, {"state", "alloc", "all"},
"expected context to be 'state', 'alloc', 'all', or None");
}
checkOptionIn(stacks, {"python", "all"},
"expected stacks to be 'python', or 'all'");
c10_npu::NPUCachingAllocator::CreateContextFn recorder = gather;
if (enabled && stacks == "all") {
recorder = gather_with_cpp;
#if defined(__x86_64__)
torch::unwind::unwind();
#else
torch_npu::unwind::unwind();
#endif
}
max_entries = (enabled && *enabled == "all") ? max_entries : 1;
auto when = c10_npu::NPUCachingAllocator::RecordContext::NEVER;
if (context) {
if (context == "all") {
when = c10_npu::NPUCachingAllocator::RecordContext::ALL;
} else if (context == "alloc") {
when = c10_npu::NPUCachingAllocator::RecordContext::ALLOC;
} else if (context == "state") {
when = c10_npu::NPUCachingAllocator::RecordContext::STATE;
}
}
c10_npu::NPUCachingAllocator::recordHistory(enabled.has_value(), recorder,
max_entries, when);
c10_npu::NPUWorkspaceAllocator::recordHistory(enabled.has_value(), recorder, when);
}
std::string write_pickle(const c10::IValue& v)
{
std::vector<char> result;
{
auto writer = [&](const char* data, size_t size) {
result.insert(result.end(), data, data + size);
};
Pickler pickler(writer, nullptr, nullptr, nullptr, nullptr, false);
pickler.protocol();
pickler.pushIValue(v);
pickler.stop();
}
return std::string(result.begin(), result.end());
}
c10::Dict<c10::IValue, c10::IValue> new_dict()
{
return c10::Dict<c10::IValue, c10::IValue>(c10::AnyType::get(),
c10::AnyType::get());
}
c10::List<c10::IValue> new_list()
{
return c10::List<c10::IValue>(c10::AnyType::get());
}
std::vector<c10::IValue> ivalue_symbolize(std::vector<torch::CapturedTraceback*>& to_symbolize)
{
std::unordered_map<torch::CapturedTraceback*, uint64_t> cached_frames;
std::vector<torch::CapturedTraceback*> unique_frames;
for (const auto& sc : to_symbolize) {
auto it = cached_frames.find(sc);
if (it == cached_frames.end()) {
cached_frames.insert({sc, unique_frames.size()});
unique_frames.push_back(sc);
}
}
auto s = symbolize(unique_frames);
c10::IValue line_s = "line";
c10::IValue name_s = "name";
c10::IValue filename_s = "filename";
std::vector<c10::IValue> all_frames;
for (const auto& f : s.all_frames) {
auto d = new_dict();
d.insert(name_s, f.funcname);
d.insert(filename_s, f.filename);
d.insert(line_s, int64_t(f.lineno));
all_frames.emplace_back(std::move(d));
}
std::vector<c10::IValue> py_unique_frames;
for (const auto& t : s.tracebacks) {
auto l = new_list();
for (const auto& e : t) {
l.push_back(all_frames.at(e));
}
py_unique_frames.emplace_back(std::move(l));
}
std::vector<c10::IValue> result;
for (const auto& sc : to_symbolize) {
result.push_back(py_unique_frames.at(cached_frames.at(sc)));
}
return result;
}
torch::CapturedTraceback* getFromContext(const std::shared_ptr<c10::GatheredContext>& x)
{
if (torch::CapturedTraceback* sc =
dynamic_cast<torch::CapturedTraceback*>(x.get())) {
return sc;
}
TORCH_CHECK(
false,
"attempting to gather stack context from the wrong StackContext type.", PTA_ERROR(ErrCode::NOT_FOUND));
}
std::string _memory_snapshot_pickled()
{
c10::IValue device_s = "device";
c10::IValue address_s = "address";
c10::IValue total_size_s = "total_size";
c10::IValue allocated_size_s = "allocated_size";
c10::IValue active_size_s = "active_size";
c10::IValue requested_size_s = "requested_size";
c10::IValue stream_s = "stream";
c10::IValue segment_type_s = "segment_type";
c10::IValue segment_pool_id = "segment_pool_id";
c10::IValue large_s = "large";
c10::IValue small_s = "small";
c10::IValue size_s = "size";
c10::IValue state_s = "state";
c10::IValue active_allocated_s = "active_allocated";
c10::IValue active_pending_free_s = "active_pending_free";
c10::IValue inactive_s = "inactive";
c10::IValue addr_s = "addr";
c10::IValue filename_s = "filename";
c10::IValue name_s = "name";
c10::IValue line_s = "line";
c10::IValue frames_s = "frames";
c10::IValue blocks_s = "blocks";
c10::IValue is_expandable_s = "is_expandable";
auto empty_frames = new_list();
std::vector<torch::CapturedTraceback*> frame_tracebacks;
std::vector<c10::Dict<c10::IValue, c10::IValue> > frame_dict;
auto add_frame_key = [&](const c10::Dict<c10::IValue, c10::IValue>& d,
const std::shared_ptr<c10::GatheredContext>& ctx) {
if (ctx) {
frame_tracebacks.push_back(getFromContext(ctx));
frame_dict.push_back(d);
} else {
d.insert(frames_s, empty_frames);
}
};
const auto segmentInfoToDict = [&](const SegmentInfo& segmentInfo) {
auto segmentDict = new_dict();
segmentDict.insert(device_s, segmentInfo.device);
segmentDict.insert(address_s, segmentInfo.address);
segmentDict.insert(total_size_s, segmentInfo.total_size);
segmentDict.insert(allocated_size_s, segmentInfo.allocated_size);
segmentDict.insert(active_size_s, segmentInfo.active_size);
segmentDict.insert(requested_size_s, segmentInfo.requested_size);
segmentDict.insert(stream_s, int64_t(segmentInfo.stream));
segmentDict.insert(segment_type_s,
(segmentInfo.is_large ? large_s : small_s));
segmentDict.insert(segment_pool_id,
std::tuple<int64_t, int64_t>(segmentInfo.owner_private_pool_id));
segmentDict.insert(is_expandable_s, segmentInfo.is_expandable);
add_frame_key(segmentDict, segmentInfo.context_when_allocated);
auto address = segmentInfo.address;
auto blocks = new_list();
for (const auto& blockInfo : segmentInfo.blocks) {
auto blockDict = new_dict();
blockDict.insert(address_s, address);
blockDict.insert(size_s, blockInfo.size);
blockDict.insert(requested_size_s, blockInfo.requested_size);
blockDict.insert(state_s,
(blockInfo.allocated
? active_allocated_s
: (blockInfo.active ? active_pending_free_s
: inactive_s)));
add_frame_key(blockDict, blockInfo.context_when_allocated);
address += blockInfo.size;
blocks.push_back(blockDict);
}
segmentDict.insert(blocks_s, blocks);
return segmentDict;
};
auto snapshot = c10_npu::NPUCachingAllocator::snapshot();
auto segments = new_list();
for (const auto& segmentInfo : snapshot.segments) {
segments.push_back(segmentInfoToDict(segmentInfo));
}
auto traces = new_list();
c10::IValue action_s = "action";
c10::IValue alloc_s = "alloc";
c10::IValue free_requested_s = "free_requested";
c10::IValue free_completed_s = "free_completed";
c10::IValue segment_alloc_s = "segment_alloc";
c10::IValue segment_free_s = "segment_free";
c10::IValue segment_map_s = "segment_map";
c10::IValue segment_unmap_s = "segment_unmap";
c10::IValue snapshot_s = "snapshot";
c10::IValue oom_s = "oom";
c10::IValue device_free_s = "device_free";
using namespace c10_npu::NPUCachingAllocator;
auto action_to_str = [&](TraceEntry::Action action) {
switch (action) {
case TraceEntry::ALLOC:
return alloc_s;
case TraceEntry::FREE_REQUESTED:
return free_requested_s;
case TraceEntry::FREE_COMPLETED:
return free_completed_s;
case TraceEntry::SEGMENT_ALLOC:
return segment_alloc_s;
case TraceEntry::SEGMENT_FREE:
return segment_free_s;
case TraceEntry::OOM:
return oom_s;
case TraceEntry::SNAPSHOT:
return snapshot_s;
case TraceEntry::SEGMENT_UNMAP:
return segment_unmap_s;
case TraceEntry::SEGMENT_MAP:
return segment_map_s;
default:
AT_ERROR("invalid TraceEntry action");
}
throw std::runtime_error("unreachable");
};
for (const auto& traceInfo : snapshot.device_traces) {
auto trace = new_list();
for (const auto& te : traceInfo) {
auto trace_entry = new_dict();
trace_entry.insert(action_s, action_to_str(te.action_));
trace_entry.insert(te.action_ == TraceEntry::OOM ? device_free_s
: addr_s,
te.addr_);
trace_entry.insert(size_s, (int64_t)te.size_);
trace_entry.insert(stream_s, int64_t(te.stream_));
if (te.context_) {
auto sc = getFromContext(te.context_);
frame_tracebacks.push_back(sc);
frame_dict.push_back(trace_entry);
}
trace.push_back(trace_entry);
}
traces.push_back(trace);
}
auto result = new_dict();
result.insert("segments", segments);
result.insert("device_traces", traces);
auto frames = ivalue_symbolize(frame_tracebacks);
for (auto i : c10::irange(frames.size())) {
frame_dict.at(i).insert(frames_s, frames.at(i));
}
return write_pickle(result);
}
}