import inspect
import sqlite3
import threading
import weakref
import ast
from collections import OrderedDict
from typing import (
Any,
Callable,
Dict,
Final,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
overload,
)
import triton
from triton._C import libentryC
import torch
import torch_npu
torch_device_fn = torch.npu
from .code_cache import config_cache_dir
DEVICE_COUNT = torch_device_fn.device_count()
version = triton.__version__.split(".")
major_version, minor_version = eval(version[0]), eval(version[1])
def quote_identifier(name: str) -> str:
if not name:
raise ValueError("empty identifier")
allowed = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_")
if not (name[0].isalpha() or name[0] == "_"):
raise ValueError("identifier must start with letter or _")
if not all(ch in allowed for ch in name):
raise ValueError("identifier contains illegal char")
return '"' + name.replace('"', '""') + '"'
class LibTuner(triton.runtime.Autotuner):
def __init__(
self,
fn,
arg_names,
configs,
key,
reset_to_zero,
restore_value,
pre_hook=None,
post_hook=None,
prune_configs_by: Optional[Dict] = None,
warmup=None,
rep=None,
use_cuda_graph=False,
):
if major_version == 2 or (major_version == 3 and minor_version <= 1):
if warmup is None:
warmup = 25
if rep is None:
rep = 100
if major_version == 2:
super().__init__(
fn,
arg_names,
configs,
key,
reset_to_zero,
restore_value,
prune_configs_by,
warmup,
rep,
)
self.base_fn = fn
while not inspect.isfunction(self.base_fn):
self.base_fn = self.base_fn.fn
else:
super().__init__(
fn,
arg_names,
configs,
key,
reset_to_zero,
restore_value,
pre_hook,
post_hook,
prune_configs_by,
warmup,
rep,
use_cuda_graph,
)
self.__name__ = self.base_fn.__name__
self.table_name = quote_identifier(self.__name__)
self.cache_path = config_cache_dir() / "TunedConfig.db"
self.preload()
weakref.finalize(self, self.store)
def preload(self):
connect = sqlite3.connect(self.cache_path)
c = connect.cursor()
c.execute(
f"CREATE TABLE IF NOT EXISTS {self.table_name} (key TEXT PRIMARY KEY, config TEXT)"
)
cursor = c.execute(f"SELECT key, config from {self.table_name}")
for row in cursor:
key_str, config_str = row
key = [ast.literal_eval(k) for k in key_str[1:-1].split(", ")]
cfg_ls = [item.split(": ") for item in config_str.split(", ")]
config = triton.Config({})
attrs = -5 if major_version == 2 else -4
for k, v in cfg_ls[:attrs]:
config.kwargs[k] = ast.literal_eval(v)
config.num_warps = ast.literal_eval(cfg_ls[attrs][1])
config.num_ctas = ast.literal_eval(cfg_ls[attrs + 1][1])
config.num_stages = ast.literal_eval(cfg_ls[attrs + 2][1])
if major_version == 2:
config.enable_warp_specialization = ast.literal_eval(cfg_ls[attrs + 3][1])
config.enable_persistent = ast.literal_eval(cfg_ls[attrs + 4][1])
else:
config.maxnreg = ast.literal_eval(cfg_ls[attrs + 3][1])
self.cache[tuple(key)] = config
connect.close()
self.volumn = len(self.cache)
def store(self):
if len(self.cache) == self.volumn:
return
connect = sqlite3.connect(self.cache_path)
c = connect.cursor()
c.execute(
f"CREATE TABLE IF NOT EXISTS {self.table_name} (key TEXT PRIMARY KEY, config TEXT)"
)
for key, config in self.cache.items():
c.execute(
f"INSERT OR IGNORE INTO {self.table_name} (key, config) VALUES (?, ?)",
(str(key), config.__str__()),
)
connect.commit()
connect.close()
def libtuner(
configs,
key,
prune_configs_by=None,
reset_to_zero=None,
restore_value=None,
pre_hook=None,
post_hook=None,
warmup=25,
rep=100,
use_cuda_graph=False,
):
"""
Decorator for triton library autotuner.
"""
def decorator(fn):
return LibTuner(
fn,
fn.arg_names,
configs,
key,
reset_to_zero,
restore_value,
pre_hook=pre_hook,
post_hook=post_hook,
prune_configs_by=prune_configs_by,
warmup=warmup,
rep=rep,
use_cuda_graph=use_cuda_graph,
)
return decorator
class LibEntry(triton.KernelInterface):
def __init__(
self,
fn,
):
self.fn = fn
self.arg_names = fn.arg_names
self.divisibility = 16
self.kernel_cache = tuple(dict() for _ in range(DEVICE_COUNT))
while not isinstance(fn, triton.runtime.JITFunction):
fn = fn.fn
self.jit_function: triton.runtime.JITFunction = fn
self.specialize_indices = [
p.num
for p in self.jit_function.params
if not p.is_constexpr and not p.do_not_specialize
]
self.do_not_specialize_indices = [
p.num
for p in self.jit_function.params
if not p.is_constexpr and p.do_not_specialize
]
self.lock = threading.Lock()
self.signature = fn.signature
def key(self, spec_args, dns_args, const_args):
def spec_arg(arg):
if hasattr(arg, "data_ptr"):
return (arg.dtype, arg.data_ptr() % self.divisibility == 0)
return (type(arg), arg)
def dns_arg(arg):
if hasattr(arg, "data_ptr"):
return arg.dtype
if not isinstance(arg, int):
return type(arg)
if -(2**31) <= arg and arg <= 2**31 - 1:
return "i32"
if 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
return "i64"
spec_key = [spec_arg(arg) for arg in spec_args]
dns_key = [dns_arg(arg) for arg in dns_args]
return tuple(spec_key + dns_key + const_args)
def run(self, *args, **kwargs):
grid = kwargs["grid"]
spec_args = []
dns_args = []
const_args = []
k_args = OrderedDict()
param_names = list(self.signature.parameters.keys())
for i, arg in enumerate(args):
hashable_arg = arg
if (
hasattr(arg, "__class__")
and arg.__class__.__name__ == "TensorDescriptor"
):
hashable_arg = (
"TensorDescriptor",
tuple(arg.shape) if hasattr(arg, "shape") else None,
tuple(arg.strides) if hasattr(arg, "strides") else None,
tuple(arg.block_shape) if hasattr(arg, "block_shape") else None,
arg.padding if hasattr(arg, "padding") else None,
)
if i in self.specialize_indices:
k_args[param_names[i]] = arg
spec_args.append(hashable_arg)
elif i in self.do_not_specialize_indices:
k_args[param_names[i]] = arg
dns_args.append(hashable_arg)
else:
if major_version == 3 and 3 <= minor_version <= 6:
k_args[param_names[i]] = arg
const_args.append(hashable_arg)
for p in self.jit_function.params[len(args) :]:
if p.name in kwargs:
val = kwargs[p.name]
elif p.default is inspect._empty:
continue
else:
val = p.default
if p.is_constexpr:
const_args.append(val)
if major_version == 3 and 3 <= minor_version <= 6:
k_args[p.name] = val
elif p.do_not_specialize:
dns_args.append(val)
k_args[p.name] = val
else:
spec_args.append(val)
k_args[p.name] = val
entry_key = self.key(spec_args, dns_args, const_args)
device = torch_device_fn.current_device()
cache = self.kernel_cache[device]
while entry_key not in cache:
with self.lock:
if entry_key in cache:
break
kernel = self.fn.run(*args, **kwargs)
fn = self.fn
constexprs = {}
tune_constexprs = {}
heur_constexprs = {}
while not isinstance(fn, triton.runtime.JITFunction):
if isinstance(fn, triton.runtime.Autotuner):
config = fn.best_config
constexprs["num_warps"] = config.num_warps
constexprs["num_stages"] = config.num_stages
constexprs["num_ctas"] = config.num_ctas
constexprs = {**constexprs, **config.kwargs}
tune_constexprs = {**tune_constexprs, **config.kwargs}
elif isinstance(fn, triton.runtime.Heuristics):
for v, heur in fn.values.items():
heur_constexprs[v] = heur(
{
**dict(zip(fn.arg_names, args)),
**kwargs,
**constexprs,
}
)
constexprs[v] = heur_constexprs[v]
else:
raise RuntimeError("Invalid Runtime Function")
fn = fn.fn
for p in self.jit_function.params:
if (
p.is_constexpr
and p.name not in constexprs
and (p.default is not inspect._empty)
):
constexprs[p.name] = p.default
cache[entry_key] = (
kernel,
constexprs,
tune_constexprs,
heur_constexprs,
)
return kernel, constexprs
kernel, constexprs, tune_constexprs, heur_constexprs = cache[entry_key]
if callable(grid):
meta = {**dict(zip(self.arg_names, args)), **kwargs, **constexprs}
grid = grid(meta)
grid = grid + (1, 1)
if major_version == 3 and 3 <= minor_version <= 6:
all_args = []
missing_keys = []
for key in list(self.signature.parameters.keys()):
if key in k_args:
all_args.append(k_args[key])
elif key in tune_constexprs:
all_args.append(tune_constexprs[key])
elif key in heur_constexprs:
all_args.append(heur_constexprs[key])
elif key in constexprs:
all_args.append(constexprs[key])
else:
missing_keys.append(key)
if len(missing_keys):
raise RuntimeError(
f"[libentry]: probably a bug, the following kernel params where not captured: {missing_keys}"
)
kernel[grid[0:3]](*all_args)
else:
kernel[grid[0:3]](*k_args.values())
return kernel, constexprs
def libentry():
"""
Decorator for triton library entries.
"""
def decorator(fn):
from triton.runtime.interpreter import InterpretedFunction
if isinstance(fn, InterpretedFunction):
return fn
return LibEntry(fn)
return decorator