@startuml QuaRot Sequence
!theme plain
skinparam sequenceMessageAlign center
skinparam roundcorner 10
skinparam maxmessagesize 60
skinparam ParticipantPadding 20
participant "Runner" as Runner
participant "QuaRotProcessor" as Processor
participant "ModelAdapter" as Adapter
participant "QuaRotOnlineProcessor" as OnlineProcessor
==pre_run Phase==
Runner -> Processor: pre_run()
activate Processor
note over Processor: Obtain rotation configuration information from ModelAdapter.
Processor -> Adapter: Obtain configuration information.
activate Adapter
note right of Adapter: get_ln_fuse_map()\nget_bake_names()\nget_rotate_map(block_size)\nReturn the fusion mapping, bake names, and rotation mapping pair.
Adapter --> Processor: (pre_run_fused_ln, fused_map,\npre_run_bake_names, bake_names,\npre_run_pairs, rotate_pairs)
deactivate Adapter
note over Processor: Perform rotation operations in the pre_run phase.
Processor -> Processor: get_rotate_command(pre_run_pairs)
note right: Convert RotatePair into a RotateCommand list.
Processor -> Processor: Perform rotation operations.
note right: _fuse_norm(): fuses the LayerNorm and Linear layers.\n_bake_mean(): fuses the mean into the Linear layer weights.\n_rotate(): performs rotation.\n(For example, perform right rotation of embed_tokens.)
note over Processor: Prepare rotation commands for subsequent loops.
Processor -> Processor: get_rotate_command(rotate_pairs)
note right: Generate rotation commands for all layers.\nSave them to self.rotate_commands.
opt If config.online is True
Processor -> OnlineProcessor: pre_run()
activate OnlineProcessor
note over OnlineProcessor: Create an online rotation matrix.
OnlineProcessor -> OnlineProcessor: _get_available_device()
note right: Obtain available devices (NPU, CUDA, or CPU).
OnlineProcessor -> Adapter: Obtain model information.
activate Adapter
note right of Adapter: get_num_attention_heads()\nget_head_dim()
Adapter --> OnlineProcessor: (num_attn_heads, head_dim)
deactivate Adapter
OnlineProcessor -> OnlineProcessor: Create a rotation matrix and save rotation information.
note right: _add_online_rotations(): creates rot1, rot2, and rot_online_o_proj\n(Hadamard mode)\ntorch.eye(): creates an identity matrix.\nCreate QuarotOnlineRotationInfo to save rotation information.
OnlineProcessor --> Processor: Return
deactivate OnlineProcessor
end
Processor --> Runner: Return
deactivate Processor
==Loop through each DecoderLayer==
loop For each DecoderLayer (layer_idx = 0 to num_hidden_layers-1)
Runner -> Processor: preprocess(request)
activate Processor
note right of Runner: request.name = "model.layers.{layer_idx}"\nInformation about the module that contains the DecoderLayer.
note over Processor: Filter operations related to this layer based on the prefix.
Processor -> Processor: Filter operations related to this layer.
note right: prefix = request.name\n_filter_fused_map(): filters fused mapping.\n_filter_bake_names(): filters bake names.\n_filter_commands(): filters rotation commands.
note over Processor: Perform rotation operations at this layer.
Processor -> Processor: Perform rotation operations.
note right: _fuse_norm(): fuses the LayerNorm and Linear layers.\n_bake_mean(): fuses the mean into the Linear layer weights.\n_rotate(): performs rotation. The instruction comes from the model adapter.
opt If config.online is True
Processor -> OnlineProcessor: preprocess(request)
activate OnlineProcessor
OnlineProcessor -> Adapter: Obtain the module pair of this layer.
activate Adapter
note right of Adapter: get_layer_wise_ov_pair()\nget_layer_wise_up_down_pair()\nReturn the o_proj/v_proj and up_proj/down_proj mapping pairs.
Adapter --> OnlineProcessor: (ov_pairs, up_down_pairs)
deactivate Adapter
OnlineProcessor -> OnlineProcessor: Perform online rotations.
note right: Extract layer_idx.\nIf the conditions are met:\n- online_rotate_o_proj_input(): rotates the o_proj input.\n- online_rotate_down_proj(): rotate the down_proj.\n(Use the Kronecker product to rotate the matrix.)
OnlineProcessor -> OnlineProcessor: Register the online rotation HookIR.
note right: Register QuarotKroneckerRotationHookIR for down_proj.\nRegister QuarotHeadsRotationHookIR for o_proj.\nThese hooks are used to perform rotations during forward propagation.
OnlineProcessor --> Processor: Return
deactivate OnlineProcessor
end
Processor --> Runner: Return
deactivate Processor
end
@enduml