* Copyright 2023-2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "akg/Transforms/Passes.h"
#include "akg/Utils/AnalysisCommon.hpp"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
#ifndef GEN_PASS_DECL_STORELOADELIM
#define GEN_PASS_DECL_STORELOADELIM
#ifndef GEN_PASS_DEF_STORELOADELIM
#define GEN_PASS_DEF_STORELOADELIM
#ifndef GEN_PASS_CLASSES
#define GEN_PASS_CLASSES
#include "akg/Transforms/Passes.h.inc"
#endif
#endif
#endif
}
#define DEBUG_TYPE "elim-store-load"
namespace mlir {
namespace {
struct StoreLoadElimPass : public StoreLoadElimBase<StoreLoadElimPass> {
public:
void runOnOperation() override;
private:
Value getLoadResult(Operation *loadOp) const {
Value loadResult;
if (dyn_cast<affine::AffineLoadOp>(loadOp)) {
loadResult = dyn_cast<affine::AffineLoadOp>(loadOp).getResult();
} else if (dyn_cast<memref::LoadOp>(loadOp)) {
loadResult = dyn_cast<memref::LoadOp>(loadOp).getResult();
} else {
assert(false && "can only get result from AffineLoad or memref::LoadOp.");
}
return loadResult;
}
bool accessSameLocation(Operation *op1, Operation *op2) const {
SmallVector<Value, 4> vals1, vals2;
if (isa<affine::AffineStoreOp, affine::AffineLoadOp>(op1) &&
isa<affine::AffineStoreOp, affine::AffineLoadOp>(op2)) {
AffineMap map1, map2;
CommonUtils::getUnifiedAffineAccess(op1, map1, vals1);
CommonUtils::getUnifiedAffineAccess(op2, map2, vals2);
if (!map1 || !map2 || map1 != map2) return false;
} else if (isa<memref::StoreOp, memref::LoadOp>(op1) && isa<memref::StoreOp, memref::LoadOp>(op2)) {
llvm::append_range(vals1, CommonUtils::getStoreLoadIndices(op1));
llvm::append_range(vals2, CommonUtils::getStoreLoadIndices(op2));
} else {
return false;
}
if (vals1.size() != vals2.size()) return false;
for (size_t i = 0; i < vals1.size(); ++i) {
if (vals1[i] != vals2[i]) return false;
}
return true;
}
bool hasInterveningStore(Operation *storeOp, Operation *loadOp, const SmallVector<Operation *> &otherStores) const {
auto storeBlock = storeOp->getBlock();
auto loadBlock = loadOp->getBlock();
for (auto otherStore : otherStores) {
if (!accessSameLocation(otherStore, loadOp)) continue;
auto otherBlock = otherStore->getBlock();
if (otherBlock == storeBlock && loadBlock == storeBlock) {
if (storeOp->isBeforeInBlock(otherStore) && otherStore->isBeforeInBlock(loadOp)) {
return true;
}
} else {
return true;
}
}
return false;
}
SmallVector<Operation *> getPossibleElimLoads(Operation *storeOp) const {
SmallVector<Operation *> elimLoads;
auto memref = CommonUtils::getStoreMemref(storeOp);
if (!memref || !isa<MemRefType>(memref.getType())) {
return SmallVector<Operation *>();
}
SmallVector<Operation *> otherStores;
SmallVector<Operation *> candidateLoads;
for (auto user : memref.getUsers()) {
if (user == storeOp) {
continue;
}
if (isa<memref::StoreOp, affine::AffineStoreOp>(user)) {
otherStores.push_back(user);
} else if (isa<memref::LoadOp, affine::AffineLoadOp>(user)) {
candidateLoads.push_back(user);
}
}
for (auto loadOp : candidateLoads) {
auto storeBlock = storeOp->getBlock();
auto loadBlock = loadOp->getBlock();
bool inDiffBranch = (storeBlock != loadBlock);
bool isNestBranch = inDiffBranch && (storeBlock->getParent() && loadBlock->getParent() &&
storeBlock->getParent()->isAncestor(loadBlock->getParent()));
bool isSameBranchWAR = !isNestBranch && !inDiffBranch && loadOp->isBeforeInBlock(storeOp);
bool canEliminate = !(inDiffBranch && !isNestBranch) && !isSameBranchWAR && accessSameLocation(storeOp, loadOp);
if (!canEliminate) continue;
if (hasInterveningStore(storeOp, loadOp, otherStores)) continue;
elimLoads.push_back(loadOp);
}
return elimLoads;
}
};
void StoreLoadElimPass::runOnOperation() {
DominanceInfo &domInfo = getAnalysis<DominanceInfo>();
SmallVector<Operation *> toElimLoads;
llvm::DenseSet<Operation *> processedLoads;
llvm::SmallDenseSet<Value> affectedMemrefs;
getOperation()->walk([&](Operation *op) {
if (!isa<memref::StoreOp, affine::AffineStoreOp>(op)) {
return;
}
auto memref = CommonUtils::getStoreMemref(op);
if (!memref || !isa<MemRefType>(memref.getType())) {
return;
}
auto elimLoads = getPossibleElimLoads(op);
for (auto loadOp : elimLoads) {
if (processedLoads.count(loadOp)) {
continue;
}
processedLoads.insert(loadOp);
Value storeValue = CommonUtils::getStoreValue(op);
if (!domInfo.properlyDominates(storeValue, loadOp)) {
continue;
}
Value loadResult = getLoadResult(loadOp);
loadResult.replaceAllUsesWith(storeValue);
if (loadOp->use_empty()) {
toElimLoads.push_back(loadOp);
affectedMemrefs.insert(memref);
}
}
});
for (auto loadOp : toElimLoads) {
loadOp->erase();
}
for (auto memref : affectedMemrefs) {
auto memrefOp = memref.getDefiningOp();
if (!memrefOp) continue;
SmallVector<Operation *> storesToErase;
bool onlyStores = true;
for (auto user : memref.getUsers()) {
if (isa<memref::StoreOp, affine::AffineStoreOp>(user)) {
storesToErase.push_back(user);
} else {
onlyStores = false;
break;
}
}
if (!onlyStores) continue;
for (auto storeOp : storesToErase) {
storeOp->erase();
}
if (memrefOp->use_empty()) {
memrefOp->erase();
}
}
}
}
std::unique_ptr<Pass> createStoreLoadElimPass() { return std::make_unique<StoreLoadElimPass>(); }
}