* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef AIR_CXX_RUNTIME_V2_KERNEL_MEMORY_MULTI_STREAM_L_2_MEM_BLOCK_H_
#define AIR_CXX_RUNTIME_V2_KERNEL_MEMORY_MULTI_STREAM_L_2_MEM_BLOCK_H_
#include <cstdint>
#include <vector>
#include <list>
#include "exe_graph/runtime/gert_mem_block.h"
#include "exe_graph/runtime/gert_mem_allocator.h"
#include "exe_graph/runtime/continuous_vector.h"
#include "mif.h"
#include "block_alloc_type.h"
namespace gert {
namespace memory {
* 支持多流的Memory Block,对于同一块内存,对应唯一的本实例。当发生birth回收时,本实例同步可以被回收。
*/
class MultiStreamMemBlock : public GertMemBlock {
public:
MultiStreamMemBlock();
~MultiStreamMemBlock() override;
MultiStreamMemBlock(const MultiStreamMemBlock &) = delete;
MultiStreamMemBlock &operator=(const MultiStreamMemBlock &) = delete;
* 调用本接口前,需要保证GertMemBlock对应的内存已经完成了birth recycle
* @param birth_stream
* @param stream_ids_to_allocators
* @param block
* @return
*/
ge::graphStatus ReInit(GertAllocator *birth_allocator, ge::MemBlock *block, BlockAllocType alloc_type);
void Free(int64_t stream_id) override;
void *GetAddr() override {
if (SECUREC_UNLIKELY(mem_block_ == nullptr)) {
return nullptr;
}
return mem_block_->GetAddr();
}
size_t GetSize() const {
if (SECUREC_UNLIKELY(mem_block_ == nullptr)) {
return 0U;
}
return mem_block_->GetSize();
}
void SetSize(const size_t mem_size) {
if (SECUREC_LIKELY(mem_block_ != nullptr)) {
mem_block_->SetSize(mem_size);
}
}
size_t AddCount(int64_t stream_id) {
if (SECUREC_UNLIKELY(!IsStreamIdValid(stream_id))) {
return 0U;
}
return ++stream_ids_to_occupy_count_[stream_id];
}
size_t SubCount(int64_t stream_id) {
if (SECUREC_UNLIKELY(!IsStreamIdValid(stream_id))) {
return 0U;
}
return --stream_ids_to_occupy_count_[stream_id];
}
size_t GetCount(int64_t stream_id) const {
if (SECUREC_UNLIKELY(!IsStreamIdValid(stream_id))) {
return 0U;
}
return stream_ids_to_occupy_count_[stream_id];
}
const ge::MemBlock *GetMemBlock() const {
return mem_block_;
}
ge::MemBlock *GetMemBlock() {
return mem_block_;
}
void SetMemBlock(ge::MemBlock *block) {
mem_block_ = block;
}
int64_t GetBirthStreamId() const {
return birth_allocator_->GetStreamId();
}
int64_t GetVersion() const {
return version_;
}
ge::graphStatus NewAccessStream(int64_t src_stream, int64_t dst_stream);
* 从 src_stream向dst_stream同步回收状态
* @param src_stream
* @param dst_stream
* @return
*/
ge::graphStatus SyncLocalRecycleStatus(int64_t src_stream, int64_t dst_stream);
ge::graphStatus SyncAllLocalRecycleStatus(int64_t dst_stream);
const GertAllocator *GetBirthAllocator() const {
return birth_allocator_;
}
GertAllocator *GetBirthAllocator() {
return birth_allocator_;
}
const BlockAllocType &GetAllocType() const {
return alloc_type_;
}
private:
bool IsStreamIdValid(int64_t stream_id) const {
return static_cast<size_t>(stream_id) < stream_ids_to_occupy_count_.size();
}
int64_t PlusVersion() {
return ++version_;
}
* 调用本函数前,确认`birth_allocator_`已经正确初始化
*/
void ReInitMifs() {
occupy_mif_.ReInit(birth_allocator_->GetStreamNum());
stream_ids_to_occupy_count_.clear();
stream_ids_to_occupy_count_.resize(birth_allocator_->GetStreamNum(), 0U);
}
* 危险函数,确保本block对应的内存已经完成了birth回收,才可以调用该函数
*/
void ClearMemBlock() {
mem_block_ = nullptr;
}
private:
ge::MemBlock *mem_block_;
int64_t version_;
GertAllocator *birth_allocator_;
BlockAllocType alloc_type_;
Mif occupy_mif_;
ge::SmallVector<size_t, kDefaultMaxStreamNum> stream_ids_to_occupy_count_;
friend class MultiStreamMemBlockHelper;
};
}
}
#endif