* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <map>
#include "common/plugin/ge_make_unique_util.h"
#include "base/err_msg.h"
#include "caching_mem_allocator.h"
#include "runtime/rt.h"
#include "framework/common/debug/ge_log.h"
#include "common/checker.h"
#include "utils/utils.h"
#include "multi_stream_mem_block_helper.h"
#include "graph/load/model_manager/model_utils.h"
#include "acl/acl_rt.h"
namespace gert {
namespace memory {
thread_local std::vector<CachingMemAllocator *> CachingMemAllocator::same_thread_allocators_ = {};
std::vector<CachingMemAllocator *> CachingMemAllocator::all_caching_mem_allocators_ = {};
std::mutex CachingMemAllocator::mutex_;
ge::MemBlock *CachingMemAllocator::AllocateWithTryRecycle(size_t size) {
auto addr = memory_pool_->Alloc(*this, size);
if (addr != nullptr) {
return addr;
}
GELOGE(ge::MEMALLOC_FAILED,
"%s Failed to apply for memory. We will try to free memory from memory pool, the above or this error log can "
"be ignored. Try to free cached memory...",
memory_pool_->GetId().c_str());
memory_pool_->PrintDetails(DLOG_INFO);
GELOGI("will synchronize on stream %p", stream_);
GE_ASSERT_SUCCESS(Synchronize());
Recycle();
addr = memory_pool_->Alloc(*this, size);
if (addr == nullptr) {
GELOGI("addr is nullptr, try to free other allocator memory and malloc again");
const std::lock_guard<std::mutex> lock(mutex_);
for (size_t i = 0U; i < same_thread_allocators_.size(); ++i) {
if (same_thread_allocators_[i] == this) {
continue;
}
GE_ASSERT_SUCCESS(same_thread_allocators_[i]->Synchronize());
same_thread_allocators_[i]->memory_pool_->Recycle();
addr = memory_pool_->Alloc(*this, size);
if (addr == nullptr) {
continue;
}
break;
}
}
return addr;
}
ge::MemBlock *CachingMemAllocator::Malloc(size_t size) {
GELOGI("Malloc size:%zu.", size);
auto block_mem = AllocateWithTryRecycle(size);
if (block_mem != nullptr) {
DeviceMemoryRecorder::AddTotalAllocateMemory(static_cast<uint64_t>(block_mem->GetSize()));
}
return block_mem;
}
CachingMemAllocator::CachingMemAllocator(const uint32_t device_id)
: rts_mem_allocator_(*RtsCachingMemAllocator::GetAllocator(device_id, RT_MEMORY_HBM), device_id, "rt2 memory pool"),
memory_pool_(new ScalableAllocator(span_allocator_, rts_mem_allocator_, ScalableConfig())) {
GELOGI("create caching memory allocator, %s", memory_pool_->GetId().c_str());
}
CachingMemAllocator::CachingMemAllocator(const uint32_t device_id, const rtMemType_t memory_type)
: CachingMemAllocator("", device_id, memory_type) {}
CachingMemAllocator::CachingMemAllocator(const std::string &graph_name, const uint32_t device_id,
const rtMemType_t memory_type)
: rts_mem_allocator_(*RtsCachingMemAllocator::GetAllocator(device_id, memory_type), device_id, "rt2 memory pool"),
memory_pool_(new ScalableAllocator(span_allocator_, rts_mem_allocator_, ScalableConfig(), graph_name)) {
const std::lock_guard<std::mutex> lock(mutex_);
same_thread_allocators_.emplace_back(this);
all_caching_mem_allocators_.emplace_back(this);
GELOGI("create caching memory allocator, %s", memory_pool_->GetId().c_str());
}
CachingMemAllocator::CachingMemAllocator(const uint32_t device_id, const rtMemType_t memory_type,
ScalableConfig &config)
: rts_mem_allocator_(*RtsCachingMemAllocator::GetAllocator(device_id, memory_type), device_id, "rt2 memory pool"),
memory_pool_(new ScalableAllocator(span_allocator_, rts_mem_allocator_, config)) {
const std::lock_guard<std::mutex> lock(mutex_);
same_thread_allocators_.emplace_back(this);
all_caching_mem_allocators_.emplace_back(this);
GELOGI("create caching memory allocator, %s", memory_pool_->GetId().c_str());
}
std::unique_ptr<CachingMemAllocator> CachingMemAllocator::GetAllocator(const uint32_t device_id) {
return GetAllocator("", device_id, RT_MEMORY_HBM);
}
std::unique_ptr<CachingMemAllocator> CachingMemAllocator::GetAllocator(const std::string &graph_name,
const uint32_t device_id,
const rtMemType_t rt_mem_type) {
auto caching_allocator = ge::MakeUnique<CachingMemAllocator>(graph_name, device_id, rt_mem_type);
if ((caching_allocator != nullptr) && ge::ModelUtils::IsGeUseExtendSizeMemory(true)) {
auto allocator = caching_allocator->GetScalableAllocator();
if (allocator != nullptr) {
allocator->InitExpandableAllocator(*caching_allocator, rt_mem_type);
}
}
return caching_allocator;
}
std::unique_ptr<CachingMemAllocator> CachingMemAllocator::GetAllocator() {
int32_t device_id = 0;
const auto rt_result = aclrtGetDevice(&device_id);
if (rt_result != ACL_SUCCESS) {
GELOGE(ge::RT_FAILED, "[Get][aclrtGetDevice] Failed, result:%d.", rt_result);
REPORT_INNER_ERR_MSG("E19999", "aclrtGetDevice failed, result:%d.", rt_result);
return nullptr;
}
return GetAllocator(device_id);
}
ge::Status CachingMemAllocator::Finalize(bool no_log) {
return memory_pool_->Finalize(no_log);
}
ge::Status CachingMemAllocator::Synchronize() const {
GE_ASSERT_SUCCESS(DoRtStreamSyncWithTimeout(stream_));
return ge::SUCCESS;
}
void CachingMemAllocator::Recycle() {
memory_pool_->Recycle();
}
}
}