* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <optional>
#include "mfusion/Conversion/Passes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
namespace mlir {
namespace TorchD = torch::Torch;
namespace {
static std::optional<int64_t> getConstInt(Value v) {
if (!v) {
return std::nullopt;
}
if (auto cst = v.getDefiningOp<TorchD::ConstantIntOp>()) {
return cst.getValueAttr().getInt();
}
return std::nullopt;
}
static std::optional<int64_t> getKnownLastDim(Value v) {
auto tensorTy = mlir::dyn_cast<TorchD::BaseTensorType>(v.getType());
if (!tensorTy || !tensorTy.hasSizes()) {
return std::nullopt;
}
auto sizes = tensorTy.getSizes();
if (sizes.empty()) {
return std::nullopt;
}
int64_t last = sizes.back();
if (last == TorchD::kUnknownSize) {
return std::nullopt;
}
return last;
}
static bool isLastDimIndex(Value v, int64_t dim) {
auto tensorTy = mlir::dyn_cast<TorchD::BaseTensorType>(v.getType());
if (!tensorTy || !tensorTy.hasSizes()) {
return dim == -1;
}
int64_t rank = static_cast<int64_t>(tensorTy.getSizes().size());
if (rank <= 0) {
return dim == -1;
}
return dim == -1 || dim == (rank - 1);
}
struct RotateHalfCatOperands {
Value negVal;
Value leftVal;
};
struct RotateHalfSlices {
Value base;
TorchD::AtenSliceTensorOp left;
TorchD::AtenSliceTensorOp right;
};
static std::optional<RotateHalfCatOperands> matchRotateHalfCatOperands(Value v) {
auto catOp = v.getDefiningOp<TorchD::AtenCatOp>();
if (!catOp) {
return std::nullopt;
}
auto catDim = getConstInt(catOp.getDim());
if (!catDim || *catDim != -1) {
return std::nullopt;
}
auto listOp = catOp.getTensors().getDefiningOp<TorchD::PrimListConstructOp>();
if (!listOp || listOp->getNumOperands() != 2) {
return std::nullopt;
}
return RotateHalfCatOperands{listOp->getOperand(0), listOp->getOperand(1)};
}
static std::optional<RotateHalfSlices> matchRotateHalfSlices(const RotateHalfCatOperands &operands) {
auto negOp = operands.negVal.getDefiningOp<TorchD::AtenNegOp>();
if (!negOp) {
return std::nullopt;
}
Value rightVal = negOp.getSelf();
auto sliceRight = rightVal.getDefiningOp<TorchD::AtenSliceTensorOp>();
auto sliceLeft = operands.leftVal.getDefiningOp<TorchD::AtenSliceTensorOp>();
if (!sliceRight || !sliceLeft) {
return std::nullopt;
}
Value xRightBase = sliceRight.getSelf();
Value xLeftBase = sliceLeft.getSelf();
if (xRightBase != xLeftBase) {
return std::nullopt;
}
return RotateHalfSlices{xLeftBase, sliceLeft, sliceRight};
}
static bool isLastDimSlice(Value base, TorchD::AtenSliceTensorOp sliceOp) {
auto dim = getConstInt(sliceOp.getDim());
return dim && isLastDimIndex(base, *dim);
}
static bool hasUnitStep(TorchD::AtenSliceTensorOp sliceOp) {
auto step = getConstInt(sliceOp.getStep());
return step && *step == 1;
}
static bool hasConstRange(TorchD::AtenSliceTensorOp sliceOp, int64_t expectedStart, int64_t expectedEnd) {
auto start = getConstInt(sliceOp.getStart());
auto end = getConstInt(sliceOp.getEnd());
return start && end && *start == expectedStart && *end == expectedEnd;
}
static bool hasRotateHalfSliceParams(const RotateHalfSlices &slices) {
if (!isLastDimSlice(slices.base, slices.left) || !isLastDimSlice(slices.base, slices.right)) {
return false;
}
if (!hasUnitStep(slices.left) || !hasUnitStep(slices.right)) {
return false;
}
auto lastDim = getKnownLastDim(slices.base);
if (!lastDim || (*lastDim % 2) != 0) {
return false;
}
int64_t half = *lastDim / 2;
constexpr int64_t kEndMax = 9223372036854775807LL;
return hasConstRange(slices.left, 0, half) && hasConstRange(slices.right, half, kEndMax);
}
static Value matchRotateHalfCat(Value v) {
auto operands = matchRotateHalfCatOperands(v);
if (!operands) {
return Value();
}
auto slices = matchRotateHalfSlices(*operands);
if (!slices || !hasRotateHalfSliceParams(*slices)) {
return Value();
}
return slices->base;
}
struct RoPEMatchState {
TorchD::AtenAddTensorOp addOp;
TorchD::AtenMulTensorOp cosMulOp;
TorchD::AtenMulTensorOp sinMulOp;
Value x;
Value cos;
Value sin;
};
static LogicalResult matchRoPE(TorchD::AtenAddTensorOp addOp, RoPEMatchState &state, PatternRewriter &rewriter) {
auto alpha = getConstInt(addOp.getAlpha());
if (!alpha || *alpha != 1) {
return rewriter.notifyMatchFailure(addOp, "RoPE requires add alpha == 1");
}
auto lhsMul = addOp.getSelf().getDefiningOp<TorchD::AtenMulTensorOp>();
auto rhsMul = addOp.getOther().getDefiningOp<TorchD::AtenMulTensorOp>();
if (!lhsMul || !rhsMul) {
return rewriter.notifyMatchFailure(addOp, "add inputs must come from torch.aten.mul.Tensor");
}
auto tryAssign = [&](TorchD::AtenMulTensorOp cosMul, TorchD::AtenMulTensorOp sinMul) -> LogicalResult {
Value rotInput = sinMul.getSelf();
Value sinCandidate = sinMul.getOther();
Value xFromRot = matchRotateHalfCat(rotInput);
if (!xFromRot) {
rotInput = sinMul.getOther();
sinCandidate = sinMul.getSelf();
xFromRot = matchRotateHalfCat(rotInput);
}
if (!xFromRot) {
return rewriter.notifyMatchFailure(addOp, "rotate_half(x) not matched");
}
Value a = cosMul.getSelf();
Value b = cosMul.getOther();
Value aBase = a;
Value bBase = b;
Value xBase = xFromRot;
Value xVal;
Value cosVal;
if (aBase == xBase) {
xVal = a;
cosVal = b;
} else if (bBase == xBase) {
xVal = b;
cosVal = a;
} else {
return rewriter.notifyMatchFailure(addOp, "x in cos mul must match rotate_half input");
}
state.addOp = addOp;
state.cosMulOp = cosMul;
state.sinMulOp = sinMul;
state.x = xVal;
state.cos = cosVal;
state.sin = sinCandidate;
return success();
};
if (succeeded(tryAssign(lhsMul, rhsMul)) || succeeded(tryAssign(rhsMul, lhsMul))) {
return success();
}
return failure();
}
class TorchFuseRoPEPattern : public OpRewritePattern<TorchD::AtenAddTensorOp> {
public:
using OpRewritePattern<TorchD::AtenAddTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TorchD::AtenAddTensorOp op, PatternRewriter &rewriter) const override {
RoPEMatchState state;
if (failed(matchRoPE(op, state, rewriter))) {
return failure();
}
rewriter.replaceOpWithNewOp<TorchD::OperatorOp>(
op, op.getResult().getType(), rewriter.getStringAttr("torch.npu.npu_rotary_mul"),
SmallVector<Value>{state.x, state.cos, state.sin}, 0);
return success();
}
};
struct TorchFuseRoPEPass : public PassWrapper<TorchFuseRoPEPass, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "torch-fuse-rope"; }
StringRef getDescription() const final {
return "Fuse decomposed RoPE on Torch dialect into torch.npu.npu_rotary_mul";
}
void getDependentDialects(DialectRegistry ®istry) const override { registry.insert<TorchD::TorchDialect>(); }
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<TorchFuseRoPEPattern>(&getContext());
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
}
};
}
std::unique_ptr<Pass> createTorchFuseRoPEPass() { return std::make_unique<TorchFuseRoPEPass>(); }
}