* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "RLMicroBatchMegatronClassifier.h"
#include "RLMicroBatchFSDPClassifier.h"
#include "gtest/gtest.h"
using namespace Dic::Module::RL;
class MicroBatchTest : public testing::Test {
protected:
class MicroBatchMegaTronTest : public RLMicroBatchMegatronClassifier {
public:
std::vector<RLPipelineNode> Classifier(std::vector<RLPipelineNode> &items) {
return RLMicroBatchMegatronClassifier::MicroBatchClassifier(items);
}
};
class MicroBatchFSDPTest : public RLMicroBatchFSDPClassifier {
public:
std::vector<RLPipelineNode> Classifier(std::vector<RLPipelineNode> &items) {
return RLMicroBatchFSDPClassifier::MicroBatchClassifier(items);
}
std::vector<RLPipelineNode> NodeSortMerge(
const std::vector<RLPipelineNode> &left, const std::vector<RLPipelineNode> &right) {
return RLMicroBatchFSDPClassifier::NodeSortMerge(left, right);
}
};
std::vector<RLPipelineNode> generatorOneByOneData() {
std::vector<RLPipelineNode> res;
RLPipelineNode fp1{"", "FP", 0, 10, "transformerBlock", "rollout"};
res.push_back(fp1);
RLPipelineNode bp1{"", "BP", 12, 10, "transformerLayer", "rollout"};
res.push_back(bp1);
RLPipelineNode fp2{"", "FP", 24, 10, "transformerBlock", "rollout"};
res.push_back(fp2);
RLPipelineNode bp2{"", "BP", 35, 2, "transformerLayer", "rollout"};
res.push_back(bp2);
return res;
}
std::vector<RLPipelineNode> generatorData1() {
std::vector<RLPipelineNode> res;
RLPipelineNode fp1{"", "FP", 0, 50, "transformerBlock", "rollout"};
res.push_back(fp1);
RLPipelineNode bp1{"", "BP", 1, 2, "transformerLayer", "rollout"};
res.push_back(bp1);
RLPipelineNode bp2{"", "BP", 4, 4, "transformerLayer", "rollout"};
res.push_back(bp2);
RLPipelineNode fp2{"", "FP", 51, 10, "transformerBlock", "rollout"};
res.push_back(fp2);
RLPipelineNode bp3{"", "BP", 62, 2, "transformerLayer", "rollout"};
res.push_back(bp3);
RLPipelineNode bp4{"", "BP", 65, 3, "transformerLayer", "rollout"};
res.push_back(bp4);
RLPipelineNode bp5{"", "BP", 69, 10, "transformerLayer", "rollout"};
res.push_back(bp5);
return res;
}
std::vector<RLPipelineNode> generatorFPOverData() {
std::vector<RLPipelineNode> res;
RLPipelineNode fp1{"", "FP", 0, 10, "transformerBlock", "rollout"};
res.push_back(fp1);
RLPipelineNode bp1{"", "FP", 2, 8, "transformerLayer", "rollout"};
res.push_back(bp1);
return res;
}
};
TEST_F(MicroBatchTest, normalOnebyOne) {
auto originalData = generatorOneByOneData();
MicroBatchMegaTronTest classifier;
auto res = classifier.Classifier(originalData);
EXPECT_EQ(res.size(), 4);
auto node1 = res[0];
EXPECT_EQ(node1.nodeType, "FP");
EXPECT_EQ(node1.startTime, 0);
EXPECT_EQ(node1.duration, 10);
auto node2 = res[1];
EXPECT_EQ(node2.nodeType, "BP");
EXPECT_EQ(node2.startTime, 12);
EXPECT_EQ(node2.duration, 10);
auto node3 = res[2];
EXPECT_EQ(node3.nodeType, "FP");
EXPECT_EQ(node3.startTime, 24);
EXPECT_EQ(node3.duration, 10);
auto node4 = res[3];
EXPECT_EQ(node4.nodeType, "BP");
EXPECT_EQ(node4.startTime, 35);
EXPECT_EQ(node4.duration, 2);
}
TEST_F(MicroBatchTest, timeCover) {
auto original = generatorData1();
MicroBatchMegaTronTest classifier;
auto res = classifier.Classifier(original);
EXPECT_EQ(res.size(), 4);
auto node1 = res[0];
EXPECT_EQ(node1.nodeType, "FP");
EXPECT_EQ(node1.startTime, 0);
EXPECT_EQ(node1.duration, 50);
auto node2 = res[1];
EXPECT_EQ(node2.nodeType, "FP");
EXPECT_EQ(node2.startTime, 51);
EXPECT_EQ(node2.duration, 10);
auto node3 = res[2];
EXPECT_EQ(node3.nodeType, "BP");
EXPECT_EQ(node3.startTime, 62);
EXPECT_EQ(node3.duration, 6);
auto node4 = res[3];
EXPECT_EQ(node4.nodeType, "BP");
EXPECT_EQ(node4.startTime, 69);
EXPECT_EQ(node4.duration, 10);
}
TEST_F(MicroBatchTest, fsdp_normalOnebyOne) {
auto originalData = generatorOneByOneData();
MicroBatchFSDPTest classifier;
auto res = classifier.Classifier(originalData);
EXPECT_EQ(res.size(), 4);
auto node1 = res[0];
EXPECT_EQ(node1.nodeType, "FP");
EXPECT_EQ(node1.startTime, 0);
EXPECT_EQ(node1.duration, 10);
auto node2 = res[1];
EXPECT_EQ(node2.nodeType, "BP");
EXPECT_EQ(node2.startTime, 12);
EXPECT_EQ(node2.duration, 10);
auto node3 = res[2];
EXPECT_EQ(node3.nodeType, "FP");
EXPECT_EQ(node3.startTime, 24);
EXPECT_EQ(node3.duration, 10);
auto node4 = res[3];
EXPECT_EQ(node4.nodeType, "BP");
EXPECT_EQ(node4.startTime, 35);
EXPECT_EQ(node4.duration, 2);
}
TEST_F(MicroBatchTest, fsdp_timeCover) {
auto original = generatorData1();
MicroBatchFSDPTest classifier;
auto res = classifier.Classifier(original);
EXPECT_EQ(res.size(), 4);
auto node1 = res[0];
EXPECT_EQ(node1.nodeType, "FP");
EXPECT_EQ(node1.startTime, 0);
EXPECT_EQ(node1.duration, 50);
auto node2 = res[1];
EXPECT_EQ(node2.nodeType, "BP");
EXPECT_EQ(node2.startTime, 1);
EXPECT_EQ(node2.duration, 7);
auto node3 = res[2];
EXPECT_EQ(node3.nodeType, "FP");
EXPECT_EQ(node3.startTime, 51);
EXPECT_EQ(node3.duration, 10);
auto node4 = res[3];
EXPECT_EQ(node4.nodeType, "BP");
EXPECT_EQ(node4.startTime, 62);
EXPECT_EQ(node4.duration, 17);
}
TEST_F(MicroBatchTest, fsdp_fp_timer_cover) {
auto original = generatorFPOverData();
MicroBatchFSDPTest classifier;
auto res = classifier.Classifier(original);
EXPECT_EQ(res.size(), 1);
EXPECT_EQ(res[0].startTime, 0);
EXPECT_EQ(res[0].duration, 10);
}
TEST_F(MicroBatchTest, fsdp_node_sort_merge) {
auto left = generatorData1();
auto right = generatorFPOverData();
MicroBatchFSDPTest classifier;
auto res = classifier.NodeSortMerge(left, right);
EXPECT_EQ(res.size(), 9);
}