import os
import time
import sys
import logging
import torch
from mindspeed_mm.fsdp.params.tools_args import MemoryProfiler
from mindspeed_mm.fsdp.utils.device import (IS_NPU_AVAILABLE, get_torch_device,
get_max_memory_reserved, get_max_memory_allocated, reset_peak_memory_stats)
from mindspeed_mm.fsdp.distributed.parallel_state import get_parallel_state
if IS_NPU_AVAILABLE:
import torch_npu
logger = logging.getLogger(__name__)
def _record(stacks: str = "all", max_entries: int = sys.maxsize):
get_torch_device().memory._record_memory_history(stacks=stacks, max_entries=max_entries)
def _dump(dump_path: str = "./memory_snapshot", dump_ranks=None):
rank_id = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
if dump_ranks is None or rank_id in dump_ranks:
str_time = time.strftime("%Y-%m-%d-%H-%M")
file_path = os.path.join(dump_path, f"snapshot_{str_time}_{rank_id}.pickle")
get_torch_device().memory._dump_snapshot(file_path)
if torch.distributed.is_initialized():
torch.distributed.barrier()
if rank_id == 0:
print(f"memory snapshot dump to {dump_path}")
def _stop():
get_torch_device().memory._record_memory_history(enabled=None)
class MemoryProfiler:
def __init__(self):
self.enable = False
self.start_step = None
self.end_step = None
self.save_path = None
self.dump_ranks = None
self.stacks = None
self.max_entries = None
self.current_step = 0
self.mem_info = False
def reset(self, config: MemoryProfiler):
if config is not None:
self.enable = config.enable
self.mem_info = config.mem_info
if self.enable:
rank_id = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
self.start_step = config.start_step
self.end_step = config.end_step
self.save_path = config.save_path
self.current_step = 0
self.dump_ranks = config.dump_ranks
self.stacks = config.stacks
self.max_entries = sys.maxsize if config.max_entries is None else config.max_entries
if rank_id == 0 and not os.path.exists(self.save_path):
os.makedirs(self.save_path)
self.step()
else:
self.enable = False
def step(self):
if self.enable:
if self.start_step == self.current_step:
_record(self.stacks, self.max_entries)
if self.end_step == self.current_step:
_dump(self.save_path, self.dump_ranks)
_stop()
self.enable = False
if self.mem_info:
max_memory_reserved = get_max_memory_reserved()
max_memory_allocated = get_max_memory_allocated()
ps = get_parallel_state()
print(f"\nstep: {self.current_step} \
global_rank: {torch.distributed.get_rank()} \
tp_rank: {ps.get_tp_rank()} \
dp_rank: {ps.get_dp_rank()} \
max_memory_reserved: {max_memory_reserved} \
max_memory_allocated: {max_memory_allocated}\n",
end="",
flush=True)
reset_peak_memory_stats()
self.current_step += 1
def stop(self):
if self.enable:
if self.start_step is not None and self.current_step > self.start_step:
_dump(self.save_path, self.dump_ranks)
_stop()
self.enable = False
memory_profiler = MemoryProfiler()