import json
import os
import uuid
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import base64
import hashlib
import functools
import sysconfig
from triton import __version__, knobs
class CacheManager(ABC):
def __init__(self, key, override=False, dump=False):
pass
@abstractmethod
def get_file(self, filename) -> Optional[str]:
pass
@abstractmethod
def put(self, data, filename, binary=True) -> str:
pass
@abstractmethod
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
pass
@abstractmethod
def put_group(self, filename: str, group: Dict[str, str]):
pass
class FileCacheManager(CacheManager):
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if dump:
self.cache_dir = knobs.cache.dump_dir
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = knobs.cache.override_dir
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
self.cache_dir = knobs.cache.dir
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")
def _make_path(self, filename) -> str:
return os.path.join(self.cache_dir, filename)
def has_file(self, filename) -> bool:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
return os.path.exists(self._make_path(filename))
def get_file(self, filename) -> Optional[str]:
if self.has_file(filename):
return self._make_path(filename)
else:
return None
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
grp_filename = f"__grp__{filename}"
if not self.has_file(grp_filename):
return None
grp_filepath = self._make_path(grp_filename)
with open(grp_filepath) as f:
grp_data = json.load(f)
child_paths = grp_data.get("child_paths", None)
if child_paths is None:
return None
result = {}
for c, p in child_paths.items():
if os.path.exists(p):
result[c] = p
return result
def put_group(self, filename: str, group: Dict[str, str]) -> str:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
grp_contents = json.dumps({"child_paths": group})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename, binary=False)
def put(self, data, filename, binary=True) -> str:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
binary = isinstance(data, bytes)
if not binary:
data = str(data)
assert self.lock_path is not None
filepath = self._make_path(filename)
rnd_id = str(uuid.uuid4())
pid = os.getpid()
temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.join(temp_dir, filename)
mode = "wb" if binary else "w"
with open(temp_path, mode) as f:
f.write(data)
os.replace(temp_path, filepath)
os.removedirs(temp_dir)
return filepath
class RemoteCacheBackend:
"""
A backend implementation for accessing a remote/distributed cache.
"""
def __init__(self, key: str):
pass
@abstractmethod
def get(self, filenames: List[str]) -> Dict[str, bytes]:
pass
@abstractmethod
def put(self, filename: str, data: bytes):
pass
class RedisRemoteCacheBackend(RemoteCacheBackend):
def __init__(self, key):
import redis
self._key = key
self._key_fmt = knobs.cache.redis.key_format
self._redis = redis.Redis(
host=knobs.cache.redis.host,
port=knobs.cache.redis.port,
)
def _get_key(self, filename: str) -> str:
return self._key_fmt.format(key=self._key, filename=filename)
def get(self, filenames: List[str]) -> Dict[str, str]:
results = self._redis.mget([self._get_key(f) for f in filenames])
return {filename: result for filename, result in zip(filenames, results) if result is not None}
def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
self._redis.set(self._get_key(filename), data)
class RemoteCacheManager(CacheManager):
def __init__(self, key, override=False, dump=False):
remote_cache_cls = knobs.cache.remote_manager_class
if not remote_cache_cls:
raise RuntimeError(
"Unable to instantiate RemoteCacheManager, TRITON_REMOTE_CACHE_BACKEND doesn't point to a valid class")
self._backend = remote_cache_cls(key)
self._override = override
self._dump = dump
self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)
def _materialize(self, filename: str, data: bytes):
return self._file_cache_manager.put(data, filename, binary=True)
def get_file(self, filename: str) -> Optional[str]:
if self._dump or self._override:
return self._file_cache_manager.get_file(filename)
results = self._backend.get([filename])
if len(results) == 0:
return None
(_, data), = results.items()
return self._materialize(filename, data)
def put(self, data, filename: str, binary=True) -> str:
if self._dump or self._override:
return self._file_cache_manager.put(data, filename, binary=binary)
if not isinstance(data, bytes):
data = str(data).encode("utf-8")
self._backend.put(filename, data)
return self._materialize(filename, data)
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
if self._dump or self._override:
return self._file_cache_manager.get_group(filename)
grp_filename = f"__grp__{filename}"
grp_filepath = self.get_file(grp_filename)
if grp_filepath is None:
return None
with open(grp_filepath) as f:
grp_data = json.load(f)
child_paths = grp_data.get("child_paths", None)
result = None
if child_paths is not None:
result = {}
for child_path, data in self._backend.get(child_paths).items():
result[child_path] = self._materialize(child_path, data)
return result
def put_group(self, filename: str, group: Dict[str, str]):
if self._dump or self._override:
return self._file_cache_manager.put_group(filename, group)
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename)
def _base32(key):
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
def get_cache_manager(key) -> CacheManager:
cls = knobs.cache.manager_class or FileCacheManager
return cls(_base32(key))
def get_override_manager(key) -> CacheManager:
cls = knobs.cache.manager_class or FileCacheManager
return cls(_base32(key), override=True)
def get_dump_manager(key) -> CacheManager:
cls = knobs.cache.manager_class or FileCacheManager
return cls(_base32(key), dump=True)
def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
for kw in kwargs:
key = f"{key}-{kwargs.get(kw)}"
key = hashlib.sha256(key.encode("utf-8")).hexdigest()
return _base32(key)
@functools.lru_cache()
def triton_key():
import pkgutil
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
contents = []
with open(__file__, "rb") as f:
contents += [hashlib.sha256(f.read()).hexdigest()]
path_prefixes = [
(os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
(os.path.join(TRITON_PATH, "backends"), "triton.backends."),
]
for path, prefix in path_prefixes:
for lib in pkgutil.walk_packages([path], prefix=prefix):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha256(f.read()).hexdigest()]
libtriton_hash = hashlib.sha256()
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
while True:
chunk = f.read(1024**2)
if not chunk:
break
libtriton_hash.update(chunk)
contents.append(libtriton_hash.hexdigest())
language_path = os.path.join(TRITON_PATH, 'language')
for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
contents += [hashlib.sha256(f.read()).hexdigest()]
return f'{__version__}' + '-'.join(contents)
def get_cache_key(src, backend, backend_options, env_vars):
key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}"
return key