#include "ProcessGroupLCCL.hpp"
#include "torch_npu/csrc/core/NPUBridge.h"
#include "torch_npu/csrc/core/npu/DeviceUtils.h"
#include "torch_npu/csrc/core/npu/NPUGuard.h"
#include "torch_npu/csrc/core/npu/NPUStreamUtils.h"
#include "torch_npu/csrc/framework/FormatHelper.h"
#include "torch_npu/csrc/framework/OpCommand.h"
namespace c10d_npu {
namespace {
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
void syncStreams(const std::vector<at::Device> &devices, std::vector<c10_npu::NPUEvent> &lcclEvents,
std::vector<c10_npu::NPUStream> &lcclStreams)
{
for (size_t i = 0; i < devices.size(); ++i) {
c10_npu::NPUStream &lcclStream = lcclStreams[i];
c10_npu::NPUEvent &lcclEvent = lcclEvents[i];
lcclEvent.record(c10_npu::getCurrentNPUStream(devices[i].index()));
ASCEND_LOGI("Event: record lccl group is successfully executed, event=%p", lcclEvent.event());
lcclEvent.block(lcclStream);
ASCEND_LOGI("Event: block lccl group is successfully executed, event=%p", lcclEvent.event());
}
}
}
ProcessGroupLCCL::WorkLCCL::WorkLCCL(const std::vector<at::Device> &devices)
: devices_(devices), workStartTime_(std::chrono::steady_clock::now())
{
lcclEndEvents_ = std::make_shared<std::vector<c10_npu::NPUEvent>>(devices.size());
lcclComms_.resize(devices.size());
}
ProcessGroupLCCL::WorkLCCL::~WorkLCCL()
{}
bool ProcessGroupLCCL::WorkLCCL::isCompleted()
{
checkAndSetException();
return exception() || finishedNPUExecutionInternal();
}
bool ProcessGroupLCCL::WorkLCCL::isSuccess() const
{
if (exception()) {
return false;
}
return finishedNPUExecutionInternal();
}
void ProcessGroupLCCL::WorkLCCL::synchronizeInternal(std::chrono::milliseconds timeout)
{
for (const auto i: c10::irange(devices_.size())) {
auto currentStream = c10_npu::getCurrentNPUStream(devices_[i].index());
(*lcclEndEvents_)[i].block(currentStream);
ASCEND_LOGI("Event: block lccl work is successfully executed, event=%p", (*lcclEndEvents_)[i].event());
}
if (blockingWait_) {
while (!isCompleted()) {
auto currentTimepoint = std::chrono::steady_clock::now();
if (std::chrono::duration_cast<std::chrono::milliseconds>(currentTimepoint - workStartTime_) > opTimeout_) {
throw std::runtime_error("Operation has exceeded timeout limit!");
}
checkAndThrowException();
std::this_thread::sleep_for(std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
}
checkAndThrowException();
}
}
bool ProcessGroupLCCL::WorkLCCL::wait(std::chrono::milliseconds timeout)
{
synchronizeInternal(timeout);
return true;
}
void ProcessGroupLCCL::WorkLCCL::synchronize()
{
synchronizeInternal(kNoTimeout);
}
bool ProcessGroupLCCL::WorkLCCL::finishedNPUExecution()
{
checkAndSetException();
return finishedNPUExecutionInternal();
}
std::vector<at::Tensor> ProcessGroupLCCL::WorkLCCL::result()
{
return *outputs_;
}
void ProcessGroupLCCL::WorkLCCL::checkAndThrowException() const
{
checkAndSetException();
if (exception()) {
std::rethrow_exception(exception());
}
}
void ProcessGroupLCCL::WorkLCCL::checkAndSetException() const
{
if (exception()) {
return;
}
}
bool ProcessGroupLCCL::WorkLCCL::finishedNPUExecutionInternal() const
{
if (!c10_npu::NpuSysCtrl::GetInstance().GetInitFlag()) {
return false;
}
try {
for (const auto i: c10::irange(devices_.size())) {
if (!(*lcclEndEvents_)[i].query()) {
return false;
}
}
} catch (const std::exception &e) {
if (std::string(e.what()).find("driver shutting down") == std::string::npos) {
throw std::runtime_error(DIST_ERROR(ErrCode::INTERNAL));
}
LOG(INFO) << "[Rank " << rank_ << "] Event query failed with exception: " << e.what();
}
return true;
}
c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupLCCL::WorkLCCL::getFuture()
{
return future_;
}
const int64_t ProcessGroupLCCL::kProcessGroupLCCLOpTimeoutMillis = 10 * 1000;
ProcessGroupLCCL::ProcessGroupLCCL(const c10::intrusive_ptr<c10d::Store> &store, int rank, int size)
: c10d::Backend(rank, size), blockingWait_(false), store_(store),
opTimeout_(ProcessGroupLCCL::kProcessGroupLCCLOpTimeoutMillis)
{}
std::vector<at_npu::lccl::LcclComm> &ProcessGroupLCCL::getLCCLComm(
const std::string &devicesKey,
const std::vector<at::Device> &devices)
{
if (devicesKey.empty()) {
throw std::runtime_error("Not able to create/get the lccll Communicator since "
"the NPU devices are not known" +
DIST_ERROR(ErrCode::PARAM));
}
{
std::lock_guard<std::mutex> lock(mutex_);
if (devLCCLCommMap_.find(devicesKey) != devLCCLCommMap_.end()) {
return devLCCLCommMap_[devicesKey];
}
}
std::vector<at_npu::lccl::LcclComm> lcclComms;
lcclComms.resize(devices.size());
c10_npu::OptionalNPUGuard npuGuard;
std::vector<c10_npu::NPUStream> streamVal;
streamVal.reserve(devices.size());
for (size_t i = 0; i < devices.size(); ++i) {
npuGuard.set_index(devices[i].index());
auto ret = at_npu::lccl::LcclCommInitRankLocal(size_, rank_, &lcclComms[i]);
TORCH_CHECK(ret == 0, "init lccl comm failed, error code:", ret, PTA_ERROR(ErrCode::INTERNAL));
streamVal.push_back(c10_npu::getNPUStreamFromPool(devices[i].index()));
}
lcclStreams_.emplace(devicesKey, std::move(streamVal));
lcclEvents_.emplace(std::piecewise_construct, std::make_tuple(devicesKey), std::make_tuple(devices.size()));
std::lock_guard<std::mutex> lock(mutex_);
devLCCLCommMap_.emplace(devicesKey, std::move(lcclComms));
return devLCCLCommMap_[devicesKey];
}
template<typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<c10d::Work> ProcessGroupLCCL::collective(std::vector<at::Tensor> &inputs,
std::vector<at::Tensor> &outputs, Fn fn, PreProcess pre,
PostProcess post, c10d::OpType opType)
{
const auto devices = getDeviceList(inputs);
for (const auto& device : devices) {
c10_npu::detail::checkCurrentStreamNotExternal(device.index(), "ProcessGroupLCCL::collective");
}
auto key = getKeyFromDevices(devices);
std::vector<at_npu::lccl::LcclComm> lcclComms;
lcclComms = getLCCLComm(key, devices);
auto &lcclStreams = lcclStreams_[key];
syncStreams(devices, lcclEvents_[key], lcclStreams);
auto work = c10::make_intrusive<ProcessGroupLCCL::WorkLCCL>(devices);
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
c10_npu::OptionalNPUGuard npuGuard;
pre(lcclStreams, work);
for (const auto i: c10::irange(inputs.size())) {
npuGuard.set_index(devices[i].index());
c10_npu::NPUStream &lcclStream = lcclStreams[i];
c10_npu::NPUCachingAllocator::recordStream(inputs[i].storage().data_ptr(), lcclStream);
}
{
for (const auto i: c10::irange(inputs.size())) {
npuGuard.set_index(devices[i].index());
c10_npu::NPUStream &lcclStream = lcclStreams[i];
auto ret = fn(inputs[i], outputs[i], lcclComms[i], lcclStream);
TORCH_CHECK(ret == 0, "LCCL function error:", opTypeToString(opType).c_str(), ", error code is", ret, "\n");
}
}
post(lcclStreams, work);
{
c10_npu::NPUMultiStreamGuard guard(lcclStreams);
work->future_ = c10::make_intrusive<at::ivalue::Future>(c10::ListType::create(c10::TensorType::get()), devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
}
for (size_t i = 0; i < inputs.size(); ++i) {
c10_npu::NPUStream &lcclStream = lcclStreams_[key][i];
(*(work->lcclEndEvents_))[i].record(lcclStream);
ASCEND_LOGI("Event: record lccl work is successfully executed, event=%p", (*(work->lcclEndEvents_))[i].event());
work->lcclComms_[i] = lcclComms[i];
}
work->blockingWait_ = blockingWait_;
work->opTimeout_ = opTimeout_;
return work;
}
template<typename Fn>
c10::intrusive_ptr<c10d::Work> ProcessGroupLCCL::collective(std::vector<at::Tensor> &inputs,
std::vector<at::Tensor> &outputs, Fn fn,
c10d::OpType opType)
{
return collective(
inputs, outputs, fn,
[](std::vector<c10_npu::NPUStream> &, c10::intrusive_ptr<ProcessGroupLCCL::WorkLCCL> &) {
},
[](std::vector<c10_npu::NPUStream> &, c10::intrusive_ptr<ProcessGroupLCCL::WorkLCCL> &) {
},
opType);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupLCCL::allreduce(std::vector<at::Tensor> &tensors,
const c10d::AllreduceOptions &opts)
{
checkTensors(tensors);
std::vector<at::Tensor> tensors_cp = {tensors[0]};
std::string functionName = __FUNCTION__;
return collective(
tensors_cp, tensors_cp,
[&](at::Tensor &input, at::Tensor &output, at_npu::lccl::LcclComm comm, c10_npu::NPUStream &stream) {
auto lcclType = getLcclDataType(input.scalar_type());
checkSupportedDataType(lcclType, functionName);
RECORD_FUNCTION("LcclAllreduce", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForLCCL(input);
auto lcclReduceOp = getLcclReduceOp(opts.reduceOp, input);
auto lccl_call = [inputDataPtr, outputDataPtr, numel, lcclType, lcclReduceOp, stream, comm]() -> int {
auto lccl_result = at_npu::lccl::LcclAllReduce(inputDataPtr, outputDataPtr, numel, lcclType,
lcclReduceOp, comm, stream.stream(false));
return lccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("LcclAllreduce", lccl_call, false, &stream);
return 0;
},
c10d::OpType::ALLREDUCE);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupLCCL::allgather(std::vector<std::vector<at::Tensor>> &outputTensors,
std::vector<at::Tensor> &inputTensors,
const c10d::AllgatherOptions &opts)
{
checkTensors(inputTensors);
auto inputTensors_ = castOriginFormat(inputTensors);
bool same_size = CheckTensorsSameSize(outputTensors.back());
if (same_size) {
auto outputFlattened = FlattenForScatterGather(outputTensors, inputTensors, size_);
checkTensors(outputFlattened);
return collective(
inputTensors_, outputFlattened,
[&](at::Tensor &input, at::Tensor &output, at_npu::lccl::LcclComm comm, c10_npu::NPUStream &stream) {
RECORD_FUNCTION("LcclAllgather", std::vector<c10::IValue>({input}));
c10_npu::NPUCachingAllocator::recordStream(output.storage().data_ptr(), stream);
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForLCCL(input);
auto lcclType = getLcclDataType(input.scalar_type());
auto lccl_call = [inputDataPtr, outputDataPtr, numel, lcclType, comm, stream]() -> int {
auto lccl_result = at_npu::lccl::LcclAllGather(inputDataPtr, outputDataPtr, numel, lcclType, comm,
stream.stream(false));
return lccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("LcclAllgather", lccl_call, false, &stream);
return 0;
},
[&](std::vector<c10_npu::NPUStream> &, c10::intrusive_ptr<ProcessGroupLCCL::WorkLCCL> &) {
},
[&](std::vector<c10_npu::NPUStream> &lcclStreams, c10::intrusive_ptr<ProcessGroupLCCL::WorkLCCL> &work) {
for (const auto i: c10::irange(outputTensors.size())) {
c10_npu::NPUStreamGuard guard(lcclStreams[i]);
for (const auto j: c10::irange(outputTensors[0].size())) {
c10_npu::NPUCachingAllocator::recordStream(outputTensors[i][j].storage().data_ptr(),
lcclStreams[i]);
outputTensors[i][j].copy_(outputFlattened[i][j], true);
}
}
},
c10d::OpType::ALLGATHER);
} else {
TORCH_CHECK(false, "lccl doesn't support to all_gather different shape");
}
}
c10::intrusive_ptr<c10d::Work> ProcessGroupLCCL::broadcast(std::vector<at::Tensor> &tensors,
const c10d::BroadcastOptions &opts)
{
checkTensors(tensors);
return collective(
tensors, tensors,
[&](at::Tensor &input, at::Tensor &output, at_npu::lccl::LcclComm comm, c10_npu::NPUStream &stream) {
RECORD_FUNCTION("LcclBroadcast", std::vector<c10::IValue>({input}));
const auto root = opts.rootRank * tensors.size() + opts.rootTensor;
auto inputDataPtr = input.data_ptr();
auto numel = getNumelForLCCL(input);
auto lcclType = getLcclDataType(input.scalar_type());
auto lccl_call = [inputDataPtr, numel, lcclType, root, comm, stream]() -> int {
auto lccl_result =
at_npu::lccl::LcclBroadcast(inputDataPtr, numel, lcclType, root, comm, stream.stream(false));
return lccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("LcclBroadcast", lccl_call, false, &stream);
return 0;
},
c10d::OpType::BROADCAST);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupLCCL::reduce_scatter(std::vector<at::Tensor> &outputTensors,
std::vector<std::vector<at::Tensor>> &inputTensors,
const c10d::ReduceScatterOptions &opts)
{
checkTensors(outputTensors);
bool same_size = CheckTensorsSameSize(inputTensors.back());
if (same_size) {
auto inputFlattened = FlattenForScatterGather(inputTensors, outputTensors, size_);
checkTensors(inputFlattened);
std::string functionName = __FUNCTION__;
return collective(
inputFlattened, outputTensors,
[&](at::Tensor &input, at::Tensor &output, at_npu::lccl::LcclComm comm, c10_npu::NPUStream &stream) {
auto lcclType = getLcclDataType(input.scalar_type());
checkSupportedDataType(lcclType, functionName);
RECORD_FUNCTION("LcclReduceScatter", std::vector<c10::IValue>({input}));
c10_npu::NPUCachingAllocator::recordStream(output.storage().data_ptr(), stream);
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForLCCL(output);
auto lcclReduceOp = getLcclReduceOp(opts.reduceOp, input);
auto lccl_call = [inputDataPtr, outputDataPtr, numel, lcclType, lcclReduceOp, stream, comm]() -> int {
auto lccl_result = at_npu::lccl::LcclReduceScatter(inputDataPtr, outputDataPtr, numel, lcclType,
lcclReduceOp, comm, stream.stream(false));
return lccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("LcclReduceScatter", lccl_call, false, &stream);
return 0;
},
[&](std::vector<c10_npu::NPUStream> &lcclStreams, c10::intrusive_ptr<ProcessGroupLCCL::WorkLCCL> &work) {
for (const auto i: c10::irange(inputTensors.size())) {
c10_npu::NPUStreamGuard guard(lcclStreams[i]);
for (const auto j: c10::irange(inputTensors[0].size())) {
c10_npu::NPUCachingAllocator::recordStream(inputTensors[i][j].storage().data_ptr(),
lcclStreams[i]);
inputFlattened[i][j].copy_(inputTensors[i][j], true);
}
}
},
[&](std::vector<c10_npu::NPUStream> &, c10::intrusive_ptr<ProcessGroupLCCL::WorkLCCL> &) {
},
c10d::OpType::REDUCE_SCATTER);
} else {
TORCH_CHECK(false, "lccl doesn't support to reduce_scatter different shape");
}
}
ProcessGroupLCCL::~ProcessGroupLCCL()
{
{
std::lock_guard<std::mutex> lock(mutex_);
for (auto &it: devLCCLCommMap_) {
auto &lcclComms = it.second;
for (const auto &lcclComm: lcclComms) {
at_npu::lccl::LcclCommDestroy(lcclComm);
}
}
}
}
}