* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "sequence.h"
#include <algorithm>
#include "log.h"
namespace mindie_llm {
bool sequence_status::IsFinish(const SequenceStatus status) { return status > SequenceStatus::SWAPPED; }
SequenceData SequenceData::FromSequence(const std::vector<TokenId> &tPromptTokenIds) {
return SequenceData(tPromptTokenIds);
}
SequenceData::SequenceData(const std::vector<TokenId> &tPromptTokenIds) : promptTokenIds(tPromptTokenIds) {}
void SequenceData::ResetStateForRecompute() {
numComputedTokens_ = 0;
stage_ = SequenceStage::PREFILL;
layerwiseRecompute_ = true;
auto it =
std::find_if(outputTokenIds.rbegin(), outputTokenIds.rend(), [](int val) { return val != PLACEHOLDER_TOKEN; });
outputTokenIds.erase(it.base(), outputTokenIds.end());
}
size_t SequenceData::GetNumComputedTokens() const { return numComputedTokens_; }
size_t SequenceData::GetNumUncomputedTokens() {
size_t totalLength = GetLength();
size_t numComputedTokens = GetNumComputedTokens();
Assert(totalLength >= numComputedTokens);
return (totalLength - numComputedTokens);
}
void SequenceData::UpdateNumComputedTokens(size_t numNewComputedTokens) {
numComputedTokens_ += numNewComputedTokens;
if (GetNumUncomputedTokens() == 0) {
stage_ = SequenceStage::DECODE;
}
}
void SequenceData::SetLayerwiseStage(bool isPrefill) {
if (isPrefill) {
layerwiseStage_ = SequenceStage::PREFILL;
} else {
layerwiseStage_ = SequenceStage::DECODE;
}
}
* 重新计算的时候,需要将前面已经重新prefill
*/
size_t SequenceData::GetLength() { return outputTokenIds.size() + promptTokenIds.size(); }
Sequence::Sequence(SequenceId seqId, int blockSize)
: seqId_(seqId), blockSize_(blockSize), status_(SequenceStatus::WAITING) {}
Sequence::Sequence(SequenceId seqId, int blockSize, const std::vector<TokenId> &inputs)
: seqId_(seqId),
blockSize_(blockSize),
data_(SequenceData::FromSequence(inputs)),
status_(SequenceStatus::WAITING) {}
size_t Sequence::GetLen() { return data_.GetLength(); }
size_t Sequence::GetOutputLen(bool containsPlaceholder) {
if (containsPlaceholder) {
return data_.outputTokenIds.size();
} else {
return std::count_if(data_.outputTokenIds.rbegin(), data_.outputTokenIds.rend(),
[](auto token) { return token != PLACEHOLDER_TOKEN; });
}
}
bool Sequence::IsPrefill() const { return data_.stage_ == SequenceStage::PREFILL; }
bool Sequence::IsLayerwisePrefill() const { return data_.layerwiseStage_ == SequenceStage::PREFILL; }
bool Sequence::IsFinished() const { return sequence_status::IsFinish(status_); }
size_t Sequence::GetNumComputedTokens() const { return data_.GetNumComputedTokens(); }
size_t Sequence::GetNumUncomputedTokens() {
if (data_.stage_ == SequenceStage::DECODE) {
return 1;
} else {
return data_.GetNumUncomputedTokens();
}
}
void Sequence::ResetStateForCompute() { data_.ResetStateForRecompute(); }
const std::vector<TokenId> Sequence::GetTokenIds() const {
std::vector<TokenId> tokenIds = data_.promptTokenIds;
tokenIds.insert(tokenIds.end(), data_.outputTokenIds.begin(), data_.outputTokenIds.end());
return tokenIds;
}
HashValue Sequence::GetExtraHash() const { return hashValue_; }
void Sequence::SetExtraHash(HashValue hashValue) { hashValue_ = hashValue; }
}