"""
Async Copy in Gluon
===================
Modern GPUs provide asynchronous instructions for long-running operations like
global memory reads and writes. Asynchronous operations allow overlapping memory
transactions with compute, also known as "pipelining".
Asynchronous instructions vary by GPU vendor and architecture, so this tutorial
focuses on NVIDIA GPUs. On NVIDIA GPUs, async copies transfer data between
global memory and shared memory, unlike `gl.load` and `gl.store` which
directly write to and read from the register file.
"""
import pytest
import torch
import triton
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.language.nvidia.ampere import async_copy as cp
def is_ampere_or_newer():
target = triton.runtime.driver.active.get_current_target()
return target.backend == "cuda" and torch.cuda.get_device_capability()[0] >= 8
if __name__ == "__main__" and not is_ampere_or_newer():
raise RuntimeError("This tutorial requires Ampere or newer NVIDIA GPU")
@gluon.jit
def memcpy_1d_cpasync_kernel(in_ptr, out_ptr, xnumel, XBLOCK: gl.constexpr):
pid = gl.program_id(0)
layout: gl.constexpr = gl.BlockedLayout([1], [32], [4], [0])
offsets = pid * XBLOCK + gl.arange(0, XBLOCK, layout=layout)
mask = offsets < xnumel
smem_layout: gl.constexpr = gl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[0])
smem = gl.allocate_shared_memory(gl.float32, [XBLOCK], layout=smem_layout)
cp.async_copy_global_to_shared(smem, in_ptr + offsets, mask=mask)
cp.commit_group()
cp.wait_group(0)
value = smem.load(layout)
gl.store(out_ptr + offsets, value, mask=mask)
def memcpy_1d_cpasync(input, output, XBLOCK=8192, num_warps=4):
grid = (triton.cdiv(input.numel(), XBLOCK), )
memcpy_1d_cpasync_kernel[grid](input, output, input.numel(), XBLOCK, num_warps=num_warps)
@pytest.mark.parametrize("xnumel, XBLOCK", [(200, 128), (1000, 256)])
@pytest.mark.skipif(not is_ampere_or_newer(), reason="Requires Ampere or newer")
def test_memcpy_1d_cpasync(xnumel, XBLOCK):
input = torch.randn(xnumel, device="cuda")
output = torch.empty_like(input)
memcpy_1d_cpasync(input, output, XBLOCK)
torch.testing.assert_close(input, output, atol=0, rtol=0)
@gluon.jit
def elementwise_add_kernel(
a_ptr, b_ptr, c_ptr, xnumel, ynumel,
xstride_a, ystride_a, xstride_b, ystride_b, xstride_c, ystride_c,
XBLOCK: gl.constexpr, YBLOCK: gl.constexpr,
):
pid = gl.program_id(0)
layout: gl.constexpr = gl.BlockedLayout([1, 1], [1, 32], [1, 4], [1, 0])
xoffs = pid * XBLOCK + gl.arange(0, XBLOCK, gl.SliceLayout(1, layout))
a_ptrs = a_ptr + xstride_a * xoffs[:, None]
b_ptrs = b_ptr + xstride_b * xoffs[:, None]
c_ptrs = c_ptr + xstride_c * xoffs[:, None]
for yoff in range(0, ynumel, YBLOCK):
yoffs = yoff + gl.arange(0, YBLOCK, gl.SliceLayout(0, layout))
mask = (xoffs < xnumel)[:, None] & (yoffs < ynumel)[None, :]
a_val = gl.load(a_ptrs + ystride_a * yoffs[None, :], mask=mask)
b_val = gl.load(b_ptrs + ystride_b * yoffs[None, :], mask=mask)
c_val = a_val + b_val
gl.store(c_ptrs + ystride_c * yoffs[None, :], c_val, mask=mask)
def elementwise_add(A, B, C, XBLOCK=32, YBLOCK=64):
assert A.shape == B.shape == C.shape
xnumel, ynumel = A.shape
grid = (triton.cdiv(xnumel, XBLOCK), )
return elementwise_add_kernel[grid](
A, B, C, xnumel, ynumel,
*A.stride(), *B.stride(), *C.stride(),
XBLOCK, YBLOCK)
@pytest.mark.parametrize("xnumel, ynumel", [(1000, 2000)])
@pytest.mark.parametrize("XBLOCK, YBLOCK", [(32, 32), (128, 128)])
def test_elementwise_add(xnumel, ynumel, XBLOCK, YBLOCK):
a = torch.randn(xnumel, ynumel, device="cuda")
b = torch.randn(xnumel, ynumel, device="cuda")
c = torch.empty_like(a, device="cuda")
elementwise_add(a, b, c, XBLOCK, YBLOCK)
torch.testing.assert_close(a + b, c, atol=0, rtol=0)
@gluon.jit
def elementwise_add_cpasync_kernel(
a_ptr, b_ptr, c_ptr, xnumel, ynumel,
xstride_a, ystride_a, xstride_b, ystride_b, xstride_c, ystride_c,
XBLOCK: gl.constexpr, YBLOCK: gl.constexpr,
smem_layout: gl.constexpr,
):
pid = gl.program_id(0)
layout: gl.constexpr = gl.BlockedLayout([1, 1], [1, 32], [1, 4], [1, 0])
xoffs = pid * XBLOCK + gl.arange(0, XBLOCK, gl.SliceLayout(1, layout))
a_ptrs = a_ptr + xstride_a * xoffs[:, None]
b_ptrs = b_ptr + xstride_b * xoffs[:, None]
c_ptrs = c_ptr + xstride_c * xoffs[:, None]
dtype: gl.constexpr = a_ptr.dtype.element_ty
a_smem = gl.allocate_shared_memory(dtype, [XBLOCK, YBLOCK], layout=smem_layout)
b_smem = gl.allocate_shared_memory(dtype, [XBLOCK, YBLOCK], layout=smem_layout)
for yoff in range(0, ynumel, YBLOCK):
yoffs = yoff + gl.arange(0, YBLOCK, gl.SliceLayout(0, layout))
mask = (xoffs < xnumel)[:, None] & (yoffs < ynumel)[None, :]
cp.async_copy_global_to_shared(a_smem, a_ptrs + ystride_a * yoffs[None, :], mask=mask)
cp.async_copy_global_to_shared(b_smem, b_ptrs + ystride_b * yoffs[None, :], mask=mask)
cp.commit_group()
cp.wait_group(0)
a_val = a_smem.load(layout)
b_val = b_smem.load(layout)
c_val = a_val + b_val
gl.store(c_ptrs + ystride_c * yoffs[None, :], c_val, mask=mask)
def elementwise_add_cpasync(A, B, C, smem_layout, XBLOCK=32, YBLOCK=64):
assert A.shape == B.shape == C.shape
xnumel, ynumel = A.shape
grid = (triton.cdiv(xnumel, XBLOCK), )
return elementwise_add_cpasync_kernel[grid](
A, B, C, xnumel, ynumel,
*A.stride(), *B.stride(), *C.stride(),
XBLOCK, YBLOCK, smem_layout)
@pytest.mark.parametrize("xnumel, ynumel", [(1000, 2000)])
@pytest.mark.parametrize("XBLOCK, YBLOCK", [(32, 32), (128, 128)])
@pytest.mark.skipif(not is_ampere_or_newer(), reason="Requires Ampere or newer")
def test_elementwise_add_cpasync(xnumel, ynumel, XBLOCK, YBLOCK):
a = torch.randn(xnumel, ynumel, device="cuda")
b = torch.randn(xnumel, ynumel, device="cuda")
c = torch.empty_like(a, device="cuda")
smem_layout = gl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0])
elementwise_add_cpasync(a, b, c, smem_layout, XBLOCK, YBLOCK)
torch.testing.assert_close(a + b, c, atol=0, rtol=0)
def get_throughput(ms, C):
tbytes = (3 * C.numel() * C.element_size() >> 30) / 1024
return tbytes / (ms * 1e-3)
if __name__ == "__main__":
print("Benchmarking elementwise_add")
print("============================")
xnumel, ynumel = 32 * 1024, 32 * 1024
A = torch.randn(xnumel, ynumel, device="cuda")
B = torch.randn(xnumel, ynumel, device="cuda")
C = torch.empty_like(A, device="cuda")
ms = triton.testing.do_bench(lambda: elementwise_add(A, B, C))
print(f"elementwise_add: {get_throughput(ms, C):.2f} TB/s")
smem_layout = gl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0])
ms = triton.testing.do_bench(lambda: elementwise_add_cpasync(A, B, C, smem_layout))
print(f"elementwise_add_cpasync: {get_throughput(ms, C):.2f} TB/s")
@gluon.jit
def issue_loads(copy_idx, a_smem, b_smem, a_ptrs, ystride_a, b_ptrs, xmask, ynumel, y_idx, ystride_b,
YBLOCK: gl.constexpr, num_buffers: gl.constexpr):
yoffs = copy_idx * YBLOCK + y_idx
mask = xmask & (yoffs < ynumel)[None, :]
cp.async_copy_global_to_shared(a_smem.index(copy_idx % num_buffers),
a_ptrs + ystride_a * yoffs[None, :], mask)
cp.async_copy_global_to_shared(b_smem.index(copy_idx % num_buffers),
b_ptrs + ystride_b * yoffs[None, :], mask)
cp.commit_group()
return copy_idx + 1
@gluon.jit
def perform_add(read_idx, a_smem, b_smem, c_ptrs, ynumel, ystride_c, y_idx, xmask, YBLOCK: gl.constexpr,
num_buffers: gl.constexpr, layout: gl.constexpr):
a_val = a_smem.index(read_idx % num_buffers).load(layout)
b_val = b_smem.index(read_idx % num_buffers).load(layout)
c_val = a_val + b_val
yoffs = read_idx * YBLOCK + y_idx
mask = xmask & (yoffs < ynumel)[None, :]
gl.store(c_ptrs + ystride_c * yoffs[None, :], c_val, mask=mask)
return read_idx + 1
@gluon.jit
def elementwise_add_pipelined_kernel(
a_ptr, b_ptr, c_ptr, xnumel, ynumel,
xstride_a, ystride_a, xstride_b, ystride_b, xstride_c, ystride_c,
XBLOCK: gl.constexpr, YBLOCK: gl.constexpr,
smem_layout: gl.constexpr, num_buffers: gl.constexpr,
):
pid = gl.program_id(0)
layout: gl.constexpr = gl.BlockedLayout([1, 1], [1, 32], [1, 4], [1, 0])
xoffs = pid * XBLOCK + gl.arange(0, XBLOCK, gl.SliceLayout(1, layout))
a_ptrs = a_ptr + xstride_a * xoffs[:, None]
b_ptrs = b_ptr + xstride_b * xoffs[:, None]
c_ptrs = c_ptr + xstride_c * xoffs[:, None]
y_idx = gl.arange(0, YBLOCK, gl.SliceLayout(0, layout))
xmask = (xoffs < xnumel)[:, None]
dtype: gl.constexpr = a_ptr.dtype.element_ty
a_smem = gl.allocate_shared_memory(dtype, [num_buffers, XBLOCK, YBLOCK], layout=smem_layout)
b_smem = gl.allocate_shared_memory(dtype, [num_buffers, XBLOCK, YBLOCK], layout=smem_layout)
copy_idx = 0
read_idx = 0
for _ in gl.static_range(num_buffers - 1):
copy_idx = issue_loads(copy_idx, a_smem, b_smem, a_ptrs, ystride_a, b_ptrs, xmask, ynumel, y_idx, ystride_b,
YBLOCK, num_buffers)
for _ in range(gl.cdiv(ynumel, YBLOCK) - (num_buffers - 1)):
copy_idx = issue_loads(copy_idx, a_smem, b_smem, a_ptrs, ystride_a, b_ptrs, xmask, ynumel, y_idx, ystride_b,
YBLOCK, num_buffers)
cp.wait_group(num_buffers - 1)
read_idx = perform_add(read_idx, a_smem, b_smem, c_ptrs, ynumel, ystride_c, y_idx, xmask, YBLOCK, num_buffers,
layout)
for i in gl.static_range(num_buffers - 1):
cp.wait_group(num_buffers - 2 - i)
read_idx = perform_add(read_idx, a_smem, b_smem, c_ptrs, ynumel, ystride_c, y_idx, xmask, YBLOCK, num_buffers,
layout)
def elementwise_add_pipelined(A, B, C, XBLOCK=32, YBLOCK=64, num_buffers=2):
assert A.shape == B.shape == C.shape
xnumel, ynumel = A.shape
grid = (triton.cdiv(xnumel, XBLOCK), )
smem_layout = gl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0])
return elementwise_add_pipelined_kernel[grid](
A, B, C, xnumel, ynumel,
*A.stride(), *B.stride(), *C.stride(),
XBLOCK, YBLOCK, smem_layout, num_buffers)
@pytest.mark.parametrize("xnumel, ynumel", [(1000, 2000), (4000, 120)])
@pytest.mark.parametrize("XBLOCK, YBLOCK", [(32, 64)])
@pytest.mark.parametrize("num_buffers", [1, 2, 3])
@pytest.mark.skipif(not is_ampere_or_newer(), reason="Requires Ampere or newer")
def test_elementwise_add_pipelined(xnumel, ynumel, XBLOCK, YBLOCK, num_buffers):
a = torch.randn(xnumel, ynumel, device="cuda")
b = torch.randn(xnumel, ynumel, device="cuda")
c = torch.empty_like(a, device="cuda")
elementwise_add_pipelined(a, b, c, XBLOCK, YBLOCK, num_buffers)
torch.testing.assert_close(a + b, c, atol=0, rtol=0)
if __name__ == "__main__":
ms = triton.testing.do_bench(lambda: elementwise_add_pipelined(A, B, C, num_buffers=2))
print(f"elementwise_add_pipelined (double buffer): {get_throughput(ms, C):.2f} TB/s")
ms = triton.testing.do_bench(lambda: elementwise_add_pipelined(A, B, C, num_buffers=3))
print(f"elementwise_add_pipelined (triple buffer): {get_throughput(ms, C):.2f} TB/s")