import ast
import builtins
import contextlib
import copy
import inspect
import re
import warnings
import textwrap
import itertools
from dataclasses import dataclass
from types import ModuleType
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List
import triton.language.extra.cann.extension as extension
from triton.extension.buffer.language import core as bl
from triton.extension.buffer.language.builder import setup_unified_builder_with_buffer_builder
from .. import knobs, language
from .._C.libtriton import ir, gluon_ir, buffer_ir
from .._C.libtriton.ascend import ir as ascend_ir
from ..language import constexpr, str_to_ty, tensor, tuple as tl_tuple
from ..language.core import _unwrap_if_constexpr, base_value, base_type
from ..runtime.jit import get_jit_fn_file_line, get_full_name, JITCallable, BoundConstexprFunction, ConstexprFunction, JITFunction
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
def check_identifier_legality(name, type):
pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*$'
if not re.match(pattern, name):
raise CompilationError(f"invalid {type} identifier: {name}", name)
return name
WITH_DISPATCH = {}
from triton.language.extra.cann.extension.dispatch import ASCEND_WITH_DISPATCH
from triton.language.extra.cann.extension.builder import setup_unified_builder
WITH_DISPATCH.update(ASCEND_WITH_DISPATCH)
def mangle_fn(name, arg_tys, constants, caller_context):
mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys])
mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
mangled_constants = mangled_constants.replace('.', '_d_')
mangled_constants = mangled_constants.replace("'", '_sq_')
mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
if caller_context is not None:
ret += caller_context.mangle()
return ret
def _is_triton_value(o: Any) -> bool:
return isinstance(o, base_value)
def _is_triton_tensor(o: Any) -> bool:
return isinstance(o, tensor)
def _is_constexpr(o: Any) -> bool:
return o is None or isinstance(o, (constexpr, language.core.dtype, JITCallable))
def _is_non_scalar_tensor(o: Any) -> bool:
return _is_triton_tensor(o) and (o.type.is_block() and o.type.numel != 1)
def _is_list_like(o: Any) -> bool:
return isinstance(o, (list, tuple))
def _check_fn_args(node, fn, args):
if fn.noinline:
for idx, arg in enumerate(args):
if not _is_constexpr(arg) and _is_non_scalar_tensor(arg):
raise UnsupportedLanguageConstruct(
fn.src, node,
f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
)
def _is_namedtuple(val):
return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields")
def _apply_to_tuple_values(value, fn):
if _is_namedtuple(type(value)):
fields = value._fields
elif isinstance(value, language.tuple):
fields = value.type.fields
else:
assert False, f"Unsupported type {type(value)}"
vals = [fn(v) for v in value]
vals = [constexpr(v) if v is None else v for v in vals]
types = [v.type for v in vals]
return language.tuple(vals, language.tuple_type(types, fields))
def flatten_values_to_ir(values: Iterable[base_value]):
handles = []
for v in values:
v._flatten_ir(handles)
return handles
def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
cursor = 0
for ty in types:
value, cursor = ty._unflatten_ir(handles, cursor)
yield value
assert cursor == len(handles)
_condition_types = {bool, int, type(None)}
def _clone_triton_value(val):
handles = []
val._flatten_ir(handles)
clone, _ = val.type._unflatten_ir(handles, 0)
return clone
def _clone_scope(scope):
return {name: _clone_triton_value(val) if _is_triton_value(val) else val for name, val in scope.items()}
class enter_sub_region:
def __init__(self, generator):
self.generator = generator
def __enter__(self):
self.liveins = _clone_scope(self.generator.lscope)
self.prev_defs = _clone_scope(self.generator.local_defs)
self.generator.local_defs = {}
self.insert_block = self.generator.builder.get_insertion_block()
self.insert_point = self.generator.builder.get_insertion_point()
return self.liveins, self.insert_block
def __exit__(self, *args, **kwargs):
self.generator.builder.restore_insertion_point(self.insert_point)
self.generator.lscope = self.liveins
self.generator.local_defs = self.prev_defs
class ContainsReturnChecker(ast.NodeVisitor):
def __init__(self, gscope):
self.gscope = gscope
def _visit_stmts(self, body) -> bool:
return any(self.visit(s) for s in body)
def _visit_function(self, fn) -> bool:
return False
def generic_visit(self, node) -> bool:
ret = False
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
ret = ret or self.visit(item)
elif isinstance(value, ast.AST):
ret = ret or self.visit(value)
return ret
def visit_Attribute(self, node: ast.Attribute) -> bool:
if isinstance(node.value, ast.Name):
if node.value.id in self.gscope:
value = self.gscope[node.value.id]
fn = getattr(value, node.attr)
return self._visit_function(fn)
return False
return self.visit(node.value)
def visit_Name(self, node: ast.Name) -> bool:
if type(node.ctx) is ast.Store:
return False
if node.id in self.gscope:
fn = self.gscope[node.id]
return self._visit_function(fn)
return False
def visit_Return(self, node: ast.Return) -> bool:
return True
def visit_Assign(self, node: ast.Assign) -> bool:
return False
def visit_AugAssign(self, node: ast.AugAssign) -> bool:
return False
def visit_Module(self, node: ast.Module) -> bool:
return self._visit_stmts(node.body)
def visit_FunctionDef(self, node: ast.FunctionDef) -> bool:
return self._visit_stmts(node.body)
def visit_If(self, node: ast.If) -> bool:
ret = self._visit_stmts(node.body)
if node.orelse:
ret = ret or self._visit_stmts(node.orelse)
return ret
def visit_IfExp(self, node: ast.IfExp) -> bool:
return self.visit(node.body) or self.visit(node.orelse)
def visit_Call(self, node: ast.Call) -> bool:
return self.visit(node.func)
class ASTFunction:
def __init__(self, ret_types, arg_types, constants, attrs):
self.ret_types = ret_types
self.arg_types = arg_types
self.constants = constants
self.attrs = attrs
def flatten_ir_types(self, builder: ir.builder, types: List[base_type]) -> List[ir.type]:
ir_types = []
for ty in types:
if ty is None:
continue
ty._flatten_ir_types(builder, ir_types)
return ir_types
def return_types_ir(self, builder: ir.builder) -> List[ir.type]:
return self.flatten_ir_types(builder, self.ret_types)
def serialize(self, builder: ir.builder):
is_val = lambda path, _: path not in self.constants and _ is not None
val_paths = list(find_paths_if(self.arg_types, is_val))
arg_types = [get_iterable_path(self.arg_types, path) for path in val_paths]
arg_types_ir = self.flatten_ir_types(builder, arg_types)
ret_types_ir = self.return_types_ir(builder)
return builder.get_function_ty(arg_types_ir, ret_types_ir)
def deserialize(self, fn):
def make_template(ty):
if isinstance(ty, (list, tuple, language.tuple_type)):
return language.tuple([make_template(x) for x in ty], ty)
return language.constexpr(None)
vals = make_template(self.arg_types)
is_val = lambda path, _: path not in self.constants and _ is not None
val_paths = list(find_paths_if(self.arg_types, is_val))
cursor = 0
handles = [fn.args(i) for i in range(fn.get_num_args())]
for path in val_paths:
ty = get_iterable_path(self.arg_types, path)
attr_specs = self.attrs.get(path, [])
for attr_name, attr_val in attr_specs:
fn.set_arg_attr(cursor, attr_name, attr_val)
val, cursor = ty._unflatten_ir(handles, cursor)
set_iterable_path(vals, path, val)
constants = self.constants
for path, val in constants.items():
set_iterable_path(vals, path, language.constexpr(val))
return vals
@dataclass(frozen=True)
class BoundJITMethod:
__self__: base_value
__func__: JITFunction
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, *, options, codegen_fns,
module_map, is_gluon, module=None, is_kernel=False, function_types: Optional[Dict] = None,
noinline=False, caller_context=None, file_name: Optional[str] = None, begin_line=0):
self.context = context
self.is_gluon = is_gluon
if is_gluon:
from triton.experimental.gluon.language._semantic import GluonSemantic
self.builder = gluon_ir.GluonOpBuilder(context)
self.semantic = GluonSemantic(self.builder)
else:
from triton.language.semantic import TritonSemantic
if hasattr(options, "force_simt_only") and options.force_simt_only:
self.builder = ir.builder(context, compile_mode="simt")
else:
self.builder = ir.builder(context, compile_mode="simd")
self.semantic = TritonSemantic(self.builder)
self.name_loc_as_prefix = None
self.file_name = file_name
self.begin_line = begin_line - 1
self.builder.set_loc(file_name, begin_line, 0)
self.builder.options = options
self.ascend_builder = ascend_ir.ascendnpu_ir_builder(context, getattr(options, "arch", ""))
self.ascend_builder.set_loc(file_name, begin_line, 0)
setup_unified_builder(self.builder, self.ascend_builder)
self.buffer_builder = buffer_ir.buffer_builder(context)
self.buffer_builder.set_loc(file_name, begin_line, 0)
setup_unified_builder_with_buffer_builder(self.builder, self.buffer_builder)
self.builder.codegen_fns = codegen_fns
self.builder.module_map = {} if module_map is None else module_map
self.module = self.builder.create_module() if module is None else module
self.function_ret_types = {} if function_types is None else function_types
self.prototype = prototype
self.gscope = {}
for k, v in gscope.items():
if isinstance(v, ModuleType):
self.gscope[k] = module_map.get(v.__name__, v)
continue
module_name = getattr(v, "__module__", "")
if module_name in module_map:
self.gscope[k] = getattr(module_map[module_name], v.__name__)
else:
self.gscope[k] = v
self.lscope = {}
self.jit_fn = jit_fn
if is_kernel:
function_name = function_name[function_name.rfind('.') + 1:]
function_name = check_identifier_legality(function_name, "function")
self.function_name = function_name
self.is_kernel = is_kernel
self.cur_node = None
self.noinline = noinline
self.caller_context = caller_context
self.scf_stack = []
self.ret_type = None
self.local_defs: Dict[str, tensor] = {}
self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
self.fn = None
self.visiting_arg_default_value = False
builtin_namespace: Dict[str, Any] = {
_.__name__: _
for _ in (len, list, range, float, int, isinstance, getattr, hasattr)
}
builtin_namespace.update((
('print', language.core.device_print),
('min', language.minimum),
('max', language.maximum),
))
def _unsupported(self, node, message):
return UnsupportedLanguageConstruct(self.jit_fn.src, node, message)
def _is_constexpr_global(self, name):
absent_marker = object()
val = self.gscope.get(name, absent_marker)
if val is absent_marker:
return False
if _is_constexpr(val):
return True
return False
def _define_name_lookup(self):
def local_lookup(name: str, absent):
return self.lscope.get(name, absent)
def global_lookup(name: str, absent):
val = self.gscope.get(name, absent)
if any([
val is absent,
name in self.builtin_namespace,
type(val) is ModuleType,
isinstance(val, JITCallable),
getattr(val, "__triton_builtin__", False),
getattr(val, "__triton_aggregate__", False),
getattr(val, "__module__", "").startswith("triton.language"),
getattr(val, "__module__", "").startswith("triton.experimental.gluon.language"),
isinstance(val, language.dtype),
_is_namedtuple(val),
self._is_constexpr_global(name),
self.visiting_arg_default_value,
knobs.compilation.allow_non_constexpr_globals,
]):
return val
raise NameError(
textwrap.dedent(f"""\
Cannot access global variable {name} from within @jit'ed
function. Triton kernels can only access global variables that
are instanstiated as constexpr (`x = triton.language.constexpr(42)`). Note that this is different from
annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported. Alternatively, set the
envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not
promise to support this forever.""").replace("\n", " "))
absent_marker = object()
def name_lookup(name: str) -> Any:
absent = absent_marker
for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get:
value = lookup_function(name, absent)
if value is not absent:
return value
raise NameError(f'{name} is not defined')
return name_lookup
@contextlib.contextmanager
def _name_loc_prefix(self, prefix):
self.name_loc_as_prefix = prefix
yield
self.name_loc_as_prefix = None
def _maybe_set_loc_to_name(self, val, name):
if isinstance(val, (ir.value, ir.block_argument)):
val.set_loc(self.builder.create_name_loc(name, val.get_loc()))
elif _is_triton_value(val):
handles = []
val._flatten_ir(handles)
for handle in handles:
handle.set_loc(self.builder.create_name_loc(name, handle.get_loc()))
def set_value(self, name: str, value: Union[base_value, constexpr]) -> None:
''' This function:
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
1. record local defined name (FIXME: should consider control flow)
2. store tensor in self.lvalue
'''
self.lscope[name] = value
self.local_defs[name] = value
def _get_insertion_point_and_loc(self, builder=None):
_builder = self.builder if not builder else builder
loc = _builder.get_loc()
ip = _builder.get_insertion_point()
return ip, loc
def _set_insertion_point_and_loc(self, ip, loc, builder=None):
_builder = self.builder if not builder else builder
_builder.restore_insertion_point(ip)
_builder.set_loc(loc)
def _find_carries(self, node, liveins):
block = self.builder.create_block()
self.builder.set_insertion_point_to_start(block)
self.scf_stack.append(node)
self.visit_compound_statement(node.body)
self.scf_stack.pop()
block.erase()
init_tys = []
init_handles = []
names = []
for name, live_val in liveins.items():
if _is_triton_value(live_val):
loop_val = self.lscope[name]
self._verify_loop_carried_variable(name, loop_val, live_val)
live_handles = flatten_values_to_ir([live_val])
loop_handles = flatten_values_to_ir([loop_val])
if live_handles != loop_handles:
names.append(name)
init_tys.append(live_val.type)
init_handles.extend(live_handles)
else:
assert name not in self.local_defs, f'Loop carried variable {name} is not a triton value'
self.lscope = liveins.copy()
self.local_defs = {}
return names, init_handles, init_tys
def visit_compound_statement(self, stmts):
if not _is_list_like(stmts):
stmts = [stmts]
for stmt in stmts:
self.visit(stmt)
if isinstance(stmt, ast.Return):
break
def visit_Module(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_List(self, node):
ctx = self.visit(node.ctx)
assert ctx is None
elts = language.tuple([self.visit(elt) for elt in node.elts])
return elts
def visit_ListComp(self, node: ast.ListComp):
if len(node.generators) != 1:
raise ValueError("nested comprehensions are not supported")
comp = node.generators[0]
iter = self.visit(comp.iter)
if not isinstance(iter, tl_tuple):
raise NotImplementedError("only tuple comprehensions are supported")
results = []
for item in iter:
self.set_value(comp.target.id, item)
results.append(self.visit(node.elt))
return tl_tuple(results)
def visit_Return(self, node):
ret_value = self.visit(node.value)
handles = []
def decay(value):
if isinstance(value, language.tuple):
return _apply_to_tuple_values(value, decay)
elif isinstance(value, (language.constexpr, int, float)):
return self.semantic.to_tensor(value)
return value
ret_value = decay(ret_value)
if ret_value is None:
ret_ty = language.void
else:
assert isinstance(ret_value, language.core.base_value)
ret_value._flatten_ir(handles)
ret_ty = ret_value.type
self.builder.ret(handles)
if self.ret_type is None:
self.ret_type = ret_ty
elif self.ret_type != ret_ty:
raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}')
post_ret_block = self.builder.create_block()
self.builder.set_insertion_point_to_end(post_ret_block)
def visit_Starred(self, node) -> Any:
args = self.visit(node.value)
assert isinstance(args, language.core.tuple)
return args.values
def visit_FunctionDef(self, node):
arg_names, kwarg_names = self.visit(node.args)
if self.fn:
raise self._unsupported(node, "nested function definition is not supported.")
for i, default_value in enumerate(node.args.defaults[::-1]):
arg_node = node.args.args[-i - 1]
annotation = arg_node.annotation
name = arg_node.arg
st_target = ast.Name(id=name, ctx=ast.Store())
if annotation is None:
init_node = ast.Assign(targets=[st_target], value=default_value)
else:
init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
try:
assert not self.visiting_arg_default_value
self.visiting_arg_default_value = True
self.visit(init_node)
finally:
self.visiting_arg_default_value = False
visibility = "public" if self.is_kernel else "private"
fn_ty = self.prototype.serialize(self.builder)
self.fn = self.builder.get_or_insert_function(self.module, self.function_name, fn_ty, visibility, self.noinline)
self.module.push_back(self.fn)
entry = self.fn.add_entry_block()
arg_values = self.prototype.deserialize(self.fn)
if self.caller_context is not None:
self.caller_context.initialize_callee(self.fn, self.builder)
for arg_name, arg_value in zip(arg_names, arg_values):
self._maybe_set_loc_to_name(arg_value, arg_name)
self.set_value(arg_name, arg_value)
insert_pt = self.builder.get_insertion_block()
self.builder.set_insertion_point_to_start(entry)
self.visit_compound_statement(node.body)
assert not self.builder.get_insertion_block().has_terminator()
if self.ret_type is None or self.ret_type == language.void:
self.ret_type = language.void
self.builder.ret([])
else:
if isinstance(self.ret_type, language.tuple_type):
self.prototype.ret_types = self.ret_type.types
else:
self.prototype.ret_types = [self.ret_type]
self.fn.reset_type(self.prototype.serialize(self.builder))
self.builder.ret([self.builder.create_poison(ty) for ty in self.prototype.return_types_ir(self.builder)])
self.fn.finalize()
if insert_pt:
self.builder.set_insertion_point_to_end(insert_pt)
def visit_arguments(self, node):
arg_names = []
for arg in node.args:
arg_names += [self.visit(arg)]
kwarg_names = self.visit(node.kwarg)
return arg_names, kwarg_names
def visit_arg(self, node):
ast.NodeVisitor.generic_visit(self, node)
return node.arg
def visit_AnnAssign(self, node):
annotation = self.visit(node.annotation)
target = self.visit(node.target)
value = self.visit(node.value)
if annotation == constexpr:
if target in self.lscope:
raise ValueError(f'{target} is already defined.'
f' constexpr cannot be reassigned.')
value = constexpr(value)
self.lscope[target] = value
return self.lscope[target]
return self.visit_Assign(node)
def assignTarget(self, target, value):
assert isinstance(target.ctx, ast.Store)
if isinstance(target, ast.Subscript):
return self.visit_Subscript_Store(target, value)
if isinstance(target, ast.Tuple):
for i, target in enumerate(target.elts):
self.assignTarget(target, value.values[i])
return
if isinstance(target, ast.Attribute):
raise NotImplementedError("Attribute assignment is not supported in triton")
assert isinstance(target, ast.Name)
self.set_value(self.visit(target), value)
def visit_Assign(self, node):
def _sanitize_value(value):
if isinstance(value, language.tuple):
return _apply_to_tuple_values(value, _sanitize_value)
native_nontensor_types = (language.dtype, language.tuple)
value = _unwrap_if_constexpr(value)
if value is not None and \
not _is_triton_value(value) and \
not isinstance(value, native_nontensor_types):
value = self.semantic.to_tensor(value)
return value
targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets
assert len(targets) == 1
target = targets[0]
if isinstance(target, ast.Name):
with self._name_loc_prefix(target.id):
values = _sanitize_value(self.visit(node.value))
else:
values = _sanitize_value(self.visit(node.value))
self.assignTarget(target, values)
def visit_AugAssign(self, node):
lhs = copy.deepcopy(node.target)
lhs.ctx = ast.Load()
rhs = ast.BinOp(lhs, node.op, node.value)
assign = ast.Assign(targets=[node.target], value=rhs)
self.visit(assign)
return self.visit(lhs)
def visit_Name(self, node):
if type(node.ctx) is ast.Store:
return node.id
return self.dereference_name(node.id)
def visit_Store(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_Load(self, node):
ast.NodeVisitor.generic_visit(self, node)
def visit_Tuple(self, node):
args = [self.visit(x) for x in node.elts]
return language.tuple(args)
def _apply_binary_method(self, method_name, lhs, rhs):
if _is_triton_tensor(lhs):
return getattr(lhs, method_name)(rhs, _semantic=self.semantic)
if _is_triton_tensor(rhs):
reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
return getattr(rhs, reverse_method_name)(lhs, _semantic=self.semantic)
if not isinstance(lhs, (constexpr, language.tuple)) and isinstance(rhs, constexpr):
lhs = constexpr(lhs)
return getattr(lhs, method_name)(rhs)
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
method_name = self._method_name_for_bin_op.get(type(node.op))
if method_name is None:
raise self._unsupported(node,
"AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_bin_op: Dict[Type[ast.operator], str] = {
ast.Add: '__add__',
ast.Sub: '__sub__',
ast.Mult: '__mul__',
ast.Div: '__truediv__',
ast.FloorDiv: '__floordiv__',
ast.Mod: '__mod__',
ast.Pow: '__pow__',
ast.LShift: '__lshift__',
ast.RShift: '__rshift__',
ast.BitAnd: '__and__',
ast.BitOr: '__or__',
ast.BitXor: '__xor__',
}
def visit_then_else_blocks(self, node, liveins, then_block, else_block):
self.builder.set_insertion_point_to_start(then_block)
self.visit_compound_statement(node.body)
then_block = self.builder.get_insertion_block()
then_defs = self.local_defs.copy()
then_vals = self.lscope.copy()
else_defs = {}
else_vals = liveins.copy()
if node.orelse:
self.builder.set_insertion_point_to_start(else_block)
self.lscope = liveins.copy()
self.local_defs = {}
self.visit_compound_statement(node.orelse)
else_defs = self.local_defs.copy()
else_block = self.builder.get_insertion_block()
else_vals = self.lscope.copy()
names = []
for name, value in liveins.items():
if not _is_triton_value(value):
continue
then_handles = flatten_values_to_ir([then_vals[name]])
else_handles = flatten_values_to_ir([else_vals[name]])
if then_handles == else_handles:
continue
names.append(name)
then_defs[name] = then_vals[name]
else_defs[name] = else_vals[name]
for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
type_equal = type(defs[name]) == type(value)
assert type_equal and defs[name].type == value.type, \
f'initial value for `{name}` is of type {value}, '\
f'but the {block_name} block redefines it as {defs[name]}'
for name in sorted(then_defs.keys() & else_defs.keys()):
if name in names:
continue
then_val = then_defs[name]
then_ty = then_val.type
else_val = else_defs[name]
else_ty = else_val.type
type_equal = type(then_val) == type(else_val)
assert type_equal and then_ty == else_ty, \
f'Mismatched type for {name} between then block ({then_ty}) '\
f'and else block ({else_ty})'
names.append(name)
return then_defs, else_defs, then_block, else_block, names
def visit_if_top_level(self, cond, node):
with enter_sub_region(self) as sr:
liveins, ip_block = sr
then_block = self.builder.create_block()
else_block = self.builder.create_block()
self.builder.set_insertion_point_to_end(ip_block)
self.builder.create_cond_branch(cond.handle, then_block, else_block)
then_defs, else_defs, then_block, else_block, names = \
self.visit_then_else_blocks(node, liveins, then_block, else_block)
endif_block = self.builder.create_block()
self.builder.set_insertion_point_to_end(then_block)
assert not then_block.has_terminator(), f"{then_block}"
then_handles = flatten_values_to_ir(then_defs[name] for name in names)
self.builder.create_branch(endif_block, then_handles)
self.builder.set_insertion_point_to_end(else_block)
assert not else_block.has_terminator(), f"{else_block}"
else_handles = flatten_values_to_ir(else_defs[name] for name in names)
self.builder.create_branch(endif_block, else_handles)
assert len(then_handles) == len(else_handles)
for then_h, else_h in zip(then_handles, else_handles):
ty = then_h.get_type()
assert ty == else_h.get_type()
endif_block.add_argument(ty)
self.builder.set_insertion_point_to_start(endif_block)
res_handles = [endif_block.arg(i) for i in range(len(then_handles))]
types = [then_defs[name].type for name in names]
new_values = unflatten_ir_values(res_handles, types)
for name, new_value in zip(names, new_values):
self.set_value(name, new_value)
def visit_if_scf(self, cond, node):
with enter_sub_region(self) as sr:
liveins, _ = sr
ip, last_loc = self._get_insertion_point_and_loc()
then_block = self.builder.create_block()
else_block = self.builder.create_block() if node.orelse else None
then_defs, else_defs, then_block, else_block, names = \
self.visit_then_else_blocks(node, liveins, then_block, else_block)
then_handles = flatten_values_to_ir(then_defs[name] for name in names)
for name, val in zip(names, then_handles):
self._maybe_set_loc_to_name(val, name)
self._set_insertion_point_and_loc(ip, last_loc)
if_op = self.builder.create_if_op([h.get_type() for h in then_handles], cond.handle, True)
then_block.merge_block_before(if_op.get_then_block())
self.builder.set_insertion_point_to_end(if_op.get_then_block())
if len(names) > 0:
self.builder.create_yield_op(then_handles)
if not node.orelse:
else_block = if_op.get_else_block()
else:
else_block.merge_block_before(if_op.get_else_block())
self.builder.set_insertion_point_to_end(if_op.get_else_block())
if len(names) > 0:
else_handles = flatten_values_to_ir(else_defs[name] for name in names)
for name, val in zip(names, else_handles):
self._maybe_set_loc_to_name(val, name)
self.builder.create_yield_op(else_handles)
res_handles = [if_op.get_result(i) for i in range(len(then_handles))]
types = [then_defs[name].type for name in names]
new_values = unflatten_ir_values(res_handles, types)
for name, new_value in zip(names, new_values):
self.set_value(name, new_value)
def visit_If(self, node):
cond = self.visit(node.test)
if _is_triton_tensor(cond):
if _is_non_scalar_tensor(cond):
raise self._unsupported(node, "Boolean value of Tensor with more than one value is ambiguous")
if cond.type.is_block():
warnings.warn(
"If conditional called with multidimensional Tensor instead of scalar; please use \"if (%s).item()\" instead"
% ast.unparse(node.test))
cond = language.core._unsplat(cond, _semantic=self.semantic, _generator=self)
cond = cond.to(language.int1, _semantic=self.semantic)
if ContainsReturnChecker(self.gscope).visit(node):
if self.scf_stack:
raise self._unsupported(
node, "Cannot have `return` statements inside `while` or `for` statements in triton.")
self.visit_if_top_level(cond, node)
else:
self.visit_if_scf(cond, node)
else:
cond = _unwrap_if_constexpr(cond)
if type(cond) not in _condition_types:
raise self._unsupported(
node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types),
type(cond).__name__))
active_block = node.body if cond else node.orelse
self.visit_compound_statement(active_block)
def visit_IfExp(self, node):
cond = self.visit(node.test)
if _is_triton_tensor(cond):
cond = cond.to(language.int1, _semantic=self.semantic)
with enter_sub_region(self):
ip, last_loc = self._get_insertion_point_and_loc()
then_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(then_block)
then_val = self.semantic.to_tensor(self.visit(node.body))
then_block = self.builder.get_insertion_block()
else_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(else_block)
else_val = self.semantic.to_tensor(self.visit(node.orelse))
else_block = self.builder.get_insertion_block()
self._set_insertion_point_and_loc(ip, last_loc)
assert then_val.type == else_val.type, \
f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
ret_type = then_val.type
ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True)
then_block.merge_block_before(if_op.get_then_block())
if ret_type_ir:
self.builder.set_insertion_point_to_end(if_op.get_then_block())
self.builder.create_yield_op([then_val.handle])
self.builder.set_insertion_point_to_end(if_op.get_then_block())
else_block.merge_block_before(if_op.get_else_block())
if ret_type_ir:
self.builder.set_insertion_point_to_end(if_op.get_else_block())
self.builder.create_yield_op([else_val.handle])
return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
else:
cond = _unwrap_if_constexpr(cond)
if type(cond) not in _condition_types:
raise self._unsupported(
node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
', '.join(_.__name__ for _ in _condition_types),
type(cond).__name__))
if cond:
return self.visit(node.body)
else:
return self.visit(node.orelse)
def visit_With(self, node):
"""
Handle 'with' statements with dispatch pattern for Ascend extensions,
falling back to standard context manager protocol for general cases.
This implementation:
1. First tries dispatch mechanism for Ascend-specific context managers (e.g., scope)
2. Falls back to standard Python context manager protocol for general cases
"""
if len(node.items) == 1:
context = node.items[0].context_expr
if isinstance(context, ast.Call):
withitemClass = self.visit(context.func)
handler = WITH_DISPATCH.get(withitemClass)
if handler:
return handler(self, node)
cm_list = []
for item in node.items:
call = item.context_expr
fn = self.visit(call.func)
args = [self.visit(arg) for arg in call.args]
kws = dict(self.visit(kw) for kw in call.keywords)
cm = fn(*args, _semantic=self.semantic, **kws)
cm_list.append(cm)
for cm, item in zip(cm_list, node.items):
res = cm.__enter__()
if item.optional_vars is not None:
var_name = self.visit(item.optional_vars)
self.set_value(var_name, res)
if ContainsReturnChecker(self.gscope).visit(node):
raise self._unsupported(node, "Cannot have `return` statements inside `with` statements in triton ")
self.visit_compound_statement(node.body)
for cm in reversed(cm_list):
cm.__exit__(None, None, None)
def visit_Pass(self, node):
pass
def visit_Compare(self, node):
if not (len(node.comparators) == 1 and len(node.ops) == 1):
raise self._unsupported(node, "simultaneous multiple comparison is not supported")
lhs = self.visit(node.left)
rhs = self.visit(node.comparators[0])
lhs_value = _unwrap_if_constexpr(lhs)
rhs_value = _unwrap_if_constexpr(rhs)
if type(node.ops[0]) is ast.Is:
return constexpr(lhs_value is rhs_value)
if type(node.ops[0]) is ast.IsNot:
return constexpr(lhs_value is not rhs_value)
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
if method_name is None:
raise self._unsupported(
node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
return self._apply_binary_method(method_name, lhs, rhs)
_method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
}
def visit_UnaryOp(self, node):
operand = self.visit(node.operand)
fn = self._method_name_for_unary_op.get(type(node.op))
if fn is None:
raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.")
if _is_triton_tensor(operand):
return getattr(operand, fn)(_semantic=self.semantic)
try:
return getattr(operand, fn)()
except AttributeError:
if fn == "__not__":
return constexpr(not operand)
raise self._unsupported(
node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}")
_method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {
ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
}
def _verify_loop_carried_variable(self, name, loop_val, live_val):
assert _is_triton_value(loop_val), f'cannot reassign constexpr {name} in the loop'
assert _is_triton_value(live_val), f'cannot reassign constexpr {name} in the loop'
assert type(loop_val) is type(live_val), (
f'Loop carried variable {name} changed type, was {type(loop_val)} but is now {type(live_val)}')
assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
f'Loop-carried variable {name} has initial type {live_val.type} '\
f'but is re-assigned to {loop_val.type} in loop! '\
f'Please make sure that the type stays consistent.'
def visit_While(self, node):
with enter_sub_region(self) as sr:
liveins, insert_block = sr
ip, last_loc = self._get_insertion_point_and_loc()
names, init_handles, init_fe_tys = self._find_carries(node, liveins)
init_tys = [h.get_type() for h in init_handles]
self._set_insertion_point_and_loc(ip, last_loc)
while_op = self.builder.create_while_op(init_tys, init_handles)
before_block = self.builder.create_block_with_parent(while_op.get_before(), init_tys)
self.builder.set_insertion_point_to_start(before_block)
block_args = [before_block.arg(i) for i in range(len(init_handles))]
condition_args = unflatten_ir_values(block_args, init_fe_tys)
for name, val in zip(names, condition_args):
self.lscope[name] = val
self.local_defs[name] = val
self._maybe_set_loc_to_name(val, name)
cond = self.visit(node.test)
if isinstance(cond, language.condition):
if cond.disable_licm:
while_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
cond = cond.condition
self.builder.set_insertion_point_to_end(before_block)
self.builder.create_condition_op(cond.handle, block_args)
after_block = self.builder.create_block_with_parent(while_op.get_after(), init_tys)
self.builder.set_insertion_point_to_start(after_block)
body_handles = [after_block.arg(i) for i in range(len(init_handles))]
body_args = unflatten_ir_values(body_handles, init_fe_tys)
for name, val in zip(names, body_args):
self.lscope[name] = val
self.local_defs[name] = val
self._maybe_set_loc_to_name(val, name)
self.scf_stack.append(node)
self.visit_compound_statement(node.body)
self.scf_stack.pop()
yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
self.builder.create_yield_op(yield_handles)
result_handles = [while_op.get_result(i) for i in range(len(init_handles))]
result_vals = unflatten_ir_values(result_handles, init_fe_tys)
for name, new_def in zip(names, result_vals):
self.lscope[name] = new_def
self.local_defs[name] = new_def
self._maybe_set_loc_to_name(new_def, name)
for stmt in node.orelse:
assert False, "Not implemented"
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Subscript_Load(self, node):
assert isinstance(node.ctx, ast.Load)
lhs = self.visit(node.value)
slices = self.visit(node.slice)
if _is_triton_value(lhs):
return self.call_Method(node, lhs.__getitem__, lhs, [slices], {})
return lhs[slices]
def visit_Subscript_Store(self, node, value):
raise NotImplementedError("__setitem__ is not supported in triton")
def visit_Subscript(self, node):
return self.visit_Subscript_Load(node)
def visit_ExtSlice(self, node):
return [self.visit(dim) for dim in node.dims]
def visit_For(self, node):
IteratorClass = self.visit(node.iter.func)
iter_args = [self.visit(arg) for arg in node.iter.args]
iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords)
if IteratorClass == language.static_range:
iterator = IteratorClass(*iter_args, **iter_kwargs)
static_range = range(iterator.start.value, iterator.end.value, iterator.step.value)
for i in static_range:
self.lscope[node.target.id] = constexpr(i)
self.visit_compound_statement(node.body)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
return
num_stages = None
loop_unroll_factor = None
disallow_acc_multi_buffer = False
flatten = False
warp_specialize = False
disable_licm = False
bind_sub_block = None
if IteratorClass in [language.range, extension.parallel]:
iterator = IteratorClass(*iter_args, **iter_kwargs)
lb = iterator.start
ub = iterator.end
step = iterator.step
num_stages = iterator.num_stages
loop_unroll_factor = iterator.loop_unroll_factor
disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer
flatten = iterator.flatten
warp_specialize = iterator.warp_specialize
disable_licm = iterator.disable_licm
if (IteratorClass is extension.parallel):
bind_sub_block = iterator.bind_sub_block
elif IteratorClass is range:
lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0))
ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0])
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
else:
raise RuntimeError('Only `range` and `static_range` iterators are currently supported')
negative_step = False
if _is_constexpr(step) and step.value < 0:
step = constexpr(-step.value)
negative_step = True
lb, ub = ub, lb
lb = self.semantic.to_tensor(lb)
ub = self.semantic.to_tensor(ub)
step = self.semantic.to_tensor(step)
if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
iv_type = self.semantic.integer_promote_impl(lb.dtype, ub.dtype)
iv_type = self.semantic.integer_promote_impl(iv_type, step.dtype)
iv_ir_type = iv_type.to_ir(self.builder)
iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED
lb = lb.handle
ub = ub.handle
step = step.handle
lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed)
ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed)
step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed)
iv = self.builder.create_poison(iv_ir_type)
self.set_value(node.target.id, language.core.tensor(iv, iv_type))
with enter_sub_region(self) as sr:
liveins, insert_block = sr
ip, last_loc = self._get_insertion_point_and_loc()
names, init_handles, init_tys = self._find_carries(node, liveins)
self._set_insertion_point_and_loc(ip, last_loc)
for_op = self.builder.create_for_op(lb, ub, step, init_handles)
if _unwrap_if_constexpr(num_stages) is not None:
for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
if _unwrap_if_constexpr(loop_unroll_factor) is not None:
for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor))
if disallow_acc_multi_buffer:
for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr())
if flatten:
for_op.set_attr("tt.flatten", self.builder.get_unit_attr())
if warp_specialize:
for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr())
if disable_licm:
for_op.set_attr("tt.disable_licm", self.builder.get_unit_attr())
if (IteratorClass is extension.parallel):
for_op.set_attr("hivm.parallel_loop", self.builder.get_unit_attr())
self.scf_stack.append(node)
for_op_body = for_op.get_body(0)
self.builder.set_insertion_point_to_start(for_op_body)
block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))]
block_args = unflatten_ir_values(block_handles, init_tys)
for name, val in zip(names, block_args):
self._maybe_set_loc_to_name(val, name)
self.set_value(name, val)
self.visit_compound_statement(node.body)
self.scf_stack.pop()
yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
if len(yield_handles) > 0:
self.builder.create_yield_op(yield_handles)
for_op_region = for_op_body.get_parent()
assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
self.builder.set_insertion_point_to_start(for_op_body)
iv = for_op.get_induction_var()
if negative_step:
iv = self.builder.create_sub(ub, iv)
iv = self.builder.create_add(iv, lb)
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
self.set_value(node.target.id, language.core.tensor(iv, iv_type))
self._maybe_set_loc_to_name(iv, node.target.id)
result_handles = [for_op.get_result(i) for i in range(len(init_handles))]
result_values = unflatten_ir_values(result_handles, init_tys)
for name, val in zip(names, result_values):
self.set_value(name, val)
self._maybe_set_loc_to_name(val, name)
for stmt in node.orelse:
assert False, "Don't know what to do with else after for"
ast.NodeVisitor.generic_visit(self, stmt)
def visit_Slice(self, node):
lower = self.visit(node.lower)
upper = self.visit(node.upper)
step = self.visit(node.step)
return language.slice(lower, upper, step)
def visit_Index(self, node):
return self.visit(node.value)
def visit_keyword(self, node) -> Tuple[str, Any]:
return node.arg, self.visit(node.value)
def visit_Assert(self, node) -> Any:
test = self.visit(node.test)
msg = self.visit(node.msg) if node.msg is not None else ""
return language.core.device_assert(test, msg, _semantic=self.semantic)
def call_JitFunction(self, fn: JITFunction, args, kwargs, caller_context=None):
args = inspect.getcallargs(fn.fn, *args, **kwargs)
args = [args[name] for name in fn.arg_names]
for i, arg in enumerate(args):
if isinstance(arg, (language.dtype, float, int, bool, JITFunction)):
args[i] = language.core.constexpr(arg)
args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x))
args_cst = {path: get_iterable_path(args, path) for path in args_cst}
args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
args_val = [get_iterable_path(args, path) for path in args_path]
caller_context = caller_context or self.caller_context
fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst, caller_context)
if not self.module.has_function(fn_name):
file_name, begin_line = get_jit_fn_file_line(fn)
arg_types = [
language.core.constexpr if arg is None or isinstance(arg,
(bool, int, language.core.dtype)) else arg.type
for arg in args
]
prototype = ASTFunction([], arg_types, args_cst, dict())
generator = CodeGenerator(self.context, prototype, fn.get_capture_scope(), module=self.module, jit_fn=fn,
function_name=fn_name, function_types=self.function_ret_types,
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
options=self.builder.options, codegen_fns=self.builder.codegen_fns,
module_map=self.builder.module_map, caller_context=caller_context,
is_gluon=self.is_gluon)
try:
generator.visit(fn.parse())
except Exception as e:
if knobs.compilation.front_end_debugging:
raise
raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from e
callee_ret_type = generator.ret_type
self.function_ret_types[fn_name] = callee_ret_type
else:
callee_ret_type = self.function_ret_types[fn_name]
symbol = self.module.get_function(fn_name)
args_val = flatten_values_to_ir(args_val)
call_op = self.builder.call(symbol, args_val)
if callee_ret_type == language.void:
return None
handles = [call_op.get_result(i) for i in range(call_op.get_num_results())]
return next(unflatten_ir_values(handles, [callee_ret_type]))
def call_Function(self, node, fn, args, kws):
if isinstance(fn, (BoundJITMethod, BoundConstexprFunction)):
args.insert(0, fn.__self__)
fn = fn.__func__
if isinstance(fn, JITFunction):
_check_fn_args(node, fn, args)
return self.call_JitFunction(fn, args, kws)
if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn)or isinstance(
fn, ConstexprFunction):
ip, last_loc = self._get_insertion_point_and_loc()
_builder = self.ascend_builder if extension.is_builtin(fn) else self.builder
self._set_insertion_point_and_loc(ip, last_loc, _builder)
extra_kwargs = dict()
if isinstance(fn, ConstexprFunction):
sig = inspect.signature(fn.__call__)
else:
sig = inspect.signature(fn)
if '_semantic' in sig.parameters:
extra_kwargs["_semantic"] = self.semantic
if '_generator' in sig.parameters:
extra_kwargs['_generator'] = self
try:
ret = fn(*args, **extra_kwargs, **kws)
if isinstance(ret, tuple):
ret = language.tuple(ret)
ip, last_loc = self._get_insertion_point_and_loc(_builder)
self._set_insertion_point_and_loc(ip, last_loc)
return ret
except Exception as e:
if knobs.compilation.front_end_debugging:
raise
raise CompilationError(self.jit_fn.src, node, repr(e)) from e
if fn in self.builtin_namespace.values():
args = map(_unwrap_if_constexpr, args)
ret = fn(*args, **kws)
def wrap_constexpr(x):
if _is_triton_value(x):
return x
return constexpr(x)
if isinstance(ret, (builtins.tuple, language.tuple)):
return _apply_to_tuple_values(ret, wrap_constexpr)
return wrap_constexpr(ret)
def call_Method(self, node, fn, fn_self, args, kws):
if isinstance(fn, JITFunction):
args.insert(0, fn_self)
return self.call_Function(node, fn, args, kws)
def visit_Call(self, node):
fn = _unwrap_if_constexpr(self.visit(node.func))
if not isinstance(fn, BoundJITMethod):
static_implementation = self.statically_implemented_functions.get(fn)
if static_implementation is not None:
return static_implementation(self, node)
mur = getattr(fn, '_must_use_result', False)
if mur and getattr(node, '_is_unused', False):
error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
if isinstance(mur, str):
error_message.append(mur)
raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
kws = dict(self.visit(keyword) for keyword in node.keywords)
args = [self.visit(arg) for arg in node.args]
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
return self.call_Function(node, fn, args, kws)
def visit_Constant(self, node):
return constexpr(node.value)
def visit_BoolOp(self, node: ast.BoolOp):
method_name = self._method_name_for_bool_op.get(type(node.op))
if method_name is None:
raise self._unsupported(
node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
nontrivial_values = []
for subnode in node.values:
value = self.visit(subnode)
if not _is_triton_tensor(value):
bv = bool(value)
if (bv is False) and (method_name == "logical_and"):
return value
if (bv is True) and (method_name == "logical_or"):
return value
else:
if value.type.is_block():
lineno = getattr(node, "lineno", None)
if lineno is not None:
lineno += self.begin_line
warnings.warn_explicit(
"Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead",
category=UserWarning,
filename=self.file_name,
lineno=lineno,
source=ast.unparse(node),
)
nontrivial_values.append(value)
if len(nontrivial_values) == 0:
nontrivial_values.append(value)
while len(nontrivial_values) >= 2:
rhs = nontrivial_values.pop()
lhs = nontrivial_values.pop()
res = self._apply_binary_method(method_name, lhs, rhs)
nontrivial_values.append(res)
assert len(nontrivial_values) == 1
return nontrivial_values[0]
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
def visit_Attribute(self, node):
lhs = self.visit(node.value)
if _is_triton_tensor(lhs) and node.attr == "T":
return self.semantic.permute(lhs, (1, 0))
if isinstance(lhs, constexpr) and node.attr not in ("value", "type"):
lhs = lhs.value
attr = getattr(lhs, node.attr)
if _is_triton_value(lhs) and isinstance(attr, JITFunction):
return BoundJITMethod(lhs, attr)
return attr
def visit_Expr(self, node):
node.value._is_unused = True
ast.NodeVisitor.generic_visit(self, node)
def visit_NoneType(self, node):
return None
def visit_JoinedStr(self, node):
values = list(node.values)
for i, value in enumerate(values):
if isinstance(value, ast.Constant):
values[i] = str(value.value)
elif isinstance(value, ast.FormattedValue):
conversion_code = value.conversion
evaluated = self.visit(value.value)
if not _is_constexpr(evaluated):
raise self._unsupported(
node,
"Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type "
+ str(type(evaluated)))
values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
else:
raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
return ''.join(values)
def visit(self, node):
if node is None:
return
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
warnings.simplefilter("ignore", PendingDeprecationWarning)
last_node = self.cur_node
last_loc = self.builder.get_loc()
self.cur_node = node
if hasattr(node, 'lineno') and hasattr(node, 'col_offset'):
here_loc = self.builder.create_loc(self.file_name, self.begin_line + node.lineno, node.col_offset)
if self.name_loc_as_prefix is not None:
self.builder.set_loc(self.builder.create_name_loc(self.name_loc_as_prefix, here_loc))
else:
self.builder.set_loc(here_loc)
last_loc = self.builder.get_loc()
try:
ret = super().visit(node)
except CompilationError:
raise
except Exception as e:
if knobs.compilation.front_end_debugging:
raise
raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None
if last_loc:
self.cur_node = last_node
self.builder.set_loc(last_loc)
return ret
def generic_visit(self, node):
raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__))
def execute_static_assert(self, node: ast.Call) -> None:
arg_count = len(node.args)
if not (0 < arg_count <= 2) or len(node.keywords):
raise TypeError("`static_assert` requires one or two positional arguments only")
passed = _unwrap_if_constexpr(self.visit(node.args[0]))
if not isinstance(passed, bool):
raise NotImplementedError(
"Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values"
)
if not passed:
if arg_count == 1:
message = ""
else:
try:
message = self.visit(node.args[1])
except Exception as e:
message = "<failed to evaluate assertion message: " + repr(e) + ">"
raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message))
return None
def static_executor(python_fn):
def ret(self, node: ast.Call):
kws = {
name: _unwrap_if_constexpr(value)
for name, value in (self.visit(keyword) for keyword in node.keywords)
}
args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args]
return constexpr(python_fn(*args, **kws))
return ret
from ..experimental.gluon import language as ttgl
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
language.core.static_assert: execute_static_assert,
language.core.static_print: static_executor(print),
ttgl.static_assert: execute_static_assert,
ttgl.static_print: static_executor(print),
int: static_executor(int),
len: static_executor(len),
extension.int64: static_executor(extension.int64),
}
def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None):
arg_types = [None] * len(fn.arg_names)
const_iter = iter(src.constants.items())
kc, vc = next(const_iter, (None, None))
for i, (ks, v) in enumerate(src.signature.items()):
idx = fn.arg_names.index(ks)
cexpr = None
if kc is not None and kc[0] == i:
cexpr = vc
kc, vc = next(const_iter, (None, None))
arg_types[idx] = str_to_ty(v, cexpr)
prototype = ASTFunction([], arg_types, src.constants, src.attrs)
file_name, begin_line = get_jit_fn_file_line(fn)
from collections import namedtuple
leaves = filter(lambda v: len(v) == 1, src.constants)
constants = {fn.arg_names[i[0]]: src.constants[i] for i in leaves}
signature = src.signature
proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy),
jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
codegen_fns=codegen_fns, module_map=module_map, module=module, is_gluon=fn.is_gluon())
generator.visit(fn.parse())
module = generator.module
module.context = context
if not module.verify_with_diagnostics():
if not fn.is_gluon():
print(module)
raise RuntimeError("error encountered during parsing")
return module