@startuml QuaRot Sequence
!theme plain
skinparam sequenceMessageAlign center
skinparam roundcorner 10
skinparam maxmessagesize 60
skinparam ParticipantPadding 20

participant "Runner" as Runner
participant "QuaRotProcessor" as Processor
participant "ModelAdapter" as Adapter
participant "QuaRotOnlineProcessor" as OnlineProcessor

==pre_run Phase==

Runner -> Processor: pre_run()
activate Processor

note over Processor: Obtain rotation configuration information from ModelAdapter.

Processor -> Adapter: Obtain configuration information.
activate Adapter
note right of Adapter: get_ln_fuse_map()\nget_bake_names()\nget_rotate_map(block_size)\nReturn the fusion mapping, bake names, and rotation mapping pair.
Adapter --> Processor: (pre_run_fused_ln, fused_map,\npre_run_bake_names, bake_names,\npre_run_pairs, rotate_pairs)
deactivate Adapter

note over Processor: Perform rotation operations in the pre_run phase.

Processor -> Processor: get_rotate_command(pre_run_pairs)
note right: Convert RotatePair into a RotateCommand list.

Processor -> Processor: Perform rotation operations.
note right: _fuse_norm(): fuses the LayerNorm and Linear layers.\n_bake_mean(): fuses the mean into the Linear layer weights.\n_rotate(): performs rotation.\n(For example, perform right rotation of embed_tokens.)

note over Processor: Prepare rotation commands for subsequent loops.

Processor -> Processor: get_rotate_command(rotate_pairs)
note right: Generate rotation commands for all layers.\nSave them to self.rotate_commands.

opt If config.online is True
    Processor -> OnlineProcessor: pre_run()
    activate OnlineProcessor
    
    note over OnlineProcessor: Create an online rotation matrix.
    
    OnlineProcessor -> OnlineProcessor: _get_available_device()
    note right: Obtain available devices (NPU, CUDA, or CPU).
    
    OnlineProcessor -> Adapter: Obtain model information.
    activate Adapter
    note right of Adapter: get_num_attention_heads()\nget_head_dim()
    Adapter --> OnlineProcessor: (num_attn_heads, head_dim)
    deactivate Adapter
    
    OnlineProcessor -> OnlineProcessor: Create a rotation matrix and save rotation information.
    note right: _add_online_rotations(): creates rot1, rot2, and rot_online_o_proj\n(Hadamard mode)\ntorch.eye(): creates an identity matrix.\nCreate QuarotOnlineRotationInfo to save rotation information.
    
    OnlineProcessor --> Processor: Return
    deactivate OnlineProcessor
end

Processor --> Runner: Return
deactivate Processor

==Loop through each DecoderLayer==

loop For each DecoderLayer (layer_idx = 0 to num_hidden_layers-1)
    Runner -> Processor: preprocess(request)
    activate Processor
    note right of Runner: request.name = "model.layers.{layer_idx}"\nInformation about the module that contains the DecoderLayer.
    
    note over Processor: Filter operations related to this layer based on the prefix.
    
    Processor -> Processor: Filter operations related to this layer.
    note right: prefix = request.name\n_filter_fused_map(): filters fused mapping.\n_filter_bake_names(): filters bake names.\n_filter_commands(): filters rotation commands.
    
    note over Processor: Perform rotation operations at this layer.
    
    Processor -> Processor: Perform rotation operations.
    note right: _fuse_norm(): fuses the LayerNorm and Linear layers.\n_bake_mean(): fuses the mean into the Linear layer weights.\n_rotate(): performs rotation. The instruction comes from the model adapter.
    
    opt If config.online is True
        Processor -> OnlineProcessor: preprocess(request)
        activate OnlineProcessor
        
        OnlineProcessor -> Adapter: Obtain the module pair of this layer.
        activate Adapter
        note right of Adapter: get_layer_wise_ov_pair()\nget_layer_wise_up_down_pair()\nReturn the o_proj/v_proj and up_proj/down_proj mapping pairs.
        Adapter --> OnlineProcessor: (ov_pairs, up_down_pairs)
        deactivate Adapter
        
        OnlineProcessor -> OnlineProcessor: Perform online rotations.
        note right: Extract layer_idx.\nIf the conditions are met:\n- online_rotate_o_proj_input(): rotates the o_proj input.\n- online_rotate_down_proj(): rotate the down_proj.\n(Use the Kronecker product to rotate the matrix.)
        
        OnlineProcessor -> OnlineProcessor: Register the online rotation HookIR.
        note right: Register QuarotKroneckerRotationHookIR for down_proj.\nRegister QuarotHeadsRotationHookIR for o_proj.\nThese hooks are used to perform rotations during forward propagation.
        
        OnlineProcessor --> Processor: Return
        deactivate OnlineProcessor
    end
    
    Processor --> Runner: Return
    deactivate Processor
end

@enduml