* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
*/
#include <iostream>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "safe_io.h"
namespace mindie_llm {
constexpr int TestJsonDepth = 10;
class TestSafe : public ::testing::Test {
void SetUp() override {
}
};
std::string TestCreatNestedJson(int depth)
{
std::string out;
for (int i = 0; i < depth; i++) {
out += "{\"a\":";
}
out += "\"b\"";
for (int i = 0; i < depth; i++) {
out += "}";
}
return out;
}
std::string TestCreatNestedArrayJson(int depth)
{
std::string out;
for (int i = 0; i < depth; i++) {
out += "[";
}
out += "\"b\"";
for (int i = 0; i < depth; i++) {
out += "]";
}
return out;
}
int TestGetJsonDepth(const nlohmann::json& j)
{
if (j.is_object()) {
int maxLevel = 0;
for (auto& element : j.items()) {
if (element.value().is_object()) {
int subLevel = TestGetJsonDepth(element.value());
if (subLevel > maxLevel) {
maxLevel = subLevel;
}
}
}
return maxLevel + 1;
} else if (j.is_array()) {
int maxLevel = 0;
for (auto& element : j) {
if (element.is_array()) {
int subLevel = TestGetJsonDepth(element);
if (subLevel > maxLevel) {
maxLevel = subLevel;
}
}
}
return maxLevel + 1;
} else {
return 0;
}
}
TEST_F(TestSafe, TestCheckJsonDepthCallback)
{
nlohmann::json dummObj{};
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(TestJsonDepth, Json::parse_event_t::object_start, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(TestJsonDepth, Json::parse_event_t::array_start, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(TestJsonDepth, Json::parse_event_t::object_end, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(TestJsonDepth, Json::parse_event_t::array_end, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(TestJsonDepth, Json::parse_event_t::key, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(TestJsonDepth, Json::parse_event_t::value, dummObj));
ASSERT_EQ(false, CheckJsonDepthCallbackNoLogger(TestJsonDepth + 1, Json::parse_event_t::object_start, dummObj));
ASSERT_EQ(false, CheckJsonDepthCallbackNoLogger(TestJsonDepth + 1, Json::parse_event_t::array_start, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(TestJsonDepth + 1, Json::parse_event_t::object_end, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(TestJsonDepth + 1, Json::parse_event_t::array_end, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(TestJsonDepth + 1, Json::parse_event_t::key, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(TestJsonDepth + 1, Json::parse_event_t::value, dummObj));
std::vector<int> testCase = {10, 20};
for (int i = 0; i < testCase.size(); i++) {
int depth = testCase[i];
SetJsonDepthLimit(depth);
ASSERT_EQ(depth, GetJsonDepthLimit());
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(depth, Json::parse_event_t::object_start, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(depth, Json::parse_event_t::array_start, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(depth, Json::parse_event_t::object_end, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(depth, Json::parse_event_t::array_end, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(depth, Json::parse_event_t::key, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(depth, Json::parse_event_t::value, dummObj));
ASSERT_EQ(false, CheckJsonDepthCallbackNoLogger(depth + 1, Json::parse_event_t::object_start, dummObj));
ASSERT_EQ(false, CheckJsonDepthCallbackNoLogger(depth + 1, Json::parse_event_t::array_start, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(depth + 1, Json::parse_event_t::object_end, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(depth + 1, Json::parse_event_t::array_end, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(depth + 1, Json::parse_event_t::key, dummObj));
ASSERT_EQ(true, CheckJsonDepthCallbackNoLogger(depth + 1, Json::parse_event_t::value, dummObj));
}
}
TEST_F(TestSafe, TestCheckJsonDepthCallbackWithParse)
{
std::string strNest10 = TestCreatNestedJson(TestJsonDepth);
auto objBase = nlohmann::json::parse(strNest10);
int depBase = TestGetJsonDepth(objBase);
ASSERT_EQ(depBase, TestJsonDepth);
auto objNest10 = nlohmann::json::parse(strNest10, CheckJsonDepthCallbackNoLogger);
int depNest10 = TestGetJsonDepth(objNest10);
ASSERT_EQ(depNest10, TestJsonDepth);
std::string strNest11 = TestCreatNestedJson(TestJsonDepth + 1);
auto objNest11 = nlohmann::json::parse(strNest11, CheckJsonDepthCallbackNoLogger);
int depNest11 = TestGetJsonDepth(objNest11);
ASSERT_EQ(depNest11, TestJsonDepth + 1);
std::string strArr10 = TestCreatNestedArrayJson(TestJsonDepth);
auto arrBase = nlohmann::json::parse(strArr10);
depBase = TestGetJsonDepth(arrBase);
ASSERT_EQ(depBase, TestJsonDepth);
auto arrNest10 = nlohmann::json::parse(strArr10, CheckJsonDepthCallbackNoLogger);
int depArr10 = TestGetJsonDepth(arrNest10);
ASSERT_EQ(depArr10, TestJsonDepth);
std::string strArr11 = TestCreatNestedArrayJson(TestJsonDepth + 1);
auto arrNest11 = nlohmann::json::parse(strArr11, CheckJsonDepthCallbackNoLogger);
int depArr11 = TestGetJsonDepth(arrNest11);
ASSERT_EQ(depArr11, TestJsonDepth + 1);
}
}