* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef AIR_CXX_L2_MEM_POOL_H
#define AIR_CXX_L2_MEM_POOL_H
#include "ge/ge_allocator.h"
#include "kernel/memory/allocator/scalable_allocator.h"
#include "caching_mem_allocator.h"
#include "utils/utils.h"
namespace gert {
namespace memory {
class L2MemPool;
class MultiStreamL1Allocator : public DeviceMemAllocator {
public:
MultiStreamL1Allocator(ge::Allocator *l1_allocator, aclrtStream l2_stream,
TypedContinuousVector<L2MemPool *> *all_l2_mem_pool = nullptr)
: l1_allocator_(l1_allocator),
l2_stream_(l2_stream),
all_l2_mem_pool_(all_l2_mem_pool),
is_rt2_multi_thread_(false) {
if (IsEnableRmLaunchFreeEdge()) {
is_rt2_multi_thread_ = true;
}
}
BlockAddr Alloc(const MemSize size) override;
bool Free(ge::MemBlock *const addr) override;
DeviceId GetDeviceId() const override;
void SetL1Allocator(ge::Allocator *allocator) {
l1_allocator_ = allocator;
}
void SetStream(aclrtStream stream) {
l2_stream_ = stream;
}
private:
ge::Allocator *l1_allocator_;
aclrtStream l2_stream_;
TypedContinuousVector<L2MemPool *> *all_l2_mem_pool_;
bool is_rt2_multi_thread_;
};
class L2MemPool : public ge::Allocator, public MemSynchronizer {
public:
explicit L2MemPool(ge::Allocator *allocator, aclrtStream stream,
TypedContinuousVector<L2MemPool *> *all_l2_mem_pool = nullptr);
~L2MemPool() override;
ge::MemBlock *Malloc(size_t size) override;
void Free(ge::MemBlock *block) override;
void Recycle() override;
ge::Status Synchronize() const override;
ge::Status Finalize(bool no_log = false);
void SetL1Allocator(ge::Allocator *allocator) {
first_level_pool_.SetL1Allocator(allocator);
const auto caching_mem_allocator = dynamic_cast<CachingMemAllocator *>(allocator);
if ((caching_mem_allocator != nullptr) && (caching_mem_allocator->GetScalableAllocator() != nullptr)) {
GELOGI("set l1 allocator %s for l2 allocator %s",
caching_mem_allocator->GetScalableAllocator()->GetId().c_str(), memory_pool_->GetId().c_str());
}
}
aclrtStream GetStream();
void SetStream(aclrtStream stream);
ge::MemBlock *MoveL2ToL1(ge::MemBlock *block);
private:
MultiStreamL1Allocator first_level_pool_;
SpanAllocatorImp span_allocator_;
std::unique_ptr<MemoryPool> memory_pool_;
aclrtStream stream_;
};
}
}
#endif