* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef BLOCK_TRACKER_H
#define BLOCK_TRACKER_H
#include <functional>
#include <limits>
#include <unordered_map>
#include "device_aware_block_allocator.h"
#include "math_utils.h"
#include "sequence.h"
namespace mindie_llm {
struct PairHash {
template <typename T1, typename T2>
size_t operator()(const std::pair<T1, T2> &p) const {
auto h1 = std::hash<T1>{}(p.first);
auto h2 = std::hash<T2>{}(p.second);
size_t hashValue = h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << mindie_llm::HASH_SHIFT_LEFT) +
(h1 >> mindie_llm::HASH_SHIFT_RIGHT));
return hashValue;
}
};
HashValue ComputeHashValueForSeq(HashValue prevBlockHash, std::vector<TokenId> &tokenIds, HashValue extraHash);
class BlockComputedAttr {
struct BlockInfo {
bool active = false;
bool computed = false;
TimeStamp lastAccessed = DEFAULT_LAST_ACCESSED_TIME;
};
public:
BlockComputedAttr(size_t numBlocks, BlockId beginBlockId) : blockInfos_(numBlocks), beginBlockId_(beginBlockId) {}
~BlockComputedAttr() = default;
void Reset(BlockId blockId);
void Enable(BlockId blockId);
void Disable(BlockId blockId);
void SetActive(BlockId blockId, bool active);
bool IsActive(BlockId blockId) const;
void SetComputed(BlockId blockId, bool computed);
bool IsComputed(BlockId blockId) const;
void UpdateAccessTime(BlockId blockId, TimeStamp now);
TimeStamp LastAccessed(BlockId blockId) const;
private:
std::vector<BlockInfo> blockInfos_;
BlockId beginBlockId_;
bool IsValidBlockId(BlockId blockId) const {
if (beginBlockId_ > std::numeric_limits<BlockId>::max() - static_cast<BlockId>(blockInfos_.size())) {
throw std::runtime_error("blockId range overflow!");
}
if (blockId >= beginBlockId_ && blockId < beginBlockId_ + static_cast<BlockId>(blockInfos_.size())) {
return true;
} else {
throw std::runtime_error("blockId is error!");
}
}
};
class SeqsBlocksComputedTracker {
public:
SeqsBlocksComputedTracker() = default;
SeqsBlocksComputedTracker(DeviceAwareBlockAllocatorSPtr allocator, size_t blockSize, bool enableCaching,
size_t rankSize)
: allocator_(allocator), blockSize_(blockSize), enableCaching_(enableCaching), rankSize_(rankSize) {}
~SeqsBlocksComputedTracker() = default;
size_t GetCachedTokensNum(const SequenceSPtr &seq, size_t rankIdx, std::vector<HashValue> &blockHashes,
bool seqPrefillFlag);
size_t GetCachedTokensNum(const SequenceSPtr &seq);
void RemoveSeq(SequenceId seqId);
private:
DeviceAwareBlockAllocatorSPtr allocator_;
size_t blockSize_;
bool enableCaching_;
size_t rankSize_{1};
std::unordered_map<std::pair<SequenceId, size_t>, size_t, PairHash> seqIdToNumComputedTokens_ = {};
};
class SeqsLastAccessBlocksTracker {
public:
SeqsLastAccessBlocksTracker() = default;
explicit SeqsLastAccessBlocksTracker(DeviceAwareBlockAllocatorSPtr allocator) : allocator_(allocator) {}
~SeqsLastAccessBlocksTracker() = default;
void AddSeq(SequenceId seqId);
void RemoveSeq(SequenceId seqId);
void UpdateSeqLastAccess(SequenceId seqId, TimeStamp time);
void UpdateSeqBlocksLastAccess(SequenceId seqId, std::vector<std::vector<BlockId>> &rankedBlockIds);
private:
DeviceAwareBlockAllocatorSPtr allocator_;
std::unordered_map<SequenceId, TimeStamp> seqIdToLastAccessTime_ = {};
};
}
#endif