"""
The 5th Generation TensorCore^TM
================================
This tutorial covers the APIs for interacting with Tensor Cores on Blackwell
GPUs. Blackwell Tensor Cores introduce a new memory space called Tensor Memory
that must be used to interact with the async MMA instructions.
In this tutorial, we will cover allocating and interacting with Tensor Memory
and demonstrate how to use the `tcgen05` MMA instructions. We will build a
simple matmul kernel to demonstrate practical uses of the APIs and show an
example of how to pipeline MMA instructions.
"""
import itertools
import pytest
import torch
import triton
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
from triton.experimental.gluon.language.nvidia.blackwell import (
TensorMemoryLayout,
allocate_tensor_memory,
get_tmem_32x32b_reg_layout,
tma,
mbarrier,
tcgen05_mma,
tcgen05_commit,
fence_async_shared,
)
def is_blackwell():
target = triton.runtime.driver.active.get_current_target()
return target.backend == "cuda" and torch.cuda.get_device_capability()[0] == 10
if __name__ == "__main__" and not is_blackwell():
raise RuntimeError("This tutorial requires a Blackwell NVIDIA GPU")
@gluon.jit
def tmem_example_kernel(in_ptr, out_ptr, M: gl.constexpr, N: gl.constexpr, num_warps: gl.constexpr):
global_memory_layout: gl.constexpr = gl.BlockedLayout([1, 1], [1, 32], [1, num_warps], [1, 0])
offs_m = gl.arange(0, M, gl.SliceLayout(1, global_memory_layout))
offs_n = gl.arange(0, N, gl.SliceLayout(0, global_memory_layout))
offs = offs_m[:, None] * N + offs_n[None, :]
input = gl.load(in_ptr + offs)
tmem_layout: gl.constexpr = TensorMemoryLayout(
block=(64, 64),
unpacked=True,
)
tmem = allocate_tensor_memory(
element_ty=in_ptr.dtype.element_ty,
shape=[M, N],
layout=tmem_layout,
)
tmem_reg_layout: gl.constexpr = get_tmem_32x32b_reg_layout(
M=64,
N=64,
shape=[M, N],
num_warps=num_warps,
)
input = gl.convert_layout(input, tmem_reg_layout)
tmem.store(input)
output = tmem.load(tmem_reg_layout)
output = gl.convert_layout(output, global_memory_layout)
gl.store(out_ptr + offs, output)
@pytest.mark.parametrize("M", [64, 128, 256])
@pytest.mark.parametrize("N", [64, 128])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_tmem_example_kernel(M, N, num_warps):
input = torch.randn(M, N, dtype=torch.float32, device="cuda")
output = torch.empty_like(input)
tmem_example_kernel[(1, )](input, output, M, N, num_warps=num_warps)
torch.testing.assert_close(input, output, atol=0, rtol=0)
@gluon.jit
def small_mma_kernel(a_desc, b_desc, c_desc, d_desc, tmem_block: gl.constexpr,
LHS_IN_TMEM: gl.constexpr, USE_COMMIT: gl.constexpr, num_warps: gl.constexpr):
bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(bar, count=1)
a_smem = gl.allocate_shared_memory(a_desc.dtype, a_desc.block_type.shape, a_desc.layout)
b_smem = gl.allocate_shared_memory(b_desc.dtype, b_desc.block_type.shape, b_desc.layout)
c_smem = gl.allocate_shared_memory(c_desc.dtype, c_desc.block_type.shape, c_desc.layout)
mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + c_desc.block_type.nbytes)
tma.async_copy_global_to_shared(a_desc, [0, 0], bar, a_smem)
tma.async_copy_global_to_shared(b_desc, [0, 0], bar, b_smem)
tma.async_copy_global_to_shared(c_desc, [0, 0], bar, c_smem)
mbarrier.wait(bar, phase=0)
mbarrier.invalidate(bar)
mbarrier.init(bar, count=1)
M: gl.constexpr = d_desc.block_type.shape[0]
N: gl.constexpr = d_desc.block_type.shape[1]
K: gl.constexpr = a_desc.block_type.shape[1]
acc_tmem_layout: gl.constexpr = TensorMemoryLayout(tmem_block.value, unpacked=True)
acc_tmem = allocate_tensor_memory(d_desc.dtype, [M, N], acc_tmem_layout)
acc_reg_layout: gl.constexpr = get_tmem_32x32b_reg_layout(tmem_block[0], tmem_block[1], [M, N], num_warps)
acc = c_smem.load(acc_reg_layout)
acc_tmem.store(acc)
if LHS_IN_TMEM:
lhs_tmem_layout: gl.constexpr = TensorMemoryLayout(tmem_block.value, unpacked=False)
lhs_tmem = allocate_tensor_memory(a_desc.dtype, [M, K], lhs_tmem_layout)
lhs_reg_layout: gl.constexpr = get_tmem_32x32b_reg_layout(M, K, [M, K], num_warps)
lhs = a_smem.load(lhs_reg_layout)
lhs_tmem.store(lhs)
a = lhs_tmem
else:
a = a_smem
if USE_COMMIT:
tcgen05_mma(a, b_smem, acc_tmem)
tcgen05_commit(bar)
else:
tcgen05_mma(a, b_smem, acc_tmem, mbarriers=[bar], mbarrier_preds=[True])
mbarrier.wait(bar, phase=0)
mbarrier.invalidate(bar)
d_smem = gl.allocate_shared_memory(d_desc.dtype, d_desc.block_type.shape, d_desc.layout)
acc = acc_tmem.load(acc_reg_layout)
d_smem.store(acc)
fence_async_shared()
tma.async_copy_shared_to_global(d_desc, [0, 0], d_smem)
tma.store_wait(pendings=0)
def small_mma(A, B, C, D, tmem_block, LHS_IN_TMEM, USE_COMMIT, num_warps):
a_layout = gl.NVMMASharedLayout.get_default_for(A.shape, gl.float16)
b_layout = gl.NVMMASharedLayout.get_default_for(B.shape, gl.float16)
cd_layout = gl.NVMMASharedLayout.get_default_for(C.shape, gl.float32)
a_desc = TensorDescriptor.from_tensor(A, A.shape, a_layout)
b_desc = TensorDescriptor.from_tensor(B, B.shape, b_layout)
c_desc = TensorDescriptor.from_tensor(C, C.shape, cd_layout)
d_desc = TensorDescriptor.from_tensor(D, D.shape, cd_layout)
small_mma_kernel[(1, )](
a_desc, b_desc, c_desc, d_desc, tmem_block,
LHS_IN_TMEM, USE_COMMIT, num_warps=num_warps)
@pytest.mark.parametrize("M, N, K", [(128, 128, 128), (64, 128, 128), (64, 256, 256), (256, 64, 64)])
@pytest.mark.parametrize("LHS_IN_TMEM", [False, True])
@pytest.mark.parametrize("USE_COMMIT", [False, True])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_small_mma(M, N, K, LHS_IN_TMEM, USE_COMMIT, num_warps):
torch.manual_seed(0)
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.randn(M, N, device="cuda", dtype=torch.float32)
D = torch.empty_like(C)
blockM = min(128, M)
blockN = N
small_mma(A, B, C, D, (blockM, blockN), LHS_IN_TMEM, USE_COMMIT, num_warps)
torch.testing.assert_close(A @ B + C, D, atol=1e-3, rtol=1e-1)
@gluon.jit
def blocked_matmul_kernel(a_desc, b_desc, c_desc, TRANSPOSE_B: gl.constexpr, num_warps: gl.constexpr):
BLOCK_M: gl.constexpr = c_desc.block_type.shape[0]
BLOCK_N: gl.constexpr = c_desc.block_type.shape[1]
BLOCK_K: gl.constexpr = a_desc.block_type.shape[1]
dtype: gl.constexpr = a_desc.dtype
K = a_desc.shape[1]
pid_m = gl.program_id(axis=0)
pid_n = gl.program_id(axis=1)
off_m = pid_m * BLOCK_M
off_n = pid_n * BLOCK_N
a_smem = gl.allocate_shared_memory(dtype, a_desc.block_type.shape, a_desc.layout)
b_smem = gl.allocate_shared_memory(dtype, b_desc.block_type.shape, b_desc.layout)
tma_bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(tma_bar, count=1)
mma_bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(mma_bar, count=1)
phase = 0
tmem_layout: gl.constexpr = TensorMemoryLayout([BLOCK_M, BLOCK_N], unpacked=True)
acc_tmem = allocate_tensor_memory(gl.float32, [BLOCK_M, BLOCK_N], tmem_layout)
use_acc = False
for k in range(0, K, BLOCK_K):
mbarrier.expect(tma_bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes)
tma.async_copy_global_to_shared(a_desc, [off_m, k], tma_bar, a_smem)
tma.async_copy_global_to_shared(b_desc, [off_n, k] if TRANSPOSE_B else [k, off_n], tma_bar, b_smem)
mbarrier.wait(tma_bar, phase=phase)
if TRANSPOSE_B:
b = b_smem.permute((1, 0))
else:
b = b_smem
tcgen05_mma(a_smem, b, acc_tmem, use_acc=use_acc)
tcgen05_commit(mma_bar)
mbarrier.wait(mma_bar, phase=phase)
use_acc = True
phase ^= 1
mbarrier.invalidate(tma_bar)
mbarrier.invalidate(mma_bar)
acc_reg_layout: gl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, BLOCK_N, [BLOCK_M, BLOCK_N], num_warps)
acc = acc_tmem.load(acc_reg_layout)
c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout)
c_smem.store(acc.to(dtype))
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [off_m, off_n], c_smem)
tma.store_wait(pendings=0)
def blocked_matmul(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, TRANSPOSE_B, num_warps):
M, N = C.shape
a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16)
a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout)
B_BLOCK_SHAPE = [BLOCK_N, BLOCK_K] if TRANSPOSE_B else [BLOCK_K, BLOCK_N]
b_layout = gl.NVMMASharedLayout.get_default_for(B_BLOCK_SHAPE, gl.float16)
b_desc = TensorDescriptor.from_tensor(B, B_BLOCK_SHAPE, b_layout)
c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16)
c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout)
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
blocked_matmul_kernel[grid](a_desc, b_desc, c_desc, TRANSPOSE_B, num_warps=num_warps)
@pytest.mark.parametrize("M, N, K", [(208, 416, 304), (2000, 1000, 2000)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(64, 64, 64), (128, 128, 128)])
@pytest.mark.parametrize("TRANSPOSE_B", [False, True])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_blocked_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, TRANSPOSE_B, num_warps):
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn((N, K) if TRANSPOSE_B else (K, N), device="cuda", dtype=torch.float16)
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
blocked_matmul(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, TRANSPOSE_B, num_warps)
C_ref = A @ (B.T if TRANSPOSE_B else B)
torch.testing.assert_close(C_ref, C, rtol=1e-3, atol=1e-1)
if __name__ == "__main__":
print("Benchmarking selected configs")
print("=============================")
M, N, K = 8192, 8192, 16 * 1024
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
print("BLOCK_M BLOCK_N BLOCK_K num_warps time (ms) tflops/s")
configs = []
for BLOCK_MN, BLOCK_K, num_warps in itertools.product([64, 128], [64, 128, 256], [4]):
if (BLOCK_MN * BLOCK_K) * 4 // 1024 > 224:
continue
configs.append((BLOCK_MN, BLOCK_K, num_warps))
fn = lambda: blocked_matmul(A, B, C, BLOCK_MN, BLOCK_MN, BLOCK_K, False, num_warps)
ms = triton.testing.do_bench(fn, warmup=100, rep=500)
flops = 2 * M * N * K
tflops_per_sec = flops * 1e-12 / (ms * 1e-3)
print(f"{BLOCK_MN:>7} {BLOCK_MN:>7} {BLOCK_K:>7} {num_warps:>9} {ms:>9.2f} {tflops_per_sec:>8.2f}")
print()
@gluon.jit
def get_and_increment(counter):
return counter % 2, counter // 2 & 1, counter + 1
@gluon.jit
def blocked_matmul_pipelined_kernel(a_desc, b_desc, c_desc, num_warps: gl.constexpr):
BLOCK_M: gl.constexpr = c_desc.block_type.shape[0]
BLOCK_N: gl.constexpr = c_desc.block_type.shape[1]
BLOCK_K: gl.constexpr = a_desc.block_type.shape[1]
dtype: gl.constexpr = a_desc.dtype
K = a_desc.shape[1]
pid_m = gl.program_id(axis=0)
pid_n = gl.program_id(axis=1)
off_m = pid_m * (2 * BLOCK_M)
off_n = pid_n * BLOCK_N
u_bufs = gl.allocate_shared_memory(dtype, [2] + a_desc.block_type.shape, a_desc.layout)
v_bufs = gl.allocate_shared_memory(dtype, [2] + a_desc.block_type.shape, a_desc.layout)
b_bufs = gl.allocate_shared_memory(dtype, [2] + b_desc.block_type.shape, b_desc.layout)
tmem_layout: gl.constexpr = TensorMemoryLayout([BLOCK_M, BLOCK_N], unpacked=True)
ub_tmem = allocate_tensor_memory(gl.float32, [BLOCK_M, BLOCK_N], tmem_layout)
vb_tmem = allocate_tensor_memory(gl.float32, [BLOCK_M, BLOCK_N], tmem_layout)
mma_ub_bars = gl.allocate_shared_memory(gl.int64, [2, 1], mbarrier.MBarrierLayout())
mma_vb_bars = gl.allocate_shared_memory(gl.int64, [2, 1], mbarrier.MBarrierLayout())
load_ub_bars = gl.allocate_shared_memory(gl.int64, [2, 1], mbarrier.MBarrierLayout())
load_v_bars = gl.allocate_shared_memory(gl.int64, [2, 1], mbarrier.MBarrierLayout())
for i in gl.static_range(2):
mbarrier.init(mma_ub_bars.index(i), count=1)
mbarrier.init(mma_vb_bars.index(i), count=1)
mbarrier.init(load_ub_bars.index(i), count=1)
mbarrier.init(load_v_bars.index(i), count=1)
load_counter = 0
mma_counter = 0
k = 0
ub_acc = False
vb_acc = False
load_index, load_phase, load_counter = get_and_increment(load_counter)
load_ub_bar = load_ub_bars.index(load_index)
mbarrier.expect(load_ub_bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes)
tma.async_copy_global_to_shared(a_desc, [off_m, k], load_ub_bar, u_bufs.index(load_index))
tma.async_copy_global_to_shared(b_desc, [k, off_n], load_ub_bar, b_bufs.index(load_index))
load_v_bar = load_v_bars.index(load_index)
mbarrier.expect(load_v_bar, a_desc.block_type.nbytes)
tma.async_copy_global_to_shared(a_desc, [off_m + BLOCK_M, k], load_v_bar, v_bufs.index(load_index))
k += BLOCK_K
load_index, load_phase, load_counter = get_and_increment(load_counter)
load_ub_bar = load_ub_bars.index(load_index)
mbarrier.expect(load_ub_bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes)
tma.async_copy_global_to_shared(a_desc, [off_m, k], load_ub_bar, u_bufs.index(load_index))
tma.async_copy_global_to_shared(b_desc, [k, off_n], load_ub_bar, b_bufs.index(load_index))
load_v_bar = load_v_bars.index(load_index)
mbarrier.expect(load_v_bar, a_desc.block_type.nbytes)
tma.async_copy_global_to_shared(a_desc, [off_m + BLOCK_M, k], load_v_bar, v_bufs.index(load_index))
k += BLOCK_K
for _ in range(gl.cdiv(K, BLOCK_K) - 2):
mma_index, mma_phase, mma_counter = get_and_increment(mma_counter)
mbarrier.wait(load_ub_bars.index(mma_index), mma_phase)
tcgen05_mma(u_bufs.index(mma_index), b_bufs.index(mma_index), ub_tmem, use_acc=ub_acc)
tcgen05_commit(mma_ub_bars.index(mma_index))
ub_acc = True
mbarrier.wait(load_v_bars.index(mma_index), mma_phase)
tcgen05_mma(v_bufs.index(mma_index), b_bufs.index(mma_index), vb_tmem, use_acc=vb_acc)
tcgen05_commit(mma_vb_bars.index(mma_index))
vb_acc = True
load_index, load_phase, load_counter = get_and_increment(load_counter)
mbarrier.wait(mma_ub_bars.index(mma_index), mma_phase)
load_ub_bar = load_ub_bars.index(load_index)
mbarrier.expect(load_ub_bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes)
tma.async_copy_global_to_shared(a_desc, [off_m, k], load_ub_bar, u_bufs.index(load_index))
mbarrier.wait(mma_vb_bars.index(mma_index), mma_phase)
tma.async_copy_global_to_shared(b_desc, [k, off_n], load_ub_bar, b_bufs.index(load_index))
load_v_bar = load_v_bars.index(load_index)
mbarrier.expect(load_v_bar, a_desc.block_type.nbytes)
tma.async_copy_global_to_shared(a_desc, [off_m + BLOCK_M, k], load_v_bar, v_bufs.index(load_index))
k += BLOCK_K
acc_reg_layout: gl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, BLOCK_N, [BLOCK_M, BLOCK_N], num_warps)
mma_index, mma_phase, mma_counter = get_and_increment(mma_counter)
ub_bar = mma_ub_bars.index(mma_index)
vb_bar = mma_vb_bars.index(mma_index)
epilogue_phase = mma_phase
mbarrier.wait(load_ub_bars.index(mma_index), mma_phase)
tcgen05_mma(u_bufs.index(mma_index), b_bufs.index(mma_index), ub_tmem, use_acc=True)
mbarrier.wait(load_v_bars.index(mma_index), mma_phase)
tcgen05_mma(v_bufs.index(mma_index), b_bufs.index(mma_index), vb_tmem, use_acc=True)
mma_index, mma_phase, mma_counter = get_and_increment(mma_counter)
mbarrier.wait(load_ub_bars.index(mma_index), mma_phase)
tcgen05_mma(u_bufs.index(mma_index), b_bufs.index(mma_index), ub_tmem, use_acc=True)
tcgen05_commit(ub_bar)
mbarrier.wait(load_v_bars.index(mma_index), mma_phase)
tcgen05_mma(v_bufs.index(mma_index), b_bufs.index(mma_index), vb_tmem, use_acc=True)
tcgen05_commit(vb_bar)
mbarrier.wait(ub_bar, epilogue_phase)
c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout)
ub = ub_tmem.load(acc_reg_layout)
c_smem.store(ub.to(dtype))
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [off_m, off_n], c_smem)
mbarrier.wait(vb_bar, epilogue_phase)
vb = vb_tmem.load(acc_reg_layout)
tma.store_wait(pendings=0)
c_smem.store(vb.to(dtype))
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [off_m + BLOCK_M, off_n], c_smem)
tma.store_wait(pendings=0)
def blocked_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_warps):
M, N = C.shape
a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16)
b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], gl.float16)
c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16)
a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout)
b_desc = TensorDescriptor.from_tensor(B, [BLOCK_K, BLOCK_N], b_layout)
c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout)
grid = (triton.cdiv(M, 2 * BLOCK_M), triton.cdiv(N, BLOCK_N))
blocked_matmul_pipelined_kernel[grid](a_desc, b_desc, c_desc, num_warps=num_warps)
@pytest.mark.parametrize("M, N, K", [(208, 416, 304), (2000, 1000, 2000)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(64, 64, 64), (128, 128, 128)])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_blocked_matmul_pipelined(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps):
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
blocked_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_warps)
torch.testing.assert_close(A @ B, C, rtol=1e-3, atol=1e-1)
if __name__ == "__main__":
print("Benchmarking pipelined matmul")
print("=============================")
print("BLOCK_M BLOCK_N BLOCK_K num_warps time (ms) tflops/s")
for BLOCK_M, BLOCK_N, BLOCK_K, num_warps in itertools.product([128], [128], [64, 128], [4, 8]):
fn = lambda: blocked_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_warps)
ms = triton.testing.do_bench(fn, warmup=200, rep=1000)
flops = 2 * M * N * K
tflops_per_sec = flops * 1e-12 / (ms * 1e-3)
print(f"{BLOCK_M:>7} {BLOCK_N:>7} {BLOCK_K:>7} {num_warps:>9} {ms:>9.2f} {tflops_per_sec:>8.2f}")
print()