* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Description: Test ClientManager class
*/
#include "datasystem/worker/client_manager/client_manager.h"
#include <gtest/gtest.h>
#include <sys/socket.h>
#include <atomic>
#include <chrono>
#include <string>
#include <thread>
#include "ut/common.h"
#include "datasystem/common/inject/inject_point.h"
#include "datasystem/common/rpc/unix_sock_fd.h"
#include "datasystem/common/string_intern/string_ref.h"
#include "datasystem/common/util/uuid_generator.h"
DS_DECLARE_uint32(max_client_num);
using namespace datasystem::worker;
namespace datasystem {
namespace ut {
class ClientInfoTest : public CommonTest {
};
class ClientManagerTest : public CommonTest {
};
TEST_F(ClientInfoTest, TestInvalid)
{
int socketFd = 1;
auto clientId = ClientKey::Intern("client1");
bool uniqueCount = false;
ClientInfo clientInfo(socketFd, clientId, uniqueCount);
std::shared_ptr<ShmUnit> shmUnit;
EXPECT_FALSE(clientInfo.RemoveShmUnit(shmUnit));
}
TEST_F(ClientManagerTest, TestAddShmUnitUniqueCount)
{
ClientManager &clientMgr = ClientManager::Instance();
DS_ASSERT_OK(clientMgr.Init());
auto clientId = ClientKey::Intern(GetBytesUuid());
int socketFd = 1;
DS_ASSERT_OK(clientMgr.AddClient(clientId, socketFd, true));
int getSocketFd;
DS_ASSERT_OK(clientMgr.GetClientSocketFd(clientId, getSocketFd));
ASSERT_TRUE(socketFd == getSocketFd);
auto shmUnit = std::make_shared<ShmUnit>();
shmUnit->id = ShmKey::Intern(GetBytesUuid());
shmUnit->fd = socketFd;
DS_ASSERT_OK(clientMgr.AddShmUnit(clientId, shmUnit));
ASSERT_EQ(shmUnit->refCount, 1);
DS_ASSERT_OK(clientMgr.AddShmUnit(clientId, shmUnit));
ASSERT_EQ(shmUnit->refCount, 1);
DS_ASSERT_OK(clientMgr.RemoveShmUnit(clientId, shmUnit));
ASSERT_EQ(shmUnit->refCount, 0);
}
TEST_F(ClientManagerTest, TestAddShmUnit)
{
ClientManager &clientMgr = ClientManager::Instance();
DS_ASSERT_OK(clientMgr.Init());
auto clientId = ClientKey::Intern(GetBytesUuid());
int socketFd = 1;
DS_ASSERT_OK(clientMgr.AddClient(clientId, socketFd, false));
int getSocketFd;
DS_ASSERT_OK(clientMgr.GetClientSocketFd(clientId, getSocketFd));
ASSERT_TRUE(socketFd == getSocketFd);
std::shared_ptr<ClientInfo> clientInfo = clientMgr.GetClientInfo(clientId);
ASSERT_TRUE(clientInfo != nullptr);
auto shmUnit = std::make_shared<ShmUnit>();
shmUnit->id = ShmKey::Intern(GetBytesUuid());
shmUnit->fd = socketFd;
DS_ASSERT_OK(clientMgr.AddShmUnit(clientId, shmUnit));
DS_ASSERT_OK(clientMgr.AddShmUnit(clientId, shmUnit));
DS_ASSERT_OK(clientMgr.AddShmUnit(clientId, shmUnit));
ASSERT_GE(shmUnit->refCount, 3);
std::unordered_map<ShmKey, uint32_t> shmUnitIds;
clientInfo->GetShmUnitIds(shmUnitIds);
ASSERT_TRUE(shmUnitIds.find(shmUnit->id) != shmUnitIds.end());
ASSERT_EQ(shmUnitIds[shmUnit->id], static_cast<uint32_t>(3));
DS_ASSERT_OK(clientMgr.RemoveShmUnit(clientId, shmUnit));
ASSERT_EQ(shmUnit->refCount, 2);
clientInfo->GetShmUnitIds(shmUnitIds);
ASSERT_TRUE(shmUnitIds.find(shmUnit->id) != shmUnitIds.end());
ASSERT_EQ(shmUnitIds[shmUnit->id], size_t(2));
DS_ASSERT_OK(clientMgr.RemoveShmUnit(clientId, shmUnit));
DS_ASSERT_OK(clientMgr.RemoveShmUnit(clientId, shmUnit));
ASSERT_EQ(shmUnit->refCount, 0);
clientInfo->GetShmUnitIds(shmUnitIds);
ASSERT_TRUE(shmUnitIds.find(shmUnit->id) == shmUnitIds.end());
DS_ASSERT_OK(clientMgr.RemoveShmUnit(clientId, shmUnit));
ASSERT_EQ(shmUnit->refCount, 0);
}
TEST_F(ClientManagerTest, TestSessionQuantityChange)
{
ClientManager &clientMgr = ClientManager::Instance();
DS_ASSERT_OK(clientMgr.Init());
auto clientId = ClientKey::Intern(GetBytesUuid());
int socketFd = 1;
DS_ASSERT_OK(clientMgr.AddClient(clientId, socketFd));
auto clientInfoPtr = clientMgr.GetClientInfo(clientId);
EXPECT_NE(clientInfoPtr, nullptr);
std::string readerSessionId = GetBytesUuid();
ASSERT_TRUE(clientInfoPtr->AddReaderSessionId(readerSessionId));
std::string writerSessionId = GetBytesUuid();
ASSERT_TRUE(clientInfoPtr->AddWriterSessionId(writerSessionId));
std::unordered_set<std::string> readerSessionIds;
std::unordered_set<std::string> writerSessionIds;
clientInfoPtr->GetReaderSessionIds(readerSessionIds);
clientInfoPtr->GetWriterSessionIds(writerSessionIds);
EXPECT_EQ(readerSessionIds.size(), size_t(1));
EXPECT_EQ(writerSessionIds.size(), size_t(1));
DS_ASSERT_OK(clientInfoPtr->RemoveReaderSessionId(readerSessionId));
DS_ASSERT_OK(clientInfoPtr->RemoveWriterSessionId(writerSessionId));
clientInfoPtr->GetClientId();
clientInfoPtr->SetUniqueCount(false);
}
TEST_F(ClientManagerTest, TestCheckClientId)
{
ClientManager &clientMgr = ClientManager::Instance();
DS_ASSERT_OK(clientMgr.Init());
auto clientId = ClientKey::Intern(GetBytesUuid());
int socketFd = 1;
DS_ASSERT_OK(clientMgr.AddClient(clientId, socketFd));
std::vector<std::string> clientIds;
DS_ASSERT_OK(clientMgr.CheckClientId(clientId));
DS_ASSERT_NOT_OK(clientMgr.CheckClientId(ClientKey::Intern("123344")));
}
TEST_F(ClientManagerTest, TestRemoveShmUnitOfAllClient)
{
ClientManager &clientMgr = ClientManager::Instance();
DS_ASSERT_OK(clientMgr.Init());
auto clientId = ClientKey::Intern(GetBytesUuid());
int socketFd = 1;
DS_ASSERT_OK(clientMgr.AddClient(clientId, socketFd, true));
int getSocketFd;
DS_ASSERT_OK(clientMgr.GetClientSocketFd(clientId, getSocketFd));
ASSERT_TRUE(socketFd == getSocketFd);
auto shmUnit = std::make_shared<ShmUnit>();
shmUnit->id = ShmKey::Intern(GetBytesUuid());
shmUnit->fd = socketFd;
DS_ASSERT_OK(clientMgr.AddShmUnit(clientId, shmUnit));
ASSERT_GE(shmUnit->refCount, 1);
DS_ASSERT_OK(clientMgr.RemoveShmUnit(shmUnit));
ASSERT_EQ(shmUnit->refCount, 0);
}
TEST_F(ClientManagerTest, TestRunCallbackByUDSHeartbeat)
{
ClientManager &clientMgr = ClientManager::Instance();
DS_ASSERT_OK(clientMgr.Init());
uint32_t lockId;
auto clientId = ClientKey::Intern("clientId");
std::string sockPath = "/tmp/" + GetStringUuid();
unlink(sockPath.data());
UnixSockFd sockFd;
sockaddr_un addr{};
DS_ASSERT_OK(sockFd.CreateUnixSocket());
DS_ASSERT_OK(UnixSockFd::SetUpSockPath(sockPath, addr));
DS_ASSERT_OK(sockFd.BindUds(addr, RPC_SOCK_MODE));
int socketFd = sockFd.GetFd();
ASSERT_TRUE(clientMgr.AddClient(clientId, true, socketFd, "", "", "", "", lockId));
bool called = false;
auto callback = [&called]() {
LOG(INFO) << "run callback";
called = true;
};
ASSERT_TRUE(clientMgr.RegisterLostHandler(clientId, callback, HeartbeatType::UDS_HEARTBEAT));
ASSERT_FALSE(called);
int ret = shutdown(socketFd, SHUT_RDWR);
int waitCallback = 100'000;
usleep(waitCallback);
ASSERT_TRUE(ret == 0);
ASSERT_TRUE(called);
clientMgr.RemoveClient(clientId);
}
TEST_F(ClientManagerTest, DISABLED_TestRunCallbackByRPCHeartbeat)
{
datasystem::inject::Set("ClientManager.Init.heartbeatInterval", "call(500)");
datasystem::inject::Set("ClientManager.IsClientLost.heartbeatThreshold", "call(1)");
ClientManager &clientMgr = ClientManager::Instance();
DS_ASSERT_OK(clientMgr.Init());
uint32_t lockId;
auto clientId = ClientKey::Intern("clientId");
ASSERT_TRUE(clientMgr.AddClient(clientId, true, -1, "", "", "", "", lockId));
std::atomic<bool> called{ false };
auto callback = [&called]() {
LOG(INFO) << "run callback";
called = true;
};
ASSERT_TRUE(clientMgr.RegisterLostHandler(clientId, callback, HeartbeatType::RPC_HEARTBEAT));
ASSERT_FALSE(called);
std::this_thread::sleep_for(std::chrono::seconds(2));
auto clientInfo = clientMgr.GetClientInfo(clientId);
ASSERT_TRUE(clientInfo->IsClientLost());
ASSERT_TRUE(called);
clientMgr.RemoveClient(clientId);
}
TEST_F(ClientManagerTest, TestAddClientFailed)
{
FLAGS_max_client_num = 10;
ClientManager &clientMgr = ClientManager::Instance();
DS_ASSERT_OK(clientMgr.Init());
auto clientId = ClientKey::Intern(GetBytesUuid());
uint32_t lockId;
DS_ASSERT_OK(clientMgr.AddClient(clientId, true, -1, "", true, "", "", lockId));
for (int i = 0; i < 20; i++) {
auto status = clientMgr.AddClient(clientId, true, -1, "", true, "", "", lockId);
ASSERT_TRUE(status.GetMsg().find("Failed to insert client") != std::string::npos) << status.GetMsg();
}
}
TEST_F(ClientManagerTest, TestMaxClientNumOnlyCountsShmEnabled)
{
FLAGS_max_client_num = 1;
ClientManager &clientMgr = ClientManager::Instance();
DS_ASSERT_OK(clientMgr.Init());
auto shmClientId1 = ClientKey::Intern("shm_client_1");
uint32_t lockId1 = 0;
DS_ASSERT_OK(clientMgr.AddClient(shmClientId1, true, -1, "", true, "", "", lockId1));
ASSERT_NE(lockId1, 0u);
auto shmClientId2 = ClientKey::Intern("shm_client_2");
uint32_t lockId2 = 0;
auto rc = clientMgr.AddClient(shmClientId2, true, -1, "", true, "", "", lockId2);
ASSERT_TRUE(rc.IsError());
ASSERT_TRUE(rc.GetMsg().find("Client number upper to the limit") != std::string::npos) << rc.ToString();
for (int i = 0; i < 10; ++i) {
auto nonShmClientId = ClientKey::Intern("non_shm_client_" + std::to_string(i));
uint32_t nonShmLockId = 123;
DS_ASSERT_OK(clientMgr.AddClient(nonShmClientId, false, -1, "", true, "", "", nonShmLockId));
ASSERT_EQ(nonShmLockId, 0u);
}
}
TEST_F(ClientManagerTest, TestRemovableClientCount)
{
ClientManager &clientMgr = ClientManager::Instance();
DS_ASSERT_OK(clientMgr.Init());
size_t count = 100;
uint32_t lockId;
for (size_t i = 0; i < count; ++i) {
auto clientId = ClientKey::Intern("client_id" + std::to_string(i));
ASSERT_TRUE(clientMgr.AddClient(clientId, true, -1, "", true, "", "", lockId));
}
for (size_t i = 0; i < count; ++i) {
auto clientId = ClientKey::Intern("client_id" + std::to_string(i));
ASSERT_FALSE(clientMgr.AddClient(clientId, true, -1, "", true, "", "", lockId));
}
ASSERT_EQ(clientMgr.GetClientCount(), count);
size_t removableCount = 30;
int loop = 3;
for (int k = 0; k < loop; ++k) {
for (size_t i = 0; i < removableCount; ++i) {
auto clientId = ClientKey::Intern("client_id" + std::to_string(i));
ASSERT_TRUE(clientMgr.UpdateLastHeartbeat(clientId, true));
}
ASSERT_EQ(clientMgr.GetClientCount(), count - removableCount);
}
for (int k = 0; k < loop; ++k) {
for (size_t i = 0; i < removableCount; ++i) {
auto clientId = ClientKey::Intern("client_id" + std::to_string(i));
ASSERT_TRUE(clientMgr.UpdateLastHeartbeat(clientId, false));
}
ASSERT_EQ(clientMgr.GetClientCount(), count);
}
for (size_t i = 0; i < removableCount; ++i) {
auto clientId = ClientKey::Intern("client_id" + std::to_string(i));
ASSERT_TRUE(clientMgr.UpdateLastHeartbeat(clientId, true));
}
ASSERT_EQ(clientMgr.GetClientCount(), count - removableCount);
for (size_t i = 0; i < removableCount; ++i) {
auto clientId = ClientKey::Intern("client_id" + std::to_string(i));
clientMgr.RemoveClient(clientId);
}
ASSERT_EQ(clientMgr.GetClientCount(), count - removableCount);
size_t decrCount = removableCount * 2;
for (size_t i = removableCount; i < decrCount; ++i) {
auto clientId = ClientKey::Intern("client_id" + std::to_string(i));
clientMgr.RemoveClient(clientId);
}
ASSERT_EQ(clientMgr.GetClientCount(), count - decrCount);
}
}
}