* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "memory_compare.h"
#include <fstream>
#include <sstream>
#include "file.h"
#include "utils.h"
#include "config_info.h"
#include "record_info.h"
#include "ustring.h"
#include "bit_field.h"
namespace MemScope {
MemoryCompare& MemoryCompare::GetInstance(Config config)
{
static MemoryCompare instance(config);
return instance;
}
MemoryCompare::MemoryCompare(Config config)
{
config_ = config;
}
std::string MemoryCompare::ReadQuotedField(std::stringstream& ss)
{
std::string field;
if (ss.peek() == '"') {
ss.get();
std::getline(ss, field, '"');
size_t pos = 0;
while ((pos = field.find("\"\"", pos)) != std::string::npos) {
field.replace(pos, 2, "\"");
pos += 1;
}
if (ss.peek() == ',') {
ss.get();
}
} else {
std::getline(ss, field, ',');
}
return field;
}
bool Compare(const std::unordered_map<std::string, std::string> &a,
const std::unordered_map<std::string, std::string> &b)
{
uint64_t compareA;
uint64_t compareB;
if (!Utility::StrToUint64(compareA, a.at("Timestamp(ns)"))) {
LOG_WARN("StrToUint64 failed, the str is %s.", a.at("Timestamp(ns)").c_str());
compareA = UINT64_MAX;
}
if (!Utility::StrToUint64(compareB, b.at("Timestamp(ns)"))) {
LOG_WARN("StrToUint64 failed, the str is %s.", b.at("Timestamp(ns)").c_str());
compareB = UINT64_MAX;
}
return compareA < compareB;
}
void MemoryCompare::ReadFile(std::string &path, std::unordered_map<DEVICEID, ORIGINAL_FILE_DATA> &data)
{
std::vector<std::string> fileName;
Utility::Split(path, std::back_inserter(fileName), ".");
if (fileName.size() > 0 && fileName.back() == "csv") {
LOG_INFO("Read csv file: %s.", path.c_str());
ReadCsvFile(path, data);
for (const auto& pair : data) {
uint64_t deviceId = pair.first;
sort(data[deviceId].begin(), data[deviceId].end(), Compare);
}
} else {
LOG_ERROR("The file %s is an unsupported format.", path.c_str());
}
}
bool MemoryCompare::CheckCsvHeader(std::string &path, std::ifstream& file, std::vector<std::string> &headerData)
{
if (!file.is_open()) {
LOG_ERROR("The path: %s open failed!", path.c_str());
return false;
}
std::string line;
getline(file, line);
std::string normalizedLine = NormalizeString(line);
if (normalizedLine + "\n" != std::string(MEMSCOPE_HEADERS)) {
return false;
}
Utility::Split(normalizedLine, std::back_inserter(headerData), ",");
return true;
}
std::string MemoryCompare::NormalizeString(const std::string& line)
{
std::string result = line;
result.erase(
std::remove_if(result.begin(), result.end(), [](unsigned char c) {
return c == '\r' || c == '\n';
}),
result.end()
);
auto start = result.begin();
auto end = result.end();
while (start != end && std::isspace(*start)) ++start;
while (start != end && std::isspace(*(end - 1))) --end;
return std::string(start, end);
}
bool IsSupportedFramework(const std::string& name)
{
static const std::unordered_set<std::string> supportedFrameworks = {"PTA", "MINDSPORE"};
return supportedFrameworks.find(name) != supportedFrameworks.end();
}
void MemoryCompare::ReadCsvFile(std::string &path, std::unordered_map<DEVICEID, ORIGINAL_FILE_DATA> &data)
{
std::ifstream csvFile(path, std::ios::in);
std::vector<std::string> headerData;
if (!CheckCsvHeader(path, csvFile, headerData)) {
LOG_ERROR("The headers of %s file is illegal!", path.c_str());
return ;
}
std::string line;
uint64_t countLine = 1;
while (getline(csvFile, line)) {
++countLine;
std::vector<std::string> lineData;
std::stringstream ss(line);
while (ss.good()) {
std::string singleValue = ReadQuotedField(ss);
Utility::ToSafeString(singleValue);
lineData.emplace_back(singleValue);
}
if (lineData.size() != headerData.size()) {
LOG_ERROR("The file %s on line %d is invalid!", path.c_str(), countLine);
data.clear();
return ;
}
std::unordered_map<std::string, std::string> tempLine;
for (size_t index = 0; index < headerData.size(); ++index) {
tempLine.insert({headerData[index], lineData[index]});
}
if (IsSupportedFramework(tempLine["Event Type"])) {
if (framework_.empty()) {
framework_ = tempLine["Event Type"];
}
if (framework_ != tempLine["Event Type"]) {
LOG_ERROR("The content of the file %s is invalid.", path.c_str());
data.clear();
return ;
}
}
uint64_t deviceId;
if (tempLine["Device Id"] == std::to_string(GD_INVALID_NUM) || tempLine["Device Id"] == "host" ||
tempLine["Device Id"] == "N/A") {
continue;
}
if (!Utility::StrToUint64(deviceId, tempLine["Device Id"])) {
LOG_WARN("StrToUint64 failed, the str is %s.", tempLine["Device Id"].c_str());
continue;
}
data[deviceId].emplace_back(tempLine);
}
csvFile.close();
}
void MemoryCompare::ReadNameIndexData(const ORIGINAL_FILE_DATA &originData, NAME_WITH_INDEX &dataList)
{
LOG_DEBUG("Read kernelLaunch/op data.");
std::unordered_set<std::string> eventMap;
BitField<decltype(config_.levelType)> levelType(config_.levelType);
if (levelType.checkBit(static_cast<size_t>(LevelType::LEVEL_OP))) {
if (framework_ == "MINDSPORE") {
LOG_ERROR("Comparison of the MindSpore framework under the op level is not supported.");
return ;
}
eventMap.insert("ATB_END");
eventMap.insert("ATEN_END");
}
if (levelType.checkBit(static_cast<size_t>(LevelType::LEVEL_KERNEL))) {
eventMap.insert("KERNEL_LAUNCH");
}
for (size_t index = 0; index < originData.size(); ++index) {
auto lineData = originData[index];
if (eventMap.find(lineData["Event Type"]) != eventMap.end()) {
if (Utility::CheckStrIsStartsWithInvalidChar(lineData["Name"].c_str())) {
LOG_ERROR("Name %s is invalid!", lineData["Name"].c_str());
dataList.clear();
return ;
}
dataList.emplace_back(std::make_tuple(lineData["Name"], lineData["Event"], index));
}
}
}
void MemoryCompare::GetMemoryUsage(size_t index, const ORIGINAL_FILE_DATA &data, int64_t &memDiff)
{
LOG_DEBUG("Get memorypool usage.");
std::unordered_map<std::string, std::string> frameworkMemory;
for (size_t i = index; i < data.size(); ++i) {
auto lineData = data[i];
if (lineData["Event Type"] == framework_) {
frameworkMemory = lineData;
break;
}
}
if (frameworkMemory.empty()) {
memDiff = 0;
return ;
}
std::string attrKey = "size";
std::string attrValue = Utility::ExtractAttrValueByKey(frameworkMemory["Attr"], attrKey);
if (attrValue.empty()) {
LOG_WARN("Attr has no \"size\" value");
return ;
}
if (!Utility::StrToInt64(memDiff, attrValue)) {
LOG_WARN("Alloc Size to int64_t failed!");
}
}
bool MemoryCompare::WriteCompareDataToCsv()
{
LOG_DEBUG("Write compare result data to csv file.");
if (result_.empty()) {
LOG_WARN("Empty comparison result data!");
return false;
}
if (!Utility::FileCreateManager::GetInstance(config_.outputDir).CreateCsvFile(&compareFile_,
GD_INVALID_NUM, MEMORY_COMPARE_FILE_PREFIX, COMPARE_DIR, std::string(STEP_INTER_HEADERS))) {
LOG_ERROR("Create comparison csv file failed!");
return false;
}
for (const auto& pair : result_) {
uint64_t deviceId = pair.first;
std::reverse(result_[deviceId].begin(), result_[deviceId].end());
for (const auto& str : result_[deviceId]) {
int fpRes = fprintf(compareFile_, "%s\n", str.c_str());
if (fpRes < 0) {
std::cout << "[msmemscope] Error: Fail to write data to csv file, errno:" << fpRes << std::endl;
return false;
}
}
}
return true;
}
void MemoryCompare::CalcuMemoryDiff(const DEVICEID deviceId,
const std::tuple<std::string, std::string, size_t> &baseData,
const std::tuple<std::string, std::string, size_t> &compareData)
{
std::string temp;
std::string name;
std::string event;
int64_t baseAllocSize = 0;
int64_t compareAllocSize = 0;
std::string baseMemDiff;
if (!std::get<0>(baseData).empty()) {
name = std::get<0>(baseData);
event = std::get<1>(baseData);
GetMemoryUsage(std::get<2>(baseData), baseFileOriginData_[deviceId], baseAllocSize);
baseMemDiff = std::to_string(baseAllocSize);
} else {
baseMemDiff = "N/A";
}
std::string compareMemDiff;
if (!std::get<0>(compareData).empty()) {
name = std::get<0>(compareData);
event = std::get<1>(compareData);
GetMemoryUsage(std::get<2>(compareData), compareFileOriginData_[deviceId], compareAllocSize);
compareMemDiff = std::to_string(compareAllocSize);
} else {
compareMemDiff = "N/A";
}
temp += event;
temp = temp + "," + name + "," + std::to_string(deviceId) + "," + baseMemDiff + "," + compareMemDiff;
int64_t diffAllocSize = Utility::GetSubResult(compareAllocSize, baseAllocSize);
temp = temp + "," + std::to_string(diffAllocSize);
result_[deviceId].emplace_back(temp);
}
std::shared_ptr<PathNode> MemoryCompare::BuildPath(const NAME_WITH_INDEX &baseLists,
const NAME_WITH_INDEX &compareLists)
{
LOG_DEBUG("Start to build myers path.");
const int64_t n = static_cast<int64_t>(baseLists.size());
const int64_t m = static_cast<int64_t>(compareLists.size());
const int64_t max = m + n + 1;
const int64_t size = 1 + 2 * max;
const int64_t middle = size / 2;
std::vector<std::shared_ptr<PathNode>> diagonal(size, nullptr);
diagonal[middle + 1] = std::make_shared<Snake>(0, -1);
auto start_time = Utility::GetTimeMicroseconds();
for (int64_t d = 0; d < max; ++d) {
for (int64_t k = -d; k <= d; k += KSTEPSIZE) {
auto end_time = Utility::GetTimeMicroseconds();
if ((end_time - start_time) >= MAXLOOPTIME) {
LOG_ERROR("Memory comparison build path failed! Reaching maximum loop time limit!");
break;
}
int64_t kmiddle = middle + k;
int64_t kplus = kmiddle + 1;
int64_t kminus = kmiddle - 1;
int64_t i;
std::shared_ptr<PathNode> prev;
if ((k == -d) || (k != d && diagonal[kminus]->i < diagonal[kplus]->i)) {
i = diagonal[kplus]->i;
prev = diagonal[kplus];
} else {
i = diagonal[kminus]->i + 1;
prev = diagonal[kminus];
}
int64_t j = i - k;
diagonal[kminus] = nullptr;
std::shared_ptr<PathNode> node = std::make_shared<DiffNode>(i, j, prev);
while (i < n && j < m && (std::get<0>(baseLists[i]) == std::get<0>(compareLists[j]))) {
++i;
++j;
}
if (i > node->i) {
node = std::make_shared<Snake>(i, j, node);
}
diagonal[kmiddle] = node;
if (i >= n && j >= m) {
return diagonal[kmiddle];
}
}
}
return nullptr;
}
void MemoryCompare::BuildDiff(std::shared_ptr<PathNode> path, const DEVICEID deviceId,
const NAME_WITH_INDEX &baseLists, const NAME_WITH_INDEX &compareLists)
{
LOG_DEBUG("Start to build myers diff.");
if (path == nullptr) {
LOG_WARN("Empty myers path!");
return ;
}
auto start_time = Utility::GetTimeMicroseconds();
while (path && path->prev && path->prev->j >= 0) {
auto end_time = Utility::GetTimeMicroseconds();
if ((end_time - start_time) >= MAXLOOPTIME) {
LOG_ERROR("Memory compare build diff failed! Reaching maximum loop time limit!");
break;
}
if (path->IsSnake()) {
int endi = path->i;
int endj = path->j;
int beginj = path->prev->j;
for (int i = endi - 1, j = endj - 1; j >= beginj; --i, --j) {
CalcuMemoryDiff(deviceId, baseLists[i], compareLists[j]);
}
} else {
int i = path->i;
int j = path->j;
int prei = path->prev->i;
if (prei < i) {
CalcuMemoryDiff(deviceId, baseLists[i - 1], {});
} else {
CalcuMemoryDiff(deviceId, {}, compareLists[j - 1]);
}
}
path = path->prev;
}
}
void MemoryCompare::MyersDiff(const DEVICEID deviceId, const NAME_WITH_INDEX &baseLists,
const NAME_WITH_INDEX &compareLists)
{
LOG_DEBUG("Start to compare with Myers algorithm.");
if (baseLists.empty() && compareLists.empty()) {
LOG_WARN("Device %s has empty kernelLaunch/op data!", std::to_string(deviceId).c_str());
return ;
} else {
auto pathNode = BuildPath(baseLists, compareLists);
BuildDiff(pathNode, deviceId, baseLists, compareLists);
}
}
void MemoryCompare::RunComparison(const std::vector<std::string> &paths)
{
LOG_INFO("Start to compare memory data.");
auto start_time = Utility::GetTimeMicroseconds();
std::string pathBase = paths[0];
std::string pathCompare = paths[1];
ReadFile(pathBase, baseFileOriginData_);
ReadFile(pathCompare, compareFileOriginData_);
if (baseFileOriginData_.empty() || compareFileOriginData_.empty()) {
std::cout << "[msmemscope] ERROR: Memory comparison failed!" << std::endl;
return ;
}
for (const auto& pair : baseFileOriginData_) {
deviceIdSet_.insert(pair.first);
}
for (const auto& pair : compareFileOriginData_) {
deviceIdSet_.insert(pair.first);
}
for (const auto& deviceId : deviceIdSet_) {
NAME_WITH_INDEX baseLists {};
NAME_WITH_INDEX compareLists {};
ReadNameIndexData(baseFileOriginData_[deviceId], baseLists);
ReadNameIndexData(compareFileOriginData_[deviceId], compareLists);
MyersDiff(deviceId, baseLists, compareLists);
}
if (!WriteCompareDataToCsv()) {
std::cout << "[msmemscope] ERROR: Memory comparison failed!" << std::endl;
} else {
auto end_time = Utility::GetTimeMicroseconds();
LOG_INFO("The memory comparison has been completed "
"in a total time of %.6f(s)", (end_time-start_time) / MICROSEC);
}
return ;
}
MemoryCompare::~MemoryCompare()
{
if (compareFile_ != nullptr) {
std::fclose(compareFile_);
compareFile_ = nullptr;
}
}
}