import os
import contextlib
import json
import fsspec
import numpy as np
from functools import partial
from typing import Any, Optional, Literal, TypedDict, Union
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizer
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
from datasets import Dataset, IterableDataset
from datasets import load_dataset, load_from_disk
from ..utils.arguments import DataArguments, ModelArguments, TrainingArguments
from .converter import get_local_rank, align_dataset
from .parser import DatasetAttr
from .parser import get_dataset_list
from .template import Template
from .processor import (
DatasetProcessor,
PackedSFTDatasetProcessor,
PretrainDatasetProcessor,
SFTDatasetProcessor
)
from mindspeed_llm.fsdp2.utils.logging import get_logger
from mindspeed_llm.fsdp2.distributed.parallel_state import ParallelState
logger = get_logger(__name__)
FILEEXT2TYPE = {
"arrow": "arrow",
"csv": "csv",
"json": "json",
"jsonl": "json",
"parquet": "parquet",
"txt": "text",
}
class DatasetModule(TypedDict):
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]
def merge_dataset(
all_datasets: list[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]:
r"""Merge multiple datasets to a unified dataset."""
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.info_rank0("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.info_rank0("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,
probabilities=data_args.interleave_probs,
seed=seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
def split_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]],
data_args: "DataArguments",
seed: int,
) -> "DatasetDict":
r""" Split the dataset and returns a dataset dict containing train set and validation set. Support both map dataset and iterable dataset."""
if eval_dataset is not None and data_args.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
dataset_dict = {}
if dataset is not None:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
if data_args.val_size > 1e-6:
if data_args.streaming:
dataset_dict["validation"] = dataset.take(int(data_args.val_size))
dataset_dict["train"] = dataset.skip(int(data_args.val_size))
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed)
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
dataset_dict = {"train": dataset["train"], "validation": dataset["test"]}
else:
dataset_dict["train"] = dataset
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
else:
if data_args.streaming:
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
dataset_dict["validation"] = eval_dataset
return DatasetDict(dataset_dict)
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
r"""Convert dataset or dataset dict to dataset module."""
dataset_module: DatasetModule = {}
if isinstance(dataset, DatasetDict):
if "train" in dataset:
dataset_module["train_dataset"] = dataset["train"]
if "validation" in dataset:
dataset_module["eval_dataset"] = dataset["validation"]
else:
eval_dataset = {}
for key in dataset.keys():
if key.startswith("validation_"):
eval_dataset[key[len("validation_") :]] = dataset[key]
if len(eval_dataset):
dataset_module["eval_dataset"] = eval_dataset
else:
dataset_module["train_dataset"] = dataset
return dataset_module
def setup_fs(path: str, anon: bool = False) -> "fsspec.AbstractFileSystem":
r"""Set up a filesystem object based on the path protocol."""
storage_options = {"anon": anon} if anon else {}
if path.startswith("s3://"):
fs = fsspec.filesystem("s3", **storage_options)
elif path.startswith(("gs://", "gcs://")):
fs = fsspec.filesystem("gcs", **storage_options)
else:
raise ValueError(f"Unsupported protocol in path: {path}. Use 's3://' or 'gs://'.")
if not fs.exists(path):
raise ValueError(f"Path does not exist: {path}.")
return fs
def _read_json_with_fs(fs: "fsspec.AbstractFileSystem", path: str) -> list[Any]:
r"""Helper function to read JSON/JSONL files using fsspec."""
with fs.open(path, "r") as f:
if path.endswith(".jsonl"):
return [json.loads(line) for line in f if line.strip()]
else:
return json.load(f)
def read_cloud_json(cloud_path: str) -> list[Any]:
r"""Read a JSON/JSONL file from cloud storage (S3 or GCS).
Args:
cloud_path: str
Cloud path in the format:
- 's3://bucket-name/file.json' for AWS S3
- 'gs://bucket-name/file.jsonl' or 'gcs://bucket-name/file.jsonl' for Google Cloud Storage
"""
try:
fs = setup_fs(cloud_path, anon=True)
except Exception:
fs = setup_fs(cloud_path)
files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files)
if not files:
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
return sum([_read_json_with_fs(fs, file) for file in files], [])
def has_tokenized_data(path: "os.PathLike") -> bool:
r"""Check if the path has a tokenized dataset."""
return os.path.isdir(path) and len(os.listdir(path)) > 0
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""Optionally check the package version."""
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if "gptmodel" in requirement or "autoawq" in requirement:
pip_command = f"pip install {requirement} --no-build-isolation"
else:
pip_command = f"pip install {requirement}"
if mandatory:
hint = f"To fix: run `{pip_command}`."
else:
hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version(requirement, hint)
def _load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "TrainingArguments",
) -> Union["Dataset", "IterableDataset"]:
r"""Load a single dataset and aligns it to the standard format."""
logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
data_path = dataset_attr.dataset_name
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
elif dataset_attr.load_from == "script":
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_name = dataset_attr.subset
data_dir = dataset_attr.folder
elif dataset_attr.load_from == "cloud_file":
data_path = dataset_attr.dataset_name
elif dataset_attr.load_from == "file":
data_files = []
local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
if os.path.isdir(local_path):
for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name))
elif os.path.isfile(local_path):
data_files.append(local_path)
else:
raise ValueError(f"File {local_path} not found.")
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
if data_path is None:
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
raise ValueError("File types should be identical.")
else:
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub":
check_version("modelscope>=1.14.0", mandatory=True)
from modelscope import MsDataset
from modelscope.utils.config_ds import MS_DATASETS_CACHE
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load(
dataset_name=data_path,
subset_name=data_name,
data_dir=data_dir,
data_files=data_files,
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.ms_hub_token,
use_streaming=data_args.streaming,
)
if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub":
check_version("openmind>=0.8.0", mandatory=True)
from openmind import OmDataset
from openmind.utils.hub import OM_DATASETS_CACHE
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
dataset = OmDataset.load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=dataset_attr.split,
cache_dir=cache_dir,
token=model_args.om_hub_token,
streaming=data_args.streaming,
)
elif dataset_attr.load_from == "cloud_file":
dataset = Dataset.from_list(read_cloud_json(data_path), split=dataset_attr.split)
else:
dataset = load_dataset(
path=data_path,
name=data_name,
data_dir=data_dir,
data_files=data_files,
split=dataset_attr.split,
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
num_proc=data_args.preprocessing_num_workers,
streaming=data_args.streaming and dataset_attr.load_from != "file",
)
if data_args.streaming and dataset_attr.load_from == "file":
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num]
target_num -= len(indexes)
if target_num > 0:
expand_indexes = np.random.choice(len(dataset), target_num)
indexes = np.concatenate((indexes, expand_indexes), axis=0)
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes)
logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
if data_args.max_samples is not None:
max_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(max_samples))
return align_dataset(dataset, dataset_attr, data_args, training_args)
def _get_merged_dataset(
dataset_names: Optional[list[str]],
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "TrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
return_dict: bool = False,
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
r"""Return the merged datasets in the standard format."""
if dataset_names is None:
return None
datasets = {}
for dataset_name, dataset_attr in zip(dataset_names, get_dataset_list(dataset_names, data_args.dataset_dir)):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")
datasets[dataset_name] = _load_single_dataset(dataset_attr, model_args, data_args, training_args)
if return_dict:
return datasets
else:
return merge_dataset(list(datasets.values()), data_args, seed=training_args.seed)
def _get_dataset_processor(
data_args: "DataArguments",
stage: Literal["pt", "sft"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
do_generate: bool = False,
) -> "DatasetProcessor":
r"""Return the corresponding dataset processor."""
if stage == "pt":
dataset_processor_class = PretrainDatasetProcessor
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing:
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
def __init__(self, data, **kwargs):
return TypedSequence.__init__(
self,
data,
type=kwargs.pop("type", None),
try_type=kwargs.pop("try_type", None),
optimized_int_type=kwargs.pop("optimized_int_type", None),
)
OptimizedTypedSequence.__init__ = __init__
dataset_processor_class = PackedSFTDatasetProcessor
else:
dataset_processor_class = SFTDatasetProcessor
else:
raise RuntimeError(f"The stage parameter currently is only supported to be set as 'pt' or 'sft', please check it")
return dataset_processor_class(template=template, tokenizer=tokenizer, data_args=data_args)
def _get_preprocessed_dataset(
dataset: Optional[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",
training_args: "TrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""Preprocesses the dataset, including format checking and tokenization."""
if dataset is None:
return None
dataset_processor = _get_dataset_processor(data_args, stage, template, tokenizer, do_generate=is_eval)
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache) or (get_local_rank() != 0),
desc="Running tokenizer on dataset",
)
dataset = dataset.map(
dataset_processor.preprocess_dataset,
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=column_names,
**kwargs,
)
if dist.get_rank() == 0:
try:
print("eval example:" if is_eval else "training example:")
dataset_processor.print_data_example(next(iter(dataset)))
except StopIteration:
if stage == "pt":
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
else:
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
return dataset
@contextlib.contextmanager
def main_process_first(local=True, desc="work"):
r"""
A context manager for torch distributed environment where on needs to do something on the main process, while
blocking replicas, and when it's finished releasing the replicas.
One such use is for `datasets`'s `map` feature which to be efficient should be run once on the main process,
which upon completion saves a cached version of results and which then automatically gets loaded by the
replicas.
Args:
local (`bool`, *optional*, defaults to `True`):
if `True` first means process of rank 0 of each node if `False` first means process of rank 0 of node
rank 0 In multi-node environment with a shared filesystem you most likely will want to use
`local=False` so that only the main process of the first node will do the processing. If however, the
filesystem is not shared, then the main process of each node will need to do the processing, which is
the default behavior.
desc (`str`, *optional*, defaults to `"work"`):
a work description to be used in debug logs
"""
if dist.get_world_size() > 1:
main_process_desc = "main local process" if local else "main process"
is_main_process = (get_local_rank() == 0) if local else (dist.get_rank() == 0)
ps = ParallelState()
try:
if not is_main_process:
print(f"waiting for the {main_process_desc} to perform {desc}")
dist.barrier(group = ps.get_group('dp_fsdp'))
yield
finally:
if is_main_process:
print(f"{main_process_desc} completed {desc}, releasing all replicas")
dist.barrier(group = ps.get_group('dp_fsdp'))
else:
yield
def get_dataset(
template: "Template",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "TrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer"
) -> "DatasetModule":
r"""Get the train dataset and optionally gets the evaluation dataset."""
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.info_rank0("Loading dataset from disk will ignore other data arguments.")
tokenized_data = load_from_disk(data_args.tokenized_path)
dataset_module = get_dataset_module(tokenized_data)
if data_args.streaming:
dataset_module["train_dataset"] = dataset_module["train_dataset"].to_iterable_dataset()
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
return dataset_module
if data_args.streaming:
raise ValueError("Turn off `streaming` when saving dataset to disk.")
with main_process_first(desc="load dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_merged_dataset(
data_args.dataset,
model_args,
data_args,
training_args,
stage
)
eval_dataset = _get_merged_dataset(
data_args.eval_dataset,
model_args,
data_args,
training_args,
stage,
return_dict=data_args.eval_on_each_dataset,
)
with main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_preprocessed_dataset(
dataset, data_args, training_args, stage, template, tokenizer, is_eval=False
)
if isinstance(eval_dataset, dict):
for eval_name, eval_data in eval_dataset.items():
eval_dataset[eval_name] = _get_preprocessed_dataset(
eval_data, data_args, training_args, stage, template, tokenizer, is_eval=True
)
else:
eval_dataset = _get_preprocessed_dataset(
eval_dataset, data_args, training_args, stage, template, tokenizer, is_eval=True
)
dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
if data_args.tokenized_path is not None:
dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.")
logger.info_rank0(f"Please launch the training with `tokenized_path: {data_args.tokenized_path}`.")
return get_dataset_module(dataset_dict)