import csv
import glob
import logging
import math
import os
import sys
import time
import argparse
from collections import defaultdict
from dataclasses import dataclass, fields
from typing import Dict, List, Optional, Tuple, Union
sys.path.append("./")
from tinker.profiler.profile_classes import ProfileArgs
from tinker.search.arguments import print_args, preprocess_args
from tinker.search.data import TaskParam, SearchArgs, ResultArgs, Metrics
from tinker.search.process import ResultOutputHandler
from tinker.utils.block_args import BlockArgs, BlockCost, DetailedInfo
from tinker.utils.utils import load_infos, convert_to_pp_stage_block_idx
from tinker.utils.logger import logger, init_log
FeaturesType = ProfileArgs
ProfileDataType = Dict[FeaturesType, Dict[str, float]]
class FixedValueDict:
def __init__(self, fixed_value):
self.fixed_value = fixed_value
def __getitem__(self, key):
return self.fixed_value
def get(self, key, default=None):
return self.fixed_value
class ProfiledData:
def __init__(self):
self._block_data = defaultdict(dict)
@staticmethod
def _get_data(datas: ProfileDataType, features: FeaturesType, block_name="") -> Union[Dict, FixedValueDict, float]:
if features not in datas:
logger.info(f"feature {features} not in profiled data, using 10000000.0")
if block_name:
return 20000000.0
return FixedValueDict(20000000.0)
if block_name:
return datas[features][block_name]
return datas[features]
def add_data(self, data: Tuple[float, ...], features: FeaturesType, block_name: str):
self._block_data[features][block_name] = BlockCost(*data)
def get_data_by_args(self, profiled_args: ProfileArgs):
return self._block_data[profiled_args]
def get_profiled_args_list(self) -> List[FeaturesType]:
return list(self._block_data.keys())
def get_block_names(self):
"""返回"""
for block_data in self._block_data.values():
if isinstance(block_data, dict):
return list(block_data.keys())
return []
@dataclass
class BlockNames:
"""需要与block_profiler打配合,有点不好"""
pre: str
block: str
post1: str
post2: str
class TinkerCostModel:
def __init__(self, args):
self._band_data_ready = None
self._block_data_ready = None
self.profiled_data = ProfiledData()
self._read_block_data(args.profiled_data_path)
self.inter_band = None
self.intra_band = None
self._read_band_time(args.profiled_data_path)
self.block_names = BlockNames(*self.profiled_data.get_block_names())
self.num_other_block = len(fields(self.block_names)) - 1
self.num_procs_per_node = args.num_npus_per_node
self.args = args
@property
def _data_ready(self):
return self._band_data_ready and self._block_data_ready
@staticmethod
def calc_reserved_mem_costs(pp: int, blocks: List[BlockArgs]) -> List[float]:
"""
为每个stage内存加上额外的reserved部分(内存碎片导致),当前策略为
1. 含头处理的stage: blocks[0].bwd_reserved
2. 含尾处理的stage: 尾处理峰值工作内存,`blocks[-1].data.act + blocks[-1].fwd_reserved`
3. 其他stage: blocks[1].bwd_reserved * 2
4. 若含头又含尾,则: max(blocks[0].bwd_reserved, blocks[-1].fwd_reserved)
"""
reserved_mem_costs = []
first_stage_mem_reserved = blocks[0].max_reserved_mem
last_stage_mem_reserved = blocks[-1].max_reserved_mem
other_stage_mem_reserved = blocks[1].max_reserved_mem
if pp == 1:
reserved_mem_costs.append(max(first_stage_mem_reserved, last_stage_mem_reserved, other_stage_mem_reserved))
return reserved_mem_costs
reserved_mem_costs.append(max(first_stage_mem_reserved, other_stage_mem_reserved))
for _ in range(1, pp - 1):
reserved_mem_costs.append(other_stage_mem_reserved)
reserved_mem_costs.append(max(last_stage_mem_reserved, other_stage_mem_reserved))
return reserved_mem_costs
@staticmethod
def get_num_fwd_act(pp: int, stage: int, micro_batch_num: int) -> int:
"""
给出指定stage做1F1B调度时需保存的峰值前向激活值份数
:param pp: 本次训练流水线并行度,也即总流水线stage数量
:param stage: 当前stage序号,首stage序号为0
:param micro_batch_num: 在流水线上的微批个数,即gbs // dp // mbs
:return: 该stage需保存你的峰值前向激活值份数
"""
return min(pp - stage, micro_batch_num)
@staticmethod
def get_stage_mem_cost(current_blocks, num_fwd_act):
"""
计算stage的内存开销
"""
mem_cost = 0
head_block = current_blocks[0]
head_block.is_first = True
for block in current_blocks:
block.num_fwd_act = num_fwd_act
mem_cost += block.block_mem()
mem_cost += TinkerCostModel.calc_recompute_mem(current_blocks)
head_block.is_first = False
return mem_cost
@staticmethod
def calc_recompute_mem(blocks: List[BlockArgs]):
recompute_work_block = max(blocks, key=lambda x: x.data.act if x.recompute else 0)
recompute_work_mem = recompute_work_block.data.act if recompute_work_block.recompute else 0
return recompute_work_mem
@staticmethod
def get_pp_range(num_npus, num_layers, p_args: ProfileArgs):
for pp in range(1, min(num_layers, num_npus) + 1):
if num_npus % (p_args.npu_used * pp) == 0 and num_npus // p_args.tp // pp >= p_args.ep:
yield pp
@staticmethod
def _read_band_file(file_path: str):
with open(file_path) as f:
src_data = csv.reader(f)
_ = next(src_data)
row = next(src_data)
return [float(band) for band in row]
@staticmethod
def _refresh_blocks(param):
for block in param.blocks:
block.num_fwd_act = None
block.is_first = False
def node_comm_time(self, data_size, inter_node=True):
"""返回用于计算p2p通信时间的通信时间 inter_node用于指定是否节点间通信 向下取整到2的幂次"""
if not self._band_data_ready:
raise RuntimeError("band data not ready yet, run `_read_band_time` first.")
if data_size < 0:
raise ValueError(f'communicate data size invalid: {data_size} <= 0')
if data_size == 0:
return 0
bands = self.inter_band if inter_node else self.intra_band
index = int(math.log(data_size, 2))
if index >= 1:
index -= 1
if index >= len(bands):
band = bands[-1] * 0.001
else:
band = bands[index] * 0.001
return data_size / band
def p2p_comm_time(self, block_args: BlockArgs, num_npu_before: int, head=False, tail=False):
if not head and not tail:
raise ValueError("When calculate p2p communicate time, either head or tail should be set to True")
comm_size = block_args.data.in_size if head else block_args.data.out_size
is_cross_nodes = num_npu_before % self.num_procs_per_node == 0 and num_npu_before
comm_time = self.node_comm_time(comm_size, is_cross_nodes)
return comm_time
def get_profile_arg_list(self) -> List[ProfileArgs]:
return self.profiled_data.get_profiled_args_list()
def get_block_args(self, block_name: str, profiled_args: ProfileArgs) -> BlockArgs:
data = self.profiled_data.get_data_by_args(profiled_args)
if block_name not in data:
raise KeyError(f"{block_name} is not defined in profiled_data")
block_data = data[block_name]
return BlockArgs(self.args, profiled_args, block_data)
def init_blocks(self, profile_args: ProfileArgs, num_layers: int) -> List[BlockArgs]:
"""当前就是头处理 + 若干个block + 尾处理,调用时机确定 ProfileArgs 之后"""
block_list = [self.get_block_args(self.block_names.pre, profile_args)]
block_list.extend([self.get_block_args(self.block_names.block, profile_args) for _ in range(num_layers)])
block_list.append(self.get_block_args(self.block_names.post1, profile_args))
block_list.append(self.get_block_args(self.block_names.post2, profile_args))
attention_mask_mem = self.args.seq_length * self.args.seq_length / 1024.0 / 1024.0
for block in block_list[1:-2]:
block.attention_mask_mem = attention_mask_mem
return block_list
def get_stage_status(self, current_blocks, num_npu_before, is_first_stage, is_last_stage):
"""
此处计算与stage有关的time_cost
"""
time_cost = 0
head_block = current_blocks[0]
tail_block = current_blocks[-1]
for block in current_blocks:
time_cost += block.block_time()
input_comm = 0 if is_first_stage else self.p2p_comm_time(head_block, num_npu_before, head=True)
num_npu_before += head_block.num_npu_block
output_comm = 0 if is_last_stage else self.p2p_comm_time(tail_block, num_npu_before, tail=True)
time_cost += input_comm + output_comm
return num_npu_before, time_cost, input_comm, output_comm
def calculate_cost(self, param: TaskParam, pp_stage_block_intervals: list, detail=False):
if detail:
detail_infos = []
time_costs = []
mem_costs = []
num_npu_before = 0
profile_args = param.blocks[0].profile_args
micro_batch_num = self.args.global_batch_size // param.search_args.dp // profile_args.mbs
pp = param.search_args.pp
reserved_mem_costs = TinkerCostModel.calc_reserved_mem_costs(pp, param.blocks)
for p in range(pp):
if detail:
logger.info(f'stage {p}'.center(80, '='))
detailed_info = DetailedInfo()
detail_infos.append(detailed_info)
time_cost, mem_cost = 0, 0
head_idx, tail_idx = pp_stage_block_intervals[p]
head_block, tail_block = param.blocks[head_idx], param.blocks[tail_idx]
head_block.is_first = True
for block_idx in range(head_idx, tail_idx + 1):
block = param.blocks[block_idx]
block.num_fwd_act = TinkerCostModel.get_num_fwd_act(pp, p, micro_batch_num)
mem_cost += block.block_mem()
num_npu_before, time_cost, input_comm, output_comm = self.get_stage_status(
param.blocks[head_idx: tail_idx + 1], num_npu_before, p == 0, p == pp - 1)
recompute_mem = TinkerCostModel.calc_recompute_mem(param.blocks[head_idx:tail_idx + 1])
mem_cost += recompute_mem
mem_cost += reserved_mem_costs[p]
time_costs.append(time_cost)
mem_costs.append(mem_cost)
if detail:
for block_idx in range(head_idx, tail_idx + 1):
block = param.blocks[block_idx]
_ = block.block_time(detail=detail, detail_info=detailed_info)
_ = block.block_mem(detail=detail, detail_info=detailed_info)
detailed_info.set_and_print(input_comm, output_comm, recompute_mem, reserved_mem_costs[p], mem_cost)
logger.info('stage %d total Memory: %.3f MB', p, mem_cost)
bubble_time = sum(time_costs)
profile_args = param.blocks[0].profile_args
micro_batch_num = self.args.global_batch_size // param.search_args.dp // profile_args.mbs
time_costs = [bubble_time + (micro_batch_num - 1) * stage_time for stage_time in time_costs]
if detail:
logger.info(f'Time Cost with Bubble'.center(80, '='))
logger.info('Sum(unit time): %.3f ms', bubble_time / 1000)
for time_cost, detail_info in zip(time_costs, detail_infos):
detail_info.print_time(bubble_time, micro_batch_num, time_cost)
self._refresh_blocks(param)
return Metrics(time_costs, mem_costs, max(time_costs), max(mem_costs))
def _read_block_data(self, data_path: str):
"""基于profiler,生成searcher参数范围;或者直接基于每个tp sp mbs [ep],去衍化dp pp zero"""
file_path = os.path.join(data_path, 'profiled_data.csv')
try:
with open(file_path, 'r') as f:
data = csv.reader(f)
next(data, None)
for row in data:
if all(not field.strip() for field in row):
continue
filename_without_suffix = row[0]
block_name = row[1]
data = tuple(float(data) for data in row[2:])
profile_args = ProfileArgs.new_from_file_name(filename_without_suffix)
self.profiled_data.add_data(data, profile_args, block_name)
except Exception as e:
raise RuntimeError(f'Load profiled data: {file_path} failed.') from e
self._block_data_ready = True
def _read_band_time(self, data_path):
intra_band_file = os.path.join(data_path, "p2p_intra_node.csv")
inter_band_file = os.path.join(data_path, "p2p_inter_node.csv")
logger.info(intra_band_file)
logger.info(inter_band_file)
try:
self.intra_band = self._read_band_file(intra_band_file)
except FileNotFoundError:
logger.error(f"intra-node bandwidth file is not found.")
try:
self.inter_band = self._read_band_file(inter_band_file)
except FileNotFoundError:
logger.error(f"inter-node bandwidth file is not found, using intra-node bandwidth instead.")
self.inter_band = self.intra_band
if self.inter_band is None and self.intra_band is None:
raise RuntimeError("Intra bandwidth and intra bandwidth file are required.")
self._band_data_ready = True
def run(args: argparse.Namespace):
if args.mode != 'simulate':
return
init_log(None, logging.DEBUG)
start_time = time.time()
preprocess_args(args)
load_infos(args)
print_args(args)
cost_model = TinkerCostModel(args)
pred_profiled_args = ProfileArgs(tp=args.simu_tp, sp=args.simu_sp, ep=args.simu_ep, mbs=args.micro_batch_size)
pred_blocks = cost_model.init_blocks(pred_profiled_args, args.num_layers)
remainder = args.num_npus % (args.simu_pp * pred_profiled_args.tp)
if remainder != 0:
raise ValueError(
"incorrect num_npus={}, pp={}, tp={}, the former must be divided into the latter two.".format(
args.num_npus, args.simu_pp, pred_profiled_args.tp
))
npu_used = pred_profiled_args.tp * args.simu_pp
if args.num_npus % npu_used:
raise ValueError("num_npus cannot be evenly divided by the parallel strategy, check tp pp")
dp = args.num_npus // npu_used
local_batch_size = dp * pred_profiled_args.mbs
if args.global_batch_size % local_batch_size:
raise ValueError("incorrect gbs={}, dp={}, mbs={}, the former must be divided into the latter two.".format(
args.global_batch_size, dp, args.micro_batch_size
))
cost_model_args = dict(dp=dp, dist_opt=args.dist_opt, recompute=args.recompute)
for block in pred_blocks:
block.update_cost_model_args(cost_model_args)
for block in [pred_blocks[0], pred_blocks[-2], pred_blocks[-1]]:
block.recompute = False
split_way = list(map(int, args.num_layer_list.split(',')))
intervals = convert_to_pp_stage_block_idx(split_way, len(pred_blocks))
search_args = SearchArgs(pp=args.simu_pp, **cost_model_args, **pred_profiled_args.__dict__)
task_param = TaskParam(search_args=search_args, blocks=pred_blocks)
strategy = ResultArgs(
gbs=args.global_batch_size,
num_layers_list=args.num_layer_list,
blocks=task_param.blocks,
**task_param.search_args.__dict__
)
metrics = cost_model.calculate_cost(task_param, intervals, args.detail)
result_output_handler = ResultOutputHandler(args, cost_model, [(strategy, metrics)])
result_output_handler.print_and_write_to_file(1, save=False)
end_time = time.time()
logger.info(f"[TOTAL TIME] {end_time - start_time} s.")