import functools
import contextlib
from typing import TypeVar, Union, Iterator, Callable, Dict
from codegen.utils import S, T, context
from codegen.model import (NativeFunction, NativeFunctionsGroup, BackendIndex, DispatchKey)
import codegen.local as local
F = TypeVar(
'F',
NativeFunction,
NativeFunctionsGroup,
Union[NativeFunction, NativeFunctionsGroup],
)
@contextlib.contextmanager
def native_function_manager(g: Union[NativeFunctionsGroup, NativeFunction]) -> Iterator[None]:
if isinstance(g, NativeFunctionsGroup):
f = g.out
else:
f = g
with context(lambda: f'in native_functions.yaml func:\n {f.func}'):
with local.parametrize(
param_use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
param_use_ilistref_for_tensor_lists=f.part_of_structured_group,
):
yield
def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
@functools.wraps(func)
def wrapper(f: F) -> T:
with native_function_manager(f):
return func(f)
return wrapper
def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
@functools.wraps(func)
def wrapper(slf: S, f: F) -> T:
with native_function_manager(f):
return func(slf, f)
return wrapper
def with_native_function_and_index(func: Callable[[F, BackendIndex], T]) -> Callable[[F, BackendIndex], T]:
@functools.wraps(func)
def wrapper(f: F, backend_index: BackendIndex) -> T:
with native_function_manager(f):
return func(f, backend_index)
return wrapper
def with_native_function_and_indices(
func: Callable[[F, Dict[DispatchKey, BackendIndex]], T]
) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]:
@functools.wraps(func)
def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T:
with native_function_manager(f):
return func(f, backend_indices)
return wrapper