/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include "dl_comm_def.h"
#include "dl_api.h"
#include "dl_acl_api.h"
#include "shmemi_net_util.h"
#include "hybm_device_mem_segment.h"
#include "host/shmem_host_def.h"
#include "hybm_device_mem_segment.h"
#include "hybm_vmm_based_segment.h"

namespace shm {
bool MemSegment::deviceInfoReady{false};
int MemSegment::deviceId_{-1};
int MemSegment::logicDeviceId_{-1};
uint32_t MemSegment::pid_{0};
uint32_t MemSegment::sdid_{0};
uint32_t MemSegment::serverId_{0};
uint32_t MemSegment::superPodId_{0};
AscendSocType MemSegment::socType_{AscendSocType::ASCEND_UNKNOWN};
std::string MemSegment::sysBoolId_{};
uint32_t MemSegment::bootIdHead_{0};

MemSegmentPtr MemSegment::Create(const MemSegmentOptions &options, int entityId)
{
    if (options.rankId >= options.rankCnt) {
        SHM_LOG_ERROR("rank(" << options.rankId << ") but total " << options.rankCnt);
        return nullptr;
    }

    auto ret = MemSegment::InitDeviceInfo();
    if (ret != ACLSHMEM_SUCCESS) {
        SHM_LOG_ERROR("MemSegment::InitDeviceInfo failed: " << ret);
        return nullptr;
    }

    MemSegmentPtr tmpSeg;
    switch (options.segType) {
#ifdef HAS_ACLRT_MEM_FABRIC_HANDLE
        case HYBM_MST_HBM:
        case HYBM_MST_DRAM:
            tmpSeg = std::make_shared<MemSegmentDevice>(options, entityId);
            break;
#else
        case HYBM_MST_HBM:
            if (socType_ == AscendSocType::ASCEND_950 || (HybmGetGvaVersion() == HYBM_GVA_V4)) {
                if (!CannVersionCheck("8.5")) {
                    SHM_LOG_ERROR("CANN version must be >= 8.5 for HybmVmmBasedSegment");
                    break;
                }
                tmpSeg = std::make_shared<HybmVmmBasedSegment>(options, entityId);
            } else {
                tmpSeg = std::make_shared<MemSegmentDevice>(options, entityId);
            }
            break;
        case HYBM_MST_DRAM:
            SHM_LOG_ERROR("Not support HOST_SIDE malloc now.");
            break;
#endif
        default:
            SHM_LOG_ERROR("Invalid memory seg type " << int(options.segType));
    }
    return tmpSeg;
}

bool MemSegment::CheckSdmaReaches(uint32_t rankId) const noexcept
{
    return false;
}

Result MemSegment::InitDeviceInfo()
{
    if (deviceInfoReady) {
        return ACLSHMEM_SUCCESS;
    }

    auto ret = DlAclApi::AclrtGetDevice(&deviceId_);
    if (ret != 0) {
        SHM_LOG_ERROR("get device id failed: " << ret);
        return ACLSHMEM_DL_FUNC_FAILED;
    }

    ret = DlAclApi::RtGetLogicDevIdByUserDevId(deviceId_, &logicDeviceId_);
    if (ret != 0 || logicDeviceId_ < 0) {
        SHM_LOG_ERROR("Failed to get logic deviceId: " << deviceId_ << ", ret=" << ret);
        return ACLSHMEM_INNER_ERROR;
    }

    ret = DlAclApi::RtDeviceGetBareTgid(&pid_);
    if (ret != ACLSHMEM_SUCCESS) {
        SHM_LOG_ERROR("get bare tgid failed: " << ret);
        return ACLSHMEM_DL_FUNC_FAILED;
    }

    int64_t value = 0;
    ret = DlAclApi::RtGetDeviceInfo(deviceId_, 0, INFO_TYPE_SDID, &value);
    if (ret != ACLSHMEM_SUCCESS) {
        SHM_LOG_ERROR("get sdid failed: " << ret);
        return ACLSHMEM_DL_FUNC_FAILED;
    }

    sdid_ = static_cast<uint32_t>(value);
    ret = DlAclApi::RtGetDeviceInfo(deviceId_, 0, INFO_TYPE_SERVER_ID, &value);
    if (ret != ACLSHMEM_SUCCESS) {
        SHM_LOG_ERROR("get server id failed: " << ret);
        return ACLSHMEM_DL_FUNC_FAILED;
    }
    serverId_ = static_cast<uint32_t>(value);
    SHM_LOG_DEBUG("local server=0x" << std::hex << serverId_);

    ret = DlAclApi::RtGetDeviceInfo(deviceId_, 0, INFO_TYPE_SUPER_POD_ID, &value);
    if (ret != ACLSHMEM_SUCCESS) {
        SHM_LOG_ERROR("get super pod id failed: " << ret);
        return ACLSHMEM_DL_FUNC_FAILED;
    }
    FillSysBootIdInfo();
    superPodId_ = static_cast<uint32_t>(value);
    if (superPodId_ == invalidSuperPodId && serverId_ == invalidServerId) {
        if (bootIdHead_ != 0) {
            serverId_ = bootIdHead_;
        } else {
            auto networks = utils::NetworkGetIpAddresses();
            if (networks.empty()) {
                SHM_LOG_ERROR("get local host ip address empty.");
                return ACLSHMEM_INNER_ERROR;
            }
            serverId_ = networks[0];
        }
    }

    socType_ = DlApi::GetAscendSocType();
    SHM_LOG_DEBUG("local sdid=0x" << std::hex << sdid_ << ", local server=0x" << std::hex << serverId_
                                 << ", spid=" << superPodId_);
    deviceInfoReady = true;
    return ACLSHMEM_SUCCESS;
}

void MemSegment::FillSysBootIdInfo() noexcept
{
    std::string bootIdPath("/proc/sys/kernel/random/boot_id");
    std::ifstream input(bootIdPath);
    input >> sysBoolId_;

    std::stringstream ss(sysBoolId_);
    ss >> std::hex >> bootIdHead_;
    SHM_LOG_DEBUG("os-boot-id: " << sysBoolId_ << ", head u32: " << std::hex << bootIdHead_);
}

bool MemSegment::CanLocalHostReaches(uint32_t superPodId, uint32_t serverId, uint32_t deviceId) noexcept
{
    if (superPodId != superPodId_ || serverId != serverId_) {
        return false;
    }
    return (socType_ != ASCEND_910B) || ((deviceId / ASC910B_CONN_RANKS) == (logicDeviceId_ / ASC910B_CONN_RANKS));
}

bool MemSegment::IsSdmaAccessible(uint32_t superPodId, uint32_t serverId, uint32_t deviceId) noexcept
{
    if (serverId == serverId_) {
        return (socType_ != ASCEND_910B) || ((deviceId / ASC910B_CONN_RANKS) == (logicDeviceId_ / ASC910B_CONN_RANKS));
    }

    if (superPodId == invalidSuperPodId || superPodId_ == invalidSuperPodId) {
        SHM_LOG_DEBUG("spid: " << superPodId << ", local: " << superPodId_ << " cannot reach.");
        return false;
    }

    return superPodId == superPodId_;
}
}