import os
from contextlib import contextmanager
import torch
import triton
import triton.language as tl
@contextmanager
def enable_dump_context(pass_name="1"):
try:
os.environ["MLIR_ENABLE_DUMP"] = pass_name
yield
finally:
os.environ["MLIR_ENABLE_DUMP"] = "0"
def test_fn_dump(capfd, device, fresh_triton_cache):
N = 1024
src = torch.zeros(N, device=device)
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]), )
@triton.jit
def _kernel(src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offsets, mask=offsets < N) + 1
tl.store(src + offsets, x, mask=offsets < N)
with enable_dump_context():
BLOCK_SIZE = 16
_kernel[grid](src, N, BLOCK_SIZE)
captured = capfd.readouterr()
print(captured.err)
assert "IR Dump Before" in captured.err
assert "tt.func public @_kernel" in captured.err
with enable_dump_context("_kernel"):
BLOCK_SIZE = 32
_kernel[grid](src, N, BLOCK_SIZE)
captured = capfd.readouterr()
assert "IR Dump Before" in captured.err
assert "tt.func public @_kernel" in captured.err
with enable_dump_context("_kernel2"):
BLOCK_SIZE = 64
_kernel[grid](src, N, BLOCK_SIZE)
captured = capfd.readouterr()
assert "IR Dump Before" not in captured.err