* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "om2_malloc_helper.h"
#include <unordered_map>
#include "ge_common/debug/ge_log.h"
#include "runtime/rt_external_mem.h"
namespace gert {
namespace {
using ge::FAILED;
struct MemTypeInfo {
aclrtMemMallocPolicy policy;
aclError (*handler)(void **, size_t, uint16_t, aclrtMemMallocPolicy);
};
aclError MallocDevice(void **ptr, size_t size, uint16_t module_id, aclrtMemMallocPolicy policy) {
aclrtMallocAttribute attr;
attr.attr = ACL_RT_MEM_ATTR_MODULE_ID;
attr.value.moduleId = module_id;
aclrtMallocConfig cfg;
cfg.attrs = &attr;
cfg.numAttrs = 1UL;
const aclError ret = aclrtMallocWithCfg(ptr, size, policy, &cfg);
if (ret != ACL_SUCCESS) {
*ptr = nullptr;
GELOGE(FAILED, "[Call][aclrtMallocWithCfg] failed, size: %zu, policy: %d, ret: %d", size,
static_cast<int32_t>(policy), ret);
}
return ret;
}
aclError MallocHost(void **ptr, size_t size, uint16_t module_id) {
*ptr = nullptr;
aclrtMallocAttribute attr;
attr.attr = ACL_RT_MEM_ATTR_MODULE_ID;
attr.value.moduleId = module_id;
aclrtMallocConfig cfg;
cfg.attrs = &attr;
cfg.numAttrs = 1UL;
const aclError ret = aclrtMallocHostWithCfg(ptr, size, &cfg);
if (ret != ACL_SUCCESS) {
*ptr = nullptr;
GELOGE(FAILED, "[Call][aclrtMallocHostWithCfg] failed, size: %zu, ret: %d", size, ret);
}
return ret;
}
aclError MallocForTaskScheduler(void **ptr, size_t size, aclrtMemMallocPolicy policy, uint16_t module_id) {
*ptr = nullptr;
aclrtMallocAttribute attr;
attr.attr = ACL_RT_MEM_ATTR_MODULE_ID;
attr.value.moduleId = module_id;
aclrtMallocConfig cfg;
cfg.attrs = &attr;
cfg.numAttrs = 1UL;
const aclError ret = aclrtMallocForTaskScheduler(ptr, size, policy, &cfg);
if (ret != ACL_SUCCESS) {
*ptr = nullptr;
GELOGE(FAILED, "[Call][aclrtMallocForTaskScheduler] failed, size: %zu, ret: %d", size, ret);
}
return ret;
}
aclError HandleTs(void **ptr, size_t size, uint16_t module_id, aclrtMemMallocPolicy policy) {
return MallocForTaskScheduler(ptr, size, policy, module_id);
}
aclError HandleHost(void **ptr, size_t size, uint16_t module_id, aclrtMemMallocPolicy) {
return MallocHost(ptr, size, module_id);
}
}
aclError Om2Malloc(void **ptr, size_t size, rtMemType_t mem_type, uint16_t module_id) {
if (ptr == nullptr) {
GELOGE(FAILED, "[Check][Om2Malloc] ptr is nullptr.");
return ACL_ERROR_RT_PARAM_INVALID;
}
*ptr = nullptr;
if (size == 0U) {
return ACL_SUCCESS;
}
static const std::unordered_map<rtMemType_t, MemTypeInfo> kMemTypeMap = {
{RT_MEMORY_TS, {ACL_MEM_MALLOC_HUGE_FIRST, &HandleTs}},
{RT_MEMORY_HOST, {ACL_MEM_TYPE_HIGH_BAND_WIDTH, &HandleHost}},
{RT_MEMORY_HBM, {ACL_MEM_TYPE_HIGH_BAND_WIDTH, &MallocDevice}},
{RT_MEMORY_DEFAULT, {ACL_MEM_TYPE_HIGH_BAND_WIDTH, &MallocDevice}},
{RT_MEMORY_RDMA_HBM, {ACL_MEM_TYPE_HIGH_BAND_WIDTH, &MallocDevice}},
{RT_MEMORY_SPM, {ACL_MEM_TYPE_HIGH_BAND_WIDTH, &MallocDevice}},
{RT_MEMORY_P2P_HBM, {ACL_MEM_MALLOC_HUGE_FIRST_P2P, &MallocDevice}},
{RT_MEMORY_DDR, {ACL_MEM_TYPE_LOW_BAND_WIDTH, &MallocDevice}},
{RT_MEMORY_DDR_NC, {ACL_MEM_TYPE_LOW_BAND_WIDTH, &MallocDevice}},
{RT_MEMORY_P2P_DDR, {ACL_MEM_MALLOC_HUGE_FIRST_P2P, &MallocDevice}},
};
const auto it = kMemTypeMap.find(mem_type);
if (it != kMemTypeMap.end()) {
return it->second.handler(ptr, size, module_id, it->second.policy);
}
GELOGW("[Call][Om2Malloc] unknown mem_type: %u, fallback to default", mem_type);
return MallocDevice(ptr, size, module_id, ACL_MEM_TYPE_HIGH_BAND_WIDTH);
}
}