"""
Persistent Matmul
=====================
This script demonstrates persistent kernel implementations of matrix multiplication using Triton.
Various matmul methods are included, such as naive, persistent, and TMA (Tensor Memory Accelerator) based approaches.
The kernels support both FP16 and FP8 data types but the FP8 implementation is only available on CUDA devices with compute capability >= 9.0.
Triton and cuBLAS implementations are benchmarked under different configurations and evaluated using the proton profiler.
Users can pass command-line arguments to specify matrix dimensions and iteration steps flexibly.
.. code-block:: bash
# FP8
python 09-persistent-matmul.py --prec fp8 --K_range 128 1024 --K_step 128
# FP16
python 09-persistent-matmul.py --prec fp16 --K_range 128 1024 --K_step 128
Note that currently this tutorial will fail on devices with a small shared memory size, such as RTX-4090.
"""
import argparse
import time
import torch
import triton
import triton.language as tl
import triton.tools.experimental_descriptor
import triton.profiler as proton
if torch.cuda.is_available():
from triton._C.libtriton import nvidia
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
cublas = nvidia.cublas.CublasLt(cublas_workspace)
else:
cublas = None
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def supports_tma():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
def _matmul_launch_metadata(grid, kernel, args):
ret = {}
M, N, K = args["M"], args["N"], args["K"]
ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]"
if "c_ptr" in args:
bytes_per_elem = args["c_ptr"].element_size()
else:
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K
ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N)
return ret
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
if HAS_TMA_DESC:
print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", )
else:
print("TMA benchmarks will be running without grid constant TMA descriptor.", )
class TmaAutoTuneHelper:
class KernelParamWrapper:
def __init__(self, desc):
self.desc = desc
def tma_desc_cpu_ptr(self):
return self.desc.data_ptr()
TMA_SIZE = 128
def __init__(self):
self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor)
self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor)
if HAS_TMA_DESC:
self.descriptors = {}
else:
self.cuda_descriptors = {}
def init_tma_descriptor(self, name):
if HAS_TMA_DESC:
self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8)
else:
self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8)
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr())
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr())
desc_x.copy_(buf_x, non_blocking=True)
def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr())
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr())
desc_x.copy_(buf_x, non_blocking=True)
def get_tma_descriptor_kernel_param(self, name):
if HAS_TMA_DESC:
assert self.descriptors[name] is not None
return self.KernelParamWrapper(self.descriptors[name])
else:
assert self.cuda_descriptors[name] is not None
return self.cuda_descriptors[name]
@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel(a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if (c_ptr.dtype.element_ty == tl.float8e4nv):
c = accumulator.to(tl.float8e4nv)
else:
c = accumulator.to(tl.float16)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def matmul(a, b):
configs = {
torch.float8_e4m3fn: {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4,
"num_warps": 8
}, torch.float16: {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3,
"num_warps": 8
}
}
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
M, K = a.shape
K, N = b.shape
dtype = a.dtype
c = torch.empty((M, N), device=a.device, dtype=dtype)
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"],
BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"],
BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"],
GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"],
num_stages=configs[dtype]["num_stages"],
num_warps=configs[dtype]["num_warps"],
)
return c
@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_SMS: tl.constexpr,
):
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1
tile_id = start_pid - NUM_SMS
ki = -1
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
pid_m = 0
pid_n = 0
offs_am = tl.arange(0, BLOCK_SIZE_M)
offs_bn = tl.arange(0, BLOCK_SIZE_N)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
if ki == k_tiles - 1:
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
if (c_ptr.dtype.element_ty == tl.float8e4nv):
c = accumulator.to(tl.float8e4nv)
else:
c = accumulator.to(tl.float16)
tl.store(c_ptrs, c, mask=c_mask)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
def matmul_persistent(a, b):
configs = {
torch.float8_e4m3fn: {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4,
"num_warps": 8
}, torch.float16: {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3,
"num_warps": 8
}
}
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
M, K = a.shape
K, N = b.shape
dtype = a.dtype
c = torch.empty((M, N), device=a.device, dtype=dtype)
grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
matmul_kernel_persistent[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"],
BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"],
BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"],
GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"],
NUM_SMS=NUM_SMS,
num_stages=configs[dtype]["num_stages"],
num_warps=configs[dtype]["num_warps"],
)
return c
@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr,
M, N, K,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
FP8_OUTPUT: tl.constexpr,
NUM_SMS: tl.constexpr):
dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1
tile_id = start_pid - NUM_SMS
ki = -1
pid_m = 0
pid_n = 0
offs_am = 0
offs_bn = 0
num_pid_in_group = GROUP_SIZE_M * num_pid_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N
offs_k = ki * BLOCK_SIZE_K
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype)
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype)
accumulator = tl.dot(a, b.T, accumulator)
if ki == k_tiles - 1:
c = accumulator.to(dtype)
tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn])
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
def matmul_tma_persistent(a, b):
configs = {
torch.float8_e4m3fn: {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4,
"num_warps": 8
}, torch.float16: {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3,
"num_warps": 8
}
}
assert a.shape[1] == b.shape[1], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
M, K = a.shape
N, K = b.shape
dtype = a.dtype
c = torch.empty((M, N), device=a.device, dtype=dtype)
desc_a = triton.tools.experimental_descriptor.create_2d_tma_descriptor(a.data_ptr(), M, K,
configs[dtype]["BLOCK_SIZE_M"],
configs[dtype]["BLOCK_SIZE_K"],
a.element_size())
desc_b = triton.tools.experimental_descriptor.create_2d_tma_descriptor(b.data_ptr(), N, K,
configs[dtype]["BLOCK_SIZE_N"],
configs[dtype]["BLOCK_SIZE_K"],
b.element_size())
desc_c = triton.tools.experimental_descriptor.create_2d_tma_descriptor(c.data_ptr(), M, N,
configs[dtype]["BLOCK_SIZE_M"],
configs[dtype]["BLOCK_SIZE_N"],
c.element_size())
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
matmul_kernel_tma_persistent[grid](
desc_a, desc_b, desc_c,
M, N, K,
BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"],
BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"],
BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"],
GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"],
FP8_OUTPUT=dtype == torch.float8_e4m3fn,
NUM_SMS=NUM_SMS,
num_stages=configs[dtype]["num_stages"],
num_warps=configs[dtype]["num_warps"],
)
return c
@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel_device_tma_persistent(workspace_ptr,
tiles_per_update: tl.constexpr,
a_ptr, b_ptr, c_ptr,
M, N, K,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_SMS: tl.constexpr):
dtype = c_ptr.dtype.element_ty
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
TMA_SIZE: tl.constexpr = 128
workspace_base = workspace_ptr + start_pid * 3 * TMA_SIZE
a_desc_ptr = workspace_base
b_desc_ptr = workspace_base + TMA_SIZE
c_desc_ptr = workspace_base + 2 * TMA_SIZE
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], global_size=[M, K],
element_ty=a_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr,
load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K], global_size=[N, K],
element_ty=b_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], global_size=[M, N],
element_ty=c_ptr.dtype.element_ty)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1
tile_id = start_pid - NUM_SMS
ki = -1
ni = -1
pid_m = 0
pid_n = 0
offs_am = 0
offs_bn = 0
num_pid_in_group = GROUP_SIZE_M * num_pid_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
ni += 1
if ni == tiles_per_update:
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=a_desc_ptr, global_address=a_ptr,
load_size=[BLOCK_SIZE_M,
BLOCK_SIZE_K], global_size=[M, K],
element_ty=a_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=b_desc_ptr, global_address=b_ptr,
load_size=[BLOCK_SIZE_N,
BLOCK_SIZE_K], global_size=[N, K],
element_ty=b_ptr.dtype.element_ty)
tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=c_desc_ptr, global_address=c_ptr,
load_size=[BLOCK_SIZE_M,
BLOCK_SIZE_N], global_size=[M, N],
element_ty=c_ptr.dtype.element_ty)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
ni = 0
tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N
offs_k = ki * BLOCK_SIZE_K
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype)
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype)
accumulator = tl.dot(a, b.T, accumulator)
if ki == k_tiles - 1:
c = accumulator.to(dtype)
tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn])
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
def matmul_device_tma_persistent(a, b, tiles_per_update):
configs = {
torch.float8_e4m3fn: {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4,
"num_warps": 8
}, torch.float16: {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3,
"num_warps": 8
}
}
assert a.shape[1] == b.shape[1], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
M, K = a.shape
N, K = b.shape
dtype = a.dtype
c = torch.empty((M, N), device=a.device, dtype=dtype)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
tma_size = 128
workspace = torch.empty(NUM_SMS * 3 * tma_size, dtype=torch.uint8, device="cuda")
grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
matmul_kernel_device_tma_persistent[grid](
workspace,
tiles_per_update,
a, b, c,
M, N, K,
BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"],
BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"],
BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"],
GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"],
NUM_SMS=NUM_SMS,
num_stages=configs[dtype]["num_stages"],
num_warps=configs[dtype]["num_warps"],
)
return c
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"NUM_CONSUMER_GROUPS": 2,
},
num_stages=2,
num_warps=4,
num_consumer_groups=2,
num_buffers_warp_spec=3,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"NUM_CONSUMER_GROUPS": 1,
},
num_stages=3,
num_warps=4,
num_consumer_groups=0,
num_buffers_warp_spec=3,
),
],
key=["M", "N", "K"],
use_cuda_graph=True,
)
@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_persistent_tma_ws_cooperative_kernel(
a_desc_ptr,
b_desc_ptr,
c_desc_ptr,
M,
N,
K,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
FP8_OUTPUT: tl.constexpr,
NUM_CONSUMER_GROUPS: tl.constexpr,
):
dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
num_tiles = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N)
for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)):
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N
offs_k = 0
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl._experimental_descriptor_load(
a_desc_ptr,
[offs_am, offs_k],
[BLOCK_SIZE_M, BLOCK_SIZE_K],
dtype,
)
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype)
accumulator = tl.dot(a, b.T, accumulator)
offs_k += BLOCK_SIZE_K
c = accumulator.to(dtype)
tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn])
def matmul_persistent_tma_ws_cooperative(a, b):
assert a.shape[1] == b.shape[1], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
M, K = a.shape
N, K = b.shape
dtype = a.dtype
c = torch.empty((M, N), device=a.device, dtype=dtype)
desc_helper = TmaAutoTuneHelper()
desc_helper.init_tma_descriptor("a")
desc_helper.init_tma_descriptor("b")
desc_helper.init_tma_descriptor("c")
def grid(META):
nonlocal desc_helper
desc_helper.fill_2d_tma_descriptor(
"a",
a.data_ptr(),
M,
K,
META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"],
META["BLOCK_SIZE_K"],
a.element_size(),
)
desc_helper.fill_2d_tma_descriptor(
"b",
b.data_ptr(),
N,
K,
META["BLOCK_SIZE_N"],
META["BLOCK_SIZE_K"],
b.element_size(),
)
desc_helper.fill_2d_tma_descriptor(
"c",
c.data_ptr(),
M,
N,
META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"],
META["BLOCK_SIZE_N"],
c.element_size(),
)
return (min(
NUM_SMS,
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
), )
desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
matmul_persistent_tma_ws_cooperative_kernel[grid](
desc_a, desc_b, desc_c,
M, N, K,
FP8_OUTPUT=dtype == torch.float8_e4m3fn,
)
return c
def cublas_matmul(a, b):
assert a.shape[1] == b.shape[1], "Incompatible dimensions"
M, K = a.shape
N, K = b.shape
dtype = a.dtype
c = torch.empty((M, N), device=a.device, dtype=dtype)
bytes_per_elem = a.element_size()
flops_str = f"flops{bytes_per_elem * 8}"
with proton.scope(f"cublas [M={M}, N={N}, K={K}]",
{"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}):
cublas.matmul(a, b, c)
return c
def torch_matmul(a, b):
M, K = a.shape
N, K = b.shape
bytes_per_elem = a.element_size()
flops_str = f"flops{bytes_per_elem * 8}"
with proton.scope(f"torch [M={M}, N={N}, K={K}]",
{"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}):
c = torch.matmul(a, b.T)
return c
def bench(K, dtype, tiles_per_update, reps=10):
M = 8192
N = 8192
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
b = b.T.contiguous()
proton.activate(0)
if cublas is not None:
for _ in range(reps):
cublas_matmul(a, b)
time.sleep(0.01)
if dtype == torch.float16:
for _ in range(reps):
torch_matmul(a, b)
time.sleep(0.01)
for _ in range(reps):
matmul(a, b.T)
time.sleep(0.01)
for _ in range(reps):
matmul_persistent(a, b.T)
time.sleep(0.01)
if supports_tma():
for _ in range(reps):
matmul_tma_persistent(a, b)
time.sleep(0.01)
for _ in range(reps):
matmul_persistent_tma_ws_cooperative(a, b)
time.sleep(0.01)
with proton.scope(
f"matmul_kernel_device_tma_persistent [M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}]"):
for _ in range(reps):
matmul_device_tma_persistent(a, b, tiles_per_update)
time.sleep(0.01)
proton.deactivate(0)
def validate(M, N, K, dtype, tiles_per_update):
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
b = b.T.contiguous()
torch_result = torch_matmul(a, b) if dtype == torch.float16 else None
cublas_result = cublas_matmul(a, b) if cublas is not None else None
naive_result = matmul(a, b.T)
persistent_result = matmul_persistent(a, b.T)
tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None
device_tma_persistent_result = matmul_device_tma_persistent(a, b, tiles_per_update) if supports_tma() else None
matmul_persistent_tma_ws_cooperative_result = matmul_persistent_tma_ws_cooperative(a, b) if supports_tma() else None
if torch_result is not None:
naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16),
atol=1.0) else "❌"
if cublas_result is not None:
naive_vs_cublas = "✅" if torch.allclose(naive_result.to(torch.float16), cublas_result.to(torch.float16),
atol=1.0) else "❌"
naive_vs_persistent = "✅" if torch.allclose(naive_result.to(torch.float16), persistent_result.to(torch.float16),
atol=1.0) else "❌"
if tma_persistent_result is not None:
naive_vs_tma_persistent = "✅" if torch.allclose(cublas_result.to(torch.float16),
tma_persistent_result.to(torch.float16), atol=1.0) else "❌"
if device_tma_persistent_result is not None:
naive_vs_device_tma_persistent = "✅" if torch.allclose(cublas_result.to(
torch.float16), device_tma_persistent_result.to(torch.float16), atol=1.0) else "❌"
if matmul_persistent_tma_ws_cooperative_result is not None:
naive_vs_matmul_persistent_tma_ws_cooperative = "✅" if torch.allclose(
cublas_result.to(torch.float16), matmul_persistent_tma_ws_cooperative_result.to(torch.float16),
atol=1.0) else "❌"
print(f"M={M}, N={N}, K={K} verification naive vs: ", end="")
if torch_result is not None:
print(f"torch: {naive_vs_torch} ", end="")
if cublas_result is not None:
print(f"cublas: {naive_vs_cublas} ", end="")
print(f"persistent: {naive_vs_persistent} ", end="")
if tma_persistent_result is not None:
print(f"TMA persistent: {naive_vs_tma_persistent} ", end="")
if device_tma_persistent_result is not None:
print(f"Device TMA persistent: {naive_vs_device_tma_persistent} ", end="")
if matmul_persistent_tma_ws_cooperative_result is not None:
print(f"TMA persistent with warp specialization: {naive_vs_matmul_persistent_tma_ws_cooperative} ", end="")
print()
def show_profile(precision, profile_name):
import triton.profiler.viewer as proton_viewer
metrics = ["time/ms"]
if precision == 'fp8':
metrics = ["tflop8/s"] + metrics
elif precision == 'fp16':
metrics = ["tflop16/s"] + metrics
file_name = f"{profile_name}.hatchet"
proton_viewer.parse(metrics, file_name, depth=100)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-K", type=int, required=False, default=512)
parser.add_argument("--K_range", type=int, nargs=2)
parser.add_argument("--K_step", type=int, default=512)
parser.add_argument(
"--tiles_per_update",
type=int,
default=1,
help=
"Number of output tiles calculated for each update of the tma descriptor in matmul_device_tma_persistent_kernel",
)
parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16")
args = parser.parse_args()
if args.prec == 'fp8' and (not hasattr(torch, "float8_e4m3fn") or not is_cuda()):
print("This example requires CUDA with fp8 support.")
exit(1)
dtype = torch.float8_e4m3fn if args.prec == 'fp8' else torch.float16
if args.K and args.K_range is None:
args.K_range = [args.K, args.K]
args.K_step = 1
torch.manual_seed(0)
validate(32, 32, 32, dtype, args.tiles_per_update)
validate(8192, 8192, 512, dtype, args.tiles_per_update)
proton.start("matmul", hook="triton")
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
bench(K, dtype, args.tiles_per_update)
proton.finalize()
show_profile(args.prec, "matmul")