RFC: SwiGLU Operator Fusion
Metadata
| Item | Content |
|---|---|
| Status | Approved |
| Author | genius52 |
| Created Date | 2026-02-06 |
| Related Links | https://gitcode.com/Ascend/msmodeling/pull/75 |
1. Overview
The SwiGLU (Swish Gated Linear Unit) activation function consists of linear transformations, SiLU activation, and element-wise multiplication. Traditional execution approaches break it down into separate operators, leading to significant kernel launch overhead and inefficient memory access.
This RFC proposes implementing SwiGLU fusion using a pattern matching approach based on PyTorch graphs. The solution uses torch.ops.tensor_cast.swiglu operator to replace the computational patterns of SwiGLU activation. The performance is enhanced through Sink Split optimization that converts static parameters into dynamic inputs.
2. Solution Design
2.1 Recommended Approach
Implement SwiGLU fusion using PyTorch Graph-based Pattern Matching that identifies SwiGLU patterns and replaces them with calls to torch.ops.tensor_cast.swiglu.default, combined with Sink Split optimization for performance enhancement. The solution leverages pattern registration and replacement mechanisms from pattern matching infrastructure.
Core Implementation Files
tensor_cast/compilation/patterns/swiglu.py: SwiGLU pattern definition and registrationtensor_cast/compilation/passes/pattern_match_pass.py: Pattern matching and replacement pass implementationtensor_cast/compilation/freezing_passes/sink_split_pass.py: Sink Split optimization implementation
Interface
- Custom operator:
tensor_cast::swiglu- input:
gate: Tensor,up: Tensor - output:
Tensor,swiglu(gate, up)
- input:
Current Scope of SwiGLU Operator:
The current torch.ops.tensor_cast.swiglu operator only handles the activation computation segment, not including the generation of gate and up projections. These projections are passed as parameters to the operator.
graph LR
A[x] --> B[gate_proj]
B[gate_proj] --> D[silu/sigmoid]
A --> C[up_proj]
C[up_proj] --> F[mul]
D --> F
F --> G[swiglu_out]
style B fill:#ccf,stroke:#333,stroke-width:2px
style C fill:#ccf,stroke:#333,stroke-width:2px
style D fill:#fcc,stroke:#333,stroke-width:2px
style F fill:#fcc,stroke:#333,stroke-width:2px
Implementation Details:
- Activation segment only: The operator matches and replaces only the activation computation:
gate → fp32 conversion → sigmoid → fp16 conversion → mul with up - Gate and Up as inputs: The linear transformations for gate and up projections are generated upstream and passed as inputs to the swiglu operator
- Future matmul fusion: Current implementation does not handle matrix multiplications for gate and up projections, this will be completed during GMM operator integration
Core Implementation
Based on tensor_cast/compilation/patterns/swiglu.py implementation:
- SwiGLUPattern class: Defines SwiGLU pattern matching and replacement logic
- create method: Returns pattern, replacement, get_inputs triplet
- pattern function: Matches activation computation segment: gate → fp32 conversion → sigmoid → fp16 conversion → mul with up
- replacement function: Uses
torch.ops.tensor_cast.swigluto replace original pattern - Supported data types: torch.float16, torch.bfloat16
2.2 SwiGLU Pattern Detection and Fusion
2.2.1 Pattern Detection and Registration
Based on the implementation in tensor_cast/compilation/patterns/swiglu.py, the pattern detection happens through a registration mechanism:
- Iterate through each data type, get pattern, replacement, example_inputs
- Register patterns via PyTorch pattern matcher's
register_patternfunction - After registration, patterns become available for matching in computation graphs
Note: Pattern matching focuses on the activation computation segment: gate → fp32 conversion → sigmoid → fp16 conversion → multiplication with up tensor.
2.2.2 Pattern Matching and Replacement Process
PatternMatchPass performs iterative pattern matching in the computation graph by loading registered patterns and detecting SwiGLU patterns. The process replaces matched patterns with single torch.ops.tensor_cast.swiglu.default calls and continues optimizing until no more patterns can be matched. This approach uses direct pattern-to-operator replacement instead of complex grouping strategies, achieving simpler and more efficient fusion.
2.2.3 Placement in Compilation Pipeline
graph LR
A[Before Freezing] --> B[Quantization Passes]
B --> C[Pattern Match Pass]
C --> D[Pattern Match Passes]
D --> E[After Freezing]
E --> F[Sink Split Pass]
style C fill:#ffcccc,stroke:#333,stroke-width:2px
style F fill:#ccffcc,stroke:#333,stroke-width:2px
2.2.4 Rationale for the New Approach
The new approach using PyTorch Graph-based Pattern Matching and Sink Split:
- Simplified Architecture: Direct pattern-to-operator replacement eliminates complex grouping logic
- Enhanced Flexibility: Pattern registration system allows easy addition of new operator patterns
- Better Performance: Sink Split optimization converts static parameters into dynamic inputs for memory efficiency
- Seamless Integration:
torch.ops.tensor_cast.swigluintegrates naturally with existing tensor_cast operations - Higher Compatibility: Works with various quantization strategies and model structures through flexible pattern matching
2.3 Performance Modeling
The performance characteristics of torch.ops.tensor_cast.swiglu are handled through the standard tensor_cast performance modeling infrastructure.
FLOPs Calculation
- Matrix multiplication operations (upstream):
2 * M * N * K - SiLU activation operations:
3 * M * N(sigmoid + multiplication + one more multiplication) - Total computation depends on upstream operators and the fused activation
2.4 Finding and Validating the Operator
Registration locations:
tensor_cast/compilation/patterns/swiglu.py: SwiGLU pattern registrationtensor_cast/compilation/passes/pattern_match_pass.py: Pattern matching infrastructure
Graph pattern recognition:
- Pattern gate → fp32 conversion → sigmoid → fp16 conversion → multiplication with up tensor
- Ignores transparent operations like reshape/cast
- Supports both torch.float16 and torch.bfloat16 data types
Validation methods:
- Check if SwiGLU patterns are registered via
register_all_patterns() - Run compilation pipeline with pattern matching enabled
- Use graph observers to verify
torch.ops.tensor_cast.swiglucalls appear after transformations
2.5 Alternative Approaches
2.5.1 Individual Pass Approach (Abandoned)
The original approach using individual passes like swiglu_fusion_pass.py was abandoned for the following reasons:
- Poor Performance: Individual passes created significant overhead due to complex grouping logic
- Inflexibility: Difficult to adapt to different model structures and optimization requirements
- Limited Extensibility: Hard to add support for new operator patterns
- Inability to Fuse with GMM Operators: The grouping approach prevented seamless integration with GMM operations
2.5.2 Front-end Fusion
Combining linear transformations, SiLU activation and multiplication during model definition: SwiGLU(x, W_gate, W_up) = Silu(x @ W_gate) * (x @ W_up)
Comparison:
- Individual Pass (Abandoned): Poor performance, inflexibility, poor extensibility, inability to fuse with GMM operators
- Front-end Fusion: Compile-time optimizations, hardware-specific optimizations, better memory locality vs. lack of flexibility, high loading overhead, difficult to maintain
- This RFC: High performance, strong flexibility, good extensibility, seamless integration vs. requires graph matching, higher development complexity
2.5.3 Why Back-end Fusion is Better Suited for SwiGLU
Back-end fusion is the optimal choice for SwiGLU given its specific usage scenarios:
-
Quantization Strategy Selection:
- SwiGLU is commonly used in large language models that typically employ int4/int8 quantization for optimal performance
- Code explicitly excludes quantized nodes because post-quantization linear transformations differ from original floating-point computation patterns
- Back-end fusion can dynamically adjust based on quantization strategies to ensure optimal fusion on non-quantized operators
-
Graph Structure Flexibility:
- In large models, SiLU branches may consist of various variants (SiLU, sigmoid variants, etc.)
- Model compilation may introduce transparent nodes and graph optimizations
- Matching stability requirements necessitate pattern recognition during the compilation phase
-
Performance Maximization and Safety:
- SwiGLU typically appears in groups within models (e.g., multiple parallel SwiGLUs in FeedForward layers)
- Back-end approach enables precise dependency checking to ensure independence and safety of these groups
- Achieves higher optimization potential and performance gains compared to front-end fusion
- Explicit exclusion of quantized nodes prevents performance degradation caused by quantization
Advantages Summary:
- Computational graphs with fewer nodes and better kernel fusion opportunities
- More accurate performance modeling through aggregated matmul properties
- Enhanced cycle detection ensures fusion safety and prevents calculation errors
- Explicit exclusion of quantized nodes to avoid performance degradation after quantization
- Superior performance compared to the abandoned individual pass approach
- Seamless integration with GMM operators for improved hardware utilization
2.6 Sink Split Optimization
Based on tensor_cast/compilation/freezing_passes/sink_split_pass.py, the Sink Split optimization enhances SwiGLU fusion by converting static split parameters into dynamic inputs for performance improvement.
2.6.1 SwiGLU Sink Split Working Mechanism
SwiGLU Split Configuration:
# Binary operation configuration
binary_ops = [
torch.ops.aten.mul.Tensor,
torch.ops.tensor_cast.swiglu.default,
]
for op in binary_ops:
# gate[0] and up[1] can be split, output[0] can be split
add_config(op, {0, 1}, {0})
Sink Split Working Principle:
Traditional Split Pattern:
# Multiple getitem+split combinations
getitem1 = input_tensor[0] # Slice
split1 = split(getitem1, size1) # Split
getitem2 = input_tensor[1] # Slice
split2 = split(getitem2, size2) # Split
# SwiGLU needs to combine results
swiglu_out = swiglu(split1, split2) # Fusion
Sink Split Optimization:
# Single split tree merging
dynamic_size = get_dynamic_sizes() # Convert static to dynamic input
output_list = split(input_tensor, dynamic_size) # One-time split
swiglu_out = swiglu(output_list[0], output_list[1]) # Direct indexing
Key Formulas:
- Static→Dynamic Conversion:
split_sizes=[a,b,c]→dynamic_sizes=a+b+c, unified memory allocation - Tree Merging Optimization: Time complexity O(n) → O(1), reduces memory allocation and fragmentation
- Memory Continuity: Physical memory continuous → Logical slice access, improves cache hit rate
2.6.2 Performance Improvement Effects
SwiGLU achieves following performance improvements through Sink Split optimization:
- Memory Access Efficiency: Reduces memory fragmentation and improves data locality
- Kernel Call Reduction: Merges split operations into single operations, decreasing communication overhead
- Hardware Utilization Enhancement: Uses more continuous data access patterns to optimize hardware resource usage
2.7 SwiGLU and GMM Fusion
GMM Fusion Mechanism: SwiGLU not only optimizes independently but also deeply integrates with various types of GMM (Grouped MatMul) operators to achieve maximum performance enhancement:
Supported GMM Types:
torch.ops.tensor_cast.static_quant_linear: Static quantized linear layertorch.ops.tensor_cast.static_quant_linear_int4: Static quantized int4 linear layertorch.ops.tensor_cast.fp8_linear: FP8 linear layertorch.ops.tensor_cast.mxfp4_linear: MXFP4 linear layertorch.ops.tensor_cast.grouped_matmul: Generic grouped matrix multiplication
GMM-SwiGLU Joint Optimization:
- Parameter Unification: GMM and SwiGLU share split optimization configuration with unified iterator handling
- Tree Merging: GMM configured via
add_config(op, {0}, {0})works collaboratively with SwiGLU tree merging mechanism - Static to Dynamic Conversion: Converts static split parameters to dynamic inputs to reduce memory fragmentation
- Hardware Optimization: Achieves end-to-end memory access optimization by aggregating SwiGLU outputs with GMM inputs
SwiGLU Split Configuration:
# Binary operation configuration
binary_ops = [
torch.ops.aten.mul.Tensor,
torch.ops.tensor_cast.swiglu.default,
]
for op in binary_ops:
# gate[0] and up[1] can be split, output[0] can be split
add_config(op, {0, 1}, {0})
Sink Split Working Principle:
Traditional Split Pattern:
# Multiple getitem+split combinations
getitem1 = input_tensor[0] # Slice
split1 = split(getitem1, size1) # Split
getitem2 = input_tensor[1] # Slice
split2 = split(getitem2, size2) # Split
# SwiGLU needs to combine results
swiglu_out = swiglu(split1, split2) # Fusion
Sink Split Optimization:
# Single split tree merging
dynamic_size = get_dynamic_sizes() # Convert static to dynamic input
output_list = split(input_tensor, dynamic_size) # One-time split
swiglu_out = swiglu(output_list[0], output_list[1]) # Direct indexing
Key Formulas:
- Static→Dynamic Conversion:
split_sizes=[a,b,c]→dynamic_sizes=a+b+c, unified memory allocation - Tree Merging Optimization: Time complexity O(n) → O(1), reduces memory allocation and fragmentation
- Memory Continuity: Physical memory continuous → Logical slice access, improves cache hit rate
2.6.2 Performance Improvement Effects
SwiGLU achieves following performance improvements through Sink Split optimization:
- Memory Access Efficiency: Reduces memory fragmentation and improves data locality
- Kernel Call Reduction: Merges split operations into single operations, decreasing communication overhead
- Hardware Utilization Enhancement: Uses more continuous data access patterns to optimize hardware resource usage
3. Implementation Plan
3.1 Implementation Steps
- Pattern Registration: Implement SwiGLU pattern definitions for data types
- Matching Integration: Integrate with PatternMatchPass for replacement
- Optimization Configuration: Configure SinkSplitPass for split optimization
- GMM Integration: Configure SwiGLU fusion with various GMM operators
- Validation Testing: Verify functionality correctness and performance gains
3.2 Expected Benefits
- Architecture Simplification: Direct replacement eliminates complex grouping logic
- Performance Optimization: Enhanced memory access and computation efficiency through Sink Split
- Strong Extensibility: Easy to support new operator patterns and quantization strategies
- Hardware Utilization: Collaborative enhancement with GMM operators for optimal hardware resource efficiency