import ctypes
import matplotlib.pyplot as plt
import triton
from triton._C.libtriton import nvidia
import torch
import csv
from dataclasses import dataclass
import inspect
@dataclass
class PerfRecord:
time_ns: float
flops: float
bytes: float
def parse_profile(profile_path, useful_op_regex):
"""
construct a PerfRecord from a (proton) profile path and a regex for useful operations
"""
from triton.profiler import viewer
gf, _, _, _ = viewer.read(profile_path)
useful = gf.filter(f"MATCH ('*', c) WHERE c.'name' =~ '{useful_op_regex}' AND c IS LEAF").dataframe
bytes = int(useful["bytes"].sum())
flops = int(sum(useful[[c for c in ["flops8", "flops16"] if c in useful.columns]].sum()))
allops = gf.filter("MATCH ('*', c) WHERE c IS LEAF").dataframe
time_ns = allops["time (ns)"].sum()
return PerfRecord(time_ns=time_ns, flops=flops, bytes=bytes)
def write_csv(xs, perfs, fpath):
csv_path = fpath.with_suffix(".csv")
with csv_path.open("w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["x", "flops", "bytes", "time_ns"])
for x, p in zip(xs, perfs):
writer.writerow([x, p.flops, p.bytes, p.time_ns])
return csv_path
def compute_roofline(*args, \
bench_fn, intensity_proxy_name, intensity_proxy_values, out_path, verbose, \
**kwargs):
if not isinstance(intensity_proxy_name, str):
raise TypeError("intensity_proxy must be a string naming a parameter in target_fn")
sig = inspect.signature(bench_fn)
params = list(sig.parameters.values())
if intensity_proxy_name not in sig.parameters:
raise ValueError(f"Parameter '{intensity_proxy_name}' not found in {bench_fn.__name__} signature")
pos_index = [p.name for p in params].index(intensity_proxy_name)
def inject_proxy_and_call(val, args, kwargs):
args_list = list(args)
args_list.insert(pos_index, val)
return bench_fn(*args_list, **kwargs)
perfs = []
if verbose:
print("=========================================")
print(f"{out_path }...")
print("=========================================")
for val in intensity_proxy_values:
perf = inject_proxy_and_call(val, args, kwargs)
perfs.append(perf)
if verbose:
tflops = perfs[-1].flops / perfs[-1].time_ns * 1e-3
tbps = perfs[-1].bytes / perfs[-1].time_ns * 1e-3
print(f"{intensity_proxy_name}: {val:5d} | TFLOPS: {tflops:#.4g} | TBPS: {tbps:.2f}")
return write_csv(intensity_proxy_values, perfs, out_path)
def get_memset_tbps():
if torch.version.cuda is None:
raise RuntimeError("get_memset_tbps is only supported on CUDA")
cuda = ctypes.CDLL("libcuda.so")
cuda.cuInit.argtypes = [ctypes.c_uint]
cuda.cuInit.restype = ctypes.c_int
if cuda.cuInit(0) != 0:
raise RuntimeError("cuInit failed")
cuda.cuMemsetD8Async.argtypes = [ctypes.c_uint64, ctypes.c_ubyte, ctypes.c_size_t, ctypes.c_void_p]
cuda.cuMemsetD8Async.restype = ctypes.c_int
n_bytes = 1 << 32
buf = torch.empty(n_bytes, device="cuda", dtype=torch.uint8)
dptr = ctypes.c_uint64(buf.data_ptr())
fn = lambda: cuda.cuMemsetD8Async(dptr, ctypes.c_ubyte(0), ctypes.c_size_t(n_bytes), ctypes.c_void_p(0))
time_ms = triton.testing.do_bench(fn, rep=1000)
tbps = (n_bytes / (time_ms * 1e-3)) * 1e-12
return tbps
def get_cublas_tflops(dtype):
dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[dtype]
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
cublas = nvidia.cublas.CublasLt(cublas_workspace)
device = "cuda"
M, N, K = 8192, 8192, 8192
a = torch.randn(M, K, device=device, dtype=torch.float32).to(dtype)
b = torch.randn(K, N, device=device, dtype=torch.float32).to(dtype).T
c = torch.empty((M, N), device=device, dtype=dtype)
time_ms = triton.testing.do_bench(lambda: cublas.matmul(a, b, c), rep=1000)
return 2 * M * N * K / time_ms * 1e-9
def load_perf_csv(path):
xs, flops, bytes_, times = [], [], [], []
with open(path, "r", newline="") as f:
reader = csv.DictReader(f)
has_time_ns = "time_ns" in reader.fieldnames
has_time = "time" in reader.fieldnames
if not (has_time_ns or has_time):
raise ValueError(f"CSV {path} missing time_ns/time column")
for row in reader:
xs.append(int(row["x"]))
flops.append(int(row["flops"]))
bytes_.append(int(row["bytes"]))
tval = row["time_ns"] if has_time_ns else row["time"]
times.append(int(float(tval)))
return xs, flops, bytes_, times
def validate_perfs(perfs):
xs_ref, flops_ref, bytes_ref, _ = perfs[0]
for _, (xs, flops, bytes, _) in enumerate(perfs[1:], start=1):
for i in range(len(xs)):
if xs[i] != xs_ref[i]:
raise ValueError(f"x mismatch between series[0] and series[{i}]")
def plot_roofline(series, flops_dtype, out_path, max_tbps, max_tflops, title="", xlabel="", labels=None):
from bisect import bisect_left
from pathlib import Path
perfs = [load_perf_csv(p) for p in series]
validate_perfs(perfs)
xs, flops_ref, bytes_ref, _ = perfs[0]
if not isinstance(max_tbps, int):
assert max_tbps == "memset"
max_tbps = get_memset_tbps()
if not isinstance(max_tflops, int):
assert max_tflops == "cublas"
max_tflops = get_cublas_tflops(flops_dtype)
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
ax.set_xlabel(xlabel)
ax.set_ylabel("performance [TFLOP/s]")
ax.set_title(title)
xmin, xmax = min(xs), max(xs)
dx = 0.05 * (xmax - xmin) if xmax > xmin else 1.0
ax.set_xlim(xmin - dx, xmax + dx)
ax.set_ylim(100, max_tflops + 500)
opints = [f / b for f, b in zip(flops_ref, bytes_ref)]
knee = bisect_left(opints, max_tflops / max_tbps)
if knee > 0:
x_bw = [xs[0], xs[knee - 1]]
y_bw = [opints[0] * max_tbps, max_tflops]
else:
x_bw = y_bw = []
x_comp = xs[max(knee - 1, 0):]
y_comp = [max_tflops] * len(x_comp)
grey = "#7f7f7f"
ax.plot(x_bw, y_bw, linestyle="--", color=grey, label=f"BW-bound - {max_tbps:.1f} TB/s [memset]", zorder=1)
ax.plot(x_comp, y_comp, linestyle=":", color=grey, label=f"Compute-bound - {max_tflops:.0f} TFLOP/s [cuBLAS]",
zorder=1)
for idx, (pth, (_, f, b, t)) in enumerate(zip(series, perfs)):
perf_tflops = [ff / tt * 1e-3 if tt > 0 else 0.0 for ff, tt in zip(f, t)]
label = (labels[idx] if labels and idx < len(labels) else Path(pth).stem)
ax.plot(xs, perf_tflops, label=label, linewidth=1.8, zorder=2)
ax.legend(frameon=False, loc="lower right")
ax.grid(True, which="both", ls=":", lw=0.5)
fig.tight_layout()
plt.savefig(out_path)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Plot roofline(s) from perf CSV series")
parser.add_argument("--series", type=str, nargs="+", required=True,
help="list of .csv files; columns must be `x`, `flops`, `bytes`, `time_ns`")
parser.add_argument("--dtype", type=str, required=True, choices=["fp16", "bf16", "fp8"],
help="data type used for compute-bound roof")
parser.add_argument("--out_path", type=str, required=True, help="path to write the output image")
parser.add_argument("--title", type=str, default="", help="plot title")
parser.add_argument("--xlabel", type=str, default="", help="x-axis label")
parser.add_argument("--labels", type=str, nargs="+", default=None,
help="optional list of names for each series, in order; must match number of --series")
args = parser.parse_args()
if args.labels is not None and len(args.labels) != len(args.series):
parser.error("--labels must have the same number of entries as --series")
plot_roofline(args.series, args.dtype, args.out_path, title=args.title, xlabel=args.xlabel, labels=args.labels)