Compute by Caching
Diffusion models iterate over multiple timesteps during inference. All blocks participate in computation within each timestep, and the similarity of latents between adjacent steps leads to significant redundant computation. MindIE SD provides the following caching acceleration methods to reduce redundant computation:
- DiTCache: Caches and reuses intermediate results at the block granularity, suitable for models with many blocks.
- AttentionCache: Caches and reuses attention computation results at the Attention layer granularity, suitable for models where attention computation accounts for a large proportion.
- Timestep Optimization: Reduces or skips some timesteps during the diffusion process, suitable for scenarios requiring fine-grained control over inference steps.
Each caching strategy can be used independently, or the most suitable combination can be selected based on model characteristics.
- DiTCache First: Block-granularity caching with the strongest generality. Recommended to try first.
- AttentionCache Alternative: Suitable for models with a high proportion of attention computation. Finer granularity than DiTCache.
- Timestep Optimization Supplement: Complements other caching strategies. Can further reduce the number of steps on top of DiTCache or AttentionCache.
Interface Description
CacheConfig
Cache configuration class that defines parameters such as the caching method, number of blocks, and number of steps.
from mindiesd import CacheConfig
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
method |
str |
Yes | - | Caching method, "attention_cache" or "dit_block_cache" |
blocks_count |
int |
Yes | - | Number of blocks per step |
steps_count |
int |
Yes | - | Total number of iteration steps |
step_start |
int |
No | 0 |
Step index to start caching |
step_interval |
int |
No | 1 |
Caching interval in steps |
step_end |
int |
No | 10000 |
Step index to end caching |
block_start |
int |
No | 0 |
Block index to start caching |
block_end |
int |
No | 10000 |
Block index to end caching |
CacheAgent
Cache agent class that manages cache application based on configuration.
from mindiesd import CacheAgent
Constructor Parameters:
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
config |
CacheConfig |
Yes | - | Cache configuration object |
apply Method:
apply(function: callable, *args, **kwargs) -> Any
| Parameter | Type | Required | Description |
|---|---|---|---|
function |
callable |
Yes | The function to execute (block or attention module) |
*args |
- | No | Positional arguments for the function |
**kwargs |
- | No | Keyword arguments for the function |
DitCache
-
Background
During inference, the DiT model iterates over T steps. Each step t fully computes all blocks, and each block contains numerous computation operations (as shown in the figure below). However, the latents between adjacent steps are very similar, causing nearly identical intermediate results to be computed repeatedly during inference, resulting in computational redundancy and slow inference speed.

-
Principle
Based on the activation similarity between adjacent iteration steps or between adjacent blocks, local model features are reused, specified DiTBlocks are skipped, redundant computation is reduced, and model inference is accelerated.
-
Optimization Method
Using a search script, the minimum number of blocks to skip is calculated based on the speedup ratio. Then, starting from the start/end blocks, all possible combinations are traversed to find the configuration with the minimum MSE loss as the optimal solution. The core optimization is: when a cache hit occurs, the computation results of a specific block range in stepN are directly reused in stepM, turning the entire forward propagation computation of the DiTBlock sequence into a simple tensor read operation.

- Calculate the minimum number of blocks to cache based on the speedup ratio.
- In each step, since block0 needs to be computed, traversal starts from block1, and the three smallest groups of block start and block end are found through calculation.
- Traverse all possible combinations obtained in the previous step, calculate the MSE before and after caching, and find the configuration with the minimum MSE loss as the optimal solution.
- Configure the parameters obtained from the above steps into the model, and enable caching during inference to complete the acceleration.
- Optimization Procedure
-
Call the CacheConfig and CacheAgent interfaces.
from mindiesd import CacheConfig, CacheAgent -
Initialize CacheConfig in the model's initialization method.
config = CacheConfig( method="dit_block_cache", blocks_count=len(transformer.single_blocks), # Number of blocks with caching enabled steps_count=args.infer_steps, # Total number of inference iteration steps step_start=args.cache_start_steps, # Step index to start caching step_interval=args.cache_interval, # Interval steps for forced recomputation step_end=args.infer_steps-1, # Step index to stop caching block_start=args.single_block_start, # Block index to start caching in each step block_end=args.single_block_end # Block index to stop caching in each step ) -
Initialize the cache variable in the Transformer's init method.
self.cache = None -
Initialize CacheAgent and assign it to the block.
cache_agent = CacheAgent(config) # Enable ditcache pipeline.transformer.cache = CacheAgent(config) -
In the Transformer's forward method, use the apply method to enable caching for inference. The first argument of the apply method is the block, and the remaining arguments are consistent with the original code.
for index_block, block in enumerate(self.transformer_blocks): # Enable ditcache hidden_states, encoder_hidden_states = self.cache.apply( block, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=attention_kwargs, txt_pad_len=txt_pad_len )
-
-
Example
For detailed examples, please refer to the cache directory under examples.
AttentionCache
-
Background
During inference, the model iterates over T steps. Each step contains multiple blocks, and each block contains numerous computation operations, such as STA (as shown in the figure below). However, the attention layers within blocks between adjacent steps are quite similar, causing nearly identical intermediate results to be computed repeatedly during inference, resulting in computational redundancy and slow inference speed.

-
Principle
Based on the feature similarity between adjacent timesteps, unlike DiTCache, AttentionCache reuses attention computation results within blocks, thereby skipping some attention layers, reducing redundant computation, and accelerating model inference.
-
Optimization Method
Using a search script, the minimum number of attention operations to skip is calculated based on the speedup ratio. Then, the start and end steps are traversed, and among all possible combinations, the configuration with the minimum MSE loss is found as the optimal solution. The core optimization is: using the principle of trading space for time, the cached computation results of attention layers of blocks in stepN are directly reused in stepM, significantly reducing the computation workload of attention layers.

- Based on the speedup ratio, calculate the minimum number of attention operations to skip.
- Calculate min_interval and step_end based on the start step and min_skip_attention, traverse all possible results, calculate the MSE before and after caching, and find the configuration with the minimum MSE loss as the optimal solution.
- Configure the parameters obtained from the above steps into the model, and enable caching during inference to complete the acceleration.
-
Optimization Procedure
-
Call the CacheConfig and CacheAgent interfaces.
from mindiesd import CacheConfig, CacheAgent -
Initialize CacheConfig. For attention_cache, block_start and block_end can use default values.
config = CacheConfig( method="attention_cache", blocks_count=len(transformer.transformer_blocks), # Number of blocks contained in the transformer model steps_count=args.infer_steps, # Total number of inference iteration steps step_start=args.start_step, # Step index to start caching step_interval=args.attentioncache_interval, # Interval steps for forced recomputation step_end=args.end_step # Step index to stop caching ) -
Initialize the cache variable in the init method of the Transformer's block module.
self.cache = None -
Initialize CacheAgent and assign it to the block.
cache_agent = CacheAgent(config) # Cache the attention part within the block for block in transformer.transformer_blocks: block.cache = cache_agent -
In the forward method of the Transformer's block module, use the apply method to enable caching for inference. The first argument of the method is the original inference function, and the remaining arguments are consistent with the original code.
# Enable attention cache attn_output = self.cache.apply( self.attn, hidden_states=img_modulated, encoder_hidden_states=txt_modulated, encoder_hidden_states_mask=encoder_hidden_states_mask, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, )
-
-
FAQ
Q: Qwen-Image-Edit-2509 reports an error during inference after enabling AttentionCache: RuntimeError: NPU out of memory.
A: Enabling AttentionCache increases memory consumption. Single-card memory is prone to insufficiency. Using eight-card inference is recommended.
Timestep Optimization
-
Principle
By reducing, adjusting, or skipping certain steps in the denoising process of the diffusion model, the number of DiTModules is reduced while minimizing precision loss, avoiding redundant computation, and accelerating model inference.
-
Optimization Method
- Modify the timestep value: for example, reducing from 50 to 20 to improve inference speed.
- Adastep sampling: An adaptive, dynamic timestep skipping algorithm. Its core idea: evaluate the current state of the latent in real-time during inference, dynamically deciding to skip several steps with small inter-step differences, to achieve fast convergence. Currently, this method is only used in CogVideoX; other models do not use it, replaced by DiTCache and AttentionCache.