"""
Persistent Matmul
=====================
This script demonstrates persistent kernel implementations of matrix multiplication using Triton.
Various matmul methods are included, such as naive, persistent, and TMA (Tensor Memory Accelerator) based approaches.
The kernels support both FP16 and FP8 data types but the FP8 implementation is only available on CUDA devices with compute capability >= 9.0.
Triton and cuBLAS implementations are benchmarked under different configurations and evaluated using the proton profiler.
Users can pass command-line arguments to specify matrix dimensions and iteration steps flexibly.
.. code-block:: bash
# FP8
python 09-persistent-matmul.py --prec fp8 --K_range 128 1024 --K_step 128
# FP16
python 09-persistent-matmul.py --prec fp16 --K_range 128 1024 --K_step 128
Note that currently this tutorial will fail on devices with a small shared memory size, such as RTX-4090.
"""
import argparse
import itertools
import torch
import triton
import triton.language as tl
import triton.profiler as proton
from triton.tools.tensor_descriptor import TensorDescriptor
from contextlib import contextmanager
from typing import Optional
if torch.cuda.is_available():
from triton._C.libtriton import nvidia
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
cublas = nvidia.cublas.CublasLt(cublas_workspace)
else:
cublas = None
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def supports_tma():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
def is_hopper():
return torch.cuda.get_device_capability()[0] == 9
def supports_ws():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
def _matmul_launch_metadata(grid, kernel, args):
ret = {}
M, N, K, WS = args["M"], args["N"], args["K"], args.get("WARP_SPECIALIZE", False)
ws_str = "_ws" if WS else ""
ret["name"] = f"{kernel.name}{ws_str} [M={M}, N={N}, K={K}]"
if "c_ptr" in args:
bytes_per_elem = args["c_ptr"].element_size()
else:
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K
ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N)
return ret
HAS_TENSOR_DESC = supports_tma() and hasattr(tl, "make_tensor_descriptor")
HAS_HOST_TENSOR_DESC = supports_tma() and hasattr(triton.tools.tensor_descriptor, "TensorDescriptor")
HAS_WARP_SPECIALIZE = supports_ws() and HAS_TENSOR_DESC
def matmul_get_configs(pre_hook=None):
return [
triton.Config({'BLOCK_SIZE_M': BM, 'BLOCK_SIZE_N': BN, "BLOCK_SIZE_K": BK, "GROUP_SIZE_M": 8}, num_stages=s,
num_warps=w, pre_hook=pre_hook)
for BM in [128]
for BN in [128, 256]
for BK in [64, 128]
for s in ([2, 3, 4])
for w in [4, 8]
]
@triton.autotune(
configs=matmul_get_configs(),
key=["M", "N", "K"],
)
@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel(a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if (c_ptr.dtype.element_ty == tl.float8e4nv):
c = accumulator.to(tl.float8e4nv)
else:
c = accumulator.to(tl.float16)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def matmul(a, b):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
M, K = a.shape
K, N = b.shape
dtype = a.dtype
c = torch.empty((M, N), device=a.device, dtype=dtype)
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
)
return c
def matmul_tma_set_block_size_hook(nargs):
EPILOGUE_SUBTILE = nargs.get("EPILOGUE_SUBTILE", False)
BLOCK_M = nargs["BLOCK_SIZE_M"]
BLOCK_N = nargs["BLOCK_SIZE_N"]
BLOCK_K = nargs["BLOCK_SIZE_K"]
nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K]
if EPILOGUE_SUBTILE:
nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N // 2]
else:
nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
@triton.autotune(
configs=matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook),
key=["M", "N", "K", "WARP_SPECIALIZE"],
)
@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel_tma(a_desc, b_desc, c_desc,
M, N, K,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
FP8_OUTPUT: tl.constexpr,
WARP_SPECIALIZE: tl.constexpr,
):
dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in tl.range(k_tiles, warp_specialize=WARP_SPECIALIZE):
offs_k = k * BLOCK_SIZE_K
a = a_desc.load([offs_am, offs_k])
b = b_desc.load([offs_bn, offs_k])
accumulator = tl.dot(a, b.T, accumulator)
c = accumulator.to(dtype)
offs_cm = pid_m * BLOCK_SIZE_M
offs_cn = pid_n * BLOCK_SIZE_N
c_desc.store([offs_cm, offs_cn], c)
def matmul_tma(a, b, warp_specialize: bool):
assert a.shape[1] == b.shape[1], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
M, K = a.shape
N, K = b.shape
dtype = a.dtype
c = torch.empty((M, N), device=a.device, dtype=dtype)
dummy_block = [1, 1]
a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
def grid(META):
BLOCK_M = META["BLOCK_SIZE_M"]
BLOCK_N = META["BLOCK_SIZE_N"]
return (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), )
matmul_kernel_tma[grid](
a_desc, b_desc, c_desc,
M, N, K,
FP8_OUTPUT=dtype == torch.float8_e4m3fn,
WARP_SPECIALIZE=warp_specialize,
)
return c
@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
return pid_m, pid_n
@triton.autotune(
configs=matmul_get_configs(),
key=["M", "N", "K"],
)
@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_SMS: tl.constexpr,
):
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
tile_id_c = start_pid - NUM_SMS
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
offs_am = tl.where(offs_am < M, offs_am, 0)
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
if (c_ptr.dtype.element_ty == tl.float8e4nv):
c = accumulator.to(tl.float8e4nv)
else:
c = accumulator.to(tl.float16)
tl.store(c_ptrs, c, mask=c_mask)
def matmul_persistent(a, b):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
M, K = a.shape
K, N = b.shape
dtype = a.dtype
c = torch.empty((M, N), device=a.device, dtype=dtype)
grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
matmul_kernel_persistent[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
NUM_SMS=NUM_SMS,
)
return c
def matmul_tma_persistent_get_configs(pre_hook=None):
return [
triton.Config(
{
'BLOCK_SIZE_M': BM, 'BLOCK_SIZE_N': BN, "BLOCK_SIZE_K": BK, "GROUP_SIZE_M": 8, "EPILOGUE_SUBTILE":
SUBTILE
}, num_stages=s, num_warps=w, pre_hook=pre_hook)
for BM in [128]
for BN in [128, 256]
for BK in [64, 128]
for s in ([2, 3, 4])
for w in [4, 8]
for SUBTILE in [True, False]
]
@triton.autotune(
configs=matmul_tma_persistent_get_configs(pre_hook=matmul_tma_set_block_size_hook),
key=["M", "N", "K", "WARP_SPECIALIZE"],
)
@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel_tma_persistent(a_desc, b_desc, c_desc,
M, N, K,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
FP8_OUTPUT: tl.constexpr,
EPILOGUE_SUBTILE: tl.constexpr,
NUM_SMS: tl.constexpr,
WARP_SPECIALIZE: tl.constexpr,
):
dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
tile_id_c = start_pid - NUM_SMS
num_pid_in_group = GROUP_SIZE_M * num_pid_n
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE):
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
offs_k = ki * BLOCK_SIZE_K
a = a_desc.load([offs_am, offs_k])
b = b_desc.load([offs_bn, offs_k])
accumulator = tl.dot(a, b.T, accumulator)
tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
offs_am_c = pid_m * BLOCK_SIZE_M
offs_bn_c = pid_n * BLOCK_SIZE_N
if EPILOGUE_SUBTILE:
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
acc = tl.permute(acc, (0, 2, 1))
acc0, acc1 = tl.split(acc)
c0 = acc0.to(dtype)
c_desc.store([offs_am_c, offs_bn_c], c0)
c1 = acc1.to(dtype)
c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1)
else:
accumulator = accumulator.to(dtype)
c_desc.store([offs_am_c, offs_bn_c], accumulator)
def matmul_tma_persistent(a, b, warp_specialize: bool):
assert a.shape[1] == b.shape[1], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
M, K = a.shape
N, K = b.shape
dtype = a.dtype
c = torch.empty((M, N), device=a.device, dtype=dtype)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
dummy_block = [1, 1]
a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
def grid(META):
nonlocal a_desc, b_desc, c_desc
BLOCK_M = META["BLOCK_SIZE_M"]
BLOCK_N = META["BLOCK_SIZE_N"]
return (min(
NUM_SMS,
triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
), )
matmul_kernel_tma_persistent[grid](
a_desc, b_desc, c_desc,
M, N, K,
FP8_OUTPUT=dtype == torch.float8_e4m3fn,
NUM_SMS=NUM_SMS,
WARP_SPECIALIZE=warp_specialize,
)
return c
def prune_invalid_configs(configs, named_args, **kwargs):
FLATTEN = kwargs["FLATTEN"]
return [conf for conf in configs if not (conf.kwargs.get("EPILOGUE_SUBTILE", True) and FLATTEN is False)]
@triton.autotune(configs=matmul_tma_persistent_get_configs(), key=["M", "N", "K", "WARP_SPECIALIZE", "FLATTEN"],
prune_configs_by={'early_config_prune': prune_invalid_configs})
@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel_descriptor_persistent(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
EPILOGUE_SUBTILE: tl.constexpr,
NUM_SMS: tl.constexpr,
WARP_SPECIALIZE: tl.constexpr,
FLATTEN: tl.constexpr,
):
dtype = c_ptr.dtype.element_ty
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
a_desc = tl.make_tensor_descriptor(
a_ptr,
shape=[M, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
b_desc = tl.make_tensor_descriptor(
b_ptr,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
c_desc = tl.make_tensor_descriptor(
c_ptr,
shape=[M, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N if not EPILOGUE_SUBTILE else BLOCK_SIZE_N // 2],
)
tile_id_c = start_pid - NUM_SMS
num_pid_in_group = GROUP_SIZE_M * num_pid_n
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=WARP_SPECIALIZE):
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
offs_k = ki * BLOCK_SIZE_K
a = a_desc.load([offs_am, offs_k])
b = b_desc.load([offs_bn, offs_k])
accumulator = tl.dot(a, b.T, accumulator)
tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
offs_cm = pid_m * BLOCK_SIZE_M
offs_cn = pid_n * BLOCK_SIZE_N
if EPILOGUE_SUBTILE:
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
acc = tl.permute(acc, (0, 2, 1))
acc0, acc1 = tl.split(acc)
c0 = acc0.to(dtype)
c_desc.store([offs_cm, offs_cn], c0)
c1 = acc1.to(dtype)
c_desc.store([offs_cm, offs_cn + BLOCK_SIZE_N // 2], c1)
else:
c = accumulator.to(dtype)
c_desc.store([offs_cm, offs_cn], c)
def matmul_descriptor_persistent(a, b, warp_specialize: bool):
assert a.shape[1] == b.shape[1], "Incompatible dimensions"
assert a.dtype == b.dtype, "Incompatible dtypes"
M, K = a.shape
N, K = b.shape
dtype = a.dtype
c = torch.empty((M, N), device=a.device, dtype=dtype)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
return torch.empty(size, device="cuda", dtype=torch.int8)
triton.set_allocator(alloc_fn)
flatten = False if (warp_specialize and is_hopper()) else True
grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), )
matmul_kernel_descriptor_persistent[grid](
a,
b,
c,
M,
N,
K,
NUM_SMS=NUM_SMS,
WARP_SPECIALIZE=warp_specialize,
FLATTEN=flatten,
)
return c
def cublas_matmul(a, b):
assert a.shape[1] == b.shape[1], "Incompatible dimensions"
M, K = a.shape
N, K = b.shape
dtype = a.dtype
c = torch.empty((M, N), device=a.device, dtype=dtype)
bytes_per_elem = a.element_size()
flops_str = f"flops{bytes_per_elem * 8}"
with proton.scope(f"cublas [M={M}, N={N}, K={K}]",
{"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}):
cublas.matmul(a, b, c)
return c
def torch_matmul(a, b):
M, K = a.shape
N, K = b.shape
bytes_per_elem = a.element_size()
flops_str = f"flops{bytes_per_elem * 8}"
with proton.scope(f"torch [M={M}, N={N}, K={K}]",
{"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}):
c = torch.matmul(a, b.T)
return c
@contextmanager
def proton_context():
proton.activate(0)
try:
yield
finally:
proton.deactivate(0)
def bench_fn(label, reps, warmup_reps, fn, *args):
print(f"Benchmarking {label}: ...", end="")
for _ in range(warmup_reps):
fn(*args)
with proton_context():
for _ in range(reps):
fn(*args)
print(f"\rBenchmarking {label}: done")
def bench(K, dtype, reps=10000, warmup_reps=10000):
M = 8192
N = 8192
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
b = b.T.contiguous()
if cublas is not None:
bench_fn("cublas", reps, warmup_reps, cublas_matmul, a, b)
if dtype == torch.float16:
bench_fn("torch", reps, warmup_reps, torch_matmul, a, b)
bench_fn("naive", reps, warmup_reps, matmul, a, b.T)
bench_fn("persistent", reps, warmup_reps, matmul_persistent, a, b.T)
warp_specialize = [False, True] if HAS_WARP_SPECIALIZE else [False]
for ws in warp_specialize:
ws_str = "_ws" if ws else ""
if HAS_HOST_TENSOR_DESC and not (is_hopper() and ws):
bench_fn(f"tma_persistent{ws_str}", reps, warmup_reps, lambda a, b: matmul_tma_persistent(a, b, ws), a, b)
bench_fn(f"tma{ws_str}", reps, warmup_reps, lambda a, b: matmul_tma(a, b, ws), a, b)
if HAS_TENSOR_DESC:
bench_fn(f"descriptor_persistent{ws_str}", reps, warmup_reps,
lambda a, b: matmul_descriptor_persistent(a, b, ws), a, b)
def run_test(expect, fn, a, b, label, enabled=True):
print(f" {label}: ...", end="")
if enabled:
actual = fn(a, b)
passed = torch.allclose(expect, actual.to(expect.dtype), atol=1.0)
icon = "✅" if passed else "❌"
else:
icon = "⭕"
print(f"\r {label}: {icon} ")
def validate(M, N, K, dtype):
print(f"{M=}, {N=}, {K=}, verification naive vs: ")
a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype)
b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype)
b = b.T.contiguous()
naive_result = matmul(a, b.T).to(torch.float16)
run_test(naive_result, torch_matmul, a, b, "Torch", enabled=dtype == torch.float16)
run_test(naive_result, cublas_matmul, a, b, "cuBLAS", enabled=cublas is not None)
run_test(naive_result, matmul_persistent, a, b.T, "Persistent")
kernels = [
(matmul_tma, "TMA", HAS_HOST_TENSOR_DESC),
(matmul_tma_persistent, "TMA Persistent", HAS_HOST_TENSOR_DESC),
(matmul_descriptor_persistent, "Tensor Descriptor Persistent", HAS_TENSOR_DESC),
]
warp_specialize = [False, True] if HAS_WARP_SPECIALIZE else [False]
for (kernel, label, enabled), warp_specialize in itertools.product(kernels, warp_specialize):
label = f"{label} (warp_specialize={warp_specialize})"
skipped = is_hopper() and warp_specialize and kernel != matmul_descriptor_persistent
enabled = enabled and (not warp_specialize or HAS_TENSOR_DESC) and (not skipped)
run_test(naive_result, lambda a, b: kernel(a, b, warp_specialize), a, b, label, enabled)
print()
def show_profile(precision, profile_name):
import triton.profiler.viewer as proton_viewer
metric_names = ["time/ms"]
if precision == 'fp8':
metric_names = ["tflop8/s"] + metric_names
elif precision == 'fp16':
metric_names = ["tflop16/s"] + metric_names
file_name = f"{profile_name}.hatchet"
tree, metrics = proton_viewer.parse(metric_names, file_name)
proton_viewer.print_tree(tree, metrics)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-K", type=int, required=False, default=512)
parser.add_argument("--K_range", type=int, nargs=2)
parser.add_argument("--K_step", type=int, default=512)
parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16")
args = parser.parse_args()
if args.prec == 'fp8' and (not hasattr(torch, "float8_e4m3fn") or not is_cuda()):
print("This example requires CUDA with fp8 support.")
else:
dtype = torch.float8_e4m3fn if args.prec == 'fp8' else torch.float16
if args.K and args.K_range is None:
args.K_range = [args.K, args.K]
args.K_step = 1
torch.manual_seed(0)
validate(32, 32, 32, dtype)
validate(8192, 8192, args.K_range[0], dtype)
proton.start("matmul", hook="triton")
proton.deactivate()
for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):
bench(K, dtype)
proton.finalize()
show_profile(args.prec, "matmul")