import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
logger = logging.getLogger(__name__)
WrappedForwardFactory = Callable[[Callable[..., Any]], Callable[..., Any]]
MakeWrappedForward = Callable[[Any], WrappedForwardFactory]
BlockSetter = Callable[[Any], None]
GetBlocksWithSetters = Callable[[Any], Sequence[Tuple[Any, BlockSetter]]]
@dataclass(frozen=True)
class DiTBlockCacheSpec:
model_type: str
get_blocks_with_setters: GetBlocksWithSetters
make_wrapped_forward: MakeWrappedForward
_DIT_BLOCK_CACHE_SPECS: Dict[str, DiTBlockCacheSpec] = {}
def register_dit_block_cache_spec(class_name: str, spec: DiTBlockCacheSpec) -> None:
if not class_name:
raise ValueError("'class_name' must be a non-empty string.")
_DIT_BLOCK_CACHE_SPECS[class_name] = spec
def get_dit_block_cache_spec(class_name: Optional[str]) -> Optional[DiTBlockCacheSpec]:
if not class_name:
return None
return _DIT_BLOCK_CACHE_SPECS.get(class_name)
def _make_block_setter(container, index: int) -> BlockSetter:
def _set_block(new_block: Any) -> None:
container[index] = new_block
return _set_block
def _module_list_blocks_with_setters(blocks) -> list[tuple[Any, BlockSetter]]:
return [(block, _make_block_setter(blocks, idx)) for idx, block in enumerate(blocks)]
def replace_blocks_in_range(
blocks_with_setters: Sequence[Tuple[Any, BlockSetter]],
start: int,
end: int,
make_cache_block: Callable[[Any, int], Any],
) -> int:
from .cache_agent.dit_block_cache import DiTBlockCache
replaced = 0
bounded_end = min(end, len(blocks_with_setters))
for flat_idx in range(start, bounded_end):
block, setter = blocks_with_setters[flat_idx]
if isinstance(block, DiTBlockCache):
continue
setter(make_cache_block(block, flat_idx))
replaced += 1
return replaced
def _get_wan_blocks_with_setters(inner: Any) -> Sequence[Tuple[Any, BlockSetter]]:
if not hasattr(inner, "blocks"):
logger.warning("WanTransformer3DModel has no attribute 'blocks'.")
return []
pairs = _module_list_blocks_with_setters(inner.blocks)
if not pairs:
logger.warning("WanTransformer3DModel.blocks is empty.")
return pairs
def _get_hunyuanvideo_blocks_with_setters(
inner: Any,
) -> Sequence[Tuple[Any, BlockSetter]]:
if not hasattr(inner, "transformer_blocks"):
logger.warning("HunyuanVideoTransformer3DModel has no attribute 'transformer_blocks'.")
return []
if not hasattr(inner, "single_transformer_blocks"):
logger.warning("HunyuanVideoTransformer3DModel has no attribute 'single_transformer_blocks'.")
return []
pairs = _module_list_blocks_with_setters(inner.transformer_blocks)
pairs.extend(_module_list_blocks_with_setters(inner.single_transformer_blocks))
if not pairs:
logger.warning("HunyuanVideoTransformer3DModel transformer blocks are empty.")
return pairs
def _get_hunyuanvideo15_blocks_with_setters(
inner: Any,
) -> Sequence[Tuple[Any, BlockSetter]]:
if not hasattr(inner, "transformer_blocks"):
logger.warning("HunyuanVideo15Transformer3DModel has no attribute 'transformer_blocks'.")
return []
pairs = _module_list_blocks_with_setters(inner.transformer_blocks)
if not pairs:
logger.warning("HunyuanVideo15Transformer3DModel.transformer_blocks is empty.")
return pairs
def _wan_make_wrapped_forward(agent: Any) -> WrappedForwardFactory:
def _make_wrapped_forward(
orig_forward_bound: Callable[..., Any],
) -> Callable[..., Any]:
def _wrapped_forward(
_self_block: Any,
hidden_states: Any,
encoder_hidden_states: Any,
temb: Any,
rotary_emb: Any,
) -> Any:
return agent.apply(
orig_forward_bound,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
rotary_emb=rotary_emb,
)
return _wrapped_forward
return _make_wrapped_forward
def _hunyuanvideo_make_wrapped_forward(agent: Any) -> WrappedForwardFactory:
def _make_wrapped_forward(
orig_forward_bound: Callable[..., Any],
) -> Callable[..., Any]:
def _wrapped_forward(
_self_block: Any,
hidden_states: Any,
encoder_hidden_states: Any,
temb: Any,
attention_mask: Any,
image_rotary_emb: Any,
token_replace_emb: Any,
first_frame_num_tokens: Any,
) -> Any:
return agent.apply(
orig_forward_bound,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
attention_mask=attention_mask,
freqs_cis=image_rotary_emb,
token_replace_emb=token_replace_emb,
first_frame_num_tokens=first_frame_num_tokens,
)
return _wrapped_forward
return _make_wrapped_forward
def _hunyuanvideo15_make_wrapped_forward(agent: Any) -> WrappedForwardFactory:
def _make_wrapped_forward(
orig_forward_bound: Callable[..., Any],
) -> Callable[..., Any]:
def _wrapped_forward(
_self_block: Any,
hidden_states: Any,
encoder_hidden_states: Any,
temb: Any,
encoder_attention_mask: Any,
image_rotary_emb: Any,
) -> Any:
return agent.apply(
orig_forward_bound,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
attention_mask=encoder_attention_mask,
freqs_cis=image_rotary_emb,
)
return _wrapped_forward
return _make_wrapped_forward
register_dit_block_cache_spec(
"WanTransformer3DModel",
DiTBlockCacheSpec(
model_type="Wan",
get_blocks_with_setters=_get_wan_blocks_with_setters,
make_wrapped_forward=_wan_make_wrapped_forward,
),
)
register_dit_block_cache_spec(
"HunyuanVideoTransformer3DModel",
DiTBlockCacheSpec(
model_type="HunyuanVideo",
get_blocks_with_setters=_get_hunyuanvideo_blocks_with_setters,
make_wrapped_forward=_hunyuanvideo_make_wrapped_forward,
),
)
register_dit_block_cache_spec(
"HunyuanVideo15Transformer3DModel",
DiTBlockCacheSpec(
model_type="HunyuanVideo15",
get_blocks_with_setters=_get_hunyuanvideo15_blocks_with_setters,
make_wrapped_forward=_hunyuanvideo15_make_wrapped_forward,
),
)