#include <mutex>
#include <unordered_map>
#include <utility>
#include "torch_npu/csrc/npu/NPUPluggableAllocator.h"
#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"
#include "torch_npu/csrc/core/npu/NPUGuard.h"
#ifndef BUILD_LIBTORCH
#include "torch_npu/csrc/sanitizer/NPUTrace.h"
#endif
namespace torch::npu::NPUPluggableAllocator {
int device_count = 0;
void custom_raw_deleter(void* ptr);
_AllocationMetadata::_AllocationMetadata()
: size(0), device_idx(-1), stream{} {}
_AllocationMetadata::_AllocationMetadata(
size_t size,
int device_idx,
aclrtStream stream)
: size(size), device_idx(device_idx), stream(stream) {}
NPUPluggableAllocator::NPUPluggableAllocator(
std::function<void*(size_t, int, aclrtStream)> alloc_fn,
std::function<void(void*, size_t, int, aclrtStream)> free_fn)
: alloc_fn_(std::move(alloc_fn)), free_fn_(std::move(free_fn)) {}
NPUPluggableAllocator::NPUPluggableAllocator(NPUPluggableAllocator& other)
: alloc_fn_(other.alloc_fn_),
free_fn_(other.free_fn_),
init_fn_(other.init_fn_),
reset_fn_(other.reset_fn_),
memory_fraction_fn_(other.memory_fraction_fn_),
base_alloc_fn_(other.base_alloc_fn_),
record_stream_fn_(other.record_stream_fn_),
erase_stream_fn_(other.erase_stream_fn_),
get_device_stats_fn_(other.get_device_stats_fn_),
reset_peak_status_fn_(other.reset_peak_status_fn_),
begin_allocate_to_pool_fn_(other.begin_allocate_to_pool_fn_),
end_allocate_to_pool_fn_(other.end_allocate_to_pool_fn_),
release_pool_fn_(other.release_pool_fn_) {}
void NPUPluggableAllocator::set_init_fn(std::function<void(int)> init_fn)
{
init_fn_ = std::move(init_fn);
}
void NPUPluggableAllocator::set_reset_fn(std::function<void(bool)> reset_fn)
{
reset_fn_ = std::move(reset_fn);
}
void NPUPluggableAllocator::set_memory_fraction_fn(
std::function<void(double, int)> memory_fraction_fn)
{
memory_fraction_fn_ = std::move(memory_fraction_fn);
}
void NPUPluggableAllocator::set_base_alloc_fn(
std::function<void*(void*, size_t*)> base_alloc_fn)
{
base_alloc_fn_ = std::move(base_alloc_fn);
}
void NPUPluggableAllocator::set_record_stream_fn(
std::function<void(void* ptr, c10_npu::NPUStream stream)> record_stream_fn)
{
record_stream_fn_ = std::move(record_stream_fn);
}
void NPUPluggableAllocator::set_erase_stream_fn(
std::function<void(void* ptr, c10_npu::NPUStream stream)> erase_stream_fn)
{
erase_stream_fn_ = std::move(erase_stream_fn);
}
void NPUPluggableAllocator::set_get_device_stats_fn(
std::function<c10_npu::NPUCachingAllocator::DeviceStats(int)> get_device_stats_fn)
{
get_device_stats_fn_ = std::move(get_device_stats_fn);
}
void NPUPluggableAllocator::set_reset_peak_status_fn(
std::function<void(int)> reset_peak_status_fn)
{
reset_peak_status_fn_ = std::move(reset_peak_status_fn);
}
void NPUPluggableAllocator::set_begin_allocate_to_pool(
std::function<void(int, c10_npu::MempoolId_t, std::function<bool(aclrtStream)>)> capture_begin_fn)
{
begin_allocate_to_pool_fn_ = std::move(capture_begin_fn);
}
void NPUPluggableAllocator::set_end_allocate_to_pool_fn(
std::function<void(int, c10_npu::MempoolId_t)> capture_about_to_end_fn)
{
end_allocate_to_pool_fn_ = std::move(capture_about_to_end_fn);
}
void NPUPluggableAllocator::set_release_pool(
std::function<void(int, c10_npu::MempoolId_t)> capture_destroy_fn)
{
release_pool_fn_ = std::move(capture_destroy_fn);
}
void* NPUPluggableAllocator::malloc(
size_t size,
int device,
aclrtStream stream)
{
void* r = alloc_fn_(size, device, stream);
{
const std::lock_guard<std::mutex> lock(allocator_mutex_);
allocation_metadata_.emplace(r, _AllocationMetadata(size, device, stream));
}
#ifndef BUILD_LIBTORCH
if (r) {
const c10_npu::impl::PyCallbackTrigger *trigger = c10_npu::impl::NPUTrace::getTrace();
if (C10_UNLIKELY(trigger)) {
trigger->traceNpuMemoryAllocation(reinterpret_cast<uintptr_t>(r));
}
}
#endif
return r;
}
c10::DataPtr NPUPluggableAllocator::allocate(size_t size)
{
int device = -1;
NPU_CHECK_ERROR(c10_npu::GetDevice(&device));
aclrtStream stream = c10_npu::getCurrentNPUStreamNoWait(device);
void* r =
this->malloc(size, device, stream);
c10::DataPtr data_ptr = {
r,
r,
raw_deleter(),
c10::Device(
c10::DeviceType::PrivateUse1, static_cast<c10::DeviceIndex>(device))};
return data_ptr;
}
c10::DataPtr NPUPluggableAllocator::allocate_with_aligned(size_t size, size_t base_addr_aligned_kb) const
{
TORCH_CHECK(false, "NPUPluggableAllocator does't has allocate_with_aligned", PTA_ERROR(ErrCode::NOT_SUPPORT));
return c10::DataPtr();
}
c10::DeleterFnPtr NPUPluggableAllocator::raw_deleter() const
{
return &custom_raw_deleter;
}
void* NPUPluggableAllocator::raw_alloc(size_t nbytes)
{
int device = -1;
NPU_CHECK_ERROR(c10_npu::GetDevice(&device));
aclrtStream stream = c10_npu::getCurrentNPUStreamNoWait(device);
return malloc(nbytes, device, stream);
}
void* NPUPluggableAllocator::raw_alloc_with_stream(
size_t nbytes,
aclrtStream stream)
{
int device = -1;
NPU_CHECK_ERROR(c10_npu::GetDevice(&device));
return malloc(nbytes, device, stream);
}
void NPUPluggableAllocator::raw_delete(void* ptr)
{
aclrtStream stream{};
int device_idx = -1;
size_t size = 0;
{
const std::lock_guard<std::mutex> lock(allocator_mutex_);
TORCH_CHECK(
allocation_metadata_.count(ptr),
"Trying to free a pointer not allocated here", PTA_ERROR(ErrCode::PTR));
_AllocationMetadata& metadata = allocation_metadata_[ptr];
size = metadata.size;
device_idx = metadata.device_idx;
stream = metadata.stream;
allocation_metadata_.erase(ptr);
}
#ifndef BUILD_LIBTORCH
const c10_npu::impl::PyCallbackTrigger *trigger = c10_npu::impl::NPUTrace::getTrace();
if (C10_UNLIKELY(trigger)) {
trigger->traceNpuMemoryDeallocation(reinterpret_cast<uintptr_t>(ptr));
}
#endif
free_fn_(ptr, size, device_idx, stream);
}
void NPUPluggableAllocator::init(int device_count)
{
if (init_fn_) {
init_fn_(device_count);
}
initialized_ = true;
}
bool NPUPluggableAllocator::initialized()
{
return initialized_;
}
double NPUPluggableAllocator::getMemoryFraction(int device)
{
TORCH_CHECK(
false,
"NPUPluggableAllocator does not yet support getMemoryFraction. "
"If you need it, please file an issue describing your use case.");
}
void NPUPluggableAllocator::setMemoryFraction(double fraction, int device)
{
if (memory_fraction_fn_) {
memory_fraction_fn_(fraction, device);
}
}
void NPUPluggableAllocator::emptyCacheImpl(bool check_error, bool free_physical)
{
TORCH_NPU_WARN("NPUPluggableAllocator does not yet support emptyCacheImpl. "
"If you need it, please file an issue describing your use case.");
}
void NPUPluggableAllocator::emptyCache(bool check_error)
{
if (reset_fn_) {
return reset_fn_(check_error);
}
}
void NPUPluggableAllocator::emptyVirtAddrCache(bool check_error)
{
TORCH_NPU_WARN("NPUPluggableAllocator does not yet support emptyVirtAddrCache. "
"If you need it, please file an issue describing your use case.");
}
void NPUPluggableAllocator::cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock)
{
TORCH_NPU_WARN("NPUPluggableAllocator does not yet support cacheInfo. "
"If you need it, please file an issue describing your use case.");
}
void* NPUPluggableAllocator::getBaseAllocation(void* ptr, size_t* size)
{
if (base_alloc_fn_) {
return base_alloc_fn_(ptr, size);
} else {
return ptr;
}
}
void NPUPluggableAllocator::recordStream(
const c10::DataPtr& ptr,
streamType stream)
{
#ifndef BUILD_LIBTORCH
if (ptr.get()) {
const c10_npu::impl::PyCallbackTrigger *trigger = c10_npu::impl::NPUTrace::getTrace();
if (C10_UNLIKELY(trigger)) {
trigger->traceNpuRecordStream(
reinterpret_cast<uintptr_t>(ptr.get()),
reinterpret_cast<uintptr_t>(stream.stream(false)));
}
}
#endif
if (record_stream_fn_) {
record_stream_fn_(ptr.get(), stream);
}
}
void NPUPluggableAllocator::eraseStream(
const c10::DataPtr& ptr,
streamType stream)
{
if (erase_stream_fn_) {
erase_stream_fn_(ptr.get(), stream);
}
#ifndef BUILD_LIBTORCH
if (ptr.get()) {
const c10_npu::impl::PyCallbackTrigger* trigger = c10_npu::impl::NPUTrace::getTrace();
if (C10_UNLIKELY(trigger)) {
trigger->traceNpuEraseStream(
reinterpret_cast<uintptr_t>(ptr.get()),
reinterpret_cast<uintptr_t>(stream.stream(false)));
}
}
#endif
}
void NPUPluggableAllocator::eraseStreamWithBlockPtr(void* block_ptr, c10_npu::NPUStream stream, void* work_ptr)
{
TORCH_NPU_WARN("NPUPluggableAllocator does not yet support eraseStreamWithBlockPtr. "
"If you need it, please file an issue describing your use case.");
}
void* NPUPluggableAllocator::getBlockPtr(const c10::DataPtr& ptr)
{
TORCH_NPU_WARN("NPUPluggableAllocator does not yet support getBlockPtr. "
"If you need it, please file an issue describing your use case.");
return nullptr;
}
void NPUPluggableAllocator::recordHcclWorkForBlock(void* block_ptr, void* work_ptr)
{
TORCH_NPU_WARN("NPUPluggableAllocator does not yet support recordHcclWorkForBlock. "
"If you need it, please file an issue describing your use case.");
}
c10_npu::NPUCachingAllocator::DeviceStats NPUPluggableAllocator::getDeviceStats(int device)
{
if (get_device_stats_fn_) {
return get_device_stats_fn_(device);
} else {
TORCH_NPU_WARN("get_device_stats_fn_ is not define, please set by set_get_device_stats_fn");
}
}
void NPUPluggableAllocator::resetAccumulatedStats(int device)
{
TORCH_NPU_WARN("NPUPluggableAllocator does not yet support resetAccumulatedStats. "
"If you need it, please file an issue describing your use case.");
}
void NPUPluggableAllocator::resetPeakStats(int device)
{
if (reset_peak_status_fn_) {
reset_peak_status_fn_(device);
} else {
TORCH_NPU_WARN("reset_peak_status_fn_ is not define, please set by set_reset_peak_status_fn");
}
}
c10_npu::NPUCachingAllocator::SnapshotInfo NPUPluggableAllocator::snapshot()
{
TORCH_NPU_WARN("NPUPluggableAllocator does not yet support snapshot. "
"If you need it, please file an issue describing your use case.");
}
void NPUPluggableAllocator::beginAllocateToPool(
c10::DeviceIndex device,
c10_npu::MempoolId_t mempool_id,
std::function<bool(aclrtStream)> filter)
{
if (begin_allocate_to_pool_fn_) {
begin_allocate_to_pool_fn_(device, mempool_id, std::move(filter));
}
}
void NPUPluggableAllocator::endAllocateToPool(
c10::DeviceIndex device,
c10_npu::MempoolId_t mempool_id)
{
if (end_allocate_to_pool_fn_) {
end_allocate_to_pool_fn_(device, mempool_id);
}
}
void NPUPluggableAllocator::releasePool(
c10::DeviceIndex device,
c10_npu::MempoolId_t mempool_id)
{
if (release_pool_fn_) {
release_pool_fn_(device, mempool_id);
}
}
void NPUPluggableAllocator::FreeDeviceCachedMemory(int device)
{
TORCH_NPU_WARN("NPUPluggableAllocator does not yet support FreeDeviceCachedMemory. "
"If you need it, please file an issue describing your use case.");
}
std::string NPUPluggableAllocator::name()
{
return "pluggable";
}
void NPUPluggableAllocator::copy_data(void* dest, const void* src, std::size_t count) const
{
NPU_CHECK_ERROR(aclrtMemcpy(dest, count, src, count, ACL_MEMCPY_DEVICE_TO_DEVICE));
}
std::shared_ptr<void> NPUPluggableAllocator::getIpcDevPtr(std::string handle)
{
TORCH_NPU_WARN(
"NPUPluggableAllocator does not yet support getIpcDevPtr. "
"If you need it, please file an issue describing your use case.");
auto sp = std::shared_ptr<void>();
return sp;
}
c10_npu::NPUCachingAllocator::ShareableHandle NPUPluggableAllocator::shareIpcHandle(void* ptr)
{
TORCH_NPU_WARN(
"NPUPluggableAllocator does not yet support shareIPcHandle. "
"If you need it, please file an issue describing your use case.");
return c10_npu::NPUCachingAllocator::ShareableHandle{0, nullptr};
}
void NPUPluggableAllocator::recordHistory(
bool enabled,
c10_npu::NPUCachingAllocator::CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
c10_npu::NPUCachingAllocator::RecordContext when)
{
TORCH_NPU_WARN(
"NPUPluggableAllocator does not yet support recordHistory. "
"If you need it, please file an issue describing your use case.");
}
void NPUPluggableAllocator::attachOutOfMemoryObserver(
c10_npu::NPUCachingAllocator::OutOfMemoryObserver observer)
{
TORCH_NPU_WARN(
"NPUPluggableAllocator does not yet support attachOutOfMemoryObserver. "
"If you need it, please file an issue describing your use case.");
}
bool NPUPluggableAllocator::checkUceInMemPool(int device)
{
TORCH_NPU_WARN(
"NPUPluggableAllocator does not yet support checkUceInMemPool. "
"If you need it, please file an issue describing your use case.");
return false;
}
bool NPUPluggableAllocator::checkBlockIsSafe(const c10::DataPtr& ptr)
{
TORCH_NPU_WARN(
"NPUPluggableAllocator does not yet support checkBlockIsSafe. "
"If you need it, please file an issue describing your use case.");
return false;
}
void NPUPluggableAllocator::markAllBlockUnsafe(int device)
{
TORCH_NPU_WARN(
"NPUPluggableAllocator does not yet support markAllBlockUnsafe. "
"If you need it, please file an issue describing your use case.");
}
void NPUPluggableAllocator::updateBlockToSafe(const c10::DataPtr& ptr)
{
TORCH_NPU_WARN(
"NPUPluggableAllocator does not yet support updateBlockToSafe. "
"If you need it, please file an issue describing your use case.");
}
void NPUPluggableAllocator::cleanEvent()
{
TORCH_NPU_WARN(
"NPUPluggableAllocator does not yet support cleanEvent. "
"If you need it, please file an issue describing your use case.");
}
std::shared_ptr<c10_npu::NPUCachingAllocator::AllocatorState> NPUPluggableAllocator::getCheckpointState(c10::DeviceIndex device, c10_npu::MempoolId_t id)
{
TORCH_NPU_WARN(
"NPUPluggableAllocator does not yet support getCheckpointState. "
"If you need it, please file an issue describing your use case.");
}
c10_npu::NPUCachingAllocator::CheckpointDelta NPUPluggableAllocator::setCheckpointPoolState(c10::DeviceIndex device, std::shared_ptr<c10_npu::NPUCachingAllocator::AllocatorState> pps)
{
TORCH_NPU_WARN(
"NPUPluggableAllocator does not yet support setCheckpointPoolState. "
"If you need it, please file an issue describing your use case.");
}
std::shared_ptr<c10_npu::NPUCachingAllocator::NPUAllocator>
current_custom_allocator;
std::shared_ptr<c10_npu::NPUCachingAllocator::NPUAllocator> getCurrentAllocator()
{
return current_custom_allocator;
}
std::shared_ptr<c10_npu::NPUCachingAllocator::NPUAllocator> createCustomAllocator(
std::function<void*(size_t, int, aclrtStream)> alloc_fn,
std::function<void(void*, size_t, int, aclrtStream)> free_fn)
{
std::shared_ptr<NPUPluggableAllocator> allocator(
new NPUPluggableAllocator(std::move(alloc_fn), std::move(free_fn)));
NPU_CHECK_ERROR(c10_npu::GetDevice(&device_count));
allocator->init(device_count);
return allocator;
}
void changeCurrentAllocator(
const std::shared_ptr<c10_npu::NPUCachingAllocator::NPUAllocator>&
allocator)
{
TORCH_CHECK(
!c10_npu::NPUCachingAllocator::allocator.load()->initialized(),
"Can't swap an already initialized allocator", PTA_ERROR(ErrCode::PTR));
c10_npu::NPUCachingAllocator::allocator.store(allocator.get());
current_custom_allocator = allocator;
}
void custom_raw_deleter(void* ptr)
{
current_custom_allocator->raw_delete(ptr);
}
}