* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mfusion/Conversion/TorchToMfuse/TorchToMfuse.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "mfusion/Conversion/MfuseTypeConverter.h"
#include "mfusion/Conversion/TorchToMfuse/TorchAtenToMfuse.h"
#include "mfusion/Conversion/TorchToMfuse/TorchNpuToMfuse.h"
#include "mfusion/Conversion/PdllHelper.h"
#include "mfusion/Dialect/Mfuse/IR/Mfuse.h"
#include "mfusion/Dialect/Mfuse/IR/MfuseDialect.h"
#include "mfusion/Support/Logging.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "TorchToMfuse.pdll.h.inc"
namespace {
namespace TorchD = mlir::torch::Torch;
void populateTorchToMfuseTypeConversions(mlir::TypeConverter &converter) {
converter.addConversion([](TorchD::ValueTensorType type) -> mlir::Type { return type.toBuiltinTensor(); });
converter.addConversion([](TorchD::IntType type) -> mlir::Type {
auto ctx = type.getContext();
auto scalarMarker =
mlir::NamedAttribute(mlir::StringAttr::get(ctx, mlir::mfuse::kScalarMarkerAttr), mlir::StringAttr::get(ctx, ""));
auto newEncoding = mlir::DictionaryAttr::get(ctx, {scalarMarker});
auto elementType = mlir::IntegerType::get(type.getContext(), 64);
return mlir::RankedTensorType::get({}, elementType, newEncoding);
});
converter.addConversion([](TorchD::FloatType type) -> mlir::Type {
auto ctx = type.getContext();
auto scalarMarker =
mlir::NamedAttribute(mlir::StringAttr::get(ctx, mlir::mfuse::kScalarMarkerAttr), mlir::StringAttr::get(ctx, ""));
auto newEncoding = mlir::DictionaryAttr::get(ctx, {scalarMarker});
auto elementType = mlir::Float64Type::get(type.getContext());
return mlir::RankedTensorType::get({}, elementType, newEncoding);
});
converter.addConversion([](TorchD::BoolType type) -> mlir::Type {
auto ctx = type.getContext();
auto scalarMarker = mlir::NamedAttribute(mlir::StringAttr::get(ctx, mlir::mfuse::kScalarMarkerAttr),
mlir::StringAttr::get(ctx, "!torch.bool"));
auto newEncoding = mlir::DictionaryAttr::get(ctx, {scalarMarker});
auto elementType = mlir::IntegerType::get(ctx, 1);
return mlir::RankedTensorType::get({}, elementType, newEncoding);
});
converter.addConversion(
[](TorchD::StringType type) -> mlir::Type { return mlir::mfuse::StringType::get(type.getContext()); });
converter.addConversion(
[](TorchD::DeviceType type) -> mlir::Type { return mlir::mfuse::StringType::get(type.getContext()); });
converter.addConversion(
[](TorchD::NoneType type) -> mlir::Type { return mlir::mfuse::NoneType::get(type.getContext()); });
converter.addConversion([&](TorchD::ListType type) -> mlir::Type {
if (mlir::isa<TorchD::IntType>(type.getContainedType())) {
return mlir::RankedTensorType::get({mlir::ShapedType::kDynamic}, mlir::IntegerType::get(type.getContext(), 64));
}
return mlir::mfuse::ListType::get(type.getContext(), converter.convertType(type.getContainedType()));
});
}
class TorchToMfuseTypeConverter : public mlir::TypeConverter {
public:
TorchToMfuseTypeConverter() {
addConversion([](mlir::Type type) { return type; });
mlir::mfuse::populateMfuseTypeConversions(*this);
mlir::mfuse::populateMfuseTypeMaterializations(*this);
populateTorchToMfuseTypeConversions(*this);
addTorchMaterializations();
}
private:
template <typename OpTy>
static mlir::Value tryConvertConstant(mlir::OpBuilder &builder, mlir::Type toType, mlir::Value input,
mlir::Location loc) {
if (auto op = input.getDefiningOp<OpTy>()) {
if (auto ranked = mlir::dyn_cast<mlir::RankedTensorType>(toType)) {
auto denseAttr = mlir::DenseElementsAttr::get(ranked, op.getValueAttr());
return builder.create<mlir::mfuse::ConstantOp>(loc, ranked, denseAttr).getResult();
}
}
return {};
}
void addTorchMaterializations() {
addTargetMaterialization(
[](mlir::OpBuilder &builder, mlir::Type toType, mlir::ValueRange inputs, mlir::Location loc) -> mlir::Value {
if (inputs.size() != 1) return {};
mlir::Value input = inputs[0];
if (auto v = tryConvertConstant<TorchD::ConstantIntOp>(builder, toType, input, loc)) return v;
if (auto v = tryConvertConstant<TorchD::ConstantFloatOp>(builder, toType, input, loc)) return v;
if (auto v = tryConvertConstant<TorchD::ConstantBoolOp>(builder, toType, input, loc)) return v;
return builder.create<mlir::UnrealizedConversionCastOp>(loc, toType, inputs).getResult(0);
});
addSourceMaterialization(
[](mlir::OpBuilder &builder, mlir::Type toType, mlir::ValueRange inputs, mlir::Location loc) -> mlir::Value {
if (inputs.size() != 1) return {};
return builder.create<mlir::UnrealizedConversionCastOp>(loc, toType, inputs).getResult(0);
});
}
};
struct ConvertTorchToMfusePass
: public mlir::PassWrapper<ConvertTorchToMfusePass, mlir::OperationPass<mlir::ModuleOp>> {
mlir::StringRef getArgument() const final { return "convert-torch-to-mfuse"; }
mlir::StringRef getDescription() const final { return "Convert Torch operations to Mfuse dialect operations"; }
void getDependentDialects(mlir::DialectRegistry ®istry) const override {
registry.insert<TorchD::TorchDialect>();
registry.insert<mlir::arith::ArithDialect>();
registry.insert<mlir::mfuse::MfuseDialect>();
registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::pdl::PDLDialect>();
registry.insert<mlir::pdl_interp::PDLInterpDialect>();
}
mlir::LogicalResult initialize(mlir::MLIRContext *ctx) override {
mlir::RewritePatternSet patternList(ctx);
mlir::registerConversionPDLFunctions(patternList);
mlir::registerPDLLHelperFunctions(patternList);
populateGeneratedPDLLPatterns(patternList, mlir::PDLConversionConfig(&converter_));
mlir::populateAtenToMfuseConversionPatterns(converter_, patternList);
mlir::populateNpuToMfuseConversionPatterns(converter_, patternList);
patterns_ = std::move(patternList);
return mlir::success();
}
void runOnOperation() override {
MLOG(DEBUG) << "convert-torch-to-mfuse pass start";
mlir::ModuleOp module = getOperation();
mlir::MLIRContext *ctx = &getContext();
mlir::ConversionTarget target(*ctx);
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalDialect<mlir::mfuse::MfuseDialect>();
target.addLegalOp<mlir::UnrealizedConversionCastOp>();
target.addLegalOp<TorchD::BindSymbolicShapeOp>();
target.addLegalOp<TorchD::SymbolicIntOp>();
target.addDynamicallyLegalOp<TorchD::ValueTensorLiteralOp>([](TorchD::ValueTensorLiteralOp op) {
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(op.getValueAttr());
if (!denseAttr || !denseAttr.isSplat()) {
return true;
}
auto tensorType = mlir::cast<TorchD::ValueTensorType>(op.getType());
mlir::Type dtype = tensorType.getOptionalDtype();
return !dtype || mlir::failed(mlir::getTorchScalarTypeInt(dtype));
});
if (mlir::failed(mlir::applyPartialConversion(module, target, patterns_))) {
signalPassFailure();
}
MLOG(DEBUG) << "convert-torch-to-mfuse pass end";
}
mlir::FrozenRewritePatternSet patterns_;
TorchToMfuseTypeConverter converter_;
};
}
namespace mlir {
std::unique_ptr<Pass> createConvertTorchToMfusePass() { return std::make_unique<ConvertTorchToMfusePass>(); }
}