import numpy as np
from mlir import ir
from mlir.dialects import arith
from mlir.dialects import func
from mlir.dialects import gpu
from mlir.dialects import memref
from mlir.dialects import nvgpu
from mlir.dialects import nvvm
from mlir.dialects import llvm
from mlir.dialects import builtin
from mlir.dialects import scf
from mlir.dialects import vector
from mlir.extras import types as T
TMA_LAST_DIM_F16 = 64
WARP_SIZE = 32
WARP_GROUP_SIZE = WARP_SIZE * 4
PRODUCER_REGISTER_SIZE = 40
CONSUMER_REGISTER_SIZE = 232
PRODUCER_PRIMARY_THREAD = 128
CONSUMER_PRIMARY_THREAD = 0
MLIR_DYNAMIC = -9223372036854775808
DEBUG = False
class TmaDescriptorBuilder:
"""A class that builds a TMA descriptor."""
def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty):
self.swizzle = swizzle
self.l2promo = l2promo
self.oob = oob
self.interleave = interleave
self.tma_box_shape = tma_box_shape
self.memref_ty = memref_ty
@property
def tensormap_descriptor_ty(self):
"""Returns a tensormap descriptor type."""
tensorMemrefType = ir.MemRefType.get(
self.tma_box_shape,
self.memref_ty.element_type,
memory_space=ir.Attribute.parse("3"),
)
return nvgpu.TensorMapDescriptorType.get(
tensorMemrefType,
self.swizzle,
self.l2promo,
self.oob,
self.interleave,
)
def tma_descriptor_op(self, device_ptr):
"""Returns a tensormap descriptor op."""
tma_descriptor_ty = self.tensormap_descriptor_ty
device_unranked_memref = memref.CastOp(
ir.UnrankedMemRefType.get(
self.memref_ty.element_type, self.memref_ty.memory_space
),
device_ptr,
)
tma_descriptor_op = nvgpu.TmaCreateDescriptorOp(
tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape)
)
return tma_descriptor_op.result
def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
if not DEBUG and not forcePrint:
return
type_formats = []
for arg in args:
ty_format = None
if ir.IndexType.isinstance(arg.type):
ty_format = "%llu"
if ir.IntegerType.isinstance(arg.type):
width = ir.IntegerType(arg.type).width
if width == 64:
ty_format = "%llu"
elif width == 32:
ty_format = "%d"
elif width == 1:
ty_format = "%i"
if ir.F32Type.isinstance(arg.type):
ty_format = "%f"
if ty_format is None:
raise NotImplementedError(arg.type)
type_formats.append(ty_format)
if threadNumber != -1:
tidx = gpu.thread_id(gpu.Dimension.x)
predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
scf.yield_([])
if_op = scf.IfOp(predicate)
with ir.InsertionPoint(if_op.then_block):
gpu.printf(fmt.format(*type_formats) + "\n", args)
scf.yield_([])
def get_type_size(ty):
if ir.FloatType.isinstance(ty):
return ir.FloatType(ty).width // 8
if ir.IntegerType.isinstance(ty):
return ir.IntegerType(ty).width // 8
raise NotImplementedError(ty)
def get_mlir_ty(dtype):
if dtype == np.float16:
return T.f16()
if dtype == np.float32:
return T.f32()
if dtype == np.float64:
return T.f64()
if dtype == np.int32:
return T.i32()
if dtype == np.int64:
return T.i64()
raise NotImplementedError(dtype)
def c(value, ty=None):
ty = T.index() if ty is None else ty
return arith.constant(ty, value)
def make_kernel_name(
input_type=np.float16,
output_type=np.float32,
M=4096,
N=4096,
K=4096,
BLOCK_M=128,
BLOCK_N=128,
BLOCK_K=128,
num_stages=3,
use_warp_specialization=False,
):
kernelName = "warpspecialized" if use_warp_specialization else "multistage"
return (
kernelName
+ "_"
+ str(M)
+ "x"
+ str(N)
+ "x"
+ str(K)
+ "_"
+ str(BLOCK_M)
+ "x"
+ str(BLOCK_N)
+ "x"
+ str(BLOCK_K)
+ "_"
+ str(num_stages)
)
def generate_matmul_ws(
input_type=np.float16,
output_type=np.float32,
M=4096,
N=4096,
K=4096,
BLOCK_M=128,
BLOCK_N=128,
BLOCK_K=128,
num_stages=3,
):
assert input_type == np.float16
assert output_type == np.float32
assert BLOCK_M == 128
assert BLOCK_N == 128
assert BLOCK_K == 64
assert M % BLOCK_M == 0
assert N % BLOCK_N == 0
assert K % BLOCK_K == 0
module = ir.Module.create()
token_ty = gpu.AsyncTokenType.get()
a_elem_ty = get_mlir_ty(input_type)
b_elem_ty = get_mlir_ty(input_type)
c_elem_ty = get_mlir_ty(output_type)
a_ty = ir.MemRefType.get([M, K], a_elem_ty)
b_ty = ir.MemRefType.get((K, N), b_elem_ty)
c_ty = ir.MemRefType.get((M, N), c_elem_ty)
a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
b_tile_shape = (BLOCK_K, BLOCK_N)
txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
)
smem_space_str = "#gpu.address_space<workgroup>"
smem_space = ir.Attribute.parse(smem_space_str)
mbar_ty = ir.Type.parse(
"!nvgpu.mbarrier.group<memorySpace = "
+ str(smem_space)
+ ", num_barriers = "
+ str(num_stages)
+ ">"
)
acc_ty = ir.Type.parse(
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ str(BLOCK_M)
+ "x"
+ str(BLOCK_N)
+ "x"
+ str(c_elem_ty)
+ ">>"
)
a_wgmma_ty = ir.Type.parse(
"!nvgpu.warpgroup.descriptor<tensor=memref<"
+ str(BLOCK_M)
+ "x"
+ str(BLOCK_K)
+ "x"
+ str(a_elem_ty)
+ ", "
+ smem_space_str
+ ">>"
)
b_wgmma_ty = ir.Type.parse(
"!nvgpu.warpgroup.descriptor<tensor=memref<"
+ str(BLOCK_K)
+ "x"
+ str(BLOCK_N)
+ "x"
+ str(a_elem_ty)
+ ", "
+ smem_space_str
+ ">>"
)
kernelName = make_kernel_name(
input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, True
)
with ir.InsertionPoint(module.body):
fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
with ir.InsertionPoint(fop.add_entry_block()):
a_host = fop.arguments[0]
b_host = fop.arguments[1]
c_host = fop.arguments[2]
lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
smem_size = max(smem_size_input, smem_size_output)
t1 = gpu.wait(token_ty, [])
a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
t7 = gpu.wait(token_ty, [t6])
a_tma_desc = TmaDescriptorBuilder(
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
nvgpu.TensorMapOOBKind.OOB_ZERO,
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
a_tma_shape,
a_ty,
)
b_tma_desc = TmaDescriptorBuilder(
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
nvgpu.TensorMapOOBKind.OOB_ZERO,
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
b_tma_shape,
b_ty,
)
a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
cta_m = M // BLOCK_M
cta_n = N // BLOCK_N
assert M % BLOCK_M == 0 and N % BLOCK_N == 0
grid = (cta_m, cta_n, 1)
block = (WARP_GROUP_SIZE * 2, 1, 1)
launch_op = gpu.LaunchOp(
token_ty,
[t7],
*map(c, grid),
*map(c, block),
dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
)
launch_op.body.blocks.append(*([T.index()] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
memref.assume_alignment(c_device, 16)
dynamic_smem = gpu.dynamic_shared_memory(
ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
)
ticks = c(10000000)
tidx = gpu.thread_id(gpu.Dimension.x)
wgPrimaryThread = arith.cmpi(
arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0)
)
warp_id = arith.divui(tidx, c(32))
warpgroup_id = arith.divui(warp_id, c(4))
is_producer = arith.cmpi(
arith.CmpIPredicate.eq,
warpgroup_id,
c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0),
)
is_consumer = arith.cmpi(
arith.CmpIPredicate.eq,
warpgroup_id,
c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1),
)
producerPrimaryThread = arith.cmpi(
arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD)
)
consumerPrimaryThread = arith.cmpi(
arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD)
)
bidx = gpu.block_id(gpu.Dimension.x)
bidy = gpu.block_id(gpu.Dimension.y)
dimX = arith.muli(bidx, c(BLOCK_M))
dimY = arith.muli(bidy, c(BLOCK_N))
mbarTMA = nvgpu.mbarrier_create(mbar_ty)
mbarDONE = nvgpu.mbarrier_create(mbar_ty)
for i in range(num_stages):
nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
gpu.barrier()
nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=wgPrimaryThread)
nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=wgPrimaryThread)
ns = num_stages if num_stages == 1 else num_stages - 1
with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
nvvm.setmaxregister(
PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease
)
for_op = scf.ForOp(
c(0), c(K // BLOCK_K), c(1), [arith.constant(T.bool(), 1)]
)
with ir.InsertionPoint(for_op.body):
phaseParity = for_op.inner_iter_args[0]
iv = for_op.induction_variable
stage = arith.remui(iv, c(num_stages))
debug_print(
"[prod] iv={} | mbarDONE[{}] try_wait phase={}",
iv,
stage,
phaseParity,
predicate=producerPrimaryThread,
)
nvgpu.MBarrierTryWaitParityOp(
mbarDONE, phaseParity, ticks, mbarId=stage
)
debug_print(
"[prod] iv={} | mbarDONE[{}] try_wait phase={} [done]",
iv,
stage,
phaseParity,
predicate=producerPrimaryThread,
)
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
phaseParity = arith.select(
p,
arith.xori(phaseParity, arith.constant(T.bool(), 1)),
phaseParity,
)
a_offset = arith.muli(stage, c(lhs_tile_bytes))
a_tma_slice = memref.view(
ir.MemRefType.get(
a_tma_shape, a_elem_ty, memory_space=smem_space
),
dynamic_smem,
a_offset,
[],
)
b_offset = arith.addi(
arith.muli(stage, c(rhs_tile_bytes)),
c(lhs_tile_bytes * num_stages),
)
b_tma_slice_1 = memref.view(
ir.MemRefType.get(
b_tma_shape, b_elem_ty, memory_space=smem_space
),
dynamic_smem,
b_offset,
[],
)
b_offset2 = arith.addi(
b_offset,
c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
)
b_tma_slice_2 = memref.view(
ir.MemRefType.get(
b_tma_shape, b_elem_ty, memory_space=smem_space
),
dynamic_smem,
b_offset2,
[],
)
debug_print(
"[prod] a_offset={} b_offset={} b_offset2={}",
a_offset,
b_offset,
b_offset2,
predicate=producerPrimaryThread,
)
coord = arith.muli(c(64), iv)
nvgpu.TmaAsyncLoadOp(
a_tma_slice,
mbarTMA,
a_tma_desc_op,
coordinates=[coord, dimX],
mbarId=stage,
predicate=producerPrimaryThread,
)
nvgpu.TmaAsyncLoadOp(
b_tma_slice_1,
mbarTMA,
b_tma_desc_op,
coordinates=[dimY, coord],
mbarId=stage,
predicate=producerPrimaryThread,
)
dimY2 = arith.addi(dimY, c(64))
nvgpu.TmaAsyncLoadOp(
b_tma_slice_2,
mbarTMA,
b_tma_desc_op,
coordinates=[dimY2, coord],
mbarId=stage,
predicate=producerPrimaryThread,
)
debug_print(
"[prod] iv={} | mbarTMA[{}] arrive",
iv,
stage,
predicate=producerPrimaryThread,
)
nvgpu.mbarrier_arrive_expect_tx(
mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
)
debug_print(
"[prod] iv={} | mbarTMA[{}] arrive [done]",
iv,
stage,
predicate=producerPrimaryThread,
)
scf.yield_([phaseParity])
scf.yield_([])
if_op = scf.IfOp(is_consumer)
with ir.InsertionPoint(if_op.then_block):
nvvm.setmaxregister(
CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase
)
acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
for_op = scf.ForOp(
c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
)
with ir.InsertionPoint(for_op.body):
phaseParity = for_op.inner_iter_args[1]
iv = for_op.induction_variable
stage = arith.remui(iv, c(num_stages))
debug_print(
"[cons] iv={} | mbarTMA[{}] try_wait phase={}",
iv,
stage,
phaseParity,
predicate=consumerPrimaryThread,
)
nvgpu.MBarrierTryWaitParityOp(
mbarTMA, phaseParity, ticks, mbarId=stage
)
debug_print(
"[cons] iv={} | mbarTMA[{}] try_wait phase={} [done]",
iv,
stage,
phaseParity,
predicate=consumerPrimaryThread,
)
a_offset = arith.muli(stage, c(lhs_tile_bytes))
a_tile_slice = memref.view(
ir.MemRefType.get(
a_tile_shape, a_elem_ty, memory_space=smem_space
),
dynamic_smem,
a_offset,
[],
)
b_offset = arith.addi(
arith.muli(stage, c(rhs_tile_bytes)),
c(lhs_tile_bytes * num_stages),
)
b_tile_slice = memref.view(
ir.MemRefType.get(
b_tile_shape, b_elem_ty, memory_space=smem_space
),
dynamic_smem,
b_offset,
[],
)
debug_print(
"[cons] a_offset={} b_offset={}",
a_offset,
b_offset,
predicate=consumerPrimaryThread,
)
da = nvgpu.WarpgroupGenerateDescriptorOp(
a_wgmma_ty, a_tile_slice, a_tma_desc_op
)
db = nvgpu.WarpgroupGenerateDescriptorOp(
b_wgmma_ty, b_tile_slice, b_tma_desc_op
)
carry_acc = for_op.inner_iter_args[0]
new_acc = nvgpu.WarpgroupMmaOp(
acc.type, da, db, carry_acc, transposeB=True
)
if num_stages == 1:
p_arrive = consumerPrimaryThread
else:
p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
p_arrive = arith.andi(consumerPrimaryThread, p1)
with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
barId = arith.select(
p, c(num_stages - 1), arith.subi(stage, c(1))
)
debug_print(
"[cons] iv={} | mbarDONE[{}] arrive ",
iv,
barId,
predicate=consumerPrimaryThread,
)
nvgpu.mbarrier_arrive(
ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
)
debug_print(
"[cons] iv={} | mbarDONE[{}] arrive [done]",
iv,
barId,
predicate=consumerPrimaryThread,
)
scf.yield_([])
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
phaseParity = arith.select(
p,
arith.xori(phaseParity, arith.constant(T.bool(), 1)),
phaseParity,
)
scf.yield_([new_acc, phaseParity])
nvvm.WgmmaWaitGroupSyncOp(0)
with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
barId = c((K // BLOCK_K) % num_stages)
nvgpu.mbarrier_arrive(
ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
)
scf.yield_([])
acc_smem_ty = ir.MemRefType.get(
(BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
)
acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
debug_print("[cons] | Storing", predicate=consumerPrimaryThread)
nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
scf.yield_([])
gpu.barrier()
fd = ir.MemRefType.get(
[BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
)
collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
rty = ir.MemRefType.get(
(BLOCK_M, BLOCK_N),
c_elem_ty,
ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
)
c_device_per_block = memref.SubViewOp(
rty,
c_device,
[dimX, dimY],
[],
[],
[MLIR_DYNAMIC, MLIR_DYNAMIC],
[BLOCK_M, BLOCK_N],
[1, 1],
)
vlen = 1
for_op = scf.ForOp(
tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2)
)
with ir.InsertionPoint(for_op.body):
x = arith.divui(for_op.induction_variable, c(BLOCK_M))
y = arith.remui(for_op.induction_variable, c(BLOCK_N))
vdata = vector.load(
ir.VectorType.get((vlen,), c_elem_ty),
collapsed_smem,
[for_op.induction_variable],
)
vector.store(vdata, c_device_per_block, [x, y])
scf.yield_([])
gpu.terminator()
t8 = gpu.wait(token_ty, [launch_op])
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
gpu.dealloc(token_ty, [t8], a_device)
gpu.dealloc(token_ty, [t8], b_device)
gpu.wait(token_ty, [t9])
gpu.dealloc(token_ty, [t8], c_device)
func.ReturnOp([])
fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
module.operation.verify()
return module
def generate_matmul_multistage(
input_type=np.float16,
output_type=np.float32,
M=4096,
N=4096,
K=4096,
BLOCK_M=128,
BLOCK_N=128,
BLOCK_K=64,
num_stages=3,
):
assert input_type == np.float16
assert output_type == np.float32
assert BLOCK_M == 128
assert BLOCK_N == 128
assert BLOCK_K == 64
assert M % BLOCK_M == 0
assert N % BLOCK_N == 0
assert K % BLOCK_K == 0
module = ir.Module.create()
token_ty = gpu.AsyncTokenType.get()
a_elem_ty = get_mlir_ty(input_type)
b_elem_ty = get_mlir_ty(input_type)
c_elem_ty = get_mlir_ty(output_type)
a_ty = ir.MemRefType.get([M, K], a_elem_ty)
b_ty = ir.MemRefType.get((K, N), b_elem_ty)
c_ty = ir.MemRefType.get((M, N), c_elem_ty)
a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
b_tile_shape = (BLOCK_K, BLOCK_N)
txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
)
smem_space_str = "#gpu.address_space<workgroup>"
smem_space = ir.Attribute.parse(smem_space_str)
mbar_ty = ir.Type.parse(
"!nvgpu.mbarrier.group<memorySpace = "
+ str(smem_space)
+ ", num_barriers = "
+ str(num_stages)
+ ">"
)
acc_ty = ir.Type.parse(
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ str(BLOCK_M)
+ "x"
+ str(BLOCK_N)
+ "x"
+ str(c_elem_ty)
+ ">>"
)
a_wgmma_ty = ir.Type.parse(
"!nvgpu.warpgroup.descriptor<tensor=memref<"
+ str(BLOCK_M)
+ "x"
+ str(BLOCK_K)
+ "x"
+ str(a_elem_ty)
+ ", "
+ smem_space_str
+ ">>"
)
b_wgmma_ty = ir.Type.parse(
"!nvgpu.warpgroup.descriptor<tensor=memref<"
+ str(BLOCK_K)
+ "x"
+ str(BLOCK_N)
+ "x"
+ str(a_elem_ty)
+ ", "
+ smem_space_str
+ ">>"
)
with ir.InsertionPoint(module.body):
kernelName = make_kernel_name(
input_type,
output_type,
M,
N,
K,
BLOCK_M,
BLOCK_N,
BLOCK_K,
num_stages,
False,
)
fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
with ir.InsertionPoint(fop.add_entry_block()):
a_host = fop.arguments[0]
b_host = fop.arguments[1]
c_host = fop.arguments[2]
lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
smem_size = max(smem_size_input, smem_size_output)
t1 = gpu.wait(token_ty, [])
a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
t7 = gpu.wait(token_ty, [t6])
a_tma_desc = TmaDescriptorBuilder(
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
nvgpu.TensorMapOOBKind.OOB_ZERO,
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
a_tma_shape,
a_ty,
)
b_tma_desc = TmaDescriptorBuilder(
nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
nvgpu.TensorMapOOBKind.OOB_ZERO,
nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
b_tma_shape,
b_ty,
)
a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
cta_m = M // BLOCK_M
cta_n = N // BLOCK_N
assert M % BLOCK_M == 0 and N % BLOCK_N == 0
grid = (cta_m, cta_n, 1)
block = (WARP_GROUP_SIZE, 1, 1)
launch_op = gpu.LaunchOp(
token_ty,
[t7],
*map(c, grid),
*map(c, block),
dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
)
launch_op.body.blocks.append(*([T.index()] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
memref.assume_alignment(c_device, 16)
dynamic_smem = gpu.dynamic_shared_memory(
ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
)
ticks = c(10000000)
tidx = gpu.thread_id(gpu.Dimension.x)
primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
warpId = arith.divui(tidx, c(32))
bidx = gpu.block_id(gpu.Dimension.x)
bidy = gpu.block_id(gpu.Dimension.y)
dimX = arith.muli(bidx, c(BLOCK_M))
dimY = arith.muli(bidy, c(BLOCK_N))
mbarTMA = nvgpu.mbarrier_create(mbar_ty)
for i in range(num_stages):
nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
gpu.barrier()
nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=primaryThread)
nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=primaryThread)
ns = num_stages if num_stages == 1 else num_stages - 1
for_op = scf.ForOp(c(0), c(ns), c(1))
with ir.InsertionPoint(for_op.body):
iv = for_op.induction_variable
a_offset = arith.muli(iv, c(lhs_tile_bytes))
a_tma_slice = memref.view(
ir.MemRefType.get(
a_tma_shape, a_elem_ty, memory_space=smem_space
),
dynamic_smem,
a_offset,
[],
)
b_offset = arith.addi(
arith.muli(iv, c(rhs_tile_bytes)),
c(lhs_tile_bytes * num_stages),
)
b_tma_slice_1 = memref.view(
ir.MemRefType.get(
b_tma_shape, b_elem_ty, memory_space=smem_space
),
dynamic_smem,
b_offset,
[],
)
b_offset2 = arith.addi(
b_offset,
c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
)
b_tma_slice_2 = memref.view(
ir.MemRefType.get(
b_tma_shape, b_elem_ty, memory_space=smem_space
),
dynamic_smem,
b_offset2,
[],
)
coord = arith.muli(c(64), iv)
dimY2 = arith.addi(dimY, c(64))
debug_print(
"[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
a_offset,
b_offset,
b_offset2,
coord,
dimX,
dimY,
coord,
predicate=primaryThread,
)
nvgpu.TmaAsyncLoadOp(
a_tma_slice,
mbarTMA,
a_tma_desc_op,
coordinates=[coord, dimX],
mbarId=iv,
predicate=primaryThread,
)
nvgpu.TmaAsyncLoadOp(
b_tma_slice_1,
mbarTMA,
b_tma_desc_op,
coordinates=[dimY, coord],
mbarId=iv,
predicate=primaryThread,
)
nvgpu.TmaAsyncLoadOp(
b_tma_slice_2,
mbarTMA,
b_tma_desc_op,
coordinates=[dimY2, coord],
mbarId=iv,
predicate=primaryThread,
)
debug_print(
"[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread
)
nvgpu.mbarrier_arrive_expect_tx(
mbarTMA, c(txcount), iv, predicate=primaryThread
)
debug_print(
"[Prologue] mbarTMA[{}] arrive [done]",
iv,
predicate=primaryThread,
)
scf.yield_([])
acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
for_op = scf.ForOp(
c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
)
with ir.InsertionPoint(for_op.body):
phaseParity = for_op.inner_iter_args[1]
iv = for_op.induction_variable
stage = arith.remui(iv, c(num_stages))
debug_print(
"[MainLoop] mbarTMA[{}] try_wait phase={}",
stage,
phaseParity,
predicate=primaryThread,
)
nvgpu.MBarrierTryWaitParityOp(
mbarTMA, phaseParity, ticks, mbarId=stage
)
debug_print(
"[MainLoop] mbarTMA[{}] try_wait phase={} [done]",
stage,
phaseParity,
predicate=primaryThread,
)
a_offset = arith.muli(stage, c(lhs_tile_bytes))
a_tile_slice = memref.view(
ir.MemRefType.get(
a_tile_shape, a_elem_ty, memory_space=smem_space
),
dynamic_smem,
a_offset,
[],
)
b_offset = arith.addi(
arith.muli(stage, c(rhs_tile_bytes)),
c(lhs_tile_bytes * num_stages),
)
b_tile_slice = memref.view(
ir.MemRefType.get(
b_tile_shape, b_elem_ty, memory_space=smem_space
),
dynamic_smem,
b_offset,
[],
)
debug_print(
"[MainLoop] iv={} MMA a_offset={} b_offset={}",
iv,
a_offset,
b_offset,
predicate=primaryThread,
)
da = nvgpu.WarpgroupGenerateDescriptorOp(
a_wgmma_ty, a_tile_slice, a_tma_desc_op
)
db = nvgpu.WarpgroupGenerateDescriptorOp(
b_wgmma_ty, b_tile_slice, b_tma_desc_op
)
carry_acc = for_op.inner_iter_args[0]
new_acc = nvgpu.WarpgroupMmaOp(
acc.type, da, db, carry_acc, transposeB=True
)
if num_stages == 1:
nvvm.WgmmaWaitGroupSyncOp(0)
p1 = arith.cmpi(
arith.CmpIPredicate.ult,
arith.addi(iv, c(ns)),
c(K // BLOCK_K),
)
p = arith.andi(primaryThread, p1)
nextStage = arith.addi(iv, c(ns))
nextSlot = arith.remui(nextStage, c(num_stages))
a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
debug_print(
"[MainLoop] mbarTMA[{}] arrive",
nextSlot,
predicate=p,
)
nvgpu.mbarrier_arrive_expect_tx(
mbarTMA, c(txcount), nextSlot, predicate=p
)
debug_print(
"[MainLoop] mbarTMA[{}] arrive [done]",
nextSlot,
predicate=p,
)
a_tma_slice = memref.view(
ir.MemRefType.get(
a_tma_shape, a_elem_ty, memory_space=smem_space
),
dynamic_smem,
a_offset,
[],
)
b_offset = arith.addi(
arith.muli(nextSlot, c(rhs_tile_bytes)),
c(lhs_tile_bytes * num_stages),
)
b_tma_slice_1 = memref.view(
ir.MemRefType.get(
b_tma_shape, b_elem_ty, memory_space=smem_space
),
dynamic_smem,
b_offset,
[],
)
b_offset2 = arith.addi(
b_offset,
c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
)
b_tma_slice_2 = memref.view(
ir.MemRefType.get(
b_tma_shape, b_elem_ty, memory_space=smem_space
),
dynamic_smem,
b_offset2,
[],
)
coord = arith.muli(c(64), nextStage)
debug_print(
"[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
iv,
a_offset,
b_offset,
b_offset2,
coord,
dimX,
dimY,
coord,
predicate=p,
)
nvgpu.TmaAsyncLoadOp(
a_tma_slice,
mbarTMA,
a_tma_desc_op,
coordinates=[coord, dimX],
mbarId=nextSlot,
predicate=p,
)
nvgpu.TmaAsyncLoadOp(
b_tma_slice_1,
mbarTMA,
b_tma_desc_op,
coordinates=[dimY, coord],
mbarId=nextSlot,
predicate=p,
)
dimY2 = arith.addi(dimY, c(64))
nvgpu.TmaAsyncLoadOp(
b_tma_slice_2,
mbarTMA,
b_tma_desc_op,
coordinates=[dimY2, coord],
mbarId=nextSlot,
predicate=p,
)
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
phaseParity = arith.select(
p,
arith.xori(phaseParity, arith.constant(T.bool(), 1)),
phaseParity,
)
scf.yield_([new_acc, phaseParity])
nvvm.WgmmaWaitGroupSyncOp(0)
acc_smem_ty = ir.MemRefType.get(
(BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
)
acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
debug_print("Storing", predicate=primaryThread)
nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
gpu.barrier()
fd = ir.MemRefType.get(
[BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
)
collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
rty = ir.MemRefType.get(
(BLOCK_M, BLOCK_N),
c_elem_ty,
ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
)
c_device_per_block = memref.SubViewOp(
rty,
c_device,
[dimX, dimY],
[],
[],
[MLIR_DYNAMIC, MLIR_DYNAMIC],
[BLOCK_M, BLOCK_N],
[1, 1],
)
vlen = 1
for_op = scf.ForOp(
tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE)
)
with ir.InsertionPoint(for_op.body):
x = arith.divui(for_op.induction_variable, c(BLOCK_M))
y = arith.remui(for_op.induction_variable, c(BLOCK_N))
vdata = vector.load(
ir.VectorType.get((vlen,), c_elem_ty),
collapsed_smem,
[for_op.induction_variable],
)
vector.store(vdata, c_device_per_block, [x, y])
scf.yield_([])
gpu.terminator()
t8 = gpu.wait(token_ty, [launch_op])
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
gpu.dealloc(token_ty, [t8], a_device)
gpu.dealloc(token_ty, [t8], b_device)
gpu.wait(token_ty, [t9])
gpu.dealloc(token_ty, [t8], c_device)
func.ReturnOp([])
fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
module.operation.verify()
return module