* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "RLMicroBatchMegatronClassifier.h"
#include "RLMstxConfigManager.h"
#include "ServerLog.h"
#include "DataBaseManager.h"
#include "RenderEngine.h"
namespace Dic::Module::RL {
using namespace Server;
using namespace FullDb;
std::vector<Protocol::RLPipelineNode> RLMicroBatchMegatronClassifier::MicroBatchClassifier(
std::vector<RLPipelineNode> &nodes) {
if (nodes.empty()) {
return {};
}
Clear();
std::vector<Protocol::RLPipelineNode> res;
for (auto &node : nodes) {
switch (state) {
case Init: {
InitStateProcess(res, node);
break;
}
case FP: {
FPStateProcess(res, node);
break;
}
case BP: {
BPStateProcess(res, node);
break;
}
default:
break;
}
}
res.push_back(current);
return res;
}
void RLMicroBatchMegatronClassifier::PushFPNode(std::vector<Protocol::RLPipelineNode> &res) {
res.emplace_back(std::move(current));
countQue.push(count);
count = 0;
}
void RLMicroBatchMegatronClassifier::PushBPNode(std::vector<Protocol::RLPipelineNode> &res) {
res.emplace_back(std::move(current));
if (!countQue.empty()) {
count = countQue.front();
countQue.pop();
} else {
count = 0;
}
}
void RLMicroBatchMegatronClassifier::SetStateAndNode(const RLPipelineNode &node, State newState) {
current = node;
this->state = newState;
}
void RLMicroBatchMegatronClassifier::Clear() {
state = Init;
count = 0;
current = RLPipelineNode();
countQue = std::queue<int>();
}
void RLMicroBatchMegatronClassifier::InitStateProcess(
std::vector<Protocol::RLPipelineNode> &res, const Protocol::RLPipelineNode &node) {
if (node.nodeType == "FP") {
SetStateAndNode(node, FP);
count = 0;
} else {
Server::ServerLog::Error("Receive a backward pass operator at init state, ignored");
}
}
void RLMicroBatchMegatronClassifier::FPStateProcess(
std::vector<Protocol::RLPipelineNode> &res, const RLPipelineNode &node) {
if (node.nodeType == "FP") {
if (node.startTime >= current.startTime &&
node.startTime + node.duration <= current.startTime + current.duration) {
return;
}
PushFPNode(res);
SetStateAndNode(node, FP);
} else {
if (node.startTime >= current.startTime &&
node.duration + node.startTime <= current.duration + current.startTime) {
count++;
} else {
PushFPNode(res);
count = countQue.front();
countQue.pop();
SetStateAndNode(node, BP);
}
}
}
void RLMicroBatchMegatronClassifier::BPStateProcess(
std::vector<Protocol::RLPipelineNode> &res, const RLPipelineNode &node) {
if (node.nodeType == "FP") {
PushBPNode(res);
SetStateAndNode(node, FP);
} else {
if (count <= 1) {
PushBPNode(res);
SetStateAndNode(node, BP);
return;
}
current.duration = (node.startTime + node.duration) - current.startTime;
count--;
}
}
std::vector<Protocol::RLPipelineNode> RLMicroBatchMegatronClassifier::QueryMicroBatchSlices(
const std::string &fileId, const RLMstxConfig &config, const Protocol::RLPipelineNode &taskNode) {
if (config.taskConfigMap.find(taskNode.name) == config.taskConfigMap.end()) {
ServerLog::Error("[RL] task config not found when query micro batch");
return {};
}
* 1. 整理所有microBatch
* 2. 在task的时间区间内过滤microBatch
* 3. 状态机算法处理microBatch的时间掩盖问题
*/
const auto &taskConfig = config.taskConfigMap.at(taskNode.name);
std::vector<std::string> microBatchNames;
microBatchNames.reserve(taskConfig.microBatchConfigs.size());
std::transform(taskConfig.microBatchConfigs.begin(), taskConfig.microBatchConfigs.end(),
std::back_inserter(microBatchNames), [](const MicroBatchConfig &item) { return item.batchName; });
if (microBatchNames.empty()) {
return {};
}
FullDb::DataType type = DataBaseManager::Instance().GetDataType(fileId);
auto microBatchInDbs = RenderEngine::Instance()->QueryMstxRLDetail(
fileId, type, microBatchNames, taskNode.startTime, NumberSafe::Add(taskNode.startTime, taskNode.duration));
if (microBatchInDbs.empty()) {
return {};
}
std::sort(microBatchInDbs.begin(), microBatchInDbs.end(),
[](const CompeteSliceDomain &left, const CompeteSliceDomain &right) {
if (left.timestamp != right.timestamp) {
return left.timestamp < right.timestamp;
} else {
return left.duration > right.duration;
}
});
std::vector<Protocol::RLPipelineNode> res;
std::for_each(
microBatchInDbs.begin(), microBatchInDbs.end(), [&res, &taskNode, &taskConfig](const auto &sliceItem) {
RLPipelineNode microBatchNode;
microBatchNode.stageType = taskNode.stageType;
microBatchNode.name = sliceItem.name;
microBatchNode.nodeType = taskConfig.microBatchConfigMap.at(microBatchNode.name).type;
microBatchNode.startTime = sliceItem.timestamp;
microBatchNode.duration = sliceItem.endTime - sliceItem.timestamp;
res.emplace_back(std::move(microBatchNode));
});
return res;
}
}