* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#ifndef PROFILER_SERVER_RLMICROBATCHMEGATRONCLASSIFIER_H
#define PROFILER_SERVER_RLMICROBATCHMEGATRONCLASSIFIER_H
#include <queue>
#include "RLDomainObject.h"
#include "RLProtocolResponse.h"
#include "RLMicroBatchClassifierBase.h"
namespace Dic::Module::RL {
using namespace Protocol;
* @brief:
* 状态机算法,处理micro batch的分类和聚合
* 状态: 0--初始状态 1--正向转播阶段 2--反向转播阶段 3--结束
* 状态转换:
* 0-->1 : 接收到一个正向算子
* 0-->3: 完成遍历
* 1-->1 : 接收到一个正向算子
* : 接收到一个反向算子,且在正向算子的时间范围内, count++
* 1-->2 : 接收到一个反向算子,且在正向算子的时间范围外
* 0-->3: 完成遍历
* 2-->2 : 接收到一个反向算子,且count不为0, count--
* 2-->1 : 接收一个正向算子,count=0
* 0-->3: 完成遍历
* microBatch聚合逻辑(以transformer为例):
* transformer下正向转播算子名称为transformerBlock, 反向算子名称为transformerLayer。
* 在正向传播阶段一个transformerBlock的范围内可能有n个transformerLayer
* 对应在反向传播阶段 一个microBatch对应n个transformerLayer
*/
class RLMicroBatchMegatronClassifier : public RLMicroBatchClassifierBase {
public:
virtual ~RLMicroBatchMegatronClassifier() = default;
protected:
* @brief 查询数据
*/
std::vector<Protocol::RLPipelineNode> QueryMicroBatchSlices(
const std::string &fileId, const RLMstxConfig &config, const Protocol::RLPipelineNode &taskNode) override;
* @brief: 分类聚合
*/
std::vector<Protocol::RLPipelineNode> MicroBatchClassifier(std::vector<RLPipelineNode> &nodes) override;
void Clear();
private:
* @brief 封装前向传播microBatch的生成
*/
void PushFPNode(std::vector<Protocol::RLPipelineNode> &res);
* @brief 封装反向传播microBatch的生成
*/
void PushBPNode(std::vector<Protocol::RLPipelineNode> &res);
* @brief 设置当前的状态和node
*/
void SetStateAndNode(const RLPipelineNode &node, State state);
void InitStateProcess(std::vector<Protocol::RLPipelineNode> &res, const Protocol::RLPipelineNode &node);
void FPStateProcess(std::vector<Protocol::RLPipelineNode> &res, const RLPipelineNode &node);
void BPStateProcess(std::vector<Protocol::RLPipelineNode> &res, const RLPipelineNode &node);
private:
State state = Init;
std::queue<int> countQue;
int count = 0;
RLPipelineNode current;
};
}
#endif