#ifdef USE_RPC_FRAMEWORK
#pragma once
#include <atomic>
#include <thread>
#include <c10/core/thread_pool.h>
#include <c10/util/strong_type.h>
#include <torch/csrc/distributed/rpc/agent_utils.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
namespace tensorpipe_npu {
class Context;
class Error;
class Listener;
class Message;
class Pipe;
namespace transport {
class Context;
}
namespace channel {
class Context;
}
}
namespace torch_npu {
namespace distributed {
namespace rpc {
using torch::distributed::rpc::collectCurrentNames;
using torch::distributed::rpc::collectNames;
using torch::distributed::rpc::createExceptionResponse;
using torch::distributed::rpc::DeviceMap;
using torch::distributed::rpc::JitFuture;
using torch::distributed::rpc::kRpcTimeoutErrorStr;
using torch::distributed::rpc::kSecToMsConversion;
using torch::distributed::rpc::kUnsetRpcTimeout;
using torch::distributed::rpc::makeRPCError;
using torch::distributed::rpc::Message;
using torch::distributed::rpc::MessageType;
using torch::distributed::rpc::removeCurrentName;
using torch::distributed::rpc::RequestCallback;
using torch::distributed::rpc::RpcAgent;
using torch::distributed::rpc::RPCErrorType;
using torch::distributed::rpc::syncCallCount;
using torch::distributed::rpc::TensorPipeRpcBackendOptions;
using torch::distributed::rpc::worker_id_t;
using torch::distributed::rpc::WorkerInfo;
constexpr int64_t kShmTransportPriority = 200;
constexpr int64_t kIbvTransportPriority = 100;
constexpr int64_t kUvTransportPriority = 0;
constexpr int64_t kCmaChannelPriority = 1200;
constexpr int64_t kMultiplexedUvChannelPriority = 1100;
constexpr int64_t kBasicChannelPriority = 1000;
constexpr int64_t kNpuBasicChannelPriority = 0;
using steady_clock_time_point = std::chrono::time_point<std::chrono::steady_clock>;
struct TransportRegistration {
std::shared_ptr<tensorpipe_npu::transport::Context> transport;
int64_t priority;
std::string address;
};
C10_DECLARE_REGISTRY(TensorPipeTransportRegistry, TransportRegistration);
struct ChannelRegistration {
std::shared_ptr<tensorpipe_npu::channel::Context> channel;
int64_t priority;
};
C10_DECLARE_REGISTRY(TensorPipeChannelRegistry, ChannelRegistration);
constexpr auto kDefaultNumWorkerThreads = 16;
struct NetworkSourceInfo {
worker_id_t srcRank;
std::vector<uint8_t> srcMachineAddr;
};
struct AggregatedNetworkData {
uint64_t numCalls{0};
uint64_t totalSentBytes{0};
uint64_t totalRecvBytes{0};
uint64_t totalErrors{0};
};
class TensorPipeAgent : public RpcAgent {
public:
TensorPipeAgent(const c10::intrusive_ptr<::c10d::Store> &store, std::string selfName, worker_id_t selfId,
c10::optional<int> worldSize, TensorPipeRpcBackendOptions opts,
std::unordered_map<std::string, DeviceMap> reverseDeviceMaps, std::vector<c10::Device> devices,
std::unique_ptr<RequestCallback> cb);
TensorPipeAgent(const TensorPipeAgent &) = delete;
TensorPipeAgent &operator=(const TensorPipeAgent &) = delete;
c10::intrusive_ptr<JitFuture> send(const WorkerInfo &to, c10::intrusive_ptr<Message> message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
const DeviceMap &deviceMap = {}) override;
void join(bool shutdown = false, float timeout = 0) override;
void sync() override{};
void startImpl() override;
void shutdownImpl() override;
~TensorPipeAgent() override;
const WorkerInfo &getWorkerInfo(const std::string &workerName) const override;
const WorkerInfo &getWorkerInfo(worker_id_t workerId) const override;
std::vector<WorkerInfo> getWorkerInfos() const override;
void updateGroupMembership(const WorkerInfo &workerInfo, const std::vector<c10::Device> devices,
const std::unordered_map<std::string, DeviceMap> reverseDeviceMaps, bool isJoin);
std::unordered_map<std::string, std::string> getMetrics() override;
void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override;
TensorPipeRpcBackendOptions getBackendOptions() const;
const c10::intrusive_ptr<::c10d::Store> getStore() const;
DeviceMap getDeviceMap(const WorkerInfo &dest) const override;
const std::vector<c10::Device> &getDevices() const override;
using NetworkDataDict = std::unordered_map<std::string, AggregatedNetworkData>;
NetworkDataDict getNetworkData();
NetworkSourceInfo getNetworkSourceInfo();
static const std::string &guessAddress();
size_t timeoutMapSize();
size_t numPendingResponses();
size_t messageIdToTimeoutMapSize();
const bool isStaticGroup_;
protected:
virtual void pipeWrite(const std::shared_ptr<tensorpipe_npu::Pipe> &, c10::intrusive_ptr<Message> message,
std::vector<c10::Device> &&devices, std::vector<c10::Stream> streams,
std::function<void(const tensorpipe_npu::Error &)>) noexcept;
private:
void removeFromTimeoutMap(uint64_t messageId);
void prepareNames(bool isStaticGroup);
void checkAndSetStaticGroup(const c10::intrusive_ptr<::c10d::Store> &store);
const std::string &findWorkerURL(const WorkerInfo &worker) const;
void leaveGroup();
void pipeRead(const std::shared_ptr<tensorpipe_npu::Pipe> &,
std::function<void(const tensorpipe_npu::Error &, c10::intrusive_ptr<Message>,
std::vector<c10::Stream>)>) noexcept;
void onListenerAccepted(const tensorpipe_npu::Error &error, std::shared_ptr<tensorpipe_npu::Pipe> &pipe);
void respond(std::shared_ptr<tensorpipe_npu::Pipe> &pipe);
void sendCompletedResponseMessage(std::shared_ptr<tensorpipe_npu::Pipe> &pipe, JitFuture &futureResponseMessage,
uint64_t messageId, std::vector<c10::Stream> stream);
void trackNetworkData(uint64_t requestSize, uint64_t responseSize, const std::string &destWorkerName);
void trackNetworkError(uint64_t requestSize, const std::string &destWorkerName);
inline std::vector<c10::Device> getDevicesForRemote(const std::string &remoteName, const Message &message) const;
struct AtomicJitFuture {
explicit AtomicJitFuture(const std::vector<c10::Device> &devices)
{
jitFuture = c10::make_intrusive<at::ivalue::Future>(at::AnyClassType::get(), devices);
}
std::atomic_flag isComplete = ATOMIC_FLAG_INIT;
c10::intrusive_ptr<JitFuture> jitFuture;
};
struct ClientPipe {
explicit ClientPipe(std::shared_ptr<tensorpipe_npu::Pipe> pipe) : pipe_(std::move(pipe)) {}
std::shared_ptr<tensorpipe_npu::Pipe> pipe_;
mutable std::mutex mutex_;
bool inError_{false};
std::unordered_map<uint64_t, std::shared_ptr<AtomicJitFuture>> pendingResponseMessage_;
};
const c10::intrusive_ptr<::c10d::Store> store_;
const TensorPipeRpcBackendOptions opts_;
std::unordered_map<std::string, DeviceMap> reverseDeviceMaps_;
std::vector<c10::Device> devices_;
c10::ThreadPool threadPool_;
std::shared_ptr<tensorpipe_npu::Context> context_;
std::shared_ptr<tensorpipe_npu::Listener> listener_;
mutable std::mutex connectedPipesMutex_;
std::unordered_map<worker_id_t, ClientPipe> connectedPipes_;
std::unordered_map<worker_id_t, WorkerInfo> workerIdToInfo_;
std::unordered_map<std::string, WorkerInfo> workerNameToInfo_;
std::unordered_map<std::string, std::string> workerNameToURL_;
::c10d::PrefixStore rankToNameStore_;
::c10d::PrefixStore nameToAddressStore_;
::c10d::PrefixStore shutdownStore_;
int worldSize_ = 0;
std::atomic<uint64_t> nextMessageID_{0};
struct TimeoutMessageMetadata {
TimeoutMessageMetadata(uint64_t messageId_, std::shared_ptr<AtomicJitFuture> responseFuture_,
std::chrono::milliseconds timeout_)
: messageId(messageId_), responseFuture(std::move(responseFuture_)), timeout(timeout_)
{
}
uint64_t messageId;
std::shared_ptr<AtomicJitFuture> responseFuture;
std::chrono::milliseconds timeout;
};
std::map<steady_clock_time_point, std::vector<TimeoutMessageMetadata>> timeoutMap_;
std::unordered_map<uint64_t, steady_clock_time_point> messageIdToTimeout_;
std::thread timeoutThread_;
void pollTimeoutRpcs();
std::mutex timeoutMapMutex_;
std::condition_variable timeoutThreadCV_;
inline steady_clock_time_point computeRpcMessageExpiryTime(std::chrono::milliseconds timeout) const
{
return std::chrono::time_point_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() + timeout);
}
void handleClientError(ClientPipe &clientPipe, const tensorpipe_npu::Error &error);
struct TimeSeriesMetricsTracker {
uint64_t currentSum_;
uint64_t currentCount_;
explicit TimeSeriesMetricsTracker(uint64_t currentSum = 0, uint64_t currentCount = 0);
void addData(uint64_t dataPoint);
float computeAverage() const;
};
std::unordered_map<std::string, TimeSeriesMetricsTracker> timeSeriesMetrics_;
std::mutex metricsMutex_;
struct GroupMembershipLockGuard {
GroupMembershipLockGuard(std::mutex &mutex, bool isStaticGroup) : ref_(mutex), isStaticGroup_(isStaticGroup)
{
if (isStaticGroup_) {
ref_.lock();
}
}
~GroupMembershipLockGuard()
{
if (isStaticGroup_) {
ref_.unlock();
}
}
GroupMembershipLockGuard(const GroupMembershipLockGuard &) = delete;
private:
std::mutex &ref_;
bool isStaticGroup_;
};
mutable std::mutex groupMembershipMutex_;
NetworkDataDict networkData_;
std::mutex networkDataMutex_;
std::mutex callCountMutex_;
std::condition_variable callCountCV_;
int32_t clientActiveCalls_{0};
int32_t serverActiveCalls_{0};
int32_t serverActiveAsyncCalls_{0};
std::atomic<bool> shuttingDown_{false};
void increaseCallCount(int32_t &count);
void decreaseCallCount(int32_t &count);
void markFutureAsComplete(std::shared_ptr<AtomicJitFuture> atomicFuture, c10::intrusive_ptr<Message> message,
std::vector<c10::Stream> streams);
void markFutureWithError(std::shared_ptr<AtomicJitFuture> atomicFuture, std::string errorMsg);
};
}
}
}
#endif