#include <functional>
#include <c10/core/DeviceGuard.h>
#include <c10/core/StreamGuard.h>
#include <c10/util/Exception.h>
#include <c10/util/hash.h>
#include <c10d/comm.hpp>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/utils/grad_layout_contract.h>
#include <torch/csrc/autograd/utils/lambda_post_hook.h>
#include <c10d/debug.h>
#include "torch_npu/csrc/distributed/ProcessGroupHCCL.hpp"
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
#include "torch_npu/csrc/framework/utils/OpPreparation.h"
#include "torch_npu/csrc/core/NPUBridge.h"
#include "torch_npu/csrc/core/NPUStorageImpl.h"
#include "torch_npu/csrc/distributed/reducer.hpp"
namespace c10d_npu {
namespace {
int64_t physical_numel(at::Tensor self) {
auto sizes = torch_npu::NPUBridge::GetNpuStorageImpl(self)->npu_desc_.storage_sizes_;
int64_t n = 1;
for (auto s : sizes) {
n *= s;
}
return n;
}
constexpr int kUnsetDivFactor = -1;
c10d::DebugLevel g_debug_level = c10d::DebugLevel::Off;
c10d::DebugLevel debug_level() noexcept {
return g_debug_level;
}
#define REDUCER_CHECK(cond, logger_, ...) \
if (C10_UNLIKELY_OR_CONST(!(cond))) { \
if (!logger_.expired()) { \
logger_.lock()->set_error_and_log(__VA_ARGS__); \
} \
TORCH_CHECK(false, ##__VA_ARGS__, DIST_ERROR(ErrCode::INTERNAL)); \
}
}
C10_DEFINE_TYPED_REGISTRY(
TimerRegistry,
c10::DeviceType,
Timer,
std::unique_ptr,
c10::Device);
Reducer::Reducer(
std::vector<at::Tensor> params,
std::vector<std::vector<size_t>> bucket_indices,
std::vector<size_t> per_bucket_size_limits,
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
std::vector<bool> expect_sparse_gradients,
int64_t bucket_bytes_cap,
bool find_unused_parameters,
bool gradient_as_bucket_view,
std::unordered_map<size_t, std::string> paramNames,
int64_t first_bucket_bytes_cap)
: params_(std::move(params)),
process_group_(std::move(process_group)),
expect_sparse_gradients_(std::move(expect_sparse_gradients)),
expect_autograd_hooks_(false),
require_finalize_(false),
next_bucket_(0),
has_marked_unused_parameters_(false),
find_unused_parameters_(find_unused_parameters),
gradient_as_bucket_view_(gradient_as_bucket_view),
local_used_map_reduced_(false),
num_iterations_(0),
num_buckets_ready_(0),
has_rebuilt_bucket_(false),
bucket_bytes_cap_(bucket_bytes_cap),
div_factor_(kUnsetDivFactor),
static_graph_(false),
comm_hook_(nullptr),
ddp_debug_level_(debug_level()),
param_names_(std::move(paramNames)),
first_bucket_bytes_cap_(first_bucket_bytes_cap) {
C10_LOG_API_USAGE_ONCE("torch.distributed.ddp.reducer");
TORCH_INTERNAL_ASSERT(params_.size() >= 1,
"Expected at least one parameter.",
DIST_ERROR(ErrCode::PARAM));
if (ddp_debug_level_ != c10d::DebugLevel::Off) {
LOG(INFO) << "Reducer initialized with bucket_bytes_cap: "
<< bucket_bytes_cap_
<< " first_bucket_bytes_cap: " << first_bucket_bytes_cap;
}
{
std::set<int> unique_devices;
for (const auto& v : params_) {
auto device_idx = int(v.device().index());
if (unique_devices.find(device_idx) == unique_devices.end()) {
unique_devices.insert(device_idx);
if (unique_devices.size() > 1) {
is_multi_device_module_ = true;
break;
}
}
}
}
c10::Device device = params_[0].device();
if (!(device.type() == c10::DeviceType::PrivateUse1 &&
is_multi_device_module_)) {
timer_ = TimerRegistry()->Create(device.type(), device);
}
if (expect_sparse_gradients_.empty()) {
expect_sparse_gradients_ = std::vector<bool>(params_.size(), false);
}
TORCH_INTERNAL_ASSERT(expect_sparse_gradients_.size() == params_.size(),
DIST_ERROR(ErrCode::PARAM));
{
std::lock_guard<std::mutex> lock(mutex_);
initialize_buckets(std::move(bucket_indices),
std::move(per_bucket_size_limits));
}
{
const auto variable_count = params_.size();
grad_accumulators_.resize(variable_count);
for (const auto variable_index : c10::irange(variable_count)) {
auto& variable = params_[variable_index];
auto grad_accumulator =
torch::autograd::impl::grad_accumulator(variable);
#ifndef _WIN32
using torch::distributed::autograd::ThreadLocalDistAutogradContext;
#endif
hooks_.emplace_back(
grad_accumulator->add_post_hook(std::make_unique<
torch::autograd::utils::
LambdaPostHook>(
[=](const torch::autograd::variable_list& outputs,
const torch::autograd::variable_list& ) {
#ifndef _WIN32
this->rpc_context_.set(
ThreadLocalDistAutogradContext::getContextPtr());
#endif
this->autograd_hook(variable_index);
return outputs;
})),
grad_accumulator);
if (find_unused_parameters_) {
gradAccToVariableMap_[grad_accumulator.get()] = variable_index;
}
numGradHooksTriggeredMap_[variable_index] = 0;
REDUCER_CHECK(
grad_accumulators_[variable_index] == nullptr,
logger_,
c10::str(
"Reducer tried to register duplicate grad "
"accumulator for variable ",
variable_index), DIST_ERROR(ErrCode::PTR));
grad_accumulators_[variable_index] = std::move(grad_accumulator);
}
}
{
const auto variable_count = params_.size();
backward_stats_.resize(variable_count);
}
if (find_unused_parameters_) {
initialize_local_used_map();
}
}
Reducer::~Reducer() noexcept(false)
{
for (auto& hook : hooks_) {
auto& key = hook.first;
auto& grad_accumulator = hook.second;
TORCH_INTERNAL_ASSERT(
grad_accumulator->del_post_hook(key),
"Reducer attempts to delete a non-existing hook.", DIST_ERROR(ErrCode::INTERNAL));
}
}
bool Reducer::dynamic_graph_find_unused() const
{
return !static_graph_ && find_unused_parameters_;
}
bool Reducer::static_graph_first_iteration() const
{
return static_graph_ && num_iterations_ == 1;
}
bool Reducer::static_graph_after_first_iteration() const
{
return static_graph_ && num_iterations_ > 1;
}
bool Reducer::ddp_graph_static()
{
std::lock_guard<std::mutex> lock(mutex_);
return ddp_graph_static_;
}
void Reducer::initialize_local_used_map()
{
const auto variable_count = params_.size();
at::TensorOptions options;
options = options.dtype(at::kInt);
local_used_map_ =
at::zeros({static_cast<long>(variable_count)}, options);
options = options.device(params_[0].device());
local_used_map_dev_ =
at::empty({static_cast<long>(variable_count)}, options);
}
void Reducer::check_grad_layout(
const at::Tensor& grad,
const at::Tensor& bucket_view)
{
REDUCER_CHECK(
grad.options().type_equal(bucket_view.options()),
logger_,
c10::str("Expected ", bucket_view.toString(), ", got ", grad.toString()),
DIST_ERROR(ErrCode::PARAM));
TORCH_INTERNAL_ASSERT(grad.device() == bucket_view.device(),
DIST_ERROR(ErrCode::PARAM));
if (grad.strides() != bucket_view.strides()) {
TORCH_WARN_ONCE(
"Grad strides do not match bucket view strides. "
"This may indicate grad was not created according to the "
"gradient layout contract, or that the param's strides "
"changed since DDP was constructed. This is not an error, "
"but may impair performance.\n"
"grad.sizes() = ", grad.sizes(),
", strides() = ", grad.strides(), "\n",
"bucket_view.sizes() = ", bucket_view.sizes(),
", strides() = ", bucket_view.strides());
}
if (!gradient_as_bucket_view_) {
TORCH_INTERNAL_ASSERT(!grad.is_alias_of(bucket_view),
DIST_ERROR(ErrCode::PARAM));
}
}
void Reducer::mark_variable_ready_dense(size_t variable_index)
{
const auto replica_index = 0;
const auto& bucket_index = variable_locators_[variable_index];
auto& bucket = buckets_[bucket_index.bucket_index];
auto& replica = bucket.replicas[replica_index];
auto& variable = replica.variables[bucket_index.intra_bucket_index];
auto& bucket_view = replica.bucket_views_in[bucket_index.intra_bucket_index];
runGradCallbackForVariable(variable, [&](auto& grad) {
if (grad.defined()) {
this->check_grad_layout(grad, bucket_view);
if (!grad.is_alias_of(bucket_view)) {
if (torch_npu::NPUBridge::GetNpuStorageImpl(grad)->npu_desc_.npu_format_ !=
torch_npu::NPUBridge::GetNpuStorageImpl(variable)->npu_desc_.npu_format_) {
grad = at_npu::native::NPUNativeFunctions::npu_format_cast(grad,
torch_npu::NPUBridge::GetNpuStorageImpl(variable)->npu_desc_.npu_format_);
}
if (comm_hook_ == nullptr) {
if (!grad.requires_grad()) {
at_npu::native::NPUNativeFunctions::copy_memory_(bucket_view, grad.mul(float(1.) / div_factor_), true);
} else {
C10_LOG_EVERY_N(WARNING, 1000)
<< "Using DistributedDataParallel with create_graph=True "
<< " is not well-supported. The higher-order gradient will "
<< " not be synchronized across ranks, and backpropagation "
<< " through all_reduce operations will not occur.";
at_npu::native::NPUNativeFunctions::copy_memory_(bucket_view, grad.mul(float(1.) / div_factor_), true);
}
} else {
at_npu::native::NPUNativeFunctions::copy_memory_(bucket_view, grad, true);
}
if (gradient_as_bucket_view_) {
grad = bucket_view;
return true;
}
} else {
if (comm_hook_ == nullptr) {
bucket_view.div_(div_factor_);
}
}
} else {
if (this->dynamic_graph_find_unused() ||
this->static_graph_first_iteration()) {
REDUCER_CHECK(
local_used_map_[variable_index].item<int>() == 0,
logger_,
"Encountered gradient which is undefined, but still allreduced by "
"DDP reducer. This indicates a bug in DDP implementation, please "
"report a bug with a repro to PyTorch.", DIST_ERROR(ErrCode::INTERNAL)
);
}
bucket_view.zero_();
}
return false;
});
}
void Reducer::mark_variable_ready_sparse(size_t variable_index)
{
const auto replica_index = 0;
const auto& bucket_index = variable_locators_[variable_index];
auto& bucket = buckets_[bucket_index.bucket_index];
auto& replica = bucket.replicas[replica_index];
auto& variable = replica.variables[bucket_index.intra_bucket_index];
runGradCallbackForVariable(variable, [&](auto& grad) {
REDUCER_CHECK(
grad.defined(), logger_, "Expected sparse gradient to be defined.", DIST_ERROR(ErrCode::PARAM));
REDUCER_CHECK(
grad.options().layout() == c10::kSparse,
logger_,
"Expected variable to have sparse gradient.", DIST_ERROR(ErrCode::TYPE));
replica.contents = grad;
if (comm_hook_ == nullptr) {
replica.contents.div_(div_factor_);
}
return true;
});
}
std::vector<c10d::GradBucket> Reducer::get_grad_buckets(
bool return_zero_tensors) const
{
std::lock_guard<std::mutex> lock(mutex_);
std::vector<c10d::GradBucket> gradBuckets;
gradBuckets.reserve(buckets_.size());
for (const auto i : c10::irange(buckets_.size())) {
auto& bucket = buckets_[i];
auto variables_for_bucket = get_variables_for_bucket(i, bucket);
gradBuckets.emplace_back(
i,
buckets_.size(),
return_zero_tensors ? at::zeros_like(bucket.replicas[0].contents)
: bucket.replicas[0].contents,
bucket.replicas[0].offsets,
bucket.replicas[0].lengths,
bucket.replicas[0].sizes_vec,
variables_for_bucket,
c10::nullopt);
}
return gradBuckets;
}
void Reducer::set_forward_pass_work_handle(
c10::intrusive_ptr<c10d::Work> forwardPassWorkHandle,
bool useStaticWorldSize)
{
std::lock_guard<std::mutex> lock(mutex_);
forwardPassWorkHandle_.workHandle = std::move(forwardPassWorkHandle);
forwardPassWorkHandle_.useStaticWorldSize = useStaticWorldSize;
}
at::Tensor Reducer::get_local_used_map_on_device() const
{
std::lock_guard<std::mutex> lock(mutex_);
return local_used_map_dev_;
}
void Reducer::push_rebuilt_params_for_all_indices()
{
std::lock_guard<std::mutex> lock(mutex_);
if (!should_rebuild_buckets() || !rebuilt_param_indices_.empty()) {
return;
}
const auto variable_count = params_.size();
for (const auto variable_index : c10::irange(variable_count)) {
push_rebuilt_params(variable_index);
}
}
void Reducer::push_rebuilt_params(const size_t& index)
{
rebuilt_params_.push_back(params_[index]);
rebuilt_param_indices_.push_back(index);
}
void Reducer::set_divide_factor()
{
if (div_factor_ == kUnsetDivFactor) {
div_factor_ = process_group_->getSize();
auto& workHandle = forwardPassWorkHandle_.workHandle;
if (workHandle && !forwardPassWorkHandle_.useStaticWorldSize) {
workHandle->wait();
auto results = workHandle->result();
TORCH_INTERNAL_ASSERT(results.size() > 0, DIST_ERROR(ErrCode::INTERNAL));
at::Tensor& res = results.front();
div_factor_ = res.item().to<int>();
}
}
}
void Reducer::delay_all_reduce()
{
std::lock_guard<std::mutex> lock(this->mutex_);
if (should_collect_runtime_stats()) {
record_backward_compute_end_time();
record_backward_comm_start_time();
}
all_reduce_local_used_map();
unused_parameters_.clear();
for (const auto variable_index : c10::irange(params_.size())) {
if (numGradHooksTriggeredMap_[variable_index] == 0) {
unused_parameters_.push_back(variable_index);
}
require_finalize_ = true;
set_divide_factor();
if (expect_sparse_gradients_[variable_index]) {
mark_variable_ready_sparse(variable_index);
} else {
mark_variable_ready_dense(variable_index);
}
}
for (auto& bucket : buckets_) {
all_reduce_bucket(bucket);
}
finalize_backward();
}
void Reducer::set_logger(std::weak_ptr<c10d::Logger> logger)
{
logger_ = logger;
}
void Reducer::autograd_hook(size_t index)
{
std::lock_guard<std::mutex> lock(this->mutex_);
if (!expect_autograd_hooks_) {
return;
}
grad_ready_order_indices_.push_back(index);
if (dynamic_graph_find_unused() || static_graph_first_iteration()) {
auto& variable = get_param_from_index(index);
runGradCallbackForVariable(variable, [&](auto& grad) {
if (grad.defined()) {
local_used_map_[index] = 1;
}
return false;
});
}
if (static_graph_first_iteration()) {
numGradHooksTriggeredMap_[index] += 1;
return;
}
if (!has_marked_unused_parameters_) {
has_marked_unused_parameters_ = true;
for (const auto& unused_index : unused_parameters_) {
mark_variable_ready(unused_index);
}
}
if (static_graph_after_first_iteration()) {
REDUCER_CHECK(
numGradHooksTriggeredMapPerIteration_[index] > 0,
logger_,
"Your training graph has changed in this iteration, ",
"e.g., one parameter is unused in first iteration, but ",
"then got used in the second iteration. this is not ",
"compatible with static_graph set to True.", DIST_ERROR(ErrCode::NOT_SUPPORT));
if (--numGradHooksTriggeredMapPerIteration_[index] == 0) {
if (should_rebuild_buckets()) {
push_rebuilt_params(index);
}
mark_variable_ready(index);
}
} else {
if (should_rebuild_buckets()) {
push_rebuilt_params(index);
}
mark_variable_ready(index);
}
}
void Reducer::all_reduce_local_used_map()
{
local_used_map_dev_.copy_(local_used_map_, true);
std::vector<at::Tensor> temp_local_used_map_dev_vec_ = {local_used_map_dev_};
local_used_work_ = process_group_->allreduce(temp_local_used_map_dev_vec_);
}
at::Tensor& Reducer::get_param_from_index(size_t index)
{
const auto& bucket_index = variable_locators_[index];
auto& bucket = buckets_[bucket_index.bucket_index];
auto& replica = bucket.replicas[0];
auto& variable = replica.variables[bucket_index.intra_bucket_index];
return variable;
}
void Reducer::checkAndRaiseMarkedTwiceError(size_t index)
{
bool marked_twice =
perIterationReadyParams_.find(index) != perIterationReadyParams_.end();
if (marked_twice) {
auto param_name = param_names_.find(index);
const bool found_param_name = param_name != param_names_.end();
TORCH_INTERNAL_ASSERT(
ddp_debug_level_ == c10d::DebugLevel::Off ||
found_param_name,
"Expected to find parameter name in debug mode.", DIST_ERROR(ErrCode::PARAM));
std::string paramInfo = c10::str(
"Parameter at index ",
index,
found_param_name ? c10::str(" with name ", param_name->second) : "",
" has been marked as ready twice. This means that multiple autograd engine ",
" hooks have fired for this particular parameter during this iteration.");
if (!found_param_name) {
paramInfo += c10::str(
" You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either",
" INFO or DETAIL to print parameter names for further debugging.");
}
std::string common_error = c10::str(
"Expected to mark a variable ready only once. ",
"",
"This error is caused by one of the following reasons: ",
"1) Use of a module parameter outside the `forward` function. ",
"Please make sure model parameters are not shared across multiple ",
"concurrent forward-backward passes. or try to use _set_static_graph() ",
"as a workaround if this module graph does not change ",
"during training loop.",
"2) Reused parameters in multiple reentrant backward passes. For ",
"example, if you use multiple `checkpoint` functions to wrap the ",
"same part of your model, it would result in the same set of ",
"parameters been used by different reentrant backward passes ",
"multiple times, and hence marking a variable ready multiple times. ",
"DDP does not support such use cases in default. You can try to ",
"use _set_static_graph() as a workaround if your module graph ",
"does not change over iterations.");
common_error += c10::str("\n", paramInfo);
REDUCER_CHECK(
has_marked_unused_parameters_,
logger_,
common_error,
"3) Incorrect unused parameter detection. The return value of the ",
"`forward` function is inspected by the distributed data parallel ",
"wrapper to figure out if any of the module's parameters went ",
"unused. For unused parameters, DDP would not expect gradients from ",
"then. However, if an unused parameter becomes part of the autograd ",
"graph at a later point in time (e.g., in a reentrant backward when ",
"using `checkpoint`), the gradient will show up unexpectedly. If all ",
"parameters in the model participate in the backward pass, you can ",
"disable unused parameter detection by passing the keyword argument ",
"`find_unused_parameters=False` to ",
"`torch.nn.parallel.DistributedDataParallel`. If unused parameters ",
"in the model do not change over iterations, You can try to use ",
"_set_static_graph() as a workaround if this module graph does not ",
"change during training loop.", DIST_ERROR(ErrCode::PARAM));
REDUCER_CHECK(!has_marked_unused_parameters_, logger_, common_error, DIST_ERROR(ErrCode::PARAM));
}
}
void Reducer::mark_variable_ready(size_t variable_index)
{
REDUCER_CHECK(variable_index < variable_locators_.size(), logger_,
"Out of range variable index.", DIST_ERROR(ErrCode::PARAM));
checkAndRaiseMarkedTwiceError(variable_index);
perIterationReadyParams_.insert(variable_index);
backward_stats_[variable_index] =
current_time_in_nanos() - backward_compute_start_time_;
require_finalize_ = true;
const auto& bucket_index = variable_locators_[variable_index];
auto& bucket = buckets_[bucket_index.bucket_index];
auto& replica = bucket.replicas[0];
set_divide_factor();
if (bucket.expect_sparse_gradient) {
mark_variable_ready_sparse(variable_index);
} else {
mark_variable_ready_dense(variable_index);
}
if (--replica.pending == 0) {
if (--bucket.pending == 0) {
mark_bucket_ready(bucket_index.bucket_index);
}
}
if (next_bucket_ == buckets_.size()) {
if (dynamic_graph_find_unused()) {
all_reduce_local_used_map();
}
torch::autograd::Engine::get_default_engine().queue_callback([=] {
std::lock_guard<std::mutex> lock(this->mutex_);
if (should_collect_runtime_stats()) {
record_backward_compute_end_time();
}
TORCH_INTERNAL_ASSERT(next_bucket_ == buckets_.size(),
DIST_ERROR(ErrCode::INTERNAL));
if (static_graph_after_first_iteration() &&
should_rebuild_buckets()) {
for (const auto& unused_index : unused_parameters_) {
push_rebuilt_params(unused_index);
}
}
this->finalize_backward();
});
}
}
c10::intrusive_ptr<c10::ivalue::Future> Reducer::run_comm_hook(
c10d::GradBucket& grad_bucket)
{
if (comm_hook_ == nullptr) {
return run_allreduce_hook(grad_bucket);
} else {
return comm_hook_->runHook(grad_bucket);
}
}
c10::intrusive_ptr<c10::ivalue::Future> Reducer::run_allreduce_hook(
c10d::GradBucket& grad_bucket)
{
c10d::_AllReduceBySumCommHook allreduce_hook(process_group_);
return allreduce_hook.runHook(grad_bucket);
}
void Reducer::all_reduce_bucket(Bucket& bucket)
{
std::vector<at::Tensor> tensors;
tensors.reserve(bucket.replicas.size());
for (const auto& replica : bucket.replicas) {
tensors.push_back(replica.contents);
}
auto variables_for_bucket = get_variables_for_bucket(next_bucket_, bucket);
c10d::GradBucket grad_bucket(
next_bucket_,
buckets_.size(),
tensors[0],
bucket.replicas[0].offsets,
bucket.replicas[0].lengths,
bucket.replicas[0].sizes_vec,
variables_for_bucket,
c10::nullopt);
bucket.future_work = run_comm_hook(grad_bucket);
}
std::vector<at::Tensor> Reducer::get_variables_for_bucket(
size_t bucket_index,
const Bucket& bucket) const
{
if (has_rebuilt_bucket_ &&
cached_variables_for_bucket_.find(bucket_index) !=
cached_variables_for_bucket_.end()) {
return cached_variables_for_bucket_[bucket_index];
}
std::vector<at::Tensor> variables_for_bucket;
variables_for_bucket.reserve(bucket.variable_indices.size());
for (const auto& variable_index : bucket.variable_indices) {
auto& replica = bucket.replicas[0];
auto& bucket_index_for_variable = variable_locators_[variable_index];
auto& variable =
replica.variables[bucket_index_for_variable.intra_bucket_index];
variables_for_bucket.emplace_back(variable);
}
if (has_rebuilt_bucket_) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
cached_variables_for_bucket_.find(bucket_index) ==
cached_variables_for_bucket_.end(), DIST_ERROR(ErrCode::PARAM));
cached_variables_for_bucket_.insert(
{bucket_index, std::move(variables_for_bucket)});
return cached_variables_for_bucket_[bucket_index];
} else {
return variables_for_bucket;
}
}
void Reducer::mark_bucket_ready(size_t bucket_index)
{
TORCH_INTERNAL_ASSERT(bucket_index >= next_bucket_,
DIST_ERROR(ErrCode::PARAM));
if (bucket_index > next_bucket_) {
return;
}
for (;
next_bucket_ < buckets_.size() && buckets_[next_bucket_].pending == 0;
next_bucket_++) {
num_buckets_ready_++;
if (num_buckets_ready_ == 1 && should_collect_runtime_stats()) {
record_backward_comm_start_time();
}
auto& bucket = buckets_[next_bucket_];
all_reduce_bucket(bucket);
}
}
void Reducer::install_futures(c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futs)
{
if (!installed_futures_) {
installed_futures_ = std::move(futs);
} else {
installed_futures_->append(futs);
}
}
void Reducer::initialize_buckets(
std::vector<std::vector<size_t>> bucket_indices,
std::vector<size_t> per_bucket_sizes)
{
#ifndef _WIN32
using torch::distributed::autograd::ThreadLocalDistAutogradContext;
this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr());
#endif
REDUCER_CHECK(
!expect_autograd_hooks_,
logger_,
"`initialize_buckets` must NOT be called during autograd execution.", DIST_ERROR(ErrCode::PARAM));
buckets_.clear();
variable_locators_.clear();
variable_locators_.resize(params_.size());
const auto bucket_count = bucket_indices.size();
buckets_.reserve(bucket_count);
TORCH_INTERNAL_ASSERT(bucket_count == per_bucket_sizes.size(), DIST_ERROR(ErrCode::PARAM));
for (const auto bucket_index : c10::irange(bucket_count)) {
Bucket bucket;
bucket.bucket_size_limit = per_bucket_sizes[bucket_index];
REDUCER_CHECK(
bucket_indices[bucket_index].size() > 0,
logger_,
"Empty bucket specified.", DIST_ERROR(ErrCode::PARAM));
if (bucket_indices[bucket_index].size() == 1) {
const auto variable_index = bucket_indices[bucket_index].front();
bucket.expect_sparse_gradient =
expect_sparse_gradients_[variable_index];
} else {
for (const auto variable_index : bucket_indices[bucket_index]) {
REDUCER_CHECK(
!expect_sparse_gradients_[variable_index],
logger_,
"Buckets with more than one variable cannot include variables ",
"that expect a sparse gradient.", DIST_ERROR(ErrCode::PARAM));
}
}
BucketReplica replica;
if (bucket.expect_sparse_gradient) {
const auto variable_index = bucket_indices[bucket_index].front();
const auto& variable = params_[variable_index];
TORCH_INTERNAL_ASSERT(bucket_indices[bucket_index].size() == 1, DIST_ERROR(ErrCode::PARAM));
replica.variables = {variable};
} else {
at::TensorOptions options;
size_t offset = 0;
const size_t num_variables = bucket_indices[bucket_index].size();
replica.variables.reserve(num_variables);
replica.offsets.reserve(num_variables);
replica.lengths.reserve(num_variables);
replica.sizes_vec.reserve(num_variables);
for (const auto variable_index : bucket_indices[bucket_index]) {
TORCH_INTERNAL_ASSERT(
variable_index < params_.size(),
"Out of range variable index specified.", DIST_ERROR(ErrCode::PARAM));
const auto& variable = params_[variable_index];
if (!options.has_device()) {
options = options.device(variable.device());
} else {
REDUCER_CHECK(
variable.device() == options.device(),
logger_,
"All parameters in a bucket must be ",
"placed on the same device.", DIST_ERROR(ErrCode::PARAM));
}
if (!options.has_dtype()) {
options = options.dtype(variable.dtype());
} else {
REDUCER_CHECK(
variable.dtype() == options.dtype(),
logger_,
"All parameters in a bucket must have the same dtype.", DIST_ERROR(ErrCode::TYPE));
}
const auto length = physical_numel(variable);
replica.variables.push_back(variable);
replica.offsets.push_back(offset);
replica.lengths.push_back(length);
replica.sizes_vec.push_back(variable.sizes());
offset += length;
}
replica.contents = at::empty({static_cast<long>(offset)}, options);
initialize_bucket_views(replica, replica.contents);
}
bucket.replicas.push_back(std::move(replica));
size_t intra_bucket_index = 0;
for (const auto variable_index : bucket_indices[bucket_index]) {
TORCH_INTERNAL_ASSERT(
variable_index < variable_locators_.size(),
"Out of range variable index specified.", DIST_ERROR(ErrCode::PARAM));
variable_locators_[variable_index] =
VariableLocator(bucket_index, intra_bucket_index++);
}
bucket.variable_indices = std::move(bucket_indices[bucket_index]);
buckets_.push_back(std::move(bucket));
}
}
void Reducer::initialize_bucket_views(
Reducer::BucketReplica& replica,
at::Tensor& contents)
{
for (const auto i : c10::irange(replica.variables.size())) {
auto& v = replica.variables[i];
const auto offset = replica.offsets[i];
const auto length = replica.lengths[i];
if (!gradient_as_bucket_view_) {
replica.bucket_views_in.push_back(contents.narrow(0, offset, length));
} else {
replica.bucket_views_in.push_back(contents.narrow(0, offset, length).view(v.sizes()));
}
replica.bucket_views_out = replica.bucket_views_in;
if (gradient_as_bucket_view_) {
auto& bucket_view = replica.bucket_views_in.back();
runGradCallbackForVariable(v, [&](auto& grad) {
if (grad.defined() && !grad.is_alias_of(bucket_view)) {
bucket_view.copy_(grad);
grad = bucket_view;
return true;
}
return false;
});
}
}
}
void Reducer::populate_bucket_views_out(
Reducer::BucketReplica& replica,
at::Tensor& tensor) const
{
replica.bucket_views_out.clear();
for (size_t i = 0; i < replica.variables.size(); i++) {
const auto& v = replica.variables[i];
const auto offset = replica.offsets[i];
const auto length = replica.lengths[i];
replica.bucket_views_out.push_back(
tensor.narrow(0, offset, length));
}
}
void Reducer::prepare_for_forward()
{
std::lock_guard<std::mutex> lock(mutex_);
num_iterations_++;
if (should_collect_runtime_stats()) {
record_forward_compute_start_time();
}
}
void Reducer::reset_bucket_counting()
{
next_bucket_ = 0;
num_buckets_ready_ = 0;
for (auto& bucket : buckets_) {
for (auto& replica : bucket.replicas) {
replica.pending = replica.variables.size();
}
bucket.pending = bucket.replicas.size();
}
if (static_graph_) {
numGradHooksTriggeredMapPerIteration_ = numGradHooksTriggeredMap_;
}
}
void Reducer::search_unused_parameters(
const std::vector<torch::autograd::Variable>& outputs)
{
std::unordered_set<torch::autograd::Node*> seen;
std::vector<torch::autograd::Node*> queue;
RECORD_FUNCTION(
"torch.distributed.ddp.reducer::search_unused_parameters",
std::vector<c10::IValue>());
for (const auto& output : outputs) {
const auto& grad_fn = output.grad_fn();
if (grad_fn) {
queue.push_back(grad_fn.get());
}
}
while (!queue.empty()) {
auto fn = queue.back();
queue.pop_back();
for (const auto& edge : fn->next_edges()) {
if (auto next_ptr = edge.function.get()) {
const bool was_inserted = seen.insert(next_ptr).second;
if (was_inserted) {
queue.push_back(next_ptr);
}
}
}
}
for (const auto& it : gradAccToVariableMap_) {
if (seen.count(it.first) == 0) {
if (ddp_debug_level_ == c10d::DebugLevel::Detail) {
const auto param_info = param_names_.find(it.second);
TORCH_INTERNAL_ASSERT(
param_info != param_names_.end(),
"Did not find variable index ",
it.second,
" in DDP parameter name mapping!", DIST_ERROR(ErrCode::PARAM));
const auto param_name = param_info->second;
LOG(INFO) << "[Rank " << process_group_->getRank() << "]: "
<< "Parameter " << param_name << " at index " << it.second
<< " is marked as unused.";
}
unused_parameters_.push_back(it.second);
}
}
if (unused_parameters_.empty()) {
TORCH_WARN_ONCE(
"find_unused_parameters=True was specified in DDP constructor, "
"but did not find any unused parameters in the forward pass. This flag "
"results in an extra traversal of the autograd graph every iteration, "
" which can adversely affect performance. If your model indeed never "
"has any unused parameters in the forward pass, consider turning this "
"flag off. Note that this warning may be a false positive if your model "
"has flow control causing later iterations to have unused parameters.");
}
if (!static_graph_ && ddp_graph_static_) {
if (num_iterations_ > 1) {
ddp_graph_static_ =
prev_iteration_unused_parameters_ == unused_parameters_;
if (!ddp_graph_static_) {
logger_.lock()->log_if_graph_static(false);
}
}
prev_iteration_unused_parameters_ = unused_parameters_;
}
}
void Reducer::prepare_for_backward(
const std::vector<torch::autograd::Variable>& outputs)
{
std::lock_guard<std::mutex> lock(mutex_);
backward_compute_start_time_ = current_time_in_nanos();
if (should_collect_runtime_stats()) {
record_backward_compute_start_time();
}
expect_autograd_hooks_ = true;
grad_ready_order_indices_.clear();
reset_bucket_counting();
has_marked_unused_parameters_ = false;
perIterationReadyParams_.clear();
if (dynamic_graph_find_unused()) {
unused_parameters_.clear();
search_unused_parameters(outputs);
}
}
void Reducer::copy_bucket_to_grad(
torch::autograd::Variable& variable,
Reducer::BucketReplica& replica,
size_t intra_bucket_index,
bool global_unused)
{
const auto& bucket_view = replica.bucket_views_out[intra_bucket_index];
runGradCallbackForVariable(variable, [&](auto& grad) {
if (!global_unused) {
if (!grad.defined()) {
grad = at_npu::native::OpPreparation::ApplyTensorWithFormat(
variable.sizes(), bucket_view.options(),
torch_npu::NPUBridge::GetNpuStorageImpl(variable)->npu_desc_.npu_format_);
at_npu::native::NPUNativeFunctions::copy_memory_(grad, bucket_view, true);
} else {
at_npu::native::NPUNativeFunctions::copy_memory_(grad, bucket_view, true);
}
return true;
}
return false;
});
}
std::vector<std::string> Reducer::getUnmarkedParamsForIteration()
{
std::vector<std::string> unMarkedParamNames;
for (const auto& it : param_names_) {
if (perIterationReadyParams_.find(it.first) ==
perIterationReadyParams_.end()) {
unMarkedParamNames.push_back(it.second);
}
}
return unMarkedParamNames;
}
std::vector<size_t> Reducer::getUnmarkedParamIndicesForIteration()
{
std::vector<size_t> unmarked_param_indices;
const auto variable_count = params_.size();
for (const auto variable_index : c10::irange(variable_count)) {
if (perIterationReadyParams_.find(variable_index) ==
perIterationReadyParams_.end()) {
unmarked_param_indices.push_back(variable_index);
}
}
return unmarked_param_indices;
}
void Reducer::finalize_bucket_dense(Bucket& bucket)
{
size_t replica_index = 0;
auto& replica = bucket.replicas[replica_index];
for (const auto intra_bucket_index : c10::irange(replica.variables.size())) {
auto& variable = replica.variables[intra_bucket_index];
bool global_unused = false;
if (static_graph_ || find_unused_parameters_) {
size_t variable_index = bucket.variable_indices[intra_bucket_index];
global_unused =
local_used_map_[variable_index].item<int>() == 0;
if (global_unused && !local_used_map_reduced_) {
local_used_work_->wait();
local_used_map_.copy_(local_used_map_dev_);
global_unused =
local_used_map_[variable_index].item<int>() == 0;
local_used_map_reduced_ = true;
}
}
if (!gradient_as_bucket_view_) {
RECORD_FUNCTION(
"torch.distributed.ddp.reducer::copy_bucket_to_grad",
std::vector<c10::IValue>({variable}));
copy_bucket_to_grad(variable, replica, intra_bucket_index, global_unused);
} else {
const auto& bucket_view_out =
replica.bucket_views_out[intra_bucket_index];
auto& bucket_view_in = replica.bucket_views_in[intra_bucket_index];
if (!bucket_view_in.is_alias_of(bucket_view_out)) {
bucket_view_in.copy_(bucket_view_out);
}
runGradCallbackForVariable(variable, [&](auto& grad) {
if (!global_unused) {
if (!grad.defined()) {
grad = bucket_view_in;
} else {
if (!grad.is_alias_of(bucket_view_in)) {
REDUCER_CHECK(
false,
logger_,
"Detected at least one parameter gradient is not the "
"expected DDP bucket view with gradient_as_bucket_view=True. "
"This may happen (for example) if multiple allreduce hooks "
"were registered onto the same parameter. If you hit this error, "
"please file an issue with a minimal repro.", DIST_ERROR(ErrCode::PARAM));
}
}
return true;
}
return false;
});
}
}
}
void Reducer::finalize_backward()
{
TORCH_INTERNAL_ASSERT(expect_autograd_hooks_,
DIST_ERROR(ErrCode::INTERNAL));
expect_autograd_hooks_ = false;
TORCH_INTERNAL_ASSERT(require_finalize_, DIST_ERROR(ErrCode::INTERNAL));
require_finalize_ = false;
for (auto& bucket : buckets_) {
if (comm_hook_ == nullptr) {
TORCH_INTERNAL_ASSERT(
bucket.future_work,
"Expected bucket.work not to be null. "
"This may indicate that allreduce hooks were not "
"properly installed.",
DIST_ERROR(ErrCode::PARAM));
bucket.future_work->wait();
auto future_result = c10d::detail::parseCppCommHookResult(
bucket.future_work->value());
auto& replica = bucket.replicas[0];
if (bucket.expect_sparse_gradient) {
replica.contents.copy_(future_result);
} else {
populate_bucket_views_out(replica, future_result);
}
} else {
TORCH_INTERNAL_ASSERT(bucket.future_work,
"Expected bucket.future_work not to be null. "
"This may indicate that communication hook "
"was not properly installed.",
DIST_ERROR(ErrCode::PARAM));
bucket.future_work->wait();
auto future_result =
comm_hook_->parseHookResult(bucket.future_work->value());
auto& replica = bucket.replicas[0];
if (bucket.expect_sparse_gradient) {
replica.contents.copy_(future_result);
} else {
populate_bucket_views_out(replica, future_result);
}
}
div_factor_ = kUnsetDivFactor;
if (!bucket.expect_sparse_gradient) {
finalize_bucket_dense(bucket);
}
}
if (installed_futures_ != c10::nullopt) {
c10::collectAll(*installed_futures_)->wait();
installed_futures_ = c10::nullopt;
}
if (dynamic_graph_find_unused() || static_graph_first_iteration()) {
if (!local_used_map_reduced_) {
local_used_work_->wait();
}
}
if (dynamic_graph_find_unused()) {
local_used_map_.fill_(0);
local_used_map_reduced_ = false;
}
if (should_collect_runtime_stats()) {
record_backward_comm_end_time();
}
}
void Reducer::runGradCallbackForVariable(
at::Tensor& variable,
GradCallback&& cb)
{
#ifdef _WIN32
cb(variable.mutable_grad());
#else
auto context_ptr = rpc_context_.context_ptr.load();
if (context_ptr == nullptr) {
cb(variable.mutable_grad());
} else {
context_ptr->runGradCallbackForVariable(variable, std::move(cb));
}
#endif
}
#ifndef _WIN32
void Reducer::RpcContext::set(ContextPtr&& new_context_ptr)
{
const auto new_context_raw_ptr = new_context_ptr.get();
if (context_ptr.exchange(new_context_raw_ptr) != new_context_raw_ptr) {
context_ptr_holder = std::move(new_context_ptr);
}
}
#endif
void Reducer::sync_bucket_indices(
std::vector<std::vector<size_t>>& bucket_indices)
{
auto num_buckets = bucket_indices.size();
std::vector<size_t> bucket_sizes;
bucket_sizes.reserve(num_buckets);
int64_t total_size = 0;
for (const auto i : c10::irange(num_buckets)) {
auto bucket_size = bucket_indices.at(i).size();
bucket_sizes.push_back(bucket_size);
total_size += static_cast<int64_t>(bucket_size);
}
at::TensorOptions options;
options = options.dtype(at::kInt);
options = options.device(params_[0].device());
auto indices_tensor = at::empty({total_size + 1}, at::kInt);
auto indices_accessor = indices_tensor.accessor<int, 1>();
auto indices_accessor_Index = 0;
for (const auto i : c10::irange(num_buckets)) {
const auto& bucket_size = bucket_indices.at(i).size();
for (const auto j : c10::irange(bucket_size)) {
indices_accessor[indices_accessor_Index++] = static_cast<int>(bucket_indices[i][j]);
}
}
indices_accessor[indices_accessor_Index] = static_cast<int>(num_buckets);
auto indices_tensor_device = at::empty({total_size + 1}, options);
indices_tensor_device.copy_(indices_tensor, true);
std::vector<at::Tensor> indices_tensor_list = {indices_tensor_device};
process_group_->broadcast(indices_tensor_list)->wait();
indices_tensor.copy_(indices_tensor_list.front(), false);
num_buckets = static_cast<size_t>(indices_accessor[indices_accessor_Index]);
auto bucket_sizes_tensor = at::empty({static_cast<int64_t>(num_buckets)}, at::kInt);
auto bucket_sizes_accessor = bucket_sizes_tensor.accessor<int, 1>();
for (const auto i : c10::irange(num_buckets)) {
bucket_sizes_accessor[i] =
static_cast<int>(bucket_sizes.at(std::min(i, (bucket_sizes.size() - 1))));
}
auto bucket_sizes_tensor_device = at::empty({static_cast<int64_t>(num_buckets)}, options);
bucket_sizes_tensor_device.copy_(bucket_sizes_tensor, true);
std::vector<at::Tensor> bucket_sizes_tensor_list = {
bucket_sizes_tensor_device};
process_group_->broadcast(bucket_sizes_tensor_list)->wait();
bucket_sizes_tensor.copy_(
bucket_sizes_tensor_list.front(), false);
bucket_indices.clear();
bucket_indices.reserve(num_buckets);
indices_accessor_Index = 0;
for (const auto i : c10::irange(num_buckets)) {
const auto& bucket_size = bucket_sizes_accessor[i];
std::vector<size_t> bucket;
bucket.reserve(bucket_size);
for (const auto j : c10::irange(bucket_size)) {
(void)j;
bucket.push_back(indices_accessor[indices_accessor_Index++]);
}
bucket_indices.emplace_back(std::move(bucket));
}
}
bool Reducer::rebuild_buckets()
{
std::lock_guard<std::mutex> lock(mutex_);
ensure_prior_reduction_finished();
if (!should_rebuild_buckets() || rebuilt_params_.empty()) {
return false;
}
TORCH_INTERNAL_ASSERT(
rebuilt_params_.size() == rebuilt_param_indices_.size(),
c10::str(
"rebuilt parameter tensors size is not same as rebuilt parameter indices size: ",
rebuilt_params_.size(),
" versus ",
rebuilt_param_indices_.size()), DIST_ERROR(ErrCode::PARAM));
TORCH_INTERNAL_ASSERT(
params_.size() == rebuilt_param_indices_.size(),
c10::str(
"rebuilt parameter indices size is not same as original model parameters size.",
"Original model param size is: ",
params_.size(),
" versus rebuilt params size of: ",
rebuilt_param_indices_.size()), DIST_ERROR(ErrCode::PARAM));
std::vector<std::vector<size_t>> rebuilt_bucket_indices;
std::vector<size_t> bucket_size_limits;
bucket_size_limits.push_back(first_bucket_bytes_cap_);
bucket_size_limits.push_back(bucket_bytes_cap_);
std::vector<size_t> per_bucket_size_limits;
auto ddp_set_last_bucket_as_small =
(c10d::getCvarString({"DDP_SET_LAST_BUCKET_CAP"}, "N/A").compare("1") == 0);
if (ddp_set_last_bucket_as_small) {
std::reverse(rebuilt_params_.begin(), rebuilt_params_.end());
std::reverse(rebuilt_param_indices_.begin(), rebuilt_param_indices_.end());
}
std::tie(rebuilt_bucket_indices, per_bucket_size_limits) =
c10d_npu::compute_bucket_assignment_by_size(
rebuilt_params_,
bucket_size_limits,
expect_sparse_gradients_,
rebuilt_param_indices_,
logger_);
if (ddp_set_last_bucket_as_small) {
std::reverse(rebuilt_bucket_indices.begin(), rebuilt_bucket_indices.end());
std::reverse(per_bucket_size_limits.begin(), per_bucket_size_limits.end());
}
if (ddp_debug_level_ != c10d::DebugLevel::Off) {
TORCH_INTERNAL_ASSERT(
rebuilt_bucket_indices.size() == per_bucket_size_limits.size(), DIST_ERROR(ErrCode::PARAM))
LOG(INFO) << rebuilt_bucket_indices.size()
<< " buckets rebuilt with size limits: "
<< c10::Join(", ", per_bucket_size_limits)
<< " bytes.";
}
sync_bucket_indices(rebuilt_bucket_indices);
has_rebuilt_bucket_ = true;
rebuilt_params_.clear();
rebuilt_param_indices_.clear();
initialize_buckets(
std::move(rebuilt_bucket_indices), std::move(per_bucket_size_limits));
return true;
}
void Reducer::register_comm_hook(std::unique_ptr<c10d::CommHookInterface> iface)
{
REDUCER_CHECK(
comm_hook_ == nullptr,
logger_,
"register_comm_hook or register_builtin_comm_hook can only be called once.",
DIST_ERROR(ErrCode::PTR));
comm_hook_ = std::move(iface);
}
void Reducer::register_builtin_comm_hook(c10d::BuiltinCommHookType comm_hook_type)
{
REDUCER_CHECK(
comm_hook_ == nullptr,
logger_,
"register_builtin_comm_hook or register_comm_hook can only be called once.",
DIST_ERROR(ErrCode::PTR));
switch (comm_hook_type) {
case c10d::BuiltinCommHookType::ALLREDUCE:
comm_hook_ = std::make_unique<c10d::AllReduceCommHook>(process_group_);
LOG(INFO) << "Built-in communication hook ALLREDUCE is registered.";
break;
case c10d::BuiltinCommHookType::FP16_COMPRESS:
comm_hook_ =std::make_unique<c10d::FP16CompressCommHook>(process_group_);
LOG(INFO) << "Built-in communication hook FP16_COMPRESS is registered.";
break;
default:
TORCH_WARN_ONCE("Unknown built-in DDP comm hook type is provided. No comm hook will be used.");
}
}
void Reducer::ensure_prior_reduction_finished()
{
if (require_finalize_) {
auto unmarked_param_indices = getUnmarkedParamIndicesForIteration();
TORCH_INTERNAL_ASSERT(unmarked_param_indices.size() > 0, DIST_ERROR(ErrCode::PARAM));
const std::string unmarkedParamIndices =
c10::Join(", ", unmarked_param_indices);
std::string kBaseErrorMsg =
"Expected to have finished reduction in the prior iteration before "
"starting a new one. "
""
"This error indicates that your module has parameters that were "
"not used in producing loss. ";
std::string kOutputsNotUsedInLossErrorMsg =
"making sure all "
"`forward` function outputs participate in calculating loss. ";
std::string kDDPBugErrorMsg =
"\nIf you already have done the above, then the distributed "
"data parallel module wasn't able to locate the output tensors in the "
"return value of your module's `forward` function. "
"Please include the loss function and the structure of the return "
"value of `forward` of your module when reporting this issue (e.g. "
"list, dict, iterable).";
if (static_graph_) {
kBaseErrorMsg =
"Expected to have finished reduction in the prior iteration before "
"starting a new one. "
"This error indicates that your training graph has changed "
"in this iteration, e.g., one parameter is used in first "
"iteration, but then got unused in the second iteration. "
"this is not compatible with static_graph set to True.";
} else if (!find_unused_parameters_) {
kBaseErrorMsg +=
"You can enable unused parameter detection by passing the "
"keyword argument `find_unused_parameters=True` to "
"`torch.nn.parallel.DistributedDataParallel`, and by \n";
kBaseErrorMsg += kOutputsNotUsedInLossErrorMsg;
kBaseErrorMsg += kDDPBugErrorMsg;
} else {
kBaseErrorMsg +=
"Since `find_unused_parameters=True` is enabled, this likely "
" means that not all `forward` outputs participate in computing loss. You can fix this by ";
kBaseErrorMsg += kOutputsNotUsedInLossErrorMsg;
kBaseErrorMsg += kDDPBugErrorMsg;
}
const std::string unmarked_param_indices_info = c10::str(
"\n",
"Parameter indices which did not receive grad for rank ",
process_group_->getRank(),
": ",
unmarked_param_indices);
if (ddp_debug_level_ == c10d::DebugLevel::Off) {
kBaseErrorMsg += unmarked_param_indices_info;
kBaseErrorMsg +=
"\n In addition, you can set the environment variable "
"TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information "
"about which particular parameters did not receive gradient on this rank "
"as part of this error";
} else {
auto unmarkedParams = getUnmarkedParamsForIteration();
TORCH_INTERNAL_ASSERT(unmarkedParams.size() > 0, DIST_ERROR(ErrCode::PARAM));
for (const auto& s : unmarkedParams) {
LOG(INFO) << "[Rank " << process_group_->getRank() << "] "
<< "Parameter: " << s
<< " did not get gradient in backwards pass.";
}
const std::string unmarkedParamInfo = c10::Join(", ", unmarkedParams);
kBaseErrorMsg += c10::str(
"\n",
"Parameters which did not receive grad for rank ",
process_group_->getRank(),
": ",
unmarkedParamInfo);
kBaseErrorMsg += unmarked_param_indices_info;
}
REDUCER_CHECK(false, logger_, kBaseErrorMsg, DIST_ERROR(ErrCode::PARAM));
}
}
void Reducer::set_ddp_runtime_logging_sample_rate(int sample_rate)
{
ddp_runtime_logging_sample_rate_ = sample_rate;
}
int Reducer::get_ddp_runtime_logging_sample_rate() const
{
return ddp_runtime_logging_sample_rate_;
}
bool Reducer::should_collect_runtime_stats()
{
if (num_iterations_ > 0 &&
(num_iterations_ <= 10 ||
num_iterations_ % get_ddp_runtime_logging_sample_rate() == 0)) {
return true;
}
return false;
}
void Reducer::record_forward_compute_start_time()
{
if (timer_) {
timer_->record(Timer::Event::kForwardStart);
}
}
void Reducer::record_backward_compute_start_time()
{
if (timer_) {
timer_->record(Timer::Event::kBackwardComputeStart);
}
}
void Reducer::record_backward_compute_end_time()
{
if (timer_) {
timer_->record(Timer::Event::kBackwardComputeEnd);
}
}
void Reducer::record_backward_comm_start_time()
{
if (timer_) {
timer_->record(Timer::Event::kBackwardCommStart);
}
}
void Reducer::record_backward_comm_end_time()
{
if (timer_) {
timer_->record(Timer::Event::kBackwardCommEnd);
}
}
void Reducer::set_static_graph()
{
std::lock_guard<std::mutex> lock(mutex_);
REDUCER_CHECK(
num_iterations_ == 0,
logger_,
"set_static_graph() should be called before training loop starts "
"and after DistributedDataParallel is constructed.", DIST_ERROR(ErrCode::PARAM));
static_graph_ = true;
initialize_local_used_map();
}
namespace {
struct BucketKey {
BucketKey(c10::ScalarType type, c10::Device device)
: type(std::move(type)), device(std::move(device)) {}
const c10::ScalarType type;
const c10::Device device;
static size_t hash(const BucketKey& key)
{
return c10::get_hash(key.type, key.device);
}
};
inline bool operator==(const BucketKey& lhs, const BucketKey& rhs)
{
return lhs.type == rhs.type && lhs.device == rhs.device;
}
}
std::tuple<std::vector<std::vector<size_t>>, std::vector<size_t>> compute_bucket_assignment_by_size(
const std::vector<at::Tensor>& tensors,
const std::vector<size_t>& bucket_size_limits,
const std::vector<bool>& expect_sparse_gradient,
const std::vector<int64_t>& tensor_indices,
const c10::optional<std::weak_ptr<c10d::Logger>>& logger)
{
TORCH_INTERNAL_ASSERT(expect_sparse_gradient.empty() ||
(tensors.size() == expect_sparse_gradient.size()),
DIST_ERROR(ErrCode::PARAM));
TORCH_INTERNAL_ASSERT(tensors.size() > 0, DIST_ERROR(ErrCode::PARAM));
std::vector<std::tuple<std::vector<size_t>, size_t>> result;
size_t kNoSizeLimit = 0;
result.reserve(tensors.size());
std::unordered_map<BucketKey, std::vector<size_t>::const_iterator,
c10::hash<BucketKey>>
bucket_size_limit_iterators;
std::unordered_map<BucketKey, BucketAccumulator, c10::hash<BucketKey>>
buckets;
for (const auto i : c10::irange(tensors.size())) {
const auto& tensor = tensors[i];
auto msg = std::string("No support for sparse tensors.");
if (logger.has_value()) {
REDUCER_CHECK(!tensor.is_sparse(), logger.value(), msg, DIST_ERROR(ErrCode::NOT_SUPPORT));
} else {
TORCH_CHECK(!tensor.is_sparse(), msg, DIST_ERROR(ErrCode::NOT_SUPPORT));
}
auto tensor_index = i;
if (!tensor_indices.empty()) {
tensor_index = tensor_indices[i];
}
if (!expect_sparse_gradient.empty() &&
expect_sparse_gradient[tensor_index]) {
result.emplace_back(std::vector<size_t>({tensor_index}),
kNoSizeLimit);
continue;
}
auto key = BucketKey(tensor.scalar_type(), tensor.device());
auto& bucket = buckets[key];
bucket.indices.push_back(tensor_index);
bucket.size +=
static_cast<size_t>(physical_numel(tensor) * tensor.element_size());
if (bucket_size_limit_iterators.count(key) == 0) {
bucket_size_limit_iterators[key] = bucket_size_limits.begin();
}
auto& bucket_size_limit_iterator = bucket_size_limit_iterators[key];
const auto bucket_size_limit = *bucket_size_limit_iterator;
bucket.size_limit = bucket_size_limit;
if (bucket.size >= bucket_size_limit) {
result.emplace_back(std::move(bucket.indices), bucket.size_limit);
bucket = BucketAccumulator();
auto next = bucket_size_limit_iterator + 1;
if (next != bucket_size_limits.end()) {
bucket_size_limit_iterator = next;
}
}
}
for (auto& it : buckets) {
auto& bucket = it.second;
if (!bucket.indices.empty()) {
result.emplace_back(std::move(bucket.indices), bucket.size_limit);
}
}
if (tensor_indices.empty()) {
std::sort(result.begin(), result.end(),
[](const std::tuple<std::vector<size_t>, size_t>& a,
const std::tuple<std::vector<size_t>, size_t>& b) {
auto indices_a = std::get<0>(a);
auto indices_b = std::get<0>(b);
const auto amin =
std::min_element(indices_a.begin(), indices_a.end());
const auto bmin =
std::min_element(indices_b.begin(), indices_b.end());
return *amin < *bmin;
});
}
std::vector<std::vector<size_t>> bucket_indices;
bucket_indices.reserve(result.size());
std::vector<size_t> per_bucket_size_limits;
per_bucket_size_limits.reserve(result.size());
for (const auto& bucket_indices_with_size : result) {
bucket_indices.emplace_back(std::get<0>(bucket_indices_with_size));
per_bucket_size_limits.emplace_back(
std::get<1>(bucket_indices_with_size));
}
return std::make_tuple(bucket_indices, per_bucket_size_limits);
}
void verify_params_across_processes(
const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
const std::vector<at::Tensor>& params,
const c10::optional<std::weak_ptr<c10d::Logger>>& logger)
{
size_t i = 0;
for (const auto& t : params) {
i += static_cast<size_t>(2 * t.dim());
}
at::TensorOptions options;
options = options.dtype(at::kLong);
auto metadata = at::empty({static_cast<long>(i)}, options);
auto metadata_accessor = metadata.accessor<int64_t, 1>();
i = 0;
for (const auto& t : params) {
for (const auto& sz : t.sizes()) {
metadata_accessor[i++] = sz;
}
for (const auto& str : t.strides()) {
metadata_accessor[i++] = str;
}
}
auto metadata_dev = metadata.clone().to(params[0].device());
std::vector<at::Tensor> vec{metadata_dev};
process_group->broadcast(vec)->wait();
auto control = at::empty({static_cast<long>(i)}, options);
control.copy_(metadata_dev, false);
auto control_accessor = control.accessor<int64_t, 1>();
i = 0;
for (const auto p : c10::irange(params.size())) {
const auto& t = params[p];
for (const auto& sz : t.sizes()) {
auto msg = c10::str("params[", p, "] in this process",
" with sizes ",
t.sizes(),
" appears not to match sizes of the same param in process 0.");
if (logger.has_value()) {
REDUCER_CHECK(sz == control_accessor[i++], logger.value(), msg, DIST_ERROR(ErrCode::PARAM))
} else {
TORCH_CHECK(sz == control_accessor[i++], msg, DIST_ERROR(ErrCode::PARAM))
}
}
for (const auto& str : t.strides()) {
auto msg = c10::str("params[", p, "] in this process",
" with sizes ",
t.sizes(),
" appears not to match strides of the same param in process 0.");
if (logger.has_value()) {
REDUCER_CHECK(str == control_accessor[i++], logger.value(), msg, DIST_ERROR(ErrCode::PARAM))
} else {
TORCH_CHECK(str == control_accessor[i++], msg, DIST_ERROR(ErrCode::PARAM))
}
}
}
}
}