// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s
// Check that we place local_alloc, local_store (optional) and local_load right after definition of their operands
// in cases where local_alloc is in the loop but it's operand is not.
// This is useful for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers
// throughout the computation.
// CHECK-LABEL: hoist_q_out_of_the_loop
// CHECK: %[[TRUNCF:.+]] = arith.truncf
// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[TRUNCF]]
// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]]
// CHECK: scf.for
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 1.44269502 : f32
%c128_i32 = arith.constant 128 : i32
%c128_i64 = arith.constant 128 : i64
%c0_i64 = arith.constant 0 : i64
%cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma>
%1 = tt.get_program_id y : i32
%2 = arith.muli %1, %arg7 : i32
%3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32
%12 = tt.splat %3 : !tt.ptr<f16> -> tensor<256x128x!tt.ptr<f16>, #blocked1>
%41 = tt.load %12 : tensor<256x128x!tt.ptr<f16>, #blocked1>
%42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1>
%43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1>
%44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1>
%45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
%54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 {
%73 = tt.splat %3 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked2>
%74 = tt.load %73 : tensor<128x128x!tt.ptr<f16>, #blocked2>
%75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory>
%76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
%77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory>
%78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
%79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma>
%107 = arith.addi %arg26, %c128_i64 : i64
scf.yield %107 : i64
} {tt.divisibility_arg1 = dense<128> : tensor<1xi32>}
tt.return
}
}
// -----
// Check that reordering described in hoist_q_out_of_the_loop is not done in the case where both
// local_alloc and it's src tensor defining op are in the loop.
// CHECK-LABEL: no_hoist_q_type_reordering
// CHECK: scf.for
// CHECK: %[[TRUNCF:.+]] = arith.truncf
// CHECK-NEXT: arith.constant
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 1.44269502 : f32
%c128_i32 = arith.constant 128 : i32
%c128_i64 = arith.constant 128 : i64
%c0_i64 = arith.constant 0 : i64
%1 = tt.get_program_id y : i32
%2 = arith.muli %1, %arg7 : i32
%3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32
%12 = tt.splat %3 : !tt.ptr<f16> -> tensor<256x128x!tt.ptr<f16>, #blocked1>
%41 = tt.load %12 : tensor<256x128x!tt.ptr<f16>, #blocked1>
%42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1>
%43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1>
%44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1>
%54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 {
%45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma>
%73 = tt.splat %3 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked2>
%74 = tt.load %73 : tensor<128x128x!tt.ptr<f16>, #blocked2>
%75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory>
%76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
%77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory>
%78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
%79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma>
%107 = arith.addi %arg26, %c128_i64 : i64
scf.yield %107 : i64
} {tt.divisibility_arg1 = dense<128> : tensor<1xi32>}
tt.return
}
}
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
// CHECK-LABEL: order_load_alloc_local_load_local_store
// CHECK: %[[LOAD:.+]] = tt.load
// CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc
// CHECK: triton_gpu.local_store %[[LOAD]], %[[ALLOC]]
// CHECK: triton_gpu.local_load %[[ALLOC]]
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @order_load_alloc_local_load_local_store(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked>) attributes {noinline = false} {
%9 = tt.load %arg0 : tensor<32x32x!tt.ptr<f32>, #blocked>
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
%10 = triton_gpu.local_alloc : () -> !tt.memdesc<32x32xf32, #shared, mutable>
triton_gpu.local_store %9, %10 : tensor<32x32xf32, #blocked> -> !tt.memdesc<32x32xf32, #shared, mutable>
%cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared, mutable> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma>
%13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked>
tt.store %arg0, %13 : tensor<32x32x!tt.ptr<f32>, #blocked>
tt.return
}
}
// -----
// Move loads (and independent local_stores) as early as possible.
// For example in the matmul_loop below, the scf.for loop looks like this after pipeliner:
// scf.for ... {
// // stage 1
// %a = tt.local_load %a_tile
// %b = tt.local_load %b_tile
// tt.dot %c, %a, %b
// // stage 0
// %aptr = tt.addptr %aptr, %k
// %a_next = tt.load %aptr
// %bptr = tt.addptr %bptr, %k
// %b_next = tt.load %bptr
// tt.local_store %a_next
// tt.local_store %b_next
// yield
// }
//
// Solution for num_stages=2 :
// scf.for ... {
// // stage 0.a
// %aptr = tt.addptr %aptr, %k
// %a_next = tt.load %aptr
// %bptr = tt.addptr %bptr, %k
// %b_next = tt.load %bptr
// // stage 1
// %a = tt.local_load %a_tile
// %b = tt.local_load %b_tile
// tt.dot %c, %a, %b
// // stage 0.b
// tt.local_store %a_next
// tt.local_store %b_next
// yield
// }
//
// Solution for num_stages=3 (double-buffered) :
// scf.for ... {
// // stage 1
// tt.local_store %a_next_1
// tt.local_store %b_next_1
// // stage 0
// %aptr = tt.addptr %aptr, %k
// %a_next_2 = tt.load %aptr
// %bptr = tt.addptr %bptr, %k
// %b_next_2 = tt.load %bptr
// // stage 2
// %a = tt.local_load %a_tile
// %b = tt.local_load %b_tile
// tt.dot %c, %a, %b
// yield
// }
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
#shared2 = #triton_gpu.shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], hasLeadingOffset = false}>
#shared3 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
#shared4 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32, triton_gpu.target = "hip:gfx942"} {
// CHECK-LABEL: tt.func @matmul_loop
// CHECK: %{{.*}}:6 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}})
// Stage 0.a
// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG6]], %{{.*}}
// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}}
// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_21]]
// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]]
// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]]
// CHECK: %[[ADDPTR_25:.*]] = tt.addptr %[[ARG7]], %{{.*}}
// CHECK: %[[SPLAT_26:.*]] = tt.splat %[[CMPI_22]]
// CHECK: %[[LOAD_27:.*]] = tt.load %[[ADDPTR_25]], %[[SPLAT_26]]
// Stage 1
// CHECK: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %[[ARG10]]
// CHECK: %[[LOCAL_LOAD_29:.*]] = triton_gpu.local_load %[[ARG11]]
// CHECK: %[[MULF_30:.*]] = arith.mulf %[[LOCAL_LOAD_29]], %{{.*}}
// CHECK: %[[DOT_31:.*]] = tt.dot %[[LOCAL_LOAD_28]], %[[MULF_30]], %[[ARG8]]
// Stage 0.b
// CHECK: %[[ADDI_32:.*]] = arith.addi %[[ARG9]], %{{.*}}
// CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ADDI_32]], %{{.*}}
// CHECK: %[[SELECT_34:.*]] = arith.select %[[CMPI_33]], %[[ADDI_32]], %{{.*}}
// CHECK: %[[MEMDESC_SUBVIEW_35:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_35]]
// CHECK: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_34]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[LOAD_27]], %[[MEMDESC_SUBVIEW_36]]
// CHECK: scf.yield %[[ADDPTR_20]], %[[ADDPTR_25]], %[[DOT_31]], %[[SELECT_34]], %[[MEMDESC_SUBVIEW_35]], %[[MEMDESC_SUBVIEW_36]]
// CHECK: }
tt.func @matmul_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> {
%c1_i32 = arith.constant 1 : i32
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked>
%cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked>
%0 = tt.splat %arg3 : !tt.ptr<f16> -> tensor<128x32x!tt.ptr<f16>, #blocked1>
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1>
%3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1>
%4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<128x32xi32, #blocked1>
%5 = tt.splat %arg4 : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #blocked>
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked>
%8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked>
%9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<32x128xi32, #blocked>
%10 = triton_gpu.local_alloc : () -> !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable>
%11 = triton_gpu.local_alloc : () -> !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
%12 = arith.cmpi slt, %arg0, %arg1 : index
%13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1>
%14 = tt.load %4, %13 : tensor<128x32x!tt.ptr<f16>, #blocked1>
%15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked>
%16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr<f16>, #blocked>
%17 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>
%18 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
%19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>) {
%20 = arith.subi %arg1, %arg2 : index
%21 = arith.cmpi slt, %arg5, %20 : index
%22 = triton_gpu.local_load %arg10 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%23 = triton_gpu.local_load %arg11 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%24 = arith.mulf %23, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%25 = tt.dot %22, %24, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
%26 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<128x32xi32, #blocked1>
%27 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<32x128xi32, #blocked>
%28 = tt.splat %21 : i1 -> tensor<128x32xi1, #blocked1>
%29 = tt.load %26, %28 : tensor<128x32x!tt.ptr<f16>, #blocked1>
%30 = tt.splat %21 : i1 -> tensor<32x128xi1, #blocked>
%31 = tt.load %27, %30, %cst_3 : tensor<32x128x!tt.ptr<f16>, #blocked>
%32 = arith.addi %arg9, %c1_i32 : i32
%33 = arith.cmpi slt, %32, %c1_i32 : i32
%34 = arith.select %33, %32, %c0_i32 : i32
%35 = triton_gpu.memdesc_subview %10[%34, %c0_i32, %c0_i32] : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %29, %35 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>
%36 = triton_gpu.memdesc_subview %11[%34, %c0_i32, %c0_i32] : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %31, %36 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
}
triton_gpu.local_dealloc %10 : !tt.memdesc<1x128x32xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_gpu.local_dealloc %11 : !tt.memdesc<1x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
tt.return %19#2 : tensor<128x128xf32, #mma>
}
// This example tests that tt.load overlaps with independent ttg.local_store which
// overlaps with independent tt.dot.
// num_stages == 3, double buffered
// CHECK-LABEL: tt.func @matmul_loop_mb
// CHECK: %{{.*}}:8 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}})
// Stage 0
// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG6]], %{{.*}}
// CHECK: %[[MULI_29:.*]] = arith.muli %{{.*}}, %{{.*}}
// CHECK: %[[SUBI_30:.*]] = arith.subi %{{.*}}, %[[MULI_29]]
// CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_30]]
// CHECK: %[[SPLAT_32:.*]] = tt.splat %[[CMPI_31]]
// CHECK: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_32]]
// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG7]], %{{.*}}
// CHECK: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_31]]
// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_35]]
// Stage 1
// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG9]], %{{.*}}
// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}}
// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}}
// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_40]]
// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_41]]
// Stage 2
// CHECK: %[[LOCAL_LOAD_42:.*]] = triton_gpu.local_load %[[ARG10]]
// CHECK: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[ARG11]]
// CHECK: %[[MULF_44:.*]] = arith.mulf %[[LOCAL_LOAD_43]], %{{.*}}
// CHECK: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_42]], %[[MULF_44]], %[[ARG8]]
// CHECK: scf.yield %[[ADDPTR_28]], %[[ADDPTR_34]], %[[DOT_45]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]], %[[LOAD_33]], %[[LOAD_36]]
// CHECK: }
tt.func @matmul_loop_mb(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> {
%c2 = arith.constant 2 : index
%c2_i32 = arith.constant 2 : i32
%c1_i32 = arith.constant 1 : i32
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked>
%cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked>
%0 = tt.splat %arg3 : !tt.ptr<f16> -> tensor<128x32x!tt.ptr<f16>, #blocked1>
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1>
%3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1>
%4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<128x32xi32, #blocked1>
%5 = tt.splat %arg4 : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #blocked>
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked>
%8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked>
%9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<32x128xi32, #blocked>
%10 = triton_gpu.local_alloc : () -> !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable>
%11 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
%12 = arith.cmpi slt, %arg0, %arg1 : index
%13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1>
%14 = tt.load %4, %13 : tensor<128x32x!tt.ptr<f16>, #blocked1>
%15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked>
%16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr<f16>, #blocked>
%17 = arith.addi %arg0, %arg2 : index
%18 = arith.cmpi slt, %17, %arg1 : index
%19 = tt.addptr %4, %cst_1 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<128x32xi32, #blocked1>
%20 = tt.addptr %9, %cst_0 : tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<32x128xi32, #blocked>
%21 = tt.splat %18 : i1 -> tensor<128x32xi1, #blocked1>
%22 = tt.load %19, %21 : tensor<128x32x!tt.ptr<f16>, #blocked1>
%23 = tt.splat %18 : i1 -> tensor<32x128xi1, #blocked>
%24 = tt.load %20, %23, %cst_3 : tensor<32x128x!tt.ptr<f16>, #blocked>
%25 = triton_gpu.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %14, %25 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>
%26 = triton_gpu.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %16, %26 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
%27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) {
%28 = arith.muli %arg2, %c2 : index
%29 = arith.subi %arg1, %28 : index
%30 = arith.cmpi slt, %arg5, %29 : index
%31 = triton_gpu.local_load %arg10 : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%32 = triton_gpu.local_load %arg11 : !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%33 = arith.mulf %32, %cst : tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%34 = tt.dot %31, %33, %arg8 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
%35 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<128x32xi32, #blocked1>
%36 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<32x128xi32, #blocked>
%37 = tt.splat %30 : i1 -> tensor<128x32xi1, #blocked1>
%38 = tt.load %35, %37 : tensor<128x32x!tt.ptr<f16>, #blocked1>
%39 = tt.splat %30 : i1 -> tensor<32x128xi1, #blocked>
%40 = tt.load %36, %39, %cst_3 : tensor<32x128x!tt.ptr<f16>, #blocked>
%41 = arith.addi %arg9, %c1_i32 : i32
%42 = arith.cmpi slt, %41, %c2_i32 : i32
%43 = arith.select %42, %41, %c0_i32 : i32
%44 = triton_gpu.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %arg12, %44 : tensor<128x32xf16, #blocked1> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>
%45 = triton_gpu.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %arg13, %45 : tensor<32x128xf16, #blocked> -> !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable>, !tt.memdesc<32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>
}
triton_gpu.local_dealloc %10 : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable>
triton_gpu.local_dealloc %11 : !tt.memdesc<2x32x128xf16, #shared1, #triton_gpu.shared_memory, mutable>
tt.return %27#2 : tensor<128x128xf32, #mma>
}
// This example shows dependent loads and verifies all are moved early.
// CHECK-LABEL: tt.func @indirect_bmm_vector
// CHECK: %{{.*}}:7 = scf.for %[[ARG6:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}})
// Stage 0
// CHECK: %[[ADDPTR_20:.*]] = tt.addptr %[[ARG8]], %{{.*}}
// CHECK: %[[SUBI_21:.*]] = arith.subi %{{.*}}, %{{.*}}
// CHECK: %[[CMPI_22:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_21]]
// CHECK: %[[SPLAT_23:.*]] = tt.splat %[[CMPI_22]]
// CHECK: %[[LOAD_24:.*]] = tt.load %[[ADDPTR_20]], %[[SPLAT_23]]
// Stage 1.a
// CHECK: %[[EXPAND_DIMS_25:.*]] = tt.expand_dims %[[ARG13]] {axis = 1 : i32}
// CHECK: %[[BROADCAST_26:.*]] = tt.broadcast %[[EXPAND_DIMS_25]]
// CHECK: %[[MULI_27:.*]] = arith.muli %{{.*}}, %[[BROADCAST_26]]
// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %{{.*}}, %[[MULI_27]]
// CHECK: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_22]]
// CHECK: %[[LOAD_30:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_29]]
// CHECK: %[[ADDPTR_31:.*]] = tt.addptr %[[ARG9]], %{{.*}}
// CHECK: %[[SUBI_32:.*]] = arith.subi %{{.*}}, %{{.*}}
// CHECK: %[[CMPI_33:.*]] = arith.cmpi slt, %[[ARG6]], %[[SUBI_32]]
// CHECK: %[[SPLAT_34:.*]] = tt.splat %[[CMPI_33]]
// CHECK: %[[LOAD_35:.*]] = tt.load %[[ADDPTR_31]], %[[SPLAT_34]]
// Stage 2
// CHECK: %[[LOCAL_LOAD_36:.*]] = triton_gpu.local_load %[[ARG11]]
// CHECK: %[[LOCAL_LOAD_37:.*]] = triton_gpu.local_load %[[ARG12]]
// CHECK: %[[DOT_38:.*]] = tt.dot %[[LOCAL_LOAD_36]], %[[LOCAL_LOAD_37]], %[[ARG7]]
// Stage 1.b
// CHECK: %[[ADDI_39:.*]] = arith.addi %[[ARG10]], %{{.*}}
// CHECK: %[[CMPI_40:.*]] = arith.cmpi slt, %[[ADDI_39]], %{{.*}}
// CHECK: %[[SELECT_41:.*]] = arith.select %[[CMPI_40]], %[[ADDI_39]], %{{.*}}
// CHECK: %[[MEMDESC_SUBVIEW_42:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[LOAD_24]], %[[MEMDESC_SUBVIEW_42]]
// CHECK: %[[MEMDESC_SUBVIEW_43:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_41]], %{{.*}}, %{{.*}}]
// CHECK: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_43]]
// CHECK: scf.yield %[[DOT_38]], %[[ADDPTR_20]], %[[ADDPTR_31]], %[[SELECT_41]], %[[MEMDESC_SUBVIEW_42]], %[[MEMDESC_SUBVIEW_43]], %[[LOAD_35]]
// CHECK: }
tt.func @indirect_bmm_vector(%arg0: tensor<16x16xi64, #blocked> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg1: index, %arg2: tensor<16x16x!tt.ptr<f16>, #blocked1> {tt.contiguity = 2 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<16x!tt.ptr<i64>, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg4: tensor<16x16xi32, #blocked1> {tt.constancy = 16 : i32, tt.divisibility = 16 : i32}, %arg5: tensor<16x16x!tt.ptr<f16>, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) -> tensor<16x16xf32, #mma> {
%c2 = arith.constant 2 : index
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c1_i32 = arith.constant 1 : i32
%cst_0 = arith.constant dense<1> : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>
%1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>
%2 = arith.cmpi sgt, %arg1, %c0 : index
%3 = tt.splat %2 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%4 = tt.load %arg3, %3 : tensor<16x!tt.ptr<i64>, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%5 = arith.cmpi sgt, %arg1, %c1 : index
%6 = tt.addptr %arg3, %cst_0 : tensor<16x!tt.ptr<i64>, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%7 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked1>
%8 = tt.load %arg2, %7 : tensor<16x16x!tt.ptr<f16>, #blocked1>
%9 = tt.expand_dims %4 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked>
%10 = tt.broadcast %9 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked>
%11 = arith.muli %arg0, %10 : tensor<16x16xi64, #blocked>
%12 = tt.addptr %arg5, %11 : tensor<16x16x!tt.ptr<f16>, #blocked>, tensor<16x16xi64, #blocked>
%13 = tt.splat %2 : i1 -> tensor<16x16xi1, #blocked>
%14 = tt.load %12, %13 : tensor<16x16x!tt.ptr<f16>, #blocked>
%15 = tt.splat %5 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%16 = tt.load %6, %15 : tensor<16x!tt.ptr<i64>, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%17 = triton_gpu.memdesc_subview %0[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %8, %17 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>
%18 = triton_gpu.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %14, %18 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>
%19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr<f16>, #blocked1>, tensor<16x!tt.ptr<i64>, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) {
%20 = arith.subi %arg1, %c2 : index
%21 = arith.cmpi slt, %arg6, %20 : index
%22 = arith.subi %arg1, %c1 : index
%23 = arith.cmpi slt, %arg6, %22 : index
%24 = triton_gpu.local_load %arg11 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%25 = triton_gpu.local_load %arg12 : !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%26 = tt.dot %24, %25, %arg7 : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
%27 = tt.addptr %arg8, %arg4 : tensor<16x16x!tt.ptr<f16>, #blocked1>, tensor<16x16xi32, #blocked1>
%28 = tt.addptr %arg9, %cst_0 : tensor<16x!tt.ptr<i64>, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%29 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked1>
%30 = tt.load %27, %29 : tensor<16x16x!tt.ptr<f16>, #blocked1>
%31 = tt.expand_dims %arg13 {axis = 1 : i32} : tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi64, #blocked>
%32 = tt.broadcast %31 : tensor<16x1xi64, #blocked> -> tensor<16x16xi64, #blocked>
%33 = arith.muli %arg0, %32 : tensor<16x16xi64, #blocked>
%34 = tt.addptr %arg5, %33 : tensor<16x16x!tt.ptr<f16>, #blocked>, tensor<16x16xi64, #blocked>
%35 = tt.splat %23 : i1 -> tensor<16x16xi1, #blocked>
%36 = tt.load %34, %35 : tensor<16x16x!tt.ptr<f16>, #blocked>
%37 = tt.splat %21 : i1 -> tensor<16xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%38 = tt.load %28, %37 : tensor<16x!tt.ptr<i64>, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%39 = arith.addi %arg10, %c1_i32 : i32
%40 = arith.cmpi slt, %39, %c1_i32 : i32
%41 = arith.select %40, %39, %c0_i32 : i32
%42 = triton_gpu.memdesc_subview %0[%41, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %30, %42 : tensor<16x16xf16, #blocked1> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>
%43 = triton_gpu.memdesc_subview %1[%41, %c0_i32, %c0_i32] : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %36, %43 : tensor<16x16xf16, #blocked> -> !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>
scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr<f16>, #blocked1>, tensor<16x!tt.ptr<i64>, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, i32, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, !tt.memdesc<16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>, tensor<16xi64, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
}
triton_gpu.local_dealloc %0 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>
triton_gpu.local_dealloc %1 : !tt.memdesc<1x16x16xf16, #shared2, #triton_gpu.shared_memory, mutable>
tt.return %19#0 : tensor<16x16xf32, #mma>
}
}
// -----
// This test ensures that loads will not be moved across `for` loops.
// CHECK-LABEL: tt.func public @_attn_bwd
// CHECK: tt.load
// CHECK: tt.load
// CHECK: scf.for
// CHECK: }
// CHECK: scf.for
// CHECK: }
// Moved before the independent `tt.store` ops but not before the `for` ops.
// CHECK: tt.load
// CHECK: tt.load
// CHECK: tt.load
// CHECK: tt.load
// CHECK: tt.load
// CHECK: tt.load
// CHECK: tt.store
// CHECK: tt.store
// CHECK: scf.for
// CHECK: }
// CHECK: scf.for
// CHECK: }
// CHECK: tt.store
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked3 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
#mma1 = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
#shared2 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1], hasLeadingOffset = false}>
#shared3 = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @_attn_bwd(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c-1_i32 = arith.constant -1 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma>
%c128_i32 = arith.constant 128 : i32
%c8_i32 = arith.constant 8 : i32
%c32_i32 = arith.constant 32 : i32
%c1_i32 = arith.constant 1 : i32
%c0_i32 = arith.constant 0 : i32
%c16_i32 = arith.constant 16 : i32
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
%cst_2 = arith.constant dense<0.693147182> : tensor<128x64xf32, #mma>
%0 = tt.get_program_id z : i32
%1 = arith.muli %0, %arg14 : i32
%2 = arith.extsi %1 : i32 to i64
%3 = arith.remsi %0, %arg13 : i32
%4 = arith.muli %arg11, %3 : i32
%5 = arith.divsi %0, %arg13 : i32
%6 = arith.muli %arg10, %5 : i32
%7 = arith.addi %4, %6 : i32
%8 = arith.extsi %7 : i32 to i64
%9 = tt.get_program_id x : i32
%10 = tt.addptr %arg0, %8 : !tt.ptr<f16>, i64
%11 = tt.addptr %arg1, %8 : !tt.ptr<f16>, i64
%12 = tt.addptr %arg2, %8 : !tt.ptr<f16>, i64
%13 = tt.addptr %arg4, %8 : !tt.ptr<f16>, i64
%14 = tt.addptr %arg5, %8 : !tt.ptr<f16>, i64
%15 = tt.addptr %arg6, %8 : !tt.ptr<f16>, i64
%16 = tt.addptr %arg7, %8 : !tt.ptr<f16>, i64
%17 = tt.addptr %arg8, %2 : !tt.ptr<f32>, i64
%18 = tt.addptr %arg9, %2 : !tt.ptr<f32>, i64
%19 = arith.muli %9, %c128_i32 : i32
%20 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%22 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
%23 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
%24 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%25 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%26 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%27 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
%28 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
%29 = tt.splat %19 : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%30 = arith.addi %25, %20 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%31 = arith.addi %26, %21 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%32 = arith.addi %27, %22 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
%33 = arith.addi %28, %23 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
%34 = arith.addi %29, %24 : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%35 = tt.expand_dims %30 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma>
%36 = tt.expand_dims %31 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked>
%37 = tt.expand_dims %32 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xi32, #mma1>
%38 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #mma>
%39 = tt.splat %arg12 : i32 -> tensor<128x1xi32, #blocked>
%40 = arith.muli %35, %38 : tensor<128x1xi32, #mma>
%41 = arith.muli %36, %39 : tensor<128x1xi32, #blocked>
%42 = tt.splat %11 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked>
%43 = tt.addptr %42, %41 : tensor<128x1x!tt.ptr<f16>, #blocked>, tensor<128x1xi32, #blocked>
%44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
%45 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
%46 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%47 = tt.expand_dims %44 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x64xi32, #mma>
%48 = tt.expand_dims %45 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
%49 = tt.expand_dims %46 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
%50 = tt.broadcast %43 : tensor<128x1x!tt.ptr<f16>, #blocked> -> tensor<128x64x!tt.ptr<f16>, #blocked>
%51 = tt.broadcast %47 : tensor<1x64xi32, #mma> -> tensor<128x64xi32, #mma>
%52 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked>
%53 = tt.addptr %50, %52 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi32, #blocked>
%54 = tt.load %53 : tensor<128x64x!tt.ptr<f16>, #blocked>
%55 = tt.splat %12 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked>
%56 = tt.addptr %55, %41 : tensor<128x1x!tt.ptr<f16>, #blocked>, tensor<128x1xi32, #blocked>
%57 = tt.broadcast %56 : tensor<128x1x!tt.ptr<f16>, #blocked> -> tensor<128x64x!tt.ptr<f16>, #blocked>
%58 = tt.addptr %57, %52 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi32, #blocked>
%59 = tt.load %58 : tensor<128x64x!tt.ptr<f16>, #blocked>
%60 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%61 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>>
%62 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%63 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%64 = tt.splat %19 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%65 = arith.addi %63, %60 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>>
%66 = arith.addi %64, %62 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
%67 = tt.expand_dims %65 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16xi32, #blocked2>
%68 = tt.splat %arg12 : i32 -> tensor<1x16xi32, #blocked2>
%69 = arith.muli %67, %68 : tensor<1x16xi32, #blocked2>
%70 = tt.splat %10 : !tt.ptr<f16> -> tensor<1x16x!tt.ptr<f16>, #blocked2>
%71 = tt.addptr %70, %69 : tensor<1x16x!tt.ptr<f16>, #blocked2>, tensor<1x16xi32, #blocked2>
%72 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>>
%73 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>>
%74 = tt.expand_dims %72 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<64x1xi32, #blocked2>
%75 = tt.expand_dims %73 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<64x1xi32, #blocked3>
%76 = tt.broadcast %71 : tensor<1x16x!tt.ptr<f16>, #blocked2> -> tensor<64x16x!tt.ptr<f16>, #blocked2>
%77 = tt.broadcast %74 : tensor<64x1xi32, #blocked2> -> tensor<64x16xi32, #blocked2>
%78 = tt.addptr %76, %77 : tensor<64x16x!tt.ptr<f16>, #blocked2>, tensor<64x16xi32, #blocked2>
%79 = tt.expand_dims %66 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1>
%80 = tt.splat %arg12 : i32 -> tensor<16x1xi32, #blocked1>
%81 = arith.muli %79, %80 : tensor<16x1xi32, #blocked1>
%82 = tt.splat %13 : !tt.ptr<f16> -> tensor<16x1x!tt.ptr<f16>, #blocked1>
%83 = tt.addptr %82, %81 : tensor<16x1x!tt.ptr<f16>, #blocked1>, tensor<16x1xi32, #blocked1>
%84 = tt.broadcast %83 : tensor<16x1x!tt.ptr<f16>, #blocked1> -> tensor<16x64x!tt.ptr<f16>, #blocked1>
%85 = tt.broadcast %49 : tensor<1x64xi32, #blocked1> -> tensor<16x64xi32, #blocked1>
%86 = tt.addptr %84, %85 : tensor<16x64x!tt.ptr<f16>, #blocked1>, tensor<16x64xi32, #blocked1>
%87 = tt.splat %17 : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>, #triton_gpu.slice<{dim = 0, parent = #mma1}>>
%88 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1>
%89 = tt.splat %18 : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>, #triton_gpu.slice<{dim = 0, parent = #mma1}>>
%90 = arith.muli %arg12, %c16_i32 : i32
%91 = tt.splat %90 : i32 -> tensor<64x16xi32, #blocked2>
%92 = tt.splat %90 : i32 -> tensor<16x64xi32, #blocked1>
%93:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %cst_1, %arg18 = %19, %arg19 = %78, %arg20 = %86) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr<f16>, #blocked2>, tensor<16x64x!tt.ptr<f16>, #blocked1>) : i32 {
%206 = tt.load %arg19 : tensor<64x16x!tt.ptr<f16>, #blocked2>
%207 = tt.splat %arg18 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>>
%208 = arith.addi %207, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>>
%209 = tt.addptr %87, %208 : tensor<16x!tt.ptr<f32>, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>>
%210 = tt.load %209 : tensor<16x!tt.ptr<f32>, #triton_gpu.slice<{dim = 0, parent = #mma1}>>
%211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>
%212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
%213 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>
%214 = triton_gpu.local_load %213 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>>
%215 = tt.dot %212, %214, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1>
%216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1>
%217 = tt.broadcast %216 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1>
%218 = arith.subf %215, %217 : tensor<128x16xf32, #mma1>
%219 = math.exp2 %218 : tensor<128x16xf32, #mma1>
%220 = tt.expand_dims %208 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1>
%221 = tt.broadcast %220 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1>
%222 = arith.cmpi sge, %221, %88 : tensor<128x16xi32, #mma1>
%223 = arith.select %222, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1>
%224 = tt.load %arg20 : tensor<16x64x!tt.ptr<f16>, #blocked1>
%225 = arith.truncf %223 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1>
%226 = triton_gpu.local_alloc %225 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory>
%227 = triton_gpu.local_load %226 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%228 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory>
%229 = triton_gpu.local_load %228 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%230 = tt.dot %227, %229, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma>
%231 = tt.addptr %89, %208 : tensor<16x!tt.ptr<f32>, #triton_gpu.slice<{dim = 0, parent = #mma1}>>, tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>>
%232 = tt.load %231 : tensor<16x!tt.ptr<f32>, #triton_gpu.slice<{dim = 0, parent = #mma1}>>
%233 = triton_gpu.local_alloc %224 : (tensor<16x64xf16, #blocked1>) -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory>
%234 = tt.trans %233 {order = array<i32: 1, 0>} : !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>
%235 = triton_gpu.local_load %234 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>>
%236 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>
%237 = triton_gpu.local_load %236 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
%238 = tt.dot %237, %235, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1>
%239 = tt.expand_dims %232 {axis = 0 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xf32, #mma1>
%240 = tt.broadcast %239 : tensor<1x16xf32, #mma1> -> tensor<128x16xf32, #mma1>
%241 = arith.subf %238, %240 : tensor<128x16xf32, #mma1>
%242 = arith.mulf %223, %241 : tensor<128x16xf32, #mma1>
%243 = arith.truncf %242 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1>
%244 = triton_gpu.local_alloc %206 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory>
%245 = tt.trans %244 {order = array<i32: 1, 0>} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory>
%246 = triton_gpu.local_load %245 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%247 = triton_gpu.local_alloc %243 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory>
%248 = triton_gpu.local_load %247 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%249 = tt.dot %248, %246, %arg17 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma>
%250 = arith.addi %arg18, %c16_i32 : i32
%251 = tt.addptr %arg19, %91 : tensor<64x16x!tt.ptr<f16>, #blocked2>, tensor<64x16xi32, #blocked2>
%252 = tt.addptr %arg20, %92 : tensor<16x64x!tt.ptr<f16>, #blocked1>, tensor<16x64xi32, #blocked1>
scf.yield %230, %249, %250, %251, %252 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr<f16>, #blocked2>, tensor<16x64x!tt.ptr<f16>, #blocked1>
}
%94 = arith.addi %19, %c128_i32 : i32
%95 = arith.subi %arg14, %94 : i32
%96 = arith.divsi %95, %c32_i32 : i32
%97 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
%98 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
%99 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%100 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
%101 = tt.splat %94 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%102 = arith.addi %100, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
%103 = arith.addi %101, %99 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%104 = tt.expand_dims %102 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3>
%105 = tt.splat %arg12 : i32 -> tensor<1x32xi32, #blocked3>
%106 = arith.muli %104, %105 : tensor<1x32xi32, #blocked3>
%107 = tt.splat %10 : !tt.ptr<f16> -> tensor<1x32x!tt.ptr<f16>, #blocked3>
%108 = tt.addptr %107, %106 : tensor<1x32x!tt.ptr<f16>, #blocked3>, tensor<1x32xi32, #blocked3>
%109 = tt.broadcast %108 : tensor<1x32x!tt.ptr<f16>, #blocked3> -> tensor<64x32x!tt.ptr<f16>, #blocked3>
%110 = tt.broadcast %75 : tensor<64x1xi32, #blocked3> -> tensor<64x32xi32, #blocked3>
%111 = tt.addptr %109, %110 : tensor<64x32x!tt.ptr<f16>, #blocked3>, tensor<64x32xi32, #blocked3>
%112 = tt.expand_dims %103 {axis = 1 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
%113 = tt.splat %arg12 : i32 -> tensor<32x1xi32, #blocked>
%114 = arith.muli %112, %113 : tensor<32x1xi32, #blocked>
%115 = tt.splat %13 : !tt.ptr<f16> -> tensor<32x1x!tt.ptr<f16>, #blocked>
%116 = tt.addptr %115, %114 : tensor<32x1x!tt.ptr<f16>, #blocked>, tensor<32x1xi32, #blocked>
%117 = tt.broadcast %116 : tensor<32x1x!tt.ptr<f16>, #blocked> -> tensor<32x64x!tt.ptr<f16>, #blocked>
%118 = tt.broadcast %48 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
%119 = tt.addptr %117, %118 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
%120 = tt.splat %17 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>, #triton_gpu.slice<{dim = 0, parent = #mma}>>
%121 = tt.splat %18 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>, #triton_gpu.slice<{dim = 0, parent = #mma}>>
%122 = arith.muli %arg12, %c32_i32 : i32
%123 = tt.splat %122 : i32 -> tensor<64x32xi32, #blocked3>
%124 = tt.splat %122 : i32 -> tensor<32x64xi32, #blocked>
%125:5 = scf.for %arg15 = %c0_i32 to %96 step %c1_i32 iter_args(%arg16 = %93#0, %arg17 = %93#1, %arg18 = %94, %arg19 = %111, %arg20 = %119) -> (tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr<f16>, #blocked3>, tensor<32x64x!tt.ptr<f16>, #blocked>) : i32 {
%206 = tt.load %arg19 : tensor<64x32x!tt.ptr<f16>, #blocked3>
%207 = tt.splat %arg18 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
%208 = arith.addi %207, %98 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
%209 = tt.addptr %120, %208 : tensor<32x!tt.ptr<f32>, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
%210 = tt.load %209 : tensor<32x!tt.ptr<f32>, #triton_gpu.slice<{dim = 0, parent = #mma}>>
%211 = triton_gpu.local_alloc %54 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>
%212 = triton_gpu.local_load %211 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%213 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory>
%214 = triton_gpu.local_load %213 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%215 = tt.dot %212, %214, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma>
%216 = tt.expand_dims %210 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma>
%217 = tt.broadcast %216 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma>
%218 = arith.subf %215, %217 : tensor<128x32xf32, #mma>
%219 = math.exp2 %218 : tensor<128x32xf32, #mma>
%220 = tt.load %arg20 : tensor<32x64x!tt.ptr<f16>, #blocked>
%221 = arith.truncf %219 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma>
%222 = triton_gpu.convert_layout %221 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%223 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory>
%224 = triton_gpu.local_load %223 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%225 = tt.dot %222, %224, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma>
%226 = tt.addptr %121, %208 : tensor<32x!tt.ptr<f32>, #triton_gpu.slice<{dim = 0, parent = #mma}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #mma}>>
%227 = tt.load %226 : tensor<32x!tt.ptr<f32>, #triton_gpu.slice<{dim = 0, parent = #mma}>>
%228 = triton_gpu.local_alloc %220 : (tensor<32x64xf16, #blocked>) -> !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory>
%229 = tt.trans %228 {order = array<i32: 1, 0>} : !tt.memdesc<32x64xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory>
%230 = triton_gpu.local_load %229 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%231 = triton_gpu.local_alloc %59 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>
%232 = triton_gpu.local_load %231 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%233 = tt.dot %232, %230, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma>
%234 = tt.expand_dims %227 {axis = 0 : i32} : tensor<32xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xf32, #mma>
%235 = tt.broadcast %234 : tensor<1x32xf32, #mma> -> tensor<128x32xf32, #mma>
%236 = arith.subf %233, %235 : tensor<128x32xf32, #mma>
%237 = arith.mulf %219, %236 : tensor<128x32xf32, #mma>
%238 = arith.truncf %237 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma>
%239 = triton_gpu.local_alloc %206 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory>
%240 = tt.trans %239 {order = array<i32: 1, 0>} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory>
%241 = triton_gpu.local_load %240 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%242 = triton_gpu.convert_layout %238 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%243 = tt.dot %242, %241, %arg17 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma>
%244 = arith.addi %arg18, %c32_i32 : i32
%245 = tt.addptr %arg19, %123 : tensor<64x32x!tt.ptr<f16>, #blocked3>, tensor<64x32xi32, #blocked3>
%246 = tt.addptr %arg20, %124 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
scf.yield %225, %243, %244, %245, %246 : tensor<128x64xf32, #mma>, tensor<128x64xf32, #mma>, i32, tensor<64x32x!tt.ptr<f16>, #blocked3>, tensor<32x64x!tt.ptr<f16>, #blocked>
}
%126 = tt.splat %16 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #mma>
%127 = tt.addptr %126, %40 : tensor<128x1x!tt.ptr<f16>, #mma>, tensor<128x1xi32, #mma>
%128 = tt.broadcast %127 : tensor<128x1x!tt.ptr<f16>, #mma> -> tensor<128x64x!tt.ptr<f16>, #mma>
%129 = tt.addptr %128, %51 : tensor<128x64x!tt.ptr<f16>, #mma>, tensor<128x64xi32, #mma>
%130 = arith.truncf %125#0 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
tt.store %129, %130 : tensor<128x64x!tt.ptr<f16>, #mma>
%131 = tt.splat %arg3 : f32 -> tensor<128x64xf32, #mma>
%132 = arith.mulf %125#1, %131 : tensor<128x64xf32, #mma>
%133 = tt.splat %15 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #mma>
%134 = tt.addptr %133, %40 : tensor<128x1x!tt.ptr<f16>, #mma>, tensor<128x1xi32, #mma>
%135 = tt.broadcast %134 : tensor<128x1x!tt.ptr<f16>, #mma> -> tensor<128x64x!tt.ptr<f16>, #mma>
%136 = tt.addptr %135, %51 : tensor<128x64x!tt.ptr<f16>, #mma>, tensor<128x64xi32, #mma>
%137 = arith.truncf %132 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
tt.store %136, %137 : tensor<128x64x!tt.ptr<f16>, #mma>
%138 = tt.splat %10 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked>
%139 = tt.addptr %138, %41 : tensor<128x1x!tt.ptr<f16>, #blocked>, tensor<128x1xi32, #blocked>
%140 = tt.broadcast %139 : tensor<128x1x!tt.ptr<f16>, #blocked> -> tensor<128x64x!tt.ptr<f16>, #blocked>
%141 = tt.addptr %140, %52 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi32, #blocked>
%142 = tt.load %141 : tensor<128x64x!tt.ptr<f16>, #blocked>
%143 = tt.splat %13 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked>
%144 = tt.addptr %143, %41 : tensor<128x1x!tt.ptr<f16>, #blocked>, tensor<128x1xi32, #blocked>
%145 = tt.broadcast %144 : tensor<128x1x!tt.ptr<f16>, #blocked> -> tensor<128x64x!tt.ptr<f16>, #blocked>
%146 = tt.addptr %145, %52 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi32, #blocked>
%147 = tt.load %146 : tensor<128x64x!tt.ptr<f16>, #blocked>
%148 = tt.splat %17 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
%149 = tt.splat %17 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%150 = tt.addptr %148, %33 : tensor<128x!tt.ptr<f32>, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
%151 = tt.addptr %149, %34 : tensor<128x!tt.ptr<f32>, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%152 = tt.load %150 : tensor<128x!tt.ptr<f32>, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
%153 = tt.load %151 : tensor<128x!tt.ptr<f32>, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%154 = tt.expand_dims %152 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1>
%155 = tt.expand_dims %153 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma>
%156 = tt.splat %11 : !tt.ptr<f16> -> tensor<1x16x!tt.ptr<f16>, #blocked2>
%157 = tt.addptr %156, %69 : tensor<1x16x!tt.ptr<f16>, #blocked2>, tensor<1x16xi32, #blocked2>
%158 = tt.broadcast %157 : tensor<1x16x!tt.ptr<f16>, #blocked2> -> tensor<64x16x!tt.ptr<f16>, #blocked2>
%159 = tt.addptr %158, %77 : tensor<64x16x!tt.ptr<f16>, #blocked2>, tensor<64x16xi32, #blocked2>
%160 = tt.splat %12 : !tt.ptr<f16> -> tensor<1x16x!tt.ptr<f16>, #blocked2>
%161 = tt.addptr %160, %69 : tensor<1x16x!tt.ptr<f16>, #blocked2>, tensor<1x16xi32, #blocked2>
%162 = tt.broadcast %161 : tensor<1x16x!tt.ptr<f16>, #blocked2> -> tensor<64x16x!tt.ptr<f16>, #blocked2>
%163 = tt.addptr %162, %77 : tensor<64x16x!tt.ptr<f16>, #blocked2>, tensor<64x16xi32, #blocked2>
%164 = tt.splat %18 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
%165 = tt.splat %18 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%166 = tt.addptr %164, %33 : tensor<128x!tt.ptr<f32>, #triton_gpu.slice<{dim = 1, parent = #mma1}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
%167 = tt.addptr %165, %34 : tensor<128x!tt.ptr<f32>, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%168 = tt.load %166 : tensor<128x!tt.ptr<f32>, #triton_gpu.slice<{dim = 1, parent = #mma1}>>
%169 = tt.load %167 : tensor<128x!tt.ptr<f32>, #triton_gpu.slice<{dim = 1, parent = #mma}>>
%170 = tt.broadcast %154 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1>
%171 = tt.broadcast %37 : tensor<128x1xi32, #mma1> -> tensor<128x16xi32, #mma1>
%172 = tt.expand_dims %168 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma1}>> -> tensor<128x1xf32, #mma1>
%173 = tt.broadcast %172 : tensor<128x1xf32, #mma1> -> tensor<128x16xf32, #mma1>
%174 = arith.muli %arg12, %c16_i32 : i32
%175 = tt.splat %174 : i32 -> tensor<64x16xi32, #blocked2>
%176 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable>
%177:5 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_1, %arg17 = %19, %arg18 = %159, %arg19 = %163, %arg20 = %c-1_i32) -> (tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr<f16>, #blocked2>, tensor<64x16x!tt.ptr<f16>, #blocked2>, i32) : i32 {
%206 = arith.addi %arg20, %c1_i32 : i32
%207 = arith.cmpi slt, %206, %c1_i32 : i32
%208 = arith.select %207, %206, %c0_i32 : i32
%209 = tt.load %arg18 : tensor<64x16x!tt.ptr<f16>, #blocked2>
%210 = tt.load %arg19 : tensor<64x16x!tt.ptr<f16>, #blocked2>
%211 = triton_gpu.memdesc_subview %176[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %210, %211 : tensor<64x16xf16, #blocked2> -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable>
%212 = triton_gpu.local_load %211 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>>
%213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>
%214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
%215 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>
%216 = triton_gpu.local_load %215 : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>>
%217 = tt.dot %214, %216, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1>
%218 = arith.subf %217, %170 : tensor<128x16xf32, #mma1>
%219 = math.exp2 %218 : tensor<128x16xf32, #mma1>
%220 = tt.splat %arg17 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>>
%221 = arith.addi %220, %61 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>>
%222 = tt.expand_dims %221 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<1x16xi32, #mma1>
%223 = tt.broadcast %222 : tensor<1x16xi32, #mma1> -> tensor<128x16xi32, #mma1>
%224 = arith.cmpi sge, %171, %223 : tensor<128x16xi32, #mma1>
%225 = arith.select %224, %219, %cst_0 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1>
%226 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>
%227 = triton_gpu.local_load %226 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
%228 = tt.dot %227, %212, %cst_0 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma1, kWidth = 4}>> -> tensor<128x16xf32, #mma1>
%229 = arith.subf %228, %173 : tensor<128x16xf32, #mma1>
%230 = arith.mulf %225, %229 : tensor<128x16xf32, #mma1>
%231 = arith.truncf %230 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1>
%232 = triton_gpu.local_alloc %209 : (tensor<64x16xf16, #blocked2>) -> !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory>
%233 = tt.trans %232 {order = array<i32: 1, 0>} : !tt.memdesc<64x16xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory>
%234 = triton_gpu.local_load %233 : !tt.memdesc<16x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%235 = triton_gpu.local_alloc %231 : (tensor<128x16xf16, #mma1>) -> !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory>
%236 = triton_gpu.local_load %235 : !tt.memdesc<128x16xf16, #shared2, #triton_gpu.shared_memory> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%237 = tt.dot %236, %234, %arg16 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma>
%238 = arith.addi %arg17, %c16_i32 : i32
%239 = tt.addptr %arg18, %175 : tensor<64x16x!tt.ptr<f16>, #blocked2>, tensor<64x16xi32, #blocked2>
%240 = tt.addptr %arg19, %175 : tensor<64x16x!tt.ptr<f16>, #blocked2>, tensor<64x16xi32, #blocked2>
scf.yield %237, %238, %239, %240, %208 : tensor<128x64xf32, #mma>, i32, tensor<64x16x!tt.ptr<f16>, #blocked2>, tensor<64x16x!tt.ptr<f16>, #blocked2>, i32
}
triton_gpu.local_dealloc %176 : !tt.memdesc<1x64x16xf16, #shared1, #triton_gpu.shared_memory, mutable>
%178 = arith.divsi %19, %c32_i32 : i32
%179 = arith.muli %178, %c32_i32 : i32
%180 = arith.subi %19, %179 : i32
%181 = tt.splat %180 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
%182 = arith.addi %181, %97 : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>>
%183 = tt.expand_dims %182 {axis = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x32xi32, #blocked3>
%184 = arith.muli %183, %105 : tensor<1x32xi32, #blocked3>
%185 = tt.splat %11 : !tt.ptr<f16> -> tensor<1x32x!tt.ptr<f16>, #blocked3>
%186 = tt.addptr %185, %184 : tensor<1x32x!tt.ptr<f16>, #blocked3>, tensor<1x32xi32, #blocked3>
%187 = tt.broadcast %186 : tensor<1x32x!tt.ptr<f16>, #blocked3> -> tensor<64x32x!tt.ptr<f16>, #blocked3>
%188 = tt.addptr %187, %110 : tensor<64x32x!tt.ptr<f16>, #blocked3>, tensor<64x32xi32, #blocked3>
%189 = tt.splat %12 : !tt.ptr<f16> -> tensor<1x32x!tt.ptr<f16>, #blocked3>
%190 = tt.addptr %189, %184 : tensor<1x32x!tt.ptr<f16>, #blocked3>, tensor<1x32xi32, #blocked3>
%191 = tt.broadcast %190 : tensor<1x32x!tt.ptr<f16>, #blocked3> -> tensor<64x32x!tt.ptr<f16>, #blocked3>
%192 = tt.addptr %191, %110 : tensor<64x32x!tt.ptr<f16>, #blocked3>, tensor<64x32xi32, #blocked3>
%193 = tt.broadcast %155 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma>
%194 = tt.expand_dims %169 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma>
%195 = tt.broadcast %194 : tensor<128x1xf32, #mma> -> tensor<128x32xf32, #mma>
%196 = arith.muli %arg12, %c32_i32 : i32
%197 = tt.splat %196 : i32 -> tensor<64x32xi32, #blocked3>
%198 = triton_gpu.local_alloc : () -> !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable>
%199:4 = scf.for %arg15 = %c0_i32 to %178 step %c1_i32 iter_args(%arg16 = %177#0, %arg17 = %188, %arg18 = %192, %arg19 = %c-1_i32) -> (tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr<f16>, #blocked3>, tensor<64x32x!tt.ptr<f16>, #blocked3>, i32) : i32 {
%206 = arith.addi %arg19, %c1_i32 : i32
%207 = arith.cmpi slt, %206, %c1_i32 : i32
%208 = arith.select %207, %206, %c0_i32 : i32
%209 = tt.load %arg17 : tensor<64x32x!tt.ptr<f16>, #blocked3>
%210 = tt.load %arg18 : tensor<64x32x!tt.ptr<f16>, #blocked3>
%211 = triton_gpu.memdesc_subview %198[%208, %c0_i32, %c0_i32] : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable>
triton_gpu.local_store %210, %211 : tensor<64x32xf16, #blocked3> -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable>
%212 = triton_gpu.local_load %211 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%213 = triton_gpu.local_alloc %142 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>
%214 = triton_gpu.local_load %213 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%215 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory>
%216 = triton_gpu.local_load %215 : !tt.memdesc<64x32xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%217 = tt.dot %214, %216, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma>
%218 = arith.subf %217, %193 : tensor<128x32xf32, #mma>
%219 = math.exp2 %218 : tensor<128x32xf32, #mma>
%220 = triton_gpu.local_alloc %147 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>
%221 = triton_gpu.local_load %220 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%222 = tt.dot %221, %212, %cst : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x32xf32, #mma>
%223 = arith.subf %222, %195 : tensor<128x32xf32, #mma>
%224 = arith.mulf %219, %223 : tensor<128x32xf32, #mma>
%225 = arith.truncf %224 : tensor<128x32xf32, #mma> to tensor<128x32xf16, #mma>
%226 = triton_gpu.local_alloc %209 : (tensor<64x32xf16, #blocked3>) -> !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory>
%227 = tt.trans %226 {order = array<i32: 1, 0>} : !tt.memdesc<64x32xf16, #shared2, #triton_gpu.shared_memory> -> !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory>
%228 = triton_gpu.local_load %227 : !tt.memdesc<32x64xf16, #shared3, #triton_gpu.shared_memory> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%229 = triton_gpu.convert_layout %225 : tensor<128x32xf16, #mma> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%230 = tt.dot %229, %228, %arg16 : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x64xf32, #mma>
%231 = tt.addptr %arg17, %197 : tensor<64x32x!tt.ptr<f16>, #blocked3>, tensor<64x32xi32, #blocked3>
%232 = tt.addptr %arg18, %197 : tensor<64x32x!tt.ptr<f16>, #blocked3>, tensor<64x32xi32, #blocked3>
scf.yield %230, %231, %232, %208 : tensor<128x64xf32, #mma>, tensor<64x32x!tt.ptr<f16>, #blocked3>, tensor<64x32x!tt.ptr<f16>, #blocked3>, i32
}
triton_gpu.local_dealloc %198 : !tt.memdesc<1x64x32xf16, #shared1, #triton_gpu.shared_memory, mutable>
%200 = tt.splat %14 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #mma>
%201 = tt.addptr %200, %40 : tensor<128x1x!tt.ptr<f16>, #mma>, tensor<128x1xi32, #mma>
%202 = tt.broadcast %201 : tensor<128x1x!tt.ptr<f16>, #mma> -> tensor<128x64x!tt.ptr<f16>, #mma>
%203 = tt.addptr %202, %51 : tensor<128x64x!tt.ptr<f16>, #mma>, tensor<128x64xi32, #mma>
%204 = arith.mulf %199#0, %cst_2 : tensor<128x64xf32, #mma>
%205 = arith.truncf %204 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma>
tt.store %203, %205 : tensor<128x64x!tt.ptr<f16>, #mma>
tt.return
}
}
// -----
// CHECK-LABEL: sink_convert_dealloc
// CHECK-COUNT-2: triton_gpu.local_dealloc %{{.+}} : !tt.memdesc<4x128x64xf16, #shared, mutable>
// CHECK: triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @sink_convert_dealloc(%arg0: tensor<32x32xf32, #blocked>) attributes {noinline = false} {
%0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable>
%1 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable>
%2 = triton_gpu.convert_layout %arg0 : tensor<32x32xf32, #blocked> -> tensor<32x32xf32, #blocked1>
triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable>
triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable>
%3 = arith.addf %2, %2 : tensor<32x32xf32, #blocked1>
tt.return
}
}
// -----
// CHECK-LABEL: anchor_barrier
// CHECK: gpu.barrier
// CHECK: tt.load %arg0 : tensor<32x32x!tt.ptr<f16>, #blocked>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @anchor_barrier(%arg0: tensor<32x32x!tt.ptr<f16>, #blocked>) attributes {noinline = false} {
%0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x128x64xf16, #shared, mutable>
gpu.barrier
%2 = tt.load %arg0 : tensor<32x32x!tt.ptr<f16>, #blocked>
%1 = triton_gpu.local_alloc %2 : (tensor<32x32xf16, #blocked>) -> !tt.memdesc<4x128x64xf16, #shared, mutable>
triton_gpu.local_dealloc %0 : !tt.memdesc<4x128x64xf16, #shared, mutable>
triton_gpu.local_dealloc %1 : !tt.memdesc<4x128x64xf16, #shared, mutable>
tt.return
}
}