from copy import deepcopy
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.core import mpu, InferenceParams
from megatron.core.distributed import DistributedDataParallel as LocalDDP
from megatron.inference.text_generation.communication import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank
from megatron.core.transformer.module import Float16Module as MegatronFloat16Module
from megatron.training import get_args
from megatron.training.checkpointing import load_checkpoint
from megatron.training.training import get_model
from megatron.training.utils import unwrap_model
from mindspeed_mm.utils.transformer_model_config import get_model_config
class ParallelWrapper:
def __init__(self, model):
model = get_model(model, wrap_with_ddp=False)
load_checkpoint(model, None, None, 'load')
self.model = unwrap_model(model, (torchDDP, LocalDDP, MegatronFloat16Module))[0].eval()
args = get_args()
self.vlm_config = deepcopy(args.mm.model)
self.text_decoder_config = get_model_config(self.vlm_config.text_decoder)
self.pipeline_size_larger_than_one = args.pipeline_model_parallel_size > 1
self.pipelining_batch_x_seqlen = getattr(self.text_decoder_config, 'inference_batch_times_seqlen_threshold', 512)
if self.vlm_config.generation_config.max_length and self.vlm_config.generation_config.max_new_tokens:
print(
"warning: Both `max_new_tokens` (= %s) and `max_length` (= %s) seem to have been set. `max_new_tokens` will take precedence. "
% (self.vlm_config.generation_config.max_new_tokens, self.vlm_config.generation_config.max_length))
self.inference_params = None
self.split_batch = getattr(self.vlm_config.generation_config, "split_batch", True)
def __call__(self, **kwargs):
"""Invocation of the forward methods. """
input_ids = kwargs.get("input_ids", None)
batch_size = input_ids.size(0)
sequence_length = input_ids.size(1)
if self.inference_params is None and hasattr(self.vlm_config.generation_config,
"kv_cache") and self.vlm_config.generation_config.kv_cache:
max_length = _get_max_length(self.vlm_config.generation_config, sequence_length)
self.inference_params = InferenceParams(batch_size, max_length)
model_forward_kwargs = kwargs
if self.pipeline_size_larger_than_one:
if not self.split_batch:
return self._with_pipelining_forward_step_without_split_batch(model_forward_kwargs)
else:
current_batch_x_seqlen = input_ids.size(0) * input_ids.size(1)
micro_batch_size = 1
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
micro_batch_size = max(1, self.pipelining_batch_x_seqlen // input_ids.size(1))
return self._with_pipelining_forward_step(model_forward_kwargs, micro_batch_size)
return self._no_pipelining_forward_step(model_forward_kwargs)
def __getattr__(self, item):
if item in self.__dict__:
return self.__dict__[item]
return getattr(self.model, item)
def _forward(self, **kwargs):
if self.inference_params is not None:
kwargs["inference_params"] = self.inference_params
return self.model(**kwargs)
def _forward_step_helper(self, model_forward_kwargs, recv_buffer=None):
"""Single forward step. Update the allocate memory flag so
only the first time the memory is allocated."""
tokens = model_forward_kwargs.get("input_ids", None)
batch_size = tokens.size(0)
sequence_length = tokens.size(1)
if recv_buffer is None:
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
recv_from_prev_pipeline_rank_(recv_buffer)
self.model.set_input_tensor(recv_buffer)
output_tensor = self._forward(**model_forward_kwargs)
send_to_next_pipeline_rank(output_tensor)
return output_tensor
def _no_pipelining_forward_step(self, model_forward_kwargs, recv_buffer=None):
"""If recv_buffer is none, we will allocate one on the fly."""
output_tensor = self._forward_step_helper(model_forward_kwargs, recv_buffer=recv_buffer)
tokens = model_forward_kwargs.get("input_ids", None)
if self.inference_params is not None:
self.inference_params.sequence_len_offset += tokens.size(1)
logits = None
if mpu.is_pipeline_last_stage():
logits = output_tensor
return logits
def _with_pipelining_forward_step_without_split_batch(self, model_forward_kwargs):
input_ids = model_forward_kwargs.get("input_ids", None)
batch_size = input_ids.size(0)
sequence_length = input_ids.size(1)
logits = None
if mpu.is_pipeline_last_stage():
args = get_args()
logits = torch.empty(
(batch_size, sequence_length, args.padded_vocab_size),
dtype=torch.float32, device=torch.cuda.current_device())
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
output = self._forward_step_helper(model_forward_kwargs, recv_buffer=recv_buffer)
if self.inference_params is not None:
self.inference_params.batch_size_offset += batch_size
if mpu.is_pipeline_last_stage():
if isinstance(output, dict):
logits = output["logits"]
if self.inference_params is not None:
self.inference_params.sequence_len_offset += sequence_length
self.inference_params.batch_size_offset = 0
return logits
def _with_pipelining_forward_step(self, model_forward_kwargs, micro_batch_size):
"""No interleaving is supported.
Ensure the first dimension of `model_forward_kwargs` is the batch size.
"""
first_dims = [v.shape[0] for k, v in model_forward_kwargs.items() if (k != "position_ids" and k != "cache_position" and v is not None)]
if "position_ids" in model_forward_kwargs.keys():
first_dims.append(model_forward_kwargs["position_ids"].shape[1])
if not len(set(first_dims)) == 1:
raise Exception(
"All values in model_forward_kwargs must have the same first dimension, which represents the batch size.")
tokens = model_forward_kwargs.get("input_ids")
sequence_length = tokens.size(1)
batch_size = tokens.size(0)
num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
if last_chunk > 0:
num_micro_batches += 1
logits = None
if mpu.is_pipeline_last_stage():
args = get_args()
logits = torch.empty(
(batch_size, sequence_length, args.padded_vocab_size),
dtype=torch.float32, device=torch.cuda.current_device())
recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)
for micro_batch_index in range(num_micro_batches):
start = micro_batch_index * micro_batch_size
end = min(start + micro_batch_size, batch_size)
this_micro_batch_size = end - start
model_forward_kwargs = {
key: (value[start:end, ...] if value is not None else None)
for key, value in model_forward_kwargs.items()
}
if this_micro_batch_size != micro_batch_size:
recv_buffer = None
output = self._forward_step_helper(model_forward_kwargs, recv_buffer=recv_buffer)
if self.inference_params is not None:
self.inference_params.batch_size_offset += this_micro_batch_size
if mpu.is_pipeline_last_stage():
if isinstance(output, dict):
logits = output["logits"][start:end, ...]
if self.inference_params is not None:
self.inference_params.sequence_len_offset += sequence_length
self.inference_params.batch_size_offset = 0
return logits
def _get_recv_buffer_dtype(args):
"""Receive happens between the layers."""
if args.fp32_residual_connection:
return torch.float
return args.params_dtype
def _allocate_recv_buffer(batch_size, sequence_length):
"""Receive happens between the layers with size [s, b, h]."""
if mpu.is_pipeline_first_stage():
return None
args = get_args()
vlm_config = deepcopy(args.mm.model)
text_decoder_config = get_model_config(vlm_config.text_decoder)
recv_size = (sequence_length, batch_size, text_decoder_config.hidden_size)
return torch.empty(recv_size,
dtype=_get_recv_buffer_dtype(args),
device=torch.cuda.current_device())
def _get_max_length(config, inputs_seq_length):
if config.max_new_tokens and not config.max_length:
max_length = inputs_seq_length + config.max_new_tokens
elif config.max_length and not config.max_new_tokens:
if config.max_length < inputs_seq_length:
raise ValueError(
f"generation config max length:{config.max_length} must larger than inputs seq length:{inputs_seq_length}")
max_length = config.max_length
elif config.max_length and config.max_new_tokens:
max_length = inputs_seq_length + config.max_new_tokens
else:
raise ValueError("You must set either `max_new_tokens` or `max_length`.")
return max_length