/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!

* \file sk_common.cpp
* \brief Utility functions for ELF symbol extraction and caching
*/


#include "sk_common.h"

#include <acl/acl.h>
#include <unordered_map>
#include <string>
#include <cstring>
#include <elf.h>
#include "sk_log.h"

namespace {

constexpr const char* ASCEND950_SOC_NAME = "Ascend950";

} // namespace

SkKernelArch GetCurrentSkKernelArch()
{
    const char* socName = aclrtGetSocName();
    if (socName == nullptr) {
        SK_LOGI("Kernel arch detection: soc name is null, fallback to arch=%s",
                to_string(SkKernelArch::DAV_2201));
        return SkKernelArch::DAV_2201;
    }
    if (strstr(socName, ASCEND950_SOC_NAME) != nullptr) {
        return SkKernelArch::DAV_3510;
    }
    return SkKernelArch::DAV_2201;
}

const char* GetSkKernelArchSymbolSuffix(SkKernelArch arch)
{
    switch (arch) {
    case SkKernelArch::DAV_2201:
        return "dav_2201";
    case SkKernelArch::DAV_3510:
        return "dav_3510";
    default:
        return "unknown";
    }
}

namespace {

enum class SymBindType : uint8_t {
    LOCAL = 0,
    GLOBAL = 1,
    WEAK = 2,
};

const char* SymBindTypeToStr(SymBindType bindType) {
    switch (bindType) {
        case SymBindType::LOCAL:  
            return "LOCAL";
        case SymBindType::GLOBAL: 
            return "GLOBAL";
        case SymBindType::WEAK:    
            return "WEAK";
        default:                   
            return "UNKNOWN";
    }
}

struct FuncSymbolInfo {
    std::string name;
    uint64_t size;
    SymBindType bindType;
};
using SkFuncSymbolTable = std::unordered_map<uint64_t, FuncSymbolInfo>;
using SkBinSymbolTable = std::unordered_map<aclrtBinHandle, SkFuncSymbolTable>;

struct ElfSymbolTables {
    const Elf64_Sym* symTbl;
    size_t symSize;
    const char* strTbl;
    size_t strTblSize;
};

static bool ValidateElfSectionRange(size_t binSize, uint64_t offset, uint64_t size)
{
    return offset <= binSize && offset + size <= binSize;
}

static bool FindElfSymbolTables(const char* binAddr, size_t binSize, ElfSymbolTables& tables) {
    constexpr size_t ELF64_EHDR_SIZE = sizeof(Elf64_Ehdr);
    constexpr size_t ELF64_SHDR_SIZE = sizeof(Elf64_Shdr);
    
    if (binSize < ELF64_EHDR_SIZE) {
        SK_LOGE("Invalid ELF: binSize=%zu < minimum header size %zu", binSize, ELF64_EHDR_SIZE);
        return false;
    }
    
    const Elf64_Ehdr* ehdr = reinterpret_cast<const Elf64_Ehdr*>(binAddr);
    
    if (ehdr->e_shoff > binSize ||
        ehdr->e_shoff + static_cast<uint64_t>(ehdr->e_shnum) * ELF64_SHDR_SIZE > binSize) {
        SK_LOGE("Invalid ELF: shoff=0x%lx, shnum=%u exceed binSize=%zu", 
                static_cast<uint64_t>(ehdr->e_shoff), ehdr->e_shnum, binSize);
        return false;
    }
    
    if (ehdr->e_shstrndx >= ehdr->e_shnum) {
        SK_LOGE("Invalid ELF: e_shstrndx=%u >= e_shnum=%u", ehdr->e_shstrndx, ehdr->e_shnum);
        return false;
    }
    
    const Elf64_Shdr* shHdr = reinterpret_cast<const Elf64_Shdr*>(binAddr + ehdr->e_shoff);
    
    const Elf64_Shdr& shstrtabHdr = shHdr[ehdr->e_shstrndx];
    if (!ValidateElfSectionRange(binSize, shstrtabHdr.sh_offset, shstrtabHdr.sh_size)) {
        SK_LOGE("Invalid ELF: shstrtab offset=0x%lx size=0x%lx exceed binSize", 
                static_cast<uint64_t>(shstrtabHdr.sh_offset), static_cast<uint64_t>(shstrtabHdr.sh_size));
        return false;
    }
    
    const char* shStrTbl = binAddr + shstrtabHdr.sh_offset;
    size_t shStrTblSize = shstrtabHdr.sh_size;
    
    tables.symTbl = nullptr;
    tables.symSize = 0;
    tables.strTbl = nullptr;
    tables.strTblSize = 0;
    
    for (uint16_t i = 0; i < ehdr->e_shnum; ++i) {
        if (shHdr[i].sh_type == SHT_NULL || shHdr[i].sh_type == SHT_NOBITS) {
            continue;
        }
        
        if (shHdr[i].sh_name >= shStrTblSize) {
            continue;
        }
        
        const char* secName = shStrTbl + shHdr[i].sh_name;
        size_t secNameMaxLen = shStrTblSize - shHdr[i].sh_name;
        
        if (!ValidateElfSectionRange(binSize, shHdr[i].sh_offset, shHdr[i].sh_size)) {
            continue;
        }
        
        if (strncmp(".symtab", secName, secNameMaxLen) == 0) {
            tables.symTbl = reinterpret_cast<const Elf64_Sym*>(binAddr + shHdr[i].sh_offset);
            tables.symSize = shHdr[i].sh_size;
        } else if (strncmp(".strtab", secName, secNameMaxLen) == 0) {
            tables.strTbl = binAddr + shHdr[i].sh_offset;
            tables.strTblSize = shHdr[i].sh_size;
        }
    }
    return (tables.symTbl != nullptr && tables.strTbl != nullptr);
}

static void ExtractFunctionSymbols(const ElfSymbolTables& tables, SkFuncSymbolTable& funcSymTable) 
{
    constexpr size_t ELF64_SYM_SIZE = sizeof(Elf64_Sym);
    size_t symCount = tables.symSize / ELF64_SYM_SIZE;
    
    for (size_t i = 0; i < symCount; ++i) {
        const Elf64_Sym& sym = tables.symTbl[i];
        
        if ((sym.st_info & 0xf) != STT_FUNC || sym.st_size == 0) {
            continue;
        }
        
        if (sym.st_name >= tables.strTblSize) {
            continue;
        }
        
        const char* name = tables.strTbl + sym.st_name;
        if (name == nullptr || name[0] == '\0') {
            continue;
        }
        
        SymBindType bindType = (sym.st_info >> 4) == STB_WEAK ? SymBindType::WEAK
                              : (sym.st_info >> 4) == STB_GLOBAL ? SymBindType::GLOBAL
                              : SymBindType::LOCAL;
        funcSymTable[sym.st_value] = {name, sym.st_size, bindType};
    }
}

SkFuncSymbolTable BuildFuncSymbolTable(const char* binAddr, size_t binSize) 
{
    SkFuncSymbolTable funcSymTable;
    if (binAddr == nullptr || binSize == 0) {
        SK_LOGE("Invalid bin parameters: binAddr=%p, binSize=%zu", binAddr, binSize);
        return funcSymTable;
    }
    ElfSymbolTables tables;
    if (!FindElfSymbolTables(binAddr, binSize, tables)) {
        SK_LOGE("Failed to find valid symtab or strtab sections");
        return funcSymTable;
    }
    ExtractFunctionSymbols(tables, funcSymTable);
    SK_LOGI("total %zu function symbols found", funcSymTable.size());
    return funcSymTable;
}

} // namespace

bool GetFuncSymbolInfo(aclrtBinHandle binHdl, const char* binAddr, size_t binSize, uint64_t funcAddr,
                       std::string& symbolName, uint64_t& funcSize, std::string& symbolBind)

{
    if (binAddr == nullptr || binSize == 0) {
        SK_LOGE("Invalid bin parameters: binAddr=%p, binSize=%zu", binAddr, binSize);
        return false;
    }
    static SkBinSymbolTable symbolTable;
    auto cacheIt = symbolTable.find(binHdl);
    if (cacheIt == symbolTable.end()) {
        SK_LOGI("Building symbol table for binHdl=%p", binHdl);
        symbolTable[binHdl] = BuildFuncSymbolTable(binAddr, binSize);
        cacheIt = symbolTable.find(binHdl);
    }

    const auto& funcSymTable = cacheIt->second;
    auto it = funcSymTable.find(funcAddr);
    if (it != funcSymTable.end()) {
        symbolName = it->second.name;
        funcSize = it->second.size;
        symbolBind = SymBindTypeToStr(it->second.bindType);
        SK_LOGI("Found symbol: name=%s, addr=0x%lx, size=0x%lx, bind=%s",
                symbolName.c_str(), funcAddr, funcSize, symbolBind.c_str());
        return true;
    }
    SK_LOGW("Function symbol not found for addr=0x%lx", funcAddr);
    return false;
}

std::string GetSocName()
{
    SK_LOGI("Init socName");
    const char* socNameTmp = aclrtGetSocName();
    if (socNameTmp == nullptr) {
        SK_LOGE("Failed to get soc name");
        return "";
    }
    std::string socName(socNameTmp);
    SK_LOGI("Soc name: %s", socName.c_str());
    return socName;
}

// ==================== Device Core Number Utilities ====================

int64_t GetDeviceCubeCoreNum()
{
    int32_t deviceId = 0;
    aclError ret = aclrtGetDevice(&deviceId);
    if (ret != ACL_SUCCESS) {
        SK_LOGE("[DeviceCores] Failed to get deviceId, ret=%d", ret);
        return 0;
    }
    int64_t cubeNum = 0;
    ret = aclrtGetDeviceInfo(deviceId, ACL_DEV_ATTR_CUBE_CORE_NUM, &cubeNum);
    if (ret != ACL_SUCCESS) {
        SK_LOGE("[DeviceCores] Failed to get cube core num, ret=%d", ret);
        return 0;
    }
    return cubeNum;
}

int64_t GetDeviceVecCoreNum()
{
    int32_t deviceId = 0;
    aclError ret = aclrtGetDevice(&deviceId);
    if (ret != ACL_SUCCESS) {
        SK_LOGE("[DeviceCores] Failed to get deviceId, ret=%d", ret);
        return 0;
    }
    int64_t vecNum = 0;
    ret = aclrtGetDeviceInfo(deviceId, ACL_DEV_ATTR_VECTOR_CORE_NUM, &vecNum);
    if (ret != ACL_SUCCESS) {
        SK_LOGE("[DeviceCores] Failed to get vec core num, ret=%d", ret);
        return 0;
    }
    return vecNum;
}

aclError GetDeviceCoreNums(int64_t& cubeNum, int64_t& vecNum)
{
    cubeNum = GetDeviceCubeCoreNum();
    if (cubeNum <= 0) {
        SK_LOGE("[DeviceCores] GetDeviceCubeCoreNum returned invalid value: %ld", cubeNum);
        return ACL_ERROR_INVALID_PARAM;
    }
    vecNum = GetDeviceVecCoreNum();
    if (vecNum <= 0) {
        SK_LOGE("[DeviceCores] GetDeviceVecCoreNum returned invalid value: %ld", vecNum);
        return ACL_ERROR_INVALID_PARAM;
    }
    SK_LOGI("[DeviceCores] Get core nums: cube=%ld, vec=%ld", cubeNum, vecNum);
    return ACL_SUCCESS;
}