from dataclasses import dataclass, field, fields, is_dataclass, MISSING
from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union, get_type_hints, get_origin, get_args
import logging
from functools import wraps

from mindspeed.fsdp.utils.log import print_rank

logger = logging.getLogger(__name__)


def create_nested_dataclass_field(
    parent_cls: type,
    field_name: str,
    nested_data: Dict[str, Any],
) -> None:
    nested_cls_name = f"{field_name.capitalize()}Field"
    nested_dataclass = create_nested_dataclass(nested_cls_name, nested_data)

    setattr(parent_cls, field_name, field(default_factory=nested_dataclass))
    if '__annotations__' not in parent_cls.__dict__:
        parent_cls.__annotations__ = {}
    parent_cls.__annotations__[field_name] = nested_dataclass


def create_nested_dataclass(cls_name: str, data: Dict[str, Any]) -> type:
    nested_fields = {}
    nested_annotations = {}
    for k, v in data.items():
        if isinstance(v, dict):
            sub_cls = create_nested_dataclass(f"{cls_name}_{k}", v)
            nested_fields[k] = field(default_factory=sub_cls)
            nested_annotations[k] = sub_cls
        else:
            field_type = type(v) if v is not None else Any
            if field_type in (list, set):
                nested_fields[k] = field(default_factory=field_type)
            else:
                nested_fields[k] = field(default=v)
            nested_annotations[k] = field_type

    dynamic_cls = type(cls_name, (), nested_fields)
    dynamic_cls.__annotations__ = nested_annotations
    return dataclass(dynamic_cls)


def allow_extra_fields(cls):
    """
    Decorator: Allows dataclass to accept extra fields beyond its defined attributes.

    This decorator wraps the original __init__ method of a dataclass to:
    1. Separate known dataclass fields from unknown fields
    2. Pass known fields to the original dataclass __init__
    3. Store unknown fields as dynamic attributes with warning logging
    4. Preserve the __post_init__ hook if defined

    Args:
        cls: The dataclass to be decorated

    Returns:
        The modified dataclass with enhanced __init__ method

    Example:
        @allow_extra_fields
        @dataclass
        class MyDataClass:
            name: str = "default"
            age: int = 0

        # This will work and store 'extra_field' dynamically
        obj = MyDataClass(name="John", age=25, extra_field="value")
    """
    # Store reference to the original __init__ method generated by @dataclass
    original_init = cls.__init__

    @wraps(original_init)
    def new_init(self, *args, **kwargs):
        """
        Enhanced __init__ method that handles both defined and extra fields.

        Args:
            *args: Positional arguments (typically not used in dataclass __init__)
            **kwargs: Keyword arguments including both defined and extra fields
        """
        # Get all field names defined in the dataclass
        dataclass_field_names = {f.name for f in fields(cls)}

        # Separate arguments into two categories:
        dataclass_kwargs = {}
        extra_kwargs = {}

        for key, value in kwargs.items():
            if key in dataclass_field_names:
                dataclass_kwargs[key] = value
            else:
                # Unknown field: store separately and log warning
                extra_kwargs[key] = value
                print_rank(logger.warning,
                    f"Field '{key}' is not defined in {cls.__name__}. "
                    f"Adding as dynamic attribute."
                )

        # Call the original dataclass __init__ with only defined fields
        # This ensures proper type validation and default value handling
        original_init(self, *args, **dataclass_kwargs)

        # Initialize _extra_fields dictionary if it doesn't exist
        # This stores metadata about dynamically added fields
        if not hasattr(self, '_extra_fields'):
            self._extra_fields = {}

        def _assign_nested_fields(obj: Any, data: Dict[str, Any]) -> None:
            for k, v in data.items():
                if isinstance(v, dict):
                    sub_obj = getattr(obj, k)
                    _assign_nested_fields(sub_obj, v)
                else:
                    setattr(obj, k, v)

        # Process extra fields: set them as attributes and track in _extra_fields
        for key, value in extra_kwargs.items():
            if isinstance(value, dict):
                create_nested_dataclass_field(self, key, value)
                nested_cls = get_type_hints(self)[key]
                nested_instance = nested_cls()
                _assign_nested_fields(nested_instance, value)
                setattr(self, key, nested_instance)
            else:
                setattr(self, key, value)
            self._extra_fields[key] = value

    cls.__init__ = new_init
    return cls


def instantiate_dataclass(dataclass_type: type, data: Dict[str, Any]) -> Any:
    """
    Recursively instantiate dataclass from dictionary data.

    Args:
        dataclass_type: Dataclass type to instantiate
        data: Dictionary data

    Returns:
        Instantiated dataclass object
    """
    if not is_dataclass(dataclass_type):
        return data

    # Get type hints for the dataclass
    try:
        type_hints = get_type_hints(dataclass_type)
    except Exception as e:
        raise RuntimeError(f"Failed to get type hints for {dataclass_type}: {e}") from e

    # Dictionary for storing known dataclass field values
    dataclass_kwargs = {}

    # Dictionary for storing extra/unknown field values
    extra_kwargs = {}

    # Get all field names defined in the dataclass
    dataclass_field_names = {f.name for f in fields(dataclass_type)}

    # Process each key-value pair in the input data dictionary
    for key, value in data.items():
        if key in dataclass_field_names:
            # --- KNOWN FIELD PROCESSING ---
            # Get the type annotation for this field
            # Used for type checking and nested instantiation
            field_type = type_hints[key]

            # Check if this field is itself a dataclass and the value is a dictionary
            # If so, recursively instantiate the nested dataclass
            if is_dataclass(field_type) and isinstance(value, dict):
                dataclass_kwargs[key] = instantiate_dataclass(field_type, value)
            else:
                # For primitive types or non-dataclass objects, use the value as-is
                dataclass_kwargs[key] = value
        else:
            # --- EXTRA FIELD PROCESSING ---
            # This field is not defined in the dataclass
            extra_kwargs[key] = value

    # Merge known dataclass fields and extra fields into a single dictionary
    # The @allow_extra_fields decorator will handle the separation internally
    all_kwargs = {**dataclass_kwargs, **extra_kwargs}

    return dataclass_type(**all_kwargs)