import importlib
import json
from typing import Union
import torch
def get_diffusers_transformer_module(config_or_json_path: Union[str, dict]):
if isinstance(config_or_json_path, str):
with open(config_or_json_path, encoding="utf-8") as f:
config = json.load(f)
elif isinstance(config_or_json_path, dict):
config = config_or_json_path
else:
raise TypeError(
f"config_or_json_path must be a dict or a path to config.json, but got {type(config_or_json_path)}."
)
try:
module = importlib.import_module("diffusers.models")
except ModuleNotFoundError as e:
raise ModuleNotFoundError("Import models from diffusers failed.") from e
class_name = config.get("_class_name")
if class_name is None or not isinstance(class_name, str):
raise ValueError(
"Unable to find _class_name attribute or _class_name not a str from the diffusers transformer config.json."
)
if class_name not in dir(module):
raise KeyError(f"The class {class_name} is not supported by diffusers.")
model_class = getattr(module, class_name)
return model_class
_model_class_to_vae_stride = {
"WanTransformer3DModel": (4, 8),
"HunyuanVideoTransformer3DModel": (4, 8),
"HunyuanVideo15Transformer3DModel": (4, 16),
"Default": (4, 8),
}
def model_class_to_vae_stride(model_class: str) -> tuple:
if model_class not in _model_class_to_vae_stride.keys():
model_class = "Default"
return _model_class_to_vae_stride.get(model_class)
def generate_hunyuanvideo_input(**kwargs):
batch_size = kwargs.get("batch_size")
assert isinstance(batch_size, int)
seq_lens = kwargs.get("seq_lens")
assert isinstance(seq_lens, int)
dtype = kwargs.get("dtype")
attention_mask = torch.zeros(
[batch_size, seq_lens],
device=torch.device("meta"),
dtype=dtype,
)
return {
"encoder_attention_mask": attention_mask,
}
def generate_hunyuanvideo15_input(**kwargs):
res = {}
batch_size = kwargs.get("batch_size")
assert isinstance(batch_size, int)
seq_lens = kwargs.get("seq_lens")
assert isinstance(seq_lens, int)
dtype = kwargs.get("dtype")
attention_mask = torch.zeros(
[batch_size, seq_lens],
device=torch.device("meta"),
dtype=dtype,
)
res["encoder_attention_mask"] = attention_mask
res["encoder_attention_mask_2"] = attention_mask
text_embed_2_dim = kwargs.get("text_embed_2_dim")
if text_embed_2_dim is not None:
res["encoder_hidden_states_2"] = torch.zeros(
[batch_size, seq_lens, text_embed_2_dim],
device=torch.device("meta"),
dtype=dtype,
)
image_embed_dim = kwargs.get("image_embed_dim")
if image_embed_dim is not None:
res["image_embeds"] = SafeMetaTensor((image_embed_dim, image_embed_dim, image_embed_dim), dtype=dtype)
return res
_model_class_input = {
"HunyuanVideoTransformer3DModel": generate_hunyuanvideo_input,
"HunyuanVideo15Transformer3DModel": generate_hunyuanvideo15_input,
}
def model_class_to_input(model_class):
return _model_class_input.get(model_class, lambda **kwargs: {})
def get_ulysses_split_dim(hidden_states: torch.Tensor, ulysses_size: int) -> int:
if hidden_states is None:
raise ValueError("hidden_states is None")
split_dim = -1
if hidden_states.shape[-2] // 2 % ulysses_size == 0:
split_dim = -2
return split_dim
class SafeMetaTensor(torch.Tensor):
"""
A meta-device tensor subclass that enables safe boolean indexing (e.g., `x[mask]`)
during model tracing or initialization.
Plain meta tensors crash when used in boolean indexing because PyTorch cannot
determine output shapes without real data. This class sidesteps the issue by
returning a valid-shaped meta tensor directly in `__getitem__`.
Use this class for all meta tensors involved in indexing—especially boolean masks
and the tensors they index into. Avoid operations that convert it back to plain tensors
(e.g., `.bool()`, `.float()`), as they break the subclass protection.
"""
@staticmethod
def __new__(cls, shape, dtype=None, device=None, requires_grad=False):
if device is not None and device != torch.device("meta"):
raise ValueError("SafeMetaTensor only supports 'meta' device.")
return torch.empty(shape, dtype=dtype, device="meta", requires_grad=requires_grad).as_subclass(cls)
def __bool__(self):
return True
def __getitem__(self, idx):
if isinstance(idx, torch.Tensor) and idx.dtype == torch.bool and idx.device.type == "meta":
return SafeMetaTensor((self.shape[0],) + self.shape[1:], dtype=self.dtype)
return super().__getitem__(idx)