#include "TritonParser.h"
#include <algorithm>
#include <array>
#include "FileUtil.h"
#include "JsonUtil.h"
#include "ParserStatusManager.h"
#include "ServerLog.h"
#include "ThreadPool.h"
#include "ProtocolDefs.h"
#include "TimelineProtocolEvent.h"
#include "TritonProtocolEvent.h"
#include "TritonService.h"
#include "WsSender.h"
namespace Dic::Module::Triton {
using namespace Dic::Server;
namespace {
template <size_t N> bool HasExpectedMembers(const json_t &jsonObj, const std::array<const char *, N> &expectedKeys) {
if (!jsonObj.IsObject()) {
return false;
}
return std::all_of(
expectedKeys.begin(), expectedKeys.end(), [&jsonObj](const char *key) { return jsonObj.HasMember(key); });
}
}
TritonParser &TritonParser::Instance() {
static TritonParser instance;
return instance;
}
void TritonParser::Parse(const std::string &parseDir) {
Timeline::ParserStatusManager::Instance().WaitStartParse();
BeforeParse(parseDir);
auto future = ThreadPool::Instance().AddTask(
[this, parseDir]() { return ParseImpl(parseDir); }, TraceIdManager::GetTraceId());
auto result = future.get();
AfterParse(result);
}
bool TritonParser::IsParsed(const std::string &filePath) const {
if (filePath.empty()) {
return false;
}
if (!FileUtil::CheckPathSecurity(filePath, CHECK_FILE_READ)) {
return false;
}
const std::string fileName = FileUtil::GetFileName(filePath);
return fileName == tritonMemFileName;
}
void TritonParser::BeforeParse(const std::string &parsedDir) {
TritonService::Instance().Reset();
if (!FileUtil::CheckPathSecurity(parsedDir, CHECK_FILE_READ)) {
ServerLog::Error("Triton file dir is not safe, please check log for more information");
return;
}
parsedFiles.clear();
parsedFiles.push_back(parsedDir);
}
void TritonParser::AfterParse(const ParseResult &result) const {
if (!result) {
auto event = std::make_unique<Protocol::ParseFailEvent>();
event->moduleName = Protocol::MODULE_TIMELINE;
event->result = false;
std::string path = parsedFiles.empty() ? "" : parsedFiles.front();
event->body.rankId = path;
event->body.error = result.GetErrorMsg();
event->body.dbPath = path;
Dic::SendEvent(std::move(event));
return;
}
auto event = std::make_unique<Protocol::TritonParseSuccessEvent>();
event->moduleName = Protocol::MODULE_TRITON;
event->result = true;
Protocol::TritonParseSuccessEventBody body;
event->body = body;
Dic::SendEvent(std::move(event));
}
ParseResult TritonParser::ParseImpl(const std::string &parsedDir) {
if (parsedFiles.empty()) {
ServerLog::Error("Not found need parsed File.");
return {false, "Not found triton file"};
}
std::for_each(parsedFiles.begin(), parsedFiles.end(),
[this](const std::string &filePath) { auto result = ParseOneTriton(filePath); });
return {true, "success"};
}
bool TritonParser::CheckFileValid(const std::string &fileName, std::string &error) {
if (fileName.empty()) {
error = "Triton file name is required";
return false;
}
if (!FileUtil::CheckPathSecurity(fileName, CHECK_FILE_READ)) {
error = "Triton file not satisfy safety requirement, please check the log for mor information";
return false;
}
return true;
}
ParseResult TritonParser::ParseOneTriton(const std::string &memFile) {
document_t jsonDoc = JsonUtil::ReadJsonFromFile(memFile);
if (!CheckDataValid(jsonDoc)) {
return {false, "Invalid Data"};
}
TritonMemeHeader header;
auto &headerJson = jsonDoc["Header"];
JsonUtil::SetByJsonKeyValue(header.kernelName, headerJson, "KernelName");
std::map<std::string, TritonRecord> scopeMap;
auto &jsonRecord = jsonDoc["Record"];
for (const json_t &recordItem : jsonRecord.GetArray()) {
TritonRecord tritonRecord;
std::string status, errMsg, scope;
JsonUtil::SetByJsonKeyValue(scope, recordItem, "scope");
JsonUtil::SetByJsonKeyValue(status, recordItem, "status");
JsonUtil::SetByJsonKeyValue(errMsg, recordItem, "err_msg");
TritonService::Instance().UpdateCompileInfo(scope, {status, errMsg});
header.memTypes.push_back(scope);
auto &memInfoArray = recordItem["memory_info_array"];
tritonRecord.segments.reserve(memInfoArray.Size());
for (const json_t &memInfo : memInfoArray.GetArray()) {
TritonTensorSegment segment;
JsonUtil::SetByJsonKeyValue(segment.allocate, memInfo, "alloc_time_in_ir");
JsonUtil::SetByJsonKeyValue(segment.buffer, memInfo, "buffer");
JsonUtil::SetByJsonKeyValue(segment.sourceLocation, memInfo, "source_location");
JsonUtil::SetByJsonKeyValue(segment.tmpBuf, memInfo, "is_tmpbuf");
auto lifeTime = JsonUtil::GetVector<uint64_t>(memInfo, "life_time_in_ir");
if (lifeTime.size() >= 2) {
segment.start = lifeTime[0];
segment.end = lifeTime[1];
} else {
segment.start = 0;
segment.end = 0;
}
uint64_t extend = memInfo["extent"].GetUint64() / 8;
uint64_t blockCount = memInfo["offset"].Size();
segment.size = extend * blockCount;
segment.blocks.reserve(blockCount);
for (const auto &block : memInfo["offset"].GetArray()) {
TritonTensorBlock blockData(segment);
blockData.offset = block.GetUint64();
blockData.size = extend;
segment.blocks.emplace_back(std::move(blockData));
}
tritonRecord.segments.emplace_back(std::move(segment));
}
scopeMap[scope] = std::move(tritonRecord);
}
TritonService::Instance().SetHeader(std::move(header));
TritonService::Instance().UpdateRecord(std::move(scopeMap));
return {true, "Success"};
}
bool TritonParser::CheckDataValid(document_t &json) {
if (json.IsNull()) {
return false;
}
if (!json.IsObject()) {
return false;
}
constexpr std::array<const char *, 2> rootKeys = {"Header", "Record"};
if (!HasExpectedMembers(json, rootKeys)) {
ServerLog::Error("Triton json root required keys are missing");
return false;
}
const json_t &header = json["Header"];
constexpr std::array<const char *, 1> headerKeys = {"KernelName"};
if (!HasExpectedMembers(header, headerKeys)) {
ServerLog::Error("Triton Header required keys are missing");
return false;
}
if (!header["KernelName"].IsString()) {
ServerLog::Error("Triton Header value types are invalid: KernelName must be string");
return false;
}
if (!json["Record"].IsArray()) {
ServerLog::Error("Triton Record must be array");
return false;
}
constexpr std::array<const char *, 3> recordItemKeys = {"scope", "status", "memory_info_array"};
constexpr std::array<const char *, 7> memoryInfoKeys = {
"alloc_time_in_ir", "buffer", "extent", "is_tmpbuf", "life_time_in_ir", "offset", "source_location"};
for (const json_t &recordItem : json["Record"].GetArray()) {
if (!HasExpectedMembers(recordItem, recordItemKeys)) {
ServerLog::Error("Triton Record item required keys are missing");
return false;
}
bool recordTypesOk = recordItem["scope"].IsString() && recordItem["status"].IsString() &&
recordItem["memory_info_array"].IsArray();
if (!recordTypesOk) {
ServerLog::Error("Triton Record item value types are invalid");
return false;
}
for (const json_t &memInfo : recordItem["memory_info_array"].GetArray()) {
if (!HasExpectedMembers(memInfo, memoryInfoKeys)) {
ServerLog::Error("Triton memory_info item required keys are missing");
return false;
}
bool typesOk = memInfo["alloc_time_in_ir"].IsInt64() && memInfo["buffer"].IsString() &&
memInfo["extent"].IsInt64() && memInfo["is_tmpbuf"].IsBool() && memInfo["offset"].IsArray() &&
memInfo["source_location"].IsString();
if (!typesOk) {
ServerLog::Error("Triton memory_info item value types are invalid");
return false;
}
const json_t &lifeTime = memInfo["life_time_in_ir"];
if (!lifeTime.IsArray()) {
ServerLog::Error("Triton life_time_in_ir format is invalid");
return false;
}
}
}
return true;
}
}