"""Forward step utilities."""
from collections.abc import Iterable
from functools import wraps
import torch
from megatron.training import get_args
from megatron.core import mpu, ModelParallelConfig, InferenceParams
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region
from megatron.inference.text_generation.forward_step import _get_recv_buffer_dtype
def inference_forward_step_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
args = get_args()
if not args.use_kv_cache:
self.inference_context = None
return wrapper
def _forward_step_helper(self, tokens, position_ids, attention_mask, recv_buffer=None):
"""Single forward step. Update the allocate memory flag so
only the first time the memory is allocated."""
config = ModelParallelConfig
batch_size = tokens.size(0)
sequence_length = tokens.size(1)
if recv_buffer is None:
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
recv_buffer = recv_buffer if mpu.is_pipeline_first_stage() else recv_buffer.shape
input_tensor = p2p_communication.recv_forward(recv_buffer, config)
self.model.set_input_tensor(input_tensor)
output_tensor = self._forward(tokens, position_ids, attention_mask)
p2p_communication.send_forward(output_tensor, config)
return output_tensor
def _no_pipelining_forward_step_wrapper(_no_pipelining_forward_step):
@wraps(_no_pipelining_forward_step)
def wrapper(self, tokens, position_ids, attention_mask, recv_buffer=None):
"""If recv_buffer is none, we will allocate one on the fly."""
args = get_args()
output_tensor = self._forward_step_helper(tokens, position_ids,
attention_mask,
recv_buffer=recv_buffer)
if self.inference_context:
self.inference_context.sequence_len_offset += tokens.size(1)
logits = None
if mpu.is_pipeline_last_stage():
if args.sequence_parallel:
logits = gather_from_tensor_model_parallel_region(output_tensor)
else:
logits = output_tensor
return logits
return wrapper
def _with_pipelining_forward_step_wrapper(_with_pipelining_forward_step):
@wraps(_with_pipelining_forward_step)
def wrapper(self, tokens, position_ids, attention_mask, micro_batch_size):
"""No interleaving is supported."""
args = get_args()
sequence_length = tokens.size(1)
batch_size = tokens.size(0)
num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
if last_chunk > 0:
num_micro_batches += 1
logits = None
if mpu.is_pipeline_last_stage():
args = get_args()
if getattr(args, "task", False) and args.task[0] == 'needlebench':
logits = torch.empty(
(batch_size, 100, args.padded_vocab_size),
dtype=torch.float32, device=torch.cuda.current_device())
else:
logits = torch.empty(
(batch_size, sequence_length, args.padded_vocab_size),
dtype=torch.float32, device=torch.cuda.current_device())
recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)
for micro_batch_index in range(num_micro_batches):
start = micro_batch_index * micro_batch_size
end = min(start + micro_batch_size, batch_size)
this_micro_batch_size = end - start
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
if this_micro_batch_size != micro_batch_size:
recv_buffer = None
output = self._forward_step_helper(tokens2use, position_ids2use,
attention_mask,
recv_buffer=recv_buffer)
if self.inference_context:
self.inference_context.batch_size_offset += this_micro_batch_size
if mpu.is_pipeline_last_stage():
if args.sequence_parallel:
output = gather_from_tensor_model_parallel_region(output)
logits[start:end, ...] = output
if self.inference_context:
self.inference_context.sequence_len_offset += sequence_length
self.inference_context.batch_size_offset = 0
return logits
return wrapper
def _allocate_recv_buffer(batch_size, sequence_length):
"""Receive happens between the layers with size [s, b, h]."""
if mpu.is_pipeline_first_stage():
return None
args = get_args()
if args.sequence_parallel:
sequence_length = sequence_length // mpu.get_tensor_model_parallel_world_size()
recv_size = (sequence_length, batch_size, args.hidden_size)
return torch.empty(recv_size,
dtype=_get_recv_buffer_dtype(args),
device=torch.cuda.current_device())