// RUN: triton-opt %s -split-input-file -tritongpu-combine-tensor-select-and-if | FileCheck %s

#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @select_if_combine
  tt.func public @select_if_combine(%arg0: tensor<64xf32, #blocked>, %dst_ptr: tensor<64x!tt.ptr<f32>, #blocked>, %cnd: i1) attributes {noinline = false} {
    // CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00>
    %cst = arith.constant dense<0.000000e+00> : tensor<64xf32, #blocked>
    // CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00>
    %cst_1 = arith.constant dense<1.000000e+00> : tensor<64xf32, #blocked>
    // CHECK-NOT: arith.select
    %sel = arith.select %cnd, %cst, %cst_1 : tensor<64xf32, #blocked>
    // CHECK: %[[IF_RES:.*]] = scf.if
    scf.if %cnd {
      tt.store %dst_ptr, %arg0 : tensor<64x!tt.ptr<f32>, #blocked>
    // CHECK: scf.yield %[[CST0]]
    }
    // CHECK: else
    // CHECK: scf.yield %[[CST1]]
    // CHECK: tt.store %{{.*}}, %[[IF_RES]]
    tt.store %dst_ptr, %sel : tensor<64x!tt.ptr<f32>, #blocked>
    tt.return
  }
}

// -----

// CHECK-LABEL: @if_multiple_sel
tt.func @if_multiple_sel(%arg0: i1, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> (i32, f32, i32){
// CHECK-NOT: select
// CHECK: %[[R:.+]]:3 = scf.if %{{.*}} -> (i32, i32, f32) {
// CHECK:   scf.yield {{.*}} : i32, i32, f32
// CHECK: } else {
// CHECK:   scf.yield {{.*}} : i32, i32, f32
// CHECK: }
// CHECK: tt.return %[[R]]#1, %[[R]]#2, %[[R]]#0 : i32, f32, i32
  %0 = arith.select %arg0, %arg1, %arg2 : i32
  %1 = arith.select %arg0, %arg3, %arg4 : f32
  %2 = scf.if %arg0 -> (i32) {
    %3 = arith.subi %arg1, %arg2 : i32
    scf.yield %3 : i32
  } else {
    scf.yield %arg1 : i32
  }
  tt.return %0, %1, %2 : i32, f32, i32
}