Layer Normalization

In this section, you will use Triton to write a high-performance layer normalization kernel that runs faster than the PyTorch implementation.

Compute Kernel

import pytest
import torch
import triton
import triton.language as tl
import torch_npu

@triton.jit
def _layer_norm_fwd_fused(
    X,  # Pointer to the input
    Y,  # Pointer to the output
    W,  # Pointer to the weights
    B,  # Pointer to the biases
    Mean,  # Pointer to the mean
    Rstd,  # Pointer to the 1/std
    stride,  # Number of elements to be added when the pointer moves by one row
    N,  # Number of columns in X
    eps,  # Epsilon used to avoid division by zero
    BLOCK_SIZE: tl.constexpr,
):
    # Map the program ID to the corresponding rows of X and Y for computation.
    row = tl.program_id(0)
    Y += row * stride
    X += row * stride
    # Calculate the mean.
    mean = 0
    _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
        _mean += a
    mean = tl.sum(_mean, axis=0) / N
    # Calculate the variance.
    _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
        x = tl.where(cols < N, x - mean, 0.)
        _var += x * x
    var = tl.sum(_var, axis=0) / N
    rstd = 1 / tl.sqrt(var + eps)
    # Write mean/rstd.
    tl.store(Mean + row, mean)
    tl.store(Rstd + row, rstd)
    # Normalize and apply linear transformation.
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        mask = cols < N
        w = tl.load(W + cols, mask=mask)
        b = tl.load(B + cols, mask=mask)
        x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
        x_hat = (x - mean) * rstd
        y = x_hat * w + b
        # Write the output.
        tl.store(Y + cols, y, mask=mask)

LayerNorm Implementation Defined by Using Triton

@torch.inference_mode()
def layer_norm(x, weight, bias, eps=1e-5):
    # Allocate the output tensor with the same shape and data type as the input.
    y = torch.empty_like(x)
    
    # Flatten the input x into a two-dimensional shape [-1, feature_dim] for processing the last dimension.
    x_arg = x.reshape(-1, x.shape[-1])
    M, N = x_arg.shape
    
    mean = torch.empty((M, ), dtype=torch.float32, device=x.device)
    rstd = torch.empty((M, ), dtype=torch.float32, device=x.device)

    BLOCK_SIZE = 1024

    # enqueue kernel
    kernel = _layer_norm_fwd_fused[(M,)](# M indicates the number of blocks, and launch grid=(M,)
        x_arg, y, weight, bias, mean, rstd,  # Input, output, and intermediate variables
        x_arg.stride(0), N, eps,
        BLOCK_SIZE=BLOCK_SIZE)
    # Return the normalized output.   
    return y

# Call layer normalization during forward pass.
def _layer_norm(M, N, dtype, eps=1e-5, device='npu'):
    # Construct data.
    x_shape = (M, N)
    w_shape = (x_shape[-1], )
    weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
    bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
    x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device)
    dy = .1 * torch.randn_like(x)
    x.requires_grad_(True)
    # Forward pass
    y_tri = layer_norm(x, weight, bias, eps)
    y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
    # Determine whether the results are approximate.
    assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
    print(f"y_tri: {y_tri}")
    print(f"y_ref: {y_ref}")
    print(f"Layer Normalization {M},{N} {dtype} PASSED!")

# Perform the test.
if __name__ == '__main__':
    _layer_norm(128, 128, torch.float16)
    _layer_norm(128, 128, torch.bfloat16)
    _layer_norm(128, 128, torch.float32)

Result

y_tri: tensor([[ 0.2512,  0.0647,  0.8389,  ...,  2.3652,  1.5039,  1.1904],
        [ 1.0908,  1.5391,  0.2269,  ...,  1.6846,  1.0996,  0.9614],
        [-0.2974,  0.5918,  0.3225,  ...,  2.2891, -0.8418,  0.6885],
        ...,
        [ 0.5225, -0.0068,  0.4968,  ..., -1.1221,  1.7422,  0.6143],
        [ 0.4463,  1.2441,  0.2224,  ...,  2.2969, -0.3311,  0.6177],
        [-0.0113,  0.8423,  0.3696,  ...,  1.3838,  1.2471,  0.8750]],
       device='npu:0', dtype=torch.float16)
y_ref: tensor([[ 0.2512,  0.0647,  0.8389,  ...,  2.3652,  1.5039,  1.1904],
        [ 1.0908,  1.5391,  0.2269,  ...,  1.6846,  1.0996,  0.9614],
        [-0.2974,  0.5918,  0.3225,  ...,  2.2891, -0.8418,  0.6885],
        ...,
        [ 0.5225, -0.0068,  0.4968,  ..., -1.1221,  1.7422,  0.6143],
        [ 0.4463,  1.2441,  0.2224,  ...,  2.2969, -0.3311,  0.6177],
        [-0.0113,  0.8423,  0.3696,  ...,  1.3838,  1.2471,  0.8750]],
       device='npu:0', dtype=torch.float16, grad_fn=<NativeLayerNormBackward0>)
Layer Normalization 128,128 torch.float16 PASSED!
y_tri: tensor([[-0.4180,  0.9648,  0.8633,  ...,  0.7656,  0.8438,  0.3633],
        [ 0.4453,  0.5352,  0.9102,  ...,  1.1875, -0.0562,  0.5391],
        [ 1.3125,  0.9961,  0.9219,  ...,  0.9688,  0.0025,  0.5156],
        ...,
        [-0.1426,  0.6289,  0.9609,  ...,  0.9648, -0.1260, -0.1270],
        [ 1.1641,  0.6680,  0.8281,  ...,  0.9258,  0.9062,  0.1768],
        [-0.2129,  0.7109,  0.9141,  ...,  0.7891, -0.0767,  0.5156]],
       device='npu:0', dtype=torch.bfloat16)
y_ref: tensor([[-0.4180,  0.9648,  0.8633,  ...,  0.7656,  0.8438,  0.3633],
        [ 0.4453,  0.5352,  0.9102,  ...,  1.1875, -0.0562,  0.5391],
        [ 1.3125,  0.9961,  0.9219,  ...,  0.9688,  0.0025,  0.5156],
        ...,
        [-0.1426,  0.6289,  0.9609,  ...,  0.9648, -0.1260, -0.1270],
        [ 1.1641,  0.6680,  0.8281,  ...,  0.9258,  0.9062,  0.1768],
        [-0.2129,  0.7109,  0.9141,  ...,  0.7891, -0.0767,  0.5156]],
       device='npu:0', dtype=torch.bfloat16, grad_fn=<NativeLayerNormBackward0>)
Layer Normalization 128,128 torch.bfloat16 PASSED!
y_tri: tensor([[-0.2980,  0.2922,  0.6481,  ...,  0.9786,  0.7304,  0.8982],
        [ 1.5911,  0.0474,  0.6518,  ...,  0.8013,  0.2435,  1.3748],
        [ 1.3024,  0.6265,  0.6473,  ...,  0.8423,  0.0984, -1.1839],
        ...,
        [-0.2195,  0.1359,  0.6461,  ...,  0.8319,  1.0899,  1.5015],
        [ 0.6371,  0.3687,  0.6530,  ...,  0.9359,  0.0818,  0.6499],
        [ 0.1178,  0.3639,  0.6475,  ...,  0.7221,  0.4622,  1.4510]],
       device='npu:0')
y_ref: tensor([[-0.2980,  0.2922,  0.6481,  ...,  0.9786,  0.7304,  0.8982],
        [ 1.5911,  0.0474,  0.6518,  ...,  0.8013,  0.2435,  1.3748],
        [ 1.3024,  0.6265,  0.6473,  ...,  0.8423,  0.0984, -1.1839],
        ...,
        [-0.2195,  0.1359,  0.6461,  ...,  0.8319,  1.0899,  1.5015],
        [ 0.6371,  0.3687,  0.6530,  ...,  0.9359,  0.0818,  0.6499],
        [ 0.1178,  0.3639,  0.6475,  ...,  0.7221,  0.4622,  1.4510]],
       device='npu:0', grad_fn=<NativeLayerNormBackward0>)
Layer Normalization 128,128 torch.float32 PASSED!

"Layer Normalization 128,128 torch.float16 PASSED!",
"Layer Normalization 128,128 torch.bfloat16 PASSED!",
The result "Layer Normalization 128,128 torch.float32 PASSED!" indicates that the output of float16, bfloat16, and float32 data types on Triton is the same as that on PyTorch.