import inspect
import json
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import transformers
from transformers.modeling_outputs import ModelOutput
@dataclass
class HunyuanMLLmModelOutput(ModelOutput):
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class HunyuanMLLmModel(nn.Module):
def __init__(
self,
model,
template_info,
image_embed_interleave=2,
):
super().__init__()
self.model = model.to(model.dtype)
self.template_info = template_info
self.image_embed_interleave = image_embed_interleave
def forward(
self,
input_ids=None,
attention_mask=None,
pixel_values=None,
**kwargs
):
crop_start = self.template_info.get("crop_start", None)
model_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"output_hidden_states": True
}
if pixel_values is not None:
model_kwargs["pixel_values"] = pixel_values.to(self.model.dtype)
prompt_embeds = self.model(**model_kwargs).hidden_states[-(self.hidden_state_skip_layer + 1)]
if pixel_values is None:
if crop_start is not None and crop_start > 0:
prompt_embeds = prompt_embeds[:, crop_start:]
if attention_mask is not None:
attention_mask.set_(attention_mask[:, crop_start:].contiguous())
else:
image_emb_len = self.template_info.get("image_emb_len", 576)
image_emb_start = self.template_info.get("image_emb_start", 5)
image_emb_end = self.template_info.get("image_emb_end", 581)
double_return_token_id = self.template_info.get("double_return_token_id", 271)
if crop_start is not None and crop_start > 0:
text_crop_start = crop_start - 1 + image_emb_len
batch_indices, last_double_return_token_indices = torch.where(input_ids == double_return_token_id)
if last_double_return_token_indices.shape[0] == 3:
last_double_return_token_indices = torch.cat(
(last_double_return_token_indices, torch.tensor([input_ids.shape[-1]]).to(device=last_double_return_token_indices.device)
)
)
last_double_return_token_indices = last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1]
assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4
assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len
attention_mask_assistant_crop_start = last_double_return_token_indices - 4
attention_mask_assistant_crop_end = last_double_return_token_indices
prompt_embed_list = []
prompt_attention_mask_list = []
image_embed_list = []
image_attention_mask_list = []
for i in range(input_ids.shape[0]):
prompt_embed_list.append(
torch.cat(
(
prompt_embeds[i, text_crop_start:assistant_crop_start[i].item()],
prompt_embeds[i, assistant_crop_end[i].item():]
)
)
)
prompt_attention_mask_list.append(
torch.cat(
(
attention_mask[i, crop_start:attention_mask_assistant_crop_start[i].item()],
attention_mask[i, attention_mask_assistant_crop_end[i].item():]
)
)
)
image_embed_list.append(
prompt_embeds[i, image_emb_start:image_emb_end]
)
image_attention_mask_list.append(
torch.ones(image_embed_list[-1].shape[0]).to(prompt_embeds.device).to(attention_mask.dtype)
)
prompt_embed_list = torch.stack(prompt_embed_list)
prompt_attention_mask_list = torch.stack(prompt_attention_mask_list)
image_embed_list = torch.stack(image_embed_list)
image_attention_mask_list = torch.stack(image_attention_mask_list)
if 0 < self.image_embed_interleave < 6:
image_embed_list = image_embed_list[:, ::self.image_embed_interleave, :]
image_attention_mask_list = image_attention_mask_list[:, ::self.image_embed_interleave]
prompt_embeds = torch.cat((image_embed_list, prompt_embed_list), dim=1)
prompt_attention_mask = torch.cat((image_attention_mask_list, prompt_attention_mask_list), dim=1)
attention_mask.set_(prompt_attention_mask.contiguous())
return HunyuanMLLmModelOutput(
hidden_states=(prompt_embeds,) * (self.hidden_state_skip_layer + 1),
)
def __getattr__(self, name):
if name in dir(self):
return super().__getattr__(name)
else:
return getattr(self.model, name)
@classmethod
def from_pretrained(cls, **config):
template_file_path = config.pop("template_file_path")
template_id = config.pop("template_id", "hyv-llm-encode-video")
with open(template_file_path, "r") as f:
templates = json.load(f)
image_embed_interleave = config.pop("image_embed_interleave", 4)
model_type = config.pop("model_type", "AutoModel")
model = getattr(transformers, model_type).from_pretrained(**config)
return HunyuanMLLmModel(
model=model,
template_info=templates[template_id],
image_embed_interleave=image_embed_interleave,
)