// NOTE: Assertions have been autogenerated by mlir/utils/generate-test-checks.py
// RUN: triton-opt %s -allow-unregistered-dialect -split-input-file -tritonamdgpu-canonicalize-pointers -canonicalize -verify-diagnostics | FileCheck %s
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @conversion1(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.splat %1 : i32 -> tensor<1024xi32>
%3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%4 = tt.addptr %3, %2 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%5 = tt.load %4 : tensor<1024x!tt.ptr<f32>>
tt.return %5 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @conversion1(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32
// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32
// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_4]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_6:.*]] = tt.load %[[VAL_5]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_6]] : tensor<1024xf32>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @conversion2(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%7 = tt.load %6 : tensor<1024x!tt.ptr<f32>>
tt.return %7 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @conversion2(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32
// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32
// CHECK: %[[VAL_4:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_6:.*]] = tt.splat %[[VAL_5]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_6]], %[[VAL_4]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// CHECK: %[[VAL_8:.*]] = tt.load %[[VAL_7]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_8]] : tensor<1024xf32>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @conversion3(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%8 = tt.load %7 : tensor<1024x!tt.ptr<f32>>
tt.return %8 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @conversion3(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32
// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32
// CHECK: %[[VAL_4:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_6:.*]] = arith.extsi %[[VAL_4]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_5]], %[[VAL_3]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_8:.*]] = arith.extsi %[[VAL_4]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_6]] : tensor<1024xi64>
// CHECK: %[[VAL_10:.*]] = tt.splat %[[VAL_7]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_11:.*]] = tt.addptr %[[VAL_10]], %[[VAL_9]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_12:.*]] = tt.load %[[VAL_11]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_12]] : tensor<1024xf32>
// CHECK: }
// -----
//
// This is the same as conversion3, but now the `arith.extsi` operations
// disappeared and all the offsets are 32 bits.
//
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @conversion4(%arg0: !tt.ptr<f32> {tt.pointer_range = 32 : i32}) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%8 = tt.load %7 : tensor<1024x!tt.ptr<f32>>
tt.return %8 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @conversion4(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32> {tt.pointer_range = 32 : i32}) -> tensor<1024xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32
// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32
// CHECK: %[[VAL_4:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_5]], %[[VAL_3]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : tensor<1024xi32>
// CHECK: %[[VAL_8:.*]] = tt.splat %[[VAL_6]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_8]], %[[VAL_7]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// CHECK: %[[VAL_10:.*]] = tt.load %[[VAL_9]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_10]] : tensor<1024xf32>
// CHECK: }
// -----
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @convertLayoutOp(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<i32>, %arg2: tensor<1024xi32, #blocked>) -> tensor<1024xf32, #blocked1> {
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
%1 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
%2 = tt.addptr %1, %arg2 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
%3 = tt.load %2 : tensor<1024x!tt.ptr<i32>, #blocked>
%4 = tt.addptr %0, %3 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
%5 = ttg.convert_layout %4 : tensor<1024x!tt.ptr<f32>, #blocked> -> tensor<1024x!tt.ptr<f32>, #blocked1>
%6 = tt.load %5 : tensor<1024x!tt.ptr<f32>, #blocked1>
tt.return %6 : tensor<1024xf32, #blocked1>
}
}
// CHECK: #[[$ATTR_0:.+]] = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
// CHECK: #[[$ATTR_1:.+]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
// CHECK-LABEL: tt.func public @convertLayoutOp(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>, %[[VAL_1:.*]]: !tt.ptr<i32>, %[[VAL_2:.*]]: tensor<1024xi32, #[[$ATTR_0]]>) -> tensor<1024xf32, #[[$ATTR_1]]> {
// CHECK: %[[VAL_3:.*]] = tt.splat %[[VAL_1]] : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #[[$ATTR_0]]>
// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_3]], %[[VAL_2]] : tensor<1024x!tt.ptr<i32>, #[[$ATTR_0]]>, tensor<1024xi32, #[[$ATTR_0]]>
// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : tensor<1024x!tt.ptr<i32>, #[[$ATTR_0]]>
// CHECK: %[[VAL_6:.*]] = arith.extsi %[[VAL_5]] : tensor<1024xi32, #[[$ATTR_0]]> to tensor<1024xi64, #[[$ATTR_0]]>
// CHECK: %[[VAL_7:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<1024xi64, #[[$ATTR_0]]> -> tensor<1024xi64, #[[$ATTR_1]]>
// CHECK: %[[VAL_8:.*]] = arith.trunci %[[VAL_7]] : tensor<1024xi64, #[[$ATTR_1]]> to tensor<1024xi32, #[[$ATTR_1]]>
// CHECK: %[[VAL_9:.*]] = tt.splat %[[VAL_0]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #[[$ATTR_1]]>
// CHECK: %[[VAL_10:.*]] = tt.addptr %[[VAL_9]], %[[VAL_8]] : tensor<1024x!tt.ptr<f32>, #[[$ATTR_1]]>, tensor<1024xi32, #[[$ATTR_1]]>
// CHECK: %[[VAL_11:.*]] = tt.load %[[VAL_10]] : tensor<1024x!tt.ptr<f32>, #[[$ATTR_1]]>
// CHECK: tt.return %[[VAL_11]] : tensor<1024xf32, #[[$ATTR_1]]>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @forOp(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%7:2 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %6, %arg4 = %arg1) -> (tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>) {
%10 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%11 = tt.load %10 : tensor<1024x!tt.ptr<f32>>
%12 = arith.addf %11, %arg4 : tensor<1024xf32>
scf.yield %10, %12 : tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>
}
%8 = tt.addptr %7#0, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%9 = tt.load %8 : tensor<1024x!tt.ptr<f32>>
tt.return %9 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @forOp(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1024 : i32
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 128 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_2]] : i32
// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_0]], %[[VAL_7]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_10:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_11:.*]]:3 = scf.for %[[VAL_12:.*]] = %[[VAL_3]] to %[[VAL_4]] step %[[VAL_5]] iter_args(%[[VAL_13:.*]] = %[[VAL_9]], %[[VAL_14:.*]] = %[[VAL_10]], %[[VAL_15:.*]] = %[[VAL_1]]) -> (!tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>) {
// CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_13]], %[[VAL_7]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_17:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_14]] : tensor<1024xi64>
// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_16]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_19]], %[[VAL_18]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_21:.*]] = tt.load %[[VAL_20]] : tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_22:.*]] = arith.addf %[[VAL_21]], %[[VAL_15]] : tensor<1024xf32>
// CHECK: scf.yield %[[VAL_16]], %[[VAL_18]], %[[VAL_22]] : !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>
// CHECK: }
// CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_24:.*]]#0, %[[VAL_7]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_25:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_24]]#1 : tensor<1024xi64>
// CHECK: %[[VAL_27:.*]] = tt.splat %[[VAL_23]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_28:.*]] = tt.addptr %[[VAL_27]], %[[VAL_26]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_29:.*]] = tt.load %[[VAL_28]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_29]] : tensor<1024xf32>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @forOp2(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6:2 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %5, %arg4 = %arg1) -> (tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>) {
%9 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%10 = tt.load %9 : tensor<1024x!tt.ptr<f32>>
%11 = arith.addf %10, %arg4 : tensor<1024xf32>
scf.yield %9, %11 : tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>
}
%7 = tt.addptr %6#0, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%8 = tt.load %7 : tensor<1024x!tt.ptr<f32>>
tt.return %8 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @forOp2(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64>
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1024 : i32
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 128 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_3]] : i32
// CHECK: %[[VAL_9:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_10:.*]]:3 = scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_12:.*]] = %[[VAL_0]], %[[VAL_13:.*]] = %[[VAL_2]], %[[VAL_14:.*]] = %[[VAL_1]]) -> (!tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>) {
// CHECK: %[[VAL_15:.*]] = tt.addptr %[[VAL_12]], %[[VAL_8]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_16:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_13]] : tensor<1024xi64>
// CHECK: %[[VAL_18:.*]] = tt.splat %[[VAL_15]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_18]], %[[VAL_17]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_20:.*]] = tt.load %[[VAL_19]] : tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_21:.*]] = arith.addf %[[VAL_20]], %[[VAL_14]] : tensor<1024xf32>
// CHECK: scf.yield %[[VAL_15]], %[[VAL_17]], %[[VAL_21]] : !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>
// CHECK: }
// CHECK: %[[VAL_22:.*]] = tt.addptr %[[VAL_23:.*]]#0, %[[VAL_8]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_24:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_23]]#1 : tensor<1024xi64>
// CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_22]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_26]], %[[VAL_25]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_28:.*]] = tt.load %[[VAL_27]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_28]] : tensor<1024xf32>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @forNested(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6:2 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %5, %arg4 = %arg1) -> (tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>) {
%9:2 = scf.for %arg5 = %c0 to %c128 step %c1 iter_args(%arg6 = %arg3, %arg7 = %arg4) -> (tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>) {
%10 = tt.addptr %arg6, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%11 = tt.load %10 : tensor<1024x!tt.ptr<f32>>
%12 = arith.addf %11, %arg7 : tensor<1024xf32>
scf.yield %10, %12 : tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>
}
scf.yield %9#0, %9#1 : tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>
}
%7 = tt.addptr %6#0, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%8 = tt.load %7 : tensor<1024x!tt.ptr<f32>>
tt.return %8 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @forNested(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64>
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1024 : i32
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 128 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_3]] : i32
// CHECK: %[[VAL_9:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_10:.*]]:3 = scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_12:.*]] = %[[VAL_0]], %[[VAL_13:.*]] = %[[VAL_2]], %[[VAL_14:.*]] = %[[VAL_1]]) -> (!tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>) {
// CHECK: %[[VAL_15:.*]]:3 = scf.for %[[VAL_16:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_14]]) -> (!tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>) {
// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_17]], %[[VAL_8]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_21:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_18]] : tensor<1024xi64>
// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_20]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_24:.*]] = tt.addptr %[[VAL_23]], %[[VAL_22]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_25:.*]] = tt.load %[[VAL_24]] : tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_25]], %[[VAL_19]] : tensor<1024xf32>
// CHECK: scf.yield %[[VAL_20]], %[[VAL_22]], %[[VAL_26]] : !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>
// CHECK: }
// CHECK: scf.yield %[[VAL_27:.*]]#0, %[[VAL_27]]#1, %[[VAL_27]]#2 : !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>
// CHECK: }
// CHECK: %[[VAL_28:.*]] = tt.addptr %[[VAL_29:.*]]#0, %[[VAL_8]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_30:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_29]]#1 : tensor<1024xi64>
// CHECK: %[[VAL_32:.*]] = tt.splat %[[VAL_28]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_33:.*]] = tt.addptr %[[VAL_32]], %[[VAL_31]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_34:.*]] = tt.load %[[VAL_33]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_34]] : tensor<1024xf32>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @ifOp(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>, %arg2: i1) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6 = scf.if %arg2 -> (tensor<1024x!tt.ptr<f32>>) {
%8 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
scf.yield %8 : tensor<1024x!tt.ptr<f32>>
} else {
%8 = tt.addptr %5, %3 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
scf.yield %8 : tensor<1024x!tt.ptr<f32>>
}
%7 = tt.load %6 : tensor<1024x!tt.ptr<f32>>
tt.return %7 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @ifOp(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: i1) -> tensor<1024xf32> {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0> : tensor<1024xi64>
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1024 : i32
// CHECK: %[[VAL_5:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_4]] : i32
// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_8:.*]]:2 = scf.if %[[VAL_2]] -> (!tt.ptr<f32>, tensor<1024xi64>) {
// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_10:.*]] = arith.extsi %[[VAL_7]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : !tt.ptr<f32>, tensor<1024xi64>
// CHECK: } else {
// CHECK: %[[VAL_11:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr<f32>, i32
// CHECK: scf.yield %[[VAL_11]], %[[VAL_3]] : !tt.ptr<f32>, tensor<1024xi64>
// CHECK: }
// CHECK: %[[VAL_12:.*]] = arith.trunci %[[VAL_13:.*]]#1 : tensor<1024xi64> to tensor<1024xi32>
// CHECK: %[[VAL_14:.*]] = tt.splat %[[VAL_13]]#0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_15:.*]] = tt.addptr %[[VAL_14]], %[[VAL_12]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// CHECK: %[[VAL_16:.*]] = tt.load %[[VAL_15]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_16]] : tensor<1024xf32>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @whileOp(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6:2 = scf.while (%arg2 = %5, %arg3 = %2) : (tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>) -> (tensor<1024x!tt.ptr<f32>> , tensor<1024xi32>) {
%8 = "dummy.evaluate_condition"() : () -> i1
scf.condition(%8) %arg2, %arg3 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
} do {
^bb0(%arg2: tensor<1024x!tt.ptr<f32>>, %arg3: tensor<1024xi32>):
%res = tt.addptr %arg2, %arg3 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
scf.yield %res, %arg3 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
}
%7 = tt.load %6#0 : tensor<1024x!tt.ptr<f32>>
tt.return %7 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @whileOp(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> {
// CHECK: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64>
// CHECK: %[[VAL_3:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_4:.*]] = scf.while (%[[VAL_5:.*]] = %[[VAL_2]]) : (tensor<1024xi64>) -> tensor<1024xi64> {
// CHECK: %[[VAL_6:.*]] = "dummy.evaluate_condition"() : () -> i1
// CHECK: scf.condition(%[[VAL_6]]) %[[VAL_5]] : tensor<1024xi64>
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_7:.*]]: tensor<1024xi64>):
// CHECK: %[[VAL_8:.*]] = arith.extsi %[[VAL_3]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : tensor<1024xi64>
// CHECK: scf.yield %[[VAL_9]] : tensor<1024xi64>
// CHECK: }
// CHECK: %[[VAL_10:.*]] = arith.trunci %[[VAL_4]] : tensor<1024xi64> to tensor<1024xi32>
// CHECK: %[[VAL_11:.*]] = tt.splat %[[VAL_0]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_12:.*]] = tt.addptr %[[VAL_11]], %[[VAL_10]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// CHECK: %[[VAL_13:.*]] = tt.load %[[VAL_12]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_13]] : tensor<1024xf32>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @condBranch(%arg0: !tt.ptr<f32>, %arg1: i1) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
cf.cond_br %arg1, ^bb1(%5 : tensor<1024x!tt.ptr<f32>>), ^bb2(%6 : tensor<1024x!tt.ptr<f32>>)
^bb1(%7: tensor<1024x!tt.ptr<f32>>): // pred: ^bb0
%8 = tt.load %7 : tensor<1024x!tt.ptr<f32>>
tt.return %8 : tensor<1024xf32>
^bb2(%9: tensor<1024x!tt.ptr<f32>>): // pred: ^bb0
%10 = tt.load %9 : tensor<1024x!tt.ptr<f32>>
tt.return %10 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @condBranch(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
// CHECK-SAME: %[[VAL_1:.*]]: i1) -> tensor<1024xf32> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64>
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1024 : i32
// CHECK: %[[VAL_4:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_5:.*]] = arith.muli %[[VAL_4]], %[[VAL_3]] : i32
// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_0]], %[[VAL_5]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_8:.*]] = arith.extsi %[[VAL_6]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: cf.cond_br %[[VAL_1]], ^bb1(%[[VAL_0]], %[[VAL_2]] : !tt.ptr<f32>, tensor<1024xi64>), ^bb2(%[[VAL_7]], %[[VAL_8]] : !tt.ptr<f32>, tensor<1024xi64>)
// CHECK: ^bb1(%[[VAL_9:.*]]: !tt.ptr<f32>, %[[VAL_10:.*]]: tensor<1024xi64>):
// CHECK: %[[VAL_11:.*]] = arith.trunci %[[VAL_10]] : tensor<1024xi64> to tensor<1024xi32>
// CHECK: %[[VAL_12:.*]] = tt.splat %[[VAL_9]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_13:.*]] = tt.addptr %[[VAL_12]], %[[VAL_11]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// CHECK: %[[VAL_14:.*]] = tt.load %[[VAL_13]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_14]] : tensor<1024xf32>
// CHECK: ^bb2(%[[VAL_15:.*]]: !tt.ptr<f32>, %[[VAL_16:.*]]: tensor<1024xi64>):
// CHECK: %[[VAL_17:.*]] = arith.trunci %[[VAL_16]] : tensor<1024xi64> to tensor<1024xi32>
// CHECK: %[[VAL_18:.*]] = tt.splat %[[VAL_15]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_18]], %[[VAL_17]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// CHECK: %[[VAL_20:.*]] = tt.load %[[VAL_19]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_20]] : tensor<1024xf32>
// CHECK: }
// -----
// REWRITE branch gets DCEd
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @branch(%arg0: !tt.ptr<f32>, %arg1: i1) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
cf.br ^bb1(%6 : tensor<1024x!tt.ptr<f32>>)
^bb1(%7: tensor<1024x!tt.ptr<f32>>): // pred: ^bb0
%8 = tt.load %7 : tensor<1024x!tt.ptr<f32>>
tt.return %8 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @branch(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
// CHECK-SAME: %[[VAL_1:.*]]: i1) -> tensor<1024xf32> {
// CHECK: %[[VAL_2:.*]] = arith.constant 1024 : i32
// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32
// CHECK: %[[VAL_5:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_0]], %[[VAL_4]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_7:.*]] = tt.splat %[[VAL_6]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_8:.*]] = tt.addptr %[[VAL_7]], %[[VAL_5]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// CHECK: %[[VAL_9:.*]] = tt.load %[[VAL_8]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_9]] : tensor<1024xf32>
// CHECK: }
// -----
// The following is a simple case of a tile offset like: (A*B + C + D) where B,C are Uniform and A,D are not. So
// we expect that the Uniform offset (which can be added to the scalar pointer) will be simply C and the NonUniform
// offset will be A*B+D
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @tile_offset(%arg0: !tt.ptr<f16>, %arg1: i32, %arg2: i32) -> tensor<16x256xf16, #blocked> {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%3 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%4 = arith.addi %3, %2 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%6 = tt.expand_dims %5 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked>
%7 = tt.splat %arg2 : i32 -> tensor<16x1xi32, #blocked>
%8 = arith.muli %6, %7 : tensor<16x1xi32, #blocked>
%9 = tt.expand_dims %4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked>
%10 = tt.broadcast %8 : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked>
%11 = tt.broadcast %9 : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked>
%12 = arith.addi %10, %11 : tensor<16x256xi32, #blocked>
%13 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<16x256x!tt.ptr<f16>, #blocked>
%14 = tt.addptr %13, %12 : tensor<16x256x!tt.ptr<f16>, #blocked>, tensor<16x256xi32, #blocked>
%15 = tt.load %14 : tensor<16x256x!tt.ptr<f16>, #blocked>
tt.return %15 : tensor<16x256xf16, #blocked>
}
}
// CHECK: #[[$ATTR_0:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
// CHECK-LABEL: tt.func @tile_offset(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f16>,
// CHECK-SAME: %[[VAL_1:.*]]: i32,
// CHECK-SAME: %[[VAL_2:.*]]: i32) -> tensor<16x256xf16, #[[$ATTR_0]]> {
// CHECK: %[[VAL_3:.*]] = arith.constant 256 : i32
// CHECK: %[[VAL_4:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_5:.*]] = arith.muli %[[VAL_4]], %[[VAL_3]] : i32
// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #[[$ATTR_0]]}>>
// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
// CHECK: %[[VAL_8:.*]] = tt.expand_dims %[[VAL_7]] {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<16x1xi32, #[[$ATTR_0]]>
// CHECK: %[[VAL_9:.*]] = tt.splat %[[VAL_2]] : i32 -> tensor<16x1xi32, #[[$ATTR_0]]>
// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_8]], %[[VAL_9]] : tensor<16x1xi32, #[[$ATTR_0]]>
// CHECK: %[[VAL_11:.*]] = tt.broadcast %[[VAL_10]] : tensor<16x1xi32, #[[$ATTR_0]]> -> tensor<16x256xi32, #[[$ATTR_0]]>
// CHECK: %[[VAL_12:.*]] = tt.expand_dims %[[VAL_6]] {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #[[$ATTR_0]]}>> -> tensor<1x256xi32, #[[$ATTR_0]]>
// CHECK: %[[VAL_13:.*]] = tt.broadcast %[[VAL_12]] : tensor<1x256xi32, #[[$ATTR_0]]> -> tensor<16x256xi32, #[[$ATTR_0]]>
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_13]] : tensor<16x256xi32, #[[$ATTR_0]]>
// CHECK: %[[VAL_15:.*]] = tt.addptr %[[VAL_0]], %[[VAL_5]] : !tt.ptr<f16>, i32
// CHECK: %[[VAL_16:.*]] = tt.splat %[[VAL_15]] : !tt.ptr<f16> -> tensor<16x256x!tt.ptr<f16>, #[[$ATTR_0]]>
// CHECK: %[[VAL_17:.*]] = tt.addptr %[[VAL_16]], %[[VAL_14]] : tensor<16x256x!tt.ptr<f16>, #[[$ATTR_0]]>, tensor<16x256xi32, #[[$ATTR_0]]>
// CHECK: %[[VAL_18:.*]] = tt.load %[[VAL_17]] : tensor<16x256x!tt.ptr<f16>, #[[$ATTR_0]]>
// CHECK: tt.return %[[VAL_18]] : tensor<16x256xf16, #[[$ATTR_0]]>
// CHECK: }
// -----
// The following is a more complex case where also a multiplication is involved. It's useful to walk through the case.
// We have that the offset to the pointer is the following:
// %12 = %10 + 11
// This can be transformed in:
// = %7 + %9
// = %5*%6 + %8
// = %4*%arg1 + %8
// = (%3+%2)*%arg1 + %8
// = (%1 + %2) * %arg1 + %8
// = (U + N)*U + N
// Where U means uniform (e.g., a splat) and N means NonUniform (e.g., a make_range)
// The scalar offset we want is (%1*%arg1), while the variable offset should be (%2*%arg1 + %8)
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func public @matmul_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) -> tensor<128x16xf16, #blocked> {
%c128_i32 = arith.constant 128 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c128_i32 : i32
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%3 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%4 = arith.addi %3, %2 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked>
%6 = tt.splat %arg1 : i32 -> tensor<128x1xi32, #blocked>
%7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked>
%8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked>
%10 = tt.broadcast %7 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked>
%11 = tt.broadcast %9 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked>
%12 = arith.addi %10, %11 : tensor<128x16xi32, #blocked>
%13 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
%14 = tt.addptr %13, %12 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi32, #blocked>
%15 = tt.load %14 : tensor<128x16x!tt.ptr<f16>, #blocked>
tt.return %15 : tensor<128x16xf16, #blocked>
}
}
// CHECK: #[[$ATTR_1:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
// CHECK-LABEL: tt.func public @matmul_kernel(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f16> {tt.divisibility = 16 : i32},
// CHECK-SAME: %[[VAL_1:.*]]: i32 {tt.divisibility = 16 : i32}) -> tensor<128x16xf16, #[[$ATTR_1]]> {
// CHECK: %[[VAL_2:.*]] = arith.constant 128 : i32
// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32
// CHECK: %[[VAL_5:.*]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #[[$ATTR_1]]}>>
// CHECK: %[[VAL_7:.*]] = tt.expand_dims %[[VAL_5]] {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128x1xi32, #[[$ATTR_1]]>
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_4]], %[[VAL_1]] : i32
// CHECK: %[[VAL_9:.*]] = tt.splat %[[VAL_1]] : i32 -> tensor<128x1xi32, #[[$ATTR_1]]>
// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_7]], %[[VAL_9]] : tensor<128x1xi32, #[[$ATTR_1]]>
// CHECK: %[[VAL_11:.*]] = tt.broadcast %[[VAL_10]] : tensor<128x1xi32, #[[$ATTR_1]]> -> tensor<128x16xi32, #[[$ATTR_1]]>
// CHECK: %[[VAL_12:.*]] = tt.expand_dims %[[VAL_6]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #[[$ATTR_1]]}>> -> tensor<1x16xi32, #[[$ATTR_1]]>
// CHECK: %[[VAL_13:.*]] = tt.broadcast %[[VAL_12]] : tensor<1x16xi32, #[[$ATTR_1]]> -> tensor<128x16xi32, #[[$ATTR_1]]>
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_13]] : tensor<128x16xi32, #[[$ATTR_1]]>
// CHECK: %[[VAL_15:.*]] = tt.addptr %[[VAL_0]], %[[VAL_8]] : !tt.ptr<f16>, i32
// CHECK: %[[VAL_16:.*]] = tt.splat %[[VAL_15]] : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #[[$ATTR_1]]>
// CHECK: %[[VAL_17:.*]] = tt.addptr %[[VAL_16]], %[[VAL_14]] : tensor<128x16x!tt.ptr<f16>, #[[$ATTR_1]]>, tensor<128x16xi32, #[[$ATTR_1]]>
// CHECK: %[[VAL_18:.*]] = tt.load %[[VAL_17]] : tensor<128x16x!tt.ptr<f16>, #[[$ATTR_1]]>
// CHECK: tt.return %[[VAL_18]] : tensor<128x16xf16, #[[$ATTR_1]]>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @select(%arg0: !tt.ptr<f32>, %arg1: i1) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%7 = arith.select %arg1, %5, %6 : tensor<1024x!tt.ptr<f32>>
%8 = tt.load %7 : tensor<1024x!tt.ptr<f32>>
tt.return %8 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @select(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
// CHECK-SAME: %[[VAL_1:.*]]: i1) -> tensor<1024xf32> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<0> : tensor<1024xi64>
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1024 : i32
// CHECK: %[[VAL_4:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_5:.*]] = arith.muli %[[VAL_4]], %[[VAL_3]] : i32
// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_0]], %[[VAL_5]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_8:.*]] = arith.extsi %[[VAL_6]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_1]], %[[VAL_0]], %[[VAL_7]] : !tt.ptr<f32>
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_1]], %[[VAL_2]], %[[VAL_8]] : tensor<1024xi64>
// CHECK: %[[VAL_11:.*]] = arith.trunci %[[VAL_10]] : tensor<1024xi64> to tensor<1024xi32>
// CHECK: %[[VAL_12:.*]] = tt.splat %[[VAL_9]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_13:.*]] = tt.addptr %[[VAL_12]], %[[VAL_11]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// CHECK: %[[VAL_14:.*]] = tt.load %[[VAL_13]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_14]] : tensor<1024xf32>
// CHECK: }
// -----
module attributes {"ttg.num-ctas" = 1 : i32} {
tt.func @where_kernel(%arg0: !tt.ptr<i64>, %arg1: !tt.ptr<i64>, %cst: i8) -> tensor<1024xi64> {
%c0_i8 = arith.constant 0 : i8
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = arith.cmpi ne, %c0_i8, %cst : i8
%6 = arith.select %5, %arg0, %arg1 : !tt.ptr<i64>
%7 = tt.splat %6 : !tt.ptr<i64> -> tensor<1024x!tt.ptr<i64>>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<i64>>, tensor<1024xi32>
%9 = tt.load %8 : tensor<1024x!tt.ptr<i64>>
tt.return %9 : tensor<1024xi64>
}
}
// I don't know why but FileCheck doesn't like check-same here and elsewhere where I've removed them...
// CHECK: tt.func @where_kernel(%[[VAL_0:.*]]: !tt.ptr<i64>, %[[VAL_1:.*]]: !tt.ptr<i64>, %[[VAL_3:.*]]: i8) -> tensor<1024xi64> {
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i8
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : i32
// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_5]] : i32
// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_9:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i8
// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_9]], %[[VAL_0]], %[[VAL_1]] : !tt.ptr<i64>
// CHECK: %[[VAL_11:.*]] = tt.addptr %[[VAL_10]], %[[VAL_7]] : !tt.ptr<i64>, i32
// CHECK: %[[VAL_12:.*]] = tt.splat %[[VAL_11]] : !tt.ptr<i64> -> tensor<1024x!tt.ptr<i64>>
// CHECK: %[[VAL_13:.*]] = tt.addptr %[[VAL_12]], %[[VAL_8]] : tensor<1024x!tt.ptr<i64>>, tensor<1024xi32>
// CHECK: %[[VAL_14:.*]] = tt.load %[[VAL_13]] : tensor<1024x!tt.ptr<i64>>
// CHECK: tt.return %[[VAL_14]] : tensor<1024xi64>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @forOpWithHints(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c128 = arith.constant 128 : index
%0 = tt.get_program_id x : i32
%1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%2 = tt.splat %0 : i32 -> tensor<1024xi32>
%3 = arith.addi %2, %1 : tensor<1024xi32>
%4 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%5 = tt.addptr %4, %3 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%6:2 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %5, %arg4 = %arg1) -> (tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>) {
%9 = tt.load %arg3 : tensor<1024x!tt.ptr<f32>>
%10 = tt.addptr %arg3, %3 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%11 = tt.addptr %10, %2 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%12 = arith.addf %9, %arg4 : tensor<1024xf32>
scf.yield %11, %12 : tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>
} {tt.divisibility_arg1 = dense<16> : tensor<1xi32>}
%7 = tt.addptr %6#0, %3 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%8 = tt.load %7 : tensor<1024x!tt.ptr<f32>>
tt.return %8 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @forOpWithHints(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 128 : index
// CHECK: %[[VAL_5:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_0]], %[[VAL_5]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_8:.*]] = arith.extsi %[[VAL_6]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_9:.*]]:3 = scf.for %[[VAL_10:.*]] = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_3]] iter_args(%[[VAL_11:.*]] = %[[VAL_7]], %[[VAL_12:.*]] = %[[VAL_8]], %[[VAL_13:.*]] = %[[VAL_1]]) -> (!tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>) {
// CHECK: %[[VAL_14:.*]] = arith.trunci %[[VAL_12]] : tensor<1024xi64> to tensor<1024xi32>
// CHECK: %[[VAL_15:.*]] = tt.splat %[[VAL_11]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_15]], %[[VAL_14]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// CHECK: %[[VAL_17:.*]] = tt.load %[[VAL_16]] : tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_18:.*]] = tt.addptr %[[VAL_11]], %[[VAL_5]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_19:.*]] = arith.extsi %[[VAL_6]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_12]] : tensor<1024xi64>
// CHECK: %[[VAL_21:.*]] = tt.addptr %[[VAL_18]], %[[VAL_5]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_22:.*]] = arith.addf %[[VAL_17]], %[[VAL_13]] : tensor<1024xf32>
// CHECK: scf.yield %[[VAL_21]], %[[VAL_20]], %[[VAL_22]] : !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>
// CHECK: } {tt.divisibility_arg1 = dense<16> : tensor<1xi32>, tt.divisibility_arg2 = dense<16> : tensor<1xi32>}
// CHECK: %[[VAL_23:.*]] = tt.addptr %[[VAL_24:.*]]#0, %[[VAL_5]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_25:.*]] = arith.extsi %[[VAL_6]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_24]]#1 : tensor<1024xi64>
// CHECK: %[[VAL_27:.*]] = tt.splat %[[VAL_23]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_28:.*]] = tt.addptr %[[VAL_27]], %[[VAL_26]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_29:.*]] = tt.load %[[VAL_28]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_29]] : tensor<1024xf32>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func public @scalar_pointers(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
%c1_i32 = arith.constant 1 : i32
%c0_i64 = arith.constant 0 : i64
%c100_i32 = arith.constant 100 : i32
%1 = tt.addptr %arg0, %c1_i32 : !tt.ptr<i64>, i32
%2 = scf.for %arg3 = %c1_i32 to %c100_i32 step %c1_i32 iter_args(%arg4 = %1) -> (!tt.ptr<i64>) : i32 {
tt.store %arg4, %c0_i64 : !tt.ptr<i64>
%3 = tt.addptr %arg4, %c1_i32 : !tt.ptr<i64>, i32
scf.yield %3 : !tt.ptr<i64>
}
tt.return
}
}
// CHECK: tt.func public @scalar_pointers(%[[VAL_0:.*]]: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i64
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 100 : i32
// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_0]], %[[VAL_4]] : !tt.ptr<i64>, i32
// CHECK: %[[VAL_7:.*]] = scf.for %[[VAL_8:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_9:.*]] = %[[VAL_6]]) -> (!tt.ptr<i64>) : i32 {
// CHECK: tt.store %[[VAL_9]], %[[VAL_3]] : !tt.ptr<i64>
// CHECK: %[[VAL_10:.*]] = tt.addptr %[[VAL_9]], %[[VAL_4]] : !tt.ptr<i64>, i32
// CHECK: scf.yield %[[VAL_10]] : !tt.ptr<i64>
// CHECK: }
// CHECK: tt.return
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @scalar_if(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>, %arg2: i1) -> f32 {
%c1_i32 = arith.constant 1 : i32
%c100_i32 = arith.constant 100 : i32
%1 = tt.addptr %arg0, %c1_i32 : !tt.ptr<f32>, i32
%2 = scf.if %arg2 -> (!tt.ptr<f32>) {
%4 = tt.addptr %1, %c1_i32 : !tt.ptr<f32>, i32
scf.yield %4 : !tt.ptr<f32>
} else {
%4 = tt.addptr %1, %c100_i32 : !tt.ptr<f32>, i32
scf.yield %4 : !tt.ptr<f32>
}
%3 = tt.load %2 : !tt.ptr<f32>
tt.return %3 : f32
}
}
// CHECK-LABEL: tt.func @scalar_if(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1024xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: i1) -> f32 {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 100 : i32
// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_2]] -> (!tt.ptr<f32>) {
// CHECK: %[[VAL_7:.*]] = tt.addptr %[[VAL_5]], %[[VAL_3]] : !tt.ptr<f32>, i32
// CHECK: scf.yield %[[VAL_7]] : !tt.ptr<f32>
// CHECK: } else {
// CHECK: %[[VAL_8:.*]] = tt.addptr %[[VAL_5]], %[[VAL_4]] : !tt.ptr<f32>, i32
// CHECK: scf.yield %[[VAL_8]] : !tt.ptr<f32>
// CHECK: }
// CHECK: %[[VAL_9:.*]] = tt.load %[[VAL_6]] : !tt.ptr<f32>
// CHECK: tt.return %[[VAL_9]] : f32
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @scalar_while(%arg0: !tt.ptr<f32>, %arg1: f32) -> f32 {
%c128_i32 = arith.constant 128 : i32
%0 = tt.get_program_id x : i32
%1 = tt.addptr %arg0, %0 : !tt.ptr<f32>, i32
%2 = scf.while (%arg2 = %1) : (!tt.ptr<f32>) -> !tt.ptr<f32> {
%4 = "dummy.evaluate_condition"() : () -> i1
scf.condition(%4) %arg2 : !tt.ptr<f32>
} do {
^bb0(%arg2: !tt.ptr<f32>):
%4 = tt.addptr %arg2, %c128_i32 : !tt.ptr<f32>, i32
scf.yield %4 : !tt.ptr<f32>
}
%3 = tt.load %2 : !tt.ptr<f32>
tt.return %3 : f32
}
}
// CHECK-LABEL: tt.func @scalar_while(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>,
// CHECK-SAME: %[[VAL_1:.*]]: f32) -> f32 {
// CHECK: %[[VAL_2:.*]] = arith.constant 128 : i32
// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_4]]) : (!tt.ptr<f32>) -> !tt.ptr<f32> {
// CHECK: %[[VAL_7:.*]] = "dummy.evaluate_condition"() : () -> i1
// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : !tt.ptr<f32>
// CHECK: } do {
// CHECK: ^bb0(%[[VAL_8:.*]]: !tt.ptr<f32>):
// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_8]], %[[VAL_2]] : !tt.ptr<f32>, i32
// CHECK: scf.yield %[[VAL_9]] : !tt.ptr<f32>
// CHECK: }
// CHECK: %[[VAL_10:.*]] = tt.load %[[VAL_5]] : !tt.ptr<f32>
// CHECK: tt.return %[[VAL_10]] : f32
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @scalar_cond_branch(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i1) -> f32 {
cf.cond_br %arg2, ^bb1(%arg0 : !tt.ptr<f32>), ^bb2(%arg1 : !tt.ptr<f32>)
^bb1(%0: !tt.ptr<f32>): // pred: ^bb0
%1 = tt.load %0 : !tt.ptr<f32>
tt.return %1 : f32
^bb2(%2: !tt.ptr<f32>): // pred: ^bb0
%3 = tt.load %2 : !tt.ptr<f32>
tt.return %3 : f32
}
}
// CHECK-LABEL: tt.func @scalar_cond_branch(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>, %[[VAL_1:.*]]: !tt.ptr<f32>, %[[VAL_2:.*]]: i1) -> f32 {
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i64
// CHECK: cf.cond_br %[[VAL_2]], ^bb1(%[[VAL_0]], %[[VAL_3]] : !tt.ptr<f32>, i64), ^bb2(%[[VAL_1]], %[[VAL_3]] : !tt.ptr<f32>, i64)
// CHECK: ^bb1(%[[VAL_4:.*]]: !tt.ptr<f32>, %[[VAL_5:.*]]: i64):
// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_4]], %[[VAL_5]] : !tt.ptr<f32>, i64
// CHECK: %[[VAL_7:.*]] = tt.load %[[VAL_6]] : !tt.ptr<f32>
// CHECK: tt.return %[[VAL_7]] : f32
// CHECK: ^bb2(%[[VAL_8:.*]]: !tt.ptr<f32>, %[[VAL_9:.*]]: i64):
// CHECK: %[[VAL_10:.*]] = tt.addptr %[[VAL_8]], %[[VAL_9]] : !tt.ptr<f32>, i64
// CHECK: %[[VAL_11:.*]] = tt.load %[[VAL_10]] : !tt.ptr<f32>
// CHECK: tt.return %[[VAL_11]] : f32
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @flipFlopForOpSimple(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%60 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%7:3 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg30 = %60, %arg3 = %6, %arg4 = %arg1) -> (tensor<1024x!tt.ptr<f32>>, tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>) {
%10 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%11 = tt.load %10 : tensor<1024x!tt.ptr<f32>>
%12 = arith.addf %11, %arg4 : tensor<1024xf32>
%100 = tt.addptr %arg30, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
scf.yield %10, %arg30, %12 : tensor<1024x!tt.ptr<f32>>, tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>
}
%8 = tt.addptr %7#0, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%9 = tt.load %8 : tensor<1024x!tt.ptr<f32>>
tt.return %9 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @flipFlopForOpSimple(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>, %[[VAL_1:.*]]: tensor<1024xf32>) -> tensor<1024xf32> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1024 : i32
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 128 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_2]] : i32
// CHECK: %[[VAL_8:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_9:.*]] = tt.addptr %[[VAL_0]], %[[VAL_7]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_10:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_11:.*]] = tt.addptr %[[VAL_0]], %[[VAL_7]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_12:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_13:.*]]:5 = scf.for %[[VAL_14:.*]] = %[[VAL_3]] to %[[VAL_4]] step %[[VAL_5]] iter_args(%[[VAL_15:.*]] = %[[VAL_11]], %[[VAL_16:.*]] = %[[VAL_12]], %[[VAL_17:.*]] = %[[VAL_9]], %[[VAL_18:.*]] = %[[VAL_10]], %[[VAL_19:.*]] = %[[VAL_1]]) -> (!tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>) {
// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_17]], %[[VAL_7]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_21:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_18]] : tensor<1024xi64>
// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_20]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_24:.*]] = tt.addptr %[[VAL_23]], %[[VAL_22]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_25:.*]] = tt.load %[[VAL_24]] : tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_25]], %[[VAL_19]] : tensor<1024xf32>
// CHECK: scf.yield %[[VAL_20]], %[[VAL_22]], %[[VAL_15]], %[[VAL_16]], %[[VAL_26]] : !tt.ptr<f32>, tensor<1024xi64>, !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>
// CHECK: }
// CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_28:.*]]#0, %[[VAL_7]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_29:.*]] = arith.extsi %[[VAL_8]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_29]], %[[VAL_28]]#1 : tensor<1024xi64>
// CHECK: %[[VAL_31:.*]] = tt.splat %[[VAL_27]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_32:.*]] = tt.addptr %[[VAL_31]], %[[VAL_30]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_33:.*]] = tt.load %[[VAL_32]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_33]] : tensor<1024xf32>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @flipFlopForOpComplex(%arg0: !tt.ptr<f32>, %arg00: !tt.ptr<f32>, %arg1: tensor<1024xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) {
%c1024_i32 = arith.constant 1024 : i32
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%40 = arith.addi %3, %2 : tensor<1024xi32>
%50 = tt.splat %arg00 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%60 = tt.addptr %50, %40 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%7:4 = scf.for %arg2 = %c0 to %c128 step %c1 iter_args(%arg3 = %6, %arg4 = %arg1, %arg30 = %60, %arg40 = %arg1) -> (tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>, tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>) {
%10 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%11 = tt.load %10 : tensor<1024x!tt.ptr<f32>>
%12 = arith.addf %11, %arg4 : tensor<1024xf32>
%100 = tt.addptr %arg30, %40 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%110 = tt.load %100 : tensor<1024x!tt.ptr<f32>>
%120 = arith.addf %110, %arg40 : tensor<1024xf32>
scf.yield %100, %120, %10, %12 : tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>, tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>
}
%8 = tt.addptr %7#0, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%9 = tt.load %8 : tensor<1024x!tt.ptr<f32>>
%80 = tt.addptr %7#2, %40 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%90 = tt.load %80 : tensor<1024x!tt.ptr<f32>>
tt.return %9, %90 : tensor<1024xf32>, tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @flipFlopForOpComplex(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>, %[[VAL_1:.*]]: !tt.ptr<f32>, %[[VAL_2:.*]]: tensor<1024xf32>) -> (tensor<1024xf32>, tensor<1024xf32>) {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1024 : i32
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 128 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_7:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_3]] : i32
// CHECK: %[[VAL_9:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_10:.*]] = tt.addptr %[[VAL_0]], %[[VAL_8]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_11:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_12:.*]] = tt.addptr %[[VAL_1]], %[[VAL_8]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_13:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_14:.*]]:6 = scf.for %[[VAL_15:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] iter_args(%[[VAL_16:.*]] = %[[VAL_10]], %[[VAL_17:.*]] = %[[VAL_11]], %[[VAL_18:.*]] = %[[VAL_2]], %[[VAL_19:.*]] = %[[VAL_12]], %[[VAL_20:.*]] = %[[VAL_13]], %[[VAL_21:.*]] = %[[VAL_2]]) -> (!tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>) {
// CHECK: %[[VAL_22:.*]] = tt.addptr %[[VAL_16]], %[[VAL_8]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_23:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_17]] : tensor<1024xi64>
// CHECK: %[[VAL_25:.*]] = tt.splat %[[VAL_22]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_26:.*]] = tt.addptr %[[VAL_25]], %[[VAL_24]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_27:.*]] = tt.load %[[VAL_26]] : tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_27]], %[[VAL_18]] : tensor<1024xf32>
// CHECK: %[[VAL_29:.*]] = tt.addptr %[[VAL_19]], %[[VAL_8]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_30:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_20]] : tensor<1024xi64>
// CHECK: %[[VAL_32:.*]] = tt.splat %[[VAL_29]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_33:.*]] = tt.addptr %[[VAL_32]], %[[VAL_31]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_34:.*]] = tt.load %[[VAL_33]] : tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_35:.*]] = arith.addf %[[VAL_34]], %[[VAL_21]] : tensor<1024xf32>
// CHECK: scf.yield %[[VAL_29]], %[[VAL_31]], %[[VAL_35]], %[[VAL_22]], %[[VAL_24]], %[[VAL_28]] : !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>, !tt.ptr<f32>, tensor<1024xi64>, tensor<1024xf32>
// CHECK: }
// CHECK: %[[VAL_36:.*]] = tt.addptr %[[VAL_37:.*]]#0, %[[VAL_8]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_38:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_38]], %[[VAL_37]]#1 : tensor<1024xi64>
// CHECK: %[[VAL_40:.*]] = tt.splat %[[VAL_36]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_41:.*]] = tt.addptr %[[VAL_40]], %[[VAL_39]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_42:.*]] = tt.load %[[VAL_41]] : tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_43:.*]] = tt.addptr %[[VAL_37]]#3, %[[VAL_8]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_44:.*]] = arith.extsi %[[VAL_9]] : tensor<1024xi32> to tensor<1024xi64>
// CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_44]], %[[VAL_37]]#4 : tensor<1024xi64>
// CHECK: %[[VAL_46:.*]] = tt.splat %[[VAL_43]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_47:.*]] = tt.addptr %[[VAL_46]], %[[VAL_45]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi64>
// CHECK: %[[VAL_48:.*]] = tt.load %[[VAL_47]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_42]], %[[VAL_48]] : tensor<1024xf32>, tensor<1024xf32>
// CHECK: }
// -----
// test_functional_regressions.test_inductor_cummax_bool
// tt.bitcast immediately materializes the fat pointer, ending the analysis
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @test_inductor_cummax_bool(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<i1> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
%cst = arith.constant dense<0> : tensor<64xi8, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%1 = tt.splat %arg0 : !tt.ptr<i1> -> tensor<64x!tt.ptr<i1>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i1>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%3 = tt.bitcast %2 : tensor<64x!tt.ptr<i1>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> -> tensor<64x!tt.ptr<i8>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%4 = tt.load %3 : tensor<64x!tt.ptr<i8>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%5 = arith.cmpi ne, %4, %cst : tensor<64xi8, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%6 = arith.extsi %0 : tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> to tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%7:2 = "tt.scan"(%5, %6) <{axis = 0 : i32, reverse = false}> ({
^bb0(%arg3: i1, %arg4: i64, %arg5: i1, %arg6: i64):
%14 = arith.cmpi ugt, %arg3, %arg5 : i1
%15 = arith.cmpi eq, %arg3, %arg5 : i1
%16 = arith.cmpi sgt, %arg4, %arg6 : i64
%17 = arith.andi %15, %16 : i1
%18 = arith.ori %14, %17 : i1
%19 = arith.select %18, %arg3, %arg5 : i1
%20 = arith.select %18, %arg4, %arg6 : i64
tt.scan.return %19, %20 : i1, i64
}) : (tensor<64xi1, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> (tensor<64xi1, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>)
%8 = tt.splat %arg1 : !tt.ptr<i1> -> tensor<64x!tt.ptr<i1>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%9 = tt.addptr %8, %0 : tensor<64x!tt.ptr<i1>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%10 = tt.bitcast %9 : tensor<64x!tt.ptr<i1>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> -> tensor<64x!tt.ptr<i8>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%11 = arith.extui %7#0 : tensor<64xi1, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> to tensor<64xi8, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
tt.store %10, %11 : tensor<64x!tt.ptr<i8>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%12 = tt.splat %arg2 : !tt.ptr<i64> -> tensor<64x!tt.ptr<i64>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%13 = tt.addptr %12, %0 : tensor<64x!tt.ptr<i64>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
tt.store %13, %7#1 : tensor<64x!tt.ptr<i64>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
tt.return
}
}
// CHECK-LABEL: tt.func public @test_inductor_cummax_bool(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<i1> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_1:.*]]: !tt.ptr<i1> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_2:.*]]: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : tensor<64xi8, #[[$ATTR_0]]>
// CHECK: %[[VAL_4:.*]] = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #[[$ATTR_0]]>
// CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_0]] : !tt.ptr<i1> -> tensor<64x!tt.ptr<i1>, #[[$ATTR_0]]>
// CHECK: %[[VAL_6:.*]] = tt.addptr %[[VAL_5]], %[[VAL_4]] : tensor<64x!tt.ptr<i1>, #[[$ATTR_0]]>, tensor<64xi32, #[[$ATTR_0]]>
// CHECK: %[[VAL_7:.*]] = tt.bitcast %[[VAL_6]] : tensor<64x!tt.ptr<i1>, #[[$ATTR_0]]> -> tensor<64x!tt.ptr<i8>, #[[$ATTR_0]]>
// CHECK: %[[VAL_8:.*]] = tt.load %[[VAL_7]] : tensor<64x!tt.ptr<i8>, #[[$ATTR_0]]>
// CHECK: %[[VAL_9:.*]] = arith.cmpi ne, %[[VAL_8]], %[[VAL_3]] : tensor<64xi8, #[[$ATTR_0]]>
// CHECK: %[[VAL_10:.*]] = arith.extsi %[[VAL_4]] : tensor<64xi32, #[[$ATTR_0]]> to tensor<64xi64, #[[$ATTR_0]]>
// CHECK: %[[VAL_11:.*]]:2 = "tt.scan"(%[[VAL_9]], %[[VAL_10]]) <{axis = 0 : i32, reverse = false}> ({
// CHECK: ^bb0(%[[VAL_12:.*]]: i1, %[[VAL_13:.*]]: i64, %[[VAL_14:.*]]: i1, %[[VAL_15:.*]]: i64):
// CHECK: %[[VAL_16:.*]] = arith.cmpi ugt, %[[VAL_12]], %[[VAL_14]] : i1
// CHECK: %[[VAL_17:.*]] = arith.cmpi eq, %[[VAL_12]], %[[VAL_14]] : i1
// CHECK: %[[VAL_18:.*]] = arith.cmpi sgt, %[[VAL_13]], %[[VAL_15]] : i64
// CHECK: %[[VAL_19:.*]] = arith.andi %[[VAL_17]], %[[VAL_18]] : i1
// CHECK: %[[VAL_20:.*]] = arith.ori %[[VAL_16]], %[[VAL_19]] : i1
// CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_20]], %[[VAL_12]], %[[VAL_14]] : i1
// CHECK: %[[VAL_22:.*]] = arith.select %[[VAL_20]], %[[VAL_13]], %[[VAL_15]] : i64
// CHECK: tt.scan.return %[[VAL_21]], %[[VAL_22]] : i1, i64
// CHECK: }) : (tensor<64xi1, #[[$ATTR_0]]>, tensor<64xi64, #[[$ATTR_0]]>) -> (tensor<64xi1, #[[$ATTR_0]]>, tensor<64xi64, #[[$ATTR_0]]>)
// CHECK: %[[VAL_23:.*]] = tt.splat %[[VAL_1]] : !tt.ptr<i1> -> tensor<64x!tt.ptr<i1>, #[[$ATTR_0]]>
// CHECK: %[[VAL_24:.*]] = tt.addptr %[[VAL_23]], %[[VAL_4]] : tensor<64x!tt.ptr<i1>, #[[$ATTR_0]]>, tensor<64xi32, #[[$ATTR_0]]>
// CHECK: %[[VAL_25:.*]] = tt.bitcast %[[VAL_24]] : tensor<64x!tt.ptr<i1>, #[[$ATTR_0]]> -> tensor<64x!tt.ptr<i8>, #[[$ATTR_0]]>
// CHECK: %[[VAL_26:.*]] = arith.extui %[[VAL_27:.*]]#0 : tensor<64xi1, #[[$ATTR_0]]> to tensor<64xi8, #[[$ATTR_0]]>
// CHECK: tt.store %[[VAL_25]], %[[VAL_26]] : tensor<64x!tt.ptr<i8>, #[[$ATTR_0]]>
// CHECK: %[[VAL_28:.*]] = tt.splat %[[VAL_2]] : !tt.ptr<i64> -> tensor<64x!tt.ptr<i64>, #[[$ATTR_0]]>
// CHECK: %[[VAL_29:.*]] = tt.addptr %[[VAL_28]], %[[VAL_4]] : tensor<64x!tt.ptr<i64>, #[[$ATTR_0]]>, tensor<64xi32, #[[$ATTR_0]]>
// CHECK: tt.store %[[VAL_29]], %[[VAL_27]]#1 : tensor<64x!tt.ptr<i64>, #[[$ATTR_0]]>
// CHECK: tt.return
// CHECK: }
// -----
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @test_atomic_rmw(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
%true = arith.constant true
%0 = tt.get_program_id x : i32
%1 = tt.addptr %arg0, %0 : !tt.ptr<f16>, i32
%2 = tt.load %1 : !tt.ptr<f16>
%3 = tt.atomic_rmw fadd, acq_rel, gpu, %arg1, %2, %true : (!tt.ptr<f16>, f16, i1) -> f16
tt.return
}
}
// CHECK-LABEL: tt.func public @test_atomic_rmw(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_1:.*]]: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
// CHECK: %[[VAL_2:.*]] = arith.constant true
// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr<f16>, i32
// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr<f16>
// CHECK: %[[VAL_6:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, %[[VAL_1]], %[[VAL_5]], %[[VAL_2]] : (!tt.ptr<f16>, f16, i1) -> f16
// CHECK: tt.return
// CHECK: }
// -----
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @test_atomic_rmw_bf16(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
%true = arith.constant true
%0 = tt.get_program_id x : i32
%1 = tt.addptr %arg0, %0 : !tt.ptr<bf16>, i32
%2 = tt.load %1 : !tt.ptr<bf16>
%3 = tt.atomic_rmw fadd, acq_rel, gpu, %arg1, %2, %true : (!tt.ptr<bf16>, bf16, i1) -> bf16
tt.return
}
}
// CHECK-LABEL: tt.func public @test_atomic_rmw_bf16(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_1:.*]]: !tt.ptr<bf16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
// CHECK: %[[VAL_2:.*]] = arith.constant true
// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr<bf16>, i32
// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr<bf16>
// CHECK: %[[VAL_6:.*]] = tt.atomic_rmw fadd, acq_rel, gpu, %[[VAL_1]], %[[VAL_5]], %[[VAL_2]] : (!tt.ptr<bf16>, bf16, i1) -> bf16
// CHECK: tt.return
// CHECK: }
// -----
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// expected-remark@+1 {{expected at least 1 use of unrealized_cast}}
tt.func public @empty_kernel(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
tt.return
}
}
// -----
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @test_reduce(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
%cst = arith.constant dense<16> : tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>
%cst_0 = arith.constant dense<16> : tensor<1x2x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%cst_1 = arith.constant dense<16> : tensor<32x1x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%cst_2 = arith.constant dense<2> : tensor<32x1x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>>
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>>
%2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> -> tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>
%3 = tt.expand_dims %1 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> -> tensor<32x1xi32, #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>
%4 = tt.expand_dims %3 {axis = 2 : i32} : tensor<32x1xi32, #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<32x1x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%5 = arith.muli %4, %cst_2 : tensor<32x1x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%6 = arith.muli %5, %cst_1 : tensor<32x1x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x1x1x!tt.ptr<f32>, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%8 = tt.addptr %7, %6 : tensor<32x1x1x!tt.ptr<f32>, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>, tensor<32x1x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%9 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>>
%10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>
%11 = tt.expand_dims %10 {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<1x2x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%12 = arith.muli %11, %cst_0 : tensor<1x2x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%13 = tt.broadcast %8 : tensor<32x1x1x!tt.ptr<f32>, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> -> tensor<32x2x1x!tt.ptr<f32>, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%14 = tt.broadcast %12 : tensor<1x2x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> -> tensor<32x2x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%15 = tt.addptr %13, %14 : tensor<32x2x1x!tt.ptr<f32>, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>, tensor<32x2x1xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%16 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>>
%17 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>>
%18 = tt.expand_dims %16 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> -> tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>
%19 = tt.expand_dims %17 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>}>> -> tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>
%20 = tt.expand_dims %19 {axis = 1 : i32} : tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<1x1x16xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%21 = tt.broadcast %15 : tensor<32x2x1x!tt.ptr<f32>, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> -> tensor<32x2x16x!tt.ptr<f32>, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%22 = tt.broadcast %20 : tensor<1x1x16xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>> -> tensor<32x2x16xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%23 = tt.addptr %21, %22 : tensor<32x2x16x!tt.ptr<f32>, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>, tensor<32x2x16xi32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%24 = tt.load %23 : tensor<32x2x16x!tt.ptr<f32>, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%25 = "tt.reduce"(%24) <{axis = 1 : i32}> ({
^bb0(%arg2: f32, %arg3: f32):
%34 = arith.maxnumf %arg2, %arg3 : f32
tt.reduce.return %34 : f32
}) : (tensor<32x2x16xf32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>) -> tensor<32x16xf32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>
%26 = tt.expand_dims %25 {axis = 1 : i32} : tensor<32x16xf32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<32x1x16xf32, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
%27 = arith.muli %2, %cst : tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>
%28 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<32x1x!tt.ptr<f32>, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>
%29 = tt.addptr %28, %27 : tensor<32x1x!tt.ptr<f32>, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>, tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>
%30 = tt.broadcast %29 : tensor<32x1x!tt.ptr<f32>, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<32x16x!tt.ptr<f32>, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>
%31 = tt.broadcast %18 : tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<32x16xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>
%32 = tt.addptr %30, %31 : tensor<32x16x!tt.ptr<f32>, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>, tensor<32x16xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>>
%33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<32x16x!tt.ptr<f32>, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>}>> -> tensor<32x1x16x!tt.ptr<f32>, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
tt.store %33, %26 : tensor<32x1x16x!tt.ptr<f32>, #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>>
tt.return
}
}
// CHECK: #[[$ATTR_3:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [8, 1, 4], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>
// CHECK-LABEL: tt.func public @test_reduce(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_1:.*]]: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
// CHECK: %[[VAL_2:.*]] = arith.constant dense<16> : tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>>
// CHECK: %[[VAL_3:.*]] = arith.constant dense<16> : tensor<1x2x1xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_4:.*]] = arith.constant dense<16> : tensor<32x1x1xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_5:.*]] = arith.constant dense<2> : tensor<32x1x1xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>}>>
// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>}>>
// CHECK: %[[VAL_8:.*]] = tt.expand_dims %[[VAL_7]] {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>}>> -> tensor<32x1xi32, #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>>
// CHECK: %[[VAL_9:.*]] = tt.expand_dims %[[VAL_8]] {axis = 2 : i32} : tensor<32x1xi32, #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>> -> tensor<32x1x1xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_5]] : tensor<32x1x1xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_4]] : tensor<32x1x1xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_12:.*]] = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>}>>
// CHECK: %[[VAL_13:.*]] = tt.broadcast %[[VAL_11]] : tensor<32x1x1xi32, #[[$ATTR_3]]> -> tensor<32x2x1xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_14:.*]] = tt.expand_dims %[[VAL_12]] {axis = 0 : i32} : tensor<2xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>}>> -> tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>>
// CHECK: %[[VAL_15:.*]] = tt.expand_dims %[[VAL_14]] {axis = 2 : i32} : tensor<1x2xi32, #ttg.slice<{dim = 2, parent = #[[$ATTR_3]]}>> -> tensor<1x2x1xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_15]], %[[VAL_3]] : tensor<1x2x1xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_17:.*]] = tt.broadcast %[[VAL_16]] : tensor<1x2x1xi32, #[[$ATTR_3]]> -> tensor<32x2x1xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_13]] : tensor<32x2x1xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_19:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>}>>
// CHECK: %[[VAL_20:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>}>>
// CHECK: %[[VAL_21:.*]] = tt.broadcast %[[VAL_18]] : tensor<32x2x1xi32, #[[$ATTR_3]]> -> tensor<32x2x16xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_22:.*]] = tt.expand_dims %[[VAL_20]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>}>> -> tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>>
// CHECK: %[[VAL_23:.*]] = tt.expand_dims %[[VAL_22]] {axis = 1 : i32} : tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> -> tensor<1x1x16xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_24:.*]] = tt.broadcast %[[VAL_23]] : tensor<1x1x16xi32, #[[$ATTR_3]]> -> tensor<32x2x16xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_21]] : tensor<32x2x16xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_26:.*]] = tt.splat %[[VAL_0]] : !tt.ptr<f32> -> tensor<32x2x16x!tt.ptr<f32>, #[[$ATTR_3]]>
// CHECK: %[[VAL_27:.*]] = tt.addptr %[[VAL_26]], %[[VAL_25]] : tensor<32x2x16x!tt.ptr<f32>, #[[$ATTR_3]]>, tensor<32x2x16xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_28:.*]] = tt.load %[[VAL_27]] : tensor<32x2x16x!tt.ptr<f32>, #[[$ATTR_3]]>
// CHECK: %[[VAL_29:.*]] = "tt.reduce"(%[[VAL_28]]) <{axis = 1 : i32}> ({
// CHECK: ^bb0(%[[VAL_30:.*]]: f32, %[[VAL_31:.*]]: f32):
// CHECK: %[[VAL_32:.*]] = arith.maxnumf %[[VAL_30]], %[[VAL_31]] : f32
// CHECK: tt.reduce.return %[[VAL_32]] : f32
// CHECK: }) : (tensor<32x2x16xf32, #[[$ATTR_3]]>) -> tensor<32x16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>>
// CHECK: %[[VAL_33:.*]] = tt.expand_dims %[[VAL_29]] {axis = 1 : i32} : tensor<32x16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> -> tensor<32x1x16xf32, #[[$ATTR_3]]>
// CHECK: %[[VAL_34:.*]] = tt.expand_dims %[[VAL_6]] {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>}>> -> tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>>
// CHECK: %[[VAL_35:.*]] = arith.muli %[[VAL_34]], %[[VAL_2]] : tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>>
// CHECK: %[[VAL_36:.*]] = tt.broadcast %[[VAL_35]] : tensor<32x1xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> -> tensor<32x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>>
// CHECK: %[[VAL_37:.*]] = tt.expand_dims %[[VAL_19]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>}>> -> tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>>
// CHECK: %[[VAL_38:.*]] = tt.broadcast %[[VAL_37]] : tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> -> tensor<32x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>>
// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_38]], %[[VAL_36]] : tensor<32x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>>
// CHECK: %[[VAL_40:.*]] = tt.expand_dims %[[VAL_39]] {axis = 1 : i32} : tensor<32x16xi32, #ttg.slice<{dim = 1, parent = #[[$ATTR_3]]}>> -> tensor<32x1x16xi32, #[[$ATTR_3]]>
// CHECK: %[[VAL_41:.*]] = tt.splat %[[VAL_1]] : !tt.ptr<f32> -> tensor<32x1x16x!tt.ptr<f32>, #[[$ATTR_3]]>
// CHECK: %[[VAL_42:.*]] = tt.addptr %[[VAL_41]], %[[VAL_40]] : tensor<32x1x16x!tt.ptr<f32>, #[[$ATTR_3]]>, tensor<32x1x16xi32, #[[$ATTR_3]]>
// CHECK: tt.store %[[VAL_42]], %[[VAL_33]] : tensor<32x1x16x!tt.ptr<f32>, #[[$ATTR_3]]>
// CHECK: tt.return
// CHECK: }
// -----
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @block_copy_kernel(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<i1> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0> : tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%c2_i32 = arith.constant 2 : i32
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c64_i32 : i32
%2 = arith.divsi %arg2, %c2_i32 : i32
%3 = arith.extsi %2 : i32 to i64
%4 = tt.bitcast %arg0 : !tt.ptr<i1> -> !tt.ptr<i8>
%5 = arith.extsi %1 : i32 to i64
%6 = arith.extsi %arg2 : i32 to i64
%7 = tt.bitcast %arg1 : !tt.ptr<i1> -> !tt.ptr<i8>
%8 = tt.splat %4 : !tt.ptr<i8> -> tensor<64x!tt.ptr<i8>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%9 = tt.splat %5 : i64 -> tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%11 = arith.extsi %10 : tensor<64xi32, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>> to tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%12 = arith.addi %9, %11 : tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%13 = tt.addptr %8, %12 : tensor<64x!tt.ptr<i8>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%14 = arith.cmpi sge, %12, %cst : tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%15 = tt.splat %3 : i64 -> tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%16 = arith.cmpi slt, %12, %15 : tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%17 = arith.andi %14, %16 : tensor<64xi1, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%18 = tt.load %13, %17 : tensor<64x!tt.ptr<i8>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%19 = tt.splat %7 : !tt.ptr<i8> -> tensor<64x!tt.ptr<i8>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%20 = tt.addptr %19, %12 : tensor<64x!tt.ptr<i8>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>, tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%21 = tt.splat %6 : i64 -> tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%22 = arith.cmpi slt, %12, %21 : tensor<64xi64, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
%23 = arith.andi %14, %22 : tensor<64xi1, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
tt.store %20, %18, %23 : tensor<64x!tt.ptr<i8>, #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
tt.return
}
}
// CHECK: #[[$ATTR_4:.+]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
// CHECK-LABEL: tt.func public @block_copy_kernel(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<i1> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_1:.*]]: !tt.ptr<i1> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %[[VAL_2:.*]]: i32 {tt.divisibility = 16 : i32}) {
// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : tensor<64xi64, #[[$ATTR_4]]>
// CHECK: %[[VAL_4:.*]] = arith.constant 2 : i32
// CHECK: %[[VAL_5:.*]] = arith.constant 64 : i32
// CHECK: %[[VAL_6:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_6]], %[[VAL_5]] : i32
// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_2]], %[[VAL_4]] : i32
// CHECK: %[[VAL_9:.*]] = arith.extsi %[[VAL_8]] : i32 to i64
// CHECK: %[[VAL_10:.*]] = tt.bitcast %[[VAL_0]] : !tt.ptr<i1> -> !tt.ptr<i8>
// CHECK: %[[VAL_11:.*]] = arith.extsi %[[VAL_7]] : i32 to i64
// CHECK: %[[VAL_12:.*]] = arith.extsi %[[VAL_2]] : i32 to i64
// CHECK: %[[VAL_13:.*]] = tt.bitcast %[[VAL_1]] : !tt.ptr<i1> -> !tt.ptr<i8>
// CHECK: %[[VAL_14:.*]] = tt.splat %[[VAL_10]] : !tt.ptr<i8> -> tensor<64x!tt.ptr<i8>, #[[$ATTR_4]]>
// CHECK: %[[VAL_15:.*]] = tt.splat %[[VAL_11]] : i64 -> tensor<64xi64, #[[$ATTR_4]]>
// CHECK: %[[VAL_16:.*]] = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #[[$ATTR_4]]>
// CHECK: %[[VAL_17:.*]] = arith.extsi %[[VAL_16]] : tensor<64xi32, #[[$ATTR_4]]> to tensor<64xi64, #[[$ATTR_4]]>
// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_17]] : tensor<64xi64, #[[$ATTR_4]]>
// CHECK: %[[VAL_19:.*]] = tt.addptr %[[VAL_14]], %[[VAL_18]] : tensor<64x!tt.ptr<i8>, #[[$ATTR_4]]>, tensor<64xi64, #[[$ATTR_4]]>
// CHECK: %[[VAL_20:.*]] = arith.cmpi sge, %[[VAL_18]], %[[VAL_3]] : tensor<64xi64, #[[$ATTR_4]]>
// CHECK: %[[VAL_21:.*]] = tt.splat %[[VAL_9]] : i64 -> tensor<64xi64, #[[$ATTR_4]]>
// CHECK: %[[VAL_22:.*]] = arith.cmpi slt, %[[VAL_18]], %[[VAL_21]] : tensor<64xi64, #[[$ATTR_4]]>
// CHECK: %[[VAL_23:.*]] = arith.andi %[[VAL_20]], %[[VAL_22]] : tensor<64xi1, #[[$ATTR_4]]>
// CHECK: %[[VAL_24:.*]] = tt.load %[[VAL_19]], %[[VAL_23]] : tensor<64x!tt.ptr<i8>, #[[$ATTR_4]]>
// CHECK: %[[VAL_25:.*]] = tt.splat %[[VAL_13]] : !tt.ptr<i8> -> tensor<64x!tt.ptr<i8>, #[[$ATTR_4]]>
// CHECK: %[[VAL_26:.*]] = tt.addptr %[[VAL_25]], %[[VAL_18]] : tensor<64x!tt.ptr<i8>, #[[$ATTR_4]]>, tensor<64xi64, #[[$ATTR_4]]>
// CHECK: %[[VAL_27:.*]] = tt.splat %[[VAL_12]] : i64 -> tensor<64xi64, #[[$ATTR_4]]>
// CHECK: %[[VAL_28:.*]] = arith.cmpi slt, %[[VAL_18]], %[[VAL_27]] : tensor<64xi64, #[[$ATTR_4]]>
// CHECK: %[[VAL_29:.*]] = arith.andi %[[VAL_20]], %[[VAL_28]] : tensor<64xi1, #[[$ATTR_4]]>
// CHECK: tt.store %[[VAL_26]], %[[VAL_24]], %[[VAL_29]] : tensor<64x!tt.ptr<i8>, #[[$ATTR_4]]>
// CHECK: tt.return
// CHECK: }
// -----
module attributes {} {
tt.func public @asin_kernel(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i32) {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg2 : i32 -> tensor<1024xi32>
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32>
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>>
%10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__ocml_asin_f32"} : (tensor<1024xf32>) -> tensor<1024xf32>
%11 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
tt.store %12, %10, %6 : tensor<1024x!tt.ptr<f32>>
tt.return
}
}
// CHECK-LABEL: tt.func public @asin_kernel(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>, %[[VAL_1:.*]]: !tt.ptr<f32>, %[[VAL_2:.*]]: i32) {
// CHECK: %[[VAL_3:.*]] = arith.constant 1024 : i32
// CHECK: %[[VAL_4:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_5:.*]] = arith.muli %[[VAL_4]], %[[VAL_3]] : i32
// CHECK: %[[VAL_6:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
// CHECK: %[[VAL_7:.*]] = tt.splat %[[VAL_5]] : i32 -> tensor<1024xi32>
// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_6]] : tensor<1024xi32>
// CHECK: %[[VAL_9:.*]] = tt.splat %[[VAL_2]] : i32 -> tensor<1024xi32>
// CHECK: %[[VAL_10:.*]] = arith.cmpi slt, %[[VAL_8]], %[[VAL_9]] : tensor<1024xi32>
// CHECK: %[[VAL_11:.*]] = tt.addptr %[[VAL_0]], %[[VAL_5]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_12:.*]] = tt.splat %[[VAL_11]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_13:.*]] = tt.addptr %[[VAL_12]], %[[VAL_6]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// CHECK: %[[VAL_14:.*]] = tt.load %[[VAL_13]], %[[VAL_10]] : tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_15:.*]] = tt.extern_elementwise %[[VAL_14]] {libname = "", libpath = "", pure = true, symbol = "__ocml_asin_f32"} : (tensor<1024xf32>) -> tensor<1024xf32>
// CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_1]], %[[VAL_5]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_17:.*]] = tt.splat %[[VAL_16]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_18:.*]] = tt.addptr %[[VAL_17]], %[[VAL_6]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
// CHECK: tt.store %[[VAL_18]], %[[VAL_15]], %[[VAL_10]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return
// CHECK: }
// -----
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @inline_asm(%arg0: !tt.ptr<i8>, %arg1: !tt.ptr<i8>) {
%0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
%1 = tt.splat %arg0 : !tt.ptr<i8> -> tensor<512x!tt.ptr<i8>>
%2 = tt.addptr %1, %0 : tensor<512x!tt.ptr<i8>>, tensor<512xi32>
%3 = tt.load %2 : tensor<512x!tt.ptr<i8>>
%4 = tt.elementwise_inline_asm "shl.b32 $0, $0, 3;" {constraints = "=r,r", packed_element = 4 : i32, pure = true} %3 : tensor<512xi8> -> tensor<512xi8>
%5 = tt.splat %arg1 : !tt.ptr<i8> -> tensor<512x!tt.ptr<i8>>
%6 = tt.addptr %5, %0 : tensor<512x!tt.ptr<i8>>, tensor<512xi32>
tt.store %6, %4 : tensor<512x!tt.ptr<i8>>
tt.return
}
}
// CHECK-LABEL: tt.func public @inline_asm(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<i8>,
// CHECK-SAME: %[[VAL_1:.*]]: !tt.ptr<i8>) {
// CHECK: %[[VAL_2:.*]] = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
// CHECK: %[[VAL_3:.*]] = tt.splat %[[VAL_0]] : !tt.ptr<i8> -> tensor<512x!tt.ptr<i8>>
// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_3]], %[[VAL_2]] : tensor<512x!tt.ptr<i8>>, tensor<512xi32>
// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : tensor<512x!tt.ptr<i8>>
// CHECK: %[[VAL_6:.*]] = tt.elementwise_inline_asm "shl.b32 $0, $0, 3;" {constraints = "=r,r", packed_element = 4 : i32, pure = true} %[[VAL_5]] : tensor<512xi8> -> tensor<512xi8>
// CHECK: %[[VAL_7:.*]] = tt.splat %[[VAL_1]] : !tt.ptr<i8> -> tensor<512x!tt.ptr<i8>>
// CHECK: %[[VAL_8:.*]] = tt.addptr %[[VAL_7]], %[[VAL_2]] : tensor<512x!tt.ptr<i8>>, tensor<512xi32>
// CHECK: tt.store %[[VAL_8]], %[[VAL_6]] : tensor<512x!tt.ptr<i8>>
// CHECK: tt.return
// CHECK: }
// -----
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @_compute_indx(
%arg0: !tt.ptr<i16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
%arg2: i32 {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32}
) -> tensor<256xi32> {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
%3 = tt.splat %1 : i32 -> tensor<256xi32>
%4 = arith.addi %3, %2 : tensor<256xi32>
%5 = tt.splat %arg3 : i32 -> tensor<256xi32>
%6 = arith.cmpi slt, %4, %5 : tensor<256xi32>
%7 = tt.splat %arg0 : !tt.ptr<i16> -> tensor<256x!tt.ptr<i16>>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<i16>>, tensor<256xi32>
%9 = tt.load %8, %6 : tensor<256x!tt.ptr<i16>>
%10 = arith.muli %0, %arg2 : i32
%11 = tt.addptr %arg1, %10 : !tt.ptr<i32>, i32
%12 = tt.splat %11 : !tt.ptr<i32> -> tensor<256x!tt.ptr<i32>>
%13 = tt.addptr %12, %9 : tensor<256x!tt.ptr<i32>>, tensor<256xi16>
%14 = tt.load %13, %6 : tensor<256x!tt.ptr<i32>>
tt.return %14 : tensor<256xi32>
}
}
// CHECK-LABEL: tt.func public @_compute_indx(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<i16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
// CHECK-SAME: %[[VAL_1:.*]]: !tt.ptr<i32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
// CHECK-SAME: %[[VAL_2:.*]]: i32 {tt.divisibility = 16 : i32},
// CHECK-SAME: %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}) -> tensor<256xi32> {
// CHECK: %[[VAL_4:.*]] = arith.constant 256 : i32
// CHECK: %[[VAL_5:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_5]], %[[VAL_4]] : i32
// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
// CHECK: %[[VAL_8:.*]] = tt.splat %[[VAL_6]] : i32 -> tensor<256xi32>
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : tensor<256xi32>
// CHECK: %[[VAL_10:.*]] = tt.splat %[[VAL_3]] : i32 -> tensor<256xi32>
// CHECK: %[[VAL_11:.*]] = arith.cmpi slt, %[[VAL_9]], %[[VAL_10]] : tensor<256xi32>
// CHECK: %[[VAL_12:.*]] = tt.addptr %[[VAL_0]], %[[VAL_6]] : !tt.ptr<i16>, i32
// CHECK: %[[VAL_13:.*]] = tt.splat %[[VAL_12]] : !tt.ptr<i16> -> tensor<256x!tt.ptr<i16>>
// CHECK: %[[VAL_14:.*]] = tt.addptr %[[VAL_13]], %[[VAL_7]] : tensor<256x!tt.ptr<i16>>, tensor<256xi32>
// CHECK: %[[VAL_15:.*]] = tt.load %[[VAL_14]], %[[VAL_11]] : tensor<256x!tt.ptr<i16>>
// CHECK: %[[VAL_16:.*]] = arith.muli %[[VAL_5]], %[[VAL_2]] : i32
// CHECK: %[[VAL_17:.*]] = tt.addptr %[[VAL_1]], %[[VAL_16]] : !tt.ptr<i32>, i32
// CHECK: %[[VAL_18:.*]] = arith.extsi %[[VAL_15]] : tensor<256xi16> to tensor<256xi32>
// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_17]] : !tt.ptr<i32> -> tensor<256x!tt.ptr<i32>>
// CHECK: %[[VAL_20:.*]] = tt.addptr %[[VAL_19]], %[[VAL_18]] : tensor<256x!tt.ptr<i32>>, tensor<256xi32>
// CHECK: %[[VAL_21:.*]] = tt.load %[[VAL_20]], %[[VAL_11]] : tensor<256x!tt.ptr<i32>>
// CHECK: tt.return %[[VAL_21]] : tensor<256xi32>
// CHECK: }
// -----
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
tt.func @conversion_extract_slice(%arg0: !tt.ptr<f32>, %arg1: tensor<256x256xi32, #blocked>) -> tensor<128x256xf32, #blocked> {
%3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x256x!tt.ptr<f32>, #blocked>
%4 = tt.addptr %3, %arg1 : tensor<256x256x!tt.ptr<f32>, #blocked>, tensor<256x256xi32, #blocked>
%5 = amdgpu.extract_slice %4 [0, 0] : tensor<256x256x!tt.ptr<f32>, #blocked> to tensor<128x256x!tt.ptr<f32>, #blocked>
%6 = tt.load %5 : tensor<128x256x!tt.ptr<f32>, #blocked>
tt.return %6 : tensor<128x256xf32, #blocked>
}
}
// CHECK-LABEL: tt.func @conversion_extract_slice(
// CHECK-SAME: %[[ARG_0:.*]]: !tt.ptr<f32>, %[[ARG_1:.*]]: tensor<256x256xi32, #blocked>) -> tensor<128x256xf32, #blocked> {
// CHECK: %[[VAR_0:.*]] = arith.extsi %[[ARG_1]] : tensor<256x256xi32, #blocked> to tensor<256x256xi64, #blocked>
// CHECK: %[[VAR_1:.*]] = amdgpu.extract_slice %[[VAR_0]] [0, 0] : tensor<256x256xi64, #blocked> to tensor<128x256xi64, #blocked>
// CHECK: %[[VAR_2:.*]] = arith.trunci %[[VAR_1]] : tensor<128x256xi64, #blocked> to tensor<128x256xi32, #blocked>
// CHECK: %[[VAR_3:.*]] = tt.splat %[[ARG_0]] : !tt.ptr<f32> -> tensor<128x256x!tt.ptr<f32>, #blocked>
// CHECK: %[[VAR_4:.*]] = tt.addptr %[[VAR_3]], %[[VAR_2]] : tensor<128x256x!tt.ptr<f32>, #blocked>, tensor<128x256xi32, #blocked>
// CHECK: %[[VAR_5:.*]] = tt.load %[[VAR_4]] : tensor<128x256x!tt.ptr<f32>, #blocked>
// CHECK: tt.return %[[VAR_5]] : tensor<128x256xf32, #blocked>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @ifOpPoison(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>, %arg2: i1) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
// expected-remark@+1 {{skipping canonicalize-pointers due to ub.poison}}
%poison = ub.poison : tensor<1024x!tt.ptr<f32>>
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%6 = scf.if %arg2 -> (tensor<1024x!tt.ptr<f32>>) {
%8 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
scf.yield %8 : tensor<1024x!tt.ptr<f32>>
} else {
scf.yield %poison : tensor<1024x!tt.ptr<f32>>
}
%7 = tt.load %6 : tensor<1024x!tt.ptr<f32>>
tt.return %7 : tensor<1024xf32>
}
}
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @propagate_divisibility(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.splat %1 : i32 -> tensor<1024xi32>
%3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%4 = tt.addptr %3, %2 {tt.divisibility = 16 : i32, misc.misc = 3 : i32} : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%5 = tt.load %4 : tensor<1024x!tt.ptr<f32>>
tt.return %5 : tensor<1024xf32>
}
}
// CHECK-LABEL: tt.func @propagate_divisibility(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32
// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32
// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] {tt.divisibility = 16 : i32} : !tt.ptr<f32>, i32
// CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_4]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
// CHECK: %[[VAL_6:.*]] = tt.load %[[VAL_5]] : tensor<1024x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_6]] : tensor<1024xf32>
// CHECK: }
// -----
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @divisiblity_changeing_dims(%arg0: !tt.ptr<f32>) -> tensor<1024x32xf32> {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.splat %1 : i32 -> tensor<1024x32xi32>
%3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x32x!tt.ptr<f32>>
%4 = tt.addptr %3, %2 {tt.divisibility = dense<[1, 16]> : tensor<2xi32>} : tensor<1024x32x!tt.ptr<f32>>, tensor<1024x32xi32>
%5 = tt.load %4 : tensor<1024x32x!tt.ptr<f32>>
tt.return %5 : tensor<1024x32xf32>
}
}
// CHECK-LABEL: tt.func @divisiblity_changeing_dims(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024x32xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32
// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32
// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr<f32>, i32
// CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_4]] : !tt.ptr<f32> -> tensor<1024x32x!tt.ptr<f32>>
// CHECK: %[[VAL_6:.*]] = tt.load %[[VAL_5]] : tensor<1024x32x!tt.ptr<f32>>
// CHECK: tt.return %[[VAL_6]] : tensor<1024x32xf32>
// CHECK: }