* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "akg/Transforms/AddOutParameter.h"
#include <algorithm>
#include <iterator>
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "bishengir/Dialect/HACC/IR/HACC.h"
#define DEBUG_TYPE "add-out-parameter"
namespace mlir {
#define GEN_PASS_DECL_ADDOUTPARAMETER
#define GEN_PASS_DEF_ADDOUTPARAMETER
#include "akg/Transforms/Passes.h.inc"
}
namespace mlir {
using hacc::InputIdxAttr;
using hacc::KernelArgType;
using hacc::KernelArgTypeAttr;
using hacc::OutputIdxAttr;
namespace {
static void setHaccIOArgAttrs(func::FuncOp f, unsigned nInputs, unsigned nOutputs, OpBuilder &builder) {
auto *ctx = builder.getContext();
for (unsigned i = 0; i < nInputs; ++i) {
f.setArgAttr(i, KernelArgTypeAttr::name, KernelArgTypeAttr::get(ctx, KernelArgType::kInput));
f.setArgAttr(i, InputIdxAttr::name, InputIdxAttr::get(ctx, i));
}
for (unsigned i = 0; i < nOutputs; ++i) {
unsigned argIdx = nInputs + i;
f.setArgAttr(argIdx, KernelArgTypeAttr::name, KernelArgTypeAttr::get(ctx, KernelArgType::kOutput));
f.setArgAttr(argIdx, OutputIdxAttr::name, OutputIdxAttr::get(ctx, i));
}
}
static bool isAliasOfOriginalInput(Value v, func::FuncOp func, unsigned origNumInputs) {
if (auto ba = mlir::dyn_cast<BlockArgument>(v)) {
Block *owner = ba.getOwner();
if (&func.front() != owner) return false;
unsigned idx = ba.getArgNumber();
return idx < origNumInputs;
}
return false;
}
static LogicalResult collectOriginalReturnValues(func::FuncOp func, unsigned origNumResults,
SmallVectorImpl<func::ReturnOp> &returns,
SmallVectorImpl<Value> &origReturnValues) {
func.walk([&](func::ReturnOp ret) { returns.push_back(ret); });
if (returns.empty()) return success();
func::ReturnOp repr = returns.front();
if (repr.getNumOperands() != origNumResults) {
repr.emitError() << "number of return values (" << repr.getNumOperands()
<< ") does not match number of function results (" << origNumResults << ")";
return failure();
}
origReturnValues.append(repr.getOperands().begin(), repr.getOperands().end());
return success();
}
static SmallVector<ReassociationIndices> parseReassociation(ArrayAttr reassocAttr) {
SmallVector<ReassociationIndices> result;
if (!reassocAttr) return result;
result.reserve(reassocAttr.size());
for (Attribute attr : reassocAttr) {
auto arr = cast<ArrayAttr>(attr);
ReassociationIndices inds;
inds.reserve(arr.size());
for (Attribute idxAttr : arr) {
auto intAttr = cast<IntegerAttr>(idxAttr);
inds.push_back(intAttr.getInt());
}
result.push_back(std::move(inds));
}
return result;
}
template <typename SrcOpTy, typename BackViewOpTy>
static LogicalResult rewriteSrcOutShapeOpToOutViewImpl(SrcOpTy shapeOp, Value outArg) {
Value src = shapeOp.getSrc();
auto srcTy = dyn_cast<MemRefType>(src.getType());
auto outTy = dyn_cast<MemRefType>(outArg.getType());
if (!srcTy || !outTy) return success();
auto func = shapeOp->template getParentOfType<func::FuncOp>();
if (!func || func.empty()) return success();
Block &entry = func.front();
auto reassoc = parseReassociation(shapeOp.getReassociation());
OpBuilder topBuilder(&entry, entry.begin());
auto backViewOp = topBuilder.template create<BackViewOpTy>(shapeOp.getLoc(), srcTy, outArg, reassoc);
Value backView = backViewOp.getResult();
SmallVector<memref::StoreOp, 4> srcStores;
for (Operation *user : src.getUsers()) {
if (auto store = dyn_cast<memref::StoreOp>(user)) {
if (store.getMemRef() == src) srcStores.push_back(store);
}
}
for (auto store : srcStores) {
OpBuilder b(store->getContext());
b.setInsertionPointAfter(store);
b.create<memref::StoreOp>(store.getLoc(), store.getValue(), backView, store.getIndices());
}
return success();
}
static LogicalResult rewriteSrcOutExpandOpToOutView(memref::ExpandShapeOp expandOp, Value outArg) {
return rewriteSrcOutShapeOpToOutViewImpl<memref::ExpandShapeOp, memref::CollapseShapeOp>(expandOp, outArg);
}
static LogicalResult rewriteSrcOutCollapseToOutView(memref::CollapseShapeOp collapseOp, Value outArg) {
return rewriteSrcOutShapeOpToOutViewImpl<memref::CollapseShapeOp, memref::ExpandShapeOp>(collapseOp, outArg);
}
template <typename SrcOpTy, typename NewViewOpTy>
static LogicalResult rewriteTempSourceToOutViewImpl(SrcOpTy srcOp, Value outArg) {
Value tmp = srcOp.getSrc();
auto tmpTy = dyn_cast<MemRefType>(tmp.getType());
auto outTy = dyn_cast<MemRefType>(outArg.getType());
if (!tmpTy || !outTy) return success();
auto reassoc = parseReassociation(srcOp.getReassociation());
Operation *tmpDef = tmp.getDefiningOp();
Value newViewVal;
if (!tmpDef) {
BlockArgument ba = mlir::cast<BlockArgument>(tmp);
Block *parentBlock = ba.getOwner();
if (!parentBlock) return success();
auto insertPt = parentBlock->begin();
OpBuilder b(parentBlock, insertPt);
auto newView = b.template create<NewViewOpTy>(srcOp.getLoc(), tmpTy, outArg, reassoc);
newViewVal = newView.getResult();
} else {
Block *parentBlock = tmpDef->getBlock();
auto insertPt = std::next(Block::iterator(tmpDef));
OpBuilder b(parentBlock, insertPt);
auto newView = b.template create<NewViewOpTy>(srcOp.getLoc(), tmpTy, outArg, reassoc);
newViewVal = newView.getResult();
}
SmallVector<OpOperand *> uses;
for (OpOperand &use : tmp.getUses()) {
if (use.getOwner() == srcOp) continue;
uses.push_back(&use);
}
for (auto *u : uses) u->set(newViewVal);
srcOp.getResult().replaceAllUsesWith(outArg);
srcOp.erase();
if (auto tmpAlloc = tmp.getDefiningOp<memref::AllocOp>()) {
if (tmpAlloc->use_empty()) tmpAlloc->erase();
}
return success();
}
static LogicalResult rewriteTempCollapseSourceToOutView(memref::CollapseShapeOp collapseOp, Value outArg) {
return rewriteTempSourceToOutViewImpl<memref::CollapseShapeOp, memref::ExpandShapeOp>(collapseOp, outArg);
}
static LogicalResult rewriteTempExpandSourceToOutView(memref::ExpandShapeOp expandOp, Value outArg) {
return rewriteTempSourceToOutViewImpl<memref::ExpandShapeOp, memref::CollapseShapeOp>(expandOp, outArg);
}
static Value buildReshapeFromOut(OpBuilder &b, Location loc, MemRefType tmpTy, Value outArg,
ArrayRef<int64_t> staticDims) {
MLIRContext *ctx = b.getContext();
unsigned rank = tmpTy.getRank();
auto idxTy = IndexType::get(ctx);
auto shapeMemrefTy = MemRefType::get({static_cast<int64_t>(rank)}, idxTy);
Value shapeMemref = b.create<memref::AllocOp>(loc, shapeMemrefTy).getResult();
for (unsigned i = 0; i < rank; ++i) {
Value dimVal;
if (ShapedType::isDynamic(staticDims[i])) {
dimVal = b.create<memref::DimOp>(loc, outArg, i);
} else {
dimVal = b.create<arith::ConstantIndexOp>(loc, staticDims[i]);
}
Value idx = b.create<arith::ConstantIndexOp>(loc, i);
b.create<memref::StoreOp>(loc, dimVal, shapeMemref, ValueRange{idx});
}
auto newView = b.create<memref::ReshapeOp>(loc, tmpTy, outArg, shapeMemref);
return newView.getResult();
}
static void tryEraseOldShapeMemref(Value tmpShape) {
auto shapeAlloc = tmpShape.getDefiningOp<memref::AllocOp>();
if (!shapeAlloc) return;
SmallVector<Operation *> shapeUsers;
std::copy(tmpShape.getUsers().begin(), tmpShape.getUsers().end(), std::back_inserter(shapeUsers));
bool onlyStoreUsers = true;
SmallVector<memref::StoreOp> storeOpsToErase;
for (Operation *user : shapeUsers) {
if (auto store = dyn_cast<memref::StoreOp>(user)) {
if (store.getMemRef() == tmpShape)
storeOpsToErase.push_back(store);
else
onlyStoreUsers = false;
} else {
onlyStoreUsers = false;
}
}
if (!onlyStoreUsers) return;
for (memref::StoreOp s : storeOpsToErase) s.erase();
if (shapeAlloc->use_empty()) shapeAlloc->erase();
}
static LogicalResult rewriteTempReshapeSourceToOutView(memref::ReshapeOp reshapeOp, Value outArg) {
Value tmp = reshapeOp.getSource();
auto tmpTy = dyn_cast<MemRefType>(tmp.getType());
auto outTy = dyn_cast<MemRefType>(outArg.getType());
if (!tmpTy || !outTy) return success();
SmallVector<int64_t, 4> staticDims(tmpTy.getShape().begin(), tmpTy.getShape().end());
Operation *tmpDef = tmp.getDefiningOp();
Value newViewVal;
Location loc = reshapeOp.getLoc();
Value tmpShape = reshapeOp.getShape();
if (!tmpDef) {
BlockArgument ba = mlir::cast<BlockArgument>(tmp);
Block *parentBlock = ba.getOwner();
if (!parentBlock) return success();
auto insertPt = parentBlock->begin();
OpBuilder b(parentBlock, insertPt);
newViewVal = buildReshapeFromOut(b, loc, tmpTy, outArg, staticDims);
} else {
Block *parentBlock = tmpDef->getBlock();
auto insertPt = std::next(Block::iterator(tmpDef));
OpBuilder b(parentBlock, insertPt);
newViewVal = buildReshapeFromOut(b, loc, tmpTy, outArg, staticDims);
}
SmallVector<OpOperand *> uses;
for (OpOperand &use : tmp.getUses()) {
if (use.getOwner() == reshapeOp) continue;
uses.push_back(&use);
}
for (auto *u : uses) u->set(newViewVal);
reshapeOp.getResult().replaceAllUsesWith(outArg);
reshapeOp.erase();
if (auto tmpAlloc = tmp.getDefiningOp<memref::AllocOp>()) {
if (tmpAlloc->use_empty()) tmpAlloc->erase();
}
tryEraseOldShapeMemref(tmpShape);
return success();
}
static LogicalResult rewriteSrcOutReshapeOpToOutView(memref::ReshapeOp reshapeOp, Value outArg) {
Value src = reshapeOp.getSource();
auto srcTy = dyn_cast<MemRefType>(src.getType());
auto outTy = dyn_cast<MemRefType>(outArg.getType());
if (!srcTy || !outTy) return success();
auto func = reshapeOp->getParentOfType<func::FuncOp>();
if (!func || func.empty()) return success();
Block &entry = func.front();
SmallVector<int64_t, 4> staticDims(srcTy.getShape().begin(), srcTy.getShape().end());
OpBuilder topBuilder(&entry, entry.begin());
Value backView = buildReshapeFromOut(topBuilder, reshapeOp.getLoc(), srcTy, outArg, staticDims);
SmallVector<memref::StoreOp, 4> srcStores;
for (Operation *user : src.getUsers()) {
if (auto store = dyn_cast<memref::StoreOp>(user)) {
if (store.getMemRef() == src) srcStores.push_back(store);
}
}
for (auto store : srcStores) {
OpBuilder b(store->getContext());
b.setInsertionPointAfter(store);
b.create<memref::StoreOp>(store.getLoc(), store.getValue(), backView, store.getIndices());
}
reshapeOp.erase();
tryEraseOldShapeMemref(reshapeOp.getShape());
return success();
}
static LogicalResult handleAllocReturn(Value oldResVal, Value outArg) {
if (auto allocOp = oldResVal.getDefiningOp<memref::AllocOp>()) {
Value root = allocOp.getResult();
if (root != outArg) root.replaceAllUsesWith(outArg);
if (allocOp->use_empty()) allocOp->erase();
}
return success();
}
static LogicalResult handleCollapseReturn(Value oldResVal, Value outArg, SmallVector<Value> &origReturnValues,
SmallVector<Value> &newReturnValues) {
auto collapseOp = oldResVal.getDefiningOp<memref::CollapseShapeOp>();
if (!collapseOp) return success();
Value src = collapseOp.getSrc();
if (llvm::is_contained(origReturnValues, src) || llvm::is_contained(newReturnValues, src)) {
return rewriteSrcOutCollapseToOutView(collapseOp, outArg);
}
return rewriteTempCollapseSourceToOutView(collapseOp, outArg);
}
static LogicalResult handleExpandReturn(Value oldResVal, Value outArg, SmallVector<Value> &origReturnValues,
SmallVector<Value> &newReturnValues) {
auto expandOp = oldResVal.getDefiningOp<memref::ExpandShapeOp>();
if (!expandOp) return success();
Value src = expandOp.getSrc();
if (llvm::is_contained(origReturnValues, src) || llvm::is_contained(newReturnValues, src)) {
return rewriteSrcOutExpandOpToOutView(expandOp, outArg);
}
return rewriteTempExpandSourceToOutView(expandOp, outArg);
}
static LogicalResult handleReshapeReturn(Value oldResVal, Value outArg, SmallVector<Value> &origReturnValues,
SmallVector<Value> &newReturnValues) {
auto reshapeOp = oldResVal.getDefiningOp<memref::ReshapeOp>();
if (!reshapeOp) return success();
Value src = reshapeOp.getSource();
if (llvm::is_contained(origReturnValues, src) || llvm::is_contained(newReturnValues, src)) {
return rewriteSrcOutReshapeOpToOutView(reshapeOp, outArg);
}
return rewriteTempReshapeSourceToOutView(reshapeOp, outArg);
}
static bool isReshapeSpecialCase(memref::ReshapeOp reshapeOp, MemRefType srcTy, MemRefType resTy) {
if (srcTy && resTy && srcTy.hasStaticShape() && resTy.hasStaticShape()) {
ArrayRef<int64_t> sShape = srcTy.getShape();
ArrayRef<int64_t> rShape = resTy.getShape();
auto getNumElems = [](ArrayRef<int64_t> shp) -> int64_t {
if (shp.empty()) return 1;
int64_t prod = 1;
for (int64_t d : shp) {
if (ShapedType::isDynamic(d)) return ShapedType::kDynamic;
prod *= d;
}
return prod;
};
int64_t sElems = getNumElems(sShape);
int64_t rElems = getNumElems(rShape);
return sElems == 1 && rElems == 1;
}
return false;
}
static LogicalResult processReturnValue(unsigned resultIdx, Value oldResVal, Value maybeOutArg, func::ReturnOp ret,
SmallVector<Value> &origReturnValues, SmallVector<Value> &newReturnValues) {
if (mlir::isa<BlockArgument>(oldResVal)) {
return success();
}
if (!maybeOutArg) {
ret.emitError() << "no out-argument is assigned for this return value, "
"but it is not a block argument: "
<< oldResVal;
return failure();
}
Value outArg = maybeOutArg;
if (oldResVal.getType() != outArg.getType()) {
ret.emitError() << "type mismatch between return value (" << oldResVal.getType() << ") and out argument ("
<< outArg.getType() << ")";
return failure();
}
if (isa<memref::AllocOp>(oldResVal.getDefiningOp())) {
if (failed(handleAllocReturn(oldResVal, outArg))) return failure();
newReturnValues[resultIdx] = outArg;
return success();
}
if (isa<memref::ReshapeOp>(oldResVal.getDefiningOp())) {
auto reshapeOp = cast<memref::ReshapeOp>(oldResVal.getDefiningOp());
auto srcTy = mlir::dyn_cast<MemRefType>(reshapeOp.getSource().getType());
auto resTy = mlir::dyn_cast<MemRefType>(reshapeOp.getResult().getType());
OpBuilder b(reshapeOp);
if (isReshapeSpecialCase(reshapeOp, srcTy, resTy) && srcTy.getRank() == 0 && resTy.getRank() == 1) {
SmallVector<ReassociationIndices, 1> reassoc;
auto expand = b.create<memref::ExpandShapeOp>(reshapeOp.getLoc(), resTy, reshapeOp.getSource(), reassoc);
reshapeOp.replaceAllUsesWith(expand.getResult());
reshapeOp.erase();
oldResVal = expand.getResult();
} else if (isReshapeSpecialCase(reshapeOp, srcTy, resTy) && srcTy.getRank() == 1 && resTy.getRank() == 0) {
SmallVector<ReassociationIndices, 1> reassoc;
auto collapse = b.create<memref::CollapseShapeOp>(reshapeOp.getLoc(), resTy, reshapeOp.getSource(), reassoc);
reshapeOp.replaceAllUsesWith(collapse.getResult());
reshapeOp.erase();
oldResVal = collapse.getResult();
} else {
if (failed(handleReshapeReturn(oldResVal, outArg, origReturnValues, newReturnValues))) return failure();
newReturnValues[resultIdx] = outArg;
return success();
}
}
if (isa<memref::CollapseShapeOp>(oldResVal.getDefiningOp())) {
if (failed(handleCollapseReturn(oldResVal, outArg, origReturnValues, newReturnValues))) return failure();
newReturnValues[resultIdx] = outArg;
return success();
}
if (isa<memref::ExpandShapeOp>(oldResVal.getDefiningOp())) {
if (failed(handleExpandReturn(oldResVal, outArg, origReturnValues, newReturnValues))) return failure();
newReturnValues[resultIdx] = outArg;
return success();
}
ret.emitError() << "unsupported pattern for return value when converting to "
"out-parameter: "
<< oldResVal;
return success();
}
static LogicalResult handleAllReturns(func::FuncOp func, unsigned origNumResults, SmallVector<Value> maybeOutArgs,
SmallVector<Value> &origReturnValues) {
SmallVector<func::ReturnOp> returns;
func.walk([&](func::ReturnOp ret) { returns.push_back(ret); });
if (returns.empty()) return success();
SmallVector<Value> newReturnValues(origReturnValues.begin(), origReturnValues.end());
for (auto ret : returns) {
auto retOperands = ret.getOperands();
if (retOperands.size() != origNumResults) {
ret.emitError() << "number of return values (" << retOperands.size()
<< ") does not match number of function results (" << origNumResults << ")";
return failure();
}
for (unsigned i = 0; i < origNumResults; ++i) {
if (failed(processReturnValue(i, retOperands[i], maybeOutArgs[i], ret, origReturnValues, newReturnValues))) {
return failure();
}
}
}
return success();
}
template <typename OpTy>
static void eraseDeadOpsOfType(func::FuncOp func) {
SmallVector<OpTy, 4> deadOps;
func.walk([&](OpTy op) {
if (op->use_empty()) deadOps.push_back(op);
});
for (auto op : deadOps) op->erase();
}
static void eraseDeadOps(func::FuncOp func) {
eraseDeadOpsOfType<memref::ExpandShapeOp>(func);
eraseDeadOpsOfType<memref::CollapseShapeOp>(func);
eraseDeadOpsOfType<memref::AllocOp>(func);
}
static void rebuildAllReturns(const SmallVectorImpl<func::ReturnOp> &returns, unsigned origNumResults,
ArrayRef<bool> needOut, ArrayRef<Value> resultToOutArg) {
for (auto ret : returns) {
SmallVector<Value, 4> newRetOperands;
newRetOperands.reserve(origNumResults);
auto oldOps = ret.getOperands();
for (unsigned i = 0; i < origNumResults; ++i) {
if (needOut[i]) {
newRetOperands.push_back(resultToOutArg[i]);
} else {
newRetOperands.push_back(oldOps[i]);
}
}
OpBuilder builder(ret);
builder.create<func::ReturnOp>(ret.getLoc(), newRetOperands);
ret.erase();
}
}
static LogicalResult rewriteReturnsAndAllocToUseOutParams(func::FuncOp func, unsigned origNumInputs,
unsigned origNumResults, ArrayRef<bool> needOut) {
if (origNumResults == 0) return success();
if (func.empty()) return success();
Block &entry = func.front();
if (entry.getNumArguments() < origNumInputs) {
func.emitError() << "entry block arg count is less than original inputs";
return failure();
}
func::ReturnOp anyRet;
func.walk([&](func::ReturnOp r) {
if (!anyRet) anyRet = r;
});
if (!anyRet) return success();
if (anyRet.getNumOperands() != origNumResults) {
anyRet.emitError() << "unexpected number of return operands: got " << anyRet.getNumOperands() << ", expected "
<< origNumResults;
return failure();
}
unsigned expectedTotalArgs = origNumInputs;
for (unsigned i = 0; i < origNumResults; ++i)
if (needOut[i]) ++expectedTotalArgs;
if (entry.getNumArguments() != expectedTotalArgs) {
func.emitError() << "entry block arg count mismatch after selective out-param "
"transform: expected "
<< expectedTotalArgs << " got " << entry.getNumArguments();
return failure();
}
SmallVector<Value> resultToOutArg(origNumResults, Value());
unsigned curOutArgIdx = origNumInputs;
for (unsigned i = 0; i < origNumResults; ++i) {
if (needOut[i]) {
resultToOutArg[i] = entry.getArgument(curOutArgIdx++);
}
}
SmallVector<func::ReturnOp> returns;
SmallVector<Value> origReturnValues;
if (failed(collectOriginalReturnValues(func, origNumResults, returns, origReturnValues))) return failure();
if (failed(handleAllReturns(func, origNumResults, resultToOutArg, origReturnValues))) return failure();
rebuildAllReturns(returns, origNumResults, needOut, resultToOutArg);
eraseDeadOps(func);
return success();
}
static LogicalResult transformFunc(func::FuncOp func, OpBuilder &builder) {
auto funcTy = func.getFunctionType();
auto *ctx = builder.getContext();
SmallVector<Type> origInputs(funcTy.getInputs().begin(), funcTy.getInputs().end());
SmallVector<Type> origResults(funcTy.getResults().begin(), funcTy.getResults().end());
unsigned origNumInputs = origInputs.size();
unsigned origNumResults = origResults.size();
if (origNumResults == 0) {
setHaccIOArgAttrs(func, origNumInputs, 0, builder);
return success();
}
if (func.empty()) {
SmallVector<Type> newInputs;
newInputs.reserve(origNumInputs + origNumResults);
newInputs.append(origInputs.begin(), origInputs.end());
newInputs.append(origResults.begin(), origResults.end());
auto newFuncTy = FunctionType::get(ctx, newInputs, origResults);
func.setFunctionType(newFuncTy);
setHaccIOArgAttrs(func, origNumInputs, origNumResults, builder);
return success();
}
func::ReturnOp reprRet;
func.walk([&](func::ReturnOp r) {
if (!reprRet) reprRet = r;
});
if (!reprRet) {
setHaccIOArgAttrs(func, origNumInputs, 0, builder);
return success();
}
if (reprRet.getNumOperands() != origNumResults) {
reprRet.emitError() << "number of return values (" << reprRet.getNumOperands()
<< ") does not match number of function results (" << origNumResults << ")";
}
SmallVector<bool, 4> needOut(origNumResults, true);
for (unsigned i = 0; i < origNumResults; ++i) {
Value rv = reprRet.getOperand(i);
if (isAliasOfOriginalInput(rv, func, origNumInputs)) {
needOut[i] = false;
}
}
SmallVector<Type> newInputs;
newInputs.reserve(origNumInputs + origNumResults);
newInputs.append(origInputs.begin(), origInputs.end());
for (unsigned i = 0; i < origNumResults; ++i) {
if (needOut[i]) {
newInputs.push_back(origResults[i]);
}
}
auto newFuncTy = FunctionType::get(ctx, newInputs, origResults);
func.setFunctionType(newFuncTy);
Block &entry = func.front();
if (entry.getNumArguments() != origNumInputs) {
func.emitError() << "unexpected number of entry block arguments before transform: got " << entry.getNumArguments()
<< ", expected " << origNumInputs;
return failure();
}
for (unsigned i = 0; i < origNumResults; ++i) {
if (needOut[i]) {
(void)entry.addArgument(origResults[i], func.getLoc());
}
}
unsigned nOutputs = 0;
for (unsigned i = 0; i < origNumResults; ++i)
if (needOut[i]) ++nOutputs;
setHaccIOArgAttrs(func, origNumInputs, nOutputs, builder);
if (failed(rewriteReturnsAndAllocToUseOutParams(func, origNumInputs, origNumResults, needOut))) return failure();
return success();
}
struct AddOutParameter : public mlir::impl::AddOutParameterBase<AddOutParameter> {
AddOutParameter() = default;
void runOnOperation() override {
ModuleOp module = getOperation();
if (!module) {
signalPassFailure();
return;
}
OpBuilder builder(module.getContext());
SmallVector<func::FuncOp, 8> funcs;
module.walk([&](func::FuncOp f) { funcs.push_back(f); });
auto it = std::find_if(funcs.begin(), funcs.end(), [&](func::FuncOp f) {
if (failed(transformFunc(f, builder))) {
f.emitError("AddOutParameter pass failed for this function");
return true;
}
return false;
});
if (it != funcs.end()) {
signalPassFailure();
return;
}
}
};
}
}
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> mlir::createAddOutParameterPass() {
return std::make_unique<mlir::AddOutParameter>();
}