* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Description: Defines object reference info class ObjectMemoryRefTable class and ObjectGlobalRefTable.
*/
#include "datasystem/common/object_cache/object_ref_info.h"
#include <algorithm>
#include <chrono>
#include <cmath>
#include "datasystem/common/immutable_string/immutable_string.h"
#include "datasystem/common/inject/inject_point.h"
#include "datasystem/common/log/log.h"
#include "datasystem/common/metrics/kv_metrics.h"
#include "datasystem/common/rpc/rpc_constants.h"
#include "datasystem/common/shared_memory/allocator.h"
#include "datasystem/common/string_intern/string_ref.h"
#include "datasystem/common/util/format.h"
#include "datasystem/common/util/raii.h"
#include "datasystem/common/util/strings_util.h"
#include "datasystem/common/util/thread_local.h"
#include "datasystem/common/util/timer.h"
#include "datasystem/common/util/uuid_generator.h"
#include "datasystem/utils/status.h"
namespace datasystem {
namespace object_cache {
namespace {
constexpr uint64_t SHM_REF_RECONCILE_MIN_EXPIRE_MS = 1000UL;
constexpr uint64_t SHM_REF_RECONCILE_MAX_EXPIRE_MS = RPC_TIMEOUT;
constexpr uint64_t SHM_REF_FLUSH_INTERVAL_MS = 1000UL;
constexpr size_t MAX_MAYBE_EXPIRED_SHM_IDS_PER_RECONCILE = 1024;
uint64_t GetCurrentTimeMs()
{
INJECT_POINT("shm_ref.GetCurrentTimeMs", [](uint64_t value) { return value; });
return static_cast<uint64_t>(GetSteadyClockTimeStampMs());
}
uint64_t GetMaybeExpiredDeadlineMs(int64_t requestTimeoutMs)
{
auto expireMs = std::min<uint64_t>(std::max<int64_t>(requestTimeoutMs, SHM_REF_RECONCILE_MIN_EXPIRE_MS),
SHM_REF_RECONCILE_MAX_EXPIRE_MS);
return GetCurrentTimeMs() + expireMs;
}
}
SharedMemoryRefTable::SharedMemoryRefTable()
{
if (maybeExpiredFlushThread_ != nullptr) {
return;
}
maybeExpiredFlushExit_ = false;
maybeExpiredFlushPost_.Clear();
maybeExpiredFlushThread_ = std::make_unique<Thread>([this]() {
while (!maybeExpiredFlushExit_) {
(void)maybeExpiredFlushPost_.WaitFor(SHM_REF_FLUSH_INTERVAL_MS);
if (maybeExpiredFlushExit_) {
break;
}
FlushMaybeExpiredQueue(GetCurrentTimeMs());
}
});
}
SharedMemoryRefTable::~SharedMemoryRefTable()
{
if (maybeExpiredFlushThread_ == nullptr) {
return;
}
maybeExpiredFlushExit_ = true;
maybeExpiredFlushPost_.Set();
maybeExpiredFlushThread_->join();
}
Status SharedMemoryRefTable::GetShmUnit(const ShmKey &shmId, std::shared_ptr<ShmUnit> &shmUnit)
{
TbbMemoryObjectRefTable::const_accessor shmAccessor;
auto found = shmRefTable_.find(shmAccessor, shmId);
if (!found) {
RETURN_STATUS(K_NOT_FOUND, FormatString("Get a not found shm: %s ", shmId));
}
shmUnit = shmAccessor->second.first;
return Status::OK();
}
void SharedMemoryRefTable::ClientTableGetOrInsert(const ClientKey &clientId,
TbbMemoryClientRefTable::const_accessor &accessor)
{
while (!clientRefTable_.find(accessor, clientId)) {
TbbMemoryClientRefTable::accessor writeAccessor;
if (clientRefTable_.insert(writeAccessor, clientId)) {
auto clientInfo = std::make_shared<ObjectRefInfo<ShmKey>>();
writeAccessor->second = std::move(clientInfo);
}
}
}
void SharedMemoryRefTable::InitShmRefForClient(const ClientKey &clientId, bool supportMultiShmRefCount)
{
LOG(INFO) << "Initializing shared memory reference for client " << clientId
<< ", supportMultiShmRefCount: " << supportMultiShmRefCount;
TbbMemoryClientRefTable::accessor accessor;
if (clientRefTable_.insert(accessor, clientId)) {
auto clientInfo = std::make_shared<ObjectRefInfo<ShmKey>>(!supportMultiShmRefCount);
accessor->second = std::move(clientInfo);
} else {
LOG(WARNING) << "Client " << clientId << " already exists";
}
}
void SharedMemoryRefTable::AddShmUnit(const ClientKey &clientId, std::shared_ptr<ShmUnit> &shmUnit,
int64_t requestTimeoutMs)
{
const auto &shmId = shmUnit->GetId();
TbbMemoryClientRefTable::const_accessor clientAccessor;
TbbMemoryObjectRefTable::accessor objectAccessor;
ClientTableGetOrInsert(clientId, clientAccessor);
if (clientAccessor->second->AddRef(shmId)) {
shmUnit->IncrementRefCount();
}
if (!shmRefTable_.find(objectAccessor, shmId)) {
shmRefTable_.emplace(objectAccessor, shmId, std::make_pair(shmUnit, std::unordered_set<ClientKey>()));
}
objectAccessor->second.second.emplace(clientId);
if (shmUnit->GetRefCount() == 1) {
datasystem::memory::Allocator::Instance()->ChangeNoRefPageCount(-1);
datasystem::memory::Allocator::Instance()->ChangeRefPageCount(1);
metrics::GetGauge(static_cast<uint16_t>(metrics::KvMetricId::WORKER_SHM_REF_TABLE_BYTES))
.Inc(static_cast<int64_t>(shmUnit->size));
}
RecordMaybeExpiredShm(clientId, shmId, requestTimeoutMs);
METRIC_INC(metrics::KvMetricId::WORKER_SHM_REF_ADD_TOTAL);
metrics::GetGauge(static_cast<uint16_t>(metrics::KvMetricId::WORKER_SHM_REF_TABLE_SIZE))
.Set(static_cast<int64_t>(shmRefTable_.size()));
VLOG(1) << "AddShmUnit for shmid: " << shmUnit->id << " client id: " << clientId;
}
void SharedMemoryRefTable::AddShmUnits(TbbMemoryClientRefTable::const_accessor &clientAccessor,
std::vector<std::shared_ptr<ShmUnit>> &shmUnits, int64_t requestTimeoutMs)
{
TbbMemoryObjectRefTable::accessor objectAccessor;
const ClientKey &clientId = clientAccessor->first;
uint64_t addedCount = 0;
uint64_t firstRefBytes = 0;
for (auto &shmUnit : shmUnits) {
if (shmUnit == nullptr) {
continue;
}
const auto &shmId = shmUnit->GetId();
if (clientAccessor->second->AddRef(shmId)) {
shmUnit->IncrementRefCount();
}
if (!shmRefTable_.find(objectAccessor, shmId)) {
shmRefTable_.emplace(objectAccessor, shmId,
std::make_pair(shmUnit, std::unordered_set<ClientKey>{ clientId }));
} else {
objectAccessor->second.second.emplace(clientId);
}
objectAccessor.release();
if (shmUnit->GetRefCount() == 1) {
datasystem::memory::Allocator::Instance()->ChangeNoRefPageCount(-1);
datasystem::memory::Allocator::Instance()->ChangeRefPageCount(1);
firstRefBytes += shmUnit->size;
}
RecordMaybeExpiredShm(clientId, shmId, requestTimeoutMs);
++addedCount;
VLOG(1) << "AddShmUnit for shmid: " << shmUnit->id << " client id: " << clientId;
}
if (addedCount > 0) {
METRIC_ADD(metrics::KvMetricId::WORKER_SHM_REF_ADD_TOTAL, addedCount);
metrics::GetGauge(static_cast<uint16_t>(metrics::KvMetricId::WORKER_SHM_REF_TABLE_SIZE))
.Set(static_cast<int64_t>(shmRefTable_.size()));
}
if (firstRefBytes > 0) {
metrics::GetGauge(static_cast<uint16_t>(metrics::KvMetricId::WORKER_SHM_REF_TABLE_BYTES))
.Inc(static_cast<int64_t>(firstRefBytes));
}
}
Status SharedMemoryRefTable::RemoveShmUnit(const ClientKey &clientId, const ShmKey &shmId)
{
TbbMemoryClientRefTable::accessor clientAccessor;
TbbMemoryObjectRefTable::accessor shmAccessor;
auto found = clientRefTable_.find(clientAccessor, clientId);
if (!found) {
LOG(WARNING) << FormatString("The clientId not exists in clientRefTable_, clientId: %s, shmId: %s", clientId,
shmId);
return Status::OK();
}
if (!shmRefTable_.find(shmAccessor, shmId)) {
LOG(WARNING) << FormatString("The shmId not exists in shmRefTable_, clientId: %s, shmId: %s", clientId, shmId);
return Status::OK();
}
auto shmUnit = shmAccessor->second.first;
#ifdef WITH_TESTS
INJECT_POINT("RemoveShmUnit");
#endif
RemoveShmUnitDetail(clientId, false, shmAccessor, clientAccessor);
ClearMaybeExpiredShmIds(clientId, { shmId });
METRIC_INC(metrics::KvMetricId::WORKER_SHM_REF_REMOVE_TOTAL);
metrics::GetGauge(static_cast<uint16_t>(metrics::KvMetricId::WORKER_SHM_REF_TABLE_SIZE))
.Set(static_cast<int64_t>(shmRefTable_.size()));
VLOG(1) << "RemoveShmUnit for shmid: " << shmId << " client id: " << clientId;
return Status::OK();
}
void SharedMemoryRefTable::RemoveShmUnitDetail(const ClientKey &clientId, bool forceRemoveClient,
TbbMemoryObjectRefTable::accessor &shmAccessor,
TbbMemoryClientRefTable::accessor &clientAccessor)
{
const auto &shmUnit = shmAccessor->second.first;
const auto shmId = shmUnit->GetId();
uint32_t refCount = 0;
if (clientAccessor->second->RemoveAndGetRefCnt(shmId, refCount)) {
if (shmUnit->GetRefCount() > 0) {
shmUnit->DecrementRefCount();
} else {
LOG(WARNING) << "RemoveShmUnit: The value of refCount is 0 and cannot be decreased. id:"
<< BytesUuidToString(shmId.ToString());
}
}
if (shmUnit->GetRefCount() == 0) {
datasystem::memory::Allocator::Instance()->ChangeNoRefPageCount(1);
datasystem::memory::Allocator::Instance()->ChangeRefPageCount(-1);
metrics::GetGauge(static_cast<uint16_t>(metrics::KvMetricId::WORKER_SHM_REF_TABLE_BYTES))
.Dec(static_cast<int64_t>(shmUnit->size));
}
if (!forceRemoveClient && refCount > 0) {
return;
}
auto &clientIds = shmAccessor->second.second;
(void)clientIds.erase(clientId);
if (clientIds.empty()) {
VLOG(1) << FormatString("Erase shmId: %s", shmAccessor->first);
if (!shmRefTable_.erase(shmAccessor)) {
LOG(ERROR) << FormatString("Fail to erase shmId: %s", shmAccessor->first);
}
}
}
#ifdef WITH_TESTS
bool SharedMemoryRefTable::Contains(const ClientKey &clientId, const ShmKey &shmId) const
{
TbbMemoryClientRefTable::accessor accessor;
if (clientRefTable_.find(accessor, clientId)) {
return accessor->second->Contains(shmId);
}
return false;
}
#endif
void SharedMemoryRefTable::GetClientRefIds(const ClientKey &clientId, std::vector<ShmKey> &shmIds) const
{
TbbMemoryClientRefTable::accessor accessor;
if (clientRefTable_.find(accessor, clientId)) {
accessor->second->GetRefIds(shmIds);
}
}
void SharedMemoryRefTable::RecordMaybeExpiredShm(const ClientKey &clientId, const ShmKey &shmId,
int64_t requestTimeoutMs)
{
auto expireTimeMs = GetMaybeExpiredDeadlineMs(requestTimeoutMs);
std::lock_guard<std::shared_mutex> lock(maybeExpiredShmQueueMutex_);
maybeExpiredShmQueue_.push({ requestTimeoutMs, expireTimeMs, clientId, shmId });
VLOG(1) << "RecordMaybeExpiredShm: clientId=" << clientId << " shmId=" << shmId
<< " requestTimeoutMs=" << requestTimeoutMs << " expireTimeMs=" << expireTimeMs;
}
void SharedMemoryRefTable::GetMaybeExpiredShmIds(const ClientKey &clientId, std::vector<ShmKey> &shmIds)
{
shmIds.clear();
std::shared_lock<std::shared_mutex> lock(maybeExpiredShmTableMutex_);
auto iter = maybeExpiredShmTable_.find(clientId);
if (iter == maybeExpiredShmTable_.end()) {
return;
}
auto count = std::min(iter->second.size(), MAX_MAYBE_EXPIRED_SHM_IDS_PER_RECONCILE);
shmIds.reserve(count);
auto shmIter = iter->second.begin();
for (size_t i = 0; i < count; ++i, ++shmIter) {
shmIds.emplace_back(*shmIter);
}
}
void SharedMemoryRefTable::ClearMaybeExpiredShmIds(const ClientKey &clientId, const std::vector<ShmKey> &shmIds)
{
if (shmIds.empty()) {
return;
}
std::lock_guard<std::shared_mutex> lock(maybeExpiredShmTableMutex_);
auto iter = maybeExpiredShmTable_.find(clientId);
if (iter == maybeExpiredShmTable_.end()) {
return;
}
for (const auto &shmId : shmIds) {
(void)iter->second.erase(shmId);
}
if (iter->second.empty()) {
maybeExpiredShmTable_.erase(iter);
}
}
void SharedMemoryRefTable::ReconcileClientShmRefs(const ClientKey &clientId,
const std::vector<ShmKey> &confirmedExpiredShmIds,
std::vector<ShmKey> &maybeExpiredShmIds)
{
for (const auto &shmId : confirmedExpiredShmIds) {
LOG(INFO) << "confirmed expired, try decreasing ref count for " << shmId << ", client: " << clientId;
Status rc = RemoveShmUnit(clientId, shmId);
if (rc.IsError()) {
LOG(WARNING) << FormatString("[ObjectKey %s] DoDecrease failed, error: %s", shmId, rc.ToString());
}
}
ClearMaybeExpiredShmIds(clientId, confirmedExpiredShmIds);
GetMaybeExpiredShmIds(clientId, maybeExpiredShmIds);
}
Status SharedMemoryRefTable::RemoveClient(const ClientKey &clientId)
{
TbbMemoryClientRefTable::accessor clientAccessor;
if (!clientRefTable_.find(clientAccessor, clientId)) {
return Status::OK();
}
std::vector<ShmKey> shmIds;
clientAccessor->second->GetRefIds(shmIds);
uint64_t removedCount = 0;
for (const auto &shmId : shmIds) {
TbbMemoryObjectRefTable::accessor shmAccessor;
if (!shmRefTable_.find(shmAccessor, shmId)) {
continue;
}
RemoveShmUnitDetail(clientId, true, shmAccessor, clientAccessor);
++removedCount;
}
if (removedCount > 0) {
METRIC_ADD(metrics::KvMetricId::WORKER_SHM_REF_REMOVE_TOTAL, removedCount);
METRIC_ADD(metrics::KvMetricId::WORKER_REMOVE_CLIENT_REFS_TOTAL, removedCount);
metrics::GetGauge(static_cast<uint16_t>(metrics::KvMetricId::WORKER_SHM_REF_TABLE_SIZE))
.Set(static_cast<int64_t>(shmRefTable_.size()));
}
clientRefTable_.erase(clientAccessor);
std::lock_guard<std::shared_mutex> lock(maybeExpiredShmTableMutex_);
maybeExpiredShmTable_.erase(clientId);
return Status::OK();
}
void SharedMemoryRefTable::FlushMaybeExpiredQueue(uint64_t nowMs)
{
std::vector<MaybeExpiredShmItem> expiredItems;
size_t queueSize = 0;
{
std::lock_guard<std::shared_mutex> lockForQueue(maybeExpiredShmQueueMutex_);
queueSize = maybeExpiredShmQueue_.size();
while (!maybeExpiredShmQueue_.empty() && maybeExpiredShmQueue_.top().expireTimeMs <= nowMs) {
auto item = maybeExpiredShmQueue_.top();
expiredItems.push_back(item);
maybeExpiredShmQueue_.pop();
}
}
constexpr int logIntervalSec = 120;
LOG_EVERY_T(INFO, logIntervalSec) << "The size of maybeExpiredShmQueue_ is " << queueSize
<< ", the size of shmRefTable_ is " << shmRefTable_.size();
std::vector<MaybeExpiredShmItem> failedItems;
for (const auto &item : expiredItems) {
if (!ClientContainsShm(item.clientId, item.shmId)) {
continue;
}
std::lock_guard<std::shared_mutex> lockForTable(maybeExpiredShmTableMutex_);
auto &shmIds = maybeExpiredShmTable_[item.clientId];
if (!shmIds.emplace(item.shmId).second) {
failedItems.emplace_back(item);
}
}
for (const auto &item : failedItems) {
RecordMaybeExpiredShm(item.clientId, item.shmId, item.requestTimeoutMs);
}
}
bool SharedMemoryRefTable::ClientContainsShm(const ClientKey &clientId, const ShmKey &shmId) const
{
TbbMemoryClientRefTable::const_accessor accessor;
if (!clientRefTable_.find(accessor, clientId)) {
return false;
}
return accessor->second->Contains(shmId);
}
}
}