* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "seq_group_builder_from_infer_req.h"
#include <functional>
#include "common_util.h"
#include "log.h"
#include "request_response/request.h"
#include "request_response/request_id.h"
namespace mindie_llm {
SequenceGroupSPtr SeqGroupBuilderFromInferReq::BuildSequenceGroup(RequestSPtr request,
SchedulerConfigSPtr schedulerConfig, size_t rankId) {
SamplingParamsSPtr sampleParamSptr = GetSampleParamFromInferRequest(request, true);
sampleParamSptr->enableParallelSampling = (sampleParamSptr->useBeamsearch || sampleParamSptr->bestOf > 1);
std::vector<SequenceSPtr> seqs;
SequenceSPtr seqSptr = InitSeqFromInferRequest(request, schedulerConfig);
seqs.push_back(seqSptr);
std::string requestId = request->requestId;
std::optional<std::string> loraId;
if (request->loraId != "None") {
loraId = request->loraId;
}
SequenceGroupSPtr seqGrpSPtr = std::make_shared<SequenceGroup>(requestId, seqs, sampleParamSptr, loraId, rankId);
InitSeqGrpFromInferRequest(request, requestId, seqGrpSPtr, schedulerConfig);
if (sampleParamSptr->enableParallelSampling) {
seqGrpSPtr->seqId2ParallelSeqGroup_.Insert(seqGrpSPtr->firstSeq->seqId_, seqGrpSPtr);
seqGrpSPtr->parentSeqId_ = seqGrpSPtr->firstSeq->seqId_;
}
return seqGrpSPtr;
}
SamplingParamsSPtr SeqGroupBuilderFromInferReq::GetSampleParamFromInferRequest(RequestSPtr request,
bool chooseV2BlockManager) {
SamplingParamsSPtr sampleParamSptr = std::make_shared<SamplingParams>();
sampleParamSptr->maxOutputLen = request->maxOutputLen;
sampleParamSptr->temperature = request->temperature;
sampleParamSptr->topK = request->topK;
sampleParamSptr->topP = request->topP;
sampleParamSptr->typicalP = request->typicalP;
sampleParamSptr->doSample = request->doSample;
sampleParamSptr->seed = request->seed;
sampleParamSptr->repetitionPenalty = request->repetitionPenalty;
sampleParamSptr->frequencyPenalty = request->frequencyPenalty;
sampleParamSptr->presencePenalty = request->presencyPenalty;
sampleParamSptr->watermark = request->watermark;
sampleParamSptr->includeStopStrInOutput = request->includeStopStrInOutput;
if (request->stopTokenIds.has_value()) {
sampleParamSptr->stopTokenIds = request->stopTokenIds.value();
}
if (request->stopStrings.has_value()) {
sampleParamSptr->stopStrings = request->stopStrings.value();
}
if (chooseV2BlockManager) {
uint32_t nValue = request->n.value_or(1);
sampleParamSptr->n = nValue;
sampleParamSptr->bestOf = request->bestOf.value_or(nValue);
}
sampleParamSptr->useBeamsearch = request->useBeamSearch.value_or(false);
sampleParamSptr->logprobs = request->logprobs;
sampleParamSptr->topLogprobs = request->topLogprobs;
sampleParamSptr->responseFormat = request->responseFormat;
return sampleParamSptr;
}
SequenceSPtr SeqGroupBuilderFromInferReq::InitSeqFromInferRequest(RequestSPtr request,
SchedulerConfigSPtr schedulerConfig) {
SequenceId seqId = IDUtils::GenerateSequenceId();
while (seqId == SIMULATE_SEQUENCE_ID) {
seqId = IDUtils::GenerateSequenceId();
}
if (request->isSimulateRequest) {
seqId = SIMULATE_SEQUENCE_ID;
MINDIE_LLM_LOG_DEBUG("[SimulateInference] Detected simulate inference request, requestId: "
<< request->requestId << ", assigned fixed seqId: " << seqId);
}
std::vector<TokenId> inputsTokenIds = request->input_ids;
const size_t inputTokenNum = inputsTokenIds.size();
if (inputTokenNum >= static_cast<size_t>(schedulerConfig->maxSeqLen)) {
throw std::invalid_argument("prompt tokens too long.");
}
SequenceSPtr seqSPtr = std::make_shared<Sequence>(seqId, schedulerConfig->cacheBlockSize, inputsTokenIds);
return seqSPtr;
}
* 从前端(service层)请求里获取由Tensor传递而来的参数, 组到SequenceGroup里
*/
void SeqGroupBuilderFromInferReq::InitSeqGrpFromInferRequest(RequestSPtr request, RequestIdNew inferReqId,
SequenceGroupSPtr &seqGroup,
SchedulerConfigSPtr schedulerConfigSptr) {
seqGroup->metrics_.inferReqId_ = inferReqId;
seqGroup->priority_ = request->priority;
seqGroup->isSynchronous_ = request->isSynchronous;
seqGroup->pInstanceId = request->pInstanceId.has_value() ? request->pInstanceId.value() : 0;
seqGroup->dpInstanceId_ = request->dpInstanceIds.empty() ? 0 : request->dpInstanceIds[0];
seqGroup->pBlockTable = request->srcBlockTable;
seqGroup->prefillReplayTokenIds_ = request->prefillReplayTokenIds;
seqGroup->isThinking_ = request->isThinking;
seqGroup->skipSpecialTokens_ = request->skipSpecialTokens;
seqGroup->ignoreEos_ = request->ignoreEos;
seqGroup->loraId_ = request->loraId;
seqGroup->enableThinking_ = request->enableThinking.has_value() ? request->enableThinking.value() : false;
seqGroup->thinkingBudget_ = request->thinkingBudget.has_value() ? request->thinkingBudget.value() : 0;
seqGroup->topLogProbs_ = request->topLogprobs.has_value() ? request->topLogprobs.value() : 0;
request->isThinking = false;
std::hash<std::string> hasher;
if (seqGroup->loraId_.has_value()) {
HashValue hv = hasher(seqGroup->loraId_.value());
hv |= 1;
if (seqGroup->loraId_.value() == "None") {
seqGroup->firstSeq->SetExtraHash(INVALID_HASH_VALUE);
} else {
seqGroup->firstSeq->SetExtraHash(hv);
}
}
if (seqGroup->isThinking_) {
seqGroup->thinkingTokens = 1;
}
uint32_t inputTokenNum = static_cast<uint32_t>(seqGroup->seqs_.at(0)->data_.promptTokenIds.size());
seqGroup->maxIterTimes_ =
std::min(schedulerConfigSptr->maxSeqLen - inputTokenNum, schedulerConfigSptr->maxIterTimes);
seqGroup->maxOutputLen_ = std::min(seqGroup->maxIterTimes_, seqGroup->sampling->maxOutputLen);
seqGroup->sampling->maxOutputLen = seqGroup->maxOutputLen_;
}
}