// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s
// We don't combine if the dot result is used by more than one op.
// CHECK-LABEL: @test_combine_dot_add_invalid_pattern
tt.func @test_combine_dot_add_invalid_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) {
// CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
// CHECK-DAG: %[[e:.*]] = arith.constant dense<4.000000e+00> : tensor<128x128xf32>
%a = arith.constant dense<1.0> : tensor<128x128xf32>
%b = arith.constant dense<2.0> : tensor<128x128xf32>
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
%d = arith.constant dense<3.0> : tensor<128x128xf32>
%e = arith.constant dense<4.0> : tensor<128x128xf32>
%dot_out = tt.dot %a, %b, %zero : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK: arith.addf %{{.*}}, %[[d]] : tensor<128x128xf32>
%res0 = arith.addf %dot_out, %d : tensor<128x128xf32>
// CHECK-NEXT: arith.addf %{{.*}}, %[[e]] : tensor<128x128xf32>
%res1 = arith.addf %dot_out, %e : tensor<128x128xf32>
tt.return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
}
// CHECK-LABEL: @test_combine_dot_add_pattern
tt.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>) {
// CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
// CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32>
// CHECK-DAG: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32>
%a = arith.constant dense<1.0> : tensor<128x128xf32>
%b = arith.constant dense<2.0> : tensor<128x128xf32>
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
%d = arith.constant dense<3.0> : tensor<128x128xf32>
%dot_out = tt.dot %a, %b, %zero : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res:.*]] = tt.dot %[[a]], %[[b]], %[[d]] : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: tt.return %[[res]] : tensor<128x128xf32>
%res = arith.addf %dot_out, %d : tensor<128x128xf32>
tt.return %res : tensor<128x128xf32>
}
// CHECK-LABEL: @test_combine_dot_add_rev_pattern
tt.func @test_combine_dot_add_rev_pattern() -> (tensor<128x128xf32>) {
// CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32>
// CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32>
// CHECK-DAG: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32>
%a = arith.constant dense<1.0> : tensor<128x128xf32>
%b = arith.constant dense<2.0> : tensor<128x128xf32>
%zero = arith.constant dense<0.0> : tensor<128x128xf32>
%d = arith.constant dense<3.0> : tensor<128x128xf32>
%dot_out = tt.dot %a, %b, %zero : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: %[[res:.*]] = tt.dot %[[a]], %[[b]], %[[d]] : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32>
// CHECK-NEXT: tt.return %[[res]] : tensor<128x128xf32>
%res = arith.addf %d, %dot_out : tensor<128x128xf32>
tt.return %res : tensor<128x128xf32>
}
// CHECK-LABEL: @test_combine_addptr_pattern
tt.func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
%off0 = arith.constant 10 : i32
%off1 = arith.constant 15 : i32
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi32>
%base_ = tt.splat %base : !tt.ptr<f32> -> tensor<8x!tt.ptr<f32>>
// CHECK-NEXT: %[[tmp0:.*]] = tt.splat %{{.*}} : !tt.ptr<f32> -> tensor<8x!tt.ptr<f32>>
%idx0 = tt.splat %off0 : i32 -> tensor<8xi32>
%idx1 = tt.splat %off1 : i32 -> tensor<8xi32>
// CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
tt.return %ptr1 : tensor<8x!tt.ptr<f32>>
}
// CHECK-LABEL: @test_combine_addptr_pattern_i64
tt.func @test_combine_addptr_pattern_i64(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
%off0 = arith.constant 10 : i64
%off1 = arith.constant dense<15> : tensor<8xi64>
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<25> : tensor<8xi64>
%base_ = tt.splat %base : !tt.ptr<f32> -> tensor<8x!tt.ptr<f32>>
// CHECK-NEXT: %[[tmp0:.*]] = tt.splat %{{.*}} : !tt.ptr<f32> -> tensor<8x!tt.ptr<f32>>
%idx0 = tt.splat %off0 : i64 -> tensor<8xi64>
// CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>, tensor<8xi64>
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>, tensor<8xi64>
%ptr1 = tt.addptr %ptr0, %off1 : tensor<8x!tt.ptr<f32>>, tensor<8xi64>
tt.return %ptr1 : tensor<8x!tt.ptr<f32>>
}
// CHECK-LABEL: @test_combine_addptr_pattern_scalar
tt.func @test_combine_addptr_pattern_scalar(%base: !tt.ptr<f32>) -> !tt.ptr<f32> {
%off0 = arith.constant 10 : i32
%off1 = arith.constant 15 : i32
// CHECK-NEXT: %[[cst:.*]] = arith.constant 25 : i32
// CHECK-NEXT: %0 = tt.addptr %{{.*}}, %[[cst]] : !tt.ptr<f32>, i32
%ptr0 = tt.addptr %base, %off0 : !tt.ptr<f32>, i32
%ptr1 = tt.addptr %ptr0, %off1 : !tt.ptr<f32>, i32
tt.return %ptr1 : !tt.ptr<f32>
}
// CHECK-LABEL: @test_not_combine_addptr_pattern_1
tt.func @test_not_combine_addptr_pattern_1(%base: !tt.ptr<f32>, %idx0: tensor<8xi32>) -> tensor<8x!tt.ptr<f32>> {
%off1 = arith.constant 15 : i32
%base_ = tt.splat %base : !tt.ptr<f32> -> tensor<8x!tt.ptr<f32>>
%idx1 = tt.splat %off1 : i32 -> tensor<8xi32>
// CHECK: tt.addptr
// CHECK-NEXT: tt.addptr
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
tt.return %ptr1 : tensor<8x!tt.ptr<f32>>
}
// CHECK-LABEL: @test_not_combine_addptr_pattern
tt.func @test_not_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
%off0 = arith.constant 10 : i16
%off1 = arith.constant 15 : i32
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<10> : tensor<8xi16>
// CHECK-DAG: %[[cst1:.*]] = arith.constant dense<15> : tensor<8xi32>
%base_ = tt.splat %base : !tt.ptr<f32> -> tensor<8x!tt.ptr<f32>>
%idx0 = tt.splat %off0 : i16 -> tensor<8xi16>
%idx1 = tt.splat %off1 : i32 -> tensor<8xi32>
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>, tensor<8xi16>
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>, tensor<8xi32>
tt.return %ptr1 : tensor<8x!tt.ptr<f32>>
}
// CHECK-LABEL: @test_not_combine_addptr_pattern_overflow
tt.func @test_not_combine_addptr_pattern_overflow(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
%off0 = arith.constant 127 : i8
%off1 = arith.constant 1 : i8
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<127> : tensor<8xi8>
// CHECK-DAG: %[[cst1:.*]] = arith.constant dense<1> : tensor<8xi8>
%base_ = tt.splat %base : !tt.ptr<f32> -> tensor<8x!tt.ptr<f32>>
%idx0 = tt.splat %off0 : i8 -> tensor<8xi8>
%idx1 = tt.splat %off1 : i8 -> tensor<8xi8>
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>, tensor<8xi8>
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>, tensor<8xi8>
tt.return %ptr1 : tensor<8x!tt.ptr<f32>>
}
// CHECK-LABEL: @test_combine_select_masked_load_pattern
tt.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
%mask = tt.splat %cond : i1 -> tensor<8xi1>
%false_val = arith.constant dense<0.0> : tensor<8xf32>
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} : tensor<8x!tt.ptr<f32>>
%x = tt.load %ptr, %mask, %false_val : tensor<8x!tt.ptr<f32>>
%0 = arith.select %cond, %x, %false_val : tensor<8xf32>
// CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} : tensor<8x!tt.ptr<f32>>
%y = tt.load %ptr, %mask, %false_val : tensor<8x!tt.ptr<f32>>
%1 = arith.select %cond, %y, %false_val : tensor<8xf32>
// CHECK: tt.return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32>
tt.return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
tt.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
%false_val = arith.constant dense<0.0> : tensor<8xf32>
// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
// CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
%0 = arith.select %cond0, %dummy_load, %false_val : tensor<8xf32>
// Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized.
%real_load0 = tt.load %ptr, %dummy_broadcast, %false_val : tensor<8x!tt.ptr<f32>>
// CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
%1 = arith.select %cond0, %real_load0, %false_val : tensor<8xf32>
// Case 3: condition of "broadcast" is not the same as the condition of "select". Select should not be canonicalized.
%cond0_ = tt.splat %cond0 : i1 -> tensor<8xi1>
%real_load1 = tt.load %ptr, %cond0_, %false_val : tensor<8x!tt.ptr<f32>>
// CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
%2 = arith.select %cond1, %real_load1, %false_val : tensor<8xf32>
tt.return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
tt.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
%const = arith.constant dense<1.0> : tensor<8x1xf32>
%bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32>
// CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32>
tt.return %bst_out : tensor<8x2xf32>
}
// CHECK-LABEL: @test_canonicalize_masked_load_pattern
tt.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
%true_mask = arith.constant dense<true> : tensor<8xi1>
%false_mask = arith.constant dense<false> : tensor<8xi1>
%other_val = arith.constant dense<0.0> : tensor<8xf32>
// true_mask with other
// CHECK: %[[res1:.*]] = tt.load %{{.*}} : tensor<8x!tt.ptr<f32>>
%x = tt.load %ptr, %true_mask : tensor<8x!tt.ptr<f32>>
// true_mask without other
// CHECK: %[[res2:.*]] = tt.load %{{.*}} : tensor<8x!tt.ptr<f32>>
%y = tt.load %ptr, %true_mask, %other_val : tensor<8x!tt.ptr<f32>>
// false_mask with other. It should become "other" (i.e., %y)
%z = tt.load %ptr, %false_mask, %y : tensor<8x!tt.ptr<f32>>
// CHECK: tt.return %[[res1]], %[[res2]], %[[res2]] : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
tt.return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern
tt.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
%other_val = arith.constant dense<0.0> : tensor<8xf32>
// Case: value at the "mask" position is not an "op". Load should not be canonicalized.
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}} : tensor<8x!tt.ptr<f32>>
%x = tt.load %ptr, %mask : tensor<8x!tt.ptr<f32>>
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} : tensor<8x!tt.ptr<f32>>
%y = tt.load %ptr, %mask, %other_val : tensor<8x!tt.ptr<f32>>
tt.return %x, %y: tensor<8xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_canonicalize_masked_store_pattern
tt.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
%true_mask = arith.constant dense<true> : tensor<8xi1>
%false_mask = arith.constant dense<false> : tensor<8xi1>
// CHECK: tt.store %{{.*}}, %{{.*}} : tensor<8x!tt.ptr<f32>>
tt.store %ptr, %val, %true_mask : tensor<8x!tt.ptr<f32>>
// The following store should disappear.
// CHECK-NEXT: tt.return
tt.store %ptr, %val, %false_mask : tensor<8x!tt.ptr<f32>>
tt.return
}
// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern
tt.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
// Case: value at the "mask" position is not an "op". Store should not be canonicalized.
// CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8x!tt.ptr<f32>>
tt.store %ptr, %val, %mask : tensor<8x!tt.ptr<f32>>
tt.return
}
// CHECK-LABEL: @test_canonicalize_expand_dims
tt.func @test_canonicalize_expand_dims(%arg0: tensor<f32>, %arg1: tensor<1xf32>) -> (tensor<1x8xf32>, tensor<8x8xf32>) {
%splat = tt.splat %arg0 : tensor<f32> -> tensor<8xf32>
// CHECK: %{{.*}} = tt.splat %arg0 : tensor<f32> -> tensor<1x8xf32>
%ed = tt.expand_dims %splat {axis = 0 : i32} : tensor<8xf32> -> tensor<1x8xf32>
// CHECK-NEXT: %[[ed2:.*]] = tt.expand_dims %arg1 {axis = 0 : i32} : tensor<1xf32> -> tensor<1x1xf32>
// CHECK-NEXT: %{{.*}} = tt.broadcast %[[ed2]] : tensor<1x1xf32> -> tensor<8x8xf32>
%bc = tt.broadcast %arg1 : tensor<1xf32> -> tensor<8xf32>
%ed2 = tt.expand_dims %bc {axis = 0 : i32} : tensor<8xf32> -> tensor<1x8xf32>
%bc2 = tt.broadcast %ed2 : tensor<1x8xf32> -> tensor<8x8xf32>
tt.return %ed, %bc2 : tensor<1x8xf32>, tensor<8x8xf32>
}
// CHECK-LABEL: @test_canonicalize_view
tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>) {
%view0 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<2x4xf32>
// CHECK: %{{.*}} = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<4x2xf32>
%view1 = tt.reshape %view0 allow_reorder : tensor<2x4xf32> -> tensor<4x2xf32>
%splat = tt.splat %arg1 : tensor<f32> -> tensor<8xf32>
// CHECK: %{{.*}} = tt.splat %arg1 : tensor<f32> -> tensor<2x2x2xf32>
%view2 = tt.reshape %splat allow_reorder : tensor<8xf32> -> tensor<2x2x2xf32>
%view3 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<8xf32>
// CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32>
%add = arith.addf %view3, %arg0 : tensor<8xf32>
tt.return %view1, %view2, %add : tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>
}
// CHECK-LABEL: @test_canonicalize_broadcast
tt.func @test_canonicalize_broadcast(%arg0: tensor<1x1x8xf32>, %arg1: tensor<f32>) -> (tensor<4x2x8xf32>, tensor<8x8xf32>, tensor<1x1x8xf32>) {
%broadcast0 = tt.broadcast %arg0 : tensor<1x1x8xf32> -> tensor<1x2x8xf32>
// CHECK: %{{.*}} = tt.broadcast %arg0 : tensor<1x1x8xf32> -> tensor<4x2x8xf32>
%broadcast1 = tt.broadcast %broadcast0 : tensor<1x2x8xf32> -> tensor<4x2x8xf32>
%splat = tt.splat %arg1 : tensor<f32> -> tensor<1x8xf32>
// CHECK: %{{.*}} = tt.splat %arg1 : tensor<f32> -> tensor<8x8xf32>
%broadcast2 = tt.broadcast %splat : tensor<1x8xf32> -> tensor<8x8xf32>
%broadcast3 = tt.broadcast %arg0 : tensor<1x1x8xf32> -> tensor<1x1x8xf32>
// CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<1x1x8xf32>
%add = arith.addf %broadcast3, %arg0 : tensor<1x1x8xf32>
tt.return %broadcast1, %broadcast2, %add : tensor<4x2x8xf32>, tensor<8x8xf32>, tensor<1x1x8xf32>
}
// CHECK-LABEL: @test_fold_views
tt.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x128xf32>) {
%a = arith.constant dense<1.0> : tensor<1x128xf32>
// CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x8xf32>
%b = tt.reshape %a allow_reorder : tensor<1x128xf32> -> tensor<16x8xf32>
// CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x128xf32>
%c = tt.broadcast %a : tensor<1x128xf32> -> tensor<16x128xf32>
// CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<1x1x128xf32>
%d = tt.expand_dims %a {axis = 0: i32} : tensor<1x128xf32> -> tensor<1x1x128xf32>
tt.return %b, %c, %d : tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x128xf32>
}
// CHECK-LABEL: @test_nop_transpose
tt.func @test_nop_transpose(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>) {
%a = tt.trans %arg0 {order = array<i32: 0, 1>} : tensor<2x4xf32> -> tensor<2x4xf32>
// CHECK: tt.return %arg0
tt.return %a : tensor<2x4xf32>
}
// CHECK-LABEL: @test_nested_transpose
tt.func @test_nested_transpose(%arg0: tensor<2x4x8xf32>) -> (tensor<8x2x4xf32>) {
%a = tt.trans %arg0 {order = array<i32: 1, 0, 2>} : tensor<2x4x8xf32> -> tensor<4x2x8xf32>
%b = tt.trans %a {order = array<i32: 2, 1, 0>} : tensor<4x2x8xf32> -> tensor<8x2x4xf32>
// CHECK: %[[res:.*]] = tt.trans %arg0 {order = array<i32: 2, 0, 1>}
// CHECK: tt.return %[[res]]
tt.return %b : tensor<8x2x4xf32>
}