0e5976c7创建于 2025年4月27日历史提交
# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
import re
import torch
import torch_npu
from megatron.training import get_args


def get_layer_id(name):
    if name:
        matches = re.findall(r'\.(\d+)\.?', str(name))
        if matches:
            return matches[0]
        return -1
    return -1


class SwapTensor:
    """Manage the transmission of tensors between device and host"""
    def __init__(self, tensor, layer_name):
        self.tensor = tensor
        self.size = tensor.size()
        self.storage_size = tensor.storage().size()
        self.tensor_cpu = torch.empty(tensor.shape, dtype=tensor.dtype, pin_memory=True, device='cpu')

        self.d2h_event = None
        self.h2d_event = torch.npu.Event()

        self.stat = "device"
        self.layer_name = layer_name

        self.prefetch_data_ptr = tensor.data_ptr()
        self.storage_data_ptr = tensor.storage().data_ptr()
        self.layer_id = None
        self.first_tensor = False
        self.last_tensor = False
        self.is_slice_tensor = tensor.storage().size() != tensor.numel()
        self.stream = None
        self.layer_index = 0

    # device to host
    def launch_d2h(self, stream):
        if self.stat != "device":
            return
        forward_event = torch.npu.Event()
        forward_event.record()
        with torch.no_grad():
            with torch_npu.npu.stream(stream):
                stream.wait_event(forward_event)
                if self.is_slice_tensor:
                    self.tensor_cpu.copy_(self.tensor, non_blocking=True)
                else:
                    self.tensor_cpu.storage().copy_(self.tensor.storage(), non_blocking=True)
                self.stat = "d2h"

    # synchronize d2h and resize 0
    def wait_d2h_finished(self, stream, need_wait=False):
        if self.stat != "d2h":
            return
        if need_wait:
            torch.npu.current_stream().wait_stream(stream)
            torch.npu.default_stream().wait_stream(stream)
        self.tensor.storage().resize_(0)
        self.stat = "host"

    # resize storage_size and host to device
    def launch_h2d(self, stream, flag):
        if self.stat != "host":
            return
        backward_event = torch.npu.Event()
        backward_event.record()
        if flag:
            self.tensor.storage().resize_(self.storage_size)
        with torch.no_grad():
            with torch_npu.npu.stream(stream):
                stream.wait_event(backward_event)
                if self.is_slice_tensor:
                    self.tensor.copy_(self.tensor_cpu, non_blocking=True)
                else:
                    self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True)
                self.h2d_event.record()
                self.stat = "h2d"

    # synchronize h2d
    def wait_h2d_finished(self, stream, need_wait=False):
        if self.stat != "h2d":
            return
        if need_wait:
            torch.npu.current_stream().wait_stream(stream)
            torch.npu.default_stream().wait_stream(stream)
        self.stat = "device"


class SwapPrefetch:
    swap_prefetch = None

    def __init__(self, prefetch_args):
        swap_list, vpp, interval, num_prefetch = prefetch_args
        all_args = get_args()
        self.prefetch_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
        self.pp = all_args.pipeline_model_parallel_size
        self.vpp = min(vpp, num_prefetch)
        self.first_layer_id = 0
        if isinstance(getattr(all_args, 'noop_layers', None), set):
            for layer_id in swap_list[0]:
                if layer_id != '':
                    self.first_layer_id = int(layer_id)
                    break

        self.swap_tensors = []
        self.layer_name = ""

        self.data_ptr = {}
        self.prefetch_list = []
        self.prefetch_data_ptr_list = []
        self.cur_micro_num = 0
        self.remove_num = 0
        self.forward_flag = False
        self.interval = interval
        self.slice_tensor_storage_ptr = {}
        self.slice_tensor_storage_ptr_list = []
        self.eval_end_flag = False

    @staticmethod
    def no_swap_tensor(ori_tensor):
        if ori_tensor.numel() * ori_tensor.element_size() * 2 < 1024 * 1024:
            return True
        if ori_tensor.grad_fn is None:
            return True
        if ori_tensor.storage().size() == 0:
            return True
        if ori_tensor.storage().size() != ori_tensor.numel():
            return True
        if ori_tensor._base is not None and ori_tensor._base.dim() >= 5:
            return True

        return False

    def pack_hook(self, ori_tensor):
        """hook function for forward to device to host
        Args:
            ori_tensor (Tensor): Device tensor expected to be swaped to host
        Returns:
            class: Information about swap and swaped host Tensor.
        """
        args = get_args()
        if args.eval_interval:
            if args.curr_iteration % args.eval_interval != 0:
                self.eval_end_flag = False
            if args.curr_iteration and args.curr_iteration % args.eval_interval == 0 and not self.eval_end_flag:
                self.prefetch_data_ptr_list = []
                self.prefetch_list = []
                self.slice_tensor_storage_ptr_list = []
                self.eval_end_flag = True

        if self.no_swap_tensor(ori_tensor):
            return ori_tensor
        swap_tensor = SwapTensor(ori_tensor, self.layer_name)
        if not self.swap_tensors:
            swap_tensor.first_tensor = True
        # Records the slice tensor status.
        if ori_tensor.storage().size() != ori_tensor.numel():
            swap_tensor.is_slice_tensor = True
            if ori_tensor.storage().data_ptr() not in self.slice_tensor_storage_ptr:
                if self.swap_tensors and self.swap_tensors[0].layer_id != 0:
                    self.slice_tensor_storage_ptr[ori_tensor.storage().data_ptr()] = \
                        [f'{len(self.prefetch_list) - 1}_{len(self.swap_tensors)}']
                else:
                    self.slice_tensor_storage_ptr[ori_tensor.storage().data_ptr()] = \
                        [f'{len(self.prefetch_list)}_{len(self.swap_tensors)}']
            else:
                if self.swap_tensors and self.swap_tensors[0].layer_id != 0:
                    self.slice_tensor_storage_ptr[ori_tensor.storage().data_ptr()].append(
                        f'{len(self.prefetch_list) - 1}_{len(self.swap_tensors)}')
                else:
                    self.slice_tensor_storage_ptr[ori_tensor.storage().data_ptr()].append(
                        f'{len(self.prefetch_list)}_{len(self.swap_tensors)}')

        # Records the same data_ptr tensor status.
        if ori_tensor.storage().data_ptr() in self.data_ptr:
            self.swap_tensors[self.data_ptr[ori_tensor.storage().data_ptr()]].stat = 'h2d'
            swap_tensor.stat = 'd2h'
            swap_tensor.tensor_cpu = self.swap_tensors[self.data_ptr[ori_tensor.storage().data_ptr()]].tensor_cpu
            self.data_ptr[ori_tensor.storage().data_ptr()] = len(self.swap_tensors)
        else:
            self.data_ptr[ori_tensor.storage().data_ptr()] = len(self.swap_tensors)

        swap_tensor.launch_d2h(self.prefetch_stream)
        swap_tensor.stream = self.prefetch_stream
        swap_tensor.layer_id = int(get_layer_id(swap_tensor.layer_name))
        self.swap_tensors.append(swap_tensor)
        self.forward_flag = True
        return swap_tensor

    def unpack_hook(self, swap_tensor):
        """hook function for backward to host to device
        Args:
            swap_tensor(class): Information about swap and swaped host Tensor.
        Returns:
            Tensor: Device tensor swaped from host
        """
        if isinstance(swap_tensor, torch.Tensor):
            return swap_tensor
        swap_tensor.wait_h2d_finished(self.prefetch_stream, swap_tensor.last_tensor)
        self.prefetch_list[self.cur_micro_num][swap_tensor.layer_index].remove(swap_tensor)
        # Remove prefetch completed list
        if len(self.prefetch_list[self.cur_micro_num][swap_tensor.layer_index]) == 0:
            self.prefetch_list[self.cur_micro_num].remove(
                self.prefetch_list[self.cur_micro_num][swap_tensor.layer_index])
            self.prefetch_data_ptr_list[self.cur_micro_num].remove(
                self.prefetch_data_ptr_list[self.cur_micro_num][swap_tensor.layer_index])
            self.slice_tensor_storage_ptr_list[self.cur_micro_num].remove(
                self.slice_tensor_storage_ptr_list[self.cur_micro_num][swap_tensor.layer_index])
            if len(self.prefetch_list[self.cur_micro_num]) == 0:
                self.prefetch_list.remove(self.prefetch_list[self.cur_micro_num])
                self.prefetch_data_ptr_list.remove(self.prefetch_data_ptr_list[self.cur_micro_num])
                self.slice_tensor_storage_ptr_list.remove(self.slice_tensor_storage_ptr_list[self.cur_micro_num])
                self.remove_num += 1
                if self.remove_num // self.pp == self.vpp:
                    self.remove_num = 0
        self.forward_flag = False
        return swap_tensor.tensor

    def hook_swap_manager_forward(self, forward_func, layer_name):
        """
        Wrap a neural network layer's forward pass with tensor saving hooks
        Args:
            forward_func: Original forward function of the layer
            layer_name (str): Identifier for the target layer, used to track layer context
            in subsequent hook operations     
        Returns:
            Callable[..., Any]: Wrapped forward function with tensor saving logic
        """
        def custom_forward(*args, **kargs):
            self.layer_name = layer_name
            with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):
                return forward_func(*args, **kargs)

        return custom_forward

    def update_slice_tensor_stat(self, swap_tensor):
        if swap_tensor.is_slice_tensor and swap_tensor.storage_data_ptr in self.slice_tensor_storage_ptr:
            _, index = self.slice_tensor_storage_ptr[swap_tensor.storage_data_ptr][0].split('_')
            if swap_tensor != self.swap_tensors[int(index)]:
                swap_tensor.stat = 'host'
                return False
        return True

    def sync_d2h(self, module_name):
        """
        Synchronize d2h tensor transfers and manage prefetch pipeline
        Args:
            module_name (str): Identifier for the neural network module/layer being processed
        """
        if not self.swap_tensors:
            return
        if self.swap_tensors[0].layer_id <= self.first_layer_id:
            self.first_layer_id = self.swap_tensors[0].layer_id
        elif self.prefetch_list and self.swap_tensors[0].layer_id <= self.prefetch_list[-1][-1][-1].layer_id:
            self.first_layer_id = self.swap_tensors[0].layer_id
        first_resize_tensor = False
        for swap_tensor in self.swap_tensors:
            if self.swap_tensors[0].layer_id > self.first_layer_id and self.prefetch_list:
                swap_tensor.layer_index = len(self.prefetch_list[-1])
            if swap_tensor.layer_id == int(get_layer_id(module_name)) \
                    and swap_tensor.stat == "d2h":
                if not self.update_slice_tensor_stat(swap_tensor):
                    continue
                if not first_resize_tensor:
                    swap_tensor.first_tensor = True
                    first_resize_tensor = True
                # During synchronization, let the first tensor wait for d2h
                swap_tensor.wait_d2h_finished(swap_tensor.stream, swap_tensor.first_tensor)
        self.swap_tensors[-1].last_tensor = True
        if self.swap_tensors[-1].stat == 'host':
            if self.swap_tensors[0].layer_id > self.first_layer_id and self.prefetch_list:
                self.prefetch_list[-1].append(self.swap_tensors)
                self.prefetch_data_ptr_list[-1].append(self.data_ptr)
                self.slice_tensor_storage_ptr_list[-1].append(self.slice_tensor_storage_ptr)
            else:
                self.prefetch_list.append([self.swap_tensors])
                self.prefetch_data_ptr_list.append([self.data_ptr])
                self.slice_tensor_storage_ptr_list.append([self.slice_tensor_storage_ptr])
            self.swap_tensors = []
            self.data_ptr = {}
            self.slice_tensor_storage_ptr = {}
        if self.vpp == 1:
            self.cur_micro_num = 0
        else:
            if not self.remove_num and len(self.prefetch_list) > self.pp:
                self.cur_micro_num = self.pp * (self.vpp - 1)
            elif self.remove_num and self.remove_num % self.pp == 0:
                self.cur_micro_num = self.pp * (self.vpp - 1 - self.remove_num // self.pp)

    def h2d_special_tensor(self, swap_tensor):
        """
        Handle h2d transfer for sliced tensor storage with conflict resolution
        Args:
            swap_tensor (SwapTensor): Tensor wrapper requiring H2D transfer
        """
        if swap_tensor.is_slice_tensor:
            if swap_tensor.storage_data_ptr in self.slice_tensor_storage_ptr_list[self.cur_micro_num][swap_tensor.layer_index]:
                _, index = self.slice_tensor_storage_ptr_list[self.cur_micro_num][swap_tensor.layer_index][swap_tensor.storage_data_ptr][
                    0].split('_')
                if swap_tensor == self.prefetch_list[self.cur_micro_num][swap_tensor.layer_index][int(index)]:
                    swap_tensor.launch_h2d(self.prefetch_stream, True)
                    del self.slice_tensor_storage_ptr_list[self.cur_micro_num][swap_tensor.layer_index][swap_tensor.storage_data_ptr]
            else:
                swap_tensor.launch_h2d(self.prefetch_stream, False)
        else:
            swap_tensor.launch_h2d(self.prefetch_stream, True)

    def h2d(self, module_name):
        if not self.prefetch_list:
            return
        if self.vpp != 1 and not self.forward_flag:
            self.cur_micro_num = self.pp * (self.vpp - 1 - self.remove_num // self.pp)
        for swap_tensor_list in self.prefetch_list[self.cur_micro_num]:
            for swap_tensor in reversed(swap_tensor_list):
                if swap_tensor.layer_id + self.interval == int(get_layer_id(module_name)) \
                        and swap_tensor.stat == "host" \
                        and swap_tensor.storage_data_ptr in self.prefetch_data_ptr_list[self.cur_micro_num][swap_tensor.layer_index]:
                    del self.prefetch_data_ptr_list[self.cur_micro_num][swap_tensor.layer_index][swap_tensor.storage_data_ptr]
                    # For slice tensors, only the first tensor is resized. Other h2d the tensor size
                    self.h2d_special_tensor(swap_tensor)


def get_swap_prefetch(prefetch_args):
    if SwapPrefetch.swap_prefetch is None:
        SwapPrefetch.swap_prefetch = SwapPrefetch(prefetch_args)

    return SwapPrefetch.swap_prefetch


def pre_forward_hook_func(module_name, prefetch_args):
    def custom_func(module, *args, **kargs):
        get_swap_prefetch(prefetch_args).sync_d2h(module_name)

    return custom_func


def post_backward_hook_func(module_name, prefetch_args):
    def custom_func(module, *args, **kargs):
        get_swap_prefetch(prefetch_args).h2d(module_name)

    return custom_func


# manage activation tensor
def prefetch_tensor(module, name, prefetch_args):
    get_swap_prefetch(prefetch_args).hook_swap_manager_forward(module.forward, name)


# register prefetch before backward, prefetch h2d
def prefetch_register_post_backward_hook(module, name, prefetch_args):
    module.register_backward_hook(post_backward_hook_func(name, prefetch_args))


# register prefetch after forward, sync d2h
def prefetch_register_pre_forward_hook(module, name, prefetch_args):
    module.register_forward_hook(pre_forward_hook_func(name, prefetch_args))