#ifdef USE_RPC_FRAMEWORK
#pragma once
#include <torch/csrc/distributed/rpc/utils.h>
#include "torch_npu/csrc/core/npu/NPUException.h"
namespace tensorpipe_npu {
class Message;
class Allocation;
class Descriptor;
}
namespace torch_npu {
namespace distributed {
namespace rpc {
using torch::distributed::rpc::cloneSparseTensors;
using torch::distributed::rpc::Message;
using torch::distributed::rpc::MessageType;
const c10::Stream &getStreamForDevice(const std::vector<c10::Stream> &streams, const c10::Device &device);
class TensorpipeDeviceTypeConverter {
public:
virtual c10::optional<std::vector<char>> prepareTensorForSending(const c10::Storage &storage,
const std::vector<c10::Stream> &streams,
tensorpipe_npu::Message &message) const = 0;
virtual at::DataPtr allocateTensorForReceiving(int deviceIndex, size_t length,
const std::vector<c10::Stream> &streams,
tensorpipe_npu::Allocation &allocation) const = 0;
virtual ~TensorpipeDeviceTypeConverter() = default;
};
extern std::array<std::atomic<const TensorpipeDeviceTypeConverter *>,
static_cast<size_t>(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
device_type_converter_registry;
class TensorpipeDeviceTypeConverterRegistrar {
public:
TensorpipeDeviceTypeConverterRegistrar(c10::DeviceType, const TensorpipeDeviceTypeConverter *);
};
#define C10_REGISTER_TENSORPIPE_DEVICE_TYPE_CONVERTER(DevType, TensorpipeDeviceTypeConverter) \
static ::torch_npu::distributed::rpc::TensorpipeDeviceTypeConverterRegistrar C10_ANONYMOUS_VARIABLE( \
g_##DeviceType)(::c10::DeviceType::DevType, new TensorpipeDeviceTypeConverter());
inline const TensorpipeDeviceTypeConverter *getDeviceTypeConverter(c10::DeviceType type)
{
return device_type_converter_registry[static_cast<size_t>(type)].load();
}
struct TensorpipeWriteBuffers {
std::unique_ptr<MessageType> type;
std::unique_ptr<int64_t> id;
std::vector<char> payload;
std::vector<char> pickle;
std::vector<torch::Tensor> tensors;
std::vector<std::vector<char>> copiedTensors;
};
struct TensorpipeReadBuffers {
std::unique_ptr<MessageType> type;
std::unique_ptr<int64_t> id;
std::vector<char> payload;
std::vector<char> pickle;
std::vector<c10::DataPtr> tensors;
};
std::tuple<tensorpipe_npu::Message, TensorpipeWriteBuffers> tensorpipeSerialize(
c10::intrusive_ptr<Message> rpcMessage, std::vector<c10::Device> devices, const std::vector<c10::Stream> &streams);
std::pair<tensorpipe_npu::Allocation, TensorpipeReadBuffers> tensorpipeAllocate(
const tensorpipe_npu::Descriptor &tpDescriptor, const std::vector<c10::Stream> &streams);
c10::intrusive_ptr<Message> tensorpipeDeserialize(tensorpipe_npu::Descriptor &&tpDescriptor,
TensorpipeReadBuffers &&holder);
}
}
}
#endif