* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/manager/rdma_pool_allocator.h"
#include "graph/types.h"
#include "framework/common/debug/log.h"
#include "framework/common/debug/ge_log.h"
#include "graph/def_types.h"
#include "graph/ge_context.h"
#include "runtime/dev.h"
#include "graph/manager/mem_manager.h"
#include "common/math/math_util.h"
#include "graph_metadef/common/ge_common/util.h"
#include "acl/acl_rt.h"
namespace {
constexpr size_t kAlignedSize = 512U;
constexpr ge::float32_t kSplitBlockThreshold = 0.5F;
inline size_t GetAlignedBlockSize(const size_t size) {
if (size == 0U) {
return kAlignedSize;
}
if (ge::CheckSizeTAddOverflow(size, kAlignedSize) != ge::SUCCESS) {
return SIZE_MAX;
}
return kAlignedSize * ((size + kAlignedSize - 1U) / kAlignedSize);
}
inline bool ShouldSplit(const ge::Block &block, const size_t size) {
return static_cast<ge::float64_t>(size) <= (static_cast<ge::float64_t>(block.size) * kSplitBlockThreshold);
}
inline bool CanMergeBlock(const ge::Block &block) { return !block.allocated; }
bool BlockComp(const ge::Block *const left, const ge::Block *const right) {
if (left->size != right->size) {
return left->size < right->size;
}
return ge::PtrToValue(left->ptr) < ge::PtrToValue(right->ptr);
}
}
namespace ge {
RdmaPoolAllocator::RdmaPoolAllocator(const rtMemType_t memory_type)
: memory_type_(memory_type), block_bin_(BlockBin(&BlockComp)) {}
Status RdmaPoolAllocator::Initialize() {
memory_allocator_ = &MemManager::Instance().MemInstance(memory_type_);
return SUCCESS;
}
void RdmaPoolAllocator::Finalize() {
GELOGD("Rdma pool finalize start.");
auto it_block = allocated_blocks_.begin();
while (it_block != allocated_blocks_.end()) {
const auto block = it_block->second;
it_block = allocated_blocks_.erase(it_block);
delete block;
}
auto it_bin = block_bin_.begin();
while (it_bin != block_bin_.end()) {
const auto block = *it_bin;
it_bin = block_bin_.erase(it_bin);
delete block;
}
if (rdma_base_addr_ != nullptr) {
GELOGD("Start to free rdma pool memory.");
if ((memory_allocator_ == nullptr) || (memory_allocator_->FreeMemory(rdma_base_addr_) != SUCCESS)) {
GELOGW("Free rdma pool memory failed");
}
rdma_base_addr_ = nullptr;
}
}
Status RdmaPoolAllocator::InitMemory(const size_t mem_size) {
const auto device_id = GetContext().DeviceId();
GELOGD("Init Rdma Memory with size [%zu] for devid:[%u].", mem_size, device_id);
if (rdma_base_addr_ != nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Param rdma_base_addr_ is not nullptr, devid:%u, check invalid", device_id);
GELOGE(GE_MULTI_INIT, "[Check][Param] Rdma pool has been malloced, devid:%u", device_id);
return GE_MULTI_INIT;
}
const std::string purpose = "Memory for rdma pool";
const std::lock_guard<std::recursive_mutex> lock(mutex_);
const int32_t dev_id = static_cast<int32_t>(device_id);
GE_CHK_RT_RET(aclrtSetDevice(dev_id));
GE_MAKE_GUARD(not_used_var, [&dev_id]() { GE_CHK_RT(aclrtResetDevice(dev_id)); });
GE_CHECK_NOTNULL(memory_allocator_);
rdma_base_addr_ = memory_allocator_->MallocMemory(purpose, mem_size, device_id);
if (rdma_base_addr_ == nullptr) {
GELOGE(GE_GRAPH_MALLOC_FAILED, "[Malloc][Memory] failed, size:%zu, device_id:%u", mem_size, device_id);
return GE_GRAPH_MALLOC_FAILED;
}
rdma_mem_size_ = mem_size;
auto *const base_block = new (std::nothrow) Block(device_id, mem_size, rdma_base_addr_);
if (base_block == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "New Block failed, size:%zu, device_id:%u", mem_size, device_id);
GELOGE(GE_GRAPH_MALLOC_FAILED, "[New][Block] failed, size:%zu, device_id:%u", mem_size, device_id);
return GE_GRAPH_MALLOC_FAILED;
}
(void)block_bin_.insert(base_block);
return SUCCESS;
}
uint8_t *RdmaPoolAllocator::Malloc(const size_t size, const uint32_t device_id) {
GELOGI("start to malloc rdma memory size:%zu, device id = %u.", size, device_id);
const auto aligned_size = GetAlignedBlockSize(size);
Block key(device_id, aligned_size, nullptr);
const std::lock_guard<std::recursive_mutex> lock(mutex_);
const auto it = block_bin_.lower_bound(&key);
if (it != block_bin_.end()) {
Block *block = *it;
(void)block_bin_.erase(it);
block->allocated = true;
if (block->ptr == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Rdmapool memory address is nullptr, device_id:%u, check invalid",
device_id);
GELOGE(INTERNAL_ERROR, "[Check][Param] Rdmapool memory address is nullptr, device_id:%u", device_id);
return nullptr;
}
(void)allocated_blocks_.emplace(block->ptr, block);
if (ShouldSplit(*block, aligned_size)) {
GELOGD("Block will be splited block size = %zu, aligned_size:%zu.", block->size, aligned_size);
auto *const new_block = new (std::nothrow) Block(device_id, block->size - aligned_size, nullptr,
PtrAdd(block->ptr, block->size, aligned_size));
if (new_block == nullptr) {
GELOGW("Block split failed");
return block->ptr;
}
new_block->next = block->next;
if (block->next != nullptr) {
block->next->prev = new_block;
}
new_block->prev = block;
block->next = new_block;
block->size = aligned_size;
(void)block_bin_.insert(new_block);
}
GELOGD("Find block size = %zu", block->size);
return block->ptr;
}
GELOGW("Memory block not founded.");
return nullptr;
}
Status RdmaPoolAllocator::Free(uint8_t *const memory_addr, const uint32_t device_id) {
GELOGI("Free rdma memory, device id = %u.", device_id);
if (memory_addr == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Param memory_addr is nullptr, device_id:%u, check invalid", device_id);
GELOGE(GE_GRAPH_FREE_FAILED, "[Check][Param] Invalid memory pointer, device id:%u", device_id);
return GE_GRAPH_FREE_FAILED;
}
const std::lock_guard<std::recursive_mutex> lock(mutex_);
const auto it = allocated_blocks_.find(memory_addr);
if (it == allocated_blocks_.end()) {
REPORT_INNER_ERR_MSG("E19999", "Param memory_addr is not allocated before, device_id:%u, "
"check invalid", device_id);
GELOGE(PARAM_INVALID, "[Check][Param] Invalid memory pointer, device id:%u", device_id);
return PARAM_INVALID;
}
Block *const block = it->second;
block->allocated = false;
(void)allocated_blocks_.erase(it);
const std::vector<Block *> merge_blocks = {block->prev, block->next};
for (Block *const merge_block : merge_blocks) {
if (merge_block != nullptr) {
MergeBlocks(*block, *merge_block);
}
}
(void)block_bin_.insert(block);
return SUCCESS;
}
void RdmaPoolAllocator::MergeBlocks(Block &dst, Block &src) {
if ((!CanMergeBlock(dst)) || (!CanMergeBlock(src))) {
return;
}
if (dst.prev == &src) {
dst.ptr = src.ptr;
dst.prev = src.prev;
if (dst.prev != nullptr) {
dst.prev->next = &dst;
}
} else {
dst.next = src.next;
if (dst.next != nullptr) {
dst.next->prev = &dst;
}
}
if (CheckSizeTAddOverflow(dst.size, src.size) == SUCCESS) {
dst.size += src.size;
} else {
dst.size = SIZE_MAX;
}
(void)block_bin_.erase(&src);
delete &src;
}
Status RdmaPoolAllocator::GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size) const {
if (rdma_base_addr_ == nullptr) {
REPORT_INNER_ERR_MSG("E19999", "Param rdma_base_addr_ is nullptr, check invalid");
GELOGE(INTERNAL_ERROR, "[Check][Param] Rdma base addr is nullptr.");
return INTERNAL_ERROR;
}
base_addr = PtrToValue(rdma_base_addr_);
mem_size = rdma_mem_size_;
return SUCCESS;
}
}