import functools
import os
import re
import subprocess
import tempfile
FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*')
SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*')
FNAME_RE = re.compile(r'\s*Function : (\w+)\s*')
BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);')
def parseCtrl(sline):
enc = int(SLINE_RE.match(sline).group(1), 16)
stall = (enc >> 41) & 0xf
yld = (enc >> 45) & 0x1
wrtdb = (enc >> 46) & 0x7
readb = (enc >> 49) & 0x7
watdb = (enc >> 52) & 0x3f
yld_str = 'Y' if yld == 0 else '-'
wrtdb_str = '-' if wrtdb == 7 else str(wrtdb)
readb_str = '-' if readb == 7 else str(readb)
watdb_str = '--' if watdb == 0 else f'{watdb:02d}'
return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}'
def processSassLines(fline, sline, labels):
asm = FLINE_RE.match(fline).group(1)
if asm.endswith(" ;"):
asm = asm[:-2] + ";"
ctrl = parseCtrl(sline)
if BRA_RE.match(asm) is not None:
target = int(BRA_RE.match(asm).group(2), 16)
if target in labels:
pass
else:
labels[target] = len(labels)
return (f'{ctrl}', f'{asm}')
@functools.lru_cache()
def get_sass(cubin_asm, fun=None):
fd, path = tempfile.mkstemp()
try:
with open(fd, 'wb') as cubin:
cubin.write(cubin_asm)
sass = extract(path, fun)
finally:
os.remove(path)
return sass
def path_to_cuobjdump():
from triton import knobs
return knobs.nvidia.cuobjdump.path
def extract(file_path, fun):
cuobjdump = path_to_cuobjdump()
if fun is None:
sass_str = subprocess.check_output([cuobjdump, "-sass", file_path])
else:
sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path])
sass_lines = sass_str.splitlines()
line_idx = 0
while line_idx < len(sass_lines):
line = sass_lines[line_idx].decode()
while FNAME_RE.match(line) is None:
line_idx += 1
if line_idx < len(sass_lines):
line = sass_lines[line_idx].decode()
else:
return
fname = FNAME_RE.match(line).group(1)
ret = ''
ret += f'Function:{fname}\n'
line_idx += 2
line = sass_lines[line_idx].decode()
labels = {}
asm_buffer = []
while FLINE_RE.match(line) is not None:
fline = sass_lines[line_idx].decode()
line_idx += 1
sline = sass_lines[line_idx].decode()
line_idx += 1
asm_buffer.append(processSassLines(fline, sline, labels))
line = sass_lines[line_idx].decode()
for idx, (ctrl, asm) in enumerate(asm_buffer):
offset = idx * 16
if offset in labels:
label_name = f'LBB{labels[offset]}'
ret += f'{label_name}:\n'
ret += ctrl + '\t'
if BRA_RE.match(asm):
target = int(BRA_RE.match(asm).group(2), 16)
target_name = f'LBB{labels[target]}'
asm = BRA_RE.sub(rf'\1{target_name};', asm)
ret += asm + '\n'
ret += '\n'
return ret