* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/build/memory/dynamic_batch_mem_assigner.h"
#include <string>
#include "graph/utils/type_utils.h"
#include "common/checker.h"
namespace ge {
bool DynamicBatchBlockReuse(MemoryBlock &block) {
return ((block.IsSameBatchLabel()) && ((block.reuse_mem_) || (block.need_same_offset_in_batch_)));
}
struct CompareSize {
explicit CompareSize() {}
bool operator() (const MemoryBlock *const left, const MemoryBlock *const right) const {
if ((left != nullptr) && (right != nullptr)) {
auto left_size = left->Size();
auto right_size = right->Size();
return (left_size > right_size);
}
return false;
}
};
void AddBlockToMaxBatchBlock(std::vector<MemoryBlock *> &max_batch_blocks, std::vector<MemoryBlock *> &blocks) {
CompareSize cmp;
std::sort(max_batch_blocks.begin(), max_batch_blocks.end(), cmp);
std::sort(blocks.begin(), blocks.end(), cmp);
for (auto max_batch_block : max_batch_blocks) {
max_batch_block->Reset();
for (size_t i = 0U; i < blocks.size(); ++i) {
auto block = blocks[i];
if (block->child_block_) {
continue;
}
block->Resize();
if (!max_batch_block->AddBatchChildBlock(block)) {
break;
}
}
}
}
size_t GetBlocksSize(std::vector<MemoryBlock *> &blocks) {
size_t block_size = 0U;
for (auto block : blocks) {
block->Resize();
block_size += block->Size();
}
return block_size;
}
void CreateBlocks(std::vector<MemoryBlock *> &blocks, std::vector<std::vector<MemoryBlock *>> &new_blocks,
bool use_extend_size_memory) {
size_t block_size = 0U;
std::vector<MemoryBlock *> split_blocks;
for (auto block : blocks) {
block->Resize();
if (use_extend_size_memory && ((block_size + block->Size()) > kMaxSplitSizeForDynamicBatch)) {
new_blocks.emplace_back(std::move(split_blocks));
block_size = 0U;
}
block_size += block->Size();
split_blocks.emplace_back(block);
}
new_blocks.emplace_back(std::move(split_blocks));
}
void DynamicBatchMemAssigner::ResizeDynamicBatchBlocks() {
use_extend_size_memory_ = ge::VarManager::IsGeUseExtendSizeMemory();
struct DynamicBatchBlocsk {
std::map<std::string, std::vector<MemoryBlock *>> dynamic_batch_blocks;
std::map<std::string, std::vector<MemoryBlock *>> need_same_offset_blocks;
std::map<std::string, std::vector<MemoryBlock *>> continuous_input_blocks;
};
std::map<uint64_t, DynamicBatchBlocsk> memory_type_to_batch_blocks;
for (auto block : memory_blocks_) {
if ((block == nullptr) || (block->child_block_) || (block->is_zero_copy_)) {
continue;
}
if (DynamicBatchBlockReuse(*block)) {
const auto memory_type = block->memory_type_;
if (block->need_same_offset_in_batch_) {
memory_type_to_batch_blocks[memory_type].need_same_offset_blocks[block->batch_label_].emplace_back(block);
} else if (block->continuous_block_) {
memory_type_to_batch_blocks[memory_type].continuous_input_blocks[block->batch_label_].emplace_back(block);
} else {
memory_type_to_batch_blocks[memory_type].dynamic_batch_blocks[block->batch_label_].emplace_back(block);
}
}
}
for (auto &memory_type_batch_blocks : memory_type_to_batch_blocks) {
GELOGI("Process memory type:%lu", memory_type_batch_blocks.first);
DoResizeDynamicBatchBlocks(memory_type_batch_blocks.second.dynamic_batch_blocks);
DoResizeDynamicBatchBlocks(memory_type_batch_blocks.second.need_same_offset_blocks, true);
DoResizeDynamicBatchBlocks(memory_type_batch_blocks.second.continuous_input_blocks, true);
}
}
MemoryBlock *DynamicBatchMemAssigner::CreateContinuousBlock(std::vector<MemoryBlock *> &continuous_blocks) {
GE_WARN_ASSERT(!continuous_blocks.empty());
auto continuous_block = new (std::nothrow) MemoryBlock(reuse_strategy_, 0, -1, true);
GE_ASSERT_TRUE(continuous_block != nullptr);
blocks_store_.emplace_back(continuous_block);
memory_blocks_.emplace_back(continuous_block);
size_t continuous_block_size = 0U;
size_t life_begin = 0U;
size_t life_end = 0U;
std::string batch_label;
int64_t parent_node_stream_id = 0;
uint64_t memory_type = RT_MEMORY_HBM;
const ge::Node *parent_node = nullptr;
for (auto block : continuous_blocks) {
if ((block != nullptr) && (!block->NodeTypeIndexList().empty())) {
if (parent_node == nullptr) {
parent_node = block->NodeTypeIndexList().front().node_;
life_begin = block->GetLifeBegin();
batch_label = block->batch_label_;
parent_node_stream_id = block->stream_id_;
memory_type = block->memory_type_;
}
block->Resize();
continuous_block_size += block->Size();
continuous_block->AddChildBlock(block);
life_end = std::max(life_end, block->GetLifeEnd(block->stream_id_));
}
}
GE_ASSERT_TRUE(parent_node != nullptr);
continuous_block->SetSize(continuous_block_size);
continuous_block->stream_id_ = parent_node_stream_id;
continuous_block->batch_label_ = batch_label;
continuous_block->memory_type_ = memory_type;
continuous_block->AddNodeTypeIndex({parent_node, kWorkspace, 0, false, life_begin,
parent_node_stream_id, true},
continuous_block_size, continuous_block_size, parent_node_stream_id);
continuous_block->SetLifeTimeEnd(life_end, parent_node_stream_id);
GELOGI("%s continuous block size:%zu memory_type:%lu", continuous_block->batch_label_.c_str(),
continuous_block->Size(), memory_type);
return continuous_block;
}
void DynamicBatchMemAssigner::DoResizeDynamicBatchBlocks(
std::map<std::string, std::vector<MemoryBlock *>> &dynamic_batch_blocks, bool continuous) {
size_t max_size = 0U;
std::string max_batch_label;
for (auto &batch_blocks : dynamic_batch_blocks) {
auto size = GetBlocksSize(batch_blocks.second);
if (size > max_size) {
max_size = size;
max_batch_label = batch_blocks.first;
}
}
auto it_max_batch = dynamic_batch_blocks.find(max_batch_label);
if (it_max_batch == dynamic_batch_blocks.end()) {
return;
}
if (continuous) {
auto max_continuous_block = CreateContinuousBlock(it_max_batch->second);
for (auto &batch_blocks : dynamic_batch_blocks) {
if (batch_blocks.first == max_batch_label) {
continue;
}
auto continuous_block = CreateContinuousBlock(batch_blocks.second);
if ((max_continuous_block != nullptr) && (continuous_block != nullptr)) {
max_continuous_block->Reset();
(void) max_continuous_block->AddBatchChildBlock(continuous_block);
}
}
return;
}
std::vector<std::vector<MemoryBlock *>> new_blocks;
CreateBlocks(it_max_batch->second, new_blocks, use_extend_size_memory_);
std::vector<MemoryBlock *> max_blocks;
for (auto &blocks : new_blocks) {
auto block = CreateContinuousBlock(blocks);
if (block != nullptr) {
max_blocks.emplace_back(block);
}
}
for (auto &batch_blocks : dynamic_batch_blocks) {
if (batch_blocks.first == max_batch_label) {
continue;
}
AddBlockToMaxBatchBlock(max_blocks, batch_blocks.second);
}
}
}