* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "kernel/memory/device/device_allocator.h"
#include "framework/common/debug/ge_log.h"
#include "framework/runtime/device_memory_recorder.h"
#include "subscriber/profiler/cann_memory_profiler.h"
#include "kernel/memory/rts_caching_mem_allocator.h"
#include "graph/manager/mem_manager.h"
#include "common/aclrt_malloc_helper.h"
namespace gert {
DeviceAllocator::DeviceAllocator(DeviceMemAllocator &mem_allocator)
: mem_allocator_(mem_allocator) {
}
BlockAddr DeviceAllocator::Alloc(const MemSize size) {
++alloc_count_;
const auto ptr = mem_allocator_.Alloc(size);
if (ptr != nullptr) {
GELOGI("DeviceAllocator::Alloc device_id:%u, size:%llu, block_addr:%p", mem_allocator_.GetDeviceId(), size, ptr);
occupied_size_ += size;
} else {
if (log_error_if_alloc_failed_) {
GELOGE(ge::FAILED, "DeviceAllocator::Alloc failed device_id:%u, size:%llu", mem_allocator_.GetDeviceId(), size);
}
}
return ptr;
}
void DeviceAllocator::Free(ge::MemBlock *const addr) {
++free_count_;
if (addr == nullptr) {
return;
}
auto size = addr->GetSize();
if (mem_allocator_.Free(addr)) {
GELOGI("DeviceAllocator::free device_id:%u ,size:%llu, block_addr:%p", mem_allocator_.GetDeviceId(), size, addr);
if (occupied_size_ >= size) {
occupied_size_ -= size;
}
}
}
void DeviceAllocator::SetLogErrorIfAllocFailed(bool log_error_if_alloc_failed) {
log_error_if_alloc_failed_ = log_error_if_alloc_failed;
}
RtMemAllocator::RtMemAllocator(ge::Allocator &allocator, const DeviceId device_id, const uint32_t mem_type)
: allocator_{allocator}, device_id_{device_id}, mem_type_{mem_type} {
}
BlockAddr RtMemAllocator::Alloc(const MemSize size) {
void *ptr = nullptr;
const auto rt_ret = ge::AclrtMalloc(&ptr, size, mem_type_, GE_MODULE_NAME_U16);
if (rt_ret == ACL_SUCCESS) {
DeviceMemoryRecorder::AddTotalReserveMemory(static_cast<uint64_t>(size));
GE_PRINT_DYNAMIC_MEMORY(aclrtMalloc, ge::ToMallocMemInfo("page caching", ptr, device_id_, GE_MODULE_NAME_U16).c_str(),
size);
auto block = new (block_allocator_.Alloc()) ge::MemBlock{allocator_, ptr, static_cast<size_t>(size)};
return block;
}
REPORT_INNER_ERR_MSG("E19999", "Call aclrtMalloc fail, purpose: page caching, type = %u, size:%llu, device_id:%u",
mem_type_, size, device_id_);
GELOGE(ge::INTERNAL_ERROR, "[Malloc][Memory] failed, rt_ret:%d, device_id:%u, size:%llu", rt_ret, device_id_,
size);
return nullptr;
}
bool RtMemAllocator::Free(ge::MemBlock *const addr) {
auto size = addr->GetSize();
GELOGI("RtMemAllocator::free device_id:%u ,size:%llu, mem_addr:%p", GetDeviceId(), size, addr->GetAddr());
const auto rt_ret = aclrtFree(addr->GetAddr());
if (rt_ret == ACL_SUCCESS) {
block_allocator_.Free(dynamic_cast<ge::MemBlock &>(*addr));
DeviceMemoryRecorder::ReduceTotalReserveMemory(static_cast<uint64_t>(size));
return true;
}
REPORT_INNER_ERR_MSG("E19999", "Call aclrtFree fail, device_id:%u", device_id_);
GELOGE(ge::FAILED, "[Call][RtFree] failed, rt_ret:%d, device_id:%u, addr:%p, size:%llu",
rt_ret, device_id_, addr->GetAddr(), size);
return false;
}
}