#!/usr/bin/env python3
# Copyright (c) 2025 Huawei Device Co., Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTE(pronai) need idempotence test and negative tests for non-normal YAMLs

import io
import os
from pathlib import Path
import posixpath
import shutil
import subprocess
import sys
# YAML parser/dumper that preserves formatting
from ruamel import yaml
import numpy as np
import math

usage = f"""\
Usage: {sys.argv[0]} [--all] inputs

    inputs: a (possibly empty) list of paths to diagnostic YAMLs
    --all: run on all YAMLs that look like they include diagnostics
    -h | --help: print this help
    """

permissible_collision_chance = 0.001
simultaneous_commits = 10
# how many new ids each commit adds, based on analysis of previous commits, this is ceil(avg)
# it must be a positive int
ids_per_commit = 2
# we model commits with multiple ids as two commits
# this makes the math a lot easier and is a pessimistic approximation
effective_alloc_rate = simultaneous_commits * ids_per_commit


# see https://en.wikipedia.org/wiki/Birthday_problem
def collision_prob(n):
    # the factorials are optimized away
    return 1 - math.prod(range(n - effective_alloc_rate + 1, n+1)) / n ** effective_alloc_rate

# binary search for a solution
def compute_range():
    min_range = effective_alloc_rate
    upper_bound = min_range
    lower_bound = min_range
    step = 1
    while collision_prob(upper_bound) >= permissible_collision_chance:
        lower_bound = upper_bound
        upper_bound+=step
        step*=2
    while lower_bound < upper_bound:
        mid = (upper_bound+lower_bound)/2
        # n must be integer
        mid_lo = math.floor(mid)
        mid_hi = math.ceil(mid)
        p_lo = collision_prob(mid_lo)
        p_hi = collision_prob(mid_hi)
        # higher n gives lower p!
        if p_hi > permissible_collision_chance:
            # p is too high, need higher n
            lower_bound = mid_hi
        else: # p_hi <= permissible_collision_chance
            if permissible_collision_chance > p_lo:
                upper_bound = mid_lo
            else: # permissible_collision_chance <= p_lo
                min_range = mid_hi
                break
    return min_range

min_range = compute_range()

# needed to make sure we are using the right numpy dtype
assert math.ceil(math.log2(min_range))<32
dtype = np.uint32

def normalize(in_path):
    out_path = in_path + ".new"
    with open(in_path) as in_file, open(out_path, "w+") as out_file:
        parser = yaml.YAML()
        parser.preserve_quotes = True
        parser.width = 120
        docs = parser.load(in_file)
        doc = docs[next(iter(docs.keys()))]
        doc.sort(key=lambda item: item['name'])
        rng = np.random.default_rng()
        if "graveyard" in docs:
            docs["graveyard"].sort()
        for diagnostic in doc:
            if "id" in diagnostic and "graveyard" in docs:
                continue
            # we re-compute this every time because it's easier than splitting the ranges after each allocation
            # it can be optimized later if necessary
            # either way, we only add 1-2 IDs in each commit, so at most it doubles the run time
            allocated = np.fromiter((int(diagnostic['id']) for diagnostic in doc if 'id' in diagnostic), dtype)
            # add 0 as sentinel value, so we can deal with deletions
            # 1 is the first actually available id, it's taken as of press time, but it might not remain that way
            allocated = np.append(allocated, [0])
            allocated.sort()
            if "graveyard" not in docs:
                print(f"{in_path}: filling graveyard")
                docs.insert(len(docs), "graveyard", sorted(list(set(range(1, allocated[-1]+1)).difference(set(allocated)))))
            if 'id' in diagnostic:
                continue
            allocated = np.append(allocated, docs["graveyard"])
            allocated.sort()
            gaps = allocated[1:]-allocated[:-1]-1
            gaps = np.append(gaps, [max(0, min_range-gaps.sum())])
            non_unique = gaps<0
            assert len(gaps) == len(allocated)
            if np.any(non_unique):
                raise ValueError(f"Non unique ids in yaml {in_path}:", allocated[non_unique])
            # choose from the available intervals using their size as weight
            p = gaps
            # need to normalize them so they sum to 1
            p = p / np.sum(p, dtype=float)
            ival_id = rng.choice(len(allocated), p=p)
            ival_start = 1+allocated[ival_id]
            ival_size = gaps[ival_id]
            # generate a random number from the interval
            iid = ival_start+rng.choice(ival_size)
            gaps[ival_id]-=1
            diagnostic.insert(1, 'id', int(iid))
        sio = io.StringIO()
        parser.dump(docs, sio)
        contents = sio.getvalue()
        # make sure there is at least one blank line before each list item
        prev="not blank"
        for line in contents.splitlines(keepends=True):
            # print("prev", prev)
            # print("line", line)
            if line.lstrip().startswith("- name:") and prev.strip()!="":
                out_file.write("\n")
                pass
            out_file.write(line)
            prev=line
        usage_comment = "# See ets_frontend/ets2panda/util/diagnostic/README.md before contributing.\n"
        if prev.strip() != usage_comment.strip():
            out_file.write(usage_comment)
    os.rename(out_path, in_path)

def known_file_paths():
    repo_base = posixpath.dirname(shutil.which(sys.argv[0])) + "/../../"
    for relative in subprocess.run(["git", "grep", "--files-with-matches", "-e" "^  message:", "--", "**.yaml"], cwd = repo_base, stdout=subprocess.PIPE).stdout.decode().splitlines():
        yield repo_base + relative

def main():
    params = sys.argv[1:]
    if "-h" in sys.argv or "--help" in params:
        print(usage, file=sys.stderr)
        sys.exit(1)
    if "--all" in params:
        params = [p for p in params if p != "--all"]
        for path in known_file_paths():
            params.append(path)
    for path in params:
        print("Processing", path)
        normalize(path)

if __name__ == "__main__":
    main()