* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mfusion/Conversion/TorchToMfuse/TorchAtenToMfuse.h"
#include <algorithm>
#include <limits>
#include <numeric>
#include <string>
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "mfusion/Dialect/Mfuse/IR/Mfuse.h"
#include "mfusion/Dialect/Mfuse/Support/OpConstants.h"
namespace mlir {
namespace TorchD = mlir::torch::Torch;
namespace {
using TorchScalarType = torch::torch_upstream::ScalarType;
static int64_t toTorchScalarTypeInt(TorchScalarType scalarType) { return static_cast<int64_t>(scalarType); }
static FailureOr<int64_t> getTorchFloatScalarTypeInt(Type type) {
if (isa<Float32Type>(type)) {
return toTorchScalarTypeInt(TorchScalarType::Float);
}
if (isa<Float64Type>(type)) {
return toTorchScalarTypeInt(TorchScalarType::Double);
}
if (type.isBF16()) {
return toTorchScalarTypeInt(TorchScalarType::BFloat16);
}
if (type.isF16()) {
return toTorchScalarTypeInt(TorchScalarType::Half);
}
return failure();
}
static FailureOr<int64_t> getTorchIntegerScalarTypeInt(Type type) {
if (type.isSignlessInteger(1)) {
return toTorchScalarTypeInt(TorchScalarType::Bool);
}
if (type.isSignedInteger(64)) {
return toTorchScalarTypeInt(TorchScalarType::Long);
}
if (type.isSignedInteger(32)) {
return toTorchScalarTypeInt(TorchScalarType::Int);
}
if (type.isSignedInteger(16)) {
return toTorchScalarTypeInt(TorchScalarType::Short);
}
if (type.isUnsignedInteger(8)) {
return toTorchScalarTypeInt(TorchScalarType::Byte);
}
if (type.isSignedInteger(8)) {
return toTorchScalarTypeInt(TorchScalarType::Char);
}
return failure();
}
static FailureOr<int64_t> getTorchQuantizedScalarTypeInt(Type type) {
if (isa<TorchD::QUInt8Type>(type)) {
return toTorchScalarTypeInt(TorchScalarType::QUInt8);
}
if (isa<TorchD::QInt8Type>(type)) {
return toTorchScalarTypeInt(TorchScalarType::QInt8);
}
if (isa<TorchD::QInt16Type>(type)) {
return toTorchScalarTypeInt(TorchScalarType::QInt16);
}
if (isa<TorchD::QInt32Type>(type)) {
return toTorchScalarTypeInt(TorchScalarType::QInt32);
}
return failure();
}
static FailureOr<int64_t> getTorchComplexScalarTypeInt(ComplexType complexType) {
Type elementType = complexType.getElementType();
if (elementType.isF16()) {
return toTorchScalarTypeInt(TorchScalarType::ComplexHalf);
}
if (elementType.isF32()) {
return toTorchScalarTypeInt(TorchScalarType::ComplexFloat);
}
if (elementType.isF64()) {
return toTorchScalarTypeInt(TorchScalarType::ComplexDouble);
}
return failure();
}
static FailureOr<int64_t> getTorchFloat8ScalarTypeInt(Type type) {
if (isa<Float8E5M2Type>(type)) {
return toTorchScalarTypeInt(TorchScalarType::Float8_e5m2);
}
if (isa<Float8E4M3FNType>(type)) {
return toTorchScalarTypeInt(TorchScalarType::Float8_e4m3fn);
}
if (isa<Float8E5M2FNUZType>(type)) {
return toTorchScalarTypeInt(TorchScalarType::Float8_e5m2fnuz);
}
if (isa<Float8E4M3FNUZType>(type)) {
return toTorchScalarTypeInt(TorchScalarType::Float8_e4m3fnuz);
}
return failure();
}
}
FailureOr<int64_t> getTorchScalarTypeInt(Type type) {
auto scalarTypeInt = getTorchFloatScalarTypeInt(type);
if (succeeded(scalarTypeInt)) {
return *scalarTypeInt;
}
scalarTypeInt = getTorchIntegerScalarTypeInt(type);
if (succeeded(scalarTypeInt)) {
return *scalarTypeInt;
}
scalarTypeInt = getTorchQuantizedScalarTypeInt(type);
if (succeeded(scalarTypeInt)) {
return *scalarTypeInt;
}
if (auto complexType = dyn_cast<ComplexType>(type)) {
return getTorchComplexScalarTypeInt(complexType);
}
scalarTypeInt = getTorchFloat8ScalarTypeInt(type);
if (succeeded(scalarTypeInt)) {
return *scalarTypeInt;
}
return failure();
}
namespace {
bool isConstantListInt(Value listVal, llvm::SmallVectorImpl<int64_t> &out) {
llvm::SmallVector<Value, 4> elems;
if (!TorchD::getListConstructElements(listVal, elems)) {
return false;
}
out.clear();
out.reserve(elems.size());
for (Value v : elems) {
int64_t dim = 0;
if (!matchPattern(v, TorchD::m_TorchConstantInt(&dim))) {
return false;
}
out.push_back(dim);
}
return true;
}
template <typename OpTy>
FailureOr<llvm::SmallVector<int64_t, 4>> getReductionDims(OpTy op, Value dimsValue, int64_t inputRank,
ConversionPatternRewriter &rewriter) {
bool reduceAll = isa<TorchD::NoneType>(dimsValue.getType());
llvm::SmallVector<int64_t, 4> dims;
if (!reduceAll) {
llvm::SmallVector<Value, 4> dimValues;
if (!TorchD::getListConstructElements(dimsValue, dimValues)) {
(void)rewriter.notifyMatchFailure(op, "dim must come from list construct");
return failure();
}
if (dimValues.empty()) {
reduceAll = true;
} else {
dims.reserve(dimValues.size());
for (Value dimValue : dimValues) {
int64_t dim = 0;
if (!matchPattern(dimValue, TorchD::m_TorchConstantInt(&dim))) {
(void)rewriter.notifyMatchFailure(op, "dim list must be constant ints");
return failure();
}
dim = TorchD::toPositiveDim(dim, inputRank);
if (!TorchD::isValidDim(dim, inputRank)) {
(void)rewriter.notifyMatchFailure(op, "dim out of range");
return failure();
}
if (std::find(dims.begin(), dims.end(), dim) != dims.end()) {
(void)rewriter.notifyMatchFailure(op, "duplicate reduction dims are not supported");
return failure();
}
dims.push_back(dim);
}
}
}
if (reduceAll) {
dims.resize(inputRank);
std::iota(dims.begin(), dims.end(), 0);
}
return dims;
}
template <typename OpTy>
FailureOr<int64_t> getStaticReductionSize(OpTy op, ArrayRef<int64_t> dims, RankedTensorType inputType,
ConversionPatternRewriter &rewriter) {
int64_t reductionSize = 1;
for (int64_t dim : dims) {
int64_t dimSize = inputType.getDimSize(dim);
if (dimSize == ShapedType::kDynamic) {
(void)rewriter.notifyMatchFailure(op, "reduced dimensions must be statically known");
return failure();
}
if (dimSize <= 0) {
(void)rewriter.notifyMatchFailure(op, "reduced dimensions must be positive");
return failure();
}
if (reductionSize > std::numeric_limits<int64_t>::max() / dimSize) {
(void)rewriter.notifyMatchFailure(op, "reduction size overflows int64");
return failure();
}
reductionSize *= dimSize;
}
return reductionSize;
}
}
template <typename SourceOp, typename TargetOp>
struct ConvertBinaryOpPattern : public OpConversionPattern<SourceOp> {
using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operands = adaptor.getOperands();
if (operands.size() < 2) {
return rewriter.notifyMatchFailure(op, "binary op requires at least 2 operands");
}
Value lhs = operands[0];
Value rhs = operands[1];
Type resType = this->getTypeConverter()->convertType(op.getType());
if (!resType) {
return rewriter.notifyMatchFailure(op, "result type conversion failed");
}
auto targetOp = rewriter.create<TargetOp>(op.getLoc(), resType, lhs, rhs);
rewriter.replaceOp(op, targetOp.getResult());
return success();
}
};
template <typename SourceOp>
struct ConvertReshapeLikeOp : public OpConversionPattern<SourceOp> {
using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outType = mlir::dyn_cast<RankedTensorType>(this->getTypeConverter()->convertType(op.getType()));
if (!outType) {
return rewriter.notifyMatchFailure(op, "result must be ranked tensor");
}
auto dynamic_dim_count = std::count_if(outType.getShape().begin(), outType.getShape().end(),
[](int64_t dim) { return dim == ShapedType::kDynamic; });
if (dynamic_dim_count > 1) {
return rewriter.notifyMatchFailure(op, "result has more than one dynamic dimension");
}
rewriter.replaceOpWithNewOp<mlir::mfuse::ReshapeOp>(op, outType, adaptor.getSelf());
return success();
}
};
struct ConvertAtenBroadcastTo : public OpConversionPattern<TorchD::AtenBroadcastToOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchD::AtenBroadcastToOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
Value sizeVal = op.getSize();
llvm::SmallVector<int64_t, 4> sizeInts;
if (!isConstantListInt(sizeVal, sizeInts)) {
return rewriter.notifyMatchFailure(op, "size must be a list construct of constant ints for mfuse.broadcast_to");
}
rewriter.replaceOpWithNewOp<mlir::mfuse::BroadcastToOp>(op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
};
constexpr int64_t kConvGroupsDefault = 1;
constexpr int64_t kConvStrideDefault = 1;
constexpr int64_t kConvPaddingDefault = 0;
constexpr int64_t kConvDilationDefault = 1;
constexpr int64_t kConvOutputPaddingDefault = 0;
static LogicalResult extractConvParamList2(TorchD::AtenConvolutionOp op, ConversionPatternRewriter &rewriter,
Value listVal, int64_t &out0, int64_t &out1) {
llvm::SmallVector<Value, 2> elts;
if (!TorchD::getListConstructElements(listVal, elts) || elts.size() != mfuse::kDim2) {
return rewriter.notifyMatchFailure(op, "conv spatial param list must be a 2-element constant int list");
}
if (!matchPattern(elts[0], TorchD::m_TorchConstantInt(&out0)) ||
!matchPattern(elts[1], TorchD::m_TorchConstantInt(&out1))) {
return rewriter.notifyMatchFailure(op, "conv spatial param list must be constant ints");
}
return success();
}
static LogicalResult extractAtenConvolutionHyperParams(TorchD::AtenConvolutionOp op,
ConversionPatternRewriter &rewriter, int64_t &groups,
bool &transposed, int64_t &s0, int64_t &s1, int64_t &p0,
int64_t &p1, int64_t &d0, int64_t &d1, int64_t &o0,
int64_t &o1) {
groups = 0;
if (!matchPattern(op.getGroups(), TorchD::m_TorchConstantInt(&groups)) || groups < 1) {
return rewriter.notifyMatchFailure(op, "groups must be a constant positive int");
}
transposed = false;
if (!matchPattern(op.getTransposed(), TorchD::m_TorchConstantBool(&transposed))) {
return rewriter.notifyMatchFailure(op, "transposed must be a constant bool");
}
s0 = s1 = p0 = p1 = d0 = d1 = o0 = o1 = 0;
if (failed(extractConvParamList2(op, rewriter, op.getStride(), s0, s1))) {
return failure();
}
if (failed(extractConvParamList2(op, rewriter, op.getPadding(), p0, p1))) {
return failure();
}
if (failed(extractConvParamList2(op, rewriter, op.getDilation(), d0, d1))) {
return failure();
}
if (failed(extractConvParamList2(op, rewriter, op.getOutputPadding(), o0, o1))) {
return failure();
}
return success();
}
static bool isNarrowDefaultMetaConv2d(int64_t groups, bool transposed, int64_t s0, int64_t s1, int64_t p0, int64_t p1,
int64_t d0, int64_t d1, int64_t o0, int64_t o1, bool noBias) {
return noBias && groups == kConvGroupsDefault && !transposed && s0 == kConvStrideDefault &&
s1 == kConvStrideDefault && p0 == kConvPaddingDefault && p1 == kConvPaddingDefault &&
d0 == kConvDilationDefault && d1 == kConvDilationDefault && o0 == kConvOutputPaddingDefault &&
o1 == kConvOutputPaddingDefault;
}
template <typename ConvolutionAdaptor>
static LogicalResult replaceAtenConvolutionWithAclnn(TorchD::AtenConvolutionOp op, ConvolutionAdaptor adaptor,
ConversionPatternRewriter &rewriter, RankedTensorType resultType,
int64_t groups, bool transposed, int64_t s0, int64_t s1,
int64_t p0, int64_t p1, int64_t d0, int64_t d1, int64_t o0,
int64_t o1, bool noBias) {
Value input = adaptor.getInput();
Value weight = adaptor.getWeight();
if (isNarrowDefaultMetaConv2d(groups, transposed, s0, s1, p0, p1, d0, d1, o0, o1, noBias)) {
rewriter.replaceOpWithNewOp<mlir::mfuse::AclnnConv2DOp>(op, resultType, input, weight);
return success();
}
auto strideAttr = rewriter.getI64ArrayAttr({s0, s1});
auto paddingAttr = rewriter.getI64ArrayAttr({p0, p1});
auto dilationAttr = rewriter.getI64ArrayAttr({d0, d1});
auto outputPaddingAttr = rewriter.getI64ArrayAttr({o0, o1});
auto transposedAttr = rewriter.getBoolAttr(transposed);
auto groupsAttr = rewriter.getI64IntegerAttr(groups);
if (noBias) {
rewriter.replaceOpWithNewOp<mlir::mfuse::AclnnConv2DOp>(op, resultType, input, weight, strideAttr, paddingAttr,
dilationAttr, transposedAttr, outputPaddingAttr,
groupsAttr);
return success();
}
Value bias = adaptor.getBias();
rewriter.replaceOpWithNewOp<mlir::mfuse::AclnnConv2DWithBiasOp>(op, resultType, input, weight, bias, strideAttr,
paddingAttr, dilationAttr, transposedAttr,
outputPaddingAttr, groupsAttr);
return success();
}
struct ConvertAtenConvolution : public OpConversionPattern<TorchD::AtenConvolutionOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchD::AtenConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int64_t groups = 0;
bool transposed = false;
int64_t s0 = 0, s1 = 0, p0 = 0, p1 = 0, d0 = 0, d1 = 0, o0 = 0, o1 = 0;
if (failed(extractAtenConvolutionHyperParams(op, rewriter, groups, transposed, s0, s1, p0, p1, d0, d1, o0, o1))) {
return failure();
}
Value input = adaptor.getInput();
Value weight = adaptor.getWeight();
auto resultType = dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getResult().getType()));
if (!resultType) {
return rewriter.notifyMatchFailure(op, "result type conversion failed");
}
auto inTy = dyn_cast<RankedTensorType>(input.getType());
auto wTy = dyn_cast<RankedTensorType>(weight.getType());
if (!inTy || inTy.getRank() != 4 || !wTy || wTy.getRank() != 4) {
return rewriter.notifyMatchFailure(op, "only rank-4 input and weight are supported for mfuse conv lowering");
}
const bool noBias = isa<TorchD::NoneType>(op.getBias().getType());
return replaceAtenConvolutionWithAclnn(op, adaptor, rewriter, resultType, groups, transposed, s0, s1, p0, p1, d0,
d1, o0, o1, noBias);
}
};
struct ConvertAtenExpand : public OpConversionPattern<TorchD::AtenExpandOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchD::AtenExpandOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
llvm::SmallVector<int64_t, 4> sizeInts;
if (!isConstantListInt(op.getSize(), sizeInts)) {
return rewriter.notifyMatchFailure(op, "size must be a list construct of constant ints for mfuse.broadcast_to");
}
rewriter.replaceOpWithNewOp<mlir::mfuse::BroadcastToOp>(op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
};
struct ConvertAtenMeanDim : public OpConversionPattern<TorchD::AtenMeanDimOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchD::AtenMeanDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
auto inType = dyn_cast<RankedTensorType>(self.getType());
if (!inType || !inType.hasRank()) {
return rewriter.notifyMatchFailure(op, "input must be ranked tensor");
}
auto dimsOr = getReductionDims(op, op.getDim(), inType.getRank(), rewriter);
if (failed(dimsOr)) {
return failure();
}
bool keepdimValue = false;
if (!matchPattern(op.getKeepdim(), TorchD::m_TorchConstantBool(&keepdimValue))) {
return rewriter.notifyMatchFailure(op, "keepdim must be constant bool");
}
auto outType = dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (!outType) {
return rewriter.notifyMatchFailure(op, "result type conversion failed");
}
Type resultElementType = outType.getElementType();
if (!isa<FloatType>(resultElementType)) {
return rewriter.notifyMatchFailure(op, "result element type must be floating point");
}
if (failed(getStaticReductionSize(op, *dimsOr, inType, rewriter))) {
return failure();
}
auto dimsAttr = rewriter.getI64ArrayAttr(*dimsOr);
auto keepdimAttr = rewriter.getBoolAttr(keepdimValue);
rewriter.replaceOpWithNewOp<mfuse::ReduceMeanOp>(op, outType, self, dimsAttr, keepdimAttr);
return success();
}
};
struct ConvertAtenVarMeanCorrection : public OpConversionPattern<TorchD::AtenVarMeanCorrectionOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchD::AtenVarMeanCorrectionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
auto inType = dyn_cast<RankedTensorType>(self.getType());
if (!inType || !inType.hasRank()) {
return rewriter.notifyMatchFailure(op, "input must be ranked tensor");
}
int64_t inputRank = inType.getRank();
bool reduceAll = isa<TorchD::NoneType>(op.getDim().getType());
llvm::SmallVector<int64_t, 4> dims;
if (!reduceAll) {
llvm::SmallVector<Value, 4> dimValues;
if (!TorchD::getListConstructElements(op.getDim(), dimValues)) {
return rewriter.notifyMatchFailure(op, "dim must come from list construct");
}
if (dimValues.empty()) {
reduceAll = true;
} else {
dims.reserve(dimValues.size());
for (Value dimValue : dimValues) {
int64_t dim = 0;
if (!matchPattern(dimValue, TorchD::m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(op, "dim list must be constant ints");
}
dim = TorchD::toPositiveDim(dim, inputRank);
if (!TorchD::isValidDim(dim, inputRank)) {
return rewriter.notifyMatchFailure(op, "dim out of range");
}
if (std::find(dims.begin(), dims.end(), dim) != dims.end()) {
return rewriter.notifyMatchFailure(op, "duplicate reduction dims are not supported");
}
dims.push_back(dim);
}
std::sort(dims.begin(), dims.end());
}
}
if (reduceAll) {
dims.resize(inputRank);
std::iota(dims.begin(), dims.end(), 0);
}
int64_t correction = 0;
if (!matchPattern(op.getCorrection(), TorchD::m_TorchConstantInt(&correction))) {
return rewriter.notifyMatchFailure(op, "correction must be constant int");
}
bool keepdimValue = false;
if (!matchPattern(op.getKeepdim(), TorchD::m_TorchConstantBool(&keepdimValue))) {
return rewriter.notifyMatchFailure(op, "keepdim must be constant bool");
}
auto varType = getTypeConverter()->convertType(op.getResult(0).getType());
auto meanType = getTypeConverter()->convertType(op.getResult(1).getType());
if (!varType || !meanType) {
return rewriter.notifyMatchFailure(op, "result type conversion failed");
}
auto dimsAttr = rewriter.getI64ArrayAttr(dims);
auto correctionAttr = rewriter.getI64IntegerAttr(correction);
auto keepdimAttr = rewriter.getBoolAttr(keepdimValue);
rewriter.replaceOpWithNewOp<mfuse::AclnnVarMeanOp>(op, varType, meanType, self, dimsAttr, correctionAttr,
keepdimAttr);
return success();
}
};
struct ConvertAtenPermute : public OpConversionPattern<TorchD::AtenPermuteOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchD::AtenPermuteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
auto inType = mlir::dyn_cast<RankedTensorType>(self.getType());
if (!inType) {
return rewriter.notifyMatchFailure(op, "input type must be ranked tensor");
}
int64_t inputRank = inType.getRank();
llvm::SmallVector<int64_t, 4> permValues;
if (!isConstantListInt(op.getDims(), permValues)) {
return rewriter.notifyMatchFailure(op, "dims must be a list construct of constant ints");
}
if (permValues.size() != static_cast<size_t>(inputRank)) {
return rewriter.notifyMatchFailure(op, "dims size must match input rank");
}
for (int64_t &dim : permValues) {
dim = TorchD::toPositiveDim(dim, inputRank);
if (!TorchD::isValidDim(dim, inputRank)) {
return rewriter.notifyMatchFailure(op, "dim out of range");
}
}
auto outType = cast<RankedTensorType>(getTypeConverter()->convertType(op->getResult(0).getType()));
auto permAttr = rewriter.getI64ArrayAttr(permValues);
rewriter.replaceOpWithNewOp<mlir::mfuse::PermuteOp>(op, outType, self, permAttr);
return success();
}
};
struct ConvertAtenFull : public OpConversionPattern<TorchD::AtenFullOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchD::AtenFullOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<int64_t, 4> sizeInts;
if (!isConstantListInt(op.getSize(), sizeInts)) {
return rewriter.notifyMatchFailure(op, "size must be a list construct of constant ints");
}
auto outType = dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (!outType) {
return rewriter.notifyMatchFailure(op, "result type conversion failed");
}
mlir::IntegerAttr dtypeAttr;
int64_t dtypeInt = 0;
if (!isa<TorchD::NoneType>(op.getDtype().getType()) &&
matchPattern(op.getDtype(), TorchD::m_TorchConstantInt(&dtypeInt))) {
dtypeAttr = rewriter.getI64IntegerAttr(dtypeInt);
}
mlir::IntegerAttr layoutAttr;
int64_t layoutInt = 0;
if (!isa<TorchD::NoneType>(op.getLayout().getType()) &&
matchPattern(op.getLayout(), TorchD::m_TorchConstantInt(&layoutInt))) {
layoutAttr = rewriter.getI64IntegerAttr(layoutInt);
}
mlir::StringAttr deviceAttr;
std::string deviceStr;
if (!isa<TorchD::NoneType>(op.getDevice().getType())) {
if (auto constOp = op.getDevice().getDefiningOp<TorchD::ConstantDeviceOp>()) {
deviceAttr = constOp.getValueAttr();
}
}
mlir::BoolAttr pinMemoryAttr;
bool pinMemoryVal = false;
if (!isa<TorchD::NoneType>(op.getPinMemory().getType()) &&
matchPattern(op.getPinMemory(), TorchD::m_TorchConstantBool(&pinMemoryVal))) {
pinMemoryAttr = rewriter.getBoolAttr(pinMemoryVal);
}
rewriter.replaceOpWithNewOp<mlir::mfuse::FullOp>(op, outType, adaptor.getFillValue(), dtypeAttr, layoutAttr,
deviceAttr, pinMemoryAttr);
return success();
}
};
struct ConvertValueTensorLiteral : public OpConversionPattern<TorchD::ValueTensorLiteralOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchD::ValueTensorLiteralOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto denseAttr = dyn_cast<DenseElementsAttr>(op.getValueAttr());
if (!denseAttr || !denseAttr.isSplat()) {
return rewriter.notifyMatchFailure(op, "only splat vtensor literals are supported");
}
auto outType = dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (!outType) {
return rewriter.notifyMatchFailure(op, "result type conversion failed");
}
auto fillType = RankedTensorType::get({}, outType.getElementType());
auto fillAttr = DenseElementsAttr::get(fillType, denseAttr.getSplatValue<mlir::Attribute>());
auto fillValue = rewriter.create<mfuse::ConstantOp>(op.getLoc(), fillType, fillAttr);
auto dtypeIntOrFailure = getTorchScalarTypeInt(outType.getElementType());
if (failed(dtypeIntOrFailure)) {
return rewriter.notifyMatchFailure(op, "unsupported dtype for splat vtensor literal");
}
auto dtypeAttr = rewriter.getI64IntegerAttr(*dtypeIntOrFailure);
IntegerAttr layoutAttr;
StringAttr deviceAttr;
BoolAttr pinMemoryAttr;
rewriter.replaceOpWithNewOp<mfuse::FullOp>(op, outType, fillValue, dtypeAttr, layoutAttr, deviceAttr,
pinMemoryAttr);
return success();
}
};
struct ConvertAtenSumDimIntList : public OpConversionPattern<TorchD::AtenSumDimIntListOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchD::AtenSumDimIntListOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
auto inType = cast<RankedTensorType>(self.getType());
Value dimsValue = op.getDim();
bool reduceAll = false;
if (isa<TorchD::NoneType>(dimsValue.getType())) {
reduceAll = true;
}
llvm::SmallVector<int64_t, 4> dims;
int64_t inputRank = inType.getRank();
if (!reduceAll) {
llvm::SmallVector<Value, 4> dimValues;
if (!TorchD::getListConstructElements(dimsValue, dimValues)) {
return rewriter.notifyMatchFailure(op, "dim must come from list construct");
}
if (dimValues.empty()) {
reduceAll = true;
} else {
for (Value dimValue : dimValues) {
int64_t dim = 0;
if (!matchPattern(dimValue, TorchD::m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(op, "dim list must be constant ints");
}
dims.push_back(dim < 0 ? dim + inputRank : dim);
}
}
}
if (reduceAll) {
dims.resize(inputRank);
std::iota(dims.begin(), dims.end(), 0);
}
bool keepdimValue = false;
if (!matchPattern(op.getKeepdim(), TorchD::m_TorchConstantBool(&keepdimValue))) {
return rewriter.notifyMatchFailure(op, "keepdim must be constant bool");
}
auto outType = cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto dimsAttr = rewriter.getI64ArrayAttr(dims);
auto keepdimAttr = rewriter.getBoolAttr(keepdimValue);
rewriter.replaceOpWithNewOp<mlir::mfuse::ReduceSumOp>(op, outType, self, dimsAttr, keepdimAttr);
return success();
}
};
struct ConvertAtenTransposeInt : public OpConversionPattern<TorchD::AtenTransposeIntOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchD::AtenTransposeIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
int64_t dim0, dim1;
if (!matchPattern(op.getDim0(), TorchD::m_TorchConstantInt(&dim0)))
return rewriter.notifyMatchFailure(op, "dim0 must be constant");
if (!matchPattern(op.getDim1(), TorchD::m_TorchConstantInt(&dim1)))
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
auto inType = cast<RankedTensorType>(self.getType());
int64_t inputRank = inType.getRank();
auto outType = cast<RankedTensorType>(getTypeConverter()->convertType(op->getResult(0).getType()));
dim0 = TorchD::toPositiveDim(dim0, inputRank);
if (!TorchD::isValidDim(dim0, inputRank)) return rewriter.notifyMatchFailure(op, "dim0 out of range");
dim1 = TorchD::toPositiveDim(dim1, inputRank);
if (!TorchD::isValidDim(dim1, inputRank)) return rewriter.notifyMatchFailure(op, "dim1 out of range");
llvm::SmallVector<int64_t, 4> permValues(inputRank);
std::iota(std::begin(permValues), std::end(permValues), 0);
std::swap(permValues[dim0], permValues[dim1]);
auto permAttr = rewriter.getI64ArrayAttr(permValues);
rewriter.replaceOpWithNewOp<mlir::mfuse::PermuteOp>(op, outType, self, permAttr);
return success();
}
};
struct ConvertAtenSliceTensor : public OpConversionPattern<TorchD::AtenSliceTensorOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchD::AtenSliceTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
auto inType = cast<RankedTensorType>(self.getType());
auto outType = cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
int64_t dim, start, end, step;
if (!matchPattern(op.getDim(), TorchD::m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(op, "dim must be constant");
}
if (!matchPattern(op.getStart(), TorchD::m_TorchConstantInt(&start))) {
return rewriter.notifyMatchFailure(op, "start must be constant");
}
if (!matchPattern(op.getEnd(), TorchD::m_TorchConstantInt(&end))) {
return rewriter.notifyMatchFailure(op, "end must be constant");
}
if (!matchPattern(op.getStep(), TorchD::m_TorchConstantInt(&step))) {
return rewriter.notifyMatchFailure(op, "step must be constant");
}
int64_t inputRank = inType.getRank();
dim = TorchD::toPositiveDim(dim, inputRank);
if (!TorchD::isValidDim(dim, inputRank)) {
return rewriter.notifyMatchFailure(op, "dim out of range");
}
int64_t dimSize = inType.getDimSize(dim);
start = TorchD::toPositiveDim(start, dimSize);
if (start < 0) start = 0;
if (start > dimSize) start = dimSize;
end = TorchD::toPositiveDim(end, dimSize);
if (end < start) end = start;
if (end > dimSize) end = dimSize;
rewriter.replaceOpWithNewOp<mlir::mfuse::SliceOp>(
op, outType, self, rewriter.getI64IntegerAttr(dim), rewriter.getI64IntegerAttr(start),
rewriter.getI64IntegerAttr(end), rewriter.getI64IntegerAttr(step));
return success();
}
};
constexpr double kDefaultRmsNormEpsilon = 1e-6;
struct ConvertAtenRmsNorm : public OpConversionPattern<TorchD::AtenRmsNormOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(TorchD::AtenRmsNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (isa<TorchD::NoneType>(op.getWeight().getType())) {
return rewriter.notifyMatchFailure(op, "weight must not be None");
}
double epsVal = kDefaultRmsNormEpsilon;
if (!isa<TorchD::NoneType>(op.getEps().getType())) {
if (!matchPattern(op.getEps(), TorchD::m_TorchConstantFloat(&epsVal))) {
return rewriter.notifyMatchFailure(op, "eps must be a constant float");
}
}
llvm::SmallVector<int64_t, 4> normalizedShape;
if (!isConstantListInt(op.getNormalizedShape(), normalizedShape)) {
return rewriter.notifyMatchFailure(op, "normalized_shape must be constant int list");
}
Value input = adaptor.getInput();
Value weight = adaptor.getWeight();
auto inputType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank();
int64_t numNorm = static_cast<int64_t>(normalizedShape.size());
if (numNorm > inputRank) {
return rewriter.notifyMatchFailure(op, "normalized_shape rank exceeds input rank");
}
auto yOutType = cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto inputShape = inputType.getShape();
SmallVector<int64_t> rstdShape(inputShape.begin(), inputShape.end());
for (int64_t i = inputRank - numNorm; i < inputRank; ++i) {
rstdShape[i] = 1;
}
auto rstdType = RankedTensorType::get(rstdShape, inputType.getElementType());
auto epsilonAttr = rewriter.getF64FloatAttr(epsVal);
SmallVector<Type, 2> resultTypes = {yOutType, rstdType};
auto rmsNormOp = rewriter.create<mfuse::AclnnRmsNormOp>(op.getLoc(), resultTypes, input, weight, epsilonAttr);
rewriter.replaceOp(op, rmsNormOp.getYOut());
return success();
}
};
static void populateAtenToMfuseCustomPatterns(TypeConverter &converter, RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<ConvertAtenBroadcastTo>(converter, context);
patterns.add<ConvertAtenConvolution>(converter, context);
patterns.add<ConvertAtenExpand>(converter, context);
patterns.add<ConvertAtenFull>(converter, context);
patterns.add<ConvertAtenRmsNorm>(converter, context);
patterns.add<ConvertAtenSumDimIntList>(converter, context);
patterns.add<ConvertAtenTransposeInt>(converter, context);
patterns.add<ConvertAtenMeanDim>(converter, context);
patterns.add<ConvertAtenPermute>(converter, context);
patterns.add<ConvertValueTensorLiteral>(converter, context);
}
static void populateAtenToMfuseReshapeLikeOpPatterns(TypeConverter &converter, RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<ConvertReshapeLikeOp<TorchD::AtenReshapeOp>>(converter, context);
patterns.add<ConvertReshapeLikeOp<TorchD::AtenSqueezeDimOp>>(converter, context);
patterns.add<ConvertReshapeLikeOp<TorchD::AtenSqueezeOp>>(converter, context);
patterns.add<ConvertReshapeLikeOp<TorchD::AtenUnsqueezeOp>>(converter, context);
patterns.add<ConvertReshapeLikeOp<TorchD::AtenViewOp>>(converter, context);
}
static void populateAtenToMfuseBinaryOpPatterns(TypeConverter &converter, RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
patterns.add<ConvertBinaryOpPattern<TorchD::AtenDivTensorOp, mfuse::DivOp>,
ConvertBinaryOpPattern<TorchD::AtenDivScalarOp, mfuse::DivOp>,
ConvertBinaryOpPattern<TorchD::AtenEqTensorOp, mfuse::EqOp>,
ConvertBinaryOpPattern<TorchD::AtenEqScalarOp, mfuse::EqOp>,
ConvertBinaryOpPattern<TorchD::AtenGeTensorOp, mfuse::GeOp>,
ConvertBinaryOpPattern<TorchD::AtenGeScalarOp, mfuse::GeOp>,
ConvertBinaryOpPattern<TorchD::AtenGtTensorOp, mfuse::GtOp>,
ConvertBinaryOpPattern<TorchD::AtenGtScalarOp, mfuse::GtOp>,
ConvertBinaryOpPattern<TorchD::AtenLeTensorOp, mfuse::LeOp>,
ConvertBinaryOpPattern<TorchD::AtenLeScalarOp, mfuse::LeOp>,
ConvertBinaryOpPattern<TorchD::AtenLogicalAndOp, mfuse::LogicalAndOp>,
ConvertBinaryOpPattern<TorchD::AtenLogicalOrOp, mfuse::LogicalOrOp>,
ConvertBinaryOpPattern<TorchD::AtenLtTensorOp, mfuse::LtOp>,
ConvertBinaryOpPattern<TorchD::AtenLtScalarOp, mfuse::LtOp>,
ConvertBinaryOpPattern<TorchD::AtenMaximumOp, mfuse::MaximumOp>,
ConvertBinaryOpPattern<TorchD::AtenMinimumOp, mfuse::MinimumOp>,
ConvertBinaryOpPattern<TorchD::AtenMulTensorOp, mfuse::MulOp>,
ConvertBinaryOpPattern<TorchD::AtenMulScalarOp, mfuse::MulOp>,
ConvertBinaryOpPattern<TorchD::AtenNeTensorOp, mfuse::NeOp>,
ConvertBinaryOpPattern<TorchD::AtenNeScalarOp, mfuse::NeOp>,
ConvertBinaryOpPattern<TorchD::AtenPowTensorTensorOp, mfuse::PowOp>,
ConvertBinaryOpPattern<TorchD::AtenPowTensorScalarOp, mfuse::PowOp>>(converter, ctx);
}
void populateAtenToMfuseConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns) {
populateAtenToMfuseCustomPatterns(converter, patterns);
populateAtenToMfuseBinaryOpPatterns(converter, patterns);
populateAtenToMfuseReshapeLikeOpPatterns(converter, patterns);
patterns.add<ConvertAtenVarMeanCorrection>(converter, patterns.getContext());
}
}