__all__ = []
import argparse
import math
from typing import Any
import os
import re
import sys
import stat
from tools.flight_recorder.components.fr_logger import FlightRecorderLogger
from tools.flight_recorder.components.types import (
Group,
MatchInfo,
MatchState,
MatchStateRecord,
Membership,
Op,
)
logger: FlightRecorderLogger = FlightRecorderLogger()
try:
from tabulate import tabulate
except ModuleNotFoundError:
logger.debug("tabulate is not installed. Proceeding without it.")
PATH_WHITE_LIST_REGEX = re.compile(r"[^_A-Za-z0-9/.-]")
MAX_READ_FILE_SIZE_4G = 4294967296
MAX_READ_FILE_SIZE_32G = 34359738368
MAX_READ_FILE_SIZE_512G = 549755813888
WRITE_FILE_NOT_PERMITTED_STAT = stat.S_IWGRP | stat.S_IWOTH | stat.S_IROTH | stat.S_IXOTH
READ_FILE_NOT_PERMITTED_STAT = stat.S_IWGRP | stat.S_IWOTH
def type_to_str(value_type):
return " or ".join(ii.__name__ for ii in value_type) if isinstance(value_type, tuple) else value_type.__name__
def check_type(value, value_type, param_name="value"):
if not isinstance(value, value_type):
raise TypeError("{} must be {}, not {}.".format(param_name, type_to_str(value_type), type(value).__name__))
def get_valid_path(path):
check_type(path, str, "path")
if not path or len(path) == 0:
raise ValueError("The value of the path cannot be empty.")
if PATH_WHITE_LIST_REGEX.search(path):
raise ValueError("Input path contains invalid characters.")
path = os.path.expanduser(path)
if os.path.islink(os.path.abspath(path)):
raise ValueError("The value of the path cannot be a symbolic link: {}.".format(path))
real_path = os.path.realpath(path)
if len(real_path) > 4096:
raise ValueError("The length of file path should be less than 4096.")
if real_path != path and PATH_WHITE_LIST_REGEX.search(real_path):
raise ValueError("Input path contains invalid characters.")
return real_path
def is_belong_to_user_or_group(file_stat):
return file_stat.st_uid == os.getuid() or file_stat.st_gid in os.getgroups()
def get_valid_read_path(path, size_max=MAX_READ_FILE_SIZE_4G, check_user_stat=True, is_dir=False):
real_path = get_valid_path(path)
if not is_dir and not os.path.isfile(real_path):
raise ValueError("The path {} doesn't exists or not a file.".format(path))
if is_dir and not os.path.isdir(real_path):
raise ValueError("The path {} doesn't exists or not a directory.".format(path))
file_stat = os.stat(real_path)
if check_user_stat and not sys.platform.startswith("win") and not is_belong_to_user_or_group(file_stat):
raise ValueError("The file {} doesn't belong to the current user or group.".format(path))
if check_user_stat and os.stat(path).st_mode & READ_FILE_NOT_PERMITTED_STAT > 0:
raise ValueError("The file {} is group writable, or is others writable.".format(path))
if not os.access(real_path, os.R_OK) or file_stat.st_mode & stat.S_IRUSR == 0:
raise ValueError("Current user doesn't have read permission to the file {}.".format(path))
if not is_dir and size_max > 0 and file_stat.st_size > size_max:
raise ValueError("The file {} exceeds size limitation of {}.".format(path, size_max))
return real_path
def check_write_directory(dir_name, check_user_stat=True):
real_dir_name = get_valid_path(dir_name)
if not os.path.isdir(real_dir_name):
raise ValueError("The file writen directory {} doesn't exists.".format(dir_name))
file_stat = os.stat(real_dir_name)
if check_user_stat and not sys.platform.startswith("win") and not is_belong_to_user_or_group(file_stat):
raise ValueError("The file writen directory {} doesn't belong to the current user or group.".format(dir_name))
if not os.access(real_dir_name, os.W_OK):
raise ValueError("Current user doesn't have writen permission to file writen directory {}.".format(dir_name))
def get_valid_write_path(path, check_user_stat=True, is_dir=False, warn_exists=True):
real_path = get_valid_path(path)
real_path_dir = real_path if is_dir else os.path.dirname(real_path)
check_write_directory(real_path_dir, check_user_stat=check_user_stat)
if not is_dir and os.path.exists(real_path):
if os.path.isdir(real_path):
raise ValueError("The file {} exist and is a directory.".format(path))
if check_user_stat and os.stat(real_path).st_uid != os.getuid():
raise ValueError("The file {} doesn't belong to the current user.".format(path))
if check_user_stat and os.stat(real_path).st_mode & WRITE_FILE_NOT_PERMITTED_STAT > 0:
raise ValueError("The file {} permission for others is not 0, or is group writable.".format(path))
if not os.access(real_path, os.W_OK):
raise ValueError("The file {} exist and not writable.".format(path))
if warn_exists:
logger.warning("%s already exist. The original file will be overwritten.", path)
return real_path
def format_frame(frame: dict[str, str]) -> str:
name = frame.get("name", "unknown")
filename = frame.get("filename", "unknown_file")
line = frame.get("line", "unknown_line")
return f"{name} at {filename}:{line}"
def format_frames(frames: list[dict[str, str]]) -> str:
formatted_frames = []
for frame in frames:
formatted_frames.append(format_frame(frame))
return "\n".join(formatted_frames)
def match_one_event(
event_a: dict[Any, Any],
event_b: dict[Any, Any],
memberships: dict[str, set[Any]],
pg_name: str,
) -> MatchInfo:
op_a = Op(event_a, memberships, pg_name)
op_b = Op(event_b, memberships, pg_name)
return op_a.match(op_b)
def check_size_alltoall(alltoall_cases: list[dict[str, Any]]) -> tuple[bool, int, int]:
input_numel = 0
output_numel = 0
for e in alltoall_cases:
input_sizes = e.get("input_sizes", [])
output_sizes = e.get("output_sizes", [])
if input_sizes and len(input_sizes) > 0:
input_numel += math.prod(input_sizes[0])
if output_sizes and len(output_sizes) > 0:
output_numel += math.prod(output_sizes[0])
return input_numel != output_numel, input_numel, output_numel
class ProcessGroupData:
def __init__(self, pg_guids: dict[tuple[str, int], str], pg_name: str, desc: str, mismatch: dict[str, int]):
self.pg_guids, self.pg_name, self.desc, self.mismatch = pg_guids, pg_name, desc, mismatch
def check_current_entry_match(
all_entries: dict[int, list[dict[str, Any]]],
current_entry: dict[str, Any],
_memberships: dict[str, set[Any]],
pg_data: ProcessGroupData,
match_record: MatchStateRecord,
) -> None:
pg_guids, pg_name, mismatch, desc = pg_data.pg_guids, pg_data.pg_name, pg_data.mismatch, pg_data.desc
for rank in match_record.expected_ranks.intersection(set(match_record.other_ranks)):
for entry_idx, entry in enumerate(all_entries[rank]):
if (
pg_guids.get((entry.get("process_group", [None])[0], rank)) == pg_name
and entry.get("collective_seq_id") == match_record.entry_state.collective_seq_id
):
match_info = match_one_event(current_entry, entry, _memberships, pg_name)
if match_info.state in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED] and mismatch[pg_name] == 0:
match_record.found_ranks.add(rank)
match_record.found_idx[rank] = entry_idx
match_record.has_undecided_case = match_info.state == MatchState.UNDECIDED
else:
match_record.candidate_ranks.add(rank)
match_record.candidate_idx[rank] = entry_idx
if match_info.state not in [
MatchState.FULLY_MATCHED,
MatchState.UNDECIDED,
]:
match_record.errors.add((rank, match_info))
break
class EntryContext:
def __init__(self, all_entries, current_entry, dumps_ranks, first_rank):
self.all_entries = all_entries
self.current_entry = current_entry
self.dumps_ranks = dumps_ranks
self.first_rank = first_rank
def error_analysis(
entry_context: EntryContext,
match_record: MatchStateRecord,
mismatch: dict[str, int],
version: tuple[int, int],
pg_name: str,
) -> None:
all_entries = entry_context.all_entries
current_entry = entry_context.current_entry
dumps_ranks = entry_context.dumps_ranks
first_rank = entry_context.first_rank
major_v, minor_v = version[0], version[1]
if (
match_record.candidate_ranks | match_record.found_ranks
) != match_record.expected_ranks and match_record.expected_ranks - (
match_record.candidate_ranks | match_record.found_ranks
) <= dumps_ranks:
mismatch[pg_name] += 1
logger_msg = "Not all ranks joining collective, sequence number: %s"
missing_ranks = match_record.expected_ranks - (match_record.candidate_ranks | match_record.found_ranks)
match_record.entry_state.log(
logger, logger_msg, format_frames, additional_info={"missing_ranks": missing_ranks}
)
match_record.candidate_ranks.update(match_record.found_ranks)
match_record.candidate_idx.update(match_record.found_idx)
match_record.found_idx.clear()
match_record.found_ranks.clear()
elif len(match_record.candidate_ranks) == 1 and dumps_ranks == match_record.expected_ranks:
if match_record.has_undecided_case:
alltoall_cases = [current_entry] + [
all_entries[rank][match_record.found_idx[rank]] for rank in match_record.found_ranks
]
fail_check, total_input_numel, total_output_numel = check_size_alltoall(alltoall_cases)
if major_v <= 2 and minor_v <= 3:
fail_check = False
if fail_check:
mismatch[pg_name] += 1
logger_msg = "Input/output mismatch in the collective sequence number: %s"
match_record.entry_state.log(
logger,
logger_msg,
format_frames,
additional_info={"total_numel": (total_input_numel, total_output_numel)},
)
match_record.candidate_ranks.update(match_record.found_ranks)
match_record.candidate_idx.update(match_record.found_idx)
match_record.found_idx.clear()
match_record.found_ranks.clear()
match_record.errors.add((first_rank, MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)))
else:
match_record.found_ranks.update(match_record.candidate_ranks)
match_record.found_idx.update(match_record.candidate_idx)
match_record.candidate_idx.clear()
match_record.candidate_ranks.clear()
else:
match_record.found_ranks.update(match_record.candidate_ranks)
match_record.found_idx.update(match_record.candidate_idx)
match_record.candidate_idx.clear()
match_record.candidate_ranks.clear()
elif len(match_record.errors) > 0:
mismatch[pg_name] += 1
logger_msg = "Collective sequence number: %s has errors"
match_record.entry_state.log(logger, logger_msg, format_frames, errors=match_record.errors)
match_record.candidate_ranks.update(match_record.found_ranks)
match_record.candidate_idx.update(match_record.found_idx)
match_record.found_idx.clear()
match_record.found_ranks.clear()
else:
match_record.candidate_ranks.update(match_record.found_ranks)
match_record.candidate_idx.update(match_record.found_idx)
match_record.found_idx.clear()
match_record.found_ranks.clear()
if match_record.expected_ranks - dumps_ranks:
mismatch[pg_name] += 1
logger.info(
"We cannot decide what's wrong with this collective entry "
"because we missed FR dumps from ranks (%s) so we don't have enough "
"information. If you want to debug further use -j to dump all raw trace",
str(match_record.expected_ranks - dumps_ranks),
)
else:
logger.info(
"No errors found for this collective entry, There could be some "
"other reasons why we see collective timeout."
)
def just_print_entries(
all_entries: dict[int, list[dict[str, Any]]],
_groups: dict[str, Group],
_memberships: dict[str, set[Any]],
_pg_guids: dict[tuple[str, int], str],
args: argparse.Namespace,
) -> None:
rows = []
ranks = sorted(all_entries.keys())
headers = [f"Rank {rank}" for rank in ranks if args.selected_ranks is None or rank in args.selected_ranks]
progress = True
while progress:
progress = False
row = []
for rank in ranks:
if args.selected_ranks is not None and rank not in args.selected_ranks:
continue
if len(all_entries[rank]) == 0:
row.append("")
else:
entry = all_entries[rank].pop(0)
process_group = entry.get("process_group", [None, None])
pg_name = _pg_guids.get((process_group[0], rank))
if (
args.pg_filters is None
or process_group[1] in args.pg_filters
or process_group[0] in args.pg_filters
):
row.append(str(Op(entry, _memberships, pg_name)))
else:
row.append("")
progress = True
if progress:
rows.append(row)
logger.info(tabulate(rows, headers=headers))
def check_no_missing_dump_files(entries: dict[int, Any], memberships: list[Membership]) -> None:
try:
all_ranks = {int(m.global_rank) for m in memberships}
except (ValueError, TypeError) as e:
raise ValueError(f"Cannot extract rank from memberships. Invalid global_rank value encountered: {e}") from e
try:
dumps_ranks = {int(key) for key in entries.keys()}
except (ValueError, TypeError) as e:
raise ValueError(f"Cannot extract rank from entries keys. Invalid key value encountered: {e}") from e
missing_ranks = all_ranks - dumps_ranks
if missing_ranks:
raise ValueError(
f"Missing dump files for {len(missing_ranks)} ranks: {sorted(missing_ranks)}\n"
f"Expected ranks: {sorted(all_ranks)}\n"
f"Found dumps for: {sorted(dumps_ranks)}"
)
def check_version(version_by_ranks: dict[str, str], expected_version: str) -> None:
for rank, actual_version in version_by_ranks.items():
if actual_version != expected_version:
raise ValueError(f"Version mismatch at rank {rank}: " f"expected {expected_version}, got {actual_version}")
def get_version_detail(version_str: str) -> tuple[int, int]:
parts = version_str.split(".")
if len(parts) != 2:
raise ValueError(f"Invalid version format: expected 'X.Y', got '{version_str}'")
try:
major, minor = int(parts[0]), int(parts[1])
except ValueError as e:
raise ValueError(f"Version components must be integers: '{version_str}'") from e
return major, minor
def align_trace_from_beginning(
entries: dict[int, list[dict[str, Any]]],
) -> dict[int, list[dict[str, Any]]]:
"""
Align the trace entries by record ID for entries.
This function takes a dictionary of rank names to lists of trace entries as input.
Each trace entry is a dictionary containing information about a collective operation,
including its unique identifier (`record_id` is monotonically increasing as we write into the ring buffer).
The function finds the largest starting point across all ranks by taking the maximum
`record_id` value of the first entry in each rank. Finally, it filters out any
entries with `record_id` values less than the maximum starting point.
The function returns the updated dictionary of sorted and filtered trace entries.
Args:
entries (Dict[str, List[Dict[str, Any]]]): A dictionary of rank names to lists of trace entries.
Returns:
entries (Dict[str, List[Dict[str, Any]]]): Entries sorted by record ID and filtered by the maximum starting point.
"""
maximum_starting_record_id = 0
for rank in entries:
first_record_id = entries[rank][0].get("record_id", -1)
maximum_starting_record_id = max(maximum_starting_record_id, first_record_id)
for rank in entries:
entries[rank] = [entry for entry in entries[rank] if entry.get("record_id", -1) >= maximum_starting_record_id]
return entries