from dataclasses import dataclass, field
from typing import List, Literal, Optional
from mindspeed_mm.config.arguments.base_args import BaseArguments
class RecomputePlanConfig(BaseArguments):
"""Configuration for recompute plan"""
apply_modules: List[str] = field(default_factory=list)
use_reentrant: bool = False
class ChunkLossPlanConfig(BaseArguments):
apply_module: str = field(
default="lm_head",
metadata={"help": "module that applied chunk loss"}
)
chunk_size: int = field(
default=1024,
metadata={"help": "Size of each chunk loss"},
)
total_chunk_size: int = field(
default=4096,
metadata={"help": "Size of total chunk loss"},
)
class LossArguments(BaseArguments):
loss_type: Optional[str] = field(
default="raw",
metadata={"help": "Type of loss function type, If ot provided, will be computed based on raw model loss function"},
)
router_aux_loss_coef: float = field(
default=0.0,
metadata={"help": "Router Auxiliary Loss Coefficient"},
)
router_aux_loss_offload: bool = field(
default=False,
metadata={"help": "Whether apply router auxiliary loss offload"},
)
class ActivationOffloadPlanConfig(BaseArguments):
apply_modules: List[str] = field(
default=None,
metadata={"help": "module that applied activation offload"}
)
class ChunkMbsPlanConfig(BaseArguments):
apply_modules: List[str] = field(
default=None,
metadata={"help": "module that applied chunkmbs"}
)
chunk_mbs: int = field(
default=1,
metadata={"help": "chunk_mbs, chunked micro batch size"}
)
batch_dim: int = field(
default=0,
metadata={"help": "chunk_mbs, batchsize dim"}
)
chunk_arg_indexs: List[int] = field(
default=[0],
metadata={"help": "chunk_mbs, chunk args indexs"}
)
chunk_kwarg_names: List[str] = field(
default=[],
metadata={"help": "chunk_mbs, chunk kwarg names"}
)
class FeatureArguments(BaseArguments):
recompute_plan: RecomputePlanConfig = field(default_factory=RecomputePlanConfig)
loss_cfg: LossArguments = field(default_factory=LossArguments)
enable_chunk_loss: bool = field(
default=False,
metadata={"help": "Whether apply chunkloss for loss compute"},
)
enable_dynamic_chunk_loss: bool = field(
default=False,
metadata={"help": "Whether apply dynamic chunkloss for loss compute"},
)
chunkloss_plan: ChunkLossPlanConfig = field(default_factory=ChunkLossPlanConfig)
enable_activation_offload: bool = field(
default=False,
metadata={"help": "Whether apply activation offload"}
)
activation_offload_plan: ActivationOffloadPlanConfig = field(default_factory=ActivationOffloadPlanConfig)
enable_chunk_mbs: bool = field(
default=False,
metadata={"help": "Whether apply chunk_mbs"}
)
chunkmbs_plan: ChunkMbsPlanConfig = field(default_factory=ChunkMbsPlanConfig)