from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
import logging

from mindspeed_mm.fsdp.data.data_utils.func_utils.convert import DatasetAttr
from mindspeed_mm.fsdp.data.data_utils.func_utils.convert import DataArguments as BasicDataAruments
from mindspeed_mm.fsdp.data.data_utils.func_utils.model_args import ProcessorArguments
from mindspeed_mm.config.arguments.base_args import BaseArguments

logger = logging.getLogger(__name__)


class DataSetArguments(BaseArguments):
    dataset_type: str = field(
        metadata={"help": "Type of dataset to use."}
    )
    basic_parameters: BasicDataAruments = field(default_factory=BasicDataAruments)
    preprocess_parameters: Optional[ProcessorArguments] = field(default_factory=ProcessorArguments)
    attr: DatasetAttr = field(default_factory=DatasetAttr)


class CollateArguments(BaseArguments):
    model_name: str = field(metadata={"help": "Name of the model for which collation is configured."})
    ignore_pad_token_for_loss: bool = field(
        default=False,
        metadata={"help": ""}
    )
    pad_to_multiple_of: int = field(
        default=8,
        metadata={"help": "Pad sequences to a multiple of this value for efficient processing."}
    )


class DataloaderArguments(BaseArguments):
    dataloader_mode: str = field(metadata={"help": "Mode of dataloader."})
    sampler_type: str = field(metadata={"help": "Type of sampler to use."})
    shuffle: Optional[bool] = field(metadata={"help": "Whether to shuffle the data during training."})
    drop_last: bool = field(metadata={"help": "Whether to drop the last incomplete batch if dataset size is not divisible by batch size."})
    pin_memory: bool = field(metadata={"help": "Whether to pin memory for faster data transfer to GPU."})
    collate_param: CollateArguments = field(default_factory=CollateArguments)
    num_workers: int = field(default=2, metadata={"help": "Number of worker processes for data loading."})
    enable_preload: bool = field(
        default=False,
        metadata={"help": "Whether to enable async data preloading to overlap CPU-H2D transfer with training."},
    )


class DataArguments(BaseArguments):
    dataset_param: DataSetArguments = field(default_factory=DataSetArguments)
    dataloader_param: DataloaderArguments = field(default_factory=DataloaderArguments)