* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Description: Test of priority queue
*/
#include <atomic>
#include <cstring>
#include <memory>
#include <string>
#include <thread>
#include "datasystem/common/util/queue/priority_queue.h"
#include "datasystem/common/util/queue/queue.h"
#include "datasystem/common/util/random_data.h"
#include "datasystem/common/util/wait_post.h"
#include "ut/common.h"
namespace datasystem {
namespace ut {
class PriorityQueueTest : public CommonTest {};
TEST_F(PriorityQueueTest, SingleThreadPushAndPop)
{
size_t capacity = 6;
PriorityQueue<std::string> pq(capacity);
std::string str1 = "Hello world1!";
std::string str2 = "Hello world2!";
std::string str3 = "Hello world3!";
std::string str4 = "Hello world4!";
std::string str5 = "Hello world5!";
std::string str6 = "Hello world6!";
DS_ASSERT_OK(pq.Put(str1));
DS_ASSERT_OK(pq.Put(std::move(str2)));
DS_ASSERT_OK(pq.Add(str3));
DS_ASSERT_OK(pq.Add(std::move(str4)));
DS_ASSERT_OK(pq.Offer(str5, 10));
DS_ASSERT_OK(pq.Offer(std::move(str6), 10));
ASSERT_TRUE(pq.Full());
std::string str7 = "Hello world7!";
ASSERT_EQ(pq.Add(str7).GetCode(), StatusCode::K_TRY_AGAIN);
ASSERT_EQ(pq.Add(std::move(str7)).GetCode(), StatusCode::K_TRY_AGAIN);
ASSERT_EQ(pq.Offer(str7, 10).GetCode(), StatusCode::K_TRY_AGAIN);
ASSERT_EQ(pq.Offer(std::move(str7), 10).GetCode(), StatusCode::K_TRY_AGAIN);
std::string buffer;
DS_ASSERT_OK(pq.Take(&buffer));
ASSERT_EQ(buffer, "Hello world6!");
DS_ASSERT_OK(pq.Take(&buffer));
ASSERT_EQ(buffer, "Hello world5!");
DS_ASSERT_OK(pq.Remove(&buffer));
ASSERT_EQ(buffer, "Hello world4!");
DS_ASSERT_OK(pq.Remove(&buffer));
ASSERT_EQ(buffer, "Hello world3!");
DS_ASSERT_OK(pq.Poll(&buffer, 10));
ASSERT_EQ(buffer, "Hello world2!");
DS_ASSERT_OK(pq.Poll(&buffer, 10));
ASSERT_EQ(buffer, "Hello world1!");
ASSERT_TRUE(pq.Empty());
ASSERT_EQ(pq.Remove(&buffer).GetCode(), StatusCode::K_TRY_AGAIN);
ASSERT_EQ(pq.Poll(&buffer, 10).GetCode(), StatusCode::K_TRY_AGAIN);
}
TEST_F(PriorityQueueTest, MultiThreadPushAndPop1)
{
size_t capacity = 8;
PriorityQueue<std::string> pq(capacity);
std::vector<std::thread> writerThreads;
std::vector<std::string> strs = { "test1", "test2", "test3", "test4", "test5", "test6", "test7", "test8" };
for (int i = 0; i < 2; ++i) {
writerThreads.push_back(std::thread([&pq, &strs, i]() {
for (int j = 0; j < 4; ++j) {
DS_ASSERT_OK(pq.Put(strs[j + 4 * i]));
}
}));
}
std::this_thread::sleep_for(std::chrono::seconds(2));
ASSERT_TRUE(pq.Full());
for (int i = 0; i < 2; ++i) {
writerThreads.push_back(std::thread([&pq]() { DS_ASSERT_OK(pq.Put("test9")); }));
}
std::this_thread::sleep_for(std::chrono::seconds(2));
std::string buffer;
DS_ASSERT_OK(pq.Poll(&buffer, 10));
ASSERT_EQ(buffer, "test8");
std::this_thread::sleep_for(std::chrono::seconds(1));
DS_ASSERT_OK(pq.Poll(&buffer, 10));
ASSERT_EQ(buffer, "test9");
std::this_thread::sleep_for(std::chrono::seconds(1));
DS_ASSERT_OK(pq.Poll(&buffer, 10));
ASSERT_EQ(buffer, "test9");
for (int i = 6; i >= 0; --i) {
DS_ASSERT_OK(pq.Poll(&buffer, 10));
ASSERT_EQ(buffer, strs[i]);
}
ASSERT_TRUE(pq.Empty());
for (int i = 0; i < static_cast<int>(writerThreads.size()); ++i) {
writerThreads[i].join();
}
}
TEST_F(PriorityQueueTest, MultiThreadPushAndPop2)
{
size_t capacity = 8;
PriorityQueue<std::string> pq(capacity);
std::vector<std::thread> readerThreads;
for (int i = 0; i < 2; ++i) {
readerThreads.push_back(std::thread([&pq]() {
std::string buffer;
DS_ASSERT_OK(pq.Take(&buffer));
}));
}
std::string str = "Hello world!";
for (int i = 0; i < 2; ++i) {
DS_ASSERT_OK(pq.Put(str));
}
for (int i = 0; i < static_cast<int>(readerThreads.size()); ++i) {
readerThreads[i].join();
}
ASSERT_TRUE(pq.Empty());
}
TEST_F(PriorityQueueTest, MultiThreadPushAndPop3)
{
size_t capacity = 100;
struct Comp {
bool operator()(const std::unique_ptr<int> &p1, const std::unique_ptr<int> &p2)
{
return *p1 / 30 < *p2 / 30;
}
};
PriorityQueue<std::unique_ptr<int>, Comp> pq(capacity);
int numReaders = 10, numWriters = 10;
std::vector<std::thread> readerThreads, writerThreads;
for (int i = 0; i < numWriters; ++i) {
writerThreads.push_back(std::thread([&pq, i]() {
Status rc;
for (int j = 0; j < 30; ++j) {
int op = j % 3;
switch (op) {
case 0:
do {
rc = pq.Add(std::make_unique<int>(i * 30 + j));
} while (rc.GetCode() == StatusCode::K_TRY_AGAIN);
break;
case 1:
do {
rc = pq.Offer(std::make_unique<int>(i * 30 + j), 100);
} while (rc.GetCode() == StatusCode::K_TRY_AGAIN);
break;
default:
rc = pq.Put(std::make_unique<int>(i * 30 + j));
break;
}
ASSERT_TRUE(rc.IsOk());
}
}));
}
for (int i = 0; i < numReaders; ++i) {
readerThreads.push_back(std::thread([&pq]() {
Status rc;
for (int j = 0; j < 30; ++j) {
int op = j % 3;
std::unique_ptr<int> buffer;
switch (op) {
case 0:
do {
rc = pq.Remove(&buffer);
} while (rc.GetCode() == StatusCode::K_TRY_AGAIN);
break;
case 1:
do {
rc = pq.Poll(&buffer, 100);
} while (rc.GetCode() == StatusCode::K_TRY_AGAIN);
break;
default:
rc = pq.Take(&buffer);
break;
}
ASSERT_TRUE(rc.IsOk());
}
}));
}
for (int i = 0; i < static_cast<int>(writerThreads.size()); ++i) {
writerThreads[i].join();
}
for (int i = 0; i < static_cast<int>(readerThreads.size()); ++i) {
readerThreads[i].join();
}
ASSERT_TRUE(pq.Empty());
}
TEST_F(PriorityQueueTest, MultiThreadPushAndPop4)
{
size_t capacity = 300;
struct Comp {
bool operator()(const int &i1, const int &i2)
{
return i1 / 30 < i2 / 30;
}
};
PriorityQueue<int, Comp> pq(capacity);
int numWriters = 10;
std::vector<std::thread> writerThreads;
for (int i = 0; i < numWriters; ++i) {
writerThreads.push_back(std::thread([&pq, i]() {
Status rc;
for (int j = 0; j < 30; ++j) {
int op = j % 3;
int k = i * 30 + j;
switch (op) {
case 0:
do {
rc = pq.Add(k);
} while (rc.GetCode() == StatusCode::K_TRY_AGAIN);
break;
case 1:
do {
rc = pq.Offer(k, 100);
} while (rc.GetCode() == StatusCode::K_TRY_AGAIN);
break;
default:
rc = pq.Put(k);
break;
}
ASSERT_TRUE(rc.IsOk());
}
}));
}
for (int i = 0; i < static_cast<int>(writerThreads.size()); ++i) {
writerThreads[i].join();
}
ASSERT_TRUE(pq.Full());
int k = -1;
for (int i = 9; i >= 0; --i) {
for (int j = 0; j < 30; ++j) {
DS_ASSERT_OK(pq.Poll(&k, 1000));
ASSERT_TRUE(k >= i * 30);
ASSERT_TRUE(k < (i + 1) * 30);
}
}
}
TEST_F(PriorityQueueTest, MultiThreadPushAndPop5)
{
int capacity = 100;
struct Comp {
bool operator()(const std::shared_ptr<int> &p1, const std::shared_ptr<int> &p2)
{
return *p1 % 2 > *p2 % 2;
}
};
PriorityQueue<std::shared_ptr<int>, Comp> pq(capacity);
auto writer = std::thread([&pq, capacity]() {
for (int i = 0; i < capacity; ++i) {
if (i % 2 == 0) {
DS_ASSERT_OK(pq.Offer(std::make_shared<int>(i), 100));
} else {
auto sp = std::make_shared<int>(i);
DS_ASSERT_OK(pq.Add(sp));
}
}
});
writer.join();
ASSERT_TRUE(pq.Full());
std::vector<std::thread> readers;
std::mutex mux;
Queue<std::shared_ptr<int>> queue(2);
for (int j = 0; j < 2; ++j) {
readers.push_back(std::thread([&pq, &queue, &mux, capacity]() {
std::shared_ptr<int> buffer;
for (int i = 0; i < capacity / 2; ++i) {
std::unique_lock<std::mutex> lock(mux);
DS_ASSERT_OK(pq.Poll(&buffer, 100));
Status rc;
do {
rc = queue.Offer(std::move(buffer), 100);
} while (rc.GetCode() == StatusCode::K_TRY_AGAIN);
ASSERT_TRUE(rc.IsOk());
}
}));
}
std::shared_ptr<int> buffer;
Status rc;
for (int j = 0; j < 2; ++j) {
for (int i = j; i < capacity; i += 2) {
do {
rc = queue.Remove(&buffer);
} while (rc.GetCode() == StatusCode::K_TRY_AGAIN);
ASSERT_TRUE(rc.IsOk());
ASSERT_EQ(*buffer, i);
}
}
readers[1].join();
readers[0].join();
ASSERT_TRUE(pq.Empty());
ASSERT_TRUE(queue.Empty());
}
TEST_F(PriorityQueueTest, TestNumerousThreads)
{
int capacity = 10;
struct Comp {
bool operator()(const std::unique_ptr<int> &p1, const std::unique_ptr<int> &p2)
{
return *p1 / 100 < *p2 / 100;
}
};
PriorityQueue<std::unique_ptr<int>, Comp> pq(capacity);
int numWriters = 100, numReaders = 100;
std::vector<std::thread> writers;
uint16_t done = 0;
std::atomic<uint16_t> toDo(10000);
std::mutex mux1, mux2;
for (int i = 0; i < numWriters; ++i) {
writers.push_back(std::thread([&pq, &mux1, &toDo, i]() {
Status rc;
for (int j = 0; j < 100; ++j) {
do {
std::unique_lock<std::mutex> lock(mux1);
rc = pq.Offer(std::make_unique<int>(j + i * 100), 100);
if (rc.IsOk()) {
--toDo;
}
} while (rc.GetCode() == StatusCode::K_TRY_AGAIN);
ASSERT_TRUE(rc.IsOk());
}
}));
}
std::vector<std::thread> readers;
for (int i = 0; i < numReaders; ++i) {
readers.push_back(std::thread([&pq, &mux2, &toDo, &done] {
std::unique_ptr<int> buffer;
do {
std::unique_lock<std::mutex> lock(mux2);
Status rc = pq.Poll(&buffer, 100);
if (rc.IsOk()) {
++done;
}
} while (!pq.Empty() || toDo != 0);
}));
}
for (int i = 0; i < numWriters; ++i) {
writers[i].join();
}
for (int i = 0; i < numReaders; ++i) {
readers[i].join();
}
ASSERT_EQ(done, 10000);
}
TEST_F(PriorityQueueTest, TestEmplaceBack)
{
int capacity = 8;
struct Comp {
std::less<std::string> cmp;
bool operator()(const Status &s1, const Status &s2)
{
return cmp(s1.GetMsg(), s2.GetMsg());
}
};
PriorityQueue<Status, Comp> pq(capacity);
for (int i = 0; i < capacity; ++i) {
DS_ASSERT_OK(pq.EmplaceBack(StatusCode::K_TRY_AGAIN, "Test EmplaceBack" + std::to_string(i)));
}
ASSERT_TRUE(pq.Full());
Status buffer;
for (int i = capacity - 1; i >= 0; --i) {
DS_ASSERT_OK(pq.Remove(&buffer));
ASSERT_EQ(buffer.GetCode(), StatusCode::K_TRY_AGAIN);
ASSERT_EQ(buffer.GetMsg(), "Test EmplaceBack" + std::to_string(i));
}
ASSERT_TRUE(pq.Empty());
}
TEST_F(PriorityQueueTest, TestStabilityGreaterCmp)
{
size_t capacity = 12;
struct Comp {
bool operator()(const std::shared_ptr<int> &p1, const std::shared_ptr<int> &p2)
{
return *p1 / 100 > *p2 / 100;
}
};
PriorityQueue<std::shared_ptr<int>, Comp> pq(capacity);
for (int i = 0; i < 4; ++i) {
DS_ASSERT_OK(pq.Offer(std::make_shared<int>(i), 100));
}
DS_ASSERT_OK(pq.Offer(std::make_shared<int>(100), 100));
DS_ASSERT_OK(pq.Offer(std::make_shared<int>(110), 100));
DS_ASSERT_OK(pq.Offer(std::make_shared<int>(120), 100));
DS_ASSERT_OK(pq.Offer(std::make_shared<int>(130), 100));
for (int i = 7; i >= 4; --i) {
DS_ASSERT_OK(pq.Offer(std::make_shared<int>(i), 100));
}
ASSERT_TRUE(pq.Full());
std::shared_ptr<int> buffer;
for (int i = 0; i < 4; ++i) {
DS_ASSERT_OK(pq.Remove(&buffer));
ASSERT_EQ(*buffer, i);
}
for (int i = 7; i >= 4; --i) {
DS_ASSERT_OK(pq.Remove(&buffer));
ASSERT_EQ(*buffer, i);
}
DS_ASSERT_OK(pq.Remove(&buffer));
ASSERT_EQ(*buffer, 100);
DS_ASSERT_OK(pq.Remove(&buffer));
ASSERT_EQ(*buffer, 110);
DS_ASSERT_OK(pq.Remove(&buffer));
ASSERT_EQ(*buffer, 120);
DS_ASSERT_OK(pq.Remove(&buffer));
ASSERT_EQ(*buffer, 130);
ASSERT_TRUE(pq.Empty());
}
TEST_F(PriorityQueueTest, TestPeek)
{
size_t capacity = 1;
PriorityQueue<std::string> pq(capacity);
std::string str = "Hello world!";
DS_ASSERT_OK(pq.Put(str));
const std::string *ptr = nullptr;
DS_ASSERT_OK(pq.Peek(ptr));
ASSERT_EQ(*ptr, "Hello world!");
std::string buffer;
DS_ASSERT_OK(pq.Take(&buffer));
ASSERT_EQ(buffer, "Hello world!");
ASSERT_TRUE(pq.Empty());
}
}
}