* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ------------------------------------------------------------------------- */
#include "MemoryContext.h"
#include <string>
#include "utils/InjectLogger.h"
void MemoryContext::Append(const MemAddr addr, uint64_t size)
{
std::lock_guard<std::mutex> mapLock(mapMtx_);
if (addr == nullptr) {
WARN_LOG("Failed to add memory map, devPtr is nullptr.");
return;
}
memSectionMap_[addr] = {addr, nullptr, size, aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST};
}
void MemoryContext::Discard(MemAddr addr)
{
std::lock_guard<std::mutex> mapLock(mapMtx_);
auto it = memSectionMap_.find(addr);
if (it == memSectionMap_.end()) {
return;
}
if (it->second.snapshotAddr != nullptr) {
(void)aclrtFreeImplOrigin(it->second.snapshotAddr);
}
memSectionMap_.erase(addr);
}
bool MemoryContext::Backup()
{
std::lock_guard<std::mutex> mapLock(mapMtx_);
for (auto &p : memSectionMap_) {
auto ret = CreateSnapshot(p.second, GetCurrentStream());
if (!ret) {
return false;
}
}
return true;
}
bool MemoryContext::Restore()
{
std::lock_guard<std::mutex> mapLock(mapMtx_);
for (const auto &p : memSectionMap_) {
auto ret = RestoreFromSnapshot(p.second, GetCurrentStream());
if (!ret) {
return false;
}
}
return true;
}
MemoryContext::~MemoryContext()
{
DiscardAll();
}
void MemoryContext::DiscardAll()
{
for (const auto &p : memSectionMap_) {
if (p.second.snapshotAddr != nullptr) {
(void)aclrtFreeImplOrigin(p.second.snapshotAddr);
}
}
std::lock_guard<std::mutex> mapLock(mapMtx_);
memSectionMap_.clear();
stream_ = nullptr;
}
bool MemoryContext::CreateSnapshot(MemSection §ion, aclrtStream stream) const
{
aclError ret;
if (section.originAddr == nullptr) {
ERROR_LOG("failed to create snapshot for an invalid addr");
return false;
}
if (section.snapshotAddr == nullptr) {
ret = aclrtMallocImplOrigin(§ion.snapshotAddr, section.size, section.policy);
if (ret != ACL_SUCCESS) {
ERROR_LOG("failed to create memory snapshot for 0x%lx ret=%d",
reinterpret_cast<uint64_t>(section.originAddr), ret);
return false;
}
}
MemCopyDirection dir = MemCopyDirection::ORIGIN_TO_SNAPSHOT;
if (stream == nullptr) {
return MemCopySync(section, dir);
}
return MemCopyWithStream(section, stream, dir);
}
bool MemoryContext::RestoreFromSnapshot(const MemSection §ion, aclrtStream stream) const
{
if (section.originAddr == nullptr) {
ERROR_LOG("failed to restore an invalid addr");
return false;
}
if (section.snapshotAddr == nullptr) {
ERROR_LOG("there is no snapshot memory allocated for 0x%lx",
reinterpret_cast<uint64_t>(section.originAddr));
return false;
}
MemCopyDirection dir = MemCopyDirection::SNAPSHOT_TO_ORIGIN;
if (stream == nullptr) {
return MemCopySync(section, dir);
}
return MemCopyWithStream(section, stream, dir);
}
aclrtStream MemoryContext::GetCurrentStream()
{
std::lock_guard<std::mutex> lock(streamMtx_);
if (stream_ != nullptr) {
return stream_;
}
aclError ret = aclrtCtxGetCurrentDefaultStreamImplOrigin(&stream_);
if (ret != ACL_SUCCESS) {
stream_ = nullptr;
DEBUG_LOG("aclrtCtxGetCurrentDefaultStreamImplOrigin failed. ret=%d", ret);
}
return stream_;
}
bool MemoryContext::MemCopyWithStream(const MemSection §ion, aclrtStream stream, MemCopyDirection dir) const
{
aclError ret;
MemAddr dst;
MemAddr src;
if (dir == MemCopyDirection::ORIGIN_TO_SNAPSHOT) {
dst = section.snapshotAddr;
src = section.originAddr;
} else {
dst = section.originAddr;
src = section.snapshotAddr;
}
ret = aclrtMemcpyAsyncImplOrigin(dst, section.size, src, section.size, ACL_MEMCPY_DEVICE_TO_DEVICE, stream);
if (ret != ACL_SUCCESS) {
ERROR_LOG("aclrtMemcpyAsyncImpl copy return failed. ret=%d, dst_addr=%p, dst_size=%lu, "
"src_addr=%p, src_size=%lu", ret, dst, section.size, src, section.size);
return false;
}
ret = aclrtSynchronizeStreamImplOrigin(stream);
if (ret != ACL_SUCCESS) {
ERROR_LOG("aclrtSynchronizeStreamImpl failed. ret=%d", ret);
return false;
}
return true;
}
bool MemoryContext::MemCopySync(const MemSection §ion, MemCopyDirection dir) const
{
MemAddr dst;
MemAddr src;
if (dir == MemCopyDirection::ORIGIN_TO_SNAPSHOT) {
dst = section.snapshotAddr;
src = section.originAddr;
} else {
dst = section.originAddr;
src = section.snapshotAddr;
}
aclError ret = aclrtMemcpyImplOrigin(dst, section.size, src, section.size, ACL_MEMCPY_DEVICE_TO_DEVICE);
if (ret != ACL_SUCCESS) {
ERROR_LOG("aclrtMemcpyImpl copy return failed. ret=%d, dst_addr=%p, dst_size=%lu, "
"src_addr=%p, src_size=%lu", ret, section.snapshotAddr,
section.size, section.originAddr, section.size);
return false;
}
return true;
}