from dataclasses import dataclass, field, fields, is_dataclass, MISSING
from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union, get_type_hints, get_origin, get_args
import logging
from functools import wraps
from mindspeed.fsdp.utils.log import print_rank
logger = logging.getLogger(__name__)
def create_nested_dataclass_field(
parent_cls: type,
field_name: str,
nested_data: Dict[str, Any],
) -> None:
nested_cls_name = f"{field_name.capitalize()}Field"
nested_dataclass = create_nested_dataclass(nested_cls_name, nested_data)
setattr(parent_cls, field_name, field(default_factory=nested_dataclass))
if '__annotations__' not in parent_cls.__dict__:
parent_cls.__annotations__ = {}
parent_cls.__annotations__[field_name] = nested_dataclass
def create_nested_dataclass(cls_name: str, data: Dict[str, Any]) -> type:
nested_fields = {}
nested_annotations = {}
for k, v in data.items():
if isinstance(v, dict):
sub_cls = create_nested_dataclass(f"{cls_name}_{k}", v)
nested_fields[k] = field(default_factory=sub_cls)
nested_annotations[k] = sub_cls
else:
field_type = type(v) if v is not None else Any
if field_type in (list, set):
nested_fields[k] = field(default_factory=field_type)
else:
nested_fields[k] = field(default=v)
nested_annotations[k] = field_type
dynamic_cls = type(cls_name, (), nested_fields)
dynamic_cls.__annotations__ = nested_annotations
return dataclass(dynamic_cls)
def allow_extra_fields(cls):
"""
Decorator: Allows dataclass to accept extra fields beyond its defined attributes.
This decorator wraps the original __init__ method of a dataclass to:
1. Separate known dataclass fields from unknown fields
2. Pass known fields to the original dataclass __init__
3. Store unknown fields as dynamic attributes with warning logging
4. Preserve the __post_init__ hook if defined
Args:
cls: The dataclass to be decorated
Returns:
The modified dataclass with enhanced __init__ method
Example:
@allow_extra_fields
@dataclass
class MyDataClass:
name: str = "default"
age: int = 0
# This will work and store 'extra_field' dynamically
obj = MyDataClass(name="John", age=25, extra_field="value")
"""
original_init = cls.__init__
@wraps(original_init)
def new_init(self, *args, **kwargs):
"""
Enhanced __init__ method that handles both defined and extra fields.
Args:
*args: Positional arguments (typically not used in dataclass __init__)
**kwargs: Keyword arguments including both defined and extra fields
"""
dataclass_field_names = {f.name for f in fields(cls)}
dataclass_kwargs = {}
extra_kwargs = {}
for key, value in kwargs.items():
if key in dataclass_field_names:
dataclass_kwargs[key] = value
else:
extra_kwargs[key] = value
print_rank(logger.warning,
f"Field '{key}' is not defined in {cls.__name__}. "
f"Adding as dynamic attribute."
)
original_init(self, *args, **dataclass_kwargs)
if not hasattr(self, '_extra_fields'):
self._extra_fields = {}
def _assign_nested_fields(obj: Any, data: Dict[str, Any]) -> None:
for k, v in data.items():
if isinstance(v, dict):
sub_obj = getattr(obj, k)
_assign_nested_fields(sub_obj, v)
else:
setattr(obj, k, v)
for key, value in extra_kwargs.items():
if isinstance(value, dict):
create_nested_dataclass_field(self, key, value)
nested_cls = get_type_hints(self)[key]
nested_instance = nested_cls()
_assign_nested_fields(nested_instance, value)
setattr(self, key, nested_instance)
else:
setattr(self, key, value)
self._extra_fields[key] = value
cls.__init__ = new_init
return cls
def instantiate_dataclass(dataclass_type: type, data: Dict[str, Any]) -> Any:
"""
Recursively instantiate dataclass from dictionary data.
Args:
dataclass_type: Dataclass type to instantiate
data: Dictionary data
Returns:
Instantiated dataclass object
"""
if not is_dataclass(dataclass_type):
return data
try:
type_hints = get_type_hints(dataclass_type)
except Exception as e:
raise RuntimeError(f"Failed to get type hints for {dataclass_type}: {e}") from e
dataclass_kwargs = {}
extra_kwargs = {}
dataclass_field_names = {f.name for f in fields(dataclass_type)}
for key, value in data.items():
if key in dataclass_field_names:
field_type = type_hints[key]
if is_dataclass(field_type) and isinstance(value, dict):
dataclass_kwargs[key] = instantiate_dataclass(field_type, value)
else:
dataclass_kwargs[key] = value
else:
extra_kwargs[key] = value
all_kwargs = {**dataclass_kwargs, **extra_kwargs}
return dataclass_type(**all_kwargs)