* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "akg/Transforms/CopyRemoval.h"
#include <algorithm>
#include <iterator>
#include "akg/Transforms/Passes.h"
#include "llvm/IR/Module.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
#ifndef GEN_PASS_DECL_COPYREMOVAL
#define GEN_PASS_DECL_COPYREMOVAL
#ifndef GEN_PASS_DEF_COPYREMOVAL
#define GEN_PASS_DEF_COPYREMOVAL
#ifndef GEN_PASS_CLASSES
#define GEN_PASS_CLASSES
#include "akg/Transforms/Passes.h.inc"
#endif
#endif
#endif
}
#define DEBUG_TYPE "copy-removal"
using namespace mlir;
using namespace MemoryEffects;
namespace {
struct CopyRemovalPass : public CopyRemovalBase<CopyRemovalPass> {
public:
void runOnOperation() override;
private:
Operation *getAllocationOp(Value value) const {
if (Operation *op = value.getDefiningOp()) {
if (auto effects = dyn_cast<MemoryEffectOpInterface>(op)) {
if (effects.hasEffect<Allocate>()) {
return op;
}
}
}
return nullptr;
}
Operation *getDeallocationOp(Value value) const {
auto valueUsers = value.getUsers();
auto it = llvm::find_if(valueUsers, [&](Operation *op) {
auto effects = dyn_cast<MemoryEffectOpInterface>(op);
return effects && effects.hasEffect<Free>();
});
return (it == valueUsers.end() ? nullptr : *it);
}
static bool doesOpHaveWriteEffect(Value val, Operation *op) {
if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) {
if (!llvm::is_contained(val.getUsers(), op)) {
return false;
}
SmallVector<MemoryEffects::EffectInstance, 1> effects;
memEffect.getEffects(effects);
return llvm::any_of(effects, [](const MemoryEffects::EffectInstance effect) {
return isa<MemoryEffects::Write>(effect.getEffect());
});
}
if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
for (Region ®ion : op->getRegions()) {
auto walkResult = region.walk([&](Operation *op) {
if (doesOpHaveWriteEffect(val, op)) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) {
return true;
}
}
return false;
}
return true;
}
static bool doesOpUseVal(Value val, Operation *op) {
if (!llvm::is_contained(val.getUsers(), op)) {
return false;
}
return true;
}
bool hasInterveningOp(const Value val, Operation *start, Operation *end,
std::function<bool(Value, Operation *)> checkPropertiesOfOperation) const {
std::function<bool(Operation *, Operation *)> recur = [&](Operation *fromOp, Operation *untilOp) {
auto untilOpParentRegion = untilOp->getParentRegion();
auto untilOpParentOp = untilOp->getParentOp();
auto fromOpParentRegion = fromOp->getParentRegion();
auto fromOpBlock = fromOp->getBlock();
auto untilOpBlock = untilOp->getBlock();
if (!fromOpParentRegion->isAncestor(untilOpParentRegion)) {
return false;
}
if (fromOpParentRegion != untilOpParentRegion) {
(void)recur(fromOp, untilOpParentOp);
if (checkPropertiesOfOperation(val, untilOpParentOp)) {
return true;
}
return false;
}
SmallVector<Block *, 2> todoBlocks;
{
for (auto iter = ++fromOp->getIterator(), end = fromOpBlock->end(); iter != end && &*iter != untilOp; ++iter) {
if (checkPropertiesOfOperation(val, &*iter)) {
return true;
}
}
if (untilOpBlock != fromOpBlock) {
(void)std::copy(fromOpBlock->getSuccessors().begin(), fromOpBlock->getSuccessors().end(),
std::back_inserter(todoBlocks));
}
}
SmallPtrSet<Block *, 4> done;
while (!todoBlocks.empty()) {
Block *blk = todoBlocks.pop_back_val();
if (done.insert(blk).second) {
continue;
}
for (Operation &op : *blk) {
if (&op == untilOp) {
break;
}
if (checkPropertiesOfOperation(val, &op)) {
return true;
}
if (&op == blk->getTerminator()) {
(void)std::copy(blk->getSuccessors().begin(), blk->getSuccessors().end(), std::back_inserter(todoBlocks));
}
}
}
return false;
};
return recur(start, end);
}
void replaceDest4StoreOp(CopyOpInterface copyOp) const {
Value src = copyOp.getSource();
Value dest = copyOp.getTarget();
src.replaceAllUsesWith(dest);
}
void removeCopy(CopyOpInterface copyOp, llvm::SmallPtrSet<Operation *, 4> &opsToErase) {
if (opsToErase.count(copyOp) != 0) {
return;
}
Value src = copyOp.getSource();
Value dest = copyOp.getTarget();
if ((src.getDefiningOp() == nullptr && dest.getDefiningOp() == nullptr) ||
isa<memref::GetGlobalOp>(src.getDefiningOp()) || isa<memref::ExpandShapeOp>(src.getDefiningOp()) ||
isa<memref::CollapseShapeOp>(src.getDefiningOp()) || isa<memref::ReshapeOp>(src.getDefiningOp()) ||
isa<memref::SubViewOp>(src.getDefiningOp())) {
return;
}
Operation *lastOpUsingDest = &src.getParentRegion()->back().back();
Operation *srcDeallocOp = getDeallocationOp(src);
Operation *destDeallocOp = getDeallocationOp(dest);
if (srcDeallocOp) {
(void)opsToErase.insert(srcDeallocOp);
}
if (destDeallocOp) {
lastOpUsingDest = destDeallocOp;
}
if (!hasInterveningOp(dest, copyOp, lastOpUsingDest, &doesOpUseVal) &&
(!doesOpUseVal(dest, lastOpUsingDest) || destDeallocOp)) {
(void)opsToErase.insert(copyOp);
}
(void)opsToErase.insert(src.getDefiningOp());
replaceDest4StoreOp(copyOp);
}
};
void CopyRemovalPass::runOnOperation() {
SmallVector<func::FuncOp, 2> funcs;
getOperation()->walk([&](func::FuncOp func) { funcs.push_back(func); });
for (auto func : funcs) {
llvm::SmallPtrSet<Operation *, 4> opsToErase;
func.walk([&](CopyOpInterface copyOp) { removeCopy(copyOp, opsToErase); });
for (Operation *op : opsToErase) {
assert(op->use_empty() &&
"uses remaining for copy ops, memref allocation and deallocation "
"ops that should have ready to be erased");
op->erase();
}
}
return;
}
}
std::unique_ptr<Pass> mlir::createCopyRemovalPass() { return std::make_unique<CopyRemovalPass>(); }