@@ -8,6 +8,7 @@ import fnmatch
import glob
import hashlib
import json
+import random
import os
import tempfile
import threading
@@ -898,6 +899,7 @@ def safetensors_weights_iterator(
*,
safetensors_prefetch_num_threads: int = DEFAULT_SAFETENSORS_PREFETCH_NUM_THREADS,
safetensors_prefetch_block_size: int = DEFAULT_SAFETENSORS_PREFETCH_BLOCK_SIZE,
+ shuffle_safetensors_files: bool = False,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files.
@@ -911,6 +913,13 @@ def safetensors_weights_iterator(
sorted_files = sorted(hf_weights_files, key=_natural_sort_key)
+ if shuffle_safetensors_files:
+ if torch.distributed.is_initialized():
+ rank = torch.distributed.get_rank()
+ else:
+ rank = 0
+ rng = random.Random(42 + rank)
+ rng.shuffle(sorted_files)
fs_type = _get_fs_type(sorted_files)
is_net_fs = fs_type in ("nfs", "nfs4", "lustre")
total_bytes = _get_checkpoints_size_bytes(sorted_files)