* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mfusion/Conversion/MfuseToTorch/MfuseToTorch.h"
#include <algorithm>
#include <iterator>
#include <string>
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mfusion/Conversion/MfuseToTorch/MfuseMetaToTorch.h"
#include "mfusion/Conversion/MfuseToTorch/MfuseAclnnToTorch.h"
#include "mfusion/Dialect/Mfuse/IR/Mfuse.h"
#include "mfusion/Dialect/Mfuse/IR/MfuseDialect.h"
#include "mfusion/Support/Logging.h"
#include "mfusion/Conversion/PdllHelper.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "MfuseToTorch.pdll.h.inc"
#include "llvm/Support/CommandLine.h"
namespace {
namespace TorchD = mlir::torch::Torch;
void populateMfuseToTorchTypeConversions(mlir::TypeConverter &converter) {
converter.addConversion([](mlir::RankedTensorType type) -> mlir::Type {
auto encoding = type.getEncoding();
if (encoding) {
auto dictAttr = mlir::dyn_cast<mlir::DictionaryAttr>(encoding);
if (dictAttr && dictAttr.contains(mlir::mfuse::kScalarMarkerAttr)) {
auto elementType = type.getElementType();
if (mlir::isa<mlir::FloatType>(elementType)) {
return TorchD::FloatType::get(type.getContext());
}
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType)) {
return TorchD::IntType::get(type.getContext());
}
}
}
llvm::SmallVector<int64_t> shape;
auto sizes = type.getShape();
shape.reserve(sizes.size());
std::transform(sizes.begin(), sizes.end(), std::back_inserter(shape),
[](int64_t dim) { return dim == mlir::ShapedType::kDynamic ? TorchD::kUnknownSize : dim; });
return TorchD::ValueTensorType::get(type.getContext(), llvm::ArrayRef<int64_t>(shape), type.getElementType());
});
converter.addConversion([](mlir::UnrankedTensorType type) -> mlir::Type {
return TorchD::ValueTensorType::get(type.getContext(), std::nullopt, type.getElementType());
});
converter.addConversion(
[](mlir::mfuse::StringType type) -> mlir::Type { return TorchD::StringType::get(type.getContext()); });
converter.addConversion(
[](mlir::mfuse::NoneType type) -> mlir::Type { return TorchD::NoneType::get(type.getContext()); });
converter.addConversion([&](mlir::mfuse::ListType type) -> mlir::Type {
return TorchD::ListType::get(type.getContext(), converter.convertType(type.getContainedType()));
});
}
namespace {
mlir::Value convertDenseElementsAttrToTorchConstant(mlir::OpBuilder &builder, mlir::DenseElementsAttr denseAttr,
mlir::Location loc, mlir::Value input) {
auto encoding = mlir::dyn_cast<mlir::RankedTensorType>(input.getType()).getEncoding();
if (!encoding) {
return {};
}
auto dictAttr = mlir::dyn_cast<mlir::DictionaryAttr>(encoding);
if (!dictAttr || !dictAttr.contains(mlir::mfuse::kScalarMarkerAttr)) {
return {};
}
auto scalarAttr = dictAttr.get(mlir::mfuse::kScalarMarkerAttr);
if (auto strAttr = mlir::dyn_cast<mlir::StringAttr>(scalarAttr)) {
std::string typeStr = strAttr.getValue().str();
if (typeStr == "!torch.int") {
if (mlir::isa<mlir::FloatType>(denseAttr.getType().getElementType())) {
double val = denseAttr.getSplatValue<mlir::APFloat>().convertToDouble();
return builder.create<TorchD::ConstantIntOp>(loc, builder.getI64IntegerAttr(static_cast<int64_t>(val)));
}
int64_t val = denseAttr.getSplatValue<mlir::APInt>().getSExtValue();
return builder.create<TorchD::ConstantIntOp>(loc, builder.getI64IntegerAttr(val));
} else if (typeStr == "!torch.float") {
if (mlir::isa<mlir::IntegerType>(denseAttr.getType().getElementType())) {
int64_t val = denseAttr.getSplatValue<mlir::APInt>().getSExtValue();
return builder.create<TorchD::ConstantFloatOp>(loc, builder.getF64FloatAttr(static_cast<double>(val)));
}
double val = denseAttr.getSplatValue<mlir::APFloat>().convertToDouble();
return builder.create<TorchD::ConstantFloatOp>(loc, builder.getF64FloatAttr(val));
} else if (typeStr == "!torch.bool") {
auto boolValue = denseAttr.getSplatValue<mlir::APInt>().getBoolValue();
return builder.create<TorchD::ConstantBoolOp>(loc, builder.getBoolAttr(boolValue));
}
}
auto tensorType = denseAttr.getType();
mlir::Type elemType = tensorType.getElementType();
if (mlir::isa<mlir::FloatType>(elemType)) {
double val = denseAttr.getSplatValue<mlir::APFloat>().convertToDouble();
return builder.create<TorchD::ConstantFloatOp>(loc, builder.getF64FloatAttr(val));
}
if (mlir::isa<mlir::IntegerType>(elemType)) {
int64_t val = denseAttr.getSplatValue<mlir::APInt>().getSExtValue();
return builder.create<TorchD::ConstantIntOp>(loc, builder.getI64IntegerAttr(val));
}
return {};
}
}
class MfuseToTorchTypeConverter : public mlir::TypeConverter {
private:
static mlir::Value tryConvertConstant(mlir::OpBuilder &builder, mlir::Type toType, mlir::Value input,
mlir::Location loc) {
if (auto cst = input.getDefiningOp<mlir::mfuse::ConstantOp>()) {
if (auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(cst.getValue())) {
return convertDenseElementsAttrToTorchConstant(builder, denseAttr, loc, input);
}
}
return {};
}
public:
MfuseToTorchTypeConverter() {
addConversion([](mlir::Type type) { return type; });
populateMfuseToTorchTypeConversions(*this);
addTargetMaterialization(
[](mlir::OpBuilder &builder, mlir::Type toType, mlir::ValueRange inputs, mlir::Location loc) -> mlir::Value {
if (inputs.size() != 1) return {};
mlir::Value input = inputs[0];
if (auto v = tryConvertConstant(builder, toType, input, loc)) return v;
return builder.create<mlir::UnrealizedConversionCastOp>(loc, toType, inputs).getResult(0);
});
addSourceMaterialization(
[](mlir::OpBuilder &builder, mlir::Type toType, mlir::ValueRange inputs, mlir::Location loc) -> mlir::Value {
if (inputs.size() != 1) return {};
return builder.create<mlir::UnrealizedConversionCastOp>(loc, toType, inputs).getResult(0);
});
}
};
class ConvertMfuseConstantToTorch : public mlir::OpConversionPattern<mlir::mfuse::ConstantOp> {
public:
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::mfuse::ConstantOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = op.getLoc();
mlir::Attribute value = op.getValue();
mlir::Type resultType = op.getResult().getType();
mlir::Type convertedType = getTypeConverter()->convertType(resultType);
if (!convertedType) {
return rewriter.notifyMatchFailure(op, "failed to convert result type");
}
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(value);
if (!denseAttr) {
return rewriter.notifyMatchFailure(op, "value must be a dense elements attribute");
}
auto tensorType = denseAttr.getType();
if (auto torchCst = convertDenseElementsAttrToTorchConstant(rewriter, denseAttr, loc, op.getResult())) {
rewriter.replaceOp(op, torchCst);
return mlir::success();
}
llvm::SmallVector<int64_t> shape;
auto sizes = mlir::dyn_cast<mlir::RankedTensorType>(resultType).getShape();
shape.reserve(sizes.size());
std::transform(sizes.begin(), sizes.end(), std::back_inserter(shape),
[](int64_t dim) { return dim == mlir::ShapedType::kDynamic ? TorchD::kUnknownSize : dim; });
auto vtensorType = TorchD::ValueTensorType::get(rewriter.getContext(), shape, tensorType.getElementType());
if (vtensorType) {
auto torchTensor = rewriter.create<TorchD::ValueTensorLiteralOp>(loc, vtensorType, denseAttr);
rewriter.replaceOp(op, torchTensor.getResult());
return mlir::success();
}
return rewriter.notifyMatchFailure(op, "failed to convert tensor type");
}
};
class ConvertAkgCallOp : public mlir::OpConversionPattern<mlir::mfuse::AkgCallOp> {
public:
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::mfuse::AkgCallOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<mlir::Type, 4> resultTypes;
if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes))) {
return rewriter.notifyMatchFailure(op, "failed to convert result types");
}
auto numInputs = adaptor.getOperands().size();
auto numOutputs = op->getNumResults();
std::string opName = llvm::formatv("torch.mfusion.akg_call__i{0}_o{1}", numInputs, numOutputs).str();
mlir::OperationState subgraphState(op.getLoc(), "torch.constant.str");
subgraphState.addAttribute("value", op.getSubgraphAttr());
subgraphState.addTypes(TorchD::StringType::get(op.getContext()));
mlir::Operation *subgraphConst = rewriter.create(subgraphState);
mlir::Value subgraphValue = subgraphConst->getResult(0);
mlir::OperationState state(op.getLoc(), "torch.operator");
state.addOperands(adaptor.getOperands());
state.addOperands(subgraphValue);
state.addTypes(resultTypes);
state.addAttribute("name", rewriter.getStringAttr(opName));
state.addAttribute("mfusion.subgraph_mlir", op.getSubgraphMlirAttr());
state.addAttribute("mfusion.is_dynamic", op.getIsDynamicAttr());
mlir::Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return mlir::success();
}
};
class ConvertBishengCallOp : public mlir::OpConversionPattern<mlir::mfuse::BishengCallOp> {
public:
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::mfuse::BishengCallOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<mlir::Type, 4> resultTypes;
if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes))) {
return rewriter.notifyMatchFailure(op, "failed to convert result types");
}
auto numInputs = adaptor.getOperands().size();
auto numOutputs = op->getNumResults();
std::string opName = llvm::formatv("torch.mfusion.bisheng_call__i{0}_o{1}", numInputs, numOutputs).str();
mlir::OperationState subgraphState(op.getLoc(), "torch.constant.str");
subgraphState.addAttribute("value", op.getSubgraphAttr());
subgraphState.addTypes(TorchD::StringType::get(op.getContext()));
mlir::Operation *subgraphConst = rewriter.create(subgraphState);
mlir::Value subgraphValue = subgraphConst->getResult(0);
mlir::OperationState state(op.getLoc(), "torch.operator");
state.addOperands(adaptor.getOperands());
state.addOperands(subgraphValue);
state.addTypes(resultTypes);
state.addAttribute("name", rewriter.getStringAttr(opName));
state.addAttribute("mfusion.subgraph_mlir", op.getSubgraphMlirAttr());
state.addAttribute("mfusion.is_dynamic", op.getIsDynamicAttr());
mlir::Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return mlir::success();
}
};
class ConvertDvmCallOp : public mlir::OpConversionPattern<mlir::mfuse::DvmCallOp> {
public:
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::mfuse::DvmCallOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<mlir::Type, 4> resultTypes;
if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), resultTypes))) {
return rewriter.notifyMatchFailure(op, "failed to convert result types");
}
auto numInputs = adaptor.getOperands().size();
auto numOutputs = op->getNumResults();
std::string opName = llvm::formatv("torch.mfusion.dvm_call__i{0}_o{1}", numInputs, numOutputs).str();
mlir::OperationState subgraphState(op.getLoc(), "torch.constant.str");
subgraphState.addAttribute("value", op.getSubgraphAttr());
subgraphState.addTypes(TorchD::StringType::get(op.getContext()));
mlir::Operation *subgraphConst = rewriter.create(subgraphState);
mlir::Value subgraphValue = subgraphConst->getResult(0);
mlir::OperationState state(op.getLoc(), "torch.operator");
state.addOperands(adaptor.getOperands());
state.addOperands(subgraphValue);
state.addTypes(resultTypes);
state.addAttribute("name", rewriter.getStringAttr(opName));
state.addAttribute("mfusion.subgraph_mlir", op.getSubgraphMlirAttr());
state.addAttribute("mfusion.is_dynamic", op.getIsDynamicAttr());
mlir::Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return mlir::success();
}
};
struct ConvertMfuseToTorchPass : public mlir::OperationPass<mlir::ModuleOp> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertMfuseToTorchPass)
ConvertMfuseToTorchPass()
: mlir::OperationPass<mlir::ModuleOp>(mlir::TypeID::get<ConvertMfuseToTorchPass>()) {}
mlir::StringRef getName() const override { return "convert-mfuse-to-torch"; }
mlir::StringRef getArgument() const override { return "convert-mfuse-to-torch"; }
mlir::StringRef getDescription() const override { return "Convert Mfuse operations to Torch dialect operations"; }
mlir::Pass::Option<std::string> kernelGenerator{
*this, "kernel-generator",
llvm::cl::desc("Kernel backend: dvm (transpose as attrs on mm) or akg/bisheng (permute before mm)."),
llvm::cl::init("dvm")};
std::unique_ptr<mlir::Pass> clonePass() const override {
return std::make_unique<ConvertMfuseToTorchPass>();
}
static bool classof(const mlir::Pass *pass) {
return pass->getTypeID() == mlir::TypeID::get<ConvertMfuseToTorchPass>();
}
void getDependentDialects(mlir::DialectRegistry ®istry) const override {
registry.insert<mlir::arith::ArithDialect>();
registry.insert<TorchD::TorchDialect>();
registry.insert<mlir::mfuse::MfuseDialect>();
registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::pdl::PDLDialect>();
registry.insert<mlir::pdl_interp::PDLInterpDialect>();
}
mlir::LogicalResult initialize(mlir::MLIRContext *ctx) override {
mlir::RewritePatternSet patternList(ctx);
mlir::registerConversionPDLFunctions(patternList);
mlir::registerPDLLHelperFunctions(patternList);
populateGeneratedPDLLPatterns(patternList, mlir::PDLConversionConfig(&converter_));
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(patternList, converter_);
llvm::StringRef kg = kernelGenerator.getValue();
mlir::populateMfuseMetaToTorchConversionPatterns(converter_, patternList, kg);
mlir::populateMfuseAclnnToTorchConversionPatterns(converter_, patternList, kg);
patternList.add<ConvertAkgCallOp, ConvertBishengCallOp, ConvertDvmCallOp>(converter_, ctx);
patternList.add<ConvertMfuseConstantToTorch>(converter_, ctx);
patterns_ = std::move(patternList);
return mlir::success();
}
void runOnOperation() override {
MLOG(DEBUG) << "convert-mfuse-to-torch pass start";
mlir::ModuleOp module = getOperation();
mlir::MLIRContext *ctx = &getContext();
mlir::ConversionTarget target(*ctx);
target.addIllegalDialect<mlir::mfuse::MfuseDialect>();
target.addLegalDialect<TorchD::TorchDialect>();
target.addLegalDialect<mlir::func::FuncDialect>();
target.addLegalOp<mlir::UnrealizedConversionCastOp>();
target.addDynamicallyLegalOp<mlir::func::FuncOp>(
[&](mlir::func::FuncOp op) { return converter_.isSignatureLegal(op.getFunctionType()); });
target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
[&](mlir::func::ReturnOp op) { return converter_.isLegal(op.getOperandTypes()); });
if (mlir::failed(mlir::applyPartialConversion(module, target, patterns_))) {
signalPassFailure();
}
MLOG(DEBUG) << "convert-mfuse-to-torch pass end";
}
mlir::FrozenRewritePatternSet patterns_;
MfuseToTorchTypeConverter converter_;
};
}
namespace mlir {
std::unique_ptr<Pass> createConvertMfuseToTorchPass() { return std::make_unique<ConvertMfuseToTorchPass>(); }
}