"""
A dataloader implementing dynamic batching. It selects appropriate samples from a buffer based on text length
and concatenates them into a new batch of a specified `max_seq_len`. The key workflow is as follows:
1. `_get_data_from_dataloader` fetches a batch from the original dataloader and splits it along the batch
dimension into individual samples.
2. `batching_strategy` places the obtained samples into a buffer.
3. When the buffer is full, `batching_strategy` selects a subset of samples such that the sum of their text
lengths is as close as possible to `max_seq_len`.
4. `collate_fn` concatenates the selected samples into a single batch and outputs it.
"""
import traceback
import warnings
from collections.abc import Iterable
import torch
from torch.utils.data import default_collate
from megatron.training import print_rank_0
from mindspeed_mm.data.dataloader.batching_strategy import TextBatchingStrategy
class DynamicBatchingDataLoader:
def __init__(
self,
dataloader,
max_seq_len: int,
dynamic_batch_buffer_size: int,
drop_last: bool = False,
vision_layout: str = 'TND',
consumed_train_samples: int = 0
) -> None:
print_rank_0("[INFO] initializing dynamic batching DataLoader")
self.vision_layout = vision_layout
self.dataloader = dataloader
self.num_step = len(self.dataloader)
self.batching_strategy = TextBatchingStrategy(
max_seq_len=max_seq_len,
buffer_size=dynamic_batch_buffer_size,
)
self.drop_last = drop_last
self.consumed_train_samples = consumed_train_samples
self.non_packing_data = {}
print_rank_0("[INFO] Successfully initialize dynamic batching DataLoader")
def __len__(self):
return len(self.dataloader)
def __iter__(self):
self.step = 0
self._data_iter = iter(self.dataloader)
self._batched_data_iter = self.dynamic_batching_data_generator()
return self
def __next__(self):
return next(self._batched_data_iter)
def dynamic_batching_data_generator(self):
while True:
if self.batching_strategy.is_full_filled():
micro_batch = self._get_micro_batch()
yield micro_batch
self.step += 1
try:
processing_item = self._get_data_from_dataloader()
except Exception as e:
if isinstance(e, StopIteration):
if not self.drop_last and not self.batching_strategy.empty():
while not self.batching_strategy.empty():
micro_batch = self._get_micro_batch()
yield micro_batch
self.step += 1
return
else:
return
else:
print(f"DynamicBatchDataset iter data exception: {e} \n{traceback.format_exc()}")
raise
if isinstance(processing_item, dict):
processing_item = [processing_item]
for item in processing_item:
self.batching_strategy.put_item(item)
def _get_data_from_dataloader(self):
data = next(self._data_iter)
origin_mbs = data['input_ids'].shape[0]
for data_name, value in data.items():
if not isinstance(value, Iterable) or isinstance(value, (str, bytes)):
self.non_packing_data[data_name] = value
elif data_name != 'pixel_values' and len(value) != origin_mbs:
warnings.warn(
f"The iterable data {data_name} (of type {type(value)}) in micro batch extracted from "
f"the original dataloader has a length inconsistent with the micro batch size (mbs). To ensure "
f"correct decomposition into mbs individual samples, it has been moved to non_packing_data. "
f"Please verify the actual purpose of this data and apply appropriate adjustments."
)
self.non_packing_data[data_name] = value
for data_name in self.non_packing_data:
data.pop(data_name)
data_names = data.keys()
data['input_ids'] = [data['input_ids'][i][mask] for i, mask in enumerate(data['attention_mask'].bool())]
data['labels'] = [data['labels'][i][mask] for i, mask in enumerate(data['attention_mask'].bool())]
data['attention_mask'] = [
data['attention_mask'][i][mask]
for i, mask in enumerate(data['attention_mask'].bool())
]
pixel_length = data['image_grid_thw'][:, 0] * data['image_grid_thw'][:, 1] * data['image_grid_thw'][:, 2]
data['pixel_values'] = data['pixel_values'].split(pixel_length.tolist())
return [dict(zip(data_names, row)) for row in zip(*data.values())]
def _get_micro_batch(self):
micro_batch = self.batching_strategy.get_micro_batch()
self.consumed_train_samples += len(micro_batch)
micro_batch = self.collect_fn(micro_batch, self.vision_layout)
return micro_batch
def collect_fn(self, features, vision_layout):
seqlens = torch.tensor([len(feature['input_ids']) for feature in features], dtype=torch.long)
batch = {"seqlens": seqlens}
for data_name, value in self.non_packing_data.items():
batch[data_name] = value
for input_name in features[0].keys():
if input_name in ('input_ids', 'attention_mask', 'labels'):
batch[input_name] = torch.cat([feature[input_name] for feature in features]).unsqueeze(0)
else:
if input_name == 'pixel_values' and vision_layout == "TND":
batch[input_name] = torch.cat([feature[input_name] for feature in features])
else:
batch[input_name] = default_collate([feature[input_name] for feature in features])
if "attention_mask" in batch.keys():
batch["indices"] = torch.arange(len(batch["attention_mask"][0]))
else:
raise ValueError("Need attention mask to generate indices")
return batch