#pragma once
#include <limits>
#include <vector>
#include <memory>
#include <c10/util/flat_hash_map.h>
#include <third_party/acl/inc/acl/acl_base.h>
#include <third_party/acl/inc/acl/acl_rt.h>
#include <torch_npu/csrc/core/npu/NPUStream.h>
#include <torch_npu/csrc/core/npu/NPUCachingAllocator.h>
#include "NPUVmmApi.h"
using c10_npu::NPUCachingAllocator::DeviceStats;
using c10_npu::NPUCachingAllocator::RecordContext;
using c10_npu::NPUCachingAllocator::SegmentInfo;
using c10_npu::NPUCachingAllocator::Stat;
using c10_npu::NPUCachingAllocator::StatArray;
using c10_npu::NPUCachingAllocator::StatType;
using c10_npu::NPUCachingAllocator::TraceEntry;
using OutOfMemoryObserver =
std::function<void(int64_t device, int64_t allocated, int64_t device_total, int64_t device_free)>;
struct History {
void *addr;
size_t real_size;
std::shared_ptr<c10::GatheredContext> context;
};
struct BlockInfo {
int64_t size = 0;
int64_t requested_size = 0;
int32_t gc_counter = 0;
bool allocated = false;
bool active = false;
std::shared_ptr<c10::GatheredContext> context_when_allocated;
std::vector<History> history;
};
using stream_set = ska::flat_hash_set<c10_npu::NPUStream>;
using CreateContextFn = std::shared_ptr<c10::GatheredContext> (*)(void);
constexpr size_t kMinBlockSize = 512;
constexpr size_t kSmallSize = 1048576;
constexpr size_t kSmallBuffer = 2097152;
constexpr size_t kLargeBuffer = 20971520;
constexpr size_t kMinLargeAlloc = 10485760;
constexpr size_t kRoundLarge = 2097152;
constexpr size_t kGranularity = 2097152;
using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
void update_stat(Stat &stat, int64_t amount);
void reset_accumulated_stat(Stat &stat);
void reset_peak_stat(Stat &stat);
template <typename Func> void for_each_selected_stat_type(const StatTypes &stat_types, Func f)
{
for (const auto stat_type : c10::irange(stat_types.size())) {
if (stat_types[stat_type]) {
f(stat_type);
}
}
}
void update_stat_array(StatArray &stat_array, int64_t amount, const StatTypes &stat_types);
struct Block;
using Comparison = bool (*)(const Block *, const Block *);
struct BlockPool {
BlockPool(Comparison comparator, bool small) : blocks(comparator), is_small(small) {}
std::set<Block *, Comparison> blocks;
std::unordered_set<size_t> hash;
const bool is_small;
};
struct HistoryChain {
History h;
std::unique_ptr<HistoryChain> next;
};
struct Block {
int device;
aclrtStream stream;
stream_set stream_uses;
size_t size;
size_t requested_size;
size_t actual_size;
BlockPool *pool{ nullptr };
void *ptr{ nullptr };
bool allocated{ false };
Block *prev{ nullptr };
Block *next{ nullptr };
int event_count{ 0 };
int gc_count{ 0 };
std::unique_ptr<HistoryChain> history{ nullptr };
HistoryChain *history_last{ nullptr };
std::shared_ptr<VmmSegment> vmm_segment;
size_t ptr_hash;
Block(int device, aclrtStream stream, size_t size, BlockPool *pool, void *ptr)
: device(device),
stream(stream),
stream_uses(),
size(size),
actual_size(0),
requested_size(0),
pool(pool),
ptr(ptr)
{
ptr_hash = reinterpret_cast<size_t>(ptr);
}
Block(int device, aclrtStream stream, size_t size)
: device(device),
stream(stream),
stream_uses(),
size(size),
actual_size(0),
requested_size(0)
{
ptr_hash = 0;
}
bool is_split() const
{
return (prev != nullptr) || (next != nullptr);
}
void splice(Block *before, Block *after)
{
if (before) {
before->next = this;
}
prev = before;
if (after) {
after->prev = this;
}
next = after;
}
};
struct BlockHash {
size_t operator () (const Block *b) const
{
return b->ptr_hash;
}
};
bool BlockComparator(const Block *a, const Block *b);
using EventOrderedBlockSet = std::unordered_set<Block *, BlockHash>;
using SetIterator = EventOrderedBlockSet::iterator;
struct BlockEventOrderPool {
BlockEventOrderPool() : pool_size(0) {}
void insert(Block *block)
{
if (blocks.count(block) == 0) {
blocks.insert(block);
pool_size += block->size;
}
}
bool erase(Block *block)
{
if (blocks.count(block)) {
blocks.erase(block);
pool_size -= block->size;
return true;
} else {
return false;
}
}
SetIterator erase(SetIterator it)
{
if (blocks.count(*it)) {
pool_size -= (*it)->size;
return blocks.erase(it);
} else {
return blocks.end();
}
}
EventOrderedBlockSet blocks;
size_t pool_size;
};
inline std::string format_size(uint64_t size)
{
std::ostringstream os;
os.precision(2);
os << std::fixed;
if (size <= 1024) {
os << size << " bytes";
} else if (size <= 1048576) {
os << (size / 1024.0);
os << " KiB";
} else if (size <= 1073741824ULL) {
os << (size / 1048576.0);
os << " MiB";
} else {
os << (size / 1073741824.0);
os << " GiB";
}
return os.str();
}
struct AllocParams {
AllocParams(int device, size_t size, aclrtStream stream, BlockPool *pool, size_t alloc_size, DeviceStats &stats)
: search_key(device, stream, size), pool(pool), alloc_size(alloc_size), block(nullptr), err(ACL_ERROR_NONE)
{}
int device() const
{
return search_key.device;
}
aclrtStream stream() const
{
return search_key.stream;
}
size_t size() const
{
return search_key.size;
}
Block search_key;
BlockPool *pool;
size_t alloc_size;
Block *block;
StatTypes stat_types = { false };
aclError err;
};