"""
算子预估
"""
from mindspeed.auto_settings.config.model_config import get_model_config
from mindspeed.auto_settings.config.system_config import get_system_config
from mindspeed.auto_settings.module.memory.memory_modeling import MemoryModeling
from mindspeed.auto_settings.utils.logger import get_logger
from mindspeed.auto_settings.utils.utils import get_num_warmup_micro_batches
class MemoryCost(object):
def __init__(self):
self.logger = get_logger("Memory")
self.memory_model: MemoryModeling = None
def train_models(self):
work_dir = get_system_config().work_dir
self.memory_model = MemoryModeling(get_model_config())
self.memory_model.modeling(working_dir=work_dir)
def get_memory_cost(self, config):
"""
算子执行耗时
"""
if config.layers_per_vpp:
layers_per_vpp = config.layers_per_vpp
else:
layers_per_vpp = get_model_config().num_layers // config.pp
num_layers = get_model_config().num_layers // config.pp
device_mem_cap = get_system_config().memory_cap
recompute_mem, peak_stage_mem, optimizer_peak = self.memory_model.estimate(config)
peak_mem = max(peak_stage_mem, optimizer_peak)
self.logger.debug(f"before recompute, memory = {peak_stage_mem}")
memory_per_layer = recompute_mem
warmup_micro_batchs, total_num_micro_batches = get_num_warmup_micro_batches(config, get_model_config())
release_mem = 0
max_release_mem = warmup_micro_batchs * layers_per_vpp * memory_per_layer - memory_per_layer
oom_cap = device_mem_cap - peak_stage_mem
if max_release_mem <= oom_cap:
return {
"layer_calculate": 0,
"peak_memory": peak_mem,
"need_recompute": False,
"new_memory": max_release_mem + peak_stage_mem,
"recompute_layer": 0
}
if config.layers_per_vpp:
layer_calculate = (oom_cap // (memory_per_layer * config.pp))
release_mem += layer_calculate * memory_per_layer * config.pp
if 0 < layer_calculate < num_layers:
release_mem -= memory_per_layer
return {
"layer_calculate": layer_calculate,
"peak_memory": peak_mem,
"need_recompute": True,
"new_memory": release_mem + peak_stage_mem,
"recompute_layer": num_layers - layer_calculate
}
layer_calculate = (oom_cap // (memory_per_layer * config.pp))
release_mem += layer_calculate * memory_per_layer * config.pp
if 0 < layer_calculate < num_layers:
release_mem -= memory_per_layer
return {
"layer_calculate": layer_calculate,
"peak_memory": peak_mem,
"need_recompute": True,
"new_memory": release_mem + peak_stage_mem,
"recompute_layer": num_layers - layer_calculate
}