/*
 * -------------------------------------------------------------------------
 * This file is part of the MindStudio project.
 * Copyright (c) 2025 Huawei Technologies Co.,Ltd.
 *
 * MindStudio is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *
 *          http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 * -------------------------------------------------------------------------
 */

#include "BandwidthContentionAnalyzer.h"
#include "DataBaseManager.h"
#include "ServerLog.h"

namespace Dic {
namespace Module {
namespace Communication {
bool BandwidthContentionAnalyzer::QueryAdvisorData(const std::string &clusterPath) {
    std::vector<IterationsOrRanksObject> rankList;
    auto communicationDatabase = Timeline::DataBaseManager::Instance().GetClusterDatabase(clusterPath);
    if (!communicationDatabase || !communicationDatabase->QueryRanksHandler(rankList)) {
        Server::ServerLog::Error("Failed to get ranks data when query bandwidth contention data.");
        return false;
    }

    for (const auto &rank : rankList) {
        data.matMulData.insert({rank.iterationOrRankId, {}});
        data.SDMAData.insert({rank.iterationOrRankId, {}});
        auto summaryDatabase = Timeline::DataBaseManager::Instance().GetSummaryDatabaseByRankId(rank.iterationOrRankId);
        if (!summaryDatabase) {
            summaryDatabase = Timeline::DataBaseManager::Instance().GetSummaryDatabaseWithCluster(
                clusterPath, rank.iterationOrRankId);
        }
        if (!summaryDatabase) {
            Server::ServerLog::Error("Failed to get summary database connection.");
            continue;
        }
        summaryDatabase->QueryBandwidthContentionMatMulData(data.matMulData[rank.iterationOrRankId]);
        communicationDatabase->QueryBandwidthContentionAnalyzerData(
            data.SDMAData[rank.iterationOrRankId], rank.iterationOrRankId);
    }
    return true;
}

void BandwidthContentionAnalyzer::ComputeStatistics() {
    for (const auto &item : data.matMulData) {
        size_t HCCLIndex = 0;
        size_t matMulIndex = 0;
        while (HCCLIndex < data.SDMAData[item.first].size() && matMulIndex < data.matMulData[item.first].size()) {
            if (data.SDMAData[item.first][HCCLIndex].startTime + data.SDMAData[item.first][HCCLIndex].duration <
                data.matMulData[item.first][matMulIndex].startTime) {
                ++HCCLIndex;
                continue;
            }
            if (data.matMulData[item.first][matMulIndex].startTime + data.matMulData[item.first][matMulIndex].duration <
                data.SDMAData[item.first][HCCLIndex].startTime) {
                ++matMulIndex;
                continue;
            }
            if (data.SDMAData[item.first][HCCLIndex].bandwidth < BANDWIDTH_CONTENTION_ANALYZER_THRESHOLD) {
                BandwidthContentionAnalyzerStatistics op;
                op.rankId = item.first;
                op.name = data.SDMAData[item.first][HCCLIndex].name;
                op.duration = data.SDMAData[item.first][HCCLIndex].duration;
                op.bandwidth = data.SDMAData[item.first][HCCLIndex].bandwidth;
                statistics.emplace_back(op);
            }
            ++HCCLIndex;
        }
    }
}

void BandwidthContentionAnalyzer::AssembleAdvisor(Dic::Protocol::CommunicationAdvisorInfo &info) {
    info.name = BANDWIDTHCONTENTION_ANALYZER_TITLE;
    info.statistics.insert({"rankId", {}});
    info.statistics.insert({"name", {}});
    info.statistics.insert({"duration(us)", {}});
    info.statistics.insert({"bandwidth(GB/s)", {}});
    for (const auto &item : statistics) {
        info.statistics["rankId"].emplace_back(item.rankId);
        info.statistics["name"].emplace_back(item.name);
        info.statistics["duration(us)"].emplace_back(std::to_string(item.duration));
        info.statistics["bandwidth(GB/s)"].emplace_back(std::to_string(item.bandwidth));
    }
}
}
}
}