# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

__all__ = []

import argparse
import math
from typing import Any
import os
import re
import sys
import stat

from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
from tools.flight_recorder.components.types import (
    Group,
    MatchInfo,
    MatchState,
    MatchStateRecord,
    Membership,
    Op,
)

logger: FlightRecorderLogger = FlightRecorderLogger()

try:
    from tabulate import tabulate
except ModuleNotFoundError:
    logger.debug("tabulate is not installed. Proceeding without it.")

PATH_WHITE_LIST_REGEX = re.compile(r"[^_A-Za-z0-9/.-]")
MAX_READ_FILE_SIZE_4G = 4294967296  # 4G, 4 * 1024 * 1024 * 1024
MAX_READ_FILE_SIZE_32G = 34359738368  # 32G, 32 * 1024 * 1024 * 1024
MAX_READ_FILE_SIZE_512G = 549755813888  # 512G, 512 * 1024 * 1024 * 1024

# group not writable, others no permission, max stat is 750
WRITE_FILE_NOT_PERMITTED_STAT = stat.S_IWGRP | stat.S_IWOTH | stat.S_IROTH | stat.S_IXOTH
# group not writable, others not writable, max stat is 755
READ_FILE_NOT_PERMITTED_STAT = stat.S_IWGRP | stat.S_IWOTH


def type_to_str(value_type):
    return " or ".join(ii.__name__ for ii in value_type) if isinstance(value_type, tuple) else value_type.__name__


def check_type(value, value_type, param_name="value"):
    if not isinstance(value, value_type):
        raise TypeError("{} must be {}, not {}.".format(param_name, type_to_str(value_type), type(value).__name__))


def get_valid_path(path):
    check_type(path, str, "path")
    if not path or len(path) == 0:
        raise ValueError("The value of the path cannot be empty.")
    if PATH_WHITE_LIST_REGEX.search(path):  # Check special char
        raise ValueError("Input path contains invalid characters.")  # Not printing out the path value for invalid char
    path = os.path.expanduser(path)
    if os.path.islink(os.path.abspath(path)):  # when checking link, get rid of the "/" at the path tail if any
        raise ValueError("The value of the path cannot be a symbolic link: {}.".format(path))

    real_path = os.path.realpath(path)

    if len(real_path) > 4096:
        raise ValueError("The length of file path should be less than 4096.")

    if real_path != path and PATH_WHITE_LIST_REGEX.search(real_path):  # Check special char again
        raise ValueError("Input path contains invalid characters.")  # Not printing out the path value for invalid char

    return real_path


def is_belong_to_user_or_group(file_stat):
    return file_stat.st_uid == os.getuid() or file_stat.st_gid in os.getgroups()


def get_valid_read_path(path, size_max=MAX_READ_FILE_SIZE_4G, check_user_stat=True, is_dir=False):
    real_path = get_valid_path(path)
    if not is_dir and not os.path.isfile(real_path):
        raise ValueError("The path {} doesn't exists or not a file.".format(path))
    if is_dir and not os.path.isdir(real_path):
        raise ValueError("The path {} doesn't exists or not a directory.".format(path))

    file_stat = os.stat(real_path)
    if check_user_stat and not sys.platform.startswith("win") and not is_belong_to_user_or_group(file_stat):
        raise ValueError("The file {} doesn't belong to the current user or group.".format(path))
    if check_user_stat and os.stat(path).st_mode & READ_FILE_NOT_PERMITTED_STAT > 0:
        raise ValueError("The file {} is group writable, or is others writable.".format(path))
    if not os.access(real_path, os.R_OK) or file_stat.st_mode & stat.S_IRUSR == 0:  # At least been 400
        raise ValueError("Current user doesn't have read permission to the file {}.".format(path))
    if not is_dir and size_max > 0 and file_stat.st_size > size_max:
        raise ValueError("The file {} exceeds size limitation of {}.".format(path, size_max))
    return real_path


def check_write_directory(dir_name, check_user_stat=True):
    real_dir_name = get_valid_path(dir_name)
    if not os.path.isdir(real_dir_name):
        raise ValueError("The file writen directory {} doesn't exists.".format(dir_name))

    file_stat = os.stat(real_dir_name)
    if check_user_stat and not sys.platform.startswith("win") and not is_belong_to_user_or_group(file_stat):
        raise ValueError("The file writen directory {} doesn't belong to the current user or group.".format(dir_name))
    if not os.access(real_dir_name, os.W_OK):
        raise ValueError("Current user doesn't have writen permission to file writen directory {}.".format(dir_name))


def get_valid_write_path(path, check_user_stat=True, is_dir=False, warn_exists=True):
    real_path = get_valid_path(path)
    real_path_dir = real_path if is_dir else os.path.dirname(real_path)
    check_write_directory(real_path_dir, check_user_stat=check_user_stat)

    if not is_dir and os.path.exists(real_path):
        if os.path.isdir(real_path):
            raise ValueError("The file {} exist and is a directory.".format(path))
        if check_user_stat and os.stat(real_path).st_uid != os.getuid():  # Has to be exactly belonging to current user
            raise ValueError("The file {} doesn't belong to the current user.".format(path))
        if check_user_stat and os.stat(real_path).st_mode & WRITE_FILE_NOT_PERMITTED_STAT > 0:
            raise ValueError("The file {} permission for others is not 0, or is group writable.".format(path))
        if not os.access(real_path, os.W_OK):
            raise ValueError("The file {} exist and not writable.".format(path))
        if warn_exists:
            logger.warning("%s already exist. The original file will be overwritten.", path)
    return real_path


def format_frame(frame: dict[str, str]) -> str:
    name = frame.get("name", "unknown")
    filename = frame.get("filename", "unknown_file")
    line = frame.get("line", "unknown_line")
    return f"{name} at {filename}:{line}"


def format_frames(frames: list[dict[str, str]]) -> str:
    formatted_frames = []
    for frame in frames:
        formatted_frames.append(format_frame(frame))
    return "\n".join(formatted_frames)


def match_one_event(
    event_a: dict[Any, Any],
    event_b: dict[Any, Any],
    memberships: dict[str, set[Any]],
    pg_name: str,
) -> MatchInfo:
    op_a = Op(event_a, memberships, pg_name)
    op_b = Op(event_b, memberships, pg_name)
    return op_a.match(op_b)


def check_size_alltoall(alltoall_cases: list[dict[str, Any]]) -> tuple[bool, int, int]:
    input_numel = 0
    output_numel = 0
    for e in alltoall_cases:
        input_sizes = e.get("input_sizes", [])
        output_sizes = e.get("output_sizes", [])

        if input_sizes and len(input_sizes) > 0:
            input_numel += math.prod(input_sizes[0])

        if output_sizes and len(output_sizes) > 0:
            output_numel += math.prod(output_sizes[0])
    return input_numel != output_numel, input_numel, output_numel


class ProcessGroupData:
    def __init__(self, pg_guids: dict[tuple[str, int], str], pg_name: str, desc: str, mismatch: dict[str, int]):
        self.pg_guids, self.pg_name, self.desc, self.mismatch = pg_guids, pg_name, desc, mismatch


def check_current_entry_match(
    all_entries: dict[int, list[dict[str, Any]]],
    current_entry: dict[str, Any],
    _memberships: dict[str, set[Any]],
    pg_data: ProcessGroupData,
    match_record: MatchStateRecord,
) -> None:
    pg_guids, pg_name, mismatch, desc = pg_data.pg_guids, pg_data.pg_name, pg_data.mismatch, pg_data.desc
    for rank in match_record.expected_ranks.intersection(set(match_record.other_ranks)):
        for entry_idx, entry in enumerate(all_entries[rank]):
            # step over ops from other PGs
            # only check match state when seq_id matches
            if (
            pg_guids.get((entry.get("process_group", [None])[0], rank)) == pg_name
            and entry.get("collective_seq_id") == match_record.entry_state.collective_seq_id
            ):
                match_info = match_one_event(current_entry, entry, _memberships, pg_name)
                if match_info.state in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED] and mismatch[pg_name] == 0:
                    match_record.found_ranks.add(rank)
                    match_record.found_idx[rank] = entry_idx
                    match_record.has_undecided_case = match_info.state == MatchState.UNDECIDED
                else:
                    match_record.candidate_ranks.add(rank)
                    match_record.candidate_idx[rank] = entry_idx
                    if match_info.state not in [
                        MatchState.FULLY_MATCHED,
                        MatchState.UNDECIDED,
                    ]:
                        match_record.errors.add((rank, match_info))
                break


class EntryContext:
    def __init__(self, all_entries, current_entry, dumps_ranks, first_rank):
        self.all_entries = all_entries
        self.current_entry = current_entry
        self.dumps_ranks = dumps_ranks
        self.first_rank = first_rank


def error_analysis(
    entry_context: EntryContext,
    match_record: MatchStateRecord, # all
    mismatch: dict[str, int],  # all
    version: tuple[int, int],  # 2
    pg_name: str,  # all, mismatch
) -> None:
    all_entries = entry_context.all_entries
    current_entry = entry_context.current_entry
    dumps_ranks = entry_context.dumps_ranks
    first_rank = entry_context.first_rank
    major_v, minor_v = version[0], version[1]
    # case one: not every rank join the collective or in the flight recorder.
    if (
        match_record.candidate_ranks | match_record.found_ranks
    ) != match_record.expected_ranks and match_record.expected_ranks - (
        match_record.candidate_ranks | match_record.found_ranks
    ) <= dumps_ranks:
        mismatch[pg_name] += 1
        logger_msg = "Not all ranks joining collective, sequence number: %s"
        missing_ranks = match_record.expected_ranks - (match_record.candidate_ranks | match_record.found_ranks)
        match_record.entry_state.log(
            logger, logger_msg, format_frames, additional_info={"missing_ranks": missing_ranks}
        )
        match_record.candidate_ranks.update(match_record.found_ranks)
        match_record.candidate_idx.update(match_record.found_idx)
        match_record.found_idx.clear()
        match_record.found_ranks.clear()
    elif len(match_record.candidate_ranks) == 1 and dumps_ranks == match_record.expected_ranks:
        # case two: alltoall or alltoall_base case.
        if match_record.has_undecided_case:
            alltoall_cases = [current_entry] + [
                all_entries[rank][match_record.found_idx[rank]] for rank in match_record.found_ranks
            ]
            fail_check, total_input_numel, total_output_numel = check_size_alltoall(alltoall_cases)
            if major_v <= 2 and minor_v <= 3:
                # We don't log the input/output sizes for alltoall before v2.4,
                # so we don't consider the size mismatch as an error for now.
                fail_check = False
            if fail_check:
                # When we see errors in all_to_all, it's hard to tell which rank is the source of the error.
                mismatch[pg_name] += 1
                logger_msg = "Input/output mismatch in the collective sequence number: %s"
                match_record.entry_state.log(
                    logger,
                    logger_msg,
                    format_frames,
                    additional_info={"total_numel": (total_input_numel, total_output_numel)},
                )
                match_record.candidate_ranks.update(match_record.found_ranks)
                match_record.candidate_idx.update(match_record.found_idx)
                match_record.found_idx.clear()
                match_record.found_ranks.clear()
                match_record.errors.add((first_rank, MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)))
            else:
                match_record.found_ranks.update(match_record.candidate_ranks)
                match_record.found_idx.update(match_record.candidate_idx)
                match_record.candidate_idx.clear()
                match_record.candidate_ranks.clear()
        # case three: all joined and everything matches on all ranks.
        else:
            match_record.found_ranks.update(match_record.candidate_ranks)
            match_record.found_idx.update(match_record.candidate_idx)
            match_record.candidate_idx.clear()
            match_record.candidate_ranks.clear()
    # case four: mismatch cases due to not same type, size mismatch or state mismatch.
    elif len(match_record.errors) > 0:
        mismatch[pg_name] += 1
        logger_msg = "Collective sequence number: %s has errors"
        match_record.entry_state.log(logger, logger_msg, format_frames, errors=match_record.errors)
        match_record.candidate_ranks.update(match_record.found_ranks)
        match_record.candidate_idx.update(match_record.found_idx)
        match_record.found_idx.clear()
        match_record.found_ranks.clear()
    # partial analysis case when we cannot decide what's wrong with this collective entry.
    else:
        match_record.candidate_ranks.update(match_record.found_ranks)
        match_record.candidate_idx.update(match_record.found_idx)
        match_record.found_idx.clear()
        match_record.found_ranks.clear()
        if match_record.expected_ranks - dumps_ranks:
            mismatch[pg_name] += 1
            logger.info(
                "We cannot decide what's wrong with this collective entry "
                "because we missed FR dumps from ranks (%s) so we don't have enough "
                "information. If you want to debug further use -j to dump all raw trace",
                str(match_record.expected_ranks - dumps_ranks),
            )
        else:
            logger.info(
                "No errors found for this collective entry, There could be some "
                "other reasons why we see collective timeout."
            )


def just_print_entries(
    all_entries: dict[int, list[dict[str, Any]]],
    _groups: dict[str, Group],
    _memberships: dict[str, set[Any]],
    _pg_guids: dict[tuple[str, int], str],
    args: argparse.Namespace,
) -> None:
    rows = []
    ranks = sorted(all_entries.keys())
    headers = [f"Rank {rank}" for rank in ranks if args.selected_ranks is None or rank in args.selected_ranks]
    progress = True
    while progress:
        progress = False
        row = []
        for rank in ranks:
            if args.selected_ranks is not None and rank not in args.selected_ranks:
                continue
            if len(all_entries[rank]) == 0:
                row.append("")
            else:
                entry = all_entries[rank].pop(0)
                process_group = entry.get("process_group", [None, None])
                pg_name = _pg_guids.get((process_group[0], rank))

                if (
                    args.pg_filters is None
                    or process_group[1] in args.pg_filters
                    or process_group[0] in args.pg_filters
                ):
                    row.append(str(Op(entry, _memberships, pg_name)))
                else:
                    row.append("")
                progress = True
        if progress:
            rows.append(row)

    logger.info(tabulate(rows, headers=headers))


def check_no_missing_dump_files(entries: dict[int, Any], memberships: list[Membership]) -> None:
    try:
        all_ranks = {int(m.global_rank) for m in memberships}
    except (ValueError, TypeError) as e:
        raise ValueError(f"Cannot extract rank from memberships. Invalid global_rank value encountered: {e}") from e

    try:
        dumps_ranks = {int(key) for key in entries.keys()}
    except (ValueError, TypeError) as e:
        raise ValueError(f"Cannot extract rank from entries keys. Invalid key value encountered: {e}") from e

    missing_ranks = all_ranks - dumps_ranks
    if missing_ranks:
        raise ValueError(
            f"Missing dump files for {len(missing_ranks)} ranks: {sorted(missing_ranks)}\n"
            f"Expected ranks: {sorted(all_ranks)}\n"
            f"Found dumps for: {sorted(dumps_ranks)}"
        )


def check_version(version_by_ranks: dict[str, str], expected_version: str) -> None:
    for rank, actual_version in version_by_ranks.items():
        if actual_version != expected_version:
            raise ValueError(f"Version mismatch at rank {rank}: " f"expected {expected_version}, got {actual_version}")


def get_version_detail(version_str: str) -> tuple[int, int]:
    parts = version_str.split(".")
    if len(parts) != 2:
        raise ValueError(f"Invalid version format: expected 'X.Y', got '{version_str}'")

    try:
        major, minor = int(parts[0]), int(parts[1])
    except ValueError as e:
        raise ValueError(f"Version components must be integers: '{version_str}'") from e

    return major, minor


def align_trace_from_beginning(
    entries: dict[int, list[dict[str, Any]]],
) -> dict[int, list[dict[str, Any]]]:
    """
    Align the trace entries by record ID for entries.
    This function takes a dictionary of rank names to lists of trace entries as input.
    Each trace entry is a dictionary containing information about a collective operation,
    including its unique identifier (`record_id` is monotonically increasing as we write into the ring buffer).
    The function finds the largest starting point across all ranks by taking the maximum
    `record_id` value of the first entry in each rank. Finally, it filters out any
    entries with `record_id` values less than the maximum starting point.
    The function returns the updated dictionary of sorted and filtered trace entries.

    Args:
        entries (Dict[str, List[Dict[str, Any]]]): A dictionary of rank names to lists of trace entries.

    Returns:
        entries (Dict[str, List[Dict[str, Any]]]): Entries sorted by record ID and filtered by the maximum starting point.
    """

    maximum_starting_record_id = 0
    for rank in entries:
        # Although this is a ring buffer, we already sort the entries by `record_id` when dumping, we just
        # need to find the largest starting point. For example, if the buffer has the following entries:
        # Rank 0: [0, 1, 2, 3, 4, 5, 6]
        # Rank 1: [1, 2, 3, 4, 5, 6, 7]
        # Rank 2: [2, 3, 4, 5, 6, 7, 8]
        # Rank 3: [0, 1, 2, 3, 4, 5, None]
        # Then we should start from collective 2 not 0 because any collective before,
        # we don't have complete records from all ranks so we need to ignore them.
        first_record_id = entries[rank][0].get("record_id", -1)
        maximum_starting_record_id = max(maximum_starting_record_id, first_record_id)

    for rank in entries:
        entries[rank] = [entry for entry in entries[rank] if entry.get("record_id", -1) >= maximum_starting_record_id]

    return entries