* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Description: Testing read-write lock
*/
#include <cstdlib>
#include <string>
#include <vector>
#include <securec.h>
#include "datasystem/common/object_cache/lock.h"
#include "ut/common.h"
static constexpr char READ = 'r';
static constexpr char WRITE = 'w';
namespace datasystem {
namespace ut {
class RWLockTest : public CommonTest {};
class RWLockTestObject {
public:
explicit RWLockTestObject(int incTimes)
: incTimes_(incTimes), criticalVal_(0), readCounter_(0), lockFrame_(nullptr), lk_(nullptr)
{
size_t size = sizeof(uint32_t) + sizeof(uint8_t);
lockFrame_ = malloc(size);
memset_s(lockFrame_, size, 0, size);
lk_ = std::make_shared<datasystem::object_cache::ShmLock>(lockFrame_, size, 0);
lk_->Init();
}
RWLockTestObject(RWLockTestObject &&other) noexcept
: incTimes_(other.incTimes_),
criticalVal_(other.criticalVal_),
readCounter_(other.readCounter_),
lockFrame_(other.lockFrame_),
lk_(other.lk_)
{
}
RWLockTestObject(const RWLockTestObject &other) = delete;
~RWLockTestObject()
{
if (lockFrame_) {
free(lockFrame_);
}
}
void IncreaseCriticalVal()
{
for (int i = 0; i < incTimes_; i++) {
DS_ASSERT_OK(lk_->WLatch(60));
std::this_thread::sleep_for(std::chrono::milliseconds(1));
criticalVal_++;
lk_->UnWLatch();
}
}
void ReadCriticalVal()
{
DS_ASSERT_OK(lk_->RLatch(60));
__atomic_add_fetch(&readCounter_, 1, __ATOMIC_SEQ_CST);
std::this_thread::sleep_for(std::chrono::milliseconds(1));
lk_->UnRLatch();
}
int GetCriticalValue()
{
return criticalVal_;
}
int GetReaderCounter()
{
return readCounter_;
}
int GetLockFrameVal()
{
return *(static_cast<int *>(lockFrame_));
}
void UnmatchedLockUnlock(char ops)
{
if (ops == WRITE) {
DS_ASSERT_OK(lk_->WLatch(60));
lk_->UnRLatch();
}
if (ops == READ) {
lk_->UnWLatch();
DS_ASSERT_OK(lk_->RLatch(60));
lk_->UnWLatch();
}
}
void LockValueOverflowFail(char ops)
{
if (ops == 'w') {
DS_ASSERT_OK(lk_->RLatch(60));
lk_->UnRLatch();
lk_->UnRLatch();
}
if (ops == 'r') {
DS_ASSERT_OK(lk_->WLatch(60));
lk_->UnWLatch();
lk_->UnWLatch();
}
}
void ThreadALock()
{
for (int i = 0; i < incTimes_; i++) {
DS_ASSERT_OK(lk_->RLatch(60));
}
}
void ThreadBUnLock()
{
lk_->UnRLatch();
}
private:
int incTimes_;
int criticalVal_;
int readCounter_;
void *lockFrame_;
std::shared_ptr<datasystem::object_cache::Lock> lk_;
};
TEST_F(RWLockTest, ConcurrentWriteSuccess)
{
LOG(INFO) << "Test whether the rwlock can protect the critical section from concurrent writing";
int numWriters = 100, incTimes = 100;
RWLockTestObject RWT(incTimes);
std::vector<std::thread> writersPool;
for (int i = 0; i < numWriters; ++i) {
writersPool.push_back(std::move(std::thread(&RWLockTestObject::IncreaseCriticalVal, &RWT)));
}
for (auto &w : writersPool) {
w.join();
}
int criticalVal = RWT.GetCriticalValue();
int lockVal = RWT.GetLockFrameVal();
EXPECT_EQ(criticalVal, numWriters * incTimes);
EXPECT_EQ(lockVal, 0);
}
TEST_F(RWLockTest, ConcurrentReadSuccess)
{
LOG(INFO) << "Test whether the rwlock can allow concurrent read of the critical section";
int numReaders = 100, incTimes = 1;
RWLockTestObject RWT(incTimes);
std::vector<std::thread> readersPool;
for (int i = 0; i < numReaders; ++i) {
readersPool.push_back(std::move(std::thread(&RWLockTestObject::ReadCriticalVal, &RWT)));
}
for (auto &r : readersPool) {
r.join();
}
int readCnt = RWT.GetReaderCounter();
int lockVal = RWT.GetLockFrameVal();
EXPECT_EQ(readCnt, numReaders);
EXPECT_EQ(lockVal, 0);
}
TEST_F(RWLockTest, ConcurrentReadWriteSuccess)
{
LOG(INFO) << "Test the effectiveness of rwlock in the scenario of concurrent read and write interleaving";
int numReaders = 100, numWriters = 100, incTimes = 100;
RWLockTestObject RWT(incTimes);
std::vector<std::thread> readersWritersPool;
for (int i = 0, j = 0; i < numWriters || j < numReaders; ++i, ++j) {
if (i < numWriters) {
readersWritersPool.push_back(std::move(std::thread(&RWLockTestObject::IncreaseCriticalVal, &RWT)));
}
if (j < numReaders) {
readersWritersPool.push_back(std::move(std::thread(&RWLockTestObject::ReadCriticalVal, &RWT)));
}
}
for (auto &rw : readersWritersPool) {
rw.join();
}
int criticalVal = RWT.GetCriticalValue();
int readCnt = RWT.GetReaderCounter();
int lockVal = RWT.GetLockFrameVal();
EXPECT_EQ(criticalVal, numWriters * incTimes);
EXPECT_EQ(readCnt, numReaders);
EXPECT_EQ(lockVal, 0);
}
TEST_F(RWLockTest, UnmatchedLockUnlockFail)
{
LOG(INFO) << "Test unmatched Lock call fail";
int incTimes = 1;
RWLockTestObject RWT(incTimes);
RWT.UnmatchedLockUnlock('w');
int lockVal = RWT.GetLockFrameVal();
EXPECT_EQ(lockVal, 1);
RWT.UnmatchedLockUnlock('r');
lockVal = RWT.GetLockFrameVal();
EXPECT_EQ(lockVal, 2);
}
TEST_F(RWLockTest, LockValueOverflowFail)
{
LOG(INFO) << "Test lock value overflow fail";
int incTimes = 1;
RWLockTestObject RWT(incTimes);
RWT.LockValueOverflowFail(WRITE);
int lockVal = RWT.GetLockFrameVal();
EXPECT_EQ(lockVal, 0);
RWT.LockValueOverflowFail(READ);
lockVal = RWT.GetLockFrameVal();
EXPECT_EQ(lockVal, 0);
}
TEST_F(RWLockTest, ThreadAUnlockThreadBFail)
{
LOG(INFO) << "Test thread A unlocks thread B fail";
int numReaders = 100, incTimes = 100;
RWLockTestObject RWT(incTimes);
std::thread t1(&RWLockTestObject::ThreadALock, &RWT);
std::vector<std::thread> readersPool;
for (int i = 0; i < numReaders; ++i) {
readersPool.push_back(std::move(std::thread(&RWLockTestObject::ThreadBUnLock, &RWT)));
}
t1.join();
for (auto &r : readersPool) {
r.join();
}
int lockVal = RWT.GetLockFrameVal();
EXPECT_EQ(lockVal, 2 * incTimes);
}
}
}