* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <thread>
#include <chrono>
#include "ut/common.h"
#include "datasystem/common/object_cache/safe_object.h"
#include "datasystem/common/immutable_string/immutable_string.h"
#include "datasystem/common/immutable_string/immutable_string_pool.h"
#include "datasystem/common/object_cache/object_base.h"
#include "datasystem/common/object_cache/safe_table.h"
#include "datasystem/common/util/raii.h"
#include "datasystem/common/util/thread_pool.h"
#include "datasystem/common/util/uuid_generator.h"
namespace datasystem {
namespace ut {
class SafeObjectTest : public CommonTest {
public:
SafeObjectTest() : numObjects_(0), numWriterThreads_(0), numReaderThreads_(0), currObjKey_(0)
{
}
void JoinThreads()
{
if (!writerThreads_.empty()) {
for (auto &writerThread : writerThreads_) {
writerThread.join();
}
}
if (!readerThreads_.empty()) {
for (auto &readerThread : readerThreads_) {
readerThread.join();
}
}
}
int GetNextObjKey()
{
if (currObjKey_.compare_exchange_strong(numObjects_, 0)) {
return 0;
} else {
return currObjKey_.fetch_add(1);
}
}
int GetRandomInt(int min, int max)
{
std::random_device rD;
static thread_local std::mt19937 randGen(rD());
std::uniform_int_distribution<int> distr(min, max);
return distr(randGen);
}
std::string NewObjectKey()
{
return GetStringUuid();
}
protected:
int numObjects_;
int numWriterThreads_;
int numReaderThreads_;
std::atomic<int> currObjKey_;
std::vector<std::thread> writerThreads_;
std::vector<std::thread> readerThreads_;
};
class BaseObj : public ObjectInterface {
public:
explicit BaseObj(int bVal) : baseVal_(bVal)
{
}
virtual void print()
{
LOG(INFO) << "base print has val: " << baseVal_ << std::endl;
}
virtual void mustOverride(int num) = 0;
Status FreeResources() override
{
return Status::OK();
}
protected:
int baseVal_;
};
class DerivedObj : public BaseObj {
public:
DerivedObj(int dVal, int bVal) : BaseObj(bVal), derivedVal_(dVal)
{
}
void print() override
{
LOG(INFO) << "derived print. base has val: " << baseVal_ << " and derived val: " << derivedVal_ << std::endl;
}
void mustOverride(int num) override
{
LOG(INFO) << "must override: " << num << std::endl;
}
void notVirtualFunction()
{
LOG(INFO) << "called notVirtualFunction. base class does not have this";
}
private:
int derivedVal_;
};
TEST_F(SafeObjectTest, TestConstruct1)
{
using SafeString = SafeObject<std::string>;
std::string data1("hello1");
auto data2 = std::make_unique<std::string>("hello2");
auto objPtr1 = std::make_shared<SafeString>(data1);
auto objPtr2 = std::make_shared<SafeString>(std::move(data2));
SafeString &safeString1 = *objPtr1;
std::string &fetchedData = *safeString1;
LOG(INFO) << "Data is: " << fetchedData;
fetchedData.insert(6, " world");
LOG(INFO) << "Data is: " << *(*objPtr1);
SafeString &safeString2 = *objPtr2;
safeString2->insert(6, " world");
LOG(INFO) << "Data is: " << *(*objPtr2);
}
TEST_F(SafeObjectTest, TestConstruct2)
{
using SafeString = SafeObject<std::string>;
std::string data1("hello1");
auto data2 = std::make_unique<std::string>("hello2");
auto objPtr1 = std::make_shared<SafeString>();
auto objPtr2 = std::make_shared<SafeString>(std::unique_ptr<std::string>(nullptr));
objPtr1->SetRealObject(data1);
objPtr2->SetRealObject(std::move(data2));
LOG(INFO) << "Data is: " << *(*objPtr1);
}
TEST_F(SafeObjectTest, TestLocks1)
{
using SafeString = SafeObject<std::string>;
const int numLoops = 10;
numWriterThreads_ = 10;
numObjects_ = 2;
std::vector<std::shared_ptr<SafeString>> safeStrings;
safeStrings.push_back(std::move(std::make_shared<SafeString>("abc")));
safeStrings.push_back(std::move(std::make_shared<SafeString>("def")));
for (int threadId = 0; threadId < numWriterThreads_; ++threadId) {
writerThreads_.push_back(std::thread([this, threadId, &numLoops, &safeStrings]() {
for (int i = 0; i < numLoops; ++i) {
int id = GetRandomInt(0, (numObjects_ - 1));
DS_ASSERT_OK(safeStrings[id]->WLock());
LOG(INFO) << "Thread[" << threadId << "] got lock and appending an \"a\" to SafeObject[" << id << "]";
*(*(safeStrings[id])) += "a";
safeStrings[id]->WUnlock();
}
}));
}
JoinThreads();
LOG(INFO) << "SafeObject[0]: " << *(*(safeStrings[0]));
LOG(INFO) << "SafeObject[1]: " << *(*(safeStrings[1]));
}
TEST_F(SafeObjectTest, TestLocks2)
{
using SafeString = SafeObject<std::string>;
const int numLoops = 10;
numReaderThreads_ = 10;
numObjects_ = 2;
std::vector<std::shared_ptr<SafeString>> safeStrings;
safeStrings.push_back(std::move(std::make_shared<SafeString>("abc")));
safeStrings.push_back(std::move(std::make_shared<SafeString>("def")));
for (int threadId = 0; threadId < numReaderThreads_; ++threadId) {
readerThreads_.push_back(std::thread([this, threadId, &numLoops, &safeStrings]() {
for (int i = 0; i < numLoops; ++i) {
int id = GetRandomInt(0, (numObjects_ - 1));
DS_ASSERT_OK(safeStrings[id]->RLock());
LOG(INFO) << "Thread[" << threadId << "]: " << *(*(safeStrings[id]));
safeStrings[id]->RUnlock();
}
}));
}
JoinThreads();
}
TEST_F(SafeObjectTest, TestLocks3)
{
using SafeString = SafeObject<std::string>;
const int numLoops = 100;
const int numWLoops = 20;
numReaderThreads_ = 10;
numWriterThreads_ = 2;
numObjects_ = 2;
std::vector<std::shared_ptr<SafeString>> safeStrings;
safeStrings.push_back(std::move(std::make_shared<SafeString>("abc")));
safeStrings.push_back(std::move(std::make_shared<SafeString>("def")));
for (int threadId = 0; threadId < numReaderThreads_; ++threadId) {
readerThreads_.push_back(std::thread([this, threadId, &numLoops, &safeStrings]() {
for (int i = 0; i < numLoops; ++i) {
int id = GetRandomInt(0, (numObjects_ - 1));
DS_ASSERT_OK(safeStrings[id]->RLock());
LOG(INFO) << "Thread[" << threadId << "]: " << *(*(safeStrings[id]));
safeStrings[id]->RUnlock();
}
}));
}
for (int threadId = 0; threadId < numWriterThreads_; ++threadId) {
writerThreads_.push_back(std::thread([this, threadId, &numWLoops, &safeStrings]() {
for (int i = 0; i < numWLoops; ++i) {
int id = GetRandomInt(0, (numObjects_ - 1));
std::this_thread::sleep_for(std::chrono::microseconds(100));
DS_ASSERT_OK(safeStrings[id]->WLock());
LOG(INFO) << "Thread[" << threadId << "] got lock and appending an \"a\" to SafeObject[" << id << "]";
*(*(safeStrings[id])) += "a";
safeStrings[id]->WUnlock();
}
}));
}
JoinThreads();
}
TEST_F(SafeObjectTest, TestLocks4)
{
using SafeString = SafeObject<std::string>;
const int numLoops = 100;
const int numWLoops = 20;
numReaderThreads_ = 10;
numWriterThreads_ = 2;
numObjects_ = 2;
std::vector<std::shared_ptr<SafeString>> safeStrings;
safeStrings.push_back(std::move(std::make_shared<SafeString>("abc")));
safeStrings.push_back(std::move(std::make_shared<SafeString>("def")));
for (int threadId = 0; threadId < numReaderThreads_; ++threadId) {
readerThreads_.push_back(std::thread([this, threadId, &numLoops, &safeStrings]() {
for (int i = 0; i < numLoops; ++i) {
int id = GetRandomInt(0, (numObjects_ - 1));
DS_ASSERT_OK(safeStrings[id]->RLock());
LOG(INFO) << "Thread[" << threadId << "]: " << *(*(safeStrings[id]));
safeStrings[id]->RUnlock();
}
}));
}
for (int threadId = 0; threadId < numWriterThreads_; ++threadId) {
writerThreads_.push_back(std::thread([this, threadId, &numWLoops, &safeStrings]() {
for (int i = 0; i < numWLoops; ++i) {
int id = GetRandomInt(0, (numObjects_ - 1));
std::this_thread::sleep_for(std::chrono::microseconds(100));
DS_ASSERT_OK(safeStrings[id]->WLock());
LOG(INFO) << "Thread[" << threadId << "] got lock and appending an \"a\" to SafeObject[" << id << "]";
*(*(safeStrings[id])) += "a";
safeStrings[id]->WUnlock();
}
}));
}
JoinThreads();
}
TEST_F(SafeObjectTest, TestDelete)
{
using SafeString = SafeObject<std::string>;
auto objPtr1 = std::make_shared<SafeString>("abc");
std::shared_ptr<SafeString> objPtr2 = objPtr1;
DS_ASSERT_OK(objPtr1->WLock());
objPtr1->DeleteObject();
objPtr1->WUnlock();
objPtr1.reset();
Status rc = objPtr2->WLock();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_NOT_FOUND);
rc = objPtr2->RLock();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_NOT_FOUND);
}
TEST_F(SafeObjectTest, TestTryWriteLocks)
{
using SafeString = SafeObject<std::string>;
auto objPtr1 = std::make_shared<SafeString>("abc");
DS_ASSERT_OK(objPtr1->WLock());
Status rc = objPtr1->TryWLock();
LOG(INFO) << "TryWLock returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_TRY_AGAIN);
objPtr1->WUnlock();
DS_ASSERT_OK(objPtr1->RLock());
rc = objPtr1->TryWLock();
LOG(INFO) << "TryWLock returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_TRY_AGAIN);
objPtr1->RUnlock();
DS_ASSERT_OK(objPtr1->WLock());
rc = objPtr1->TryWLock();
LOG(INFO) << "TryWLock returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_TRY_AGAIN);
objPtr1->WUnlock();
objPtr1->RLock();
rc = objPtr1->TryWLock();
LOG(INFO) << "TryWLock returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_TRY_AGAIN);
objPtr1->RUnlock();
rc = objPtr1->TryWLock();
LOG(INFO) << "TryWLock returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_OK);
objPtr1->DeleteObject();
objPtr1->WUnlock();
rc = objPtr1->TryWLock();
LOG(INFO) << "TryWLock returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_NOT_FOUND);
}
TEST_F(SafeObjectTest, TestTryReadLocks)
{
using SafeString = SafeObject<std::string>;
auto objPtr1 = std::make_shared<SafeString>("abc");
DS_ASSERT_OK(objPtr1->WLock());
Status rc = objPtr1->TryRLock();
LOG(INFO) << "TryRLock returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_TRY_AGAIN);
objPtr1->WUnlock();
objPtr1->RLock();
rc = objPtr1->TryRLock();
LOG(INFO) << "TryRLock returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_OK);
objPtr1->RUnlock();
objPtr1->RUnlock();
DS_ASSERT_OK(objPtr1->WLock());
rc = objPtr1->TryRLock();
LOG(INFO) << "TryRLock returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_TRY_AGAIN);
objPtr1->WUnlock();
objPtr1->RLock();
rc = objPtr1->TryRLock();
LOG(INFO) << "TryRLock returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_OK);
objPtr1->RUnlock();
objPtr1->RUnlock();
objPtr1->WLock();
objPtr1->DeleteObject();
rc = objPtr1->TryRLock();
LOG(INFO) << "TryRLock returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_TRY_AGAIN);
}
TEST_F(SafeObjectTest, TestTableBasic)
{
const int numKeys = 3;
using SafeString = SafeObject<std::string>;
using StringTable = SafeTable<std::string, std::string>;
StringTable safeTable;
std::vector<std::string> keys;
for (int i = 0; i < numKeys; ++i) {
keys.push_back(NewObjectKey());
}
std::shared_ptr<SafeString> fetchedObj = nullptr;
Status rc = safeTable.Get(keys[0], fetchedObj);
LOG(INFO) << "Get returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_NOT_FOUND);
std::string objData1("hello1");
DS_ASSERT_OK(safeTable.Insert(keys[0], objData1));
rc = safeTable.Insert(keys[0], objData1);
LOG(INFO) << "Insert returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_DUPLICATED);
std::string notExistKey("not_a_key");
DS_ASSERT_OK(safeTable.Contains(keys[0]));
rc = safeTable.Contains(notExistKey);
ASSERT_TRUE(rc.GetCode() == StatusCode::K_NOT_FOUND);
auto objData2 = std::make_unique<std::string>("hello2");
DS_ASSERT_OK(safeTable.Insert(keys[1], std::move(objData2)));
objData2 = std::make_unique<std::string>("hello2");
rc = safeTable.Insert(keys[1], std::move(objData2));
LOG(INFO) << "Insert returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_DUPLICATED);
auto objData3 = std::make_shared<SafeString>("hello3");
DS_ASSERT_OK(safeTable.Insert(keys[2], objData3));
rc = safeTable.Insert(keys[2], objData3);
LOG(INFO) << "Insert returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_DUPLICATED);
DS_ASSERT_OK(safeTable.Get(keys[1], fetchedObj));
LOG(INFO) << "Fetched object: " << *(*fetchedObj);
}
TEST_F(SafeObjectTest, ReserveAndCreate)
{
using SafeString = SafeObject<std::string>;
using StringTable = SafeTable<std::string, std::string>;
StringTable safeTable;
std::string key1 = NewObjectKey();
std::shared_ptr<SafeString> objPtr = nullptr;
DS_ASSERT_OK(safeTable.ReserveAndLock(key1, objPtr));
Status rc = objPtr->TryWLock(true);
LOG(INFO) << "TryWLock returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_TRY_AGAIN);
auto testData = std::make_unique<std::string>("hello");
objPtr->SetRealObject(std::move(testData));
objPtr->WUnlock();
objPtr.reset();
rc = safeTable.ReserveAndLock(key1, objPtr);
LOG(INFO) << "ReserveAndLock returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_DUPLICATED);
}
TEST_F(SafeObjectTest, TestErase)
{
using SafeString = SafeObject<std::string>;
using StringTable = SafeTable<std::string, std::string>;
StringTable safeTable;
std::string key1 = NewObjectKey();
Status rc = safeTable.Erase(key1);
LOG(INFO) << "Erase returned: " << rc.ToString();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_NOT_FOUND);
std::string objData1("hello1");
DS_ASSERT_OK(safeTable.Insert(key1, objData1));
std::shared_ptr<SafeString> fetchedObj = nullptr;
DS_ASSERT_OK(safeTable.Get(key1, fetchedObj));
DS_ASSERT_OK(safeTable.Erase(key1));
ASSERT_EQ(safeTable.GetSize(), 0ul);
rc = fetchedObj->WLock();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_NOT_FOUND);
DS_ASSERT_OK(safeTable.Insert(key1, objData1));
fetchedObj = nullptr;
DS_ASSERT_OK(safeTable.Get(key1, fetchedObj));
DS_ASSERT_OK(safeTable.Erase(key1, *fetchedObj));
ASSERT_EQ(safeTable.GetSize(), 0ul);
rc = fetchedObj->WLock();
ASSERT_TRUE(rc.GetCode() == StatusCode::K_NOT_FOUND);
}
TEST_F(SafeObjectTest, TestIterate1)
{
const int numKeys = 3;
using StringTable = SafeTable<std::string, std::string>;
StringTable safeTable;
std::string objData("hello");
for (int i = 0; i < numKeys; ++i) {
DS_ASSERT_OK(safeTable.Insert(NewObjectKey(), objData));
}
for (auto &kv : safeTable) {
LOG(INFO) << "key: " << kv.first << " data: " << *(*kv.second);
}
}
TEST_F(SafeObjectTest, TestIterate2)
{
const int numKeys = 3;
using StringTable = SafeTable<std::string, std::string>;
StringTable safeTable;
numReaderThreads_ = 2;
std::vector<std::string> keys;
std::string objData("hello");
for (int i = 0; i < numKeys; ++i) {
keys.push_back(NewObjectKey());
DS_ASSERT_OK(safeTable.Insert(keys[i], objData));
}
{
StringTable::Iterator iter = std::begin(safeTable);
for (int threadId = 0; threadId < numReaderThreads_; ++threadId) {
readerThreads_.push_back(std::thread([this, threadId, &safeTable, &keys]() {
std::string objData("test");
std::string newKey = NewObjectKey();
LOG(INFO) << "Try to insert. Should lock on the table lock because iteration in progress.";
DS_ASSERT_OK(safeTable.Insert(newKey, objData));
LOG(INFO) << "Insert done now";
}));
}
std::this_thread::sleep_for(std::chrono::seconds(2));
while (iter != std::end(safeTable)) {
iter->second->RLock();
LOG(INFO) << "key: " << iter->first << " data: " << *(*iter->second);
iter->second->RUnlock();
std::this_thread::sleep_for(std::chrono::seconds(2));
++iter;
}
}
JoinThreads();
}
TEST_F(SafeObjectTest, TestBaseObj)
{
using SafeBase = SafeObject<BaseObj>;
using BaseTable = SafeTable<std::string, BaseObj>;
BaseTable safeTable;
std::string key1 = NewObjectKey();
std::unique_ptr<BaseObj> baseObj = std::make_unique<DerivedObj>(13, 99);
DS_ASSERT_OK(safeTable.Insert(key1, std::move(baseObj)));
std::shared_ptr<SafeBase> fetchedObj = nullptr;
DS_ASSERT_OK(safeTable.Get(key1, fetchedObj));
fetchedObj->Get()->print();
DerivedObj *childClass = SafeBase::GetDerived<DerivedObj>(*fetchedObj);
childClass->notVirtualFunction();
}
TEST_F(SafeObjectTest, TestIsWLockedByCurrentThread)
{
using SafeString = SafeObject<std::string>;
auto objPtr = std::make_shared<SafeString>("abc");
ThreadPool threadPool(2);
auto fut = threadPool.Submit([objPtr]() {
DS_ASSERT_OK(objPtr->WLock());
ASSERT_EQ(objPtr->IsWLockedByCurrentThread(), true);
objPtr->WUnlock();
DS_ASSERT_OK(objPtr->RLock());
usleep(50000);
ASSERT_EQ(objPtr->IsWLockedByCurrentThread(), false);
objPtr->RUnlock();
});
auto fut2 = threadPool.Submit([objPtr]() {
usleep(25000);
DS_ASSERT_OK(objPtr->WLock());
ASSERT_EQ(objPtr->IsWLockedByCurrentThread(), true);
objPtr->WUnlock();
});
fut.get();
fut2.get();
}
TEST_F(SafeObjectTest, TestReadWriteLock)
{
using SafeBase = SafeObject<BaseObj>;
using BaseTable = SafeTable<ImmutableString, BaseObj>;
BaseTable safeTable;
std::string key = NewObjectKey();
ImmutableStringPool::Instance().Init();
intern::StringPool::InitAll();
std::atomic<bool> running{ true };
const int printBatch = 100;
const int threadCount = 3;
const int delayMs = 10;
ThreadPool threadPool(threadCount);
auto f1 = threadPool.Submit([&safeTable, &running, delayMs] {
int64_t c1 = 0;
while (running) {
std::vector<std::string> list;
for (const auto &item : safeTable) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
list.emplace_back(item.first);
}
c1++;
if (c1 % printBatch == 0) {
LOG(INFO) << "iterator loop:" << c1;
std::this_thread::sleep_for(std::chrono::milliseconds(delayMs));
}
}
});
auto f2 = threadPool.Submit([&safeTable, &running, &key, delayMs] {
int64_t c2 = 0;
while (running) {
std::shared_ptr<SafeBase> safeObj;
bool isInsert;
Status rc = safeTable.ReserveGetAndLock(key, safeObj, isInsert);
if (rc.IsError()) {
continue;
}
Raii raii([&safeObj] { safeObj->WUnlock(); });
c2++;
if (c2 % printBatch == 0) {
LOG(INFO) << "ReserveGetAndLock loop:" << c2;
std::this_thread::sleep_for(std::chrono::milliseconds(delayMs));
}
}
});
auto f3 = threadPool.Submit([&safeTable, &running, &key, delayMs] {
int64_t c3 = 0;
while (running) {
std::shared_ptr<SafeBase> safeObj;
Status rc = safeTable.Get(key, safeObj);
if (rc.IsError()) {
continue;
}
rc = safeObj->WLock(true);
if (rc.IsError()) {
continue;
}
Raii raii([&safeObj] { safeObj->WUnlock(); });
rc = safeTable.Erase(key, *safeObj);
ASSERT_TRUE(rc.IsOk());
c3++;
if (c3 % printBatch == 0) {
LOG(INFO) << "Get and erase loop:" << c3;
std::this_thread::sleep_for(std::chrono::milliseconds(delayMs));
}
}
});
const int testTime = 10000;
std::this_thread::sleep_for(std::chrono::milliseconds(testTime));
running = false;
f1.get();
f2.get();
f3.get();
}
}
}