// RUN: triton-opt --split-input-file --nvws-lower-warp-group %s | FileCheck %s

#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

  // CHECK-LABEL: @warp_group
  //       CHECK-NOT: nvws.warp_group
  //       CHECK:   ttg.warp_specialize
  //       CHECK-NEXT:   default
  //       CHECK:   partition0
  //       CHECK-NEXT:   arith.constant
  //       CHECK-NEXT:   ttng.tc_gen5_mma
  tt.func @warp_group(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
                  %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>,
                  %c: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory, mutable>,
                  %accUse: i1,
                  %pred: i1,
                  %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) {
    %false = arith.constant false
    nvws.warp_group
    partition0  num_warps(8) {
      ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%false] {is_async} :
        !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
         !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>,
         !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory, mutable>,
         !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>
        nvws.warp_group.return
      }
    tt.return
  }
}

// -----

#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

  // CHECK-LABEL: @warp_default
  //       CHECK-NOT: nvws.warp_group
  //       CHECK:   ttg.warp_specialize
  //       CHECK-NEXT:   default
  //       CHECK-NEXT:   ttng.tc_gen5_mma
  tt.func @warp_default(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
                  %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>,
                  %c: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory, mutable>,
                  %accUse: i1,
                  %pred: i1,
                  %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) {
    %false = arith.constant false
    nvws.warp_group
    partition0  num_warps(4) {
      ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%false] {is_async} :
         !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
         !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>,
         !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory, mutable>,
         !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>
        nvws.warp_group.return
      }
    tt.return
  }
}

// -----

#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 8}>
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#acc_tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 256, unpacked = true>
#blocked = #ttg.blocked<{sizePerThread = [1, 256], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

  // CHECK-LABEL: @warp_multiple_group
  //       CHECK-NOT: nvws.warp_group
  //       CHECK:   ttg.warp_specialize(%
  //       CHECK-NEXT:   default
  //       CHECK-NEXT:   ttng.tc_gen5_mma
  //       CHECK:   partition0(%
  //       CHECK-NEXT:   arith.constant
  //       CHECK-NEXT:   ttg.local_load
  //       CHECK-NEXT:   ttng.wait_barrier
  //       CHECK-NEXT:   ttng.tmem_load
  //       CHECK-NEXT:   tt.store
  //       CHECK-NEXT:   ttg.warp_return
  //       CHECK-NEXT:   }
  tt.func @warp_multiple_group(%a: !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>,
                  %b: !ttg.memdesc<128x256xf16, #shared1, #ttg.shared_memory>,
                  %c: !ttg.memdesc<128x256xf16, #acc_tmem, #ttng.tensor_memory, mutable>,
                  %d: tensor<128x256x!tt.ptr<f16>, #blocked>,
                  %accUse: i1,
                  %pred: i1,
                  %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) {
    %false = arith.constant false
    %c0 = arith.constant 0 : i32
    nvws.warp_group
    partition0  num_warps(4) {
      ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%false] {is_async} :
         !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>,
         !ttg.memdesc<128x256xf16, #shared1, #ttg.shared_memory>,
         !ttg.memdesc<128x256xf16, #acc_tmem, #ttng.tensor_memory, mutable>,
         !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>
        nvws.warp_group.return
      }
    partition1 num_warps(4) {
      ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>
      %c_reg = ttng.tmem_load %c : !ttg.memdesc<128x256xf16, #acc_tmem, #ttng.tensor_memory, mutable> -> tensor<128x256xf16, #blocked>
      tt.store %d, %c_reg : tensor<128x256x!tt.ptr<f16>, #blocked>
      nvws.warp_group.return
    }
    tt.return
  }
}