/*
 * -------------------------------------------------------------------------
 * This file is part of the MindStudio project.
 * Copyright (c) 2025 Huawei Technologies Co.,Ltd.
 *
 * MindStudio is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *
 *          http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 * -------------------------------------------------------------------------
 */
#include "MetaDataParser.h"
#include "ServerLog.h"
#include "JsonUtil.h"
#include "FileReader.h"
#include "NumberUtil.h"

namespace Dic::Module::Timeline {
std::vector<ParallelGroupInfo> MetaDataParser::ParserParallelGroupInfoByText(const std::string &text) {
    std::vector<ParallelGroupInfo> res;
    if (text.empty()) {
        return res;
    }
    try {
        std::string error;
        auto groupInfoJson = JsonUtil::TryParse<kParseNumbersAsStringsFlag>(text, error);
        if (!error.empty()) {
            Server::ServerLog::Error("Fail to convert parallel group info, error: ", error);
            return res;
        }
        return ConvertGroupInfoJsonToObject(groupInfoJson.value());
    } catch (const std::exception &e) {
        Server::ServerLog::Error("Fail to parser parallel group info. Error: ", e.what());
        return res;
    }
}

std::vector<ParallelGroupInfo> MetaDataParser::ParserParallelGroupInfoByFilePath(const std::string &filePath) {
    FileReader reader;
    std::string fileContext = reader.ReadJsonArray(filePath, 0, 0);
    std::vector<ParallelGroupInfo> res;
    if (fileContext.empty()) {
        Server::ServerLog::Error("Fail to read meta data file.");
        return res;
    }
    try {
        std::string error;
        auto metaDataJsonOpt = JsonUtil::TryParse<kParseNumbersAsStringsFlag>(fileContext, error);
        if (!error.empty()) {
            Server::ServerLog::Error("Fail to parser meta data file, error: ", error);
            return res;
        }
        if (!metaDataJsonOpt.value().IsObject()) {
            Server::ServerLog::Error("Fail to parser meta data file, data in wrong format.");
            return res;
        }
        if (metaDataJsonOpt.value().HasMember("parallel_group_info")) {
            document_t groupInfo;
            groupInfo.CopyFrom(metaDataJsonOpt.value()["parallel_group_info"], groupInfo.GetAllocator());
            res = ConvertGroupInfoJsonToObject(groupInfo);
        }
        return res;
    } catch (const std::exception &e) {
        Server::ServerLog::Error("Fail to parser meta data file context. Error: ", e.what());
        return res;
    }
}

std::optional<DistributedArgs> MetaDataParser::ParserDistributedArgsByFilePath(const std::string &filePath) {
    FileReader reader;
    std::string fileContext = reader.ReadJsonArray(filePath, 0, 0);
    if (fileContext.empty()) {
        Server::ServerLog::Error("Fail to read meta data file.");
        return std::nullopt;
    }
    try {
        std::string error;
        auto metaDataJsonOpt = JsonUtil::TryParse(fileContext, error);
        if (!error.empty()) {
            Server::ServerLog::Error("Fail to parser meta data file, error: ", error);
            return std::nullopt;
        }
        if (!metaDataJsonOpt.value().IsObject()) {
            Server::ServerLog::Error("Fail to parser meta data file, data in wrong format.");
            return std::nullopt;
        }
        if (!metaDataJsonOpt.value().HasMember("distributed_args")) {
            return std::nullopt;
        }
        document_t distributedArgsInfo;
        distributedArgsInfo.CopyFrom(metaDataJsonOpt.value()["distributed_args"], distributedArgsInfo.GetAllocator());
        return ConvertDistributedArgsJsonToObject(distributedArgsInfo);
    } catch (const std::exception &e) {
        Server::ServerLog::Error("Fail to parser meta data file context. Error: ", e.what());
        return std::nullopt;
    }
}

std::vector<ParallelGroupInfo> MetaDataParser::ConvertGroupInfoJsonToObject(const document_t &json) {
    std::vector<ParallelGroupInfo> res;
    if (!json.IsObject()) {
        Server::ServerLog::Error("Fail to convert parallel group info, data in wrong format.");
        return res;
    }
    for (auto iter = json.MemberBegin(); iter != json.MemberEnd(); ++iter) {
        ParallelGroupInfo info;
        info.group = JsonUtil::GetStringWithoutKey(iter->name);
        info.globalRanks = JsonUtil::GetVector<std::string>(iter->value, GLOBAL_RANKS);
        info.groupName = JsonUtil::GetString(iter->value, GLOBAL_NAME);
        res.push_back(info);
    }
    return res;
}

std::optional<DistributedArgs> MetaDataParser::ConvertDistributedArgsJsonToObject(const Dic::document_t &json) {
    if (!json.IsObject()) {
        Server::ServerLog::Error("Value of key distributed_args in profiler_metadata.json is not valid json format.");
        return std::nullopt;
    }
    for (const auto &item : DISTRIBUTED_ARGS_INT_KEY) {
        if (!json.HasMember(item.c_str()) || !json[item.c_str()].IsInt64()) {
            Server::ServerLog::Error("Value of key distributed_args in profiler_metadata.json lacks ", item,
                " key or "
                "value of this key is not of int type.");
            return std::nullopt;
        }
    }
    for (const auto &item : DISTRIBUTED_ARGS_BOOL_KEY) {
        if (!json.HasMember(item.c_str()) || !json[item.c_str()].IsBool()) {
            Server::ServerLog::Error("Value of key distributed_args in profiler_metadata.json lacks ", item,
                " key or "
                "value of this key is not of bool type.");
            return std::nullopt;
        }
    }
    DistributedArgs args;
    args.config.tpSize = NumberUtil::IntToUint32(json["tensor_model_parallel_size"].GetInt());
    args.config.ppSize = NumberUtil::IntToUint32(json["pipeline_model_parallel_size"].GetInt());
    args.config.dpSize = NumberUtil::IntToUint32(json["data_parallel_size"].GetInt());
    args.config.cpSize = NumberUtil::IntToUint32(json["context_parallel_size"].GetInt());
    args.config.epSize = NumberUtil::IntToUint32(json["expert_model_parallel_size"].GetInt());
    args.worldSize = NumberUtil::IntToUint32(json["world_size"].GetInt());
    args.sequenceParallel = json["sequence_parallel"].GetBool();
    return std::optional<DistributedArgs>(args);
}
}