import os
import random
import time
import logging
from collections import Counter
import numpy as np
import torch
import torch.nn.functional as F
from megatron.training import print_rank_0, get_args
from megatron.core import parallel_state, mpu
from megatron.legacy.data.dataset_utils import get_train_valid_test_split_
from mindspeed_llm.training.tokenizer import build_tokenizer
from mindspeed_llm.tasks.utils.error_utils import check_equal
from mindspeed_llm.tasks.preprocess.mtf_dataset import MTFDataset, get_packed_indexed_dataset
from mindspeed_llm.tasks.preprocess.templates import get_model_template
from mindspeed_llm.training.utils import compute_actual_seq_len
logger = logging.getLogger(__name__)
def build_train_valid_test_datasets(
data_prefix,
splits_string,
seq_length: int,
train_valid_test_num_samples,
seed,
):
"""Build train, valid, and test datasets."""
args = get_args()
tokenizer = build_tokenizer(args)
pad_token = tokenizer.pad
eos_token = tokenizer.eos
all_train_datasets, all_valid_datasets, all_test_datasets = _build_train_valid_test_datasets(
data_prefix=data_prefix[0],
splits_string=splits_string,
seq_length=seq_length,
pad_token=pad_token,
eos_token=eos_token,
train_valid_test_num_samples=train_valid_test_num_samples,
seed=seed,
)
return all_train_datasets, all_valid_datasets, all_test_datasets
def _build_train_valid_test_datasets(
data_prefix,
splits_string,
seq_length: int,
pad_token: int,
eos_token: int,
train_valid_test_num_samples,
seed,
):
"""Build train, valid, and test datasets."""
packed_indexed_dataset = get_packed_indexed_dataset(data_prefix=data_prefix)
total_num_of_documents = len(list(packed_indexed_dataset.values())[0])
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32)
dataset = DecoderPackedMTFDataset(
name=name,
data_prefix=data_prefix,
documents=documents,
seq_length=seq_length,
pad_token=pad_token,
eos_token=eos_token,
num_samples=train_valid_test_num_samples[index],
seed=seed
)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
class DecoderPackedMTFDataset(torch.utils.data.Dataset):
def __init__(
self,
name,
data_prefix,
documents,
num_samples,
seq_length: int,
pad_token: int,
eos_token: int,
seed,
):
self.args = get_args()
self.mtf_dataset = MTFDataset(name=name, data_prefix=data_prefix, documents=documents)
self.pad_token = pad_token
self.seq_length = seq_length
self.eos_token = eos_token
self.shuffle_index = _build_index_mappings(name=name,
data_prefix=data_prefix,
start_index=documents[0],
nb_documents=len(documents),
mtf_dataset=self.mtf_dataset,
num_samples=num_samples,
seq_length=seq_length,
seed=seed,
shuffle=not self.args.no_shuffle)
self.cur_batch_index = []
self.iteration = 1
def __len__(self):
return len(self.shuffle_index)
@staticmethod
def _get_neat_pack_position_ids(data: torch.Tensor):
"""
position_ids is generated by attention_mask
"""
counter = Counter(data.tolist())
position_ids = []
zeros_nums = counter.pop(0) if 0 in counter.keys() else 0
for key in sorted(counter.keys()):
position_ids += [i for i in range(counter[key])]
position_ids += [0] * zeros_nums
return torch.tensor(position_ids, dtype=torch.long, device=data.device)
def __getitem__(self, idx):
doc_idx = self.shuffle_index[idx]
if self.args.return_document_ids and mpu.get_tensor_model_parallel_rank() == 0 and mpu.get_pipeline_model_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0:
self.cur_batch_index.append(doc_idx)
if len(self.cur_batch_index) == self.args.global_batch_size / self.args.data_parallel_size:
print("current iteration: {}, current rank:{}, data_parallel_rank:{}, document_ids:{}".format(
self.iteration, torch.distributed.get_rank(), mpu.get_data_parallel_rank(), self.cur_batch_index))
self.cur_batch_index = []
self.iteration += 1
item = self.mtf_dataset[doc_idx]
if self.args.is_pairwise_dataset:
return self._cut_pairwise_token(item, np.int64)
data = torch.from_numpy(item['input_ids'])
seq_length = data.numel()
actual_seq_len = None
position_ids = torch.arange(self.args.seq_length, dtype=torch.long, device=data.device)
if self.args.reset_position_ids:
position_ids = position_ids.clone()
if self.args.reset_attention_mask:
eod_index = position_ids[:seq_length][data == self.eos_token]
if self.args.reset_position_ids:
eod_index = eod_index.clone()
prev_index = 0
for j in range(eod_index.numel()):
i = eod_index[j]
if self.args.reset_position_ids:
position_ids[(i + 1):] -= i + 1 - prev_index
prev_index = i + 1
if self.args.neat_pack:
position_ids = self._get_neat_pack_position_ids(torch.from_numpy(item['attention_mask']))
actual_seq_len = compute_actual_seq_len(position_ids)
actual_seq_len = torch.tensor(actual_seq_len)
if not self.args.reset_position_ids:
position_ids = torch.arange(self.args.seq_length, dtype=torch.long, device=data.device)
if actual_seq_len is None:
actual_seq_len = eod_index + 1
mtp_res = None
if self.args.mtp_num_layers:
actual_seq_len = actual_seq_len.tolist()
mtp_res = [actual_seq_len]
for i in range(1, self.args.mtp_num_layers + 1):
next_actual_seq_len = []
for j in actual_seq_len:
if j % seq_length == 0:
next_actual_seq_len.append(j)
else:
next_actual_seq_len.append(j - i)
mtp_res.append(next_actual_seq_len)
mtp_res = torch.tensor(mtp_res)
pad_length = seq_length - len(actual_seq_len)
padded_actual_seq_len = F.pad(actual_seq_len, (0, pad_length), value=-1)
return {
"input_ids": self._cut_token(item['input_ids'], np.int64),
"attention_mask": self._cut_token(item["attention_mask"], np.int64),
"labels": self._cut_token(item["labels"], np.int64),
"position_ids": self._cut_token(position_ids.numpy(), np.int64),
"actual_seq_len": padded_actual_seq_len if mtp_res is None else mtp_res
}
elif self.args.cut_max_seqlen:
return {
"input_ids": self._cut_token(item['input_ids'], np.int64),
"attention_mask": self._cut_token(item["attention_mask"], np.int64),
"labels": self._cut_token(item["labels"], np.int64)
}
else:
return self._cut_instruction_token(item, np.int64)
def _cut_token(self, token, dtype):
token_length = len(token)
if not self.args.no_cut_token and token_length >= self.seq_length:
token = token[:self.seq_length]
return token.astype(dtype)
def _cut_instruction_token(self, item, dtype):
IGNORE_INDEX = -100
if "labels" in item.keys() and not self.args.dataset_additional_keys:
token_length = len(item["input_ids"])
if token_length <= self.seq_length:
return {
"input_ids": item["input_ids"].astype(dtype),
"attention_mask": np.ones_like(item["input_ids"]).astype(dtype),
"labels": item["labels"].astype(dtype)
}
template = None
if hasattr(self.args, "prompt_type") and self.args.prompt_type is not None:
template = get_model_template(self.args.prompt_type, self.args.prompt_type_path,
self.args.enable_thinking, self.args.reasoning_effort, self.args.drop_thinking)
prompt_begin_list, prompt_end_list = get_prompt_index(item["labels"], IGNORE_INDEX)
multi_turns = len(prompt_begin_list)
total_length = 0
if template is not None and template.efficient_eos:
total_length = 1
prompt_end_list = [x - 1 for x in prompt_end_list]
eos_token_id = item["input_ids"][token_length - 1]
item["input_ids"] = item["input_ids"][:token_length]
item["labels"] = item["labels"][:token_length]
cutoff_len = self.seq_length
input_ids = np.array([], dtype=dtype)
labels = np.array([], dtype=dtype)
for turn_idx in range(multi_turns):
if total_length >= cutoff_len:
break
source_ids = item["input_ids"][prompt_begin_list[turn_idx]:prompt_end_list[turn_idx]]
mask_ids = item["labels"][prompt_begin_list[turn_idx]:prompt_end_list[turn_idx]]
label_begin_idx = prompt_end_list[turn_idx]
if turn_idx != multi_turns - 1:
target_ids = item["labels"][label_begin_idx:prompt_begin_list[turn_idx + 1]]
else:
target_ids = item["labels"][label_begin_idx:]
source_len, target_len = _infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
source_ids = source_ids[:source_len]
target_ids = target_ids[:target_len]
mask_ids = mask_ids[:source_len]
total_length += source_len + target_len
input_ids = np.concatenate((input_ids, source_ids, target_ids), axis=0)
labels = np.concatenate((labels, mask_ids, target_ids), axis=0)
if template is not None and template.efficient_eos:
input_ids = np.concatenate((input_ids, np.array([eos_token_id], dtype=dtype)), axis=0)
labels = np.concatenate((labels, np.array([eos_token_id], dtype=dtype)), axis=0)
res = {
"input_ids": input_ids.astype(dtype),
"attention_mask": np.ones_like(input_ids).astype(dtype),
"labels": labels.astype(dtype)
}
else:
prompt_ids = item["input_ids"]
input_ids = prompt_ids[:self.seq_length]
add_vals = {}
for add_keys in self.args.dataset_additional_keys:
if add_keys in item.keys():
add_vals[add_keys] = item[add_keys]
res = dict(
{
"input_ids": input_ids.astype(dtype),
"attention_mask": np.ones_like(input_ids).astype(dtype)
}, **add_vals
)
return res
def _cut_pairwise_token(self, item, dtype):
"""Cut prompt and response proportionally for pairwise datasets."""
IGNORE_INDEX = -100
prompt_length = (item["chosen_labels"] != IGNORE_INDEX).nonzero()[0][0]
prompt_ids = item["chosen_input_ids"][:prompt_length]
chosen_ids = item["chosen_input_ids"][prompt_length:]
rejected_ids = item["rejected_input_ids"][prompt_length:]
source_len, target_len = _infer_seqlen(
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), self.seq_length
)
prompt_ids = prompt_ids[:source_len]
chosen_ids = chosen_ids[:target_len]
rejected_ids = rejected_ids[:target_len]
chosen_input_ids = np.append(prompt_ids, chosen_ids)
chosen_labels = np.append(IGNORE_INDEX * np.ones(source_len), chosen_ids)
rejected_input_ids = np.append(prompt_ids, rejected_ids)
rejected_labels = np.append(IGNORE_INDEX * np.ones(source_len), rejected_ids)
res = {
"chosen_input_ids": chosen_input_ids.astype(dtype),
"chosen_attention_mask": np.ones_like(chosen_input_ids).astype(dtype),
"chosen_labels": chosen_labels.astype(dtype),
"rejected_input_ids": rejected_input_ids.astype(dtype),
"rejected_attention_mask": np.ones_like(rejected_input_ids).astype(dtype),
"rejected_labels": rejected_labels.astype(dtype)
}
return res
def get_prompt_index(labels, ignored_label):
prompt_begin_list = []
prompt_end_list = []
in_group = False
for idx, label in enumerate(labels):
if label == ignored_label:
if not in_group:
prompt_begin_list.append(idx)
in_group = True
elif in_group:
prompt_end_list.append(idx)
in_group = False
return prompt_begin_list, prompt_end_list
def _infer_seqlen(source_len: int, target_len: int, cutoff_len: int):
r"""
Computes the real sequence length after truncation by the cutoff_len.
"""
if target_len * 2 < cutoff_len:
max_target_len = cutoff_len
elif source_len * 2 < cutoff_len:
max_target_len = cutoff_len - source_len
else:
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
new_target_len = min(max_target_len, target_len)
max_source_len = max(cutoff_len - new_target_len, 0)
new_source_len = min(max_source_len, source_len)
return new_source_len, new_target_len
def _build_index_mappings(
name,
data_prefix,
start_index,
nb_documents,
mtf_dataset,
num_samples: int,
seq_length: int,
seed,
shuffle=True
):
"""
- `shuffle_index` is [num_epoch * len(self.mtf)]
- `sample_index` is [num_sample, 2] (storing the start and end of the sample). We query the sample via `self.shuffle_index[start:end]`
"""
np_rng = np.random.RandomState(seed=seed)
args = get_args()
_filename = data_prefix
_filename += '_{}_indexmap'.format(name)
_filename += '_{}ns'.format(num_samples)
_filename += '_{}s'.format(seed)
_filename += '_{}'.format("padded_samples") if args.padded_samples else ""
_filename += '_{}'.format("shuffle") if shuffle else ""
shuffle_idx_filename = _filename + '_decoder_packed_idx.npy'
if torch.distributed.get_rank() % torch.cuda.device_count() == 0 or args.stage in ["ray_ppo", "ray_online_dpo", "ray_grpo"]:
if not os.path.isfile(shuffle_idx_filename):
print_rank_0(' > WARNING: could not find index map files, building '
'the indices on rank 0 ...')
start_time = time.time()
epoch = 0
shuffle_idx = []
while len(shuffle_idx) < num_samples:
if shuffle:
new_document_ids = _build_shuffle_idx(nb_documents=nb_documents, start_index=start_index,
np_rng=np_rng)
else:
new_document_ids = _build_sequential_idx(nb_documents=nb_documents, start_index=start_index)
shuffle_idx.extend(new_document_ids.tolist())
epoch += 1
if args.full_shuffle_instruction_dataset:
random.shuffle(shuffle_idx)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx and sample-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))
torch.distributed.barrier()
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group())
torch.distributed.all_reduce(counts, group=parallel_state.get_context_parallel_group())
item = (torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=parallel_state.get_tensor_model_parallel_group()))
check_equal(counts[0].item(), item)
start_time = time.time()
print_rank_0(' > loading shuffle-idx mapping from {}'.format(
shuffle_idx_filename))
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r+')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
return shuffle_idx
def _build_sequential_idx(nb_documents: int, start_index):
"""Build the range [0, dataset_size)."""
_args = get_args()
batch_size = _args.global_batch_size
remainder_docs = nb_documents % batch_size
dtype_ = np.int64
stop = start_index + nb_documents
result = np.arange(start=start_index, stop=stop, step=1, dtype=dtype_)
if _args.padded_samples and remainder_docs != 0:
padded_idnex = np.arange(start=0, stop=batch_size - remainder_docs, step=1, dtype=dtype_)
result = np.append(result, padded_idnex)
return result
def _build_shuffle_idx(nb_documents: int, start_index, np_rng):
"""Build the range [0, dataset_size) and shuffle."""
result = _build_sequential_idx(nb_documents, start_index)
np_rng.shuffle(result)
return result