* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "model_exec_output_handler.h"
#include <chrono>
#include <string>
#include "error_queue.h"
#include "live_infer_context.h"
#include "log.h"
#include "msServiceProfiler/msServiceProfiler.h"
#include "policy/dynamic_batch_recorder.h"
#include "policy/stage_policy/edge_cloud_policy.h"
#include "policy/stage_policy/stage_policy.h"
using namespace mindie_llm;
using namespace model_execute_data;
std::atomic<int> g_decodeTokenCount = 0;
ModelExecOutputHandler::ModelExecOutputHandler(ForwardRespToManagerCall cb, Role pdRole, SchedulerConfigSPtr &config,
std::shared_ptr<LatencyPredictor> latencypredictor, size_t localDPRank)
: role_(pdRole),
forwardRespToManagerCall_(cb),
schedulerConfig_(config),
bufferResponseConfig_({config->bufferResponseEnabled, config->prefillExpectedTime, config->decodeExpectedTime}),
latencypredictor_(latencypredictor),
localDPRank_(localDPRank),
bufferedResponser_(cb, bufferResponseConfig_),
dpRankId_(config->dpRankId_) {}
void ModelExecOutputHandler::AsyncPublishPrefilledKvCache(ModelBatchResultSPtr &modelBatchResult) {
for (int i = 0; i < modelBatchResult->outputs_size(); i++) {
model_execute_data::CompletionSequenceGroupOutput output = modelBatchResult->outputs(i);
if (output.samples_size() == 0) {
throw std::runtime_error("There is no sample in output.");
}
model_execute_data::SequenceOutput firstSample = output.samples(0);
SequenceGroupSPtr seqGroup = LiveInferContext::GetInstance(localDPRank_)->GetSeqGroup(firstSample.seq_id());
if ((schedulerConfig_->enableChunkedPrefill) && (seqGroup != nullptr) && (!seqGroup->isLastChunk_)) {
continue;
}
if (firstSample.finish_reason() == static_cast<int64_t>(InferStatusType::ITERATION_CONTINUE)) {
if (seqGroup == nullptr) {
MINDIE_LLM_LOG_INFO("Can not find sequence group, seqId=" << firstSample.seq_id());
continue;
}
ResponseSPtr response = std::make_shared<Response>(seqGroup->metrics_.inferReqId_);
response->transferStatusFlag = TransferStatusType::PUBLISH_KV_COMPLETE;
response->responseContents.resize(1);
response->responseContents[0].srcBlockTable = seqGroup->pBlockTable;
response->responseContents[0].isThinking = seqGroup->isThinking_;
response->responseContents[0].singleLLMPrefillReqHandlerId = localDPRank_;
MINDIE_LLM_LOG_INFO_REQUEST("[LlmEngine|Request-Publish Complete] DP RankId: "
<< dpRankId_ << ". Request Prefill Complete, requestId: "
<< seqGroup->metrics_.inferReqId_ << ", seqId: " << firstSample.seq_id()
<< ", pInstanceId:" << seqGroup->pInstanceId
<< ", localDPRank_:" << localDPRank_);
seqGroup->isThinking_ = false;
forwardRespToManagerCall_(response);
}
}
}
void ModelExecOutputHandler::Entry4Executor(ModelBatchResultSPtr &modelBatchResult) {
if (modelBatchResult == nullptr) {
throw std::runtime_error("modelBatchResult is nullptr.");
}
if (modelBatchResult->outputs_size() == 0) {
asyncBatchNum_.fetch_sub(1);
return;
}
LiveInferContextSPtr liveInferContext = LiveInferContext::GetInstance(localDPRank_);
std::vector<ResponseSPtr> responsesToCallback;
std::vector<uint64_t> queueWaitTimes;
std::vector<uint64_t> prefixCachedTokenNums;
ForwardMode lastForwardMode = ForwardMode::DUMMY;
bool layerwiseNeedUpdate = false;
ForwardMode lwdCurrBatchType = ForwardMode::DUMMY;
if (schedulerConfig_->layerwiseDisaggregated) {
lwdCurrBatchType = static_cast<ForwardMode>(!modelBatchResult->layerwise_is_prefill());
}
std::deque<SequenceGroupSPtr> recomputeInDBatchQueue;
for (const CompletionSequenceGroupOutput &output : modelBatchResult->outputs()) {
uint64_t queueWaitTime = 0;
uint64_t currentPrefixCachedTokenNums = 0;
auto spanConvert = PROF(INFO, Domain("Engine").SpanStart("ConvertOutputToResponse"));
if (output.samples_size() == 0) {
throw std::runtime_error("There is no sample in output.");
}
ResponseSPtr response =
ConvertSequenceGroupOutputToResponse(output, queueWaitTime, currentPrefixCachedTokenNums);
PROF(spanConvert.SpanEnd());
if (output.samples_size() > 1) {
HandleParallelSampling(output, liveInferContext);
} else if (output.samples_size() == 1) {
HandleGreedySampling(output.samples(0), response);
}
SequenceGroupSPtr seqGroup = liveInferContext->GetSeqGroup(output.samples(0).seq_id());
bool discardChunkedPrefillReqToken =
(seqGroup != nullptr) && ((schedulerConfig_->enableChunkedPrefill) && (!seqGroup->isLastChunk_));
if (response != nullptr) {
if (discardChunkedPrefillReqToken) {
MINDIE_LLM_LOG_DEBUG_REQUEST("The output token of the chunked prefill request need to be discard.");
} else {
queueWaitTimes.push_back(queueWaitTime);
responsesToCallback.push_back(response);
for (size_t i = 0; i < response->responseContents.size(); i++) {
g_decodeTokenCount += response->responseContents[i].speculativeTokenNum;
}
MINDIE_LLM_LOG_INFO_TOKEN("[LlmEngine|Request-Response] DP RankId: "
<< dpRankId_ << ". Response generated, requestId: " << response->reqId
<< ", batchsize: " << modelBatchResult->outputs_size()
<< ", total decoded tokens: " << g_decodeTokenCount);
prefixCachedTokenNums.push_back(currentPrefixCachedTokenNums);
}
}
layerwiseNeedUpdate =
layerwiseMixin_.LwdProcessResponse(schedulerConfig_->layerwiseDisaggregated, seqGroup, lastForwardMode,
lwdCurrBatchType, recomputeInDBatchQueue);
}
layerwiseMixin_.LwdProcessRecomputeSeq(layerwiseNeedUpdate, lastForwardMode, recomputeInDBatchQueue);
layerwiseMixin_.LwdHandlerSubBatchCnt(schedulerConfig_->layerwiseDisaggregated, stagePolicy_, lwdCurrBatchType);
if (schedulerConfig_->stageSelectPolicy == static_cast<uint32_t>(StagePolicyType::LATENCY_FIRST)) {
auto &recorder = DynamicBatchRecorder::GetInstance(localDPRank_);
auto predictor = recorder.GetLatencyPredictor();
if (predictor != nullptr) {
predictor->UpdateBatchStats();
}
}
asyncBatchNum_.fetch_sub(1);
for (size_t i = 0; i < responsesToCallback.size(); i++) {
ResponseSPtr response = responsesToCallback[i];
response->metrics.batchSize = responsesToCallback.size();
response->metrics.queueWaitTime = queueWaitTimes.at(i);
response->metrics.prefixCachedTokenNum = prefixCachedTokenNums.at(i);
if (bufferResponseConfig_.bufferResponseEnabled) {
bufferedResponser_.TryRespond(response);
} else {
forwardRespToManagerCall_(response);
}
}
SequenceGroupSPtr seqGroup = liveInferContext->GetSeqGroup(modelBatchResult->outputs().at(0).samples(0).seq_id());
if (seqGroup != nullptr && stagePolicy_ != nullptr) {
stagePolicy_->MarkInferenceEndTimeStamp();
}
if (role_ == Role::P ||
(seqGroup != nullptr && liveInferContext->GetInferInstanceFlexRole4Req(seqGroup->requestId) == Role::FlexP)) {
AsyncPublishPrefilledKvCache(modelBatchResult);
}
}
SequenceGroupSPtr ModelExecOutputHandler::FindRootSequenceGroup(const CompletionSequenceGroupOutput &output,
LiveInferContextSPtr &liveInferContext) const {
SequenceGroupSPtr seqGroup = nullptr;
for (const model_execute_data::SequenceOutput &sample : output.samples()) {
seqGroup = liveInferContext->GetSeqGroupFromSeqRootMap(sample.parent_seq_id());
if (seqGroup != nullptr) {
break;
}
}
return seqGroup;
}
void ModelExecOutputHandler::ProcessSequenceStatus(SequenceId seqId, int64_t finishReason) {
if (finishReason == static_cast<int64_t>(InferStatusType::ITERATION_CONTINUE)) {
return;
}
MINDIE_LLM_LOG_INFO_REQUEST("[LlmEngine|Request-End] DP RankId: " << dpRankId_ << ". Sequence finished. seqId: "
<< seqId << "; finishReason: " << finishReason);
if (finishReason == static_cast<int64_t>(InferStatusType::END_OF_SENTENCE)) {
finishedSeqIds_.PushBack(seqId);
} else {
execExceptionSeqIds_.PushBack(seqId);
}
}
void ModelExecOutputHandler::UpdateThinkingStatus(SequenceGroupSPtr &seqGrp, int64_t outputToken) {
if (schedulerConfig_->earlyStoppingIds.size() == 0) {
return;
}
if (outputToken == schedulerConfig_->startThinkingId) {
seqGrp->isThinking_ = true;
}
if (seqGrp->isThinking_) {
seqGrp->thinkingTokens++;
}
if (outputToken == schedulerConfig_->stopThinkingId) {
seqGrp->isThinking_ = false;
} else if (seqGrp->isThinking_ && seqGrp->thinkingTokens >= seqGrp->thinkingBudget_) {
seqGrp->exceededThinkingbudget_ = true;
}
}
void ModelExecOutputHandler::UpdateResponse(SequenceGroupSPtr &seqGrp, ResponseSPtr &response) {
if (response == nullptr || response->responseContents.size() == 0) {
return;
}
size_t &speculativeTokenNum = response->responseContents[0].speculativeTokenNum;
std::vector<TokenId> &outTokenIds = response->responseContents[0].outTokenIds;
std::vector<float> &outLogProbs = response->responseContents[0].outLogProbs;
std::vector<TokenId> &topTokenIds = response->responseContents[0].topLogProbTokenIds;
std::vector<float> &topLogProbs = response->responseContents[0].topLogProbs;
std::vector<TokenId> &stopIds = schedulerConfig_->earlyStoppingIds;
outTokenIds.insert(outTokenIds.end(), stopIds.begin(), stopIds.end());
speculativeTokenNum += stopIds.size();
outLogProbs.insert(outLogProbs.end(), stopIds.size(), 0);
if (seqGrp->topLogProbs_ > 0) {
std::for_each(stopIds.begin(), stopIds.end(),
[&](TokenId token) { topTokenIds.insert(topTokenIds.end(), seqGrp->topLogProbs_, token); });
topLogProbs.insert(topLogProbs.end(), seqGrp->topLogProbs_ * stopIds.size(), 0);
}
}
void ModelExecOutputHandler::HandleGreedySampling(const model_execute_data::SequenceOutput &sample,
ResponseSPtr &response) {
auto spanGreedySampling = PROF(INFO, Domain("Engine").SpanStart("HandleGreedySampling"));
SequenceGroupSPtr seqGrp = LiveInferContext::GetInstance(localDPRank_)->GetSeqGroup(sample.seq_id());
int64_t tokenIdx = 0;
for (int64_t output_token : sample.output_token()) {
if (schedulerConfig_->speculationGamma > 0 && tokenIdx >= sample.num_speculative_tokens()) {
break;
}
tokenIdx++;
if (output_token != PLACEHOLDER_TOKEN) {
seqIdToOutputTokenQueue_.PushBack(std::pair<SequenceId, TokenId>{sample.seq_id(), output_token});
if (seqGrp != nullptr && seqGrp->enableThinking_ && seqGrp->thinkingBudget_ > 0) {
UpdateThinkingStatus(seqGrp, output_token);
}
} else if (schedulerConfig_->layerwiseDisaggregated) {
MINDIE_LLM_LOG_INFO("[layerwiseDisaggregated|handler] " << "seq id is " << sample.seq_id()
<< ", output_token is -1");
}
}
if (seqGrp != nullptr && seqGrp->exceededThinkingbudget_) {
UpdateResponse(seqGrp, response);
}
ProcessSequenceStatus(sample.seq_id(), sample.finish_reason());
PROF(spanGreedySampling.SpanEnd());
}
void ModelExecOutputHandler::HandleParallelSampling(const model_execute_data::CompletionSequenceGroupOutput &output,
LiveInferContextSPtr &liveInferContext) {
auto spanParallelSampling = PROF(INFO, Domain("Engine").SpanStart("HandleParallelSampling"));
SequenceGroupSPtr rootSeqGrp = FindRootSequenceGroup(output, liveInferContext);
if (rootSeqGrp == nullptr) {
for (const model_execute_data::SequenceOutput &sample : output.samples()) {
execExceptionSeqIds_.PushBack(sample.seq_id());
}
return;
}
for (const model_execute_data::SequenceOutput &sample : output.samples()) {
SequenceId seqId = sample.seq_id();
if (seqId == EOS_SEQUENCE_ID) {
continue;
}
if (rootSeqGrp->seqId2ParallelSeqGroup_.Count(seqId) == 0) {
CreateNewSequenceGroup(sample, rootSeqGrp, liveInferContext);
} else if (seqId != sample.parent_seq_id()) {
UpdateSequenceGroup(sample, rootSeqGrp);
}
}
std::unordered_set<SequenceId> outputSeqIds;
for (const model_execute_data::SequenceOutput &sample : output.samples()) {
SequenceId seqId = sample.seq_id();
if (seqId == EOS_SEQUENCE_ID) {
continue;
}
SequenceGroupSPtr seqGrp = rootSeqGrp->seqId2ParallelSeqGroup_.Get(seqId).value();
if (!seqGrp->isNewSeqGroup_) {
for (TokenId outputToken : sample.output_token()) {
seqIdToOutputTokenQueue_.PushBack(std::pair<SequenceId, TokenId>{seqId, outputToken});
}
}
ProcessSequenceStatus(sample.seq_id(), sample.finish_reason());
outputSeqIds.insert(seqId);
}
std::vector<SequenceId> parallelSeqIds = rootSeqGrp->seqId2ParallelSeqGroup_.KeySet();
for (const auto &seqId : parallelSeqIds) {
if (outputSeqIds.count(seqId) == 0) {
execExceptionSeqIds_.PushBack(seqId);
}
}
PROF(spanParallelSampling.SpanEnd());
}
void ModelExecOutputHandler::CreateNewSequenceGroup(const model_execute_data::SequenceOutput &sample,
SequenceGroupSPtr &rootSeqGrp,
LiveInferContextSPtr &liveInferContext) const {
if (!rootSeqGrp->seqId2ParallelSeqGroup_.Get(sample.parent_seq_id())) {
MINDIE_LLM_LOG_ERROR("Can not find sequence group for parent seq id=" << sample.parent_seq_id());
throw std::runtime_error("Can not find sequence group for parent seq id");
}
SequenceGroupSPtr parentSeqGrp = rootSeqGrp->seqId2ParallelSeqGroup_.Get(sample.parent_seq_id()).value();
std::vector<TokenId> promptTokenIds = parentSeqGrp->firstSeq->data_.promptTokenIds;
SequenceSPtr newSeq =
std::make_shared<Sequence>(sample.seq_id(), parentSeqGrp->firstSeq->blockSize_, promptTokenIds);
std::vector<TokenId> &outputTokenIds = newSeq->data_.outputTokenIds;
std::copy(sample.output_token().begin(), sample.output_token().end(), std::back_inserter(outputTokenIds));
SequenceGroupSPtr newSeqGrp =
std::make_shared<SequenceGroup>(rootSeqGrp->requestId, std::vector<SequenceSPtr>{newSeq});
newSeqGrp->isNewSeqGroup_ = true;
newSeqGrp->needUpdate_ = true;
newSeqGrp->parentSeqId_ = sample.parent_seq_id();
rootSeqGrp->seqId2ParallelSeqGroup_.Insert(sample.seq_id(), newSeqGrp);
liveInferContext->AddIntoSeqRootMap(sample.seq_id(), rootSeqGrp);
newSeqGrp->UpdateNumComputedTokens(newSeqGrp->firstSeq->GetLen());
}
void ModelExecOutputHandler::UpdateSequenceGroup(const model_execute_data::SequenceOutput &sample,
SequenceGroupSPtr &rootSeqGrp) const {
if (rootSeqGrp == nullptr) {
throw std::runtime_error("rootSeqGrp is null.");
}
if (!rootSeqGrp->seqId2ParallelSeqGroup_.Get(sample.parent_seq_id())) {
MINDIE_LLM_LOG_ERROR("Can not find sequence group for parent seq id=" << sample.parent_seq_id());
throw std::runtime_error("Can not find sequence group for parent seq id");
}
SequenceGroupSPtr parentSeqGrp = rootSeqGrp->seqId2ParallelSeqGroup_.Get(sample.parent_seq_id()).value();
if (parentSeqGrp == nullptr) {
throw std::runtime_error("parentSeqGrp is null.");
}
if (!rootSeqGrp->seqId2ParallelSeqGroup_.Get(sample.seq_id())) {
MINDIE_LLM_LOG_ERROR("Can not find sequence group for seq id=" << sample.seq_id());
throw std::runtime_error("Can not find sequence group for seq id");
}
SequenceGroupSPtr seqGrp = rootSeqGrp->seqId2ParallelSeqGroup_.Get(sample.seq_id()).value();
if (seqGrp == nullptr) {
throw std::runtime_error("seqGrp is null.");
}
seqGrp->needUpdate_ = true;
seqGrp->firstSeq->data_.promptTokenIds = parentSeqGrp->firstSeq->data_.promptTokenIds;
seqGrp->firstSeq->data_.outputTokenIds = parentSeqGrp->firstSeq->data_.outputTokenIds;
seqGrp->parentSeqId_ = sample.parent_seq_id();
}
void ModelExecOutputHandler::AddOutputsToResponse(
ResponseSPtr response, const model_execute_data::CompletionSequenceGroupOutput &output) const {
response->numParallelTokens = static_cast<size_t>(output.samples(0).num_parallel_tokens());
for (const model_execute_data::SequenceOutput &sample : output.samples()) {
int tokenNum = sample.num_speculative_tokens();
int trailingPlaceholderNum = 0;
while (trailingPlaceholderNum < tokenNum &&
sample.output_token(tokenNum - 1 - trailingPlaceholderNum) == PLACEHOLDER_TOKEN) {
trailingPlaceholderNum++;
}
if (trailingPlaceholderNum == tokenNum) {
continue;
}
response->responseContents.emplace_back(ResponseContent{
.seqId = sample.seq_id(),
.parentSeqId = sample.parent_seq_id(),
.finishReason = static_cast<InferStatusType>(sample.finish_reason()),
.speculativeTokenNum = static_cast<size_t>(sample.num_speculative_tokens()),
.outTokenIds =
std::vector<TokenId>(sample.output_token().begin(), sample.output_token().begin() + tokenNum),
.outLogProbs = std::vector<float>(sample.logprob().begin(), sample.logprob().end()),
.cumLogProb = sample.cumulative_logprobs(),
.truncationIndex = sample.truncation_index(),
.topLogProbTokenIds = std::vector<TokenId>(sample.top_token_ids().begin(), sample.top_token_ids().end()),
.topLogProbs = std::vector<float>(sample.top_logprobs().begin(), sample.top_logprobs().end()),
.srcBlockTable = {},
.singleLLMPrefillReqHandlerId = 0,
.pdErrorCode = 0,
.isThinking = false});
}
}
ResponseSPtr ModelExecOutputHandler::ConvertSequenceGroupOutputToResponse(
const model_execute_data::CompletionSequenceGroupOutput &output, uint64_t &queueWaitTime,
uint64_t ¤tPrefixCachedTokenNums) {
LiveInferContextSPtr liveInferContext = LiveInferContext::GetInstance(localDPRank_);
SequenceGroupSPtr seqGroup = FindRootSequenceGroup(output, liveInferContext);
if (seqGroup == nullptr) {
seqGroup = liveInferContext->GetSeqGroup(output.samples(0).seq_id());
}
if (seqGroup == nullptr) {
MINDIE_LLM_LOG_DEBUG_REQUEST("Can not find sequence group.");
return nullptr;
}
ResponseSPtr response = std::make_shared<Response>(seqGroup->metrics_.inferReqId_);
SetResponseFlags(output, response);
response->iterTimes = seqGroup->iterTimes;
seqGroup->iterTimes++;
seqGroup->metrics_.responseTime_ = std::chrono::high_resolution_clock::now();
queueWaitTime = seqGroup->metrics_.queueWaitTime_;
currentPrefixCachedTokenNums = seqGroup->metrics_.prefixCachedTokenNum_;
if (bufferResponseConfig_.bufferResponseEnabled) {
bufferedResponser_.RecordArriveTime(seqGroup->metrics_.inferReqId_, seqGroup->arriveTime);
}
if (schedulerConfig_->dynamicBatchSizeEnable) {
uint32_t numOutputTokens = 1;
if (output.samples_size() > 0) {
numOutputTokens = static_cast<uint32_t>(output.samples(0).num_speculative_tokens());
}
latencypredictor_->AddPercentileData(seqGroup, schedulerConfig_, numOutputTokens);
}
AddOutputsToResponse(response, output);
if (response->responseContents.empty()) {
return nullptr;
}
return response;
}
void ModelExecOutputHandler::SetResponseFlags(const model_execute_data::CompletionSequenceGroupOutput &output,
ResponseSPtr response) const {
size_t continueSeqCount =
static_cast<size_t>(std::count_if(output.samples().begin(), output.samples().end(), [](const auto &sample) {
return sample.finish_reason() == static_cast<int64_t>(InferStatusType::ITERATION_CONTINUE);
}));
if (continueSeqCount == 0) {
response->isEos = true;
MINDIE_LLM_LOG_INFO_REQUEST("[LlmEngine|Request-End] DP RankId: " << dpRankId_ << ". Send eos response. seqId: "
<< output.samples(0).seq_id());
}
response->inferStatusFlag = static_cast<InferStatusType>(output.samples(0).finish_reason());
LiveInferContextSPtr liveInferContext = LiveInferContext::GetInstance(localDPRank_);
SequenceGroupSPtr seqGroup = liveInferContext->GetSeqGroup(output.samples(0).seq_id());
if (role_ == Role::P ||
(seqGroup != nullptr && liveInferContext->GetInferInstanceFlexRole4Req(seqGroup->requestId) == Role::FlexP)) {
response->transferStatusFlag = TransferStatusType::PREFILL_COMPLETE;
}
}
ConcurrentDeque<SequenceId> &ModelExecOutputHandler::GetFinishedSeqIds() { return finishedSeqIds_; }
ConcurrentDeque<SequenceId> &ModelExecOutputHandler::GetExceptionSeqIds() { return execExceptionSeqIds_; }
ConcurrentDeque<std::pair<SequenceId, TokenId>> &ModelExecOutputHandler::GetSeqIdToOutputTokenQueue() {
return seqIdToOutputTokenQueue_;
}
std::atomic<size_t> &ModelExecOutputHandler::GetAsyncBatchNum() { return asyncBatchNum_; }