Fused Softmax
In this section, you will use Triton to write a program of the fused softmax operation. In this process, you will learn:
- The advantages of kernel fusion for bandwidth-bound operations.
- Reduction operations in Triton.
Using Native PyTorch to Perform Softmax Operation on X Row by Row
import torch
import torch_npu
import triton
import triton.language as tl
def naive_softmax(x):
"""
Subtract the maximum element to avoid overflow. Softmax is invariant to this offset.
"""
# Read MN elements; write M elements.
x_max = x.max(dim=1)[0]
# Read MN + M elements; write MN elements.
z = x - x_max[:, None]
# Read MN elements; write MN elements.
numerator = torch.exp(z)
# Read MN elements; write M elements.
denominator = numerator.sum(dim=1)
# Read MN + M elements; write MN elements.
ret = numerator / denominator[:, None]
# Total: Read 5 × MN + 2 × M elements; write 3 × MN + 2 × M elements.
return ret
Purpose of kernel fusion
When implemented naively in PyTorch, computing y = naive_softmax(x) requires reading 5 × MN + 2 × M elements from DRAM and writing back 3 MN + 2 M elements. Obviously, this is very inefficient. A more efficient solution is to use a custom "fused" kernel that reads x only once and completes all necessary computations on the chip.
Doing so requires reading and writing back only 2 × MN bytes. Therefore, the theoretical speedup ratio is about 4 times, that is, 8 × MN + 4 × M)/2 × MN.
torch.jit.script is designed to automatically perform this kind of "kernel fusion", but it is still far from ideal.
Compute Kernel
The softmax kernel works as follows: Each compute unit (program) loads a group of data rows of the input matrix X stridden by number of programs, normalizes it, and writes back the result to the output matrix Y. Note: A significant limitation of Triton is that each block must have a power-of-two number of elements. Therefore, to handle any possible input shapes, internally "pad" each row and ensure the correctness of memory operations.
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr):
# Program start row
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step):
# The stride indicates the required increment of the pointer to advance one row.
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit
# rows in a single block.
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM using a mask, because BLOCK_SIZE may be greater than n_cols.
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract the maximum value for numerical stability.
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate.
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write the output back to DRAM.
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
Create a helper function. This function can add the kernel function and its meta-parameters to the execution queue to process any given input tensor.
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size for each loop iteration is the smallest power of two greater than or equal to the number of columns in `x`.
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Allocate output space.
y = torch.empty_like(x)
# Precompile the kernel to obtain the register usage and compute the thread occupancy.
kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
if kernel is None:
num_programs = 32
kernel = softmax_kernel
kernels[BLOCK_SIZE] = (kernel, num_programs)
num_programs = min(num_programs, n_rows)
kernel[(num_programs, 1, 1)](
y,
x,
x.stride(0),
y.stride(0),
n_rows,
n_cols,
BLOCK_SIZE
)
return y
Unit Test
The processed kernel needs to be tested on a matrix with irregular numbers of rows and columns. This can verify that the padding mechanism works.
torch.manual_seed(0)
x = torch.randn(1823, 781, device='npu')
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
print(y_triton)
print(y_torch)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(y_triton-y_torch))}')
Output:
tensor([[0.0002, 0.0017, 0.0009, ..., 0.0009, 0.0013, 0.0073],
[0.0001, 0.0004, 0.0006, ..., 0.0006, 0.0004, 0.0003],
[0.0007, 0.0002, 0.0006, ..., 0.0011, 0.0004, 0.0039],
...,
[0.0021, 0.0002, 0.0015, ..., 0.0012, 0.0014, 0.0022],
[0.0003, 0.0002, 0.0007, ..., 0.0005, 0.0006, 0.0007],
[0.0034, 0.0014, 0.0005, ..., 0.0007, 0.0016, 0.0028]],
device='npu:0')
tensor([[0.0002, 0.0017, 0.0009, ..., 0.0009, 0.0013, 0.0073],
[0.0001, 0.0004, 0.0006, ..., 0.0006, 0.0004, 0.0003],
[0.0007, 0.0002, 0.0006, ..., 0.0011, 0.0004, 0.0039],
...,
[0.0021, 0.0002, 0.0015, ..., 0.0012, 0.0014, 0.0022],
[0.0003, 0.0002, 0.0007, ..., 0.0005, 0.0006, 0.0007],
[0.0034, 0.0014, 0.0005, ..., 0.0007, 0.0016, 0.0028]],
device='npu:0')
The maximum difference between torch and triton is 1.4901161193847656e-08
"The maximum difference between torch and triton is 1.4901161193847656e-08" indicates that the output results of Triton and PyTorch are very close and cannot be visually distinguished.