# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import pickle
import re
from collections import defaultdict

from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
from tools.flight_recorder.components.utils import get_valid_read_path

MAX_DEPTH = 3

logger: FlightRecorderLogger = FlightRecorderLogger()

SAFE_CLASSES = {
    # Built-in security type
    "builtins": {"str", "int", "float", "list", "dict", "tuple"},
}

exp = re.compile(r"^([a-zA-Z0-9_]{0,100}?)(\d+)$")


class SafeUnpickler(pickle.Unpickler):
    def find_class(self, module, name):
        # Check if the module and class are in the whitelist
        if module in SAFE_CLASSES and name in SAFE_CLASSES[module]:
            return super().find_class(module, name)
        raise pickle.UnpicklingError(f"Forbidden class: {module}.{name}")


def read_dump(prefix, filename):
    basename = os.path.basename(filename)
    try:
        rank = int(basename[len(prefix):])
    except ValueError as e:
        raise ValueError(f"Cannot extract rank from '{basename}' with prefix '{prefix}'.") from e
    filename = get_valid_read_path(filename)
    try:
        with open(filename, "rb") as infile:
            dump = SafeUnpickler(infile).load()
    except Exception as e:
        logger.error(f"Failed to load data from {filename}: {e}")
    return rank, dump


def determine_prefix(files):
    possible_prefixes: defaultdict[str, set[int]] = defaultdict(set)
    for f in files:
        m = exp.search(f)
        if m:
            p, r = m.groups()
            possible_prefixes[p].add(int(r))
    if len(possible_prefixes) == 1:
        prefix = next(iter(possible_prefixes))
        return prefix
    else:
        raise ValueError(
            "Unable to automatically determine the common prefix for the trace file names. "
            "Please specify --prefix argument manually"
        )


def read_dir(args):
    """Load recorder data for all ranks"""
    prefix = args.prefix
    path = args.trace_dir
    details = {}
    version = ""
    for root, _, files in os.walk(path):
        current_depth = root.count(os.sep) - path.count(os.sep)
        if current_depth > MAX_DEPTH:
            logger.error("The current file depth has exceeded the maximum depth limit, which is set to {MAX_DEPTH}.")
            break
        if prefix is None:
            prefix = determine_prefix(files)
        for f in files:
            if "py_traceback" in f:
                continue
            if f.find(prefix) != 0:
                continue
            rank, dump = read_dump(prefix, os.path.join(root, f))
            details[rank] = dump
            if not version:
                version = str(details[rank]["version"])
    return details, version