"""Dataloaders."""
import torch
from transformers import DataCollatorForSeq2Seq
from megatron.training import get_args, get_tokenizer
from megatron.core import parallel_state
from megatron.legacy.data.data_samplers import MegatronPretrainingSampler, MegatronPretrainingRandomSampler
from mindspeed_llm.tasks.preprocess.collator import PairwiseDataCollatorWithPadding
def build_pretraining_data_loader(dataset, consumed_samples):
"""Buld dataloader given an input dataset."""
if dataset is None or len(dataset) == 0:
return None
args = get_args()
if args.dataloader_type == 'single':
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=parallel_state.get_data_parallel_rank(),
data_parallel_size=parallel_state.get_data_parallel_world_size())
elif args.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
total_samples=len(dataset),
dataset=dataset,
data_sharding=True,
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=parallel_state.get_data_parallel_rank(),
data_parallel_size=parallel_state.get_data_parallel_world_size())
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
tokenizer = get_tokenizer().tokenizer
if args.is_pairwise_dataset:
tokenizer.padding_side = args.tokenizer_padding_side
collator = PairwiseDataCollatorWithPadding(
tokenizer,
pad_to_multiple_of=args.pad_to_multiple_of if args.no_pad_to_seq_lengths else args.seq_length,
return_tensors='pt',
padding=True
)
elif args.is_instruction_dataset:
tokenizer.padding_side = args.tokenizer_padding_side
collator = DataCollatorForSeq2Seq(
tokenizer,
pad_to_multiple_of=args.pad_to_multiple_of if args.no_pad_to_seq_lengths else args.seq_length,
return_tensors='pt',
padding=True
)
else:
collator = None
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed),
collate_fn=collator,
pin_memory=True)