* Copyright 2025 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "akg/Dialect/Affine/Transforms/AKGLoopFusion.h"
#include <regex>
#include <sstream>
#include <string>
#include <vector>
#include "akg/Dialect/Affine/Analysis/DependenceAnalysis.h"
#include "akg/Dialect/Affine/Analysis/AKGLoopFusionAnalyzer.h"
#include "akg/Dialect/Affine/Analysis/AKGLoopFusionBuilder.h"
#include "akg/Dialect/MindSpore/IR/MindSporeOps.h"
#include "akg/Utils/AnalysisCommon.hpp"
#include "akg/Analysis/SymbolicShapeAnalysis.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Value.h"
namespace mlir {
#define GEN_PASS_DEF_AKGLOOPFUSION
#define GEN_PASS_DECL_AKGLOOPFUSION
#include "akg/Dialect/Affine/Passes.h.inc"
}
#define DEBUG_TYPE "akg-loop-fusion"
namespace {
struct AKGLoopFusion : public mlir::impl::AKGLoopFusionBase<AKGLoopFusion> {
AKGLoopFusion() {}
void runOnOperation() override;
private:
void runOnBlock(mlir::Block *block);
void runPreProcess();
void replaceDimWithPrimes(mlir::func::FuncOp funcOp);
void restoreDimFromPrimes(mlir::func::FuncOp funcOp);
std::optional<llvm::SmallVector<std::string>> getSymShapeAttrFromValue(mlir::Value source);
std::optional<int64_t> getConstantDimIndex(mlir::Value dimIndex);
llvm::SmallVector<int64_t> getDynamicDimIndicesOfValue(mlir::Value v);
void collectDimOperationsFromLoops(mlir::func::FuncOp funcOp,
llvm::StringMap<llvm::SmallVector<mlir::Value>> &axisToDimValues);
llvm::DenseMap<mlir::Value, mlir::Value> primeToDimMap;
bool printFusionInfo{false};
static int64_t newSymbolCount;
static constexpr auto NEW_SYMBOL = "si";
static constexpr int64_t kPrimes[] = {1000003, 1000033, 1000037, 1000039, 1000081,
1000099, 1000117, 1000121, 1000133, 1000139};
static constexpr size_t kNumPrimes = sizeof(kPrimes) / sizeof(kPrimes[0]);
};
}
int64_t AKGLoopFusion::newSymbolCount = 0;
std::optional<llvm::SmallVector<std::string>> AKGLoopFusion::getSymShapeAttrFromValue(mlir::Value source) {
auto getSymShape = [](mlir::Type type) -> std::optional<llvm::SmallVector<std::string>> {
if (mlir::isa<mlir::RankedTensorType, mlir::MemRefType>(type)) {
return mlir::SymbolicShapeAnalysis::getInstance().getSymbolicShape(type);
}
return std::nullopt;
};
if (auto symShape = getSymShape(source.getType())) {
return symShape;
}
if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(source)) {
if (auto funcOp = mlir::dyn_cast<mlir::func::FuncOp>(blockArg.getOwner()->getParentOp())) {
mlir::Type argType = funcOp.getFunctionType().getInput(blockArg.getArgNumber());
return getSymShape(argType);
}
return std::nullopt;
}
if (auto *defOp = source.getDefiningOp()) {
if (auto toMemref = mlir::dyn_cast<mlir::bufferization::ToMemrefOp>(defOp)) {
mlir::Value tensor = toMemref.getTensor();
if (auto symShape = getSymShape(tensor.getType())) {
return symShape;
}
return getSymShapeAttrFromValue(tensor);
}
}
return std::nullopt;
}
std::optional<int64_t> AKGLoopFusion::getConstantDimIndex(mlir::Value dimIndex) {
if (auto constOp = dimIndex.getDefiningOp<mlir::arith::ConstantIndexOp>()) {
return constOp.value();
}
if (auto constOp = dimIndex.getDefiningOp<mlir::arith::ConstantOp>()) {
if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
return intAttr.getInt();
}
}
return std::nullopt;
}
llvm::SmallVector<int64_t> AKGLoopFusion::getDynamicDimIndicesOfValue(mlir::Value v) {
llvm::SmallVector<int64_t> dynDims;
auto st = mlir::dyn_cast<mlir::ShapedType>(v.getType());
if (!st || !st.hasRank()) return dynDims;
auto shape = st.getShape();
for (int64_t i = 0; i < static_cast<int64_t>(shape.size()); ++i) {
if (shape[i] == mlir::ShapedType::kDynamic) dynDims.push_back(i);
}
return dynDims;
}
void AKGLoopFusion::collectDimOperationsFromLoops(mlir::func::FuncOp funcOp,
llvm::StringMap<llvm::SmallVector<mlir::Value>> &axisToDimValues) {
funcOp.walk([&](mlir::affine::AffineForOp forOp) {
auto processBoundOperand = [&](mlir::Value operand) {
auto *defOp = operand.getDefiningOp();
if (defOp && mlir::isa<mlir::memref::DimOp>(defOp)) {
mlir::Value source = defOp->getOperand(0);
mlir::Value dimIndexValue = defOp->getOperand(1);
auto constIndex = getConstantDimIndex(dimIndexValue);
if (!constIndex.has_value()) return;
auto symShape = getSymShapeAttrFromValue(source);
if (!symShape.has_value() || *constIndex >= static_cast<int64_t>(symShape->size())) {
return;
}
std::string axisKey = (*symShape)[*constIndex];
axisToDimValues[axisKey].push_back(operand);
} else {
mlir::Value useOp;
int64_t userIndex;
for (auto &use : defOp->getResult(0).getUses()) {
mlir::Operation *userOp = use.getOwner();
if (auto allocOp = mlir::dyn_cast<mlir::memref::AllocOp>(userOp)) {
useOp = allocOp.getResult();
userIndex = static_cast<int64_t>(use.getOperandNumber());
break;
}
}
if (!useOp) {
std::string axisKey = std::string(NEW_SYMBOL) + std::to_string(newSymbolCount);
++newSymbolCount;
axisToDimValues[axisKey].push_back(operand);
return;
}
auto dynDims = getDynamicDimIndicesOfValue(useOp);
if (dynDims.empty()) {
llvm::errs() << "This op has no dynamic shape\n";
return;
}
int64_t dynamicIndex = dynDims[userIndex];
auto symShape = getSymShapeAttrFromValue(useOp);
if (!symShape.has_value() || dynamicIndex >= static_cast<int64_t>(symShape->size())) {
return;
}
std::string axisKey = (*symShape)[dynamicIndex];
axisToDimValues[axisKey].push_back(operand);
}
};
for (mlir::Value operand : forOp.getUpperBoundOperands()) {
processBoundOperand(operand);
}
for (mlir::Value operand : forOp.getLowerBoundOperands()) {
processBoundOperand(operand);
}
});
}
void AKGLoopFusion::replaceDimWithPrimes(mlir::func::FuncOp funcOp) {
llvm::StringMap<llvm::SmallVector<mlir::Value>> axisToDimValues;
collectDimOperationsFromLoops(funcOp, axisToDimValues);
if (axisToDimValues.empty()) {
return;
}
mlir::OpBuilder builder(funcOp.getContext());
size_t primeIndex = 0;
for (auto &entry : axisToDimValues) {
if (primeIndex >= kNumPrimes) {
llvm::errs() << "Warning: Not enough primes for all dim axes\n";
break;
}
llvm::StringRef axisKey = entry.getKey();
auto &dimValues = entry.getValue();
mlir::Value earliestDim;
for (mlir::Value dimVal : dimValues) {
if (!earliestDim) {
earliestDim = dimVal;
continue;
}
auto *op1 = earliestDim.getDefiningOp();
if (!op1) break;
auto *op2 = dimVal.getDefiningOp();
if (!op2) {
earliestDim = dimVal;
break;
}
if (op1->getBlock() == op2->getBlock()) {
if (op2->isBeforeInBlock(op1)) {
earliestDim = dimVal;
}
} else {
llvm::errs() << "Dim ops for the same axisKey are in different blocks\n";
return;
}
}
mlir::Value representativeDim = earliestDim ? earliestDim : dimValues.front();
builder.setInsertionPointAfterValue(representativeDim);
auto primeConst = builder.create<mlir::arith::ConstantIndexOp>(representativeDim.getLoc(), kPrimes[primeIndex]);
primeToDimMap[primeConst] = representativeDim;
for (mlir::Value dimValue : dimValues) {
dimValue.replaceAllUsesWith(primeConst);
}
primeIndex++;
}
}
void AKGLoopFusion::restoreDimFromPrimes(mlir::func::FuncOp funcOp) {
for (auto &[primeConst, representativeDim] : primeToDimMap) {
primeConst.replaceAllUsesWith(representativeDim);
if (auto constOp = primeConst.getDefiningOp<mlir::arith::ConstantIndexOp>()) {
if (constOp.use_empty()) {
constOp.erase();
}
}
}
primeToDimMap.clear();
}
void AKGLoopFusion::runOnBlock(mlir::Block *block) {
auto dependenceGraph = mlir::akg::MemRefDependenceGraphForFusion(block);
if (!dependenceGraph.init()) {
return;
}
if (printFusionInfo) {
dependenceGraph.dump();
}
mlir::func::FuncOp funcOp = getOperation();
mlir::akg::FusionAnalyzer analyzer(dependenceGraph, funcOp);
analyzer.plan();
if (analyzer.fusionPlans.empty()) {
return;
}
if (printFusionInfo) {
analyzer.dump();
}
mlir::akg::FusionCodeGenHelper codegenerator = mlir::akg::FusionCodeGenHelper(dependenceGraph);
for (size_t i = 0; i < analyzer.fusionPlans.size(); ++i) {
auto &plan = analyzer.fusionPlans[i];
auto actualSrcId = codegenerator.getAliasId(plan.fusedBand.from);
auto actualDstId = codegenerator.getAliasId(plan.fusedBand.to);
if (actualSrcId == actualDstId) {
continue;
}
bool hasConflict = false;
for (size_t j = i + 1; j < analyzer.fusionPlans.size(); ++j) {
auto &futurePlan = analyzer.fusionPlans[j];
auto futureSrcId = codegenerator.getAliasId(futurePlan.fusedBand.from);
auto futureDstId = codegenerator.getAliasId(futurePlan.fusedBand.to);
if ((actualSrcId == futureDstId && actualDstId == futureSrcId) ||
(actualSrcId == futureSrcId && actualDstId == futureDstId)) {
hasConflict = true;
break;
}
}
if (hasConflict) {
continue;
}
auto srcFor = mlir::dyn_cast<mlir::affine::AffineForOp>(dependenceGraph.getNode(actualSrcId)->op);
auto dstFor = mlir::dyn_cast<mlir::affine::AffineForOp>(dependenceGraph.getNode(actualDstId)->op);
if (srcFor && dstFor) {
if (plan.fusionType == "V") {
codegenerator.doVFuse(actualSrcId, actualDstId, srcFor, dstFor, plan);
} else if (plan.fusionType == "H") {
codegenerator.doHFuse(actualSrcId, actualDstId, srcFor, dstFor, plan);
} else {
llvm::outs() << "Warning: Could not find valid operations for fusion plan: node " << plan.fusedBand.from
<< " to " << plan.fusedBand.to << "\n";
}
} else {
llvm::outs() << "Warning: Could not find valid operations for fusion plan: node " << plan.fusedBand.from << " to "
<< plan.fusedBand.to << "\n";
}
}
}
void AKGLoopFusion::runPreProcess() {
mlir::func::FuncOp funcOp = getOperation();
funcOp.walk([&](mlir::affine::AffineForOp inner) {
if (auto outer = mlir::dyn_cast<mlir::affine::AffineForOp>(inner->getParentOp())) {
if (mlir::CommonUtils::isReduceAxis(funcOp, inner->getParentOp())) {
mlir::affine::interchangeLoops(outer, inner);
}
}
});
replaceDimWithPrimes(funcOp);
}
void AKGLoopFusion::runOnOperation() {
auto funcOp = getOperation();
runPreProcess();
for (mlir::Region ®ion : funcOp->getRegions()) {
for (mlir::Block &block : region.getBlocks()) {
runOnBlock(&block);
}
}
restoreDimFromPrimes(funcOp);
}
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> mlir::createAKGLoopFusionPass() {
return std::make_unique<AKGLoopFusion>();
}