// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-assign-latencies=num-stages=3 -tritongpu-schedule-loops | FileCheck %s

#AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#BL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
#ALs0 = #ttg.slice<{parent=#AL, dim=0}>
#BLs0 = #ttg.slice<{parent=#BL, dim=0}>
#CLs0 = #ttg.slice<{parent=#C, dim=0}>
#A = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} {
// CHECK-LABLE: @matmul_loop_load_acc
// CHECK: tt.load %{{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32}
// CHECK: tt.load %{{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32}
// CHECK: tt.load %{{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32}
// CHECK: tt.dot {{.*}} {loop.cluster = 0 : i32, loop.stage = 2 : i32}
tt.func @matmul_loop_load_acc(%lb : index, %ub : index, %step : index,
                  %A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
                  %B : !tt.ptr<f16> {tt.divisibility = 16 : i32},
                  %C : !tt.ptr<f32> {tt.divisibility = 16 : i32},
                  %c_init: tensor<128x128xf32, #C>) -> tensor<128x128xf32, #C> {

  // A ptrs
  %a_ptr_splat = tt.splat %A : !tt.ptr<f16> -> tensor<128x32x!tt.ptr<f16>, #AL>
  %a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0>
  %a_tmp1 = tt.expand_dims %a_tmp0 {axis = 0 : i32} : tensor<32xi32, #ALs0> -> tensor<1x32xi32, #AL>
  %a_offs = tt.broadcast %a_tmp1 : tensor<1x32xi32, #AL> -> tensor<128x32xi32, #AL>
  %a_ptr_init = tt.addptr %a_ptr_splat, %a_offs : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
  // B ptrs
  %b_ptr_splat = tt.splat %B : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #BL>
  %b_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #BLs0>
  %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : tensor<128xi32, #BLs0> -> tensor<1x128xi32, #BL>
  %b_offs = tt.broadcast %b_tmp1 : tensor<1x128xi32, #BL> -> tensor<32x128xi32, #BL>
  %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
  // C ptrs
  %c_ptr_splat = tt.splat %C : !tt.ptr<f32> -> tensor<128x128x!tt.ptr<f32>, #C>
  %c_tmp0 = tt.make_range {end = 128: i32, start = 0: i32} : tensor<128xi32, #CLs0>
  %c_tmp1 = tt.expand_dims %c_tmp0 {axis = 0 : i32} : tensor<128xi32, #CLs0> -> tensor<1x128xi32, #C>
  %c_offs = tt.broadcast %c_tmp1 : tensor<1x128xi32, #C> -> tensor<128x128xi32, #C>
  %c_ptr_init = tt.addptr %c_ptr_splat, %c_offs : tensor<128x128x!tt.ptr<f32>, #C>, tensor<128x128xi32, #C>

  %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
  %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
  %c_off = arith.constant dense<4> : tensor<128x128xi32, #C>

  %loop:4 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %c_ptr = %c_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128x!tt.ptr<f32>, #C>, tensor<128x128xf32, #C>) {
    %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr<f16>, #AL>
    %a = ttg.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A>
    %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr<f16>, #BL>
    %b = ttg.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B>
    %c_ = tt.load %c_ptr : tensor<128x128x!tt.ptr<f32>, #C>
    %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>

    %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
    %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
    %next_c_ptr = tt.addptr %c_ptr, %c_off : tensor<128x128x!tt.ptr<f32>, #C>, tensor<128x128xi32, #C>
    scf.yield %next_a_ptr, %next_b_ptr, %next_c_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128x!tt.ptr<f32>, #C>, tensor<128x128xf32, #C>
  }
  tt.return %loop#3: tensor<128x128xf32, #C>
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}>
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
#smem = #ttg.shared_memory

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {

// CHECK-LABEL: @fused_loop
tt.func public @fused_loop(%arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}) {
  %c10_i32 = arith.constant 10 : i32
  %false = arith.constant false
  %0 = ub.poison : !tt.tensordesc<tensor<64x256xf16>>
  %cst = arith.constant dense<0> : tensor<128x1xi64, #blocked>
  %c-1_i32 = arith.constant -1 : i32
  %c1_i32 = arith.constant 1 : i32
  %c0_i32 = arith.constant 0 : i32
  %c64_i32 = arith.constant 64 : i32
  %c1_i64 = arith.constant 1 : i64
  %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>

  %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
  %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
  %3 = arith.extsi %arg7 : i32 to i64
  %4 = tt.make_tensor_descriptor %arg5, [%arg7, %arg7], [%3, %c1_i64] : <f16>, <tensor<64x256xf16>>
  %5 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked>
  %7 = tt.splat %3 : i64 -> tensor<128x1xi64, #blocked>

  // CHECK: scf.for
  %8:9 = scf.for %arg29 = %c0_i32 to %arg7 step %c1_i32 iter_args(%arg30 = %c-1_i32, %arg31 = %4, %arg32 = %c0_i32, %arg33 = %arg5, %arg34 = %cst_0, %arg35 = %c0_i32, %arg36 = %cst, %arg37 = %0, %arg38 = %false) -> (i32, !tt.tensordesc<tensor<64x256xf16>>, i32, !tt.ptr<f16>, tensor<128x256xf32, #mma>, i32, tensor<128x1xi64, #blocked>, !tt.tensordesc<tensor<64x256xf16>>, i1)  : i32 {
    %9 = arith.addi %arg30, %c1_i32 : i32
    %10 = arith.cmpi eq, %arg30, %c10_i32 : i32
    %11 = arith.select %10, %c0_i32, %9 : i32
    %12 = arith.cmpi eq, %11, %c0_i32 : i32

    // This op is a distance 1 dependency of itself.
    // CHECK: {_test_marker_0, loop.cluster = 4 : i32, loop.stage = 0 : i32}
    %13 = arith.select %12, %c0_i32, %arg32 {_test_marker_0} : i32

    %14 = arith.select %12, %arg31, %arg37 : !tt.tensordesc<tensor<64x256xf16>>
    %15 = arith.select %12, %c10_i32, %arg35 : i32
    %16 = scf.if %12 -> (tensor<128x1xi64, #blocked>) {
      %32 = arith.muli %cst, %7 : tensor<128x1xi64, #blocked>
      scf.yield %32 : tensor<128x1xi64, #blocked>
    } else {
      scf.yield %arg36 : tensor<128x1xi64, #blocked>
    }
    %17 = tt.splat %arg33 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked>
    %18 = tt.addptr %17, %16 : tensor<128x1x!tt.ptr<f16>, #blocked>, tensor<128x1xi64, #blocked>
    %19 = tt.broadcast %18 : tensor<128x1x!tt.ptr<f16>, #blocked> -> tensor<128x64x!tt.ptr<f16>, #blocked>
    %20 = tt.addptr %19, %5 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi32, #blocked>
    %21 = tt.addptr %arg33, %c64_i32 : !tt.ptr<f16>, i32
    %22 = tt.load %20 : tensor<128x64x!tt.ptr<f16>, #blocked>
    %23 = ttg.local_alloc %22 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
    %24 = arith.muli %13, %c64_i32 : i32
    %25 = tt.descriptor_load %14[%24, %15] : !tt.tensordesc<tensor<64x256xf16>> -> tensor<64x256xf16, #blocked1>
    %26 = ttg.local_alloc %25 : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #smem>
    %27 = ttng.warp_group_dot %23, %26, %arg34, %arg38 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared, #smem> -> tensor<128x256xf32, #mma>
    %28 = arith.addi %13, %c1_i32 : i32

    // This op is in the backward slice of `_test_marker_2` and the epilogue.
    // CHECK: {_test_marker_1, loop.cluster = 3 : i32, loop.stage = 1 : i32}
    %29 = arith.cmpi eq, %11, %c10_i32 {_test_marker_1} : i32

    // CHECK: {_test_marker_2, loop.cluster = 3 : i32, loop.stage = 1 : i32}
    %30 = arith.select %29, %arg5, %21 {_test_marker_2} : !tt.ptr<f16>

    %31 = arith.cmpi ne, %11, %c10_i32 : i32

    scf.if %29 {
      "use"(%27) : (tensor<128x256xf32, #mma>) -> ()
      // CHECK: {_test_marker_3, loop.cluster = 5 : i32, loop.stage = 2 : i32}
    } {_test_marker_3}
    scf.yield %11, %14, %28, %30, %27, %15, %16, %14, %31 : i32, !tt.tensordesc<tensor<64x256xf16>>, i32, !tt.ptr<f16>, tensor<128x256xf32, #mma>, i32, tensor<128x1xi64, #blocked>, !tt.tensordesc<tensor<64x256xf16>>, i1
  }
  tt.return
}

}

// -----

// CHECK-LABEL: @prologue_backward_slice
tt.func @prologue_backward_slice(%ub: i32, %cond: i1) {
  %c0_i32 = arith.constant 0 : i32
  %c1_i32 = arith.constant 1 : i32

  // CHECK: scf.for
  scf.for %i = %c0_i32 to %ub step %c1_i32 : i32 {
    // CHECK: scf.if
    %0 = scf.if %cond -> i32 {
      scf.yield %c0_i32 : i32
    } else {
      scf.yield %c1_i32 : i32
    }
    // CHECK: loop.cluster = 0 : i32, loop.stage = 0 : i32

    // CHECK: op.with_region
    %1 = "op.with_region"() ({
      "use"(%0) : (i32) -> ()
    }) : () -> i32
    // CHECK: loop.cluster = 1 : i32, loop.stage = 0 : i32

    // CHECK: op.with_region
    "op.with_region"() ({
      "use"(%1) : (i32) -> ()
    }) {tt.latency = 2 : i32} : () -> ()
    // CHECK: loop.cluster = 1 : i32, loop.stage = 0 : i32

  } {tt.num_stages = 3 : i32}

  tt.return
}

// -----

// CHECK-LABEL: @epilogue_forward_slice
tt.func @epilogue_forward_slice(%ub: i32, %cond: i1) {
  %c0_i32 = arith.constant 0 : i32
  %c1_i32 = arith.constant 1 : i32

  // CHECK: scf.for
  scf.for %i = %c0_i32 to %ub step %c1_i32 : i32 {
    // CHECK: "latency.op"() {loop.cluster = 3 : i32, loop.stage = 0 : i32
    %0 = "latency.op"() {tt.latency = 2 : i32} : () -> i32
    // CHECK: scf.if
    %1 = scf.if %cond -> i32 {
      scf.yield %0 : i32
    } else {
      scf.yield %c0_i32 : i32
    }
    // CHECK: {loop.cluster = 1 : i32, loop.stage = 2 : i32}

    // CHECK: "use"(%{{.*}}) {loop.cluster = 1 : i32, loop.stage = 2 : i32}
    "use"(%1) : (i32) -> ()

  } {tt.num_stages = 3 : i32}

  tt.return
}

// -----

// CHECK-LABEL: @prologue_latency
tt.func @prologue_latency(%ub: i32, %cond: i1) {
  %c0_i32 = arith.constant 0 : i32
  %c1_i32 = arith.constant 1 : i32

  // CHECK: scf.for
  scf.for %i = %c0_i32 to %ub step %c1_i32 : i32 {
    // CHECK: "some.op"() {loop.cluster = 0 : i32, loop.stage = 0 : i32}
    %0 = "some.op"() : () -> i32
    // CHECK: scf.if
    %1 = scf.if %cond -> i32 {
      scf.yield %0 : i32
    } else {
      scf.yield %c0_i32 : i32
    } {tt.latency = 2 : i32}
    // CHECK: loop.cluster = 0 : i32, loop.stage = 0 : i32

  } {tt.num_stages = 3 : i32}

  tt.return
}