"""
Group GEMM
============================
This group gemm kernel launches a fixed number of CTA to compute a group
of gemms. The scheduling is static and we do it on device.
"""
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'NUM_SM': 84,
}),
triton.Config({
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'NUM_SM': 128,
}),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'NUM_SM': 84,
}),
triton.Config({
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'NUM_SM': 128,
}),
],
key=['group_size'],
)
@triton.jit
def grouped_matmul_kernel(
group_a_ptrs,
group_b_ptrs,
group_c_ptrs,
group_gemm_sizes,
g_lds,
group_size,
NUM_SM: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
tile_idx = tl.program_id(0)
last_problem_end = 0
for g in range(group_size):
gm = tl.load(group_gemm_sizes + g * 3)
gn = tl.load(group_gemm_sizes + g * 3 + 1)
gk = tl.load(group_gemm_sizes + g * 3 + 2)
num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles):
k = gk
lda = tl.load(g_lds + g * 3)
ldb = tl.load(g_lds + g * 3 + 1)
ldc = tl.load(g_lds + g * 3 + 2)
a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16))
b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16))
c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16))
tile_idx_in_gemm = tile_idx - last_problem_end
tile_m_idx = tile_idx_in_gemm // num_n_tiles
tile_n_idx = tile_idx_in_gemm % num_n_tiles
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]
b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
tl.multiple_of(a_ptrs, [16, 16])
tl.multiple_of(b_ptrs, [16, 16])
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K * ldb
c = accumulator.to(tl.float16)
offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]
tl.store(c_ptrs, c)
tile_idx += NUM_SM
last_problem_end = last_problem_end + num_tiles
def group_gemm_fn(group_A, group_B):
device = torch.device('cuda')
assert len(group_A) == len(group_B)
group_size = len(group_A)
A_addrs = []
B_addrs = []
C_addrs = []
g_sizes = []
g_lds = []
group_C = []
for i in range(group_size):
A = group_A[i]
B = group_B[i]
assert A.shape[1] == B.shape[0]
M, K = A.shape
K, N = B.shape
C = torch.empty((M, N), device=device, dtype=A.dtype)
group_C.append(C)
A_addrs.append(A.data_ptr())
B_addrs.append(B.data_ptr())
C_addrs.append(C.data_ptr())
g_sizes += [M, N, K]
g_lds += [A.stride(0), B.stride(0), C.stride(0)]
d_a_ptrs = torch.tensor(A_addrs, device=device)
d_b_ptrs = torch.tensor(B_addrs, device=device)
d_c_ptrs = torch.tensor(C_addrs, device=device)
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device)
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device)
grid = lambda META: (META['NUM_SM'], )
grouped_matmul_kernel[grid](
d_a_ptrs,
d_b_ptrs,
d_c_ptrs,
d_g_sizes,
d_g_lds,
group_size,
)
return group_C
group_m = [1024, 512, 256, 128]
group_n = [1024, 512, 256, 128]
group_k = [1024, 512, 256, 128]
group_A = []
group_B = []
assert len(group_m) == len(group_n)
assert len(group_n) == len(group_k)
group_size = len(group_m)
for i in range(group_size):
M = group_m[i]
N = group_n[i]
K = group_k[i]
A = torch.rand((M, K), device="cuda", dtype=torch.float16)
B = torch.rand((K, N), device="cuda", dtype=torch.float16)
group_A.append(A)
group_B.append(B)
tri_out = group_gemm_fn(group_A, group_B)
ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)]
for i in range(group_size):
assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=0)
def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size):
grid = lambda META: (META['NUM_SM'], )
grouped_matmul_kernel[grid](
a_ptrs,
b_ptrs,
c_ptrs,
sizes,
lds,
group_size,
)
def torch_perf_fn(group_A, group_B):
for a, b in zip(group_A, group_B):
torch.matmul(a, b)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
x_vals=[2**i for i in range(7, 11)],
line_arg='provider',
line_vals=['cublas', 'triton'],
line_names=["cuBLAS", "Triton"],
styles=[('green', '-'), ('blue', '-')],
ylabel="runtime(ms)",
plot_name="group-gemm-performance",
args={},
))
def benchmark(N, provider):
group_size = 4
group_A = []
group_B = []
A_addrs = []
B_addrs = []
C_addrs = []
g_sizes = []
g_lds = []
group_C = []
for i in range(group_size):
A = torch.rand((N, N), device="cuda", dtype=torch.float16)
B = torch.rand((N, N), device="cuda", dtype=torch.float16)
C = torch.empty((N, N), device="cuda", dtype=torch.float16)
group_A.append(A)
group_B.append(B)
group_C.append(C)
A_addrs.append(A.data_ptr())
B_addrs.append(B.data_ptr())
C_addrs.append(C.data_ptr())
g_sizes += [N, N, N]
g_lds += [N, N, N]
d_a_ptrs = torch.tensor(A_addrs, device="cuda")
d_b_ptrs = torch.tensor(B_addrs, device="cuda")
d_c_ptrs = torch.tensor(C_addrs, device="cuda")
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="cuda")
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles)
return ms, max_ms, min_ms
benchmark.run(show_plots=True, print_data=True)