import re
import torch
import torch_npu
from megatron.training import get_args
def get_layer_id(name):
if name:
matches = re.findall(r'\.(\d+)\.?', str(name))
if matches:
return matches[0]
return -1
return -1
class SwapTensor:
"""Manage the transmission of tensors between device and host"""
def __init__(self, tensor, layer_name):
self.tensor = tensor
self.size = tensor.size()
self.storage_size = tensor.storage().size()
self.tensor_cpu = torch.empty(tensor.shape, dtype=tensor.dtype, pin_memory=True, device='cpu')
self.d2h_event = None
self.h2d_event = torch.npu.Event()
self.stat = "device"
self.layer_name = layer_name
self.prefetch_data_ptr = tensor.data_ptr()
self.storage_data_ptr = tensor.storage().data_ptr()
self.layer_id = None
self.first_tensor = False
self.last_tensor = False
self.is_slice_tensor = tensor.storage().size() != tensor.numel()
self.stream = None
self.layer_index = 0
def launch_d2h(self, stream):
if self.stat != "device":
return
forward_event = torch.npu.Event()
forward_event.record()
with torch.no_grad():
with torch_npu.npu.stream(stream):
stream.wait_event(forward_event)
if self.is_slice_tensor:
self.tensor_cpu.copy_(self.tensor, non_blocking=True)
else:
self.tensor_cpu.storage().copy_(self.tensor.storage(), non_blocking=True)
self.stat = "d2h"
def wait_d2h_finished(self, stream, need_wait=False):
if self.stat != "d2h":
return
if need_wait:
torch.npu.current_stream().wait_stream(stream)
torch.npu.default_stream().wait_stream(stream)
self.tensor.storage().resize_(0)
self.stat = "host"
def launch_h2d(self, stream, flag):
if self.stat != "host":
return
backward_event = torch.npu.Event()
backward_event.record()
if flag:
self.tensor.storage().resize_(self.storage_size)
with torch.no_grad():
with torch_npu.npu.stream(stream):
stream.wait_event(backward_event)
if self.is_slice_tensor:
self.tensor.copy_(self.tensor_cpu, non_blocking=True)
else:
self.tensor.storage().copy_(self.tensor_cpu.storage(), non_blocking=True)
self.h2d_event.record()
self.stat = "h2d"
def wait_h2d_finished(self, stream, need_wait=False):
if self.stat != "h2d":
return
if need_wait:
torch.npu.current_stream().wait_stream(stream)
torch.npu.default_stream().wait_stream(stream)
self.stat = "device"
class SwapPrefetch:
swap_prefetch = None
def __init__(self, prefetch_args):
swap_list, vpp, interval, num_prefetch = prefetch_args
all_args = get_args()
self.prefetch_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
self.pp = all_args.pipeline_model_parallel_size
self.vpp = min(vpp, num_prefetch)
self.first_layer_id = 0
if isinstance(getattr(all_args, 'noop_layers', None), set):
for layer_id in swap_list[0]:
if layer_id != '':
self.first_layer_id = int(layer_id)
break
self.swap_tensors = []
self.layer_name = ""
self.data_ptr = {}
self.prefetch_list = []
self.prefetch_data_ptr_list = []
self.cur_micro_num = 0
self.remove_num = 0
self.forward_flag = False
self.interval = interval
self.slice_tensor_storage_ptr = {}
self.slice_tensor_storage_ptr_list = []
self.eval_end_flag = False
@staticmethod
def no_swap_tensor(ori_tensor):
if ori_tensor.numel() * ori_tensor.element_size() * 2 < 1024 * 1024:
return True
if ori_tensor.grad_fn is None:
return True
if ori_tensor.storage().size() == 0:
return True
if ori_tensor.storage().size() != ori_tensor.numel():
return True
if ori_tensor._base is not None and ori_tensor._base.dim() >= 5:
return True
return False
def pack_hook(self, ori_tensor):
"""hook function for forward to device to host
Args:
ori_tensor (Tensor): Device tensor expected to be swaped to host
Returns:
class: Information about swap and swaped host Tensor.
"""
args = get_args()
if args.eval_interval:
if args.curr_iteration % args.eval_interval != 0:
self.eval_end_flag = False
if args.curr_iteration and args.curr_iteration % args.eval_interval == 0 and not self.eval_end_flag:
self.prefetch_data_ptr_list = []
self.prefetch_list = []
self.slice_tensor_storage_ptr_list = []
self.eval_end_flag = True
if self.no_swap_tensor(ori_tensor):
return ori_tensor
swap_tensor = SwapTensor(ori_tensor, self.layer_name)
if not self.swap_tensors:
swap_tensor.first_tensor = True
if ori_tensor.storage().size() != ori_tensor.numel():
swap_tensor.is_slice_tensor = True
if ori_tensor.storage().data_ptr() not in self.slice_tensor_storage_ptr:
if self.swap_tensors and self.swap_tensors[0].layer_id != 0:
self.slice_tensor_storage_ptr[ori_tensor.storage().data_ptr()] = \
[f'{len(self.prefetch_list) - 1}_{len(self.swap_tensors)}']
else:
self.slice_tensor_storage_ptr[ori_tensor.storage().data_ptr()] = \
[f'{len(self.prefetch_list)}_{len(self.swap_tensors)}']
else:
if self.swap_tensors and self.swap_tensors[0].layer_id != 0:
self.slice_tensor_storage_ptr[ori_tensor.storage().data_ptr()].append(
f'{len(self.prefetch_list) - 1}_{len(self.swap_tensors)}')
else:
self.slice_tensor_storage_ptr[ori_tensor.storage().data_ptr()].append(
f'{len(self.prefetch_list)}_{len(self.swap_tensors)}')
if ori_tensor.storage().data_ptr() in self.data_ptr:
self.swap_tensors[self.data_ptr[ori_tensor.storage().data_ptr()]].stat = 'h2d'
swap_tensor.stat = 'd2h'
swap_tensor.tensor_cpu = self.swap_tensors[self.data_ptr[ori_tensor.storage().data_ptr()]].tensor_cpu
self.data_ptr[ori_tensor.storage().data_ptr()] = len(self.swap_tensors)
else:
self.data_ptr[ori_tensor.storage().data_ptr()] = len(self.swap_tensors)
swap_tensor.launch_d2h(self.prefetch_stream)
swap_tensor.stream = self.prefetch_stream
swap_tensor.layer_id = int(get_layer_id(swap_tensor.layer_name))
self.swap_tensors.append(swap_tensor)
self.forward_flag = True
return swap_tensor
def unpack_hook(self, swap_tensor):
"""hook function for backward to host to device
Args:
swap_tensor(class): Information about swap and swaped host Tensor.
Returns:
Tensor: Device tensor swaped from host
"""
if isinstance(swap_tensor, torch.Tensor):
return swap_tensor
swap_tensor.wait_h2d_finished(self.prefetch_stream, swap_tensor.last_tensor)
self.prefetch_list[self.cur_micro_num][swap_tensor.layer_index].remove(swap_tensor)
if len(self.prefetch_list[self.cur_micro_num][swap_tensor.layer_index]) == 0:
self.prefetch_list[self.cur_micro_num].remove(
self.prefetch_list[self.cur_micro_num][swap_tensor.layer_index])
self.prefetch_data_ptr_list[self.cur_micro_num].remove(
self.prefetch_data_ptr_list[self.cur_micro_num][swap_tensor.layer_index])
self.slice_tensor_storage_ptr_list[self.cur_micro_num].remove(
self.slice_tensor_storage_ptr_list[self.cur_micro_num][swap_tensor.layer_index])
if len(self.prefetch_list[self.cur_micro_num]) == 0:
self.prefetch_list.remove(self.prefetch_list[self.cur_micro_num])
self.prefetch_data_ptr_list.remove(self.prefetch_data_ptr_list[self.cur_micro_num])
self.slice_tensor_storage_ptr_list.remove(self.slice_tensor_storage_ptr_list[self.cur_micro_num])
self.remove_num += 1
if self.remove_num // self.pp == self.vpp:
self.remove_num = 0
self.forward_flag = False
return swap_tensor.tensor
def hook_swap_manager_forward(self, forward_func, layer_name):
"""
Wrap a neural network layer's forward pass with tensor saving hooks
Args:
forward_func: Original forward function of the layer
layer_name (str): Identifier for the target layer, used to track layer context
in subsequent hook operations
Returns:
Callable[..., Any]: Wrapped forward function with tensor saving logic
"""
def custom_forward(*args, **kargs):
self.layer_name = layer_name
with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):
return forward_func(*args, **kargs)
return custom_forward
def update_slice_tensor_stat(self, swap_tensor):
if swap_tensor.is_slice_tensor and swap_tensor.storage_data_ptr in self.slice_tensor_storage_ptr:
_, index = self.slice_tensor_storage_ptr[swap_tensor.storage_data_ptr][0].split('_')
if swap_tensor != self.swap_tensors[int(index)]:
swap_tensor.stat = 'host'
return False
return True
def sync_d2h(self, module_name):
"""
Synchronize d2h tensor transfers and manage prefetch pipeline
Args:
module_name (str): Identifier for the neural network module/layer being processed
"""
if not self.swap_tensors:
return
if self.swap_tensors[0].layer_id <= self.first_layer_id:
self.first_layer_id = self.swap_tensors[0].layer_id
elif self.prefetch_list and self.swap_tensors[0].layer_id <= self.prefetch_list[-1][-1][-1].layer_id:
self.first_layer_id = self.swap_tensors[0].layer_id
first_resize_tensor = False
for swap_tensor in self.swap_tensors:
if self.swap_tensors[0].layer_id > self.first_layer_id and self.prefetch_list:
swap_tensor.layer_index = len(self.prefetch_list[-1])
if swap_tensor.layer_id == int(get_layer_id(module_name)) \
and swap_tensor.stat == "d2h":
if not self.update_slice_tensor_stat(swap_tensor):
continue
if not first_resize_tensor:
swap_tensor.first_tensor = True
first_resize_tensor = True
swap_tensor.wait_d2h_finished(swap_tensor.stream, swap_tensor.first_tensor)
self.swap_tensors[-1].last_tensor = True
if self.swap_tensors[-1].stat == 'host':
if self.swap_tensors[0].layer_id > self.first_layer_id and self.prefetch_list:
self.prefetch_list[-1].append(self.swap_tensors)
self.prefetch_data_ptr_list[-1].append(self.data_ptr)
self.slice_tensor_storage_ptr_list[-1].append(self.slice_tensor_storage_ptr)
else:
self.prefetch_list.append([self.swap_tensors])
self.prefetch_data_ptr_list.append([self.data_ptr])
self.slice_tensor_storage_ptr_list.append([self.slice_tensor_storage_ptr])
self.swap_tensors = []
self.data_ptr = {}
self.slice_tensor_storage_ptr = {}
if self.vpp == 1:
self.cur_micro_num = 0
else:
if not self.remove_num and len(self.prefetch_list) > self.pp:
self.cur_micro_num = self.pp * (self.vpp - 1)
elif self.remove_num and self.remove_num % self.pp == 0:
self.cur_micro_num = self.pp * (self.vpp - 1 - self.remove_num // self.pp)
def h2d_special_tensor(self, swap_tensor):
"""
Handle h2d transfer for sliced tensor storage with conflict resolution
Args:
swap_tensor (SwapTensor): Tensor wrapper requiring H2D transfer
"""
if swap_tensor.is_slice_tensor:
if swap_tensor.storage_data_ptr in self.slice_tensor_storage_ptr_list[self.cur_micro_num][swap_tensor.layer_index]:
_, index = self.slice_tensor_storage_ptr_list[self.cur_micro_num][swap_tensor.layer_index][swap_tensor.storage_data_ptr][
0].split('_')
if swap_tensor == self.prefetch_list[self.cur_micro_num][swap_tensor.layer_index][int(index)]:
swap_tensor.launch_h2d(self.prefetch_stream, True)
del self.slice_tensor_storage_ptr_list[self.cur_micro_num][swap_tensor.layer_index][swap_tensor.storage_data_ptr]
else:
swap_tensor.launch_h2d(self.prefetch_stream, False)
else:
swap_tensor.launch_h2d(self.prefetch_stream, True)
def h2d(self, module_name):
if not self.prefetch_list:
return
if self.vpp != 1 and not self.forward_flag:
self.cur_micro_num = self.pp * (self.vpp - 1 - self.remove_num // self.pp)
for swap_tensor_list in self.prefetch_list[self.cur_micro_num]:
for swap_tensor in reversed(swap_tensor_list):
if swap_tensor.layer_id + self.interval == int(get_layer_id(module_name)) \
and swap_tensor.stat == "host" \
and swap_tensor.storage_data_ptr in self.prefetch_data_ptr_list[self.cur_micro_num][swap_tensor.layer_index]:
del self.prefetch_data_ptr_list[self.cur_micro_num][swap_tensor.layer_index][swap_tensor.storage_data_ptr]
self.h2d_special_tensor(swap_tensor)
def get_swap_prefetch(prefetch_args):
if SwapPrefetch.swap_prefetch is None:
SwapPrefetch.swap_prefetch = SwapPrefetch(prefetch_args)
return SwapPrefetch.swap_prefetch
def pre_forward_hook_func(module_name, prefetch_args):
def custom_func(module, *args, **kargs):
get_swap_prefetch(prefetch_args).sync_d2h(module_name)
return custom_func
def post_backward_hook_func(module_name, prefetch_args):
def custom_func(module, *args, **kargs):
get_swap_prefetch(prefetch_args).h2d(module_name)
return custom_func
def prefetch_tensor(module, name, prefetch_args):
get_swap_prefetch(prefetch_args).hook_swap_manager_forward(module.forward, name)
def prefetch_register_post_backward_hook(module, name, prefetch_args):
module.register_backward_hook(post_backward_hook_func(name, prefetch_args))
def prefetch_register_pre_forward_hook(module, name, prefetch_args):
module.register_forward_hook(pre_forward_hook_func(name, prefetch_args))