import json
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import transformers
from transformers.modeling_outputs import ModelOutput
from mindspeed_mm.models.text_encoder.hunyuan_mllm_text_encoder import HunyuanMLLmModel
@dataclass
class TextEncoderModelOutput(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
List of decoded texts.
"""
hidden_state: torch.FloatTensor = None
attention_mask: Optional[torch.LongTensor] = None
hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
text_outputs: Optional[list] = None
image_features: Optional[list] = None
class Hunyuan15MLLmModel(HunyuanMLLmModel):
def __init__(
self,
model,
template_info,
image_embed_interleave=2,
):
super().__init__(model=model, template_info=template_info, image_embed_interleave=image_embed_interleave)
@classmethod
def from_pretrained(cls, **config):
template_file_path = config.pop("template_file_path")
template_id = config.pop("template_id", "hyv-llm-encode-video")
with open(template_file_path, "r") as f:
templates = json.load(f)
image_embed_interleave = config.pop("image_embed_interleave", 4)
model_type = config.pop("model_type", "AutoModel")
model = getattr(transformers, model_type).from_pretrained(**config)
if hasattr(model, 'language_model'):
model = model.language_model
model.final_layer_norm = model.norm
model.requires_grad_(False)
return Hunyuan15MLLmModel(
model=model,
template_info=templates[template_id],
image_embed_interleave=image_embed_interleave,
)
def encode(
self,
batch_encoding,
use_attention_mask=True,
output_hidden_states=False,
hidden_state_skip_layer=2,
device=None,
use_template=True,
apply_final_norm=False,
):
"""
Args:
batch_encoding (dict): Batch encoding from tokenizer.
use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
Defaults to None.
output_hidden_states (bool): Whether to output hidden states. If False, return the value of
self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
output_hidden_states will be set True. Defaults to False.
do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
When self.produce is False, do_sample is set to True by default.
hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
If None, self.output_key will be used. Defaults to None.
return_texts (bool): Whether to return the decoded texts. Defaults to False.
"""
device = self.model.device if device is None else device
attention_mask = (
batch_encoding["attention_mask"].to(device) if use_attention_mask else None
)
outputs = self.model(
input_ids=batch_encoding["input_ids"].to(device),
attention_mask=attention_mask,
output_hidden_states=output_hidden_states
or hidden_state_skip_layer is not None,
)
if hidden_state_skip_layer is not None:
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
if hidden_state_skip_layer > 0 and apply_final_norm:
last_hidden_state = self.model.final_layer_norm(last_hidden_state)
else:
last_hidden_state = outputs[self.output_key]
crop_start = self.template_info.get("crop_start", None)
if use_template and crop_start > 0:
last_hidden_state = last_hidden_state[:, crop_start:]
attention_mask = (
attention_mask[:, crop_start:] if use_attention_mask else None
)
if output_hidden_states:
return TextEncoderModelOutput(
last_hidden_state, attention_mask, outputs.hidden_states
)
return TextEncoderModelOutput(last_hidden_state, attention_mask)