// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-task-partition=num-consumer-groups=2 | FileCheck %s

// CHECK-LABEL: @matmul_persistent_tma_ws_cooperative_kernel
// CHECK: %[[#GA:]] = tt.experimental_descriptor_load {{.*}} {async_task_id = dense<0> : vector<1xi32>}
// CHECK: %[[#LA:]] = triton_gpu.local_alloc %[[#GA]]
// CHECK: %[[#GB:]] = tt.experimental_descriptor_load {{.*}} {async_task_id = dense<0> : vector<1xi32>}
// CHECK: %[[#LB:]] = triton_gpu.local_alloc %[[#GB]]
// CHECK: %[[#C:]] = triton_nvidia_gpu.warp_group_dot %[[#LA]], %[[#LB]], {{.*}} {async_task_id = dense<[1, 2]> : vector<2xi32>

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @matmul_persistent_tma_ws_cooperative_kernel(%arg0: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}, %arg1: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}, %arg2: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %c127_i32 = arith.constant 127 : i32
    %c8_i32 = arith.constant 8 : i32
    %c128_i32 = arith.constant 128 : i32
    %c256_i32 = arith.constant 256 : i32
    %c0_i32 = arith.constant 0 : i32
    %c63_i32 = arith.constant 63 : i32
    %c255_i32 = arith.constant 255 : i32
    %c1_i32 = arith.constant 1 : i32
    %c64_i32 = arith.constant 64 : i32
    %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
    %0 = arith.addi %arg3, %c127_i32 : i32
    %1 = arith.divsi %0, %c128_i32 : i32
    %2 = arith.addi %arg4, %c255_i32 : i32
    %3 = arith.divsi %2, %c256_i32 : i32
    %4 = arith.muli %1, %3 : i32
    %5 = tt.get_program_id x : i32
    %6 = tt.get_num_programs x : i32
    %7 = arith.muli %3, %c8_i32 : i32
    %8 = arith.addi %arg5, %c63_i32 : i32
    %9 = arith.divsi %8, %c64_i32 : i32
    scf.for %arg6 = %5 to %4 step %6  : i32 {
      %10 = arith.divsi %arg6, %7 : i32
      %11 = arith.muli %10, %c8_i32 : i32
      %12 = arith.subi %1, %11 : i32
      %13 = arith.minsi %12, %c8_i32 : i32
      %14 = arith.remsi %arg6, %7 : i32
      %15 = arith.remsi %14, %13 : i32
      %16 = arith.addi %11, %15 : i32
      %17 = arith.divsi %14, %13 : i32
      %18 = arith.muli %16, %c128_i32 : i32
      %19 = arith.muli %17, %c256_i32 : i32
      %true = arith.constant true
      %false = arith.constant false
      %20:2 = scf.for %arg7 = %c0_i32 to %9 step %c1_i32 iter_args(%arg8 = %cst, %arg9 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32)  : i32 {
        %23 = tt.experimental_descriptor_load %arg0[%18, %arg9] : !tt.ptr<i8, 0> -> tensor<128x64xf16, #blocked>
        %24 = triton_gpu.local_alloc %23 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>
        %25 = tt.experimental_descriptor_load %arg1[%arg9, %19] : !tt.ptr<i8, 0> -> tensor<64x256xf16, #blocked1>
        %26 = triton_gpu.local_alloc %25 : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory>
        %27 = triton_nvidia_gpu.warp_group_dot %24, %26, %arg8 {inputPrecision = 0 : i32} : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma>
        %28 = arith.addi %arg9, %c64_i32 : i32
        scf.yield %27, %28 : tensor<128x256xf32, #mma>, i32
      }
      %21 = arith.truncf %20#0 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma>
      %22 = triton_gpu.convert_layout %21 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1>
      tt.experimental_descriptor_store %arg2[%18, %19], %22 : !tt.ptr<i8, 0>, tensor<128x256xf16, #blocked1>
    }
    tt.return
  }
}