* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can be obtained a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "WsSessionManager.h"
#include "DataBaseManager.h"
#include "TraceTime.h"
#include "RankOffsetHandler.h"
#include "RankOffsetCalculator.h"
namespace Dic {
namespace Module {
namespace Timeline {
using namespace Dic::Server;
struct FailedRank {
std::string rankId;
std::string reason;
};
struct SortedSliceQuery {
std::vector<Protocol::SimpleSlice> slices;
std::set<std::string> processIds;
};
namespace {
constexpr int32_t RANK_OFFSET_ERROR_CODE = 1;
void SetRankOffsetError(RankOffsetResponse &response, const std::string &message) {
response.body.error.code = RANK_OFFSET_ERROR_CODE;
response.body.error.message = message;
}
void SortSlicesByTimestamp(std::vector<Protocol::SimpleSlice> &slices) {
std::sort(slices.begin(), slices.end(),
[](const Protocol::SimpleSlice &a, const Protocol::SimpleSlice &b) { return a.timestamp < b.timestamp; });
}
int64_t FindSelectedSliceIndex(
const std::vector<Protocol::SimpleSlice> &pidSlices, uint64_t startTime, uint64_t duration) {
for (size_t i = 0; i < pidSlices.size(); ++i) {
if (pidSlices[i].timestamp == startTime && pidSlices[i].duration == duration) {
return static_cast<int64_t>(i);
}
}
return -1;
}
std::vector<std::string> GetTargetRankIds(const std::string &baseRankId) {
auto allRankIds = DataBaseManager::Instance().GetAllRankId();
std::vector<std::string> targetRankIds;
for (const auto &rankId : allRankIds) {
if (rankId != baseRankId) {
targetRankIds.push_back(rankId);
}
}
return targetRankIds;
}
std::string BuildFailedRanksMessage(const std::vector<FailedRank> &failedRanks) {
std::string details;
for (size_t i = 0; i < failedRanks.size(); ++i) {
if (i > 0) {
details += "; ";
}
details += failedRanks[i].rankId + " (" + failedRanks[i].reason + ")";
}
return "Failed ranks: " + details;
}
bool QuerySlicesByName(VirtualTraceDatabase &db, const std::string &rankId, const std::string &sliceName,
const std::string &metaType, RankOffsetSide side, SortedSliceQuery &query) {
if (metaType == "TEXT") {
return db.QueryTextSlicesByName(sliceName, metaType, query.slices, query.processIds);
}
if (side == RankOffsetSide::HOST) {
return db.QueryHostSlicesByName(sliceName, metaType, query.slices, query.processIds);
}
if (side == RankOffsetSide::DEVICE) {
return db.QueryDeviceSlicesByName(rankId, sliceName, metaType, query.slices, query.processIds);
}
return true;
}
SortedSliceQuery QueryAndSortSlices(
const std::string &rankId, const std::string &sliceName, const std::string &metaType, RankOffsetSide side) {
SortedSliceQuery query;
auto db = DataBaseManager::Instance().GetTraceDatabaseByRankId(rankId);
if (db == nullptr) {
return query;
}
QuerySlicesByName(*db, rankId, sliceName, metaType, side, query);
SortSlicesByTimestamp(query.slices);
return query;
}
const Protocol::SimpleSlice *PickSliceByIndex(const std::vector<Protocol::SimpleSlice> &slices, int64_t index) {
if (index < 0 || static_cast<int64_t>(slices.size()) <= index) {
return nullptr;
}
return &slices[index];
}
}
bool RankOffsetHandler::HandleRequest(std::unique_ptr<Protocol::Request> requestPtr) {
RankOffsetRequest &request = dynamic_cast<RankOffsetRequest &>(*requestPtr.get());
auto responsePtr = std::make_unique<RankOffsetResponse>();
RankOffsetResponse &response = *responsePtr.get();
SetBaseResponse(request, response);
std::string errorMsg;
if (!request.params.CheckParams(errorMsg)) {
ServerLog::Warn(errorMsg);
SetRankOffsetError(response, errorMsg);
SendResponse(std::move(responsePtr), false);
return false;
}
response.body.baseOffset = TraceTime::Instance().GetOffsetByFileId(request.params.rankId);
RankOffsetSide side = RankOffsetCalculator::GetSide(request.params.metaType);
if (side == RankOffsetSide::UNSUPPORTED) {
errorMsg = "Unsupported metaType: " + request.params.metaType;
ServerLog::Warn(errorMsg);
SetRankOffsetError(response, errorMsg);
SendResponse(std::move(responsePtr), false);
return false;
}
RankOffsetAlignType alignType = RankOffsetCalculator::ParseAlignType(request.params.alignType);
if (alignType == RankOffsetAlignType::UNSUPPORTED) {
errorMsg = "Unsupported alignType: " + request.params.alignType;
ServerLog::Warn(errorMsg);
SetRankOffsetError(response, errorMsg);
SendResponse(std::move(responsePtr), false);
return false;
}
int64_t baseIndex = -1;
Protocol::SimpleSlice baseSlice;
if (!ResolveBaseSlice(request.params, side, baseSlice, baseIndex, errorMsg)) {
ServerLog::Warn(errorMsg);
SetRankOffsetError(response, errorMsg);
SendResponse(std::move(responsePtr), false);
return false;
}
auto targetRankIds = GetTargetRankIds(request.params.rankId);
if (targetRankIds.empty()) {
errorMsg = "No target ranks found. Only base rank is loaded.";
ServerLog::Warn(errorMsg);
SetRankOffsetError(response, errorMsg);
SendResponse(std::move(responsePtr), false);
return false;
}
std::vector<FailedRank> failedRanks;
for (const auto &targetRankId : targetRankIds) {
ProcessTargetRank(targetRankId, request.params.sliceName, request.params.metaType, side, alignType, baseSlice,
baseIndex, response.body.baseOffset, response.body.successList, failedRanks);
}
if (!failedRanks.empty()) {
SetRankOffsetError(response, BuildFailedRanksMessage(failedRanks));
}
bool hasSuccess = !response.body.successList.empty();
SendResponse(std::move(responsePtr), hasSuccess);
return hasSuccess;
}
bool RankOffsetHandler::ResolveBaseSlice(const RankOffsetParams ¶ms, RankOffsetSide side,
Protocol::SimpleSlice &outSlice, int64_t &outIndex, std::string &errorMsg) {
auto baseDb = DataBaseManager::Instance().GetTraceDatabaseByRankId(params.rankId);
if (baseDb == nullptr) {
errorMsg = "Base rank database not found for rankId: " + params.rankId;
return false;
}
SortedSliceQuery baseQuery;
if (!QuerySlicesByName(*baseDb, params.rankId, params.sliceName, params.metaType, side, baseQuery)) {
errorMsg = "Failed to query base slices from database.";
return false;
}
SortSlicesByTimestamp(baseQuery.slices);
std::vector<Protocol::SimpleSlice> &pidSlices = baseQuery.slices;
if (pidSlices.empty()) {
errorMsg = "Base slice not found for sliceName: " + params.sliceName + ", pid: " + params.pid;
return false;
}
int64_t selectedIndex = FindSelectedSliceIndex(pidSlices, params.startTime, params.duration);
if (selectedIndex < 0) {
errorMsg = "Selected slice not found for sliceName: " + params.sliceName +
", startTime: " + std::to_string(params.startTime) + ", duration: " + std::to_string(params.duration);
return false;
}
outSlice = pidSlices[selectedIndex];
outIndex = selectedIndex;
return true;
}
void RankOffsetHandler::ProcessTargetRank(const std::string &targetRankId, const std::string &sliceName,
const std::string &metaType, RankOffsetSide side, RankOffsetAlignType alignType,
const Protocol::SimpleSlice &baseSlice, int64_t baseIndex, uint64_t baseOffset,
std::vector<RankOffsetItem> &successList, std::vector<FailedRank> &failedRanks) {
auto query = QueryAndSortSlices(targetRankId, sliceName, metaType, side);
if (query.slices.empty()) {
failedRanks.push_back({targetRankId, "Slice not found"});
return;
}
const auto *targetSlice = PickSliceByIndex(query.slices, baseIndex);
if (targetSlice == nullptr) {
failedRanks.push_back({targetRankId,
"Insufficient slices: has " + std::to_string(query.slices.size()) + ", need " +
std::to_string(baseIndex + 1)});
return;
}
int64_t offset =
RankOffsetCalculator::CalculateOffset(baseSlice, *targetSlice, alignType) + static_cast<int64_t>(baseOffset);
RankOffsetItem item;
item.rankId = targetRankId;
item.offset = offset;
item.processId.assign(query.processIds.begin(), query.processIds.end());
successList.push_back(item);
}
}
}
}