* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is part of the MindStudio project.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------*/
#include "analysis/csrc/application/summary/task_time_assembler.h"
#include "analysis/csrc/infrastructure/dfx/error_code.h"
#include "analysis/csrc/domain/services/environment/context.h"
#include "analysis/csrc/domain/services/device_context/device_context.h"
namespace Analysis {
namespace Application {
using namespace Analysis::Utils;
using namespace Analysis::Domain::Environment;
using namespace Analysis::Viewer::Database;
TaskTimeAssembler::TaskTimeAssembler(const std::string &name, const std::string &profPath)
: SummaryAssembler(name, profPath)
{
headers_ = {"Device_id", "kernel_name", "kernel_type", "stream_id", "task_id", "task_time(us)",
"task_start(us)", "task_stop(us)"};
}
uint8_t TaskTimeAssembler::AssembleData(DataInventory &dataInventory)
{
auto taskInfoData = dataInventory.GetPtr<std::vector<TaskInfoData>>();
auto hostTaskData = dataInventory.GetPtr<std::vector<HostTask>>();
auto ascendTaskData = dataInventory.GetPtr<std::vector<AscendTaskData>>();
if (ascendTaskData == nullptr) {
WARN("AscendTask not exists, can't export task time data.");
return DATA_NOT_EXIST;
}
if (taskInfoData == nullptr) {
WARN("No ge data collected, maybe the TaskInfo table is not created, now try to export data with no ge data");
} else {
FormatTaskInfoData(*taskInfoData);
}
if (hostTaskData == nullptr) {
WARN("No host data collected, maybe the HostTask table is not created, now try to export data with no host data");
} else {
FormatHostTaskData(*hostTaskData);
}
AssembleTaskTime(*ascendTaskData);
if (res_.empty()) {
ERROR("Can't match any task time data, failed to generate task_time_*.csv");
return ASSEMBLE_FAILED;
}
sortRes();
WriteToFile(File::PathJoin({profPath_, OUTPUT_PATH, TASK_TIME_SUMMARY_NAME}), {});
return ASSEMBLE_SUCCESS;
}
void TaskTimeAssembler::FormatTaskInfoData(const std::vector<TaskInfoData> &taskInfoData)
{
if (taskInfoData.empty()) {
WARN("task info data is empty, no task info data for task time, check ge info please.");
return;
}
for (const auto &taskInfoDatum : taskInfoData) {
TaskId taskId{static_cast<uint16_t>(taskInfoDatum.streamId), static_cast<uint16_t>(taskInfoDatum.batchId),
taskInfoDatum.taskId, taskInfoDatum.contextId, taskInfoDatum.deviceId};
formatedTaskInfo_.insert({taskId, {taskInfoDatum.opName, taskInfoDatum.taskType}});
}
}
void TaskTimeAssembler::FormatHostTaskData(const std::vector<HostTask> &hostTaskData)
{
if (hostTaskData.empty()) {
WARN("host task data is empty, no host task data for task time, check host task please.");
return;
}
for (const auto &taskInfoDatum : hostTaskData) {
TaskId taskId{static_cast<uint16_t>(taskInfoDatum.streamId), taskInfoDatum.batchId,
taskInfoDatum.taskId, taskInfoDatum.contextId, taskInfoDatum.deviceId};
opNameMap.insert({taskId, taskInfoDatum.kernelNameStr});
}
}
std::vector<AscendTaskData> FilterAscendTaskData(const std::vector<AscendTaskData> &ascendTaskDatas) {
struct TupleHash {
size_t operator()(const std::tuple<uint32_t, uint32_t, uint32_t>& key) const {
const size_t h1 = std::hash<uint32_t>()(std::get<0>(key));
const size_t h2 = std::hash<uint32_t>()(std::get<1>(key));
const size_t h3 = std::hash<uint32_t>()(std::get<2>(key));
return h1 ^ (h2 << 1) ^ (h3 << 2);
}
};
struct TupleEqual {
bool operator()(const std::tuple<uint32_t, uint32_t, uint32_t>& a,
const std::tuple<uint32_t, uint32_t, uint32_t>& b) const {
return std::get<0>(a) == std::get<0>(b) &&
std::get<1>(a) == std::get<1>(b) &&
std::get<2>(a) == std::get<2>(b);
}
};
typedef std::tuple<uint32_t, uint32_t, uint32_t> KeyType;
std::unordered_map<KeyType, std::vector<AscendTaskData>, TupleHash, TupleEqual> groups_dict;
groups_dict.reserve(ascendTaskDatas.size() / 2);
for (const auto& item : ascendTaskDatas) {
groups_dict[KeyType(item.streamId, item.taskId, item.batchId)].push_back(item);
}
std::vector<AscendTaskData> result;
size_t estimated_size = 0;
for (const auto& pair : groups_dict) {
estimated_size += pair.second.size();
}
result.reserve(estimated_size);
for (const auto& pair : groups_dict) {
const auto& tasks = pair.second;
if (tasks.size() == 1) {
result.push_back(tasks[0]);
continue;
}
bool mix = false;
bool staticGraph = true;
for (const auto& task : tasks) {
if (task.contextId > 0 && task.contextId != INVALID_CONTEXT_ID ) {
break;
}
if (task.contextId == 0) {
mix = true;
staticGraph = false;
}
}
if (mix || staticGraph) {
for (const auto& task : tasks) {
if (mix && task.contextId == 0) {
result.push_back(task);
}
if (staticGraph) {
result.push_back(task);
}
}
}
}
return result;
}
void TaskTimeAssembler::AssembleTaskTime(const std::vector<AscendTaskData> &ascendTaskData)
{
if (ascendTaskData.empty()) {
WARN("ascend task data is empty, no ascend task data for task time, check ascend task please.");
return;
}
const std::string DIVIDE_CHAR = "\t";
for (auto &ascendTaskDatum : FilterAscendTaskData(ascendTaskData)) {
TaskId taskId{static_cast<uint16_t>(ascendTaskDatum.streamId), static_cast<uint16_t>(ascendTaskDatum.batchId),
ascendTaskDatum.taskId, ascendTaskDatum.contextId, ascendTaskDatum.deviceId};
auto taskInfoIt = formatedTaskInfo_.find(taskId);
const std::string opName = [&]() -> std::string {
auto opNameIt = opNameMap.find(taskId);
if (opNameIt != opNameMap.end()) {
return opNameIt->second;
}
if (taskInfoIt != formatedTaskInfo_.end()) {
return taskInfoIt->second.first;
}
return NA;
}();
const std::string taskType = (taskInfoIt != formatedTaskInfo_.end() && taskInfoIt->second.second != NA)
? taskInfoIt->second.second
: ascendTaskDatum.taskType;
auto row = {std::to_string(ascendTaskDatum.deviceId),
opName, taskType, std::to_string(taskId.streamId),
std::to_string(taskId.taskId),
DivideByPowersOfTenWithPrecision(static_cast<uint64_t>(ascendTaskDatum.duration)),
DivideByPowersOfTenWithPrecision(ascendTaskDatum.timestamp) + DIVIDE_CHAR,
DivideByPowersOfTenWithPrecision(ascendTaskDatum.end) + DIVIDE_CHAR};
res_.emplace_back(row);
}
}
void TaskTimeAssembler::sortRes()
{
const auto DEVICE_INDEX = 0;
const auto START_TIME_INDEX = 6;
using rowType = std::vector<std::string>;
std::sort(res_.begin(), res_.end(), [DEVICE_INDEX, START_TIME_INDEX](const rowType& lv, const rowType& rv) {
if (lv.at(DEVICE_INDEX) == rv.at(DEVICE_INDEX)) {
return lv.at(START_TIME_INDEX) < rv.at(START_TIME_INDEX);
}
return lv.at(START_TIME_INDEX) < rv.at(START_TIME_INDEX);
});
}
}
}