516cc34c创建于 2024年7月24日历史提交
import torch


class PipelineCheckpointManager:
    instance = None

    def __init__(self, num_of_chunks=2):
        self.open_ri_pipe = False
        self.do_pre_recompute = False
        self.checkpoint_list = []
        self.chunk_list = [[] for i in range(num_of_chunks)]
        self.chunk_do_recompute = True

    def batch_fin(self, chunk_idx):
        self.chunk_list[chunk_idx].append(self.checkpoint_list)
        self.checkpoint_list = []

    def iter_fin(self):
        if len(self.checkpoint_list) != 0:
            raise RuntimeError("recompute list is not empty")

        for batch_list_for_chunk in self.chunk_list:
            for layer_list_for_batch in batch_list_for_chunk:
                if len(layer_list_for_batch) != 0:
                    raise RuntimeError(
                        f"{torch.distributed.get_rank()} batch recompute list is not empty, got {layer_list_for_batch[0]} in index 0")
            batch_list_for_chunk.clear()
        self.open_ri_pipe = False
        self.do_pre_recompute = False
        self.chunk_do_recompute = True

    def add_recompute(self, recompute_func):
        self.checkpoint_list.append(recompute_func)

    def recompute_next_force(self, chunk_idx):
        layer_list_for_head_batch = []
        while len(layer_list_for_head_batch) == 0:
            if len(self.chunk_list[chunk_idx]) == 0:
                print(f"rank-{torch.distributed.get_rank()} warning, poping empty list")
                break
            layer_list_for_head_batch = self.chunk_list[chunk_idx].pop(0)

        for layer_func in layer_list_for_head_batch:
            layer_func()

    def recompute_next(self, chunk_idx):
        layer_list_for_head_batch = self.chunk_list[chunk_idx].pop(0)
        if self.do_pre_recompute:
            for layer_func in layer_list_for_head_batch:
                layer_func()

    def disable_recompute(self):
        self.chunk_do_recompute = False

    def enable_recompute(self):
        self.chunk_do_recompute = True


def get_pipeline_checkpoint_manager(num_of_chunks=2):
    if PipelineCheckpointManager.instance is None:
        PipelineCheckpointManager.instance = PipelineCheckpointManager(num_of_chunks=num_of_chunks)
    return PipelineCheckpointManager.instance