* @copyright Copyright (c) 2024 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <map>
#include <list>
#include <mutex>
#include <vector>
#include <thread>
#include <atomic>
#include <unordered_map>
#include <functional>
#include <netdb.h>
#include "StoreMessagePacker.hpp"
namespace c10d {
namespace torch_npu {
using PI = std::pair<int, int>;
* @brief wrapper for pthread_spinlock_t
*/
class SpinLock {
public:
SpinLock() noexcept
{
pthread_spin_init(&spinlock_, 0);
}
virtual ~SpinLock() noexcept
{
pthread_spin_destroy(&spinlock_);
}
void lock() noexcept
{
pthread_spin_lock(&spinlock_);
}
bool try_lock() noexcept
{
return pthread_spin_trylock(&spinlock_) == 0;
}
void unlock() noexcept
{
pthread_spin_unlock(&spinlock_);
}
private:
pthread_spinlock_t spinlock_{};
};
* @brief store client IO context for server.
*/
class ClientIoContext {
public:
explicit ClientIoContext(int fd, uint32_t events) : currentEvents_{ events }, fd_{ fd } {}
public:
void ReceiveData() noexcept;
bool HasNextReq() const noexcept;
StoreMessage NextRequest() noexcept;
void SendResponse(const StoreMessage &response) noexcept;
bool SendBufEmpty() const noexcept;
void FlushSendBuf() noexcept;
uint32_t currentEvents_;
private:
const int fd_;
uint32_t recSize_{ 0 };
std::vector<uint8_t> recBuf_;
std::vector<uint8_t> sendBuf_;
std::list<StoreMessage> requests_;
};
using ServerProcFn = std::function<StoreMessage(int fd, const StoreMessage &req)>;
* @brief epoll based TCP server with registered message processor.
*/
class ParallelTcpServer {
public:
explicit ParallelTcpServer(uint32_t threadNum, const std::string host, uint16_t port, uint32_t listenThreadNum,
ServerProcFn process) noexcept;
explicit ParallelTcpServer(uint32_t threadNum, const std::string localSocketPath, uint32_t listenThreadNum,
ServerProcFn process) noexcept;
int Start() noexcept;
void Stop() noexcept;
inline void SetKeysWaitingSocket(const std::vector<std::string> &keys, int socket, int workerFd, int64_t waitCount) noexcept
{
std::lock_guard<SpinLock> lockGuard{ spinLock_ };
for (auto &key : keys) {
keyWaitingSockets_[key].emplace_back(std::make_pair(socket, workerFd));
}
socketWaitKeyNum_[std::make_pair(socket, workerFd)] = waitCount;
}
void WakeupWaitingClients(const std::string &key) noexcept;
private:
static int CreateSocketWithFamily(const std::string host, uint16_t port, int family) noexcept;
static int CreateSocketAndListen(const ::addrinfo &addr) noexcept;
static int CreateSocket(const std::string host, uint16_t port) noexcept;
static int CreateLocalSocket(const std::string &localSocketPath) noexcept;
static int CreateEpoll(int targetFd = -1) noexcept;
void LoopProcessClients(int epollFd) noexcept;
void ProcessListenEvent() noexcept;
void ProcessClientEvent(int epFd, int fd, uint32_t event, std::unordered_map<int, ClientIoContext> &ctx) noexcept;
static int SetNonBlocking(int fd) noexcept;
static int SetBlockSocketTimeout(int fd) noexcept;
private:
const uint32_t listenThreadNum_{ 1 };
const uint32_t threadNum_{ 0 };
const std::uint16_t port_{ 0 };
const std::string host_{};
const std::string localSocketPath_{};
const ServerProcFn process_{ nullptr };
int listenSocket_{ -1 };
bool isLocalServer_{ false };
std::vector<int> epClientFds_;
std::vector<std::thread> clientThreads_;
std::vector<std::thread> listenThreads_;
uint8_t *buffer_{ nullptr };
std::atomic<bool> running_{ false };
SpinLock spinLock_;
std::unordered_map<std::string, std::list<PI>> keyWaitingSockets_;
std::map<PI, int64_t> socketWaitKeyNum_;
};
}
}