triton 总览

triton op 支持度总览

Triton Op int8 int16 int32 uint32 int64 fp16 fp32 bf16 bool
Creation Ops arange × × × × × × × ×
cat ×
full ×
zeros ×
zeros_like ×
cast ×
Shape Manipulation Ops broadcast ×
broadcast_to ×
expand_dims ×
interleave ×
join ×
permute × ×
ravel ×
reshape ×
split ×
trans × ×
view ×
Linear Algebra Ops dot × × × × ×
dot_scaled × × × × × × × × ×
Memory/Pointer Ops load ×
store ×
make_block_ptr × ×
make_tensor_descriptor × ×
load_tensor_descriptor × ×
store_tensor_descriptor × ×
advance × ×
Indexing Ops flip ×
where × × ✓*
swizzle2d × × × × ×
Math Ops add × ✓*
sub × ✓*
mul × ✓*
div × ✓*
floordiv(//) × × × × ✓*
mod × ✓*
neg × ×
invert(~) × × × ×
and(&) × × × ×
or(|) × × × ×
xor(^) × × × ×
not(!) × × × ×
lshift(<<) × × × × ×
rshift(>>) × × × × ×
gt × ✓*
ge × ✓*
lt × ✓*
le × ✓*
eq × ✓*
ne × ✓*
logical and
logical or
abs × ✓*
cdiv × ×
ceil × × × × × ×
clamp × × × × × ×
cos × × × × × ×
div_rn × × × × × ×
erf × × × × × ×
exp × × × × × ×
exp2 × × × × × ×
fdiv × × × × × ×
floor × × × × × ×
fma × × × × × ×
log × × × × × ×
log2 × × × × × ×
maximum × ✓*
minimum × ✓*
round × × × × × × × ×
rsqrt × × × × × ×
sigmoid × × × × × ×
sin × × × × × ×
softmax × × × × × ×
sqrt × × × × × ×
sqrt_rn × × × × × ×
umulhi × × × × × × × ×
Reduction Ops argmax × ×
argmin × ×
max × ✓*
min × ✓*
reduce × ✓*
sum × ✓*
xor_sum × × × × ✓*
Scan/Sort Ops associative_scan ×
cumprod ×
cumsum ×
histogram × × × × × ×
sort × × × × × × × × ×
gather × × × × × ×
Atomic Ops atomic_add × ×
atomic_and × × × ×
atomic_cas × × ×
atomic_max × × ×
atomic_min × × ×
atomic_or × × × ×
atomic_xchg × ×
atomic_xor × × × ×
Random Number Generation randint4x × × × ×
randint × × × ×
rand × × × × ×
randn × × × × ×
Iterators range × × × × ×
static_range × × × × ×
Inline Assembly inline_asm_elementwise × × × × × × × × ×
Compiler Hint Ops assume × × × × × × × ×
debug_barrier ×
max_constancy × × × × × × × × ×
max_contiguous × × × × × × × × ×
multiple_of × × × × × × × × ×
Debug Ops static_print ×
static_assert ×
device_print × ×
device_assert × × × × × × × ×

约束说明

  • dot: 两个输入A[batch(optional), M, K], B[batch(optional), K, N]。

  • gather: triton.gather(x, index, axis),假设x的shape为n维度,目前只支持axis=n-1。

  • permute: triton.permute(x, dims),不支持dims=[2, 1, 0]。

  • trans: triton.trans(x, dims),不支持dims=[2, 1 , 0]。

  • device_print: 需要增加1个环境变量,TRITON_DEVICE_PRINT=1。

  • device_assert: 生效需要设置两个环境变量,TRITON_DEBUG=1,TRITON_DEVICE_PRINT=1。

  • atomic_add: 昇腾不支持atomic_add实现多核add+保存中间结果,需要修改成普通add来保存中间结果

  • atomic类op: 对于昇腾后端,sem只支持默认值"acq_rel"模式,其他值均按默认值处理;scope只支持默认值"gpu",其他值均按默认值处理

  • atomic_or/atomic_xor/atomic_and/atomic_xchg/atomic_cas: 昇腾暂不支持在loop中使用

  • permute: 不支持不相邻轴转置,如(0, 1, 2) -> (2, 1, 0)

  • trans: 不支持不相邻轴转置,如(0, 1, 2) -> (2, 1, 0)

  • umulhi: 不支持负数输入

  • mod: int64仅支持处理 -2^24 ~ 2^24 范围内的数值

  • rand类op: 所支持的数据类型仅针对算子的输出。

  • tensor_descriptor类op: 当前仅支持绑定使用,即 make/load/store_tensor_descriptor 需配套使用

  • ALL: int8类型由于特殊处理,会占用更大的片上空间,编译时容易造成ub overflow报错,通常调整tiling即可解决

  • ALL: triton kernel中同时存在所有tensor总和不能超过96KB,若关闭double buffer,则不能超过192KB

  • ALL: 所有tensor不允许某个shape的size小于1

  • ALL: ✓*表示triton内部将bool类型转为int8类型进行运算,并能够执行得到结果的OP

  • ALL: 不支持使用shape为"[[]]"的标量tensor进行计算