import importlib
import torch
import torch_npu
import torch.nn as nn
from mindspeed_mm.models.text_encoder.hunyuan_mllm_text_encoder import HunyuanMLLmModel
from mindspeed_mm.models.text_encoder.hunyuan15_mllm_text_encoder import Hunyuan15MLLmModel
from mindspeed_mm.models.text_encoder.hunyuan15_glyph_text_encoder import Hunyuan15GlyphModel
from mindspeed_mm.models.text_encoder.stepllm_text_encoder import StepLLmModel
from mindspeed_mm.utils.utils import get_dtype
TRANSFORMERS_TEXT_ENCODER_MAPPING = {
"T5": "T5EncoderModel",
"MT5": "MT5EncoderModel",
"UMT5": "UMT5EncoderModel",
"CLIP": "CLIPTextModel",
"Auto": "AutoModel",
"BertModel": "BertModel",
"CLIPWithProjection": "CLIPTextModelWithProjection",
}
CUSTOM_TEXT_ENCODER_MAPPING = {
"StepLLmModel": StepLLmModel,
"HunyuanMLLmModel": HunyuanMLLmModel,
"Hunyuan15MLLmModel": Hunyuan15MLLmModel,
"Hunyuan15GlyphModel": Hunyuan15GlyphModel,
}
class TextEncoder(nn.Module):
"""
Configuration for initializing one or more Text Encoder Model instances.
Args:
config (Optional[dict, list(dict)]): the general config for Text Encoder Model.
- If `config` is a dictionary, it specifies the parameters for a single Text Encoder Model instance.
e.g.
{
(1) args for our feautrues
"backend": type-str, "hf" or "om",
"model_id": type-str, "AutoModel" or other automodel name,
"dtype": type-str, dtype of text encoder
(2) args for automodel.from_pretrained() of transformers or openmind
"pretrained_model_name_or_path": type-str, local path or hub path,
"local_files_only": type-bool,
...
}
- If `config` is a list of dictionaries, each dictionary in the list will be used to instantiate a separate Text Encoder Model instance,
effectively allowing the creation of multiple Text Encoder based on different configurations.
"""
def __init__(self, config):
super().__init__()
if isinstance(config, list) or isinstance(config, tuple):
self.text_encoders = nn.ModuleList()
for config_i in config:
text_encoder_i = self._init_text_encoder(config_i)
self.text_encoders.append(text_encoder_i)
else:
self.text_encoders = self._init_text_encoder(config)
def get_model(self):
return self.text_encoders
def encode(self, input_ids, mask, offload_cpu=False, **kwargs):
"""
Encode input sequences with encoder-wise offload in different scenarios.
Scenarios:
1. Single encoder + single input tensor -> (Tensor outputs, Tensor masks)
2. Multiple encoders + single-step inputs (List[Tensor]) -> (List[Tensor] outputs, List[Tensor] masks)
(Each encoder processes its corresponding input tensor)
3. Single/multiple encoders + multi-step inputs (List[List[Tensor]]) -> (List[List[Tensor]] outputs, List[List[Tensor]] masks)
(Each encoder processes its corresponding input tensor in each step)
Args:
input_ids (Union[Tensor, List[Tensor], List[List[Tensor]]]): Input token IDs matching scenario structures
mask (Union[Tensor, List[Tensor], List[List[Tensor]]]): Attention masks matching input_ids structure
**kwargs: Contains 'offload_cpu' (bool, optional) - only used in scenario 3
Returns:
Tuple[Union[Tensor, List[Tensor], List[List[Tensor]]], Union[Tensor, List[Tensor], List[List[Tensor]]]]:
- outputs: Encoded outputs matching input structure
- masks: Corresponding attention masks
"""
encoders = self.text_encoders
is_multi_encoder = isinstance(encoders, nn.ModuleList)
num_encoders = len(encoders) if is_multi_encoder else 1
if not is_multi_encoder and isinstance(input_ids, torch.Tensor):
return self._single_encode(encoders, input_ids, mask, **kwargs)
if is_multi_encoder and isinstance(input_ids, list) and not isinstance(input_ids[0], list):
outputs, masks = [], []
for encoder, inp, msk in zip(encoders, input_ids, mask):
out, att = self._single_encode(encoder, inp, msk, **kwargs)
outputs.append(out)
masks.append(att)
return outputs, masks
num_steps = len(input_ids)
if not is_multi_encoder:
encoders = [encoders]
outputs = [[None for _ in range(num_encoders)] for _ in range(num_steps)]
result_masks = [[None for _ in range(num_encoders)] for _ in range(num_steps)]
for encoder_idx, encoder in enumerate(encoders):
if offload_cpu:
encoder.to(torch_npu.npu.current_device())
for step in range(num_steps):
step_input = input_ids[step][encoder_idx]
step_mask = mask[step][encoder_idx]
out, att = self._single_encode(encoder, step_input, step_mask, **kwargs)
outputs[step][encoder_idx] = out
result_masks[step][encoder_idx] = att
if offload_cpu:
encoder.to(torch.device('cpu'))
torch_npu.npu.empty_cache()
return outputs, result_masks
def _single_encode(self, text_encoder, input_ids, attention_mask, **kwargs):
*BN, L = input_ids.shape
input_ids = input_ids.to(text_encoder.device).view(-1, L)
attention_mask = attention_mask.to(text_encoder.device).view(-1, L)
model_attention_mask = attention_mask if text_encoder.use_attention_mask else None
model_kwargs = {}
if text_encoder.using_kwargs:
for k in text_encoder.using_kwargs:
if k in kwargs.keys():
model_kwargs[k] = kwargs[k]
else:
raise ValueError(f"{k} is not in kwargs")
output = text_encoder(
input_ids=input_ids,
attention_mask=model_attention_mask,
output_hidden_states=text_encoder.hidden_state_skip_layer is not None,
**model_kwargs
)
emb = output[text_encoder.output_key]
if text_encoder.hidden_state_skip_layer:
emb = emb[-(text_encoder.hidden_state_skip_layer + 1)]
if text_encoder.ucg_rate is not None and text_encoder.ucg_rate > 0.0:
def expand_dims_like(x, y):
while x.dim() != y.dim():
x = x.unsqueeze(-1)
return x
emb = (
expand_dims_like(
torch.bernoulli(
(1.0 - text_encoder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device, dtype=emb.dtype)),
emb,
)
* emb
)
if text_encoder.output_key in ["last_hidden_state", "hidden_states"]:
emb = emb.view(*BN, emb.shape[-2], -1)
elif text_encoder.output_key in ["pooler_output", "text_embeds"]:
emb = emb.view(*BN, -1)
else:
raise NotImplementedError(f"Text encoder output_key: {text_encoder.output_key} is not implemented! ")
if text_encoder.use_attention_mask:
attention_mask = model_attention_mask
attention_mask = attention_mask.view(*BN, -1)
return emb, attention_mask
def _init_text_encoder(self, config):
if not isinstance(config, dict):
config = config.to_dict()
backend = config.pop("hub_backend")
use_attention_mask = config.pop("use_attention_mask", True)
ucg_rate = config.pop("ucg_rate", None)
output_key = config.pop("output_key", "last_hidden_state")
hidden_state_skip_layer = config.pop("hidden_state_skip_layer", None)
using_kwargs = config.pop("using_kwargs", None)
config["pretrained_model_name_or_path"] = config.pop("from_pretrained")
config["torch_dtype"] = get_dtype(config.pop("dtype"))
config["local_files_only"] = True
try:
from megatron.training import get_args
config["trust_remote_code"] = get_args().trust_remote_code
except (ImportError, AssertionError, AttributeError):
config["trust_remote_code"] = False
model_id = config.pop("model_id")
config.pop("load_in_8bit", None)
if model_id in TRANSFORMERS_TEXT_ENCODER_MAPPING:
module = importlib.import_module("transformers")
self.automodel_name = TRANSFORMERS_TEXT_ENCODER_MAPPING[model_id]
automodel = getattr(module, self.automodel_name)
text_encoder = automodel.from_pretrained(**config)
elif model_id in CUSTOM_TEXT_ENCODER_MAPPING:
automodel = CUSTOM_TEXT_ENCODER_MAPPING[model_id]
text_encoder = automodel.from_pretrained(**config)
else:
raise ValueError(f"Model ID {model_id} is not supported for text encoder")
setattr(text_encoder, "ucg_rate", ucg_rate)
setattr(text_encoder, "use_attention_mask", use_attention_mask)
setattr(text_encoder, "output_key", output_key)
setattr(text_encoder, "hidden_state_skip_layer", hidden_state_skip_layer)
setattr(text_encoder, "using_kwargs", using_kwargs)
if hidden_state_skip_layer and output_key not in ["hidden_states"]:
raise ValueError("If use hidden_state_skip, the output_keys must in [`hidden_states`]")
return text_encoder