new file mode 100644
@@ -0,0 +1,116 @@
+diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+index 34a6f5d74b13..c3c6be05a301 100644
+--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+@@ -529,6 +529,17 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
+ return true;
+ }
+
++ bool mustBufferizeInPlace(OpOperand &opOperand,
++ const AnalysisState &state) {
++ // ToMemrefOps always bufferize inplace.
++ //
++ // This was removed upstream in this commit:
++ // https://github.com/llvm/llvm-project/commit/4ef60283403f8277ff1048de5905af99fd2ed81d
++ // which means that this is probably a huge hack.
++ // TODO: Properly investigate this!
++ return true;
++ }
++
+ bool bufferizesToMemoryWrite(OpOperand &opOperand,
+ const AnalysisState &state) {
+ return !getReadOnly();
+diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
+index d2dd95416241..e9da68c44a22 100644
+--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
+@@ -27,6 +27,9 @@ std::unique_ptr<Pass> createTensorBufferizePass();
+ /// Creates an instance of the concat decomposition pass.
+ std::unique_ptr<Pass> createDecomposeTensorConcatPass();
+
++/// Creates an instance of the concat removal pass.
++std::unique_ptr<Pass> createConcatRemovalPass();
++
+ //===----------------------------------------------------------------------===//
+ // Registration
+ //===----------------------------------------------------------------------===//
+diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
+index d30e5fcb604d..c88515fb015c 100644
+--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
+@@ -43,4 +43,31 @@ def DecomposeTensorConcat : Pass<"decompose-tensor-concat"> {
+ let dependentDialects = ["tensor::TensorDialect"];
+ }
+
++def ConcatRemoval : Pass<"concat-removal"> {
++ let summary = "Remove concat by rewriting ops to write directly into the output buffer";
++ let description = [{
++ Rewrite ops which results are concatenated to write instead directly
++ into a big tensor shared by all ops, thus removing the necessity of
++ concatenating the partial output of the ops. e.g., rewrite:
++
++ %0 = tensor.empty() : tensor<8xf32>
++ %1 = linalg.elemwise(%0: tensor<8xf32>) outs(%0: tensor<8xf32>)
++ %2 = tensor.empty() : tensor<8xf32>
++ %3 = linalg.elemwise(%2: tensor<8xf32>) outs(%2: tensor<8xf32>)
++ %concat = tensor.concat dim(0) %1, %3 : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
++
++ into:
++
++ %0 = tensor.empty() : tensor<16xf32>
++ %1 = extract_slice %0 : tensor<16xf32> -> tensor<8xf32>
++ %2 = linalg.elemwise(%1: tensor<8xf32>) outs(%1: tensor<8xf32>)
++ %3 = insert_slice %2 into %0
++ %4 = extract_slice %3 : tensor<16xf32> -> tensor<8xf32>
++ %5 = linalg.elemwise(%4: tensor<8xf32>) outs(%4: tensor<8xf32>)
++ %6 = insert_slice %5 into %4
++ }];
++ let constructor = "mlir::tensor::createConcatRemovalPass()";
++ let dependentDialects = ["tensor::TensorDialect"];
++}
++
+ #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
+diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+index 44b8377bd6aa..4cad1aaa212f 100644
+--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+@@ -74,6 +74,12 @@ void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
+ /// that it can be bufferized into a sequence of copies.
+ void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
+
++/// Populates `patterns` with patterns that rewrite ops which results
++/// are concatenated to write instead directly into a big tensor shared
++/// by all ops, thus removing the necessity of concatenating the partial
++/// output of the ops.
++void populateConcatRemovalPatterns(RewritePatternSet &patterns);
++
+ /// Populates `patterns` with patterns that fold operations like `tensor.pad`
+ /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
+ /// respectively.
+diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+index f590e3d9da8e..dc625fbbe67b 100644
+--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+@@ -223,6 +223,12 @@ bool OneShotAnalysisState::isValueWritten(Value value) const {
+ }
+
+ bool OneShotAnalysisState::isWritable(Value value) const {
++ // THIS IS A HUGE HACK AND WE SHOULD FIND A BETTER WAY TO DO THIS!
++ if (dyn_cast<ToTensorOp>(getOwnerOfValue(value))) {
++ llvm::errs() << "WARNING: Assuming ToTensorOp is writable\n";
++ return true;
++ }
++
+ // TODO: Out-of-place bufferized value could be considered writable.
+ // Query BufferizableOpInterface to see if the BlockArgument is writable.
+ if (auto bufferizableOp =
+@@ -1189,7 +1195,8 @@ checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo,
+ }
+
+ for (OpOperand &opOperand : op->getOpOperands()) {
+- if (isa<TensorType>(opOperand.get().getType())) {
++ // ToMemrefOp does not write into the buffer ?
++ if (isa<TensorType>(opOperand.get().getType()) && !dyn_cast<ToMemrefOp>(op.getOperation())) {
+ if (wouldCreateReadAfterWriteInterference(
+ opOperand, domInfo, state,
+ /*checkConsistencyOnly=*/true)) {