from copy import deepcopy
from typing import List
import torch.cuda as cuda
from mindspeed.auto_settings.config.system_config import SystemConfig
from mindspeed.auto_settings.module.parse.profiling_parse.profiling_constant import NumberConstant
from mindspeed.auto_settings.config.search_config import SearchConfig
class ProfilingConfig:
"""
Basic parameters of profiling
"""
def __init__(self, search_cfg: SearchConfig, args=None):
self.search_cfg = deepcopy(search_cfg)
self.per_micro_layer = search_cfg.num_layers // search_cfg.pp
self.vpp = search_cfg.vpp if search_cfg.vpp else 1
self.micro_num = search_cfg.gbs // (search_cfg.mbs * search_cfg.dp) * self.vpp
self.stage_id = 0
if args:
if isinstance(args, SystemConfig):
self.nodes = args.nnodes
self.devices_per_node = args.nproc_per_node
self.node_rank = args.node_rank
else:
self.nodes = args.world_size // cuda.device_count()
self.devices_per_node = cuda.device_count()
self.node_rank = args.rank // cuda.device_count()
else:
self.nodes = 1
self.devices_per_node = 8
self.node_rank = 0
def search_first_operator_idx_for_per_layer_enable_pp_last_stage(self, fw_norm_index, bw_norm_index):
fw_layer_start = []
bw_layer_end = []
recompute_fw = []
warm_micro_num = self._calculate_warm_micro_num()
bw_idx = 0
fw_idx = 0
for micro in range(self.micro_num):
i = micro // (self.vpp * self.search_cfg.pp)
fw_layer_start.append([fw_norm_index[fw_idx]])
fw_idx = self._calculate_fw_idx(fw_idx, i, micro)
bw_idx = self._calculate_bw_idx(bw_idx, i, micro)
bw_layer_end.append([bw_norm_index[bw_idx - 1]])
if not self.search_cfg.dist_train and self.search_cfg.is_full_recompute():
if warm_micro_num <= micro + 1:
recompute_fw.append([fw_norm_index[fw_idx]])
fw_idx += NumberConstant.FW_NORM_OP_NUM_ENABLE_PP_OTHER_STAGE
if micro == self.micro_num - 1:
for i in range(warm_micro_num - 1):
fw_idx += i * NumberConstant.FW_NORM_OP_NUM_ENABLE_PP_OTHER_STAGE
recompute_fw.append([fw_norm_index[fw_idx]])
if self.vpp > 1:
fw_per_micro_opt_num = fw_layer_start[1][0] - fw_layer_start[0][0]
else:
fw_per_micro_opt_num = fw_norm_index[2] - fw_norm_index[0]
bw_per_micro_opt_num = bw_norm_index[2] - bw_norm_index[0]
return fw_layer_start, bw_layer_end, recompute_fw, fw_per_micro_opt_num, bw_per_micro_opt_num
def search_first_operator_idx_for_per_layer_enable_pp_other_stage(self, fw_norm_index, bw_norm_index):
fw_layer_start = []
bw_layer_end = []
recompute_fw = []
fw_norm_index = [fw_norm_index[i * 2: (i + 1) * 2] for i in range(len(fw_norm_index) // 2)]
bw_norm_index = [bw_norm_index[i * 2: (i + 1) * 2] for i in range(len(bw_norm_index) // 2)]
warm_micro_num = self._calculate_warm_micro_num()
for micro in range(self.micro_num):
if micro < warm_micro_num:
fw_layer_start.append([fw_norm_index[micro][0]])
else:
fw_layer_start.append([fw_norm_index[micro + micro - warm_micro_num + 1][0]])
recompute_fw.append([fw_norm_index[micro + micro - warm_micro_num][0]])
if micro == self.micro_num - 1:
recompute_fw.extend(
[[index[0]] for index in fw_norm_index[len(fw_norm_index) - warm_micro_num:]])
bw_layer_end.append([bw_norm_index[micro][-1]])
if not self.search_cfg.dist_train and self.search_cfg.is_full_recompute():
if len(recompute_fw) != self.micro_num:
for i in range(len(recompute_fw), self.micro_num):
recompute_fw.append([fw_norm_index[i + self.micro_num][0]])
bw_per_micro_opt_num = bw_norm_index[0][-1] - recompute_fw[0][0]
else:
bw_per_micro_opt_num = bw_norm_index[1][0] - bw_norm_index[0][0]
fw_per_micro_opt_num = fw_layer_start[1][0] - fw_layer_start[0][0]
return fw_layer_start, bw_layer_end, recompute_fw, fw_per_micro_opt_num, bw_per_micro_opt_num
def search_first_operator_idx_for_per_layer_enable_pp(self, fw_norm_index, bw_norm_index):
if self.stage_id == self.search_cfg.pp - 1:
return self.search_first_operator_idx_for_per_layer_enable_pp_last_stage(fw_norm_index, bw_norm_index)
else:
return self.search_first_operator_idx_for_per_layer_enable_pp_other_stage(fw_norm_index, bw_norm_index)
def search_first_operator_idx_for_per_layer_disable_pp(self, fw_norm_index, bw_norm_index):
fw_layer_start = []
bw_layer_end = []
recompute_fw = []
if not self.search_cfg.dist_train and self.search_cfg.is_full_recompute():
fw_micro_rms_num = len(fw_norm_index) // self.micro_num
fw_norm_index = [fw_norm_index[fw_micro_rms_num * i:fw_micro_rms_num * (i + 1)]
for i in range(self.micro_num)]
bw_micro_rms_num = len(bw_norm_index) // self.micro_num
bw_norm_index = [bw_norm_index[bw_micro_rms_num * i:bw_micro_rms_num * (i + 1)]
for i in range(self.micro_num)]
fw_per_micro_opt_num = fw_norm_index[0][2] - fw_norm_index[0][0]
bw_per_micro_opt_num = bw_norm_index[0][2] - bw_norm_index[0][0]
for micro in range(self.micro_num):
fw_layer_start.append([fw_norm_index[micro][0]])
bw_layer_end.append([bw_norm_index[micro][-1]])
if len(fw_norm_index[micro]) > 3:
recompute_fw.append([fw_norm_index[micro][3]])
else:
fw_per_micro_opt_num = fw_norm_index[2] - fw_norm_index[0]
bw_per_micro_opt_num = bw_norm_index[2] - bw_norm_index[0]
for micro in range(self.micro_num):
fw_layer_start.append([fw_norm_index[3 * micro]])
bw_layer_end.append([bw_norm_index[3 * (micro + 1) - 1]])
return fw_layer_start, bw_layer_end, recompute_fw, fw_per_micro_opt_num, bw_per_micro_opt_num
def _calculate_warm_micro_num(self):
if self.vpp != 1:
return self.search_cfg.pp * (self.vpp - 1) + 1 + (self.search_cfg.pp - self.stage_id - 1) * 2
else:
return self.search_cfg.pp - self.stage_id
def _calculate_fw_idx(self, fw_idx, i, micro):
if i * (self.vpp * self.search_cfg.pp) <= micro < i * (
self.vpp * self.search_cfg.pp) + self.search_cfg.pp and self.vpp > 1:
fw_idx += NumberConstant.FW_NORM_OP_NUM_ENABLE_PP_OTHER_STAGE
else:
fw_idx += NumberConstant.FW_NORM_OP_NUM_ENABLE_PP_LAST_STAGE
return fw_idx
def _calculate_bw_idx(self, bw_idx, i, micro):
if i * (self.vpp * self.search_cfg.pp) <= micro < i * (
self.vpp * self.search_cfg.pp) + self.search_cfg.pp or self.vpp == 1:
bw_idx += NumberConstant.FW_NORM_OP_NUM_ENABLE_PP_LAST_STAGE
else:
bw_idx += NumberConstant.FW_NORM_OP_NUM_ENABLE_PP_OTHER_STAGE
return bw_idx
class ProfilingLayerInfo:
def __init__(self):
self.time = []
self.start_memory = []
self.peak_memory = []
self.reserved_memory = []
self.operator_info = []
self.communication_info = []
def extend_attr(self, new_layer):
for attr_name in self.__dict__.keys():
obj_attr = getattr(self, attr_name)
if isinstance(obj_attr, list):
target_attr = getattr(new_layer, attr_name, [])
obj_attr.extend(target_attr)
setattr(self, attr_name, obj_attr)
class ProfilingModelInfo:
def __init__(self):
self.embedding = ProfilingLayerInfo()
self.forward = ProfilingLayerInfo()
self.loss = ProfilingLayerInfo()
self.backward = ProfilingLayerInfo()
self.optimizer = ProfilingLayerInfo()
self.hccl_memory = []
self.cann_and_driver_memory = []
self.recompute_memory = []
self.communication_matrix = []
self.context_parallel_comm = []
self.pipeline_parallel_comm = []
self.data_parallel_comm = []
self.tensor_parallel_comm = []
self.expert_parallel_comm = []
self.search_cfg = None
self.stage_id = 0
self.mc2_total_time = []
self.matmul_total_time = []
def extend_stage_info(self, new_model):
for attr_name in self.__dict__.keys():
obj_attr = getattr(self, attr_name)
if isinstance(obj_attr, list):
target_attr = getattr(new_model, attr_name, [])
obj_attr.extend(target_attr)
setattr(self, attr_name, obj_attr)
elif isinstance(obj_attr, ProfilingLayerInfo):
target_attr = getattr(new_model, attr_name, None)
obj_attr.extend_attr(target_attr)
class BaseParallelCommunication:
"""
Basic parallel communication information.
"""
def __init__(self):
self.total_time_ms: float = 0.0
self.wait_time_ms: float = 0.0
self.overlap_time_ms: float = 0.0
self.avg_ratio: float = 0.0
self.details: List[dict] = []
class ExpertParallelCommunication(BaseParallelCommunication):
"""
Expert parallel communication
"""
def __init__(self):
super(ExpertParallelCommunication, self).__init__()
self.min_comm_time_ms: float = 0.0
class TensorParallelCommunication(BaseParallelCommunication):
"""
Tensor parallel communication
"""
def __init__(self):
super(TensorParallelCommunication, self).__init__()
self.fixed_time_ms: float = 0.0
self.fixed_wait_time_ms: float = 0.0
class ContextParallelCommunication(BaseParallelCommunication):
"""
Context parallel communication
"""
def __init__(self):
super(ContextParallelCommunication, self).__init__()
self.vector_time_ms: float = 0.0
class DataParallelCommunication(BaseParallelCommunication):
"""
Data parallel communication
"""
def __init__(self):
super(DataParallelCommunication, self).__init__()
self.mlp_zero_time_ms: float = 0.0
self.mlp_ag_time_ms: float = 0.0
self.mlp_rs_time_ms: float = 0.0
self.other_zero_time_ms: float = 0.0
self.other_ag_time_ms: float = 0.0
self.other_rs_time_ms: float = 0.0
class PipelineParallelCommunication(BaseParallelCommunication):
"""
Pipeline parallel communication
"""
def __init__(self):
super(PipelineParallelCommunication, self).__init__()
self.min_pp_time: float = 0.0