import os
import json
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from functools import cached_property
import torch
import torch_npu
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, StaticCache
from transformers.models.chameleon.modeling_chameleon import (
ChameleonRotaryEmbedding,
ChameleonMLP,
ChameleonImageVocabularyMapping,
repeat_kv
)
from megatron.core import mpu, tensor_parallel
from megatron.training import get_args
from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding
from mindspeed_mm.models.common.module import MultiModalModule
from mindspeed_mm.models.common.normalize import LlamaRMSNorm
class RotaryEmbedding(ChameleonRotaryEmbedding):
@torch.no_grad()
def forward(self, x, position_ids):
if self.inv_freq.dtype != torch.float32:
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.inv_freq = inv_freq.to(self.inv_freq.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos, sin
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = npu_rotary_position_embedding(q, cos, sin, 0)
k_embed = npu_rotary_position_embedding(k, cos, sin, 0)
return q_embed, k_embed
@dataclass
class ChameleonMLPConfig:
hidden_size: int = 4096
intermediate_size: int = 11008
mlp_bias: bool = False
hidden_act: str = "silu"
class ChameleonLayerNorm(nn.LayerNorm):
"""
LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta
from each shard separately to each head, instead of reducing. We can apply each head's own
gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed
in the last dimension. This module applies gamma/beta manually to fulfill this requirement.
"""
def __init__(self, hidden_size, model_parallel_size, n_heads_per_mp, *args, **kwargs):
if isinstance(hidden_size, int):
hidden_size = (hidden_size,)
super().__init__([model_parallel_size, *hidden_size], *args, **kwargs)
self.normalized_shape = (hidden_size[-1],)
self.n_heads_per_mp = n_heads_per_mp
def repeat_param(self, param):
return param.repeat_interleave(self.n_heads_per_mp, dim=0)
def forward(self, hidden_states):
hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5)
hidden_states = hidden_states * self.repeat_param(self.weight) + self.repeat_param(self.bias)
return hidden_states
class ChameleonAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
num_attention_heads,
num_key_value_heads,
hidden_size,
attention_dropout,
attention_bias,
rope_theta,
max_position_embeddings,
layer_idx: Optional[int] = None,
**kwargs
):
super().__init__()
self.layer_idx = layer_idx
self.attention_dropout = attention_dropout
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.is_causal = True
self.model_parallel_size = 1
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=attention_bias)
self.q_norm = ChameleonLayerNorm(
self.head_dim, self.model_parallel_size, self.num_heads // self.model_parallel_size
)
self.k_norm = ChameleonLayerNorm(
self.head_dim, self.model_parallel_size, self.num_key_value_heads // self.model_parallel_size
)
self._init_rope()
def _init_rope(self):
self.rotary_emb = RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
query_states = self.q_norm(query_states)
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
key_states = self.k_norm(key_states)
query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=-2)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[1]]
attn_output = torch_npu.npu_fusion_attention(
query_states,
key_states,
value_states.to(query_states.dtype),
atten_mask=causal_mask,
input_layout="BSND",
scale=self.head_dim ** -0.5,
head_num=self.num_heads,
keep_prob=1 - self.attention_dropout,
)[0]
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, None
class ChameleonDecoderLayer(nn.Module):
def __init__(self,
hidden_size,
intermediate_size,
mlp_bias,
hidden_act,
num_attention_heads,
num_key_value_heads,
attention_dropout,
attention_bias,
rms_norm_eps,
dropout,
rope_theta,
max_position_embeddings,
layer_idx: int
):
super().__init__()
self.hidden_size = hidden_size
self.self_attn = ChameleonAttention(
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_size=hidden_size,
attention_dropout=attention_dropout,
attention_bias=attention_bias,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
layer_idx=layer_idx,
)
self.mlp = ChameleonMLP(ChameleonMLPConfig(hidden_size, intermediate_size, mlp_bias, hidden_act))
self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
self.dropout = nn.Dropout(dropout)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[Union[torch.Tensor, bool]] = True,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if isinstance(output_attentions, torch.Tensor):
output_attentions = output_attentions.item()
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = residual + self.dropout(hidden_states)
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.dropout(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class ChameleonModel(nn.Module):
def __init__(
self,
pad_token_id,
vocab_size,
hidden_size,
intermediate_size,
num_hidden_layers,
num_attention_heads,
num_key_value_heads,
attention_dropout,
attention_bias,
mlp_bias,
hidden_act,
vocabulary_map_path,
swin_norm,
rms_norm_eps,
dropout,
rope_theta,
max_position_embeddings,
**kwargs,
):
super().__init__()
args = get_args()
self.padding_idx = pad_token_id
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id)
vocab_map = {}
configs_path = os.path.join(vocabulary_map_path, 'config.json')
with open(configs_path, 'r', encoding='utf-8') as file:
configs = json.load(file)
vocab_map = configs.get("vocabulary_map")
self.vocabulary_mapping = ChameleonImageVocabularyMapping(vocab_map)
self.layers = nn.ModuleList(
[
ChameleonDecoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
mlp_bias=mlp_bias,
hidden_act=hidden_act,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
attention_dropout=attention_dropout,
attention_bias=attention_bias,
rms_norm_eps=rms_norm_eps,
dropout=dropout,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
layer_idx=layer_idx
) for layer_idx in range(num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
self.gradient_checkpointing = args.recompute_granularity
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = tensor_parallel.checkpoint(
decoder_layer,
False,
hidden_states,
causal_mask,
position_ids,
torch.tensor(output_attentions) if output_attentions is not None else None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone()
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask.to(torch.bool)
class ChameleonForConditionalGeneration(MultiModalModule):
def __init__(
self,
max_position_embeddings,
vocab_size,
hidden_size,
intermediate_size,
pad_token_id,
num_layers,
num_heads,
num_key_value_heads,
attention_dropout,
attention_bias,
mlp_bias,
hidden_act,
vocabulary_map_path,
swin_norm,
rms_norm_eps,
dropout,
rope_theta,
**kwargs
):
super().__init__(config=None)
self.max_position_embeddings = max_position_embeddings
self.vocab_size = vocab_size
self.output_attentions = kwargs.get("output_attentions", True)
self.output_hidden_states = kwargs.get("output_hidden_states", False)
self.model = ChameleonModel(
pad_token_id=pad_token_id,
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_layers,
num_attention_heads=num_heads,
num_key_value_heads=num_key_value_heads,
attention_dropout=attention_dropout,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
hidden_act=hidden_act,
vocabulary_map_path=vocabulary_map_path,
swin_norm=swin_norm,
rms_norm_eps=rms_norm_eps,
dropout=dropout,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings
)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
self.forward_call_times = 0
self.mask_image_logits = False
def forward(
self,
input_ids: torch.LongTensor = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
outputs = self.model(
input_ids=input_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if self.mask_image_logits:
image_tokens = self.model.vocabulary_mapping.image_tokens
logits[:, :, image_tokens] = torch.finfo(logits.dtype).min
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@property
def device(self) -> torch.device:
"""The device of the module (assuming that all the module parameters are in the same device)."""
params = tuple(self.parameters())
if len(params) > 0:
return params[0].device
else:
buffers = tuple(self.buffers())
return buffers[0].device