from . import patterns # noqa: F401
from .compile_backend import CompilerBackend
_backend_by_device = {}
def get_backend(*, device_name=None):
"""
Get the compilation backend for 'torch.compile'.
Returns:
Callable: The compilation backend function.
"""
backend = _backend_by_device.get(device_name)
if backend is None:
backend = CompilerBackend(device_name=device_name)
_backend_by_device[device_name] = backend
return backend