* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2026 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "DataBaseManager.h"
#include "MemcpyOverallDatabaseAccesser.h"
#include "Paginator.h"
#include "QueryMemcpyOverallHandler.h"
namespace Dic::Module::Timeline {
bool QueryMemcpyOverallHandler::HandleRequest(std::unique_ptr<Protocol::Request> requestPtr) {
auto &request = dynamic_cast<MemcpyOverallRequest &>(*requestPtr);
std::unique_ptr<MemcpyOverallResponse> responsePtr = std::make_unique<MemcpyOverallResponse>();
MemcpyOverallResponse &response = *responsePtr;
SetBaseResponse(request, response);
auto database = DataBaseManager::Instance().GetTraceDatabaseByFileId(request.fileId);
if (database == nullptr) {
SetTimelineError(ErrorCode::CONNECT_DATABASE_FAILED);
SendResponse(std::move(responsePtr), false);
return false;
}
uint64_t minTimestamp = TraceTime::Instance().GetStartTime();
std::string error;
if (!request.params.CheckParams(minTimestamp, error)) {
SetTimelineError(ErrorCode::PARAMS_ERROR);
SendResponse(std::move(responsePtr), false, error);
return false;
}
std::string deviceId = DataBaseManager::Instance().GetDeviceIdFromRankId(request.params.rankId);
if (deviceId.empty()) {
ServerLog::Warn("DeviceId is empty for memcpy view overall statistics.");
SendResponse(std::move(responsePtr), true);
return true;
}
request.params.deviceId = deviceId;
if (!CalMemcpyData(request, response, error, database)) {
SetTimelineError(ErrorCode::QUERY_MEMCPY_OVERALL_FAILED);
SendResponse(std::move(responsePtr), false);
return false;
}
SendResponse(std::move(responsePtr), true);
return true;
}
void BuildMemcpyOverallResult(
const std::vector<MemcpyRecord> &records, MemcpyOverallResponse &response, uint32_t current, uint32_t pageSize) {
std::map<std::string, StatsAccumulator> threadMap;
std::map<std::string, std::map<std::string, StatsAccumulator>> typeMap;
std::map<std::string, std::string> threadNameMap;
for (const auto &rec : records) {
threadMap[rec.threadId].Update(rec.size, rec.duration);
threadNameMap[rec.threadId] = rec.threadName;
typeMap[rec.threadId][rec.memcpyType].Update(rec.size, rec.duration);
}
std::vector<MemcpyOverallRes> result;
result.reserve(threadMap.size());
for (auto &[tid, tStat] : threadMap) {
MemcpyOverallRes ts;
ts.key = tid;
ts.name = threadNameMap.at(tid);
ts.level = 1;
ts.totalSize = tStat.totalSize;
ts.totalTime = tStat.totalTime;
ts.number = tStat.count;
ts.avgSize = tStat.GetAvgSize();
ts.minSize = tStat.GetMinSize();
ts.maxSize = tStat.GetMaxSize();
ts.avgTime = tStat.GetAvgTime();
ts.minTime = tStat.GetMinTime();
ts.maxTime = tStat.GetMaxTime();
if (auto it = typeMap.find(tid); it != typeMap.end()) {
ts.children.reserve(it->second.size());
for (auto &[mtype, mStat] : it->second) {
MemcpyOverallRes tts;
tts.key = mtype;
tts.name = mtype;
tts.level = 2;
tts.totalSize = mStat.totalSize;
tts.totalTime = mStat.totalTime;
tts.number = mStat.count;
tts.avgSize = mStat.GetAvgSize();
tts.minSize = mStat.GetMinSize();
tts.maxSize = mStat.GetMaxSize();
tts.avgTime = mStat.GetAvgTime();
tts.minTime = mStat.GetMinTime();
tts.maxTime = mStat.GetMaxTime();
ts.children.push_back(std::move(tts));
}
}
result.push_back(std::move(ts));
}
Paginator<MemcpyOverallRes> paginator(result, pageSize);
response.pageParam.total = paginator.GetTotal();
response.details = paginator.GetPage(current);
}
bool QueryMemcpyOverallHandler::CalMemcpyData(MemcpyOverallRequest &request, MemcpyOverallResponse &response,
std::string &error, const std::shared_ptr<VirtualTraceDatabase> &database) {
const MemcpyOverallDatabaseAccesser accesser(database, request.fileId);
std::vector<MemcpyRecord> records;
if (!accesser.GetMemcpyRecords(request.params.startTime, request.params.endTime, records)) {
error = "Failed to query memcpy statistics.";
return false;
}
BuildMemcpyOverallResult(records, response, request.params.page.current, request.params.page.pageSize);
response.pageParam.current = request.params.page.current;
response.pageParam.pageSize = request.params.page.pageSize;
return true;
}
}