# Copyright (c) 2025, HUAWEI CORPORATION.  All rights reserved.
import copy
import multiprocessing
import os
import sys
from pathlib import Path
from typing import Any, List, Tuple

import hydra
from omegaconf import DictConfig

from mindspeed_rl.config_cls.data_handler_config import DataHandlerConfig
from mindspeed_rl.config_cls.validate_config import validate_data_handler_config
from mindspeed_rl.datasets.indexed_dataset import IndexedDatasetBuilder
from mindspeed_rl.datasets.preprocess_data import (
    merge_datasets,
    build_splitter,
    cut_range_to_subs,
    handle_subset
)
from mindspeed_rl.utils.tokenizer import get_tokenizer
from mindspeed_rl.datasets.data_handler import build_dataset, get_dataset_handler
from mindspeed_rl.utils.loggers import Loggers

# Logger instance for data processing
logger = Loggers(name="process_data")

# Directory of current file for path resolution
cur_file_dir = Path(__file__).absolute().parent

# Path to model templates configuration file
TEMPLATES_DIR = os.path.join(cur_file_dir, "./configs/model/templates.json")

# Configuration name from command line arguments
config_name = sys.argv.pop(1)

# Base directory for relative path resolution
base_dir = os.path.realpath(os.path.join(cur_file_dir, ".."))


def resolve_relative_path(args: DataHandlerConfig) -> None:
    """Resolve relative paths to absolute paths within allowed directory.

    Converts relative paths for input, tokenizer, and output to absolute
    paths. Validates that resolved paths are within the allowed base
    directory to prevent directory traversal attacks.

    Args:
        args: Data handler configuration object containing paths to resolve.

    Raises:
        ValueError: If resolved path is outside the allowed base directory.
    """
    if not os.path.isabs(args.input):
        raw_path = os.path.join(base_dir, args.input)
        args.input = os.path.realpath(raw_path)
        if not args.input.startswith(base_dir):
            raise ValueError(
                f"Invalid path: {args.input} is not within the allowed directory {base_dir}"
            )

    if not os.path.isabs(args.tokenizer_name_or_path):
        raw_path = os.path.join(base_dir, args.tokenizer_name_or_path)
        args.tokenizer_name_or_path = os.path.realpath(raw_path)
        if not args.tokenizer_name_or_path.startswith(base_dir):
            raise ValueError(
                f"Invalid path: {args.tokenizer_name_or_path} is not within the allowed directory {base_dir}"
            )

    if not os.path.isabs(args.output_prefix):
        raw_path = os.path.join(base_dir, args.output_prefix)
        args.output_prefix = os.path.realpath(raw_path)
        if not args.output_prefix.startswith(base_dir):
            raise ValueError(
                f"Invalid path: {args.output_prefix} is not within the allowed directory {base_dir}"
            )


def preprocess(config: DictConfig) -> None:
    """Execute data preprocessing pipeline.

    Main preprocessing function that handles dataset loading, tokenization,
    splitting, and serialization. Supports merging multiple datasets or
    processing a single dataset with optional multi-processing.

    Args:
        config: Hydra/OmegaConf configuration object (DictConfig) containing 
                preprocessing parameters. This is converted to DataHandlerConfig 
                internally for structured access.
    """
    args = DataHandlerConfig(config)
    resolve_relative_path(args)
    validate_data_handler_config(args)

    if args.merge_group_keys is not None:
        merge_datasets(args)
        return

    tokenizer = get_tokenizer(
        args.tokenizer_name_or_path,
        prompt_type=args.prompt_type,
        prompt_type_path=args.prompt_type_path,
        enable_thinking=args.enable_thinking
    )
    splitter = build_splitter(args)

    logger.info(f"building dataset: {args.input}")
    raw_data = build_dataset(args)

    if args.n_subs == 1:
        handler = get_dataset_handler(args, raw_data, tokenizer, splitter)
        # Serialize to bin and idx files
        handler.serialize_to_disk()
    else:
        target_prefix = args.output_prefix
        target_prefixname = os.path.basename(target_prefix)

        num_samples = len(raw_data)
        start_ends = cut_range_to_subs(num_samples, num_samples // args.n_subs)
        subsets = [raw_data.select(range(x[0], x[1])) for x in start_ends]

        # Multi-processing setup
        params_list: List[List[Any]] = []
        for k, subset in enumerate(subsets):
            args_ = copy.deepcopy(args)
            args_.output_prefix = target_prefix.replace(
                target_prefixname,
                f'{str(k).zfill(3)}_of_{str(len(subsets) - 1).zfill(3)}_{target_prefixname}'
            )
            params = [args_, subset, tokenizer, splitter]
            params_list.append(params)
        
        pool = multiprocessing.Pool()
        sub_idx_files = pool.map(handle_subset, params_list)
        pool.close()
        pool.join()

        for key in sub_idx_files[0].keys():
            idx_files = [x[key] for x in sub_idx_files]
            idx_files.sort()
            target_idx = idx_files[0].replace(
                f'000_of_{str(len(subsets) - 1).zfill(3)}_{target_prefixname}',
                target_prefixname
            )
            target_bin = target_idx.replace('.idx', '.bin')
            idx = IndexedDatasetBuilder(target_bin)
            for idx_file in idx_files:
                idx.add_index(idx_file.replace('.idx', ''))
            idx.finalize(target_idx)

            # Clean up temporary files
            for idx_file in idx_files:
                os.remove(idx_file)
                os.remove(idx_file.replace('.idx', '.bin'))


@hydra.main(config_path="../configs/datasets", config_name=config_name)
def main(config: DictConfig) -> None:
    """Hydra main entry point for data preprocessing.

    Args:
        config: Hydra/OmegaConf configuration object (DictConfig) loaded from 
                YAML config file specified by command line argument.
    """
    preprocess(config)


if __name__ == '__main__':
    main()