"""
Warp-Group MMA
==============
Warp-Group MMA (also known as WGMMA or MMAv3) is a Hopper-specific instruction
for performing matrix multiply-accumulate operations using the Tensor Cores.
WGMMA instructions are asynchronous, meaning they can be pipelined.
In this tutorial, we will cover how to use WGMMAs in Gluon. We will build a
simple matmul kernel to demonstrate practical uses of WGMMA, and show an example
where WGMMAs can be pipelined for better performance.
"""
import pytest
import torch
import triton
import itertools
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
from triton.experimental.gluon.language.nvidia.hopper import (
tma,
mbarrier,
fence_async_shared,
warpgroup_mma_init,
warpgroup_mma,
warpgroup_mma_wait,
)
def is_hopper():
target = triton.runtime.driver.active.get_current_target()
return target.backend == "cuda" and torch.cuda.get_device_capability()[0] == 9
if __name__ == "__main__" and not is_hopper():
raise RuntimeError("This tutorial requires a Hopper NVIDIA GPU")
@gluon.jit
def small_mma_kernel(a_desc, b_desc, c_desc, d_desc,
LHS_IN_REG: gl.constexpr, INSTR_SHAPE_N: gl.constexpr, num_warps: gl.constexpr):
bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(bar, count=1)
a_smem = gl.allocate_shared_memory(a_desc.dtype, a_desc.block_type.shape, a_desc.layout)
b_smem = gl.allocate_shared_memory(b_desc.dtype, b_desc.block_type.shape, b_desc.layout)
c_smem = gl.allocate_shared_memory(c_desc.dtype, c_desc.block_type.shape, c_desc.layout)
mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + c_desc.block_type.nbytes)
tma.async_copy_global_to_shared(a_desc, [0, 0], bar, a_smem)
tma.async_copy_global_to_shared(b_desc, [0, 0], bar, b_smem)
tma.async_copy_global_to_shared(c_desc, [0, 0], bar, c_smem)
mbarrier.wait(bar, phase=0)
mbarrier.invalidate(bar)
m: gl.constexpr = 16
k: gl.constexpr = 256 // a_desc.dtype.primitive_bitwidth
n: gl.constexpr = INSTR_SHAPE_N
warps_per_cta: gl.constexpr = [num_warps, 1]
c_layout: gl.constexpr = gl.NVMMADistributedLayout(
version=[3, 0],
warps_per_cta=warps_per_cta,
instr_shape=[m, n, k],
)
a_reg_layout: gl.constexpr = gl.DotOperandLayout(
operand_index=0,
parent=c_layout,
k_width=32 // a_desc.dtype.primitive_bitwidth,
)
gl.static_assert(isinstance(a_smem.type.layout, gl.NVMMASharedLayout))
gl.static_assert(isinstance(b_smem.type.layout, gl.NVMMASharedLayout))
if LHS_IN_REG:
a = a_smem.load(a_reg_layout)
else:
a = a_smem
c = c_smem.load(c_layout)
d = warpgroup_mma(a, b_smem, c, is_async=True, use_acc=True)
d = warpgroup_mma_wait(num_outstanding=0, deps=(d, ))
d_smem = gl.allocate_shared_memory(d_desc.dtype, d_desc.block_type.shape, d_desc.layout)
d_smem.store(d)
fence_async_shared()
tma.async_copy_shared_to_global(d_desc, [0, 0], d_smem)
tma.store_wait(pendings=0)
def small_mma(A, B, C, D, INSTR_SHAPE_N, LHS_IN_REG=False, num_warps=4):
a_layout = gl.NVMMASharedLayout.get_default_for(A.shape, gl.float16)
b_layout = gl.NVMMASharedLayout.get_default_for(B.shape, gl.float16)
cd_layout = gl.NVMMASharedLayout.get_default_for(C.shape, gl.float32)
a_desc = TensorDescriptor.from_tensor(A, A.shape, a_layout)
b_desc = TensorDescriptor.from_tensor(B, B.shape, b_layout)
c_desc = TensorDescriptor.from_tensor(C, C.shape, cd_layout)
d_desc = TensorDescriptor.from_tensor(D, D.shape, cd_layout)
small_mma_kernel[(1, )](
a_desc, b_desc, c_desc, d_desc,
LHS_IN_REG, INSTR_SHAPE_N, num_warps=num_warps)
@pytest.mark.parametrize("M, N, K", [(64, 32, 32), (64, 256, 128)])
@pytest.mark.parametrize("LHS_IN_REG", [False, True])
@pytest.mark.parametrize("INSTR_SHAPE_N", [16, 64])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper")
def test_small_mma(M, N, K, LHS_IN_REG, INSTR_SHAPE_N, num_warps):
maxN = max(N // triton.cdiv(num_warps, triton.cdiv(M, 16)), 8)
if INSTR_SHAPE_N > maxN:
pytest.skip(f"INSTR_SHAPE_N={INSTR_SHAPE_N} is too large for M={M}, N={N}, num_warps={num_warps}")
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.randn(M, N, device="cuda", dtype=torch.float32)
D = torch.empty_like(C)
small_mma(A, B, C, D, INSTR_SHAPE_N, LHS_IN_REG, num_warps)
torch.testing.assert_close(A @ B + C, D, atol=1e-3, rtol=1e-1)
if __name__ == "__main__":
print("Benchmarking WGMMA")
print("==================")
M, N, K = 64, 128, 128
num_warps = 4
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.randn(M, N, device="cuda", dtype=torch.float32)
D = torch.empty_like(C)
print("LHS_IN_REG INSTR_SHAPE_N time (us)")
for LHS_IN_REG, INSTR_SHAPE_N in itertools.product([False, True], [16, 32, 64, 128]):
fn = lambda: small_mma(A, B, C, D, INSTR_SHAPE_N, LHS_IN_REG, num_warps)
ms = triton.testing.do_bench(fn)
print(f"{LHS_IN_REG!s:>10} {INSTR_SHAPE_N:>13} {ms*1000:>9.2f}")
print()
@gluon.constexpr_function
def get_warps_per_cta(BLOCK_M, BLOCK_N, num_warps):
warps_per_cta = [4, 1]
m = 16
while warps_per_cta[0] * warps_per_cta[1] != num_warps:
if BLOCK_M > m * warps_per_cta[0]:
warps_per_cta[0] *= 2
else:
warps_per_cta[1] *= 2
return warps_per_cta
@gluon.constexpr_function
def get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps):
m = 16
mReps = triton.cdiv(BLOCK_M, m)
nReps = triton.cdiv(num_warps, mReps)
maxN = max(BLOCK_N // nReps, 8)
n = 256
while n > maxN or BLOCK_N % n != 0:
n -= 8
assert n >= 8, "expected to find a valid n"
return n
@gluon.constexpr_function
def pick_wgmma_layout(dtype, BLOCK_M, BLOCK_N, num_warps):
m = 16
k = 256 // dtype.primitive_bitwidth
n = get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps)
warps_per_cta = get_warps_per_cta(BLOCK_M, BLOCK_N, num_warps)
return gl.NVMMADistributedLayout(
version=[3, 0],
warps_per_cta=warps_per_cta,
instr_shape=[m, n, k],
)
@gluon.jit
def blocked_matmul_kernel(a_desc, b_desc, c_desc,
TRANSPOSE_B: gl.constexpr, num_warps: gl.constexpr):
BLOCK_M: gl.constexpr = c_desc.block_type.shape[0]
BLOCK_N: gl.constexpr = c_desc.block_type.shape[1]
BLOCK_K: gl.constexpr = a_desc.block_type.shape[1]
dtype: gl.constexpr = a_desc.dtype
K = a_desc.shape[1]
a_smem = gl.allocate_shared_memory(dtype, a_desc.block_type.shape, a_desc.layout)
b_smem = gl.allocate_shared_memory(dtype, b_desc.block_type.shape, b_desc.layout)
pid_m = gl.program_id(axis=0)
pid_n = gl.program_id(axis=1)
off_m = pid_m * BLOCK_M
off_n = pid_n * BLOCK_N
mma_layout: gl.constexpr = pick_wgmma_layout(dtype, BLOCK_M, BLOCK_N, num_warps)
acc = gl.zeros((BLOCK_M, BLOCK_N), dtype=gl.float32, layout=mma_layout)
bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(bar, count=1)
phase = 0
for k in range(0, K, BLOCK_K):
mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes)
tma.async_copy_global_to_shared(a_desc, [off_m, k], bar, a_smem)
if TRANSPOSE_B:
tma.async_copy_global_to_shared(b_desc, [off_n, k], bar, b_smem)
else:
tma.async_copy_global_to_shared(b_desc, [k, off_n], bar, b_smem)
mbarrier.wait(bar, phase=phase)
phase ^= 1
if TRANSPOSE_B:
b = b_smem.permute((1, 0))
else:
b = b_smem
acc = warpgroup_mma(a_smem, b, acc, is_async=True)
acc = warpgroup_mma_wait(num_outstanding=0, deps=(acc, ))
mbarrier.invalidate(bar)
c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout)
c_smem.store(acc.to(dtype))
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [off_m, off_n], c_smem)
tma.store_wait(pendings=0)
def blocked_matmul(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, TRANSPOSE_B, num_warps):
M, N = C.shape
a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16)
a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout)
B_BLOCK_SHAPE = [BLOCK_N, BLOCK_K] if TRANSPOSE_B else [BLOCK_K, BLOCK_N]
b_layout = gl.NVMMASharedLayout.get_default_for(B_BLOCK_SHAPE, gl.float16)
b_desc = TensorDescriptor.from_tensor(B, B_BLOCK_SHAPE, b_layout)
c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16)
c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout)
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
blocked_matmul_kernel[grid](a_desc, b_desc, c_desc, TRANSPOSE_B, num_warps=num_warps)
@pytest.mark.parametrize("M, N, K", [(208, 416, 304), (2000, 1000, 2000)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(64, 64, 64), (128, 128, 128)])
@pytest.mark.parametrize("TRANSPOSE_B", [False, True])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper")
def test_blocked_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, TRANSPOSE_B, num_warps):
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn((N, K) if TRANSPOSE_B else (K, N), device="cuda", dtype=torch.float16)
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
blocked_matmul(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, TRANSPOSE_B, num_warps)
C_ref = A @ (B.T if TRANSPOSE_B else B)
torch.testing.assert_close(C_ref, C, rtol=1e-3, atol=1e-1)
def find_configs(occupancy, dtype, num_buffers=1):
dtype_bytes = torch.tensor([], dtype=dtype).element_size()
smem = 228 * 1024 // occupancy - 1024
configs = []
BLOCK_MNK = [32, 64, 128, 256]
for BLOCK_M, BLOCK_N, BLOCK_K, num_warps in itertools.product(BLOCK_MNK, BLOCK_MNK, BLOCK_MNK, [4, 8]):
regs = 64 * 1024 // occupancy - 16 * num_warps * 32
a_smem = BLOCK_M * BLOCK_K * dtype_bytes
b_smem = BLOCK_N * BLOCK_K * dtype_bytes
acc_smem = BLOCK_M * BLOCK_N * dtype_bytes
if max((a_smem + b_smem) * num_buffers, acc_smem) > smem:
continue
acc_regs = BLOCK_M * BLOCK_N
if acc_regs // num_warps // 32 >= 256:
continue
if acc_regs > regs:
continue
instr_shape_n = get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps).value
configs.append((BLOCK_M, BLOCK_N, BLOCK_K, num_warps, instr_shape_n, occupancy))
def filter_configs(configs, instr_shape_n):
max_n_configs = [cfg for cfg in configs if cfg[4] == instr_shape_n]
max_block_mk = max(cfg[0] * cfg[2] for cfg in max_n_configs)
return [cfg for cfg in max_n_configs if cfg[0] * cfg[2] == max_block_mk]
top_instr_shape_n = sorted({cfg[4] for cfg in configs}, reverse=True)
result_configs = filter_configs(configs, top_instr_shape_n[0])
if len(top_instr_shape_n) > 1:
result_configs += filter_configs(configs, top_instr_shape_n[1])
return result_configs
if __name__ == "__main__":
print("Benchmarking selected configs")
print("=============================")
configs = find_configs(occupancy=1, dtype=torch.float16)
configs += find_configs(occupancy=2, dtype=torch.float16)
M, N, K = 8192, 8192, 16 * 1024
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
print("BLOCK_M BLOCK_N BLOCK_K num_warps instr_shape_n occupancy time (ms) tflops/s")
for BLOCK_M, BLOCK_N, BLOCK_K, num_warps, instr_shape_n, occupancy in configs:
fn = lambda: blocked_matmul(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, False, num_warps)
ms = triton.testing.do_bench(fn)
flops = 2 * M * N * K
tflops_per_sec = flops * 1e-12 / (ms * 1e-3)
print(f"{BLOCK_M:>7} {BLOCK_N:>7} {BLOCK_K:>7} {num_warps:>9} {instr_shape_n:>13} "
f"{occupancy:>9} {ms:>9.2f} {tflops_per_sec:>8.2f}")
print()
@gluon.jit
def blocked_matmul_pipelined_kernel(a_desc, b_desc, c_desc, num_warps: gl.constexpr):
BLOCK_M: gl.constexpr = c_desc.block_type.shape[0]
BLOCK_N: gl.constexpr = c_desc.block_type.shape[1]
BLOCK_K: gl.constexpr = a_desc.block_type.shape[1]
dtype: gl.constexpr = a_desc.dtype
K = a_desc.shape[1]
a_smem = gl.allocate_shared_memory(dtype, [2] + a_desc.block_type.shape, a_desc.layout)
b_smem = gl.allocate_shared_memory(dtype, [2] + b_desc.block_type.shape, b_desc.layout)
index = 0
pid_m = gl.program_id(axis=0)
pid_n = gl.program_id(axis=1)
off_m = pid_m * BLOCK_M
off_n = pid_n * BLOCK_N
mma_layout: gl.constexpr = pick_wgmma_layout(dtype, BLOCK_M, BLOCK_N, num_warps)
acc = warpgroup_mma_init(gl.zeros((BLOCK_M, BLOCK_N), dtype=gl.float32, layout=mma_layout))
bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(bar, count=1)
phase = 0
for k in range(0, K, BLOCK_K):
a = a_smem.index(index)
b = b_smem.index(index)
mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes)
tma.async_copy_global_to_shared(a_desc, [off_m, k], bar, a)
tma.async_copy_global_to_shared(b_desc, [k, off_n], bar, b)
mbarrier.wait(bar, phase=phase)
phase ^= 1
acc = warpgroup_mma_wait(num_outstanding=0, deps=(acc, ))
acc = warpgroup_mma(a, b, acc, is_async=True)
index ^= 1
acc = warpgroup_mma_wait(num_outstanding=0, deps=(acc, ))
mbarrier.invalidate(bar)
c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout)
c_smem.store(acc.to(dtype))
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [off_m, off_n], c_smem)
tma.store_wait(pendings=0)
def blocked_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_warps):
M, N = C.shape
a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16)
b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], gl.float16)
c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16)
a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout)
b_desc = TensorDescriptor.from_tensor(B, [BLOCK_K, BLOCK_N], b_layout)
c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout)
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
blocked_matmul_pipelined_kernel[grid](a_desc, b_desc, c_desc, num_warps=num_warps)
@pytest.mark.parametrize("M, N, K", [(208, 416, 304), (2000, 1000, 2000)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(64, 64, 64), (128, 128, 128)])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper")
def test_blocked_matmul_pipelined(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps):
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
blocked_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_warps)
torch.testing.assert_close(A @ B, C, rtol=1e-3, atol=1e-1)
if __name__ == "__main__":
print("Benchmarking pipelined matmul")
print("=============================")
configs = find_configs(occupancy=1, dtype=torch.float16, num_buffers=2)
configs += find_configs(occupancy=2, dtype=torch.float16, num_buffers=2)
configs.append([64, 256, 128, 4, 256, 2])
print("BLOCK_M BLOCK_N BLOCK_K num_warps instr_shape_n occupancy time (ms) tflops/s")
for BLOCK_M, BLOCK_N, BLOCK_K, num_warps, instr_shape_n, occupancy in configs:
fn = lambda: blocked_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_warps)
ms = triton.testing.do_bench(fn)
flops = 2 * M * N * K
tflops_per_sec = flops * 1e-12 / (ms * 1e-3)
print(f"{BLOCK_M:>7} {BLOCK_N:>7} {BLOCK_K:>7} {num_warps:>9} {instr_shape_n:>13} "
f"{occupancy:>9} {ms:>9.2f} {tflops_per_sec:>8.2f}")
print()