* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mfusion/Conversion/MfuseToTorch/MfuseMetaToTorch.h"
#include <algorithm>
#include <iterator>
#include <optional>
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "mfusion/Dialect/Mfuse/IR/Mfuse.h"
namespace mlir {
namespace TorchD = mlir::torch::Torch;
namespace {
static bool isDvmKernelGenerator(llvm::StringRef kernelGenerator) { return kernelGenerator == "dvm"; }
static std::optional<int64_t> getTorchOrRankedTensorRank(mlir::Type ty) {
if (auto vtt = mlir::dyn_cast<TorchD::ValueTensorType>(ty)) {
if (!vtt.hasSizes()) {
return std::nullopt;
}
return static_cast<int64_t>(vtt.getSizes().size());
}
if (auto rt = mlir::dyn_cast<mlir::RankedTensorType>(ty)) {
return static_cast<int64_t>(rt.getRank());
}
return std::nullopt;
}
static bool isTwoDMatmulOperandTypes(mlir::Type selfTy, mlir::Type otherTy) {
auto r1 = getTorchOrRankedTensorRank(selfTy);
auto r2 = getTorchOrRankedTensorRank(otherTy);
return r1 && r2 && *r1 == 2 && *r2 == 2;
}
static mlir::FailureOr<mlir::Value> buildSwapLastTwoDimsPermute(mlir::Location loc, mlir::Value v,
mlir::ConversionPatternRewriter &rewriter) {
auto vtt = mlir::dyn_cast<TorchD::ValueTensorType>(v.getType());
if (!vtt || !vtt.hasSizes()) {
return mlir::failure();
}
auto sizes = vtt.getSizes();
int64_t rank = static_cast<int64_t>(sizes.size());
if (rank < 2) {
return mlir::failure();
}
llvm::SmallVector<int64_t> newSizes(sizes.begin(), sizes.end());
std::swap(newSizes[rank - 2], newSizes[rank - 1]);
mlir::Type permResultType = vtt.getWithSizesAndDtype(newSizes, vtt.getOptionalDtype());
llvm::SmallVector<mlir::Value> permDims;
permDims.reserve(static_cast<size_t>(rank));
for (int64_t i = 0; i < rank - 2; ++i) {
permDims.push_back(rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
}
permDims.push_back(rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(rank - 1)));
permDims.push_back(rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(rank - 2)));
mlir::MLIRContext *ctx = rewriter.getContext();
auto listType = TorchD::ListType::get(ctx, TorchD::IntType::get(ctx));
mlir::Value permList = rewriter.create<TorchD::PrimListConstructOp>(loc, listType, permDims);
return rewriter.create<TorchD::AtenPermuteOp>(loc, permResultType, v, permList).getResult();
}
std::optional<int64_t> getTorchScalarTypeInt(mlir::Type type) {
if (mlir::isa<mlir::NoneType, mlir::mfuse::NoneType, TorchD::NoneType>(type)) {
return std::nullopt;
}
if (type.isSignlessInteger() && !type.isSignlessInteger(1)) {
if (type.isSignlessInteger(8)) return static_cast<int64_t>(mlir::torch::torch_upstream::ScalarType::Char);
if (type.isSignlessInteger(16)) return static_cast<int64_t>(mlir::torch::torch_upstream::ScalarType::Short);
if (type.isSignlessInteger(32)) return static_cast<int64_t>(mlir::torch::torch_upstream::ScalarType::Int);
if (type.isSignlessInteger(64)) return static_cast<int64_t>(mlir::torch::torch_upstream::ScalarType::Long);
return std::nullopt;
}
if (type.isUnsignedInteger() && !type.isUnsignedInteger(8)) {
return std::nullopt;
}
auto scalarType = TorchD::getScalarTypeForType(type);
return static_cast<int64_t>(scalarType);
}
mlir::FailureOr<mlir::Value> buildTorchDtypeValue(mlir::Type dtypeType, mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter) {
auto maybeDtypeInt = getTorchScalarTypeInt(dtypeType);
if (!maybeDtypeInt) {
return mlir::failure();
}
return rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(*maybeDtypeInt)).getResult();
}
mlir::Value buildTorchIntListFromI64ArrayAttr(mlir::ArrayAttr attr, mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter) {
llvm::SmallVector<mlir::Value> values;
values.reserve(attr.size());
for (auto element : attr) {
int64_t v = mlir::cast<mlir::IntegerAttr>(element).getInt();
values.push_back(rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(v)));
}
return rewriter.create<TorchD::PrimListConstructOp>(
loc, TorchD::ListType::get(rewriter.getContext(), TorchD::IntType::get(rewriter.getContext())), values);
}
static mlir::FailureOr<mlir::Value> buildTorchIntListFromI64LikeAttr(mlir::Attribute attr, mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter) {
llvm::SmallVector<mlir::Value> values;
if (auto dense = mlir::dyn_cast<mlir::DenseI64ArrayAttr>(attr)) {
values.reserve(dense.asArrayRef().size());
std::transform(dense.asArrayRef().begin(), dense.asArrayRef().end(), std::back_inserter(values), [&](int64_t v) {
return rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(v));
});
} else if (auto arr = mlir::dyn_cast<mlir::ArrayAttr>(attr)) {
values.reserve(arr.size());
for (mlir::Attribute element : arr) {
auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(element);
if (!intAttr) {
return mlir::failure();
}
values.push_back(rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(intAttr.getInt())));
}
} else {
return mlir::failure();
}
mlir::MLIRContext *ctx = rewriter.getContext();
return rewriter
.create<TorchD::PrimListConstructOp>(loc, TorchD::ListType::get(ctx, TorchD::IntType::get(ctx)), values)
.getResult();
}
class ConvertMfuseBroadcastTo : public mlir::OpConversionPattern<mlir::mfuse::BroadcastToOp> {
public:
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::mfuse::BroadcastToOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto outType = mlir::cast<mlir::RankedTensorType>(op.getOutput().getType());
llvm::SmallVector<mlir::Value> sizeValues;
sizeValues.reserve(outType.getRank());
std::transform(outType.getShape().begin(), outType.getShape().end(), std::back_inserter(sizeValues),
[&](int64_t dim) {
return rewriter.create<TorchD::ConstantIntOp>(
op.getLoc(), rewriter.getI64IntegerAttr(dim == mlir::ShapedType::kDynamic ? -1 : dim));
});
auto listType = TorchD::ListType::get(op.getContext(), TorchD::IntType::get(op.getContext()));
mlir::Value sizeList = rewriter.create<TorchD::PrimListConstructOp>(op.getLoc(), listType, sizeValues);
mlir::Type torchResultType = getTypeConverter()->convertType(outType);
if (!torchResultType) return mlir::failure();
mlir::Value input = adaptor.getInput();
mlir::Value implicit = rewriter.create<TorchD::ConstantBoolOp>(op.getLoc(), false);
rewriter.replaceOpWithNewOp<TorchD::AtenExpandOp>(op, torchResultType, input, sizeList, implicit);
return mlir::success();
}
};
static mlir::FailureOr<mlir::Value> buildAtenConvolutionFromMfuseConvAttrs(
mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, mlir::Type resultType, mlir::Value input,
mlir::Value weight, mlir::Value bias, mlir::Attribute strideAttr, mlir::Attribute paddingAttr,
mlir::Attribute dilationAttr, bool transposed, mlir::Attribute outputPaddingAttr, int64_t groups) {
auto strideList = buildTorchIntListFromI64LikeAttr(strideAttr, loc, rewriter);
if (failed(strideList)) {
return mlir::failure();
}
auto paddingList = buildTorchIntListFromI64LikeAttr(paddingAttr, loc, rewriter);
if (failed(paddingList)) {
return mlir::failure();
}
auto dilationList = buildTorchIntListFromI64LikeAttr(dilationAttr, loc, rewriter);
if (failed(dilationList)) {
return mlir::failure();
}
auto outputPaddingList = buildTorchIntListFromI64LikeAttr(outputPaddingAttr, loc, rewriter);
if (failed(outputPaddingList)) {
return mlir::failure();
}
mlir::Value transposedVal = rewriter.create<TorchD::ConstantBoolOp>(loc, transposed);
mlir::Value groupsVal = rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(groups));
return rewriter
.create<TorchD::AtenConvolutionOp>(loc, resultType, input, weight, bias, *strideList, *paddingList, *dilationList,
transposedVal, *outputPaddingList, groupsVal)
.getResult();
}
class ConvertMfuseAclnnConv2D : public mlir::OpConversionPattern<mlir::mfuse::AclnnConv2DOp> {
public:
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::mfuse::AclnnConv2DOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Value input = adaptor.getInput();
mlir::Value weight = adaptor.getWeight();
auto resultType = getTypeConverter()->convertType(op.getResult().getType());
if (!resultType) {
return rewriter.notifyMatchFailure(op, "result type conversion failed");
}
mlir::Value noneBias = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
auto conv = buildAtenConvolutionFromMfuseConvAttrs(rewriter, op.getLoc(), resultType, input, weight, noneBias,
op.getStride(), op.getPadding(), op.getDilation(),
op.getTransposed(), op.getOutputPadding(), op.getGroups());
if (failed(conv)) {
return rewriter.notifyMatchFailure(
op, "mfuse conv stride/padding/dilation/output_padding must be DenseI64ArrayAttr or ArrayAttr of integers");
}
rewriter.replaceOp(op, *conv);
return mlir::success();
}
};
class ConvertMfuseAclnnConv2DWithBias : public mlir::OpConversionPattern<mlir::mfuse::AclnnConv2DWithBiasOp> {
public:
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::mfuse::AclnnConv2DWithBiasOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Value input = adaptor.getInput();
mlir::Value weight = adaptor.getWeight();
mlir::Value bias = adaptor.getBias();
auto resultType = getTypeConverter()->convertType(op.getResult().getType());
if (!resultType) {
return rewriter.notifyMatchFailure(op, "result type conversion failed");
}
auto conv = buildAtenConvolutionFromMfuseConvAttrs(rewriter, op.getLoc(), resultType, input, weight, bias,
op.getStride(), op.getPadding(), op.getDilation(),
op.getTransposed(), op.getOutputPadding(), op.getGroups());
if (failed(conv)) {
return rewriter.notifyMatchFailure(
op, "mfuse conv stride/padding/dilation/output_padding must be DenseI64ArrayAttr or ArrayAttr of integers");
}
rewriter.replaceOp(op, *conv);
return mlir::success();
}
};
class ConvertMfuseFull : public mlir::OpConversionPattern<mlir::mfuse::FullOp> {
public:
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::mfuse::FullOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto resultType = mlir::dyn_cast<mlir::RankedTensorType>(op.getResult().getType());
if (!resultType) {
return mlir::failure();
}
llvm::SmallVector<mlir::Value> sizeValues;
sizeValues.reserve(resultType.getRank());
const auto &shape = resultType.getShape();
std::transform(shape.begin(), shape.end(), std::back_inserter(sizeValues), [&](int64_t dim) {
return rewriter.create<TorchD::ConstantIntOp>(op.getLoc(), rewriter.getI64IntegerAttr(dim));
});
auto listType = TorchD::ListType::get(op.getContext(), TorchD::IntType::get(op.getContext()));
mlir::Value sizeList = rewriter.create<TorchD::PrimListConstructOp>(op.getLoc(), listType, sizeValues);
mlir::Value fillVal = adaptor.getFillValue();
mlir::Value dtypeVal;
if (auto dtypeAttr = op.getDtypeAttr()) {
dtypeVal = rewriter.create<TorchD::ConstantIntOp>(op.getLoc(), dtypeAttr);
} else {
auto dtypeValOrFailure = buildTorchDtypeValue(resultType.getElementType(), op.getLoc(), rewriter);
if (mlir::succeeded(dtypeValOrFailure)) {
dtypeVal = *dtypeValOrFailure;
} else {
dtypeVal = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
}
}
mlir::Value layoutVal;
if (auto layoutAttr = op.getLayoutAttr()) {
layoutVal = rewriter.create<TorchD::ConstantIntOp>(op.getLoc(), layoutAttr);
} else {
layoutVal = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
}
mlir::Value deviceVal;
if (auto deviceAttr = op.getDeviceAttr()) {
deviceVal = rewriter.create<TorchD::ConstantDeviceOp>(op.getLoc(), deviceAttr);
} else {
deviceVal = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
}
mlir::Value pinMemoryVal;
if (auto pinMemoryAttr = op.getPinMemoryAttr()) {
pinMemoryVal = rewriter.create<TorchD::ConstantBoolOp>(op.getLoc(), pinMemoryAttr.getValue());
} else {
pinMemoryVal = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
}
mlir::Type torchResultType = getTypeConverter()->convertType(resultType);
if (!torchResultType) return mlir::failure();
rewriter.replaceOpWithNewOp<TorchD::AtenFullOp>(op, torchResultType, sizeList, fillVal, dtypeVal, layoutVal,
deviceVal, pinMemoryVal);
return mlir::success();
}
};
class ConvertMfuseMatmul : public mlir::OpConversionPattern<mlir::mfuse::MatmulOp> {
public:
ConvertMfuseMatmul(mlir::TypeConverter &converter, mlir::MLIRContext *ctx, llvm::StringRef kernelGenerator)
: OpConversionPattern(converter, ctx), kernelGenerator_(kernelGenerator.str()) {}
mlir::LogicalResult matchAndRewrite(mlir::mfuse::MatmulOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Value self = adaptor.getSelf();
mlir::Value other = adaptor.getOther();
bool trans1 = op.getTransX1();
bool trans2 = op.getTransX2();
auto resultType = getTypeConverter()->convertType(op.getResult().getType());
if (!resultType) {
return mlir::failure();
}
mlir::Location loc = op.getLoc();
const bool twoD = isTwoDMatmulOperandTypes(self.getType(), other.getType());
const bool dvm = isDvmKernelGenerator(kernelGenerator_);
if (twoD && dvm) {
auto newMm = rewriter.create<TorchD::AtenMmOp>(loc, resultType, self, other);
newMm->setAttr("dvm_trans_a", rewriter.getBoolAttr(trans1));
newMm->setAttr("dvm_trans_b", rewriter.getBoolAttr(trans2));
rewriter.replaceOp(op, newMm.getResult());
return mlir::success();
}
if (trans1) {
auto permOr = buildSwapLastTwoDimsPermute(loc, self, rewriter);
if (mlir::failed(permOr)) {
return mlir::failure();
}
self = *permOr;
}
if (trans2) {
auto permOr = buildSwapLastTwoDimsPermute(loc, other, rewriter);
if (mlir::failed(permOr)) {
return mlir::failure();
}
other = *permOr;
}
if (twoD) {
rewriter.replaceOpWithNewOp<TorchD::AtenMmOp>(op, resultType, self, other);
} else {
rewriter.replaceOpWithNewOp<TorchD::AtenMatmulOp>(op, resultType, self, other);
}
return mlir::success();
}
private:
std::string kernelGenerator_;
};
class ConvertMfuseMatmulWithBias : public mlir::OpConversionPattern<mlir::mfuse::MatmulWithBiasOp> {
public:
ConvertMfuseMatmulWithBias(mlir::TypeConverter &converter, mlir::MLIRContext *ctx, llvm::StringRef kernelGenerator)
: OpConversionPattern(converter, ctx), kernelGenerator_(kernelGenerator.str()) {}
mlir::LogicalResult matchAndRewrite(mlir::mfuse::MatmulWithBiasOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Value self = adaptor.getSelf();
mlir::Value other = adaptor.getOther();
mlir::Value bias = adaptor.getBias();
bool trans1 = op.getTransX1();
bool trans2 = op.getTransX2();
auto resultType = getTypeConverter()->convertType(op.getResult().getType());
if (!resultType) {
return mlir::failure();
}
mlir::Location loc = op.getLoc();
const bool twoD = isTwoDMatmulOperandTypes(self.getType(), other.getType());
const bool dvm = isDvmKernelGenerator(kernelGenerator_);
mlir::Value matmulResult;
if (twoD && dvm) {
auto newMm = rewriter.create<TorchD::AtenMmOp>(loc, resultType, self, other);
newMm->setAttr("dvm_trans_a", rewriter.getBoolAttr(trans1));
newMm->setAttr("dvm_trans_b", rewriter.getBoolAttr(trans2));
matmulResult = newMm.getResult();
} else {
if (trans1) {
auto permOr = buildSwapLastTwoDimsPermute(loc, self, rewriter);
if (mlir::failed(permOr)) {
return mlir::failure();
}
self = *permOr;
}
if (trans2) {
auto permOr = buildSwapLastTwoDimsPermute(loc, other, rewriter);
if (mlir::failed(permOr)) {
return mlir::failure();
}
other = *permOr;
}
if (twoD) {
matmulResult = rewriter.create<TorchD::AtenMmOp>(loc, resultType, self, other).getResult();
} else {
matmulResult = rewriter.create<TorchD::AtenMatmulOp>(loc, resultType, self, other).getResult();
}
}
constexpr double kAlphaOne = 1.0;
mlir::FloatAttr alphaAttr = rewriter.getFloatAttr(rewriter.getF64Type(), kAlphaOne);
mlir::Value alphaOne = rewriter.create<TorchD::ConstantFloatOp>(op.getLoc(), alphaAttr);
mlir::Value addResult =
rewriter.create<TorchD::AtenAddTensorOp>(op.getLoc(), resultType, matmulResult, bias, alphaOne);
rewriter.replaceOp(op, addResult);
return mlir::success();
}
private:
std::string kernelGenerator_;
};
struct ConvertMfuseCast : public mlir::OpConversionPattern<mlir::mfuse::CastOp> {
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::mfuse::CastOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Value input = adaptor.getInput();
auto resultType = getTypeConverter()->convertType(op.getResult().getType());
mlir::Type dtypeType = mlir::cast<mlir::RankedTensorType>(op.getResult().getType()).getElementType();
auto dtypeValOrFailure = buildTorchDtypeValue(dtypeType, op.getLoc(), rewriter);
if (mlir::failed(dtypeValOrFailure)) {
return rewriter.notifyMatchFailure(op, "unsupported dtype for torch scalar type");
}
mlir::Value dtypeVal = *dtypeValOrFailure;
rewriter.replaceOpWithNewOp<TorchD::PrimsConvertElementTypeOp>(op, resultType, input, dtypeVal);
return mlir::success();
}
};
class ConvertMfusePermute : public mlir::OpConversionPattern<mlir::mfuse::PermuteOp> {
public:
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::mfuse::PermuteOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto permAttr = op.getPermAttr();
if (!permAttr) {
return rewriter.notifyMatchFailure(op, "perm attribute must be present");
}
auto inputType = mlir::dyn_cast<mlir::RankedTensorType>(op.getInput().getType());
if (!inputType) {
return rewriter.notifyMatchFailure(op, "input must be ranked tensor");
}
int64_t rank = inputType.getRank();
auto permValues = permAttr.getValue();
if (permValues.size() != static_cast<size_t>(rank)) {
return rewriter.notifyMatchFailure(op, "perm size must match input rank");
}
llvm::SmallVector<mlir::Value> permDims;
permDims.reserve(permValues.size());
for (auto attr : permValues) {
auto dimAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
if (!dimAttr) {
return rewriter.notifyMatchFailure(op, "perm values must be integers");
}
permDims.push_back(rewriter.create<TorchD::ConstantIntOp>(op.getLoc(), dimAttr));
}
mlir::Type resultType = getTypeConverter()->convertType(op.getResult().getType());
if (!resultType) return mlir::failure();
mlir::Value input = adaptor.getInput();
auto listType = TorchD::ListType::get(op.getContext(), TorchD::IntType::get(op.getContext()));
mlir::Value permList = rewriter.create<TorchD::PrimListConstructOp>(op.getLoc(), listType, permDims);
rewriter.replaceOpWithNewOp<TorchD::AtenPermuteOp>(op, resultType, input, permList);
return mlir::success();
}
};
struct ConvertMfuseReduceMean : public mlir::OpConversionPattern<mlir::mfuse::ReduceMeanOp> {
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::mfuse::ReduceMeanOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Value input = adaptor.getInput();
auto resultType = getTypeConverter()->convertType(op.getResult().getType());
mlir::Value dimList = buildTorchIntListFromI64ArrayAttr(op.getDimensions(), op.getLoc(), rewriter);
bool keepdim = op.getKeepdim();
mlir::Value keepdimVal = rewriter.create<TorchD::ConstantBoolOp>(op.getLoc(), keepdim);
mlir::Type dtypeType = mlir::cast<mlir::RankedTensorType>(op.getResult().getType()).getElementType();
mlir::Value dtypeVal;
if (mlir::isa<mlir::NoneType, mlir::mfuse::NoneType, TorchD::NoneType>(dtypeType)) {
dtypeVal = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
} else {
auto dtypeValOrFailure = buildTorchDtypeValue(dtypeType, op.getLoc(), rewriter);
if (mlir::failed(dtypeValOrFailure)) {
return rewriter.notifyMatchFailure(op, "unsupported dtype for torch scalar type");
}
dtypeVal = *dtypeValOrFailure;
}
rewriter.replaceOpWithNewOp<TorchD::AtenMeanDimOp>(op, resultType, input, dimList, keepdimVal, dtypeVal);
return mlir::success();
}
};
struct ConvertMfuseReduceSum : public mlir::OpConversionPattern<mlir::mfuse::ReduceSumOp> {
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::mfuse::ReduceSumOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Value input = adaptor.getInput();
auto resultType = getTypeConverter()->convertType(op.getResult().getType());
mlir::Value dimList = buildTorchIntListFromI64ArrayAttr(op.getDimensions(), op.getLoc(), rewriter);
bool keepdim = op.getKeepdim();
mlir::Value keepdimVal = rewriter.create<TorchD::ConstantBoolOp>(op.getLoc(), keepdim);
mlir::Type dtypeType = mlir::cast<mlir::RankedTensorType>(op.getResult().getType()).getElementType();
mlir::Value dtypeVal;
if (mlir::isa<mlir::NoneType, mlir::mfuse::NoneType, TorchD::NoneType>(dtypeType)) {
dtypeVal = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
} else {
auto dtypeValOrFailure = buildTorchDtypeValue(dtypeType, op.getLoc(), rewriter);
if (mlir::failed(dtypeValOrFailure)) {
return rewriter.notifyMatchFailure(op, "unsupported dtype for torch scalar type");
}
dtypeVal = *dtypeValOrFailure;
}
rewriter.replaceOpWithNewOp<TorchD::AtenSumDimIntListOp>(op, resultType, input, dimList, keepdimVal, dtypeVal);
return mlir::success();
}
};
class ConvertMfuseReshape : public mlir::OpConversionPattern<mlir::mfuse::ReshapeOp> {
public:
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::mfuse::ReshapeOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<mlir::Value> shapeValues;
auto resultType = mlir::dyn_cast<mlir::RankedTensorType>(op.getResult().getType());
if (!resultType) {
return rewriter.notifyMatchFailure(op, "result must be ranked tensor");
}
shapeValues.reserve(resultType.getShape().size());
std::transform(
resultType.getShape().begin(), resultType.getShape().end(), std::back_inserter(shapeValues), [&](int64_t d) {
return rewriter.create<TorchD::ConstantIntOp>(op.getLoc(), d == mlir::ShapedType::kDynamic ? -1 : d);
});
mlir::Type torchResultType = getTypeConverter()->convertType(resultType);
if (!torchResultType) return mlir::failure();
mlir::Value input = adaptor.getInput();
auto listType = TorchD::ListType::get(op.getContext(), TorchD::IntType::get(op.getContext()));
auto shapeList = rewriter.create<TorchD::PrimListConstructOp>(op.getLoc(), listType, shapeValues);
rewriter.replaceOpWithNewOp<TorchD::AtenReshapeOp>(op, torchResultType, input, shapeList);
return mlir::success();
}
};
template <typename SourceOp>
static mlir::FailureOr<std::tuple<mlir::Value, mlir::Value, bool, mlir::Type>> convertBinaryOpCommon(
SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter,
const OpConversionPattern<SourceOp> *pattern) {
auto operands = adaptor.getOperands();
constexpr size_t kBinaryOpNumOperands = 2;
if (operands.size() != kBinaryOpNumOperands) {
return rewriter.notifyMatchFailure(op, "binary op must have 2 operands");
}
constexpr size_t kLhsIndex = 0;
constexpr size_t kRhsIndex = 1;
mlir::Value lhs = operands[kLhsIndex];
mlir::Value rhs = operands[kRhsIndex];
auto torchResultType = pattern->getTypeConverter()->convertType(op.getResult().getType());
if (!torchResultType) {
return mlir::failure();
}
bool isRhsScalar = mlir::isa<TorchD::FloatType>(rhs.getType()) || mlir::isa<TorchD::IntType>(rhs.getType());
return std::make_tuple(lhs, rhs, isRhsScalar, torchResultType);
}
template <typename SourceOp, typename TargetTensorOp, typename TargetScalarOp>
class ConvertBinaryOpPattern : public OpConversionPattern<SourceOp> {
public:
using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto commonResult = convertBinaryOpCommon(op, adaptor, rewriter, this);
if (failed(commonResult)) {
return failure();
}
auto [lhs, processedRhs, isRhsScalar, torchResultType] = commonResult.value();
if (!isRhsScalar) {
rewriter.replaceOpWithNewOp<TargetTensorOp>(op, torchResultType, lhs, processedRhs);
} else {
rewriter.replaceOpWithNewOp<TargetScalarOp>(op, torchResultType, lhs, processedRhs);
}
return mlir::success();
}
};
template <typename SourceOp, typename TargetTensorOp, typename TargetScalarOp>
class ConvertBinaryOpWithAlphaPattern : public OpConversionPattern<SourceOp> {
public:
using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto commonResult = convertBinaryOpCommon(op, adaptor, rewriter, this);
if (failed(commonResult)) {
return failure();
}
auto [lhs, processedRhs, isRhsScalar, torchResultType] = commonResult.value();
constexpr int64_t kAlphaOne = 1;
auto alpha = rewriter.create<TorchD::ConstantIntOp>(op->getLoc(), rewriter.getI64IntegerAttr(kAlphaOne));
if (!isRhsScalar) {
rewriter.replaceOpWithNewOp<TargetTensorOp>(op, torchResultType, lhs, processedRhs, alpha);
} else {
rewriter.replaceOpWithNewOp<TargetScalarOp>(op, torchResultType, lhs, processedRhs, alpha);
}
return mlir::success();
}
};
static void populateMfuseMetaToTorchCustomPatterns(TypeConverter &converter, RewritePatternSet &patterns,
llvm::StringRef kernelGenerator) {
MLIRContext *context = patterns.getContext();
patterns.add<ConvertMfuseBroadcastTo>(converter, context);
patterns.add<ConvertMfuseCast>(converter, context);
patterns.add<ConvertMfuseAclnnConv2D>(converter, context);
patterns.add<ConvertMfuseAclnnConv2DWithBias>(converter, context);
patterns.add<ConvertMfuseFull>(converter, context);
patterns.add<ConvertMfuseMatmul>(converter, context, kernelGenerator);
patterns.add<ConvertMfuseMatmulWithBias>(converter, context, kernelGenerator);
patterns.add<ConvertMfusePermute>(converter, context);
patterns.add<ConvertMfuseReduceMean>(converter, context);
patterns.add<ConvertMfuseReduceSum>(converter, context);
patterns.add<ConvertMfuseReshape>(converter, context);
}
static void populateMfusePartBinaryOpToTorchTensorScalarPatterns(TypeConverter &converter,
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
patterns.add<ConvertBinaryOpWithAlphaPattern<mfuse::AddOp, TorchD::AtenAddTensorOp, TorchD::AtenAddScalarOp>,
ConvertBinaryOpWithAlphaPattern<mfuse::SubOp, TorchD::AtenSubTensorOp, TorchD::AtenSubScalarOp>>(
converter, ctx);
patterns.add<ConvertBinaryOpPattern<mfuse::DivOp, TorchD::AtenDivTensorOp, TorchD::AtenDivScalarOp>,
ConvertBinaryOpPattern<mfuse::EqOp, TorchD::AtenEqTensorOp, TorchD::AtenEqScalarOp>,
ConvertBinaryOpPattern<mfuse::GeOp, TorchD::AtenGeTensorOp, TorchD::AtenGeScalarOp>,
ConvertBinaryOpPattern<mfuse::GtOp, TorchD::AtenGtTensorOp, TorchD::AtenGtScalarOp>,
ConvertBinaryOpPattern<mfuse::LeOp, TorchD::AtenLeTensorOp, TorchD::AtenLeScalarOp>,
ConvertBinaryOpPattern<mfuse::LtOp, TorchD::AtenLtTensorOp, TorchD::AtenLtScalarOp>,
ConvertBinaryOpPattern<mfuse::MulOp, TorchD::AtenMulTensorOp, TorchD::AtenMulScalarOp>,
ConvertBinaryOpPattern<mfuse::NeOp, TorchD::AtenNeTensorOp, TorchD::AtenNeScalarOp>,
ConvertBinaryOpPattern<mfuse::PowOp, TorchD::AtenPowTensorTensorOp, TorchD::AtenPowTensorScalarOp>>(
converter, ctx);
}
}
void populateMfuseMetaToTorchConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns,
llvm::StringRef kernelGenerator) {
populateMfuseMetaToTorchCustomPatterns(converter, patterns, kernelGenerator);
populateMfusePartBinaryOpToTorchTensorScalarPatterns(converter, patterns);
}
}