"""
算子预估
"""
import math
import os
from dataclasses import dataclass
import numpy as np
import operator
import functools
from mindspeed.auto_settings.config.model_config import get_model_config
from mindspeed.auto_settings.config.system_config import get_system_config
from mindspeed.auto_settings.utils.logger import get_logger
from mindspeed.auto_settings.module.operator.operator import Operator
from mindspeed.auto_settings.module.communication.communication import Communication
from mindspeed.auto_settings.utils.utils import get_num_warmup_micro_batches, get_module_info, get_prof_dir, \
get_black_prof_file
from mindspeed.auto_settings.config.search_config import SearchConfig
@dataclass
class PipelineParallelParas:
num_stages: int
vpp: int
fwd_durations: float
bwd_durations: float
num_microbatches: int
comm_matrix: list
class TimeCostBlack(object):
number_samples = 100
band_width_undirectional = 25
def __init__(self):
self.logger = get_logger("TimeCostBlack")
self.operator = Operator()
self.communication = Communication()
def crop_config(self, config: SearchConfig):
"""
config
"""
config.pipeline_model_parallel_size = 1
config.num_layers_per_virtual_pipeline_stage = config.num_layers // config.pp
return config
def get_module_time(self, config: SearchConfig, module_name, node_rank):
tmp_config = self.crop_config(config)
path = get_black_prof_file(tmp_config)
module = get_module_info(path, module_name)
fwd_time = module.get('time', float('inf')) * 1000
forward_step_time = get_module_info(path, 'forward_step_time') * 1000
backward_step_time = get_module_info(path, 'backward_step_time') * 1000
return fwd_time, fwd_time / forward_step_time * backward_step_time
def get_chunks_time_module_level(self, config: SearchConfig):
forward_time_each_chunk = []
backward_time_each_chunk = []
nnodes = get_system_config().nnodes
num_chunks = config.pp * config.vpp
if num_chunks == 1:
fwd_time = 0
bwd_time = 0
embedding = self.get_module_time(config, 'embedding', 0)
fwd_time += embedding[0]
bwd_time += embedding[1]
transformer = self.get_module_time(config, '0', 0)
fwd_time += transformer[0] * config.num_layers_per_virtual_pipeline_stage
bwd_time += transformer[1] * config.num_layers_per_virtual_pipeline_stage
final_norm = self.get_module_time(config, 'final_layernorm', 0)
fwd_time += final_norm[0]
bwd_time += final_norm[1]
output_layer = self.get_module_time(config, 'output_layer', 0)
fwd_time += output_layer[0]
bwd_time += output_layer[1]
loss = self.get_module_time(config, 'loss', 0)
fwd_time += loss[0]
bwd_time += loss[1]
forward_time_each_chunk.append(fwd_time)
backward_time_each_chunk.append(bwd_time)
return forward_time_each_chunk, backward_time_each_chunk
def get_iteration_time(self, config: SearchConfig):
iteration_times = np.array([0 for _ in range(TimeCostBlack.number_samples)]).astype(np.float64)
time_chunks = self.get_chunks_time_module_level(config)
for i in range(self.number_samples):
iteration_times[i], _ = self.pipeline_costmodel(config, time_chunks[0], time_chunks[1])
return iteration_times
def get_send_recv_time(self, shape):
data_size = functools.reduce(operator.mul, shape) * 2 / (1024 ** 3)
return (data_size / self.band_width_undirectional) * 1e6
def pipeline_costmodel(self, config: SearchConfig, fwd_time_chunks, bwd_time_chunks):
model_config = get_model_config()
send_recv_time = self.get_send_recv_time(
[model_config.seq_length, config.micro_batch_size, model_config.hidden_size]
)
comm_matrix = [[send_recv_time] * config.pipeline_model_parallel_size \
for _ in range(config.pipeline_model_parallel_size)]
for i in range(config.pipeline_model_parallel_size):
comm_matrix[i][i] = 0
paras = PipelineParallelParas(
num_stages=config.pipeline_model_parallel_size,
vpp=config.vpp,
fwd_durations=fwd_time_chunks,
bwd_durations=bwd_time_chunks,
num_microbatches=config.mbs,
comm_matrix=comm_matrix
)
scheduler_1f1b = self.get_schedule_1f1b(paras)
e2e_time_1f1b, stage_start_time = self.time_model_nfmb(paras, scheduler_1f1b)
return e2e_time_1f1b, stage_start_time
def get_schedule_1f1b(self, paras):
pp_stages = paras.num_stages
vpp = paras.vpp
num_microbatches = paras.num_microbatches
computation_placement = list(range(pp_stages * vpp)) + list(range(pp_stages * vpp - 1, -1, -1))
fwd_bwd_order = ([f'F_{i}' for i in range(pp_stages * vpp)] +
[f'B_{i}' for i in range(pp_stages * vpp - 1, -1, -1)])
def get_stage_list(fwd_seq, bwd_seq, num_advanced):
stage_order = []
n = len(fwd_seq)
num_advanced = min(n, num_advanced)
for idx in range(n):
if idx < num_advanced:
stage_order.append(fwd_seq[idx])
else:
stage_order.append(fwd_seq[idx])
stage_order.append(bwd_seq[idx - num_advanced])
if idx == n - 1:
for i in range(num_advanced):
stage_order.append(bwd_seq[i - num_advanced])
return stage_order
def get_stage_schedule(all_jobs_array, comp_placement, num_stages, vpp):
stage_job_list = []
for s in range(num_stages):
stage_chunk_id = [index for index, element in enumerate(comp_placement) if (element % num_stages) == s]
if vpp > 1:
warmup = num_stages * (vpp + 1) - 2 * (s + 1)
else:
warmup = num_stages - s - 1
fwds = all_jobs_array[stage_chunk_id[0:vpp]]
fwd_list = np.concatenate([fwds[:, index:index + num_stages].flatten()
for index in range(0, np.size(all_jobs_array, 1), num_stages)])
bwds = all_jobs_array[stage_chunk_id[vpp:]]
bwd_list = np.concatenate([bwds[:, index:index + num_stages].flatten()
for index in range(0, np.size(all_jobs_array, 1), num_stages)])
stage_s_list = get_stage_list(fwd_list, bwd_list, warmup)
stage_job_list.append(stage_s_list)
return stage_job_list
all_jobs = np.array([[s + f'-{i}' for i in range(num_microbatches)] for s in fwd_bwd_order])
stage_list = get_stage_schedule(all_jobs, computation_placement, pp_stages, vpp)
return stage_list
def time_model_nfmb(self, paras, stage_list):
num_pp_stages = paras.num_stages
num_mb = paras.num_microbatches
comm_matrix = paras.comm_matrix
vpp = paras.vpp
chunk_placement = list(range(num_pp_stages)) * vpp + list(range(num_pp_stages - 1, -1, -1)) * vpp
fwd_bwd_comp_order = ([f'F_{i}' for i in range(num_pp_stages * vpp)] +
[f'B_{i}' for i in range(num_pp_stages * vpp - 1, -1, -1)])
chunk_stage_map = dict(zip(fwd_bwd_comp_order, chunk_placement))
fwd_bwd_list = ([f"F_{j}-{i}" for i in range(num_mb) for j in range(num_pp_stages * vpp)]
+ [f"B_{j}-{i}" for i in range(num_mb) for j in range(num_pp_stages * vpp)])
values = [0 for _ in range(num_pp_stages * vpp * num_mb * 2)]
start_time = dict(zip(fwd_bwd_list, values))
fwd_bwd_durations = dict()
for j in range(num_pp_stages * vpp):
for i in range(num_mb):
fwd_bwd_durations[f"F_{j}-{i}"] = paras.fwd_durations[j]
fwd_bwd_durations[f"B_{j}-{i}"] = paras.bwd_durations[j]
start_time[f"F_{0}-{0}"] = 0.1
for s in range(num_pp_stages - 1):
start_time[f"F_{s + 1}-{0}"] = start_time[f"F_{s}-{0}"] + paras.fwd_durations[s] + comm_matrix[s][s + 1]
def get_prev_task_time(task_start_time,
task_list,
pp_stage_id,
task_idx,
chunk_stage_map,
comp_order,
model_chunk_times,
comm_time_matrix):
current_task = task_list[pp_stage_id][task_idx]
previous_task = task_list[pp_stage_id][task_idx - 1]
previous_task_name, _ = previous_task.split('-')
stage_id_previous_task = chunk_stage_map[previous_task_name]
chunk_position = comp_order.index(previous_task_name)
if chunk_position < len(comp_order) - 1:
stage_id_next = chunk_stage_map[comp_order[chunk_position + 1]]
comm_time = comm_time_matrix[stage_id_previous_task][stage_id_next]
else:
comm_time = 0.01
end_time_previous_task = (task_start_time[previous_task]
+ model_chunk_times[previous_task]
+ comm_time)
chunk_name, cur_mb_index = current_task.split('-')
chunk_position = comp_order.index(chunk_name)
if chunk_position > 0:
previous_chunk = comp_order[chunk_position - 1]
dependent_task = previous_chunk + '-' + cur_mb_index
comm_time = comm_time_matrix[chunk_stage_map[previous_chunk]][chunk_stage_map[chunk_name]]
end_time_dependent_task = (task_start_time[dependent_task]
+ model_chunk_times[dependent_task]
+ comm_time)
completed_flag = task_start_time[previous_task] > 0 and task_start_time[dependent_task] > 0
else:
end_time_dependent_task = 0.1
completed_flag = task_start_time[previous_task] > 0
return end_time_previous_task, end_time_dependent_task, completed_flag
begin_up = [1] * num_pp_stages
remaining = [num_mb * vpp * 2 - begin_up[p] for p in range(num_pp_stages)]
remaining_flag = True
count = 0
while remaining_flag:
ids_old = []
ids_new = []
for s in range(num_pp_stages):
ids_old.append(remaining[s])
if remaining[s]:
microbatch_idx = len(stage_list[0]) - remaining[s]
(end_time_prev_task_same_stage,
end_time_dependent_task_same_microbatch,
job_flag) = get_prev_task_time(start_time, stage_list, s, microbatch_idx, chunk_stage_map,
fwd_bwd_comp_order, fwd_bwd_durations, comm_matrix)
if job_flag:
start_time[stage_list[s][microbatch_idx]] = max(end_time_prev_task_same_stage,
end_time_dependent_task_same_microbatch)
remaining[s] = remaining[s] - 1
ids_new.append(remaining[s])
if all(item == 0 for item in remaining):
remaining_flag = False
if ids_old == ids_new:
count += 1
if count == 3:
self.logger.info("stage list is locked")
start_time[f'B_0-{num_mb - 1}'] = 1e7
break
e2e_time = start_time[f'B_0-{num_mb - 1}'] + paras.bwd_durations[-1]
stage_start_time = [[start_time[job_name] for job_name in stage_list[s]] for s in range(num_pp_stages)]
return e2e_time, stage_start_time