#!/usr/bin/env python3
# tools/symbolshare.py
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.  The
# ASF licenses this file to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance with the
# License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
# License for the specific language governing permissions and limitations
# under the License.
#

import argparse
import os
import subprocess
import tempfile

from elftools.elf.elffile import ELFFile


def find_symbol_in_section(elf_path, section_names):
    symbols = []
    sections = []

    with open(elf_path, "rb") as elf_file:
        elf = ELFFile(elf_file)

        # Retrieve sections from the ELF file
        for section_name in section_names:
            section = elf.get_section_by_name(section_name)
            if section is None:
                print(f"Section '{section_name}' not found in {elf_path}")
                continue

            sections.append(
                (
                    section["sh_addr"],
                    section["sh_addr"] + section["sh_size"],
                    section_name,
                )
            )

        # Retrieve symbols from the symbol table
        symtab = elf.get_section_by_name(".symtab")
        if symtab is None:
            print(f"No symbol table found in {elf_path}")
            return symbols

        for symbol in symtab.iter_symbols():
            for start, end, section_name in sections:
                if (
                    symbol["st_value"] < end
                    and symbol["st_value"] >= start
                    and symbol["st_info"]["type"] == "STT_OBJECT"
                ):
                    symbols.append(
                        (
                            symbol.name,
                            symbol["st_value"],
                            symbol["st_size"],
                            section_name,
                        )
                    )

    return symbols


def objfile_iter(path):
    for p in path:
        temp_dir = tempfile.mkdtemp()
        subprocess.run(["ar", "x", os.path.abspath(p)], cwd=temp_dir)
        for filename in os.listdir(temp_dir):
            yield os.path.join(temp_dir, filename)


def args_parser():
    parser = argparse.ArgumentParser(description="Show symbols in ELF sections.")
    parser.add_argument("-e", "--elf", required=True, help="Path to the ELF file")
    parser.add_argument("-l", "--ld", default="tmp.ld", help="Output link script")
    parser.add_argument(
        "-s", "--section", required=True, nargs="+", help="Section names to inspect"
    )
    parser.add_argument(
        "--tasking",
        action="store_true",
        default=False,
    )

    return parser.parse_args()


if __name__ == "__main__":
    args = args_parser()
    section_symbols = find_symbol_in_section(args.elf, args.section)
    with open(args.ld, "w") as fd:
        fd.write("/* This file is auto-generated by tools/symbolshare.py */\n")
        if not args.tasking:
            for symbol in section_symbols:
                fd.write(f"{symbol[0]} = 0x{symbol[1]:x};\n")
        else:
            for symbol in section_symbols:
                fd.write(f'"{symbol[0]}" = 0x{symbol[1]:x};\n')