Cache-Based Acceleration Features
DiTCache
-
Background
During inference, DiT models iterate through
Tsteps, and every step computes the full set of blocks. Each block contains a large amount of compute, as shown below. However, latents from adjacent steps are highly similar, so nearly identical intermediate results are recomputed repeatedly, which introduces redundant compute and slows down inference.
-
Principle
By reusing local model features based on activation similarity between adjacent sampling steps or adjacent blocks, DiTCache skips selected DiT blocks, reduces redundant computation, and accelerates inference.
-
Optimization method
A search script first determines the minimum number of blocks that need to be skipped to achieve the target acceleration ratio. It then scans candidate start and end blocks and picks the configuration with the lowest MSE among all valid combinations. When the cache hits, the cached result from a selected block range at step
Nis reused directly at stepM, so a full DiTBlock forward pass becomes a lightweight tensor read.
- Compute the minimum number of blocks that need to be cached for the target speedup ratio.
- Because
block0must always be computed, start scanning fromblock1and search for the three best(block_start, block_end)candidates. - Evaluate all candidate combinations by measuring the MSE before and after caching, then choose the one with the smallest MSE loss.
- Apply the resulting configuration in the model and enable cache during inference.
-
Workflow
-
Import
CacheConfigandCacheAgent.from mindiesd import CacheConfig, CacheAgent -
Initialize
CacheConfigwhen the model is created.config = CacheConfig( method="dit_block_cache", blocks_count=len(transformer.single_blocks), # number of cache-enabled blocks steps_count=args.infer_steps, # total inference steps step_start=args.cache_start_steps, # first step index that uses cache step_interval=args.cache_interval, # forced recompute interval step_end=args.infer_steps-1, # last step index that uses cache block_start=args.single_block_start, # first cache-enabled block in each step block_end=args.single_block_end # last cache-enabled block in each step ) -
Initialize the cache variable in the transformer
initmethod.self.cache = None -
Create
CacheAgentand assign it to the block.cache_agent = CacheAgent(config) # enable DiTCache pipeline.transformer.cache = CacheAgent(cache_config) -
Call
applyin the transformerforwardpath. The first argument is the block itself, and the remaining arguments match the original implementation.for index_block, block in enumerate(self.transformer_blocks): # enable DiTCache hidden_states, encoder_hidden_states = self.cache.apply( block, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=attention_kwargs, txt_pad_len=txt_pad_len )
-
-
Example
See the cache example for a concrete implementation.
AttentionCache
-
Background
Inference iterates over
Tsteps, and each step contains multiple blocks with expensive compute, including STA, as shown below. Attention layers in blocks across adjacent steps are often highly similar, so nearly identical intermediate results are recomputed repeatedly and inference becomes slower.
-
Principle
Unlike DiTCache, AttentionCache reuses the computed Attention result inside a block and skips selected Attention layers based on similarity across adjacent timesteps, reducing redundant computation and improving inference speed.
-
Optimization method
A search script first computes the minimum number of Attention executions that must be skipped for the target acceleration ratio. It then scans candidate start and end steps and chooses the configuration with the lowest MSE among all valid combinations. The key idea is to trade space for time by directly reusing cached Attention results from step
Nat stepM.
- Compute the minimum number of Attention executions that must be skipped from the requested speedup ratio.
- Based on the starting step and
min_skip_attention, derivemin_intervalandstep_end, traverse all valid candidates, and pick the one with minimum MSE loss. - Apply the resulting configuration and enable AttentionCache during inference.
-
Workflow
-
Import
CacheConfigandCacheAgent.from mindiesd import CacheConfig, CacheAgent -
Initialize
CacheConfig. Forattention_cache,block_start, andblock_end, the default values are usually sufficient.config = CacheConfig( method="attention_cache", blocks_count=len(transformer.transformer_blocks), # number of blocks in the transformer steps_count=args.infer_steps, # total inference steps step_start=args.start_step, # first step index that uses cache step_interval=args.attentioncache_interval, # forced recompute interval step_end=args.end_step # last step index that uses cache ) -
Initialize the cache variable in each transformer block.
self.cache = None -
Create
CacheAgentand attach it to each block.cache_agent = CacheAgent(config) # cache only the attention part inside each block for block in transformer.transformer_blocks: block.cache = cache_agent -
Use
applyinside the blockforwardmethod. The first argument is the original attention function and the remaining arguments match the original implementation.# enable attention cache attn_output = self.cache.apply( self.attn, hidden_states=img_modulated, encoder_hidden_states=txt_modulated, encoder_hidden_states_mask=encoder_hidden_states_mask, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, )
-
-
FAQ
Q: Why does Qwen-Image-Edit-2509 report
RuntimeError: NPU out of memoryafter AttentionCache is enabled?A: AttentionCache increases graphics memory usage. On a single card, memory can become insufficient. Eight-card inference is recommended for this workload.
Timestep optimization
-
Principle
By reducing, adjusting, or skipping selected denoising steps in diffusion models, timestep optimization lowers the number of executed DiT modules and avoids redundant compute while trying to preserve output quality.
-
Optimization method
- Modify the timestep count directly, for example from 50 steps down to 20 steps, to improve inference speed.
- Use Adastep sampling, an adaptive and dynamic timestep-skipping algorithm. Its core idea is to evaluate the current latent state during inference and skip groups of steps whose changes are small enough to allow faster convergence. This method is currently used only in CogVideoX; on other models it has been replaced by DiTCache and AttentionCache.