from contextlib import nullcontext
from typing import Optional, Union
import torch
import torch.distributed
from torch import Tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.core.enums import Fp8Recipe
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_block import TransformerBlock, TransformerBlockSubmodules
from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.fp8_utils import get_fp8_context
from megatron.core.extensions.transformer_engine import te_checkpoint
class Qwen3vlTransformerBlock(TransformerBlock):
"""Qwen3vl Transformer class."""
def __init__(
self,
config: TransformerConfig,
spec: Union[TransformerBlockSubmodules, ModuleSpec],
post_layer_norm: bool = True,
pre_process: bool = True,
post_process: bool = True,
):
super().__init__(config=config, spec=spec, post_layer_norm=post_layer_norm, pre_process=pre_process, post_process=post_process)
def _deepstack_process(
self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor
):
visual_pos_masks = visual_pos_masks.to(hidden_states.device)
visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds
hidden_states[visual_pos_masks, :] = local_this
return hidden_states
def forward(self,
hidden_states: Union[Tensor, WrappedTensor],
attention_mask: Optional[Tensor],
context: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[BaseInferenceContext] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
):
inference_context = deprecate_inference_params(inference_context, inference_params)
if isinstance(hidden_states, WrappedTensor):
hidden_states = hidden_states.unwrap()
if not self.pre_process:
hidden_states = self.input_tensor
if inference_context and not self.training:
inference_context.current_batch_size = hidden_states.size(1)
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
if self.config.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed
use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed
outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext()
with rng_context, outer_fp8_context:
if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
deepstack_visual_embeds=deepstack_visual_embeds,
use_inner_fp8_context=use_inner_fp8_context,
)
else:
for layer_idx, layer in enumerate(self.layers):
inner_fp8_context = (
get_fp8_context(self.config, layer.layer_number - 1)
if use_inner_fp8_context
else nullcontext()
)
with self.offload_context, inner_fp8_context:
hidden_states, context = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):
hidden_states += deepstack_visual_embeds[layer_idx]
if (
torch.is_grad_enabled()
and self.config.cpu_offloading
and self.group_prefetch_offload_commit_async is not None
):
hidden_states = self.group_prefetch_offload_commit_async(hidden_states)
if self.final_layernorm is not None:
hidden_states = self.final_layernorm(hidden_states)
hidden_states = make_viewless_tensor(
inp=hidden_states, requires_grad=True, keep_graph=True
)
return hidden_states
def _checkpointed_forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
context: Tensor,
context_mask: Tensor,
rotary_pos_emb: Tensor,
attention_bias: Tensor,
packed_seq_params: PackedSeqParams,
deepstack_visual_embeds: list,
use_inner_fp8_context: bool,
):
"""Forward method with activation checkpointing."""
def custom(start: int, end: int):
def custom_forward(
hidden_states, attention_mask, context, context_mask, rotary_pos_emb
):
for index in range(start, end):
layer = self._get_layer(index)
inner_fp8_context = (
get_fp8_context(self.config, layer.layer_number - 1)
if use_inner_fp8_context
else nullcontext()
)
with inner_fp8_context:
hidden_states, context = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
attention_bias=attention_bias,
inference_context=None,
packed_seq_params=packed_seq_params,
)
return hidden_states, context
return custom_forward
def checkpoint_handler(forward_func):
"""Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`"""
if self.config.fp8:
return te_checkpoint(
forward_func,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
)
else:
return tensor_parallel.checkpoint(
forward_func,
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
)
if self.config.recompute_method == 'uniform':
layer_idx = 0
while layer_idx < self.num_layers_per_pipeline_rank:
hidden_states, context = checkpoint_handler(
custom(layer_idx, layer_idx + self.config.recompute_num_layers)
)
layer_idx += self.config.recompute_num_layers
elif self.config.recompute_method == 'block':
recompute_skip_num_layers = 0
for layer_idx in range(self.num_layers_per_pipeline_rank):
if self.config.fp8 and not hidden_states.requires_grad:
recompute_skip_num_layers += 1
if (
layer_idx >= recompute_skip_num_layers
and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers
):
hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1))
else:
hidden_states, context = custom(layer_idx, layer_idx + 1)(
hidden_states, attention_mask, context, context_mask, rotary_pos_emb
)
if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):
hidden_states += deepstack_visual_embeds[layer_idx]
else:
raise ValueError("Invalid activation recompute method.")
return hidden_states