import ast
import inspect
import re
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import (Any, Callable, Dict, Generator, Iterable, List, Literal, NoReturn, Optional, Tuple, Type, TypeVar,
Union, overload)
from typing_extensions import ParamSpec, TypeAlias
from .._C import ir
from .errors import CodegenError, UnsupportedSyntaxError
from .function import Function, FunctionLocation
from .name_scope import NameScope
from .specialization import BaseArgType, IRArgType, PointerArgType, PlainArgType, Specialization, StructArgType
from ..common.compat import get_annotations, merge_dict
from ..language.core.constexpr import ConstExpr
from ..language.core.dtype import KnownTypes
from ..language.core.ir_value import GlobalAddress, IRHandle, IRValue, PlainValue, materialize_ir_value
from ..language.core.range import range as _range, static_range
from ..language.core.struct import BaseField, Struct
from ..language.core.tensor import BaseTensor
from ..language.core.utils import static_assert, global_builder
T = TypeVar("T")
P = ParamSpec("P")
@dataclass
class CodegenOptions:
capture_exceptions: bool = True
ir_multithreading: bool = True
@dataclass
class ReturnType:
py_type: Type[IRValue]
ir_type: ir.Type
ReturnTypesDict: TypeAlias = Dict[str, List[ReturnType]]
@dataclass
class VisitorState:
discard_everything: bool = False
inside_function: bool = False
return_allowed: bool = True
return_types: List[ReturnType] = field(default_factory=list)
visited_return_types: ReturnTypesDict = field(default_factory=dict)
@dataclass
class BlockInOut:
block: ir.Block
init_handles: Dict[str, IRHandle]
yield_values: Dict[str, IRValue]
yield_handles: Dict[str, IRHandle]
class FunctionVisitor(ast.NodeVisitor):
def __init__(
self,
source_lines: Optional[List[str]],
spec: Specialization,
global_vars: Dict[str, Any],
location: FunctionLocation,
options: CodegenOptions,
visited_return_types: Optional[ReturnTypesDict] = None,
is_kernel: bool = True,
):
super().__init__()
self.src = source_lines
self.ir_function: Optional[ir.FuncOp] = None
self.spec = spec
self.scope = NameScope(merge_dict(global_vars, spec.constexprs))
self.location = location
global_builder.get_ir_builder().set_insertion_point_to_start(global_builder.get_ir_module().get_body())
global_builder.get_ir_builder().set_loc(self.location.filename, self.location.line_offset, 0)
self.state = VisitorState()
self.options = options
self.is_kernel = is_kernel
if visited_return_types:
self.state.visited_return_types = visited_return_types
@staticmethod
def get_binary_method_name(op_class: Type[ast.operator]) -> str:
names: Dict[Type[ast.operator], str] = {
ast.Add: '__add__',
ast.Sub: '__sub__',
ast.Mult: '__mul__',
ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__',
ast.Mod: '__mod__',
ast.Pow: '__pow__',
ast.LShift: '__lshift__',
ast.RShift: '__rshift__',
ast.BitAnd: '__and__',
ast.BitOr: '__or__',
ast.BitXor: '__xor__',
}
name = names.get(op_class)
if name:
return name
raise NotImplementedError(f"Method for {op_class.__class__.__name__} is not implemented")
@staticmethod
def get_bool_method_name(op_class: Type[ast.boolop]) -> str:
names: Dict[Type[ast.boolop], str] = {
ast.And: 'logical_and',
ast.Or: 'logical_or',
}
name = names.get(op_class)
if name:
return name
raise NotImplementedError(f"Method for {op_class.__name__} is not implemented")
@staticmethod
def get_unary_method_name(op_class: Type[ast.unaryop]) -> str:
names: Dict[Type[ast.unaryop], str] = {
ast.USub: '__neg__',
ast.UAdd: '__pos__',
ast.Not: '__not__',
ast.Invert: '__invert__',
}
name = names.get(op_class)
if name:
return name
raise NotImplementedError(f"Method for {op_class.__name__} is not implemented")
@staticmethod
def get_compare_method_name(op_class: Type[ast.cmpop]) -> str:
names: Dict[Type[ast.cmpop], str] = {
ast.Eq: '__eq__',
ast.NotEq: '__ne__',
ast.Gt: '__gt__',
ast.GtE: '__ge__',
ast.Lt: '__lt__',
ast.LtE: '__le__',
}
name = names.get(op_class)
if name:
return name
raise NotImplementedError(f"Method for {op_class.__name__} is not implemented")
@staticmethod
def has_builder_support(value) -> bool:
return isinstance(value, (BaseTensor, GlobalAddress, PlainValue))
def raise_unsupported(self, node: ast.AST, message: Optional[str] = None, context: bool = False) -> NoReturn:
error = UnsupportedSyntaxError(node, self.src, message)
if context:
raise error
raise error from None
def apply_binary_method(self, method_name, lhs, rhs) -> Any:
reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
if not self.has_builder_support(lhs) and self.has_builder_support(rhs):
return getattr(rhs, reverse_method_name)(lhs)
result = getattr(lhs, method_name)(rhs)
if result is NotImplemented:
result = getattr(rhs, reverse_method_name)(lhs)
return result
def get_call_args(self, func: Callable, *args: Tuple[Any], **kwargs: Dict[str, Any]) -> Dict[str, Any]:
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
return bound_args.arguments
def call_jit_function(self, fn: Function, args: Tuple[Any], kwargs: Dict[str, Any]) -> Optional[Any]:
base_fn = fn.fn
call_args = self.get_call_args(base_fn, *args, **kwargs)
annotations = get_annotations(base_fn)
runtime_args, constexprs = Function.split_args(call_args, annotations)
arg_values: Dict[str, IRValue] = {}
for name, value in runtime_args.items():
if isinstance(value, IRValue):
arg_values[name] = value
else:
arg_values[name] = materialize_ir_value(value)
arg_types = {name: IRArgType(type(value), value.to_ir().get_type()) for name, value in arg_values.items()}
fn_name = fn.node.name
ret_types = []
if global_builder.get_ir_module().has_function(fn_name):
ret_types = self.state.visited_return_types[fn_name]
else:
spec = Specialization(arg_types, constexprs)
with self.visit_region():
visitor = FunctionVisitor(fn.src, spec, base_fn.__globals__, fn.location, self.options,
self.state.visited_return_types, is_kernel=False)
visitor.visit(fn.node)
ret_types = visitor.state.return_types
self.state.visited_return_types[fn_name] = ret_types
ir_operands = [value.to_ir() for value in arg_values.values()]
op = global_builder.get_ir_builder().create_func_CallOp(fn.node.name, ir_operands,
[ret_type.ir_type for ret_type in ret_types])
if len(ret_types) == 0:
return None
return_values = [ret_type.py_type.from_ir(op.get_result(i)) for i, ret_type in enumerate(ret_types)]
if len(ret_types) == 1:
return return_values[0]
return return_values
def compute_inout(self, node: ast.AST, stmts: List[ast.stmt], ind_var: Optional[Tuple[str, ir.Type]] = None,
make_args: bool = False) -> BlockInOut:
with self.visit_region() as (outer_scope, _):
block = ir.Block()
if ind_var is not None:
name, ir_type = ind_var
arg = block.add_argument(ir_type)
self.scope.save(name, PlainValue(arg))
global_builder.get_ir_builder().set_insertion_point_to_start(block)
self.visit_statements(stmts)
for name in self.scope.redefined:
old_value = outer_scope.lookup(name)
new_value = self.scope.lookup(name)
if type(old_value) is not type(new_value):
self.raise_unsupported(
node, f"'{name}' was re-assigned to an object with different type: "
f"initial type is {old_value.__class__.__name__}, new type is {new_value.__class__.__name__}")
_, init_handles = self.mat_ir_values(outer_scope.lookup(name) for name in self.scope.redefined)
if make_args:
for handle in init_handles:
arg = block.add_argument(handle.get_type())
handle.replace_uses_in_block(block, arg)
yield_values, yield_handles = self.mat_ir_values(self.scope.lookup(name) for name in self.scope.redefined)
return BlockInOut(
block=block,
init_handles=dict(zip(self.scope.redefined, init_handles)),
yield_values=dict(zip(self.scope.redefined, yield_values)),
yield_handles=dict(zip(self.scope.redefined, yield_handles)),
)
def dereference_name(self, name: str) -> Optional[Any]:
return self.scope.lookup(name)
@overload
def ensure_bool_value(self, node: ast.expr, require_ir: Literal[True]) -> PlainValue:
...
@overload
def ensure_bool_value(self, node: ast.expr, require_ir: bool = False) -> Union[PlainValue, bool]:
...
def ensure_bool_value(self, node: ast.expr, require_ir: bool = False):
value = ConstExpr.unwrap(self.visit(node))
if isinstance(value, (int, float, type(None))):
if not require_ir:
return bool(value)
value = materialize_ir_value(bool(value))
if not isinstance(value, PlainValue):
self.raise_unsupported(
node, f"Condition expression must be evaluated as PlainValue, got {value.__class__.__name__}")
return value.cast(KnownTypes.bit)
def get_arg_value(self, arg_type: BaseArgType, handle: IRHandle) -> IRValue:
if isinstance(arg_type, PointerArgType):
return GlobalAddress(handle=handle, dtype=arg_type.dtype)
if isinstance(arg_type, PlainArgType):
return PlainValue(handle=handle, dtype=arg_type.dtype)
if isinstance(arg_type, IRArgType):
return arg_type.py_type.from_ir(handle)
if isinstance(arg_type, StructArgType):
return arg_type.py_type.from_ir(handle).create_local()
raise NotImplementedError(f"Not implemented for {arg_type.__class__.__name__}")
def mat_ir_values(self, values: Iterable[Any]) -> Tuple[Tuple[PlainValue, ...], Tuple[IRHandle, ...]]:
ir_values = tuple(materialize_ir_value(value) for value in values)
ir_handles = tuple(value.to_ir() for value in ir_values)
return ir_values, ir_handles
@contextmanager
def nest_scope(self) -> Generator[None, Any, None]:
outer_scope = self.scope
self.scope = outer_scope.inherit()
try:
yield
finally:
self.scope = outer_scope
def generic_visit(self, node: ast.AST) -> NoReturn:
self.raise_unsupported(node, f"{node.__class__.__name__} syntax is not supported in JIT function")
def visit(self, node: Optional[ast.AST]) -> Optional[Any]:
if node is None:
return None
if self.state.discard_everything:
return None
if self.state.inside_function and isinstance(node, ast.FunctionDef):
self.raise_unsupported(node, "Nested functions are not supported")
if not self.state.inside_function and not isinstance(node, ast.FunctionDef):
raise RuntimeError(f"JIT compilation is applicable to functions only, got {node.__class__.__name__} node")
if hasattr(node, "lineno") and hasattr(node, "col_offset"):
global_builder.get_ir_builder().set_loc(self.location.filename, self.location.line_offset + node.lineno,
node.col_offset)
try:
return super().visit(node)
except CodegenError:
raise
except Exception as e:
if self.options.capture_exceptions:
raise CodegenError(node, self.src, f"{e.__class__.__name__}: {e}") from e
raise
def visit_arguments(self, node: ast.arguments) -> Tuple[List[str], str]:
if node.defaults or node.kw_defaults:
self.raise_unsupported(node, "Default values for function arguments are not supported")
if node.posonlyargs:
self.raise_unsupported(node, "Positional-only arguments are not supported")
if node.kwonlyargs:
self.raise_unsupported(node, "Keyword-only arguments are not supported")
arg_names = [str(self.visit(arg)) for arg in node.args]
kwarg_name = str(self.visit(node.kwarg))
return arg_names, kwarg_name
def visit_arg(self, node: ast.arg) -> str:
return node.arg
@contextmanager
def visit_region(self) -> Generator[Tuple[NameScope, ir.InsertPoint], Any, None]:
outer_scope = self.scope
self.scope = outer_scope.inherit()
insert_point = global_builder.get_ir_builder().save_insertion_point()
return_allowed = self.state.return_allowed
self.state.return_allowed = False
try:
yield outer_scope.inherit(), insert_point
finally:
self.scope = outer_scope
global_builder.get_ir_builder().restore_insertion_point(insert_point)
self.state.return_allowed = return_allowed
def visit_statements(self, stmts: List[ast.stmt]) -> None:
for stmt in stmts:
self.visit(stmt)
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
return self.visit_Assign(node)
def visit_Assert(self, node: ast.Assert) -> None:
test = self.visit(node.test)
try:
test = ConstExpr(test)
static_assert(test, self.visit(node.msg))
except TypeError:
self.raise_unsupported(
node,
f"An assertion turned out to test a runtime value {test!r}, only compile-time values are supported")
def visit_Assign(self, node: ast.Assign) -> None:
targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets
if len(targets) != 1:
self.raise_unsupported(node, "Assignment operator must have exactly one target")
lhs = targets[0]
rhs = self.visit(node.value)
if isinstance(lhs, ast.Subscript) and isinstance(lhs.ctx, ast.Store):
base = self.visit(lhs.value)
subscript = self.visit(lhs.slice)
base.__setitem__(subscript, rhs)
return
if isinstance(lhs, ast.Attribute) and isinstance(lhs.ctx, ast.Store):
base = self.visit(lhs.value)
if setter := getattr(base, "__setattrjit__", None):
setter(lhs.attr, rhs)
else:
setattr(base, lhs.attr, rhs)
return
lhs_names = []
if isinstance(lhs, ast.Name):
lhs_names.append(self.visit(lhs))
elif isinstance(lhs, ast.Tuple) and all(isinstance(elt, ast.Name) for elt in lhs.elts):
lhs_names.extend(self.visit(lhs))
else:
self.raise_unsupported(node, "Assignment target must be name or tuple of names")
rhs_values = []
if isinstance(rhs, Iterable) and len(lhs_names) != 1:
if len(rhs) != len(lhs_names):
self.raise_unsupported(
node, "Assignment operator must have equal number of names and values, "
f"got {len(lhs_names)} names and {len(rhs)} values")
rhs_values.extend(rhs)
else:
rhs_values.append(rhs)
if len(lhs_names) != len(rhs_values):
raise RuntimeError("Assignment operator must have equal number of names and values")
for name, value in zip(lhs_names, rhs_values):
self.scope.save(name, value)
def visit_AugAssign(self, node: ast.AugAssign) -> None:
lhs = node.target
if isinstance(lhs, ast.Name):
lhs = ast.Name(lhs.id, ctx=ast.Load())
elif isinstance(lhs, ast.Subscript):
lhs = ast.Subscript(lhs.value, lhs.slice, ctx=ast.Load())
else:
self.raise_unsupported(node, f"{lhs.__class__.__name__} is not supported as left operand of AugAssign")
rhs = ast.BinOp(lhs, node.op, node.value)
assign = ast.Assign(targets=[node.target], value=rhs)
self.visit(assign)
def visit_Attribute(self, node: ast.Attribute) -> Any:
lhs = self.visit(node.value)
attr = str(node.attr)
try:
value = getattr(lhs, attr)
if not isinstance(lhs, Struct) or not isinstance(value, BaseField):
return value
except AttributeError:
pass
getter = getattr(lhs, "__getattrjit__", None)
if getter:
return getter(attr)
raise AttributeError(f"'{lhs.__class__.__name__}' object has no attribute '{attr}'")
def visit_BinOp(self, node: ast.BinOp) -> Any:
lhs = self.visit(node.left)
rhs = self.visit(node.right)
method_name = self.get_binary_method_name(type(node.op))
return self.apply_binary_method(method_name, lhs, rhs)
def visit_BoolOp(self, node: ast.BoolOp) -> Any:
if len(node.values) != 2:
self.raise_unsupported(node, "Chained boolean operators are not supported, group pairs with parentheses")
lhs = self.visit(node.values[0])
rhs = self.visit(node.values[1])
method_name = self.get_bool_method_name(type(node.op))
return self.apply_binary_method(method_name, lhs, rhs)
def visit_Call(self, node: ast.Call) -> Optional[Any]:
fn = self.visit(node.func)
if not callable(fn):
self.raise_unsupported(node, f"{fn.__class__.__name__} instance is not callable")
args = [self.visit(arg) for arg in node.args]
kwargs = dict(self.visit(keyword) for keyword in node.keywords)
if isinstance(fn, Function):
return self.call_jit_function(fn, args, kwargs)
return fn(*args, **kwargs)
def visit_Compare(self, node: ast.Compare) -> Any:
if len(node.comparators) != 1 or len(node.ops) != 1:
self.raise_unsupported(node, "Only simple comparison is supported (one operation, one comparator)")
lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0])
op = node.ops[0]
if isinstance(op, ast.Is):
return lhs is rhs
if isinstance(op, ast.IsNot):
return lhs is not rhs
method_name = self.get_compare_method_name(type(node.ops[0]))
return self.apply_binary_method(method_name, lhs, rhs)
def visit_Constant(self, node: ast.Constant) -> Any:
return node.value
def visit_Expr(self, node: ast.Expr) -> Optional[Any]:
return self.visit(node.value)
def parse_iterator(self, node: ast.For) -> tuple:
func = self.visit(node.iter.func)
args = [self.visit(arg) for arg in node.iter.args]
kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords)
return func, args, kwargs
def handle_static_range(self, node: ast.For, args, kwargs, target):
range_obj = static_range(*args, **kwargs)
for i in range(range_obj.start, range_obj.stop, range_obj.step):
self.scope.save(target, i)
self.visit_statements(node.body)
def visit_For(self, node: ast.For) -> None:
if len(node.orelse) != 0:
self.raise_unsupported(node, "else statement is not allowed after for-loop")
target = self.visit(node.target)
if not isinstance(target, str):
self.raise_unsupported(node, f"For-loop target must be an identifier, got {target.__class__.__name__}")
func, args, kwargs = self.parse_iterator(node)
iter_args = None, None, None
if func is static_range:
self.handle_static_range(node, args, kwargs, target)
return
elif func is range or func is _range:
range_obj = _range(*args, **kwargs)
iter_args = range_obj.start, range_obj.stop, range_obj.step
else:
self.raise_unsupported(
node,
"Only for-loops with range or asc.language.range or asc.language.static_range are supported",
)
builder = global_builder.get_ir_builder()
start, stop, step = map(lambda s: materialize_ir_value(s, KnownTypes.int_), iter_args)
if start.dtype != stop.dtype or start.dtype != step.dtype:
self.raise_unsupported(node, "Loop bounds must have the same DataType")
yields = {}
with self.visit_region():
block_inout = self.compute_inout(node, node.body, (target, start.to_ir().get_type()), make_args=True)
op = builder.create_scf_ForOp(start.to_ir(), stop.to_ir(), step.to_ir(),
list(block_inout.init_handles.values()))
self.scope.save(target, PlainValue(op.get_induction_var()))
body = op.get_body()
body.clear()
ir.inline_block_at_end(block_inout.block, body, body.get_arguments())
builder.set_insertion_point_to_end(body)
builder.create_scf_YieldOp(list(block_inout.yield_handles.values()))
for i, (name, value) in enumerate(block_inout.yield_values.items()):
yields[name] = value.from_ir(op.get_result(i))
for name, value in yields.items():
self.scope.save(name, value)
def visit_FormattedValue(self, node: ast.FormattedValue) -> str:
value = self.visit(node.value)
template = "{"
if node.conversion >= 0:
template += f"!{chr(node.conversion)}"
spec = self.visit(node.format_spec)
if spec:
template += f":{spec}"
template += "}"
return template.format(value)
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.state.inside_function = True
arg_types = self.spec.args.values()
builder = global_builder.get_ir_builder()
input_ir_types = [arg_type.to_ir() for arg_type in arg_types]
self.ir_function = builder.create_func_FuncOp(node.name, builder.get_function_type(input_ir_types))
self.ir_function.make_aicore()
if self.is_kernel:
self.ir_function.make_global()
entry = self.ir_function.add_entry_block()
arg_names = list(self.spec.args.keys())
self.ir_function.set_arg_names(arg_names)
builder.set_insertion_point_to_start(entry)
for i, (name, arg) in enumerate(self.spec.args.items()):
value = self.get_arg_value(arg, self.ir_function.get_arg(i))
self.scope.save(name, value)
self.visit_statements(node.body)
if not entry.has_terminator():
builder.create_func_ReturnOp()
if self.state.return_types:
result_ir_types = [ret_type.ir_type for ret_type in self.state.return_types]
self.ir_function.set_type(builder.get_function_type(input_ir_types, result_ir_types))
self.state.inside_function = False
self.state.discard_everything = False
def visit_If(self, node: ast.If) -> None:
cond = self.ensure_bool_value(node.test)
if isinstance(cond, bool):
if cond:
self.visit_statements(node.body)
else:
self.visit_statements(node.orelse)
return
yields = {}
with self.visit_region():
then_inout = self.compute_inout(node, node.body)
else_inout = self.compute_inout(node, node.orelse)
def merge_sorted(dict1: Dict[str, T], dict2: Dict[str, T]) -> Dict[str, T]:
dicts = merge_dict(dict1, dict2)
return {key: dicts[key] for key in sorted(dicts.keys())}
yield_values = merge_sorted(then_inout.yield_values, else_inout.yield_values)
ret_types = [value.to_ir().get_type() for value in yield_values.values()]
builder = global_builder.get_ir_builder()
op = builder.create_scf_IfOp(cond.to_ir(), ret_types, with_else=True)
def select(keys: Iterable[str], main: Dict[str, T], fallback: Dict[str, T]) -> Dict[str, T]:
return {key: main[key] if key in main else fallback[key] for key in keys}
then_block = op.get_then_block()
then_block.clear()
then_inout.block.merge_block_before(then_block)
builder.set_insertion_point_to_end(then_block)
then_yields = select(yield_values.keys(), then_inout.yield_handles, else_inout.init_handles)
builder.create_scf_YieldOp(list(then_yields.values()))
else_block = op.get_else_block()
else_block.clear()
else_inout.block.merge_block_before(else_block)
builder.set_insertion_point_to_end(else_block)
else_yields = select(yield_values.keys(), else_inout.yield_handles, then_inout.init_handles)
builder.create_scf_YieldOp(list(else_yields.values()))
for i, (name, value) in enumerate(yield_values.items()):
yields[name] = value.from_ir(op.get_result(i))
for name, value in yields.items():
self.scope.save(name, value)
def visit_IfExp(self, node: ast.IfExp) -> Any:
cond = self.ensure_bool_value(node.test, require_ir=True)
with self.visit_region() as (_, insert_point):
then_block = ir.Block()
builder = global_builder.get_ir_builder()
builder.set_insertion_point_to_start(then_block)
then_value = materialize_ir_value(self.visit(node.body))
else_block = ir.Block()
builder.set_insertion_point_to_start(else_block)
else_value = materialize_ir_value(self.visit(node.orelse))
builder.restore_insertion_point(insert_point)
if then_value.dtype != else_value.dtype:
self.raise_unsupported(
node,
f"Conditional operator has inconsistent result types: {then_value.dtype} / {else_value.dtype}")
ret_type = then_value.dtype
if not ret_type.is_numeric():
self.raise_unsupported(node, f"Conditional operator must have numeric result type, got {ret_type}")
op = builder.create_scf_IfOp(cond.to_ir(), [ret_type.to_ir()], with_else=True)
then_block.merge_block_before(op.get_then_block())
builder.set_insertion_point_to_end(op.get_then_block())
builder.create_scf_YieldOp([then_value.to_ir()])
else_block.merge_block_before(op.get_else_block())
builder.set_insertion_point_to_end(op.get_else_block())
builder.create_scf_YieldOp([else_value.to_ir()])
return PlainValue(op.get_result(0))
def visit_JoinedStr(self, node: ast.JoinedStr) -> str:
values = (self.visit(value) for value in node.values)
return "".join(values)
def visit_keyword(self, node: ast.keyword) -> Tuple[str, Any]:
return node.arg, self.visit(node.value)
def visit_List(self, node: ast.List) -> List[Optional[Any]]:
return [self.visit(elt) for elt in node.elts]
def visit_Name(self, node: ast.Name) -> Union[str, Optional[Any]]:
if isinstance(node.ctx, ast.Store):
return node.id
value = self.dereference_name(node.id)
return ConstExpr.unwrap(value)
def visit_Pass(self, node: ast.Pass) -> None:
pass
def visit_Return(self, node: ast.Return) -> None:
if not self.state.return_allowed:
self.raise_unsupported(node, "Return statement is not allowed in nested blocks")
value = self.visit(node.value)
self.state.discard_everything = True
if value is None:
return
if self.is_kernel:
self.raise_unsupported(node, "JIT kernel function cannot return objects")
values = []
if isinstance(value, Iterable):
values.extend(value)
else:
values.append(value)
ir_values = [materialize_ir_value(value) for value in values]
self.state.return_types = [ReturnType(type(value), value.to_ir().get_type()) for value in ir_values]
global_builder.get_ir_builder().create_func_ReturnOp([value.to_ir() for value in ir_values])
def visit_Slice(self, node: ast.Slice) -> slice:
return slice(self.visit(node.lower), self.visit(node.upper), self.visit(node.step))
def visit_Subscript(self, node: ast.Subscript) -> Any:
if not isinstance(node.ctx, ast.Load):
self.raise_unsupported(node, "Subscript operation is not allowed (must be Load context or assignment)")
value = self.visit(node.value)
slices = self.visit(node.slice)
return value.__getitem__(slices)
def visit_Tuple(self, node: ast.Tuple) -> Tuple[Optional[Any], ...]:
return tuple(self.visit(elt) for elt in node.elts)
def visit_UnaryOp(self, node: ast.UnaryOp) -> Any:
operand = self.visit(node.operand)
method_name = self.get_unary_method_name(type(node.op))
return getattr(operand, method_name)()
def visit_With(self, node: ast.With) -> None:
if len(node.items) != 1:
self.raise_unsupported(node, "Only one item in with-statement is supported")
item = node.items[0]
context = self.visit(item.context_expr)
with self.nest_scope():
entered = context.__enter__()
if isinstance(item.optional_vars, ast.Name):
self.scope.save(self.visit(item.optional_vars), entered)
self.visit_statements(node.body)
context.__exit__(None, None, None)
def visit_While(self, node: ast.While) -> None:
if len(node.orelse) != 0:
self.raise_unsupported(node, "else statement is not allowed after while-loop")
builder = global_builder.get_ir_builder()
after_inout = self.compute_inout(node, node.body, make_args=True)
ret_types = [handle.get_type() for handle in after_inout.init_handles.values()]
op = builder.create_scf_WhileOp(ret_types, list(after_inout.init_handles.values()))
with self.visit_region():
before_block = builder.create_block(op.get_before(), ret_types)
builder.set_insertion_point_to_start(before_block)
cond = self.ensure_bool_value(node.test, require_ir=True)
builder.create_scf_ConditionOp(cond.to_ir(), before_block.get_arguments())
with self.visit_region():
after_block = builder.create_block(op.get_after(), ret_types)
ir.inline_block_at_end(after_inout.block, after_block, after_block.get_arguments())
builder.set_insertion_point_to_end(after_block)
builder.create_scf_YieldOp(list(after_inout.yield_handles.values()))
for i, (name, value) in enumerate(after_inout.yield_values.items()):
self.scope.save(name, value.from_ir(op.get_result(i)))