#ifdef USE_RPC_FRAMEWORK
#include "torch_npu/csrc/distributed/rpc/init.h"
#include <pybind11/chrono.h>
#include <pybind11/operators.h>
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/python_functions.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/request_callback_impl.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/types.h>
#include "torch_npu/csrc/distributed/rpc/tensorpipe_agent.h"
#include "torch_npu/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h"
namespace torch_npu {
namespace distributed {
namespace rpc {
constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000);
using torch::distributed::rpc::DeviceMap;
using torch::distributed::rpc::kDefaultInitMethod;
using torch::distributed::rpc::kDefaultRpcTimeoutSeconds;
using torch::distributed::rpc::RequestCallbackImpl;
using torch::distributed::rpc::RpcAgent;
using torch::distributed::rpc::RpcBackendOptions;
using torch::distributed::rpc::TensorPipeRpcBackendOptions;
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
PyObject *rpc_npu_init(PyObject *_unused, PyObject *noargs)
{
auto torch_npu_C_module = THPObjectPtr(PyImport_ImportModule("torch_npu._C"));
if (!torch_npu_C_module) {
throw python_error();
}
auto torch_npu_C_m = py::handle(torch_npu_C_module).cast<py::module>();
auto m = torch_npu_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings");
auto module = py::handle(m).cast<py::module>();
py::module rpc_module = py::module::import("torch.distributed.rpc");
shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent", rpc_module.attr("TensorPipeAgent"))
.def(py::init([](const c10::intrusive_ptr<::c10d::Store> &store, std::string selfName, worker_id_t selfId,
c10::optional<int> worldSize, TensorPipeRpcBackendOptions opts,
std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
std::vector<c10::Device> devices) {
return std::shared_ptr<TensorPipeAgent>(
new TensorPipeAgent(store, std::move(selfName), selfId, worldSize, std::move(opts),
std::move(reverseDeviceMaps), std::move(devices),
std::make_unique<RequestCallbackImpl>()),
torch::impl::destroy_without_gil<TensorPipeAgent>);
}),
py::arg("store"), py::arg("name"), py::arg("rank"), py::arg("world_size"), py::arg("rpc_backend_options"),
py::arg("reverse_device_maps"), py::arg("devices"))
.def("join", &TensorPipeAgent::join, py::call_guard<py::gil_scoped_release>(), py::arg("shutdown") = false,
py::arg("timeout") = 0)
.def("shutdown", &TensorPipeAgent::shutdown, py::call_guard<py::gil_scoped_release>())
.def("get_worker_info", (const WorkerInfo &(TensorPipeAgent::*)(void) const) &RpcAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def("get_worker_info",
(const WorkerInfo &(TensorPipeAgent::*)(const std::string &) const) &TensorPipeAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def("get_worker_info",
(const WorkerInfo &(TensorPipeAgent::*)(worker_id_t id) const) &TensorPipeAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def("get_worker_infos",
(std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &TensorPipeAgent::getWorkerInfos,
py::call_guard<py::gil_scoped_release>())
.def("_get_device_map",
(DeviceMap(TensorPipeAgent::*)(const WorkerInfo &dst) const) &TensorPipeAgent::getDeviceMap,
py::call_guard<py::gil_scoped_release>())
.def("_get_backend_options", &TensorPipeAgent::getBackendOptions, py::call_guard<py::gil_scoped_release>())
.def("_update_group_membership", &TensorPipeAgent::updateGroupMembership,
py::call_guard<py::gil_scoped_release>())
.def_readonly("is_static_group", &TensorPipeAgent::isStaticGroup_)
.def_property_readonly("store", &TensorPipeAgent::getStore);
shared_ptr_class_<FaultyTensorPipeRpcBackendOptions>(
module,
"FaultyTensorPipeRpcBackendOptions",
rpc_module.attr("_TensorPipeRpcBackendOptionsBase"))
.def(
py::init<
int,
float,
std::string,
std::vector<std::string>,
std::unordered_map<std::string, float>,
int>(),
py::arg("num_worker_threads"),
py::arg("rpc_timeout"),
py::arg("init_method"),
py::arg("messages_to_fail"),
py::arg("messages_to_delay"),
py::arg("num_fail_sends"))
.def_readwrite(
"num_worker_threads", &TensorPipeRpcBackendOptions::numWorkerThreads)
.def_readwrite(
"messages_to_fail",
&FaultyTensorPipeRpcBackendOptions::messagesToFail)
.def_readwrite(
"messages_to_delay",
&FaultyTensorPipeRpcBackendOptions::messagesToDelay)
.def_readwrite(
"num_fail_sends", &FaultyTensorPipeRpcBackendOptions::numFailSends);
shared_ptr_class_<FaultyTensorPipeAgent>(
module, "FaultyTensorPipeAgent", module.attr("TensorPipeAgent"))
.def(
py::init(
[](const c10::intrusive_ptr<::c10d::Store> &store,
std::string name,
worker_id_t rank,
c10::optional<int> world_size,
FaultyTensorPipeRpcBackendOptions opts,
std::unordered_map<std::string, DeviceMap> reverse_device_maps,
std::vector<c10::Device> devices) {
return std::shared_ptr<FaultyTensorPipeAgent>(
new FaultyTensorPipeAgent(
store,
std::move(name),
rank,
world_size,
std::move(opts),
std::move(reverse_device_maps),
std::move(devices),
std::make_unique<RequestCallbackImpl>()),
torch::impl::destroy_without_gil<FaultyTensorPipeAgent>);
}),
py::arg("store"),
py::arg("name"),
py::arg("rank"),
py::arg("world_size"),
py::arg("opts"),
py::arg("reverse_device_maps"),
py::arg("devices"))
.def(
"join",
&TensorPipeAgent::join,
py::call_guard<py::gil_scoped_release>(),
py::arg("shutdown") = false,
py::arg("timeout") = 0)
.def(
"shutdown",
&TensorPipeAgent::shutdown,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo &(TensorPipeAgent::*)(void) const) &
RpcAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo &(TensorPipeAgent::*)(const std::string &) const) &
TensorPipeAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo &(TensorPipeAgent::*)(worker_id_t id) const) &
TensorPipeAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_infos",
(std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
TensorPipeAgent::getWorkerInfos,
py::call_guard<py::gil_scoped_release>());
Py_RETURN_TRUE;
}
}
}
}
#endif