* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*/
#ifndef OCK_TTP_CONTROLLER_TEST_COMMON_H
#define OCK_TTP_CONTROLLER_TEST_COMMON_H
#include <fstream>
#include <mutex>
#include <sys/stat.h>
#include <gtest/gtest.h>
#include <mockcpp/mockcpp.hpp>
#define private public
#define protected public
#include "common.h"
#include "controller.h"
#include "processor.h"
#include "replica_manager.h"
#include "mindx_engine.h"
#include "ttp_logger.h"
#undef protected
#undef private
using namespace ock::ttp;
const std::string BACKUP_IP = "127.0.0.1";
const std::string BACKUP_PORT = "1234";
const std::string CONTROLLER_IP = "0.0.0.0";
constexpr uint32_t CONTROLLER_PORT = 8555;
constexpr const int32_t WORLD_SIZE = 4;
constexpr int64_t COMMON_STEP = 2;
constexpr int64_t BACKUP_STEP = 1;
constexpr uint32_t UCE_NO_REBUILD = 2;
constexpr uint8_t MASK_NORMAL = 0;
constexpr uint8_t MASK_ERROR = 1;
constexpr int32_t REPLICA_NUM_ONE = 1;
constexpr int32_t REPLICA_NUM_TWO = 2;
constexpr int32_t REPLICA_SHIFT_ONE = 1;
constexpr int32_t REPLICA_SHIFT_TWO = 2;
constexpr int32_t REPLICA_SHIFT_THREE = 3;
constexpr uint32_t CHECK_COUNT_ONE = 1;
constexpr uint32_t CHECK_COUNT_TWO = 2;
constexpr uint32_t CHECK_COUNT_THREE = 3;
constexpr uint32_t CHECK_COUNT_FOUR = 4;
constexpr uint32_t CHECK_COUNT_FIVE = 5;
constexpr uint32_t CHECK_COUNT_SIX = 6;
constexpr uint32_t CHECK_COUNT_SEVEN = 7;
constexpr uint32_t SLEEP_TWO = 2;
class ControllerTest : public testing::Test {
public:
void SetUp() override {}
void TearDown() override
{
usleep(TTP_WAIT_TIME_1MS);
if (controller1 != nullptr) {
controller1->Destroy();
controller1 = nullptr;
Controller::GetInstance(true);
MindXEngine::GetInstance()->Destroy();
MindXEngine::GetInstance(true);
}
if (controller2 != nullptr) {
controller2->Destroy();
controller2 = nullptr;
Controller::GetInstance(true);
MindXEngine::GetInstance()->Destroy();
MindXEngine::GetInstance(true);
}
if (processor1 != nullptr) {
processor1->Destroy();
processor1 = nullptr;
}
if (processor2 != nullptr) {
processor2->Destroy();
processor2 = nullptr;
}
if (processor3 != nullptr) {
processor3->Destroy();
processor3 = nullptr;
}
if (processor4 != nullptr) {
processor4->Destroy();
processor4 = nullptr;
}
if (MindXEngine::GetInstance() != nullptr) {
MindXEngine::GetInstance()->Destroy();
}
GlobalMockObject::verify();
GlobalMockObject::reset();
}
int32_t CallBackFunc(void *ctx, int ctxSize)
{
SaveCkptContext *info = static_cast<SaveCkptContext *>(ctx);
if (info != nullptr) {
{
std::lock_guard<std::mutex> lock(ckptRankInfosRanksMutex);
ckptRankInfos.emplace(info->ranks);
}
}
ckptCount.fetch_add(1);
return 0;
}
int RenameFunc(void *ctx, int ctxSize)
{
renameCount.fetch_add(1);
return 0;
}
int ExitFunc(void *ctx, int ctxSize)
{
exitCount.fetch_add(1);
return 0;
}
int StopFunc(void *ctx, int ctxSize)
{
stopCount.fetch_add(1);
return 0;
}
int CleanFunc(void *ctx, int ctxSize)
{
cleanCount.fetch_add(1);
int32_t rank = *(static_cast<int32_t *>(ctx));
auto itr = lowLevelRanks.find(rank);
return itr != lowLevelRanks.end() ? UCE_NO_REBUILD : 0;
}
int RepairFunc(void *ctx, int ctxSize)
{
RepairContext *rc = static_cast<RepairContext *>(ctx);
if (rc->type == RepairType::RT_SEND) {
repairSendCount.fetch_add(1);
{
std::lock_guard<std::mutex> lock(repairRanksMutex);
repairRanks = rc->ranks;
}
{
std::lock_guard<std::mutex> lock(repairRankInfosMutex);
repairRankInfos["send"].emplace(rc->srcRank);
}
} else if (rc->type == RepairType::RT_UCE_HIGHLEVEL || rc->type == RepairType::RT_UCE_LOWLEVEL) {
repairUCECount.fetch_add(1);
std::lock_guard<std::mutex> lock(repairRankInfosMutex);
repairRankInfos["ucerecv"].emplace(rc->dstRank);
} else if (rc->type == RepairType::RT_RECV_REPAIR) {
repairZitRecvCount.fetch_add(1);
std::lock_guard<std::mutex> lock(repairRankInfosMutex);
repairRankInfos["otherrecv"].emplace(rc->dstRank);
} else if (rc->type == RepairType::RT_LOAD_CKPT) {
repairLoadCkpt.fetch_add(1);
} else if (rc->type == RepairType::RT_LOAD_REBUILD) {
repairLoadRebuild.fetch_add(1);
}
if (repairFlag.load() == true) {
return 0;
} else {
return 1;
}
}
int RollBackFunc(void *ctx, int ctxSize)
{
RepairContext *rc = reinterpret_cast<RepairContext *>(ctx);
repairRollbackCount.fetch_add(1);
return 0;
}
int Register(void *ctx, int ctxSize)
{
registerCount.fetch_add(1);
return 0;
}
int RebuildFunc(void *ctx, int ctxSize)
{
ZitRebuildContext *rc = reinterpret_cast<ZitRebuildContext *>(ctx);
std::lock_guard<std::mutex> lock(repairRanksMutex);
auto it = std::find(rc->commGroupIdx.begin(), rc->commGroupIdx.end(), 0);
if (it != rc->commGroupIdx.end()) {
int32_t dpcpIndex = std::distance(rc->commGroupIdx.begin(), it);
repairRanks = rc->commGroups[dpcpIndex];
}
return 0;
}
int PtCommFunc(void *ctx, int ctxSize)
{
ptCommCount.fetch_add(1);
return 0;
}
int UpPtCommFunc(void *ctx, int ctxSize)
{
upPtCommCount.fetch_add(1);
return 0;
}
int PauseTrainFunc(void *ctx, int ctxSize)
{
pauseTrainCount.fetch_add(1);
return 0;
}
int ContinueTrainFunc(void *ctx, int ctxSize)
{
continueTrainCount.fetch_add(1);
return 0;
}
static int ReportFaultRanks(void *ctx, int ctxSize)
{
ProcessFaultContext *nrsc = static_cast<ProcessFaultContext *>(ctx);
std::map<int32_t, int32_t> ranks = nrsc->errorInfoMap;
MindXEngine::GetInstance()->EventProcess(MindXEvent::MINDX_EVENT_STOP_TRAIN, &ranks, ranks.size());
return 0;
}
static int ReportFaultRanksUnexcepted(void *ctx, int ctxSize)
{
ProcessFaultContext *nrsc = static_cast<ProcessFaultContext *>(ctx);
NotifyRankInfo rankInfo {nrsc->errorInfoMap, TTP_WAIT_TIME_1MS};
MindXEngine::GetInstance()->EventProcess(
MindXEvent::MINDX_EVENT_NOTIFY_FAULT_RANKS, &rankInfo, sizeof(NotifyRankInfo));
return 0;
}
static int ReportStopComplete(void *ctx, int ctxSize)
{
StopCompleteContext *nrsc = static_cast<StopCompleteContext *>(ctx);
NotifyRankInfo rankInfo {nrsc->errorInfoMap, TTP_WAIT_TIME_1MS};
MindXEngine::GetInstance()->EventProcess(MindXEvent::MINDX_EVENT_NOTIFY_FAULT_RANKS,
&rankInfo, sizeof(NotifyRankInfo));
return 0;
}
static int ReportStopCompletePause(void *ctx, int ctxSize)
{
StopCompleteContext *nrsc = static_cast<StopCompleteContext *>(ctx);
std::map<int32_t, int32_t> ranks = nrsc->errorInfoMap;
return 0;
}
static int ReportStopCompleteUnexcepted(void *ctx, int ctxSize)
{
StopCompleteContext *nrsc = static_cast<StopCompleteContext *>(ctx);
std::map<int32_t, int32_t> ranks = nrsc->errorInfoMap;
MindXEngine::GetInstance()->EventProcess(MindXEvent::MINDX_EVENT_STOP_TRAIN,
&ranks, ranks.size());
return 0;
}
int ReportStrategies(void *ctx, int ctxSize)
{
canChange.store(true);
return 0;
}
int ReportResult(void *ctx, int ctxSize)
{
canChange.store(true);
reportResultCount.fetch_add(1);
return 0;
}
void ChangeStrategy(std::string strategy)
{
while (!canChange.load()) {
usleep(TTP_WAIT_TIME_1MS);
}
canChange.store(false);
std::string param = "zit test";
MindXEngine::GetInstance()->ChangeStrategy(strategy, param);
}
void InitProcessor(ProcessorPtr &proc)
{
int32_t ret;
proc = MakeRef<Processor>();
ASSERT_TRUE(proc != nullptr);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_EXIT, std::bind(&ControllerTest::ExitFunc,
this, std::placeholders::_1,
std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_RENAME, std::bind(&ControllerTest::RenameFunc,
this, std::placeholders::_1,
std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_SAVE_CKPT, std::bind(&ControllerTest::CallBackFunc,
this, std::placeholders::_1,
std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_DEVICE_STOP, std::bind(&ControllerTest::StopFunc,
this, std::placeholders::_1,
std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_DEVICE_CLEAN, std::bind(&ControllerTest::CleanFunc,
this, std::placeholders::_1,
std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_REPAIR, std::bind(&ControllerTest::RepairFunc,
this, std::placeholders::_1,
std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_ROLLBACK, std::bind(&ControllerTest::RollBackFunc,
this, std::placeholders::_1,
std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_DOWNGRADE_REBUILD, std::bind(&ControllerTest::RebuildFunc,
this, std::placeholders::_1,
std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_PT_COMM_OPERATE,
std::bind(&ControllerTest::PtCommFunc,
this, std::placeholders::_1, std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_UPGRADE_REBUILD, std::bind(&ControllerTest::UpPtCommFunc,
this, std::placeholders::_1,
std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_UPGRADE_REPAIR, std::bind(&ControllerTest::RepairFunc,
this, std::placeholders::_1,
std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_UPGRADE_ROLLBACK, std::bind(&ControllerTest::RollBackFunc,
this, std::placeholders::_1,
std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_PAUSE,
std::bind(&ControllerTest::PauseTrainFunc,
this, std::placeholders::_1, std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = proc->RegisterEventHandler(PROCESSOR_EVENT_CONTINUE,
std::bind(&ControllerTest::ContinueTrainFunc,
this, std::placeholders::_1, std::placeholders::_2));
ASSERT_EQ(ret, 0);
}
void InitController(ControllerPtr &ctrl)
{
int32_t ret;
ctrl = Controller::GetInstance();
ASSERT_TRUE(ctrl != nullptr);
if (std::getenv("MINDX_TASK_ID") == nullptr) {
return;
}
MindXEnginePtr engine = MindXEngine::GetInstance();
ret = engine->RegisterEventHandler(MindXEvent::MINDX_EVENT_REGISTER, std::bind(&ControllerTest::Register,
this, std::placeholders::_1,
std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = engine->RegisterEventHandler(MindXEvent::MINDX_EVENT_REPORT_FAULT_RANKS,
&ControllerTest::ReportFaultRanks);
ASSERT_EQ(ret, 0);
ret = engine->RegisterEventHandler(MindXEvent::MINDX_EVENT_REPORT_STOP_COMPLETE,
&ControllerTest::ReportStopComplete);
ASSERT_EQ(ret, 0);
ret = engine->RegisterEventHandler(MindXEvent::MINDX_EVENT_REPORT_STRATEGIES,
std::bind(&ControllerTest::ReportStrategies, this,
std::placeholders::_1, std::placeholders::_2));
ASSERT_EQ(ret, 0);
ret = engine->RegisterEventHandler(MindXEvent::MINDX_EVENT_REPORT_RESULT,
std::bind(&ControllerTest::ReportResult, this,
std::placeholders::_1, std::placeholders::_2));
ASSERT_EQ(ret, 0);
}
static std::vector<BackupInfo> SelectBackUpController()
{
std::vector<BackupInfo> backUps;
BackupInfo info;
info.rank = 1;
info.ip = BACKUP_IP;
info.port = BACKUP_PORT;
backUps.push_back(info);
return backUps;
}
void CountClean()
{
ckptCount.store(0);
renameCount.store(0);
exitCount.store(0);
stopCount.store(0);
cleanCount.store(0);
repairSendCount.store(0);
repairUCECount.store(0);
repairLoadCkpt.store(0);
repairLoadRebuild.store(0);
repairRollbackCount.store(0);
registerCount.store(0);
ptCommCount.store(0);
upPtCommCount.store(0);
pauseTrainCount.store(0);
continueTrainCount.store(0);
repairZitRecvCount.store(0);
canChange.store(false);
}
void MapInfoClean()
{
lowLevelRanks.clear();
{
std::lock_guard<std::mutex> lock(repairRanksMutex);
repairRanks.clear();
}
{
std::lock_guard<std::mutex> lock(ckptRankInfosRanksMutex);
ckptRankInfos.clear();
}
{
std::lock_guard<std::mutex> lock(repairRankInfosMutex);
repairRankInfos.clear();
}
}
void InitSource(int32_t controllerReplica = 2, bool enableARF = false, bool enableZIT = false,
int32_t replicaShift = 2)
{
ControllerTest::CountClean();
ControllerTest::MapInfoClean();
ControllerTest::InitController(controller1);
std::vector<int32_t> replicaCnt = { controllerReplica };
std::vector<int32_t> replicaOffset = { replicaShift };
int32_t ret = controller1->Initialize(0, WORLD_SIZE, enableLocalCopy, enableARF, enableZIT);
controller1->retrySwitch_ = true;
ASSERT_EQ(ret, 0);
std::string ip = CONTROLLER_IP;
int32_t port = CONTROLLER_PORT;
ret = controller1->Start(ip, port, testTlsOption);
ASSERT_EQ(ret, 0);
std::vector<int32_t> ranks = {0, 1, 2, 3};
std::vector<std::vector<int32_t>> groups = { ranks };
ControllerTest::InitProcessor(processor1);
ControllerTest::InitProcessor(processor2);
ControllerTest::InitProcessor(processor3);
ControllerTest::InitProcessor(processor4);
ret = processor1->Initialize(0, WORLD_SIZE, enableLocalCopy, testTlsOption, true, enableARF, enableZIT);
ASSERT_EQ(ret, 0);
ret = processor2->Initialize(1, WORLD_SIZE, enableLocalCopy, testTlsOption, true, enableARF, enableZIT);
ASSERT_EQ(ret, 0);
ret = processor3->Initialize(2, WORLD_SIZE, enableLocalCopy, testTlsOption, true, enableARF, enableZIT);
ASSERT_EQ(ret, 0);
ret = processor4->Initialize(3, WORLD_SIZE, enableLocalCopy, testTlsOption, true, enableARF, enableZIT);
ASSERT_EQ(ret, 0);
ret = processor1->Start(ip, port);
ASSERT_EQ(ret, 0);
ret = processor2->Start(ip, port);
ASSERT_EQ(ret, 0);
ret = processor3->Start(ip, port);
ASSERT_EQ(ret, 0);
ret = processor4->Start(ip, port);
ASSERT_EQ(ret, 0);
ret = processor1->ReportReplicaInfo(groups, replicaCnt, replicaOffset);
ASSERT_EQ(ret, 0);
ret = processor2->ReportReplicaInfo(groups, replicaCnt, replicaOffset);
ASSERT_EQ(ret, 0);
ret = processor3->ReportReplicaInfo(groups, replicaCnt, replicaOffset);
ASSERT_EQ(ret, 0);
ret = processor4->ReportReplicaInfo(groups, replicaCnt, replicaOffset);
ASSERT_EQ(ret, 0);
}
void ProcessorUpdate(ProcessorPtr &proc)
{
int32_t ret = proc->BeginUpdating(BACKUP_STEP);
ASSERT_EQ(ret, 0);
ret = proc->FinishedUpdate(COMMON_STEP);
ASSERT_EQ(ret, 0);
}
void HeartbeatUpdate()
{
int32_t ret = processor1->HeartbeatSend();
ASSERT_EQ(ret, 0);
ret = processor2->HeartbeatSend();
ASSERT_EQ(ret, 0);
ret = processor3->HeartbeatSend();
ASSERT_EQ(ret, 0);
ret = processor4->HeartbeatSend();
ASSERT_EQ(ret, 0);
}
public:
std::atomic<uint32_t> ckptCount;
std::atomic<uint32_t> renameCount;
std::atomic<uint32_t> exitCount;
std::atomic<uint32_t> stopCount;
std::atomic<uint32_t> cleanCount;
std::atomic<uint32_t> repairSendCount;
std::atomic<uint32_t> repairUCECount;
std::atomic<uint32_t> repairRollbackCount;
std::atomic<uint32_t> registerCount;
std::atomic<uint32_t> reportResultCount;
std::atomic<uint32_t> ptCommCount;
std::atomic<uint32_t> upPtCommCount;
std::atomic<uint32_t> pauseTrainCount;
std::atomic<uint32_t> continueTrainCount;
std::atomic<uint32_t> repairZitRecvCount;
std::atomic<uint32_t> repairLoadCkpt;
std::atomic<uint32_t> repairLoadRebuild;
std::mutex repairRankInfosMutex;
std::mutex repairRanksMutex;
std::mutex ckptRankInfosRanksMutex;
std::set<int32_t> lowLevelRanks;
std::vector<int32_t> repairRanks;
std::set<std::vector<std::vector<int32_t>>> ckptRankInfos;
std::map<std::string, std::set<std::vector<int32_t>>> repairRankInfos;
ProcessorPtr processor1 = nullptr;
ProcessorPtr processor2 = nullptr;
ProcessorPtr processor3 = nullptr;
ProcessorPtr processor4 = nullptr;
ControllerPtr controller1 = nullptr;
ControllerPtr controller2 = nullptr;
AccTlsOption testTlsOption;
bool enableLocalCopy = false;
std::atomic<bool> repairFlag = { true };
std::atomic<bool> canChange = { false };
};
#endif