#pragma once
#include <mutex>
#include <thread>
#include <unordered_map>
#include <variant>
#include <future>
#include <atomic>
#include <string>
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <c10d/Utils.hpp>
#include <c10d/Work.hpp>
#include "torch_npu/csrc/distributed/LCCLUtils.hpp"
#include "torch_npu/csrc/npu/Event.h"
namespace c10d_npu {
const std::string LCCL_BACKEND_NAME = "lccl";
class ProcessGroupLCCL : public c10d::Backend {
public:
class WorkLCCL : public c10d::Work, public std::enable_shared_from_this<WorkLCCL> {
public:
explicit WorkLCCL(const std::vector<at::Device>& devices);
~WorkLCCL() override;
bool isCompleted() override;
bool isSuccess() const override;
bool wait(std::chrono::milliseconds timeout) override;
void synchronize() override;
bool finishedNPUExecution();
std::vector<at::Tensor> result() override;
protected:
std::vector<at::Device> devices_;
std::vector<at_npu::lccl::LcclComm> lcclComms_;
std::shared_ptr<std::vector<c10_npu::NPUEvent>> lcclStartEvents_;
std::shared_ptr<std::vector<c10_npu::NPUEvent>> lcclEndEvents_;
bool blockingWait_ = false;
std::chrono::milliseconds opTimeout_;
std::chrono::time_point<std::chrono::steady_clock> workStartTime_;
private:
void synchronizeInternal(std::chrono::milliseconds timeout);
void checkAndSetException() const;
void checkAndThrowException() const;
bool finishedNPUExecutionInternal() const;
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
std::shared_ptr<std::vector<at::Tensor>> outputs_;
c10::intrusive_ptr<c10d::Store> store_;
c10::intrusive_ptr<at::ivalue::Future> future_;
std::vector<at::Tensor> lazy_destroy_tensors_;
friend class ProcessGroupLCCL;
};
ProcessGroupLCCL(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size);
~ProcessGroupLCCL() override;
const std::string getBackendName() const override
{
return LCCL_BACKEND_NAME;
}
c10::intrusive_ptr<c10d::Work> allreduce(
std::vector<at::Tensor>& tensors,
const c10d::AllreduceOptions& opts = c10d::AllreduceOptions()) override;
c10::intrusive_ptr<c10d::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const c10d::AllgatherOptions& opts = c10d::AllgatherOptions()) override;
c10::intrusive_ptr<c10d::Work> broadcast(
std::vector<at::Tensor>& tensors,
const c10d::BroadcastOptions& opts = c10d::BroadcastOptions()) override;
c10::intrusive_ptr<c10d::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const c10d::ReduceScatterOptions& opts = c10d::ReduceScatterOptions()) override;
static const int64_t kProcessGroupLCCLOpTimeoutMillis;
protected:
std::vector<at_npu::lccl::LcclComm>& getLCCLComm(
const std::string& devicesKey,
const std::vector<at::Device>& devices);
c10::intrusive_ptr<c10d::Store> store_;
bool blockingWait_ = false;
std::chrono::milliseconds opTimeout_;
std::unordered_map<std::string, std::vector<c10_npu::NPUStream>> lcclStreams_;
std::unordered_map<std::string, std::vector<at_npu::lccl::LcclComm>> devLCCLCommMap_;
std::unordered_map<std::string, std::vector<c10_npu::NPUEvent>> lcclEvents_;
std::mutex mutex_;
private:
template <typename Fn>
c10::intrusive_ptr<c10d::Work> collective(
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn,
c10d::OpType opType);
template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<c10d::Work> collective(
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn,
PreProcess pre,
PostProcess post,
c10d::OpType opType);
};
}