#pragma once
#include "common.h"
#include "EventPool.h"
#include "CachingAllocatorConfig.h"
class DeviceCachingAllocator {
private:
mutable std::recursive_mutex mutex;
DeviceStats stats;
BlockPool large_blocks;
BlockPool free_fused_blocks;
std::unordered_map<aclrtStream, BlockEventOrderPool> free_fused_blocks_in_release_order;
std::unordered_map<aclrtStream, BlockEventOrderPool> fragmented_free_fused_blocks;
BlockPool small_blocks;
ska::flat_hash_set<Block *> active_blocks;
ska::flat_hash_set<Block *> active_fused_blocks;
ska::flat_hash_set<Block *> active_fused_blocks_to_gc;
int captures_underway = 0;
std::vector<Block *> needs_events_deferred_until_no_capture;
ska::flat_hash_map<c10_npu::NPUStream, std::deque<std::pair<EventPool::Event, Block *>>> npu_events;
size_t total_allocated_memory = 0;
size_t total_fuse_size = 0;
size_t allowed_memory_maximum = 0;
bool set_fraction = false;
std::atomic<CreateContextFn> context_recorder_;
size_t alloc_trace_next = 0;
bool alloc_trace_record_context_ = false;
RecordContext record_context_ = RecordContext::NEVER;
size_t alloc_trace_max_entries_ = 1;
std::vector<TraceEntry> *alloc_trace;
std::vector<OutOfMemoryObserver> oom_observers_;
public:
DeviceCachingAllocator()
: large_blocks(BlockComparator, false),
free_fused_blocks(BlockComparator, false),
small_blocks(BlockComparator, true),
alloc_trace(new std::vector<TraceEntry>())
{
stats.max_split_size = CachingAllocatorConfig::max_split_size();
context_recorder_.store(nullptr);
}
Block *malloc(int device, size_t orig_size, aclrtStream stream);
Block *alloc_found_block(AllocParams params, size_t orig_size, bool split_remainder);
void free(Block *block);
void update_block(Block *block);
void *getBaseAllocation(Block *block, size_t *outSize);
void recordStream(Block *block, c10_npu::NPUStream stream);
void eraseStream(Block *block, c10_npu::NPUStream stream);
void setMemoryFraction(double fraction);
void emptyCache(bool check_error);
void cacheInfo(size_t *total, size_t *largest);
DeviceStats getStats();
void resetAccumulatedStats();
void resetPeakStats();
std::vector<SegmentInfo> snapshot();
static size_t round_size(size_t size);
private:
std::vector<const Block *> get_all_blocks() const;
void free_block(Block *block, bool flag);
bool need_merge(Block *dst, Block *src);
size_t try_merge_blocks(Block *dst, Block *src, BlockPool &pool);
BlockPool &get_pool(size_t size);
StatType get_stat_type_for_pool(const BlockPool &pool);
StatTypes get_stat_types_for_pool(const BlockPool &pool);
bool should_split(const Block *block, size_t size);
static size_t get_allocation_size(size_t size);
bool get_free_block(AllocParams &p);
bool trigger_free_memory_callbacks(AllocParams &p);
void garbage_collect_cached_blocks();
bool realloc_block(AllocParams &p, bool isRetry);
bool release_available_cached_blocks(const AllocParams &p);
bool release_cached_blocks();
void release_block(Block *block);
void release_blocks(BlockPool &pool);
EventPool::Event create_event_internal(int idx);
void synchronize_and_free_events();
void insert_events(Block *block);
void insert_free_event_into_alloc_stream(Block *block);
void insert_events_deferred_until_no_capture();
void process_events();
void cache_info_aux(BlockPool &blocks, size_t *total, size_t *largest);
bool get_fused_fragmented_blocks(AllocParams &p, int time);
bool release_swapout_blocks();
Block *stitch_block(std::vector<Block *> &blocks2fuse, AllocParams &p);
Block *split_large_block(Block *block, size_t request_size);
void release_large_block(Block *block);
void activate_large_block(Block *block);
void deactivate_large_block(Block *block);
size_t garbage_collect_fused_blocks(int time, size_t require_size = 0);
};