from __future__ import annotations, division
import ast
import hashlib
import inspect
import itertools
import os
import re
import textwrap
from collections import defaultdict
from functools import cached_property
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple
from types import ModuleType
from triton._C.libtriton import get_cache_invalidating_env_vars
from .driver import driver
from . import _async_compile
TRITON_MODULE = __name__[:-len(".runtime.jit")]
T = TypeVar("T")
class DependenciesFinder(ast.NodeVisitor):
"""
This AST visitor is used to find dependencies of a JITFunction. This can
be used to invalidate a JITFunction's hash when its source code -- or
that of its dependencies -- changes.
This visitor also keeps track of the global variables touched by the
JITFunction. When we launch the kernel, we check that these have the same
values as they did when we ran this visitor. If not, we raise an error (or
otherwise we could recompile).
"""
def __init__(self, name, globals, src) -> None:
super().__init__()
self.name = name
self.hasher = hashlib.sha256(src.encode("utf-8"))
self.globals = globals
self.supported_python_builtins = {
'float',
'getattr',
'int',
'isinstance',
'len',
'list',
'max',
'min',
'print',
'range',
}
self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
self.visiting_arg_default_value = False
@property
def ret(self):
return self.hasher.hexdigest()
def _is_triton_builtin(self, node, func):
if inspect.isbuiltin(node.func):
return True
module = getattr(func, "__module__", "")
return module.startswith(TRITON_MODULE)
def _update_hash(self, func):
if isinstance(func, JITFunction):
for k in self.used_global_vals.keys() & func.used_global_vals.keys():
var_name, _ = k
v1, _ = self.used_global_vals[k]
v2, _ = func.used_global_vals[k]
if v1 != v2:
raise RuntimeError(
f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed."
)
self.used_global_vals.update(func.used_global_vals)
func_key = func.cache_key
func_key += str(getattr(func, "noinline", False))
self.hasher.update(func_key.encode("utf-8"))
def visit_Name(self, node):
if type(node.ctx) is ast.Store:
return node.id
if node.id in self.local_names:
return None
val = self.globals.get(node.id, None)
if (val is not None
and not self.visiting_arg_default_value
and type(val) is not ModuleType
and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False)
and node.id not in self.supported_python_builtins):
self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals)
self._update_hash(val)
return val
def visit_Tuple(self, node):
return [self.visit(elt) for elt in node.elts]
def visit_Attribute(self, node):
lhs = self.visit(node.value)
while isinstance(lhs, ast.Attribute):
lhs = self.visit(lhs.value)
if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE):
return None
ret = getattr(lhs, node.attr)
self._update_hash(ret)
return ret
def visit_FunctionDef(self, node):
self.local_names = {arg.arg for arg in node.args.args}
self.generic_visit(node)
def visit_arguments(self, node):
def visit_defaults(defaults):
try:
assert not self.visiting_arg_default_value
self.visiting_arg_default_value = True
for expr in defaults:
if expr is not None:
self.visit(expr)
finally:
self.visiting_arg_default_value = False
for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs):
self.visit(arg)
visit_defaults(node.kw_defaults)
if node.kwarg is not None:
self.visit(node.kwarg)
visit_defaults(node.defaults)
def visitAssnTarget(self, node):
target = self.visit(node)
if isinstance(target, list):
self.local_names |= set(target)
else:
self.local_names.add(target)
def visit_Assign(self, node):
if len(node.targets) != 1:
raise TypeError("Simultaneous multiple assignment is not supported.")
self.visitAssnTarget(node.targets[0])
self.generic_visit(node)
def visit_AnnAssign(self, node):
self.visitAssnTarget(node.target)
self.generic_visit(node)
def visit_For(self, node):
self.visitAssnTarget(node.target)
self.generic_visit(node)
def _normalize_ty(ty) -> str:
if isinstance(ty, type):
return ty.__name__
elif isinstance(ty, str):
return ty
return repr(ty)
class KernelParam:
"""Represents a parameter (name plus metadata) to a @jit'ed function."""
def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool,
do_not_specialize_on_alignment: bool):
self.num = num
self._param = param
self.do_not_specialize = do_not_specialize
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
@cached_property
def name(self):
return self._param.name
@cached_property
def annotation(self):
if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
return ""
return _normalize_ty(self._param.annotation)
@cached_property
def annotation_type(self):
annotation = self.annotation
for ty1, ty2 in [("uint", 'u'), ("int", 'i')]:
width = annotation[annotation.find(ty1) + len(ty1):]
if width and ty1 in annotation:
return f"{ty2}{width}"
if annotation == "bool":
return "u1"
return ""
@cached_property
def is_constexpr(self):
return "constexpr" in self.annotation
@cached_property
def is_const(self):
return "const" in self.annotation and not self.is_constexpr
@property
def default(self):
return self._param.default
@property
def has_default(self):
return self._param.default != inspect.Parameter.empty
def compute_spec_key(v, align):
if align and hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0):
return "D"
elif isinstance(v, int):
if align and (v % 16 == 0):
return "D"
elif v == 1:
return "1"
return "N"
dtype2str = {}
def mangle_type(arg, is_const=False):
if arg is None:
return "none"
elif isinstance(arg, bool):
return "i1"
elif isinstance(arg, int):
if -(2**31) <= arg and arg <= 2**31 - 1:
return "i32"
elif 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
else:
return "i64"
elif isinstance(arg, float):
return "fp32"
elif hasattr(arg, "tma_desc_cpu_ptr"):
return "nvTmaDesc"
else:
dsk = (arg.dtype, is_const)
res = dtype2str.get(dsk, None)
if res is None:
res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]]
dtype2str[dsk] = res
return res
class KernelInterface(Generic[T]):
run: T
def __getitem__(self, grid) -> T:
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
def serialize_specialization_data(name, signature, constants, attrs, options, key):
constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()}
import json
obj = {
'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options':
options.__dict__, 'key': key
}
serialized_obj = json.dumps(obj)
return serialized_obj
def create_function_from_signature(sig, kparams, backend):
"""
Equivalent to sig.bind followed by apply_defaults. This generates a
native Python function (using exec) which can be memoized on a per-kernel
basis to avoid having to run these expensive functions -- which constitute
much of the kernel launch overhead -- every time we run the kernel.
"""
assert len(sig.parameters) == len(kparams)
func_args = []
dict_entries = []
constexpr_vals = []
non_constexpr_vals = []
signature_types = []
specialisations = []
for ((name, sp), kp) in zip(sig.parameters.items(), kparams):
if sp.default is inspect.Parameter.empty:
func_args.append(name)
dict_entries.append(f"'{name}': {name}")
else:
func_args.append(f"{name}=default_{name}")
dict_entries.append(f"'{name}': {name}")
if kp.is_constexpr:
constexpr_vals.append(name)
else:
non_constexpr_vals.append(name)
if not kp.do_not_specialize:
if not kp.do_not_specialize_on_alignment:
specialisations.append('compute_spec_key(%s, align=True)' % name)
else:
specialisations.append('compute_spec_key(%s, align=False)' % name)
if kp.annotation_type:
signature_types.append('"%s"' % kp.annotation_type)
else:
signature_types.append('mangle_type(%s, %s)' % (name, 'True' if kp.is_const else 'False'))
cache_key = ''.join([x + ', ' for x in signature_types + specialisations])
constexpr_vals = ''.join([x + ', ' for x in constexpr_vals])
non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals])
func_args.append('**excess_kwargs')
args_str = ', '.join(func_args)
dict_str = ', '.join(dict_entries)
func_body = "def dynamic_func(%s):\n return {%s}, (%s), (%s), (%s), excess_kwargs" % (
args_str, dict_str, cache_key, constexpr_vals, non_constexpr_vals)
func_namespace = {
f"default_{name}": param.default
for name, param in sig.parameters.items()
if param.default is not inspect.Parameter.empty
}
func_namespace['mangle_type'] = mangle_type
func_namespace['compute_spec_key'] = backend.compute_spec_key
exec(func_body, func_namespace)
return func_namespace['dynamic_func']
type_canonicalisation_dict = {
"bool": "i1",
"float8e4nv": "fp8e4nv",
"float8e5": "fp8e5",
"float8e4b15": "fp8e4b15",
"float8_e4m3fn": "fp8e4nv",
"float8e4b8": "fp8e4b8",
"float8_e4m3fnuz": "fp8e4b8",
"float8_e5m2": "fp8e5",
"float8e5b16": "fp8e5b16",
"float8_e5m2fnuz": "fp8e5b16",
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"float64": "fp64",
"int8": "i8",
"int16": "i16",
"int32": "i32",
"int64": "i64",
"uint8": "u8",
"uint16": "u16",
"uint32": "u32",
"uint64": "u64",
}
for v in list(type_canonicalisation_dict.values()):
type_canonicalisation_dict[v] = v
class JITFunction(KernelInterface[T]):
cache_hook = None
compiled_hook = None
@staticmethod
def _key_of(arg):
if hasattr(arg, "dtype"):
return arg.dtype
elif isinstance(arg, bool):
return "i1"
elif isinstance(arg, int):
if -(2**31) <= arg and arg <= 2**31 - 1:
return "i32"
elif 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
else:
return "i64"
elif isinstance(arg, float):
return "fp32"
elif arg is None:
return None
else:
raise TypeError(f"Unsupported type {type(arg)} for {arg}")
@staticmethod
def _type_of(key, is_const=False):
if key is None:
return "*i8"
elif isinstance(key, str):
return key
dtype_str = str(key).split(".")[-1]
dtype_str = type_canonicalisation_dict[dtype_str]
const_str = "*k" if is_const else "*"
return const_str + dtype_str
def _make_constants(self, constexpr_key):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
def _call_hook(
self,
key,
signature,
device,
constants,
options,
configs,
is_warmup,
before,
):
hook = JITFunction.cache_hook if before else JITFunction.compiled_hook
if hook is None:
return False
name = self.fn.__name__
module = self.fn.__module__
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})"
class JitFunctionInfo:
def __init__(self, module, name, jit_function):
self.module = module
self.name = name
self.jit_function = jit_function
pass
specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key)
kwargs = {
'signature': signature,
'device': device,
'constants': constants,
'num_warps': options.num_warps,
'num_ctas': options.num_ctas,
'num_stages': options.num_stages,
'enable_fp_fusion': options.enable_fp_fusion,
'extern_libs': options.extern_libs,
'configs': configs,
'specialization_data': specialization_data,
'is_warmup': is_warmup,
}
return hook(
key=key,
repr=repr,
fn=JitFunctionInfo(module, name, self),
compile={"key": key, **kwargs},
is_manual_warmup=is_warmup,
already_compiled=False,
)
def add_pre_run_hook(self, hook):
'''
Add a hook that will be executed prior to the execution of run
function with args and kwargs passed into the kernel
'''
assert callable(hook)
self.pre_run_hooks.append(hook)
def create_binder(self, backend):
"""
Precompute as much as possible.
"""
from ..compiler import CompiledKernel, compile, ASTSource, make_backend
self.CompiledKernel = CompiledKernel
self.compile = compile
self.ASTSource = ASTSource
self.make_backend = make_backend
self.binder = create_function_from_signature(self.signature, self.params, backend)
self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr]
self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr]
self.specialised_indices = [
i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr)
]
def run(self, *args, grid, warmup, **kwargs):
kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1"
from ..compiler import make_backend
device = driver.active.get_current_device()
stream = driver.active.get_current_stream(device)
target = driver.active.get_current_target()
backend = make_backend(target)
for hook in self.pre_run_hooks:
hook(*args, **kwargs)
if self.binder is None:
self.create_binder(backend)
bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs)
key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs))
kernel = self.cache[device].get(key, None)
if kernel is None:
options = backend.parse_options(kwargs)
assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
assert "device" not in kwargs, "device option is deprecated; current device will be used"
assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
for k in excess_kwargs:
if k not in options.__dict__:
raise KeyError("Keyword argument %s was specified but unrecognised" % k)
bound_vals = tuple(bound_args.values())
sigkeys = [self.params[i].name for i in self.non_constexpr_indices]
sigvals = sig_and_spec[:len(sigkeys)]
signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)}
configs = (backend.get_attrs_descriptor(self.params, bound_vals), )
constant_params = configs[0].get_constants()
constants = {
p.name: v
for (v, p) in zip(bound_vals, self.params)
if p.is_constexpr or (p.num in constant_params) or v is None
}
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported")
kernel = self._do_compile(key, signature, device, backend, target, constants, options, configs[0], warmup)
if kernel is None:
return None
not_present = object()
for (name, _), (val, globals_dict) in self.used_global_vals.items():
if (newVal := globals_dict.get(name, not_present)) != val:
raise RuntimeError(
f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}")
if not warmup:
assert grid is not None
if callable(grid):
grid = grid(bound_args)
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if hasattr(kernel, "result"):
kernel = kernel.result()
launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals)
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals)
return kernel
def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None,
noinline=None, repr=None, launch_metadata=None):
do_not_specialize = do_not_specialize if do_not_specialize else []
do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else []
self.fn = fn
self.module = fn.__module__
self.version = version
self.signature = inspect.signature(fn)
self.do_not_specialize = do_not_specialize
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
self.starting_line_number = inspect.getsourcelines(fn)[1]
self.repr = lambda _: fn.__name__ if repr is None else repr(_)
self.launch_metadata = launch_metadata
self.binder = None
self.params = []
for i, param in enumerate(self.signature.parameters.values()):
dns = i in do_not_specialize or param.name in do_not_specialize
dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment
self.params.append(KernelParam(i, param, dns, dns_oa))
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():]
self.cache = defaultdict(dict)
self.hash = None
self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
self.kernel = None
self.noinline = noinline
self.arg_names = [p.name for p in self.params]
self.constexprs = [p.num for p in self.params if p.is_constexpr]
self.pre_run_hooks = []
self.__doc__ = fn.__doc__
self.__name__ = fn.__name__
self.__globals__ = fn.__globals__
self.__module__ = fn.__module__
@property
def cache_key(self):
if self.hash is None:
dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src)
dependencies_finder.visit(self.parse())
self.hash = dependencies_finder.ret + str(self.starting_line_number)
self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
return self.hash
def warmup(self, *args, grid, **kwargs):
return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
def preload(self, specialization_data):
from ..compiler import make_backend
from triton.backends.compiler import AttrsDescriptor
import json
import triton.language as tl
device = driver.active.get_current_device()
deserialized_obj = json.loads(specialization_data)
if deserialized_obj['name'] != self.fn.__name__:
raise RuntimeError(
f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}")
constants = {
key: tl.dtype(value) if tl.dtype.is_dtype(value) else value
for key, value in deserialized_obj['constants'].items()
}
signature = dict(deserialized_obj['signature'].items())
options = {
key: tuple(value) if isinstance(value, list) else value
for key, value in deserialized_obj['options'].items()
}
key = deserialized_obj['key']
target = driver.active.get_current_target()
backend = make_backend(target)
options = backend.parse_options(options)
attrs = AttrsDescriptor.from_dict(deserialized_obj['attrs'])
return self._do_compile(
key,
signature,
device,
backend,
target,
constants,
options,
attrs,
warmup=True,
)
def _do_compile(self, key, signature, device, backend, target, constants, options, attrs, warmup):
kernel_cache = self.cache[device]
if self._call_hook(key, signature, device, constants, options, [attrs], warmup, before=True):
return None
src = self.ASTSource(self, signature, constants, attrs)
async_mode = _async_compile.active_mode.get()
if async_mode is not None:
from triton.compiler.compiler import get_cache_key
env_vars = get_cache_invalidating_env_vars()
cache_key = get_cache_key(src, backend, options, env_vars)
def async_compile():
return self.compile(src, target=target, options=options.__dict__, _env_vars=env_vars)
def finalize_compile(kernel):
kernel_cache[key] = kernel
self._call_hook(key, signature, device, constants, options, [attrs], warmup, before=False)
kernel = async_mode.submit(cache_key, async_compile, finalize_compile)
else:
kernel = self.compile(src, target=target, options=options.__dict__)
kernel_cache[key] = kernel
self._call_hook(key, signature, device, constants, options, [attrs], warmup, before=False)
return kernel
def parse(self):
tree = ast.parse(self.src)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
assert isinstance(tree.body[0], ast.FunctionDef)
return tree
def __call__(self, *args, **kwargs):
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
def __setattr__(self, name, value):
super(JITFunction, self).__setattr__(name, value)
if name == "src":
self.hash = None
def __repr__(self):
return f"JITFunction({self.module}:{self.fn.__name__})"
@overload
def jit(fn: T) -> JITFunction[T]:
...
@overload
def jit(
*,
version=None,
repr: Optional[Callable] = None,
launch_metadata: Optional[Callable] = None,
do_not_specialize: Optional[Iterable[int]] = None,
do_not_specialize_on_alignment: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
) -> Callable[[T], JITFunction[T]]:
...
def jit(
fn: Optional[T] = None,
*,
version=None,
repr: Optional[Callable] = None,
launch_metadata: Optional[Callable] = None,
do_not_specialize: Optional[Iterable[int]] = None,
do_not_specialize_on_alignment: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
:note: When a jit'd function is called, arguments are
implicitly converted to pointers if they have a :code:`.data_ptr()` method
and a `.dtype` attribute.
:note: This function will be compiled and run on the GPU. It will only have access to:
* python primitives,
* builtins within the triton package,
* arguments to this function,
* other jit'd functions
:param fn: the function to be jit-compiled
:type fn: Callable
"""
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
if os.getenv("TRITON_INTERPRET", "0") == "1":
from .interpreter import InterpretedFunction
return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize,
do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug,
noinline=noinline, repr=repr, launch_metadata=launch_metadata)
else:
return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
do_not_specialize_on_alignment=do_not_specialize_on_alignment,
debug=debug,
noinline=noinline,
repr=repr,
launch_metadata=launch_metadata,
)
if fn is not None:
return decorator(fn)
else:
return decorator
class MockTensor:
"""
Can be used in place of real tensors when calling:
kernel.warmup(MockTensor(torch.float32), ...)
"""
@staticmethod
def wrap_dtype(arg):
if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch":
return MockTensor(arg)
return arg
def __init__(self, dtype):
self.dtype = dtype
@staticmethod
def data_ptr():
return 0
@staticmethod
def ptr_range():
return 0
class TensorWrapper:
def __init__(self, base, dtype):
self.dtype = dtype
self.base = base
self.data = base.data
self.device = base.device
self.shape = self.base.shape
def data_ptr(self):
return self.base.data_ptr()
def stride(self, i):
return self.base.stride(i)
def __str__(self) -> str:
return f"TensorWrapper[{self.dtype}]({self.base})"
def element_size(self):
return self.base.element_size()
def cpu(self):
return TensorWrapper(self.base.cpu(), self.dtype)
def copy_(self, other):
self.base.copy_(other.base)
def clone(self):
return TensorWrapper(self.base.clone(), self.dtype)
def to(self, device):
return TensorWrapper(self.base.to(device), self.dtype)
def reinterpret(tensor, dtype):
if isinstance(tensor, TensorWrapper):
if dtype == tensor.base.dtype:
return tensor.base
else:
return TensorWrapper(tensor.base, dtype)
elif hasattr(tensor, "data_ptr"):
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f"Cannot reinterpret a {type(tensor)}.")
def get_jit_fn_file_line(fn):
base_fn = fn
while not isinstance(base_fn, JITFunction):
base_fn = base_fn.fn
file_name = base_fn.fn.__code__.co_filename
lines, begin_line = inspect.getsourcelines(base_fn.fn)
for idx, line in enumerate(lines):
if line.strip().startswith("def "):
begin_line += idx
break
return file_name, begin_line