* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "akg/Transforms/CopyElision.h"
#include <algorithm>
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
namespace mlir {
#ifndef GEN_PASS_DEF_COPYELISION
#define GEN_PASS_DEF_COPYELISION
#include "akg/Transforms/Passes.h.inc"
#endif
}
#define DEBUG_TYPE "copy-elision"
namespace mlir {
static constexpr const int kVectorSizeFour = 4;
struct CopyElisionPass : public mlir::impl::CopyElisionBase<CopyElisionPass> {
public:
void runOnOperation() override;
CopyElisionPass() = default;
CopyElisionPass(const CopyElisionPass &pass) = default;
private:
Operation *getAllocationOp(const Value &value) const {
if (Operation *op = value.getDefiningOp()) {
if (auto effects = dyn_cast<MemoryEffectOpInterface>(op)) {
if (effects.hasEffect<MemoryEffects::Allocate>()) {
return op;
}
}
}
return nullptr;
}
Operation *getDeallocationOp(const Value &value) const {
auto valueUsers = value.getUsers();
auto it = llvm::find_if(valueUsers, [&](Operation *op) {
auto effects = dyn_cast<MemoryEffectOpInterface>(op);
return effects && effects.hasEffect<MemoryEffects::Free>();
});
return (it == valueUsers.end() ? nullptr : *it);
}
static bool doesOpHaveWriteEffect(const Value &val, Operation *op) {
if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) {
if (!llvm::is_contained(val.getUsers(), op)) {
return false;
}
SmallVector<MemoryEffects::EffectInstance, 1> effects;
memEffect.getEffects(effects);
return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
return isa<MemoryEffects::Write>(effect.getEffect());
});
}
if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
for (Region ®ion : op->getRegions()) {
auto walkResult = region.walk([&](Operation *op) {
if (doesOpHaveWriteEffect(val, op)) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) {
return true;
}
}
return false;
}
return true;
}
static bool doesOpUseVal(const Value &val, Operation *op) {
if (!llvm::is_contained(val.getUsers(), op)) {
return false;
}
return true;
}
bool hasInterveningOp(const Value &val, Operation *start, Operation *end,
const std::function<bool(Value, Operation *)> &checkPropertiesOfOperation) const {
std::function<bool(Operation *, Operation *)> recur = [&](Operation *fromOp, Operation *untilOp) {
auto untilOpParentRegion = untilOp->getParentRegion();
auto untilOpParentOp = untilOp->getParentOp();
auto fromOpParentRegion = fromOp->getParentRegion();
auto fromOpBlock = fromOp->getBlock();
auto untilOpBlock = untilOp->getBlock();
if (!fromOpParentRegion->isAncestor(untilOpParentRegion)) {
return false;
}
if (fromOpParentRegion != untilOpParentRegion) {
(void)recur(fromOp, untilOpParentOp);
if (checkPropertiesOfOperation(val, untilOpParentOp)) {
return true;
}
return false;
}
SmallVector<Block *, 2> todoBlocks;
{
for (auto iter = ++fromOp->getIterator(), end = fromOpBlock->end(); iter != end && &*iter != untilOp; ++iter) {
if (checkPropertiesOfOperation(val, &*iter)) {
return true;
}
}
if (untilOpBlock != fromOpBlock) {
auto succ = fromOpBlock->getSuccessors();
std::copy(succ.begin(), succ.end(), std::back_inserter(todoBlocks));
}
}
SmallPtrSet<Block *, kVectorSizeFour> done;
while (!todoBlocks.empty()) {
Block *blk = todoBlocks.pop_back_val();
if (done.insert(blk).second) {
continue;
}
for (Operation &op : *blk) {
if (&op == untilOp) {
break;
}
if (checkPropertiesOfOperation(val, &op)) {
return true;
}
if (&op == blk->getTerminator()) {
auto succ = blk->getSuccessors();
std::copy(succ.begin(), succ.end(), std::back_inserter(todoBlocks));
}
}
}
return false;
};
return recur(start, end);
}
bool canReplaceDestWithSrc(
CopyOpInterface copyOp,
Value src,
Value dest,
Operation* srcDeallocOp,
Operation* destDefOp,
Operation* lastOpOfCurrentRegion) const {
Operation *lastOpUsingSrc = lastOpOfCurrentRegion;
if (srcDeallocOp) {
lastOpUsingSrc = srcDeallocOp;
}
Operation *firstOpUsingDest = &dest.getParentRegion()->front().front();
if (destDefOp) {
firstOpUsingDest = destDefOp;
}
bool isDestReadBefore =
hasInterveningOp(dest, firstOpUsingDest, copyOp, &doesOpUseVal) || doesOpUseVal(dest, firstOpUsingDest);
bool isDestWriteAfter = hasInterveningOp(dest, copyOp, lastOpUsingSrc, &doesOpHaveWriteEffect) ||
doesOpHaveWriteEffect(dest, lastOpUsingSrc);
bool isSrcWriteAfter = hasInterveningOp(src, copyOp, lastOpUsingSrc, &doesOpHaveWriteEffect) ||
doesOpHaveWriteEffect(src, lastOpUsingSrc);
bool isSrcReadAfter =
hasInterveningOp(src, copyOp, lastOpUsingSrc, &doesOpUseVal) || doesOpUseVal(src, lastOpUsingSrc);
bool isDestReadAfter =
hasInterveningOp(dest, copyOp, lastOpUsingSrc, &doesOpUseVal) || doesOpUseVal(dest, lastOpUsingSrc);
if (isDestReadBefore || (isSrcWriteAfter && (isDestReadAfter || isDestWriteAfter)) ||
(isSrcReadAfter && isDestWriteAfter)) {
return false;
}
return true;
}
bool isMemRefTypesCompatible(Value src, Value dest) const {
auto srcType = dyn_cast<MemRefType>(src.getType());
auto destType = dyn_cast<MemRefType>(dest.getType());
if (!destType || !srcType) {
return true;
}
if (srcType.getShape() != destType.getShape() ||
srcType.getElementType() != destType.getElementType() ||
srcType.getMemorySpace() != destType.getMemorySpace()) {
return false;
}
auto srcLayout = srcType.getLayout();
auto destLayout = destType.getLayout();
if (!srcLayout && !destLayout) {
return true;
}
if ((!srcLayout && destLayout) || (srcLayout && !destLayout)) {
return false;
}
return srcLayout == destLayout;
}
void reuseCopySourceAsTarget(CopyOpInterface copyOp,
llvm::SmallPtrSet<Operation *, kVectorSizeFour> &opsToErase) const {
if (opsToErase.count(copyOp) != 0) {
return;
}
Value src = copyOp.getSource();
Value dest = copyOp.getTarget();
Operation *srcDeallocOp = getDeallocationOp(src);
Operation *destDeallocOp = getDeallocationOp(dest);
Operation *destDefOp = getAllocationOp(dest);
Operation *lastOpOfCurrentRegion = &src.getParentRegion()->back().back();
if (!canReplaceDestWithSrc(copyOp, src, dest, srcDeallocOp, destDefOp, lastOpOfCurrentRegion)) {
return;
}
Value replacement = src;
if (!isMemRefTypesCompatible(src, dest)) {
OpBuilder builder(copyOp);
builder.setInsertionPointAfter(copyOp);
replacement = builder.create<memref::CastOp>(copyOp.getLoc(), dest.getType(), src);
}
(void)opsToErase.insert(copyOp);
if (destDefOp) {
(void)opsToErase.insert(destDefOp);
}
if (srcDeallocOp && (hasInterveningOp(dest, srcDeallocOp, lastOpOfCurrentRegion, &doesOpUseVal) ||
doesOpUseVal(dest, lastOpOfCurrentRegion))) {
(void)opsToErase.insert(srcDeallocOp);
}
if (destDeallocOp) {
(void)opsToErase.insert(destDeallocOp);
}
dest.replaceAllUsesWith(replacement);
}
void removeCopy(CopyOpInterface copyOp, llvm::SmallPtrSet<Operation *, kVectorSizeFour> &opsToErase) const {
if (opsToErase.count(copyOp) != 0) {
return;
}
Value src = copyOp.getSource();
Value dest = copyOp.getTarget();
Operation *lastOpUsingDest = &src.getParentRegion()->back().back();
Operation *destDeallocOp = getDeallocationOp(dest);
if (destDeallocOp) {
lastOpUsingDest = destDeallocOp;
}
if (!hasInterveningOp(dest, copyOp, lastOpUsingDest, &doesOpUseVal) &&
(!doesOpUseVal(dest, lastOpUsingDest) || destDeallocOp)) {
(void)opsToErase.insert(copyOp);
}
}
};
void CopyElisionPass::runOnOperation() {
llvm::SmallPtrSet<Operation *, kVectorSizeFour> opsToErase;
getOperation()->walk([&](CopyOpInterface copyOp) {
if (isa<BlockArgument>(copyOp.getTarget())) {
return;
}
reuseCopySourceAsTarget(copyOp, opsToErase);
removeCopy(copyOp, opsToErase);
});
for (Operation *op : opsToErase) {
assert(op->use_empty() &&
"uses remaining for copy ops, memref allocation and deallocation "
"ops that should have ready to be erased");
op->erase();
}
return;
}
}
std::unique_ptr<mlir::Pass> mlir::createCopyElisionPass() { return std::make_unique<mlir::CopyElisionPass>(); }