// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-partition-loops -verify-diagnostics -canonicalize | FileCheck %s
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
!ty = tensor<1xi32, #blocked>
module attributes {"ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: @no_partitions
tt.func @no_partitions(%lb: i32, %ub: i32, %step: i32) {
// CHECK-NEXT: scf.for
scf.for %i = %lb to %ub step %step : i32 {
// CHECK-NEXT: op_a
"op_a"() : () -> ()
} {ttg.partition.stages = [], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
// CHECK-LABEL: @one_partition
tt.func @one_partition(%lb: i32, %ub: i32, %step: i32) {
// CHECK-NEXT: scf.for
scf.for %i = %lb to %ub step %step : i32 {
// CHECK-NEXT: op_a
"op_a"() {ttg.partition = 0} : () -> ()
} {ttg.partition.stages = [0], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
// CHECK-LABEL: @two_empty_partitions
tt.func @two_empty_partitions(%lb: i32, %ub: i32, %step: i32) {
// CHECK-NEXT: nvws.warp_group
// CHECK-NEXT: partition0 num_warps(4)
// CHECK-NEXT: scf.for [[I:%.*]] = %arg0 to %arg1 step %arg2
// CHECK-NEXT: "op_a"([[I]])
// CHECK-NEXT: }
// CHECK-NEXT: nvws.warp_group.yield
// CHECK-NEXT: }
// CHECK-NEXT: partition1 num_warps(4)
// CHECK-NEXT: scf.for [[I:%.*]] = %arg0 to %arg1 step %arg2
// CHECK-NEXT: "op_a"([[I]])
// CHECK-NEXT: }
// CHECK-NEXT: nvws.warp_group.return
scf.for %i = %lb to %ub step %step : i32 {
"op_a"(%i) : (i32) -> ()
} {ttg.partition.stages = [0, 0], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
// CHECK-LABEL: @empty_partition_fwd_root
tt.func @empty_partition_fwd_root(%lb: i32, %ub: i32, %step: i32) {
// CHECK-NEXT: [[C0:%.*]] = arith.constant 0
%c0_i32 = arith.constant 0 : i32
// CHECK: partition0
// CHECK-NEXT: scf.for [[I:%.*]] = {{.*}} iter_args([[K:%.*]] = [[C0]])
// CHECK-NEXT: "op_a"([[I]], [[K]])
scf.for %i = %lb to %ub step %step iter_args(%k = %c0_i32) -> i32 : i32 {
%0 = "op_a"(%i, %k) : (i32, i32) -> i32
scf.yield %0 : i32
} {ttg.partition.stages = [0, 0], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
// CHECK-LABEL: @multiple_partitions
tt.func @multiple_partitions(%lb: i32, %ub: i32, %step: i32) {
// CHECK: partition0 num_warps(4)
// CHECK-NEXT: scf.for
// CHECK-NEXT: [[X:%.*]] = "op_a"
// CHECK-NEXT: "op_b"([[X]])
// CHECK-NEXT: "op_b"([[X]])
// CHECK-NEXT: }
// CHECK: partition1
// CHECK-NEXT: scf.for [[I:%arg[0-9]+]]
// CHECK-NEXT: [[Y:%.*]] = arith.addi [[I]], [[I]]
// CHECK-NEXT: [[X:%.*]] = "op_a"([[Y]])
// CHECK-NEXT: "op_b"([[X]])
// CHECK-NEXT: "op_b"([[X]])
// CHECK-NEXT: }
// CHECK: partition2
// CHECK-NEXT: scf.for [[I:%arg[0-9]+]]
// CHECK-NEXT: [[Y:%.*]] = arith.addi [[I]], [[I]]
// CHECK-NEXT: [[Z:%.*]] = arith.addi [[I]], [[Y]]
// CHECK-NEXT: [[X:%.*]] = "op_a"([[Z]])
// CHECK-NEXT: "op_b"([[X]])
// CHECK-NEXT: "op_b"([[X]])
// CHECK-NEXT: }
scf.for %i = %lb to %ub step %step : i32 {
%a = arith.addi %i, %i : i32
%b = arith.addi %i, %a : i32
%0 = "op_a"(%i) {ttg.partition = 0} : (i32) -> i32
"op_b"(%0) {ttg.partition = 0} : (i32) -> ()
"op_b"(%0) {ttg.partition = 0} : (i32) -> ()
%1 = "op_a"(%a) {ttg.partition = 1} : (i32) -> i32
"op_b"(%1) {ttg.partition = 1} : (i32) -> ()
"op_b"(%1) {ttg.partition = 1} : (i32) -> ()
%2 = "op_a"(%b) {ttg.partition = 2} : (i32) -> i32
"op_b"(%2) {ttg.partition = 2} : (i32) -> ()
"op_b"(%2) {ttg.partition = 2} : (i32) -> ()
} {ttg.partition.stages = [0, 0, 0], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
// CHECK-LABEL: @multiple_partitions_two_loops
tt.func @multiple_partitions_two_loops(%lb: i32, %ub: i32, %step: i32,
%c0 : i32, %c1 : i32, %c2 : i32) {
// CHECK: partition0 num_warps(4)
// CHECK-NEXT: op_00b
// CHECK-NEXT: [[RET:%.*]]:3 = scf.for [[I:%.*]] = [[LB:%.*]] to [[UB:%.*]] step [[STEP:%.*]] iter_args([[ARG0:%.*]] = {{.*}}, [[ARG1:%.*]] = {{.*}}, [[ARG2:%.*]] = {{.*}}) -> (i32, i32, i32) : i32 {
// CHECK-NEXT: [[X:%.*]] = "op_a"
// CHECK-NEXT: "op_b"([[ARG0]])
// CHECK-NEXT: "op_b"([[X]])
// CHECK-NEXT: arith.addi
// CHECK-NEXT: arith.addi
// CHECK-NEXT: arith.addi
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
// CHECK-NEXT: "op_00e"([[RET]]#0)
// CHECK: partition1
// CHECK-NEXT: op_01b
// CHECK-NEXT: [[RET:%.*]]:3 = scf.for [[I:%.*]] = [[LB:%.*]] to [[UB:%.*]] step [[STEP:%.*]] iter_args([[ARG0:%.*]] = {{.*}}, [[ARG1:%.*]] = {{.*}}, [[ARG2:%.*]] = {{.*}}) -> (i32, i32, i32) : i32 {
// CHECK-NEXT: [[Y:%.*]] = arith.addi [[I]], [[I]]
// CHECK-NEXT: [[X:%.*]] = "op_a"([[Y]])
// CHECK-NEXT: "op_b"([[ARG1]])
// CHECK-NEXT: "op_b"([[X]])
// CHECK-NEXT: arith.addi
// CHECK-NEXT: arith.addi
// CHECK-NEXT: arith.addi
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
// CHECK-NEXT: "op_01e"([[RET]]#1)
// CHECK: partition2
// CHECK-NEXT: op_02b
// CHECK-NEXT: [[RET:%.*]]:3 = scf.for [[I:%.*]] = [[LB:%.*]] to [[UB:%.*]] step [[STEP:%.*]] iter_args([[ARG0:%.*]] = {{.*}}, [[ARG1:%.*]] = {{.*}}, [[ARG2:%.*]] = {{.*}}) -> (i32, i32, i32) : i32 {
// CHECK-NEXT: [[Y:%.*]] = arith.addi [[I]], [[I]]
// CHECK-NEXT: [[Z:%.*]] = arith.addi [[I]], [[Y]]
// CHECK-NEXT: [[X:%.*]] = "op_a"([[Z]])
// CHECK-NEXT: "op_b"([[ARG2]])
// CHECK-NEXT: "op_b"([[X]])
// CHECK-NEXT: arith.addi
// CHECK-NEXT: arith.addi
// CHECK-NEXT: arith.addi
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
// CHECK-NEXT: "op_02e"([[RET]]#2)
"op_00b"() {ttg.partition = 0, ttg.warp_specialize.tag = 0} : () -> ()
"op_01b"() {ttg.partition = 1, ttg.warp_specialize.tag = 0} : () -> ()
"op_02b"() {ttg.partition = 2, ttg.warp_specialize.tag = 0} : () -> ()
%ret:3 = scf.for %i = %lb to %ub step %step iter_args(%arg0 = %c0, %arg1 = %c1, %arg2 = %c2) -> (i32, i32, i32) : i32 {
%a = arith.addi %i, %i : i32
%b = arith.addi %i, %a : i32
%0 = "op_a"(%i) {ttg.partition = 0} : (i32) -> i32
"op_b"(%arg0) {ttg.partition = 0} : (i32) -> ()
"op_b"(%0) {ttg.partition = 0} : (i32) -> ()
%1 = "op_a"(%a) {ttg.partition = 1} : (i32) -> i32
"op_b"(%arg1) {ttg.partition = 1} : (i32) -> ()
"op_b"(%1) {ttg.partition = 1} : (i32) -> ()
%2 = "op_a"(%b) {ttg.partition = 2} : (i32) -> i32
"op_b"(%arg2) {ttg.partition = 2} : (i32) -> ()
"op_b"(%2) {ttg.partition = 2} : (i32) -> ()
%v0 = arith.addi %arg0, %arg0 : i32
%v1 = arith.addi %arg1, %arg1 : i32
%v2 = arith.addi %arg2, %arg2 : i32
scf.yield %v0, %v1, %v2: i32, i32, i32
} {ttg.partition.stages = [0, 0, 0], ttg.warp_specialize.tag = 0 : i32}
"op_00e"(%ret#0) {ttg.partition = 0, ttg.warp_specialize.tag = 0} : (i32) -> ()
"op_01e"(%ret#1) {ttg.partition = 1, ttg.warp_specialize.tag = 0} : (i32) -> ()
"op_02e"(%ret#2) {ttg.partition = 2, ttg.warp_specialize.tag = 0} : (i32) -> ()
// CHECK: partition0 num_warps(4)
// CHECK-NEXT: op_10b
// CHECK-NEXT: scf.for
// CHECK: } {ttg.warp_specialize.tag = 1
// CHECK-NEXT: op_10e
// CHECK: partition1
// CHECK-NEXT: op_11b
// CHECK-NEXT: scf.for
// CHECK: } {ttg.warp_specialize.tag = 1
// CHECK-NEXT: op_11e
// CHECK: partition2
// CHECK-NEXT: op_12b
// CHECK-NEXT: scf.for
// CHECK: } {ttg.warp_specialize.tag = 1
// CHECK-NEXT: op_12e
"op_10b"() {ttg.partition = 0, ttg.warp_specialize.tag = 1} : () -> ()
"op_11b"() {ttg.partition = 1, ttg.warp_specialize.tag = 1} : () -> ()
"op_12b"() {ttg.partition = 2, ttg.warp_specialize.tag = 1} : () -> ()
scf.for %i = %lb to %ub step %step : i32 {
%a = arith.addi %i, %i : i32
%b = arith.addi %i, %a : i32
%0 = "op_a"(%i) {ttg.partition = 0} : (i32) -> i32
"op_b"(%0) {ttg.partition = 0} : (i32) -> ()
"op_b"(%0) {ttg.partition = 0} : (i32) -> ()
%1 = "op_a"(%a) {ttg.partition = 1} : (i32) -> i32
"op_b"(%1) {ttg.partition = 1} : (i32) -> ()
"op_b"(%1) {ttg.partition = 1} : (i32) -> ()
%2 = "op_a"(%b) {ttg.partition = 2} : (i32) -> i32
"op_b"(%2) {ttg.partition = 2} : (i32) -> ()
"op_b"(%2) {ttg.partition = 2} : (i32) -> ()
} {ttg.partition.stages = [0, 0, 0], ttg.warp_specialize.tag = 1 : i32}
"op_10e"() {ttg.partition = 0, ttg.warp_specialize.tag = 1} : () -> ()
"op_11e"() {ttg.partition = 1, ttg.warp_specialize.tag = 1} : () -> ()
"op_12e"() {ttg.partition = 2, ttg.warp_specialize.tag = 1} : () -> ()
tt.return
}
// CHECK-LABEL: @split_block_arguments
tt.func @split_block_arguments(%lb: i32, %ub: i32, %step: i32) {
// CHECK-NEXT: [[C0:%.*]] = arith.constant 0
// CHECK-NEXT: [[C1:%.*]] = arith.constant 1
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
// CHECK: partition0
// CHECK-NEXT: scf.for {{.*}} iter_args([[A:%.*]] = [[C0]])
// CHECK-NEXT: [[X:%.*]] = "op_a"([[A]])
// CHECK-NEXT: yield [[X]] : i32
// CHECK: partition1
// CHECK-NEXT: scf.for {{.*}} iter_args([[B:%.*]] = [[C1]])
// CHECK-NEXT: [[X:%.*]] = "op_b"([[B]])
// CHECK-NEXT: yield [[X]] : i32
scf.for %i = %lb to %ub step %step iter_args(%a = %c0_i32, %b = %c1_i32) -> (i32, i32) : i32 {
%0 = "op_a"(%a) {ttg.partition = 0} : (i32) -> i32
%1 = "op_b"(%b) {ttg.partition = 1} : (i32) -> i32
scf.yield %0, %1 : i32, i32
} {ttg.partition.stages = [0, 0], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
// CHECK-LABEL: @partition_outputs
tt.func @partition_outputs(%lb: i32, %ub: i32, %step: i32) -> (!ty, !ty, !ty) {
// CHECK-NEXT: [[CST0:%.*]] = arith.constant dense<0>
// CHECK-NEXT: [[CST1:%.*]] = arith.constant dense<1>
// CHECK-NEXT: [[CST2:%.*]] = arith.constant dense<2>
%cst0 = arith.constant dense<0> : !ty
%cst1 = arith.constant dense<1> : !ty
%cst2 = arith.constant dense<2> : !ty
// CHECK-NEXT: [[B_BUF:%.*]] = ttg.local_alloc
// CHECK-NEXT: [[C_BUF:%.*]] = ttg.local_alloc
// CHECK-NEXT: [[A_OUT:%.*]] = nvws.warp_group
// CHECK-NEXT: partition0
// CHECK-NEXT: [[OUT:%.*]] = scf.for [[I:%arg[0-9]+]] {{.*}} iter_args([[A:%.*]] = [[CST0]])
// CHECK-NEXT: [[X:%.*]] = "op_a"([[I]], [[A]])
// CHECK-NEXT: yield [[X]]
// CHECK-NEXT: }
// CHECK-NEXT: nvws.warp_group.yield [[OUT]]
// CHECK: partition1 num_warps(4)
// CHECK-NEXT: [[OUT:%.*]] = scf.for [[I:%arg[0-9]+]] {{.*}} iter_args([[B:%.*]] = [[CST1]])
// CHECK-NEXT: [[X:%.*]] = "op_b"([[I]], [[B]])
// CHECK-NEXT: yield [[X]]
// CHECK-NEXT: }
// CHECK-NEXT: local_store [[OUT]], [[B_BUF]]
// CHECK: partition2 num_warps(4)
// CHECK-NEXT: [[OUT:%.*]] = scf.for [[I:%arg[0-9]+]] {{.*}} iter_args([[C:%.*]] = [[CST2]])
// CHECK-NEXT: [[X:%.*]] = "op_c"([[I]], [[C]])
// CHECK-NEXT: yield [[X]]
// CHECK-NEXT: }
// CHECK-NEXT: local_store [[OUT]], [[C_BUF]]
%outs:3 = scf.for %i = %lb to %ub step %step iter_args(%a = %cst0, %b = %cst1, %c = %cst2) -> (!ty, !ty, !ty) : i32 {
%0 = "op_a"(%i, %a) {ttg.partition = 0} : (i32, !ty) -> !ty
%1 = "op_b"(%i, %b) {ttg.partition = 1} : (i32, !ty) -> !ty
%2 = "op_c"(%i, %c) {ttg.partition = 2} : (i32, !ty) -> !ty
scf.yield %0, %1, %2 : !ty, !ty, !ty
} {ttg.partition.stages = [0, 0, 0], ttg.warp_specialize.tag = 0 : i32}
// CHECK: [[B_OUT:%.*]] = ttg.local_load [[B_BUF]]
// CHECK-NEXT: local_dealloc [[B_BUF]]
// CHECK-NEXT: [[C_OUT:%.*]] = ttg.local_load [[C_BUF]]
// CHECK-NEXT: local_dealloc [[C_BUF]]
// CHECK-NEXT: tt.return [[A_OUT]], [[B_OUT]], [[C_OUT]]
tt.return %outs#0, %outs#1, %outs#2 : !ty, !ty, !ty
}
// CHECK-LABEL: @future_conditional_self_use
tt.func @future_conditional_self_use(%lb: i32, %ub: i32, %step: i32, %cond: i1) {
%c0_i32 = arith.constant 0 : i32
scf.for %i = %lb to %ub step %step iter_args(%k = %c0_i32) -> i32 : i32 {
%0 = "op_a"() {ttg.partition = 0 : i32} : () -> i32
scf.if %cond {
"use"(%k) : (i32) -> ()
} {ttg.partition = 0 : i32}
scf.yield %0 : i32
} {ttg.partition.stages = [0], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
// CHECK-LABEL: @trivial_tensor_captures
tt.func @trivial_tensor_captures(%arg0: f16, %lb: i32, %ub: i32, %step: i32) {
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
%1 = tt.splat %arg0 : f16 -> tensor<32xf16>
// CHECK: [[RANGE:%.*]] = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
// CHECK-NEXT: [[SPLAT:%.*]] = tt.splat %arg0 : f16 -> tensor<32xf16>
// CHECK-NEXT: nvws.warp_group
scf.for %i = %lb to %ub step %step : i32 {
// CHECK: partition1 num_warps(4)
// CHECK-NEXT: scf.for
// CHECK-NEXT: "use"([[RANGE]], [[SPLAT]])
"use"(%0, %1) {ttg.partition = 1} : (tensor<256xi32>, tensor<32xf16>) -> ()
} {ttg.partition.stages = [0, 0], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
// CHECK-LABEL: @tensor_captures_over_smem
tt.func @tensor_captures_over_smem(%lb: i32, %ub: i32, %step: i32) {
// CHECK: [[VALUE:%.*]] = "value"()
%0 = "value"() : () -> tensor<32xf16, #blocked>
// CHECK: nvws.warp_group
scf.for %i = %lb to %ub step %step : i32 {
// CHECK: partition1
// CHECK-NEXT: scf.for
// CHECK-NEXT: "use"([[VALUE]])
"use"(%0) {ttg.partition = 1} : (tensor<32xf16, #blocked>) -> ()
} {ttg.partition.stages = [0, 0], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
// CHECK-LABEL: @dce_before_warp_allocation
tt.func @dce_before_warp_allocation(%lb: i32, %ub: i32, %step: i32) {
%cst = arith.constant dense<0> : tensor<128xi32, #blocked>
// CHECK: nvws.warp_group
// CHECK: partition1 num_warps(4)
// CHECK: partition2 num_warps(4)
scf.for %i = %lb to %ub step %step iter_args(%idxs = %cst) -> tensor<128xi32, #blocked> : i32 {
%do_prologue = "prologue_cond"(%i) : (i32) -> i1
%0 = scf.if %do_prologue -> tensor<128xi32, #blocked> {
%1 = tt.splat %i : i32 -> tensor<128xi32, #blocked>
%2 = arith.addi %1, %idxs : tensor<128xi32, #blocked>
scf.yield %2 : tensor<128xi32, #blocked>
} else {
scf.yield %idxs : tensor<128xi32, #blocked>
}
"op_a"(%0) {ttg.partition = 0 : i32} : (tensor<128xi32, #blocked>) -> ()
"op_b"(%i) {ttg.partition = 1 : i32} : (i32) -> ()
"op_c"(%0) {ttg.partition = 2 : i32} : (tensor<128xi32, #blocked>) -> ()
scf.yield %0 : tensor<128xi32, #blocked>
} {ttg.partition.stages = [0, 0, 0], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
// CHECK-LABEL: @capture_order
tt.func @capture_order(%arg0: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked>
%1 = arith.extsi %0 : tensor<4xi32, #blocked> to tensor<4xi64, #blocked>
// CHECK: [[VALUE:%.*]] = tt.make_range
// CHECK-NEXT: [[EXT:%.*]] = arith.extsi [[VALUE]]
// CHECK: nvws.warp_group
// CHECK: partition1
// CHECK-NEXT: scf.for
scf.for %arg1 = %c0_i32 to %arg0 step %c1_i32 : i32 {
// CHECK-NEXT: "use"([[VALUE]])
"use"(%0) : (tensor<4xi32, #blocked>) -> ()
// CHECK-NEXT: "use"([[EXT]])
"use"(%1) : (tensor<4xi64, #blocked>) -> ()
} {ttg.partition.stages = [1 : i32, 0 : i32], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
// CHECK-LABEL: @clone_then_capture
tt.func @clone_then_capture(%arg0: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
// CHECK: [[TT:%.*]] = "tensor_op"()
// CHECK: [[V:%.*]] = arith.addi [[TT]], [[TT]]
%0 = "tensor_op"() : () -> tensor<4xi32, #blocked>
%1 = arith.addi %0, %0 : tensor<4xi32, #blocked>
// CHECK: partition1
// CHECK: scf.for
scf.for %arg1 = %c0_i32 to %arg0 step %c1_i32 : i32 {
// CHECK: "use"([[V]])
"use"(%1) {ttg.partition = 1 : i32} : (tensor<4xi32, #blocked>) -> ()
} {ttg.partition.stages = [0 : i32, 1 : i32], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
}
// -----
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
!ty = tensor<1xi32, #blocked>
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @still_has_ssa_deps(%lb: i32, %ub: i32, %step: i32) {
scf.for %i = %lb to %ub step %step : i32 {
// expected-warning @below {{non-root partition #0 has direct SSA consumer}}
%0 = "op_a"() {ttg.partition = 0} : () -> !ty
// expected-note @below {{use at distance 0 in partition #1 here}}
"op_b"(%0) {ttg.partition = 1} : (!ty) -> ()
} {ttg.partition.stages = [0, 1], ttg.warp_specialize.tag = 0 : i32}
tt.return
}
}