* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef OBJ_POOL_H
#define OBJ_POOL_H
#include <functional>
#include <limits>
#include <memory>
#include <stack>
#include <stdexcept>
#include "block_obj.h"
namespace mindie_llm {
template <typename T>
class ObjPool {
public:
explicit ObjPool<T>(uint64_t poolSize, std::function<std::shared_ptr<T>()> createFunc)
: poolSize_(poolSize), createFunction_(createFunc) {
for (uint64_t i = 0; i < poolSize_; i++) {
objPool_.push(createFunction_());
}
}
~ObjPool() = default;
std::shared_ptr<T> AcquireObj() {
if (objPool_.empty()) {
IncreaseCapcity();
}
std::shared_ptr<T> blockObj = objPool_.top();
objPool_.pop();
return blockObj;
}
void FreeObj(std::shared_ptr<T> &blockObj) {
if (blockObj == nullptr) {
throw std::runtime_error("FreeObj Error: attempt to free a null blockObj.");
}
if (GetFreeObjNum() == poolSize_) {
throw std::runtime_error("no ObjPool is used, but caller holds an object to free. must be a bug!");
}
blockObj->ResetBlockObj();
objPool_.push(blockObj);
blockObj.reset();
}
uint64_t GetFreeObjNum() const { return objPool_.size(); }
uint64_t GetPoolSize() const { return poolSize_; }
private:
void IncreaseCapcity() {
if (poolSize_ > std::numeric_limits<uint64_t>::max() / capcityMultiplier_) {
throw std::overflow_error("Cannot increase capacity: size would overflow.");
}
for (uint64_t i = poolSize_; i < capcityMultiplier_ * poolSize_; i++) {
objPool_.push(createFunction_());
}
poolSize_ *= capcityMultiplier_;
}
uint64_t poolSize_;
std::function<std::shared_ptr<T>()> createFunction_;
std::stack<std::shared_ptr<T>> objPool_;
const uint64_t capcityMultiplier_ = 2;
};
using BlockObjPoolSPtr = std::shared_ptr<ObjPool<BlockObj>>;
}
#endif