import os
import stat
import time
import json
import re
import atexit
import torch
from mindspeed.auto_settings.utils.logger import get_logger
from mindspeed.auto_settings.utils.singleton import Singleton
from mindspeed.auto_settings.utils.utils import standardize_path
class AutoPatcher(metaclass=Singleton):
def __init__(self, save_path):
self.module_profiling_step = 5
self.stop_profiling_step = 10
self.curr_step = 0
self.unit_gb = 1024 ** 3
self.context = {}
self.handles = []
self.profile_modules = ('embedding', '0', 'final_layernorm', 'output_layer')
self.save_path = standardize_path(save_path, check_read=False)
atexit.register(self.export_to_file)
self.logger = get_logger("AutoPatcher")
def export_to_file(self):
local_rank = int(os.environ.get('LOCAL_RANK', 0))
if local_rank == 0:
self.logger.info(f"rank: {torch.distributed.get_rank()} saving context: {self.context}")
flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
modes = stat.S_IWUSR | stat.S_IRUSR
if os.path.exists(self.save_path):
self.logger.warning(f'{self.save_path} will be overwrited !')
with open(self.save_path, 'w') as fout:
fout.write(json.dumps(self.context))
os.chmod(self.save_path, 0o640)
@staticmethod
def get_memory_status():
memory = torch.npu.memory_allocated()
max_memory = torch.npu.max_memory_allocated()
return memory, max_memory
def should_profiling(self, collect_step_time=False):
if collect_step_time:
return self.module_profiling_step <= self.curr_step < self.stop_profiling_step
else:
return self.curr_step < self.module_profiling_step
def hook_train_step(self, train_step):
def custom_train_step(*args, **kwargs):
if self.should_profiling(collect_step_time=True):
for handle in self.handles:
handle.remove()
torch.cuda.synchronize()
start_time = time.time()
result = train_step(*args, **kwargs)
torch.cuda.synchronize()
step_time = time.time() - start_time
if self.should_profiling(collect_step_time=True):
cur_step_time = self.context.get('step_time', 0)
cur_step_time = (cur_step_time * (self.curr_step - self.module_profiling_step) + step_time) \
/ (self.curr_step - self.module_profiling_step + 1)
self.context['step_time'] = cur_step_time
self.curr_step += 1
return result
return custom_train_step
def forward_pre_hook(self, module_name):
if module_name not in self.context.keys():
self.context[module_name] = dict()
def hook(module, *args, **kargs):
if self.should_profiling(collect_step_time=False):
if module_name not in self.context:
self.context[module_name] = {}
torch.npu.synchronize()
mem, _ = self.get_memory_status()
self.context[module_name]['time'] = time.time()
self.context[module_name]['memory'] = mem
self.context[module_name]['max_memory'] = mem
torch.npu.reset_max_memory_allocated()
return hook
def forward_post_hook(self, module_name):
def hook(module, *args, **kargs):
if self.should_profiling(collect_step_time=False):
torch.npu.synchronize()
self.context[module_name]['time'] = (time.time() - self.context[module_name]['time']) * 1000
mem, max_mem = self.get_memory_status()
mem1, mem2 = self.context[module_name]['memory'], self.context[module_name]['max_memory']
self.context[module_name]['memory'] = (mem - mem1) / self.unit_gb
self.context[module_name]['max_memory'] = (max_mem - mem2) / self.unit_gb
return hook
def register_recursive_hook(self, prefix_name, model, ctx):
model = model[0] if isinstance(model, list) else model
for name, module in model.named_children():
next_name = prefix_name + "." + name if prefix_name != "" else name
self.logger.info(f"hook next_name: {next_name}")
match_ret = re.search(r'[^.]+$', next_name)
if match_ret and match_ret.group(0) in self.profile_modules:
self.handles.append(module.register_forward_pre_hook(self.forward_pre_hook(name)))
self.handles.append(module.register_forward_hook(self.forward_post_hook(name)))
continue
self.register_recursive_hook(next_name, module, ctx)