#include "torch_npu/csrc/distributed/symm_mem/NPUSHMEMInterface.h"
#include "torch_npu/csrc/core/npu/register/FunctionLoader.h"
#include "torch_npu/csrc/core/npu/NPUException.h"
namespace c10d {
namespace symmetric_memory {
#undef TORCH_NPU_LOAD_FUNC
#define TORCH_NPU_LOAD_FUNC(funcName) \
TORCH_NPU_REGISTER_FUNCTION(libshmem, funcName)
#undef TORCH_NPU_GET_FUNC
#define TORCH_NPU_GET_FUNC(funcName) \
TORCH_NPU_GET_FUNCTION(libshmem, funcName)
TORCH_NPU_REGISTER_LIBRARY(libshmem)
TORCH_NPU_LOAD_FUNC(aclshmemx_set_conf_store_tls)
TORCH_NPU_LOAD_FUNC(aclshmemx_get_uniqueid)
TORCH_NPU_LOAD_FUNC(aclshmemx_set_attr_uniqueid_args)
TORCH_NPU_LOAD_FUNC(aclshmemx_init_attr)
TORCH_NPU_LOAD_FUNC(aclshmem_malloc)
TORCH_NPU_LOAD_FUNC(aclshmem_free)
TORCH_NPU_LOAD_FUNC(aclshmem_ptr)
TORCH_NPU_LOAD_FUNC(aclshmem_finalize)
TORCH_NPU_LOAD_FUNC(shmem_set_conf_store_tls)
TORCH_NPU_LOAD_FUNC(shmem_set_attr)
TORCH_NPU_LOAD_FUNC(shmem_get_uniqueid)
TORCH_NPU_LOAD_FUNC(shmem_set_attr_uniqueid_args)
TORCH_NPU_LOAD_FUNC(shmem_malloc)
TORCH_NPU_LOAD_FUNC(shmem_free)
TORCH_NPU_LOAD_FUNC(shmem_ptr)
TORCH_NPU_LOAD_FUNC(shmem_finalize)
int32_t Aclshmemx_set_conf_store_tls(bool enable, const char *tls_info, const uint32_t tls_info_len)
{
typedef int32_t (*ShmemApiFunc)(bool, const char *, const uint32_t);
static ShmemApiFunc shmem_set_conf_store_tls_func = nullptr;
if (shmem_set_conf_store_tls_func == nullptr) {
shmem_set_conf_store_tls_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(aclshmemx_set_conf_store_tls);
}
if (shmem_set_conf_store_tls_func == nullptr) {
shmem_set_conf_store_tls_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(shmem_set_conf_store_tls);
}
TORCH_CHECK(shmem_set_conf_store_tls_func, "Failed to find function ",
"aclshmemx_set_conf_store_tls or shmem_set_conf_store_tls", PTA_ERROR(ErrCode::NOT_FOUND));
return shmem_set_conf_store_tls_func(enable, tls_info, tls_info_len);
}
int32_t Shmem_set_attr(int32_t my_rank, int32_t n_ranks, uint64_t local_mem_size, const char *ip_port,
shmem_init_attr_t **attributes)
{
typedef int32_t (*ShmemApiFunc)(int32_t, int32_t, uint64_t, const char *, shmem_init_attr_t **);
static ShmemApiFunc shmem_set_attr_func = nullptr;
if (shmem_set_attr_func == nullptr) {
shmem_set_attr_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(shmem_set_attr);
}
TORCH_CHECK(shmem_set_attr_func, "Failed to find function ", "shmem_set_attr", PTA_ERROR(ErrCode::NOT_FOUND));
return shmem_set_attr_func(my_rank, n_ranks, local_mem_size, ip_port, attributes);
}
int Shmemx_get_uniqueid(shmem_uniqueid_t *uid)
{
typedef int32_t (*ShmemApiFunc)(shmem_uniqueid_t *);
static ShmemApiFunc shmem_get_uniqueid_func = nullptr;
if (shmem_get_uniqueid_func == nullptr) {
shmem_get_uniqueid_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(shmem_get_uniqueid);
}
TORCH_CHECK(shmem_get_uniqueid_func, "Failed to find function ",
"shmem_get_uniqueid", PTA_ERROR(ErrCode::NOT_FOUND));
return shmem_get_uniqueid_func(uid);
}
bool Aclshmemx_get_uniqueid_exist()
{
const static bool shmemApiFuncExist = []() -> bool {
try {
auto func = TORCH_NPU_GET_FUNC(aclshmemx_get_uniqueid);
return func != nullptr;
} catch (...) {
return false;
}
}();
return shmemApiFuncExist;
}
int32_t Aclshmemx_get_uniqueid(aclshmemx_uniqueid_t *uid)
{
typedef int32_t (*ShmemApiFunc)(aclshmemx_uniqueid_t *);
static ShmemApiFunc shmem_get_uniqueid_func = nullptr;
if (shmem_get_uniqueid_func == nullptr) {
shmem_get_uniqueid_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(aclshmemx_get_uniqueid);
}
TORCH_CHECK(shmem_get_uniqueid_func, "Failed to find function ",
"aclshmemx_get_uniqueid or shmem_get_uniqueid", PTA_ERROR(ErrCode::NOT_FOUND));
return shmem_get_uniqueid_func(uid);
}
int Shmem_set_attr_uniqueid_args(int rank_id, int nranks, const shmem_uniqueid_t *uid, shmem_init_attr_t *attr)
{
typedef int32_t (*ShmemApiFunc)(int, int, const shmem_uniqueid_t *, shmem_init_attr_t *);
static ShmemApiFunc shmem_set_attr_uniqueid_args_func = nullptr;
if (shmem_set_attr_uniqueid_args_func == nullptr) {
shmem_set_attr_uniqueid_args_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(shmem_set_attr_uniqueid_args);
}
TORCH_CHECK(shmem_set_attr_uniqueid_args_func, "Failed to find function ",
"shmem_set_attr_uniqueid_args", PTA_ERROR(ErrCode::NOT_FOUND));
return shmem_set_attr_uniqueid_args_func(rank_id, nranks, uid, attr);
}
int Aclshmemx_set_attr_uniqueid_args(int rank_id, int nranks, int64_t local_mem_size,
aclshmemx_uniqueid_t *uid, aclshmemx_init_attr_t *aclshmem_attr)
{
typedef int32_t (*ShmemApiFunc)(int, int, int64_t, aclshmemx_uniqueid_t *, aclshmemx_init_attr_t *);
static ShmemApiFunc aclshmemx_set_attr_uniqueid_args_func = nullptr;
if (aclshmemx_set_attr_uniqueid_args_func == nullptr) {
aclshmemx_set_attr_uniqueid_args_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(aclshmemx_set_attr_uniqueid_args);
}
TORCH_CHECK(aclshmemx_set_attr_uniqueid_args_func, "Failed to find function ",
"aclshmemx_set_attr_uniqueid_args", PTA_ERROR(ErrCode::NOT_FOUND));
return aclshmemx_set_attr_uniqueid_args_func(rank_id, nranks, local_mem_size, uid, aclshmem_attr);
}
int Aclshmemx_init_attr(aclshmemx_bootstrap_t bootstrap_flags, aclshmemx_init_attr_t *attributes)
{
typedef int32_t (*ShmemApiFunc)(aclshmemx_bootstrap_t, aclshmemx_init_attr_t *);
static ShmemApiFunc aclshmemx_init_attr_func = nullptr;
if (aclshmemx_init_attr_func == nullptr) {
aclshmemx_init_attr_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(aclshmemx_init_attr);
}
TORCH_CHECK(aclshmemx_init_attr_func, "Failed to find function ",
"aclshmemx_init_attr", PTA_ERROR(ErrCode::NOT_FOUND));
return aclshmemx_init_attr_func(bootstrap_flags, attributes);
}
void *Aclshmem_malloc(size_t size)
{
typedef void* (*ShmemApiFunc)(size_t);
static ShmemApiFunc shmem_malloc_func = nullptr;
if (shmem_malloc_func == nullptr) {
shmem_malloc_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(aclshmem_malloc);
}
if (shmem_malloc_func == nullptr) {
shmem_malloc_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(shmem_malloc);
}
TORCH_CHECK(shmem_malloc_func, "Failed to find function ",
"aclshmem_malloc or shmem_malloc", PTA_ERROR(ErrCode::NOT_FOUND));
return shmem_malloc_func(size);
}
void Aclshmem_free(void *ptr)
{
typedef void (*ShmemApiFunc)(void *);
static ShmemApiFunc shmem_free_func = nullptr;
if (shmem_free_func == nullptr) {
shmem_free_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(aclshmem_free);
}
if (shmem_free_func == nullptr) {
shmem_free_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(shmem_free);
}
TORCH_CHECK(shmem_free_func, "Failed to find function ",
"aclshmem_free or shmem_free", PTA_ERROR(ErrCode::NOT_FOUND));
return shmem_free_func(ptr);
}
void *Aclshmem_ptr(void *ptr, int pe)
{
typedef void* (*ShmemApiFunc)(void *, int);
static ShmemApiFunc shmem_ptr_func = nullptr;
if (shmem_ptr_func == nullptr) {
shmem_ptr_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(aclshmem_ptr);
}
if (shmem_ptr_func == nullptr) {
shmem_ptr_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(shmem_ptr);
}
TORCH_CHECK(shmem_ptr_func, "Failed to find function ",
"aclshmem_ptr or shmem_ptr", PTA_ERROR(ErrCode::NOT_FOUND));
return shmem_ptr_func(ptr, pe);
}
bool Aclshmem_finalize_exist()
{
const static bool shmemApiFuncExist = []() -> bool {
try {
auto func1 = TORCH_NPU_GET_FUNC(aclshmem_finalize);
auto func2 = TORCH_NPU_GET_FUNC(shmem_finalize);
return func1 != nullptr || func2 != nullptr;
} catch (...) {
return false;
}
}();
return shmemApiFuncExist;
}
int Aclshmem_finalize(void)
{
typedef int (*ShmemApiFunc)(void);
static ShmemApiFunc shmem_finalize_func = nullptr;
if (shmem_finalize_func == nullptr) {
shmem_finalize_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(aclshmem_finalize);
}
if (shmem_finalize_func == nullptr) {
shmem_finalize_func = (ShmemApiFunc)TORCH_NPU_GET_FUNC(shmem_finalize);
}
TORCH_CHECK(shmem_finalize_func, "Failed to find function ",
"aclshmem_finalize or shmem_finalize", PTA_ERROR(ErrCode::NOT_FOUND));
return shmem_finalize_func();
}
}
}