#include <ATen/record_function.h>
#include <algorithm>
#include <map>
#include <mutex>
#include <tuple>
#include <unordered_set>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include <functional>
#include <cstdlib>
#include <linux/limits.h>
#ifndef BUILD_LIBTORCH
#include <pybind11/pybind11.h>
#include <pybind11/eval.h>
#include <pybind11/embed.h>
#endif
#include <c10/util/Optional.h>
#include <c10/util/irange.h>
#include <c10d/ParamCommsUtils.hpp>
#include <c10d/TraceUtils.h>
#include <c10d/Utils.hpp>
#include <c10d/TCPStore.hpp>
#include <c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include "torch_npu/csrc/distributed/control_plane/Handlers.hpp"
#include <arpa/inet.h>
#include "op_plugin/OpInterface.h"
#include "third_party/acl/inc/acl/acl.h"
#include "third_party/acl/inc/acl/acl_base.h"
#include "torch_npu/csrc/aten/CustomFunctions.h"
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
#include "torch_npu/csrc/core/npu/GetCANNInfo.h"
#include "torch_npu/csrc/core/npu/NPUFunctions.h"
#include "torch_npu/csrc/core/NPUBridge.h"
#include "torch_npu/csrc/core/NPUStorageImpl.h"
#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"
#include "torch_npu/csrc/core/npu/NPUGuard.h"
#include "torch_npu/csrc/core/npu/NPUGraph.h"
#include "torch_npu/csrc/core/npu/NPUGraphsUtils.h"
#include "torch_npu/csrc/core/npu/NPUAffinityController.h"
#include "torch_npu/csrc/core/npu/NPUStream.h"
#include "torch_npu/csrc/core/npu/NPUStreamUtils.h"
#include "torch_npu/csrc/core/npu/register/OptionsManager.h"
#include "torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.h"
#include "torch_npu/csrc/core/npu/interface/OpInterface.h"
#include "torch_npu/csrc/distributed/HCCLUtils.hpp"
#include "torch_npu/csrc/distributed/HcclCompile.h"
#include "torch_npu/csrc/distributed/TraceUtils.h"
#include "torch_npu/csrc/framework/OpHook.h"
#include "torch_npu/csrc/framework/FormatHelper.h"
#include "torch_npu/csrc/framework/utils/OpPreparation.h"
#include "torch_npu/csrc/logging/LogContext.h"
#include "torch_npu/csrc/core/npu/NpuVariables.h"
#include "torch_npu/csrc/distributed/ProcessGroupHCCL.hpp"
#ifndef BUILD_LIBTORCH
#include "torch_npu/csrc/toolkit/profiler/common/utils.h"
#include "torch_npu/csrc/profiler/npu_profiler.h"
namespace py = pybind11;
using namespace py::literals;
#endif
namespace c10d_npu {
namespace {
static constexpr uint32_t kOpWaitTimeoutOffset = 30U;
static uint32_t kOpWaitTimeout = 1868U;
static std::once_flag kOpWaitTimeoutInitFlag;
static int32_t defaultExecTimeout = 1836;
constexpr const char* P2P_DEVICE_KEY = "_p2p";
using hcclUs = std::chrono::steady_clock::time_point;
constexpr int32_t MAX_GROUP_NAME_LEN = 128;
constexpr int32_t NSLB_JOBID_OFFSET = 32;
static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04;
std::map<c10d::ReduceOp, HcclReduceOp> hcclOp = {
{c10d::ReduceOp::MIN, HCCL_REDUCE_MIN},
{c10d::ReduceOp::MAX, HCCL_REDUCE_MAX},
{c10d::ReduceOp::SUM, HCCL_REDUCE_SUM},
{c10d::ReduceOp::PRODUCT, HCCL_REDUCE_PROD},
};
std::map<c10d::ReduceOp, std::string> unsupportedOp = {
{c10d::ReduceOp::BAND, "BAND"},
{c10d::ReduceOp::BOR, "BOR"},
{c10d::ReduceOp::BXOR, "BXOR"}
};
bool nslb_is_end = false;
std::string device_error_msg;
bool force_stop_error_flag = false;
const char* nslb_path = c10_npu::option::OptionsManager::GetNslbPath();
bool status_save_enable = c10_npu::option::OptionsManager::CheckStatusSaveEnable();
std::string status_save_path = c10_npu::option::OptionsManager::GetStatusSavePath();
inline c10_npu::NPUStream getNPUStreamByCurrentType(c10::DeviceIndex device = -1)
{
auto current_Stream = c10_npu::getCurrentNPUStream(device);
if (!current_Stream.isSyncLaunchStream()) {
bool force_high = c10d::getCvarBool(TORCH_HCCL_HIGH_PRIORITY, false);
auto s = c10_npu::getStreamFromPool(force_high, device);
TORCH_NPU_HCCL_LOGD("Get stream, stream id: %zu", static_cast<size_t>(s.id()))
return s;
}
return c10_npu::getNPUStreamFromSyncLaunchPool(device);
}
int64_t physical_numel(const at::Tensor& self)
{
auto sizes = torch_npu::NPUBridge::GetNpuStorageImpl(self)->npu_desc_.storage_sizes_;
int64_t n = 1;
for (auto s : sizes) {
n *= s;
}
return n;
}
uint64_t getNumelForHCCL(const at::Tensor& self)
{
if (!at_npu::native::FormatHelper::IsBaseFormatType(self)) {
if (self.storage().data_ptr().get() != self.data_ptr()) {
TORCH_CHECK(false, "For a tensor of internal format, it's storage_offset must be 0", DIST_ERROR(ErrCode::NOT_SUPPORT));
}
return physical_numel(self);
}
return self.numel();
}
HcclReduceOp getHcclReduceOp(const c10d::ReduceOp reduceOp, at::Tensor& input)
{
if (reduceOp == c10d::ReduceOp::AVG) {
return HCCL_REDUCE_SUM;
}
if (reduceOp == c10d::ReduceOp::PREMUL_SUM) {
TORCH_CHECK(
input.scalar_type() == at::kHalf || input.scalar_type() == at::kFloat ||
input.scalar_type() == at::kBFloat16 || input.scalar_type() == at::kDouble,
"PreMulSum Data type must be half, float, bfloat16 or double",
DIST_ERROR(ErrCode::TYPE));
return HCCL_REDUCE_SUM;
}
if (reduceOp == c10d::ReduceOp::SUM && input.scalar_type() == at::kBool) {
return HCCL_REDUCE_MAX;
}
if (unsupportedOp.find(reduceOp) != unsupportedOp.end()) {
TORCH_CHECK(false,
"Cannot use ReduceOp." + unsupportedOp[reduceOp] + " with HCCL",
DIST_ERROR(ErrCode::NOT_SUPPORT));
} else if (hcclOp.find(reduceOp) == hcclOp.end()) {
TORCH_CHECK(false, "Unhandled ReduceOp", DIST_ERROR(ErrCode::NOT_FOUND));
}
return hcclOp[reduceOp];
}
void checkSupportedDataType(HcclDataType type, std::string functionName)
{
static std::set<HcclDataType> supportedDataTypes = {
HCCL_DATA_TYPE_INT8,
HCCL_DATA_TYPE_INT16,
HCCL_DATA_TYPE_INT32,
HCCL_DATA_TYPE_FP16,
HCCL_DATA_TYPE_FP32,
HCCL_DATA_TYPE_BFP16,
HCCL_DATA_TYPE_INT64};
TORCH_CHECK(
supportedDataTypes.count(type) != 0,
"HCCL "+functionName+": Unsupported data type ",
getHcclDataTypeSerialString(type), DIST_ERROR(ErrCode::NOT_SUPPORT));
}
std::string getKeyFromDevices(const std::vector<at::Device>& devices)
{
std::string deviceList;
for (auto& device : devices) {
if (deviceList.empty()) {
deviceList = std::to_string(device.index());
} else {
deviceList += "," + std::to_string(device.index());
}
}
return deviceList;
}
std::string getKeyFromDevice(const std::vector<at::Device>& devices)
{
return std::to_string(devices[0].index());
}
std::vector<at::Device> getDeviceList(const std::vector<at::Tensor>& tensors)
{
std::vector<at::Device> res;
res.reserve(tensors.size());
for (auto& tensor : tensors) {
res.push_back(tensor.device());
}
return res;
}
std::vector<at::Device> getDevice(const std::vector<at::Tensor>& tensors)
{
std::vector<at::Device> res;
res.reserve(1);
res.push_back(tensors[0].device());
return res;
}
at::Device getDeviceForRank(int rank)
{
TORCH_CHECK(rank >= 0, "Invalid rank ", rank, DIST_ERROR(ErrCode::VALUE));
auto numNPUs = c10_npu::device_count();
TORCH_CHECK(numNPUs > 0, "Invalid device number", numNPUs, DIST_ERROR(ErrCode::VALUE));
int16_t deviceIdx = static_cast<int16_t>(rank % numNPUs);
return at::Device(c10::DeviceType::PrivateUse1, deviceIdx);
}
std::string getKeySendRecv(int myRank, int peer)
{
int lowRank = myRank < peer ? myRank : peer;
int highRank = myRank < peer ? peer : myRank;
std::string sendRecvPair = std::to_string(lowRank) + ":" + std::to_string(highRank);
return sendRecvPair;
}
void syncStreams(
const std::vector<at::Device>& devices,
std::vector<c10_npu::NPUEvent>& hcclEvents,
std::vector<c10_npu::NPUStream>& hcclStreams)
{
for (size_t i = 0; i < devices.size(); ++i) {
c10_npu::NPUStream& hcclStream = hcclStreams[i];
c10_npu::NPUEvent& hcclEvent = hcclEvents[i];
hcclEvent.record(c10_npu::getCurrentNPUStream(devices[i].index()));
hcclEvent.block(hcclStream);
TORCH_NPU_HCCL_LOGI("Event: record and block hccl group is successfully executed, event=%p", hcclEvent.event());
}
}
std::string getExceptionMsgFromExceptionPtr(const std::exception_ptr& exceptionPtr)
{
TORCH_CHECK(exceptionPtr != nullptr, DIST_ERROR(ErrCode::PTR));
try {
std::rethrow_exception(exceptionPtr);
} catch (const std::exception& e) {
return e.what();
} catch (...) {
return "Unknown exception type";
}
}
bool getDeterministicState()
{
static bool cachedDeterministicState = []() {
const char* envValue = std::getenv("HCCL_DETERMINISTIC");
if (envValue != nullptr) {
TORCH_NPU_ENV_LOGI("get env HCCL_DETERMINISTIC = %s", envValue);
std::string valueStr(envValue);
std::transform(valueStr.begin(), valueStr.end(), valueStr.begin(), ::tolower);
if (valueStr == "true") {
return true;
}
}
return at::globalContext().deterministicAlgorithms();
}();
return cachedDeterministicState;
}
void getHcclCommConfig(HcclCommConfig* config, bool isP2P = false)
{
HcclCommConfigInit(config);
if (!isP2P) {
config->hcclBufferSize = c10_npu::option::OptionsManager::GetHcclBufferSize();
} else {
config->hcclBufferSize = c10_npu::option::OptionsManager::GetP2PBufferSize();
}
static const bool isCannVersionGteBase = []() {
const std::string baseCannversion = "8.2.RC1";
const std::string baseCannModule = "CANN";
return IsGteCANNVersion(baseCannversion, baseCannModule);
}();
if (isCannVersionGteBase) {
config->hcclDeterministic = 0xffffffff;
} else {
config->hcclDeterministic = getDeterministicState() ? 1 : 0;
}
if (!isHcclFeatureSupported(HcclCommConfigCapability::HCCL_COMM_CONFIG_COMM_NAME)) {
size_t *configSize = reinterpret_cast<size_t *>(config);
*configSize = 32;
}
}
void checkHcclCommConfigValid(const HcclCommConfig* config)
{
if (strlen(config->hcclCommName) > 0) {
TORCH_CHECK(isHcclFeatureSupported(HcclCommConfigCapability::HCCL_COMM_CONFIG_COMM_NAME),
"The current version of CANN does not support the hcclCommName:", config->hcclCommName,
DIST_ERROR(ErrCode::NOT_SUPPORT));
}
}
std::unordered_map<std::string, std::string> checkEnvVarOrLogWarning()
{
std::unordered_map<std::string, std::string> map;
map["enable"] = "true";
const char* local_rank_env = getenv("LOCAL_RANK");
if (local_rank_env == nullptr) {
map["enable"] = "false";
TORCH_NPU_WARN_ONCE("Environment variable 'LOCAL_RANK' is not set. And HCCL_ZERO_COPY will not enable.",
"Please try to launch the process by using torchrun or configure the 'LOCAL_RANK' environment variable.");
} else {
map["local_rank"] = local_rank_env;
}
const char* global_rank_env = getenv("RANK");
if (global_rank_env == nullptr) {
map["enable"] = "false";
TORCH_NPU_WARN_ONCE("Environment variable 'RANK' is not set. And HCCL_ZERO_COPY will not enable.",
"Please try to launch the process by using torchrun or configure the 'RANK' environment variable.");
} else {
map["global_rank"] = global_rank_env;
}
const char* nodes_rank_env = getenv("GROUP_RANK");
if (nodes_rank_env == nullptr) {
map["enable"] = "false";
TORCH_NPU_WARN_ONCE("Environment variable 'GROUP_RANK' is not set. And HCCL_ZERO_COPY will not enable.",
"Please try to launch the process by using torchrun or configure the 'GROUP_RANK' environment variable.");
} else {
map["nodes_rank"] = nodes_rank_env;
}
const char* local_world_size_env = getenv("LOCAL_WORLD_SIZE");
if (local_world_size_env == nullptr) {
map["enable"] = "false";
TORCH_NPU_WARN_ONCE("Environment variable 'LOCAL_WORLD_SIZE' is not set. And HCCL_ZERO_COPY will not enable.",
"Please try to launch the process by using torchrun or configure the 'LOCAL_WORLD_SIZE' environment variable.");
} else {
map["local_world_size"] = local_world_size_env;
}
return map;
}
void fill_equal_split_sizes_when_empty(std::vector<int64_t>& split_sizes, at::Tensor tensor, int group_size)
{
if (!split_sizes.empty()) {
return;
}
TORCH_CHECK(group_size > 0, "Invalid group size within current process group", group_size, DIST_ERROR(ErrCode::PARAM));
TORCH_CHECK(
tensor.size(0) % group_size == 0,
"Tensor's dim 0 does not divide equally across group size",
DIST_ERROR(ErrCode::PARAM));
int64_t equal_split_size = static_cast<int64_t>(tensor.size(0) / group_size);
for (int i = 0; i < group_size; i++) {
split_sizes.push_back(equal_split_size);
}
}
void check_split_sizes(const std::vector<int64_t>& split_sizes, const at::Tensor& tensor, int group_size)
{
if (split_sizes.empty()) {
TORCH_CHECK(tensor.size(0) % group_size == 0, "Tensor's dim 0 does not divide equally across group size",
DIST_ERROR(ErrCode::PARAM));
} else {
TORCH_CHECK(
split_sizes.size() == static_cast<size_t>(group_size), "Number of tensor splits not equal to group size",
DIST_ERROR(ErrCode::TYPE));
const auto sum = c10::sum_integers(split_sizes);
TORCH_CHECK(sum == tensor.size(0), "Split sizes dosen't match total dim 0 size", DIST_ERROR(ErrCode::TYPE));
}
}
void checkAndMakePath(const char* path, std::string errormessage)
{
try {
if (access(path, W_OK) != 0 && mkdir(path, S_IRWXU | S_IRGRP | S_IXGRP) != 0) {
throw std::exception();
}
} catch (std::exception& e) {
throw std::runtime_error(errormessage + DIST_ERROR(ErrCode::NOT_FOUND));
}
}
void createFile(const char* path)
{
int fd = open(path, O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR | S_IRGRP);
if (fd == -1) {
throw std::runtime_error("Create file failed. Please check whether input file is valid." + DIST_ERROR(ErrCode::NOT_FOUND));
}
close(fd);
}
inline bool IsCompatibleSoc()
{
static const bool is_compatible = []() {
auto soc_version = c10_npu::GetSocVersion();
return ((soc_version >= c10_npu::SocVersion::Ascend910B1) && (soc_version < c10_npu::SocVersion::Ascend310B1)) ||
((soc_version >= c10_npu::SocVersion::Ascend910_9391) && (soc_version < c10_npu::SocVersion::Ascend950));
}();
return is_compatible;
}
}
constexpr int64_t kSynchronizeBusyWaitMillis = 1;
constexpr int64_t maxOpNumPerSyncPoint = 2;
const int64_t ProcessGroupHCCL::kProcessGroupHCCLOpTimeoutMillis = 10 * 1000;
thread_local uint64_t ProcessGroupHCCL::hcclActiveGroupCounter_ = 0;
const int64_t ProcessGroupHCCL::kWatchdogThreadSleepMillis = 1000;
std::string ProcessGroupHCCL::perfdumppath = "";
ProcessGroupHCCL* ProcessGroupHCCL::global_ = nullptr;
std::unordered_map<std::string, ProcessGroupHCCL::StatusStruct> ProcessGroupHCCL::StatusOutput_;
int ProcessGroupHCCL::deviceId_ = -1;
int ProcessGroupHCCL::numRanks_ = -1;
std::string ProcessGroupHCCL::exceptionMessage_ = "";
std::atomic<bool> ProcessGroupHCCL::shouldDump_(false);
std::atomic<bool> ProcessGroupHCCL::monitorThreadEnabled_(false);
std::string dump_hccl_trace(
bool includeCollectives,
bool includeStackTraces,
bool onlyActive)
{
return HCCLTraceBuffer::get()->dump(
c10::nullopt, includeCollectives, includeStackTraces, onlyActive);
}
std::string dump_hccl_trace_json(bool includeCollectives, bool onlyActive)
{
return HCCLTraceBuffer::get()->dump_json(
c10::nullopt, includeCollectives, onlyActive);
}
static c10d::control_plane::RegisterHandler dumpHcclHandler{
"dump_hccl_trace_pickle",
[](const c10d::control_plane::Request& req, c10d::control_plane::Response& res) {
const auto& params = req.params();
size_t validParamCount = 0;
const std::string includeCollectivesStr = "includecollectives";
const std::string includeStackTracesStr = "includestacktraces";
const std::string onlyActiveStr = "onlyactive";
std::unordered_map<std::string, bool> processedParams = {
{includeCollectivesStr, true},
{includeStackTracesStr, true},
{onlyActiveStr, false}};
for (const auto& [paramName, paramValue] : params) {
auto it = processedParams.find(paramName);
if (it != processedParams.end()) {
validParamCount++;
if (paramValue == "true") {
it->second = true;
} else if (paramValue == "false") {
it->second = false;
} else {
res.setStatus(400);
res.setContent(
"Invalid value for " + paramName +
" valid values are true or false",
"text/plain");
return;
}
}
}
if (validParamCount < params.size()) {
res.setStatus(400);
res.setContent(
"Invalid parameters - unexpected param passed in", "text/plain");
return;
}
res.setContent(
dump_hccl_trace(
processedParams[includeCollectivesStr],
processedParams[includeStackTracesStr],
processedParams[onlyActiveStr]),
"application/octet-stream");
}};
static c10d::control_plane::RegisterHandler jsonDumpHcclHandler{
"dump_hccl_trace_json",
[](const c10d::control_plane::Request& req, c10d::control_plane::Response& res) {
const auto& params = req.params();
size_t validParamCount = 0;
const std::string includeCollectivesStr = "includecollectives";
const std::string onlyActiveStr = "onlyactive";
std::unordered_map<std::string, bool> processedParams = {
{includeCollectivesStr, true}, {onlyActiveStr, false}};
for (const auto& [paramName, paramValue] : params) {
auto it = processedParams.find(paramName);
if (it != processedParams.end()) {
validParamCount++;
if (paramValue == "true") {
it->second = true;
} else if (paramValue == "false") {
it->second = false;
} else {
res.setStatus(400);
res.setContent(
"Invalid value for " + paramName +
" valid values are true or false",
"text/plain");
return;
}
}
}
if (validParamCount < params.size()) {
res.setStatus(400);
res.setContent(
"Invalid parameters - unexpected param passed in", "text/plain");
return;
}
res.setStatus(200);
res.setContent(
dump_hccl_trace_json(
processedParams[includeCollectivesStr],
processedParams[onlyActiveStr]),
"application/json");
}};
c10::optional<std::function<void(std::function<void(const std::string &)>)>> &get_cpp_trace_dumper()
{
static c10::optional<
std::function<void(std::function<void(const std::string &)>)>>
dumper(c10::nullopt);
return dumper;
}
gil_checker_t &get_gil_checker()
{
static gil_checker_t gil_checker = nullptr;
return gil_checker;
}
std::future<bool> launchAsyncGilCheck()
{
std::promise<bool> resultPromise;
std::future<bool> resultFuture = resultPromise.get_future();
TORCH_CHECK(get_gil_checker(), "Can't check GIL with null GIL checker");
std::thread workerThread([promise = std::move(resultPromise)]() mutable {
try {
auto& gil_checker = get_gil_checker();
promise.set_value((*gil_checker)());
} catch (...) {
promise.set_exception(std::current_exception());
}
});
workerThread.detach();
return resultFuture;
}
std::ostream& operator<<(std::ostream& output, const ProcessGroupHCCL::WorkHCCL& workHCCL)
{
std::string workInfo = c10::str(
"WorkHCCL(",
"SeqNum=",
workHCCL.seq_,
", OpType=",
opTypeToString(workHCCL.opType_),
", NumelIn=",
workHCCL.numelIn_,
", NumelOut=",
workHCCL.numelOut_,
", Timeout(ms)=",
workHCCL.opTimeout_.count(),
")");
return output << workInfo;
}
std::string get_device_error(const std::string& error_msg)
{
static const std::vector<std::string> device_errors = {
"UCE ERROR",
"HBM MULTI BIT ECC ERROR",
"SUSPECT MEM ERROR",
"HCCS LINK ERROR",
"UB LINK ERROR",
"HCCL OP RETRY FAILED",
"SUSPECT REMOTE ERROR"
};
for (const auto& err : device_errors) {
if (error_msg.find(err) != std::string::npos) {
return err;
}
}
return "";
}
void TensorShelf::stash(std::vector<at::Tensor>& tensors)
{
std::lock_guard<std::mutex> lock(mutex_);
tVector_.insert(tVector_.end(), tensors.begin(), tensors.end());
}
void TensorShelf::stash(TensorShelf& other)
{
std::vector<at::Tensor>& otherVec = other.get();
this->stash(otherVec);
}
void TensorShelf::stash(const at::Tensor& tensor)
{
std::lock_guard<std::mutex> lock(mutex_);
tVector_.push_back(tensor);
}
void TensorShelf::unstash()
{
this->clear();
}
bool TensorShelf::empty()
{
std::lock_guard<std::mutex> lock(mutex_);
return tVector_.empty();
}
void TensorShelf::clear()
{
std::lock_guard<std::mutex> lock(mutex_);
tVector_.clear();
}
std::vector<at::Tensor>& TensorShelf::get()
{
return tVector_;
}
ProcessGroupHCCL::WorkHCCL::WorkHCCL(
const std::vector<at::Device>& devices,
int rank,
c10d::OpType opType,
uint64_t seq,
bool desyncDebug)
: Work(rank, opType),
devices_(devices),
workStartTime_(std::chrono::steady_clock::now()),
seq_(seq)
{
if (desyncDebug || (status_save_enable) || ProcessGroupHCCL::monitorThreadEnabled_.load()) {
hcclStartEvents_ = std::make_shared<std::vector<c10_npu::NPUEvent>>();
hcclStartEvents_->reserve(devices.size());
for (size_t i = 0; i < devices.size(); i++) {
hcclStartEvents_->emplace_back(ACL_EVENT_CAPTURE_STREAM_PROGRESS);
}
}
hcclEndEvents_ = std::make_shared<std::vector<c10_npu::NPUEvent>>(devices.size());
hcclComms_.resize(devices.size());
}
ProcessGroupHCCL::WorkHCCL::WorkHCCL(const WorkHCCL& w)
: Work(w.rank_, w.opType_),
std::enable_shared_from_this<WorkHCCL>(w),
devices_(w.devices_),
hcclStartEvents_(w.hcclStartEvents_),
hcclComms_(w.hcclComms_),
hcclEndEvents_(w.hcclEndEvents_),
blockingWait_(w.blockingWait_),
opTimeout_(w.opTimeout_),
ownedEphermeralTimeout_(w.ownedEphermeralTimeout_),
workStartTime_(w.workStartTime_),
seq_(w.seq_),
startTraceUpdated_(w.startTraceUpdated_),
numelIn_(w.numelIn_),
numelOut_(w.numelOut_),
store_(w.store_),
is_dispatched(w.is_dispatched),
is_reported(w.is_reported),
is_dumped(w.is_dumped),
trace_id_(w.trace_id_)
{
exception_ = w.exception_;
}
ProcessGroupHCCL::WorkHCCL::~WorkHCCL() {}
bool ProcessGroupHCCL::WorkHCCL::isCompleted()
{
checkAndSetException();
return exception() || finishedNPUExecutionInternal();
}
bool ProcessGroupHCCL::WorkHCCL::isStarted(ErrorHandlingMode errorHandling)
{
checkAndSetException();
return exception() || startedNPUExecutionInternal(errorHandling);
}
bool ProcessGroupHCCL::WorkHCCL::isSuccess() const
{
if (exception()) {
return false;
}
return !checkForHCCLErrors(hcclComms_) && finishedNPUExecutionInternal();
}
void ProcessGroupHCCL::WorkHCCL::checkAndSetException()
{
if (exception()) {
return;
}
auto exception_ptr = checkForHCCLErrors(hcclComms_);
std::unique_lock<std::mutex> lock(mutex_);
exception_ = exception_ptr;
if (exception_) {
TORCH_NPU_HCCL_LOGE("[Rank %d], found async exception when checking for HCCL errors: %s", rank_,
getExceptionMsgFromExceptionPtr(exception_).c_str());
}
}
void ProcessGroupHCCL::WorkHCCL::setException(std::exception_ptr exception_ptr)
{
std::unique_lock<std::mutex> lock(mutex_);
exception_ = exception_ptr;
}
bool ProcessGroupHCCL::WorkHCCL::finishedNPUExecution()
{
checkAndSetException();
return finishedNPUExecutionInternal();
}
bool ProcessGroupHCCL::WorkHCCL::startedNPUExecutionInternal(ErrorHandlingMode errorHandling) const
{
try {
for (const auto i : c10::irange(devices_.size())) {
if (!(*hcclStartEvents_)[i].query()) {
return false;
}
}
} catch (const std::exception& e) {
std::string exceptionMsg = std::string(e.what());
std::string device_error = get_device_error(exceptionMsg);
if (!device_error.empty()) {
TORCH_NPU_HCCL_LOGI("Find %s when startedNPUExecutionInternal.", device_error.c_str());
device_error_msg = device_error;
return false;
}
if (exceptionMsg.find("FORCE STOP") != std::string::npos) {
TORCH_NPU_HCCL_LOGI("Find FORCE STOP when startedNPUExecutionInternal.");
force_stop_error_flag = true;
return false;
}
if (exceptionMsg.find("driver shutting down") == std::string::npos) {
std::call_once(print_flag, [&exceptionMsg]() {
TORCH_NPU_HCCL_LOGE("Find exception when startedNPUExecutionInternal, %s.", exceptionMsg.c_str());
});
if (SHOULD_TEAR_DOWN(errorHandling)) {
throw std::runtime_error(DIST_ERROR(ErrCode::INTERNAL));
}
}
LOG(INFO) << "[Rank " << rank_ << "] Event query failed with exception: " << e.what();
}
return true;
}
bool ProcessGroupHCCL::WorkHCCL::finishedNPUExecutionInternal() const
{
if (!c10_npu::NpuSysCtrl::GetInstance().GetInitFlag()) {
return false;
}
try {
for (const auto i : c10::irange(devices_.size())) {
if (!(*hcclEndEvents_)[i].query()) {
return false;
}
}
} catch (const std::exception& e) {
std::string exceptionMsg = std::string(e.what());
std::string device_error = get_device_error(exceptionMsg);
if (!device_error.empty()) {
TORCH_NPU_HCCL_LOGI("Find %s when finishedNPUExecutionInternal.", device_error.c_str());
device_error_msg = device_error;
return false;
}
if (exceptionMsg.find("FORCE STOP") != std::string::npos) {
TORCH_NPU_HCCL_LOGI("Find FORCE STOP when finishedNPUExecutionInternal.");
force_stop_error_flag = true;
return false;
}
if (exceptionMsg.find("driver shutting down") == std::string::npos) {
TORCH_NPU_HCCL_LOGE("Find exception when finishedNPUExecutionInternal, %s.", exceptionMsg.c_str());
throw std::runtime_error(DIST_ERROR(ErrCode::INTERNAL));
}
LOG(INFO) << "[Rank " << rank_ << "] Event query failed with exception: " << e.what();
}
return true;
}
bool ProcessGroupHCCL::WorkHCCL::checkTimeout(c10::optional<std::chrono::milliseconds> timeout)
{
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(currentTimepoint - workStartTime_);
auto workTimeout = timeout ? *timeout : opTimeout_;
if (timeElapsed < workTimeout)
return false;
if (exception())
return true;
std::string exceptionMsg = c10::str(
"[Rank ",
rank_,
"] ",
"Watchdog caught collective operation timeout: ",
*this,
" ran for ",
timeElapsed.count(),
" milliseconds before timing out.");
LOG(ERROR) << exceptionMsg;
std::exception_ptr exception_ptr =
std::make_exception_ptr(std::runtime_error(exceptionMsg));
setException(exception_ptr);
return true;
}
std::chrono::milliseconds GetDispatchTimeout() noexcept
{
uint32_t dispatchTimeout_ = 600U;
uint32_t dispatchoffset = 30U;
uint32_t mindispatchTimeout_ = 120U;
int32_t hccl_exec_timeout = c10_npu::option::OptionsManager::GetHCCLExecTimeout();
if (hccl_exec_timeout > 0) {
if (static_cast<uint32_t>(hccl_exec_timeout) < dispatchTimeout_ + dispatchoffset && static_cast<uint32_t>(hccl_exec_timeout) > mindispatchTimeout_ + dispatchoffset) {
dispatchTimeout_ = static_cast<uint32_t>(hccl_exec_timeout) - dispatchoffset;
};
};
TORCH_NPU_HCCL_LOGI("set dispatchTimeout_ %u s.", dispatchTimeout_);
return std::chrono::milliseconds(dispatchTimeout_ * 1000U);
}
std::chrono::milliseconds dispatchTimeout_ = GetDispatchTimeout();
void ProcessGroupHCCL::WorkHCCL::checkDispatch()
{
if (!*is_dispatched && !is_reported) {
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(currentTimepoint - workStartTime_);
if (timeElapsed > dispatchTimeout_) {
std::string repo_info = c10_npu::getRepoInfo();
TORCH_NPU_HCCL_LOGW("Process group work %s, seq_num %u dispatch timeout. %s", opTypeToString(opType_).c_str(), seq_, repo_info.c_str());
is_reported = true;
}
} else if (*is_dispatched && is_reported) {
TORCH_NPU_HCCL_LOGW("Process group work %s, seq_num %u dispatch success. This warning log can be ignored.", opTypeToString(opType_).c_str(), seq_);
is_reported = false;
}
}
bool ProcessGroupHCCL::WorkHCCL::checkExec()
{
if (is_dumped) {
return false;
}
static int32_t hccl_exec_timeout = c10_npu::option::OptionsManager::GetHCCLExecTimeout();
if (hccl_exec_timeout <= 0) {
hccl_exec_timeout = defaultExecTimeout;
}
int32_t timeout = std::max(60, hccl_exec_timeout - 60);
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(currentTimepoint - workStartTime_);
if (timeElapsed > std::chrono::milliseconds(timeout * 1000)) {
is_dumped = true;
return true;
}
return false;
}
void ProcessGroupHCCL::WorkHCCL::synchronize()
{
synchronizeInternal(kNoTimeout);
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupHCCL::WorkHCCL>::unsafe_reclaim_from_nonowning(this));
}
}
void ProcessGroupHCCL::WorkHCCL::handleException(ErrorHandlingMode errorHandling)
{
if (exception_) {
auto exceptionMsg = c10::str(
"Some HCCL operations have failed or timed out. Due to the ",
"asynchronous nature of ASCEND kernels, subsequent NPU operations ",
"might run on corrupted/incomplete data.");
LOG(ERROR) << exceptionMsg;
C10_LOG_API_USAGE_ONCE("ProcessGroupHCCL.WorkHCCL.handleException");
if (SHOULD_TEAR_DOWN(errorHandling)) {
auto tearDownMsg = c10::str(
"To avoid data inconsistency, we are taking the entire process down.");
LOG(ERROR) << tearDownMsg;
std::rethrow_exception(exception_);
}
}
}
void ProcessGroupHCCL::WorkHCCL::checkAndThrowException()
{
checkAndSetException();
if (exception()) {
std::rethrow_exception(exception());
}
}
void ProcessGroupHCCL::WorkHCCL::synchronizeInternal(std::chrono::milliseconds timeout)
{
for (const auto i : c10::irange(devices_.size())) {
auto currentStream = c10_npu::getCurrentNPUStream(devices_[i].index());
(*hcclEndEvents_)[i].block(currentStream);
TORCH_NPU_HCCL_LOGI("Event: block hccl work is successfully executed, event=%p", (*hcclEndEvents_)[i].event());
if (!barrierTensors_.empty()) {
c10_npu::NPUGuard npuGuard(devices_[i]);
c10_npu::npuSynchronizeDevice();
}
}
if (!recorded_inputs_.empty()) {
auto multi_stream_memory_reuse_mode = c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse();
for (auto i = 0; i < recorded_inputs_.size(); ++i) {
auto storage = recorded_inputs_[i].first.lock();
if (storage) {
c10_npu::NPUCachingAllocator::eraseStream(storage->data_ptr(), recorded_inputs_[i].second);
} else if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
c10_npu::NPUCachingAllocator::eraseStreamWithBlockPtr(recorded_block_ptr_for_inputs_[i], recorded_inputs_[i].second, static_cast<void*>(this));
}
}
}
if (!recorded_outputs_.empty()) {
for (auto it = recorded_outputs_.begin(); it != recorded_outputs_.end(); ++it) {
auto storage = it->first.lock();
if (storage) {
c10_npu::NPUCachingAllocator::eraseStream(storage->data_ptr(), it->second);
}
}
}
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::ERASE_RECORD_STREAM) {
lazy_destroy_tensors_.clear();
} else if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::AVOID_RECORD_STREAM) {
stashed_for_allocator_safety_.clear();
}
if (blockingWait_) {
while (!isCompleted()) {
bool timedOut = checkTimeout(timeout == kNoTimeout ? c10::nullopt : c10::make_optional(timeout));
if (timedOut) {
std::string exceptionMsg = c10::str(
"[Rank ",
rank_,
"] Work ",
(*this),
" timed out in blocking wait "
"(TORCH_HCCL_BLOCKING_WAIT=1 or HCCL_BLOCKING_WAIT=1).");
LOG(ERROR) << exceptionMsg;
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
}
if (exception()) {
abort();
handleException(TearDown);
}
}
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PostHook();
}
}
void ProcessGroupHCCL::WorkHCCL::lazyDestroy(std::vector<at::Tensor> tensors)
{
if (tensors.empty() ||
(c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() != c10_npu::option::ERASE_RECORD_STREAM)) {
return;
}
TORCH_NPU_HCCL_LOGD("Lazy destroy tensors, tensors size: %zu", tensors.size());
for (const auto i : c10::irange(tensors.size())) {
lazy_destroy_tensors_.push_back(tensors[i]);
}
}
bool ProcessGroupHCCL::WorkHCCL::wait(std::chrono::milliseconds timeout)
{
synchronizeInternal(timeout);
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupHCCL::WorkHCCL>::unsafe_reclaim_from_nonowning(this));
}
return true;
}
void ProcessGroupHCCL::WorkHCCL::abort()
{
for (const auto& hcclComm : hcclComms_) {
hcclComm->destroyHcclComm();
}
}
c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupHCCL::WorkHCCL::getFuture()
{
return future_;
}
std::vector<at::Tensor> ProcessGroupHCCL::WorkHCCL::result()
{
return *outputs_;
}
static std::atomic<size_t> process_group_id = 0;
ProcessGroupHCCL::ProcessGroupHCCL(
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
c10::intrusive_ptr<Options> options)
: c10d::Backend(rank, size),
store_(store),
options_(c10::make_intrusive<Options>(*options.get())),
hcclCommCounter_(0),
traceKeyStart_("HCCL_" + std::to_string(rank) + "_trace_start"),
traceKeyEnd_("HCCL_" + std::to_string(rank) + "_trace_end"),
terminateProcessGroup_(false),
terminateHeartbeatMonitorThread_(false),
collectiveDebugInfoMode_(false),
uid_(process_group_id++)
{
std::string groupName = "group_name_" + options->group_id;
this->setGroupName(groupName);
int32_t hccl_event_timeout = c10_npu::option::OptionsManager::GetHCCLEventTimeout();
int32_t hccl_exec_timeout = c10_npu::option::OptionsManager::GetHCCLExecTimeout();
if (hccl_exec_timeout < 0) {
hccl_exec_timeout = defaultExecTimeout;
}
std::call_once(kOpWaitTimeoutInitFlag, [&]() {
if (hccl_event_timeout > 0) {
kOpWaitTimeout = static_cast<uint32_t>(hccl_event_timeout);
if (hccl_event_timeout <= hccl_exec_timeout) {
TORCH_NPU_WARN_ONCE("The value of HCCL_EVENT_TIMEOUT:", hccl_event_timeout, " is less than or equal to the value of HCCL_EXEC_TIMEOUT:", hccl_exec_timeout, ".");
} else if (hccl_exec_timeout == 0) {
TORCH_NPU_WARN_ONCE("The value of HCCL_EXEC_TIMEOUT was set to 0(never timeout), so it is bigger than the value of HCCL_EVENT_TIMEOUT:", hccl_event_timeout, ".");
}
} else if (hccl_event_timeout == 0) {
kOpWaitTimeout = 0;
} else {
if (hccl_exec_timeout == 0) {
kOpWaitTimeout = 0;
} else {
kOpWaitTimeout = static_cast<uint32_t>(hccl_exec_timeout) + kOpWaitTimeoutOffset;
if (kOpWaitTimeout <= static_cast<uint32_t>(hccl_exec_timeout)) {
kOpWaitTimeout = UINT_MAX;
}
}
}
TORCH_NPU_HCCL_LOGI("Set op wait timeout to %u.", kOpWaitTimeout);
NPU_CHECK_ERROR(c10_npu::acl::AclrtSetOpWaitTimeout(kOpWaitTimeout));
});
blockingWait_ = c10d::getCvarBool(TORCH_HCCL_BLOCKING_WAIT, false);
logPrefix_ = createLogPrefix();
if (options_->global_ranks_in_group.empty()) {
numRanks_ = size_;
}
dumpOnException_ = c10d::getCvarBool(TORCH_HCCL_DUMP_ON_TIMEOUT, false);
heartbeat_ = 1ULL;
monitorThreadEnabled_.store(c10d::getCvarBool(TORCH_HCCL_ENABLE_MONITORING, false));
heartbeatTimeoutInSec_ = c10d::getCvarInt(TORCH_HCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 10);
waitTimeoutDumpInMilSec_ = c10d::getCvarInt(TORCH_HCCL_WAIT_TIMEOUT_DUMP_MILSEC, 60 * 1000);
coordCheckIntervalMilSec_ = c10d::getCvarInt(TORCH_HCCL_COORD_CHECK_MILSEC, 1000);
hcclTraceBufferSize_ = c10d::getCvarInt(TORCH_HCCL_TRACE_BUFFER_SIZE, 0);
c10d::PrefixStore *prefixStore = dynamic_cast<c10d::PrefixStore *>(store_.get());
globalStore_ = prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_;
c10::intrusive_ptr<c10d::Store> getTcpStore = store_;
while (getTcpStore) {
c10d::PrefixStore *asPrefixStore = dynamic_cast<c10d::PrefixStore *>(getTcpStore.get());
c10d::TCPStore *tcpStore = dynamic_cast<c10d::TCPStore *>(getTcpStore.get());
if (tcpStore) {
if (!(tcpStore->getHost().empty())) {
tcpMasterAddr = tcpStore->getHost();
tcpMasterPort = tcpStore->getPort();
break;
}
}
if (asPrefixStore) {
getTcpStore = asPrefixStore->getUnderlyingStore();
} else {
break;
}
}
asyncErrorHandling_ =
static_cast<ErrorHandlingMode>(c10_npu::option::OptionsManager::CheckUseHcclAsyncErrorHandleEnable());
desyncDebug_ = static_cast<bool>(c10_npu::option::OptionsManager::CheckUseDesyncDebugEnable());
if (blockingWait_) {
if (asyncErrorHandling_ != NoHandling || desyncDebug_) {
LOG(INFO) << "[Rank " << rank_ << "] TORCH_HCCL_BLOCKING_WAIT and "
<< "HCCL_ASYNC_ERROR_HANDLING|HCCL_DESYNC_DEBUG"
<< "should not both be enabled. "
<< "Only TORCH_HCCL_BLOCKING_WAIT is being used in this process.";
asyncErrorHandling_ = NoHandling;
desyncDebug_ = false;
LOG(INFO) << logPrefix()
<< "TORCH_HCCL_BLOCKING_WAIT is enabled, NO watchdog thread is created.";
}
} else {
if (desyncDebug_ && asyncErrorHandling_ == NoHandling) {
LOG(INFO) << "[Rank " << rank_
<< "] HCCL_DESYNC_DEBUG and HCCL_ASYNC_ERROR_HANDLING "
<< "must both be enabled. "
<< "Enabling HCCL_ASYNC_ERROR_HANDLING.";
asyncErrorHandling_ = TearDown;
}
}
globalRank();
watchdog_ = std::make_unique<Watchdog>(this);
#ifdef ENABLE_HCCL_ERROR_CHECKING
if (asyncErrorHandling_ == TearDown) {
if ((options_->timeout).count() != DEFAULT_TIMEOUT) {
if ((options_->timeout).count() <= hccl_exec_timeout * 1000) {
TORCH_NPU_WARN("The watchdog timeout ", (options_->timeout).count(), "ms(which is set by init_process_group) is less than or equal to HCCL execution timeout ",
hccl_exec_timeout * 1000, "ms! The plog may not be recorded.");
} else if (hccl_exec_timeout == 0) {
TORCH_NPU_WARN("The HCCL execution timeout was set to 0(never timeout), so it is bigger than watchdog timeout ",
(options_->timeout).count(), "ms which is set by init_process_group! The plog may not be recorded. You can disable watchdog by 'export HCCL_ASYNC_ERROR_HANDLING=0'.");
}
} else {
if (hccl_exec_timeout == 0) {
options_->timeout = std::chrono::milliseconds(LLONG_MAX);
} else {
long long watchdog_timeout = (static_cast<long long>(hccl_exec_timeout) + 1800) * 1000;
if (watchdog_timeout <= static_cast<long long>(hccl_exec_timeout) * 1000) {
watchdog_timeout = LLONG_MAX;
}
options_->timeout = std::chrono::milliseconds(watchdog_timeout);
}
}
}
watchdog_->start();
#endif
if (options_->global_ranks_in_group.empty()) {
global_ = this;
if (c10_npu::option::OptionsManager::IsHcclZeroCopyEnable() && c10_npu::NPUCachingAllocator::checkConfigExpandableSegments()) {
TORCH_NPU_HCCL_LOGI("Set the HCCL_ZERO_COPY environment variable in ExpandableSegments. Try to enable the HCCL_ZERO_COPY feature.");
std::unordered_map<std::string, std::string> envMap = checkEnvVarOrLogWarning();
if (envMap["enable"] == "true") {
auto local_rank = std::stoi(envMap["local_rank"]);
if (!c10_npu::NpuSysCtrl::GetInstance().GetInitFlag()) {
TORCH_NPU_HCCL_LOGW("Device is not initialized, init device %d by rank config.", local_rank);
c10_npu::NpuSysCtrl::SysStatus status = c10_npu::NpuSysCtrl::GetInstance().Initialize(local_rank);
}
int32_t device_id = -1;
NPU_CHECK_ERROR(c10_npu::GetDevice(&device_id));
if (device_id != local_rank) {
TORCH_NPU_HCCL_LOGW("Device is %d, set device %d by rank config.", device_id, local_rank);
device_id = local_rank;
}
NPU_CHECK_ERROR(c10_npu::SetDevice(device_id));
std::vector<std::shared_ptr<HCCLComm>> hcclComms(1);
createHCCLCommForZeroCopy(hcclComms, envMap);
c10_npu::NPUCachingAllocator::buildServerMemMapForHccl(device_id, hcclComms[0]);
} else {
TORCH_NPU_HCCL_LOGI("Because the environment variables are not fully configured, the HCCL_ZERO_COPY feature cannot be enabled.");
}
} else {
TORCH_NPU_HCCL_LOGI("The IsHcclZeroCopyEnable function return %d, the checkConfigExpandableSegments function return %d.",
c10_npu::option::OptionsManager::IsHcclZeroCopyEnable(), c10_npu::NPUCachingAllocator::checkConfigExpandableSegments());
}
}
TORCH_NPU_HCCL_LOGI("process group created, group id is %s.", options_->group_id.c_str());
}
void ProcessGroupHCCL::setSequenceNumberForGroup() {}
uint64_t ProcessGroupHCCL::getSequenceNumberForGroup()
{
return seq_;
}
void abortCommsFromMap(
std::unordered_map<std::string, std::vector<std::shared_ptr<HCCLComm>>>& hcclCommsMap,
const int rank,
c10::optional<std::string> abortReason)
{
for (auto& it : hcclCommsMap) {
auto& devName = it.first;
auto& hcclComms = it.second;
for (const auto& hcclComm : hcclComms) {
hcclComm->destroyHcclComm();
}
if (abortReason.has_value()) {
LOG(INFO) << "[Rank " << rank << "] Destroyed " << hcclComms.size()
<< "communicators on ASCEND device " << devName
<< " for reason: " << *abortReason;
} else {
LOG(INFO) << "[Rank " << rank << "] Destroyed " << hcclComms.size()
<< "communicators on ASCEND device " << devName;
}
}
}
bool ProcessGroupHCCL::abort(c10::optional<std::string> abortReason)
{
std::lock_guard<std::mutex> lock(mutex_);
abortCommsFromMap(devHCCLCommMap_, rank_, abortReason);
return true;
}
void ProcessGroupHCCL::waitForFutureOrTimeout(
std::future<bool> &fut,
const std::chrono::milliseconds &timeOutMilSec,
const std::string &futDescription,
bool throwException)
{
std::string errorMsg;
TORCH_CHECK(fut.valid(), "Expected a valid future");
std::future_status status = fut.wait_for(timeOutMilSec);
if (status == std::future_status::ready) {
try {
bool result = fut.get();
if (result) {
LOG(INFO) << logPrefix()
<< "future is successfully executed for: " << futDescription;
}
} catch (const std::exception &e) {
errorMsg = c10::str(
logPrefix(),
"Exception thrown when waitng for future ",
futDescription,
": ",
e.what());
LOG(ERROR) << errorMsg;
} catch (...) {
errorMsg = c10::str(
logPrefix(),
"Unknown exception thrown when waitng for future ",
futDescription);
LOG(ERROR) << errorMsg;
}
} else {
errorMsg = c10::str(
logPrefix(),
"Future for ",
futDescription,
" timed out after ",
timeOutMilSec.count(),
" ms");
LOG(ERROR) << errorMsg;
}
if (throwException && !errorMsg.empty()) {
C10_THROW_ERROR(DistBackendError, errorMsg);
}
}
void ProcessGroupHCCL::shutdown()
{
LOG(INFO) << logPrefix() << "Starting to destroy process group, flushing operations.";
if (terminateProcessGroup_.exchange(true)) {
return;
}
std::vector<c10_npu::NPUStream> hcclStreamsToSync;
{
std::lock_guard<std::mutex> lock(mutex_);
for (const auto& it : hcclStreams_) {
hcclStreamsToSync.insert(
hcclStreamsToSync.end(),
it.second.begin(),
it.second.end());
}
}
if (hcclStreamsToSync.empty()) {
LOG(INFO) << logPrefix()
<< "Skip shutdown stream synchronization because no HCCL streams were created.";
} else {
for (const auto& hcclStream : hcclStreamsToSync) {
hcclStream.synchronize();
}
}
if (windowMem_.has_value()) {
std::vector<at::Device> devices = {windowMem_->device()};
auto comm = getHcclCommByDevices(devices);
if (comm && comm->getHcclComm() != nullptr) {
auto ret = hcclCommDeregister(comm->getHcclComm(), windowHandle_);
if (ret != HCCL_SUCCESS) {
TORCH_NPU_HCCL_LOGE("Call HcclCommDeregister failed.");
}
}
windowHandle_ = nullptr;
windowMem_ = c10::nullopt;
}
terminateHeartbeatMonitorThread_.store(true);
watchdog_->notify();
monitorWakeUpCV_.notify_one();
#ifdef ENABLE_HCCL_ERROR_CHECKING
watchdog_->join();
LOG(INFO) << logPrefix() << "Watchdog thread joined.";
if (hcclHeartbeatMonitorThread_.joinable()) {
hcclHeartbeatMonitorThread_.join();
LOG(INFO) << logPrefix() << "Heartbeat monitor thread joined.";
}
#endif
LOG(INFO) << logPrefix() << "Watchdog joined, destroying HCCL communicators.";
{
std::lock_guard<std::mutex> lock(mutex_);
for (auto& it : devHCCLCommMap_) {
auto& hcclComms = it.second;
for (const auto& hcclComm : hcclComms) {
hcclComm->destroyHcclComm();
}
}
devHCCLCommMap_.clear();
devHCCLCommNameMap_.clear();
p2pSendRecvKeys_.clear();
hcclCommCounter_ = 0;
}
LOG(INFO) << logPrefix() << "Destroy complete.";
}
void ProcessGroupHCCL::deleteTCPStoreKey()
{
try {
store_->deleteKey("0");
for (const auto &key : TCPStoreKeyList_) {
TORCH_NPU_HCCL_LOGD("Delete TCP store key: %s", key.c_str());
store_->deleteKey(key);
}
} catch(...) {
TORCH_NPU_HCCL_LOGE("Delete TCP store key failed.");
TCPStoreKeyList_.clear();
return;
}
TORCH_NPU_HCCL_LOGI("Delete TCP store key success.");
TCPStoreKeyList_.clear();
}
void ProcessGroupHCCL::abortAndClearHcclComm(c10::optional<std::string> abortReason)
{
std::lock_guard<std::mutex> lock(mutex_);
abortCommsFromMap(devHCCLCommMap_, rank_, abortReason);
devHCCLCommMap_.clear();
devHCCLCommNameMap_.clear();
p2pSendRecvKeys_.clear();
hcclCommCounter_ = 0;
return;
}
ProcessGroupHCCL::~ProcessGroupHCCL()
{
LOG(INFO) << logPrefix() << "ProcessGroupHCCL destructor entered.";
if (windowMem_.has_value()) {
std::vector<at::Device> devices = {windowMem_->device()};
auto comm = getHcclCommByDevices(devices);
if (comm->getHcclComm() != nullptr) {
auto ret = hcclCommDeregister(comm->getHcclComm(), windowHandle_);
if (ret != HCCL_SUCCESS) {
ASCEND_LOGE("Call HcclCommDeregister failed.")
}
}
windowHandle_ = nullptr;
windowMem_ = c10::nullopt;
}
if (options_->global_ranks_in_group.empty()) {
global_ = nullptr;
}
terminateProcessGroup_.store(true);
terminateHeartbeatMonitorThread_.store(true);
watchdog_->notify();
monitorWakeUpCV_.notify_one();
#ifdef ENABLE_HCCL_ERROR_CHECKING
watchdog_->join();
if (hcclHeartbeatMonitorThread_.joinable()) {
hcclHeartbeatMonitorThread_.join();
LOG(INFO) << logPrefix()
<< "ProcessGroupHCCL heart beat monitor thread joined.";
}
#endif
{
std::lock_guard<std::mutex> lock(mutex_);
for (auto& it : devHCCLCommMap_) {
auto& hcclComms = it.second;
for (const auto& hcclComm : hcclComms) {
hcclComm->destroyHcclComm();
}
}
devHCCLCommMap_.clear();
p2pSendRecvKeys_.clear();
}
LOG(INFO) << logPrefix() << "ProcessGroupHCCL destructor completed.";
TORCH_NPU_HCCL_LOGI("process group destroyed, group id is %s.", options_->group_id.c_str());
}
#ifndef BUILD_LIBTORCH
std::future<bool> ProcessGroupHCCL::launchAsyncPythonTracebackDump()
{
std::promise<bool> resultPromise;
std::future<bool> resultFuture = resultPromise.get_future();
std::thread workerThread([promise = std::move(resultPromise), this]() mutable {
try {
promise.set_value(this->dumpPythonTraceback());
} catch (...) {
promise.set_exception(std::current_exception());
}
});
workerThread.detach();
return resultFuture;
}
bool ProcessGroupHCCL::dumpPythonTraceback()
{
std::string filePath = c10d::getCvarString({"TORCH_HCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/hccl_trace_rank_");
PyGILState_STATE gil = PyGILState_Ensure();
try {
py::dict locals = py::dict("path"_a=filePath.c_str(), "rank"_a=rank_);
py::exec(R"(
import sys
import os
import traceback
import threading
from torch_npu.utils._path_manager import PathManager
try:
py_stacks = 'pid: {}\n'.format(os.getpid())
threadInfos = {}
for thread in threading.enumerate():
threadInfos[thread.ident] = thread
for thread_id, stack in sys._current_frames().items():
stack_list = traceback.format_list(traceback.extract_stack(stack))
py_stacks += 'thread {}:\n'.format(threadInfos[thread_id] if thread_id in threadInfos.keys() else thread_id)
py_stacks += ''.join(stack_list)
dump_file = '{path}{rank}_py_traceback'.format(**locals())
PathManager.check_input_file_path(dump_file)
with open(dump_file, 'w') as f:
f.write(py_stacks)
except Exception as e:
print(e);
)", py::globals(), locals);
} catch (const std::exception& e) {
LOG(ERROR) << logPrefix() << "dumpPythonTraceback error: " << e.what();
} catch (...) {
LOG(ERROR) << logPrefix() << "dumpPythonTraceback Unknown exception type.";
}
PyGILState_Release(gil);
return true;
}
bool ProcessGroupHCCL::dumpDebuggingInfo()
{
auto fut = launchAsyncPythonTracebackDump();
auto kGilCheckTimeout = std::chrono::milliseconds(3000);
auto futStatus = fut.wait_for(kGilCheckTimeout);
if (futStatus != std::future_status::ready) {
TORCH_CHECK(
futStatus != std::future_status::deferred,
"Expected the future of dumpping python traceback to have been launched eagerly.");
LOG(ERROR)
<< "Could not acquire GIL within 3000 ms when dump python traceback, possible GIL induced hang";
}
LOG(INFO) << "Could dump python traceback";
static std::mutex writeDebugInfoMutex;
std::lock_guard<std::mutex> lock(writeDebugInfoMutex);
LOG(ERROR) << logPrefix() << "ProcessGroupHCCL preparing to dump debug info.";
if (hcclTraceBufferSize_ > 0) {
auto hcclTrace = dump_hccl_trace(true, true, false);
DebugInfoWriter &writer = DebugInfoWriter::getWriter(globalRank());
LOG(ERROR) << logPrefix() << "ProcessGroupHCCL dumping hccl trace to "
<< writer.getWriterTarget();
writer.write(hcclTrace);
return true;
}
return false;
}
void ProcessGroupHCCL::dumpTraceAndResetStatus()
{
std::future<bool> asyncDebugDump = std::async(
std::launch::async,
[this]() {
return this->dumpDebuggingInfo();
});
waitForFutureOrTimeout(
asyncDebugDump,
std::chrono::milliseconds(waitTimeoutDumpInMilSec_),
"Flight recorder dump in heartbeatMonitor");
heartbeat_++;
shouldDump_.store(false);
}
#endif
void ProcessGroupHCCL::terminateProcess(std::string errMsg)
{
LOG(FATAL) << logPrefix() << errMsg;
}
int computeDeltaMS(
std::chrono::time_point<std::chrono::steady_clock> start,
std::chrono::time_point<std::chrono::steady_clock> end)
{
return std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
.count();
}
#ifndef BUILD_LIBTORCH
void ProcessGroupHCCL::heartbeatMonitor()
{
uint64_t heartBeatCounter = 0ULL;
std::string errorMsg;
std::string exitMsg;
bool checkDumpSignal = (dumpOnException_ && options_->global_ranks_in_group.empty());
int monitorPollInterval = checkDumpSignal ? coordCheckIntervalMilSec_
: heartbeatTimeoutInSec_ * 1000;
auto lastTimePollStore = std::chrono::steady_clock::now();
auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now();
c10::optional<DumpPipe> dumpPipe = c10::nullopt;
if (options_->global_ranks_in_group.empty()) {
dumpPipe.emplace(rank_);
}
while (true) {
std::unique_lock<std::mutex> lock(monitorMutex_);
if (monitorWakeUpCV_.wait_for(lock,
std::chrono::milliseconds(monitorPollInterval),
[&]{ return terminateHeartbeatMonitorThread_.load(); })) {
return;
}
auto currentTime = std::chrono::steady_clock::now();
if (checkDumpSignal) {
if (shouldDump_.load()) {
errorMsg = c10::str(
logPrefix(),
"Received a dump signal from this local rank and will ",
"start to dump the debug info. ",
"Last enqueued HCCL work: ",
pgStatus_->lastEnqueuedSeq,
", last completed HCCL work: ",
pgStatus_->lastCompletedSeq,
".");
exitMsg = c10::str(
"ProcessGroupHCCL's watchdog detected an exception from the local rank. ",
"This is most likely caused by incorrect usages of collectives, e.g., wrong ",
"sizes used across ranks, the order of collectives is not same for all ranks ",
"or the scheduled collective, for some reason, didn't run. Additionally, ",
"this can be caused by GIL deadlock or other reasons such as network errors or ",
"bugs in the communications library (e.g. HCCL), etc. We tried our best to ",
"dump the debug info into the storage to help you debug the issue.");
dumpTraceAndResetStatus();
}
if (computeDeltaMS(lastWorkListUpdateTime_, currentTime) >= kWatchdogThreadSleepMillis &&
computeDeltaMS(lastTimePollStore, currentTime) >= coordCheckIntervalMilSec_ && !hasGlobalDumped) {
lastTimePollStore = currentTime;
bool checkExceptionDump = false;
try {
checkExceptionDump =
globalStore_->check({std::string(EXCEPTION_DUMP)});
} catch (const std::exception &e) {
LOG(ERROR)
<< logPrefix()
<< "Failed to get exception dump flag from the global store."
<< e.what();
dumpTraceAndResetStatus();
}
if (checkExceptionDump) {
int timeOutRank = -1;
if (!shouldDump_.load()) {
LOG(ERROR)
<< logPrefix()
<< "First PG on this rank detecting the dump signal through tcpstore.";
}
shouldDump_.store(true);
try {
auto vec = globalStore_->get(std::string(EXCEPTION_DUMP));
TORCH_CHECK_WITH(
DistBackendError,
vec.size() == sizeof(int),
"Invalid size for the timeout rank ID");
std::memcpy(&timeOutRank, vec.data(), vec.size());
} catch (const std::exception &e) {
LOG(ERROR) << logPrefix()
<< "Failed to get timeout rank ID from the global store."
<< e.what();
}
errorMsg = c10::str(
logPrefix(),
"Received a global dump signal from rank ",
timeOutRank,
", and will start to dump the debug info. ",
"Last enqueued HCCL work: ",
pgStatus_->lastEnqueuedSeq,
", last completed HCCL work: ",
pgStatus_->lastCompletedSeq,
".");
exitMsg = c10::str(
"ProcessGroupHCCL's watchdog detected a dump signal from rank ",
timeOutRank,
" and notified the current rank. ",
"This is most likely caused by incorrect usages of collectives, e.g., wrong ",
"sizes used across ranks, the order of collectives is not same for all ranks ",
"or the scheduled collective, for some reason, didn't run. Additionally, ",
"this can be caused by GIL deadlock or other reasons such as network errors or ",
"bugs in the communications library (e.g. HCCL), etc. We tried our best to ",
"dump the debug info into the storage to help you debug the issue.");
dumpTraceAndResetStatus();
hasGlobalDumped = true;
}
}
}
if (computeDeltaMS(lastTimeHeartBeatCheck, currentTime) >=
heartbeatTimeoutInSec_ * 1000) {
lastTimeHeartBeatCheck = currentTime;
auto heartbeat = heartbeat_.load();
if (heartbeat != heartBeatCounter) {
heartBeatCounter = heartbeat;
} else {
if (!shouldDump_.load()) {
LOG(ERROR)
<< logPrefix()
<< "First PG on this rank that detected no heartbeat of its watchdog.";
}
shouldDump_.store(true);
errorMsg = c10::str(
logPrefix(),
"Heartbeat monitor timed out! Process will be terminated after dumping debug info.",
" workMetaList_.size()=",
workMetaList_.size());
exitMsg = c10::str(
"ProcessGroupHCCL's watchdog got stuck for ",
heartbeatTimeoutInSec_,
" seconds without making progress in monitoring enqueued collectives. ",
"This typically indicates a HCCL/CUDA API hang blocking the watchdog, ",
"and could be triggered by another thread holding the GIL inside a ",
"CUDA api, or other deadlock-prone behaviors.",
"If you suspect the watchdog is not actually stuck and a longer timeout would help, ",
"you can either increase the timeout (TORCH_HCCL_HEARTBEAT_TIMEOUT_SEC) to a larger value "
"or disable the heartbeat monitor (TORCH_HCCL_ENABLE_MONITORING=0)."
"If either of aforementioned helps, feel free to file an issue to PyTorch about the short timeout "
"or false positive abort; otherwise, please attempt to debug the hang. "
"workMetaList_.size() = ",
workMetaList_.size(),
"");
if (checkDumpSignal) {
dumpTraceAndResetStatus();
}
}
}
if (dumpPipe.has_value() && dumpPipe->shouldDump()) {
std::future<bool> fut = std::async(
std::launch::async, [this]() {
return this->dumpDebuggingInfo();
});
}
}
LOG(ERROR) << errorMsg;
auto &cpp_dumper = get_cpp_trace_dumper();
if (cpp_dumper.has_value()) {
LOG(INFO) << "Dumping c++ stacktraces:";
cpp_dumper.value()([](const std::string &line) {
LOG(ERROR) << line;
});
}
if ((terminateProcessGroup_.load() || collectiveDebugInfoMode_.load() ||
shouldDump_.load()) &&
!terminateHeartbeatMonitorThread_.load()) {
std::this_thread::sleep_for(std::chrono::seconds(heartbeatTimeoutInSec_));
}
if (!terminateHeartbeatMonitorThread_.load()) {
const auto finalExitMsg = c10::str(logPrefix(), exitMsg);
if (monitorThreadEnabled_.load()) {
terminateProcess(finalExitMsg);
} else {
LOG(ERROR)
<< "PGHCCL Monitor Thread is disabled, but would have killed this job:\n"
<< finalExitMsg;
}
}
}
#endif
ProcessGroupHCCL::Watchdog::Watchdog(ProcessGroupHCCL *pg)
{
pg_ = pg;
rank_ = pg_->getRank();
desyncDebug_ = static_cast<bool>(c10_npu::option::OptionsManager::CheckUseDesyncDebugEnable());
}
void ProcessGroupHCCL::Watchdog::notify()
{
pg_->workMetaListCV_.notify_one();
}
void ProcessGroupHCCL::Watchdog::start()
{
TORCH_CHECK(
!hcclCommWatchdogThread_.joinable(), "Watchdog thread already started");
hcclCommWatchdogThread_ = std::thread(&ProcessGroupHCCL::Watchdog::run, this);
TORCH_NPU_HCCL_LOGI("Watchdog thread started.");
}
void ProcessGroupHCCL::Watchdog::join()
{
if (hcclCommWatchdogThread_.joinable()) {
hcclCommWatchdogThread_.join();
LOG(INFO) << pg_->logPrefix() << "ProcessGroupHCCL watchdog thread joined.";
}
}
void ProcessGroupHCCL::Watchdog::run()
{
c10_npu::SetThreadType(c10_npu::ThreadType::WATCHDOG_THREAD);
try {
LOG(INFO) << "[Rank " << rank_ << "] HCCL watchdog thread started!";
#ifndef BUILD_LIBTORCH
if (ProcessGroupHCCL::monitorThreadEnabled_.load()) {
pg_->hcclHeartbeatMonitorThread_ = std::thread(&ProcessGroupHCCL::heartbeatMonitor, pg_);
}
#endif
runLoop();
LOG(INFO) << "[Rank " << rank_ << "] HCCL watchdog thread terminated normally";
} catch (std::exception& e) {
const auto exitMsg = c10::str(
"[Rank ",
rank_,
"] HCCL watchdog thread terminated with exception: ",
e.what());
LOG(ERROR) << exitMsg;
if (status_save_enable) {
if (ProcessGroupHCCL::exceptionMessage_.empty()) {
ProcessGroupHCCL::exceptionMessage_ = e.what();
}
pg_->recordHcclStatus(status_save_path, true, true);
}
watchDogException_ = std::make_exception_ptr(std::runtime_error(exitMsg));
std::rethrow_exception(watchDogException_);
} catch (...) {
const auto exitMsg = c10::str(
"[Rank ",
rank_,
"] HCCL watchdog thread terminated with exception: unknown");
LOG(ERROR) << exitMsg;
if (status_save_enable) {
pg_->recordHcclStatus(status_save_path, true);
}
watchDogException_ = std::make_exception_ptr(std::runtime_error(exitMsg));
std::rethrow_exception(watchDogException_);
}
}
int ProcessGroupHCCL::Watchdog::getSignalSrcRank(
c10::intrusive_ptr<c10d::Store>& store,
const std::string& signal)
{
int srcRank = -1;
bool signalExists = false;
try {
signalExists = store->check({signal});
} catch (const std::exception& e) {
LOG(WARNING) << pg_->logPrefix() << "Failed to check the signal " << signal
<< " on TCPStore, " << e.what();
}
if (!signalExists) {
return srcRank;
}
std::vector<uint8_t> vec;
try {
vec = store->get(std::string(signal));
} catch (const std::exception& e) {
LOG(ERROR) << pg_->logPrefix() << "Failed to get source rank of the signal "
<< signal << " from TCPStore." << e.what();
}
TORCH_CHECK_WITH(
DistBackendError,
vec.size() == sizeof(int),
"Invalid size for the timeout rank ID");
std::memcpy(&srcRank, vec.data(), vec.size());
return srcRank;
}
void ProcessGroupHCCL::Watchdog::checkAndSetRemoteError()
{
;
}
void ProcessGroupHCCL::Watchdog::setDesyncDebug(bool desyncDebug)
{
desyncDebug_ = desyncDebug;
}
void ProcessGroupHCCL::logWorkStart(WorkHCCL& work)
{
if (work.startTraceUpdated_) {
return;
}
if (terminateProcessGroup_.load() || storeError_) {
return;
}
work.startTraceUpdated_ = true;
storeError_ = !c10d::traceUpdate(store_, traceKeyStart_, work.seq_, opTypeToString(work.opType_));
}
void ProcessGroupHCCL::logWorkEnd(WorkHCCL& work)
{
if (terminateProcessGroup_.load() || storeError_) {
return;
}
if (!work.startTraceUpdated_) {
logWorkStart(work);
}
storeError_ = !c10d::traceUpdate(store_, traceKeyEnd_, work.seq_, opTypeToString(work.opType_));
}
std::string ProcessGroupHCCL::createLogPrefix() const
{
if (!pg_desc_.empty() && pg_desc_ != "undefined") {
return c10::str("[PG ", pg_name_, " (", pg_desc_, ") Rank ", rank_, "] ");
}
return c10::str("[PG ", pg_name_, " Rank ", rank_, "] ");
}
const std::string &ProcessGroupHCCL::logPrefix() const
{
return logPrefix_;
}
const int &ProcessGroupHCCL::globalRank() const
{
static int globalRank = rank_;
return globalRank;
}
const std::vector<uint32_t>& ProcessGroupHCCL::groupRanks() const
{
if (options_->global_ranks_in_group.empty()) {
static std::vector<uint32_t> globalRanks(size_);
std::iota(globalRanks.begin(), globalRanks.end(), 0);
return globalRanks;
}
return options_->global_ranks_in_group;
}
void ProcessGroupHCCL::addEphemeralTimeout(const std::chrono::milliseconds& timeout)
{
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
ephemeralTimeoutActive_ += timeout;
}
void ProcessGroupHCCL::checkHcclComms()
{
if (asyncErrorHandling_ == NoHandling) {
return;
}
std::lock_guard<std::mutex> maplock(mutex_);
std::unordered_set<std::string> checkErrors;
for (const auto & [name, hcclComms] : devHCCLCommMap_) {
auto exception_ptr = checkForHCCLErrors(hcclComms);
if (exception_ptr) {
checkErrors.insert(name);
if (reportedErrorComms_.find(name) == reportedErrorComms_.end()) {
auto exceptionMsg = c10::str(
"[Rank",
rank_,
"] checkHcclComms found HcclComms vector ",
name,
" got ERROR via HcclGetCommAsyncError : ");
TORCH_NPU_HCCL_LOGE("[Rank %d] checkHcclComms found HcclComms vector %s got ERROR via HcclGetCommAsyncError : %s",
rank_, name.c_str(), getExceptionMsgFromExceptionPtr(exception_ptr).c_str());
LOG(ERROR) << exceptionMsg << getExceptionMsgFromExceptionPtr(exception_ptr).c_str();
C10_LOG_API_USAGE_ONCE("ProcessGroupHCCL.handleException");
reportedErrorComms_.insert(name);
ProcessGroupHCCL::exceptionMessage_ = getExceptionMsgFromExceptionPtr(exception_ptr);
if (SHOULD_TEAR_DOWN(asyncErrorHandling_)) {
TORCH_NPU_HCCL_LOGE("To avoid data inconsistency, we are taking the entire process down.");
LOG(ERROR) << "To avoid data inconsistency, we are taking the entire process down.";
std::rethrow_exception(exception_ptr);
}
}
}
}
for (auto it = reportedErrorComms_.begin(); it != reportedErrorComms_.end();) {
if (checkErrors.find(*it) == checkErrors.end()) {
TORCH_NPU_HCCL_LOGI("[Rank %d] HcclComms vector %s error status cleared/recovered.", rank_, it->c_str());
it = reportedErrorComms_.erase(it);
} else {
++it;
}
}
}
void ProcessGroupHCCL::Watchdog::runLoop()
{
bool needSetDevice = true;
std::list<ProcessGroupHCCL::WorkHCCL> completedWorkList;
auto lastrecordtime = std::chrono::steady_clock::now();
auto timenow = std::chrono::steady_clock::now();
bool recordflag = false;
int kThousandMillis = 1000;
while (!pg_->terminateProcessGroup_.load()) {
if (status_save_enable) {
checkAndMakePath(status_save_path.c_str(), "Open shared directory failed. Please check whether input path is valid.");
timenow = std::chrono::steady_clock::now();
recordflag = (std::chrono::duration_cast<std::chrono::milliseconds>(timenow - lastrecordtime).count() > (c10_npu::option::OptionsManager::GetStatusSaveInterval() * kThousandMillis));
}
{
std::unique_lock<std::mutex> lock(pg_->workMetaListMutex_);
pg_->workMetaListCV_.wait_for(lock, std::chrono::milliseconds(kWatchdogThreadSleepMillis),
[&]() -> bool { return pg_->terminateProcessGroup_.load(); });
if (pg_->watchdogStatus == WatchdogStatus::STOP) {
continue;
}
pg_->checkHcclComms();
for (auto it = pg_->workMetaList_.begin(); it != pg_->workMetaList_.end();
) {
auto& work = *it;
try {
if (needSetDevice) {
c10::DeviceIndex device = static_cast<int>(work.devices_[0].index());
c10_npu::SetThreadAffinity(device);
NPU_CHECK_ERROR(c10_npu::SetDevice(device));
ProcessGroupHCCL::deviceId_ = static_cast<int>(work.devices_[0].index());
needSetDevice = false;
}
} catch (const std::exception& e) {
std::string exceptionMsg = std::string(e.what());
std::string device_error = get_device_error(exceptionMsg);
if (!device_error.empty()) {
TORCH_NPU_HCCL_LOGI("Find %s when runloop setDevice.", device_error.c_str());
device_error_msg = device_error;
}
if (exceptionMsg.find("FORCE STOP") == std::string::npos) {
force_stop_error_flag = true;
TORCH_NPU_HCCL_LOGI("Find FORCE STOP when runloop setDevice.");
}
}
if (!pg_->terminateProcessGroup_.load()) {
work.checkAndSetException();
}
work.checkDispatch();
bool exec_timeout = work.checkExec();
if (pg_->dumpOnException_ && exec_timeout) {
try {
auto rank = pg_->globalRank();
auto vec = std::vector<uint8_t>(
reinterpret_cast<uint8_t *>(&rank),
reinterpret_cast<uint8_t *>(&rank) + sizeof(rank));
pg_->globalStore_->set(std::string(EXCEPTION_DUMP), vec);
if (!ProcessGroupHCCL::shouldDump_.load()) {
LOG(ERROR) << pg_->logPrefix()
<< "First watchdog exec timeout to set the dump signal.";
}
ProcessGroupHCCL::shouldDump_.store(true);
} catch (const std::exception &e) {
LOG(ERROR) << pg_->logPrefix()
<< "Failed to set exec timeout dump signal in tcpstore. "
<< "Error: " << e.what();
}
}
bool timedOut = work.checkTimeout();
if (work.exception()) {
if (pg_->dumpOnException_) {
try {
auto rank = pg_->globalRank();
auto vec = std::vector<uint8_t>(
reinterpret_cast<uint8_t *>(&rank),
reinterpret_cast<uint8_t *>(&rank) + sizeof(rank));
pg_->globalStore_->set(std::string(EXCEPTION_DUMP), vec);
if (!ProcessGroupHCCL::shouldDump_.load()) {
LOG(ERROR) << pg_->logPrefix()
<< "First watchdog to set the dump signal.";
}
ProcessGroupHCCL::shouldDump_.store(true);
std::this_thread::sleep_for(
std::chrono::seconds(pg_->heartbeatTimeoutInSec_));
} catch (const std::exception &e) {
LOG(ERROR) << pg_->logPrefix()
<< "Failed to set dump signal in tcpstore. "
<< "Error: " << e.what();
}
}
if (desyncDebug_ && timedOut) {
try {
auto desyncMsg = retrieveDesyncReport(pg_->store_, "HCCL", pg_->getRank(), pg_->getSize());
LOG(ERROR) << desyncMsg;
} catch (const std::exception& e) {
LOG(ERROR) << "Failed to retrieve HCCL_DESYNC_DEBUG report. "
<< " Please file an issue. Error: " << e.what();
} catch (...) {
LOG(ERROR) << "Failed to rerieve HCCL_DESYNC_DEBUG report with unknown error."
<< " Please file an issue.";
}
}
work.handleException(pg_->asyncErrorHandling_);
}
if (pg_->desyncDebug_) {
if (work.isStarted(pg_->asyncErrorHandling_)) {
pg_->logWorkStart(work);
}
if (work.isCompleted()) {
pg_->logWorkEnd(work);
}
}
if (ProcessGroupHCCL::monitorThreadEnabled_.load() && pg_->pgStatus_->lastStartedSeq < static_cast<int64_t>(work.seq_) &&
work.isStarted(pg_->asyncErrorHandling_)) {
pg_->pgStatus_->lastStartedSeq = static_cast<int64_t>(work.seq_);
pg_->pgStatus_->lastStartedWorkName = opTypeToString(work.opType_);
pg_->pgStatus_->lastStartedNumelIn = work.numelIn_;
pg_->pgStatus_->lastStartedNumelOut = work.numelOut_;
}
if (work.isCompleted()) {
if (*(work.is_dispatched) && work.is_reported) {
TORCH_NPU_HCCL_LOGW("Process group work %s, seq_num %u dispatch success. This warning log can be ignored.", opTypeToString(work.opType_).c_str(), work.seq_);
work.is_reported = false;
}
{
std::lock_guard<std::mutex> timeoutLock(pg_->mtxTimeoutExtension_);
if (work.ownedEphermeralTimeout_.count() > 0) {
pg_->ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_;
pg_->ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_;
}
}
if (status_save_enable && !work.exception()) {
pg_->is_refreshed = pg_->refreshStatusInfo(work, "end");
}
pg_->pgStatus_->lastCompletedSeq = static_cast<int64_t>(work.seq_);
pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_;
pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_;
HCCLTraceBuffer::get()->retire_id(work.trace_id_, true);
it = pg_->workMetaList_.erase(it);
c10_npu::NPUGraph::dec_pending_event_queries();
} else {
if (status_save_enable && work.isStarted(pg_->asyncErrorHandling_)) {
pg_->is_refreshed = pg_->refreshStatusInfo(work, "start");
}
++it;
}
pg_->heartbeat_++;
}
}
if (status_save_enable && pg_->is_refreshed) {
pg_->updateStatusOutput();
}
if (recordflag && pg_->recordHcclStatus(status_save_path)) {
lastrecordtime = std::chrono::steady_clock::now();
}
}
if (status_save_enable) {
pg_->recordHcclStatus(status_save_path);
}
if (pg_->terminateProcessGroup_.load()) {
if (status_save_enable) {
pg_->recordHcclStatus(status_save_path, true);
}
std::unique_lock<std::mutex> lock(pg_->workMetaListMutex_);
pg_->workMetaList_.clear();
}
}
std::exception_ptr ProcessGroupHCCL::WorkHCCL::checkForHCCLErrors(
const std::vector<std::shared_ptr<HCCLComm>>& hcclComms) const
{
return checkForHCCLErrorsInternal(hcclComms);
}
std::exception_ptr ProcessGroupHCCL::checkForHCCLErrors(
const std::vector<std::shared_ptr<HCCLComm>>& hcclComms)
{
return checkForHCCLErrorsInternal(hcclComms);
}
std::exception_ptr ProcessGroupHCCL::checkForHCCLErrorsInternal(
const std::vector<std::shared_ptr<HCCLComm>>& hcclComms)
{
for (const auto& hcclComm : hcclComms) {
HcclResult hcclAsyncErr = hcclComm->checkForHcclError();
if (hcclAsyncErr != HCCL_SUCCESS) {
auto errmsg = c10_npu::c10_npu_get_error_message();
return std::make_exception_ptr(std::runtime_error(errmsg ? errmsg : ""));
}
}
return nullptr;
}
void ProcessGroupHCCL::broadcastMasterID(
HcclRootInfo* hcclID,
bool isSingleP2POp,
const std::string& devicesKey,
int p2pRank)
{
std::string storeKey;
if (!isSingleP2POp) {
storeKey = std::to_string(hcclCommCounter_++);
} else {
storeKey = devicesKey;
}
std::string ver_key = "version_key";
auto date_list = __DATE__ != nullptr ? __DATE__ : "";
std::vector<uint8_t> ver_list;
#ifdef PYTORCH_NPU_VERSION
auto py_list = PYTORCH_NPU_VERSION != nullptr ? PYTORCH_NPU_VERSION : "";
ver_list.insert(ver_list.end(), py_list, py_list + strlen(py_list));
#endif
ver_list.insert(ver_list.end(), date_list, date_list + strlen(date_list));
if (rank_ == 0 || (isSingleP2POp && p2pRank == 0)) {
auto vec = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(hcclID), reinterpret_cast<uint8_t*>(hcclID) + HCCL_ROOT_INFO_BYTES);
store_->set(storeKey, vec);
store_->set(ver_key, ver_list);
TCPStoreKeyList_.emplace(storeKey);
} else {
try {
auto vec = store_->get(storeKey);
TORCH_CHECK(vec.size() == HCCL_ROOT_INFO_BYTES, DIST_ERROR(ErrCode::PARAM));
TCPStoreKeyList_.emplace(storeKey);
std::memcpy(hcclID, vec.data(), vec.size());
} catch (const std::exception& e) {
std::string exceptionMsg = c10::str(
"[",
rank_,
"] is setting up HCCL communicator and "
"retrieving hcclUniqueId from [0] via c10d key-value store by key '",
storeKey,
"', but store->get('",
storeKey,
"') got error: ");
throw std::runtime_error(exceptionMsg + e.what() +
". This may indicate a possible application crash on rank 0 or a network set up issue." +
DIST_ERROR(ErrCode::INTERNAL));
} catch (...) {
throw std::runtime_error(c10::str(
"Unknown exception while [",
rank_,
"] is setting up HCCL communicator and "
"retrieving hcclUniqueId from [0] via c10d key-value store by key '",
storeKey,
"'",
". This may indicate a possible application crash on rank 0 or a network set up issue.") +
DIST_ERROR(ErrCode::INTERNAL));
}
auto main_list = store_->get(ver_key);
if (main_list != ver_list) {
TORCH_NPU_WARN("PTA version mismatch");
}
}
}
void ProcessGroupHCCL::recordDataVol(std::string opName, const std::string dataVol, const int currRank,
std::vector<std::shared_ptr<HCCLComm>>& hcclComms)
{
TORCH_NPU_HCCL_LOGD("Record data volume for HCCL op %s, data volume %s, current rank %d.", opName.c_str(), dataVol.c_str(), currRank);
std::ofstream outfile;
std::stringstream fileName;
std::string commName = getHcclCommNameWithoutInit(hcclComms);
auto master_addr = getenv("MASTER_ADDR");
auto hccl_algo = getenv("HCCL_ALGO");
TORCH_CHECK(master_addr != nullptr, "Unable to fetch master IP addr, environment variable is null.", DIST_ERROR(ErrCode::NOT_FOUND));
fileName << master_addr << "_" << commName << "_" << std::to_string(currRank) << ".log";
std::string out_file_path = c10::str(nslb_path, "/", fileName.str());
bool need_algo = hccl_algo != nullptr && access(out_file_path.c_str(), W_OK) != 0;
try {
if (access(nslb_path, W_OK) != 0 && mkdir(nslb_path, S_IRWXU | S_IRGRP | S_IXGRP) != 0) {
throw std::exception();
}
if (access(out_file_path.c_str(), W_OK) != 0) {
int fd = open(out_file_path.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR | S_IRGRP);
if (fd == -1) {
throw std::exception();
}
close(fd);
}
outfile.open(out_file_path, std::ios::app);
} catch (std::exception& e) {
throw std::runtime_error("Open shared directory failed. Please check whether input path is valid." + DIST_ERROR(ErrCode::NOT_FOUND));
}
std::transform(opName.begin(), opName.end(), opName.begin(), ::tolower);
if (need_algo) {
outfile << "HCCL_ALGO=" << hccl_algo << "\n";
}
outfile << opName << " " << dataVol << " " << std::to_string(currRank) << "\n";
outfile.close();
}
bool ProcessGroupHCCL::refreshStatusInfo(ProcessGroupHCCL::WorkHCCL work, std::string status)
{
if (StatusInfo.seq == work.seq_ && StatusInfo.status == status) {
return false;
}
StatusInfo.seq = work.seq_;
StatusInfo.pgId = options_->group_id;
StatusInfo.opType = opTypeToString(work.opType_);
if (StatusInfo.commIds == "") {
for (auto i : options_->global_ranks_in_group) {
StatusInfo.commIds += (std::to_string(i) + " ");
}
}
if (StatusInfo.commIds == "") {
StatusInfo.commIds = "all";
}
StatusInfo.status = status;
return true;
}
void ProcessGroupHCCL::updateStatusOutput()
{
std::unique_lock<std::mutex> lock(StatusMapmutex_);
if (!StatusInfo.pgId.empty()) {
StatusOutput_[options_->group_id] = StatusInfo;
}
is_refreshed = false;
}
bool ProcessGroupHCCL::recordHcclStatus(const std::string path, bool end, bool error)
{
TORCH_NPU_HCCL_LOGI("Record HCCL status, path %s, end %d, error %d.", path.c_str(), end, error);
std::unique_lock<std::mutex> lock(StatusMapmutex_);
if (!options_->global_ranks_in_group.empty() && !error) {
return true;
} else if (!StatusOutput_.empty()) {
static auto pid = getpid();
static std::chrono::time_point<std::chrono::system_clock> firstrecordtime = std::chrono::system_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(firstrecordtime.time_since_epoch()).count();
auto end_duration = duration;
if (end) {
static std::chrono::time_point<std::chrono::system_clock> endrecordtime = std::chrono::system_clock::now();
end_duration = std::chrono::duration_cast<std::chrono::milliseconds>(endrecordtime.time_since_epoch()).count();
}
std::ofstream outfile;
std::stringstream fileName;
static auto master_addr = getenv("MASTER_ADDR");
if (master_addr == nullptr) {
master_addr = "127.0.0.1";
TORCH_NPU_HCCL_LOGW("Unable to fetch master IP addr, environment variable is null, it will use 127.0.0.1");
}
int global_rank = rank_;
if (!options_->global_ranks_in_group.empty()) {
global_rank = static_cast<int>(options_->global_ranks_in_group[rank_]);
}
fileName << "torch_hccl_status-" << std::to_string(global_rank) << "_" << master_addr << "_" << std::to_string(deviceId_) << "_";
fileName << std::to_string(numRanks_) << "_" << std::to_string(pid) << "_" << std::to_string(duration) << ".log";
bool isMaster = false;
if (global_rank == 0) {
isMaster = true;
}
std::string out_file_path = c10::str(path, "/", fileName.str());
checkAndMakePath(path.c_str(), "Open shared directory failed. Please check whether input path is valid.");
createFile(out_file_path.c_str());
using json = nlohmann::json;
json result;
std::list<json> last_comm_ops;
for (auto info = StatusOutput_.begin(); info != StatusOutput_.end(); info++) {
json comm_op;
comm_op["seq"] = info->second.seq;
comm_op["op_type"] = info->second.opType;
comm_op["pg_id"] = info->second.pgId;
comm_op["comm_ids"] = info->second.commIds;
comm_op["status"] = info->second.status;
last_comm_ops.emplace_back(comm_op);
}
if (!last_comm_ops.empty()) {
result["last_comm_op"] = last_comm_ops;
}
result["is_master"] = isMaster;
result["exception_message"] = exceptionMessage_;
result["global_pg_end_time"] = end_duration;
std::string result_str = result.dump();
outfile.open(out_file_path.c_str(), std::ios::trunc);
outfile << result_str << std::endl;
outfile.close();
return true;
}
return false;
}
void ProcessGroupHCCL::recordComm(std::string filename, std::string opName, const int currRank, std::vector<std::shared_ptr<HCCLComm>>& hcclComms)
{
TORCH_NPU_HCCL_LOGD("Record HCCL comm, filename %s, opName %s, currRank %d.", filename.c_str(), opName.c_str(), currRank);
std::ofstream outfile;
std::string commName = getHcclCommNameWithoutInit(hcclComms);
if (isFileExists(filename)) {
try {
outfile.open(filename, std::ios::app);
} catch (std::exception& e) {
throw std::runtime_error("Open shared directory failed. Please check whether file is valid." + DIST_ERROR(ErrCode::UNAVAIL));
}
} else {
TORCH_CHECK(false, filename, " is not exist!", DIST_ERROR(ErrCode::NOT_FOUND));
}
std::transform(opName.begin(), opName.end(), opName.begin(), ::tolower);
const std::vector<uint32_t>& ranks = groupRanks();
std::stringstream ss;
for (size_t i = 0; i < ranks.size(); ++i) {
ss << ranks[i];
if (i != ranks.size() - 1) {
ss << ", ";
}
}
std::string group_ranks = ss.str();
CommStruct comm_struct {commName, opName};
if (commset.find(comm_struct) == commset.end()) {
outfile << "[COMM]:" << commName << "," << opName << "," << group_ranks << "\n";
outfile.close();
commset.insert(comm_struct);
}
}
std::vector<std::shared_ptr<HCCLComm>>& ProcessGroupHCCL::getHCCLComm(
const std::string& devicesKey,
const std::vector<at::Device>& devices,
HcclCommType commType,
HcclCommConfig* commConfig,
int p2pRank)
{
if (devicesKey.empty()) {
throw std::runtime_error(
"Not able to create/get the HCCL Communicator since "
"the NPU devices are not known" + DIST_ERROR(ErrCode::PARAM));
}
for (auto& device : devices) {
usedDeviceIdxs_.insert(device.index());
}
{
std::lock_guard<std::mutex> lock(mutex_);
if (devHCCLCommMap_.find(devicesKey) != devHCCLCommMap_.end()) {
return devHCCLCommMap_[devicesKey];
}
}
HCCLTraceBuffer::get()->record_pg_ranks(std::make_tuple(pg_name_, pg_desc_), groupRanks());
return createHCCLComm(devicesKey, devices, commType, commConfig, p2pRank);
}
void ProcessGroupHCCL::setNSLBCommConfig(HcclCommConfig** commConfig)
{
const char* envPtr = std::getenv("RANK");
if (envPtr == nullptr) {
TORCH_NPU_HCCL_LOGI("Failed to get env info for NSLB-DP.");
return;
}
uint32_t worldRankID = std::stoi(std::string(envPtr));
options_->hccl_config["hccl_world_rank_id"] = worldRankID;
uint32_t masterPort = tcpMasterPort;
struct sockaddr_in sa;
std::string master_addr = tcpMasterAddr;
inet_pton(AF_INET, std::string(master_addr).c_str(), &(sa.sin_addr));
uint32_t masterIp = ntohl(sa.sin_addr.s_addr);
uint64_t jobID = masterPort;
jobID = (jobID << NSLB_JOBID_OFFSET);
jobID += masterIp;
options_->hccl_config["hccl_job_id"] = jobID;
if ((*commConfig) != nullptr) {
(*commConfig)->hcclWorldRankID = worldRankID;
(*commConfig)->hcclJobID = jobID;
}
}
c10_npu::NPUStream ProcessGroupHCCL::getHcclNPUStream(const at::Device &device)
{
auto it = options_->hccl_config.find("hccl_buffer_name");
if (it == options_->hccl_config.end()) {
return getNPUStreamByCurrentType(device.index());
}
auto bufferName = std::get<std::string>(it->second);
auto stream = getHcclStreamByBufferName(bufferName, device.index());
if (stream) {
TORCH_NPU_HCCL_LOGD("HCCL use the same stream with bufferName = %s, device_index = %d, stream id = %lu", bufferName.c_str(), device.index(), stream->id());
return stream.value();
}
auto newStream = getNPUStreamByCurrentType(device.index());
auto result = setHcclStreamByBufferName(bufferName, device.index(), newStream);
TORCH_NPU_HCCL_LOGD("HCCL use alloc new stream with bufferName = %s, device_index = %d, stream id = %lu. result = %d", bufferName.c_str(), device.index(), newStream.id(), result);
return newStream;
}
void ProcessGroupHCCL::createHCCLCommOrigin(
const std::string& devicesKey,
const std::vector<at::Device>& devices,
HcclCommType commType,
HcclCommConfig* commConfig,
std::vector<std::shared_ptr<HCCLComm>> &hcclComms,
std::vector<c10_npu::NPUStream> &streamVal,
int p2pRank)
{
HcclRootInfo hcclID;
bool isSingleP2POp = commType == HcclCommType::P2P ? true : false;
if (rank_ == 0 || (isSingleP2POp && p2pRank == 0)) {
HCCL_CHECK_ERROR(hcclGetRootInfo(&hcclID));
}
broadcastMasterID(&hcclID, isSingleP2POp, devicesKey, p2pRank);
c10_npu::OptionalNPUGuard npuGuard;
auto startTime = std::chrono::steady_clock::now();
for (size_t i = 0; i < devices.size(); ++i) {
int numRanks = getSize();
int rank = getRank() * static_cast<int>(devices.size()) + static_cast<int>(i);
HcclCommConfig config;
if (options_->global_ranks_in_group.empty()) {
setNSLBCommConfig(&commConfig);
}
npuGuard.set_index(devices[i].index());
switch (commType) {
case HcclCommType::DEFAULT:
if (commConfig != nullptr) {
checkHcclCommConfigValid(commConfig);
hcclComms[i] = HCCLComm::create_config(numRanks, rank, hcclID, commConfig);
} else {
config = createHcclCommConfigWithOptions();
hcclComms[i] = HCCLComm::create_config(numRanks, rank, hcclID, &config);
}
hcclComms[i]->hcclCommType = static_cast<int>(HcclCommType::DEFAULT);
break;
case HcclCommType::P2P:
numRanks = 2;
rank = p2pRank;
getHcclCommConfig(&config, true);
hcclComms[i] = HCCLComm::create_config(numRanks, rank, hcclID, &config);
hcclComms[i]->hcclCommType = static_cast<int>(HcclCommType::P2P);
hcclComms[i]->p2pPeer = getP2pPeer();
break;
default:
throw std::runtime_error(
"create/get the HCCL Communicator failed for comm type:" +
std::to_string(static_cast<int>(commType)) + DIST_ERROR(ErrCode::PARAM));
}
streamVal.push_back(getHcclNPUStream(devices[i]));
}
auto endTime = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
TORCH_NPU_HCCL_LOGI("Create hccl comm by hcclCommInitRootInfoConfig success, group id is %s, commType is %d, use %d ms.",
options_->group_id.c_str(), static_cast<int>(commType), timeElapsed.count());
}
bool ProcessGroupHCCL::createHCCLCommEx(
const std::string& devicesKey,
const std::vector<at::Device>& devices,
HcclCommType commType,
HcclCommConfig* commConfig,
std::vector<std::shared_ptr<HCCLComm>> &hcclComms,
std::vector<c10_npu::NPUStream> &streamVal,
int p2pRank)
{
std::string rankTableFile = c10_npu::option::OptionsManager::GetRankTableFilePath();
if (rankTableFile.empty() || !checkFilePathReadable(rankTableFile)) {
TORCH_NPU_HCCL_LOGI("The rank_table_file is not available, switch to original interface.");
return false;
}
if (c10_npu::option::OptionsManager::GetHCCLConnectTimeout() < 300) {
TORCH_NPU_WARN_ONCE("When creating an HCCL process group using the RANK_TABLE_FILE method, the connection may time out. ",
"It is recommended to set the timeout duration of HCCL_CONNECT_TIMEOUT to 300 seconds or more.");
}
if (!hcclCommInitClusterInfoConfigExist()) {
TORCH_NPU_HCCL_LOGI("The hcclCommInitClusterInfoConfig is not exist, switch to original interface.");
return false;
}
c10_npu::OptionalNPUGuard npuGuard;
if (!(options_->global_ranks_in_group.empty() && commType == HcclCommType::DEFAULT)) {
TORCH_NPU_HCCL_LOGI("createHCCLCommEx only handles global comm with ranktable, skip for sub/P2P comm.");
return false;
}
auto startTime = std::chrono::steady_clock::now();
for (size_t i = 0; i < devices.size(); ++i) {
int rank = getRank() * static_cast<int>(devices.size()) + static_cast<int>(i);
npuGuard.set_index(devices[i].index());
HcclCommConfig config;
if (commConfig == nullptr) {
config = createHcclCommConfigWithOptions();
commConfig = &config;
}
auto comm = HCCLComm::createGlobalHcclComm(rankTableFile.c_str(), rank, commConfig);
if (comm == nullptr) {
TORCH_NPU_HCCL_LOGI("Create global hccl comm with ranktable failed, switch to original interface.");
return false;
}
hcclComms[i] = comm;
streamVal.push_back(getHcclNPUStream(devices[i]));
}
auto endTime = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime);
TORCH_NPU_HCCL_LOGI("Create global hccl comm with ranktable success, take %d milliseconds", static_cast<int>(timeElapsed.count()));
return true;
}
bool ProcessGroupHCCL::createHCCLCommSub(
const std::string& devicesKey,
const std::vector<at::Device>& devices,
HcclCommType commType,
HcclCommConfig* commConfig,
std::vector<std::shared_ptr<HCCLComm>> &hcclComms,
std::vector<c10_npu::NPUStream> &streamVal,
int p2pRank)
{
if (!hcclCreateSubCommConfigExist()) {
TORCH_NPU_HCCL_LOGI("The hcclCreateSubCommConfig is not exist, switch to original interface.");
return false;
}
if (global_ == nullptr) {
TORCH_NPU_HCCL_LOGI("The global process group is not exist, switch to original interface.");
return false;
}
std::shared_ptr<HCCLComm> globalHcclComm = nullptr;
try {
globalHcclComm = global_->getHcclCommByDevices(devices);
} catch (const std::exception& e) {
TORCH_NPU_HCCL_LOGI("Get global HCCL communicator failed: %s, switch to original interface.", e.what());
return false;
}
if (!globalHcclComm) {
TORCH_NPU_HCCL_LOGI("Create sub hccl comm failed, globalHcclComm is nullptr, switch to original interface.");
return false;
}
c10_npu::OptionalNPUGuard npuGuard;
uint64_t hcclid = (std::hash<string>{}(options_->group_id));
auto subStartTime = std::chrono::steady_clock::now();
for (size_t i = 0; i < devices.size(); ++i) {
int numRanks = getSize();
int rank = getRank() * static_cast<int>(devices.size()) + static_cast<int>(i);
npuGuard.set_index(devices[i].index());
HcclCommConfig config;
if (commConfig == nullptr) {
config = createHcclCommConfigWithOptions();
if (commType == HcclCommType::P2P) {
numRanks = 2;
rank = p2pRank;
config.hcclBufferSize = c10_npu::option::OptionsManager::GetP2PBufferSize();
}
commConfig = &config;
}
std::shared_ptr<HCCLComm> subComm = nullptr;
if (commType == HcclCommType::P2P) {
uint32_t peer = static_cast<uint32_t>(getP2pPeer());
uint32_t lowRank = rank_ < peer ? rank_ : peer;
uint32_t highRank = rank_ < peer ? peer : rank_;
std::vector<uint32_t> p2pRanks;
if (options_->global_ranks_in_group.empty()) {
p2pRanks = {lowRank, highRank};
} else {
TORCH_CHECK(highRank < options_->global_ranks_in_group.size(), "p2p rank id must be smaller than group size", DIST_ERROR(ErrCode::PARAM));
p2pRanks = {options_->global_ranks_in_group[lowRank], options_->global_ranks_in_group[highRank]};
}
hcclid = (std::hash<string>{}(devicesKey));
std::string p2pName = "group" + options_->group_id + "_p2p_" + std::to_string(lowRank) + "_" + std::to_string(highRank);
if (strlen(commConfig->hcclCommName) > 0) {
#ifndef BUILD_LIBTORCH
torch_npu::toolkit::profiler::Utils::safe_strcpy_s(commConfig->hcclCommName, p2pName.c_str(), COMM_NAME_MAX_LENGTH);
#else
strncpy(commConfig->hcclCommName, p2pName.c_str(), COMM_NAME_MAX_LENGTH - 1);
commConfig->hcclCommName[COMM_NAME_MAX_LENGTH - 1] = '\0';
#endif
}
if (strlen(commConfig->hcclUdi) > 0) {
#ifndef BUILD_LIBTORCH
torch_npu::toolkit::profiler::Utils::safe_strcpy_s(commConfig->hcclUdi, p2pName.c_str(), UDI_MAX_LENGTH);
#else
strncpy(commConfig->hcclUdi, p2pName.c_str(), COMM_NAME_MAX_LENGTH - 1);
commConfig->hcclUdi[COMM_NAME_MAX_LENGTH - 1] = '\0';
#endif
}
subComm = HCCLComm::createSubHcclComm(globalHcclComm, numRanks, p2pRanks.data(), hcclid, rank, commConfig);
} else {
subComm = HCCLComm::createSubHcclComm(globalHcclComm, numRanks, options_->global_ranks_in_group.data(), hcclid, rank, commConfig);
}
if (subComm == nullptr) {
TORCH_NPU_HCCL_LOGI("Create sub hccl comm by hcclCreateSubCommConfig failed, group id is %s, subCommId is %llu, devicesKey is %s, switch to original interface.",
options_->group_id.c_str(), hcclid, devicesKey.c_str());
return false;
}
hcclComms[i] = subComm;
if (commType == HcclCommType::P2P) {
hcclComms[i]->p2pPeer = getP2pPeer();
}
streamVal.push_back(getHcclNPUStream(devices[i]));
}
auto subEndTime = std::chrono::steady_clock::now();
auto subTimeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(subEndTime - subStartTime);
TORCH_NPU_HCCL_LOGI("Create sub hccl comm by hcclCreateSubCommConfig success, group id is %s, subCommId is %llu, devicesKey is %s, use %d ms.",
options_->group_id.c_str(), hcclid, devicesKey.c_str(), static_cast<int>(subTimeElapsed.count()));
return true;
}
void ProcessGroupHCCL::createHCCLCommForZeroCopy(
std::vector<std::shared_ptr<HCCLComm>> &hcclComms,
std::unordered_map<std::string, std::string> &envMap)
{
TORCH_NPU_HCCL_LOGI("Rank %s create process group HCCL communicator for hccl zero copy", envMap["global_rank"].c_str());
std::string localRootRank = "0";
HcclRootInfo hcclID;
if (envMap["local_rank"] == localRootRank) {
HCCL_CHECK_ERROR(hcclGetRootInfo(&hcclID));
}
HcclRootInfo* hcclID_ = &hcclID;
std::string storeKey = "hccl_zero_copy_" + envMap["nodes_rank"] + "_" + std::to_string(hcclCommCounter_);
if (envMap["local_rank"] == localRootRank) {
auto vec = std::vector<uint8_t>(reinterpret_cast<uint8_t*>(hcclID_), reinterpret_cast<uint8_t*>(hcclID_) + HCCL_ROOT_INFO_BYTES);
store_->set(storeKey, vec);
} else {
try {
auto vec = store_->get(storeKey);
TORCH_CHECK(vec.size() == HCCL_ROOT_INFO_BYTES, DIST_ERROR(ErrCode::PARAM));
std::memcpy(hcclID_, vec.data(), vec.size());
} catch (const std::exception& e) {
std::string exceptionMsg = c10::str(
"[",
rank_,
"] is setting up HCCL communicator and "
"retrieving hcclUniqueId from [0] via c10d key-value store by key '",
storeKey,
"', but store->get('",
storeKey,
"') got error: ");
throw std::runtime_error(exceptionMsg + e.what() +
". This may indicate a possible application crash on rank 0 or a network set up issue." +
DIST_ERROR(ErrCode::INTERNAL));
} catch (...) {
throw std::runtime_error(c10::str(
"Unknown exception while [",
rank_,
"] is setting up HCCL communicator and "
"retrieving hcclUniqueId from [0] via c10d key-value store by key '",
storeKey,
"'",
". This may indicate a possible application crash on rank 0 or a network set up issue.") +
DIST_ERROR(ErrCode::INTERNAL));
}
}
hcclComms[0] = HCCLComm::create(std::stoi(envMap["local_world_size"]), std::stoi(envMap["local_rank"]), hcclID);
return;
}
std::vector<std::shared_ptr<HCCLComm>>& ProcessGroupHCCL::createHCCLComm(
const std::string& devicesKey,
const std::vector<at::Device>& devices,
HcclCommType commType,
HcclCommConfig* commConfig,
int p2pRank)
{
std::vector<std::shared_ptr<HCCLComm>> hcclComms;
hcclComms.resize(devices.size());
std::vector<c10_npu::NPUStream> streamVal;
streamVal.reserve(devices.size());
TORCH_NPU_HCCL_LOGI("Create HCCL comm, devicesKey %s, commType %d, p2pRank %d.", devicesKey.c_str(), commType, p2pRank);
for (const auto i : c10::irange(hcclActiveGroupCounter_)) {
(void)i;
auto hccl_call = [this]() -> HcclResult {
return hcclGroupEnd();
};
at_npu::native::OpCommand::RunOpApiV3("hcclGroupEnd", hccl_call);
}
bool isSubComm = !(options_->global_ranks_in_group.empty() && commType == HcclCommType::DEFAULT);
bool created = false;
if (isSubComm) {
created = createHCCLCommSub(devicesKey, devices, commType, commConfig, hcclComms, streamVal, p2pRank);
if (!created) {
TORCH_NPU_HCCL_LOGI("Sub comm derivation failed, fallback to original interface.");
}
} else {
created = createHCCLCommEx(devicesKey, devices, commType, commConfig, hcclComms, streamVal, p2pRank);
}
if (!created) {
createHCCLCommOrigin(devicesKey, devices, commType, commConfig, hcclComms, streamVal, p2pRank);
}
for (const auto i : c10::irange(hcclActiveGroupCounter_)) {
(void)i;
auto hccl_call = [this]() -> HcclResult {
return hcclGroupStart();
};
at_npu::native::OpCommand::RunOpApiV3("hcclGroupStart", hccl_call);
}
hcclStreams_.emplace(devicesKey, std::move(streamVal));
hcclEvents_.emplace(std::piecewise_construct, std::make_tuple(devicesKey), std::make_tuple(devices.size()));
rateCtrlEvents_.emplace(std::piecewise_construct, std::make_tuple(devicesKey), std::make_tuple(devices.size()));
collectiveCnts_.emplace(std::piecewise_construct, std::make_tuple(devicesKey), std::make_tuple(devices.size()));
std::lock_guard<std::mutex> lock(mutex_);
devHCCLCommMap_.emplace(devicesKey, std::move(hcclComms));
if (commType == HcclCommType::P2P) {
auto iter = p2pSendRecvKeys_.find(rank_);
if (iter == p2pSendRecvKeys_.end()) {
p2pSendRecvKeys_.emplace(rank_, std::vector<std::string>{devicesKey});
} else {
iter->second.push_back(devicesKey);
}
}
return devHCCLCommMap_[devicesKey];
}
int64_t ProcessGroupHCCL::getStreamId(bool p2p, int peer)
{
int device = -1;
NPU_CHECK_ERROR(c10_npu::GetDevice(&device));
std::vector<at::Device> devices = {at::Device(c10::DeviceType::PrivateUse1, device)};
auto key = getKeyFromDevices(devices);
if (p2p && hcclCommInitRootInfoConfigExist() && c10_npu::option::OptionsManager::GetP2PBufferSize() != 0) {
TORCH_CHECK(
peer >= 0,
"In p2p scenarios, the passed 'dst rank id' : ",
peer,
" is error, ",
"expected value >= 0.",
DIST_ERROR(ErrCode::PARAM));
key = getKeySendRecv(rank_, peer);
}
if ((hcclStreams_.count(key) == 0) || hcclStreams_[key].empty()) {
return -1;
}
return hcclStreams_[key][0].id();
}
int64_t ProcessGroupHCCL::getCollStreamId(at::Device device)
{
const auto key = getKeyFromDevice({device});
if (hcclStreams_.find(key) == hcclStreams_.end() || hcclStreams_[key].empty()) {
bool force_high = c10d::getCvarBool(TORCH_HCCL_HIGH_PRIORITY, false);
auto streamVal = c10_npu::getStreamFromPool(
options_->is_high_priority_stream || force_high, device.index());
hcclStreams_.emplace(key, std::vector<c10_npu::NPUStream>{streamVal});
return streamVal.id();
}
auto hcclStream = hcclStreams_[key][0];
return hcclStream.id();
}
int64_t ProcessGroupHCCL::getP2PStreamId(
at::Device device,
int peer,
int is_batched)
{
std::string key;
if (is_batched == 1) {
std::vector<at::Device> devices = {device};
key = getKeyFromDevices(devices);
} else {
key = getKeySendRecv(rank_, peer);
}
if (hcclStreams_.find(key) == hcclStreams_.end() || hcclStreams_[key].empty()) {
LOG(INFO) << "getP2PStreamId: keys: (";
for (const auto& pair : hcclStreams_) {
LOG(INFO) <<"<"<< pair.first << ">, ";
}
LOG(INFO) << ")" << std::endl;
LOG(INFO) << "the key to look for: <" << key << ">" << std::endl;
return -1;
}
return hcclStreams_[key][0].id();
}
void ProcessGroupHCCL::windowRegisterAndExchange(int64_t windowSize, std::vector<uint32_t>& peerRanks)
{
TORCH_CHECK(windowSize > 0, "Window memory must be greater than 0.", DIST_ERROR(ErrCode::PARAM));
TORCH_CHECK(!windowMem_, "Window memory cannnot be registered repeatedly.", DIST_ERROR(ErrCode::UNAVAIL));
TORCH_CHECK(!c10_npu::option::OptionsManager::IsHcclZeroCopyEnable(),
"Window memory register unsupport set HCCL_ZERO_COPY=1", DIST_ERROR(ErrCode::UNAVAIL));
auto options = at::TensorOptions(c10::DeviceType::PrivateUse1).dtype(at::kChar);
windowMem_ = at::empty({windowSize}, options);
std::vector<at::Device> devices = {windowMem_->device()};
auto comm = getHcclCommByDevices(devices);
HCCL_CHECK_ERROR(hcclCommRegister(comm->getHcclComm(), windowMem_->data_ptr(), windowSize, &windowHandle_, 0));
HCCL_CHECK_ERROR(hcclCommExchangeMem(comm->getHcclComm(), windowHandle_, peerRanks.data(), peerRanks.size()));
}
const at::Tensor& ProcessGroupHCCL::getWindowMem()
{
TORCH_CHECK(windowMem_, "window memory must be registered before get.", DIST_ERROR(ErrCode::UNAVAIL))
return windowMem_.value();
}
void ProcessGroupHCCL::setTimeout(std::chrono::milliseconds timeout)
{
options_->timeout = timeout;
}
void ProcessGroupHCCL::assignTimeoutToWork(const c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>& work,
const c10::intrusive_ptr<ProcessGroupHCCL::Options>& option)
{
std::chrono::milliseconds timeout = option->timeout;
std::lock_guard<std::mutex> timeoutLock(mtxTimeoutExtension_);
if (ephemeralTimeoutActive_.count() > 0) {
timeout += ephemeralTimeoutActive_;
}
work->opTimeout_ = timeout;
work->ownedEphermeralTimeout_ = ephemeralTimeoutActive_ - ephemeralTimeoutInflight_;
ephemeralTimeoutInflight_ = ephemeralTimeoutActive_;
}
namespace {
void check_npu_tensors_different_devices(const std::vector<at::Tensor>& tensors)
{
if (tensors.size() == 0) {
TORCH_CHECK(false, "Tensor list must be nonempty", DIST_ERROR(ErrCode::PARAM));
}
if (tensors.size() != 1) {
TORCH_CHECK(false, "Tensor list mustn't be larger than the number of available NPUs", DIST_ERROR(ErrCode::VALUE));
}
const auto& first = tensors.front();
std::unordered_set<decltype(first.get_device())> usedDevices;
usedDevices.reserve(tensors.size());
for (const auto& t : tensors) {
if (!torch_npu::utils::is_npu(t) || t.is_sparse()) {
TORCH_CHECK(false, "Tensors must be NPU and dense", DIST_ERROR(ErrCode::TYPE));
}
if (t.scalar_type() != first.scalar_type()) {
TORCH_CHECK(false, "Tensors must have identical type", DIST_ERROR(ErrCode::TYPE));
}
if (t.sizes() != first.sizes()) {
TORCH_CHECK(false, "Tensors must have identical size", DIST_ERROR(ErrCode::TYPE));
}
if (t.strides() != first.strides()) {
TORCH_CHECK(false, "Tensors must have identical strides", DIST_ERROR(ErrCode::TYPE));
}
if (!t.is_contiguous(t.suggest_memory_format())) {
TORCH_CHECK(false, "Tensors must be contiguous", DIST_ERROR(ErrCode::TYPE));
}
if (!at_npu::native::FormatHelper::IsBaseFormatType(t) && (t.storage().data_ptr().get() != t.data_ptr())) {
TORCH_CHECK(false, "For a tensor of internal format, it's storage_offset must be 0", DIST_ERROR(ErrCode::NOT_SUPPORT));
}
const auto inserted = usedDevices.insert(t.get_device()).second;
if (!inserted) {
TORCH_CHECK(false, "Tensors must be on distinct NPU devices", DIST_ERROR(ErrCode::TYPE));
}
}
}
void check_npu_tensors_same_device(const std::vector<at::Tensor>& tensors)
{
if (tensors.size() == 0) {
TORCH_CHECK(false, "Tensor list must be nonempty", DIST_ERROR(ErrCode::PARAM));
}
const auto& first = tensors.front();
for (const auto& t : tensors) {
if (!torch_npu::utils::is_npu(t) || t.is_sparse()) {
TORCH_CHECK(false, "Tensors must be NPU and dense", DIST_ERROR(ErrCode::TYPE));
}
TORCH_CHECK(
t.scalar_type() == first.scalar_type(),
"Tensors must have identical type",
DIST_ERROR(ErrCode::TYPE));
TORCH_CHECK(
t.is_non_overlapping_and_dense(),
"Tensors must be non-overlapping and dense",
DIST_ERROR(ErrCode::TYPE));
TORCH_CHECK(
t.get_device() == first.get_device(),
"Tensors must be on same NPU device",
DIST_ERROR(ErrCode::TYPE));
if (!at_npu::native::FormatHelper::IsBaseFormatType(t) && (t.storage().data_ptr().get() != t.data_ptr())) {
TORCH_CHECK(false, "For a tensor of internal format, it's storage_offset must be 0", DIST_ERROR(ErrCode::NOT_SUPPORT));
}
}
}
void check_npu_single_tensor(const at::Tensor& tensor)
{
if (!torch_npu::utils::is_npu(tensor) || tensor.is_sparse()) {
TORCH_CHECK(false, "Tensors must be NPU and dense", DIST_ERROR(ErrCode::TYPE));
}
if (!tensor.is_contiguous(tensor.suggest_memory_format())) {
TORCH_CHECK(false, "Tensors must be contiguous", DIST_ERROR(ErrCode::TYPE));
}
if (!at_npu::native::FormatHelper::IsBaseFormatType(tensor) && (tensor.storage().data_ptr().get() != tensor.data_ptr())) {
TORCH_CHECK(false, "For a tensor of internal format, it's storage_offset must be 0", DIST_ERROR(ErrCode::NOT_SUPPORT));
}
}
bool check_same_size(const std::vector<at::Tensor>& input_tensors)
{
for (const auto& input_tensor : input_tensors) {
if (!input_tensors[0].is_same_size(input_tensor)) {
return false;
}
}
return true;
}
bool has_empty_tensor(const std::vector<at::Tensor>& tensors)
{
for (const auto& tensor : tensors) {
if (tensor.data_ptr() == nullptr) {
return true;
}
}
return false;
}
std::vector<at::Tensor> cast_to_origin_format(const std::vector<at::Tensor>& inputTensors)
{
std::vector<at::Tensor> inputTensors_;
inputTensors_.resize(inputTensors.size());
size_t index = 0;
for (auto& tensor : inputTensors) {
if (at_npu::native::FormatHelper::IsBaseFormatType(tensor)) {
inputTensors_[index] = tensor;
} else {
auto origin_format = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_.origin_format_;
inputTensors_[index] = at_npu::native::custom_ops::npu_format_cast(tensor, origin_format);
}
index++;
}
return inputTensors_;
}
std::vector<at::Tensor> create_base_format_tensors(const std::vector<at::Tensor>& inputTensors)
{
std::vector<at::Tensor> inputTensors_;
inputTensors_.resize(inputTensors.size());
for (size_t i = 0; i < inputTensors.size(); ++i) {
if (at_npu::native::FormatHelper::IsBaseFormatType(inputTensors[i])) {
inputTensors_[i] = inputTensors[i];
} else {
auto options = at::TensorOptions().dtype(inputTensors[i].dtype()).device(inputTensors[i].device());
inputTensors_[i] = at_npu::native::NPUNativeFunctions::empty(
inputTensors[i].sizes(),
options.dtype().toScalarType(),
options.layout_opt(),
options.device_opt(),
options.pinned_memory_opt(),
options.memory_format_opt());
}
}
return inputTensors_;
}
std::vector<at::Tensor> flatten_for_scatter_gather(
std::vector<std::vector<at::Tensor>>& tensor_lists,
std::vector<at::Tensor>& other,
size_t world_size)
{
if (tensor_lists.size() != other.size()) {
TORCH_CHECK(false, "Tensor list operands to scatter/gather must have the same length", DIST_ERROR(ErrCode::VALUE));
}
const auto num_devices = tensor_lists.size();
std::vector<at::Tensor> flattened;
flattened.resize(num_devices);
for (auto i = size_t{}; i < num_devices; ++i) {
if (tensor_lists[i].size() != world_size * num_devices) {
TORCH_CHECK(
false,
"Tensor list input to scatter/gather must match number of collective"
" participants", DIST_ERROR(ErrCode::PARAM));
}
if (tensor_lists[i].front().get_device() != other[i].get_device()) {
TORCH_CHECK(
false,
"Corresponding input/output tensors to scatter/gather must all reside"
" on the same device", DIST_ERROR(ErrCode::PARAM));
}
for (const auto& t : tensor_lists[i]) {
if (t.numel() != other[i].numel()) {
TORCH_CHECK(false, "All tensor operands to scatter/gather must have the same size", DIST_ERROR(ErrCode::PARAM));
}
}
flattened[i] = c10d::newLikeFlat(tensor_lists, i);
}
return flattened;
}
void nslb_record_end()
{
std::string end_file_path;
std::ofstream endfile;
end_file_path = c10::str(nslb_path, "/end_", getenv("MASTER_ADDR"), "_", getpid(), ".log");
try {
if (access(nslb_path, W_OK) != 0 && mkdir(nslb_path, S_IRWXU | S_IRGRP | S_IXGRP) != 0) {
throw std::exception();
}
if (access(end_file_path.c_str(), W_OK) != 0) {
int fd = open(end_file_path.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR | S_IRGRP);
if (fd == -1) {
throw std::exception();
}
close(fd);
}
} catch (std::exception& e) {
throw std::runtime_error("NSLB set end failed." + DIST_ERROR(ErrCode::NOT_FOUND));
}
}
}
c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL> ProcessGroupHCCL::initWork(
std::vector<at::Device> devices,
int rank,
c10d::OpType opType,
const char* profilingTitle,
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& outputs,
bool record)
{
if (devices.size() != 1) {
throw std::runtime_error("ProcessGroupHCCL support one device per process only" + DIST_ERROR(ErrCode::NOT_SUPPORT));
}
auto r = c10::make_intrusive<ProcessGroupHCCL::WorkHCCL>(devices, rank, opType, seq_, desyncDebug_);
if (record) {
bool isP2P = c10d::isP2POp(opType);
r->trace_id_ = HCCLTraceBuffer::get()->record(
uid_,
std::make_tuple(pg_name_, pg_desc_),
seqCollective_,
seqP2P_,
seq_,
profilingTitle ? profilingTitle : "",
inputs,
outputs,
desyncDebug_? &((*(r->hcclStartEvents_))[0]) : nullptr,
&((*(r->hcclEndEvents_))[0]),
options_->timeout,
pgStatus_,
isP2P);
}
return r;
}
void ProcessGroupHCCL::workEnqueue(c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL> work)
{
if (!device_error_msg.empty()) {
TORCH_NPU_HCCL_LOGE("Find %s when workEnqueue, throw %s.", device_error_msg.c_str(), device_error_msg.c_str());
std::string errorMsg = device_error_msg + " happened with workEnqueue.";
device_error_msg = "";
throw std::runtime_error(errorMsg + PTA_ERROR(ErrCode::ACL));
return;
}
if (force_stop_error_flag) {
force_stop_error_flag = false;
TORCH_NPU_HCCL_LOGE("force_stop_error_flag is true when workEnqueue, throw FORCE STOP.");
throw std::runtime_error("FORCE STOP." + PTA_ERROR(ErrCode::ACL));
return;
}
if (watchdogStatus == WatchdogStatus::STOP) {
return;
}
if (!terminateProcessGroup_.load()) {
std::lock_guard<std::mutex> lock(workMetaListMutex_);
workMetaList_.emplace_back(*work);
pgStatus_->lastEnqueuedSeq = static_cast<int64_t>(work->seq_);
pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_);
pgStatus_->lastEnqueuedNumelIn = work->numelIn_;
pgStatus_->lastEnqueuedNumelOut = work->numelOut_;
}
}
ProcessGroupHCCL::Options::Options(bool is_high_priority_stream)
: c10d::Backend::Options(HCCL_BACKEND_NAME),
opTimeout(kProcessGroupHCCLOpTimeoutMillis),
is_high_priority_stream(is_high_priority_stream)
{
}
std::shared_ptr<HCCLComm> ProcessGroupHCCL::getHcclCommByDevices(const std::vector<at::Device>& devices)
{
const auto key = getKeyFromDevices(devices);
auto& hcclComms = getHCCLComm(key, devices);
TORCH_CHECK(hcclComms.size() == 1, "expect hcclComms.size() = 1, but hcclComms.size() = ",
hcclComms.size(), DIST_ERROR(ErrCode::VALUE));
return hcclComms[0];
}
int64_t ProcessGroupHCCL::getHcclComm(int rankid)
{
at::Device device = getDeviceForRank(rankid);
std::vector<at::Device> devices = {device};
const auto key = getKeyFromDevices(devices);
auto& hcclComms = getHCCLComm(key, devices);
TORCH_CHECK(hcclComms.size() == 1, "expect hcclComms.size() = 1, but hcclComms.size() = ",
hcclComms.size(), DIST_ERROR(ErrCode::VALUE));
auto ret_hcom = hcclComms[0]->getHcclComm();
int64_t hccl_comm = static_cast<int64_t>(reinterpret_cast<intptr_t>(ret_hcom));
return hccl_comm;
}
void ProcessGroupHCCL::resumeHcclComm(int device_id)
{
TORCH_NPU_HCCL_LOGI("resumeHcclComm, device_id %d.", device_id);
at::Device device = at::Device(c10::DeviceType::PrivateUse1, device_id);
std::vector<at::Device> devices = {device};
auto key = getKeyFromDevices(devices);
{
std::lock_guard<std::mutex> lock(mutex_);
if (devHCCLCommMap_.find(key) != devHCCLCommMap_.end()) {
auto& hcclComms = devHCCLCommMap_[key];
for (const auto& hcclComm : hcclComms) {
auto comm = hcclComm->getHcclComm();
TORCH_NPU_HCCL_LOGI("getHcclComm, hcclComm is %p, device_id %d.", comm, device_id);
HCCL_CHECK_ERROR(at_npu::hccl::HcclCommResumeFace(comm));
}
}
if (p2pSendRecvKeys_.find(rank_) != p2pSendRecvKeys_.end()) {
auto p2pKeys = p2pSendRecvKeys_[rank_];
for (const auto& p2pKey : p2pKeys) {
if (devHCCLCommMap_.find(p2pKey) != devHCCLCommMap_.end()) {
auto& hcclComms = devHCCLCommMap_[p2pKey];
for (const auto& hcclComm : hcclComms) {
auto comm = hcclComm->getHcclComm();
TORCH_NPU_HCCL_LOGI("getHcclComm, hcclComm is %p, device_id %d.", comm, device_id);
HCCL_CHECK_ERROR(at_npu::hccl::HcclCommResumeFace(comm));
}
}
}
}
}
TORCH_NPU_HCCL_LOGI("resumeHcclComm success, group id is %s, device_id is %d.", options_->group_id.c_str(), device_id);
}
bool ProcessGroupHCCL::setCommWorkingDevNic(
const HcclComm& comm,
int nranks,
std::vector<uint32_t>& ranks,
std::vector<bool>& useBackup,
int rankid,
int hcclCommType,
int p2pPeer)
{
HcclComm sendComm = comm;
uint32_t sendnRank = 0;
std::vector<uint32_t> sendRanks;
std::vector<bool> sendUseBackup;
if (hcclCommType == 1) {
int p2pRank = rankid <= p2pPeer ? 0 : 1;
bool isSendRecvSelf = rank_ == p2pPeer;
uint32_t p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank;
for (int i = 0; i < nranks; i++) {
if (ranks[i] == static_cast<uint32_t>(rankid)) {
sendRanks.push_back(p2pRank);
sendUseBackup.push_back(useBackup[i]);
sendnRank++;
}
if (ranks[i] == p2pTargetRank) {
sendRanks.push_back(p2pTargetRank);
sendUseBackup.push_back(useBackup[i]);
sendnRank++;
}
}
} else {
for (int i = 0; i < nranks; i++) {
uint32_t localrank = 0;
for (uint32_t val : groupRanks()) {
if (ranks[i] == val) {
sendRanks.push_back(localrank);
sendUseBackup.push_back(useBackup[i]);
sendnRank++;
break;
}
localrank++;
}
}
}
if (sendnRank == 0) {
return true;
}
bool useBackupArr[sendUseBackup.size()];
uint32_t sendRanksArr[sendRanks.size()];
for (size_t i = 0; i < sendnRank; i++) {
useBackupArr[i] = sendUseBackup[i];
sendRanksArr[i] = sendRanks[i];
}
auto ret = hcclCommWorkingDevNicSet(sendComm, sendRanksArr, useBackupArr, sendnRank);
if (ret != HCCL_SUCCESS) {
TORCH_NPU_HCCL_LOGI("Fail to hcclCommWorkingDevNicSet, ret is %d.", ret);
return false;
}
return true;
}
bool ProcessGroupHCCL::setSwitchNicComm(int rankid, int nranks, std::vector<uint32_t>& ranks, std::vector<bool>& useBackup)
{
if (!hcclCommWorkingDevNicSetExist()) {
TORCH_NPU_HCCL_LOGI("The hcclCommWorkingDevNicSet does not exist. Skip it.");
return true;
}
at::Device device = getDeviceForRank(rankid);
std::vector<at::Device> devices = {device};
auto key = getKeyFromDevices(devices);
{
std::lock_guard<std::mutex> lock(mutex_);
if (devHCCLCommMap_.find(key) != devHCCLCommMap_.end()) {
auto& hcclComms = devHCCLCommMap_[key];
for (auto& hcclComm : hcclComms) {
HcclComm comm = hcclComm->getHcclComm();
bool result = setCommWorkingDevNic(comm, nranks, ranks, useBackup, rankid, hcclComm->hcclCommType, hcclComm->p2pPeer);
if (!result) {
return false;
}
}
} else {
return true;
}
}
TORCH_NPU_HCCL_LOGI("Succeed to hcclCommWorkingDevNicSet");
return true;
}
void ProcessGroupHCCL::setWatchdogStatus(int status)
{
watchdogStatus = WatchdogStatus(status);
if (watchdogStatus == WatchdogStatus::RUN) {
device_error_msg = "";
force_stop_error_flag = false;
}
}
void ProcessGroupHCCL::clearWorkMetaList()
{
std::unique_lock<std::mutex> lock(workMetaListMutex_);
workMetaList_.clear();
}
void ProcessGroupHCCL::setHcclCommName(const std::string& hccl_comm_name)
{
auto nameSize = hccl_comm_name.size();
TORCH_CHECK(nameSize > 0 && nameSize < COMM_NAME_MAX_LENGTH,
"The length of the name must be between 1 and ", COMM_NAME_MAX_LENGTH - 1, ", Invalid hcclCommName:",
hccl_comm_name, DIST_ERROR(ErrCode::VALUE));
TORCH_NPU_HCCL_LOGI("Set HCCL comm name, hccl_comm_name %s, size %d.", hccl_comm_name.c_str(), nameSize);
c10::DeviceIndex indexFromCurDevice = c10_npu::current_device();
at::Device device = at::Device(c10::DeviceType::PrivateUse1, indexFromCurDevice);
std::vector <at::Device> devices = {device};
const auto key = getKeyFromDevices(devices);
std::lock_guard <std::mutex> lock(mutex_);
auto hcclCommNameIter = devHCCLCommNameMap_.emplace(key, hccl_comm_name);
auto currentHcclCommName = hcclCommNameIter.first->second;
TORCH_CHECK(currentHcclCommName == hccl_comm_name,
"The current ProcessGroup has already set the name and cannot be duplicated, Invalid hcclCommName:",
hccl_comm_name, ", current hcclCommName:", currentHcclCommName, DIST_ERROR(ErrCode::VALUE));
}
std::string ProcessGroupHCCL::getHcclCommName(int rankid, bool init_comm)
{
TORCH_NPU_HCCL_LOGI("Get HCCL comm name, rankid %d, init_comm %d.", rankid, init_comm);
TORCH_CHECK(rankid >= 0, "Invalid rank ", rankid, DIST_ERROR(ErrCode::VALUE));
auto numNPUs = c10_npu::device_count();
TORCH_CHECK(numNPUs > 0, "Invalid device number", numNPUs, DIST_ERROR(ErrCode::VALUE));
c10::DeviceIndex indexFromRank = static_cast<c10::DeviceIndex>(rankid % numNPUs);
c10::DeviceIndex indexFromCurDevice = c10_npu::current_device();
if (indexFromRank != indexFromCurDevice) {
std::string warning_message = "The indexFromRank " + std::to_string(indexFromRank) +
"is not equal indexFromCurDevice " + std::to_string(indexFromCurDevice) +
" , which might be normal if the number of devices on your collective communication server is inconsistent." +
"Otherwise, you need to check if the current device is correct when calling the interface." +
"If it's incorrect, it might have introduced an error.";
TORCH_WARN_ONCE(warning_message);
}
at::Device device = at::Device(c10::DeviceType::PrivateUse1, indexFromCurDevice);
std::vector<at::Device> devices = {device};
const auto key = getKeyFromDevices(devices);
if (!init_comm) {
std::lock_guard<std::mutex> lock(mutex_);
if (devHCCLCommMap_.find(key) == devHCCLCommMap_.end()) {
return "";
}
}
HcclCommConfig config = createHcclCommConfigWithOptions();
std::string hcclCommName = "";
{
std::lock_guard <std::mutex> lock(mutex_);
hcclCommName = devHCCLCommNameMap_[key];
}
if (!hcclCommName.empty()) {
#ifndef BUILD_LIBTORCH
torch_npu::toolkit::profiler::Utils::safe_strcpy_s(config.hcclCommName, hcclCommName.c_str(),
COMM_NAME_MAX_LENGTH);
#else
strncpy(config.hcclCommName, hcclCommName.c_str(), COMM_NAME_MAX_LENGTH - 1);
config.hcclCommName[COMM_NAME_MAX_LENGTH - 1] = '\0';
#endif
}
std::vector <std::shared_ptr<HCCLComm>> hcclComms = getHCCLComm(key, devices, HcclCommType::DEFAULT, &config);
TORCH_CHECK(hcclComms.size() == 1, "expect hcclComms.size() = 1, but hcclComms.size() = ",
hcclComms.size(), DIST_ERROR(ErrCode::VALUE));
HcclComm hcom = hcclComms[0]->getHcclComm();
char commName[MAX_GROUP_NAME_LEN] = {};
HCCL_CHECK_ERROR(at_npu::hccl::HcclGetCommNameFace(hcom, commName));
return std::string(commName);
}
std::string ProcessGroupHCCL::getHcclCommNameWithoutInit(std::vector<std::shared_ptr<HCCLComm>>& hcclComms) const
{
TORCH_CHECK(hcclComms.size() == 1, "expect hcclComms.size() = 1, but hcclComms.size() = ",
hcclComms.size(), DIST_ERROR(ErrCode::VALUE));
HcclComm ret_hcom = hcclComms[0]->getHcclComm();
char commName[MAX_GROUP_NAME_LEN];
HCCL_CHECK_ERROR(at_npu::hccl::HcclGetCommNameFace(ret_hcom, commName));
std::string name_str(commName);
return name_str;
}
std::string mapToJson(const std::unordered_map<std::string, std::string>& map)
{
std::stringstream ss;
ss << "{";
bool first = true;
for (const auto& pair : map) {
if (!first) {
ss << ",";
}
ss << "\\\"" << pair.first << "\\\"" << ": " << "\\\"" << pair.second << "\\\"";
first = false;
}
ss << "}";
return ss.str();
}
#ifndef BUILD_LIBTORCH
std::string ProcessGroupHCCL::getMstxHcclMsg(
const std::string &opName, uint64_t dataCnt, HcclDataType dataType, HcclComm comm, int64_t streamId,
int srcRank, int dstRank)
{
const static std::map<HcclDataType, std::string> dataTypes = {
{HCCL_DATA_TYPE_INT8, "int8"},
{HCCL_DATA_TYPE_INT16, "int16"},
{HCCL_DATA_TYPE_INT32, "int32"},
{HCCL_DATA_TYPE_FP16, "fp16"},
{HCCL_DATA_TYPE_FP32, "fp32"},
{HCCL_DATA_TYPE_INT64, "int64"},
{HCCL_DATA_TYPE_UINT64, "uint64"},
{HCCL_DATA_TYPE_UINT8, "uint8"},
{HCCL_DATA_TYPE_UINT16, "uint16"},
{HCCL_DATA_TYPE_UINT32, "uint32"},
{HCCL_DATA_TYPE_FP64, "fp64"},
{HCCL_DATA_TYPE_BFP16, "bfp16"}
};
static std::map<HcclComm, std::string> commNames;
if (!torch_npu::profiler::mstxEnable()) {
return "";
}
std::unordered_map<std::string, std::string> msgDict;
msgDict["opName"] = opName;
auto nameIter = commNames.find(comm);
if (nameIter == commNames.end()) {
char commName[MAX_GROUP_NAME_LEN];
HCCL_CHECK_ERROR(at_npu::hccl::HcclGetCommNameFace(comm, commName));
std::string name(commName);
commNames.insert({comm, name});
msgDict["groupName"] = name;
} else {
msgDict["groupName"] = nameIter->second;
}
std::string data_type_str = "na";
auto iter = dataTypes.find(dataType);
if (iter != dataTypes.end()) {
data_type_str = iter->second;
}
if (srcRank != -1) {
msgDict["srcRank"] = std::to_string(srcRank);
}
if (dstRank != -1) {
msgDict["destRank"] = std::to_string(dstRank);
}
msgDict["dataType"] = data_type_str;
msgDict["count"] = std::to_string(dataCnt);
msgDict["streamId"] = std::to_string(streamId);
return mapToJson(msgDict);
}
#endif
void ProcessGroupHCCL::silenceCheck(at::Tensor &input, c10d::OpType opType)
{
if (input.scalar_type() != at::kFloat && input.scalar_type() != at::kBFloat16) {
return;
}
if (input.requires_grad()) {
return;
}
if (opType != c10d::OpType::SEND && opType != c10d::OpType::RECV && opType != c10d::OpType::UNKNOWN) {
if (c10_npu::model_state().get_call_state() != c10_npu::CallStateMode::L_BACKWARD) {
return;
}
if (opType == c10d::OpType::ALLREDUCE && input.numel() <= 1) {
return;
}
}
if (silenceCheckCache_.find(opType) == silenceCheckCache_.end()) {
at::Tensor stepTensor = at::zeros({1}, input.options().dtype(at::kLong));
at::Tensor cacheTensor = at::zeros({3}, input.options().dtype(at::kFloat));
silenceCheckCache_.emplace(opType, std::make_pair(std::move(stepTensor), std::move(cacheTensor)));
}
at::Tensor val = at::norm(input);
static double min_steps = 100.0;
op_plugin::_npu_silent_check_v2(val, input, silenceCheckCache_[opType].second, silenceCheckCache_[opType].first, min_steps,
c10_npu::option::OptionsManager::GetSilenceUpperThresh().first, c10_npu::option::OptionsManager::GetSilenceSigmaThresh().first,
c10_npu::option::OptionsManager::GetSilenceUpperThresh().second, c10_npu::option::OptionsManager::GetSilenceSigmaThresh().second,
static_cast<int64_t>(c10_npu::option::OptionsManager::GetSilenceCheckFlag()));
}
HcclCommConfig ProcessGroupHCCL::createHcclCommConfigWithOptions()
{
HcclCommConfig config;
getHcclCommConfig(&config);
if (isHcclFeatureSupported(HcclCommConfigCapability::HCCL_COMM_CONFIG_COMM_NAME)) {
std::string groupName = getGroupName();
#ifndef BUILD_LIBTORCH
torch_npu::toolkit::profiler::Utils::safe_strcpy_s(config.hcclCommName, groupName.c_str(), COMM_NAME_MAX_LENGTH);
#else
strncpy(config.hcclCommName, groupName.c_str(), COMM_NAME_MAX_LENGTH - 1);
config.hcclCommName[COMM_NAME_MAX_LENGTH - 1] = '\0';
#endif
}
if (options_->hccl_config.empty()) {
return config;
}
if (options_->hccl_config.find("hccl_buffer_size") != options_->hccl_config.end()) {
if (std::holds_alternative<uint32_t>(options_->hccl_config["hccl_buffer_size"])) {
config.hcclBufferSize = std::get<uint32_t>(options_->hccl_config["hccl_buffer_size"]);
} else {
TORCH_CHECK(false, "Value type of hccl_buffer_size should be int.", DIST_ERROR(ErrCode::TYPE));
}
}
if (options_->hccl_config.find("group_name") != options_->hccl_config.end()) {
if (std::holds_alternative<std::string>(options_->hccl_config["group_name"])) {
auto hcclGroupName = std::get<std::string>(options_->hccl_config["group_name"]);
uint32_t udiLength = hcclGroupName.length();
if (hcclGroupName.length() >= UDI_MAX_LENGTH) {
udiLength = UDI_MAX_LENGTH - 1;
TORCH_NPU_WARN("The length of group_name has exceeded the limit UDI_MAX_LENGTH which will be truncated to UDI_MAX_LENGTH - 1.");
}
strncpy(config.hcclUdi, hcclGroupName.c_str(), udiLength);
config.hcclUdi[udiLength] = '\0';
} else {
TORCH_CHECK(false, "Value type of group_name should be string.", DIST_ERROR(ErrCode::TYPE));
}
}
if (options_->hccl_config.find("qos_traffic_class") != options_->hccl_config.end()) {
if (std::holds_alternative<uint32_t>(options_->hccl_config["qos_traffic_class"])) {
config.hcclRdmaTrafficClass = std::get<uint32_t>(options_->hccl_config["qos_traffic_class"]);
} else {
TORCH_CHECK(false, "Value type of qos_traffic_class should be int.", DIST_ERROR(ErrCode::TYPE));
}
}
if (options_->hccl_config.find("qos_service_level") != options_->hccl_config.end()) {
if (std::holds_alternative<uint32_t>(options_->hccl_config["qos_service_level"])) {
config.hcclRdmaServiceLevel = std::get<uint32_t>(options_->hccl_config["qos_service_level"]);
} else {
TORCH_CHECK(false, "Value type of qos_service_level should be int.", DIST_ERROR(ErrCode::TYPE));
}
}
if (options_->hccl_config.find("hccl_sdma_qos") != options_->hccl_config.end()) {
if (std::holds_alternative<uint32_t>(options_->hccl_config["hccl_sdma_qos"])) {
config.hcclQos = std::get<uint32_t>(options_->hccl_config["hccl_sdma_qos"]);
} else {
TORCH_CHECK(false, "Value type of hccl_sdma_qos should be int.", DIST_ERROR(ErrCode::TYPE));
}
}
if (options_->hccl_config.find("hccl_op_expansion_mode") != options_->hccl_config.end()) {
if (std::holds_alternative<uint32_t>(options_->hccl_config["hccl_op_expansion_mode"])) {
config.hcclOpExpansionMode = std::get<uint32_t>(options_->hccl_config["hccl_op_expansion_mode"]);
} else {
TORCH_CHECK(false, "Value type of hccl_op_expansion_mode should be int.", DIST_ERROR(ErrCode::TYPE));
}
}
if (options_->hccl_config.find("hccl_world_rank_id") != options_->hccl_config.end()) {
if (std::holds_alternative<uint32_t>(options_->hccl_config["hccl_world_rank_id"])) {
config.hcclWorldRankID = std::get<uint32_t>(options_->hccl_config["hccl_world_rank_id"]);
} else {
TORCH_CHECK(false, "Value type of hccl_world_rank_id should be int.", DIST_ERROR(ErrCode::TYPE));
}
}
if (options_->hccl_config.find("hccl_job_id") != options_->hccl_config.end()) {
if (std::holds_alternative<uint64_t>(options_->hccl_config["hccl_job_id"])) {
config.hcclJobID = std::get<uint64_t>(options_->hccl_config["hccl_job_id"]);
} else {
TORCH_CHECK(false, "Value type of hccl_job_id should be int.", DIST_ERROR(ErrCode::TYPE));
}
}
if (options_->hccl_config.find("hccl_exec_timeout") != options_->hccl_config.end()) {
if (std::holds_alternative<int32_t>(options_->hccl_config["hccl_exec_timeout"])) {
config.hcclExecTimeOut = std::get<int32_t>(options_->hccl_config["hccl_exec_timeout"]);
if (config.hcclExecTimeOut < 0) {
TORCH_NPU_WARN_ONCE("Value type of hccl_exec_timeout less than 0.");
}
} else if (std::holds_alternative<uint32_t>(options_->hccl_config["hccl_exec_timeout"])) {
uint32_t value = std::get<uint32_t>(options_->hccl_config["hccl_exec_timeout"]);
if (value > INT32_MAX) {
TORCH_CHECK(false, "Value type of hccl_exec_timeout exceeds INT32_MAX(2147483647).", DIST_ERROR(ErrCode::TYPE));
}
config.hcclExecTimeOut = static_cast<int32_t>(value);
} else {
TORCH_CHECK(false, "Value type of hccl_exec_timeout should be int.", DIST_ERROR(ErrCode::TYPE));
}
}
if (options_->hccl_config.find("hccl_algo") != options_->hccl_config.end()) {
if (std::holds_alternative<std::string>(options_->hccl_config["hccl_algo"])) {
auto hcclAlgo = std::get<std::string>(options_->hccl_config["hccl_algo"]);
uint32_t length = hcclAlgo.length();
if (length >= HCCL_COMM_ALGO_MAX_LENGTH) {
length = HCCL_COMM_ALGO_MAX_LENGTH - 1;
TORCH_NPU_WARN("The length of hccl_algo has exceeded the limit HCCL_COMM_ALGO_MAX_LENGTH(1600) which will be truncated to HCCL_COMM_ALGO_MAX_LENGTH - 1.");
}
strncpy(config.hcclAlgo, hcclAlgo.c_str(), length);
config.hcclAlgo[length] = '\0';
} else {
TORCH_CHECK(false, "Value type of hccl_algo should be string.", DIST_ERROR(ErrCode::TYPE));
}
}
if (options_->hccl_config.find("hccl_retry_enable") != options_->hccl_config.end()) {
if (std::holds_alternative<std::string>(options_->hccl_config["hccl_retry_enable"])) {
auto hcclRetryEnable = std::get<std::string>(options_->hccl_config["hccl_retry_enable"]);
uint32_t length = hcclRetryEnable.length();
if (length >= HCCL_COMM_RETRY_ENABLE_MAX_LENGTH) {
length = HCCL_COMM_RETRY_ENABLE_MAX_LENGTH - 1;
TORCH_NPU_WARN("The length of hccl_retry_enable has exceeded the limit HCCL_COMM_RETRY_ENABLE_MAX_LENGTH(50) which will be truncated to HCCL_COMM_RETRY_ENABLE_MAX_LENGTH - 1.");
}
strncpy(config.hcclRetryEnable, hcclRetryEnable.c_str(), length);
config.hcclRetryEnable[length] = '\0';
} else {
TORCH_CHECK(false, "Value type of hccl_retry_enable should be string.", DIST_ERROR(ErrCode::TYPE));
}
}
if (options_->hccl_config.find("hccl_retry_params") != options_->hccl_config.end()) {
if (std::holds_alternative<std::string>(options_->hccl_config["hccl_retry_params"])) {
auto hcclRetryParams = std::get<std::string>(options_->hccl_config["hccl_retry_params"]);
uint32_t length = hcclRetryParams.length();
if (length >= HCCL_COMM_RETRY_PARAMS_MAX_LENGTH) {
length = HCCL_COMM_RETRY_PARAMS_MAX_LENGTH - 1;
TORCH_NPU_WARN("The length of hccl_retry_params has exceeded the limit HCCL_COMM_RETRY_PARAMS_MAX_LENGTH(128) which will be truncated to HCCL_COMM_RETRY_PARAMS_MAX_LENGTH - 1.");
}
strncpy(config.hcclRetryParams, hcclRetryParams.c_str(), length);
config.hcclRetryParams[length] = '\0';
} else {
TORCH_CHECK(false, "Value type of hccl_retry_params should be string.", DIST_ERROR(ErrCode::TYPE));
}
}
if (options_->hccl_config.find("hccl_buffer_name") != options_->hccl_config.end()) {
if (std::holds_alternative<std::string>(options_->hccl_config["hccl_buffer_name"])) {
auto bufferName = std::get<std::string>(options_->hccl_config["hccl_buffer_name"]);
uint32_t length = bufferName.length();
if (length >= BUFFER_NAME_MAX_LENGTH) {
length = BUFFER_NAME_MAX_LENGTH - 1;
TORCH_NPU_WARN("The length of hccl_buffer_name has exceeded the limit BUFFER_NAME_MAX_LENGTH(128) which will be truncated to BUFFER_NAME_MAX_LENGTH - 1.");
}
strncpy(config.hcclBufferName, bufferName.c_str(), length);
config.hcclBufferName[length] = '\0';
} else {
TORCH_CHECK(false, "Value type of hccl_buffer_name should be string.", DIST_ERROR(ErrCode::TYPE));
}
}
return config;
}
template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::collective(
std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
Fn fn,
PreProcess pre,
PostProcess post,
c10d::OpType opType)
{
c10_npu::CaptureStatus capture_status = c10_npu::currentStreamCaptureStatusMayInitCtx();
seqCollective_++;
seq_++;
op_id_++;
const auto devices = getDeviceList(inputs);
for (const auto& device : devices) {
c10_npu::detail::checkCurrentStreamNotExternal(device.index(), "ProcessGroupHCCL::collective");
}
auto key = getKeyFromDevices(devices);
HcclCommConfig config = createHcclCommConfigWithOptions();
std::vector<std::shared_ptr<HCCLComm>> hcclComms = getHCCLComm(key, devices, HcclCommType::DEFAULT, &config);
auto& hcclStreams = hcclStreams_[key];
syncStreams(devices, hcclEvents_[key], hcclStreams);
auto work = initWork(devices, rank_, opType, "", inputs, outputs, true);
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
c10_npu::OptionalNPUGuard npuGuard;
bool need_update_core_num = false;
uint32_t current_aic_num = 0;
uint32_t current_aiv_num = 0;
if (c10_npu::is_core_control_enabled()) {
auto npu_stream = c10_npu::getCurrentNPUStream();
current_aic_num = c10_npu::GetStreamResLimit(npu_stream, c10_npu::acl::ACL_RT_DEV_RES_CUBE_CORE);
current_aiv_num = c10_npu::GetStreamResLimit(npu_stream, c10_npu::acl::ACL_RT_DEV_RES_VECTOR_CORE);
if (current_aic_num != cached_aic_num || current_aiv_num != cached_aiv_num) {
need_update_core_num = true;
c10_npu::UseStreamResInCurrentThread(npu_stream);
cached_aic_num = current_aic_num;
cached_aiv_num = current_aiv_num;
}
}
if (desyncDebug_ || status_save_enable) {
for (const auto i : c10::irange(devices.size())) {
c10_npu::NPUStream& hcclStream = hcclStreams[i];
(*(work->hcclStartEvents_))[i].record(hcclStream);
}
}
if (c10_npu::model_state().get_model_mode() == c10_npu::ModelMode::L_TRAIN
&& c10_npu::option::OptionsManager::GetSilenceCheckFlag() != c10_npu::option::CHECK_CLOSE
&& opType != c10d::OpType::UNKNOWN && opType != c10d::OpType::BROADCAST) {
for (const auto i : c10::irange(inputs.size())) {
npuGuard.set_index(devices[i].index());
c10_npu::NPUStreamGuard guard(hcclStreams[i]);
silenceCheck(inputs[i], opType);
}
}
pre(hcclStreams, work);
if (nslb_path != nullptr && !nslb_is_end) {
auto nslb_num = c10_npu::option::OptionsManager::GetNslbCntVal();
if (op_id_ <= nslb_num) {
size_t dataVol = 0;
for (auto tensor:inputs) {
dataVol += tensor.storage().nbytes();
}
const char* global_rank = getenv("RANK");
TORCH_CHECK(global_rank != nullptr, "Unable to fetch global rank for NSLB.", DIST_ERROR(ErrCode::NOT_FOUND));
recordDataVol(opTypeToString(opType), std::to_string(dataVol), strtol(global_rank, nullptr, 10), hcclComms);
}
if (op_id_ >= nslb_num) {
nslb_is_end = true;
nslb_record_end();
}
}
static bool perf_dump_enable = c10_npu::option::OptionsManager::CheckPerfDumpEnable();
if (perf_dump_enable) {
if (perfdumppath.empty()) {
auto pid = getpid();
int device_id = c10_npu::current_device();
std::ostringstream oss;
oss << "perf_pt_" << pid << "_" << device_id << ".log";
std::string log_file_name = oss.str();
auto perfDumpPath = c10_npu::option::OptionsManager::GetPerfDumpPath();
char abs_path[PATH_MAX] = {'\0'};
if (realpath(perfDumpPath.c_str(), abs_path) == nullptr) {
TORCH_CHECK(0, "perfDumpPath is not realpath.", DIST_ERROR(ErrCode::NOT_FOUND));
}
auto path_temp = c10::str(perfDumpPath, "/", log_file_name);
if (isFileExists(path_temp)) {
perfdumppath = path_temp;
std::ofstream outfile;
try {
outfile.open(perfdumppath, std::ios::app);
} catch (std::exception& e) {
throw std::runtime_error("Open shared directory failed. Please check whether perfdumppath is valid." + DIST_ERROR(ErrCode::NOT_FOUND));
}
const std::vector<uint32_t>& ranks = groupRanks();
outfile << "[GLOBAL RANKID]:" << ranks[rank_] << "\n";
outfile.close();
}
} else {
recordComm(perfdumppath, opTypeToString(opType), rank_, hcclComms);
}
}
for (const auto i : c10::irange(inputs.size())) {
npuGuard.set_index(devices[i].index());
c10_npu::NPUStream& hcclStream = hcclStreams[i];
auto multi_stream_memory_reuse_mode = c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse();
if (multi_stream_memory_reuse_mode == c10_npu::option::AVOID_RECORD_STREAM) {
if (opType == c10d::OpType::UNKNOWN) {
c10_npu::NPUCachingAllocator::recordStream(inputs[i].storage().data_ptr(), hcclStream);
work->recorded_inputs_.push_back(std::make_pair(inputs[i].storage().getWeakStorageImpl(), hcclStream));
} else {
work->stashed_for_allocator_safety_.push_back(inputs[i]);
}
} else {
c10_npu::NPUCachingAllocator::recordStream(inputs[i].storage().data_ptr(), hcclStream);
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM ||
multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
work->recorded_inputs_.push_back(std::make_pair(inputs[i].storage().getWeakStorageImpl(), hcclStream));
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
auto block_ptr = c10_npu::NPUCachingAllocator::getBlockPtr(inputs[i].storage().data_ptr());
work->recorded_block_ptr_for_inputs_.push_back(block_ptr);
c10_npu::NPUCachingAllocator::recordHcclWorkForBlock(block_ptr, static_cast<void*>(work.get()));
}
}
}
}
{
for (const auto i : c10::irange(inputs.size())) {
npuGuard.set_index(devices[i].index());
c10_npu::NPUStream& hcclStream = hcclStreams[i];
if (need_update_core_num) {
c10_npu::SetStreamResLimit(hcclStream, c10_npu::acl::ACL_RT_DEV_RES_CUBE_CORE, current_aic_num);
c10_npu::SetStreamResLimit(hcclStream, c10_npu::acl::ACL_RT_DEV_RES_VECTOR_CORE, current_aiv_num);
}
hcclUs startut = std::chrono::steady_clock::now();
HCCL_CHECK_ERROR(fn(inputs[i], outputs[i], hcclComms[i]->getHcclComm(), hcclStream, work->is_dispatched), opTypeToString(opType).c_str());
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::ERASE_RECORD_STREAM) {
work->recorded_outputs_.push_back(
std::make_pair(outputs[i].storage().getWeakStorageImpl(), hcclStream));
} else if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(outputs[i]);
}
}
}
post(hcclStreams, work);
{
c10_npu::NPUMultiStreamGuard guard(hcclStreams);
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()),
devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
}
for (size_t i = 0; i < inputs.size(); ++i) {
c10_npu::NPUStream& hcclStream = hcclStreams_[key][i];
(*(work->hcclEndEvents_))[i].record(hcclStream);
TORCH_NPU_HCCL_LOGI("Event: record hccl work is successfully executed, event=%p", (*(work->hcclEndEvents_))[i].event());
work->hcclComms_[i] = hcclComms[i];
}
work->blockingWait_ = blockingWait_;
work->opTimeout_ = options_->timeout;
work->store_ = store_;
assignTimeoutToWork(work, options_);
work->numelIn_ = 0;
work->numelOut_ = 0;
for (const auto& input : inputs) {
work->numelIn_ += static_cast<size_t>(input.numel());
}
for (const auto& output : outputs) {
work->numelOut_ += static_cast<size_t>(output.numel());
}
c10_npu::NPUGraph::inc_pending_event_queries();
if (asyncErrorHandling_ != NoHandling && capture_status == c10_npu::CaptureStatus::None) {
workEnqueue(work);
} else {
c10_npu::NPUGraph::dec_pending_event_queries();
}
return work;
}
template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::collectiveCoalesced(
std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
Fn fn,
PreProcess pre,
PostProcess post,
c10d::OpType opType)
{
c10_npu::CaptureStatus capture_status = c10_npu::currentStreamCaptureStatusMayInitCtx();
seq_++;
op_id_++;
const auto devices = getDevice(inputs);
auto key = getKeyFromDevice(devices);
NPU_CHECK_ERROR(c10_npu::SetDevice(devices[0].index()));
HcclCommConfig config = createHcclCommConfigWithOptions();
std::vector<std::shared_ptr<HCCLComm>> hcclComms = getHCCLComm(key, devices, HcclCommType::DEFAULT, &config);
auto& hcclStreams = hcclStreams_[key];
syncStreams(devices, hcclEvents_[key], hcclStreams);
bool need_update_core_num = false;
uint32_t current_aic_num = 0;
uint32_t current_aiv_num = 0;
if (c10_npu::is_core_control_enabled()) {
auto npu_stream = c10_npu::getCurrentNPUStream();
current_aic_num = c10_npu::GetStreamResLimit(npu_stream, c10_npu::acl::ACL_RT_DEV_RES_CUBE_CORE);
current_aiv_num = c10_npu::GetStreamResLimit(npu_stream, c10_npu::acl::ACL_RT_DEV_RES_VECTOR_CORE);
if (current_aic_num != cached_aic_num || current_aiv_num != cached_aiv_num) {
need_update_core_num = true;
c10_npu::UseStreamResInCurrentThread(npu_stream);
cached_aic_num = current_aic_num;
cached_aiv_num = current_aiv_num;
}
}
if (coalescing_state_ & CoalActive) {
coalescing_state_ |= CoalColl;
if (coalescedDevice_.index() < 0) {
coalescedDevice_ = devices[0];
} else {
for (const auto& device : devices) {
TORCH_CHECK(
coalescedDevice_.index() == device.index(),
"Expecting same device across coalesced P2P operations. "
"Got device ", device.index(), " but expected ", coalescedDevice_.index());
}
}
if (coalescedComm_ == nullptr) {
coalescedComm_ = hcclComms[0];
} else {
TORCH_CHECK(
coalescedComm_ == hcclComms[0],
"Expecting same communicator across coalesced P2P operations.");
}
}
auto work = initWork(devices, rank_, opType);
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
c10_npu::OptionalNPUGuard npuGuard;
if (desyncDebug_) {
c10_npu::NPUStream& hcclStream = hcclStreams[0];
(*(work->hcclStartEvents_))[0].record(hcclStream);
}
if (c10_npu::model_state().get_model_mode() == c10_npu::ModelMode::L_TRAIN
&& c10_npu::option::OptionsManager::GetSilenceCheckFlag() != c10_npu::option::CHECK_CLOSE
&& opType != c10d::OpType::UNKNOWN) {
for (const auto i : c10::irange(inputs.size())) {
npuGuard.set_index(devices[0].index());
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
silenceCheck(inputs[i], opType);
}
}
pre(hcclStreams, work);
if (nslb_path != nullptr && !nslb_is_end) {
auto nslb_num = c10_npu::option::OptionsManager::GetNslbCntVal();
if (op_id_ <= nslb_num) {
size_t dataVol = 0;
for (auto tensor:inputs) {
dataVol += tensor.storage().nbytes();
}
const char* global_rank = getenv("RANK");
TORCH_CHECK(global_rank != nullptr, "Unable to fetch global rank for NSLB.", DIST_ERROR(ErrCode::NOT_FOUND));
recordDataVol(opTypeToString(opType), std::to_string(dataVol), strtol(global_rank, nullptr, 10), hcclComms);
}
if (op_id_ >= nslb_num) {
nslb_is_end = true;
nslb_record_end();
}
}
static bool perf_dump_enable = c10_npu::option::OptionsManager::CheckPerfDumpEnable();
if (perf_dump_enable) {
if (perfdumppath.empty()) {
auto pid = getpid();
int device_id = c10_npu::current_device();
std::ostringstream oss;
oss << "perf_pt_" << pid << "_" << device_id << ".log";
std::string log_file_name = oss.str();
auto perfDumpPath = c10_npu::option::OptionsManager::GetPerfDumpPath();
char abs_path[PATH_MAX] = {'\0'};
if (realpath(perfDumpPath.c_str(), abs_path) == nullptr) {
TORCH_CHECK(0, "perfDumpPath is not realpath.", DIST_ERROR(ErrCode::NOT_FOUND));
}
auto path_temp = c10::str(perfDumpPath, "/", log_file_name);
if (isFileExists(path_temp)) {
perfdumppath = path_temp;
std::ofstream outfile;
try {
outfile.open(perfdumppath, std::ios::app);
} catch (std::exception& e) {
throw std::runtime_error("Open shared directory failed. Please check whether perfdumppath is valid." + DIST_ERROR(ErrCode::NOT_FOUND));
}
const std::vector<uint32_t>& ranks = groupRanks();
outfile << "[GLOBAL RANKID]:" << ranks[rank_] << "\n";
outfile.close();
}
} else {
recordComm(perfdumppath, opTypeToString(opType), rank_, hcclComms);
}
}
for (const auto i : c10::irange(inputs.size())) {
npuGuard.set_index(devices[0].index());
c10_npu::NPUStream& hcclStream = hcclStreams[0];
auto multi_stream_memory_reuse_mode = c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse();
if (multi_stream_memory_reuse_mode == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(inputs[i]);
} else {
c10_npu::NPUCachingAllocator::recordStream(inputs[i].storage().data_ptr(), hcclStream);
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM ||
multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
work->recorded_inputs_.push_back(std::make_pair(inputs[i].storage().getWeakStorageImpl(), hcclStream));
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
auto block_ptr = c10_npu::NPUCachingAllocator::getBlockPtr(inputs[i].storage().data_ptr());
work->recorded_block_ptr_for_inputs_.push_back(block_ptr);
c10_npu::NPUCachingAllocator::recordHcclWorkForBlock(block_ptr, static_cast<void*>(work.get()));
}
}
}
}
{
for (const auto i : c10::irange(inputs.size())) {
npuGuard.set_index(devices[0].index());
c10_npu::NPUStream& hcclStream = hcclStreams[0];
if (need_update_core_num) {
c10_npu::SetStreamResLimit(hcclStream, c10_npu::acl::ACL_RT_DEV_RES_CUBE_CORE, current_aic_num);
c10_npu::SetStreamResLimit(hcclStream, c10_npu::acl::ACL_RT_DEV_RES_VECTOR_CORE, current_aiv_num);
}
hcclUs startut = std::chrono::steady_clock::now();
HCCL_CHECK_ERROR(fn(inputs[i], outputs[i], hcclComms[0]->getHcclComm(), hcclStream, work->is_dispatched), opTypeToString(opType).c_str());
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::ERASE_RECORD_STREAM) {
work->recorded_outputs_.push_back(
std::make_pair(outputs[i].storage().getWeakStorageImpl(), hcclStream));
} else if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(outputs[i]);
}
}
}
post(hcclStreams, work);
{
c10_npu::NPUMultiStreamGuard guard(hcclStreams);
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()),
devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
}
c10_npu::NPUStream& hcclStream = hcclStreams_[key][0];
(*(work->hcclEndEvents_))[0].record(hcclStream);
TORCH_NPU_HCCL_LOGI("Event: record hccl work is successfully executed, event=%p", (*(work->hcclEndEvents_))[0].event());
work->hcclComms_[0] = hcclComms[0];
work->blockingWait_ = blockingWait_;
work->opTimeout_ = options_->timeout;
work->store_ = store_;
assignTimeoutToWork(work, options_);
work->numelIn_ = static_cast<size_t>(inputs[0].numel());
work->numelOut_ = static_cast<size_t>(outputs[0].numel());
c10_npu::NPUGraph::inc_pending_event_queries();
if (asyncErrorHandling_ != NoHandling && capture_status == c10_npu::CaptureStatus::None) {
workEnqueue(work);
} else {
c10_npu::NPUGraph::dec_pending_event_queries();
}
return work;
}
template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::pointToPoint(
std::vector<at::Tensor>& tensors,
Fn fn,
int peer,
c10d::OpType opType,
PreProcess pre,
PostProcess post)
{
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::AVOID_RECORD_STREAM) {
TORCH_NPU_WARN_ONCE("MULTI_STREAM_MEMORY_REUSE=2 has no effect for point-to-point collectives.");
}
c10_npu::CaptureStatus capture_status = c10_npu::currentStreamCaptureStatusMayInitCtx();
const auto devices = getDeviceList(tensors);
int p2pRank = 0;
int p2pTargetRank = 0;
bool isSendRecvSelf = false;
std::string key;
std::vector<std::shared_ptr<HCCLComm>> hcclComms;
if (hcclCommInitRootInfoConfigExist() && c10_npu::option::OptionsManager::GetP2PBufferSize() != 0 && coalescing_state_ == 0) {
key = getKeySendRecv(rank_, peer);
p2pRank = rank_ <= peer ? 0 : 1;
isSendRecvSelf = rank_ == peer;
p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank;
setP2pPeer(peer);
hcclComms = getHCCLComm(key, devices, HcclCommType::P2P, nullptr, p2pRank);
} else {
p2pTargetRank = peer;
key = getKeyFromDevices(devices);
hcclComms = getHCCLComm(key, devices);
}
op_id_++;
if (coalescing_state_ & CoalActive) {
if ((coalescing_state_ & CoalP2P) == 0) {
seqP2P_++;
}
coalescing_state_ |= CoalP2P;
if (coalescedDevice_.index() < 0) {
coalescedDevice_ = devices[0];
} else {
for (const auto& device : devices) {
TORCH_CHECK(
coalescedDevice_.index() == device.index(),
"Expecting same device across coalesced P2P operations. "
"Got device ", device.index(), " but expected ", coalescedDevice_.index());
}
}
if (coalescedComm_ == nullptr) {
coalescedComm_ = hcclComms[0];
} else {
TORCH_CHECK(
coalescedComm_ == hcclComms[0],
"Expecting same communicator across coalesced P2P operations.");
}
} else {
seqP2P_++;
}
syncStreams(devices, hcclEvents_[key], hcclStreams_[key]);
c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL> work;
if (coalescing_state_) {
coalescedTensors_.stash(tensors);
} else {
work = initWork(devices, rank_, opType, "", tensors, tensors, true);
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(tensors);
}
c10_npu::OptionalNPUGuard npuGuard;
if (!coalescing_state_ && desyncDebug_) {
for (const auto i : c10::irange(devices.size())) {
c10_npu::NPUStream& hcclStream = hcclStreams_[key][i];
(*(work->hcclStartEvents_))[i].record(hcclStream);
}
}
if (c10_npu::model_state().get_model_mode() == c10_npu::ModelMode::L_TRAIN
&& c10_npu::option::OptionsManager::GetSilenceCheckFlag() != c10_npu::option::CHECK_CLOSE
&& opType == c10d::OpType::SEND) {
for (const auto i : c10::irange(tensors.size())) {
npuGuard.set_index(devices[i].index());
c10_npu::NPUStreamGuard guard(hcclStreams_[key][i]);
silenceCheck(tensors[i], opType);
}
}
if (!coalescing_state_) {
pre(hcclStreams_[key], work);
}
if (nslb_path != nullptr && !nslb_is_end) {
auto nslb_num = c10_npu::option::OptionsManager::GetNslbCntVal();
if (op_id_ <= nslb_num) {
size_t dataVol = 0;
for (auto tensor : tensors) {
dataVol += tensor.storage().nbytes();
}
const char* global_rank = getenv("RANK");
TORCH_CHECK(global_rank != nullptr, "Unable to fetch global rank for NSLB.",
DIST_ERROR(ErrCode::NOT_FOUND));
recordDataVol(opTypeToString(opType), std::to_string(dataVol), strtol(global_rank, nullptr, 10), hcclComms);
}
if (op_id_ >= nslb_num) {
nslb_is_end = true;
nslb_record_end();
}
}
static bool perf_dump_enable = c10_npu::option::OptionsManager::CheckPerfDumpEnable();
if (perf_dump_enable) {
if (perfdumppath.empty()) {
auto pid = getpid();
int device_id = c10_npu::current_device();
std::ostringstream oss;
oss << "perf_pt_" << pid << "_" << device_id << ".log";
std::string log_file_name = oss.str();
auto perfDumpPath = c10_npu::option::OptionsManager::GetPerfDumpPath();
char abs_path[PATH_MAX] = {'\0'};
if (realpath(perfDumpPath.c_str(), abs_path) == nullptr) {
TORCH_CHECK(0, "perfDumpPath is not realpath.", DIST_ERROR(ErrCode::NOT_FOUND));
}
auto path_temp = c10::str(perfDumpPath, "/", log_file_name);
if (isFileExists(path_temp)) {
perfdumppath = path_temp;
std::ofstream outfile;
try {
outfile.open(perfdumppath, std::ios::app);
} catch (std::exception& e) {
throw std::runtime_error("Open shared directory failed. Please check whether perfdumppath is valid." + DIST_ERROR(ErrCode::NOT_FOUND));
}
const std::vector<uint32_t>& ranks = groupRanks();
outfile << "[GLOBAL RANKID]:" << ranks[rank_] << "\n";
outfile.close();
}
} else {
recordComm(perfdumppath, opTypeToString(opType), rank_, hcclComms);
}
}
for (const auto i : c10::irange(tensors.size())) {
npuGuard.set_index(devices[i].index());
c10_npu::NPUStream& hcclStream = hcclStreams_[key][i];
auto multi_stream_memory_reuse_mode = c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse();
c10_npu::NPUCachingAllocator::recordStream(tensors[i].storage().data_ptr(), hcclStream);
if (!coalescing_state_ && multi_stream_memory_reuse_mode != c10_npu::option::CLOSE) {
work->recorded_inputs_.push_back(std::make_pair(tensors[i].storage().getWeakStorageImpl(), hcclStream));
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
auto block_ptr = c10_npu::NPUCachingAllocator::getBlockPtr(tensors[i].storage().data_ptr());
work->recorded_block_ptr_for_inputs_.push_back(block_ptr);
c10_npu::NPUCachingAllocator::recordHcclWorkForBlock(block_ptr, static_cast<void*>(work.get()));
}
}
}
{
for (const auto i : c10::irange(tensors.size())) {
npuGuard.set_index(devices[i].index());
c10_npu::NPUStream& hcclStream = hcclStreams_[key][i];
hcclUs startut = std::chrono::steady_clock::now();
std::shared_ptr<bool> is_dispatched = coalescing_state_ ? std::make_shared<bool>(false) : work->is_dispatched;
if (coalescing_state_) {
auto hccl_call = [this]() -> HcclResult {
return hcclGroupStart();
};
at_npu::native::OpCommand::RunOpApiV3("hcclGroupStart", hccl_call);
}
HCCL_CHECK_ERROR(fn(tensors[i], hcclComms[i]->getHcclComm(), hcclStream, is_dispatched, p2pTargetRank), opTypeToString(opType).c_str());
if (coalescing_state_) {
auto hccl_call = [this]() -> HcclResult {
return hcclGroupEnd();
};
at_npu::native::OpCommand::RunOpApiV3("hcclGroupEnd", hccl_call);
}
}
}
if (!coalescing_state_) {
post(hcclStreams_[key], work);
{
c10_npu::NPUMultiStreamGuard guard(hcclStreams_[key]);
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()),
devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
}
for (const auto i : c10::irange(tensors.size())) {
c10_npu::NPUStream& hcclStream = hcclStreams_[key][i];
(*(work->hcclEndEvents_))[i].record(hcclStream);
work->hcclComms_[i] = hcclComms[i];
work->blockingWait_ = blockingWait_;
work->opTimeout_ = options_->timeout;
work->store_ = store_;
assignTimeoutToWork(work, options_);
work->numelIn_ = work->numelOut_ = static_cast<size_t>(tensors[i].numel());
}
c10_npu::NPUGraph::inc_pending_event_queries();
if (asyncErrorHandling_ != NoHandling && capture_status == c10_npu::CaptureStatus::None) {
workEnqueue(work);
} else {
c10_npu::NPUGraph::dec_pending_event_queries();
}
}
return coalescing_state_ ? nullptr : work;
}
template <typename Fn>
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::collective(
std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
Fn fn,
c10d::OpType opType)
{
return collective(
inputs,
outputs,
fn,
[](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
[](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
opType);
}
template <typename Fn>
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::pointToPoint(
std::vector<at::Tensor>& tensors,
Fn fn,
int peer,
c10d::OpType opType)
{
return pointToPoint(
tensors,
fn,
peer,
opType,
[](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
[](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {});
}
int g_allreduceID = 0;
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::allreduce(
std::vector<at::Tensor>& tensors,
const c10d::AllreduceOptions& opts)
{
check_npu_tensors_different_devices(tensors);
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("allreduce", tensors);
}
std::vector<at::Tensor> tensors_cp = {tensors[0]};
std::string functionName = __FUNCTION__;
return collective(
tensors_cp,
tensors_cp,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
auto hcclType = getHcclDataType(input.scalar_type());
checkSupportedDataType(hcclType, functionName);
RECORD_FUNCTION("HcclAllreduce", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(input);
auto hcclReduceOp = getHcclReduceOp(opts.reduceOp, input);
auto hccl_call = [inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclAllreduce", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclAllReduce(
inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclAllreduce", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (tensors[0].scalar_type() == at::kBool || tensors[0].scalar_type() == at::kByte) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
tensors_cp[0] = at_npu::native::custom_ops::_npu_dtype_cast(tensors[0], at::kInt);
}
if (opts.reduceOp == c10d::ReduceOp::PREMUL_SUM) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
const auto* preMulSupplement =
reinterpret_cast<c10d::NCCLPreMulSumSupplement*>(
opts.reduceOp.supplement_.get());
auto scale_factor = preMulSupplement->double_factor;
for (auto& tensor : tensors_cp) {
tensor.mul_(scale_factor);
}
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (opts.reduceOp == c10d::ReduceOp::AVG) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
for (auto& tensor : tensors_cp) {
tensor.div_(getSize());
}
}
if (tensors_cp[0].scalar_type() != tensors[0].scalar_type()) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
c10_npu::NPUCachingAllocator::recordStream(tensors_cp[0].storage().data_ptr(), hcclStreams[0]);
tensors[0].copy_(tensors_cp[0]);
}
},
c10d::OpType::ALLREDUCE);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::batch_isend_irecv_inner(
std::vector<std::string>& op_type,
std::vector<at::Tensor>& tensors,
std::vector<int64_t> remote_rank_list)
{
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("batch_isend_irecv", tensors);
}
std::vector<at::Tensor> tensors_tmp = {tensors[0]};
return collective(
tensors_tmp,
tensors_tmp,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
RECORD_FUNCTION("HcclBatchSendRecv", std::vector<c10::IValue>({input}));
auto itemNum = static_cast<uint32_t>(op_type.size());
std::vector<void *> tensor_ptr_list;
std::vector<uint64_t> numel_list;
std::vector<HcclDataType> type_list;
for (size_t i = 0; i < op_type.size(); ++i) {
tensor_ptr_list.push_back(tensors[i].data_ptr());
numel_list.push_back(getNumelForHCCL(tensors[i]));
type_list.push_back(getHcclDataType(tensors[i].scalar_type()));
}
std::vector<uint32_t> remote_rank_list_cast;
remote_rank_list_cast.reserve(remote_rank_list.size());
for (size_t i = 0; i < remote_rank_list.size(); ++i) {
if (remote_rank_list[i] < 0 || remote_rank_list[i] > std::numeric_limits<uint32_t>::max()) {
throw std::runtime_error("Value at index " + std::to_string(i) +
" (" + std::to_string(remote_rank_list[i]) +
") is out of uint32_t range" + DIST_ERROR(ErrCode::VALUE));
}
remote_rank_list_cast.push_back(static_cast<uint32_t>(remote_rank_list[i]));
}
auto hccl_call = [tensor_ptr_list, numel_list, type_list, remote_rank_list_cast, op_type, itemNum, comm, stream, is_dispatched]() -> int {
HcclSendRecvItem sendRecvInfo[itemNum];
HcclSendRecvType currType;
for (size_t i = 0; i < op_type.size(); ++i) {
if (op_type[i] == "isend") {
currType = HcclSendRecvType::HCCL_SEND;
} else if (op_type[i] == "irecv") {
currType = HcclSendRecvType::HCCL_RECV;
} else {
currType = HcclSendRecvType::HCCL_SEND_RECV_RESERVED;
}
sendRecvInfo[i] = HcclSendRecvItem{currType,
tensor_ptr_list[i],
numel_list[i],
type_list[i],
remote_rank_list_cast[i]
};
}
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclBatchSendRecv", sendRecvInfo[0].count, sendRecvInfo[0].dataType, comm, stream.id(), -1, -1),
stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclBatchIsendIrecv(sendRecvInfo, itemNum, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclBatchSendRecv", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>& work) {
if (c10_npu::model_state().get_model_mode() == c10_npu::ModelMode::L_TRAIN
&& c10_npu::option::OptionsManager::GetSilenceCheckFlag() != c10_npu::option::CHECK_CLOSE) {
for (size_t i = 0; i < op_type.size(); ++i) {
if (op_type[i] != "irecv") {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
silenceCheck(tensors[i], c10d::OpType::SEND);
}
}
}
auto mode = c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse();
for (size_t i = 1; i < tensors.size(); ++i) {
if (mode == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(tensors[i]);
} else {
c10_npu::NPUCachingAllocator::recordStream(
tensors[i].storage().data_ptr(), hcclStreams[0]);
if (mode == c10_npu::option::ERASE_RECORD_STREAM ||
mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
work->recorded_inputs_.push_back(
std::make_pair(tensors[i].storage().getWeakStorageImpl(), hcclStreams[0]));
if (mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
auto block_ptr = c10_npu::NPUCachingAllocator::getBlockPtr(
tensors[i].storage().data_ptr());
work->recorded_block_ptr_for_inputs_.push_back(block_ptr);
c10_npu::NPUCachingAllocator::recordHcclWorkForBlock(
block_ptr, static_cast<void*>(work.get()));
}
}
}
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
c10d::OpType::UNKNOWN);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::batch_isend_irecv(
std::vector<std::string>& op_type,
std::vector<at::Tensor>& tensors,
std::vector<int64_t> remote_rank_list)
{
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("npu_custom_dist::wrap_batch_isend_irecv_inner", "")
.typed<c10::intrusive_ptr<::c10d::Work>(
at::TensorList,
std::vector<std::string>,
std::vector<int64_t>,
c10::intrusive_ptr<c10d_npu::ProcessGroupHCCL>)>();
auto work = op.call(
tensors,
op_type,
remote_rank_list,
c10::intrusive_ptr<c10d_npu::ProcessGroupHCCL>::unsafe_reclaim_from_nonowning(this));
if (c10d::allow_inflight_collective_as_graph_input()) {
for (auto tensor : tensors) {
c10d::register_work(tensor, work);
}
}
return work;
}
int g_broadcastID = 100000;
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::broadcast(
std::vector<at::Tensor>& tensors,
const c10d::BroadcastOptions& opts)
{
check_npu_tensors_different_devices(tensors);
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("broadcast", tensors);
}
return collective(
tensors,
tensors,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
RECORD_FUNCTION("HcclBroadcast", std::vector<c10::IValue>({input}));
const auto root = opts.rootRank * tensors.size() + opts.rootTensor;
auto inputDataPtr = input.data_ptr();
auto numel = getNumelForHCCL(input);
auto hcclType = getHcclDataType(input.scalar_type());
auto hccl_call = [inputDataPtr, numel, hcclType, root, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclBroadcast", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclBroadcast(inputDataPtr, numel, hcclType, root, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclBroadcast", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (c10_npu::model_state().get_model_mode() == c10_npu::ModelMode::L_TRAIN
&& c10_npu::option::OptionsManager::GetSilenceCheckFlag() != c10_npu::option::CHECK_CLOSE) {
const std::vector<uint32_t>& ranks = groupRanks();
if (opts.rootRank == ranks[rank_]) {
for (const auto i : c10::irange(tensors.size())) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
silenceCheck(tensors[i], c10d::OpType::BROADCAST);
}
}
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
c10d::OpType::BROADCAST);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const c10d::AllreduceCoalescedOptions& opts)
{
check_npu_tensors_same_device(tensors);
std::vector<at::Tensor> tensors_cp = tensors;
std::string functionName = __FUNCTION__;
return collectiveCoalesced(
tensors_cp,
tensors_cp,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
auto hcclType = getHcclDataType(input.scalar_type());
checkSupportedDataType(hcclType, functionName);
RECORD_FUNCTION("HcclAllreduce", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(input);
auto hcclReduceOp = getHcclReduceOp(opts.reduceOp, input);
auto hccl_call = [inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclAllreduce", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclAllReduce(
inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclAllreduce", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
for (const auto i : c10::irange(tensors.size())) {
if (tensors[i].scalar_type() == at::kBool || tensors[i].scalar_type() == at::kByte) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
tensors_cp[i] = at_npu::native::custom_ops::_npu_dtype_cast(tensors[i], at::kInt);
}
}
if (opts.reduceOp == c10d::ReduceOp::PREMUL_SUM) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
const auto* preMulSupplement =
reinterpret_cast<c10d::NCCLPreMulSumSupplement*>(
opts.reduceOp.supplement_.get());
auto scale_factor = preMulSupplement->double_factor;
for (auto& tensor : tensors_cp) {
tensor.mul_(scale_factor);
}
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (opts.reduceOp == c10d::ReduceOp::AVG) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
for (auto& tensor : tensors_cp) {
tensor.div_(getSize());
}
}
for (const auto i : c10::irange(tensors.size())) {
if (tensors_cp[i].scalar_type() != tensors[i].scalar_type()) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
c10_npu::NPUCachingAllocator::recordStream(tensors_cp[i].storage().data_ptr(), hcclStreams[0]);
tensors[i].copy_(tensors_cp[i]);
}
}
},
c10d::OpType::ALLREDUCE);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::reduce(
std::vector<at::Tensor>& tensors,
const c10d::ReduceOptions& opts)
{
check_npu_tensors_different_devices(tensors);
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("reduce", tensors);
}
std::string functionName = __FUNCTION__;
uint64_t rank = opts.rootRank;
std::vector<at::Tensor> tensors_cp = {tensors[0]};
return collective(
tensors_cp,
tensors_cp,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
auto hcclType = getHcclDataType(input.scalar_type());
checkSupportedDataType(hcclType, functionName);
RECORD_FUNCTION("HcclReduce", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(input);
auto reduceOp = getHcclReduceOp(opts.reduceOp, input);
auto hccl_call = [inputDataPtr, outputDataPtr, numel, hcclType, reduceOp, rank, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclReduce", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclReduce(
inputDataPtr, outputDataPtr, numel, hcclType, reduceOp, rank, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclReduce", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (tensors[0].scalar_type() == at::kBool || tensors[0].scalar_type() == at::kByte) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
tensors_cp[0] = at_npu::native::custom_ops::_npu_dtype_cast(tensors[0], at::kInt);
}
if (opts.reduceOp == c10d::ReduceOp::PREMUL_SUM) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
const auto* preMulSupplement =
reinterpret_cast<c10d::NCCLPreMulSumSupplement*>(
opts.reduceOp.supplement_.get());
auto scale_factor = preMulSupplement->double_factor;
for (auto& tensor : tensors_cp) {
tensor.mul_(scale_factor);
}
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (opts.reduceOp == c10d::ReduceOp::AVG) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
for (auto& tensor : tensors_cp) {
tensor.div_(getSize());
}
}
if (tensors_cp[0].scalar_type() != tensors[0].scalar_type()) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
c10_npu::NPUCachingAllocator::recordStream(tensors_cp[0].storage().data_ptr(), hcclStreams[0]);
tensors[0].copy_(tensors_cp[0]);
}
},
c10d::OpType::REDUCE);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::_reduce_oop(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
const c10d::ReduceOptions& opts)
{
check_npu_single_tensor(outputTensor);
if (outputTensor.numel() != inputTensor.numel()) {
TORCH_CHECK(false, "output tensor must have the same numel as input tensor", DIST_ERROR(ErrCode::PARAM));
}
uint64_t rank = opts.rootRank;
std::vector<at::Tensor> inputTensors = {inputTensor};
std::vector<at::Tensor> outputTensors = {outputTensor};
std::string functionName = __FUNCTION__;
return collective(
inputTensors,
outputTensors,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
auto hcclType = getHcclDataType(input.scalar_type());
checkSupportedDataType(hcclType, functionName);
RECORD_FUNCTION("HcclReduce", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(input);
auto reduceOp = getHcclReduceOp(opts.reduceOp, input);
auto hccl_call = [inputDataPtr, outputDataPtr, numel, hcclType, reduceOp, rank, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclReduce", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclReduce(
inputDataPtr, outputDataPtr, numel, hcclType, reduceOp, rank, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclReduce", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (inputTensors[0].scalar_type() == at::kBool || inputTensors[0].scalar_type() == at::kByte) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
inputTensors[0] = at_npu::native::custom_ops::_npu_dtype_cast(inputTensors[0], at::kInt);
}
if (outputTensors[0].scalar_type() == at::kBool || outputTensors[0].scalar_type() == at::kByte) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
outputTensors[0] = at_npu::native::custom_ops::_npu_dtype_cast(outputTensors[0], at::kInt);
}
if (opts.reduceOp == c10d::ReduceOp::PREMUL_SUM) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
const auto* preMulSupplement =
reinterpret_cast<c10d::NCCLPreMulSumSupplement*>(
opts.reduceOp.supplement_.get());
auto scale_factor = preMulSupplement->double_factor;
for (auto& tensor : inputTensors) {
tensor.mul_(scale_factor);
}
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (opts.reduceOp == c10d::ReduceOp::AVG) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
for (auto& tensor : outputTensors) {
tensor.div_(getSize());
}
}
if (outputTensors[0].scalar_type() != outputTensor.scalar_type()) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
c10_npu::NPUCachingAllocator::recordStream(outputTensors[0].storage().data_ptr(), hcclStreams[0]);
outputTensor.copy_(outputTensors[0]);
}
},
c10d::OpType::REDUCE);
}
constexpr int64_t ADDRESS_ALIGNMENT_BYTE = 512;
at::Tensor ProcessGroupHCCL::byte_alignment(at::Tensor& tensors) const
{
static bool no_need_padding = c10_npu::GetSocVersion() >= c10_npu::SocVersion::Ascend950;
at::Tensor inter_tensors = at::reshape(tensors, {1, tensors.numel()});
if (tensors.element_size() == 0 || no_need_padding) {
return inter_tensors;
}
int64_t tensor_byte = tensors.numel() * tensors.element_size();
int64_t byte_add = (tensor_byte % ADDRESS_ALIGNMENT_BYTE == 0)
? 0
: (ADDRESS_ALIGNMENT_BYTE - tensor_byte % ADDRESS_ALIGNMENT_BYTE);
int64_t num_add = byte_add / tensors.element_size();
if (num_add != 0) {
bool transflag = false;
if (inter_tensors.scalar_type() == at::ScalarType::Bool) {
inter_tensors = at_npu::native::custom_ops::_npu_dtype_cast(inter_tensors, at::ScalarType::Int);
transflag = true;
}
inter_tensors = op_plugin::constant_pad_nd(inter_tensors, {0, num_add}, 0);
if (transflag) {
inter_tensors = at_npu::native::custom_ops::_npu_dtype_cast(inter_tensors, at::ScalarType::Bool);
}
}
return inter_tensors;
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::_reduce_scatter_base_uneven_inner(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& inputSplitSizes,
const c10d::ReduceScatterOptions& opts)
{
check_npu_single_tensor(outputTensor);
check_npu_single_tensor(inputTensor);
TORCH_CHECK(inputTensor.dtype() == outputTensor.dtype(), "output tensor must have the same type as input tensor", DIST_ERROR(ErrCode::PARAM));
std::vector<at::Tensor> inputTensors = {inputTensor};
std::vector<at::Tensor> outputTensors = {outputTensor};
check_npu_tensors_different_devices(inputTensors);
check_npu_tensors_different_devices(outputTensors);
fill_equal_split_sizes_when_empty(inputSplitSizes, inputTensor, size_);
check_split_sizes(inputSplitSizes, inputTensor, size_);
int inputSize = static_cast<int>(inputSplitSizes.size());
int inputRowSize = static_cast<int>(inputTensor.size(0) != 0 ? inputTensor.numel() / inputTensor.size(0) : 1);
std::vector<uint64_t> inputCounts;
std::vector<uint64_t> inputSpl;
inputSpl.push_back(0);
for (int i = 0; i < inputSize; i++) {
inputCounts.push_back(static_cast<uint64_t>(inputSplitSizes[i] * inputRowSize));
if (i > 0) {
inputSpl.push_back(inputSpl[i - 1] + inputCounts[i - 1]);
}
}
auto inputTensors_ = cast_to_origin_format(inputTensors);
auto outputTensors_ = cast_to_origin_format(outputTensors);
return collective(
inputTensors_,
outputTensors_,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
RECORD_FUNCTION("HcclReduceScatterV", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
uint64_t outputCount = output.numel();
auto numel = getNumelForHCCL(output);
auto hcclReduceOp = getHcclReduceOp(opts.reduceOp, input);
auto hcclType = getHcclDataType(input.scalar_type());
auto hccl_call = [
inputDataPtr,
inputCounts,
inputSpl,
outputDataPtr,
outputCount,
hcclType,
hcclReduceOp,
numel,
comm,
stream,
is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclReduceScatterV", numel, hcclType, comm, stream.id(), -1, -1),
stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclReduceScatterV(
inputDataPtr,
inputCounts.data(),
inputSpl.data(),
outputDataPtr,
outputCount,
hcclType,
hcclReduceOp,
comm,
stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclReduceScatterV", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (opts.reduceOp == c10d::ReduceOp::PREMUL_SUM) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
const auto* preMulSupplement =
reinterpret_cast<c10d::NCCLPreMulSumSupplement*>(
opts.reduceOp.supplement_.get());
auto scale_factor = preMulSupplement->double_factor;
for (auto& tensor : inputTensors_) {
tensor.mul_(scale_factor);
}
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (opts.reduceOp == c10d::ReduceOp::AVG) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
for (auto& tensor : outputTensors_) {
tensor.div_(getSize());
}
}
},
c10d::OpType::REDUCE_SCATTER);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::_reduce_scatter_base_uneven(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& inputSplitSizes,
const c10d::ReduceScatterOptions& opts)
{
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("npu_custom_dist::wrap_reduce_scatter_base_uneven_inner", "")
.typed<c10::intrusive_ptr<::c10d::Work>(
at::Tensor&,
at::Tensor&,
std::vector<int64_t>,
c10::intrusive_ptr<c10d_npu::ProcessGroupHCCL>,
c10::intrusive_ptr<c10d::ReduceOp>,
int64_t)>();
auto work = op.call(
outputTensor,
inputTensor,
inputSplitSizes,
c10::intrusive_ptr<c10d_npu::ProcessGroupHCCL>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<c10d::ReduceOp>(opts.reduceOp),
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(outputTensor, work);
}
return work;
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::_allgather_base_uneven_inner(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
const c10d::AllgatherOptions& opts)
{
check_npu_single_tensor(outputTensor);
check_npu_single_tensor(inputTensor);
TORCH_CHECK(inputTensor.dtype() == outputTensor.dtype(), "output tensor must have the same type as input tensor", DIST_ERROR(ErrCode::PARAM));
std::vector<at::Tensor> inputTensors = {inputTensor};
std::vector<at::Tensor> outputTensors = {outputTensor};
check_npu_tensors_different_devices(inputTensors);
check_npu_tensors_different_devices(outputTensors);
fill_equal_split_sizes_when_empty(outputSplitSizes, outputTensor, size_);
check_split_sizes(outputSplitSizes, outputTensor, size_);
int outputSize = static_cast<int>(outputSplitSizes.size());
int outputRowSize = static_cast<int>(outputTensor.size(0) != 0 ? outputTensor.numel() / outputTensor.size(0) : 1);
std::vector<uint64_t> outputCounts;
std::vector<uint64_t> outputSpl;
outputSpl.push_back(0);
for (int i = 0; i < outputSize; i++) {
outputCounts.push_back(static_cast<uint64_t>(outputSplitSizes[i] * outputRowSize));
if (i > 0) {
outputSpl.push_back(outputSpl[i - 1] + outputCounts[i - 1]);
}
}
auto inputTensors_ = cast_to_origin_format(inputTensors);
auto outputTensors_ = cast_to_origin_format(outputTensors);
return collective(
inputTensors_,
outputTensors_,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
RECORD_FUNCTION("HcclAllGatherV", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
uint64_t inputCount = input.numel();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(input);
auto hcclType = getHcclDataType(input.scalar_type());
auto hccl_call = [
inputDataPtr,
inputCount,
outputDataPtr,
outputCounts,
outputSpl,
hcclType,
numel,
comm,
stream,
is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclAllGatherV", numel, hcclType, comm, stream.id(), -1, -1),
stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclAllGatherV(
inputDataPtr,
inputCount,
outputDataPtr,
outputCounts.data(),
outputSpl.data(),
hcclType,
comm,
stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclAllGatherV", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
c10d::OpType::ALLGATHER);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::_allgather_base_uneven(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
const c10d::AllgatherOptions& opts)
{
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("npu_custom_dist::wrap_allgather_base_uneven_inner", "")
.typed<c10::intrusive_ptr<::c10d::Work>(
at::Tensor&,
at::Tensor&,
std::vector<int64_t>,
c10::intrusive_ptr<c10d_npu::ProcessGroupHCCL>,
int64_t)>();
auto work = op.call(
outputTensor,
inputTensor,
outputSplitSizes,
c10::intrusive_ptr<c10d_npu::ProcessGroupHCCL>::unsafe_reclaim_from_nonowning(this),
opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(outputTensor, work);
}
return work;
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const c10d::AllgatherOptions& opts)
{
check_npu_tensors_same_device(outputTensors.back());
check_npu_tensors_different_devices(inputTensors);
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("allgather", outputTensors, inputTensors);
}
auto inputTensors_ = cast_to_origin_format(inputTensors);
bool same_size = check_same_size(outputTensors.back());
if (same_size) {
int outsize = static_cast<int>(outputTensors[0].size());
uint64_t output_nums[outsize];
for (const auto i : c10::irange(outputTensors.size())) {
for (const auto j : c10::irange(outsize)) {
output_nums[j] = static_cast<uint64_t>(outputTensors[0][j].numel());
}
}
std::vector<at::Tensor> byte_alignment_inputTensors_ = {byte_alignment(inputTensors_[0])};
std::vector<at::Tensor> byte_alignment_outputTensors_;
for (unsigned int i = 0; i < outputTensors[0].size(); i++) {
byte_alignment_outputTensors_.push_back(byte_alignment(outputTensors[0][i]));
}
std::vector<std::vector<at::Tensor>> byte_alignment_outputTensors;
byte_alignment_outputTensors.push_back(byte_alignment_outputTensors_);
auto outputFlattened =
flatten_for_scatter_gather(byte_alignment_outputTensors, byte_alignment_inputTensors_, size_);
check_npu_tensors_different_devices(outputFlattened);
return collective(
byte_alignment_inputTensors_,
outputFlattened,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
RECORD_FUNCTION("HcclAllgather", std::vector<c10::IValue>({input}));
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() != c10_npu::option::AVOID_RECORD_STREAM) {
c10_npu::NPUCachingAllocator::recordStream(output.storage().data_ptr(), stream);
}
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(input);
auto hcclType = getHcclDataType(input.scalar_type());
auto hccl_call = [inputDataPtr, outputDataPtr, numel, hcclType, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclAllGather", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclAllGather(inputDataPtr, outputDataPtr, numel, hcclType, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclAllgather", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>& work) {
work->lazyDestroy(byte_alignment_inputTensors_);
work->lazyDestroy(outputFlattened);
for (const auto i : c10::irange(outputTensors.size())) {
c10_npu::NPUStreamGuard guard(hcclStreams[i]);
for (const auto j : c10::irange(outputTensors[0].size())) {
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(outputTensors[i][j]);
} else {
c10_npu::NPUCachingAllocator::recordStream(
outputTensors[i][j].storage().data_ptr(), hcclStreams[i]);
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::ERASE_RECORD_STREAM) {
work->recorded_outputs_.push_back(
std::make_pair(outputTensors[i][j].storage().getWeakStorageImpl(), hcclStreams[i]));
}
}
at::Tensor output_tensor = outputFlattened[i][j].slice(1, 0, output_nums[j]);
at::Tensor output_tensor_shape = at::reshape(output_tensor, outputTensors[i][j].sizes());
outputTensors[i][j].copy_(output_tensor_shape, true);
}
}
},
c10d::OpType::ALLGATHER);
} else if (hcclAllGatherVExist() && !has_empty_tensor(outputTensors.back())) {
std::vector<at::Tensor> lastOutputTensors = outputTensors.back();
std::vector<uint64_t> outputCounts;
std::vector<uint64_t> outputSpl;
outputSpl.push_back(0);
for (size_t i = 0; i < lastOutputTensors.size(); i++) {
outputCounts.push_back(lastOutputTensors[i].numel());
if (i > 0) {
outputSpl.push_back(outputSpl[i - 1] + outputCounts[i - 1]);
}
}
std::vector<at::Tensor> flattenedOutputTensors;
for (size_t i = 0; i < lastOutputTensors.size(); i++) {
flattenedOutputTensors.push_back(at::flatten(lastOutputTensors[i]));
}
std::vector<at::Tensor> inputFlattened = {at::flatten(inputTensors[0])};
std::vector<at::Tensor> outputFlattened = {at::cat(flattenedOutputTensors, 0)};
return collective(
inputFlattened,
outputFlattened,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
RECORD_FUNCTION("HcclAllGatherV", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
uint64_t inputCount = input.numel();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(input);
auto hcclType = getHcclDataType(input.scalar_type());
auto hccl_call = [
inputDataPtr,
inputCount,
outputDataPtr,
outputCounts,
outputSpl,
hcclType,
numel,
comm,
stream,
is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclAllGatherV", numel, hcclType, comm, stream.id(), -1, -1),
stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclAllGatherV(
inputDataPtr,
inputCount,
outputDataPtr,
outputCounts.data(),
outputSpl.data(),
hcclType,
comm,
stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclAllGatherV", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>& work) {
work->lazyDestroy(inputFlattened);
work->lazyDestroy(outputFlattened);
for (const auto i : c10::irange(outputTensors.size())) {
c10_npu::NPUStreamGuard guard(hcclStreams[i]);
for (const auto j : c10::irange(outputTensors[0].size())) {
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(outputTensors[i][j]);
} else {
c10_npu::NPUCachingAllocator::recordStream(
outputTensors[i][j].storage().data_ptr(), hcclStreams[i]);
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::ERASE_RECORD_STREAM) {
work->recorded_outputs_.push_back(
std::make_pair(outputTensors[i][j].storage().getWeakStorageImpl(), hcclStreams[i]));
}
}
at::Tensor output_tensor = outputFlattened[i].slice(0, outputSpl[j], outputSpl[j] + outputCounts[j]);
at::Tensor output_tensor_reshape = at::reshape(output_tensor, outputTensors[i][j].sizes());
outputTensors[i][j].copy_(output_tensor_reshape, true);
}
}
},
c10d::OpType::ALLGATHER);
} else {
TORCH_NPU_WARN_ONCE("The current allgather operator has a defect in handling different tensor shape, \
the work event forces a wait operation, and the allgather wait on the python side would be fake");
const auto num_devices = outputTensors.size();
const auto num_reduces = outputTensors[0].size();
std::vector<c10::intrusive_ptr<c10d::Work>> works;
for (const auto i : c10::irange(num_reduces)) {
std::vector<at::Tensor> inputs_multi_dev(num_devices);
std::vector<at::Tensor> outputs_multi_dev(num_devices);
for (const auto j : c10::irange(num_devices)) {
outputs_multi_dev[j] = outputTensors[j][i];
if (i == (rank_ * num_devices + j)) {
outputs_multi_dev[j].copy_(inputTensors[j]);
}
}
auto broadcastOpts = c10d::BroadcastOptions{
static_cast<int64_t>(i / num_devices),
static_cast<int64_t>(i % num_devices),
opts.timeout};
auto work = collective(
outputs_multi_dev, outputs_multi_dev, [&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
RECORD_FUNCTION("HcclBroadcast", std::vector<c10::IValue>({input}));
const auto root = broadcastOpts.rootRank * inputs_multi_dev.size() + broadcastOpts.rootTensor;
auto inputDataPtr = input.data_ptr();
auto numel = getNumelForHCCL(input);
auto hcclType = getHcclDataType(input.scalar_type());
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclBroadcast", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
auto hccl_result = hcclBroadcast(inputDataPtr, numel, hcclType, root, comm, stream.stream());
*is_dispatched = true;
return hccl_result;
},
c10d::OpType::BROADCAST);
works.push_back(work);
}
for (auto& work : works) {
work->wait();
}
auto fake_work = initWork(getDeviceList(inputTensors), rank_, c10d::OpType::ALLGATHER);
assignTimeoutToWork(fake_work, options_);
return fake_work;
}
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::allgather_into_tensor_coalesced(
std::vector<at::Tensor>& outputs,
std::vector<at::Tensor>& inputs,
const c10d::AllgatherOptions& opts)
{
auto inputTensors_ = cast_to_origin_format(inputs);
return collectiveCoalesced(
inputTensors_,
outputs,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
RECORD_FUNCTION("HcclAllgatherBase", std::vector<c10::IValue>({input}));
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() != c10_npu::option::AVOID_RECORD_STREAM) {
c10_npu::NPUCachingAllocator::recordStream(output.storage().data_ptr(), stream);
}
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(input);
auto hcclType = getHcclDataType(input.scalar_type());
auto hccl_call = [inputDataPtr, outputDataPtr, numel, hcclType, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclAllGather", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclAllGather(inputDataPtr, outputDataPtr, numel, hcclType, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclAllGather", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
c10d::OpType::ALLGATHER);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::allgather_togather(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const c10d::AllgatherOptions& opts)
{
check_npu_tensors_different_devices(inputTensors);
check_npu_tensors_different_devices(outputTensors);
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("allgather_togather", outputTensors, inputTensors);
}
auto inputTensors_ = cast_to_origin_format(inputTensors);
return collective(
inputTensors_,
outputTensors,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
RECORD_FUNCTION("HcclAllgatherTogather", std::vector<c10::IValue>({input}));
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() != c10_npu::option::AVOID_RECORD_STREAM) {
c10_npu::NPUCachingAllocator::recordStream(output.storage().data_ptr(), stream);
}
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(input);
auto hcclType = getHcclDataType(input.scalar_type());
auto hccl_call = [inputDataPtr, outputDataPtr, numel, hcclType, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclAllGather", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclAllGather(inputDataPtr, outputDataPtr, numel, hcclType, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclAllGather", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
c10d::OpType::ALLGATHER);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::_allgather_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
const c10d::AllgatherOptions& opts)
{
if (inputTensor.dtype() != outputTensor.dtype()) {
TORCH_CHECK(false, "output tensor must have the same type as input tensor", DIST_ERROR(ErrCode::PARAM));
}
if (inputTensor.numel() * size_ != outputTensor.numel()) {
TORCH_CHECK(false, "output tensor size must be equal to world_size times input tensor size", DIST_ERROR(ErrCode::PARAM));
}
std::vector<at::Tensor> inputTensors = {inputTensor};
std::vector<at::Tensor> outputTensors = {outputTensor};
check_npu_tensors_different_devices(inputTensors);
check_npu_tensors_different_devices(outputTensors);
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("_allgather_base", outputTensors, inputTensors);
}
auto inputTensors_ = cast_to_origin_format(inputTensors);
return collective(
inputTensors_,
outputTensors,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
RECORD_FUNCTION("HcclAllgatherBase", std::vector<c10::IValue>({input}));
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() != c10_npu::option::AVOID_RECORD_STREAM) {
c10_npu::NPUCachingAllocator::recordStream(output.storage().data_ptr(), stream);
}
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(input);
auto hcclType = getHcclDataType(input.scalar_type());
auto hccl_call = [inputDataPtr, outputDataPtr, numel, hcclType, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclAllGather", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclAllGather(inputDataPtr, outputDataPtr, numel, hcclType, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclAllGather", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
c10d::OpType::ALLGATHER);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const c10d::ReduceScatterOptions& opts)
{
check_npu_tensors_different_devices(outputTensors);
check_npu_tensors_same_device(inputTensors.back());
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("reduce_scatter", outputTensors, inputTensors);
}
bool same_size = check_same_size(inputTensors.back());
if (same_size) {
auto inputFlattened = flatten_for_scatter_gather(inputTensors, outputTensors, size_);
check_npu_tensors_different_devices(inputFlattened);
std::string functionName = __FUNCTION__;
return collective(
inputFlattened,
outputTensors,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
auto hcclType = getHcclDataType(input.scalar_type());
checkSupportedDataType(hcclType, functionName);
RECORD_FUNCTION("HcclReduceScatter", std::vector<c10::IValue>({input}));
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() != c10_npu::option::AVOID_RECORD_STREAM) {
c10_npu::NPUCachingAllocator::recordStream(output.storage().data_ptr(), stream);
}
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(output);
auto hcclReduceOp = getHcclReduceOp(opts.reduceOp, input);
auto hccl_call = [inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclReduceScatter", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclReduceScatter(
inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclReduceScatter", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>& work) {
work->lazyDestroy(inputFlattened);
auto multi_stream_memory_reuse_mode = c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse();
for (const auto i : c10::irange(inputTensors.size())) {
c10_npu::NPUStreamGuard guard(hcclStreams[i]);
for (const auto j : c10::irange(inputTensors[0].size())) {
if (multi_stream_memory_reuse_mode == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(inputTensors[i][j]);
} else {
c10_npu::NPUCachingAllocator::recordStream(inputTensors[i][j].storage().data_ptr(), hcclStreams[i]);
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM ||
multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
work->recorded_inputs_.push_back(
std::make_pair(inputTensors[i][j].storage().getWeakStorageImpl(), hcclStreams[i]));
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
auto block_ptr = c10_npu::NPUCachingAllocator::getBlockPtr(inputTensors[i][j].storage().data_ptr());
work->recorded_block_ptr_for_inputs_.push_back(block_ptr);
c10_npu::NPUCachingAllocator::recordHcclWorkForBlock(block_ptr, static_cast<void*>(work.get()));
}
}
}
inputFlattened[i][j].copy_(inputTensors[i][j], true);
if (opts.reduceOp == c10d::ReduceOp::PREMUL_SUM) {
const auto* preMulSupplement =
reinterpret_cast<c10d::NCCLPreMulSumSupplement*>(
opts.reduceOp.supplement_.get());
auto scale_factor = preMulSupplement->double_factor;
inputFlattened[i][j].mul_(scale_factor);
}
}
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (opts.reduceOp == c10d::ReduceOp::AVG) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
for (auto& tensor : outputTensors) {
tensor.div_(getSize());
}
}
},
c10d::OpType::REDUCE_SCATTER);
} else if (hcclReduceScatterVExist()) {
std::vector<uint64_t> inputCounts;
std::vector<uint64_t> inputSpl;
std::vector<at::Tensor> lastInputTensors = inputTensors.back();
inputSpl.push_back(0);
for (size_t i = 0; i < lastInputTensors.size(); i++) {
inputCounts.push_back(lastInputTensors[i].numel());
if (i > 0) {
inputSpl.push_back(inputSpl[i - 1] + inputCounts[i - 1]);
}
}
std::vector<at::Tensor> flattenedInputTensors;
for (size_t i = 0; i < lastInputTensors.size(); i++) {
flattenedInputTensors.push_back(at::flatten(lastInputTensors[i]));
}
std::vector<at::Tensor> inputFlattened = {at::cat(flattenedInputTensors, 0)};
std::vector<at::Tensor> outputFlattened = {at::flatten(outputTensors[0])};
return collective(
inputFlattened,
outputFlattened,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
RECORD_FUNCTION("HcclReduceScatterV", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
uint64_t outputCount = output.numel();
auto numel = getNumelForHCCL(output);
auto hcclReduceOp = getHcclReduceOp(opts.reduceOp, input);
auto hcclType = getHcclDataType(input.scalar_type());
auto hccl_call = [
inputDataPtr,
inputCounts,
inputSpl,
outputDataPtr,
outputCount,
hcclType,
hcclReduceOp,
numel,
comm,
stream,
is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclReduceScatterV", numel, hcclType, comm, stream.id(), -1, -1),
stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclReduceScatterV(
inputDataPtr,
inputCounts.data(),
inputSpl.data(),
outputDataPtr,
outputCount,
hcclType,
hcclReduceOp,
comm,
stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclReduceScatterV", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (opts.reduceOp == c10d::ReduceOp::PREMUL_SUM) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
const auto* preMulSupplement =
reinterpret_cast<c10d::NCCLPreMulSumSupplement*>(
opts.reduceOp.supplement_.get());
auto scale_factor = preMulSupplement->double_factor;
for (auto& tensor : inputFlattened) {
tensor.mul_(scale_factor);
}
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>& work) {
work->lazyDestroy(inputFlattened);
work->lazyDestroy(outputFlattened);
for (const auto i : c10::irange(outputTensors.size())) {
c10_npu::NPUStreamGuard guard(hcclStreams[i]);
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(outputTensors[i]);
} else {
c10_npu::NPUCachingAllocator::recordStream(
outputTensors[i].storage().data_ptr(), hcclStreams[i]);
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::ERASE_RECORD_STREAM) {
work->recorded_outputs_.push_back(
std::make_pair(outputTensors[i].storage().getWeakStorageImpl(), hcclStreams[i]));
}
}
at::Tensor output_tensor_reshape = at::reshape(outputFlattened[i], outputTensors[i].sizes());
outputTensors[i].copy_(output_tensor_reshape, true);
}
if (opts.reduceOp == c10d::ReduceOp::AVG) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
for (auto& tensor : outputTensors) {
tensor.div_(getSize());
}
}
},
c10d::OpType::REDUCE_SCATTER);
} else {
TORCH_NPU_WARN_ONCE("The current reduce_scatter operator has a defect in handling different tensor shape,",
"the work event forces a wait operation in c++ side, and the reduce_scatter wait on the python side would be fake");
auto outputTensor = outputTensors.back();
auto inputTensors_ = inputTensors.back();
const auto num_reduces = inputTensors_.size();
std::vector<c10::intrusive_ptr<c10d::Work>> works;
for (const auto i : c10::irange(num_reduces)) {
auto& input = inputTensors_[i];
auto& output = (i == rank_) ? outputTensor : input;
auto reduceOpts = c10d::ReduceOptions{
opts.reduceOp,
static_cast<int64_t>(i),
static_cast<int64_t>(0),
opts.timeout};
auto work = _reduce_oop(output, input, reduceOpts);
works.push_back(work);
}
for (auto& work : works) {
work->wait();
}
auto fake_work = initWork(getDeviceList(outputTensors), rank_, c10d::OpType::REDUCE_SCATTER);
assignTimeoutToWork(fake_work, options_);
return fake_work;
}
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::_reduce_scatter_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
const c10d::ReduceScatterOptions& opts)
{
check_npu_single_tensor(inputTensor);
if (inputTensor.dtype() != outputTensor.dtype()) {
TORCH_CHECK(false, "input tensor must be the same type as the output tensor.", DIST_ERROR(ErrCode::TYPE));
}
if (inputTensor.numel() != outputTensor.numel() * size_) {
TORCH_CHECK(false, "input tensor must be the same size as output size times world size", DIST_ERROR(ErrCode::PARAM));
}
auto inputs = std::vector<at::Tensor>{inputTensor};
auto outputs = std::vector<at::Tensor>{outputTensor};
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("_reduce_scatter_base", outputs, inputs);
}
std::string functionName = __FUNCTION__;
return collective(
inputs,
outputs,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() != c10_npu::option::AVOID_RECORD_STREAM) {
c10_npu::NPUCachingAllocator::recordStream(output.storage().data_ptr(), stream);
}
auto hcclType = getHcclDataType(input.scalar_type());
checkSupportedDataType(hcclType, functionName);
RECORD_FUNCTION("HcclReduceScatterBase", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(output);
auto hcclReduceOp = getHcclReduceOp(opts.reduceOp, input);
auto hccl_call = [inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclReduceScatter", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclReduceScatter(
inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclReduceScatter", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (opts.reduceOp == c10d::ReduceOp::PREMUL_SUM) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
const auto* preMulSupplement =
reinterpret_cast<c10d::NCCLPreMulSumSupplement*>(
opts.reduceOp.supplement_.get());
auto scale_factor = preMulSupplement->double_factor;
for (auto& tensor : inputs) {
tensor.mul_(scale_factor);
}
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (opts.reduceOp == c10d::ReduceOp::AVG) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
for (auto& tensor : outputs) {
tensor.div_(getSize());
}
}
},
c10d::OpType::REDUCE_SCATTER);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::reduce_scatter_tensor_coalesced(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const c10d::ReduceScatterOptions& opts)
{
check_npu_tensors_same_device(outputTensors);
check_npu_tensors_same_device(inputTensors);
std::string functionName = __FUNCTION__;
return collectiveCoalesced(
inputTensors,
outputTensors,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() != c10_npu::option::AVOID_RECORD_STREAM) {
c10_npu::NPUCachingAllocator::recordStream(output.storage().data_ptr(), stream);
}
auto hcclType = getHcclDataType(input.scalar_type());
checkSupportedDataType(hcclType, functionName);
RECORD_FUNCTION("HcclReduceScatterBase", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(output);
auto hcclReduceOp = getHcclReduceOp(opts.reduceOp, input);
auto hccl_call = [inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclReduceScatter", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclReduceScatter(
inputDataPtr, outputDataPtr, numel, hcclType, hcclReduceOp, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclReduceScatter", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (opts.reduceOp == c10d::ReduceOp::PREMUL_SUM) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
const auto* preMulSupplement =
reinterpret_cast<c10d::NCCLPreMulSumSupplement*>(
opts.reduceOp.supplement_.get());
auto scale_factor = preMulSupplement->double_factor;
for (auto& tensor : inputTensors) {
tensor.mul_(scale_factor);
}
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {
if (opts.reduceOp == c10d::ReduceOp::AVG) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
for (auto& tensor : outputTensors) {
tensor.div_(getSize());
}
}
},
c10d::OpType::REDUCE_SCATTER);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::barrier(const c10d::BarrierOptions& opts)
{
std::vector<at::Device> devices;
if (usedDeviceIdxs_.empty()) {
auto numNPUs = c10_npu::device_count();
int16_t deviceIdx = static_cast<int16_t>(rank_ % std::max(static_cast<int>(numNPUs), 1));
devices.push_back(at::Device(c10::DeviceType::PrivateUse1));
} else {
for (auto usedDeviceIdx : usedDeviceIdxs_) {
devices.push_back(at::Device(c10::DeviceType::PrivateUse1, usedDeviceIdx));
}
}
std::vector<at::Tensor> barrierTensors;
barrierTensors.reserve(devices.size());
c10_npu::OptionalNPUGuard npuGuard;
for (auto& device : devices) {
npuGuard.set_index(device.index());
barrierTensors.push_back(
at::ones({1}, at::TensorOptions().device(c10::DeviceType::PrivateUse1).dtype(at::kFloat)));
}
auto work = allreduce(barrierTensors);
auto hcclWork = dynamic_cast<ProcessGroupHCCL::WorkHCCL*>(work.get());
TORCH_CHECK(hcclWork, DIST_ERROR(ErrCode::PARAM));
hcclWork->barrierTensors_ = std::move(barrierTensors);
return work;
}
void ProcessGroupHCCL::startCoalescing()
{
coalescedDevice_.set_index(-1);
coalescedComm_ = nullptr;
coalescedTensors_.clear();
coalescing_state_ |= CoalActive;
groupStart();
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::endCoalescing(c10d::OpType optype)
{
if (coalescedComm_ == nullptr) {
groupEnd();
coalescing_state_ = 0;
return nullptr;
}
TORCH_CHECK(
coalescedDevice_.index() >= 0,
"Something went wrong. Did you call end_coalescing before start_coalescing?");
auto comm = coalescedComm_;
auto device = coalescedDevice_;
std::vector<at::Device> devices = {device};
const auto key = getKeyFromDevice(devices);
auto& hcclStreams = hcclStreams_[key];
c10_npu::NPUStream& hcclStream = hcclStreams[0];
auto opProfilerTitle = optype != c10d::OpType::COALESCED
? "hccl:" + opTypeToString(optype) + "_coalesced"
: "hccl:coalesced";
c10_npu::CaptureStatus capture_status = c10_npu::currentStreamCaptureStatusMayInitCtx();
bool enqueue = (coalescing_state_) && capture_status == c10_npu::CaptureStatus::None;
auto work = initWork(
std::vector<c10::Device>{device},
rank_,
optype,
opProfilerTitle.c_str(),
{},
{},
enqueue);
work->hcclComms_[0] = comm;
work->blockingWait_ = blockingWait_;
work->opTimeout_ = options_->timeout;
if (desyncDebug_) {
(*(work->hcclStartEvents_))[0].record(hcclStream);
}
NPU_CHECK_ERROR(c10_npu::SetDevice(device.index()));
groupEnd();
(*(work->hcclEndEvents_))[0].record(hcclStream);
if (enqueue) {
c10_npu::NPUGraph::inc_pending_event_queries();
workEnqueue(work);
}
{
c10_npu::NPUMultiStreamGuard guard(hcclStreams);
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()),
devices);
work->future_->markCompleted(at::IValue(std::vector<at::Tensor>{}));
}
coalescing_state_ = 0;
coalescedComm_ = nullptr;
coalescedTensors_.clear();
return work;
}
void ProcessGroupHCCL::groupStart()
{
auto hccl_call = [this]() -> HcclResult {
return hcclGroupStart();
};
at_npu::native::OpCommand::RunOpApiV3("hcclGroupStart", hccl_call);
++hcclActiveGroupCounter_;
}
void ProcessGroupHCCL::groupEnd()
{
auto hccl_call = [this]() -> HcclResult {
return hcclGroupEnd();
};
at_npu::native::OpCommand::RunOpApiV3("hcclGroupEnd", hccl_call);
--hcclActiveGroupCounter_;
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::endCoalescing()
{
return endCoalescing(c10d::OpType::COALESCED);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const c10d::GatherOptions& opts)
{
static auto invalidArgument = [](const std::string& msg) {
C10_THROW_ERROR(ValueError, "ProcessGroupHCCL::gather: " + msg);
};
c10d::assertRootRank(invalidArgument, opts.rootRank, size_);
check_npu_tensors_different_devices(inputTensors);
c10d::assertSingleElementInput(invalidArgument, inputTensors);
std::vector<at::Tensor> outputs;
if (getRank() == opts.rootRank) {
if (outputTensors.size() != 1) {
std::stringstream ss;
ss << "requires a single-element output list containing a list with "
<< getSize() << " tensors.";
invalidArgument(ss.str());
} else if (outputTensors[0].size() != static_cast<size_t>(getSize())) {
std::stringstream ss;
ss << "Incorrect output list size " << outputTensors[0].size()
<< ". Output list size should be " << getSize()
<< ", same as size of the process group.";
invalidArgument(ss.str());
}
const auto& options = inputTensors[0].options();
const auto& sizes = inputTensors[0].sizes();
c10d::assertTypeAndSizesMatch(invalidArgument, outputTensors[0], options, sizes);
outputs = outputTensors[0];
} else {
if (!outputTensors.empty()) {
invalidArgument("requires empty output on non-root");
}
outputs = {};
outputs.emplace_back();
}
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("gather", outputs, inputTensors);
}
bool is_compatible_soc = IsCompatibleSoc();
bool use_compatible_impl = at_npu::native::env::CheckCompatibleImpl();
if (!use_compatible_impl || !is_compatible_soc) {
throw std::runtime_error("ProcessGroupHCCL does not support gather" + DIST_ERROR(ErrCode::NOT_SUPPORT));
}
std::vector<at::Tensor> collectiveInputs;
collectiveInputs.push_back(inputTensors[0]);
return collective(
collectiveInputs,
collectiveInputs,
[this, opts, outputs, use_compatible_impl, is_compatible_soc, inputTensors]
(at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
if (!use_compatible_impl || !is_compatible_soc) {
return HCCL_E_INTERNAL;
}
RECORD_FUNCTION("HcclGather_SendRecv", std::vector<c10::IValue>({}));
const auto root = static_cast<int32_t>(opts.rootRank);
groupStart();
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
if (getRank() == root) {
for (const auto r : c10::irange(static_cast<int>(size_))) {
if (r != root) {
if (outputs[r].numel() > 0) {
auto outputDataPtr = outputs[r].data_ptr();
auto numel = getNumelForHCCL(outputs[r]);
auto hcclType = getHcclDataType(outputs[r].scalar_type());
auto hccl_call = [outputDataPtr, numel, hcclType, r, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclRecv", numel, hcclType, comm, stream.id(),
static_cast<uint32_t>(r), -1),
stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
auto hccl_result = hcclRecv(outputDataPtr, numel, hcclType,
static_cast<uint32_t>(r), comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclRecv", hccl_call, false, &stream);
}
} else {
outputs[r].copy_(inputTensors[0], true);
}
}
} else {
if (inputTensors[0].numel() > 0) {
auto inputDataPtr = inputTensors[0].data_ptr();
auto numel = getNumelForHCCL(inputTensors[0]);
auto hcclType = getHcclDataType(inputTensors[0].scalar_type());
auto hccl_call = [inputDataPtr, numel, hcclType, root, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclSend", numel, hcclType, comm, stream.id(), -1,
static_cast<uint32_t>(root)),
stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
auto hccl_result = hcclSend(inputDataPtr, numel, hcclType,
static_cast<uint32_t>(root), comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclSend", hccl_call, false, &stream);
}
}
groupEnd();
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>& work) {
auto multi_stream_memory_reuse_mode = c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse();
for (const auto i : c10::irange(inputTensors.size())) {
c10_npu::NPUStreamGuard guard(hcclStreams[i]);
if (multi_stream_memory_reuse_mode == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(inputTensors[i]);
} else {
c10_npu::NPUCachingAllocator::recordStream(inputTensors[i].storage().data_ptr(), hcclStreams[i]);
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM ||
multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
work->recorded_inputs_.push_back(
std::make_pair(inputTensors[i].storage().getWeakStorageImpl(), hcclStreams[i]));
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
auto block_ptr = c10_npu::NPUCachingAllocator::getBlockPtr(inputTensors[i].storage().data_ptr());
work->recorded_block_ptr_for_inputs_.push_back(block_ptr);
c10_npu::NPUCachingAllocator::recordHcclWorkForBlock(block_ptr, static_cast<void*>(work.get()));
}
}
}
}
},
c10d::OpType::GATHER);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const c10d::ScatterOptions& opts)
{
static auto invalidArgument = [](const std::string& msg) {
C10_THROW_ERROR(ValueError, "ProcessGroupHCCL::scatter: " + msg);
};
c10d::assertRootRank(invalidArgument, opts.rootRank, size_);
check_npu_tensors_different_devices(outputTensors);
c10d::assertSingleElementInput(invalidArgument, outputTensors);
if (getRank() == opts.rootRank) {
if (inputTensors.size() != 1) {
std::stringstream ss;
ss << "requires a single-element input list containing a list with "
<< getSize() << " tensors.";
invalidArgument(ss.str());
} else if (inputTensors[0].size() != static_cast<size_t>(getSize())) {
std::stringstream ss;
ss << "Incorrect input list size " << inputTensors[0].size()
<< ". Input list size should be " << getSize()
<< ", same as size of the process group.";
invalidArgument(ss.str());
}
const auto& options = outputTensors[0].options();
const auto& sizes = outputTensors[0].sizes();
c10d::assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes);
} else {
if (inputTensors.size() != 0) {
invalidArgument("requires empty input on non-root");
}
}
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("scatter", outputTensors, inputTensors);
}
bool use_compatible_impl = at_npu::native::env::CheckCompatibleImpl();
bool is_compatible_soc = IsCompatibleSoc();
std::vector<at::Tensor> collectiveInputs;
std::vector<at::Tensor> inputFlattened;
bool need_flatten_copy = false;
if (use_compatible_impl && is_compatible_soc) {
if (getRank() == opts.rootRank) {
collectiveInputs.push_back(inputTensors[0][0]);
} else {
collectiveInputs.push_back(outputTensors[0]);
}
} else {
if (getRank() == opts.rootRank) {
inputFlattened = flatten_for_scatter_gather(inputTensors, outputTensors, size_);
} else {
std::vector<at::Tensor> empty;
for (int i = 0; i < size_; i++) {
empty.push_back(at::empty_like(outputTensors[0]));
}
inputTensors.push_back(empty);
inputFlattened = flatten_for_scatter_gather(inputTensors, outputTensors, size_);
}
collectiveInputs = inputFlattened;
need_flatten_copy = true;
}
return collective(
collectiveInputs,
outputTensors,
[this, opts, use_compatible_impl, inputTensors, is_compatible_soc]
(at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
if (use_compatible_impl && is_compatible_soc) {
RECORD_FUNCTION("HcclScatter_SendRecv", std::vector<c10::IValue>({}));
const auto root = static_cast<int32_t>(opts.rootRank);
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() != c10_npu::option::AVOID_RECORD_STREAM) {
c10_npu::NPUCachingAllocator::recordStream(output.storage().data_ptr(), stream);
}
groupStart();
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
if (getRank() == root) {
for (const auto r : c10::irange(static_cast<int>(size_))) {
if (r != root) {
const at::Tensor& sendTensor = inputTensors[0][r];
if (sendTensor.numel() > 0) {
auto inputDataPtr = sendTensor.data_ptr();
auto numel = getNumelForHCCL(sendTensor);
auto hcclType = getHcclDataType(sendTensor.scalar_type());
auto hccl_call = [inputDataPtr, numel, hcclType, r, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclSend", numel, hcclType, comm, stream.id(), -1,
static_cast<uint32_t>(r)),
stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
auto hccl_result = hcclSend(inputDataPtr, numel, hcclType,
static_cast<uint32_t>(r), comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclSend", hccl_call, false, &stream);
}
} else {
output.copy_(inputTensors[0][root]);
}
}
} else {
if (output.numel() > 0) {
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(output);
auto hcclType = getHcclDataType(output.scalar_type());
auto hccl_call = [outputDataPtr, numel, hcclType, root, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclRecv", numel, hcclType, comm, stream.id(),
static_cast<uint32_t>(root), -1),
stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
auto hccl_result = hcclRecv(outputDataPtr, numel, hcclType,
static_cast<uint32_t>(root), comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclRecv", hccl_call, false, &stream);
}
}
groupEnd();
return HCCL_SUCCESS;
} else {
RECORD_FUNCTION("HcclScatter", std::vector<c10::IValue>({input}));
const auto root = opts.rootRank;
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() != c10_npu::option::AVOID_RECORD_STREAM) {
c10_npu::NPUCachingAllocator::recordStream(output.storage().data_ptr(), stream);
}
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(output);
auto hcclType = getHcclDataType(input.scalar_type());
auto hccl_call = [inputDataPtr, outputDataPtr, numel, hcclType, root, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclScatter", numel, hcclType, comm, stream.id(), -1, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclScatter(inputDataPtr, outputDataPtr, numel, hcclType, root, comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclScatter", hccl_call, false, &stream);
return HCCL_SUCCESS;
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>& work) {
if (need_flatten_copy) {
work->lazyDestroy(inputFlattened);
auto multi_stream_memory_reuse_mode = c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse();
for (const auto i : c10::irange(inputTensors.size())) {
c10_npu::NPUStreamGuard guard(hcclStreams[i]);
for (const auto j : c10::irange(inputTensors[0].size())) {
if (multi_stream_memory_reuse_mode == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(inputTensors[i][j]);
} else {
c10_npu::NPUCachingAllocator::recordStream(inputTensors[i][j].storage().data_ptr(), hcclStreams[i]);
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM ||
multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
work->recorded_inputs_.push_back(
std::make_pair(inputTensors[i][j].storage().getWeakStorageImpl(), hcclStreams[i]));
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
auto block_ptr = c10_npu::NPUCachingAllocator::getBlockPtr(inputTensors[i][j].storage().data_ptr());
work->recorded_block_ptr_for_inputs_.push_back(block_ptr);
c10_npu::NPUCachingAllocator::recordHcclWorkForBlock(block_ptr, static_cast<void*>(work.get()));
}
}
}
inputFlattened[i][j].copy_(inputTensors[i][j], true);
}
}
} else {
auto multi_stream_memory_reuse_mode = c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse();
if (getRank() == opts.rootRank) {
for (const auto j : c10::irange(inputTensors[0].size())) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
if (multi_stream_memory_reuse_mode == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(inputTensors[0][j]);
} else {
c10_npu::NPUCachingAllocator::recordStream(inputTensors[0][j].storage().data_ptr(), hcclStreams[0]);
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM ||
multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
work->recorded_inputs_.push_back(
std::make_pair(inputTensors[0][j].storage().getWeakStorageImpl(), hcclStreams[0]));
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
auto block_ptr = c10_npu::NPUCachingAllocator::getBlockPtr(inputTensors[0][j].storage().data_ptr());
work->recorded_block_ptr_for_inputs_.push_back(block_ptr);
c10_npu::NPUCachingAllocator::recordHcclWorkForBlock(block_ptr, static_cast<void*>(work.get()));
}
}
}
}
}
}
},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
c10d::OpType::SCATTER);
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::send(std::vector<at::Tensor>& tensors, int dstRank, int tag)
{
check_npu_tensors_different_devices(tensors);
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("send", tensors);
}
auto tensors_ = cast_to_origin_format(tensors);
auto ret = pointToPoint(
tensors_,
[&](at::Tensor& input, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched, int dst_rank) {
RECORD_FUNCTION("HcclSend", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
auto numel = getNumelForHCCL(input);
auto hcclType = getHcclDataType(input.scalar_type());
auto hccl_call = [inputDataPtr, numel, hcclType, dst_rank, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclSend", numel, hcclType, comm, stream.id(), -1, dst_rank), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclSend(inputDataPtr, numel, hcclType, static_cast<uint32_t>(dst_rank), comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclSend", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
dstRank, c10d::OpType::SEND);
return ret;
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::recv(std::vector<at::Tensor>& tensors, int srcRank, int tag)
{
check_npu_tensors_different_devices(tensors);
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("recv", tensors);
}
auto tensors_ = create_base_format_tensors(tensors);
auto ret = pointToPoint(
tensors_,
[&](at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched, int src_rank) {
RECORD_FUNCTION("HcclRecv", std::vector<c10::IValue>({output}));
c10_npu::NPUCachingAllocator::recordStream(output.storage().data_ptr(), stream);
auto outputDataPtr = output.data_ptr();
auto numel = getNumelForHCCL(output);
auto hcclType = getHcclDataType(output.scalar_type());
auto hccl_call = [outputDataPtr, numel, hcclType, src_rank, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclRecv", numel, hcclType, comm, stream.id(), src_rank, -1), stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclRecv(outputDataPtr, numel, hcclType, static_cast<uint32_t>(src_rank), comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclRecv", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
srcRank, c10d::OpType::RECV,
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>& work) {
for (size_t i = 0; i < tensors_.size(); ++i) {
c10_npu::NPUStreamGuard guard(hcclStreams[i]);
c10_npu::NPUCachingAllocator::recordStream(tensors_[i].storage().data_ptr(), hcclStreams[i]);
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() != c10_npu::option::CLOSE) {
work->recorded_outputs_.push_back(
std::make_pair(tensors_[i].storage().getWeakStorageImpl(), hcclStreams[i]));
}
if (!at_npu::native::FormatHelper::IsBaseFormatType(tensors[i])) {
tensors[i].copy_(tensors_[i], true);
}
}
});
return ret;
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::recvAnysource(std::vector<at::Tensor>& , int )
{
TORCH_CHECK(false, "ProcessGroupHCCL does not support recv", DIST_ERROR(ErrCode::NOT_SUPPORT));
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const c10d::AllToAllOptions& opts)
{
check_npu_single_tensor(outputTensor);
check_npu_single_tensor(inputTensor);
int ranks = getSize();
TORCH_CHECK(ranks > 0, "Invalid rank count within current process group", ranks, DIST_ERROR(ErrCode::PARAM));
std::vector<at::Tensor> inputTensors = {inputTensor};
std::vector<at::Tensor> outputTensors = {outputTensor};
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("alltoall_base", outputTensors, inputTensors);
}
auto inputTensors_ = cast_to_origin_format(inputTensors);
auto outputTensors_ = cast_to_origin_format(outputTensors);
if (inputSplitSizes.empty() && outputSplitSizes.empty()) {
TORCH_CHECK(
outputTensor.numel() == inputTensor.numel() &&
outputTensor.type() == inputTensor.type(),
"Tensors are not equal in size or data type",
DIST_ERROR(ErrCode::PARAM));
TORCH_CHECK(
inputTensor.dim() >= 1 && outputTensor.dim() >= 1,
"Scalar tensors (0-dim tensors) are not supported in alltoall. "
"inputTensor.dim()=", inputTensor.dim(), ", outputTensor.dim()=", outputTensor.dim(),
". Please reshape tensors to at least 1-D before calling alltoall.",
DIST_ERROR(ErrCode::NOT_SUPPORT));
TORCH_CHECK(
outputTensor.size(0) % ranks == 0,
"Tensor's dim 0 does not divide equally across group size",
DIST_ERROR(ErrCode::PARAM));
uint64_t output_counts = static_cast<uint64_t>(outputTensor.numel() / ranks);
uint64_t input_counts = static_cast<uint64_t>(inputTensor.numel() / ranks);
check_npu_tensors_different_devices(inputTensors);
check_npu_tensors_different_devices(outputTensors);
return collective(
inputTensors_,
outputTensors_,
[&](at::Tensor& input,
at::Tensor& output,
HcclComm comm,
c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
RECORD_FUNCTION("HcclAlltoAll", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto inputhcclDataType = getHcclDataType(input.scalar_type());
auto outputhcclDataType = getHcclDataType(output.scalar_type());
auto hccl_call = [inputDataPtr,
input_counts,
inputhcclDataType,
outputDataPtr,
output_counts,
outputhcclDataType,
comm,
stream,
is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclAlltoAll", input_counts, inputhcclDataType, comm, stream.id(), -1, -1),
stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclAlltoAll(
inputDataPtr,
input_counts,
inputhcclDataType,
outputDataPtr,
output_counts,
outputhcclDataType,
comm,
stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclAlltoAll", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>& work) {
for (size_t i = 0; i < outputTensors_.size(); ++i) {
c10_npu::NPUStreamGuard guard(hcclStreams[i]);
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(outputTensors_[i]);
} else {
c10_npu::NPUCachingAllocator::recordStream(outputTensors_[i].storage().data_ptr(), hcclStreams[i]);
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::ERASE_RECORD_STREAM) {
work->recorded_outputs_.push_back(
std::make_pair(outputTensors_[i].storage().getWeakStorageImpl(), hcclStreams[i]));
}
}
if (!at_npu::native::FormatHelper::IsBaseFormatType(outputTensors[i])) {
outputTensors[i].copy_(outputTensors_[i], true);
}
}
},
c10d::OpType::ALLTOALL_BASE);
} else {
uint64_t index = static_cast<uint64_t>(outputTensor.size(0) / ranks);
if (outputSplitSizes.empty()) {
for (int i = 0; i < ranks; i++) {
outputSplitSizes.push_back(index);
}
}
index = static_cast<uint64_t>(inputTensor.size(0) / ranks);
if (inputSplitSizes.empty()) {
for (int i = 0; i < ranks; i++) {
inputSplitSizes.push_back(index);
}
}
check_split_sizes(inputSplitSizes, inputTensor, size_);
check_split_sizes(outputSplitSizes, outputTensor, size_);
int inputSize = static_cast<int>(inputSplitSizes.size());
int outSize = static_cast<int>(outputSplitSizes.size());
int inputRowSize = static_cast<int>(inputTensor.size(0) != 0 ? inputTensor.numel() / inputTensor.size(0) : 1);
int outputRowSize = static_cast<int>(outputTensor.size(0) != 0 ? outputTensor.numel() / outputTensor.size(0) : 1);
std::vector<uint64_t> inputCounts;
std::vector<uint64_t> inputSpl;
std::vector<uint64_t> outputCounts;
std::vector<uint64_t> outputSpl;
inputSpl.push_back(0);
outputSpl.push_back(0);
for (int i = 0; i < outSize; i++) {
outputCounts.push_back(static_cast<uint64_t>(outputSplitSizes[i] * outputRowSize));
if (i > 0) {
outputSpl.push_back(outputSpl[i - 1] + outputCounts[i - 1]);
}
}
for (int i = 0; i < inputSize; i++) {
inputCounts.push_back(static_cast<uint64_t>(inputSplitSizes[i] * inputRowSize));
if (i > 0) {
inputSpl.push_back(inputSpl[i - 1] + inputCounts[i - 1]);
}
}
check_npu_tensors_different_devices(inputTensors);
check_npu_tensors_different_devices(outputTensors);
return collective(
inputTensors_,
outputTensors_,
[&](at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream, std::shared_ptr<bool> is_dispatched) {
RECORD_FUNCTION("HcclAlltoAllV", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto inputhcclDataType = getHcclDataType(input.scalar_type());
auto outputhcclDataType = getHcclDataType(output.scalar_type());
auto hccl_call = [inputDataPtr,
inputCounts,
inputSpl,
inputhcclDataType,
outputDataPtr,
outputCounts,
outputSpl,
outputhcclDataType,
comm,
stream,
is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclAlltoAllV", static_cast<uint64_t>(inputCounts.size()),
inputhcclDataType, comm, stream.id(), -1, -1),
stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclAlltoAllV(
inputDataPtr,
inputCounts.data(),
inputSpl.data(),
inputhcclDataType,
outputDataPtr,
outputCounts.data(),
outputSpl.data(),
outputhcclDataType,
comm,
stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclAlltoAllV", hccl_call, false, &stream);
return HCCL_SUCCESS;
},
[&](std::vector<c10_npu::NPUStream>&, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>&) {},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>& work) {
for (size_t i = 0; i < outputTensors_.size(); ++i) {
c10_npu::NPUStreamGuard guard(hcclStreams[i]);
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(outputTensors_[i]);
} else {
c10_npu::NPUCachingAllocator::recordStream(outputTensors_[i].storage().data_ptr(), hcclStreams[i]);
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::ERASE_RECORD_STREAM) {
work->recorded_outputs_.push_back(
std::make_pair(outputTensors_[i].storage().getWeakStorageImpl(), hcclStreams[i]));
}
}
if (!at_npu::native::FormatHelper::IsBaseFormatType(outputTensors[i])) {
outputTensors[i].copy_(outputTensors_[i], true);
}
}
},
c10d::OpType::ALLTOALL);
}
}
c10::intrusive_ptr<c10d::Work> ProcessGroupHCCL::alltoall(
std::vector<at::Tensor>& output_tensors,
std::vector<at::Tensor>& input_tensors,
const c10d::AllToAllOptions& opts)
{
TORCH_CHECK(output_tensors.size() == size_,
"the size of output_tensors and worldsize must equal", DIST_ERROR(ErrCode::PARAM));
TORCH_CHECK(output_tensors.size() == input_tensors.size(),
"the size of input_tensors and output_tensors must equal", DIST_ERROR(ErrCode::PARAM));
auto device = output_tensors[0].device();
for (const auto r : c10::irange(output_tensors.size())) {
check_npu_single_tensor(output_tensors[r]);
check_npu_single_tensor(input_tensors[r]);
TORCH_CHECK(device == output_tensors[r].device() && device == input_tensors[r].device(),
"tensors must be on the same device", DIST_ERROR(ErrCode::PARAM));
}
if (C10_UNLIKELY(at_npu::native::env::CheckOpHookEnable())) {
at_npu::native::OpHook::GetInstance().PreHook("alltoall", output_tensors, input_tensors);
}
bool use_compatible_impl = at_npu::native::env::CheckCompatibleImpl();
bool is_compatible_soc = IsCompatibleSoc();
std::vector<at::Tensor> collectiveInputs;
std::vector<at::Tensor> collectiveOutputs;
std::vector<at::Tensor> input_tensors_flattened;
std::vector<at::Tensor> output_tensors_flattened;
std::vector<at::Tensor> input_tensors_;
std::vector<at::Tensor> output_tensors_;
std::vector<int64_t> output_split_sizes;
std::vector<int64_t> input_split_sizes;
bool view_as_byte = false;
bool need_flatten_copy = false;
if (use_compatible_impl && is_compatible_soc) {
collectiveInputs.push_back(input_tensors[0]);
collectiveOutputs.push_back(output_tensors[0]);
} else {
if (input_tensors[0].dtype() == at::ScalarType::Float8_e5m2 ||
input_tensors[0].dtype() == at::ScalarType::Float8_e4m3fn) {
view_as_byte = true;
}
if (view_as_byte) {
for (size_t i = 0; i < input_tensors.size(); i++) {
input_split_sizes.push_back(input_tensors[i].numel());
input_tensors_flattened.push_back(
at::reshape(input_tensors[i], {input_tensors[i].numel(), 1}).view(at::ScalarType::Byte));
}
for (size_t i = 0; i < output_tensors.size(); i++) {
output_split_sizes.push_back(output_tensors[i].numel());
output_tensors_flattened.push_back(
at::reshape(output_tensors[i], {output_tensors[i].numel(), 1}).view(at::ScalarType::Byte));
}
} else {
for (size_t i = 0; i < input_tensors.size(); i++) {
input_split_sizes.push_back(input_tensors[i].numel());
input_tensors_flattened.push_back(at::reshape(input_tensors[i], {input_tensors[i].numel(), 1}));
}
for (size_t i = 0; i < output_tensors.size(); i++) {
output_split_sizes.push_back(output_tensors[i].numel());
output_tensors_flattened.push_back(at::reshape(output_tensors[i], {output_tensors[i].numel(), 1}));
}
}
std::vector<at::Tensor> in_tensors = {at::cat(input_tensors_flattened, 0)};
std::vector<at::Tensor> out_tensors = {at::cat(output_tensors_flattened, 0)};
input_tensors_ = cast_to_origin_format(in_tensors);
output_tensors_ = cast_to_origin_format(out_tensors);
check_npu_tensors_different_devices(in_tensors);
check_npu_tensors_different_devices(out_tensors);
collectiveInputs = input_tensors_;
collectiveOutputs = output_tensors_;
need_flatten_copy = true;
}
int ranks = getSize();
int inputsize = static_cast<int>(input_tensors.size());
int outsize = static_cast<int>(output_tensors.size());
std::vector<uint64_t> input_counts;
std::vector<uint64_t> input_spl;
std::vector<uint64_t> output_counts;
std::vector<uint64_t> output_spl;
input_spl.push_back(0);
output_spl.push_back(0);
if (!use_compatible_impl || !is_compatible_soc) {
output_counts.push_back(static_cast<uint64_t>(output_split_sizes[0]));
input_counts.push_back(static_cast<uint64_t>(input_split_sizes[0]));
for (int i = 1; i < outsize; i++) {
output_counts.push_back(static_cast<uint64_t>(output_split_sizes[i]));
output_spl.push_back(output_spl[i - 1] + static_cast<uint64_t>(output_split_sizes[i - 1]));
}
for (int i = 1; i < inputsize; i++) {
input_counts.push_back(static_cast<uint64_t>(input_split_sizes[i]));
input_spl.push_back(input_spl[i - 1] + static_cast<uint64_t>(input_split_sizes[i - 1]));
}
}
return collective(
collectiveInputs,
collectiveOutputs,
[this, input_tensors, output_tensors, use_compatible_impl, input_counts, input_spl,
output_counts, output_spl, is_compatible_soc]
(at::Tensor& input, at::Tensor& output, HcclComm comm, c10_npu::NPUStream& stream,
std::shared_ptr<bool> is_dispatched) {
if (use_compatible_impl && is_compatible_soc) {
RECORD_FUNCTION("HcclAlltoAll_SendRecv", std::vector<c10::IValue>({}));
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() != c10_npu::option::AVOID_RECORD_STREAM) {
c10_npu::NPUCachingAllocator::recordStream(output.storage().data_ptr(), stream);
}
groupStart();
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
for (const int r : c10::irange(static_cast<int>(input_tensors.size()))) {
if (input_tensors[r].numel() > 0) {
auto inputDataPtr = input_tensors[r].data_ptr();
auto numel = getNumelForHCCL(input_tensors[r]);
auto hcclType = getHcclDataType(input_tensors[r].scalar_type());
auto hccl_call = [inputDataPtr, numel, hcclType, r, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclSend", numel, hcclType, comm, stream.id(), -1,
static_cast<uint32_t>(r)),
stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
auto hccl_result = hcclSend(inputDataPtr, numel, hcclType,
static_cast<uint32_t>(r), comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclSend", hccl_call, false, &stream);
}
if (output_tensors[r].numel() > 0) {
auto outputDataPtr = output_tensors[r].data_ptr();
auto numel = getNumelForHCCL(output_tensors[r]);
auto hcclType = getHcclDataType(output_tensors[r].scalar_type());
auto hccl_call = [outputDataPtr, numel, hcclType, r, comm, stream, is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclRecv", numel, hcclType, comm, stream.id(),
static_cast<uint32_t>(r), -1),
stream.stream(false),
torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
auto hccl_result = hcclRecv(outputDataPtr, numel, hcclType,
static_cast<uint32_t>(r), comm, stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclRecv", hccl_call, false, &stream);
}
}
groupEnd();
return HCCL_SUCCESS;
} else {
RECORD_FUNCTION("HcclAlltoAllV", std::vector<c10::IValue>({input}));
auto inputDataPtr = input.data_ptr();
auto outputDataPtr = output.data_ptr();
auto inputhcclDataType = getHcclDataType(input.scalar_type());
auto outputhcclDataType = getHcclDataType(output.scalar_type());
auto hccl_call = [inputDataPtr,
input_counts,
input_spl,
inputhcclDataType,
outputDataPtr,
output_counts,
output_spl,
outputhcclDataType,
comm,
stream,
is_dispatched]() -> int {
#ifndef BUILD_LIBTORCH
torch_npu::profiler::MstxRange range(
getMstxHcclMsg("HcclAlltoAllV", static_cast<uint64_t>(input_counts.size()),
inputhcclDataType, comm, stream.id(), -1, -1),
stream.stream(false), torch_npu::profiler::DOMAIN_COMMUNICATION);
#endif
if (c10_npu::is_core_control_enabled()) {
c10_npu::UseStreamResInCurrentThread(stream.stream(false));
}
auto hccl_result = hcclAlltoAllV(
inputDataPtr,
input_counts.data(),
input_spl.data(),
inputhcclDataType,
outputDataPtr,
output_counts.data(),
output_spl.data(),
outputhcclDataType,
comm,
stream.stream(false));
*is_dispatched = true;
return hccl_result;
};
at_npu::native::OpCommand::RunOpApiV3("HcclAlltoAllV", hccl_call, false, &stream);
return HCCL_SUCCESS;
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>& work) {
if (need_flatten_copy) {
work->lazyDestroy(input_tensors_);
work->lazyDestroy(output_tensors_);
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(output_tensors_[0]);
} else {
c10_npu::NPUCachingAllocator::recordStream(output_tensors_[0].storage().data_ptr(), hcclStreams[0]);
if (c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse() == c10_npu::option::ERASE_RECORD_STREAM) {
work->recorded_outputs_.push_back(
std::make_pair(output_tensors_[0].storage().getWeakStorageImpl(), hcclStreams[0]));
}
}
} else {
auto multi_stream_memory_reuse_mode = c10_npu::option::OptionsManager::GetMultiStreamMemoryReuse();
for (const auto i : c10::irange(input_tensors.size())) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
if (multi_stream_memory_reuse_mode == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(input_tensors[i]);
} else {
c10_npu::NPUCachingAllocator::recordStream(input_tensors[i].storage().data_ptr(), hcclStreams[0]);
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM ||
multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
work->recorded_inputs_.push_back(
std::make_pair(input_tensors[i].storage().getWeakStorageImpl(), hcclStreams[0]));
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
auto block_ptr = c10_npu::NPUCachingAllocator::getBlockPtr(input_tensors[i].storage().data_ptr());
work->recorded_block_ptr_for_inputs_.push_back(block_ptr);
c10_npu::NPUCachingAllocator::recordHcclWorkForBlock(block_ptr, static_cast<void*>(work.get()));
}
}
}
}
for (const auto i : c10::irange(output_tensors.size())) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
if (multi_stream_memory_reuse_mode == c10_npu::option::AVOID_RECORD_STREAM) {
work->stashed_for_allocator_safety_.push_back(output_tensors[i]);
} else {
c10_npu::NPUCachingAllocator::recordStream(output_tensors[i].storage().data_ptr(), hcclStreams[0]);
if (multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM ||
multi_stream_memory_reuse_mode == c10_npu::option::ERASE_RECORD_STREAM_WITH_OPTIMIZE) {
work->recorded_outputs_.push_back(
std::make_pair(output_tensors[i].storage().getWeakStorageImpl(), hcclStreams[0]));
}
}
}
}
},
[&](std::vector<c10_npu::NPUStream>& hcclStreams, c10::intrusive_ptr<ProcessGroupHCCL::WorkHCCL>& work) {
if (need_flatten_copy) {
c10_npu::NPUStreamGuard guard(hcclStreams[0]);
std::vector<at::Tensor> out_tensors;
if (view_as_byte) {
out_tensors = {output_tensors_[0].view(input_tensors[0].scalar_type())};
} else {
out_tensors = output_tensors_;
}
std::vector<at::Tensor> output_results = at::split(out_tensors[0], output_split_sizes, 0);
for (int i = 0; i < output_results.size(); i++) {
at::Tensor reshaped = at::reshape(output_results[i], {output_results[i].numel(), 1});
if (view_as_byte) {
reshaped = reshaped.view(input_tensors[0].scalar_type());
}
output_tensors[i].copy_(at::reshape(reshaped, output_tensors[i].sizes()), true);
}
}
},
c10d::OpType::ALLTOALL);
}
}