"""
Thread-safe configuration loading utilities.
This module provides utilities for safely loading HuggingFace model configurations
in multi-threaded environments, preventing race conditions that can occur when
multiple threads try to download and cache the same model simultaneously.
"""
import hashlib
import os
import time
from pathlib import Path
from typing import Union
import filelock
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
def safe_load_config_with_retry(
path: Union[str, Path], trust_remote_code: bool = False, max_retries: int = 3, base_delay: float = 1.0, **kwargs
) -> PretrainedConfig:
"""
Thread-safe and process-safe configuration loading with retry logic.
"""
last_exception = None
for attempt in range(max_retries + 1):
try:
path_hash = hashlib.md5(str(path).encode()).hexdigest()
lock_dir = os.getenv("MEGATRON_CONFIG_LOCK_DIR")
if lock_dir:
lock_file = Path(lock_dir) / f".megatron_config_lock_{path_hash}"
else:
lock_file = Path.home() / ".cache" / "huggingface" / f".megatron_config_lock_{path_hash}"
lock_file.parent.mkdir(parents=True, exist_ok=True)
with filelock.FileLock(str(lock_file) + ".lock", timeout=60):
return AutoConfig.from_pretrained(path, trust_remote_code=trust_remote_code, **kwargs)
except Exception as e:
last_exception = e
error_msg = str(e).lower()
non_retryable_phrases = [
"does not appear to have a file named config.json",
"repository not found",
"entry not found",
"401 client error",
"403 client error",
]
if any(phrase in error_msg for phrase in non_retryable_phrases):
raise ValueError(
f"Failed to load configuration from {path}. "
f"Ensure the path is valid and contains a config.json file. "
f"Error: {e}"
) from e
if attempt < max_retries:
delay = base_delay * (2 ** attempt) + (time.time() % 1) * 0.1
time.sleep(delay)
else:
break
raise ValueError(
f"Failed to load configuration from {path} after {max_retries + 1} attempts. "
f"This might be due to network issues or concurrent access conflicts. "
f"Last error: {last_exception}"
) from last_exception