6782b9c1创建于 2025年5月14日历史提交
diff --git a/third_party/tsl/third_party/llvm/concat.patch b/third_party/tsl/third_party/llvm/concat.patch
new file mode 100644
index 0000000..86c5db2
--- /dev/null
+++ b/third_party/tsl/third_party/llvm/concat.patch
@@ -0,0 +1,78 @@
+diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+index 8556d9570..1dd4355ad 100644
+--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
++++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+@@ -26,6 +26,24 @@ def ApplyDecomposeTensorConcatPatternsOp : Op<Transform_Dialect,
+   let assemblyFormat = "attr-dict";
+ }
+ 
++def ApplySimplifyTensorConcatPatternsOp : Op<Transform_Dialect,
++    "apply_patterns.tensor.simplify_concat",
++    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
++  let description = [{
++    Indicates that tensor.concat ops should be simplified, reducing the number
++    of the input operands by merging common source operands together.
++
++    For example, given a tensor.concat with four inputs, where the first two
++    come from extract_slices of a common value (arg0), and the other two inputs
++    come from extract_slices of another common value (arg1), then the
++    tensor.concat can be rewritten as a concat of just arg0 and arg1.
++
++    Also considers the case of more complex simplifcations, where extract_slices
++    are needed to merge common source operands.
++  }];
++
++  let assemblyFormat = "attr-dict";
++}
+ 
+ def ApplyDropRedundantInsertSliceRankExpansionPatternsOp : Op<Transform_Dialect,
+     "apply_patterns.tensor.drop_redundant_insert_slice_rank_expansion",
+diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
+index e9da68c44..cc87b0a8c 100644
+--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
++++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
+@@ -30,6 +30,9 @@ std::unique_ptr<Pass> createDecomposeTensorConcatPass();
+ /// Creates an instance of the concat removal pass.
+ std::unique_ptr<Pass> createConcatRemovalPass();
+ 
++/// Creates an instance of the concat simplification pass.
++std::unique_ptr<Pass> createSimplifyTensorConcatPass();
++
+ //===----------------------------------------------------------------------===//
+ // Registration
+ //===----------------------------------------------------------------------===//
+diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
+index c88515fb0..148fd2c33 100644
+--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
++++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
+@@ -70,4 +70,15 @@ def ConcatRemoval : Pass<"concat-removal"> {
+   let dependentDialects = ["tensor::TensorDialect"];
+ }
+ 
++def SimplifyTensorConcat : Pass<"simplify-tensor-concat"> {
++  let summary = "TODO"
++                "TODO";
++  let description = [{
++    TODO
++    TODO
++  }];
++  let constructor = "mlir::tensor::createSimplifyTensorConcatPass()";
++  let dependentDialects = ["tensor::TensorDialect"];
++}
++
+ #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
+diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+index 4cad1aaa2..744b86e7a 100644
+--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
++++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+@@ -80,6 +80,10 @@ void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
+ /// output of the ops.
+ void populateConcatRemovalPatterns(RewritePatternSet &patterns);
+ 
++/// Populates `patterns` with patterns that simplify `tensor.concat` by merging
++/// contiguous extract slice ops.
++void populateSimplifyTensorConcatPatterns(RewritePatternSet &patterns);
++
+ /// Populates `patterns` with patterns that fold operations like `tensor.pad`
+ /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
+ /// respectively.