// RUN: triton-opt %s --allow-unregistered-dialect --tritongpu-fuse-nested-loops -cse | FileCheck %s
// CHECK-LABEL: @empty_function
tt.func @empty_function() {
tt.return
}
// CHECK-LABEL: @no_fusion
tt.func @no_fusion(%lb: index, %ub: index, %step: index) -> index {
%c0 = arith.constant 0 : index
// CHECK: before.loop
"before.loop"() : () -> ()
// CHECK-NEXT: scf.for
%0 = scf.for %i = %lb to %ub step %step iter_args(%k = %c0) -> index {
// CHECK-NEXT: body
%1 = "body"(%i, %k) : (index, index) -> index
// CHECK-NEXT: yield
scf.yield %1 : index
// CHECK-NEXT: }
} {"ttg.always-fuse"}
// CHECK-NEXT: after.loop
"after.loop"() : () -> ()
tt.return %0 : index
}
// CHECK-LABEL: @fuse_one_level_simple
// CHECK-SAME: [[LBI:%.*]]: i64, [[UBI:%.*]]: i64, [[STEPI:%.*]]: i64, [[LBJ:%.*]]: i64, [[UBJ:%.*]]: i64, [[STEPJ:%.*]]: i64
tt.func @fuse_one_level_simple(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ubj: i64, %stepj: i64) {
// len_i = len(range(lbi, ubi, stepi))
//
// CHECK-NEXT: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]]
// CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]]
// len_j = len(range(lbj0, ubj0, stepj0))
//
// CHECK-NEXT: [[DIFF_J:%.*]] = arith.subi [[UBJ]], [[LBJ]]
// CHECK-NEXT: [[LEN_J:%.*]] = arith.ceildivsi [[DIFF_J]], [[STEPJ]]
// inner_len = max(1, len_j0)
//
// CHECK-NEXT: [[PLEN0:%.*]] = arith.constant 0 : i64
// CHECK: [[LEN_J_CLAMP:%.*]] = arith.maxsi %c1_i64, [[LEN_J]]
// CHECK-NEXT: [[PLEN1:%.*]] = arith.addi [[PLEN0]], [[LEN_J_CLAMP]]
// CHECK-NEXT: [[INNER_LEN:%.*]] = arith.subi [[PLEN1]], %c0_i64
// total_iters = len_i * max(1, inner_len)
//
// CHECK: [[TOTAL_ITERS:%.*]] = arith.muli [[LEN_I]], [[INNER_LEN]]
// T = -1
// i = lbi - stepi
// j = None
// for _ in range(total_iters):
//
// CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]]
// CHECK: scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS]] step %c1_i64 iter_args(
// CHECK-SAME: [[T_ARG:%.*]] = %c-1_i64, [[I_ARG:%.*]] = [[I_INIT]], [[J_ARG:%.*]] = %c0_i64) -> (i64, i64, i64) : i64 {
scf.for %i = %lbi to %ubi step %stepi : i64 {
// T = 0 if T == (inner_len - 1) else T + 1
//
// CHECK: [[T_PLUS_1:%.*]] = arith.addi [[T_ARG]], %c1_i64
// CHECK-NEXT: [[T_END:%.*]] = arith.subi [[INNER_LEN]], %c1_i64
// CHECK-NEXT: [[ROLLOVER:%.*]] = arith.cmpi eq, [[T_ARG]], [[T_END]]
// CHECK-NEXT: [[T:%.*]] = arith.select [[ROLLOVER]], %c0_i64, [[T_PLUS_1]]
// if T == 0:
// i += stepi
// prologue(i)
// j = lbj
//
// CHECK: [[START:%.*]] = arith.subi %c0_i64, %c0_i64 : i64
// CHECK-NEXT: [[PROLOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[START]]
// CHECK-NEXT: [[JI:%.*]]:2 = scf.if [[PROLOGUE_COND]] -> (i64, i64) {
// CHECK-NEXT: [[I:%.*]] = arith.addi [[I_ARG]], [[STEPI]]
// CHECK-NEXT: "prologue"([[I]]) : (i64) -> ()
// CHECK-NEXT: yield [[LBJ]], [[I]]
// CHECK-NEXT: } else {
// CHECK-NEXT: yield [[J_ARG]], [[I_ARG]]
// CHECK-NEXT: }
"prologue"(%i) : (i64) -> ()
// if T >= 0 and T < len_j:
// body(i, j)
// j += stepj
//
// CHECK: [[END:%.*]] = arith.addi [[START]], [[LEN_J]]
// CHECK-NEXT: [[GE:%.*]] = arith.cmpi sge, [[T]], [[START]]
// CHECK-NEXT: [[LT:%.*]] = arith.cmpi slt, [[T]], [[END]]
// CHECK-NEXT: [[COND:%.*]] = arith.andi [[GE]], [[LT]]
// CHECK-NEXT: [[J_NEXT:%.*]] = scf.if [[COND]] -> (i64) {
// CHECK-NEXT: "body"([[JI]]#1, [[JI]]#0) : (i64, i64) -> ()
// CHECK-NEXT: [[J_INCR:%.*]] = arith.addi [[JI]]#0, [[STEPJ]]
// CHECK-NEXT: yield [[J_INCR]]
// CHECK-NEXT: } else {
// CHECK-NEXT: yield [[JI]]#0
// CHECK-NEXT: }
scf.for %j = %lbj to %ubj step %stepj : i64 {
"body"(%i, %j) : (i64, i64) -> ()
}
// if T == max(1, len_j) - 1:
// epilogue(i)
// i += stepi
//
// CHECK-NEXT: [[EPILOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[T_END]]
// CHECK-NEXT: scf.if [[EPILOGUE_COND]] {
// CHECK-NEXT: "epilogue"([[JI]]#1) : (i64) -> ()
// CHECK-NEXT: } else {
// CHECK-NEXT: }
"epilogue"(%i) : (i64) -> ()
// CHECK-NEXT: yield [[T]], [[JI]]#1, [[J_NEXT]] : i64, i64, i64
} {"ttg.always-fuse"}
tt.return
}
// CHECK-LABEL: @fuse_one_level_inouts
// CHECK-SAME: [[LBI:%.*]]: i64, [[UBI:%.*]]: i64, [[STEPI:%.*]]: i64, [[LBJ:%.*]]: i64, [[UBJ:%.*]]: i64, [[STEPJ:%.*]]: i64
// CHECK-SAME: [[INOUT:%.*]]: index
tt.func @fuse_one_level_inouts(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ubj: i64, %stepj: i64, %inout: index) -> index {
// CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]]
// CHECK: [[OUTER_OUTS:%.*]]:6 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS:%.*]] step %c1_i64 iter_args(
// CHECK-SAME: [[T_ARG:%arg[0-9]+]] = %c-1_i64,
// CHECK-SAME: [[I_ARG:%arg[0-9]+]] = [[I_INIT]]
// CHECK-SAME: [[M:%arg[0-9]+]] = [[INOUT]]
// CHECK-SAME: [[J_ARG:%arg[0-9]+]] = %c0_i64
// CHECK-SAME: [[K_ARG:%arg[0-9]+]] = %c0
// CHECK-SAME: [[PROLOGUE_OUT_ARG:%arg[0-9]+]] = %c0
// CHECK-SAME: ) -> (i64, i64, index, i64, index, index) : i64 {
%outer_out = scf.for %i = %lbi to %ubi step %stepi iter_args(%m = %inout) -> index : i64 {
// if T == 0:
// i += stepi
// prologue(i)
// j = lbj
//
// CHECK: [[PROLOGUE_OUTS:%.*]]:4 = scf.if %{{[0-9]+}} -> (i64, index, index, i64) {
// CHECK-NEXT: [[I:%.*]] = arith.addi [[I_ARG]], [[STEPI]]
// CHECK-NEXT: [[PROLOGUE_RES:%.*]] = "prologue"([[I]], [[INOUT]], [[M]]) : (i64, index, index) -> index
// CHECK-NEXT: yield [[LBJ]], [[PROLOGUE_RES]], [[M]], [[I]]
// CHECK-NEXT: } else {
// CHECK-NEXT: yield [[J_ARG]], [[PROLOGUE_OUT_ARG]], [[K_ARG]], [[I_ARG]]
// CHECK-NEXT: }
//
// J := [[PROLOGUE_OUTS]]#0
// PROLOGUE_OUT := [[PROLOGUE_OUTS]]#1
// K := [[PROLOGUE_OUTS]]#2
// I := [[PROLOGUE_OUTS]]#3
%prologue_out = "prologue"(%i, %inout, %m) : (i64, index, index) -> index
// if T >= 0 and T < len_j:
// body(i, j)
// j += stepj
//
// CHECK: [[BODY_OUTS:%.*]]:2 = scf.if {{.*}} -> (i64, index) {
// CHECK-NEXT: [[BODY_OUT:%.*]] = "body"([[PROLOGUE_OUTS]]#3, [[PROLOGUE_OUTS]]#0, [[PROLOGUE_OUTS]]#2, [[PROLOGUE_OUTS]]#1, [[M]]) : (i64, i64, index, index, index) -> index
// CHECK-NEXT: [[J_INCR:%.*]] = arith.addi [[PROLOGUE_OUTS]]#0, [[STEPJ]]
// CHECK-NEXT: yield [[J_INCR]], [[BODY_OUT]]
// CHECK-NEXT: } else {
// CHECK-NEXT: yield [[PROLOGUE_OUTS]]#0, [[K_ARG]]
// CHECK-NEXT: }
%inner_out = scf.for %j = %lbj to %ubj step %stepj iter_args(%k = %m) -> index : i64 {
%body_out = "body"(%i, %j, %k, %prologue_out, %m) : (i64, i64, index, index, index) -> index
scf.yield %body_out : index
}
// if T == max(1, len_j) - 1:
// epilogue(i)
// i += stepi
//
// CHECK: [[EPILOGUE_OUTS:%.*]] = scf.if {{.*}} -> (index) {
// CHECK-NEXT: [[EPILOGUE_OUT:%.*]] = "epilogue"([[PROLOGUE_OUTS]]#3, [[PROLOGUE_OUTS]]#1, [[BODY_OUTS]]#1, [[M]]) : (i64, index, index, index) -> index
// CHECK-NEXT: yield [[EPILOGUE_OUT]]
// CHECK-NEXT: } else {
// CHECK-NEXT: yield [[M]]
// CHECK-NEXT: }
%epilogue_out = "epilogue"(%i, %prologue_out, %inner_out, %m) : (i64, index, index, index) -> index
// CHECK-NEXT: yield %{{.*}}, [[PROLOGUE_OUTS]]#3, [[EPILOGUE_OUTS]], [[BODY_OUTS]]#0, [[BODY_OUTS]]#1, [[PROLOGUE_OUTS]]#1 : i64, i64, index, i64, index, index
scf.yield %epilogue_out : index
} {"ttg.always-fuse"}
// CHECK: return [[OUTER_OUTS]]#2
tt.return %outer_out : index
}
// CHECK-LABEL: @multiple_loops
tt.func @multiple_loops(
// CHECK-SAME: [[LBI:%arg[0-9]+]]: i64, [[UBI:%arg[0-9]+]]: i64, [[STEPI:%arg[0-9]+]]: i64,
// CHECK-SAME: [[LBJ0:%arg[0-9]+]]: i64, [[UBJ0:%arg[0-9]+]]: i64, [[STEPJ0:%arg[0-9]+]]: i64,
// CHECK-SAME: [[LBJ1:%arg[0-9]+]]: i64, [[UBJ1:%arg[0-9]+]]: i64, [[STEPJ1:%arg[0-9]+]]: i64,
// CHECK-SAME: [[LBJ2:%arg[0-9]+]]: i64, [[UBJ2:%arg[0-9]+]]: i64, [[STEPJ2:%arg[0-9]+]]: i64,
// CHECK-SAME: [[M0:%arg[0-9]+]]: f32
%lbi: i64, %ubi: i64, %stepi: i64,
%lbj0: i64, %ubj0: i64, %stepj0: i64,
%lbj1: i64, %ubj1: i64, %stepj1: i64,
%lbj2: i64, %ubj2: i64, %stepj2: i64,
%m0: f32) -> f32 {
// CHECK: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]]
// CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]]
// CHECK-NEXT: [[DIFF_J0:%.*]] = arith.subi [[UBJ0]], [[LBJ0]]
// CHECK-NEXT: [[LEN_J0:%.*]] = arith.ceildivsi [[DIFF_J0]], [[STEPJ0]]
// CHECK-NEXT: [[DIFF_J1:%.*]] = arith.subi [[UBJ1]], [[LBJ1]]
// CHECK-NEXT: [[LEN_J1:%.*]] = arith.ceildivsi [[DIFF_J1]], [[STEPJ1]]
// CHECK-NEXT: [[DIFF_J2:%.*]] = arith.subi [[UBJ2]], [[LBJ2]]
// CHECK-NEXT: [[LEN_J2:%.*]] = arith.ceildivsi [[DIFF_J2]], [[STEPJ2]]
// CHECK: [[PLEN0:%.*]] = arith.constant 0 : i64
// CHECK: [[LEN_J0_CLAMP:%.*]] = arith.maxsi %c1_i64, [[LEN_J0]]
// CHECK-NEXT: [[PLEN1:%.*]] = arith.addi [[PLEN0]], [[LEN_J0_CLAMP]]
// CHECK-NEXT: [[LEN_J1_CLAMP:%.*]] = arith.maxsi %c1_i64, [[LEN_J1]]
// CHECK-NEXT: [[PLEN2:%.*]] = arith.addi [[PLEN1]], [[LEN_J1_CLAMP]]
// CHECK-NEXT: [[LEN_J2_CLAMP:%.*]] = arith.maxsi %c1_i64, [[LEN_J2]]
// CHECK-NEXT: [[PLEN3:%.*]] = arith.addi [[PLEN2]], [[LEN_J2_CLAMP]]
// CHECK: [[INNER_LEN:%.*]] = arith.subi [[PLEN3]], %c2_i64
// CHECK-NEXT: [[TOTAL_ITERS:%.*]] = arith.muli [[LEN_I]], [[INNER_LEN]]
// CHECK: [[I_INIT:%.*]] = arith.subi [[LBI]], [[STEPI]]
// CHECK: [[OUTS:%.*]]:12 = scf.for %{{.*}} = %c0_i64 to [[TOTAL_ITERS]] step %c1_i64 iter_args(
// CHECK-SAME: [[T_ARG:%arg[0-9]+]] = %c-1_i64,
// CHECK-SAME: [[I_ARG:%arg[0-9]+]] = [[I_INIT]],
// CHECK-SAME: [[M:%arg[0-9]+]] = [[M0]],
// CHECK-SAME: [[J0_ARG:%arg[0-9]+]] = %c0_i64,
// CHECK-SAME: [[J1_ARG:%arg[0-9]+]] = %c0_i64,
// CHECK-SAME: [[J2_ARG:%arg[0-9]+]] = %c0_i64,
// CHECK-SAME: [[BODY0_ARG:%arg[0-9]+]] = %cst,
// CHECK-SAME: [[BODY1_ARG:%arg[0-9]+]] = %cst,
// CHECK-SAME: [[BODY2_ARG:%arg[0-9]+]] = %cst,
// CHECK-SAME: [[PROLOGUE0_ARG:%arg[0-9]+]] = %cst,
// CHECK-SAME: [[PROLOGUE1_ARG:%arg[0-9]+]] = %cst,
// CHECK-SAME: [[PROLOGUE2_ARG:%arg[0-9]+]] = %cst)
%mN = scf.for %i = %lbi to %ubi step %stepi iter_args(%m = %m0) -> f32 : i64 {
// CHECK: [[T_PLUS_1:%.*]] = arith.addi [[T_ARG]], %c1_i64
// CHECK-NEXT: [[T_END:%.*]] = arith.subi [[INNER_LEN]], %c1_i64
// CHECK-NEXT: [[ROLLOVER:%.*]] = arith.cmpi eq, [[T_ARG]], [[T_END]]
// CHECK-NEXT: [[T:%.*]] = arith.select [[ROLLOVER]], %c0_i64, [[T_PLUS_1]]
// CHECK: [[START0:%.*]] = arith.subi [[PLEN0]], %c0_i64
// CHECK-NEXT: [[PROLOGUE_COND0:%.*]] = arith.cmpi eq, [[T]], [[START0]]
// CHECK-NEXT: [[PROLOGUE0_OUTS:%.*]]:4 = scf.if [[PROLOGUE_COND0]]
// CHECK-NEXT: [[I:%.*]] = arith.addi [[I_ARG]], [[STEPI]]
// CHECK-NEXT: [[RES:%.*]] = "prologue0"([[I]], [[M]])
// CHECK-NEXT: yield [[LBJ0]], [[RES]], [[RES]], [[I]]
// CHECK-NEXT: else
// CHECK-NEXT: yield [[J0_ARG]], [[PROLOGUE0_ARG]], [[BODY0_ARG]], [[I_ARG]]
%k00 = "prologue0"(%i, %m) : (i64, f32) -> f32
// CHECK: [[END0:%.*]] = arith.addi [[START0]], [[LEN_J0]]
// CHECK-NEXT: [[GE0:%.*]] = arith.cmpi sge, [[T]], [[START0]]
// CHECK-NEXT: [[LT0:%.*]] = arith.cmpi slt, [[T]], [[END0]]
// CHECK-NEXT: [[BODY_COND0:%.*]] = arith.andi [[GE0]], [[LT0]]
// CHECK-NEXT: [[BODY0_OUTS:%.*]]:2 = scf.if [[BODY_COND0]]
// CHECK-NEXT: [[RES:%.*]] = "body0"([[PROLOGUE0_OUTS]]#3, [[PROLOGUE0_OUTS]]#0, [[PROLOGUE0_OUTS]]#2)
// CHECK-NEXT: [[NEXT_J0:%.*]] = arith.addi [[PROLOGUE0_OUTS]]#0, [[STEPJ0]]
// CHECK-NEXT: yield [[NEXT_J0]], [[RES]]
// CHECK-NEXT: else
// CHECK-NEXT: yield [[PROLOGUE0_OUTS]]#0, [[BODY0_ARG]]
%k0N = scf.for %j0 = %lbj0 to %ubj0 step %stepj0 iter_args(%k0 = %k00) -> f32 : i64 {
%res = "body0"(%i, %j0, %k0) : (i64, i64, f32) -> f32
scf.yield %res : f32
}
// CHECK: [[START1:%.*]] = arith.subi [[PLEN1]], %c1_i64
// CHECK-NEXT: [[PROLOGUE_COND1:%.*]] = arith.cmpi eq, [[T]], [[START1]]
// CHECK-NEXT: [[PROLOGUE1_OUTS:%.*]]:3 = scf.if [[PROLOGUE_COND1]]
// CHECK-NEXT: [[RES:%.*]] = "prologue1"([[PROLOGUE0_OUTS]]#3, [[BODY0_OUTS]]#1)
// CHECK-NEXT: yield [[LBJ1]], [[RES]], [[RES]]
// CHECK-NEXT: else
// CHECK-NEXT: yield [[J1_ARG]], [[PROLOGUE1_ARG]], [[BODY1_ARG]]
%k10 = "prologue1"(%i, %k0N) : (i64, f32) -> f32
// CHECK: [[END1:%.*]] = arith.addi [[START1]], [[LEN_J1]]
// CHECK-NEXT: [[GE1:%.*]] = arith.cmpi sge, [[T]], [[START1]]
// CHECK-NEXT: [[LT1:%.*]] = arith.cmpi slt, [[T]], [[END1]]
// CHECK-NEXT: [[BODY_COND1:%.*]] = arith.andi [[GE1]], [[LT1]]
// CHECK-NEXT: [[BODY1_OUTS:%.*]]:2 = scf.if [[BODY_COND1]]
// CHECK-NEXT: [[RES:%.*]] = "body1"([[PROLOGUE0_OUTS]]#3, [[PROLOGUE1_OUTS]]#0, [[PROLOGUE1_OUTS]]#2)
// CHECK-NEXT: [[NEXT_J1:%.*]] = arith.addi [[PROLOGUE1_OUTS]]#0, [[STEPJ1]]
// CHECK-NEXT: yield [[NEXT_J1]], [[RES]]
// CHECK-NEXT: else
// CHECK-NEXT: yield [[PROLOGUE1_OUTS]]#0, [[BODY1_ARG]]
%k1N = scf.for %j1 = %lbj1 to %ubj1 step %stepj1 iter_args(%k1 = %k10) -> f32 : i64 {
%res = "body1"(%i, %j1, %k1) : (i64, i64, f32) -> f32
scf.yield %res : f32
}
// CHECK: [[START2:%.*]] = arith.subi [[PLEN2]], %c2_i64
// CHECK-NEXT: [[PROLOGUE_COND2:%.*]] = arith.cmpi eq, [[T]], [[START2]]
// CHECK-NEXT: [[PROLOGUE2_OUTS:%.*]]:3 = scf.if [[PROLOGUE_COND2]]
// CHECK-NEXT: [[RES:%.*]] = "prologue2"([[PROLOGUE0_OUTS]]#3, [[BODY1_OUTS]]#1)
// CHECK-NEXT: yield [[LBJ2]], [[RES]], [[RES]]
// CHECK-NEXT: else
// CHECK-NEXT: yield [[J2_ARG]], [[PROLOGUE2_ARG]], [[BODY2_ARG]]
%k20 = "prologue2"(%i, %k1N) : (i64, f32) -> f32
// CHECK: [[END2:%.*]] = arith.addi [[START2]], [[LEN_J2]]
// CHECK-NEXT: [[GE2:%.*]] = arith.cmpi sge, [[T]], [[START2]]
// CHECK-NEXT: [[LT2:%.*]] = arith.cmpi slt, [[T]], [[END2]]
// CHECK-NEXT: [[BODY_COND2:%.*]] = arith.andi [[GE2]], [[LT2]]
// CHECK-NEXT: [[BODY2_OUTS:%.*]]:2 = scf.if [[BODY_COND2]]
// CHECK-NEXT: [[RES:%.*]] = "body2"([[PROLOGUE0_OUTS]]#3, [[PROLOGUE2_OUTS]]#0, [[PROLOGUE2_OUTS]]#2)
// CHECK-NEXT: [[NEXT_J2:%.*]] = arith.addi [[PROLOGUE2_OUTS]]#0, [[STEPJ2]]
// CHECK-NEXT: yield [[NEXT_J2]], [[RES]]
// CHECK-NEXT: else
// CHECK-NEXT: yield [[PROLOGUE2_OUTS]]#0, [[BODY2_ARG]]
%k2N = scf.for %j2 = %lbj2 to %ubj2 step %stepj2 iter_args(%k2 = %k20) -> f32 : i64 {
%res = "body2"(%i, %j2, %k2) : (i64, i64, f32) -> f32
scf.yield %res : f32
}
// CHECK: [[EPILOGUE_COND:%.*]] = arith.cmpi eq, [[T]], [[T_END]]
// CHECK-NEXT: [[EPILOGUE_OUTS:%.*]] = scf.if [[EPILOGUE_COND]]
// CHECK-NEXT: [[RES:%.*]] = "epilogue"([[PROLOGUE0_OUTS]]#3, [[BODY2_OUTS]]#1)
// CHECK-NEXT: yield [[RES]]
// CHECK-NEXT: else
// CHECK-NEXT: yield [[M]]
%out = "epilogue"(%i, %k2N) : (i64, f32) -> f32
// CHECK: scf.yield [[T]], [[PROLOGUE0_OUTS]]#3, [[EPILOGUE_OUTS]],
// CHECK-SAME: [[BODY0_OUTS]]#0, [[BODY1_OUTS]]#0, [[BODY2_OUTS]]#0,
// CHECK-SAME: [[PROLOGUE0_OUTS]]#1, [[PROLOGUE1_OUTS]]#1, [[PROLOGUE2_OUTS]]#1 :
scf.yield %out : f32
} {"ttg.always-fuse"}
// CHECK: return [[OUTS]]#2
tt.return %mN : f32
}
// CHECK-LABEL: @two_loop_nests
tt.func @two_loop_nests(%lbi: i64, %ubi: i64, %stepi: i64, %lbj: i64, %ubj: i64, %stepj: i64) {
// CHECK-COUNT-2: scf.for
scf.for %i = %lbi to %ubi step %stepi : i64 {
scf.for %j = %lbj to %ubj step %stepj : i64 {
"body"(%i, %j) : (i64, i64) -> ()
}
} {"ttg.always-fuse"}
scf.for %i = %lbi to %ubi step %stepi : i64 {
scf.for %j = %lbj to %ubj step %stepj : i64 {
"body"(%i, %j) : (i64, i64) -> ()
}
} {"ttg.always-fuse"}
// CHECK-NOT: scf.for
// CHECK: tt.return
tt.return
}
// CHECK-LABEL: @hoist_loop_bound_computations
// CHECK-SAME: [[LBI:%.*]]: i64, [[UBI:%.*]]: i64, [[STEPI:%.*]]: i64
tt.func @hoist_loop_bound_computations(%lbi: i64, %ubi: i64, %stepi: i64) {
// CHECK-NEXT: [[LBJ:%.*]] = arith.addi [[LBI]], [[STEPI]]
// CHECK-NEXT: [[UBJ:%.*]] = arith.addi [[UBI]], [[STEPI]]
// CHECK-NEXT: [[STEPJ:%.*]] = arith.addi [[STEPI]], [[STEPI]]
// CHECK-NEXT: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]]
// CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]]
// CHECK-NEXT: [[DIFF_J:%.*]] = arith.subi [[UBJ]], [[LBJ]]
// CHECK-NEXT: [[LEN_J:%.*]] = arith.ceildivsi [[DIFF_J]], [[STEPJ]]
// CHECK: scf.for
scf.for %i = %lbi to %ubi step %stepi : i64 {
%lbj = arith.addi %lbi, %stepi : i64
%ubj = arith.addi %ubi, %stepi : i64
%stepj = arith.addi %stepi, %stepi : i64
// CHECK: [[J:%.*]]:2 = scf.if
// CHECK: yield [[LBJ]]
// CHECK: scf.if
// CHECK-NEXT: "body"
// CHECK-NEXT: arith.addi [[J]]#0, [[STEPJ]]
scf.for %j = %lbj to %ubj step %stepj : i64 {
"body"(%i, %j) : (i64, i64) -> ()
}
} {"ttg.always-fuse"}
tt.return
}
// CHECK-LABEL: @cannot_fuse
tt.func @cannot_fuse(%lbi: i64, %ubi: i64, %stepi: i64) {
// CHECK-COUNT-2: scf.for
scf.for %i = %lbi to %ubi step %stepi : i64 {
%lbj = arith.addi %lbi, %stepi : i64
%ubj = arith.addi %ubi, %i : i64
%stepj = arith.addi %stepi, %stepi : i64
scf.for %j = %lbj to %ubj step %stepj : i64 {
"body"(%i, %j) : (i64, i64) -> ()
}
} {"ttg.always-fuse"}
// CHECK-NOT: scf.for
tt.return
}
// CHECK-LABEL: @upcast_i16_to_i32
// CHECK-SAME: [[LBI:%.*]]: i32, [[UBI:%.*]]: i32, [[STEPI:%.*]]: i32, [[LBJ:%.*]]: i16, [[UBJ:%.*]]: i16, [[STEPJ:%.*]]: i16
tt.func @upcast_i16_to_i32(%lbi: i32, %ubi: i32, %stepi: i32, %lbj: i16, %ubj: i16, %stepj: i16) {
// CHECK-NEXT: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]] : i32
// CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]] : i32
// CHECK-NEXT: [[DIFF_J:%.*]] = arith.subi [[UBJ]], [[LBJ]] : i16
// CHECK-NEXT: [[LEN_J:%.*]] = arith.ceildivsi [[DIFF_J]], [[STEPJ]] : i16
// CHECK: arith.extsi [[LEN_J]] : i16 to i32
scf.for %i = %lbi to %ubi step %stepi : i32 {
scf.for %j = %lbj to %ubj step %stepj : i16 {
"body"(%i, %j) : (i32, i16) -> ()
}
} {"ttg.always-fuse"}
tt.return
}
// CHECK-LABEL: @upcast_index_to_i64
// CHECK-SAME: [[LBI:%.*]]: index, [[UBI:%.*]]: index, [[STEPI:%.*]]: index, [[LBJ:%.*]]: index, [[UBJ:%.*]]: index, [[STEPJ:%.*]]: index
tt.func @upcast_index_to_i64(%lbi: index, %ubi: index, %stepi: index, %lbj: index, %ubj: index, %stepj: index) {
// CHECK-NEXT: [[DIFF_I:%.*]] = arith.subi [[UBI]], [[LBI]] : index
// CHECK-NEXT: [[LEN_I:%.*]] = arith.ceildivsi [[DIFF_I]], [[STEPI]] : index
// CHECK-NEXT: [[DIFF_J:%.*]] = arith.subi [[UBJ]], [[LBJ]] : index
// CHECK-NEXT: [[LEN_J:%.*]] = arith.ceildivsi [[DIFF_J]], [[STEPJ]] : index
// CHECK: arith.index_cast [[LEN_J]] : index to i64
// CHECK: arith.index_cast [[LEN_I]] : index to i64
scf.for %i = %lbi to %ubi step %stepi {
scf.for %j = %lbj to %ubj step %stepj {
"body"(%i, %j) : (index, index) -> ()
}
} {"ttg.always-fuse"}
tt.return
}
// CHECK-LABEL: @triple_loop_nest
tt.func @triple_loop_nest(
%lbi: i64, %ubi: i64, %stepi: i64,
%lbj: i64, %ubj: i64, %stepj: i64,
%lbk: i64, %ubk: i64, %stepk: i64) {
// CHECK-COUNT-1: scf.for
scf.for %i = %lbi to %ubi step %stepi : i64 {
scf.for %j = %lbj to %ubj step %stepj : i64 {
scf.for %k = %lbk to %ubk step %stepk : i64 {
"body"(%i, %j, %k) : (i64, i64, i64) -> ()
}
}
} {"ttg.always-fuse"}
// CHECK-NOT: scf.for
// CHECK: tt.return
tt.return
}
// CHECK-LABEL: @preserve_stage_count
tt.func @preserve_stage_count(%lb: i32, %ub: i32) {
%c1_i32 = arith.constant 1 : i32
// CHECK-COUNT-1: scf.for
scf.for %i = %lb to %ub step %c1_i32 : i32 {
scf.for %j = %lb to %ub step %c1_i32 : i32 {
"body"(%j) : (i32) -> ()
scf.yield
} {tt.num_stages = 4 : i32}
scf.for %j = %lb to %ub step %c1_i32 : i32 {
"body"(%j) : (i32) -> ()
scf.yield
} {tt.num_stages = 5 : i32}
} {"ttg.always-fuse", "tt.disallow_acc_multi_buffer", tt.num_stages = 6 : i32}
// CHECK: tt.disallow_acc_multi_buffer
// CHECK: tt.num_stages = 6 : i32
// CHECK-NOT: scf.for
tt.return
}
// CHECK-LABEL: @fuse_attr_speculate
// CHECK-SAME: [[LB:%.*]]: i32, [[UB:%.*]]: i32
tt.func @fuse_attr_speculate(%lb: i32, %ub: i32) {
%c1_i32 = arith.constant 1 : i32
// CHECK: [[DIFF:%.*]] = arith.subi [[UB]], [[LB]]
// CHECK: [[LEN:%.*]] = arith.ceildivsi [[DIFF]], %c1_i32
// CHECK: [[IS_ZERO:%.*]] = arith.cmpi eq, [[LEN]], %c0_i32
// CHECK: scf.if [[IS_ZERO]]
// CHECK-NEXT: scf.for %{{.*}} = [[LB]] to [[UB]] step %c1_i32
// CHECK-NEXT: "prologue"
// CHECK-NXET: }
// CHECK: else
// CHECK-COUNT-1: scf.for
// CHECK-NOT: scf.for
scf.for %i = %lb to %ub step %c1_i32 : i32 {
// CHECK: "prologue"
"prologue"(%i) : (i32) -> ()
// CHECK: scf.if %true
scf.for %j = %lb to %ub step %c1_i32 : i32 {
// CHECK-NEXT: "body"
"body"(%i, %j) : (i32, i32) -> ()
scf.yield
}
} {tt.flatten}
tt.return
}
// CHECK-LABEL: @speculate_hoist
// CHECK-SAME: [[LB:%.*]]: i32, [[UB:%.*]]: i32
tt.func @speculate_hoist(%lb: i32, %ub: i32) {
%c1_i32 = arith.constant 1 : i32
// CHECK: [[UBJ:%.*]] = arith.addi [[LB]], [[UB]]
// CHECK: [[DIFF:%.*]] = arith.subi [[UBJ]], [[LB]]
// CHECK: [[LEN:%.*]] = arith.ceildivsi [[DIFF]], %c1_i32
// CHECK: [[IS_ZERO:%.*]] = arith.cmpi eq, [[LEN]], %c0_i32
// CHECK: scf.if [[IS_ZERO]]
scf.for %i = %lb to %ub step %c1_i32 : i32 {
"prologue"(%i) : (i32) -> ()
%ubj = arith.addi %lb, %ub : i32
scf.for %j = %lb to %ubj step %c1_i32 : i32 {
"body"(%i, %j) : (i32, i32) -> ()
scf.yield
}
} {tt.flatten}
tt.return
}
// CHECK-LABEL: @sink_prologue_to_epilogue
// CHECK-SAME: [[UB:%.*]]: i32
tt.func @sink_prologue_to_epilogue(%ub: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
// CHECK: else
// CHECK: scf.for
%0 = scf.for %i = %c0_i32 to %ub step %c1_i32 iter_args(%k = %c0_i32) -> i32 : i32 {
// CHECK: [[PROLOGUE_OUTS:%.*]]:2 = scf.if
%0 = arith.addi %i, %ub : i32
// CHECK: scf.if %true
// CHECK-NEXT: "body"
scf.for %j = %c0_i32 to %ub step %c1_i32 : i32 {
"body"(%i, %j) : (i32, i32) -> ()
scf.yield
}
// CHECK: scf.if
// CHECK-NEXT: [[V0:%.*]] = arith.addi [[PROLOGUE_OUTS]]#1, [[UB]]
// CHECK-NEXT: [[V1:%.*]] = arith.addi [[V0]], [[UB]]
%1 = arith.addi %0, %ub : i32
// CHECK-NEXT: "epilogue"([[V1]])
"epilogue"(%1) : (i32) -> ()
scf.yield %0 : i32
} {tt.flatten}
tt.return
}
// -----
// CHECK-LABEL: @prologue_output
tt.func @prologue_output(%ub: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
// CHECK: scf.for
%0 = scf.for %i = %c0_i32 to %ub step %c1_i32 iter_args(%k = %c0_i32) -> i32 : i32 {
// CHECK: scf.if
// CHECK: {increment}
%next = arith.addi %k, %c1_i32 {increment} : i32
// CHECK: scf.if
scf.for %j = %c0_i32 to %ub step %c1_i32 : i32 {
// CHECK-NEXT: "body"
"body"(%i, %j) : (i32, i32) -> ()
}
// CHECK: scf.if {{%[0-9]+}} {
// CHECK-NEXT: "epilogue"
"epilogue"(%i) : (i32) -> ()
// CHECK-NEXT: } else {
scf.yield %next : i32
} {"ttg.always-fuse"}
tt.return
}