import logging
import math
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Iterable, List, Optional, Type, Union
import numpy
import torch
from mindspeed_llm.fsdp2.data.megatron_data.blended_dataset import BlendedDataset
from mindspeed_llm.fsdp2.data.megatron_data.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
from mindspeed_llm.fsdp2.data.megatron_data.megatron_dataset import LowLevelDataset, MegatronDataset
from mindspeed_llm.fsdp2.data.megatron_data.megatron_utils import Split, normalize, need_to_build_dataset
from mindspeed_llm.fsdp2.utils.logging import get_logger
logger = get_logger(__name__)
MidLevelDataset = MegatronDataset
TopLevelDataset = Union[BlendedDataset, MidLevelDataset]
DistributedDataset = Union[
TopLevelDataset, MidLevelDataset, LowLevelDataset, torch.utils.data.Dataset
]
class BlendedMegatronDatasetBuilder(object):
"""Builder class for the BlendedDataset and MegatronDataset classes
Args:
cls (Type[MegatronDataset]): The class to instantiate, must inherit from MegatronDataset
sizes (List[Optional[int]]): The minimum total number of samples to draw, or None, per split
is_built_on_rank (Callable): A callable which returns True if the dataset should be built on
the current rank and False otherwise. It should be Megatron Core parallelism aware i.e.
global rank, local group rank, and virtual rank may inform its return value.
config (BlendedMegatronDatasetConfig): The config object which informs dataset creation
"""
def __init__(
self,
cls: Type[MidLevelDataset],
sizes: List[int],
is_built_on_rank: Callable,
config: BlendedMegatronDatasetConfig,
):
self.cls = cls
self.sizes = sizes
self.is_built_on_rank = is_built_on_rank
self.config = config
logger.info_rank0(
f"Building {cls.__name__} splits with sizes={self.sizes} and config={self.config}",
)
if not self.config.mock:
for split in Split:
size_is_none = self.sizes[split.value] is None
if self.config.blend_per_split is None:
weights_are_none = self.config.blend[1] is None
else:
if self.config.blend_per_split[split.value] is None:
continue
weights_are_none = self.config.blend_per_split[split.value][1] is None
if size_is_none:
assert (
weights_are_none
), f"size_is_none => weights_are_none fails for {split.name} split"
if torch.distributed.is_initialized():
gb_rank = torch.distributed.get_rank()
if gb_rank == 0:
assert (
self.is_built_on_rank()
), "is_built_on_rank must return True when global rank = 0"
def build(self) -> List[Optional[TopLevelDataset]]:
"""Build all dataset splits according to the provided blend(s)
This method is distributed-aware and must be called on all ranks.
The dataset splits returned can vary according to the config. Supply config.blend and
config.split to build BlendedDataset and/or MegatronDataset splits from the same
distribution. Supply config.blend_per_split to build BlendedDataset and/or MegatronDataset
splits from separate distributions. In either case, for each split, handle the following
cases:
(1) The split is None
- do nothing
(2) The split has one contributing dataset, and...
(a) 'size' is not None
- Build a mid-level dataset with low-level dataset sampling in proportion to the
size
(b) 'size' is None
- Build mid-level datasets with no excess low-level dataset sampling
(3) The split has multiple contributing datasets, and...
(a) 'weights' is not None and 'size' is not None
- Build mid-level datasets with low-level dataset sampling in proportion to their
weights and the size
- Build a top-level dataset of length marginally greater than 'size' with mid-level
dataset sampling in proportion to their weights and the size
(b) 'weights' is not None and 'size' is None
- Error
(c) 'weights' is None and 'size' is not None
- Build mid-level datasets with no excess low-level dataset sampling
- Build a top-level dataset of length 'size' (capped at the sum of the mid-level
dataset lengths) with mid-level dataset sampling in proportion to their lengths
and the size
(d) 'weights' is None and 'size' is None
- Build mid-level datasets with no excess low-level dataset sampling
- Build a top-level dataset with no excess mid-level dataset sampling
Returns:
List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per
split
"""
datasets = self._build_blended_dataset_splits()
for dataset in datasets:
if dataset is not None and len(dataset) > 0:
if isinstance(dataset, BlendedDataset):
if dataset.built_anew_on_cache_miss or any(
x.built_anew_on_cache_miss for x in dataset.datasets
):
logger.info_rank0(
(
f"Verifying NumPy indices for {type(dataset).__name__} "
f"{dataset.split.name} split"
),
)
else:
logger.info_rank0(
(
f"NumPy indices for {type(dataset).__name__} {dataset.split.name} "
f"split are fully cached, skipping verification"
),
)
continue
assert dataset.size is None or dataset.size == dataset.dataset_index.shape[0]
dataset_indices, dataset_sizes = numpy.unique(
dataset.dataset_index, return_counts=True
)
for i, (index, size) in enumerate(zip(dataset_indices, dataset_sizes)):
if len(dataset.datasets[index]) < size:
raise IndexError(
f"The {dataset.split.name} blend oversamples the contributing "
f"datasets and, e.g., requests {size} samples from "
f"{type(dataset.datasets[index]).__name__} {i} with size "
f"{len(dataset.datasets[index])}. This is unexpected. "
f"Please file an issue."
)
return datasets
def _build_blended_dataset_splits(self) -> List[Optional[TopLevelDataset]]:
"""Build all dataset splits according to the provided blend(s)
See the BlendedMegatronDatasetBuilder.build alias for more information.
Returns:
List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per
split
"""
if self.config.mock:
split = self.config.split_matrix
try:
return self._build_megatron_dataset_splits(None, split, self.sizes)
except Exception as error:
raise Exception(
f"{self.cls.__name__} failed to build as a mock data generator"
) from error
elif self.config.blend:
prefixes, weights = self.config.blend
if weights is not None:
weights = normalize(weights)
split = self.config.split_matrix
if len(prefixes) == 1 and weights is None:
return self._build_megatron_dataset_splits(prefixes[0], split, self.sizes)
if weights is None:
sizes_per_dataset_buffer = [[None for split in Split] for prefix in prefixes]
else:
sizes_per_dataset_target = _get_size_per_split_per_dataset(weights, self.sizes)
sizes_per_dataset_buffer = _get_size_per_split_per_dataset(
weights, self.sizes, margin=0.5
)
megatron_datasets = self._build_megatron_datasets_parallel(
prefixes, split, sizes_per_dataset_buffer
)
blended_datasets = [None] * len(Split)
for i in range(len(Split)):
if split[i] is not None:
weights_i = weights
if weights_i is not None and self.sizes[i] is not None:
size_per_dataset = list(zip(*sizes_per_dataset_target))[i]
size_i = sum(size_per_dataset)
elif weights_i is None:
try:
weights_i = [
len(megatron_dataset) for megatron_dataset in megatron_datasets[i]
]
except TypeError:
weights_i = [0 for _ in prefixes]
if self.sizes[i] is not None:
size_i = min(self.sizes[i], sum(weights_i))
else:
size_i = None
else:
raise ValueError(
"Using client-specified weights requires client-specified size"
)
blended_datasets[i] = self.build_generic_dataset(
BlendedDataset,
self.is_built_on_rank,
True,
megatron_datasets[i],
weights_i,
size_i,
self.config,
)
return blended_datasets
else:
blended_datasets = [None] * len(Split)
for i in range(len(Split)):
split_spoof = [None] * len(Split)
split_spoof[i] = (0.0, 1.0)
sizes_spoof = [0] * len(Split)
sizes_spoof[i] = self.sizes[i]
blend = self.config.blend_per_split[i]
if blend is not None:
prefixes, weights = blend
if weights is not None:
weights = normalize(weights)
if len(prefixes) == 1:
blended_datasets[i] = self._build_megatron_dataset_splits(
prefixes[0], split_spoof, sizes_spoof
)[i]
continue
if weights is None:
sizes_per_dataset_buffer = [
[None for split in Split] for prefix in prefixes
]
else:
sizes_per_dataset_target = _get_size_per_split_per_dataset(
weights, sizes_spoof
)
sizes_per_dataset_buffer = _get_size_per_split_per_dataset(
weights, sizes_spoof, margin=0.5
)
megatron_datasets = self._build_megatron_datasets_parallel(
prefixes, split_spoof, sizes_per_dataset_buffer
)[i]
if weights is not None and self.sizes[i] is not None:
size_per_dataset = list(zip(*sizes_per_dataset_target))[i]
size = sum(size_per_dataset)
elif weights is None:
try:
weights = [
len(megatron_dataset) for megatron_dataset in megatron_datasets
]
except TypeError:
weights = [0 for _ in prefixes]
if self.sizes[i] is not None:
size = min(self.sizes[i], sum(weights))
else:
size = None
else:
raise RuntimeError
blended_datasets[i] = self.build_generic_dataset(
BlendedDataset,
self.is_built_on_rank,
True,
megatron_datasets,
weights,
size,
self.config,
)
return blended_datasets
def _build_megatron_datasets_parallel(
self, prefixes: List[str], split: List[float], sizes_per_dataset: List[List[int]]
) -> List[List[Optional[MegatronDataset]]]:
"""Build the megatron datasets for a list of prefixes in parallel
Args:
prefixes (List[str]): The list of prefix strings
split (List[float]): The dataset split ratios (must sum to 1.00)
sizes_per_dataset (List[List[int]]): The number of samples to request
per MegatronDataset per spilt
Returns:
List[List[Optional[MegatronDataset]]]: For each split, have a list of
MegatronDataset per prefix
"""
def _threading_helper(
megatron_datasets: List[List[Optional[MegatronDataset]]],
num_workers: int,
prefixes: List[str],
split: List[float],
sizes_per_dataset: List[List[int]],
) -> None:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
all_futures = []
for i in range(len(prefixes)):
all_futures.append(
executor.submit(
self._build_megatron_dataset_splits,
prefixes[i],
split,
sizes_per_dataset[i],
False,
)
)
for future in all_futures:
try:
megatron_datasets_split = future.result()
for j in range(len(megatron_datasets_split)):
megatron_datasets[j].append(megatron_datasets_split[j])
except Exception as err:
raise err
megatron_datasets = [[] for _ in range(len(Split))]
num_dataset_builder_threads = self.config.num_dataset_builder_threads
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
if rank == 0:
num_workers = num_dataset_builder_threads
if num_workers > 1:
num_workers *= min(2, max(1, torch.cuda.device_count()))
_threading_helper(
megatron_datasets, num_workers, prefixes, split, sizes_per_dataset
)
torch.distributed.barrier()
if rank != 0:
_threading_helper(
megatron_datasets,
num_dataset_builder_threads,
prefixes,
split,
sizes_per_dataset,
)
else:
_threading_helper(
megatron_datasets, num_dataset_builder_threads, prefixes, split, sizes_per_dataset
)
return megatron_datasets
def _build_megatron_dataset_splits(
self,
dataset_path: Optional[str],
split: List[float],
sizes: List[int],
synchronize_ranks: bool = True,
) -> List[Optional[MidLevelDataset]]:
"""Build each MidLevelDataset split from a single LowLevelDataset
Args:
dataset_path (Optional[str]): The path on disk which defines the underlying
LowLevelDataset, or None for mock dataset classes
split (List[Tuple[float, float]]): The dataset split matrix
sizes (List[int]): The number of total samples to draw from each split
synchronize_ranks (bool): Whether to call barrier for rank-0 / barrier / other-ranks
behavior. Set to False when we enforce this behavior at higher level.
Returns:
List[Optional[MidLevelDataset]]: The MidLevelDataset (or None) per split
"""
if torch.distributed.is_initialized() and not self.is_built_on_rank():
for i in range(len(Split)):
if split[i] is not None and synchronize_ranks:
torch.distributed.barrier()
return [None] * len(Split)
low_level_dataset = self.cls.build_low_level_dataset(dataset_path, self.config)
num_elements = self.cls.numel_low_level_dataset(low_level_dataset)
split_indices = []
for i, _ in enumerate(Split):
if split[i] is not None:
beg = int(round(split[i][0] * float(num_elements)))
end = int(round(split[i][1] * float(num_elements)))
split_indices.append(numpy.arange(start=beg, stop=end, step=1, dtype=numpy.int32))
else:
split_indices.append(None)
mid_level_datasets = []
for i, _split in enumerate(Split):
if split[i] is None:
mid_level_datasets.append(None)
else:
mid_level_datasets.append(
self.build_generic_dataset(
self.cls,
self.is_built_on_rank,
synchronize_ranks,
low_level_dataset,
dataset_path,
split_indices[i],
sizes[i],
_split,
self.config,
)
)
return mid_level_datasets
@staticmethod
def build_generic_dataset(
cls: Union[Type[DistributedDataset], Callable],
is_built_on_rank: Callable,
synchronize_ranks: bool,
*args: Any,
) -> Optional[Union[DistributedDataset, Iterable]]:
"""Build the DistributedDataset
Return None if and only if the underlying dataset class is not built on the current rank
and torch.distributed is initialized.
Args:
cls (Union[Type[DistributedDataset], Callable]): The DistributedDataset class to be
built. In special cases, e.g. when we are building the low level dataset for a
RawMegatronDataset instance, we can accept a Callable which returns an Iterable.
synchronize_ranks (bool): Whether to call barrier for rank-0 / barrier / other-ranks
behavior. Set to False when we enforce this behavior at higher level.
args (Tuple[Any]): The positional arguments used to build the provided
DistributedDataset class
Raises:
Exception: When the dataset constructor raises an OSError
Returns:
Optional[Union[DistributedDataset, Iterable]]: The DistributedDataset instantion, the
Iterable instantiation, or None
"""
if torch.distributed.is_initialized():
dataset = None
to_build_flag = need_to_build_dataset()
if to_build_flag and is_built_on_rank():
try:
dataset = cls(*args)
except OSError as err:
log = (
f"Failed to write dataset materials to the data cache directory. "
+ f"Please supply a directory to which you have write access via "
+ f"the path_to_cache attribute in BlendedMegatronDatasetConfig and "
+ f"retry. Refer to the preserved traceback above for more information."
)
raise Exception(log) from err
if synchronize_ranks:
torch.distributed.barrier()
if not to_build_flag and is_built_on_rank():
dataset = cls(*args)
return dataset
return cls(*args)
def _get_size_per_split_per_dataset(
normalized_weights: List[float], target_size_per_split: List[int], margin: float = 0.0
) -> List[List[int]]:
"""Determine the contribution of the MegatronDataset splits to the BlendedDataset splits
Args:
normalized_weights (List[float]): e.g. [0.3, 0.7]
target_size_per_split (List[int]): The number of samples to target for each BlendedDataset
split
margin (float): The relative quantity of extra samples to build per per split per dataset,
as a percentage
Returns:
List[List[int]]: The number of samples to request per MegatronDataset per split
"""
assert numpy.isclose(sum(normalized_weights), 1.0)
sizes_per_dataset = [
[
int(math.ceil(math.ceil(target_size * weight) * (1 + margin / 100)))
for target_size in target_size_per_split
]
for weight in normalized_weights
]
return sizes_per_dataset