// RUN: triton-opt %s --triton-nvidia-interleave-tmem --allow-unregistered-dialect | FileCheck %s

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}>

#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
#smem = #ttg.shared_memory
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = true>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100"} {

tt.func public @sink_load(%arg0: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
                          %arg1: tensor<128x128xf16, #blocked>,
                          %arg2: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>)
                          -> (tensor<128x64xf16, #blocked>, tensor<128x64xf16, #blocked>, tensor<128x128xf16, #blocked>) {

  // CHECK: ttg.local_alloc
  // CHECK: ttng.tmem_load
  // CHECK: ttg.convert_layout
  // CHECK: arith.truncf
  %subslice0 = ttng.tmem_subslice %arg0 {N = 0 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>
  %subtile0 = ttng.tmem_load %subslice0 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1>
  %outLHS = ttg.convert_layout %subtile0 : tensor<128x64xf32, #blocked1> -> tensor<128x64xf32, #blocked>
  %subslice1 = ttng.tmem_subslice %arg0 {N = 64 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>
  %subtile1 = ttng.tmem_load %subslice1 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1>
  %outRHS = ttg.convert_layout %subtile1 : tensor<128x64xf32, #blocked1> -> tensor<128x64xf32, #blocked>

  // CHECK: ttng.tmem_load
  // CHECK: ttg.convert_layout
  // CHECK: ttng.tmem_store
  // CHECK: arith.truncf
  %4 = ttg.local_alloc %arg1 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem>
  %5 = arith.truncf %outLHS : tensor<128x64xf32, #blocked> to tensor<128x64xf16, #blocked>

  %true = arith.constant true
  %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
  ttng.tmem_store %cst, %arg2, %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
  %6 = arith.truncf %outRHS : tensor<128x64xf32, #blocked> to tensor<128x64xf16, #blocked>

  // CHECK: ttng.tmem_load
  // CHECK: ttg.convert_layout
  // CHECK: "unknow_may_side_effect"() : () -> ()
  // CHECK: arith.truncf
  %7 = ttng.tmem_load %arg2 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1>
  %8 = ttg.convert_layout %7 : tensor<128x128xf32, #blocked1> -> tensor<128x128xf32, #blocked>
  "unknow_may_side_effect"() : () -> ()
  %9 = arith.truncf %8 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked>

  ttg.local_dealloc %4 : !ttg.memdesc<128x128xf16, #shared, #smem>
  tt.return %5, %6, %9 : tensor<128x64xf16, #blocked>, tensor<128x64xf16, #blocked>, tensor<128x128xf16, #blocked>
}

// CHECK-LABEL: @interleave_load_store_ws
tt.func @interleave_load_store_ws() {
  %0 = ttng.tmem_alloc : () -> (!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>)
  ttg.warp_specialize(%0)
  default{
    ttg.warp_yield
  }
  // CHECK: partition0
  partition0(%arg0: !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>) num_warps(8) {
    %c0 = arith.constant 0 : i32
    %c1 = arith.constant 1 : i32
    %c32 = arith.constant 32 : i32
    %alpha = arith.constant dense<0.5> : tensor<128x64xf32, #blocked1>
    %true = arith.constant true

    // CHECK: scf.for
    scf.for %i = %c0 to %c32 step %c1 : i32 {
      // CHECK: memdesc_index
      %cur_acc = ttg.memdesc_index %arg0[%i] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>

      // CHECK-NEXT: [[S0:%.+]] = ttng.tmem_subslice %{{.+}} {N = 0 : i32}
      // CHECK-NEXT: [[S1:%.+]] = ttng.tmem_subslice %{{.+}} {N = 64 : i32}

      // CHECK-NEXT: [[L0:%.+]] = ttng.tmem_load [[S0]]
      // CHECK-NEXT: [[M0:%.+]] = arith.mulf [[L0]]
      // CHECK-NEXT: ttng.tmem_store [[M0]], [[S0]]
      %slice0 = ttng.tmem_subslice %cur_acc {N = 0 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>
      %val0 = ttng.tmem_load %slice0 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1>
      %mul0 = arith.mulf %val0, %alpha : tensor<128x64xf32, #blocked1>

      // CHECK-NEXT: [[L1:%.+]] = ttng.tmem_load [[S1]]
      // CHECK-NEXT: [[M1:%.+]] = arith.mulf [[L1]]
      // CHECK-NEXT: ttng.tmem_store [[M1]], [[S1]]
      %slice1 = ttng.tmem_subslice %cur_acc {N = 64 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>
      %val1 = ttng.tmem_load %slice1 : !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x64xf32, #blocked1>
      %mul1 = arith.mulf %val1, %alpha : tensor<128x64xf32, #blocked1>

      ttng.tmem_store %mul0, %slice0, %true : tensor<128x64xf32, #blocked1> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>
      ttng.tmem_store %mul1, %slice1, %true : tensor<128x64xf32, #blocked1> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable>

    }
    ttg.warp_return
  } : (!ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>) -> ()
  tt.return
}

// CHECK-LABEL: @arrive_barrier
tt.func @arrive_barrier(%arg0: !ttg.memdesc<1xi64, #shared, #smem, mutable>) {
  %true = arith.constant true
  %cst = arith.constant dense<0.0> : tensor<128x128xf32, #blocked1>

  // CHECK-COUNT-2: ttng.tmem_alloc
  %alloc = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
  %noalias_alloc = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
  // CHECK-NEXT: tmem_store
  // CHECK-NEXT: tmem_load
  %0 = ttng.tmem_load %alloc : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1>
  ttng.tmem_store %cst, %noalias_alloc, %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
  // CHECK-NEXT: arrive_barrier
  ttng.arrive_barrier %arg0, 1 : !ttg.memdesc<1xi64, #shared, #smem, mutable>
  "user"(%0) : (tensor<128x128xf32, #blocked1>) -> ()
  tt.return
}

// CHECK-LABEL: @sink_alloc_op
tt.func @sink_alloc_op(%arg0: tensor<128x128xf32, #blocked1>) {
  %c0 = arith.constant 0 : i32
  %true = arith.constant true

  %alloc0 = ttng.tmem_alloc : () -> !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>
  %subview0 = ttg.memdesc_index %alloc0[%c0] : !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
  // CHECK: [[ALLOC1:%.+]] = ttng.tmem_alloc
  %alloc1 = ttng.tmem_alloc : () -> !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>
  // CHECK: [[SUBVIEW1:%.+]] = ttg.memdesc_index [[ALLOC1]]
  %subview1 = ttg.memdesc_index %alloc1[%c0] : !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
  // CHECK-NEXT: tmem_store %arg0, [[SUBVIEW1]]
  ttng.tmem_store %arg0, %subview1, %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
  // CHECK-NEXT: [[ALLOC0:%.+]] = ttng.tmem_alloc
  // CHECK: [[SUBVIEW0:%.+]] = ttg.memdesc_index [[ALLOC0]]
  // CHECK-NEXT: tmem_store %arg0, [[SUBVIEW0]]
  ttng.tmem_store %arg0, %subview0, %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
  tt.return
}

}