import json
import torch
import transformers
from mindspeed_mm.models.text_encoder.hunyuan15_byt5.format_prompt import MultilingualPromptFormat
class Hunyuan15GlyphTokenizer:
def __init__(
self,
**config,
):
pretrained_model_name_or_path = config.get("pretrained_model_name_or_path", None)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
cache_dir=None)
self.color_ann_path = config.get("color_ann_path", "")
self.font_ann_path = config.get("font_ann_path", "")
self.byt5_max_length = config.get("byt5_max_length", 256)
self.add_special_token(
tokenizer=self.tokenizer,
text_encoder=None,
add_color=True,
add_font=True,
color_ann_path=self.color_ann_path,
font_ann_path=self.font_ann_path,
multilingual=True,
)
self.prompt_format = MultilingualPromptFormat(
font_path=self.font_ann_path,
color_path=self.color_ann_path
)
@staticmethod
def add_special_token(
tokenizer,
text_encoder,
add_color,
add_font,
color_ann_path,
font_ann_path,
multilingual=False,
token_len=1510,
token_add=True,
):
"""
Add special tokens for color and font to tokenizer and text encoder.
Args:
token_add: (bool) Whether to add special token.
token_len:max_token_length.
text_encoder: Hunyuan_video second text encoder.
tokenizer: Huggingface tokenizer.
add_color (bool): Whether to add color tokens.
add_font (bool): Whether to add font tokens.
color_ann_path (str): Path to color annotation JSON.
font_ann_path (str): Path to font annotation JSON.
multilingual (bool): Whether to use multilingual font tokens.
"""
with open(font_ann_path, 'r') as f:
idx_font_dict = json.load(f)
with open(color_ann_path, 'r') as f:
idx_color_dict = json.load(f)
if multilingual:
font_token = [f'<{font_code[:2]}-font-{idx_font_dict[font_code]}>' for font_code in idx_font_dict]
else:
font_token = [f'<font-{i}>' for i in range(len(idx_font_dict))]
color_token = [f'<color-{i}>' for i in range(len(idx_color_dict))]
additional_special_tokens = []
if add_color:
additional_special_tokens += color_token
if add_font:
additional_special_tokens += font_token
if token_add:
tokenizer.add_tokens(additional_special_tokens, special_tokens=True)
if text_encoder is not None:
text_encoder.resize_token_embeddings(len(tokenizer) if tokenizer is not None else token_len,
mean_resizing=False)
def __getattr__(self, name):
if name in dir(self):
return super().__getattr__(name)
else:
return getattr(self.tokenizer, name)
def _extract_glyph_texts(self, prompt):
"""
Extract glyph texts from prompt using regex pattern.
Args:
prompt: Input prompt string containing quoted text.
Returns:
List[str]: List of extracted glyph texts (deduplicated if multiple).
"""
en_results = []
start = 0
while True:
open_idx = prompt.find('"', start)
if open_idx == -1:
break
close_idx = prompt.find('"', open_idx + 1)
if close_idx == -1:
break
en_results.append(prompt[open_idx + 1: close_idx])
start = close_idx + 1
zh_results = []
start = 0
while True:
open_idx = prompt.find('“', start)
if open_idx == -1:
break
close_idx = prompt.find('”', open_idx + 1)
if close_idx == -1:
break
zh_results.append(prompt[open_idx + 1: close_idx])
start = close_idx + 1
seen = set()
final = []
for t in en_results + zh_results:
if t not in seen:
seen.add(t)
final.append(t)
return final
@staticmethod
def get_byt5_text_tokens(byt5_tokenizer, byt5_max_length, text_prompt):
"""
Tokenize text prompt for byT5 model.
Args:
byt5_tokenizer: The byT5 tokenizer.
byt5_max_length: Maximum sequence length for tokenization.
text_prompt: Text prompt string to tokenize.
Returns:
tuple[torch.Tensor, torch.Tensor]:
- input_ids: Tokenized input IDs.
- attention_mask: Attention mask tensor.
"""
byt5_text_inputs = byt5_tokenizer(
text_prompt,
padding="max_length",
max_length=byt5_max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
return byt5_text_inputs
def __call__(
self,
prompt,
**kwargs,
):
glyph_texts = self._extract_glyph_texts(prompt)
if len(glyph_texts) > 0:
text_styles = [{'color': None, 'font-family': None} for _ in range(len(glyph_texts))]
formatted_text = self.prompt_format.format_prompt(glyph_texts, text_styles)
byt5_text_inputs = self.get_byt5_text_tokens(
self.tokenizer, self.byt5_max_length, formatted_text
)
return byt5_text_inputs
return {
"input_ids": torch.zeros((1, 256)),
"attention_mask": torch.zeros((1, self.byt5_max_length), dtype=torch.int64)
}