from typing import no_type_check, Any, DefaultDict, Iterable, List, Set, Tuple
from collections import defaultdict
from copy import deepcopy
from dataclasses import replace
import math
import os
from mindspeed.auto_settings.config.model_config import ModelConfig
from mindspeed.auto_settings.config.search_config import SearchConfig
from mindspeed.auto_settings.search_space import SearchSpace
from mindspeed.auto_settings.utils.dtype import DTYPE
from mindspeed.auto_settings.utils.file_utils import restricted_read
from mindspeed.auto_settings.utils.logger import get_logger
from mindspeed.auto_settings.utils.mem_utils import mem_b_to_mb
from mindspeed.auto_settings.utils.utils import get_tp_for_profiling
_Param = DefaultDict[str, List[int]]
class StaticMemModeling:
PP4_FILENAME = "auto_settings_static_model_pp4.json"
EXPERT2_FILENAME = "auto_settings_static_model_expert2.json"
TP2_FILENAME = "auto_settings_static_model_tp2.json"
@no_type_check
def __init__(self, model_cfg: ModelConfig) -> None:
self.model_config = model_cfg
self._logger = get_logger("static_mem")
self._baseline_tp = get_tp_for_profiling()
self.params_inputs: _Param = None
self.params_per_layer_wo_experts: _Param = None
self.params_per_experts: _Param = None
self.params_outputs: _Param = None
self.params_tp_unaffected: Set[str] = set()
@staticmethod
def _cal_embedding_size(
vocab_size: int,
make_vocab_size_divisible_by: int,
tp: int,
hidden_size: int
) -> int:
division = make_vocab_size_divisible_by * tp
padded_vocab_size = int(math.ceil(vocab_size / division) * division)
return padded_vocab_size * hidden_size
@staticmethod
def _diff_params(left: _Param, right: _Param) -> _Param:
result = deepcopy(left)
for k, v in right.items():
result[k].extend([-params for params in v])
zeros_out_keys = list()
for k, v in result.items():
if sum(v) == 0:
zeros_out_keys.append(k)
for k in zeros_out_keys:
result.pop(k)
return result
@staticmethod
def _merge_params(params_list: Iterable[_Param]) -> _Param:
result = defaultdict(list)
for params in params_list:
for k, v in params.items():
result[k].extend(v)
return result
def generate_static_mem_profiling_list(self) -> List[Tuple[SearchConfig, str]]:
result: List[Tuple[SearchConfig, str]] = list()
pp4_cfg = SearchConfig()
pp4_cfg.copy_from_config(self.model_config)
pp4_cfg.tensor_model_parallel_size = self._baseline_tp
pp4_cfg.context_parallel_size = 1
pp4_cfg.pipeline_model_parallel_size = 4
pp4_cfg.num_layers = 4
pp4_cfg.untie_embeddings_and_output_weights = True
pp4_cfg.seq_length = 4096
if self.model_config.is_moe():
pp4_cfg.num_experts = 1
pp4_cfg.expert_model_parallel_size = 1
pp4_cfg.moe_tp_extend_ep = False
result.append((pp4_cfg, self.PP4_FILENAME))
if self.model_config.is_moe():
expert2_cfg = replace(
pp4_cfg,
pipeline_model_parallel_size=1,
num_layers=1,
num_experts=2,
expert_model_parallel_size=1
)
result.append((expert2_cfg, self.EXPERT2_FILENAME))
tp2 = self._baseline_tp * 2
pp = 1
if self.model_config.dist_train:
tp2 = 1
pp = 4
tp2_cfg = replace(
pp4_cfg,
tensor_model_parallel_size=tp2,
pipeline_model_parallel_size=pp,
num_layers=1,
untie_embeddings_and_output_weights=False
)
result.append((tp2_cfg, self.TP2_FILENAME))
for cfg, _ in result:
self._logger.debug(f"cfg.tp:{cfg.tensor_model_parallel_size}\
cfg.cp:{cfg.context_parallel_size}\
cfg.pp:{cfg.pipeline_model_parallel_size}")
cfg.prepare_for_profiling()
return result
def model_static_mem(self, working_dir: str) -> None:
def _decode(filename: str) -> Any:
filepath = os.path.join(working_dir, filename)
return restricted_read(filepath)
def _get_pp4_params_list(filename: str) -> List[_Param]:
result: List[_Param] = [None] * 4
for pp_rank, model_params in _decode(filename):
if not result[pp_rank]:
result[pp_rank] = defaultdict(list)
for p_name, p_size in model_params:
result[pp_rank][p_name].append(p_size)
return result
total_pp4_params = _get_pp4_params_list(self.PP4_FILENAME)
per_layer_w_experts_params = total_pp4_params[1]
self.params_inputs = self._diff_params(
total_pp4_params[0],
per_layer_w_experts_params
)
self.params_outputs = self._diff_params(
total_pp4_params[-1],
per_layer_w_experts_params
)
if self.model_config.is_moe():
self.params_per_experts = self._diff_params(
_get_pp4_params_list(self.EXPERT2_FILENAME)[0],
self._merge_params((
self.params_inputs,
per_layer_w_experts_params,
self.params_outputs
))
)
else:
self.params_per_experts = defaultdict(list)
self.params_per_layer_wo_experts = self._diff_params(
per_layer_w_experts_params,
self.params_per_experts
)
tp1_params = self._merge_params((
self.params_inputs,
self.params_per_layer_wo_experts,
self.params_per_experts,
self.params_outputs
))
tp2_params = _get_pp4_params_list(self.TP2_FILENAME)[0]
last_embedding_name = set(tp1_params.keys()).difference(tp2_params.keys())
last_embedding_params: Set[int] = set()
for k in last_embedding_name:
last_embedding_params.add(sum(self.params_outputs.pop(k)))
first_embedding_name: Set[str] = set()
for k, v in self.params_inputs.items():
if sum(v) in last_embedding_params:
first_embedding_name.add(k)
last_embedding_params.remove(sum(v))
for k in first_embedding_name:
self.params_inputs.pop(k)
for k in tp2_params.keys():
if sum(tp1_params.get(k, list())) == sum(tp2_params.get(k, list())):
self.params_tp_unaffected.add(k)
for k, v in self.params_inputs.items():
if k not in self.params_tp_unaffected:
self.params_inputs[k] = [p * self._baseline_tp for p in v]
for k, v in self.params_per_layer_wo_experts.items():
if k not in self.params_tp_unaffected:
self.params_per_layer_wo_experts[k] = [p * self._baseline_tp for p in v]
for k, v in self.params_per_experts.items():
if k not in self.params_tp_unaffected:
self.params_per_experts[k] = [p * self._baseline_tp for p in v]
for k, v in self.params_outputs.items():
if k not in self.params_tp_unaffected:
self.params_outputs[k] = [p * self._baseline_tp for p in v]
self._logger.debug("== first embedding name:")
for k in first_embedding_name:
self._logger.debug(k)
self._logger.debug("== headings params:")
for k, v in self.params_inputs.items():
self._logger.debug(f"{k}: {sum(v)} | {v}")
self._logger.debug("== layer_wo_experts params:")
for k, v in self.params_per_layer_wo_experts.items():
self._logger.debug(f"{k}: {sum(v)} | {v}")
self._logger.debug("== experts params:")
for k, v in self.params_per_experts.items():
self._logger.debug(f"{k}: {sum(v)} | {v}")
self._logger.debug("== outputs params:")
for k, v in self.params_outputs.items():
self._logger.debug(f"{k}: {sum(v)} | {v}")
self._logger.debug(f"== last embedding name:")
for k in last_embedding_name:
self._logger.debug(k)
self._logger.debug("== not tp affected params:")
for name in self.params_tp_unaffected:
self._logger.debug(name)
def cal_static_mem(self, cfg: SearchConfig) -> List[float]:
dtype = self.model_config.dtype
non_expert_zero1 = cfg.dp * cfg.cp
expert_zero1 = cfg.dp * cfg.cp / (cfg.ep if cfg.ep else 1)
def _cal_static_mem_per_stage(
non_expert_params: int,
expert_params: int,
not_zero1_div_bytes: int,
zero1_div_bytes: int
) -> float:
result = float(0)
if cfg.zero1:
result += non_expert_params * \
(not_zero1_div_bytes + zero1_div_bytes / non_expert_zero1)
result += expert_params * \
(not_zero1_div_bytes + zero1_div_bytes / expert_zero1)
else:
result += (non_expert_params + expert_params) * \
(not_zero1_div_bytes + zero1_div_bytes)
result = mem_b_to_mb(result * dtype.value[1])
result += 5000
return result
static_mem_stages: List[float] = list()
for stage_id in range(cfg.pp):
non_expert_params_per_stage, expert_params_per_stage = \
self._cal_num_params_per_stage(stage_id, cfg)
if dtype == DTYPE.fp16:
static_mem_per_stage = _cal_static_mem_per_stage(
non_expert_params_per_stage,
expert_params_per_stage,
1 + 1,
8
)
elif dtype == DTYPE.bf16:
static_mem_per_stage = _cal_static_mem_per_stage(
non_expert_params_per_stage,
expert_params_per_stage,
1 + 2,
6
)
else:
static_mem_per_stage = _cal_static_mem_per_stage(
non_expert_params_per_stage,
expert_params_per_stage,
1 + 1,
2
)
static_mem_stages.append(static_mem_per_stage)
return static_mem_stages
def _cal_num_params_per_stage(
self,
stage_id: int,
cfg: SearchConfig
) -> Tuple[int, int]:
def _cal_num_params(p_name: str, p_size: List[int], ep: int = 1):
if p_name in self.params_tp_unaffected:
return sum(p_size)
else:
return sum(p_size) // ep // cfg.tp
num_layers = self.model_config.num_layers
non_expert_params = 0
for p_name, p_size in self.params_per_layer_wo_experts.items():
non_expert_params += _cal_num_params(p_name, p_size)
non_expert_params *= num_layers // cfg.pp
expert_params = 0
if cfg.num_experts and cfg.ep:
for p_name, p_size in self.params_per_experts.items():
expert_params += _cal_num_params(p_name, p_size, ep=cfg.ep)
expert_params *= (num_layers * cfg.num_experts) // cfg.pp
if stage_id == 0:
for p_name, p_size in self.params_inputs.items():
non_expert_params += _cal_num_params(p_name, p_size)
non_expert_params += self._cal_embedding_size(
self.model_config.vocab_size,
self.model_config.make_vocab_size_divisible_by,
cfg.tp,
self.model_config.hidden_size
) // cfg.tp
if stage_id == cfg.pp - 1:
for p_name, p_size in self.params_outputs.items():
non_expert_params += _cal_num_params(p_name, p_size)
if cfg.pp != 1 or self.model_config.untie_embeddings_and_output_weights:
non_expert_params += self._cal_embedding_size(
self.model_config.vocab_size,
self.model_config.make_vocab_size_divisible_by,
cfg.tp,
self.model_config.hidden_size
) // cfg.tp
return non_expert_params, expert_params