* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include <set>
#include <unordered_set>
#include <algorithm>
#include "ServerLog.h"
#include "RLMstxConfigReader.h"
#include "RLMstxConfigManager.h"
namespace Dic::Module::RL {
RLMstxConfigManager &RLMstxConfigManager::Instance() {
static RLMstxConfigManager instance;
return instance;
}
std::vector<RLMstxConfig> RLMstxConfigManager::GetRLMstxConfig() { return config; }
std::vector<std::string> RLMstxConfigManager::GetMstxTaskNameList() {
std::set<std::string> resSet;
for (const auto &item : config) {
for (const auto &task : item.taskConfigs) {
resSet.insert(task.taskName);
}
}
return std::vector<std::string>(resSet.begin(), resSet.end());
}
std::string RLMstxConfigManager::GetTaskTypeByName(const std::string &name) {
std::string type = "";
for (const auto &item : config) {
for (const auto &task : item.taskConfigs) {
if (name == task.taskName) {
return task.roleName;
}
}
}
return type;
}
RLMstxConfig RLMstxConfigManager::GetMstxConfigByTaskName(const std::vector<std::string> &taskNames) {
std::unordered_set<std::string> taskNameSet{taskNames.begin(), taskNames.end()};
auto it = std::find_if(config.begin(), config.end(), [&taskNameSet](const RLMstxConfig &configItem) {
for (const auto &taskConfig : configItem.taskConfigs) {
if (taskNameSet.count(taskConfig.taskName) == 0) {
return false;
}
}
return true;
});
if (it == config.end()) {
Server::ServerLog::Warn("No total matching config could be found");
for (const std::string &taskName : taskNameSet) {
auto iterator = std::find_if(config.begin(), config.end(), [&taskName](const RLMstxConfig &configItem) {
return configItem.taskConfigMap.find(taskName) != configItem.taskConfigMap.end();
});
if (iterator != config.end()) {
return *iterator;
}
}
return {};
}
return *it;
}
void RLMstxConfigManager::InitConfig() {
RLMstxConfigReader reader;
config = reader.ReadConfigFile();
if (config.empty()) {
InitDefaultConf();
}
}
RLMstxConfigManager::RLMstxConfigManager() { InitConfig(); }
void RLMstxConfigManager::InitDefaultConf() {
InitVerlGrpoConf();
InitMindspeedRlGrpoConf();
}
void RLMstxConfigManager::InitVerlGrpoConf() {
RLMstxConfig defaultConf = {
.framework = "verl",
.algorithm = "GRPO",
};
TaskConfig gs{.roleName = "ActorRollout", .taskName = "generate_sequences"};
TaskConfig reward{.roleName = "Reward", .taskName = "compute_log_prob"};
MicroBatchConfig fp{.batchName = "TransformerBlock", .type = "FP"};
MicroBatchConfig bp{.batchName = "TransformerLayer", .type = "BP"};
reward.AddMicroBatchConf(std::move(fp));
reward.AddMicroBatchConf(std::move(bp));
TaskConfig actor{.roleName = "Actor", .taskName = "update_actor"};
MicroBatchConfig fp1{.batchName = "TransformerBlock", .type = "FP"};
MicroBatchConfig bp1{.batchName = "TransformerLayer", .type = "BP"};
actor.AddMicroBatchConf(std::move(fp1));
actor.AddMicroBatchConf(std::move(bp1));
TaskConfig refLog{.roleName = "Reference", .taskName = "compute_ref_log_prob"};
MicroBatchConfig fp2{.batchName = "TransformerBlock", .type = "FP"};
MicroBatchConfig bp2{.batchName = "TransformerLayer", .type = "BP"};
refLog.AddMicroBatchConf(std::move(fp2));
refLog.AddMicroBatchConf(std::move(bp2));
defaultConf.AddTaskConfig(std::move(gs));
defaultConf.AddTaskConfig(std::move(reward));
defaultConf.AddTaskConfig(std::move(actor));
defaultConf.AddTaskConfig(std::move(refLog));
config.push_back(defaultConf);
}
void RLMstxConfigManager::InitMindspeedRlGrpoConf() {
RLMstxConfig mindSpeedRlConf = {
.framework = "MindSpeed-RL",
.algorithm = "GRPO",
};
TaskConfig mindSpeedGS{.roleName = "ActorRollout", .taskName = "ActorHybridWorkerBase.generate_sequences"};
MicroBatchConfig fp{.batchName = "TransformerBlock", .type = "FP"};
MicroBatchConfig bp{.batchName = "TransformerLayer", .type = "BP"};
mindSpeedGS.AddMicroBatchConf(std::move(fp));
mindSpeedGS.AddMicroBatchConf(std::move(bp));
TaskConfig mindSpeedRef{.roleName = "Reference", .taskName = "IntegratedWorker.compute_ref_log_prob"};
MicroBatchConfig fp2{.batchName = "TransformerBlock", .type = "FP"};
MicroBatchConfig bp2{.batchName = "TransformerLayer", .type = "BP"};
mindSpeedRef.AddMicroBatchConf(std::move(fp2));
mindSpeedRef.AddMicroBatchConf(std::move(bp2));
TaskConfig mindSpeedActor{.roleName = "Actor", .taskName = "ActorRolloutHybrid.update_actor"};
MicroBatchConfig fp3{.batchName = "TransformerBlock", .type = "FP"};
MicroBatchConfig bp3{.batchName = "TransformerLayer", .type = "BP"};
mindSpeedActor.AddMicroBatchConf(std::move(fp3));
mindSpeedActor.AddMicroBatchConf(std::move(bp3));
TaskConfig mindSpeedReward{.roleName = "Reward", .taskName = "Reward.compute_rm_score"};
MicroBatchConfig fp4{.batchName = "TransformerBlock", .type = "FP"};
MicroBatchConfig bp4{.batchName = "TransformerLayer", .type = "BP"};
mindSpeedReward.AddMicroBatchConf(std::move(fp4));
mindSpeedReward.AddMicroBatchConf(std::move(bp4));
mindSpeedRlConf.AddTaskConfig(std::move(mindSpeedGS));
mindSpeedRlConf.AddTaskConfig(std::move(mindSpeedRef));
mindSpeedRlConf.AddTaskConfig(std::move(mindSpeedActor));
mindSpeedRlConf.AddTaskConfig(std::move(mindSpeedReward));
config.push_back(mindSpeedRlConf);
}
}