"""Simplified dispatch system for Python, based on classes' typeclass implementation.
This module provides a dispatch-based polymorphism system allowing extensible
behavior for different types using the `impl` decorator.
"""
__all__ = ["dispatch"]
from functools import _find_impl
from typing import Any, Callable, Dict, Optional, TypeVar
_SignatureType = TypeVar("_SignatureType", bound=Callable)
class _Dispatch:
"""Internal dispatch representation with type-based routing logic."""
__slots__ = (
"_signature",
"_name",
"_exact_types",
"_dispatch_cache",
"_doc",
"_module",
)
def __init__(self, signature: Callable) -> None:
self._signature = signature
self._name = signature.__name__
self._exact_types: Dict[Any, Callable] = {}
self._dispatch_cache: Dict[Any, Callable] = {}
self._doc = signature.__doc__
self._module = signature.__module__
def __call__(self, instance: Any, *args, **kwargs) -> Any:
"""Dispatch to the appropriate implementation based on instance type."""
if isinstance(instance, tuple):
key = tuple(v if isinstance(v, (type, str)) else type(v) for v in instance)
impl = self._exact_types.get(key)
if impl is not None:
return impl(instance, *args, **kwargs)
for registered_key, callback in self._exact_types.items():
if (
not isinstance(registered_key, tuple)
or len(registered_key) != len(key)
or not all(isinstance(t, type) for t in registered_key)
):
continue
try:
key_types = tuple(v if isinstance(v, type) else type(v) for v in instance)
if all(issubclass(k, rk) for k, rk in zip(key_types, registered_key)):
return callback(instance, *args, **kwargs)
except TypeError:
continue
def _name(obj):
return obj if isinstance(obj, str) else getattr(obj, "__name__", None) or str(obj)
key_names = tuple(_name(v) for v in key)
for registered_key, callback in self._exact_types.items():
if not isinstance(registered_key, tuple) or len(registered_key) != len(key):
continue
reg_names = tuple(_name(rk) for rk in registered_key)
if reg_names == key_names:
return callback(instance, *args, **kwargs)
error_msg = self._format_no_implementation_error(instance)
raise NotImplementedError(error_msg)
if isinstance(instance, type):
cache_key = instance
instance_type = instance
elif isinstance(instance, str):
cache_key = instance
instance_type = str
else:
cache_key = type(instance)
instance_type = cache_key
impl = self._dispatch_cache.get(cache_key)
if impl is None:
impl = self._dispatch(instance, instance_type)
if impl is None:
error_msg = self._format_no_implementation_error(instance)
raise NotImplementedError(error_msg)
self._dispatch_cache[cache_key] = impl
return impl(instance, *args, **kwargs)
def impl(self, *target_types: Any) -> Callable[[Callable], Callable]:
"""Register an implementation for one or more types.
Usage:
@mydispatch.impl(int) # Register for a single type
@mydispatch.impl(int, str) # Register for multiple types
@mydispatch.impl((list, str)) # Register for a tuple of types as a key
"""
if not target_types:
raise ValueError(
"\n✗ Missing argument to .impl()\n\n"
"You must specify at least one target type.\n\n"
"Examples:\n"
f" @{self._name}.impl(str) # Single type\n"
f" @{self._name}.impl(int, float) # Multiple types\n"
f" @{self._name}.impl((list, str)) # Tuple key\n"
)
def decorator(func: Callable) -> Callable:
if len(target_types) == 1:
self._exact_types[target_types[0]] = func
else:
for typ in target_types:
self._exact_types[typ] = func
self._dispatch_cache.clear()
return func
return decorator
def __repr__(self) -> str:
"""Rich representation showing all implementations."""
import inspect
sig = inspect.signature(self._signature)
sig_str = f"{self._name}{sig}"
lines = [f"Dispatch({sig_str})("]
for typ, impl in self._exact_types.items():
if isinstance(typ, tuple):
type_name = f"({', '.join(t.__name__ if hasattr(t, '__name__') else str(t) for t in typ)})"
else:
type_name = typ.__name__ if hasattr(typ, "__name__") else str(typ)
impl_loc = self._format_location(impl)
lines.append(f" ({type_name}): {impl.__name__} at {impl_loc}")
lines.append(")")
return "\n".join(lines)
def _dispatch(self, instance: Any, instance_type: type) -> Optional[Callable]:
"""Find the implementation for a given type.
Fallback order:
1) Exact type match
2) issubclass match (when instance is a type)
3) MRO-based match via functools._find_impl
4) Name-based fallback: match by class __name__ for dynamically generated
classes (e.g., HF transformers auto_map dynamic modules)
"""
impl = self._exact_types.get(instance_type, None)
if impl is not None:
return impl
if isinstance(instance, type):
for registered_type, callback in self._exact_types.items():
if not isinstance(registered_type, type):
continue
try:
if issubclass(instance, registered_type):
return callback
except TypeError:
pass
single_type_impls = {k: v for k, v in self._exact_types.items() if isinstance(k, type)}
impl = _find_impl(instance_type, single_type_impls)
if impl is not None:
return impl
def _name(obj):
return obj if isinstance(obj, str) else getattr(obj, "__name__", None)
if isinstance(instance, str):
inst_name = instance
elif isinstance(instance, type):
inst_name = _name(instance)
else:
inst_name = _name(type(instance))
if inst_name:
for registered_type, callback in self._exact_types.items():
reg_name = _name(registered_type)
if reg_name and str(reg_name) == inst_name:
return callback
return None
def _format_location(self, func: Callable) -> str:
"""Format the location of a function for display."""
try:
import inspect
filename = inspect.getfile(func)
_, lineno = inspect.getsourcelines(func)
import os
filename = os.path.relpath(filename)
return f"{filename}:{lineno}"
except Exception:
return "<unknown location>"
def _format_no_implementation_error(self, instance: Any) -> str:
"""Format a helpful error message when no implementation is found."""
type_name_for_header: str
type_name_for_suggestion: str
type_name_for_func: str
instance_type_hint: str
if isinstance(instance, tuple):
instance_types = tuple(v if isinstance(v, type) else type(v) for v in instance)
type_names_str = ", ".join(
t.__qualname__ if hasattr(t, "__qualname__") else str(t) for t in instance_types
)
type_name_for_header = f"tuple of types ({type_names_str})"
suggestion_names = ", ".join(t.__name__ if hasattr(t, "__name__") else str(t) for t in instance_types)
type_name_for_suggestion = f"({suggestion_names})"
type_name_for_func = "tuple"
instance_type_hint = f"Tuple[{', '.join(t.__name__ for t in instance_types)}]"
else:
instance_type = instance if isinstance(instance, type) else type(instance)
qualname = instance_type.__qualname__ if hasattr(instance_type, "__qualname__") else str(instance_type)
type_name_for_header = f"type '{qualname}'"
type_name_for_suggestion = (
instance_type.__name__ if hasattr(instance_type, "__name__") else str(instance_type)
)
type_name_for_func = type_name_for_suggestion.lower().replace(".", "_")
instance_type_hint = type_name_for_suggestion
lines = [
f"\n✗ No implementation found for {type_name_for_header}",
"",
f"The dispatch function '{self._name}' has no implementation for this type.",
"",
]
if self._exact_types:
lines.append("Available implementations:")
sorted_keys = sorted(
self._exact_types.keys(),
key=str,
)
for typ in sorted_keys:
if isinstance(typ, tuple):
type_display = f"({', '.join(t.__name__ if hasattr(t, '__name__') else str(t) for t in typ)})"
else:
type_display = typ.__name__ if hasattr(typ, "__name__") else str(typ)
lines.append(f" • {type_display}")
else:
lines.append("No implementations registered yet.")
if self._exact_types:
_, sample_impl = next(iter(self._exact_types.items()))
lines.extend(
[
"",
"To add support for this type, register an implementation:",
f" @{self._name}.impl({type_name_for_suggestion})",
f" def _{self._name}_{type_name_for_func}(instance: {instance_type_hint}) -> ...:",
" # Your implementation here",
]
)
else:
lines.extend(
[
"",
"To add support for this type:",
f" @{self._name}.impl({type_name_for_suggestion})",
f" def _{self._name}_{type_name_for_func}(instance: {instance_type_hint}, ...) -> ...:",
" # Your implementation here",
]
)
return "\n".join(lines)
def dispatch(func: _SignatureType) -> _Dispatch:
"""
Create a new dispatch function from a signature.
"""
return _Dispatch(func)