#include "torch_npu/csrc/npu/DataParallelComm.h"
#include <unordered_map>
#include <memory>
#include <bitset>
#include <thread>
#include <ATen/ATen.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/functional.h>
#include <c10/util/irange.h>
#include <c10/util/Optional.h>
#include <c10/util/ArrayRef.h>
#include <torch/csrc/utils/tensor_flatten.h>
#include <torch/csrc/autograd/variable.h>
#include "torch_npu/csrc/distributed/HCCLUtils.hpp"
#include "torch_npu/csrc/core/npu/NPUMacros.h"
#include "torch_npu/csrc/core/npu/NPUGuard.h"
#include "torch_npu/csrc/core/npu/DeviceUtils.h"
#include "torch_npu/csrc/core/npu/NPUException.h"
namespace torch_npu {
using namespace at;
namespace data_parallel {
template <typename T>
struct GetSecondArgType;
template <typename R, typename Arg0, typename Arg1, typename... Args>
struct GetSecondArgType<R(Arg0, Arg1, Args...)> {
typedef typename std::decay<Arg1>::type type;
};
constexpr auto count_max = std::numeric_limits<GetSecondArgType<decltype(HcclBroadcast)>::type>::max();
inline size_t get_max_count()
{
return count_max;
}
struct HcclCommList {
std::unique_ptr<HcclComm[]> comms;
int ndevices;
explicit HcclCommList(const std::vector<int>& devices): comms(new HcclComm[devices.size()]), ndevices(devices.size())
{
HCCL_CHECK_ERROR(c10d_npu::hcclCommInitAll(devices.size(), const_cast<int32_t*>(devices.data()), comms.get()));
}
HcclCommList(HcclCommList&& foo) = default;
HcclCommList& operator=(HcclCommList&& foo) = default;
~HcclCommList() = default;
void Release()
{
if (comms) {
for (const auto i : c10::irange(ndevices)) {
int dummy_var;
auto ret = c10_npu::GetDevice(&dummy_var);
if (ret != ACL_ERROR_NONE) {
NPU driver is already unloaded from the process.
In these cases, skip hcclCommDestroy */
return;
}
c10d_npu::hcclCommDestroy(comms[i]);
}
}
}
ArrayRef<HcclComm> ref() const
{
return ArrayRef<HcclComm>(comms.get(), ndevices);
}
};
using device_set = std::bitset<C10_COMPILE_TIME_MAX_NPUS>;
using device_list = std::vector<int>;
static std::unordered_map<device_list, HcclCommList, c10::hash<device_list>> _communicators;
using comm_list = std::vector<HcclComm>;
using stream_list = std::vector<c10::optional<c10_npu::NPUStream>>;
void ReleaseHcclCommList()
{
for (auto& [_, comm] : _communicators) {
comm.Release();
}
}
ArrayRef<HcclComm> get_communicators(TensorList inputs)
{
static auto get_device = [](const at::Tensor& t) -> int {return t.get_device();};
device_list devices = c10::fmap(inputs, get_device);
auto it = _communicators.find(devices);
if (it == _communicators.end()) {
std::tie(it, std::ignore) = _communicators.emplace(devices, devices);
}
return it->second.ref();
}
static inline void check_tensor(const at::Tensor& input, const at::optional<at::Tensor>& output,
int input_multiplier, int output_multiplier, int64_t ref_numel, ScalarType ref_dtype)
{
auto check_one = [&](const at::Tensor& tensor) {
if (!torch_npu::utils::is_npu(tensor) || tensor.is_sparse()) {
throw std::runtime_error("input and output elements have to be npu dense Tensors" + PTA_ERROR(ErrCode::TYPE));
}
if (ref_dtype != tensor.scalar_type()) {
throw std::runtime_error("all inputs and outputs must be of the same Tensor dtype" + PTA_ERROR(ErrCode::TYPE));
}
if (!tensor.is_contiguous()) {
throw std::runtime_error("all inputs and outputs have to be contiguous" + PTA_ERROR(ErrCode::TYPE));
}
};
check_one(input);
if (input.numel() != ref_numel) {
throw std::runtime_error("all inputs must have the same number of elements" + PTA_ERROR(ErrCode::PARAM));
}
if (output) {
check_one(*output);
if (input.get_device() != output->get_device()) {
throw std::runtime_error("input and output must be on the same device" + PTA_ERROR(ErrCode::PARAM));
}
if (output->numel() * output_multiplier != ref_numel * input_multiplier) {
throw std::runtime_error("output must be of size input_size * size_multiplier" + PTA_ERROR(ErrCode::PARAM));
}
}
}
void check_inputs(TensorList inputs, TensorList outputs, int input_multiplier, int output_multiplier)
{
size_t len = inputs.size();
if (len == 0) {
throw std::runtime_error("input sequence can't be empty" + PTA_ERROR(ErrCode::PARAM));
}
if (len != outputs.size()) {
std::stringstream err;
err << "inputs and outputs sequences have to be of the same length, but got input of length "
<< len << " and output of length " << outputs.size() << PTA_ERROR(ErrCode::PARAM);
throw std::runtime_error(err.str());
}
device_set devices;
int64_t numel = inputs[0].numel();
auto dtype = inputs[0].scalar_type();
for (const auto i : c10::irange(len)) {
auto input = inputs[i];
auto output = outputs[i];
check_tensor(input, output, input_multiplier, output_multiplier, numel, dtype);
auto input_device = input.get_device();
if (devices.test(input_device)) {
throw std::runtime_error("inputs must be on unique devices" + PTA_ERROR(ErrCode::TYPE));
}
devices.set(input_device);
}
}
struct unique_type_checker {
void show(size_t type_id)
{
if (!unique) {
return;
}
if (!type_id_) {
type_id_ = type_id;
}
unique = type_id_.value() == type_id;
}
c10::optional<size_t> type_id_;
bool unique = true;
};
bool is_available(TensorList tensors)
{
device_set devices;
for (auto& tensor : tensors) {
if (!torch_npu::utils::is_npu(tensor) || tensor.is_sparse()) {
return false;
}
if (!tensor.is_contiguous()) {
return false;
}
auto device = tensor.get_device();
if (devices[device]) {
return false;
}
devices[device] = true;
}
return true;
}
void broadcast(TensorList tensors, const stream_list& streams = {}, const comm_list& user_comms = {})
{
check_inputs(tensors, tensors, 1, 1);
auto data_type = c10d_npu::getHcclDataType(tensors[0].scalar_type());
int64_t numel = tensors[0].numel();
const auto comms = user_comms.empty() ? get_communicators(tensors) : ArrayRef<HcclComm>(user_comms);
c10_npu::OptionalNPUGuard device_guard;
std::vector<std::thread> threads;
for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) {
int device = tensors[i].get_device();
threads.emplace_back([&, i, device]() {
NPU_CHECK_ERROR(c10_npu::SetDevice(device));
const auto stream = (streams.empty() || !streams[i]) ?
c10_npu::getCurrentNPUStream(device).stream() : streams[i]->stream();
TORCH_CHECK(
static_cast<uint64_t>(numel) <= static_cast<uint64_t>(count_max),
"Broadcast tensor has ",
numel,
" elements, which exceeds the maximum HCCL supports (",
count_max,
")" + PTA_ERROR(ErrCode::VALUE));
HcclComm comm = comms[i];
HCCL_CHECK_ERROR(c10d_npu::hcclBroadcast(tensors[i].data_ptr(), numel, data_type, 0, comm, stream));
});
}
for (auto &t : threads) {
t.join();
}
}
static inline std::vector<Tensor>& _broadcast_out_impl(const Tensor& tensor, std::vector<Tensor>& out_tensors)
{
std::vector<Tensor> hccl_list;
hccl_list.reserve(out_tensors.size() + 1);
hccl_list.emplace_back(tensor);
for (auto& out_tensor : out_tensors) {
hccl_list.emplace_back(out_tensor);
}
if (is_available(hccl_list)) {
broadcast(hccl_list);
} else {
for (auto& out_tensor : out_tensors) {
out_tensor.copy_(tensor, true);
}
}
return out_tensors;
}
std::vector<Tensor>& broadcast_out(const Tensor& tensor, std::vector<Tensor>& out_tensors)
{
for (const auto i : c10::irange(out_tensors.size())) {
TORCH_CHECK(torch_npu::utils::is_npu(out_tensors[i]),
"Expected all output tensors to be NPU tensors, but output tensor at index ",
i,
" has device '",
out_tensors[i].device(),
"'" + PTA_ERROR(ErrCode::TYPE));
TORCH_CHECK(out_tensors[i].sizes() == tensor.sizes(),
"Expected all output tensors to have same shape as the source tensor ",
tensor.sizes(),
", but output tensor at index ",
i,
" has shape ",
out_tensors[i].sizes(), PTA_ERROR(ErrCode::VALUE));
}
return _broadcast_out_impl(tensor, out_tensors);
}
std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices)
{
std::vector<Tensor> diff_device_dst_tensors;
diff_device_dst_tensors.reserve(devices.size());
for (auto device : devices) {
TORCH_CHECK(device >= 0, "Expected non-negative device index, but got ", device, PTA_ERROR(ErrCode::VALUE));
if (device != tensor.get_device()) {
diff_device_dst_tensors.emplace_back(
at::empty(tensor.sizes(), tensor.options().device(at::Device(DeviceType::PrivateUse1, device))));
}
}
_broadcast_out_impl(tensor, diff_device_dst_tensors);
std::vector<Tensor> dst_tensors;
dst_tensors.reserve(devices.size());
auto it = diff_device_dst_tensors.begin();
for (auto device : devices) {
if (device != tensor.get_device()) {
dst_tensors.emplace_back(*it++);
} else {
dst_tensors.emplace_back(tensor);
}
}
TORCH_INTERNAL_ASSERT(it == diff_device_dst_tensors.end());
return dst_tensors;
}
tensor_list2d broadcast_coalesced(TensorList tensors, IntArrayRef devices, size_t buffer_size)
{
TORCH_CHECK(
std::all_of(tensors.begin(), tensors.end(), [&](const at::Tensor& t) { return t.get_device() == devices[0]; }),
"All tensors must be on devices[0]: ",
devices[0], PTA_ERROR(ErrCode::TYPE));
buffer_size = std::min(get_max_count(), buffer_size);
tensor_list2d outputs(devices.size());
outputs[0] = tensors.vec();
for (auto& o : outputs) {
o.reserve(tensors.size());
}
unique_type_checker type_checker;
c10_npu::NPUGuard device_guard{devices[0]};
for (auto& chunk : torch::utils::take_tensors(tensors, buffer_size)) {
auto type_id = chunk.type_id();
type_checker.show(type_id);
std::vector<at::Tensor> results;
if (chunk.options().is_sparse()) {
auto flat_tuple = torch::utils::flatten_sparse_tensors(chunk.tensors);
auto broadcast_indices = broadcast(flat_tuple.first, devices);
auto broadcast_values = broadcast(flat_tuple.second, devices);
results.reserve(devices.size());
for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
device_guard.set_index(devices[i]);
auto& device_outputs = outputs[i];
auto& inds = broadcast_indices[i];
auto& vals = broadcast_values[i];
for (const auto& var : torch::utils::unflatten_sparse_tensors(inds, vals, chunk.tensors)) {
device_outputs.emplace_back(torch::autograd::make_variable(var.tensor_data(), false));
}
}
} else {
auto results = broadcast(torch::utils::flatten_dense_tensors(chunk.tensors), devices);
for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) {
device_guard.set_index(devices[i]);
auto& device_outputs = outputs[i];
for (auto& var : torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) {
device_outputs.emplace_back(torch::autograd::make_variable(var.tensor_data(), false));
}
}
}
}
if (!type_checker.unique) {
for (auto& o : outputs) {
torch::utils::reorder_tensors_like(o, tensors);
}
}
return outputs;
}
std::vector<at::Tensor>& scatter_out(const at::Tensor& tensor, std::vector<at::Tensor>& out_tensors, int64_t dim,
const c10::optional<std::vector<c10::optional<c10_npu::NPUStream>>>& streams)
{
TORCH_CHECK(!out_tensors.empty(), "Expected at least one output tensor to scatter to" + PTA_ERROR(ErrCode::VALUE));
dim = at::maybe_wrap_dim(dim, tensor);
int64_t total_size = 0;
std::vector<int64_t> chunk_sizes;
chunk_sizes.reserve(out_tensors.size());
for (const auto i : c10::irange(out_tensors.size())) {
TORCH_CHECK(torch_npu::utils::is_npu(out_tensors[i]),
"Expected all output tensors to be NPU tensors, but output tensor at index ",
i,
" has device '",
out_tensors[i].device(),
"'" + PTA_ERROR(ErrCode::TYPE));
auto out_sizes = out_tensors[i].sizes().vec();
bool same_ndim = out_sizes.size() == static_cast<size_t>(tensor.dim());
if (same_ndim) {
total_size += out_sizes[dim];
chunk_sizes.emplace_back(out_sizes[dim]);
out_sizes[dim] = tensor.size(dim);
}
TORCH_CHECK(same_ndim && out_sizes == tensor.sizes(),
"Output tensor at index ",
i,
" has incorrect shape: ",
out_tensors[i].sizes(),
". Expected same shape except for scatter dim ",
dim,
" as the source tensor: ",
at::IntArrayRef(tensor.sizes()), PTA_ERROR(ErrCode::VALUE));
}
TORCH_CHECK(total_size == tensor.size(dim),
"Total size for output tensors along scatter dim ",
dim,
" does not match the source tensor size at dim ",
dim,
". Expected ",
tensor.size(dim),
", but got total size ",
total_size, PTA_ERROR(ErrCode::VALUE));
auto chunks = tensor.split_with_sizes(chunk_sizes, dim);
c10_npu::OptionalNPUStreamGuard npu_guard;
for (const auto i : c10::irange(chunks.size())) {
if (i < (streams ? streams->size() : 0U) && (*streams)[i]) {
const auto device_index = static_cast<int16_t>(out_tensors[i].get_device());
TORCH_CHECK((*streams)[i]->device_index() == device_index,
"Expected the device associated with the stream at index ",
i,
" (was ",
(*streams)[i]->device_index(),
") ",
"to match the device supplied at that index (expected ",
device_index,
")" + PTA_ERROR(ErrCode::VALUE));
npu_guard.reset_stream(*(*streams)[i]);
}
out_tensors[i].copy_(chunks[i], true);
}
return out_tensors;
}
std::vector<at::Tensor> scatter(const at::Tensor& tensor, at::IntArrayRef devices,
const c10::optional<std::vector<int64_t>>& chunk_sizes, int64_t dim,
const c10::optional<std::vector<c10::optional<c10_npu::NPUStream>>>& streams)
{
TORCH_CHECK(!devices.empty(), "Expected at least one device to scatter to" + PTA_ERROR(ErrCode::VALUE));
if (chunk_sizes.has_value()) {
TORCH_CHECK(chunk_sizes->size() == devices.size(),
"Expected devices and chunk_sizes to be of same length, but got len(devices) = ",
devices.size(),
" and len(chunk_sizes) = ",
chunk_sizes->size(), PTA_ERROR(ErrCode::VALUE));
}
dim = at::maybe_wrap_dim(dim, tensor);
std::vector<at::Tensor> chunks = chunk_sizes
? tensor.split_with_sizes(*chunk_sizes, dim)
: tensor.chunk(devices.size(), dim);
c10_npu::OptionalNPUStreamGuard npu_guard;
for (const auto i : c10::irange(chunks.size())) {
const auto device_index = static_cast<int16_t>(devices[i]);
if (device_index != tensor.get_device()) {
if (i < (streams ? streams->size() : 0U) && (*streams)[i]) {
TORCH_CHECK((*streams)[i]->device_index() == device_index,
"Expected the device associated with the stream at index ",
i,
" (was ",
(*streams)[i]->device_index(),
") to match the device supplied at that index (expected ",
device_index,
")" + PTA_ERROR(ErrCode::VALUE));
npu_guard.reset_stream(*(*streams)[i]);
}
TORCH_CHECK(device_index >= 0, "Expected non-negative device index, but got ", device_index, PTA_ERROR(ErrCode::VALUE));
chunks[i] = chunks[i].to({DeviceType::PrivateUse1, device_index}, true, false);
}
}
return chunks;
}
static inline at::Tensor& _gather_out_impl(at::TensorList tensors, at::Tensor& out_tensor, int64_t dim)
{
std::vector<int64_t> chunk_sizes;
chunk_sizes.reserve(tensors.size());
for (auto& tensor : tensors) {
chunk_sizes.emplace_back(tensor.size(dim));
}
auto chunks = out_tensor.split_with_sizes(chunk_sizes, dim);
for (const auto i : c10::irange(tensors.size())) {
chunks[i].copy_(tensors[i], torch_npu::utils::is_npu(out_tensor));
}
return out_tensor;
}
at::Tensor& gather_out(at::TensorList tensors, at::Tensor& out_tensor, int64_t dim)
{
TORCH_CHECK(!tensors.empty(), "Expected at least one tensor to gather from" + PTA_ERROR(ErrCode::VALUE));
int64_t total_size = 0;
auto& first = tensors.front();
const auto first_size = first.sizes();
dim = at::maybe_wrap_dim(dim, first);
std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
for (const auto i : c10::irange(tensors.size())) {
const auto& tensor = tensors[i];
TORCH_CHECK(torch_npu::utils::is_npu(tensor),
"Expected all input tensors to be NPU tensors, but "
"tensor at index ",
i,
" has device '",
tensor.device(),
"'", PTA_ERROR(ErrCode::TYPE));
TORCH_CHECK(tensor.ndimension() == static_cast<int64_t>(expected_size.size()),
"Expected all input tensors to have the same number of dimensions, but ",
"tensor at index ",
i,
"has ",
tensor.ndimension(),
" dimensions, (expected ",
expected_size.size(),
")", PTA_ERROR(ErrCode::VALUE));
expected_size[dim] = tensor.size(dim);
for (const auto dimension : c10::irange(expected_size.size())) {
TORCH_CHECK(expected_size[dimension] == tensor.size(dimension),
"Input tensor at index ",
i,
" has invalid shape ",
tensor.sizes(),
", but expected ",
at::IntArrayRef(expected_size), PTA_ERROR(ErrCode::VALUE));
}
total_size += tensor.size(dim);
}
expected_size[dim] = total_size;
TORCH_CHECK(out_tensor.sizes() == expected_size,
"Expected out tensor to have shape ",
at::IntArrayRef(expected_size),
", but got ",
out_tensor.sizes(), PTA_ERROR(ErrCode::VALUE))
return _gather_out_impl(tensors, out_tensor, dim);
}
at::Tensor gather(at::TensorList tensors, int64_t dim, c10::optional<int32_t> destination_index)
{
TORCH_CHECK(!tensors.empty(), "Expected at least one tensor to gather from" + PTA_ERROR(ErrCode::VALUE));
int64_t total_size = 0;
auto& first = tensors.front();
const auto first_size = first.sizes();
dim = at::maybe_wrap_dim(dim, first);
std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
auto memory_format = first.suggest_memory_format();
for (const auto i : c10::irange(tensors.size())) {
const auto& tensor = tensors[i];
TORCH_CHECK(
torch_npu::utils::is_npu(tensor),
"Expected all input tensors to be NPU tensors, but "
"tensor at index ",
i,
" has device ",
tensor.device(), PTA_ERROR(ErrCode::TYPE));
TORCH_CHECK(
tensor.ndimension() == static_cast<int64_t>(expected_size.size()),
"Expected all input tensors to have the same number of dimensions, but ",
"tensor at index ",
i,
"has ",
tensor.ndimension(),
" dimensions, (expected ",
expected_size.size(),
")" + PTA_ERROR(ErrCode::VALUE));
expected_size[dim] = tensor.size(dim);
for (const auto dimension : c10::irange(expected_size.size())) {
TORCH_CHECK(
expected_size[dimension] == tensor.size(dimension),
"Input tensor at index ",
i,
" has invalid shape ",
tensor.sizes(),
", but expected ",
at::IntArrayRef(expected_size), PTA_ERROR(ErrCode::VALUE));
}
total_size += tensor.size(dim);
if (memory_format != MemoryFormat::Contiguous && tensor.suggest_memory_format() != memory_format) {
memory_format = MemoryFormat::Contiguous;
}
}
expected_size[dim] = total_size;
at::Device device(DeviceType::CPU);
if (!destination_index || *destination_index != -1) {
device = at::Device(DeviceType::PrivateUse1, destination_index ? *destination_index : -1);
}
at::Tensor result = at::empty(expected_size, first.options().device(device), memory_format);
return _gather_out_impl(tensors, result, dim);
}
}
}