* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mfusion/Conversion/Passes.h"
#include "mfusion/Dialect/Mfuse/Transforms/Outlining/FusionAttributes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mfusion/Dialect/Dvm/IR/Dvm.h"
#include "mfusion/Dialect/Mfuse/IR/Mfuse.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMFUSETODVM
#include "mfusion/Conversion/Passes.h.inc"
namespace {
static bool isDvmOutlinedFunc(func::FuncOp func) {
if (!func || !func->hasAttr(mfusion_attrs::kOutlined)) {
return false;
}
auto fusionTypeAttr = func->getAttrOfType<StringAttr>(mfusion_attrs::kFusionType);
return fusionTypeAttr && fusionTypeAttr.getValue() == "dvm";
}
static bool isDvmOutlinedOp(Operation *op) {
auto func = op->getParentOfType<func::FuncOp>();
return func && isDvmOutlinedFunc(func);
}
struct ConvertMulOp : public OpConversionPattern<mfuse::MulOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(mfuse::MulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isDvmOutlinedOp(op.getOperation())) {
return failure();
}
auto attr = dvm::BinaryOpTypeAttr::get(getContext(), dvm::BinaryOpType::Mul);
rewriter.replaceOpWithNewOp<dvm::BinaryOp>(op, op.getResult().getType(), attr, adaptor.getLhs(), adaptor.getRhs());
return success();
}
};
static void insertLoadStoreOps(ModuleOp module) {
for (auto func : module.getOps<func::FuncOp>()) {
if (func.isExternal()) continue;
if (!isDvmOutlinedFunc(func)) continue;
if (!func.getBlocks().empty()) {
Block &entryBlock = func.front();
OpBuilder builder(&entryBlock, entryBlock.begin());
for (auto arg : entryBlock.getArguments()) {
if (mlir::isa<RankedTensorType>(arg.getType())) {
auto loadOp = builder.create<dvm::LoadOp>(func.getLoc(), arg.getType(), arg);
arg.replaceAllUsesExcept(loadOp.getResult(), loadOp);
}
}
}
func.walk([&](func::ReturnOp returnOp) {
OpBuilder builder(returnOp);
for (unsigned i = 0; i < returnOp.getNumOperands(); ++i) {
Value operand = returnOp.getOperand(i);
if (mlir::isa<RankedTensorType>(operand.getType())) {
auto storeOp = builder.create<dvm::StoreOp>(returnOp.getLoc(), operand.getType(), operand);
returnOp.setOperand(i, storeOp.getResult());
}
}
});
}
}
}
struct ConvertMfuseToDvmPass : public PassWrapper<ConvertMfuseToDvmPass, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "convert-mfuse-to-dvm"; }
StringRef getDescription() const final { return "Convert outlined Mfuse subgraphs to DVM dialect operations"; }
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<mlir::mfuse::MfuseDialect>();
registry.insert<mlir::dvm::DvmDialect>();
registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::arith::ArithDialect>();
}
void runOnOperation() override {
ModuleOp module = getOperation();
MLIRContext *ctx = &getContext();
insertLoadStoreOps(module);
ConversionTarget target(*ctx);
target.addLegalDialect<mlir::dvm::DvmDialect>();
target.addLegalDialect<mlir::func::FuncDialect>();
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalDialect<mlir::mfuse::MfuseDialect>();
target.addDynamicallyLegalOp<mlir::mfuse::MulOp>(
[](mlir::mfuse::MulOp op) { return !isDvmOutlinedOp(op.getOperation()); });
RewritePatternSet patterns(ctx);
patterns.add<ConvertMulOp>(ctx);
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}
};
std::unique_ptr<Pass> createConvertMfuseToDvmPass() { return std::make_unique<ConvertMfuseToDvmPass>(); }
}