* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ------------------------------------------------------------------------- */
#define private public
#include "runtime/inject_helpers/MemoryContext.h"
#undef private
#include <string>
#include <vector>
#include <sys/stat.h>
#include <sys/types.h>
#include <gtest/gtest.h>
#include "runtime/RuntimeOrigin.h"
#include "acl_rt_impl/AscendclImplOrigin.h"
#include "utils/FileSystem.h"
#include "utils/PipeCall.h"
#include "mockcpp/mockcpp.hpp"
constexpr uint64_t MEM_ADDR = 0x12c045400000U;
constexpr uint64_t MEM_SIZE = 0x1000U;
TEST(MemoryContext, record_and_discard_multiple_normal_memory_section)
{
auto &inst = MemoryContext::Instance();
MemoryContext::Instance().DiscardAll();
uint32_t count = 0U;
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR), MEM_SIZE);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR + MEM_SIZE), 2 * MEM_SIZE);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
MemoryContext::Instance().Discard(reinterpret_cast<void *>(MEM_ADDR));
count--;
ASSERT_EQ(inst.memSectionMap_.size(), count);
MemoryContext::Instance().Discard(reinterpret_cast<void *>(MEM_ADDR + MEM_SIZE));
count--;
ASSERT_EQ(inst.memSectionMap_.size(), count);
}
TEST(MemoryContext, record_null_memory_section)
{
auto &inst = MemoryContext::Instance();
MemoryContext::Instance().DiscardAll();
uint32_t count = 0U;
MemoryContext::Instance().Append(nullptr, MEM_SIZE);
ASSERT_EQ(inst.memSectionMap_.size(), count);
}
TEST(MemoryContext, record_and_discard_the_same_memory_section)
{
auto &inst = MemoryContext::Instance();
uint32_t count = 0U;
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR), MEM_SIZE);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR), MEM_SIZE * 2);
ASSERT_EQ(inst.memSectionMap_.size(), count);
MemoryContext::Instance().Discard(reinterpret_cast<void *>(MEM_ADDR));
count--;
ASSERT_EQ(inst.memSectionMap_.size(), count);
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR), MEM_SIZE);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
MemoryContext::Instance().DiscardAll();
count = 0U;
ASSERT_EQ(inst.memSectionMap_.size(), count);
}
TEST(MemoryContext, record_and_discard_all_memory_section)
{
auto &inst = MemoryContext::Instance();
uint32_t count = 0U;
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR), MEM_SIZE);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR + MEM_SIZE), MEM_SIZE * 2);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
MemoryContext::Instance().DiscardAll();
count = 0U;
ASSERT_EQ(inst.memSectionMap_.size(), count);
}
TEST(MemoryContext, create_snapshot_for_memory_without_stream_and_restore_success)
{
GlobalMockObject::verify();
auto &inst = MemoryContext::Instance();
uint32_t count = 0U;
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR), MEM_SIZE);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR + MEM_SIZE), MEM_SIZE * 2);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
bool ret = false;
void *addr = reinterpret_cast<void *>(MEM_ADDR);
MOCKER(&aclrtMallocImplOrigin)
.expects(exactly(2))
.with(outBoundP(&addr, sizeof(void *)), any(), any())
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtMemcpyImplOrigin)
.expects(exactly(4))
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtFreeImplOrigin)
.expects(exactly(2))
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtCtxGetCurrentDefaultStreamImplOrigin)
.stubs()
.will(returnValue(ACL_ERROR_BAD_ALLOC));
ret = inst.Backup();
ASSERT_EQ(ret, true);
ret = inst.Restore();
ASSERT_EQ(ret, true);
MemoryContext::Instance().DiscardAll();
count = 0U;
ASSERT_EQ(inst.memSectionMap_.size(), count);
GlobalMockObject::verify();
}
TEST(MemoryContext, create_snapshot_for_memory_with_stream_and_restore_success)
{
GlobalMockObject::verify();
auto &inst = MemoryContext::Instance();
uint32_t count = 0U;
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR), MEM_SIZE);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR + MEM_SIZE), MEM_SIZE * 2);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
bool ret = false;
void *addr = reinterpret_cast<void *>(MEM_ADDR);
aclrtStream stream = reinterpret_cast<void *>(0xfffff8c0U);
MOCKER(&aclrtMallocImplOrigin)
.expects(exactly(2))
.with(outBoundP(&addr, sizeof(void *)), any(), any())
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtMemcpyAsyncImplOrigin)
.defaults()
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtSynchronizeStreamImplOrigin)
.defaults()
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtCtxGetCurrentDefaultStreamImplOrigin)
.expects(once())
.with(outBoundP(&stream, sizeof(stream)))
.will(returnValue(ACL_SUCCESS));
ret = inst.Backup();
ASSERT_EQ(ret, true);
ret = inst.Restore();
ASSERT_EQ(ret, true);
MemoryContext::Instance().DiscardAll();
count = 0U;
ASSERT_EQ(inst.memSectionMap_.size(), count);
GlobalMockObject::verify();
}
TEST(MemoryContext, create_snapshot_failed_for_invalid_addr)
{
auto &inst = MemoryContext::Instance();
uint32_t count = 0U;
MemoryContext::Instance().DiscardAll();
int a = 10;
MemoryContext::Instance().Append(&a, MEM_SIZE);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
bool ret = false;
ret = inst.Backup();
ASSERT_EQ(ret, false);
MemoryContext::Instance().DiscardAll();
count = 0U;
ASSERT_EQ(inst.memSectionMap_.size(), count);
}
TEST(MemoryContext, create_snapshot_failed_for_invalid_section)
{
auto &inst = MemoryContext::Instance();
MemoryContext::Instance().DiscardAll();
uint32_t count = 0U;
MemoryContext::MemSection section;
section.originAddr = nullptr;
ASSERT_EQ(MemoryContext::Instance().CreateSnapshot(section), false);
}
TEST(MemoryContext, create_snapshot_failed_due_to_alloc_failure)
{
GlobalMockObject::verify();
auto &inst = MemoryContext::Instance();
uint32_t count = 0U;
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR), MEM_SIZE);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
bool ret = false;
MOCKER(aclrtMallocImplOrigin)
.expects(once())
.will(returnValue(ACL_ERROR_BAD_ALLOC));
ret = inst.Backup();
ASSERT_EQ(ret, false);
MemoryContext::Instance().DiscardAll();
count = 0U;
ASSERT_EQ(inst.memSectionMap_.size(), count);
GlobalMockObject::verify();
}
TEST(MemoryContext, create_snapshot_without_stream_failed_due_to_memcpy_failure)
{
GlobalMockObject::verify();
auto &inst = MemoryContext::Instance();
uint32_t count = 0U;
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR), MEM_SIZE);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
bool ret = false;
void *addr = reinterpret_cast<void *>(MEM_ADDR);
MOCKER(&aclrtMallocImplOrigin)
.expects(once())
.with(outBoundP(&addr, sizeof(void *)), any(), any())
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtMemcpyImplOrigin)
.expects(once())
.will(returnValue(ACL_ERROR_BAD_ALLOC));
MOCKER(&aclrtCtxGetCurrentDefaultStreamImplOrigin)
.defaults()
.will(returnValue(ACL_ERROR_BAD_ALLOC));
ret = inst.Backup();
ASSERT_EQ(ret, false);
MemoryContext::Instance().DiscardAll();
count = 0U;
ASSERT_EQ(inst.memSectionMap_.size(), count);
GlobalMockObject::verify();
}
TEST(MemoryContext, create_snapshot_with_stream_failed_due_to_memcpy_failure)
{
GlobalMockObject::verify();
auto &inst = MemoryContext::Instance();
uint32_t count = 0U;
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR), MEM_SIZE);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
bool ret = false;
void *addr = reinterpret_cast<void *>(MEM_ADDR);
rtStream_t stream = reinterpret_cast<void *>(0xfffff8c0U);
MOCKER(&aclrtMallocImplOrigin)
.expects(once())
.with(outBoundP(&addr, sizeof(void *)), any(), any())
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtMemcpyAsyncImplOrigin)
.expects(once())
.will(returnValue(ACL_ERROR_BAD_ALLOC));
MOCKER(&aclrtCtxGetCurrentDefaultStreamImplOrigin)
.defaults()
.with(outBoundP(&stream, sizeof(stream)))
.will(returnValue(ACL_SUCCESS));
ret = inst.Backup();
ASSERT_EQ(ret, false);
MemoryContext::Instance().DiscardAll();
count = 0U;
ASSERT_EQ(inst.memSectionMap_.size(), count);
GlobalMockObject::verify();
}
TEST(MemoryContext, memcpy_with_stream_success)
{
GlobalMockObject::verify();
auto &inst = MemoryContext::Instance();
rtStream_t stream = reinterpret_cast<void *>(0xfffff8c0U);
MemoryContext::MemSection section = {reinterpret_cast<void *>(MEM_ADDR),
reinterpret_cast<void *>(MEM_ADDR + MEM_SIZE), MEM_SIZE, aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST};
bool ret = false;
MOCKER(&aclrtMemcpyAsyncImplOrigin)
.defaults()
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtSynchronizeStreamImplOrigin)
.defaults()
.will(returnValue(ACL_SUCCESS));
ret = inst.MemCopyWithStream(section, stream,
MemoryContext::MemCopyDirection::ORIGIN_TO_SNAPSHOT);
ASSERT_EQ(ret, true);
ret = inst.MemCopyWithStream(section, stream,
MemoryContext::MemCopyDirection::SNAPSHOT_TO_ORIGIN);
ASSERT_EQ(ret, true);
MemoryContext::Instance().DiscardAll();
GlobalMockObject::verify();
}
TEST(MemoryContext, memcpy_with_stream_failed_due_to_stream_sync_failure)
{
GlobalMockObject::verify();
auto &inst = MemoryContext::Instance();
rtStream_t stream = reinterpret_cast<void *>(0xfffff8c0U);
MemoryContext::MemSection section = {reinterpret_cast<void *>(MEM_ADDR),
reinterpret_cast<void *>(MEM_ADDR + MEM_SIZE), MEM_SIZE};
bool ret = false;
MOCKER(&aclrtMemcpyAsyncImplOrigin)
.defaults()
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtSynchronizeStreamImplOrigin)
.defaults()
.will(returnValue(ACL_ERROR_BAD_ALLOC));
ret = inst.MemCopyWithStream(section, stream,
MemoryContext::MemCopyDirection::ORIGIN_TO_SNAPSHOT);
ASSERT_EQ(ret, false);
ret = inst.MemCopyWithStream(section, stream,
MemoryContext::MemCopyDirection::SNAPSHOT_TO_ORIGIN);
ASSERT_EQ(ret, false);
MemoryContext::Instance().DiscardAll();
GlobalMockObject::verify();
}
TEST(MemoryContext, memcpy_with_stream_failed_due_to_memcpyasync_failure)
{
GlobalMockObject::verify();
auto &inst = MemoryContext::Instance();
rtStream_t stream = reinterpret_cast<void *>(0xfffff8c0U);
MemoryContext::MemSection section = {reinterpret_cast<void *>(MEM_ADDR),
reinterpret_cast<void *>(MEM_ADDR + MEM_SIZE), MEM_SIZE};
bool ret = false;
MOCKER(&aclrtMemcpyAsyncImplOrigin)
.defaults()
.will(returnValue(ACL_ERROR_BAD_ALLOC));
ret = inst.MemCopyWithStream(section, stream,
MemoryContext::MemCopyDirection::ORIGIN_TO_SNAPSHOT);
ASSERT_EQ(ret, false);
ret = inst.MemCopyWithStream(section, stream,
MemoryContext::MemCopyDirection::SNAPSHOT_TO_ORIGIN);
ASSERT_EQ(ret, false);
MemoryContext::Instance().DiscardAll();
GlobalMockObject::verify();
}
TEST(MemoryContext, memcpy_sync_success)
{
GlobalMockObject::verify();
auto &inst = MemoryContext::Instance();
MemoryContext::MemSection section = {reinterpret_cast<void *>(MEM_ADDR),
reinterpret_cast<void *>(MEM_ADDR + MEM_SIZE), MEM_SIZE};
bool ret = false;
MOCKER(&aclrtMemcpyImplOrigin)
.defaults()
.will(returnValue(ACL_SUCCESS));
ret = inst.MemCopySync(section,
MemoryContext::MemCopyDirection::ORIGIN_TO_SNAPSHOT);
ASSERT_EQ(ret, true);
ret = inst.MemCopySync(section,
MemoryContext::MemCopyDirection::SNAPSHOT_TO_ORIGIN);
ASSERT_EQ(ret, true);
MemoryContext::Instance().DiscardAll();
GlobalMockObject::verify();
}
TEST(MemoryContext, memcpy_sync_failed_due_to_memcpy_failure)
{
GlobalMockObject::verify();
auto &inst = MemoryContext::Instance();
MemoryContext::MemSection section = {reinterpret_cast<void *>(MEM_ADDR),
reinterpret_cast<void *>(MEM_ADDR + MEM_SIZE), MEM_SIZE};
bool ret = false;
MOCKER(&aclrtMemcpyImplOrigin)
.defaults()
.will(returnValue(ACL_ERROR_BAD_ALLOC));
ret = inst.MemCopySync(section,
MemoryContext::MemCopyDirection::ORIGIN_TO_SNAPSHOT);
ASSERT_EQ(ret, false);
ret = inst.MemCopySync(section,
MemoryContext::MemCopyDirection::SNAPSHOT_TO_ORIGIN);
ASSERT_EQ(ret, false);
MemoryContext::Instance().DiscardAll();
GlobalMockObject::verify();
}
TEST(MemoryContext, restore_failed_due_to_memcpy_failure)
{
auto &inst = MemoryContext::Instance();
uint32_t count = 0U;
MemoryContext::Instance().Append(reinterpret_cast<void *>(MEM_ADDR), MEM_SIZE);
count++;
ASSERT_EQ(inst.memSectionMap_.size(), count);
bool ret = false;
void *addr = reinterpret_cast<void *>(MEM_ADDR);
rtStream_t stream = reinterpret_cast<void *>(0xfffff8c0U);
MOCKER(&aclrtMallocImplOrigin)
.defaults()
.with(outBoundP(&addr, sizeof(void *)), any(), any())
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtMemcpyAsyncImplOrigin)
.defaults()
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtFreeImplOrigin)
.defaults()
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtCtxGetCurrentDefaultStreamImplOrigin)
.defaults()
.with(outBoundP(&stream, sizeof(stream)))
.will(returnValue(ACL_SUCCESS));
MOCKER(&aclrtSynchronizeStreamImplOrigin)
.defaults()
.will(returnValue(ACL_SUCCESS));
ret = inst.Backup();
ASSERT_EQ(ret, true);
MOCKER(&aclrtMemcpyAsyncImplOrigin)
.expects(once())
.will(returnValue(ACL_ERROR_BAD_ALLOC));
ret = inst.Restore();
ASSERT_EQ(ret, false);
MemoryContext::Instance().DiscardAll();
count = 0U;
ASSERT_EQ(inst.memSectionMap_.size(), count);
GlobalMockObject::verify();
}
TEST(MemoryContext, restore_failed_due_to_invalid_addr)
{
auto &inst = MemoryContext::Instance();
MemoryContext::MemSection section = {nullptr,
reinterpret_cast<void *>(MEM_ADDR + MEM_SIZE), MEM_SIZE};
bool ret = false;
ret = inst.RestoreFromSnapshot(section);
ASSERT_EQ(ret, false);
section.originAddr = reinterpret_cast<void *>(MEM_ADDR);
section.snapshotAddr = nullptr;
ret = inst.RestoreFromSnapshot(section);
ASSERT_EQ(ret, false);
}