* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "akg/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
#ifndef GEN_PASS_DECL_ALLOCBUFFERSHRINK
#define GEN_PASS_DECL_ALLOCBUFFERSHRINK
#ifndef GEN_PASS_DEF_ALLOCBUFFERSHRINK
#define GEN_PASS_DEF_ALLOCBUFFERSHRINK
#ifndef GEN_PASS_CLASSES
#define GEN_PASS_CLASSES
#include "akg/Transforms/Passes.h.inc"
#endif
#endif
#endif
}
#define DEBUG_TYPE "alloc-buffer-shrink"
namespace mlir {
namespace {
static constexpr int kDimAbsorbed = -1;
struct AccessInfo {
Operation *op;
SmallVector<int> dimMapping;
};
static bool isAccessOp(Operation *op) {
return isa<affine::AffineLoadOp, affine::AffineStoreOp, memref::LoadOp, memref::StoreOp>(op);
}
static Operation *getOwningLoopOp(Value idx) {
auto blockArg = dyn_cast<BlockArgument>(idx);
if (!blockArg) return nullptr;
auto *parentOp = blockArg.getOwner()->getParentOp();
if (auto forOp = dyn_cast<affine::AffineForOp>(parentOp))
return forOp.getInductionVar() == blockArg ? parentOp : nullptr;
if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) return forOp.getInductionVar() == blockArg ? parentOp : nullptr;
return nullptr;
}
static SmallVector<Value> getAccessIndices(Operation *op) {
if (auto loadOp = dyn_cast<affine::AffineLoadOp>(op))
return SmallVector<Value>(loadOp.getIndices().begin(), loadOp.getIndices().end());
if (auto storeOp = dyn_cast<affine::AffineStoreOp>(op))
return SmallVector<Value>(storeOp.getIndices().begin(), storeOp.getIndices().end());
if (auto loadOp = dyn_cast<memref::LoadOp>(op))
return SmallVector<Value>(loadOp.getIndices().begin(), loadOp.getIndices().end());
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
return SmallVector<Value>(storeOp.getIndices().begin(), storeOp.getIndices().end());
return {};
}
static Value getAccessMemRef(Operation *op) {
if (auto loadOp = dyn_cast<affine::AffineLoadOp>(op)) return loadOp.getMemRef();
if (auto storeOp = dyn_cast<affine::AffineStoreOp>(op)) return storeOp.getMemRef();
if (auto loadOp = dyn_cast<memref::LoadOp>(op)) return loadOp.getMemRef();
if (auto storeOp = dyn_cast<memref::StoreOp>(op)) return storeOp.getMemRef();
return {};
}
struct AllocBufferShrinkPass : public AllocBufferShrinkBase<AllocBufferShrinkPass> {
public:
void runOnOperation() override;
private:
unsigned rank = 0;
SmallVector<bool> shrinkDims;
SmallVector<Operation *> dimLoops;
SmallVector<AccessInfo> accessInfos;
SmallVector<Operation *> viewOps;
DenseMap<Value, Value> replacementMap;
Value zeroIdx;
bool computeCollapseMapping(MemRefType srcType, ArrayRef<ReassociationIndices> reassociation,
ArrayRef<int> currentMapping, SmallVectorImpl<int> &newMapping);
bool computeExpandMapping(MemRefType resultType, ArrayRef<ReassociationIndices> reassociation,
ArrayRef<int> currentMapping, SmallVectorImpl<int> &newMapping);
bool collectAccessOpsRecursive(Value memref, SmallVector<int> currentMapping);
bool processAlloc(memref::AllocOp allocOp);
bool allAccessMapsAreIdentity();
Operation *findCommonLoopForDim(unsigned d);
void analyzeShrinkableDims(MemRefType memrefType);
void validateSubviewOffsets();
Attribute updateMemSpaceForShrink(Attribute memorySpace);
memref::AllocOp createShrunkAlloc(memref::AllocOp allocOp, MemRefType memrefType);
void rebuildViewChain(memref::AllocOp oldAlloc, memref::AllocOp newAlloc);
void rebuildSubView(memref::SubViewOp subviewOp);
void rebuildCollapseShape(memref::CollapseShapeOp collapseOp);
void rebuildExpandShape(memref::ExpandShapeOp expandOp);
void rebuildMemorySpaceCast(memref::MemorySpaceCastOp castOp);
void rewriteAccessOps();
void eraseOldOps(memref::AllocOp allocOp);
};
bool AllocBufferShrinkPass::computeCollapseMapping(MemRefType srcType, ArrayRef<ReassociationIndices> reassociation,
ArrayRef<int> currentMapping, SmallVectorImpl<int> &newMapping) {
newMapping.assign(rank, kDimAbsorbed);
for (unsigned d = 0; d < rank; d++) {
if (currentMapping[d] == kDimAbsorbed) continue;
int64_t srcDim = currentMapping[d];
for (unsigned g = 0; g < reassociation.size(); g++) {
const auto &group = reassociation[g];
if (llvm::find(group, srcDim) == group.end()) continue;
if (group.size() == 1) {
newMapping[d] = static_cast<int>(g);
break;
}
int64_t dimSize = srcType.getDimSize(srcDim);
if (ShapedType::isDynamic(dimSize)) return false;
if (dimSize == 1) break;
for (int64_t otherDim : group) {
if (otherDim == srcDim) continue;
int64_t otherSize = srcType.getDimSize(otherDim);
if (ShapedType::isDynamic(otherSize) || otherSize != 1) return false;
}
newMapping[d] = static_cast<int>(g);
break;
}
}
return true;
}
bool AllocBufferShrinkPass::computeExpandMapping(MemRefType resultType, ArrayRef<ReassociationIndices> reassociation,
ArrayRef<int> currentMapping, SmallVectorImpl<int> &newMapping) {
newMapping.assign(rank, kDimAbsorbed);
for (unsigned d = 0; d < rank; d++) {
if (currentMapping[d] == kDimAbsorbed) continue;
int64_t srcDim = currentMapping[d];
if (srcDim >= static_cast<int64_t>(reassociation.size())) return false;
const auto &group = reassociation[srcDim];
if (group.size() == 1) {
newMapping[d] = static_cast<int>(group[0]);
continue;
}
int nonUnitIdx = kDimAbsorbed;
for (int64_t resultDim : group) {
int64_t dimSize = resultType.getDimSize(resultDim);
if (ShapedType::isDynamic(dimSize)) return false;
if (dimSize != 1) {
if (nonUnitIdx != kDimAbsorbed) return false;
nonUnitIdx = static_cast<int>(resultDim);
}
}
newMapping[d] = nonUnitIdx;
}
return true;
}
bool AllocBufferShrinkPass::collectAccessOpsRecursive(Value memref, SmallVector<int> currentMapping) {
for (auto *user : memref.getUsers()) {
if (isAccessOp(user)) {
accessInfos.push_back({user, SmallVector<int>(currentMapping)});
continue;
}
SmallVector<int> newMapping;
Value nextValue;
if (auto subviewOp = dyn_cast<memref::SubViewOp>(user)) {
auto sourceRank = subviewOp.getSourceType().getRank();
auto targetRank = subviewOp.getType().getRank();
if (sourceRank != targetRank || targetRank != rank) return false;
auto strides = subviewOp.getStaticStrides();
if (!llvm::all_of(strides, [](int64_t s) { return s == 1; })) {
return false;
}
newMapping = SmallVector<int>(currentMapping);
nextValue = subviewOp.getResult();
} else if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(user)) {
if (!computeCollapseMapping(collapseOp.getSrcType(), collapseOp.getReassociationIndices(), currentMapping,
newMapping))
return false;
nextValue = collapseOp.getResult();
} else if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(user)) {
if (!computeExpandMapping(expandOp.getResultType(), expandOp.getReassociationIndices(), currentMapping,
newMapping))
return false;
nextValue = expandOp.getResult();
} else if (auto castOp = dyn_cast<memref::MemorySpaceCastOp>(user)) {
newMapping = SmallVector<int>(currentMapping);
nextValue = castOp.getResult();
} else {
return false;
}
viewOps.push_back(user);
if (!collectAccessOpsRecursive(nextValue, newMapping)) return false;
}
return true;
}
Attribute AllocBufferShrinkPass::updateMemSpaceForShrink(Attribute memorySpace) {
if (!memorySpace) return memorySpace;
auto dictAttr = dyn_cast<DictionaryAttr>(memorySpace);
if (!dictAttr) return memorySpace;
auto symShapeAttr = dictAttr.getAs<ArrayAttr>("SymShapeAttr");
if (!symShapeAttr || symShapeAttr.size() != rank) return memorySpace;
auto *context = &getContext();
SmallVector<Attribute> newSymShapes;
for (unsigned d = 0; d < rank; d++)
newSymShapes.push_back(shrinkDims[d] ? StringAttr::get(context, "1") : symShapeAttr[d]);
SmallVector<NamedAttribute> entries;
for (auto entry : dictAttr) {
if (entry.getName().getValue() == "SymShapeAttr")
entries.push_back({entry.getName(), ArrayAttr::get(context, newSymShapes)});
else
entries.push_back(entry);
}
return DictionaryAttr::get(context, entries);
}
bool AllocBufferShrinkPass::allAccessMapsAreIdentity() {
for (auto &info : accessInfos) {
if (auto loadOp = dyn_cast<affine::AffineLoadOp>(info.op)) {
if (!loadOp.getAffineMap().isIdentity()) return false;
} else if (auto storeOp = dyn_cast<affine::AffineStoreOp>(info.op)) {
if (!storeOp.getAffineMap().isIdentity()) return false;
}
}
return true;
}
Operation *AllocBufferShrinkPass::findCommonLoopForDim(unsigned d) {
Operation *commonLoop = nullptr;
bool allAbsorbed = true;
for (auto &info : accessInfos) {
int accessDim = info.dimMapping[d];
if (accessDim < 0) continue;
allAbsorbed = false;
auto indices = getAccessIndices(info.op);
if (accessDim >= static_cast<int>(indices.size())) return nullptr;
auto *loopOp = getOwningLoopOp(indices[accessDim]);
if (!loopOp) return nullptr;
if (!commonLoop)
commonLoop = loopOp;
else if (commonLoop != loopOp)
return nullptr;
}
if (allAbsorbed || !commonLoop) return nullptr;
for (auto &info : accessInfos) {
if (!commonLoop->isAncestor(info.op)) return nullptr;
}
return commonLoop;
}
void AllocBufferShrinkPass::analyzeShrinkableDims(MemRefType memrefType) {
auto shape = memrefType.getShape();
shrinkDims.assign(rank, false);
dimLoops.assign(rank, nullptr);
for (unsigned d = 0; d < rank; d++) {
if (ShapedType::isDynamic(shape[d]) || shape[d] <= 1) continue;
dimLoops[d] = findCommonLoopForDim(d);
shrinkDims[d] = (dimLoops[d] != nullptr);
}
}
void AllocBufferShrinkPass::validateSubviewOffsets() {
for (auto *op : viewOps) {
auto subviewOp = dyn_cast<memref::SubViewOp>(op);
if (!subviewOp) continue;
auto offsets = subviewOp.getStaticOffsets();
auto sizes = subviewOp.getStaticSizes();
auto sourceType = subviewOp.getSourceType();
for (unsigned j = 0; j < rank; j++) {
bool hasNonZeroOffset = (offsets[j] != 0);
bool hasPartialSize = (!ShapedType::isDynamic(sizes[j]) && sizes[j] < sourceType.getDimSize(j));
if (!hasNonZeroOffset && !hasPartialSize) continue;
shrinkDims[j] = false;
if (!dimLoops[j]) continue;
for (unsigned d = 0; d < rank; d++) {
if (shrinkDims[d] && dimLoops[d] && dimLoops[j]->isAncestor(dimLoops[d])) shrinkDims[d] = false;
}
}
}
}
memref::AllocOp AllocBufferShrinkPass::createShrunkAlloc(memref::AllocOp allocOp, MemRefType memrefType) {
auto shape = memrefType.getShape();
SmallVector<int64_t> newShape;
SmallVector<Value> newDynamicSizes;
ValueRange oldDynamicSizes = allocOp.getDynamicSizes();
unsigned oldDynamicIdx = 0;
for (unsigned d = 0; d < rank; d++) {
Value oldDynamicSize;
bool oldIsDynamic = ShapedType::isDynamic(shape[d]);
if (oldIsDynamic) oldDynamicSize = oldDynamicSizes[oldDynamicIdx++];
newShape.push_back(shrinkDims[d] ? 1 : shape[d]);
if (oldIsDynamic && !shrinkDims[d]) newDynamicSizes.push_back(oldDynamicSize);
}
auto newMemorySpace = updateMemSpaceForShrink(memrefType.getMemorySpace());
auto newType = MemRefType::get(newShape, memrefType.getElementType(), memrefType.getLayout(), newMemorySpace);
OpBuilder builder(allocOp);
return builder.create<memref::AllocOp>(allocOp.getLoc(), newType, newDynamicSizes, allocOp.getSymbolOperands(),
allocOp.getAlignmentAttr());
}
void AllocBufferShrinkPass::rebuildSubView(memref::SubViewOp subviewOp) {
Value newSource = replacementMap.lookup(subviewOp.getSource());
if (!newSource) return;
SmallVector<OpFoldResult> newOffsets = subviewOp.getMixedOffsets();
SmallVector<OpFoldResult> newSizes = subviewOp.getMixedSizes();
SmallVector<OpFoldResult> newStrides = subviewOp.getMixedStrides();
for (unsigned d = 0; d < rank; d++) {
if (shrinkDims[d]) newSizes[d] = OpBuilder(subviewOp).getI64IntegerAttr(1);
}
OpBuilder builder(subviewOp);
auto newSubview = builder.create<memref::SubViewOp>(subviewOp.getLoc(), newSource, newOffsets, newSizes, newStrides);
replacementMap[subviewOp.getResult()] = newSubview.getResult();
}
void AllocBufferShrinkPass::rebuildCollapseShape(memref::CollapseShapeOp collapseOp) {
Value newSource = replacementMap.lookup(collapseOp.getSrc());
if (!newSource) return;
OpBuilder builder(collapseOp);
auto newCollapse =
builder.create<memref::CollapseShapeOp>(collapseOp.getLoc(), newSource, collapseOp.getReassociationIndices());
replacementMap[collapseOp.getResult()] = newCollapse.getResult();
}
void AllocBufferShrinkPass::rebuildExpandShape(memref::ExpandShapeOp expandOp) {
Value newSource = replacementMap.lookup(expandOp.getSrc());
if (!newSource) return;
auto newSourceType = dyn_cast<MemRefType>(newSource.getType());
auto oldResultType = expandOp.getResultType();
auto reassociation = expandOp.getReassociationIndices();
SmallVector<int64_t> newOutputShape(oldResultType.getShape().begin(), oldResultType.getShape().end());
for (unsigned g = 0; g < reassociation.size(); g++) {
int64_t newSrcSize = newSourceType.getDimSize(g);
int64_t oldSrcSize = expandOp.getSrcType().getDimSize(g);
if (newSrcSize == oldSrcSize) continue;
if (newSrcSize == 1) {
for (int64_t resultDim : reassociation[g]) newOutputShape[resultDim] = 1;
}
}
OpBuilder builder(expandOp);
SmallVector<OpFoldResult> outputShapeOfr;
std::transform(newOutputShape.begin(), newOutputShape.end(), std::back_inserter(outputShapeOfr),
[&](int64_t s) { return builder.getIndexAttr(s); });
auto newExpand = builder.create<memref::ExpandShapeOp>(expandOp.getLoc(), newOutputShape, newSource,
expandOp.getReassociationIndices(), outputShapeOfr);
replacementMap[expandOp.getResult()] = newExpand.getResult();
}
void AllocBufferShrinkPass::rebuildMemorySpaceCast(memref::MemorySpaceCastOp castOp) {
Value newSource = replacementMap.lookup(castOp.getSource());
if (!newSource) return;
auto newSourceType = dyn_cast<MemRefType>(newSource.getType());
auto oldResultType = dyn_cast<MemRefType>(castOp.getResult().getType());
auto targetMemSpace = updateMemSpaceForShrink(oldResultType.getMemorySpace());
auto newResultType = MemRefType::get(newSourceType.getShape(), newSourceType.getElementType(),
newSourceType.getLayout(), targetMemSpace);
OpBuilder builder(castOp);
auto newCast = builder.create<memref::MemorySpaceCastOp>(castOp.getLoc(), newResultType, newSource);
replacementMap[castOp.getResult()] = newCast.getResult();
}
void AllocBufferShrinkPass::rebuildViewChain(memref::AllocOp oldAlloc, memref::AllocOp newAlloc) {
replacementMap.clear();
replacementMap[oldAlloc.getResult()] = newAlloc.getResult();
for (auto *op : viewOps) {
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
rebuildSubView(subviewOp);
else if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
rebuildCollapseShape(collapseOp);
else if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
rebuildExpandShape(expandOp);
else if (auto castOp = dyn_cast<memref::MemorySpaceCastOp>(op))
rebuildMemorySpaceCast(castOp);
}
}
void AllocBufferShrinkPass::rewriteAccessOps() {
for (auto &info : accessInfos) {
Value oldMemref = getAccessMemRef(info.op);
SmallVector<Value> oldIndices = getAccessIndices(info.op);
Value newMemref = replacementMap.lookup(oldMemref);
if (!newMemref) newMemref = oldMemref;
SmallVector<Value> newIndices(oldIndices);
for (unsigned d = 0; d < rank; d++) {
int mapped = info.dimMapping[d];
if (shrinkDims[d] && mapped >= 0 && mapped < static_cast<int>(newIndices.size())) newIndices[mapped] = zeroIdx;
}
OpBuilder builder(info.op);
Operation *newOp = nullptr;
if (auto loadOp = dyn_cast<affine::AffineLoadOp>(info.op)) {
auto newLoad = builder.create<affine::AffineLoadOp>(info.op->getLoc(), newMemref, newIndices);
loadOp.getResult().replaceAllUsesWith(newLoad.getResult());
newOp = newLoad;
} else if (auto loadOp = dyn_cast<memref::LoadOp>(info.op)) {
auto newLoad = builder.create<memref::LoadOp>(info.op->getLoc(), newMemref, newIndices);
loadOp.getResult().replaceAllUsesWith(newLoad.getResult());
newOp = newLoad;
} else if (auto storeOp = dyn_cast<affine::AffineStoreOp>(info.op)) {
newOp =
builder.create<affine::AffineStoreOp>(info.op->getLoc(), storeOp.getValueToStore(), newMemref, newIndices);
} else if (auto storeOp = dyn_cast<memref::StoreOp>(info.op)) {
newOp = builder.create<memref::StoreOp>(info.op->getLoc(), storeOp.getValueToStore(), newMemref, newIndices);
} else {
continue;
}
for (auto namedAttr : info.op->getAttrs()) {
if (namedAttr.getName().getValue() != "map") newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
}
info.op->erase();
}
}
void AllocBufferShrinkPass::eraseOldOps(memref::AllocOp allocOp) {
for (auto it = viewOps.rbegin(); it != viewOps.rend(); ++it) {
if ((*it)->use_empty()) (*it)->erase();
}
if (allocOp->use_empty()) allocOp->erase();
}
bool AllocBufferShrinkPass::processAlloc(memref::AllocOp allocOp) {
auto memrefType = allocOp.getType();
rank = memrefType.getRank();
if (rank == 0) return false;
accessInfos.clear();
viewOps.clear();
SmallVector<int> identityMapping(rank);
for (unsigned i = 0; i < rank; i++) identityMapping[i] = static_cast<int>(i);
if (!collectAccessOpsRecursive(allocOp.getResult(), identityMapping)) return false;
if (accessInfos.empty()) return false;
if (!allAccessMapsAreIdentity()) return false;
analyzeShrinkableDims(memrefType);
validateSubviewOffsets();
if (!llvm::any_of(shrinkDims, [](bool v) { return v; })) {
return false;
}
auto newAlloc = createShrunkAlloc(allocOp, memrefType);
if (!zeroIdx || zeroIdx.getDefiningOp()->getBlock() != newAlloc->getBlock()) {
OpBuilder zeroBuilder(newAlloc->getBlock(), std::next(Block::iterator(newAlloc)));
zeroIdx = zeroBuilder.create<arith::ConstantIndexOp>(newAlloc.getLoc(), 0);
}
rebuildViewChain(allocOp, newAlloc);
rewriteAccessOps();
eraseOldOps(allocOp);
return true;
}
void AllocBufferShrinkPass::runOnOperation() {
zeroIdx = Value();
SmallVector<memref::AllocOp> allocOps;
getOperation()->walk([&](memref::AllocOp op) { allocOps.push_back(op); });
for (auto allocOp : allocOps) processAlloc(allocOp);
}
}
std::unique_ptr<Pass> createAllocBufferShrinkPass() { return std::make_unique<AllocBufferShrinkPass>(); }
}