import math
from typing import Optional, Tuple, Union
import torch
from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from ..common.attention import WhisperAttention
from ..common.module import MultiModalModule
def shift_tokens_right(
input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
class WhisperPositionalEmbedding(nn.Embedding):
def __init__(
self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None
):
super().__init__(num_positions, embedding_dim)
def forward(self, input_ids, past_key_values_length=0, position_ids=None):
if position_ids is None:
return self.weight[
past_key_values_length: past_key_values_length + input_ids.shape[1]
]
else:
return self.weight[position_ids]
class WhisperEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_dim = config.d_model
self.head_dim = self.embed_dim // config.encoder_attention_heads
self.self_attn = WhisperAttention(
query_dim=self.embed_dim,
key_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
head_dim=self.head_dim,
dropout=config.attention_dropout,
proj_qv_bias=True,
proj_out_bias=True,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
query=hidden_states,
mask=attention_mask,
)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(
hidden_states, p=self.activation_dropout, training=self.training
)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(
hidden_states, min=-clamp_value, max=clamp_value
)
outputs = hidden_states
return outputs
class WhisperDecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_dim = config.d_model
self.head_dim = self.embed_dim // config.decoder_attention_heads
self.self_attn = WhisperAttention(
query_dim=self.embed_dim,
key_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
head_dim=self.head_dim,
dropout=config.attention_dropout,
proj_qv_bias=True,
proj_out_bias=True,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = WhisperAttention(
query_dim=self.embed_dim,
key_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
head_dim=self.head_dim,
dropout=config.attention_dropout,
proj_qv_bias=True,
proj_out_bias=True,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
query=hidden_states,
mask=attention_mask,
)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
hidden_states = self.encoder_attn(
query=hidden_states,
key=encoder_hidden_states,
mask=None,
)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(
hidden_states, p=self.activation_dropout, training=self.training
)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
outputs = hidden_states
return outputs
class WhisperEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
embed_dim = config.d_model
self.num_mel_bins = config.num_mel_bins
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.embed_positions.requires_grad_(False)
self.layers = nn.ModuleList(
[WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]
)
self.layer_norm = nn.LayerNorm(config.d_model)
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
def get_input_embeddings(self) -> nn.Module:
return self.conv1
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
def forward(
self,
input_features,
attn_mask=None,
):
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
attn_mask,
)
hidden_states = layer_outputs
hidden_states = self.layer_norm(hidden_states)
return hidden_states
class WhisperDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_target_positions
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(
config.vocab_size, config.d_model, self.padding_idx
)
self.embed_positions = WhisperPositionalEmbedding(
self.max_target_positions, config.d_model
)
self.layers = nn.ModuleList(
[WhisperDecoderLayer(config) for _ in range(config.decoder_layers)]
)
self.layer_norm = nn.LayerNorm(config.d_model)
mask = (
torch.ones(self.max_target_positions, self.max_target_positions)
.bool()
.triu_(1)
)
self.register_buffer("mask", mask, persistent=False)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids=None,
encoder_hidden_states=None,
):
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds + self.embed_positions(input_ids)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
mask = self.mask[None, None, : input_shape[-1], : input_shape[-1]]
mask = mask.expand(input_shape[0], 1, -1, -1)
for _, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer(
hidden_states,
attention_mask=mask,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = layer_outputs
hidden_states = self.layer_norm(hidden_states)
return hidden_states
class WhisperModel(nn.Module):
def __init__(self, config):
super().__init__()
self.encoder = WhisperEncoder(config)
self.decoder = WhisperDecoder(config)
def get_input_embeddings(self):
return self.decoder.embed_tokens
def set_input_embeddings(self, value):
self.decoder.embed_tokens = value
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def freeze_encoder(self):
self.encoder._freeze_parameters()
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
):
encoder_outputs = self.encoder(input_features)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs,
)
return decoder_outputs
class WhisperForConditionalGeneration_mm(MultiModalModule):
def __init__(self, config):
super().__init__(config=None)
self.config = core_transformer_config_from_args(get_args())
self.model = WhisperModel(config)
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.max_target_positions = config.max_target_positions
self.pad_token_id = config.pad_token_id
self.decoder_start_token_id = config.decoder_start_token_id
self.vocab_size = config.vocab_size
if hasattr(config, "ckpt_path"):
self.load_checkpoint(config.ckpt_path)
else:
print("Warning: no checkpoint found at ckpt_path, skipping loading ckpt.")
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
def get_output_embeddings(self):
return self.proj_out
def set_output_embeddings(self, new_embeddings):
self.proj_out = new_embeddings
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()
def freeze_encoder(self):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training.
"""
self.model.encoder._freeze_parameters()
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
):
decoder_input_ids = shift_tokens_right(
labels, self.pad_token_id, self.decoder_start_token_id
)
outputs = self.model(
input_features,
decoder_input_ids=decoder_input_ids,
)
lm_logits = self.proj_out(outputs)
return lm_logits
def compute_loss(self, output, label):
loss_fct = CrossEntropyLoss()
label = label.to(output.device)
loss = loss_fct(output.view(-1, self.vocab_size), label.reshape(-1))
return loss
def load_checkpoint(self, ckpt_path):
if ckpt_path and len(ckpt_path) > 0:
load_params = torch.load(ckpt_path, map_location="cpu")
print(self.load_state_dict(load_params, strict=False))
else:
print("Warning: ckpt path is None or empty, skipping loading ckpt.")