* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mfusion/Dialect/Mfuse/Support/FusedOpUtils.h"
#include <vector>
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mfusion/Dialect/Mfuse/IR/Mfuse.h"
#include "mfusion/Dialect/Mfuse/Transforms/Cluster/Utils.h"
#include "mfusion/Support/Logging.h"
namespace mlir {
namespace mfuse {
namespace {
bool isFiniteValue(DenseElementsAttr denseAttr) {
Type elementType = denseAttr.getElementType();
if (isa<FloatType>(elementType)) {
auto apfloatValues = denseAttr.getValues<APFloat>();
if (!apfloatValues.empty()) {
APFloat value = *apfloatValues.begin();
return !value.isNaN() && !value.isInfinity();
}
}
return true;
}
bool checkForGroupedMatmul(const std::string &opName, DenseElementsAttr denseAttr, size_t idx) {
Type elementType = denseAttr.getElementType();
if (auto intType = dyn_cast<IntegerType>(elementType)) {
const size_t kGroupListIndex = 7;
if (intType.getWidth() == 1 ||
(opName == "mfuse.grouped_matmul" && idx == kGroupListIndex && intType.getWidth() == 64)) {
return false;
}
}
return true;
}
struct PartitionState {
size_t preservedOps{0};
size_t segmentCount{0};
int prevIndex{-1};
int segmentStart{-1};
bool initialized{false};
};
struct SliceAnalysis {
bool analyzed{false};
bool valid{false};
};
bool isBetterPartition(size_t lhsPreservedOps, size_t lhsSegmentCount, size_t rhsPreservedOps, size_t rhsSegmentCount) {
if (lhsPreservedOps != rhsPreservedOps) {
return lhsPreservedOps > rhsPreservedOps;
}
return lhsSegmentCount < rhsSegmentCount;
}
}
llvm::DenseSet<Operation *> collectConstantsToCluster(llvm::ArrayRef<Operation *> clusterOps) {
llvm::DenseSet<Operation *> clusterOpSet(clusterOps.begin(), clusterOps.end());
llvm::DenseSet<Operation *> constantsToCluster;
const auto &opIndexInfo = getConstInputIndexInfo();
for (Operation *op : clusterOps) {
std::string opName = op->getName().getStringRef().str();
const std::unordered_set<size_t> *constIndices = nullptr;
auto iter = opIndexInfo.find(opName);
if (iter != opIndexInfo.end()) {
constIndices = &iter->second;
}
for (size_t idx = 0; idx < op->getNumOperands(); ++idx) {
Value operand = op->getOperand(idx);
Operation *defOp = operand.getDefiningOp();
if (defOp == nullptr || clusterOpSet.contains(defOp) || constantsToCluster.contains(defOp)) {
continue;
}
auto constOp = dyn_cast<mfuse::ConstantOp>(defOp);
if (!constOp) {
continue;
}
Type resultType = constOp.getResult().getType();
if (!isa<TensorType>(resultType)) {
MLOG(DEBUG) << "Constant at index " << idx << " for " << opName << " is not a tensor type, keeping in cluster";
constantsToCluster.insert(defOp);
continue;
}
if (constIndices != nullptr && constIndices->count(idx) != 0) {
MLOG(DEBUG) << "Constant at index " << idx << " for " << opName << " is value-dependent, keeping in cluster";
constantsToCluster.insert(defOp);
continue;
}
Attribute valueAttr = constOp.getValueAttr();
if (!valueAttr) {
continue;
}
auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr);
if (!denseAttr || denseAttr.getNumElements() != 1 || !isFiniteValue(denseAttr)) {
continue;
}
if (!checkForGroupedMatmul(opName, denseAttr, idx)) {
continue;
}
MLOG(DEBUG) << "Constant at index " << idx << " for " << opName
<< " is single-element finite value, keeping in cluster";
constantsToCluster.insert(defOp);
}
}
return constantsToCluster;
}
llvm::SetVector<Value> findExternalInputs(llvm::ArrayRef<Operation *> clusterOps,
const llvm::DenseSet<Operation *> &clusterOpSet) {
llvm::SetVector<Value> externalInputs;
for (Operation *op : clusterOps) {
for (Value operand : op->getOperands()) {
Operation *defOp = operand.getDefiningOp();
if (defOp == nullptr || !clusterOpSet.contains(defOp)) {
externalInputs.insert(operand);
}
}
}
return externalInputs;
}
llvm::SetVector<Value> findExternalOutputs(llvm::ArrayRef<Operation *> clusterOps,
const llvm::DenseSet<Operation *> &constantsToCluster,
const llvm::DenseSet<Operation *> &clusterOpSet) {
llvm::SetVector<Value> externalOutputs;
for (Operation *op : clusterOps) {
if (constantsToCluster.contains(op)) {
continue;
}
for (Value result : op->getResults()) {
for (Operation *user : result.getUsers()) {
if (!clusterOpSet.contains(user)) {
externalOutputs.insert(result);
break;
}
}
}
}
return externalOutputs;
}
Operation *findValidInsertPoint(const llvm::SmallVector<Operation *> &clusterOps,
const llvm::SetVector<Value> &externalInputs,
const llvm::SetVector<Value> &externalOutputs,
const llvm::DenseSet<Operation *> &clusterOpSet) {
Operation *insertPoint = clusterOps.front();
for (Value input : externalInputs) {
Operation *defOp = input.getDefiningOp();
if (defOp && defOp->getBlock() == insertPoint->getBlock() && insertPoint->isBeforeInBlock(defOp)) {
insertPoint = defOp;
}
}
for (Value output : externalOutputs) {
for (Operation *user : output.getUsers()) {
if (clusterOpSet.contains(user)) {
continue;
}
if (user->getBlock() == insertPoint->getBlock() && !insertPoint->isBeforeInBlock(user)) {
MLOG(DEBUG) << "FusedOp insert point " << insertPoint->getLoc() << " is not before non-cluster user "
<< user->getLoc();
return nullptr;
}
}
}
return insertPoint;
}
bool buildCluster(llvm::ArrayRef<Operation *> baseClusterOps, ClusterBuildInfo &buildInfo) {
if (baseClusterOps.size() < kMinClusterSize) {
return false;
}
buildInfo = ClusterBuildInfo{};
buildInfo.clusterOps.assign(baseClusterOps.begin(), baseClusterOps.end());
buildInfo.constantsToCluster = collectConstantsToCluster(buildInfo.clusterOps);
std::copy(buildInfo.constantsToCluster.begin(), buildInfo.constantsToCluster.end(),
std::back_inserter(buildInfo.clusterOps));
std::sort(buildInfo.clusterOps.begin(), buildInfo.clusterOps.end(),
[](Operation *lhs, Operation *rhs) { return lhs->isBeforeInBlock(rhs); });
buildInfo.clusterOpSet.insert(buildInfo.clusterOps.begin(), buildInfo.clusterOps.end());
buildInfo.externalInputs = findExternalInputs(buildInfo.clusterOps, buildInfo.clusterOpSet);
buildInfo.externalOutputs =
findExternalOutputs(buildInfo.clusterOps, buildInfo.constantsToCluster, buildInfo.clusterOpSet);
if (buildInfo.externalOutputs.empty()) {
return false;
}
buildInfo.insertPoint = findValidInsertPoint(buildInfo.clusterOps, buildInfo.externalInputs,
buildInfo.externalOutputs, buildInfo.clusterOpSet);
return buildInfo.insertPoint != nullptr;
}
llvm::SmallVector<llvm::SmallVector<Operation *>> partitionClusterOps(llvm::ArrayRef<Operation *> baseClusterOps,
size_t minClusterSize) {
llvm::SmallVector<llvm::SmallVector<Operation *>> partitions;
const size_t clusterSize = baseClusterOps.size();
if (clusterSize < minClusterSize) {
return partitions;
}
std::vector<std::vector<SliceAnalysis>> sliceCache(clusterSize, std::vector<SliceAnalysis>(clusterSize + 1));
auto getOrAnalyze = [&](size_t start, size_t end) -> SliceAnalysis & {
SliceAnalysis &analysis = sliceCache[start][end];
if (!analysis.analyzed) {
analysis.analyzed = true;
ClusterBuildInfo buildInfo;
analysis.valid = buildCluster(baseClusterOps.slice(start, end - start), buildInfo);
}
return analysis;
};
std::vector<PartitionState> dp(clusterSize + 1);
dp[0].initialized = true;
for (size_t end = 1; end <= clusterSize; ++end) {
if (dp[end - 1].initialized) {
dp[end] = dp[end - 1];
dp[end].prevIndex = static_cast<int>(end - 1);
dp[end].segmentStart = -1;
dp[end].initialized = true;
}
for (size_t start = 0; start + minClusterSize <= end; ++start) {
SliceAnalysis &analysis = getOrAnalyze(start, end);
if (!analysis.valid || !dp[start].initialized) {
continue;
}
size_t candidatePreservedOps = dp[start].preservedOps + (end - start);
size_t candidateSegmentCount = dp[start].segmentCount + 1;
if (isBetterPartition(candidatePreservedOps, candidateSegmentCount, dp[end].preservedOps, dp[end].segmentCount)) {
dp[end].preservedOps = candidatePreservedOps;
dp[end].segmentCount = candidateSegmentCount;
dp[end].prevIndex = static_cast<int>(start);
dp[end].segmentStart = static_cast<int>(start);
dp[end].initialized = true;
}
}
}
if (!dp[clusterSize].initialized || dp[clusterSize].preservedOps < minClusterSize) {
return partitions;
}
for (int end = static_cast<int>(clusterSize); end > 0;) {
const PartitionState &state = dp[end];
if (state.segmentStart >= 0) {
size_t start = static_cast<size_t>(state.segmentStart);
partitions.emplace_back(baseClusterOps.slice(start, static_cast<size_t>(end) - start).begin(),
baseClusterOps.slice(start, static_cast<size_t>(end) - start).end());
}
end = state.prevIndex;
}
std::reverse(partitions.begin(), partitions.end());
return partitions;
}
}
}