// RUN: triton-opt %s -split-input-file -triton-reorder-broadcast | FileCheck %s
// CHECK-LABEL: @test_splat_elementwise_pattern
tt.func @test_splat_elementwise_pattern(%arg0: f32) -> (tensor<128x128xf32>, tensor<128x128x!tt.ptr<f32>>) {
// CHECK-DAG: %[[a:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : i64
%c1 = arith.constant 1 : i64
%a = arith.constant dense<1.0> : tensor<128x128xf32>
// CHECK-DAG: %[[add:.*]] = arith.addf %arg0, %[[a]] : f32
// CHECK-NEXT: %[[splat:.*]] = tt.splat %[[add]] : f32 -> tensor<128x128xf32>
%b = tt.splat %arg0 : f32 -> tensor<128x128xf32>
%add = arith.addf %a, %b : tensor<128x128xf32>
// CHECK-NEXT: %[[ptr:.*]] = tt.int_to_ptr %[[c1]] : i64 -> !tt.ptr<f32>
// CHECK-NEXT: %{{.*}} = tt.splat %[[ptr]] : !tt.ptr<f32> -> tensor<128x128x!tt.ptr<f32>>
%c1_t = tt.splat %c1 : i64 -> tensor<128x128xi64>
%ptr = tt.int_to_ptr %c1_t : tensor<128x128xi64> -> tensor<128x128x!tt.ptr<f32>>
tt.return %add, %ptr : tensor<128x128xf32>, tensor<128x128x!tt.ptr<f32>>
}
// CHECK-LABEL: @test_broadcast_elementwise_pattern
tt.func @test_broadcast_elementwise_pattern(%arg0: tensor<128x1xf32>) -> (tensor<128x128xf32>, tensor<128x32xf32>) {
// CHECK: %[[one:.*]] = arith.constant dense<1.000000e+00> : tensor<128x1xf32>
// CHECK-NEXT: %[[abs:.*]] = math.absf %arg0 : tensor<128x1xf32>
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[abs]] : tensor<128x1xf32> -> tensor<128x128xf32>
%broadcast = tt.broadcast %arg0 : tensor<128x1xf32> -> tensor<128x128xf32>
%abs = math.absf %broadcast : tensor<128x128xf32>
// CHECK-NEXT: %[[add:.*]] = arith.addf %arg0, %[[one]] : tensor<128x1xf32>
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[add]] : tensor<128x1xf32> -> tensor<128x32xf32>
%broadcast2 = tt.broadcast %arg0 : tensor<128x1xf32> -> tensor<128x32xf32>
%one = arith.constant dense<1.0> : tensor<128x32xf32>
%add = arith.addf %one, %broadcast2 : tensor<128x32xf32>
tt.return %abs, %add : tensor<128x128xf32>, tensor<128x32xf32>
}
// CHECK-LABEL: @test_broadcast_binary_op_pattern
tt.func @test_broadcast_binary_op_pattern(%arg0: tensor<128x1xf32>, %arg1: tensor<128x1xf32>, %arg2: tensor<1x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
// CHECK: %[[mul:.*]] = arith.mulf %{{.*}}, %{{.*}} : tensor<128x1xf32>
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[mul]] : tensor<128x1xf32> -> tensor<128x128xf32>
%broadcast0 = tt.broadcast %arg0 : tensor<128x1xf32> -> tensor<128x128xf32>
%broadcast1 = tt.broadcast %arg1 : tensor<128x1xf32> -> tensor<128x128xf32>
%mul = arith.mulf %broadcast0, %broadcast1 : tensor<128x128xf32>
// CHECK: %[[mul:.*]] = arith.mulf %{{.*}}, %{{.*}} : tensor<128x128xf32>
%broadcast2 = tt.broadcast %arg2 : tensor<1x128xf32> -> tensor<128x128xf32>
%mul1 = arith.mulf %broadcast0, %broadcast2 : tensor<128x128xf32>
tt.return %mul, %mul1 : tensor<128x128xf32>, tensor<128x128xf32>
}
// CHECK-LABEL: @test_broadcast_mix_type_op_pattern
tt.func @test_broadcast_mix_type_op_pattern(%arg0: tensor<128x1xf32>, %arg1: f32, %arg2: tensor<1x128xf32>, %arg3: tensor<128x1xi1>) -> (tensor<128x128xf32>) {
// CHECK: %[[sel:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<128x1xi1>, tensor<128x1xf32>
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[sel]] : tensor<128x1xf32> -> tensor<128x128xf32>
%broadcast0 = tt.broadcast %arg0 : tensor<128x1xf32> -> tensor<128x128xf32>
%broadcast1 = tt.splat %arg1 : f32 -> tensor<128x128xf32>
%cond = tt.broadcast %arg3 : tensor<128x1xi1> -> tensor<128x128xi1>
%sel = arith.select %cond, %broadcast0, %broadcast1 : tensor<128x128xi1>, tensor<128x128xf32>
tt.return %sel : tensor<128x128xf32>
}