* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <iostream>
#include "acl/acl.h"
#include "shmemi_host_common.h"
#include "gm2gm/shmemi_device_rma.h"
#include "host_device/shmem_common_types.h"
#include "host/data_plane/shmem_host_rma.h"
inline bool check_heap_addr(void *ptr, uint64_t heap_base, size_t heap_size)
{
uint64_t lower_bound = heap_base;
uint64_t upper_bound = lower_bound + heap_size;
if (uint64_t(ptr) < lower_bound || uint64_t(ptr) >= upper_bound) {
return false;
}
return true;
}
inline bool is_host_mem_heap(void *ptr)
{
if (g_state.host_heap_base == nullptr) {
return false;
}
return check_heap_addr(ptr, (uint64_t)g_state.host_heap_base, g_state.heap_size);
}
void *aclshmem_ptr(void *ptr, int32_t pe)
{
if (pe < 0 || pe >= aclshmem_n_pes()) {
SHM_LOG_ERROR("aclshmem_ptr Failed. PE: " << aclshmem_my_pe() << " Got Illegal PE !!");
return nullptr;
}
uint64_t heap_base = is_host_mem_heap(ptr) ? (uint64_t)g_state.host_heap_base : (uint64_t)g_state.heap_base;
if (!check_heap_addr(ptr, heap_base, g_state.heap_size)) {
SHM_LOG_ERROR("aclshmem_ptr Failed. PE: " << aclshmem_my_pe() << " Got Illegal Address !!");
return nullptr;
}
uint64_t offset = (uint64_t)ptr - heap_base;
void *symm_ptr = is_host_mem_heap(ptr) ? g_state.p2p_host_heap_base[pe] : g_state.p2p_device_heap_base[pe];
if (symm_ptr != nullptr) {
symm_ptr = reinterpret_cast<void*>(reinterpret_cast<uint64_t>(symm_ptr) + offset);
return symm_ptr;
}
SHM_LOG_ERROR("aclshmem_ptr Failed. PE: " << aclshmem_my_pe()
<< " g_state.p2p_" << (is_host_mem_heap(ptr) ? "host" : "device") << "_heap_base contains nullptr, Please Check Init Status!!");
return nullptr;
}
int32_t aclshmemx_set_mte_config(uint64_t offset, uint32_t ub_size, uint32_t sync_id)
{
g_state.mte_config.aclshmem_ub = offset;
g_state.mte_config.ub_size = ub_size;
g_state.mte_config.sync_id = sync_id;
ACLSHMEM_CHECK_RET(update_device_state());
return ACLSHMEM_SUCCESS;
}
int32_t aclshmemx_set_sdma_config(uint64_t offset, uint32_t ub_size, uint32_t sync_id)
{
g_state.sdma_config.aclshmem_ub = offset;
g_state.sdma_config.ub_size = ub_size;
g_state.sdma_config.sync_id = sync_id;
ACLSHMEM_CHECK_RET(update_device_state());
return ACLSHMEM_SUCCESS;
}
int32_t aclshmemx_set_rdma_config(uint64_t offset, uint32_t ub_size, uint32_t sync_id)
{
g_state.rdma_config.aclshmem_ub = offset;
g_state.rdma_config.ub_size = ub_size;
g_state.rdma_config.sync_id = sync_id;
ACLSHMEM_CHECK_RET(update_device_state());
return ACLSHMEM_SUCCESS;
}
#define ACLSHMEM_TYPE_PUT(NAME, TYPE) \
ACLSHMEM_HOST_API void aclshmem_##NAME##_put(TYPE *dest, TYPE *source, size_t nelems, int pe) \
{ \
int ret = aclshmemi_prepare_and_post_rma("aclshmem_put_" #NAME "_mem", ACLSHMEMI_OP_PUT, NO_NBI, (uint8_t *)dest,\
(uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, 1, 1, \
g_state_host.default_stream, g_state_host.default_block_num); \
if (ret < 0) { \
SHM_LOG_ERROR("device calling transfer failed"); \
} \
}
ACLSHMEM_TYPE_FUNC(ACLSHMEM_TYPE_PUT)
#undef ACLSHMEM_TYPE_PUT
#define ACLSHMEM_TYPE_PUT_NBI(NAME, TYPE) \
ACLSHMEM_HOST_API void aclshmem_##NAME##_put_nbi(TYPE *dest, TYPE *source, size_t nelems, int pe) \
{ \
int ret = aclshmemi_prepare_and_post_rma("aclshmem_put_" #NAME "_mem_nbi", ACLSHMEMI_OP_PUT, NBI, (uint8_t *)dest, \
(uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, 1, 1, \
g_state_host.default_stream, g_state_host.default_block_num); \
if (ret < 0) { \
SHM_LOG_ERROR("device calling transfer failed"); \
} \
}
ACLSHMEM_TYPE_FUNC(ACLSHMEM_TYPE_PUT_NBI)
#undef ACLSHMEM_TYPE_PUT_NBI
#define ACLSHMEM_TYPE_IPUT(NAME, TYPE) \
ACLSHMEM_HOST_API void aclshmem_##NAME##_iput(TYPE *dest, TYPE *source, ptrdiff_t dst, ptrdiff_t sst, size_t nelems, int pe) \
{ \
int ret = aclshmemi_prepare_and_post_rma("aclshmem_iput_" #NAME "", ACLSHMEMI_OP_PUT, NO_NBI, (uint8_t *)dest, \
(uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, dst, sst, \
g_state_host.default_stream, g_state_host.default_block_num); \
if (ret < 0) { \
SHM_LOG_ERROR("device calling transfer failed"); \
} \
}
ACLSHMEM_TYPE_FUNC(ACLSHMEM_TYPE_IPUT);
#undef ACLSHMEM_TYPE_IPUT
#define ACLSHMEM_PUT_SIZE(BITS) \
ACLSHMEM_HOST_API void aclshmem_put##BITS(void *dst, void *src, uint32_t elem_size, int32_t pe) \
{ \
int ret = aclshmemi_prepare_and_post_rma("aclshmem_put_" #BITS "", ACLSHMEMI_OP_PUT, NO_NBI, (uint8_t *)dst, \
(uint8_t *)src, elem_size, (BITS) / 8, pe, nullptr, 0, 0, 1, 1, \
g_state_host.default_stream, g_state_host.default_block_num); \
if (ret < 0) { \
SHM_LOG_ERROR("device calling transfer failed"); \
} \
}
ACLSHMEM_SIZE_FUNC(ACLSHMEM_PUT_SIZE);
#undef ACLSHMEM_PUT_SIZE
#define ACLSHMEM_PUT_SIZE_NBI(BITS) \
ACLSHMEM_HOST_API void aclshmem_put##BITS##_nbi(void *dst, void *src, uint32_t elem_size, int32_t pe) \
{ \
int ret = aclshmemi_prepare_and_post_rma("aclshmem_put_" #BITS "_nbi", ACLSHMEMI_OP_PUT, NBI, (uint8_t *)dst, \
(uint8_t *)src, elem_size, (BITS) / 8, pe, nullptr, 0, 0, 1, 1, \
g_state_host.default_stream, g_state_host.default_block_num); \
if (ret < 0) { \
SHM_LOG_ERROR("device calling transfer failed"); \
} \
}
ACLSHMEM_SIZE_FUNC(ACLSHMEM_PUT_SIZE_NBI);
#undef ACLSHMEM_PUT_SIZE_NBI
#define ACLSHMEM_IPUT_SIZE(BITS) \
ACLSHMEM_HOST_API void aclshmem_iput##BITS(void *dest, void *source, ptrdiff_t dst, ptrdiff_t sst, size_t nelems, int pe) \
{ \
int ret = aclshmemi_prepare_and_post_rma("aclshmem_iput_" #BITS "", ACLSHMEMI_OP_PUT, NO_NBI, (uint8_t *)dest, \
(uint8_t *)source, nelems, (BITS) / 8, pe, nullptr, 0, 0, dst, sst, \
g_state_host.default_stream, g_state_host.default_block_num); \
if (ret < 0) { \
SHM_LOG_ERROR("device calling transfer failed"); \
} \
}
ACLSHMEM_SIZE_FUNC(ACLSHMEM_IPUT_SIZE);
#undef ACLSHMEM_IPUT_SIZE
#define ACLSHMEM_TYPE_GET(NAME, TYPE) \
ACLSHMEM_HOST_API void aclshmem_##NAME##_get(TYPE *dest, TYPE *source, size_t nelems, int pe) \
{ \
int ret = aclshmemi_prepare_and_post_rma("aclshmem_get_" #NAME "_mem", ACLSHMEMI_OP_GET, NO_NBI, (uint8_t *)dest,\
(uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, 1, 1, \
g_state_host.default_stream, g_state_host.default_block_num); \
if (ret < 0) { \
SHM_LOG_ERROR("device calling transfer failed"); \
} \
}
ACLSHMEM_TYPE_FUNC(ACLSHMEM_TYPE_GET)
#undef ACLSHMEM_TYPE_GET
#define ACLSHMEM_TYPE_GET_NBI(NAME, TYPE) \
ACLSHMEM_HOST_API void aclshmem_##NAME##_get_nbi(TYPE *dest, TYPE *source, size_t nelems, int pe) \
{ \
int ret = aclshmemi_prepare_and_post_rma("aclshmem_get_" #NAME "_mem_nbi", ACLSHMEMI_OP_GET, NBI, (uint8_t *)dest, \
(uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, 1, 1, \
g_state_host.default_stream, g_state_host.default_block_num); \
if (ret < 0) { \
SHM_LOG_ERROR("device calling transfer failed"); \
} \
}
ACLSHMEM_TYPE_FUNC(ACLSHMEM_TYPE_GET_NBI)
#undef ACLSHMEM_TYPE_GET_NBI
#define ACLSHMEM_TYPE_IGET(NAME, TYPE) \
ACLSHMEM_HOST_API void aclshmem_##NAME##_iget(TYPE *dest, TYPE *source, ptrdiff_t dst, ptrdiff_t sst, size_t nelems, int pe) \
{ \
int ret = aclshmemi_prepare_and_post_rma("aclshmem_iget_" #NAME "", ACLSHMEMI_OP_GET, NO_NBI, (uint8_t *)dest, \
(uint8_t *)source, nelems, sizeof(TYPE), pe, nullptr, 0, 0, dst, sst, \
g_state_host.default_stream, g_state_host.default_block_num); \
if (ret < 0) { \
SHM_LOG_ERROR("device calling transfer failed"); \
} \
}
ACLSHMEM_TYPE_FUNC(ACLSHMEM_TYPE_IGET);
#undef ACLSHMEM_TYPE_IGET
#define ACLSHMEM_GET_SIZE(BITS) \
ACLSHMEM_HOST_API void aclshmem_get##BITS(void *dst, void *src, uint32_t elem_size, int32_t pe) \
{ \
int ret = aclshmemi_prepare_and_post_rma("aclshmem_get_" #BITS "", ACLSHMEMI_OP_GET, NO_NBI, (uint8_t *)dst, \
(uint8_t *)src, elem_size, (BITS) / 8, pe, nullptr, 0, 0, 1, 1, \
g_state_host.default_stream, g_state_host.default_block_num); \
if (ret < 0) { \
SHM_LOG_ERROR("device calling transfer failed"); \
} \
}
ACLSHMEM_SIZE_FUNC(ACLSHMEM_GET_SIZE);
#undef ACLSHMEM_GET_SIZE
#define ACLSHMEM_GET_SIZE_NBI(BITS) \
ACLSHMEM_HOST_API void aclshmem_get##BITS##_nbi(void *dst, void *src, uint32_t elem_size, int32_t pe) \
{ \
int ret = aclshmemi_prepare_and_post_rma("aclshmem_get_" #BITS "_nbi", ACLSHMEMI_OP_GET, NBI, (uint8_t *)dst, \
(uint8_t *)src, elem_size, (BITS) / 8, pe, nullptr, 0, 0, 1, 1, \
g_state_host.default_stream, g_state_host.default_block_num); \
if (ret < 0) { \
SHM_LOG_ERROR("device calling transfer failed"); \
} \
}
ACLSHMEM_SIZE_FUNC(ACLSHMEM_GET_SIZE_NBI);
#undef ACLSHMEM_GET_SIZE_NBI
#define ACLSHMEM_IGET_SIZE(BITS) \
ACLSHMEM_HOST_API void aclshmem_iget##BITS(void *dest, void *source, ptrdiff_t dst, ptrdiff_t sst, size_t nelems, int pe) \
{ \
int ret = aclshmemi_prepare_and_post_rma("aclshmem_iget_" #BITS "", ACLSHMEMI_OP_GET, NO_NBI, (uint8_t *)dest, \
(uint8_t *)source, nelems, (BITS) / 8, pe, nullptr, 0, 0, dst, sst, \
g_state_host.default_stream, g_state_host.default_block_num); \
if (ret < 0) { \
SHM_LOG_ERROR("device calling transfer failed"); \
} \
}
ACLSHMEM_SIZE_FUNC(ACLSHMEM_IGET_SIZE);
#undef ACLSHMEM_IGET_SIZE
#define ACLSHMEM_TYPENAME_P(NAME, TYPE) \
ACLSHMEM_HOST_API void aclshmem_##NAME##_p(TYPE *dst, const TYPE value, int pe) \
{ \
aclshmemi_prepare_and_post_rma_##NAME##_p("aclshmem_" #NAME "_p", (uint8_t *)dst, value, pe, \
g_state_host.default_stream, g_state_host.default_block_num); \
}
ACLSHMEM_TYPE_FUNC(ACLSHMEM_TYPENAME_P)
#undef ACLSHMEM_TYPENAME_P
#define ACLSHMEM_TYPENAME_G(NAME, TYPE) \
ACLSHMEM_HOST_API TYPE aclshmem_##NAME##_g(TYPE *src, int32_t pe) \
{ \
TYPE value {}; \
auto ptr = aclshmem_ptr(src, pe); \
if (ptr == nullptr) { \
SHM_LOG_ERROR("aclshmem_g failed"); \
return value; \
} \
int ret = \
aclrtMemcpy(&value, sizeof(TYPE), reinterpret_cast<void *>(ptr), sizeof(TYPE), ACL_MEMCPY_DEVICE_TO_HOST); \
if (ret != 0) { \
SHM_LOG_ERROR("aclshmem_g failed"); \
} \
return value; \
}
ACLSHMEM_TYPE_FUNC(ACLSHMEM_TYPENAME_G)
#undef ACLSHMEM_TYPENAME_G
void aclshmem_putmem(void *dst, void *src, size_t elem_size, int32_t pe)
{
int ret = aclshmemi_prepare_and_post_rma("shmem putmem", ACLSHMEMI_OP_PUT, NO_NBI, (uint8_t *)dst, (uint8_t *)src,
elem_size, 1, pe, nullptr, 0, 0, 1, 1, g_state_host.default_stream,
g_state_host.default_block_num);
if (ret < 0) {
SHM_LOG_ERROR("aclshmem_putmem failed");
}
}
void aclshmem_getmem(void *dst, void *src, size_t elem_size, int32_t pe)
{
int ret = aclshmemi_prepare_and_post_rma("shmem getmem", ACLSHMEMI_OP_GET, NO_NBI, (uint8_t *)dst, (uint8_t *)src,
elem_size, 1, pe, nullptr, 0, 0, 1, 1, g_state_host.default_stream,
g_state_host.default_block_num);
if (ret < 0) {
SHM_LOG_ERROR("aclshmem_getmem failed");
}
}
void aclshmem_putmem_nbi(void *dst, void *src, size_t elem_size, int32_t pe)
{
int ret = aclshmemi_prepare_and_post_rma("aclshmem_putmem_nbi", ACLSHMEMI_OP_PUT, NBI, (uint8_t *)dst, (uint8_t *)src,
elem_size, 1, pe, nullptr, 0, 0, 1, 1, g_state_host.default_stream,
g_state_host.default_block_num);
if (ret < 0) {
SHM_LOG_ERROR("aclshmem_putmem_nbi failed");
}
}
void aclshmem_getmem_nbi(void *dst, void *src, size_t elem_size, int32_t pe)
{
int ret = aclshmemi_prepare_and_post_rma("aclshmem_getmem_nbi", ACLSHMEMI_OP_GET, NBI, (uint8_t *)dst, (uint8_t *)src,
elem_size, 1, pe, nullptr, 0, 0, 1, 1, g_state_host.default_stream,
g_state_host.default_block_num);
if (ret < 0) {
SHM_LOG_ERROR("aclshmem_getmem_nbi failed");
}
}
void aclshmemx_getmem_on_stream(void* dst, void* src, size_t elem_size, int32_t pe, aclrtStream stream)
{
if (stream == nullptr) {
stream = g_state_host.default_stream;
}
int ret = aclshmemi_prepare_and_post_rma("aclshmemx_getmem_on_stream", ACLSHMEMI_OP_GET, NO_NBI, (uint8_t *)dst, (uint8_t *)src,
elem_size, 1, pe, nullptr, 0, 0, 1, 1, stream,
g_state_host.default_block_num);
if (ret < 0) {
SHM_LOG_ERROR("aclshmemx_getmem_on_stream failed");
}
}
void aclshmemx_putmem_on_stream(void* dst, void* src, size_t elem_size, int32_t pe, aclrtStream stream)
{
if (stream == nullptr) {
stream = g_state_host.default_stream;
}
int ret = aclshmemi_prepare_and_post_rma("aclshmemx_putmem_on_stream", ACLSHMEMI_OP_PUT, NO_NBI, (uint8_t *)dst, (uint8_t *)src,
elem_size, 1, pe, nullptr, 0, 0, 1, 1, stream,
g_state_host.default_block_num);
if (ret != 0) {
SHM_LOG_ERROR("aclshmemx_putmem_on_stream failed");
}
}
void aclshmemx_quiet_on_stream(aclrtStream stream)
{
if (stream == nullptr) {
stream = g_state_host.default_stream;
}
int ret = aclshmemi_prepare_and_post_rma("aclshmemx_quiet_on_stream", ACLSHMEMI_OP_QUIET, NO_NBI,
nullptr, nullptr, 0, 0, 0, nullptr, 0, 0, 1, 1,
stream, g_state_host.default_block_num);
if (ret != 0) {
SHM_LOG_ERROR("aclshmemx_quiet_on_stream failed");
}
}