import copy
import multiprocessing
import os
import sys
from pathlib import Path
from typing import Any, List, Tuple
import hydra
from omegaconf import DictConfig
from mindspeed_rl.config_cls.data_handler_config import DataHandlerConfig
from mindspeed_rl.config_cls.validate_config import validate_data_handler_config
from mindspeed_rl.datasets.indexed_dataset import IndexedDatasetBuilder
from mindspeed_rl.datasets.preprocess_data import (
merge_datasets,
build_splitter,
cut_range_to_subs,
handle_subset
)
from mindspeed_rl.utils.tokenizer import get_tokenizer
from mindspeed_rl.datasets.data_handler import build_dataset, get_dataset_handler
from mindspeed_rl.utils.loggers import Loggers
logger = Loggers(name="process_data")
cur_file_dir = Path(__file__).absolute().parent
TEMPLATES_DIR = os.path.join(cur_file_dir, "./configs/model/templates.json")
config_name = sys.argv.pop(1)
base_dir = os.path.realpath(os.path.join(cur_file_dir, ".."))
def resolve_relative_path(args: DataHandlerConfig) -> None:
"""Resolve relative paths to absolute paths within allowed directory.
Converts relative paths for input, tokenizer, and output to absolute
paths. Validates that resolved paths are within the allowed base
directory to prevent directory traversal attacks.
Args:
args: Data handler configuration object containing paths to resolve.
Raises:
ValueError: If resolved path is outside the allowed base directory.
"""
if not os.path.isabs(args.input):
raw_path = os.path.join(base_dir, args.input)
args.input = os.path.realpath(raw_path)
if not args.input.startswith(base_dir):
raise ValueError(
f"Invalid path: {args.input} is not within the allowed directory {base_dir}"
)
if not os.path.isabs(args.tokenizer_name_or_path):
raw_path = os.path.join(base_dir, args.tokenizer_name_or_path)
args.tokenizer_name_or_path = os.path.realpath(raw_path)
if not args.tokenizer_name_or_path.startswith(base_dir):
raise ValueError(
f"Invalid path: {args.tokenizer_name_or_path} is not within the allowed directory {base_dir}"
)
if not os.path.isabs(args.output_prefix):
raw_path = os.path.join(base_dir, args.output_prefix)
args.output_prefix = os.path.realpath(raw_path)
if not args.output_prefix.startswith(base_dir):
raise ValueError(
f"Invalid path: {args.output_prefix} is not within the allowed directory {base_dir}"
)
def preprocess(config: DictConfig) -> None:
"""Execute data preprocessing pipeline.
Main preprocessing function that handles dataset loading, tokenization,
splitting, and serialization. Supports merging multiple datasets or
processing a single dataset with optional multi-processing.
Args:
config: Hydra/OmegaConf configuration object (DictConfig) containing
preprocessing parameters. This is converted to DataHandlerConfig
internally for structured access.
"""
args = DataHandlerConfig(config)
resolve_relative_path(args)
validate_data_handler_config(args)
if args.merge_group_keys is not None:
merge_datasets(args)
return
tokenizer = get_tokenizer(
args.tokenizer_name_or_path,
prompt_type=args.prompt_type,
prompt_type_path=args.prompt_type_path,
enable_thinking=args.enable_thinking
)
splitter = build_splitter(args)
logger.info(f"building dataset: {args.input}")
raw_data = build_dataset(args)
if args.n_subs == 1:
handler = get_dataset_handler(args, raw_data, tokenizer, splitter)
handler.serialize_to_disk()
else:
target_prefix = args.output_prefix
target_prefixname = os.path.basename(target_prefix)
num_samples = len(raw_data)
start_ends = cut_range_to_subs(num_samples, num_samples // args.n_subs)
subsets = [raw_data.select(range(x[0], x[1])) for x in start_ends]
params_list: List[List[Any]] = []
for k, subset in enumerate(subsets):
args_ = copy.deepcopy(args)
args_.output_prefix = target_prefix.replace(
target_prefixname,
f'{str(k).zfill(3)}_of_{str(len(subsets) - 1).zfill(3)}_{target_prefixname}'
)
params = [args_, subset, tokenizer, splitter]
params_list.append(params)
pool = multiprocessing.Pool()
sub_idx_files = pool.map(handle_subset, params_list)
pool.close()
pool.join()
for key in sub_idx_files[0].keys():
idx_files = [x[key] for x in sub_idx_files]
idx_files.sort()
target_idx = idx_files[0].replace(
f'000_of_{str(len(subsets) - 1).zfill(3)}_{target_prefixname}',
target_prefixname
)
target_bin = target_idx.replace('.idx', '.bin')
idx = IndexedDatasetBuilder(target_bin)
for idx_file in idx_files:
idx.add_index(idx_file.replace('.idx', ''))
idx.finalize(target_idx)
for idx_file in idx_files:
os.remove(idx_file)
os.remove(idx_file.replace('.idx', '.bin'))
@hydra.main(config_path="../configs/datasets", config_name=config_name)
def main(config: DictConfig) -> None:
"""Hydra main entry point for data preprocessing.
Args:
config: Hydra/OmegaConf configuration object (DictConfig) loaded from
YAML config file specified by command line argument.
"""
preprocess(config)
if __name__ == '__main__':
main()