// RUN: triton-opt --split-input-file %s -triton-licm | FileCheck %s
tt.func @hoist_load_without_mask(%arg0: tensor<1024x!tt.ptr<f32>>, %arg1: tensor<1024xi32>, %arg2: tensor<1024xi32>, %arg3: i32, %arg4 : i32, %arg5: tensor<1024x!tt.ptr<f32>>) {
%cst = arith.constant dense<0.000000e+00> : tensor<1024xf32>
%c1_i32 = arith.constant 1 : i32
// Check if the load is hoisted
// CHECK-LABEL: hoist_load_without_mask
// CHECK: %[[TRIP_COUNT_CMP:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]]
// CHECK: %[[SPLAT:.*]] = tt.splat %[[TRIP_COUNT_CMP]]
// CHECK: %[[LOAD:.*]] = tt.load %[[_:.*]], %[[SPLAT]]
// CHECK: arith.addf %[[LOAD]], %[[LOAD]]
// CHECK: scf.for
// CHECK-NOT: tt.load
%1 = scf.for %arg7 = %arg3 to %arg4 step %c1_i32 iter_args(%arg6 = %cst) -> (tensor<1024xf32>) : i32 {
%2 = tt.load %arg0 : tensor<1024x!tt.ptr<f32>>
%3 = arith.addf %2, %2 : tensor<1024xf32>
%4 = arith.addf %arg6, %3 : tensor<1024xf32>
scf.yield %4 : tensor<1024xf32>
}
tt.store %arg5, %1 : tensor<1024x!tt.ptr<f32>>
tt.return
}
// -----
tt.func @hoist_two_loads_without_mask(%arg0: tensor<1024x!tt.ptr<f32>>, %arg1: tensor<1024xi32>, %arg2: tensor<1024xi32>, %arg3: i32, %arg4 : i32, %arg5: tensor<1024x!tt.ptr<f32>>, %arg6: tensor<1024x!tt.ptr<f32>>) {
%cst = arith.constant dense<0.000000e+00> : tensor<1024xf32>
%c1_i32 = arith.constant 1 : i32
// CHECK-LABEL: hoist_two_loads_without_mask
// CHECK: %[[TRIP_COUNT_CMP_1:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]]
// CHECK: %[[SPLAT_1:.*]] = tt.splat %[[TRIP_COUNT_CMP_1]]
// CHECK: %[[LOAD_1:.*]] = tt.load %[[_:.*]], %[[SPLAT_1]]
// CHECK: %[[TRIP_COUNT_CMP_2:.*]] = arith.cmpi slt, %[[LB]], %[[UB]]
// CHECK: %[[SPLAT_2:.*]] = tt.splat %[[TRIP_COUNT_CMP_2]]
// CHECK: %[[LOAD_2:.*]] = tt.load %[[_:.*]], %[[SPLAT_2]]
// CHECK: arith.addf %[[LOAD_1]], %[[LOAD_2]]
// CHECK: scf.for
// CHECK-NOT: tt.load
%1 = scf.for %arg8 = %arg3 to %arg4 step %c1_i32 iter_args(%arg7 = %cst) -> (tensor<1024xf32>) : i32 {
%2 = tt.load %arg0 : tensor<1024x!tt.ptr<f32>>
%3 = tt.load %arg6 : tensor<1024x!tt.ptr<f32>>
%4 = arith.addf %2, %3 : tensor<1024xf32>
%5 = arith.addf %arg7, %4 : tensor<1024xf32>
scf.yield %5 : tensor<1024xf32>
}
tt.store %arg5, %1 : tensor<1024x!tt.ptr<f32>>
tt.return
}
// -----
tt.func @hoist_load_with_mask(%arg0: tensor<1024x!tt.ptr<f32>>, %arg1: tensor<1024xi32>, %arg2: tensor<1024xi32>, %arg3: i32, %arg4 : i32, %arg5: tensor<1024x!tt.ptr<f32>>) {
%cst = arith.constant dense<0.000000e+00> : tensor<1024xf32>
%c1_i32 = arith.constant 1 : i32
// Check if the load is hoisted
// CHECK-LABEL: hoist_load_with_mask
// CHECK: %[[MASK:.*]] = arith.cmpi
// CHECK: %[[TRIP_COUNT_CMP:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]]
// CHECK: %[[SPLAT:.*]] = tt.splat %[[TRIP_COUNT_CMP]]
// CHECK: %[[AND:.*]] = arith.andi %[[SPLAT]], %[[MASK]]
// CHECK: %[[LOAD:.*]] = tt.load %[[_:.*]], %[[AND]]
// CHECK: arith.addf %[[LOAD]], %[[LOAD]]
// CHECK: scf.for
// CHECK-NOT: tt.load
%0 = arith.cmpi slt, %arg1, %arg2 : tensor<1024xi32>
%1 = scf.for %arg7 = %arg3 to %arg4 step %c1_i32 iter_args(%arg6 = %cst) -> (tensor<1024xf32>) : i32 {
%2 = tt.load %arg0, %0 : tensor<1024x!tt.ptr<f32>>
%3 = arith.addf %2, %2 : tensor<1024xf32>
%4 = arith.addf %arg6, %3 : tensor<1024xf32>
scf.yield %4 : tensor<1024xf32>
}
tt.store %arg5, %1, %0 : tensor<1024x!tt.ptr<f32>>
tt.return
}
// -----
tt.func @cannot_hoist_with_print_in_loop(%arg0: tensor<1024x!tt.ptr<f32>>, %arg1: tensor<1024xi32>, %arg2: tensor<1024xi32>, %arg3: i32, %arg4 : i32, %arg5: tensor<1024x!tt.ptr<f32>>) {
%cst = arith.constant dense<0.000000e+00> : tensor<1024xf32>
%c1_i32 = arith.constant 1 : i32
// CHECK-NOT: tt.load
// CHECK: scf.for
// CHECK: tt.load
// CHECK: arith.addf
// CHECK: arith.addf
%0 = arith.cmpi slt, %arg1, %arg2 : tensor<1024xi32>
%1 = scf.for %arg7 = %arg3 to %arg4 step %c1_i32 iter_args(%arg6 = %cst) -> (tensor<1024xf32>) : i32 {
%2 = tt.load %arg0, %0 : tensor<1024x!tt.ptr<f32>>
%3 = arith.addf %2, %2 : tensor<1024xf32>
%4 = arith.addf %arg6, %3 : tensor<1024xf32>
tt.print " x: " {hex = false, isSigned = array<i32: 0>} : %4 : tensor<1024xf32>
scf.yield %4 : tensor<1024xf32>
}
tt.store %arg5, %1, %0 : tensor<1024x!tt.ptr<f32>>
tt.return
}
// -----
tt.func @cannot_hoist_with_assert_in_loop(%arg0: tensor<1024x!tt.ptr<f32>>, %arg1: tensor<1024xi32>, %arg2: tensor<1024xi32>, %arg3: i32, %arg4 : i32, %arg5: tensor<1024x!tt.ptr<f32>>) {
%cst = arith.constant dense<0.000000e+00> : tensor<1024xf32>
%c1_i32 = arith.constant 1 : i32
// CHECK-NOT: tt.load
// CHECK: scf.for
// CHECK: tt.load
// CHECK: arith.addf
// CHECK: arith.addf
%0 = arith.cmpi slt, %arg1, %arg2 : tensor<1024xi32>
%cmp = arith.cmpi sge, %arg4, %arg3 : i32
%1 = scf.for %arg7 = %arg3 to %arg4 step %c1_i32 iter_args(%arg6 = %cst) -> (tensor<1024xf32>) : i32 {
tt.assert %cmp, "cond must be true " : i1
%2 = tt.load %arg0, %0 : tensor<1024x!tt.ptr<f32>>
%3 = arith.addf %2, %2 : tensor<1024xf32>
%4 = arith.addf %arg6, %3 : tensor<1024xf32>
scf.yield %4 : tensor<1024xf32>
}
tt.store %arg5, %1, %0 : tensor<1024x!tt.ptr<f32>>
tt.return
}
// -----
tt.func @cannot_hoist_with_store_in_loop(%arg0: tensor<1024x!tt.ptr<f32>>, %arg1: tensor<1024xi32>, %arg2: tensor<1024xi32>, %arg3: i32, %arg4 : i32, %arg5: tensor<1024x!tt.ptr<f32>>, %tmp: tensor<1024x!tt.ptr<f32>>) {
%cst = arith.constant dense<0.000000e+00> : tensor<1024xf32>
%c1_i32 = arith.constant 1 : i32
// CHECK-NOT: tt.load
// CHECK: scf.for
// CHECK: tt.load
// CHECK: arith.addf
// CHECK: arith.addf
%0 = arith.cmpi slt, %arg1, %arg2 : tensor<1024xi32>
%1 = scf.for %arg7 = %arg3 to %arg4 step %c1_i32 iter_args(%arg6 = %cst) -> (tensor<1024xf32>) : i32 {
%2 = tt.load %arg0, %0 : tensor<1024x!tt.ptr<f32>>
%3 = arith.addf %2, %2 : tensor<1024xf32>
%4 = arith.addf %arg6, %3 : tensor<1024xf32>
tt.store %tmp, %4, %0 : tensor<1024x!tt.ptr<f32>>
scf.yield %4 : tensor<1024xf32>
}
tt.store %arg5, %1, %0 : tensor<1024x!tt.ptr<f32>>
tt.return
}
// -----
tt.func @hoist_cond_no_hoist_load_from_scf_while(%ptr: tensor<1024x!tt.ptr<f32>>, %arg1: i32, %arg2 : i32) {
%cst = arith.constant dense<0.000000e+00> : tensor<1024xf32>
// CHECK-LABEL: hoist_cond_no_hoist_load_from_scf_while
// CHECK: %[[CST42:.*]] = arith.constant 42
// CHECK: %[[ADD:.*]] = arith.addi %[[_:.*]], %[[CST42]]
// CHECK: %[[COND:.*]] = arith.cmpi slt, %[[ADD]], %[[_:.*]]
// CHECK: scf.while
// CHECK: do
// CHECK: tt.load
// CHECK: arith.addf
// CHECK: scf.yield
%1 = scf.while (%arg0 = %cst) : (tensor<1024xf32>) -> (tensor<1024xf32>) {
%cst_42 = arith.constant 42 : i32
%add_42 = arith.addi %arg1, %cst_42 : i32
%2 = arith.cmpi slt, %add_42, %arg2 : i32
scf.condition(%2) %arg0 : tensor<1024xf32>
} do {
^bb0(%arg0: tensor<1024xf32>):
%3 = tt.load %ptr : tensor<1024x!tt.ptr<f32>>
%4 = arith.addf %3, %3 : tensor<1024xf32>
scf.yield %4 : tensor<1024xf32>
}
tt.store %ptr, %1 : tensor<1024x!tt.ptr<f32>>
tt.return
}