/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include <algorithm>
#include "shmemi_logger.h"
#include "dl_api.h"
#include "dl_acl_api.h"
#include "hybm_device_mem_segment.h"
#include "hybm_ex_info_transfer.h"
#include "mem_entity_default.h"
#include "mem_entity_inter.h"

namespace shm {

thread_local bool MemEntityDefault::isSetDevice_ = false;

MemEntityDefault::MemEntityDefault(int id) noexcept : id_(id), initialized(false) {}

MemEntityDefault::~MemEntityDefault()
{
    SHM_LOG_INFO("Deconstruct MemEntity begin, try to release resource.");
    ReleaseResources();
}

int32_t MemEntityDefault::Initialize(const hybm_options *options) noexcept
{
    std::lock_guard<std::mutex> lock(mutex_);
    if (initialized) {
        SHM_LOG_WARN("The MemEntity has already been initialized, no action needs.");
        return ACLSHMEM_SUCCESS;
    }
    SHM_VALIDATE_RETURN((id_ >= 0 && (uint32_t)(id_) < HYBM_ENTITY_NUM_MAX),
                       "input entity id is invalid, input: " << id_ << " must be less than: " << HYBM_ENTITY_NUM_MAX,
                       ACLSHMEM_INVALID_PARAM);

    SHM_LOG_ERROR_RETURN_IT_IF_NOT_OK(CheckOptions(options), "check options failed.");

    options_ = *options;

    SHM_LOG_ERROR_RETURN_IT_IF_NOT_OK(LoadExtendLibrary(), "LoadExtendLibrary failed.");
    SHM_LOG_ERROR_RETURN_IT_IF_NOT_OK(InitSegment(), "InitSegment failed.");

    auto ret = InitTransManager();
    if (ret != ACLSHMEM_SUCCESS) {
        SHM_LOG_ERROR("init transport manager failed");
        return ret;
    }

    initialized = true;
    return ACLSHMEM_SUCCESS;
}

int32_t MemEntityDefault::SetThreadAclDevice()
{
    if (isSetDevice_) {
        return ACLSHMEM_SUCCESS;
    }
    auto ret = DlAclApi::AclrtSetDevice(HybmGetInitDeviceId());
    if (ret != ACLSHMEM_SUCCESS) {
        SHM_LOG_ERROR("Set device id to be " << HybmGetInitDeviceId() << " failed: " << ret);
        return ret;
    }
    isSetDevice_ = true;
    SHM_LOG_DEBUG("Set device id to be " << HybmGetInitDeviceId() << " success.");
    return ACLSHMEM_SUCCESS;
}

void MemEntityDefault::UnInitialize() noexcept
{
    SHM_LOG_INFO("MemEntity UnInitialize begin, try to release resource.");
    ReleaseResources();
}

int32_t MemEntityDefault::ReserveMemorySpace(void **reservedMem) noexcept
{
    std::lock_guard<std::mutex> lock(mutex_);
    if (!initialized) {
        SHM_LOG_INFO("the object is not initialized, please check whether Initialize is called.");
        return ACLSHMEM_NOT_INITED;
    }
    if (hbmSegment_ != nullptr) {
        auto ret = hbmSegment_->ReserveMemorySpace(&hbmGva_);
        if (ret != ACLSHMEM_SUCCESS) {
            SHM_LOG_ERROR("Failed to reserver HBM memory space ret: " << ret);
            return ret;
        }
        *reservedMem = hbmGva_;
    }
    if (dramSegment_ != nullptr) {
        auto ret = dramSegment_->ReserveMemorySpace(&dramGva_);
        if (ret != ACLSHMEM_SUCCESS) {
            UnReserveMemorySpace();
            SHM_LOG_ERROR("Failed to reserver DRAM memory space ret: " << ret);
            return ret;
        }
        *reservedMem = dramGva_;
    }
    return ACLSHMEM_SUCCESS;
}

int32_t MemEntityDefault::UnReserveMemorySpace() noexcept
{
    std::lock_guard<std::mutex> lock(mutex_);
    if (!initialized) {
        SHM_LOG_INFO("the object is not initialized, please check whether Initialize is called.");
        return ACLSHMEM_NOT_INITED;
    }

    if (transportManager_) {
        transportManager_->CloseDevice();
        transportManager_ = nullptr;
    }

    if (hbmSegment_ != nullptr) {
        hbmSegment_->UnReserveMemorySpace();
        hbmGva_ = nullptr;
    }
    if (dramSegment_ != nullptr) {
        dramSegment_->UnReserveMemorySpace();
        dramGva_ = nullptr;
    }
    return ACLSHMEM_SUCCESS;
}

int32_t MemEntityDefault::AllocLocalMemory(uint64_t size, hybm_mem_type mType, uint32_t flags, hybm_mem_slice_t &slice) noexcept
{
    std::lock_guard<std::mutex> lock(mutex_);
    if (!initialized) {
        SHM_LOG_INFO("the object is not initialized, please check whether Initialize is called.");
        return ACLSHMEM_NOT_INITED;
    }

    if ((size % DEVICE_LARGE_PAGE_SIZE) != 0) {
        SHM_LOG_ERROR("allocate memory size: " << size << " invalid, page size is: " << DEVICE_LARGE_PAGE_SIZE);
        return ACLSHMEM_INVALID_PARAM;
    }

    auto segment = mType == HYBM_MEM_TYPE_DEVICE ? hbmSegment_ : dramSegment_;
    if (segment == nullptr) {
        SHM_LOG_ERROR("allocate memory with mType: " << mType << ", no segment.");
        return ACLSHMEM_INVALID_PARAM;
    }

    std::shared_ptr<MemSlice> realSlice;
    auto ret = segment->AllocLocalMemory(size, realSlice);
    if (ret != 0) {
        SHM_LOG_ERROR("segment allocate slice with size: " << size << " failed: " << ret);
        return ret;
    }

    slice = realSlice->ConvertToId();
    transport::TransportMemoryRegion info;
    info.size = realSlice->size_;
    info.addr = realSlice->vAddress_;
    info.access = transport::REG_MR_ACCESS_FLAG_BOTH_READ_WRITE;
    info.flags = 
        segment->GetMemoryType() == HYBM_MEM_TYPE_DEVICE ? transport::REG_MR_FLAG_HBM : transport::REG_MR_FLAG_DRAM;
    if (transportManager_ != nullptr) {
        ret = transportManager_->RegisterMemoryRegion(info);
        if (ret != 0) {
            SHM_LOG_ERROR("register memory region allocate failed: " << ret << ", info: " << info);
            auto res = segment->ReleaseSliceMemory(realSlice);
            if (res != ACLSHMEM_SUCCESS) {
                SHM_LOG_ERROR("failed to release slice memory: " << res);
            }
            return ret;
        }
    }

    return UpdateHybmDeviceInfo(0);
}

int32_t MemEntityDefault::RegisterLocalMemory(const void *ptr, uint64_t size, uint32_t flags,
                                              hybm_mem_slice_t &slice) noexcept
{
    if (ptr == nullptr || size == 0) {
        SHM_LOG_ERROR("input ptr or size(" << size << ") is invalid");
        return ACLSHMEM_INVALID_PARAM;
    }

    if ((size % DEVICE_LARGE_PAGE_SIZE) != 0) {
        uint64_t originalSize = size;
        size = ((size + DEVICE_LARGE_PAGE_SIZE - 1) / DEVICE_LARGE_PAGE_SIZE) * DEVICE_LARGE_PAGE_SIZE;
        SHM_LOG_INFO("size: " << originalSize << " not aligned to large page (" << DEVICE_LARGE_PAGE_SIZE <<
                    "), rounded to: " << size);
    }

    auto addr = static_cast<uint64_t>(reinterpret_cast<ptrdiff_t>(ptr));
    bool isHbm = (addr >= HYBM_DEVICE_VA_START && addr < HYBM_DEVICE_VA_END);
    SHM_LOG_INFO("Hbm: " << isHbm << std::hex << ", addrs: 0x" << addr
                        << ", start: 0x" << HYBM_DEVICE_VA_START << ", end: 0x" << HYBM_DEVICE_VA_END);
    std::shared_ptr<MemSegment> segment = nullptr;
    if (dramSegment_ == nullptr) {
        segment = hbmSegment_;
    } else {
        segment = dramSegment_;
    }
    SHM_VALIDATE_RETURN(segment != nullptr, "address for segment is null.", ACLSHMEM_INVALID_PARAM);

    std::shared_ptr<MemSlice> realSlice;
    auto ret = segment->RegisterMemory(ptr, size, realSlice);
    if (ret != 0) {
        SHM_LOG_ERROR("segment register slice with size: " << size << " failed: " << ret);
        return ret;
    }

    if (transportManager_ != nullptr) {
        transport::TransportMemoryRegion mr;
        mr.addr = (uint64_t)(ptrdiff_t)ptr;
        mr.size = size;
        mr.flags = (isHbm ? transport::REG_MR_FLAG_HBM : transport::REG_MR_FLAG_DRAM);
        ret = transportManager_->RegisterMemoryRegion(mr);
        if (ret != 0) {
            SHM_LOG_ERROR("register MR: " << mr << " to transport failed: " << ret);
            return ret;
        }
    }

    slice = realSlice->ConvertToId();
    return ACLSHMEM_SUCCESS;
}

int32_t MemEntityDefault::FreeLocalMemory(hybm_mem_slice_t slice, uint32_t flags) noexcept
{
    if (!initialized) {
        SHM_LOG_INFO("the object is not initialized, please check whether Initialize is called.");
        return ACLSHMEM_INVALID_PARAM;
    }

    std::shared_ptr<MemSlice> memSlice;
    if (hbmSegment_ != nullptr && (memSlice = hbmSegment_->GetMemSlice(slice)) != nullptr) {
        hbmSegment_->ReleaseSliceMemory(memSlice);
    } else if (dramSegment_ != nullptr && (memSlice = dramSegment_->GetMemSlice(slice)) != nullptr) {
        dramSegment_->ReleaseSliceMemory(memSlice);
    }

    if (transportManager_ != nullptr && memSlice != nullptr) {
        auto ret = transportManager_->UnregisterMemoryRegion(memSlice->vAddress_);
        if (ret != ACLSHMEM_SUCCESS) {
            SHM_LOG_ERROR("UnregisterMemoryRegion failed, please check input slice.");
        }
    }
    SHM_LOG_DEBUG("free local memory successed.");
    return ACLSHMEM_SUCCESS;
}

int32_t MemEntityDefault::ExportExchangeInfo(ExchangeInfoWriter &desc, uint32_t flags) noexcept
{
    if (!initialized) {
        SHM_LOG_INFO("the object is not initialized, please check whether Initialize is called.");
        return ACLSHMEM_NOT_INITED;
    }

    std::string info;
    EntityExportInfo exportInfo;
    exportInfo.version = EXPORT_INFO_VERSION;
    exportInfo.rankId = options_.rankId;
    exportInfo.role = static_cast<uint16_t>(options_.role);
    if (transportManager_ != nullptr) {
        auto &nic = transportManager_->GetNic();
        if (nic.size() >= sizeof(exportInfo.nic)) {
            SHM_LOG_ERROR("transport get nic(" << nic << ") too long.");
            return ACLSHMEM_INNER_ERROR;
        }
        size_t copyLen = std::min(nic.size(), sizeof(exportInfo.nic));
        std::copy_n(nic.c_str(), copyLen, exportInfo.nic);
        auto ret = LiteralExInfoTranslater<EntityExportInfo>{}.Serialize(exportInfo, info);
        if (ret != ACLSHMEM_SUCCESS) {
            SHM_LOG_ERROR("export info failed: " << ret);
            return ACLSHMEM_INNER_ERROR;
        }
    }

    auto ret = desc.Append(info.data(), info.size());
    if (ret != 0) {
        SHM_LOG_ERROR("export to string wrong size: " << info.size());
        return ret;
    }

    return ACLSHMEM_SUCCESS;
}

int32_t MemEntityDefault::ExportExchangeInfo(hybm_mem_slice_t slice, ExchangeInfoWriter &desc, uint32_t flags) noexcept
{
    if (!initialized) {
        SHM_LOG_INFO("the object is not initialized, please check whether Initialize is called.");
        return ACLSHMEM_NOT_INITED;
    }
    if (slice == nullptr) {
        return ExportWithoutSlice(desc, flags);
    }

    return ExportWithSlice(slice, desc, flags);
}

int32_t MemEntityDefault::ImportExchangeInfo(const ExchangeInfoReader desc[], uint32_t count, void *addresses[],
                                             uint32_t flags) noexcept
{
    if (!initialized) {
        SHM_LOG_ERROR("the object is not initialized, please check whether Initialize is called.");
        return ACLSHMEM_NOT_INITED;
    }

    auto ret = SetThreadAclDevice();
    if (ret != ACLSHMEM_SUCCESS) {
        return ACLSHMEM_INNER_ERROR;
    }

    if (desc == nullptr) {
        SHM_LOG_ERROR("the input desc is nullptr.");
        return ACLSHMEM_INNER_ERROR;
    }

    std::unordered_map<uint32_t, std::vector<transport::TransportMemoryKey>> tempKeyMap;
    ret = ImportForTransport(desc, count, flags);
    if (ret != ACLSHMEM_SUCCESS) {
        return ret;
    }

    if (desc[0].LeftBytes() == 0) {
        SHM_LOG_INFO("no segment need import.");
        return ACLSHMEM_SUCCESS;
    }

    uint64_t magic;
    if (desc[0].Test(magic) < 0) {
        SHM_LOG_ERROR("left import data no magic size.");
        return ACLSHMEM_SUCCESS;
    }

    auto currentSegment = magic == DRAM_SLICE_EXPORT_INFO_MAGIC ? dramSegment_ : hbmSegment_;
    std::vector<std::string> infos;
    for (auto i = 0U; i < count; i++) {
        infos.emplace_back(desc[i].LeftToString());
    }

    ret = currentSegment->Import(infos, addresses);
    if (ret != ACLSHMEM_SUCCESS) {
        SHM_LOG_ERROR("segment import infos failed: " << ret);
        return ret;
    }

    return ACLSHMEM_SUCCESS;
}

int32_t MemEntityDefault::GetExportSliceInfoSize(size_t &size) noexcept
{
    size_t exportSize = 0;
    auto segment = hbmSegment_ == nullptr ? dramSegment_ : hbmSegment_;
    if (segment == nullptr) {
        SHM_LOG_ERROR("segment is null.");
        return ACLSHMEM_INNER_ERROR;
    }
    auto ret = segment->GetExportSliceSize(exportSize);
    if (ret != 0) {
        SHM_LOG_ERROR("GetExportSliceSize for segment failed: " << ret);
        return ret;
    }

    if (transportManager_ != nullptr) {
        exportSize += sizeof(SliceExportTransportKey);
    }
    size = exportSize;
    return ACLSHMEM_SUCCESS;
}

int32_t MemEntityDefault::SetExtraContext(const void *context, uint32_t size) noexcept
{
    if (!initialized) {
        SHM_LOG_INFO("the object is not initialized, please check whether Initialize is called.");
        return ACLSHMEM_NOT_INITED;
    }

    SHM_ASSERT_RETURN(context != nullptr, ACLSHMEM_INVALID_PARAM);
    if (size > HYBM_DEVICE_USER_CONTEXT_PRE_SIZE) {
        SHM_LOG_ERROR("set extra context failed, context size is too large: " << size << " limit: "
                                                                             << HYBM_DEVICE_USER_CONTEXT_PRE_SIZE);
        return ACLSHMEM_INVALID_PARAM;
    }

    uint64_t addr = HYBM_DEVICE_USER_CONTEXT_ADDR + id_ * HYBM_DEVICE_USER_CONTEXT_PRE_SIZE;
    SHM_LOG_DEBUG("set extra context to addr: 0x" << std::hex << addr << " size: " << size);
    auto ret = DlAclApi::AclrtMemcpy((void *)addr, HYBM_DEVICE_USER_CONTEXT_PRE_SIZE, context, size,
                                     ACL_MEMCPY_HOST_TO_DEVICE);
    if (ret != ACLSHMEM_SUCCESS) {
        SHM_LOG_ERROR("memcpy user context failed, ret: " << ret);
        return ACLSHMEM_INNER_ERROR;
    }

    return UpdateHybmDeviceInfo(size);
}

void MemEntityDefault::Unmap() noexcept
{
    if (!initialized) {
        SHM_LOG_INFO("the object is not initialized, please check whether Initialize is called.");
        return;
    }

    if (hbmSegment_ != nullptr) {
        hbmSegment_->Unmap();
    }
    if (dramSegment_ != nullptr) {
        dramSegment_->Unmap();
    }
}

int32_t MemEntityDefault::Mmap() noexcept
{
    if (!initialized) {
        SHM_LOG_INFO("the object is not initialized, please check whether Initialize is called.");
        return ACLSHMEM_NOT_INITED;
    }

    if (hbmSegment_ != nullptr) {
        auto ret = hbmSegment_->Mmap();
        if (ret != ACLSHMEM_SUCCESS) {
            return ret;
        }
    }

    if (dramSegment_ != nullptr) {
        auto ret = dramSegment_->Mmap();
        if (ret != ACLSHMEM_SUCCESS) {
            return ret;
        }
    }
    return ACLSHMEM_SUCCESS;
}

int32_t MemEntityDefault::RemoveImported(const std::vector<uint32_t> &ranks) noexcept
{
    if (!initialized) {
        SHM_LOG_INFO("the object is not initialized, please check whether Initialize is called.");
        return ACLSHMEM_NOT_INITED;
    }

    if (hbmSegment_ != nullptr) {
        auto ret = hbmSegment_->RemoveImported(ranks);
        if (ret != ACLSHMEM_SUCCESS) {
            return ret;
        }
    }

    if (dramSegment_ != nullptr) {
        auto ret = dramSegment_->RemoveImported(ranks);
        if (ret != ACLSHMEM_SUCCESS) {
            return ret;
        }
    }

    return ACLSHMEM_SUCCESS;
}

bool MemEntityDefault::CheckAddressInEntity(const void *ptr, uint64_t length) const noexcept
{
    if (!initialized) {
        SHM_LOG_ERROR("the object is not initialized, please check whether Initialize is called.");
        return false;
    }

    if (hbmSegment_ != nullptr && hbmSegment_->MemoryInRange(ptr, length)) {
        return true;
    }

    if (dramSegment_ != nullptr && dramSegment_->MemoryInRange(ptr, length)) {
        return true;
    }

    return false;
}

int MemEntityDefault::CheckOptions(const hybm_options *options) noexcept
{
    if (options == nullptr) {
        SHM_LOG_ERROR("initialize with nullptr.");
        return ACLSHMEM_INVALID_PARAM;
    }

    if (options->rankId >= options->rankCount) {
        SHM_LOG_ERROR("local rank id: " << options->rankId << " invalid, total is " << options->rankCount);
        return ACLSHMEM_INVALID_PARAM;
    }

    return ACLSHMEM_SUCCESS;
}

int MemEntityDefault::LoadExtendLibrary() noexcept
{
    if (options_.bmDataOpType & HYBM_DOP_TYPE_DEVICE_RDMA) {
        auto ret = DlApi::LoadExtendLibrary(DL_EXT_LIB_DEVICE_RDMA);
        if (ret != 0) {
            SHM_LOG_ERROR("LoadExtendLibrary for DEVICE RDMA failed: " << ret);
            return ret;
        }
    }

    if (options_.bmDataOpType & HYBM_DOP_TYPE_DEVICE_SDMA) {
        auto ret = DlApi::LoadExtendLibrary(DL_EXT_LIB_DEVICE_SDMA);
        if (ret != 0) {
            SHM_LOG_ERROR("LoadExtendLibrary for DEVICE SDMA failed: " << ret);
            return ret;
        }
    }

    if (options_.bmDataOpType & HYBM_DOP_TYPE_DEVICE_UDMA) {
#if defined(ACLSHMEM_UDMA_SUPPORT)
        auto ret = DlApi::LoadExtendLibrary(DL_EXT_LIB_DEVICE_UDMA);
        if (ret != 0) {
            SHM_LOG_ERROR("LoadExtendLibrary for DEVICE UDMA failed: " << ret);
            return ret;
        }
#else
        SHM_LOG_ERROR("DEVICE UDMA support is not enabled in this build.");
        return ACLSHMEM_NOT_SUPPORTED;
#endif
    }

    return ACLSHMEM_SUCCESS;
}

int MemEntityDefault::UpdateHybmDeviceInfo(uint32_t extCtxSize) noexcept
{
    HybmDeviceMeta info;
    auto addr = HYBM_DEVICE_META_ADDR + HYBM_DEVICE_GLOBAL_META_SIZE + id_ * HYBM_DEVICE_PRE_META_SIZE;

    SetHybmDeviceInfo(info);
    info.extraContextSize = extCtxSize;
    auto ret = DlAclApi::AclrtMemcpy((void *)addr, DEVICE_LARGE_PAGE_SIZE, &info, sizeof(HybmDeviceMeta),
                                     ACL_MEMCPY_HOST_TO_DEVICE);
    if (ret != ACLSHMEM_SUCCESS) {
        SHM_LOG_ERROR("update hybm info memory failed, ret: " << ret << ", addr:" << (void *)addr);
        return ACLSHMEM_INNER_ERROR;
    }
    SHM_LOG_DEBUG("update hybm info memory success, addr:" << (void *)addr);
    return ACLSHMEM_SUCCESS;
}

void MemEntityDefault::SetHybmDeviceInfo(HybmDeviceMeta &info)
{
    info.entityId = id_;
    info.rankId = options_.rankId;
    info.rankSize = options_.rankCount;
    info.symmetricSize = options_.deviceVASpace;
    info.extraContextSize = 0;
    if (transportManager_ != nullptr) {
        info.qpInfoAddress = (uint64_t)(ptrdiff_t)transportManager_->GetQpInfo();
    } else {
        info.qpInfoAddress = 0UL;
    }
}

int32_t MemEntityDefault::ExportWithSlice(hybm_mem_slice_t slice, ExchangeInfoWriter &desc, uint32_t flags) noexcept
{
    uint64_t exportMagic = 0;
    std::string info;
    std::shared_ptr<MemSlice> realSlice;
    std::shared_ptr<MemSegment> currentSegment;
    if (hbmSegment_ != nullptr) {
        realSlice = hbmSegment_->GetMemSlice(slice);
        currentSegment = hbmSegment_;
        exportMagic = HBM_SLICE_EXPORT_INFO_MAGIC;
    }
    if (realSlice == nullptr && dramSegment_ != nullptr) {
        realSlice = dramSegment_->GetMemSlice(slice);
        currentSegment = dramSegment_;
        exportMagic = DRAM_SLICE_EXPORT_INFO_MAGIC;
    }
    if (realSlice == nullptr) {
        SHM_LOG_ERROR("cannot find input slice for export.");
        return ACLSHMEM_INVALID_PARAM;
    }

    auto ret = currentSegment->Export(realSlice, info);
    if (ret != 0) {
        SHM_LOG_ERROR("export to string failed: " << ret);
        return ret;
    }

    if (transportManager_ != nullptr) {
        SliceExportTransportKey transportKey{exportMagic, options_.rankId, realSlice->vAddress_};
        ret = transportManager_->QueryMemoryKey(realSlice->vAddress_, transportKey.key);
        if (ret != 0) {
            SHM_LOG_ERROR("query memory key when export slice failed: " << ret);
            return ret;
        }
        ret = desc.Append(transportKey);
        if (ret != 0) {
            SHM_LOG_ERROR("append transport key failed: " << ret);
            return ret;
        }
    }
    static std::mutex debug_mutex;
    std::lock_guard<std::mutex> lock(debug_mutex);
    int status = desc.Append(info.data(), info.size());
    if (status != 0) {
        SHM_LOG_ERROR("export to string wrong size: " << info.size() << " ret: " << status);
        return ret;
    }
    return ACLSHMEM_SUCCESS;
}

int32_t MemEntityDefault::ExportWithoutSlice(ExchangeInfoWriter &desc, uint32_t flags) noexcept
{
    std::string info;
    int32_t ret = ACLSHMEM_INNER_ERROR;
    ret = ExportExchangeInfo(desc, flags);
    if (ret != 0) {
        SHM_LOG_ERROR("ExportExchangeInfo failed: " << ret);
        return ret;
    }

    ret = desc.Append(info.data(), info.size());
    if (ret != 0) {
        SHM_LOG_ERROR("export to string wrong size: " << info.size());
        return ret;
    }

    return ACLSHMEM_SUCCESS;
}

int32_t MemEntityDefault::ImportForTransportPrecheck(const ExchangeInfoReader desc[],
                                                     uint32_t &count,
                                                     bool &importInfoEntity)
{
    int ret = ACLSHMEM_SUCCESS;
    uint64_t magic;
    EntityExportInfo entityExportInfo;
    SliceExportTransportKey transportKey;
    for (auto i = 0U; i < count; i++) {
        ret = desc[i].Test(magic);
        if (ret != 0) {
            SHM_LOG_ERROR("read magic from import : " << i << " failed.");
            return ACLSHMEM_INNER_ERROR;
        }

        if (magic == ENTITY_EXPORT_INFO_MAGIC) {
            ret = desc[i].Read(entityExportInfo);
            if (ret == 0) {
                importedRanks_[entityExportInfo.rankId] = entityExportInfo;
                importInfoEntity = true;
            }
        } else if (magic == DRAM_SLICE_EXPORT_INFO_MAGIC || magic == HBM_SLICE_EXPORT_INFO_MAGIC) {
            ret = desc[i].Read(transportKey);
            if (ret == 0) {
                std::unique_lock<std::mutex> uniqueLock{importMutex_};
                importedMemories_[transportKey.rankId][transportKey.address] = transportKey.key;
            }
        } else {
            SHM_LOG_ERROR("magic(" << std::hex << magic << ") invalid");
            ret = ACLSHMEM_INNER_ERROR;
        }

        if (ret != 0) {
            SHM_LOG_ERROR("read info for transport failed: " << ret);
            return ret;
        }
    }
    return ACLSHMEM_SUCCESS;
}

int32_t MemEntityDefault::ImportForTransport(const ExchangeInfoReader desc[], uint32_t count, uint32_t flags) noexcept
{
    if (transportManager_ == nullptr) {
        return ACLSHMEM_SUCCESS;
    }

    int ret = ACLSHMEM_SUCCESS;
    bool importInfoEntity = false;
    ret = ImportForTransportPrecheck(desc, count, importInfoEntity);
    if (ret != ACLSHMEM_SUCCESS) {
        return ret;
    }

    transport::HybmTransPrepareOptions transOptions;
    std::unique_lock<std::mutex> uniqueLock{importMutex_};
    for (auto &rank : importedRanks_) {
        if (options_.role != HYBM_ROLE_PEER && static_cast<hybm_role_type>(rank.second.role) == options_.role) {
            continue;
        }

        transOptions.options[rank.first].role = static_cast<hybm_role_type>(rank.second.role);
        transOptions.options[rank.first].nic = rank.second.nic;
    }
    for (auto &mr : importedMemories_) {
        auto pos = transOptions.options.find(mr.first);
        if (pos != transOptions.options.end()) {
            for (auto &key : mr.second) {
                pos->second.memKeys.emplace_back(key.second);
            }
        }
    }
    uniqueLock.unlock();

    if (options_.role != HYBM_ROLE_PEER || importInfoEntity) {
        ret = transportManager_->ConnectWithOptions(transOptions);
        if (ret != 0) {
            SHM_LOG_ERROR("Transport Manager ConnectWithOptions failed: " << ret);
            return ret;
        }
        if (importInfoEntity) {
            return UpdateHybmDeviceInfo(0);
        }
    }

    return ACLSHMEM_SUCCESS;
}

Result MemEntityDefault::InitSegment()
{
    SHM_LOG_DEBUG("Initialize segment with type: " << std::hex << options_.memType);
    if (options_.memType & HYBM_MEM_TYPE_DEVICE) {
        auto ret = InitHbmSegment();
        if (ret != ACLSHMEM_SUCCESS) {
            SHM_LOG_ERROR("InitHbmSegment() failed: " << ret);
            return ret;
        }
    }

    if (options_.memType & HYBM_MEM_TYPE_HOST) {
        auto ret = InitDramSegment();
        if (ret != ACLSHMEM_SUCCESS) {
            SHM_LOG_ERROR("InitDramSegment() failed: " << ret);
            return ret;
        }
    }
    return ACLSHMEM_SUCCESS;
}

Result MemEntityDefault::InitHbmSegment()
{
    MemSegmentOptions segmentOptions;
    if (options_.deviceVASpace == 0) {
        SHM_LOG_INFO("Hbm rank space is zero.");
        return ACLSHMEM_SUCCESS;
    }
    segmentOptions.size = options_.deviceVASpace;
    segmentOptions.segType = HYBM_MST_HBM;
    segmentOptions.devId = HybmGetInitDeviceId();
    segmentOptions.role = options_.role;
    segmentOptions.dataOpType = options_.bmDataOpType;
    segmentOptions.rankId = options_.rankId;
    segmentOptions.rankCnt = options_.rankCount;
    hbmSegment_ = MemSegment::Create(segmentOptions, id_);
    SHM_VALIDATE_RETURN(hbmSegment_ != nullptr, "create segment failed", ACLSHMEM_INVALID_PARAM);
    return ACLSHMEM_SUCCESS;
}

Result MemEntityDefault::InitDramSegment()
{
    if (options_.hostVASpace == 0) {
        SHM_LOG_INFO("Dram rank space is zero.");
        return ACLSHMEM_SUCCESS;
    }

    MemSegmentOptions segmentOptions;
    segmentOptions.size = options_.hostVASpace;
    segmentOptions.devId = HybmGetInitDeviceId();
    segmentOptions.segType = HYBM_MST_DRAM;
    segmentOptions.rankId = options_.rankId;
    segmentOptions.rankCnt = options_.rankCount;
    segmentOptions.dataOpType = options_.bmDataOpType;
    segmentOptions.flags = options_.flags;
    if (options_.bmDataOpType & HYBM_DOP_TYPE_DEVICE_RDMA) {
        segmentOptions.shared = false;
    }
    dramSegment_ = MemSegment::Create(segmentOptions, id_);
    if (dramSegment_ == nullptr) {
        SHM_LOG_ERROR("Failed to create dram segment");
        return ACLSHMEM_INNER_ERROR;
    }
    return ACLSHMEM_SUCCESS;
}

Result MemEntityDefault::InitTransManager()
{
    if (options_.rankCount <= 1) {
        SHM_LOG_INFO("rank total count : " << options_.rankCount << ", no transport.");
        return ACLSHMEM_SUCCESS;
    }
    if (((options_.bmDataOpType & HYBM_DOP_TYPE_DEVICE_RDMA) == 0) &&
        ((options_.bmDataOpType & HYBM_DOP_TYPE_DEVICE_SDMA) == 0) &&
        ((options_.bmDataOpType & HYBM_DOP_TYPE_DEVICE_UDMA) == 0)) {
        SHM_LOG_DEBUG("Data operator type mathchs none of RDMA, SDMA or UDMA, skip transport init.");
        return ACLSHMEM_SUCCESS;
    }

    if (options_.bmDataOpType & HYBM_DOP_TYPE_DEVICE_RDMA) {
        transportManager_ = transport::TransportManager::Create(transport::TransportType::TT_HCCP);
    } else if (options_.bmDataOpType & HYBM_DOP_TYPE_DEVICE_SDMA) {
        transportManager_ = transport::TransportManager::Create(transport::TransportType::TT_SDMA);
    } else if (options_.bmDataOpType & HYBM_DOP_TYPE_DEVICE_UDMA) {
#if defined(ACLSHMEM_UDMA_SUPPORT)
        transportManager_ = transport::TransportManager::Create(transport::TransportType::TT_UDMA);
#else
        SHM_LOG_ERROR("DEVICE UDMA support is not enabled in this build.");
        return ACLSHMEM_NOT_SUPPORTED;
#endif
    }

    transport::TransportOptions options;
    options.rankId = options_.rankId;
    options.rankCount = options_.rankCount;
    options.protocol = options_.bmDataOpType;
    options.role = options_.role;
    options.nic = options_.nic;
    auto ret = transportManager_->OpenDevice(options);
    if (ret != 0) {
        SHM_LOG_ERROR("Failed to open device, ret: " << ret);
        transportManager_ = nullptr;
    }
    return ret;
}

bool MemEntityDefault::SdmaReaches(uint32_t remoteRank) const noexcept
{
    if (hbmSegment_ != nullptr) {
        return hbmSegment_->CheckSdmaReaches(remoteRank);
    }

    if (dramSegment_ != nullptr) {
        return dramSegment_->CheckSdmaReaches(remoteRank);
    }

    return false;
}

hybm_data_op_type MemEntityDefault::CanReachDataOperators(uint32_t remoteRank) const noexcept
{
    uint32_t supportDataOp = 0U;
    bool sdmaReach = SdmaReaches(remoteRank);   // SDMA reaches mean MTE reaches too
    if (sdmaReach) {
        supportDataOp |= HYBM_DOP_TYPE_MTE;
    }
    if (sdmaReach && ((options_.bmDataOpType & HYBM_DOP_TYPE_DEVICE_SDMA) != 0)) {
        supportDataOp |= HYBM_DOP_TYPE_DEVICE_SDMA;
    }
    if ((options_.bmDataOpType & HYBM_DOP_TYPE_DEVICE_RDMA) != 0) {
        supportDataOp |= HYBM_DOP_TYPE_DEVICE_RDMA;
    }
    if (sdmaReach && ((options_.bmDataOpType & HYBM_DOP_TYPE_DEVICE_UDMA) != 0)) {
        supportDataOp |= HYBM_DOP_TYPE_DEVICE_UDMA;
    }

    return static_cast<hybm_data_op_type>(supportDataOp);
}

void *MemEntityDefault::GetReservedMemoryPtr(hybm_mem_type memType) noexcept
{
    if (memType == HYBM_MEM_TYPE_DEVICE) {
        return hbmGva_;
    }

    if (memType == HYBM_MEM_TYPE_HOST) {
        return dramGva_;
    }

    return nullptr;
}

void MemEntityDefault::ReleaseResources()
{
    std::lock_guard<std::mutex> lock(mutex_);
    if (!initialized) {
        return;
    }
    if (transportManager_) {
        transportManager_->CloseDevice();
        transportManager_ = nullptr;
    }
    hbmSegment_.reset();
    dramSegment_.reset();
    initialized = false;
}

}  // namespace shm