import types
from typing import List, Tuple
import torch
from .cache import CacheState
class DiTBlockCache(torch.nn.Module):
"""Cache-aware block wrapper used only on the configured cache range."""
def __init__(
self,
block,
state: CacheState,
block_index,
block_start: int,
block_end: int,
make_wrapped_forward,
):
super().__init__()
self._inner = block
self._state = state
self._block_index = block_index
self._block_start = block_start
self._block_end = block_end
self.forward = types.MethodType(
make_wrapped_forward(self)(block.forward),
self,
)
def __getattr__(self, item):
try:
return super().__getattr__(item)
except AttributeError:
if hasattr(self._inner, item):
return getattr(self._inner, item)
raise
def apply(self, func: callable, *args, **kwargs):
hidden_states = kwargs.pop("hidden_states", None)
if hidden_states is None:
raise ValueError("[DiTBlockCache] Input 'hidden_states' is None.")
encoder_hidden_states = kwargs.pop("encoder_hidden_states", None)
if self._state.reuse:
return self._reuse(hidden_states, encoder_hidden_states)
if encoder_hidden_states is None:
res = func(hidden_states, *args, **kwargs)
else:
res = func(hidden_states, encoder_hidden_states, *args, **kwargs)
self._update_cache(res, hidden_states, encoder_hidden_states)
return res
def _reuse(self, hidden_states, encoder_hidden_states):
state = self._state
if state.delta_hidden is None:
raise RuntimeError("[DiTBlockCache] Cache delta is empty before reuse.")
is_range_start = self._block_index == self._block_start
if state.delta_encoder is not None:
if encoder_hidden_states is None:
raise ValueError("[DiTBlockCache] 'encoder_hidden_states' is required for two-output cache reuse.")
if is_range_start:
return (
hidden_states + state.delta_hidden,
encoder_hidden_states + state.delta_encoder,
)
return hidden_states, encoder_hidden_states
return hidden_states + state.delta_hidden if is_range_start else hidden_states
def _update_cache(self, res, ori_hidden_states, ori_encoder_hidden_states):
state = self._state
output_count = len(res) if isinstance(res, (List, Tuple)) else 1
if output_count not in (1, 2):
raise RuntimeError(f"[DiTBlockCache] The output count must be 1 or 2, but got {output_count}.")
is_range_start = self._block_index == self._block_start
is_range_end = self._block_index == (self._block_end - 1)
if is_range_start:
state.range_hidden = ori_hidden_states
state.range_encoder = ori_encoder_hidden_states
if not is_range_end:
return
if state.range_hidden is None:
raise RuntimeError("[DiTBlockCache] Missing cache range input for hidden_states.")
if output_count == 2:
hidden_states, encoder_hidden_states = res
if hidden_states is None or encoder_hidden_states is None:
raise RuntimeError("[DiTBlockCache] Cache function output is None.")
if state.range_encoder is None:
raise ValueError("[DiTBlockCache] 'encoder_hidden_states' is required when output count is 2.")
state.delta_hidden = hidden_states - state.range_hidden
state.delta_encoder = encoder_hidden_states - state.range_encoder
return
if res is None:
raise RuntimeError("[DiTBlockCache] Cache function output is None.")
state.delta_hidden = res - state.range_hidden
state.delta_encoder = None