* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "datasystem/common/rpc/rpc_stub_cache_mgr.h"
#include <mutex>
#include "datasystem/common/inject/inject_point.h"
#include "datasystem/common/perf/perf_manager.h"
#include "datasystem/common/rpc/zmq/zmq_stub_conn.h"
#include "datasystem/common/util/gflag/common_gflags.h"
#include "datasystem/common/util/net_util.h"
#include "datasystem/common/util/status_helper.h"
#include "datasystem/protos/master_object.stub.rpc.pb.h"
#include "datasystem/protos/master_stream.stub.rpc.pb.h"
#include "datasystem/protos/stream_posix.stub.rpc.pb.h"
#include "datasystem/protos/worker_object.stub.rpc.pb.h"
#include "datasystem/protos/worker_stream.stub.rpc.pb.h"
DS_DEFINE_int32(oc_worker_worker_pool_size, 3, "Number of parallel connections between worker/worker. Default is 3.");
DS_DEFINE_int32(sc_worker_worker_pool_size, 3, "Number of parallel connections between worker/worker. Default is 3.");
namespace datasystem {
namespace {
constexpr int64_t SLOW_RPC_STUB_THRESHOLD_MS = 2;
const char *GetSlowPhase(int64_t lookupElapsedMs, int64_t accessElapsedMs, int64_t createElapsedMs)
{
if (createElapsedMs >= lookupElapsedMs && createElapsedMs >= accessElapsedMs) {
return "create";
}
if (accessElapsedMs >= lookupElapsedMs) {
return "access";
}
return "lookup";
}
void LogStubGetEvent(const char *eventName, const HostPort &hostPort, StubType type, bool cacheHit,
int64_t lookupElapsedMs, int64_t getDataElapsedMs, int64_t accessElapsedMs,
int64_t createElapsedMs, int attempts, const Status *rc = nullptr)
{
const auto totalElapsedMs = lookupElapsedMs + getDataElapsedMs + accessElapsedMs + createElapsedMs;
const auto phase = GetSlowPhase(lookupElapsedMs, accessElapsedMs, createElapsedMs);
if (rc == nullptr) {
LOG_IF(INFO, totalElapsedMs > SLOW_RPC_STUB_THRESHOLD_MS)
<< FormatString("[%s] dst=%s type=%d hit=%d phase=%s lookup=%ldms access=%ldms create=%ldms "
"total=%ldms retry=%d trace=%s",
eventName, hostPort.ToString(), static_cast<int>(type), cacheHit, phase, lookupElapsedMs,
accessElapsedMs, createElapsedMs, totalElapsedMs, attempts,
Trace::Instance().GetTraceID());
return;
}
if (totalElapsedMs > SLOW_RPC_STUB_THRESHOLD_MS) {
LOG(ERROR) << FormatString("[%s] dst=%s type=%d phase=%s total=%ldms retry=%d error=%s trace=%s", eventName,
hostPort.ToString(), static_cast<int>(type), phase, totalElapsedMs, attempts,
rc->ToString(), Trace::Instance().GetTraceID());
return;
}
LOG(ERROR) << FormatString("[%s] dst=%s type=%d retry=%d error=%s trace=%s", eventName, hostPort.ToString(),
static_cast<int>(type), attempts, rc->ToString(), Trace::Instance().GetTraceID());
}
}
Status RpcStubCacheMgr::Init(uint64_t maxStubCount, const HostPort &localAddress)
{
LOG(INFO) << FormatString("Init RpcStubCacheMgr for %s, max cache num: %d", localAddress.ToString(), maxStubCount);
std::lock_guard<std::mutex> lck(initMutex_);
(void)ZmqStubConnMgr::Instance();
auto policy = std::make_unique<LruCountPolicy>();
policy->SetCacheCount(maxStubCount);
RETURN_IF_NOT_OK(LruForRpcStubCacheMgr::Builder()
.SetPolicy(std::move(policy))
.SetNumPartitions(stubPriorityNum_)
.Build(&lruCache_));
localAddress_ = localAddress;
InitCreators();
init_ = true;
return Status::OK();
}
Status RpcStubCacheMgr::CreateRpcStub(StubType type, const std::shared_ptr<RpcChannel> &channel,
std::shared_ptr<RpcStubBase> &stub)
{
switch (type) {
case StubType::WORKER_WORKER_OC_SVC:
stub = std::make_shared<WorkerWorkerOCService_Stub>(channel);
break;
case StubType::WORKER_MASTER_OC_SVC:
stub = std::make_shared<master::MasterOCService_Stub>(channel, FLAGS_node_timeout_s * TO_MILLISECOND);
break;
case StubType::WORKER_WORKER_SC_SVC:
stub = std::make_shared<WorkerWorkerSCService_Stub>(channel);
break;
case StubType::WORKER_MASTER_SC_SVC:
stub = std::make_shared<master::MasterSCService_Stub>(channel);
break;
case StubType::MASTER_WORKER_OC_SVC:
stub = std::make_shared<MasterWorkerOCService_Stub>(channel);
break;
case StubType::MASTER_WORKER_SC_SVC:
stub = std::make_shared<MasterWorkerSCService_Stub>(channel);
break;
case StubType::MASTER_MASTER_OC_SVC:
stub = std::make_shared<master::MasterOCService_Stub>(channel);
break;
case StubType::WORKER_WORKER_TRANS_SVC:
stub = std::make_shared<WorkerWorkerTransportService_Stub>(channel);
break;
default:
RETURN_STATUS(K_RUNTIME_ERROR, "Unsupport type: " + std::to_string(static_cast<int>(type)));
}
return stub->GetInitStatus();
}
Status RpcStubCacheMgr::CreateRpcChannel(const HostPort &hostPort, const std::string &serviceName,
std::shared_ptr<RpcChannel> &channel, size_t poolSize)
{
CHECK_FAIL_RETURN_STATUS(channel == nullptr, K_RUNTIME_ERROR, "channel is not nullptr");
RpcCredential cred;
RETURN_IF_NOT_OK(RpcAuthKeyManager::CreateCredentials(WORKER_SERVER_NAME, cred));
channel = std::make_shared<RpcChannel>(hostPort, cred);
RETURN_RUNTIME_ERROR_IF_NULL(channel);
if (!serviceName.empty()) {
channel->SetServiceTcpDirect(serviceName);
}
if (poolSize > 0) {
channel->SetServiceConnectPoolSize(serviceName, poolSize);
}
return Status::OK();
}
bool RpcStubCacheMgr::EnableOcWorkerWorkerDirectPort()
{
INJECT_POINT("RpcStubCacheMgr.EnableOcWorkerWorkerDirectPort", []() { return true; });
return FLAGS_oc_worker_worker_direct_port > 0;
}
bool RpcStubCacheMgr::EnableScWorkerWorkerDirectPort()
{
return FLAGS_sc_worker_worker_direct_port > 0;
}
void RpcStubCacheMgr::InitCreators()
{
creators_.emplace(
StubType::WORKER_WORKER_OC_SVC, [](const HostPort &hostPort, std::shared_ptr<RpcStubBase> &rpcStub) {
return CreatorTemplate(
[&hostPort](std::shared_ptr<RpcChannel> &channel) {
RETURN_IF_NOT_OK(CreateRpcChannel(
hostPort, EnableOcWorkerWorkerDirectPort() ? WorkerWorkerOCService_Stub::FullServiceName() : "",
channel, FLAGS_oc_worker_worker_pool_size));
return Status::OK();
},
StubType::WORKER_WORKER_OC_SVC, rpcStub);
});
creators_.emplace(
StubType::WORKER_MASTER_OC_SVC, [](const HostPort &hostPort, std::shared_ptr<RpcStubBase> &rpcStub) {
return CreatorTemplate(
[&hostPort](std::shared_ptr<RpcChannel> &channel) { return CreateRpcChannel(hostPort, "", channel); },
StubType::WORKER_MASTER_OC_SVC, rpcStub);
});
creators_.emplace(
StubType::WORKER_WORKER_SC_SVC, [](const HostPort &hostPort, std::shared_ptr<RpcStubBase> &rpcStub) {
return CreatorTemplate(
[&hostPort](std::shared_ptr<RpcChannel> &channel) {
return CreateRpcChannel(
hostPort, EnableScWorkerWorkerDirectPort() ? WorkerWorkerSCService_Stub::FullServiceName() : "",
channel, FLAGS_sc_worker_worker_pool_size);
},
StubType::WORKER_WORKER_SC_SVC, rpcStub);
});
creators_.emplace(
StubType::WORKER_MASTER_SC_SVC, [](const HostPort &hostPort, std::shared_ptr<RpcStubBase> &rpcStub) {
return CreatorTemplate(
[&hostPort](std::shared_ptr<RpcChannel> &channel) { return CreateRpcChannel(hostPort, "", channel); },
StubType::WORKER_MASTER_SC_SVC, rpcStub);
});
creators_.emplace(
StubType::MASTER_WORKER_OC_SVC, [](const HostPort &hostPort, std::shared_ptr<RpcStubBase> &rpcStub) {
return CreatorTemplate(
[&hostPort](std::shared_ptr<RpcChannel> &channel) { return CreateRpcChannel(hostPort, "", channel); },
StubType::MASTER_WORKER_OC_SVC, rpcStub);
});
creators_.emplace(
StubType::MASTER_WORKER_SC_SVC, [](const HostPort &hostPort, std::shared_ptr<RpcStubBase> &rpcStub) {
return CreatorTemplate(
[&hostPort](std::shared_ptr<RpcChannel> &channel) { return CreateRpcChannel(hostPort, "", channel); },
StubType::MASTER_WORKER_SC_SVC, rpcStub);
});
creators_.emplace(
StubType::MASTER_MASTER_OC_SVC, [](const HostPort &hostPort, std::shared_ptr<RpcStubBase> &rpcStub) {
return CreatorTemplate(
[&hostPort](std::shared_ptr<RpcChannel> &channel) { return CreateRpcChannel(hostPort, "", channel); },
StubType::MASTER_MASTER_OC_SVC, rpcStub);
});
creators_.emplace(
StubType::WORKER_WORKER_TRANS_SVC, [](const HostPort &hostPort, std::shared_ptr<RpcStubBase> &rpcStub) {
return CreatorTemplate(
[&hostPort](std::shared_ptr<RpcChannel> &channel) { return CreateRpcChannel(hostPort, "", channel); },
StubType::WORKER_WORKER_TRANS_SVC, rpcStub);
});
}
Status RpcStubCacheMgr::GetStub(const HostPort &hostPort, StubType type, std::shared_ptr<RpcStubBase> &rpcStub)
{
Timer timer;
int64_t lookupElapsedMs = 0;
int64_t getDataElapsedMs = 0;
int64_t accessElapsedMs = 0;
int64_t createElapsedMs = 0;
bool cacheHit = false;
PerfPoint point(PerfKey::WORKER_RPC_STUB_CACHE_LOOKUP);
std::shared_ptr<RpcStubCacheMgrObj> encapsulatedData = nullptr;
if (lruCache_->Lookup(HashKeyForRpcStubCacheMgr(hostPort, type), &encapsulatedData).IsOk()) {
lookupElapsedMs = static_cast<int64_t>(timer.ElapsedMilliSecondAndReset());
cacheHit = true;
rpcStub = encapsulatedData->GetData();
getDataElapsedMs = static_cast<int64_t>(timer.ElapsedMilliSecondAndReset());
if (rpcStub != nullptr) {
LogStubGetEvent("SLOW_RPC_STUB_GET", hostPort, type, cacheHit, lookupElapsedMs, getDataElapsedMs,
accessElapsedMs, createElapsedMs, 0);
return Status::OK();
}
} else {
lookupElapsedMs = static_cast<int64_t>(timer.ElapsedMilliSecondAndReset());
}
point.RecordAndReset(PerfKey::WORKER_RPC_STUB_CACHE_FIND_CREATOR);
auto creator = creators_.find(type);
if (creator == creators_.end() || creator->second == nullptr) {
RETURN_STATUS(K_RUNTIME_ERROR, "Unsupported type: " + std::to_string(static_cast<int>(type)));
}
auto newEncapsulatedData = std::make_shared<RpcStubCacheMgrObj>(hostPort, type);
Status rc;
const int maxRetries = 5;
const int retryIntervalMs = 100;
int attempts = 0;
{
point.RecordAndReset(PerfKey::WORKER_RPC_STUB_CACHE_ACCESS);
newEncapsulatedData->GetWriteLck();
Raii raii([&newEncapsulatedData]() { newEncapsulatedData->ReleaseWriteLck(); });
do {
rc = lruCache_->Access(HashKeyForRpcStubCacheMgr(hostPort, type), newEncapsulatedData);
if (rc.GetCode() == K_TRY_AGAIN) {
attempts++;
if (attempts < maxRetries) {
std::this_thread::sleep_for(std::chrono::milliseconds(retryIntervalMs));
} else {
RETURN_STATUS_LOG_ERROR(K_RUNTIME_ERROR, "Get error after retry: " + rc.GetMsg());
}
}
} while (rc.GetCode() == K_TRY_AGAIN);
accessElapsedMs = static_cast<int64_t>(timer.ElapsedMilliSecondAndReset());
RETURN_IF_NOT_OK(rc);
point.RecordAndReset(PerfKey::WORKER_RPC_STUB_CACHE_CONNECT);
rc = creator->second(hostPort, rpcStub);
createElapsedMs = static_cast<int64_t>(timer.ElapsedMilliSecond());
point.Record();
if (rc.IsError()) {
LogStubGetEvent("RPC_STUB_GET_FAIL", hostPort, type, cacheHit, lookupElapsedMs, getDataElapsedMs,
accessElapsedMs, createElapsedMs, attempts, &rc);
LOG_IF_ERROR(Remove(hostPort, type), "remove rpc stub failed");
return rc;
}
newEncapsulatedData->SetDataWithoutLck(rpcStub);
}
LogStubGetEvent("SLOW_RPC_STUB_GET", hostPort, type, cacheHit, lookupElapsedMs, getDataElapsedMs,
accessElapsedMs, createElapsedMs, attempts);
return Status::OK();
}
Status RpcStubCacheMgr::Remove(const HostPort &hostPort, StubType type)
{
Timer timer;
Status rc = lruCache_->Remove(HashKeyForRpcStubCacheMgr(hostPort, type));
auto totalElapsedMs = static_cast<int64_t>(timer.ElapsedMilliSecond());
LOG_IF(INFO, totalElapsedMs > SLOW_RPC_STUB_THRESHOLD_MS
|| (rc.IsError() && rc.GetCode() != StatusCode::K_NOT_FOUND))
<< FormatString("[SLOW_RPC_STUB_REMOVE] dst=%s type=%d total=%ldms error=%s trace=%s",
hostPort.ToString(), static_cast<int>(type), totalElapsedMs, rc.ToString(),
Trace::Instance().GetTraceID());
return rc;
}
namespace stub_priority {
StubPriority GetStubPriority(StubType type)
{
switch (type) {
case StubType::MASTER_WORKER_OC_SVC:
return StubPriority::LOW;
case StubType::WORKER_WORKER_TRANS_SVC:
case StubType::WORKER_WORKER_OC_SVC:
case StubType::WORKER_MASTER_OC_SVC:
case StubType::WORKER_WORKER_SC_SVC:
case StubType::WORKER_MASTER_SC_SVC:
case StubType::MASTER_WORKER_SC_SVC:
case StubType::MASTER_MASTER_OC_SVC:
return StubPriority::HIGH;
#ifdef WITH_TESTS
case StubType::TEST_TYPE_1:
return StubPriority::LOW;
case StubType::TEST_TYPE_2:
case StubType::TEST_TYPE_3:
return StubPriority::HIGH;
#endif
}
return StubPriority::INVALID;
}
}
}