#ifdef USE_RPC_FRAMEWORK
#include "torch_npu/csrc/distributed/rpc/tensorpipe_agent.h"
#include <limits>
#include <thread>
#include <tuple>
#include <utility>
#include <c10/core/StreamGuard.h>
#include <c10/util/irange.h>
#include <fmt/format.h>
#include <torch/csrc/distributed/rpc/agent_utils.h>
#include <torch/csrc/distributed/rpc/utils.h>
#include <unistd.h>
#include "third_party/Tensorpipe/tensorpipe/common/device_id.h"
#include "third_party/Tensorpipe/tensorpipe/tensorpipe.h"
#include "third_party/acl/inc/acl/acl_rt.h"
#include "torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.h"
#include "torch_npu/csrc/distributed/rpc/tensorpipe_utils.h"
#include "torch_npu/csrc/core/npu/NPUFunctions.h"
#include "torch_npu/csrc/core/npu/NPUException.h"
namespace torch_npu {
namespace distributed {
namespace rpc {
namespace {
const std::string kSocketIfnameEnvVar = "TP_SOCKET_IFNAME";
const std::string kDefaultUvAddress = "127.0.0.1";
const std::string kGilAverageWaitTime = "agent.gil_average_wait_time_us";
const std::string kThreadPoolSize = "agent.thread_pool_size";
const std::string kNumIdleThreads = "agent.num_idle_threads";
const std::string kClientActiveCalls = "agent.client_active_calls";
const std::string kServerActiveCalls = "agent.server_active_calls";
const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
std::vector<c10::Device> getDevicesForTensors(const std::vector<torch::Tensor> &tensors, const DeviceMap &deviceMap,
const std::string &remoteName)
{
const auto errStr = c10::str(
"TensorPipe RPC backend only supports CPU tensors by default, please "
"move your tensors to CPU before sending them over RPC, or call "
"`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly "
"configure device mapping. ",
"Request device mapping is not available for destination ", remoteName);
std::vector<c10::Device> devices;
devices.reserve(tensors.size());
bool hasMappedDevice = false;
for (const auto &t : tensors) {
if (t.device().is_cpu()) {
const auto deviceIter = deviceMap.find(c10::kCPU);
if (deviceIter == deviceMap.end()) {
devices.emplace_back(c10::kCPU);
} else {
devices.emplace_back(deviceIter->second);
hasMappedDevice = true;
}
} else {
const auto deviceIter = deviceMap.find(t.device());
TORCH_CHECK(deviceIter != deviceMap.end(), errStr, " for device ", t.device(),
" but received a tensor on that device.", DIST_ERROR(ErrCode::PARAM));
devices.push_back(deviceIter->second);
hasMappedDevice = true;
}
}
if (!hasMappedDevice) {
devices.clear();
}
return devices;
}
std::vector<c10::Stream> getStreamsFromPoolForDevices(const std::vector<c10::Device> &devices)
{
if (devices.empty()) {
return {};
}
c10::impl::VirtualGuardImpl impl(devices[0].type());
std::vector<c10::Stream> streams;
streams.reserve(devices.size());
for (const c10::Device &device : devices) {
TORCH_INTERNAL_ASSERT(device.type() == impl.type(), DIST_ERROR(ErrCode::PARAM));
streams.push_back(impl.getStreamFromGlobalPool(device));
}
return streams;
}
std::vector<c10::Stream> getCurrentStreamsForDevices(const std::vector<c10::Device> &devices)
{
if (devices.empty()) {
return {};
}
c10::impl::VirtualGuardImpl impl(devices[0].type());
std::vector<c10::Stream> streams;
streams.reserve(devices.size());
for (const c10::Device &device : devices) {
TORCH_INTERNAL_ASSERT(device.type() == impl.type(), DIST_ERROR(ErrCode::PARAM));
streams.push_back(impl.getStream(device));
}
return streams;
}
std::vector<c10::Device> getDevicesOfTensors(const std::vector<torch::Tensor> &tensors)
{
c10::optional<c10::impl::VirtualGuardImpl> impl;
size_t deviceCount = 0;
std::vector<bool> indexBitset;
for (const torch::Tensor &tensor : tensors) {
if (!tensor.is_cpu()) {
c10::Device device = tensor.device();
if (!impl.has_value()) {
impl.emplace(device.type());
indexBitset.resize(impl->deviceCount());
}
TORCH_INTERNAL_ASSERT(device.type() == impl->type(), DIST_ERROR(ErrCode::PARAM));
TORCH_INTERNAL_ASSERT(device.has_index(), DIST_ERROR(ErrCode::PARAM));
if (!indexBitset[device.index()]) {
deviceCount++;
indexBitset[device.index()] = true;
}
}
}
std::vector<c10::Device> devices;
devices.reserve(deviceCount);
for (const auto idx : c10::irange(indexBitset.size())) {
if (indexBitset[idx]) {
devices.emplace_back(impl->type(), static_cast<c10::DeviceIndex>(idx));
}
}
return devices;
}
void makeStreamsWaitOnOthers(const std::vector<c10::Stream> &consumers, const std::vector<c10::Stream> &producers)
{
for (const c10::Stream &producer : producers) {
const c10::Stream &consumer = getStreamForDevice(consumers, producer.device());
c10::Event event(producer.device_type());
event.record(producer);
event.block(consumer);
}
}
}
C10_DEFINE_REGISTRY_WITHOUT_WARNING(TensorPipeTransportRegistry, TransportRegistration);
C10_DEFINE_REGISTRY_WITHOUT_WARNING(TensorPipeChannelRegistry, ChannelRegistration);
const std::string &TensorPipeAgent::guessAddress()
{
static const std::string uvAddress = []() {
tensorpipe_npu::Error error;
std::string result;
const char *ifnameEnv = std::getenv(kSocketIfnameEnvVar.c_str());
if (ifnameEnv != nullptr) {
std::tie(error, result) = tensorpipe_npu::transport::uv::lookupAddrForIface(ifnameEnv);
if (error) {
LOG(WARNING) << "Failed to look up the IP address for interface " << ifnameEnv << " (" << error.what()
<< "), defaulting to Default Address";
return kDefaultUvAddress;
}
} else {
std::tie(error, result) = tensorpipe_npu::transport::uv::lookupAddrForHostname();
if (error) {
LOG(WARNING) << "Failed to look up the IP address for the hostname (" << error.what()
<< "), defaulting to Default Address";
return kDefaultUvAddress;
}
}
return result;
}();
return uvAddress;
}
namespace {
std::unique_ptr<TransportRegistration> makeUvTransport()
{
auto context = tensorpipe_npu::transport::uv::create();
std::string address = TensorPipeAgent::guessAddress();
return std::make_unique<TransportRegistration>(
TransportRegistration{std::move(context), kUvTransportPriority, std::move(address)});
}
C10_REGISTER_CREATOR(TensorPipeTransportRegistry, uv, makeUvTransport);
#if TENSORPIPE_HAS_SHM_TRANSPORT
std::unique_ptr<TransportRegistration> makeShmTransport()
{
auto context = tensorpipe_npu::transport::shm::create();
return std::make_unique<TransportRegistration>(
TransportRegistration{std::move(context), kShmTransportPriority, ""});
}
C10_REGISTER_CREATOR(TensorPipeTransportRegistry, shm, makeShmTransport);
#endif
#if TENSORPIPE_HAS_IBV_TRANSPORT
std::unique_ptr<TransportRegistration> makeIbvTransport()
{
auto context = tensorpipe_npu::transport::ibv::create();
std::string address = TensorPipeAgent::guessAddress();
return std::make_unique<TransportRegistration>(
TransportRegistration{std::move(context), kIbvTransportPriority, std::move(address)});
}
C10_REGISTER_CREATOR(TensorPipeTransportRegistry, ibv, makeIbvTransport);
#endif
std::unique_ptr<ChannelRegistration> makeBasicChannel()
{
auto context = tensorpipe_npu::channel::basic::create();
return std::make_unique<ChannelRegistration>(ChannelRegistration{std::move(context), kBasicChannelPriority});
}
C10_REGISTER_CREATOR(TensorPipeChannelRegistry, basic, makeBasicChannel);
#if TENSORPIPE_HAS_CMA_CHANNEL
std::unique_ptr<ChannelRegistration> makeCmaChannel()
{
auto context = tensorpipe_npu::channel::cma::create();
return std::make_unique<ChannelRegistration>(ChannelRegistration{std::move(context), kCmaChannelPriority});
}
C10_REGISTER_CREATOR(TensorPipeChannelRegistry, cma, makeCmaChannel);
#endif
constexpr static int kNumUvThreads = 16;
std::unique_ptr<ChannelRegistration> makeMultiplexedUvChannel()
{
std::vector<std::shared_ptr<tensorpipe_npu::transport::Context>> contexts;
std::vector<std::shared_ptr<tensorpipe_npu::transport::Listener>> listeners;
for (const auto laneIdx C10_UNUSED : c10::irange(kNumUvThreads)) {
auto context = tensorpipe_npu::transport::uv::create();
std::string address = TensorPipeAgent::guessAddress();
contexts.push_back(std::move(context));
listeners.push_back(contexts.back()->listen(address));
}
auto context = tensorpipe_npu::channel::mpt::create(std::move(contexts), std::move(listeners));
return std::make_unique<ChannelRegistration>(
ChannelRegistration{std::move(context), kMultiplexedUvChannelPriority});
}
C10_REGISTER_CREATOR(TensorPipeChannelRegistry, mpt_uv, makeMultiplexedUvChannel);
}
TensorPipeAgent::TimeSeriesMetricsTracker::TimeSeriesMetricsTracker(uint64_t currentSum, uint64_t currentCount)
: currentSum_(currentSum), currentCount_(currentCount)
{
}
void TensorPipeAgent::TimeSeriesMetricsTracker::addData(uint64_t dataPoint)
{
currentSum_ += dataPoint;
++currentCount_;
}
float TensorPipeAgent::TimeSeriesMetricsTracker::computeAverage() const
{
return currentCount_ == 0 ? 0 : currentSum_ / (float)currentCount_;
}
void TensorPipeAgent::removeFromTimeoutMap(uint64_t messageId)
{
{
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
auto it = messageIdToTimeout_.find(messageId);
if (it == messageIdToTimeout_.end()) {
return;
}
auto &expirationTime = it->second;
auto &timedOutFuturesVector = timeoutMap_[expirationTime];
for (auto it = timedOutFuturesVector.begin(); it != timedOutFuturesVector.end(); it++) {
if (it->messageId == messageId) {
it = timedOutFuturesVector.erase(it);
break;
}
}
if (timedOutFuturesVector.empty()) {
timeoutMap_.erase(expirationTime);
}
messageIdToTimeout_.erase(messageId);
}
}
void TensorPipeAgent::prepareNames(bool isStaticGroup)
{
std::unordered_map<std::string, worker_id_t> nameToId;
if (isStaticGroup) {
nameToId = collectNames(rankToNameStore_, workerInfo_.id_, workerInfo_.name_, worldSize_);
} else {
nameToId = collectCurrentNames(rankToNameStore_, workerInfo_.id_, workerInfo_.name_);
}
for (const auto &entry : nameToId) {
const auto &workerName = entry.first;
const auto &workerId = entry.second;
workerIdToInfo_.emplace(workerId, WorkerInfo(workerName, workerId));
workerNameToInfo_.emplace(workerName, WorkerInfo(workerName, workerId));
}
}
void TensorPipeAgent::checkAndSetStaticGroup(const c10::intrusive_ptr<::c10d::Store> &store)
{
std::string isStaticGroupKey("rpcIsStaticGroup");
std::string isStaticGroupStr = isStaticGroup_ ? "true" : "false";
std::vector<uint8_t> isStaticGroupVec((uint8_t *)isStaticGroupStr.c_str(),
(uint8_t *)isStaticGroupStr.c_str() + isStaticGroupStr.length());
std::vector<uint8_t> returnedVec;
returnedVec = store->compareSet(isStaticGroupKey, std::vector<uint8_t>(), isStaticGroupVec);
std::string returnedVal = std::string(returnedVec.begin(), returnedVec.end());
}
TensorPipeAgent::TensorPipeAgent(const c10::intrusive_ptr<::c10d::Store> &store, std::string selfName,
worker_id_t selfId, c10::optional<int> worldSize, TensorPipeRpcBackendOptions opts,
std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
std::vector<c10::Device> devices, std::unique_ptr<RequestCallback> cb)
: RpcAgent(WorkerInfo(std::move(selfName), selfId), std::move(cb),
std::chrono::milliseconds((long)(opts.rpcTimeoutSeconds * kSecToMsConversion))),
isStaticGroup_(worldSize.has_value()),
store_(store),
opts_(std::move(opts)),
reverseDeviceMaps_(std::move(reverseDeviceMaps)),
devices_(std::move(devices)),
threadPool_(opts_.numWorkerThreads),
context_(std::make_shared<tensorpipe_npu::Context>(tensorpipe_npu::ContextOptions().name(workerInfo_.name_))),
rankToNameStore_("names", store),
nameToAddressStore_("addrs", store),
shutdownStore_("shutdown", store)
{
tensorpipe_npu::setDeviceId(c10_npu::NpuSysCtrl::GetInstance().InitializedDeviceID());
if (isStaticGroup_) {
worldSize_ = worldSize.value();
}
checkAndSetStaticGroup(store);
prepareNames(isStaticGroup_);
timeSeriesMetrics_.emplace(kGilAverageWaitTime, TimeSeriesMetricsTracker());
}
TensorPipeAgent::~TensorPipeAgent()
{
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is being destroyed";
shutdown();
}
void TensorPipeAgent::startImpl()
{
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is starting";
std::vector<std::string> addresses;
int lowestPriority = std::numeric_limits<int>::max();
std::string lowestPriorityTransport;
for (auto &key : TensorPipeTransportRegistry()->Keys()) {
int64_t priority = -1;
if (opts_.transports.has_value()) {
auto iter = std::find(opts_.transports->begin(), opts_.transports->end(), key);
if (iter == opts_.transports->end()) {
continue;
}
priority = opts_.transports->size() - 1 - (iter - opts_.transports->begin());
}
std::unique_ptr<TransportRegistration> reg = TensorPipeTransportRegistry()->Create(key);
if (reg == nullptr || reg->transport == nullptr) {
TORCH_CHECK(false, "TensorPipeTransport get nullptr", DIST_ERROR(ErrCode::PTR));
}
if (!reg->transport->isViable()) {
continue;
}
if (priority == -1) {
priority = reg->priority;
}
if (priority < lowestPriority) {
lowestPriority = priority;
lowestPriorityTransport = key;
}
addresses.push_back(c10::str(key, "://", reg->address));
context_->registerTransport(priority, std::move(key), std::move(reg->transport));
}
for (auto &key : TensorPipeChannelRegistry()->Keys()) {
int64_t priority = -1;
if (opts_.channels.has_value()) {
auto iter = std::find(opts_.channels->begin(), opts_.channels->end(), key);
if (iter == opts_.channels->end()) {
continue;
}
priority = opts_.channels->size() - 1 - (iter - opts_.channels->begin());
}
std::unique_ptr<ChannelRegistration> reg = TensorPipeChannelRegistry()->Create(key);
if (!reg->channel->isViable()) {
continue;
}
if (priority == -1) {
priority = reg->priority;
}
context_->registerChannel(priority, std::move(key), std::move(reg->channel));
}
listener_ = context_->listen(addresses);
const auto address = listener_->url(lowestPriorityTransport);
nameToAddressStore_.set(workerInfo_.name_, address);
for (const auto &p : workerNameToInfo_) {
const auto &name = p.first;
auto nodeAddrData = nameToAddressStore_.get(name);
auto nodeAddrStr = std::string((const char *)nodeAddrData.data(), nodeAddrData.size());
workerNameToURL_.insert({name, nodeAddrStr});
}
timeoutThread_ = std::thread(&TensorPipeAgent::pollTimeoutRpcs, this);
listener_->accept([this](const tensorpipe_npu::Error &error, std::shared_ptr<tensorpipe_npu::Pipe> pipe) {
onListenerAccepted(error, pipe);
});
}
void TensorPipeAgent::onListenerAccepted(const tensorpipe_npu::Error &error,
std::shared_ptr<tensorpipe_npu::Pipe> &pipe)
{
if (error) {
if (error.isOfType<tensorpipe_npu::ListenerClosedError>() && !rpcAgentRunning_.load()) {
} else {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
<< " encountered error when accepting incoming pipe: " << error.what();
}
return;
}
listener_->accept([this](const tensorpipe_npu::Error &error, std::shared_ptr<tensorpipe_npu::Pipe> pipe) {
onListenerAccepted(error, pipe);
});
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " accepted incoming pipe from " << pipe->getRemoteName();
respond(pipe);
}
void TensorPipeAgent::pipeRead(
const std::shared_ptr<tensorpipe_npu::Pipe> &pipe,
std::function<void(const tensorpipe_npu::Error &, c10::intrusive_ptr<Message>, std::vector<c10::Stream>)>
fn) noexcept
{
pipe->readDescriptor([this, fn{std::move(fn)}, pipe](const tensorpipe_npu::Error &error,
tensorpipe_npu::Descriptor tpDescriptor) mutable {
if (error) {
fn(error, c10::intrusive_ptr<Message>(), {});
return;
}
std::vector<c10::Stream> streams;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
streams = getStreamsFromPoolForDevices(devices_);
}
tensorpipe_npu::Allocation tpAllocation;
TensorpipeReadBuffers tpBuffers;
std::tie(tpAllocation, tpBuffers) = tensorpipeAllocate(tpDescriptor, streams);
pipe->read(std::move(tpAllocation),
[tpDescriptor{std::move(tpDescriptor)},
tpBuffers{std::make_shared<TensorpipeReadBuffers>(std::move(tpBuffers))}, fn{std::move(fn)},
streams{std::move(streams)}](const tensorpipe_npu::Error &error) mutable {
if (error) {
fn(error, c10::intrusive_ptr<Message>(), {});
return;
}
c10::intrusive_ptr<Message> rpcMessage =
tensorpipeDeserialize(std::move(tpDescriptor), std::move(*tpBuffers));
fn(error, std::move(rpcMessage), std::move(streams));
});
});
}
void TensorPipeAgent::pipeWrite(const std::shared_ptr<tensorpipe_npu::Pipe> &pipe,
c10::intrusive_ptr<Message> rpcMessage, std::vector<c10::Device> &&devices,
std::vector<c10::Stream> streams,
std::function<void(const tensorpipe_npu::Error &)> fn) noexcept
{
tensorpipe_npu::Message tpMessage;
TensorpipeWriteBuffers tpBuffers;
std::tie(tpMessage, tpBuffers) = tensorpipeSerialize(std::move(rpcMessage), std::move(devices), streams);
pipe->write(std::move(tpMessage),
[tpBuffers{std::make_shared<TensorpipeWriteBuffers>(std::move(tpBuffers))}, fn{std::move(fn)},
streams{std::move(streams)}](const tensorpipe_npu::Error &error) { fn(error); });
}
void TensorPipeAgent::sendCompletedResponseMessage(std::shared_ptr<tensorpipe_npu::Pipe> &pipe,
JitFuture &futureResponseMessage, uint64_t messageId,
std::vector<c10::Stream> streams)
{
if (!rpcAgentRunning_.load()) {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_ << " won't send response to request #" << messageId
<< " to " << pipe->getRemoteName() << ", as the agent is shutting down";
return;
}
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending response to request #" << messageId << " to "
<< pipe->getRemoteName();
if (!futureResponseMessage.hasError()) {
c10::intrusive_ptr<Message> responseMessage = futureResponseMessage.value().toCustomClass<Message>();
responseMessage->setId(messageId);
std::vector<c10::Device> devices;
try {
devices = getDevicesForRemote(pipe->getRemoteName(), *responseMessage);
}
catch (const std::exception &e) {
responseMessage = createExceptionResponse(e.what(), messageId);
}
for (const auto &tensor : responseMessage->tensors()) {
const auto device = tensor.device();
if (!device.is_cpu()) {
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
if (std::find(devices_.begin(), devices_.end(), device) == devices_.end()) {
std::ostringstream oss;
std::copy(devices_.begin(), devices_.end(), std::ostream_iterator<c10::Device>(oss, ", "));
responseMessage = createExceptionResponse(
c10::str("RPC detected that a user-function output tensor on device ", device,
". This device is not one of the input tensor devices: ", oss.str(),
"which is not yet supported."),
messageId);
break;
}
}
}
pipeWrite(pipe, std::move(responseMessage), std::move(devices), std::move(streams),
[this, pipe, messageId](const tensorpipe_npu::Error &error) {
if (error) {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
<< " encountered error when sending response to request #" << messageId << " to "
<< pipe->getRemoteName() << ": " << error.what();
return;
}
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " done sending response to request #"
<< messageId << " to " << pipe->getRemoteName();
});
} else {
pipeWrite(pipe, createExceptionResponse(futureResponseMessage.tryRetrieveErrorMessage(), messageId), {},
std::move(streams), [this, pipe, messageId](const tensorpipe_npu::Error &error) {
if (error) {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
<< " encountered error when sending response to request #" << messageId << " to "
<< pipe->getRemoteName() << ": " << error.what();
return;
}
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " done sending response to request #"
<< messageId << " to " << pipe->getRemoteName();
});
}
}
void TensorPipeAgent::respond(std::shared_ptr<tensorpipe_npu::Pipe> &pipe)
{
pipeRead(pipe, [this, pipe](const tensorpipe_npu::Error &error, c10::intrusive_ptr<Message> requestMessage,
std::vector<c10::Stream> streams) mutable {
if (error) {
if (shuttingDown_.load()) {
} else {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
<< " encountered error when reading incoming request from " << pipe->getRemoteName()
<< ": " << error.what();
}
return;
}
respond(pipe);
uint64_t messageId = requestMessage->id();
increaseCallCount(serverActiveCalls_);
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " received request #" << messageId << " from "
<< pipe->getRemoteName();
threadPool_.run(
[this, pipe, messageId, requestMessage{std::move(requestMessage)}, streams{std::move(streams)}]() mutable {
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is running request #" << messageId << " from "
<< pipe->getRemoteName() << " in thread pool";
VLOG(1) << "TensorpipeAgent::respond set deciveID="
<< c10_npu::NpuSysCtrl::GetInstance().InitializedDeviceID() << "pid=" << getpid()
<< " thread_id=" << std::this_thread::get_id();
c10_npu::SetDevice(c10_npu::NpuSysCtrl::GetInstance().InitializedDeviceID());
c10::intrusive_ptr<JitFuture> futureResponseMessage;
try {
futureResponseMessage = cb_->operator()(*requestMessage, std::move(streams));
}
catch (const std::exception & ) {
futureResponseMessage = c10::make_intrusive<JitFuture>(at::AnyClassType::get());
futureResponseMessage->setError(std::current_exception());
}
increaseCallCount(serverActiveAsyncCalls_);
futureResponseMessage->addCallback([this, pipe, messageId](JitFuture &futureResponseMessage) mutable {
VLOG(1) << "FutureResponseMessage set deciveID="
<< c10_npu::NpuSysCtrl::GetInstance().InitializedDeviceID() << "pid=" << getpid()
<< " thread_id=" << std::this_thread::get_id();
c10_npu::SetDevice(c10_npu::NpuSysCtrl::GetInstance().InitializedDeviceID());
decreaseCallCount(serverActiveCalls_);
decreaseCallCount(serverActiveAsyncCalls_);
auto streams = getCurrentStreamsForDevices(futureResponseMessage.devices());
sendCompletedResponseMessage(pipe, futureResponseMessage, messageId, std::move(streams));
});
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " done running request #" << messageId << " from "
<< pipe->getRemoteName() << " in thread pool";
});
});
}
c10::intrusive_ptr<JitFuture> TensorPipeAgent::send(const WorkerInfo &toWorkerInfo,
c10::intrusive_ptr<Message> requestMessage,
const float rpcTimeoutSeconds, const DeviceMap &deviceMap)
{
TORCH_CHECK(requestMessage->isRequest(), "TensorPipeAgent::send(..) is only for sending requests.", DIST_ERROR(ErrCode::NOT_SUPPORT));
if (!rpcAgentRunning_.load()) {
auto err = c10::str("Node ", RpcAgent::getWorkerInfo().id_, "tried to send() a message of type ",
requestMessage->type(), " but RPC is no longer running on this node.");
TORCH_CHECK(false, err, DIST_ERROR(ErrCode::INTERNAL));
}
const auto &url = findWorkerURL(toWorkerInfo);
decltype(connectedPipes_)::iterator it;
{
std::unique_lock<std::mutex> lock(connectedPipesMutex_);
it = connectedPipes_.find(toWorkerInfo.id_);
if (it == connectedPipes_.end()) {
std::tie(it, std::ignore) =
connectedPipes_.emplace(std::piecewise_construct, std::forward_as_tuple(toWorkerInfo.id_),
std::forward_as_tuple(context_->connect(
url, tensorpipe_npu::PipeOptions().remoteName(toWorkerInfo.name_))));
}
}
ClientPipe &clientPipe = it->second;
std::shared_ptr<TensorPipeAgent::AtomicJitFuture> futureResponseMessage;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
futureResponseMessage = std::make_shared<AtomicJitFuture>(devices_);
}
uint64_t messageId = nextMessageID_++;
requestMessage->setId(messageId);
{
std::unique_lock<std::mutex> lock(clientPipe.mutex_);
clientPipe.pendingResponseMessage_[messageId] = futureResponseMessage;
}
std::vector<c10::Device> devices;
if (deviceMap.empty()) {
devices = getDevicesForRemote(clientPipe.pipe_->getRemoteName(), *requestMessage);
} else {
devices = getDevicesForTensors(requestMessage->tensors(), deviceMap, clientPipe.pipe_->getRemoteName());
}
futureResponseMessage->jitFuture->addCallback([this](JitFuture & ) {
TORCH_INTERNAL_ASSERT(this->threadPool_.inThreadPool(), "Future marked complete from outside the thread pool", DIST_ERROR(ErrCode::INTERNAL));
});
increaseCallCount(clientActiveCalls_);
auto timeout = rpcTimeoutSeconds == kUnsetRpcTimeout
? getRpcTimeout()
: std::chrono::milliseconds(static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
steady_clock_time_point expirationTime;
if (timeout.count() != 0) {
expirationTime = computeRpcMessageExpiryTime(timeout);
{
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
auto &timeoutFuturesVector = timeoutMap_[expirationTime];
messageIdToTimeout_.emplace(messageId, expirationTime);
timeoutFuturesVector.emplace_back(messageId, futureResponseMessage, timeout);
}
timeoutThreadCV_.notify_one();
}
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #" << messageId << " to "
<< clientPipe.pipe_->getRemoteName();
std::vector<c10::Stream> streams;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
streams = getStreamsFromPoolForDevices(devices_);
}
makeStreamsWaitOnOthers(streams, getCurrentStreamsForDevices(getDevicesOfTensors(requestMessage->tensors())));
pipeWrite(clientPipe.pipe_, std::move(requestMessage), std::move(devices), std::move(streams),
[this, &clientPipe, messageId](const tensorpipe_npu::Error &error) mutable {
if (error) {
if (error.isOfType<tensorpipe_npu::PipeClosedError>() && !rpcAgentRunning_.load()) {
} else {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
<< " encountered error when sending outgoing request #" << messageId << " to "
<< clientPipe.pipe_->getRemoteName() << ": " << error.what();
}
handleClientError(clientPipe, error);
return;
}
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " sent request #" << messageId << " to "
<< clientPipe.pipe_->getRemoteName();
pipeRead(clientPipe.pipe_, [this, &clientPipe](const tensorpipe_npu::Error &error,
c10::intrusive_ptr<Message> responseMessage,
std::vector<c10::Stream> streams) {
if (error) {
if (error.isOfType<tensorpipe_npu::PipeClosedError>() && !rpcAgentRunning_.load()) {
} else {
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
<< " encountered error when reading incoming response from "
<< clientPipe.pipe_->getRemoteName() << ": " << error.what();
}
handleClientError(clientPipe, error);
return;
}
uint64_t messageId = responseMessage->id();
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " received response #" << messageId
<< " from " << clientPipe.pipe_->getRemoteName();
std::shared_ptr<AtomicJitFuture> futureResponseMessage;
{
std::lock_guard<std::mutex> lock(clientPipe.mutex_);
TORCH_INTERNAL_ASSERT(!clientPipe.inError_, "Shouldn't be in error state", DIST_ERROR(ErrCode::INTERNAL));
auto it = clientPipe.pendingResponseMessage_.find(messageId);
TORCH_INTERNAL_ASSERT(it != clientPipe.pendingResponseMessage_.end(), "message ID ",
messageId, " is not recognized", DIST_ERROR(ErrCode::INTERNAL));
futureResponseMessage = std::move(it->second);
clientPipe.pendingResponseMessage_.erase(it);
}
removeFromTimeoutMap(messageId);
if (responseMessage->type() == MessageType::EXCEPTION) {
markFutureWithError(
std::move(futureResponseMessage),
std::string(responseMessage->payload().begin(), responseMessage->payload().end()));
} else {
markFutureAsComplete(std::move(futureResponseMessage), std::move(responseMessage),
std::move(streams));
}
});
});
return futureResponseMessage->jitFuture;
}
void TensorPipeAgent::handleClientError(ClientPipe &clientPipe, const tensorpipe_npu::Error &error)
{
decltype(clientPipe.pendingResponseMessage_) pendingMsgs;
{
std::lock_guard<std::mutex> lock(clientPipe.mutex_);
std::swap(clientPipe.pendingResponseMessage_, pendingMsgs);
clientPipe.inError_ = true;
}
std::string errorMsg = error.what();
for (auto &p : pendingMsgs) {
markFutureWithError(std::move(p.second), errorMsg);
removeFromTimeoutMap(p.first);
}
}
void TensorPipeAgent::pollTimeoutRpcs()
{
while (rpcAgentRunning_.load()) {
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
for (;;) {
if (!rpcAgentRunning_.load()) {
return;
}
if (!timeoutMap_.empty()) {
steady_clock_time_point earliestTimeout = timeoutMap_.begin()->first;
if (std::chrono::steady_clock::now() >= earliestTimeout) {
break;
}
timeoutThreadCV_.wait_until(lock, earliestTimeout);
} else {
timeoutThreadCV_.wait(lock);
}
}
std::vector<TimeoutMessageMetadata> timedOutFutures = std::move(timeoutMap_.begin()->second);
timeoutMap_.erase(timeoutMap_.begin());
for (auto &timeoutMetadata : timedOutFutures) {
messageIdToTimeout_.erase(timeoutMetadata.messageId);
}
lock.unlock();
for (auto &timeoutMetadata : timedOutFutures) {
std::string errorMsg =
fmt::format(kRpcTimeoutErrorStr, timeoutMetadata.timeout.count());
auto err = makeRPCError(errorMsg, RPCErrorType::TIMEOUT);
markFutureWithError(std::move(timeoutMetadata.responseFuture), std::move(err));
}
}
}
void TensorPipeAgent::leaveGroup()
{
std::unique_lock<std::mutex> lock(callCountMutex_);
callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; });
removeCurrentName(rankToNameStore_, workerInfo_.id_, workerInfo_.name_);
shuttingDown_ = true;
}
void TensorPipeAgent::join(bool shutdown, float )
{
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is joining";
if (!isStaticGroup_) {
leaveGroup();
return;
}
while (true) {
{
std::unique_lock<std::mutex> lock(callCountMutex_);
callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; });
}
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " completed all client calls and is entering a barrier";
syncCallCount(shutdownStore_, worldSize_);
{
std::unique_lock<std::mutex> lock(callCountMutex_);
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " exited the barrier and found " << clientActiveCalls_
<< " active client calls";
int totalClientActiveCalls = syncCallCount(shutdownStore_, worldSize_, clientActiveCalls_);
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " completed sync call counts and got a total of "
<< totalClientActiveCalls << " active client calls across all workers";
if (totalClientActiveCalls == 0) {
if (shutdown) {
shuttingDown_ = true;
syncCallCount(shutdownStore_, worldSize_);
}
break;
}
}
}
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " done joining";
}
void TensorPipeAgent::shutdownImpl()
{
LOG(INFO) << "RPC agent for " << workerInfo_.name_ << " is shutting down";
timeoutThreadCV_.notify_one();
if (timeoutThread_.joinable()) {
timeoutThread_.join();
}
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " done waiting for timeout thread to join";
context_->join();
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " done waiting for TensorPipe context to join";
threadPool_.waitWorkComplete();
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " done waiting for thread pool to complete work";
}
const WorkerInfo &TensorPipeAgent::getWorkerInfo(const std::string &workerName) const
{
std::unordered_map<std::string, WorkerInfo>::const_iterator it;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
it = workerNameToInfo_.find(workerName);
}
return it->second;
}
const WorkerInfo &TensorPipeAgent::getWorkerInfo(worker_id_t workerId) const
{
std::unordered_map<worker_id_t, WorkerInfo>::const_iterator it;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
it = workerIdToInfo_.find(workerId);
}
return it->second;
}
std::vector<WorkerInfo> TensorPipeAgent::getWorkerInfos() const
{
std::vector<WorkerInfo> workerInfos;
workerInfos.reserve(workerNameToInfo_.size());
for (auto &item : workerNameToInfo_) {
workerInfos.emplace_back(item.second);
}
return workerInfos;
}
const std::string &TensorPipeAgent::findWorkerURL(const WorkerInfo &worker) const
{
std::unordered_map<std::string, std::string>::const_iterator it;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
it = workerNameToURL_.find(worker.name_);
}
return it->second;
}
void TensorPipeAgent::updateGroupMembership(const WorkerInfo &workerInfo, const std::vector<c10::Device> devices,
const std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
bool isJoin)
{
std::string name = workerInfo.name_;
worker_id_t id = workerInfo.id_;
if (isJoin) {
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
workerIdToInfo_.emplace(id, workerInfo);
workerNameToInfo_.emplace(name, workerInfo);
auto nodeAddrData = nameToAddressStore_.get(name);
auto nodeAddrStr = std::string((const char *)nodeAddrData.data(), nodeAddrData.size());
workerNameToURL_.insert({name, nodeAddrStr});
for (const auto &it : reverseDeviceMaps) {
if (reverseDeviceMaps_.find(it.first) == reverseDeviceMaps_.end()) {
reverseDeviceMaps_[it.first] = it.second;
}
}
for (const auto &it : devices) {
if (std::find(devices_.begin(), devices_.end(), it) == devices_.end()) {
devices_.push_back(it);
}
}
} else {
workerIdToInfo_.erase(id);
workerNameToInfo_.erase(name);
workerNameToURL_.erase(name);
for (auto it = reverseDeviceMaps_.begin(); it != reverseDeviceMaps_.end();) {
if (reverseDeviceMaps.find(it->first) == reverseDeviceMaps.end()) {
it = reverseDeviceMaps_.erase(it);
} else {
it++;
}
}
for (auto it = devices_.begin(); it != devices_.end();) {
if (std::find(devices.begin(), devices.end(), *it) == devices.end()) {
it = devices_.erase(it);
} else {
it++;
}
}
}
}
std::unordered_map<std::string, std::string> TensorPipeAgent::getMetrics()
{
std::unordered_map<std::string, std::string> metrics;
metrics[kThreadPoolSize] = c10::to_string(threadPool_.size());
metrics[kNumIdleThreads] = c10::to_string(threadPool_.numAvailable());
{
std::unique_lock<std::mutex> lock(callCountMutex_);
metrics[kClientActiveCalls] = c10::to_string(clientActiveCalls_);
metrics[kServerActiveCalls] = c10::to_string(serverActiveCalls_);
metrics[kServerActiveAsyncCalls] = c10::to_string(serverActiveAsyncCalls_);
}
if (isGILProfilingEnabled()) {
{
std::unique_lock<std::mutex> lock(metricsMutex_);
auto averageGilWaitTime = timeSeriesMetrics_[kGilAverageWaitTime].computeAverage();
lock.unlock();
metrics[kGilAverageWaitTime] = c10::to_string(averageGilWaitTime);
}
}
return metrics;
}
void TensorPipeAgent::addGilWaitTime(const std::chrono::microseconds gilWaitTime)
{
std::lock_guard<std::mutex> lock(metricsMutex_);
timeSeriesMetrics_[kGilAverageWaitTime].addData(gilWaitTime.count());
}
TensorPipeAgent::NetworkDataDict TensorPipeAgent::getNetworkData()
{
std::lock_guard<std::mutex> lock(networkDataMutex_);
return networkData_;
}
NetworkSourceInfo TensorPipeAgent::getNetworkSourceInfo()
{
NetworkSourceInfo info = {RpcAgent::getWorkerInfo().id_, nameToAddressStore_.get(RpcAgent::getWorkerInfo().name_)};
return info;
}
void TensorPipeAgent::trackNetworkData(uint64_t requestSize, uint64_t responseSize, const std::string &destWorkerName)
{
std::lock_guard<std::mutex> lock(networkDataMutex_);
networkData_[destWorkerName].numCalls++;
networkData_[destWorkerName].totalSentBytes += requestSize;
networkData_[destWorkerName].totalRecvBytes += responseSize;
}
void TensorPipeAgent::trackNetworkError(uint64_t requestSize, const std::string &destWorkerName)
{
std::lock_guard<std::mutex> lock(networkDataMutex_);
networkData_[destWorkerName].numCalls++;
networkData_[destWorkerName].totalSentBytes += requestSize;
networkData_[destWorkerName].totalErrors++;
}
void TensorPipeAgent::increaseCallCount(int32_t &count)
{
{
std::unique_lock<std::mutex> lock(callCountMutex_);
++count;
}
callCountCV_.notify_all();
}
void TensorPipeAgent::decreaseCallCount(int32_t &count)
{
{
std::unique_lock<std::mutex> lock(callCountMutex_);
--count;
}
callCountCV_.notify_all();
}
void TensorPipeAgent::markFutureAsComplete(std::shared_ptr<AtomicJitFuture> atomicFuture,
c10::intrusive_ptr<Message> message, std::vector<c10::Stream> streams)
{
if (!atomicFuture->isComplete.test_and_set()) {
threadPool_.run([this, atomicFuture{std::move(atomicFuture)}, message{std::move(message)},
streams{std::move(streams)}]() mutable {
VLOG(1) << "TensorpipeAgent::respond set deciveID="
<< c10_npu::NpuSysCtrl::GetInstance().InitializedDeviceID() << "pid=" << getpid()
<< " thread_id=" << std::this_thread::get_id();
c10_npu::SetDevice(c10_npu::NpuSysCtrl::GetInstance().InitializedDeviceID());
c10::MultiStreamGuard guard(streams);
std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages = message->getStorages();
atomicFuture->jitFuture->markCompleted(std::move(message), std::move(storages));
decreaseCallCount(clientActiveCalls_);
});
}
}
void TensorPipeAgent::markFutureWithError(std::shared_ptr<AtomicJitFuture> atomicFuture, std::string errorMsg)
{
if (!atomicFuture->isComplete.test_and_set()) {
threadPool_.run([this, atomicFuture{std::move(atomicFuture)}, errorMsg{std::move(errorMsg)}]() mutable {
VLOG(1) << "TensorpipeAgent::respond set deciveID="
<< c10_npu::NpuSysCtrl::GetInstance().InitializedDeviceID() << "pid=" << getpid()
<< " thread_id=" << std::this_thread::get_id();
c10_npu::SetDevice(c10_npu::NpuSysCtrl::GetInstance().InitializedDeviceID());
atomicFuture->jitFuture->setError(std::make_exception_ptr(std::runtime_error(errorMsg)));
decreaseCallCount(clientActiveCalls_);
});
}
}
std::vector<c10::Device> TensorPipeAgent::getDevicesForRemote(const std::string &remoteName,
const Message &message) const
{
std::unordered_map<std::string, DeviceMap> deviceMaps;
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
deviceMaps = message.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_;
}
const auto errStr = c10::str(
"TensorPipe RPC backend only supports CPU tensors by default, please "
"move your tensors to CPU before sending them over RPC, or call "
"`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly "
"configure device mapping. ",
message.isRequest() ? "Request" : "Response", " device mapping is not available for destination ", remoteName);
const auto &iter = deviceMaps.find(remoteName);
if (iter == deviceMaps.end()) {
for (const auto &t : message.tensors()) {
TORCH_CHECK(t.device().is_cpu(), errStr, ", but found tensor on device: ", t.device(), DIST_ERROR(ErrCode::PARAM));
}
return {};
} else {
return getDevicesForTensors(message.tensors(), iter->second, errStr);
}
}
DeviceMap TensorPipeAgent::getDeviceMap(const WorkerInfo &dst) const
{
auto it = opts_.deviceMaps.find(dst.name_);
if (it == opts_.deviceMaps.end()) {
return {};
}
return it->second;
}
const c10::intrusive_ptr<::c10d::Store> TensorPipeAgent::getStore() const { return store_; }
TensorPipeRpcBackendOptions TensorPipeAgent::getBackendOptions() const { return opts_; }
const std::vector<c10::Device> &TensorPipeAgent::getDevices() const
{
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
return devices_;
}
size_t TensorPipeAgent::timeoutMapSize()
{
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
return timeoutMap_.size();
}
size_t TensorPipeAgent::numPendingResponses()
{
std::unique_lock<std::mutex> lock(callCountMutex_);
return clientActiveCalls_;
}
size_t TensorPipeAgent::messageIdToTimeoutMapSize()
{
std::unique_lock<std::mutex> lock(timeoutMapMutex_);
return messageIdToTimeout_.size();
}
}
}
}
#endif