diff --git a/megatron/core/datasets/helpers.cpp b/megatron/core/datasets/helpers.cpp
index 71299996..ee2bc103 100644
--- a/megatron/core/datasets/helpers.cpp
+++ b/megatron/core/datasets/helpers.cpp
@@ -172,14 +172,14 @@ py::array build_sample_idx(const py::array_t<int32_t> &sizes_,
   {
     num_samples = ceil(float(num_epochs * tokens_per_epoch - add_extra_token_to_sequence) / seq_length);
   }
-  int32_t *sample_idx = new int32_t[2 * (num_samples + 1)];
+  int64_t *sample_idx = new int64_t[2 * (num_samples + 1)];
 
   // Index into sample_idx.
   int64_t sample_index = 0;
   // Index into doc_idx.
   int64_t doc_idx_index = 0;
   // Begining offset for each document.
-  int32_t doc_offset = 0;
+  int64_t doc_offset = 0;
   // Start with first document and no offset.
   sample_idx[2 * sample_index] = doc_idx_index;
   sample_idx[2 * sample_index + 1] = doc_offset;
@@ -232,7 +232,7 @@ py::array build_sample_idx(const py::array_t<int32_t> &sizes_,
 	delete[] mem; });
 
   // Return the numpy array.
-  const auto byte_size = sizeof(int32_t);
+  const auto byte_size = sizeof(int64_t);
   return py::array(std::vector<int64_t>{num_samples + 1, 2}, // shape
                    {2 * byte_size, byte_size},               // C-style contiguous strides
                    sample_idx,                               // the data pointer