"""
Layer Normalization
=============
"""
import pytest
import torch
import triton
import triton.language as tl
import torch_npu
@triton.jit
def _layer_norm_fwd_fused(
X,
Y,
W,
B,
Mean,
Rstd,
stride,
N,
eps,
BLOCK_SIZE: tl.constexpr,
):
row = tl.program_id(0)
Y += row * stride
X += row * stride
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.where(cols < N, x - mean, 0.)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
x_hat = (x - mean) * rstd
y = x_hat * w + b
tl.store(Y + cols, y, mask=mask)
@torch.inference_mode()
def layer_norm(x, normalized_shape, weight, bias, eps=1e-5):
y = torch.empty_like(x)
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
mean = torch.empty((M, ), dtype=torch.float32, device=x.device)
rstd = torch.empty((M, ), dtype=torch.float32, device=x.device)
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
kernel = _layer_norm_fwd_fused[(M, )](
x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
return y
def _layer_norm(M, N, dtype, eps=1e-5, device='npu'):
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device)
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
y_tri = layer_norm(x, w_shape, weight, bias, eps)
y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
print(f"y_tri: {y_tri}")
print(f"y_ref: {y_ref}")
print(f"Layer Normalization {M},{N} {dtype} PASSED!")
if __name__ == "__main__":
_layer_norm(128, 128, torch.float16)
_layer_norm(128, 128, torch.bfloat16)
_layer_norm(128, 128, torch.float32)