--- a/weight_utils.py
+++ b/weight_utils.py
@@ -8,6 +8,7 @@ import fnmatch
 import glob
 import hashlib
 import json
+import random
 import os
 import tempfile
 import threading
@@ -898,6 +899,7 @@ def safetensors_weights_iterator(
     *,
     safetensors_prefetch_num_threads: int = DEFAULT_SAFETENSORS_PREFETCH_NUM_THREADS,
     safetensors_prefetch_block_size: int = DEFAULT_SAFETENSORS_PREFETCH_BLOCK_SIZE,
+    shuffle_safetensors_files: bool = False,
 ) -> Generator[tuple[str, torch.Tensor], None, None]:
     """Iterate over the weights in the model safetensor files.

@@ -911,6 +913,13 @@ def safetensors_weights_iterator(

     sorted_files = sorted(hf_weights_files, key=_natural_sort_key)

+    if shuffle_safetensors_files:
+        if torch.distributed.is_initialized():
+            rank = torch.distributed.get_rank()
+        else:
+            rank = 0
+        rng = random.Random(42 + rank)
+        rng.shuffle(sorted_files)
     fs_type = _get_fs_type(sorted_files)
     is_net_fs = fs_type in ("nfs", "nfs4", "lustre")
     total_bytes = _get_checkpoints_size_bytes(sorted_files)