* -------------------------------------------------------------------------
* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#ifndef PROFILER_SERVER_TABLE_H
#define PROFILER_SERVER_TABLE_H
#include <string>
#include <memory>
#include <vector>
#include <variant>
#include <functional>
#include <unordered_map>
#include "SqlitePreparedStatement.h"
#include "SqliteResultSet.h"
#include "sqlite3.h"
#include "DataBaseManager.h"
#include "ServerLog.h"
namespace Dic::Module::Timeline {
using namespace Dic::Server;
enum class TableOrder {
ASC,
DESC,
};
struct SqlStruct {
std::string conditionStr;
std::string orderByStr;
std::string groupByStr;
std::string selectStr;
};
template <typename T> class Table {
public:
Table() noexcept = default;
Table &Select(std::string_view str) {
if (std::empty(SelectStr())) {
SelectStr() = "SELECT " + std::string(str);
} else {
SelectStr() += "," + std::string(str);
}
auto it = GetAssignMap().find(str);
if (it != GetAssignMap().end()) {
AssignFuncs().emplace_back(it->second);
} else {
ServerLog::Warn("Select column is not exist");
}
return *this;
}
template <typename... Args> Table &Select(std::string_view str, const Args &...args) {
if (std::empty(SelectStr())) {
SelectStr() = "SELECT " + std::string(str);
} else {
SelectStr() += " , " + std::string(str);
}
auto it = GetAssignMap().find(str);
if (it != GetAssignMap().end()) {
AssignFuncs().emplace_back(it->second);
} else {
ServerLog::Warn("Select column is not exist");
}
Select(args...);
return *this;
}
Table &Eq(std::string_view str, std::variant<uint32_t, uint64_t, std::string> value) {
ConditionStr() += " AND " + std::string(str) + " = ? ";
Values().emplace_back(value);
return *this;
}
Table &NotEq(std::string_view str, std::variant<uint32_t, uint64_t, std::string> value) {
ConditionStr() += " AND " + std::string(str) + " != ? ";
Values().emplace_back(value);
return *this;
}
Table &Less(std::string_view str, std::variant<uint32_t, uint64_t, std::string> value) {
ConditionStr() += " AND " + std::string(str) + " < ? ";
Values().emplace_back(value);
return *this;
}
Table &LessEq(std::string_view str, std::variant<uint32_t, uint64_t, std::string> value) {
ConditionStr() += " AND " + std::string(str) + " <= ? ";
Values().emplace_back(value);
return *this;
}
Table &Greater(std::string_view str, std::variant<uint32_t, uint64_t, std::string> value) {
ConditionStr() += " AND " + std::string(str) + " > ? ";
Values().emplace_back(value);
return *this;
}
Table &GreaterEq(std::string_view str, std::variant<uint32_t, uint64_t, std::string> value) {
ConditionStr() += " AND " + std::string(str) + " >= ? ";
Values().emplace_back(value);
return *this;
}
Table &Like(std::string_view str, std::string value) {
ConditionStr() += " AND " + std::string(str) + " LIKE ? ";
Values().emplace_back(value);
return *this;
}
template <typename Y>
static inline constexpr bool is_one_of_basic_types =
std::disjunction_v<std::is_same<Y, uint32_t>, std::is_same<Y, uint64_t>, std::is_same<Y, std::string>>;
template <typename Y> std::string SaveParamListAndGetPlaceholderStr(const std::vector<Y> &inputList) {
static_assert(is_one_of_basic_types<Y>, "Fail to save param and get placeholder str, unknown type.");
std::string res;
for (size_t i = 0; i < inputList.size(); ++i) {
if (i == 0) {
res.append("?");
} else {
res.append(", ?");
}
Values().emplace_back(inputList[i]);
}
return res;
}
* 调用此函数需要先校验inputs不为空
* @param str
* @param inputs
* @return
*/
template <typename Y> Table &In(std::string_view str, const std::vector<Y> &inputs) {
ConditionStr() += " AND " + std::string(str) + " IN ( ";
std::string placeholderStr = SaveParamListAndGetPlaceholderStr(inputs);
ConditionStr() += placeholderStr + " ) ";
return *this;
}
template <typename Y> Table &NotIn(std::string_view str, const std::vector<Y> &inputs) {
ConditionStr() += " AND " + std::string(str) + " NOT IN ( ";
std::string placeholderStr = SaveParamListAndGetPlaceholderStr(inputs);
ConditionStr() += placeholderStr + " ) ";
return *this;
}
Table &OrderBy(std::string_view columnName, TableOrder order) {
if (std::empty(OrderByStr())) {
OrderByStr() = " ORDER BY " + std::string(columnName);
} else {
OrderByStr() += " , " + std::string(columnName);
}
if (order == TableOrder::DESC) {
OrderByStr() += " DESC ";
} else {
OrderByStr() += " ASC ";
}
return *this;
}
Table &GroupBy(std::string_view columnName) {
if (std::empty(GroupByStr())) {
GroupByStr() = " GROUP BY " + std::string(columnName);
} else {
GroupByStr() += " , " + std::string(columnName);
}
return *this;
}
virtual std::vector<T> ExcuteQuery(const std::string &fileId) {
std::vector<T> result;
ExcuteQuery(fileId, result);
return result;
}
virtual void ExcuteQuery(const std::string &fileId, std::vector<T> &result) {
auto database = DataBaseManager::Instance().GetTraceDatabaseByRankId(fileId);
if (database == nullptr) {
ClearThreadLocal();
return;
}
std::string sql =
SelectStr() + " FROM " + GetTableName() + " WHERE 1 = 1 " + ConditionStr() + OrderByStr() + GroupByStr();
auto stmt = database->CreatPreparedStatement(sql);
if (stmt == nullptr) {
ServerLog::Warn(GetTableName() + " Failed to get stmt.");
ClearThreadLocal();
return;
}
for (const auto &item : Values()) {
if (std::holds_alternative<uint32_t>(item)) {
stmt->BindParams(std::get<uint32_t>(item));
} else if (std::holds_alternative<uint64_t>(item)) {
stmt->BindParams(std::get<uint64_t>(item));
} else if (std::holds_alternative<std::string>(item)) {
stmt->BindParams(std::get<std::string>(item));
}
}
auto resultSet = stmt->ExecuteQuery();
if (resultSet == nullptr) {
ServerLog::Warn(GetTableName() + " Failed to get result set.", stmt->GetErrorMessage());
ClearThreadLocal();
return;
}
while (resultSet->Next()) {
T t;
for (const auto &item : AssignFuncs()) {
item(t, resultSet);
}
result.emplace_back(t);
}
ClearThreadLocal();
}
virtual void ExcuteQuery(sqlite3 *db, std::vector<T> &result) {
auto stmt = std::make_unique<SqlitePreparedStatement>(db);
if (stmt == nullptr) {
ServerLog::Warn(GetTableName() + " Failed to get stmt.");
ClearThreadLocal();
return;
}
std::string sql =
SelectStr() + " FROM " + GetTableName() + " WHERE 1 = 1 " + ConditionStr() + OrderByStr() + GroupByStr();
if (!stmt->Prepare(sql)) {
ServerLog::Error("Failed prepare sql. ", stmt->GetErrorMessage());
ClearThreadLocal();
return;
}
for (const auto &item : Values()) {
if (std::holds_alternative<uint32_t>(item)) {
stmt->BindParams(std::get<uint32_t>(item));
} else if (std::holds_alternative<uint64_t>(item)) {
stmt->BindParams(std::get<uint64_t>(item));
} else if (std::holds_alternative<std::string>(item)) {
stmt->BindParams(std::get<std::string>(item));
}
}
auto resultSet = stmt->ExecuteQuery();
if (resultSet == nullptr) {
ServerLog::Warn(GetTableName() + " Failed to get result set.", stmt->GetErrorMessage());
ClearThreadLocal();
return;
}
while (resultSet->Next()) {
T t;
for (const auto &item : AssignFuncs()) {
item(t, resultSet);
}
result.emplace_back(t);
}
ClearThreadLocal();
}
virtual uint64_t Count(const std::string &fileId) {
auto database = DataBaseManager::Instance().GetTraceDatabaseByRankId(fileId);
if (database == nullptr) {
ClearThreadLocal();
return 0;
}
uint64_t count = 0;
std::string sql = "SELECT COUNT(*) AS count FROM " + GetTableName() + " WHERE 1 = 1 " + ConditionStr();
auto stmt = database->CreatPreparedStatement(sql);
if (stmt == nullptr) {
ServerLog::Warn(GetTableName() + " Failed to get stmt.");
ClearThreadLocal();
return count;
}
for (const auto &item : Values()) {
if (std::holds_alternative<uint32_t>(item)) {
stmt->BindParams(std::get<uint32_t>(item));
} else if (std::holds_alternative<uint64_t>(item)) {
stmt->BindParams(std::get<uint64_t>(item));
} else if (std::holds_alternative<std::string>(item)) {
stmt->BindParams(std::get<std::string>(item));
}
}
auto resultSet = stmt->ExecuteQuery();
if (resultSet == nullptr) {
ServerLog::Warn(GetTableName() + " Failed to get result set.", stmt->GetErrorMessage());
ClearThreadLocal();
return count;
}
if (resultSet->Next()) {
count = resultSet->GetUint64("count");
}
ClearThreadLocal();
return count;
}
virtual std::string GetDbPath(const std::string &fileId) {
auto database = DataBaseManager::Instance().GetTraceDatabaseByRankId(fileId);
if (database == nullptr) {
std::string empty;
return empty;
}
const std::string nameKey = database->GetDbPath();
return nameKey;
}
virtual ~Table() = default;
protected:
using assign = std::function<void(T &, const std::unique_ptr<SqliteResultSet> &)>;
inline std::string &SelectStr() const { return GetSqlStruct().selectStr; }
inline std::string &ConditionStr() const { return GetSqlStruct().conditionStr; }
inline std::string &OrderByStr() const { return GetSqlStruct().orderByStr; }
inline std::string &GroupByStr() const { return GetSqlStruct().groupByStr; }
inline SqlStruct &GetSqlStruct() const {
thread_local SqlStruct sqlStruct;
return sqlStruct;
}
std::vector<std::variant<uint32_t, uint64_t, std::string>> &Values() {
thread_local std::vector<std::variant<uint32_t, uint64_t, std::string>> values;
return values;
}
std::vector<assign> &AssignFuncs() {
thread_local std::vector<assign> assignFuncs;
return assignFuncs;
}
* 每次调用完成需要重置threadlocal变量
*/
void ClearThreadLocal() {
AssignFuncs().clear();
Values().clear();
ConditionStr().clear();
SelectStr().clear();
OrderByStr().clear();
GroupByStr().clear();
}
virtual const std::unordered_map<std::string_view, assign> &GetAssignMap() = 0;
virtual const std::string &GetTableName() = 0;
};
}
#endif