#include "torch_npu/csrc/core/npu/NPUException.h"
#include "torch_npu/csrc/core/npu/NPUFunctions.h"
#include "torch_npu/csrc/logging/LogContext.h"
#include "torch_npu/csrc/distributed/symm_mem/NPUSymmetricMemoryUtils.hpp"
#include "torch_npu/csrc/distributed/symm_mem/NPUSHMEMExtension.h"
#include "torch_npu/csrc/distributed/symm_mem/NPUSHMEMSymmetricMemory.hpp"
namespace c10d {
namespace symmetric_memory {
constexpr size_t npu_signal_pad_size = 2048;
static NPUStoreExchange storeExchange = NPUStoreExchange("NPUSHMEMSymmetricMemory");
NPUSHMEMAllocation::~NPUSHMEMAllocation()
{
if (is_finalizing()) {
return;
}
auto device = c10::Device(at::DeviceType::PrivateUse1, device_idx);
at::DeviceGuard device_guard(device);
TORCH_NPU_SYMMEM_LOGD("~NPUSHMEMAllocation, start Shmem_free, ptr is %p.", ptr);
Aclshmem_free(ptr);
TORCH_NPU_SYMMEM_LOGD("~NPUSHMEMAllocation, end Shmem_free, ptr is %p.", ptr);
}
NPUSHMEMSymmetricMemory::NPUSHMEMSymmetricMemory(
std::shared_ptr<NPUSHMEMAllocation> allocation,
const std::string& group_name)
: allocation_(allocation),
buffer_size_(allocation->buffer_size),
device_idx_(allocation->device_idx),
group_name_(group_name)
{
static int exchanged_n_times = 0;
auto device = c10::Device(at::DeviceType::PrivateUse1, device_idx_);
at::DeviceGuard device_guard(device);
auto global_rank = get_group_info("0").rank;
const GroupInfo& group_info = get_group_info(group_name_);
auto store = group_info.store;
rank_ = group_info.rank;
world_size_ = group_info.world_size;
if (rank_to_global_rank_.empty()) {
rank_to_global_rank_ =
storeExchange.all_gather(store, rank_, world_size_, global_rank);
exchanged_n_times++;
if (rank_ == 0) {
std::stringstream ss;
for (size_t i = 0; i < rank_to_global_rank_.size(); ++i) {
ss << rank_to_global_rank_[i];
if (i != rank_to_global_rank_.size() - 1) {
ss << ", ";
}
}
TORCH_NPU_SYMMEM_LOGD("[rank %d] rank_to_global_rank: %s, group_name: %s, exchanged_n_times: %d.",
rank_, (ss.str()).c_str(), group_name_.c_str(), exchanged_n_times);
}
}
TORCH_INTERNAL_ASSERT(!rank_to_global_rank_.empty());
for (int r = 0; r < world_size_; ++r) {
auto buffer = Aclshmem_ptr(allocation->ptr, rank_to_global_rank_[r]);
TORCH_CHECK(buffer != nullptr, "shmem_ptr return nullptr with ptr ", allocation->ptr, DIST_ERROR(ErrCode::MEMORY));
buffers_.push_back(buffer);
TORCH_NPU_SYMMEM_LOGD("[rank %d] NPUSHMEMSymmetricMemory shmem_ptr, r is %d, rank_to_global_rank is %d, ptr is %p, shmem_ptr is %p.",
rank_, r, rank_to_global_rank_[r], allocation->ptr, buffer);
}
TORCH_NPU_SYMMEM_LOGD("NPUSHMEMSymmetricMemory created, buffer_size is %d, device_idx is %d, group_name is %s.",
allocation->buffer_size, allocation->device_idx, group_name_.c_str());
}
NPUSHMEMSymmetricMemory::~NPUSHMEMSymmetricMemory()
{
TORCH_NPU_SYMMEM_LOGD("NPUSHMEMSymmetricMemory destroy, group_name is %s", group_name_.c_str());
}
std::vector<void*> NPUSHMEMSymmetricMemory::get_buffer_ptrs()
{
return buffers_;
}
std::vector<void*> NPUSHMEMSymmetricMemory::get_signal_pad_ptrs()
{
return signal_pads_;
}
void** NPUSHMEMSymmetricMemory::get_buffer_ptrs_dev()
{
return buffers_dev_;
}
void** NPUSHMEMSymmetricMemory::get_signal_pad_ptrs_dev()
{
return signal_pads_dev_;
}
size_t NPUSHMEMSymmetricMemory::get_buffer_size()
{
return buffer_size_;
}
size_t NPUSHMEMSymmetricMemory::get_signal_pad_size()
{
return npu_signal_pad_size;
}
bool NPUSHMEMSymmetricMemory::has_multicast_support()
{
return false;
}
void* NPUSHMEMSymmetricMemory::get_multicast_ptr()
{
return nullptr;
}
at::Tensor NPUSHMEMSymmetricMemory::get_buffer(
int rank,
c10::IntArrayRef sizes,
c10::ScalarType dtype,
int64_t storage_offset)
{
throw std::runtime_error("NPUSHMEMSymmetricMemory does not support get_buffer" + DIST_ERROR(ErrCode::NOT_SUPPORT));
}
at::Tensor NPUSHMEMSymmetricMemory::get_signal_pad(
int rank,
c10::IntArrayRef sizes,
std::optional<c10::ScalarType> dtype,
int64_t storage_offset)
{
throw std::runtime_error("NPUSHMEMSymmetricMemory does not support get_signal_pad" + DIST_ERROR(ErrCode::NOT_SUPPORT));
}
void NPUSHMEMSymmetricMemory::barrier(int channel, size_t timeout_ms)
{
}
void NPUSHMEMSymmetricMemory::put_signal(
int dst_rank,
int channel,
size_t timeout_ms)
{
}
void NPUSHMEMSymmetricMemory::wait_signal(
int src_rank,
int channel,
size_t timeout_ms)
{
}
int NPUSHMEMSymmetricMemory::get_rank()
{
return rank_;
}
int NPUSHMEMSymmetricMemory::get_world_size()
{
return world_size_;
}
void* NPUSHMEMSymmetricMemoryAllocator::alloc(
size_t size,
int device_idx,
const std::optional<std::string>& group_name)
{
TORCH_CHECK(
group_name == std::nullopt,
"NPUSHMEMSymmetricMemoryAllocator::alloc "
"must not be called with a group_name", DIST_ERROR(ErrCode::PARAM));
TORCH_NPU_SYMMEM_LOGD("NPUSHMEMSymmetricMemoryAllocator alloc start, size is %d, device is %d, group_name is %s",
size, device_idx, group_name == std::nullopt ? "" : (*group_name).c_str());
c10_npu::LazySetDevice(device_idx);
auto group_info = get_group_info("0");
auto store = group_info.store;
int rank = group_info.rank;
int world_size = group_info.world_size;
npushmem_extension::initialize_npushmem_with_store(store, rank, world_size);
auto ptr = Aclshmem_malloc(size);
TORCH_CHECK(ptr != nullptr, "shmem_malloc return nullptr with size ", size, DIST_ERROR(ErrCode::MEMORY));
auto allocation =
std::make_shared<NPUSHMEMAllocation>(ptr, size, device_idx);
allocations_.try_emplace(ptr, std::move(allocation));
TORCH_NPU_SYMMEM_LOGD("NPUSHMEMSymmetricMemoryAllocator alloc end, size is %d, device is %d, group_name is %s, ptr is %p",
size, device_idx, group_name == std::nullopt ? "" : (*group_name).c_str(), ptr);
return ptr;
}
void NPUSHMEMSymmetricMemoryAllocator::free(void* ptr)
{
TORCH_NPU_SYMMEM_LOGD("NPUSHMEMSymmetricMemoryAllocator free start, ptr is %p", ptr);
allocations_.erase(ptr);
TORCH_NPU_SYMMEM_LOGD("NPUSHMEMSymmetricMemoryAllocator free end, ptr is %p", ptr);
}
size_t NPUSHMEMSymmetricMemoryAllocator::get_alloc_size(void* ptr)
{
auto it = allocations_.find(ptr);
if (it == allocations_.end()) {
TORCH_CHECK(false, ptr, " is not allocated with NPUSHMEMSymmetricMemoryAllocator", DIST_ERROR(ErrCode::PARAM));
}
return it->second->buffer_size;
}
c10::intrusive_ptr<SymmetricMemory> NPUSHMEMSymmetricMemoryAllocator::rendezvous(
void* ptr,
const std::optional<std::string>& group_name)
{
TORCH_NPU_SYMMEM_LOGD("NPUSHMEMSymmetricMemoryAllocator rendezvous start, ptr is %p, group_name is %s", ptr, (*group_name).c_str());
TORCH_CHECK(group_name.has_value(), "rendezvous, group_name is invalid.", DIST_ERROR(ErrCode::PARAM));
{
auto it = symm_mems_.find(std::make_tuple(ptr, *group_name));
if (it != symm_mems_.end()) {
return it->second;
}
}
auto it = allocations_.find(ptr);
TORCH_CHECK(it != allocations_.end(), "rendezvous, ptr is invalid.", DIST_ERROR(ErrCode::PARAM));
auto symm_mem =
c10::make_intrusive<NPUSHMEMSymmetricMemory>(it->second, *group_name);
symm_mems_[std::make_tuple(ptr, *group_name)] = symm_mem;
TORCH_NPU_SYMMEM_LOGD("NPUSHMEMSymmetricMemoryAllocator rendezvous end, ptr is %p, group_name is %s", ptr, (*group_name).c_str());
return symm_mem;
}
bool NPUSHMEMSymmetricMemoryAllocator::has_multicast_support(int device_idx)
{
throw std::runtime_error("NPUSHMEMSymmetricMemoryAllocator does not support has_multicast_support" + DIST_ERROR(ErrCode::NOT_SUPPORT));
}
struct RegisterNPUSHMEMSymmetricMemoryAllocator {
RegisterNPUSHMEMSymmetricMemoryAllocator()
{
register_allocator(
c10::DeviceType::PrivateUse1,
c10::make_intrusive<NPUSHMEMSymmetricMemoryAllocator>());
}
};
static RegisterNPUSHMEMSymmetricMemoryAllocator register_allocator_;
}
}