diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py
index a645f89..7de00b7 100644
--- a/megatron/core/datasets/gpt_dataset.py
+++ b/megatron/core/datasets/gpt_dataset.py
@@ -340,9 +340,11 @@ class GPTDataset(MegatronDataset):
         else:
             cache_hit = False
 
+        from megatron.training import get_args
+        args = get_args()
         if not path_to_cache or (
             not cache_hit
-            and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0)
+            and (not torch.distributed.is_initialized() or torch.distributed.get_rank() % args.tensor_model_parallel_size == 0)
         ):
 
             log_single_rank(