""" search 过程中的数据类
"""
from dataclasses import dataclass, fields
from typing import List, Optional
from tinker.utils.block_args import BlockArgs
from tinker.profiler.profile_classes import ProfileArgs
@dataclass(frozen=True)
class SearchArgs(ProfileArgs):
"""
搜索参数,继承自ProfileArgs
属性:
mbs: int - micro batch size,微批次大小
tp: int - tensor parallel size
sp: int - sequence parallel size
ep: int - expert parallel
dp: Optional[int] - dp值,表示数据并行的副本数
pp: Optional[int] - pp值,表示pipeline并行的阶段数
recompute: Optional[int] - 重计算
dist_opt: Optional[int] - 分布式优化器
"""
dp: Optional[int] = None
pp: Optional[int] = None
recompute: Optional[int] = None
dist_opt: Optional[int] = None
def __post_init__(self):
'''检查'recompute'和'dist_opt'这两个属性的值是否在{0, 1}中'''
valid_values = {0, 1}
for f in filter(lambda f: f.name in {'recompute', 'dist_opt'}, fields(self)):
if f.name in {'recompute', 'dist_opt'} and getattr(self, f.name) not in valid_values:
raise ValueError(f"Invalid value for {f.name}, expected 0 or 1, got {getattr(self, f.name)}")
@dataclass(frozen=True)
class ResultArgs(SearchArgs):
"""
搜索结果参数,继承自SearchArgs
属性:
mbs: int - micro batch size,微批次大小
tp: int - tensor parallel size
sp: int - sequence parallel size
ep: int - expert parallel
dp: Optional[int] - dp值,表示数据并行的副本数
pp: Optional[int] - pp值,表示pipeline并行的阶段数
recompute: Optional[int] - 重计算
dist_opt: Optional[int] - 分布式优化器
num_layers_list: str - 表示不同阶段的神经网络层数配置列表
"""
num_layers_list: Optional[str] = None
gbs: Optional[int] = None
blocks: Optional[List[BlockArgs]] = None
@dataclass()
class Metrics:
"""
性能数据
"""
time_costs: list
mem_costs: list
time_cost: float
mem_cost: float
tokens_per_npu_per_sec: Optional[float] = None
@dataclass(frozen=True)
class TaskParam:
"""
任务维度的参数,作为非均匀区间划分时的传参
"""
search_args: SearchArgs
blocks: List[BlockArgs]
@dataclass(frozen=True)
class StageData:
"""
用于存储动规过程中的数据
"""
num_npu_before: int
stage_time_max_min: float
num_layer_list: list
stage_mem_max: float