import functools
import hashlib
import importlib
import os
import shutil
import subprocess
import sysconfig
import tempfile
from pathlib import Path
import torch
import triton
import triton.language as tl
from triton.common.backend import (BaseBackend, compute_core_version_key, register_backend)
from triton.compiler.make_launcher import make_so_cache_key
from triton.runtime.cache import get_cache_manager
from triton.runtime.driver import DriverBase
def build_for_backend(name, src, srcdir):
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
cc = os.environ.get("CC")
if cc is None:
clang = shutil.which("clang")
gcc = shutil.which("gcc")
cc = gcc if gcc is not None else clang
if cc is None:
raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
if hasattr(sysconfig, 'get_default_scheme'):
scheme = sysconfig.get_default_scheme()
else:
scheme = sysconfig._get_default_scheme()
if scheme == 'posix_local':
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
subprocess.check_call([cc, src, f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-o", so])
return so
class ExtensionUtils:
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(ExtensionUtils, cls).__new__(cls)
return cls.instance
def __init__(self):
dirname = os.path.dirname(os.path.realpath(__file__))
src = Path(os.path.join(dirname, "extension_backend.c")).read_text()
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
fname = "ext_utils.so"
cache_path = cache.get_file(fname)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = build_for_backend("ext_utils", src_path, tmpdir)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("ext_utils", cache_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties
class ExtensionDriver(DriverBase):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(ExtensionDriver, cls).__new__(cls)
return cls.instance
def __init__(self):
self.utils = ExtensionUtils()
class ExtensionBackend(BaseBackend):
stub_so_path = ""
def __init__(self, device_type: str) -> None:
super(ExtensionBackend, self).__init__(device_type)
self.driver = ExtensionDriver()
self.version_key = None
def add_stages(self, stages, options, language):
filter_in_stages = ["ast", "ttir", "ttgir"]
filter_out_stages = []
for key, _ in stages.items():
if key not in filter_in_stages:
filter_out_stages.append(key)
for filter_out_key in filter_out_stages:
stages.pop(filter_out_key)
def add_meta_info(self, ir, cur_module, next_module, metadata, asm):
metadata["name"] = "extension_backend_name"
def get_driver(self):
return self.driver
def get_stream(self):
return ""
@functools.lru_cache(None)
def get_device_properties(self, device):
return self.driver.utils.get_device_properties()
def get_current_device(self):
return torch.device("cpu")
def set_current_device(self, device):
pass
def get_load_binary_fn(self):
return self.driver.utils.load_binary
def get_kernel_bin(self):
return "ttgir"
def get_architecture_descriptor(self, **kwargs):
return ""
def get_version_key(self):
if self.version_key is None:
self.version_key = compute_core_version_key()
return self.version_key
def make_launcher_stub(self, name, signature, constants):
so_cache_key = make_so_cache_key(self.get_version_key(), signature, constants)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.so"
cache_path = so_cache_manager.get_file(so_name)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src = self._generate_launcher(constants, signature)
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = build_for_backend(name, src_path, tmpdir)
with open(so, "rb") as f:
so_path = so_cache_manager.put(f.read(), so_name, binary=True)
type(self).stub_so_path = so_path
return so_path
else:
type(self).stub_so_path = cache_path
return cache_path
def _generate_launcher(self, constants, signature):
src = """
#define __EXTENSION_BACKEND__
#include <Python.h>
#include <stdio.h>
static PyObject* launch_counter(PyObject* self, PyObject* args) {
static int64_t launch_counter = 0;
launch_counter += 1;
return PyLong_FromLong(launch_counter);
}
static PyObject* launch(PyObject* self, PyObject* args) {
if (PyErr_Occurred()) {
return NULL;
}
launch_counter(self, args);
Py_RETURN_NONE;
}
static PyMethodDef ModuleMethods[] = {
{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"},
{"launch_counter", launch_counter, METH_VARARGS, "Entry point to get launch counter"},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef ModuleDef = {
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
};
PyMODINIT_FUNC PyInit___triton_launcher(void) {
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {
return NULL;
}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}
"""
return src
def test_dummy_backend():
register_backend("cpu", ExtensionBackend)
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
xnumel = 10
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)
inp = torch.randn(10)
out = torch.randn(10)
kernel[(10, )](inp, out, 10, XBLOCK=16)
spec = importlib.util.spec_from_file_location("__triton_launcher", ExtensionBackend.stub_so_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
launch_counter = getattr(mod, "launch_counter")
for _ in range(100):
kernel[(10, )](inp, out, 10, XBLOCK=16)
assert launch_counter() > 0