"""
Original code by @bertmaher; profiling added by @apgoucher
"""
import cProfile
import pstats
import time
import numpy as np
import torch
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
@triton.jit
def nop_args(
t1,
t2,
t3,
t4,
t5,
nc1,
nc2,
nc3,
nc4,
nc5,
nc6,
nc7,
nc8,
nc9,
c1: tl.constexpr,
c2: tl.constexpr,
c3: tl.constexpr,
c4: tl.constexpr,
c5: tl.constexpr,
):
pass
def do_bench_walltime(fn):
print("Compiling...")
fn()
torch.cuda.synchronize()
for _ in range(1000):
fn()
torch.cuda.synchronize()
n_repeat = 10000
mses = []
for _ in range(25):
print("Running %d benchmarking iterations..." % n_repeat)
torch.cuda.synchronize()
start_time = time.time()
for _ in range(n_repeat):
fn()
torch.cuda.synchronize()
end_time = time.time()
wall_time_ms = (end_time - start_time) * 1e3 / n_repeat
mses.append(wall_time_ms)
mses = np.array(mses)
print("Running profiler...")
profile = cProfile.Profile()
profile.enable()
for _ in range(n_repeat):
fn()
torch.cuda.synchronize()
profile.disable()
stats = pstats.Stats(profile)
stats.sort_stats("time")
stats.print_stats()
return mses
def main(use_tensor_desc: bool):
if use_tensor_desc:
targs = [TensorDescriptor.from_tensor(torch.zeros(1, 16, device="cuda"), block_shape=[1, 16]) for _ in range(5)]
else:
targs = [torch.zeros(1, device="cuda") for _ in range(5)]
ncargs = [0, 1, 1024, 2**31 - 1, 2**64 - 1, False, True, None, (16, 16)]
cargs = [32, False, True, 0, 64]
usecs = do_bench_walltime(lambda: nop_args[
1,
](*targs, *ncargs, *cargs)) * 1000.0
print(usecs)
print(sorted(usecs)[len(usecs) >> 1])
if __name__ == "__main__":
print("launch overhead of kernel with Tensor inputs")
main(use_tensor_desc=False)
print("launch overhead of kernel with TensorDescriptor inputs")
main(use_tensor_desc=True)