import os
import sys
import time
import torch
import glob
import os, sys, time, subprocess
from typing import Optional, List, Tuple
import torch.distributed as dist
from mindspeed_llm.fsdp2.data.megatron_data.megatron_dataset_handler import _get_data_format
from mindspeed_llm.fsdp2.data.megatron_data.megatron_utils import is_shared_path
from mindspeed_llm.fsdp2.utils.logging import get_logger
from mindspeed_llm.fsdp2.data.megatron_data.megatron_gpt_dataset import MockGPTDataset, GPTDataset
from mindspeed_llm.fsdp2.data.megatron_data.megatron_gpt_dataset import GPTDatasetConfig
from mindspeed_llm.fsdp2.data.megatron_data.indexed_dataset import IndexedDataset
from mindspeed_llm.fsdp2.data.megatron_data.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from mindspeed_llm.fsdp2.data.megatron_data.megatron_tokenizer import _AutoTokenizer
from mindspeed_llm.fsdp2.distributed.parallel_state import ParallelState
logger = get_logger(__name__)
def convert_datasets(model_args, data_args, shared: bool, ):
IDX_EXT = ".idx"
BIN_EXT = ".bin"
was_list = isinstance(data_args.dataset['file_name'], (list, tuple))
paths = [str(p).strip() for p in data_args.dataset['file_name']] if was_list else [
p.strip() for p in str(data_args.dataset['file_name']).split(",") if p.strip()
]
if not paths:
return
rank = dist.get_rank() if dist.is_initialized() else 0
if "LOCAL_RANK" in os.environ:
local_rank = int(os.environ["LOCAL_RANK"])
else:
local_rank = rank % max(1, (torch.cuda.device_count() if torch.cuda.is_available() else 1))
should_convert = (rank == 0) if shared else (local_rank == 0)
out_map = {}
user_out = getattr(data_args, "output_prefix", None)
for raw in paths:
p = raw.strip().strip('"').strip("'")
if os.path.isfile(p):
auto_prefix = os.path.splitext(p)[0]
raw_base = os.path.splitext(os.path.basename(p))[0]
elif os.path.isdir(p):
auto_prefix = os.path.join(p, os.path.basename(os.path.normpath(p)))
raw_base = os.path.basename(os.path.normpath(p))
else:
raise FileNotFoundError(f"[DataConvert] Expected raw file/dir but got: {p}")
if user_out:
user_prefix = str(user_out).strip().strip('"').strip("'")
if len(paths) == 1:
out_prefix = user_prefix
else:
out_prefix = f"{user_prefix}_{raw_base}"
else:
out_prefix = auto_prefix
os.makedirs(os.path.dirname(out_prefix) or ".", exist_ok=True)
out_map[p] = {
"out_prefix": out_prefix,
"base": out_prefix,
}
if should_convert:
for raw in paths:
p = raw.strip().strip('"').strip("'")
meta = out_map[p]
out_prefix = meta["out_prefix"]
logger.info_rank0(f"[DataConvert] Converting: {p} -> {out_prefix}")
cmd = [
sys.executable, os.path.abspath("mindspeed_llm/fsdp2/data/megatron_data/megatron_preprocess_dataset.py"),
"--input", p,
"--tokenizer-type", "PretrainedFromHF",
"--handler-name", "GeneralPretrainHandler",
"--output-prefix", out_prefix,
"--workers", "4",
"--log-interval", "1000",
"--n-subs", "1",
]
cmd += ["--json-keys"] + list(["text"])
if getattr(model_args, "tokenizer_model", False):
cmd += ["--tokenizer-model", str(data_args.tokenizer_model)]
if getattr(model_args, "model_name_or_path", False):
cmd += ["--model-name-or-path", str(model_args.model_name_or_path)]
if getattr(data_args, "append_eod", False):
cmd.append("--append-eod")
if getattr(data_args, "enable_thinking", None) is not None:
cmd += ["--enable-thinking", str(data_args.enable_thinking)]
if getattr(data_args, "prompt_type", None):
cmd += ["--prompt-type", data_args.prompt_type]
if getattr(data_args, "cutoff_len", None):
cmd += ["--seq-length", str(data_args.cutoff_len)]
subprocess.run(cmd, check=True)
if dist.is_initialized():
dist.barrier()
new_paths = []
for raw in paths:
q = raw.strip().strip('"').strip("'")
if q not in out_map:
continue
meta = out_map[q]
base = meta["base"]
current_matches = []
if os.path.exists(base + IDX_EXT) and os.path.exists(base + BIN_EXT):
current_matches.append(base)
else:
dir_name = os.path.dirname(base) or "."
prefix_name = os.path.basename(base)
all_files = sorted(os.listdir(dir_name))
for f in all_files:
if (f.startswith(prefix_name) and "_text_document" in f) and f.endswith(IDX_EXT):
cand = os.path.join(dir_name, f[:-len(IDX_EXT)])
if os.path.exists(cand + BIN_EXT):
current_matches.append(cand)
if not current_matches:
raise FileNotFoundError(
f"[DataConvert] Missing output prefix for training: {base}[*_text_document or *_packed]"
)
new_paths.extend(current_matches)
data_args.dataset['file_name'] = new_paths
def _is_raw_data_path(path: str) -> bool:
"""Return True if the path is a raw file/dir recognizable by _get_data_format."""
p = str(path).strip().strip('"').strip("'")
if os.path.isfile(p):
data_files = [p]
elif os.path.isdir(p):
data_files = [os.path.join(p, f) for f in os.listdir(p)]
else:
return False
if not data_files:
return False
_, fmt = _get_data_format(data_files)
return fmt is not None
def get_document_dataset(model_args, data_args):
data_path = getattr(data_args, "dataset", None)['file_name']
if data_path:
if isinstance(data_path, (list, tuple)):
raw_path = str(data_path[0])
else:
raw_path = str(data_path).split(",")[0]
raw_path = raw_path.strip().strip('"').strip("'")
if _is_raw_data_path(raw_path):
logger.info_rank0("[InitHook] Megatron initialization completed. Starting data preprocessing...")
if os.path.isfile(raw_path):
base_dir = os.path.dirname(raw_path)
elif os.path.isdir(raw_path):
base_dir = raw_path
else:
base_dir = os.path.dirname(raw_path) or "."
shared = is_shared_path(base_dir)
convert_datasets(model_args, data_args, shared)
logger.info_rank0("[InitHook] Data preprocessing finished.")
else:
raise Exception(f"'{raw_path}' is incorrect!")
else:
raise Exception(f"'{data_path}' is empty!")
def get_blend_from_list(
blend: Optional[List[str]],
) -> Optional[Tuple[List[str], Optional[List[float]]]]:
"""Get the BlendedMegatronDatasetConfig blend from the blend list
Args:
blend (Optional[List[str]]): The blend list, which can be either (1) a list of prefixes, e.g. ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], or (2) a flattened, zipped list of weights and prefixes, e.g. ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"]
Returns:
Optional[Tuple[List[str], Optional[List[float]]]]: The blend, consisting of a list of dataset prefixes and optionally a list of dataset weights, e.g. [["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], [30.0, 70.0]].
"""
if blend is None:
return None
if len(blend) % 2 == 1:
weight_per_dataset = None
raw_prefix_per_dataset = blend
else:
raw_weight_per_dataset, raw_prefix_per_dataset = zip(
*[(blend[i], blend[i + 1]) for i in range(0, len(blend), 2)]
)
weight_per_dataset = []
for rwpd in raw_weight_per_dataset:
try:
weight = float(rwpd)
except ValueError:
weight = None
weight_per_dataset.append(weight)
is_none = map(lambda _: _ is None, weight_per_dataset)
if any(is_none):
assert all(is_none)
weight_per_dataset = None
raw_prefix_per_dataset = blend
prefix_per_dataset = [rppd.strip() for rppd in raw_prefix_per_dataset]
return prefix_per_dataset, weight_per_dataset
def core_gpt_dataset_config_from_args(model_args, data_args, training_args):
tokenizer = _AutoTokenizer(
model_args.model_name_or_path,
vocab_extra_ids=0,
model_max_length=data_args.cutoff_len,
use_fast=False,
prompt_type=None,
)
return GPTDatasetConfig(
random_seed=training_args.seed,
sequence_length=data_args.cutoff_len,
blend=get_blend_from_list(data_args.dataset['file_name']),
blend_per_split=None,
split=data_args.split,
path_to_cache=None,
mmap_bin_files=True,
tokenizer=tokenizer,
reset_position_ids=False,
reset_attention_mask=data_args.reset_attention_mask,
eod_mask_loss=False,
create_attention_mask=data_args.create_attention_mask_in_dataloader,
)
def get_global_batch_size(training_args):
ps = ParallelState()
world_size = ps.get_group_size("dp_fsdp") if dist.is_initialized() else 1
global_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * world_size
return global_batch_size
def is_dataset_built_on_rank():
ps = ParallelState()
return ps.get_fsdp_group_size() > 0
def get_train_valid_test_num_samples(training_args, data_args):
"""Train/valid/test num samples."""
global_batch_size = get_global_batch_size(training_args)
if training_args.max_steps > 0:
train_iters = training_args.max_steps
elif training_args.num_train_epochs > 0:
train_iters = 0
raw_paths = data_args.dataset['file_name'] if isinstance(data_args.dataset['file_name'], list) else str(data_args.dataset['file_name'] or "").split(',')
dataset_path = [
p.strip() for p in raw_paths
if p.strip() and os.path.isfile(f"{p.strip()}.bin") and os.path.isfile(f"{p.strip()}.idx")
]
for dataset in dataset_path:
train_iters += len(IndexedDataset(dataset, multimodal=False, mmap=True))
train_samples = train_iters * global_batch_size
return (
train_samples,
0,
0,
)
def train_valid_test_datasets_provider(model_args, data_args, training_args):
"""Build the train test and validation datasets.
Args:
train_val_test_num_samples : A list containing the number of samples in train test and validation.
"""
get_document_dataset(model_args, data_args)
config = core_gpt_dataset_config_from_args(model_args, data_args, training_args)
train_val_test_num_samples = get_train_valid_test_num_samples(training_args, data_args)
if config.mock:
dataset_type = MockGPTDataset
else:
dataset_type = GPTDataset
logger.info_rank0("> building train, validation, and test datasets for GPT ...")
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
dataset_type,
train_val_test_num_samples,
is_dataset_built_on_rank,
config
).build()
logger.info_rank0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds