@@ -172,14 +172,14 @@ py::array build_sample_idx(const py::array_t<int32_t> &sizes_,
{
num_samples = ceil(float(num_epochs * tokens_per_epoch - add_extra_token_to_sequence) / seq_length);
}
- int32_t *sample_idx = new int32_t[2 * (num_samples + 1)];
+ int64_t *sample_idx = new int64_t[2 * (num_samples + 1)];
// Index into sample_idx.
int64_t sample_index = 0;
// Index into doc_idx.
int64_t doc_idx_index = 0;
// Begining offset for each document.
- int32_t doc_offset = 0;
+ int64_t doc_offset = 0;
// Start with first document and no offset.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
@@ -232,7 +232,7 @@ py::array build_sample_idx(const py::array_t<int32_t> &sizes_,
delete[] mem; });
// Return the numpy array.
- const auto byte_size = sizeof(int32_t);
+ const auto byte_size = sizeof(int64_t);
return py::array(std::vector<int64_t>{num_samples + 1, 2}, // shape
{2 * byte_size, byte_size}, // C-style contiguous strides
sample_idx, // the data pointer