* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "akg/Transforms/Passes.h"
#include "akg/Utils/AnalysisCommon.hpp"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
#ifndef GEN_PASS_DECL_STORELOADELIM
#define GEN_PASS_DECL_STORELOADELIM
#ifndef GEN_PASS_DEF_STORELOADELIM
#define GEN_PASS_DEF_STORELOADELIM
#ifndef GEN_PASS_CLASSES
#define GEN_PASS_CLASSES
#include "akg/Transforms/Passes.h.inc"
#endif
#endif
#endif
}
#define DEBUG_TYPE "elim-store-load"
namespace mlir {
namespace {
struct StoreLoadElimPass : public StoreLoadElimBase<StoreLoadElimPass> {
public:
void runOnOperation() override;
private:
Value getLoadResult(Operation *loadOp) const {
Value loadResult;
if (dyn_cast<affine::AffineLoadOp>(loadOp)) {
loadResult = dyn_cast<affine::AffineLoadOp>(loadOp).getResult();
} else if (dyn_cast<memref::LoadOp>(loadOp)) {
loadResult = dyn_cast<memref::LoadOp>(loadOp).getResult();
} else {
assert(false && "can only get result from AffineLoad or memref::LoadOp.");
}
return loadResult;
}
SmallVector<Operation *> getPossibleElimLoads(Operation *storeOp) const {
SmallVector<Operation *> elimLoads;
auto memref = CommonUtils::getStoreMemref(storeOp);
if (!memref || !isa<MemRefType>(memref.getType())) {
return SmallVector<Operation *>();
}
for (auto user : memref.getUsers()) {
if (user == storeOp) {
continue;
}
if (dyn_cast<memref::LoadOp>(user) || dyn_cast<affine::AffineLoadOp>(user)) {
auto storeBlock = storeOp->getBlock();
auto loadBlock = user->getBlock();
bool inDiffBranch = (storeBlock != loadBlock);
bool isNestBranch = inDiffBranch && (storeBlock->getParent() && loadBlock->getParent() &&
storeBlock->getParent()->isAncestor(loadBlock->getParent()));
bool isSameBranchWAR = !isNestBranch && !inDiffBranch && user->isBeforeInBlock(storeOp);
if ((inDiffBranch && !isNestBranch) || isSameBranchWAR) {
return SmallVector<Operation *>();
}
elimLoads.push_back(user);
} else {
return SmallVector<Operation *>();
}
}
return elimLoads;
}
};
void StoreLoadElimPass::runOnOperation() {
SmallVector<Operation *> toElimStores;
SmallVector<Operation *> toElimLoads;
getOperation()->walk([&](Operation *op) {
if (dyn_cast<memref::StoreOp>(op) || dyn_cast<affine::AffineStoreOp>(op)) {
auto memref = CommonUtils::getStoreMemref(op);
if (!memref || !isa<MemRefType>(memref.getType())) {
return;
}
auto elimLoads = getPossibleElimLoads(op);
size_t eraseSize = 0;
for (auto loadOp : elimLoads) {
Value storeValue = CommonUtils::getStoreValue(op);
Value loadResult = getLoadResult(loadOp);
loadResult.replaceAllUsesWith(storeValue);
if (loadOp->use_empty()) {
toElimLoads.push_back(loadOp);
eraseSize++;
}
}
bool isGlobalBuffer = memref.getDefiningOp() == nullptr;
bool elimAllLoads = eraseSize > 0 && eraseSize == elimLoads.size();
if (elimAllLoads && !isGlobalBuffer) {
toElimStores.push_back(op);
}
}
});
for (auto loadOp : toElimLoads) {
loadOp->erase();
}
for (auto storeOp : toElimStores) {
auto memref = CommonUtils::getStoreMemref(storeOp);
if (storeOp->use_empty()) {
storeOp->erase();
}
if (memref && isa<MemRefType>(memref.getType())) {
auto memrefOp = memref.getDefiningOp();
if (memrefOp && memrefOp->use_empty()) {
memrefOp->erase();
}
}
}
}
}
std::unique_ptr<Pass> createStoreLoadElimPass() { return std::make_unique<StoreLoadElimPass>(); }
}