#pragma once
#include <sys/types.h>
#include <iostream>
#include <torch/extension.h>
#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"
#include <c10/util/UniqueVoidPtr.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include "acl_base.h"
#include "acl_rt.h"
#include "torch_npu/csrc/core/npu/NPUBlockHandle.h"
#include "torch_npu/csrc/core/npu/NPUEvent.h"
#include "torch_npu/csrc/core/npu/NPUGuard.h"
#include "torch_npu/csrc/core/npu/interface/AsyncTaskQueueInterface.h"
#include "torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.h"
#include <torch/csrc/python_headers.h>
#include <atomic>
#include <mutex>
using c10_npu::NPUCachingAllocator::BlockInfo;
using c10_npu::NPUCachingAllocator::DeviceStats;
using c10_npu::NPUCachingAllocator::RecordContext;
using c10_npu::NPUCachingAllocator::SegmentInfo;
using c10_npu::NPUCachingAllocator::Stat;
using c10_npu::NPUCachingAllocator::StatArray;
using c10_npu::NPUCachingAllocator::StatType;
using c10_npu::NPUCachingAllocator::TraceEntry;
using stream_set = ska::flat_hash_set<c10_npu::NPUStream>;
#define NPU_CHECK_SUPPORT_OR_ERROR(err_code, ...) \
do { \
auto Error = err_code; \
static c10_npu::acl::AclErrorCode err_map; \
if ((Error) != ACL_ERROR_NONE) { \
if ((Error) == ACL_ERROR_RT_FEATURE_NOT_SUPPORT) { \
static auto feature_not_support_warn_once = []() { \
printf("[WARN]%s,%s:%u:%s\n", \
__FUNCTION__, __FILENAME__, __LINE__, \
"Feature is not supportted and the possible cause is" \
" that driver and firmware packages do not match."); \
return true; \
}(); \
} else { \
TORCH_CHECK( \
false, \
__func__, \
":", \
__FILE__, \
":", \
__LINE__, \
"\n", c10_npu::c10_npu_get_error_message()); \
} \
} \
} while (0)
typedef std::shared_ptr<c10::GatheredContext> (*CreateContextFn)(void);
constexpr size_t kMinBlockSize = 512;
constexpr size_t kSmallSize = 1048576;
constexpr size_t kSmallBuffer = 2097152;
constexpr size_t kLargeBuffer = 20971520;
constexpr size_t kMinLargeAlloc = 10485760;
constexpr size_t kRoundLarge = 2097152;
constexpr size_t kUnitMB = 1024 * 1024;
using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
void update_stat(Stat &stat, int64_t amount) {
stat.current += amount;
stat.peak = std::max(stat.current, stat.peak);
if (amount > 0) {
stat.allocated += amount;
}
if (amount < 0) {
stat.freed += -amount;
}
}
void reset_accumulated_stat(Stat &stat) {
stat.allocated = 0;
stat.freed = 0;
}
void reset_peak_stat(Stat &stat) { stat.peak = stat.current; }
template <typename Func>
void for_each_selected_stat_type(const StatTypes &stat_types, Func f) {
for (const auto stat_type : c10::irange(stat_types.size())) {
if (stat_types[stat_type]) {
f(stat_type);
}
}
}
void update_stat_array(StatArray &stat_array, int64_t amount, const StatTypes &stat_types) {
for_each_selected_stat_type(stat_types,
[&stat_array, amount](size_t stat_type) { update_stat(stat_array[stat_type], amount); });
}
struct Block;
using Comparison = bool (*)(const Block *, const Block *);
static bool BlockComparatorSize(const Block *a, const Block *b);
static bool BlockComparatorAddress(const Block *a, const Block *b);
struct BlockPool {
std::set<Block *, Comparison> blocks;
std::set<Block *, Comparison> unmapped;
const bool is_small;
BlockPool(bool small) : blocks(BlockComparatorSize), unmapped(BlockComparatorAddress), is_small(small) {}
};
struct ExpandableSegment;
struct Block {
int device;
aclrtStream stream;
stream_set stream_uses;
size_t size;
size_t requested_size;
BlockPool *pool;
void *ptr;
bool allocated;
bool mapped{true};
Block *prev;
Block *next;
int event_count;
int gc_count{0};
ExpandableSegment *expandable_segment_{nullptr};
std::shared_ptr<c10::GatheredContext> context_when_allocated;
std::shared_ptr<c10::GatheredContext> context_when_segment_allocated;
Block(int device, aclrtStream stream, size_t size, BlockPool *pool, void *ptr)
: device(device),
stream(stream),
stream_uses(),
size(size),
requested_size(0),
pool(pool),
ptr(ptr),
allocated(0),
prev(nullptr),
next(nullptr),
event_count(0),
gc_count(0) {}
Block(int device, aclrtStream stream, size_t size)
: device(device),
stream(stream),
stream_uses(),
size(size),
requested_size(0),
pool(nullptr),
ptr(nullptr),
allocated(0),
prev(nullptr),
next(nullptr),
event_count(0),
gc_count(0) {}
bool is_split() const { return (prev != nullptr) || (next != nullptr); }
void splice(Block *before, Block *after) {
if (before) {
TORCH_INTERNAL_ASSERT(before->next == after);
before->next = this;
}
prev = before;
if (after) {
TORCH_INTERNAL_ASSERT(after->prev == before);
after->prev = this;
}
next = after;
}
};
struct SegmentRange {
char *ptr;
size_t size;
SegmentRange(void *p, size_t s) : ptr(static_cast<char *>(p)), size(s) {}
};
struct ExpandableSegment {
ExpandableSegment(int device, aclrtStream stream, size_t size)
: device_(device),
stream_(stream),
max_handles_(0),
segment_size_(size) {
size_t device_free;
size_t device_total;
TORCH_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, &device_free, &device_total) == ACL_ERROR_NONE, "aclrtGetMemInfo failed.");
TORCH_INTERNAL_ASSERT(device_free <= device_total);
constexpr size_t extra_space_factor = 8;
max_handles_ = numSegments(device_total + device_total / extra_space_factor);
TORCH_CHECK(aclrtReserveMemAddress(&ptr_, segment_size_ * max_handles_, 0, NULL, 1) == ACL_ERROR_NONE, \
"Error, failed to reserve memory address");
}
SegmentRange map(SegmentRange range) {
auto begin = segmentLeft(range.ptr);
auto end = segmentRight(range.ptr + range.size);
TORCH_INTERNAL_ASSERT(ptr() + begin * segment_size_ == range.ptr);
if (begin == end) {
return rangeFromHandles(begin, end);
}
while (end > handles_.size()) {
handles_.emplace_back(c10::nullopt);
}
for (auto i : c10::irange(begin, end)) {
TORCH_INTERNAL_ASSERT(!handles_.at(i));
aclrtDrvMemHandle handle = nullptr;
aclrtPhysicalMemProp prop = {};
prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
prop.memAttr = ACL_HBM_MEM_HUGE;
prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device_;
prop.reserve = 0;
auto status = aclrtMallocPhysical(&handle, segment_size_, &prop, 0);
if (status == ACL_ERROR_RT_MEMORY_ALLOCATION) {
for (auto j : c10::irange(begin, i)) {
auto h = handles_.at(j).value();
handles_.at(j) = c10::nullopt;
TORCH_CHECK(aclrtFreePhysical(h) == ACL_ERROR_NONE, "aclrtFreePhysical failed.");
}
trimHandles();
return rangeFromHandles(begin, begin);
}
handles_.at(i) = handle;
}
for (auto i : c10::irange(begin, end)) {
TORCH_CHECK(aclrtMapMem(ptr_ + i * segment_size_, segment_size_, 0, handles_.at(i).value(), 0) == ACL_ERROR_NONE, \
"Error, failed to map memory");
}
return rangeFromHandles(begin, end);
}
SegmentRange unmap(SegmentRange range) {
auto begin = segmentRight(range.ptr);
auto end = segmentLeft(range.ptr + range.size);
if (begin >= end) {
return SegmentRange{range.ptr, 0};
}
unmapHandles(begin, end);
return rangeFromHandles(begin, end);
}
char *ptr() const { return (char *)ptr_; }
size_t size() const { return max_handles_ * segment_size_; }
~ExpandableSegment() {
forEachAllocatedRange([&](size_t begin, size_t end) { unmapHandles(begin, end); });
TORCH_CHECK(aclrtReleaseMemAddress(ptr_) == ACL_ERROR_NONE, "aclrtReleaseMemAddress failed.");
}
private:
void unmapHandles(size_t begin, size_t end) {
TORCH_CHECK(aclrtSynchronizeStream(stream_) == ACL_ERROR_NONE, "aclrtSynchronizeStream failed.");
for (auto i : c10::irange(begin, end)) {
aclrtDrvMemHandle h = handles_.at(i).value();
handles_.at(i) = c10::nullopt;
TORCH_CHECK(aclrtUnmapMem(ptr_ + segment_size_ * i) == ACL_ERROR_NONE, "aclrtUnmapMem failed.");
TORCH_CHECK(aclrtFreePhysical(h) == ACL_ERROR_NONE, "aclrtFreePhysical failed.");
}
trimHandles();
}
void trimHandles() {
while (!handles_.empty() && !handles_.back()) {
handles_.pop_back();
}
}
void forEachAllocatedRange(std::function<void(size_t, size_t)> fn) {
auto start = 0;
for (auto i : c10::irange(handles_.size())) {
if (handles_.at(i) && (i == 0 || !handles_.at(i - 1))) {
start = i;
}
if (handles_.at(i) && (i + 1 == handles_.size() || !handles_.at(i + 1))) {
fn(start, i + 1);
}
}
}
size_t numSegments(size_t size) { return (size + segment_size_ - 1) / segment_size_; }
size_t segmentLeft(char *p) {
auto size = p - ptr();
return size / segment_size_;
}
size_t segmentRight(char *p) {
auto size = p - ptr();
return numSegments(size);
}
SegmentRange rangeFromHandles(size_t begin, size_t end) {
TORCH_INTERNAL_ASSERT(end >= begin);
return SegmentRange(ptr() + segment_size_ * begin, segment_size_ * (end - begin));
}
int device_;
aclrtStream stream_;
void *ptr_{};
size_t max_handles_;
size_t segment_size_;
std::vector<c10::optional<aclrtDrvMemHandle>> handles_;
};
static bool BlockComparatorSize(const Block *a, const Block *b) {
if (a->stream != b->stream) {
return reinterpret_cast<uintptr_t>(a->stream) < reinterpret_cast<uintptr_t>(b->stream);
}
if (a->size != b->size) {
return a->size < b->size;
}
return reinterpret_cast<uintptr_t>(a->ptr) < reinterpret_cast<uintptr_t>(b->ptr);
}
static bool BlockComparatorAddress(const Block *a, const Block *b) {
if (a->stream != b->stream) {
return reinterpret_cast<uintptr_t>(a->stream) < reinterpret_cast<uintptr_t>(b->stream);
}
return reinterpret_cast<uintptr_t>(a->ptr) < reinterpret_cast<uintptr_t>(b->ptr);
}
inline std::string format_size(uint64_t size) {
std::ostringstream os;
os.precision(2);
os << std::fixed;
if (size <= 1024) {
os << size << " bytes";
} else if (size <= 1048576) {
os << (size / 1024.0);
os << " KiB";
} else if (size <= 1073741824ULL) {
os << (size / 1048576.0);
os << " MiB";
} else {
os << (size / 1073741824.0);
os << " GiB";
}
return os.str();
}
struct AllocParams {
AllocParams(int device, size_t size, aclrtStream stream, BlockPool *pool, size_t alloc_size, DeviceStats &stats)
: search_key(device, stream, size), pool(pool), alloc_size(alloc_size), block(nullptr), err(ACL_ERROR_NONE) {}
int device() const { return search_key.device; }
aclrtStream stream() const { return search_key.stream; }
size_t size() const { return search_key.size; }
Block search_key;
BlockPool *pool;
size_t alloc_size;
Block *block;
StatTypes stat_types = {false};
aclError err;
};
class EventPool {
public:
using Event = std::unique_ptr<c10_npu::NPUEvent, std::function<void(c10_npu::NPUEvent *)>>;
EventPool() : pools_(c10_npu::device_count()) {}
Event get(int device) {
TORCH_INTERNAL_ASSERT(0 <= device);
TORCH_INTERNAL_ASSERT(device < static_cast<int>(pools_.size()));
auto &pool = pools_[device];
auto destructor = [&pool](c10_npu::NPUEvent *event) {
std::lock_guard<std::mutex> g(pool.mutex_);
pool.event_pool_.push_back(std::unique_ptr<c10_npu::NPUEvent>(event));
};
{
std::lock_guard<std::mutex> g(pool.mutex_);
if (!pool.event_pool_.empty()) {
auto *event = pool.event_pool_.back().release();
pool.event_pool_.pop_back();
return Event(event, destructor);
}
}
return Event(std::make_unique<c10_npu::NPUEvent>(ACL_EVENT_CAPTURE_STREAM_PROGRESS).release(), destructor);
}
void empty_cache() {
for (auto &pool : pools_) {
std::lock_guard<std::mutex> g(pool.mutex_);
pool.event_pool_.clear();
}
}
private:
struct PerDevicePool {
alignas(64) std::mutex mutex_;
std::vector<std::unique_ptr<c10_npu::NPUEvent>> event_pool_;
};
std::vector<PerDevicePool> pools_;
};
class CachingAllocatorConfig {
public:
static size_t max_split_size() { return instance().m_max_split_size; }
static double garbage_collection_threshold() { return instance().m_garbage_collection_threshold; }
static bool expandable_segments() { return instance().m_expandable_segments; }
static CachingAllocatorConfig &instance() {
static CachingAllocatorConfig *s_instance = ([]() {
auto inst = new CachingAllocatorConfig();
const char *env = getenv("PYTORCH_NPU_ALLOC_CONF");
inst->parseArgs(env);
return inst;
})();
return *s_instance;
}
void parseArgs(const char *env);
private:
size_t m_max_split_size;
double m_garbage_collection_threshold;
bool m_expandable_segments;
bool set_expandable_segments_flag = false;
CachingAllocatorConfig()
: m_max_split_size(std::numeric_limits<size_t>::max()),
m_garbage_collection_threshold(0),
m_expandable_segments(true) {
void *ptr = nullptr;
constexpr size_t virtual_mem_size = 512;
auto status = aclrtReserveMemAddress(&ptr, virtual_mem_size, 0, NULL, 1);
if (status == ACL_ERROR_NONE) {
TORCH_CHECK(aclrtReleaseMemAddress(ptr) == ACL_ERROR_NONE, "aclrtReleaseMemAddress failed.");
} else {
m_expandable_segments = false;
}
}
void lexArgs(const char *env, std::vector<std::string> &config);
void consumeToken(const std::vector<std::string> &config, size_t i, const char c);
size_t parseMaxSplitSize(const std::vector<std::string> &config, size_t i);
size_t parseGarbageCollectionThreshold(const std::vector<std::string> &config, size_t i);
size_t parseExpandableSegments(const std::vector<std::string> &config, size_t i);
};
class DeviceCachingAllocator {
private:
mutable std::recursive_mutex mutex;
DeviceStats stats;
BlockPool large_blocks;
BlockPool small_blocks;
ska::flat_hash_set<Block *> active_blocks;
ska::flat_hash_map<c10_npu::NPUStream, std::deque<std::pair<EventPool::Event, Block *>>> npu_events;
size_t total_allocated_memory = 0;
size_t allowed_memory_maximum = 0;
std::vector<ExpandableSegment *> expandable_segments_;
bool set_fraction = false;
bool record_history = false;
std::atomic<CreateContextFn> context_recorder_;
size_t alloc_trace_next = 0;
RecordContext record_context_ = RecordContext::NEVER;
size_t alloc_trace_max_entries_ = 1;
std::vector<TraceEntry> *alloc_trace;
public:
DeviceCachingAllocator()
: large_blocks(false), small_blocks(true), alloc_trace(new std::vector<TraceEntry>())
{
stats.max_split_size = static_cast<int64_t>(CachingAllocatorConfig::max_split_size());
context_recorder_.store(nullptr);
}
std::shared_ptr<c10::GatheredContext> maybeGatherContext(RecordContext level)
{
if (record_context_ < level) {
return nullptr;
}
return context_recorder_.load()();
}
Block *malloc(int device, size_t orig_size, aclrtStream stream) {
auto context = maybeGatherContext(RecordContext::STATE);
std::unique_lock<std::recursive_mutex> lock(mutex);
if (device == -1) {
TORCH_CHECK(c10_npu::GetDevice(&device) == ACL_ERROR_NONE, "GetDevice failed.");
}
process_events(context);
auto size = round_size(orig_size);
auto &pool = get_pool(size);
const size_t alloc_size = get_allocation_size(size);
AllocParams params(device, size, stream, &pool, alloc_size, stats);
params.stat_types = get_stat_types_for_pool(pool);
bool block_found = false;
while (!block_found) {
block_found =
get_free_block(params) ||
(trigger_free_memory_callbacks(params) && get_free_block(params));
if (!block_found) {
if (C10_UNLIKELY(set_fraction && CachingAllocatorConfig::garbage_collection_threshold() > 0.0)) {
garbage_collect_cached_blocks(context);
}
block_found = alloc_block(params, false, context, lock) ||
(release_available_cached_blocks(params, context) && alloc_block(params, false, context, lock));
}
if (!block_found) {
ASCEND_LOGE(
"Get a block from the existing pool failed. Try to free cached blocks and reallocate. This error log "
"can be ignored.");
block_found = (release_cached_blocks(true, context) && alloc_block(params, true, context, lock));
}
if (!block_found) {
if (params.err == ACL_ERROR_NONE) {
break;
}
PyGILState_STATE state = PyGILState_Ensure();
PyObject *pModule = PyImport_ImportModule("mindspeed.core.memory.common");
if (!pModule) {
PyGILState_Release(state);
std::cout << "No MindSpeed Module" << std::endl;
break;
}
PyObject *pFunc = PyObject_GetAttrString(pModule, "swap_out_by_size");
PyObject *pArgs = PyTuple_New(1);
TORCH_CHECK(PyTuple_SetItem(pArgs, 0, PyLong_FromLong(size)) == 0, "PyTuple_SetItem failed.");
PyObject *pResult = PyObject_CallObject(pFunc, pArgs);
bool ret = false;
TORCH_CHECK(PyArg_Parse(pResult, "p", &ret), "PyArg_Parse failed.");
PyGILState_Release(state);
if (!ret) {
std::cout << "SWAP Failed" << std::endl;
break;
}
params.err = ACL_ERROR_NONE;
}
}
if (!block_found) {
if (params.err == ACL_ERROR_RT_MEMORY_ALLOCATION) {
size_t device_free;
size_t device_total;
TORCH_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, &device_free, &device_total) == ACL_ERROR_NONE, "aclrtGetMemInfo failed.");
TORCH_INTERNAL_ASSERT(device_free <= device_total);
std::string allowed_info;
if (set_fraction) {
allowed_info = format_size(allowed_memory_maximum) + " allowed; ";
}
stats.num_ooms += 1;
AT_ERROR("NPU out of memory. Tried to allocate ", format_size(alloc_size), " (NPU ", device, "; ",
format_size(device_total), " total capacity; ",
format_size(stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current),
" already allocated; ",
format_size(stats.active_bytes[static_cast<size_t>(StatType::AGGREGATE)].current), " current active; ",
format_size(device_free), " free; ", allowed_info,
format_size(stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current),
" reserved in total by PyTorch)",
" If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.");
} else {
params.err;
}
}
bool split_remainder = should_split(params.block, params.size());
return alloc_found_block(std::move(params), orig_size, std::move(context), split_remainder);
}
Block *alloc_found_block(AllocParams params, size_t orig_size, std::shared_ptr<c10::GatheredContext> context,
bool split_remainder)
{
auto size = params.size();
auto device = params.device();
auto pool = params.pool;
auto stream = params.stream();
TORCH_INTERNAL_ASSERT(params.err == ACL_ERROR_NONE && params.block != nullptr && params.block->ptr != nullptr);
Block *block = params.block;
Block *remaining = nullptr;
const bool already_split = block->is_split();
if (split_remainder) {
remaining = block;
block = new Block(device, stream, size, pool, block->ptr);
block->expandable_segment_ = remaining->expandable_segment_;
block->prev = remaining->prev;
if (block->prev) {
block->prev->next = block;
}
block->next = remaining;
remaining->prev = block;
remaining->ptr = static_cast<char *>(remaining->ptr) + size;
remaining->size -= size;
pool->blocks.insert(remaining);
if (already_split && !block->expandable_segment_) {
update_stat_array(stats.inactive_split_bytes, -static_cast<std::int64_t>(block->size), params.stat_types);
} else if (!block->expandable_segment_) {
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(stats.inactive_split_bytes[stat_type], static_cast<std::int64_t>(remaining->size));
update_stat(stats.inactive_split[stat_type], 1);
});
}
} else if (already_split && !block->expandable_segment_) {
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(stats.inactive_split_bytes[stat_type], -static_cast<std::int64_t>(block->size));
update_stat(stats.inactive_split[stat_type], -1);
});
}
block->allocated = true;
block->requested_size = orig_size;
block->context_when_allocated = std::move(context);
record_trace(TraceEntry::ALLOC, int64_t(block->ptr), orig_size, block->stream, block->device,
block->context_when_allocated);
active_blocks.insert(block);
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(stats.allocation[stat_type], 1);
update_stat(stats.allocated_bytes[stat_type], static_cast<std::int64_t>(block->size));
update_stat(stats.active[stat_type], 1);
update_stat(stats.active_bytes[stat_type], static_cast<std::int64_t>(block->size));
update_stat(stats.requested_bytes[stat_type], static_cast<std::int64_t>(block->requested_size));
});
if (block->size >= CachingAllocatorConfig::max_split_size()) update_stat(stats.oversize_allocations, 1);
ASCEND_LOGD("PTA CachingAllocator malloc: malloc = %zu, cached = %lu, allocated = %lu", block->size,
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current);
return block;
}
void free(Block *block) {
std::shared_ptr<c10::GatheredContext> context = maybeGatherContext(RecordContext::ALL);
std::lock_guard<std::recursive_mutex> lock(mutex);
block->allocated = false;
auto orig_block_ptr = block->ptr;
auto orig_block_size = block->size;
StatTypes stat_types = get_stat_types_for_pool(*(block->pool));
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(stats.allocation[stat_type], -1);
update_stat(stats.allocated_bytes[stat_type], -block->size);
});
record_trace(TraceEntry::FREE_REQUESTED, int64_t(block->ptr), block->requested_size, block->stream, block->device,
context ? context : block->context_when_allocated);
if (block->size >= CachingAllocatorConfig::max_split_size()) update_stat(stats.oversize_allocations, -1);
if (!block->stream_uses.empty() && c10_npu::NpuSysCtrl::GetInstance().GetInitFlag()) {
insert_events(block);
} else {
free_block(block, context);
}
ASCEND_LOGD("PTA CachingAllocator free: free = %zu, cached = %lu, allocated = %lu", orig_block_size,
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current);
}
void emptyCache(bool check_error) {
std::shared_ptr<c10::GatheredContext> context = maybeGatherContext(RecordContext::ALL);
std::lock_guard<std::recursive_mutex> lock(mutex);
TORCH_CHECK(release_cached_blocks(check_error, context), "release_cached_blocks failed.");
}
DeviceStats getStats() {
std::lock_guard<std::recursive_mutex> lock(mutex);
return stats;
}
void resetAccumulatedStats() {
std::lock_guard<std::recursive_mutex> lock(mutex);
for (size_t statType = 0; statType < static_cast<size_t>(StatType::NUM_TYPES); ++statType) {
reset_accumulated_stat(stats.allocation[statType]);
reset_accumulated_stat(stats.segment[statType]);
reset_accumulated_stat(stats.active[statType]);
reset_accumulated_stat(stats.inactive_split[statType]);
reset_accumulated_stat(stats.allocated_bytes[statType]);
reset_accumulated_stat(stats.reserved_bytes[statType]);
reset_accumulated_stat(stats.active_bytes[statType]);
reset_accumulated_stat(stats.inactive_split_bytes[statType]);
reset_accumulated_stat(stats.requested_bytes[statType]);
}
stats.num_alloc_retries = 0;
stats.num_ooms = 0;
reset_accumulated_stat(stats.oversize_allocations);
reset_accumulated_stat(stats.oversize_segments);
}
void resetPeakStats() {
std::lock_guard<std::recursive_mutex> lock(mutex);
for (size_t statType = 0; statType < static_cast<size_t>(StatType::NUM_TYPES); ++statType) {
reset_peak_stat(stats.allocation[statType]);
reset_peak_stat(stats.segment[statType]);
reset_peak_stat(stats.active[statType]);
reset_peak_stat(stats.inactive_split[statType]);
reset_peak_stat(stats.allocated_bytes[statType]);
reset_peak_stat(stats.reserved_bytes[statType]);
reset_peak_stat(stats.active_bytes[statType]);
reset_peak_stat(stats.inactive_split_bytes[statType]);
reset_peak_stat(stats.requested_bytes[statType]);
}
reset_peak_stat(stats.oversize_allocations);
reset_peak_stat(stats.oversize_segments);
}
std::vector<TraceEntry> trace()
{
std::lock_guard<std::recursive_mutex> lock(mutex);
std::vector<TraceEntry> result;
result.reserve(alloc_trace->size());
result.insert(result.end(), alloc_trace->begin() + alloc_trace_next, alloc_trace->end());
result.insert(result.end(), alloc_trace->begin(), alloc_trace->begin() + alloc_trace_next);
return result;
}
static size_t round_size(size_t size) {
const size_t align_size = 32;
size = size + align_size;
if (size < kMinBlockSize) {
return kMinBlockSize;
} else {
return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize);
}
}
private:
std::vector<const Block *> get_all_blocks() const {
std::vector<const Block *> blocks;
blocks.insert(blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end());
blocks.insert(blocks.end(), large_blocks.blocks.begin(), large_blocks.blocks.end());
blocks.insert(blocks.end(), active_blocks.begin(), active_blocks.end());
return blocks;
}
Block *find_expandable_block(int device, aclrtStream stream, BlockPool *pool, size_t size) {
Block key(device, stream, 0);
auto allocatable = [](Block *b) { return b && !b->allocated && b->event_count == 0 && b->stream_uses.empty(); };
auto has_available_address_space = [&](Block *b) {
size_t bytes = 0;
while (bytes < size && allocatable(b)) {
bytes += b->size;
b = b->next;
}
return bytes >= size;
};
for (auto it = pool->unmapped.lower_bound(&key); it != pool->unmapped.end() && (*it)->stream == stream; ++it) {
Block *c = *it;
if (allocatable(c->prev)) {
c = c->prev;
}
if (has_available_address_space(c)) {
return c;
}
}
auto segment_size = pool->is_small ? kSmallBuffer : kLargeBuffer;
expandable_segments_.emplace_back(new ExpandableSegment(device, stream, segment_size));
ExpandableSegment *es = expandable_segments_.back();
Block *candidate = new Block(device, stream, es->size(), pool, es->ptr());
candidate->mapped = false;
candidate->expandable_segment_ = es;
pool->unmapped.insert(candidate);
return candidate;
}
bool map_block(Block *to_map, size_t size, const std::shared_ptr<c10::GatheredContext> &ctx)
{
TORCH_INTERNAL_ASSERT(!to_map->mapped && size <= to_map->size);
auto mapped_range = to_map->expandable_segment_->map(SegmentRange{to_map->ptr, size});
if (mapped_range.size == 0) {
return false;
}
TORCH_INTERNAL_ASSERT(mapped_range.ptr == to_map->ptr && mapped_range.size >= size);
BlockPool &pool = *to_map->pool;
pool.unmapped.erase(to_map);
to_map->mapped = true;
if (mapped_range.size < to_map->size) {
Block *remaining = new Block(to_map->device, to_map->stream, to_map->size - mapped_range.size, &pool,
static_cast<char *>(to_map->ptr) + mapped_range.size);
remaining->mapped = false;
remaining->expandable_segment_ = to_map->expandable_segment_;
remaining->splice(to_map, to_map->next);
pool.unmapped.insert(remaining);
to_map->size = mapped_range.size;
}
TORCH_CHECK(try_merge_blocks(to_map, to_map->prev, pool) >= 0, "try_merge_blocks failed.");
TORCH_CHECK(try_merge_blocks(to_map, to_map->next, pool) >= 0, "try_merge_blocks failed.");
pool.blocks.insert(to_map);
total_allocated_memory += mapped_range.size;
StatTypes stat_types = get_stat_types_for_pool(*to_map->pool);
for_each_selected_stat_type(
stat_types, [&](size_t stat_type) { update_stat(stats.reserved_bytes[stat_type], mapped_range.size); });
record_trace(TraceEntry::SEGMENT_MAP, int64_t(mapped_range.ptr), mapped_range.size, to_map->stream, to_map->device,
ctx);
if (!to_map->prev && !to_map->context_when_segment_allocated) {
to_map->context_when_segment_allocated = ctx;
}
return true;
}
Block *try_allocate_expandable_block(int device, aclrtStream stream, BlockPool *pool, size_t size,
const std::shared_ptr<c10::GatheredContext> &ctx)
{
Block *candidate = find_expandable_block(device, stream, pool, size);
if (!candidate->mapped && !map_block(candidate, std::min(candidate->size, size), ctx)) {
return nullptr;
}
TORCH_INTERNAL_ASSERT(candidate->mapped);
while (candidate->size < size) {
auto remaining = size - candidate->size;
auto new_candidate = candidate->next;
if (!map_block(new_candidate, std::min(remaining, candidate->next->size), ctx)) {
return nullptr;
}
candidate = new_candidate;
}
pool->blocks.erase(candidate);
return candidate;
}
void free_block(Block *block, const std::shared_ptr<c10::GatheredContext> &context)
{
AT_ASSERT(!block->allocated && block->event_count == 0);
record_trace(TraceEntry::FREE_COMPLETED, int64_t(block->ptr), block->requested_size, block->stream, block->device,
context ? context : block->context_when_allocated);
block->context_when_allocated = nullptr;
size_t original_block_size = block->size;
size_t requested_size = block->requested_size;
auto &pool = *block->pool;
int64_t net_change_inactive_split_blocks = 0;
int64_t net_change_inactive_split_size = 0;
const std::array<Block *, 2> merge_candidates = {block->prev, block->next};
for (Block *merge_candidate : merge_candidates) {
const int64_t subsumed_size = static_cast<int64_t>(try_merge_blocks(block, merge_candidate, pool));
if (subsumed_size > 0) {
net_change_inactive_split_blocks -= 1;
net_change_inactive_split_size -= subsumed_size;
}
}
active_blocks.erase(block);
pool.blocks.insert(block);
if (block->is_split()) {
net_change_inactive_split_blocks += 1;
net_change_inactive_split_size += static_cast<int64_t>(block->size);
}
StatTypes stat_types = get_stat_types_for_pool(pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
if (!block->expandable_segment_) {
update_stat(stats.inactive_split[stat_type], net_change_inactive_split_blocks);
update_stat(stats.inactive_split_bytes[stat_type], net_change_inactive_split_size);
}
update_stat(stats.active[stat_type], -1);
update_stat(stats.active_bytes[stat_type], -original_block_size);
update_stat(stats.requested_bytes[stat_type], -static_cast<std::int64_t>(requested_size));
});
}
size_t try_merge_blocks(Block *dst, Block *src, BlockPool &pool) {
if (!src || src->allocated || src->event_count > 0 || !src->stream_uses.empty() || dst->mapped != src->mapped) {
return 0;
}
AT_ASSERT(dst->is_split() && src->is_split());
if (dst->prev == src) {
dst->ptr = src->ptr;
dst->prev = src->prev;
if (dst->prev) {
dst->prev->next = dst;
}
} else {
dst->next = src->next;
if (dst->next) {
dst->next->prev = dst;
}
}
const size_t subsumed_size = src->size;
dst->size += subsumed_size;
auto erased = src->mapped ? pool.blocks.erase(src) : pool.unmapped.erase(src);
delete src;
src = nullptr;
return subsumed_size;
}
BlockPool &get_pool(size_t size) {
if (size <= kSmallSize) {
return small_blocks;
} else {
return large_blocks;
}
}
StatTypes get_stat_types_for_pool(const BlockPool &pool) {
StatTypes stat_types = {false};
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
stat_types[static_cast<size_t>(pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL)] = true;
return stat_types;
}
bool should_split(const Block *block, size_t size) {
TORCH_INTERNAL_ASSERT(block->size >= size);
size_t remaining = block->size - size;
if (block->pool->is_small || CachingAllocatorConfig::expandable_segments()) {
return remaining >= kMinBlockSize;
} else {
return (size < CachingAllocatorConfig::max_split_size()) && (remaining > kSmallSize);
}
}
static size_t get_allocation_size(size_t size) {
if (size <= kSmallSize) {
return kSmallBuffer;
} else if (size < kMinLargeAlloc) {
return kLargeBuffer;
} else {
return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
}
}
bool get_free_block(AllocParams &p) {
BlockPool &pool = *p.pool;
if (C10_UNLIKELY(set_fraction && CachingAllocatorConfig::garbage_collection_threshold() > 0.0)) {
for (auto &b : pool.blocks) {
++b->gc_count;
}
}
auto it = pool.blocks.lower_bound(&p.search_key);
if (it == pool.blocks.end() || (*it)->stream != p.stream()) {
return false;
}
if ((*it)->expandable_segment_) {
if (CachingAllocatorConfig::expandable_segments()) {
auto expandable_size = [](Block *b) { return b->size + (b->next && !b->next->mapped ? b->next->size : 0); };
auto next = it;
next++;
while ((*it)->expandable_segment_ && next != pool.blocks.end() && (*next)->stream == p.stream() &&
expandable_size(*next) < expandable_size(*it)) {
it = next++;
}
} else {
do {
it++;
} while (it != pool.blocks.end() && (*it)->expandable_segment_ && (*it)->stream == p.stream());
if (it == pool.blocks.end() || (*it)->stream != p.stream()) {
return false;
}
}
}
if ((p.size() < CachingAllocatorConfig::max_split_size()) &&
((*it)->size >= CachingAllocatorConfig::max_split_size())) {
return false;
}
if ((p.size() >= CachingAllocatorConfig::max_split_size()) && ((*it)->size >= p.size() + kLargeBuffer)) {
return false;
}
p.block = *it;
(*it)->gc_count = 0;
pool.blocks.erase(it);
return true;
}
bool trigger_free_memory_callbacks(AllocParams &p) {
bool freed_memory = false;
return freed_memory;
}
void garbage_collect_cached_blocks(const std::shared_ptr<c10::GatheredContext> &ctx)
{
size_t gc_threshold =
static_cast<size_t>(CachingAllocatorConfig::garbage_collection_threshold() * allowed_memory_maximum);
if (total_allocated_memory <= gc_threshold) {
return;
}
const auto target_size = total_allocated_memory - gc_threshold;
size_t gc_reclaimed = 0;
double total_age = 0.0;
int freeable_block_count = 0;
for (auto &b : large_blocks.blocks) {
if (!b->is_split()) {
total_age += b->gc_count;
++freeable_block_count;
}
}
if (freeable_block_count == 0) {
return;
}
TORCH_CHECK(c10_npu::npuSynchronizeDevice(true), "npuSynchronizeDevice failed.");
bool block_freed = true;
while (gc_reclaimed < target_size && block_freed == true && freeable_block_count > 0) {
double age_threshold = total_age / freeable_block_count;
block_freed = false;
auto it = large_blocks.blocks.begin();
while (it != large_blocks.blocks.end()) {
Block *block = *it;
++it;
if (!block->is_split() && block->gc_count >= age_threshold) {
block_freed = true;
gc_reclaimed += block->size;
total_age -= block->gc_count;
freeable_block_count--;
release_block(block, ctx);
ASCEND_LOGD("PTA CachingAllocator gc: free = %zu, cached = %lu, allocated = %lu", block->size,
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current);
}
}
}
}
bool alloc_block(AllocParams &p, bool isRetry, const std::shared_ptr<c10::GatheredContext> &ctx,
std::unique_lock<std::recursive_mutex> &lock)
{
size_t size = p.alloc_size;
void *ptr = nullptr;
if (isRetry) {
stats.num_alloc_retries += 1;
}
if (set_fraction && total_allocated_memory + size > allowed_memory_maximum) {
p.err = ACL_ERROR_RT_MEMORY_ALLOCATION;
} else if (CachingAllocatorConfig::expandable_segments()) {
p.block = try_allocate_expandable_block(p.device(), p.stream(), p.pool, p.size(), ctx);
if (p.block) {
p.err = ACL_ERROR_NONE;
} else {
p.err = ACL_ERROR_RT_MEMORY_ALLOCATION;
}
return bool(p.block);
} else {
p.err = aclrtMallocAlign32(&ptr, size, aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST);
}
if (p.err != ACL_ERROR_NONE) {
p.err = ACL_ERROR_RT_MEMORY_ALLOCATION;
return false;
}
total_allocated_memory += size;
p.block = new Block(p.device(), p.stream(), size, p.pool, (char *)ptr);
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
update_stat(stats.segment[stat_type], 1);
update_stat(stats.reserved_bytes[stat_type], size);
});
if (size >= CachingAllocatorConfig::max_split_size()) update_stat(stats.oversize_segments, 1);
ASCEND_LOGD("pta_memory acl_malloc: malloc = %zu, ret = %d", size, p.err);
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
record_trace(TraceEntry::SEGMENT_ALLOC, int64_t(p.block->ptr), p.block->size, p.stream(), p.device(), ctx);
p.block->context_when_segment_allocated = ctx;
return true;
}
bool release_available_cached_blocks(const AllocParams &p, const std::shared_ptr<c10::GatheredContext> &ctx)
{
if (CachingAllocatorConfig::max_split_size() == std::numeric_limits<size_t>::max()) {
return false;
}
BlockPool &pool = *p.pool;
Block key = p.search_key;
key.size =
(key.size < CachingAllocatorConfig::max_split_size()) ? CachingAllocatorConfig::max_split_size() : key.size;
auto it = pool.blocks.lower_bound(&key);
TORCH_CHECK(c10_npu::npuSynchronizeDevice(true), "npuSynchronizeDevice failed.");
if (it == pool.blocks.end() || (*it)->stream != p.stream()) {
if (it == pool.blocks.begin()) {
return false;
}
size_t totalReleased = 0;
--it;
while ((totalReleased < key.size) && ((*it)->size >= CachingAllocatorConfig::max_split_size()) &&
((*it)->stream == p.stream())) {
auto cur = it;
totalReleased += (*it)->size;
if (it != pool.blocks.begin()) {
--it;
release_block(*cur, ctx);
} else {
release_block(*cur, ctx);
break;
}
}
if (totalReleased < key.size) {
return false;
}
} else {
release_block(*it, ctx);
}
return true;
}
bool release_cached_blocks(bool check_error, const std::shared_ptr<c10::GatheredContext> &context)
{
TORCH_CHECK(c10_npu::npuSynchronizeDevice(check_error), "npuSynchronizeDevice failed.");
synchronize_and_free_events(check_error, context);
release_blocks(large_blocks, context);
release_blocks(small_blocks, context);
return true;
}
void release_expandable_segment(Block *block) {
TORCH_INTERNAL_ASSERT(block->size == block->expandable_segment_->size(), "block disagrees with segment");
TORCH_INTERNAL_ASSERT(!block->mapped);
auto it = std::find(expandable_segments_.begin(), expandable_segments_.end(), block->expandable_segment_);
TORCH_INTERNAL_ASSERT(it != expandable_segments_.end());
expandable_segments_.erase(it);
block->pool->unmapped.erase(block);
delete block->expandable_segment_;
block->expandable_segment_ = nullptr;
delete block;
block = nullptr;
}
void release_block(Block *block, const std::shared_ptr<c10::GatheredContext> &context)
{
TORCH_INTERNAL_ASSERT(!block->expandable_segment_);
record_trace(TraceEntry::SEGMENT_FREE, int64_t(block->ptr), block->size, block->stream, block->device,
context ? context : block->context_when_segment_allocated);
TORCH_CHECK(aclrtFree((void *)block->ptr) == ACL_ERROR_NONE, "aclrtFree failed.");
total_allocated_memory -= block->size;
auto *pool = block->pool;
StatTypes stat_types = get_stat_types_for_pool(*pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(stats.segment[stat_type], -1);
update_stat(stats.reserved_bytes[stat_type], -block->size);
});
if (block->size >= CachingAllocatorConfig::max_split_size()) update_stat(stats.oversize_segments, -1);
ASCEND_LOGD("pta_memory acl_free: free_size = %zu", block->size);
pool->blocks.erase(block);
delete block;
block = nullptr;
}
void unmap_block(Block *block, const std::shared_ptr<c10::GatheredContext> &context)
{
auto unmapped = block->expandable_segment_->unmap(SegmentRange{block->ptr, block->size});
if (unmapped.size == 0) {
return;
}
block->pool->blocks.erase(block);
ptrdiff_t before_size = static_cast<char *>(unmapped.ptr) - static_cast<char *>(block->ptr);
if (before_size > 0) {
Block *before_free = new Block(block->device, block->stream, before_size, block->pool, block->ptr);
before_free->expandable_segment_ = block->expandable_segment_;
before_free->splice(block->prev, block);
block->pool->blocks.insert(before_free);
}
TORCH_CHECK(block->size >= before_size + unmapped.size, "after size should be greater than or equal to 0");
auto after_size = block->size - (before_size + unmapped.size);
if (after_size > 0) {
Block *after_free = new Block(block->device, block->stream, after_size, block->pool,
static_cast<char *>(unmapped.ptr) + unmapped.size);
after_free->expandable_segment_ = block->expandable_segment_;
after_free->splice(block, block->next);
block->pool->blocks.insert(after_free);
}
block->ptr = unmapped.ptr;
block->size = unmapped.size;
block->mapped = false;
TORCH_CHECK(try_merge_blocks(block, block->prev, *block->pool) >= 0, "try_merge_blocks failed.");
TORCH_CHECK(try_merge_blocks(block, block->next, *block->pool) >= 0, "try_merge_blocks failed.");
block->pool->unmapped.insert(block);
total_allocated_memory -= unmapped.size;
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
for_each_selected_stat_type(
stat_types, [&](size_t stat_type) { update_stat(stats.reserved_bytes[stat_type], -unmapped.size); });
record_trace(TraceEntry::SEGMENT_UNMAP, int64_t(unmapped.ptr), unmapped.size, block->stream, block->device,
context ? context : block->context_when_segment_allocated);
}
void release_blocks(BlockPool &pool, const std::shared_ptr<c10::GatheredContext> &context)
{
std::vector<Block *> to_unmap;
auto it = pool.blocks.begin();
while (it != pool.blocks.end()) {
Block *block = *it;
++it;
if (block->expandable_segment_) {
to_unmap.push_back(block);
} else if (!block->prev && !block->next) {
release_block(block, context);
}
}
for (Block *block : to_unmap) {
unmap_block(block, context);
if (!block->prev && !block->next) {
release_expandable_segment(block);
}
}
}
EventPool::Event create_event_internal(int idx) {
static auto *event_pool = new EventPool();
return event_pool->get(idx);
}
void synchronize_and_free_events(bool check_error, const std::shared_ptr<c10::GatheredContext> &context)
{
for (auto &st : npu_events) {
for (auto &e : st.second) {
EventPool::Event event = std::move(e.first);
Block *block = e.second;
if (check_error) {
TORCH_CHECK(aclrtSynchronizeEvent(*event) == ACL_ERROR_NONE, "aclrtSynchronizeEvent failed.");
} else {
TORCH_CHECK(aclrtSynchronizeEvent(*event) == ACL_ERROR_NONE, "aclrtSynchronizeEvent failed");
}
ASCEND_LOGI("Event: aclrtSynchronizeEvent is successfully executed");
block->event_count--;
if (block->event_count == 0) {
free_block(block, context);
}
}
}
npu_events.clear();
}
void insert_events(Block *block) {
aclrtContext compiler_ctx = aclrtContext();
aclError ret_ctx = aclrtGetCurrentContext(&compiler_ctx);
stream_set streams(std::move(block->stream_uses));
AT_ASSERT(block->stream_uses.empty());
for (auto &stream : streams) {
TORCH_CHECK(c10_npu::SetDevice(stream.device_index()) == ACL_ERROR_NONE, "SetDevice failed.");
EventPool::Event event = create_event_internal(stream.device_index());
event->record(stream);
ASCEND_LOGI("Event: record DeviceAllocator is successfully executed");
block->event_count++;
npu_events[stream].emplace_back(std::move(event), block);
}
if (ret_ctx == ACL_ERROR_NONE) {
TORCH_CHECK(aclrtSetCurrentContext(compiler_ctx) == ACL_ERROR_NONE, "aclrtSetCurrentContext failed.");
}
}
void process_events(const std::shared_ptr<c10::GatheredContext> &context)
{
for (auto it = npu_events.begin(); it != npu_events.end();) {
while (!it->second.empty()) {
auto &e = it->second.front();
EventPool::Event event = std::move(e.first);
Block *block = e.second;
if (!event->query()) {
e.first = std::move(event);
break;
}
block->event_count--;
if (block->event_count == 0) {
free_block(block, context);
}
it->second.pop_front();
}
if (it->second.empty()) {
it = npu_events.erase(it);
} else {
it++;
}
}
}
void record_trace(TraceEntry::Action action, int64_t addr, size_t size, aclrtStream stream, int device,
std::shared_ptr<c10::GatheredContext> context)
{
if (!record_history) {
return;
}
auto te = TraceEntry(action, device, addr, size, stream,
record_context_ >= RecordContext::ALLOC ? std::move(context) : nullptr);
if (record_history) {
if (alloc_trace->size() < alloc_trace_max_entries_) {
alloc_trace->emplace_back(te);
} else {
(*alloc_trace)[alloc_trace_next++] = te;
if (alloc_trace_next == alloc_trace_max_entries_) {
alloc_trace_next = 0;
}
}
}
}
};
void local_raw_delete(void *ptr);
class NpuCachingCustomAllocator {
private:
std::mutex mutex;
ska::flat_hash_map<void *, Block *> allocated_blocks;
void add_allocated_block(Block *block) {
std::lock_guard<std::mutex> lock(mutex);
allocated_blocks[block->ptr] = block;
}
public:
std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocator;
std::mutex *getFreeMutex() const;
Block *get_allocated_block(void *ptr, bool remove = false);
void setMemoryFraction(double fraction, int device);
void init(int device_count);
bool initialized();
void emptyCache(bool check_error);
DeviceStats getDeviceStats(int device);
void resetPeakStats(int device);
std::string name();
void *malloc(int device, size_t size, aclrtStream stream);
void free(void *ptr);
void assertValidDevice(int device);
};
extern NpuCachingCustomAllocator my_allocator;
extern "C" {
void *my_malloc(size_t size, int device, aclrtStream stream) {
void *ptr = nullptr;
if (size == 0) {
return ptr;
}
ptr = my_allocator.malloc(device, size, stream);
return ptr;
}
void my_free(void *ptr, size_t size, int device, aclrtStream stream) { my_allocator.free(ptr); }
void my_init(int device_count) { my_allocator.init(device_count); }
void my_empty_cache(bool check_error) { my_allocator.emptyCache(true); }
DeviceStats my_get_device_stats(int device) { return my_allocator.getDeviceStats(device); }
void my_reset_peak_stats(int device) { return my_allocator.resetPeakStats(device); }
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}