/**
 * Copyright 2026 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "mfusion/Conversion/MfuseToTorch/MfuseMetaToTorch.h"

#include <algorithm>
#include <iterator>
#include <optional>

#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "mfusion/Dialect/Mfuse/IR/Mfuse.h"

namespace mlir {

namespace TorchD = mlir::torch::Torch;

namespace {

static bool isDvmKernelGenerator(llvm::StringRef kernelGenerator) { return kernelGenerator == "dvm"; }

static std::optional<int64_t> getTorchOrRankedTensorRank(mlir::Type ty) {
  if (auto vtt = mlir::dyn_cast<TorchD::ValueTensorType>(ty)) {
    if (!vtt.hasSizes()) {
      return std::nullopt;
    }
    return static_cast<int64_t>(vtt.getSizes().size());
  }
  if (auto rt = mlir::dyn_cast<mlir::RankedTensorType>(ty)) {
    return static_cast<int64_t>(rt.getRank());
  }
  return std::nullopt;
}

static bool isTwoDMatmulOperandTypes(mlir::Type selfTy, mlir::Type otherTy) {
  auto r1 = getTorchOrRankedTensorRank(selfTy);
  auto r2 = getTorchOrRankedTensorRank(otherTy);
  return r1 && r2 && *r1 == 2 && *r2 == 2;
}

static mlir::FailureOr<mlir::Value> buildSwapLastTwoDimsPermute(mlir::Location loc, mlir::Value v,
                                                                mlir::ConversionPatternRewriter &rewriter) {
  auto vtt = mlir::dyn_cast<TorchD::ValueTensorType>(v.getType());
  if (!vtt || !vtt.hasSizes()) {
    return mlir::failure();
  }
  auto sizes = vtt.getSizes();
  int64_t rank = static_cast<int64_t>(sizes.size());
  if (rank < 2) {
    return mlir::failure();
  }
  llvm::SmallVector<int64_t> newSizes(sizes.begin(), sizes.end());
  std::swap(newSizes[rank - 2], newSizes[rank - 1]);
  mlir::Type permResultType = vtt.getWithSizesAndDtype(newSizes, vtt.getOptionalDtype());
  llvm::SmallVector<mlir::Value> permDims;
  permDims.reserve(static_cast<size_t>(rank));
  for (int64_t i = 0; i < rank - 2; ++i) {
    permDims.push_back(rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
  }
  permDims.push_back(rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(rank - 1)));
  permDims.push_back(rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(rank - 2)));
  mlir::MLIRContext *ctx = rewriter.getContext();
  auto listType = TorchD::ListType::get(ctx, TorchD::IntType::get(ctx));
  mlir::Value permList = rewriter.create<TorchD::PrimListConstructOp>(loc, listType, permDims);
  return rewriter.create<TorchD::AtenPermuteOp>(loc, permResultType, v, permList).getResult();
}

std::optional<int64_t> getTorchScalarTypeInt(mlir::Type type) {
  if (mlir::isa<mlir::NoneType, mlir::mfuse::NoneType, TorchD::NoneType>(type)) {
    return std::nullopt;
  }
  if (type.isSignlessInteger() && !type.isSignlessInteger(1)) {
    if (type.isSignlessInteger(8)) return static_cast<int64_t>(mlir::torch::torch_upstream::ScalarType::Char);
    if (type.isSignlessInteger(16)) return static_cast<int64_t>(mlir::torch::torch_upstream::ScalarType::Short);
    if (type.isSignlessInteger(32)) return static_cast<int64_t>(mlir::torch::torch_upstream::ScalarType::Int);
    if (type.isSignlessInteger(64)) return static_cast<int64_t>(mlir::torch::torch_upstream::ScalarType::Long);
    return std::nullopt;
  }
  if (type.isUnsignedInteger() && !type.isUnsignedInteger(8)) {
    return std::nullopt;
  }
  // Refer to: torch-mlir/Dialect/Torch/Utils/Utils.cpp
  auto scalarType = TorchD::getScalarTypeForType(type);
  return static_cast<int64_t>(scalarType);
}

mlir::FailureOr<mlir::Value> buildTorchDtypeValue(mlir::Type dtypeType, mlir::Location loc,
                                                  mlir::ConversionPatternRewriter &rewriter) {
  auto maybeDtypeInt = getTorchScalarTypeInt(dtypeType);
  if (!maybeDtypeInt) {
    return mlir::failure();
  }
  return rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(*maybeDtypeInt)).getResult();
}


mlir::Value buildTorchIntListFromI64ArrayAttr(mlir::ArrayAttr attr, mlir::Location loc,
                                              mlir::ConversionPatternRewriter &rewriter) {
  llvm::SmallVector<mlir::Value> values;
  values.reserve(attr.size());
  for (auto element : attr) {
    int64_t v = mlir::cast<mlir::IntegerAttr>(element).getInt();
    values.push_back(rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(v)));
  }
  return rewriter.create<TorchD::PrimListConstructOp>(
    loc, TorchD::ListType::get(rewriter.getContext(), TorchD::IntType::get(rewriter.getContext())), values);
}

/// Same as `buildTorchIntListFromI64ArrayAttr` but accepts `DenseI64ArrayAttr` or `ArrayAttr` of `IntegerAttr`
/// (common for ODS `I64ArrayAttr`). Fails if \p attr is neither.
static mlir::FailureOr<mlir::Value> buildTorchIntListFromI64LikeAttr(mlir::Attribute attr, mlir::Location loc,
                                                                     mlir::ConversionPatternRewriter &rewriter) {
  llvm::SmallVector<mlir::Value> values;
  if (auto dense = mlir::dyn_cast<mlir::DenseI64ArrayAttr>(attr)) {
    values.reserve(dense.asArrayRef().size());
    std::transform(dense.asArrayRef().begin(), dense.asArrayRef().end(), std::back_inserter(values), [&](int64_t v) {
      return rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(v));
    });
  } else if (auto arr = mlir::dyn_cast<mlir::ArrayAttr>(attr)) {
    values.reserve(arr.size());
    for (mlir::Attribute element : arr) {
      auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(element);
      if (!intAttr) {
        return mlir::failure();
      }
      values.push_back(rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(intAttr.getInt())));
    }
  } else {
    return mlir::failure();
  }
  mlir::MLIRContext *ctx = rewriter.getContext();
  return rewriter
    .create<TorchD::PrimListConstructOp>(loc, TorchD::ListType::get(ctx, TorchD::IntType::get(ctx)), values)
    .getResult();
}

// ============================================================================
// =   Please keep the patterns in alphabetical order by operator name   =
// ============================================================================

/// Converts mfuse.broadcast_to -> torch.aten.expand.
class ConvertMfuseBroadcastTo : public mlir::OpConversionPattern<mlir::mfuse::BroadcastToOp> {
 public:
  using OpConversionPattern::OpConversionPattern;

  mlir::LogicalResult matchAndRewrite(mlir::mfuse::BroadcastToOp op, OpAdaptor adaptor,
                                      mlir::ConversionPatternRewriter &rewriter) const override {
    auto outType = mlir::cast<mlir::RankedTensorType>(op.getOutput().getType());
    llvm::SmallVector<mlir::Value> sizeValues;
    sizeValues.reserve(outType.getRank());
    std::transform(outType.getShape().begin(), outType.getShape().end(), std::back_inserter(sizeValues),
                   [&](int64_t dim) {
                     return rewriter.create<TorchD::ConstantIntOp>(
                       // for dynamic shape, use -1 to represent the original dimension
                       op.getLoc(), rewriter.getI64IntegerAttr(dim == mlir::ShapedType::kDynamic ? -1 : dim));
                   });

    auto listType = TorchD::ListType::get(op.getContext(), TorchD::IntType::get(op.getContext()));
    mlir::Value sizeList = rewriter.create<TorchD::PrimListConstructOp>(op.getLoc(), listType, sizeValues);

    mlir::Type torchResultType = getTypeConverter()->convertType(outType);
    if (!torchResultType) return mlir::failure();

    mlir::Value input = adaptor.getInput();
    mlir::Value implicit = rewriter.create<TorchD::ConstantBoolOp>(op.getLoc(), false);
    rewriter.replaceOpWithNewOp<TorchD::AtenExpandOp>(op, torchResultType, input, sizeList, implicit);
    return mlir::success();
  }
};

/// Builds torch.aten.convolution from mfuse.aclnn.conv2d / mfuse.aclnn.conv2d_with_bias hyper-parameters.
static mlir::FailureOr<mlir::Value> buildAtenConvolutionFromMfuseConvAttrs(
  mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, mlir::Type resultType, mlir::Value input,
  mlir::Value weight, mlir::Value bias, mlir::Attribute strideAttr, mlir::Attribute paddingAttr,
  mlir::Attribute dilationAttr, bool transposed, mlir::Attribute outputPaddingAttr, int64_t groups) {
  auto strideList = buildTorchIntListFromI64LikeAttr(strideAttr, loc, rewriter);
  if (failed(strideList)) {
    return mlir::failure();
  }
  auto paddingList = buildTorchIntListFromI64LikeAttr(paddingAttr, loc, rewriter);
  if (failed(paddingList)) {
    return mlir::failure();
  }
  auto dilationList = buildTorchIntListFromI64LikeAttr(dilationAttr, loc, rewriter);
  if (failed(dilationList)) {
    return mlir::failure();
  }
  auto outputPaddingList = buildTorchIntListFromI64LikeAttr(outputPaddingAttr, loc, rewriter);
  if (failed(outputPaddingList)) {
    return mlir::failure();
  }
  mlir::Value transposedVal = rewriter.create<TorchD::ConstantBoolOp>(loc, transposed);
  mlir::Value groupsVal = rewriter.create<TorchD::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(groups));
  return rewriter
    .create<TorchD::AtenConvolutionOp>(loc, resultType, input, weight, bias, *strideList, *paddingList, *dilationList,
                                       transposedVal, *outputPaddingList, groupsVal)
    .getResult();
}

/// Converts mfuse.aclnn.conv2d -> torch.aten.convolution (with bias=None).
class ConvertMfuseAclnnConv2D : public mlir::OpConversionPattern<mlir::mfuse::AclnnConv2DOp> {
 public:
  using OpConversionPattern::OpConversionPattern;

  mlir::LogicalResult matchAndRewrite(mlir::mfuse::AclnnConv2DOp op, OpAdaptor adaptor,
                                      mlir::ConversionPatternRewriter &rewriter) const override {
    mlir::Value input = adaptor.getInput();
    mlir::Value weight = adaptor.getWeight();
    auto resultType = getTypeConverter()->convertType(op.getResult().getType());
    if (!resultType) {
      return rewriter.notifyMatchFailure(op, "result type conversion failed");
    }
    mlir::Value noneBias = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
    auto conv = buildAtenConvolutionFromMfuseConvAttrs(rewriter, op.getLoc(), resultType, input, weight, noneBias,
                                                       op.getStride(), op.getPadding(), op.getDilation(),
                                                       op.getTransposed(), op.getOutputPadding(), op.getGroups());
    if (failed(conv)) {
      return rewriter.notifyMatchFailure(
        op, "mfuse conv stride/padding/dilation/output_padding must be DenseI64ArrayAttr or ArrayAttr of integers");
    }
    rewriter.replaceOp(op, *conv);
    return mlir::success();
  }
};

/// Converts mfuse.aclnn.conv2d_with_bias -> torch.aten.convolution (with bias operand).
/// This avoids emitting a separate torch.aten.add.Tensor after conv.
class ConvertMfuseAclnnConv2DWithBias : public mlir::OpConversionPattern<mlir::mfuse::AclnnConv2DWithBiasOp> {
 public:
  using OpConversionPattern::OpConversionPattern;

  mlir::LogicalResult matchAndRewrite(mlir::mfuse::AclnnConv2DWithBiasOp op, OpAdaptor adaptor,
                                      mlir::ConversionPatternRewriter &rewriter) const override {
    mlir::Value input = adaptor.getInput();
    mlir::Value weight = adaptor.getWeight();
    mlir::Value bias = adaptor.getBias();
    auto resultType = getTypeConverter()->convertType(op.getResult().getType());
    if (!resultType) {
      return rewriter.notifyMatchFailure(op, "result type conversion failed");
    }
    auto conv = buildAtenConvolutionFromMfuseConvAttrs(rewriter, op.getLoc(), resultType, input, weight, bias,
                                                       op.getStride(), op.getPadding(), op.getDilation(),
                                                       op.getTransposed(), op.getOutputPadding(), op.getGroups());
    if (failed(conv)) {
      return rewriter.notifyMatchFailure(
        op, "mfuse conv stride/padding/dilation/output_padding must be DenseI64ArrayAttr or ArrayAttr of integers");
    }
    rewriter.replaceOp(op, *conv);
    return mlir::success();
  }
};

/// Converts mfuse.full -> torch.aten.full.
/// Reconstructs all original torch.aten.full arguments from the preserved attributes.
class ConvertMfuseFull : public mlir::OpConversionPattern<mlir::mfuse::FullOp> {
 public:
  using OpConversionPattern::OpConversionPattern;

  mlir::LogicalResult matchAndRewrite(mlir::mfuse::FullOp op, OpAdaptor adaptor,
                                      mlir::ConversionPatternRewriter &rewriter) const override {
    auto resultType = mlir::dyn_cast<mlir::RankedTensorType>(op.getResult().getType());
    if (!resultType) {
      return mlir::failure();
    }

    // Build size list from result shape.
    llvm::SmallVector<mlir::Value> sizeValues;
    sizeValues.reserve(resultType.getRank());
    const auto &shape = resultType.getShape();
    std::transform(shape.begin(), shape.end(), std::back_inserter(sizeValues), [&](int64_t dim) {
      return rewriter.create<TorchD::ConstantIntOp>(op.getLoc(), rewriter.getI64IntegerAttr(dim));
    });
    auto listType = TorchD::ListType::get(op.getContext(), TorchD::IntType::get(op.getContext()));
    mlir::Value sizeList = rewriter.create<TorchD::PrimListConstructOp>(op.getLoc(), listType, sizeValues);

    // Get fill_value from operand and convert to torch scalar.
    mlir::Value fillVal = adaptor.getFillValue();

    // Restore dtype from saved attribute, or build from element type.
    mlir::Value dtypeVal;
    if (auto dtypeAttr = op.getDtypeAttr()) {
      dtypeVal = rewriter.create<TorchD::ConstantIntOp>(op.getLoc(), dtypeAttr);
    } else {
      auto dtypeValOrFailure = buildTorchDtypeValue(resultType.getElementType(), op.getLoc(), rewriter);
      if (mlir::succeeded(dtypeValOrFailure)) {
        dtypeVal = *dtypeValOrFailure;
      } else {
        dtypeVal = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
      }
    }

    // Restore layout from saved attribute, or None.
    mlir::Value layoutVal;
    if (auto layoutAttr = op.getLayoutAttr()) {
      layoutVal = rewriter.create<TorchD::ConstantIntOp>(op.getLoc(), layoutAttr);
    } else {
      layoutVal = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
    }

    // Restore device from saved attribute, or None.
    mlir::Value deviceVal;
    if (auto deviceAttr = op.getDeviceAttr()) {
      deviceVal = rewriter.create<TorchD::ConstantDeviceOp>(op.getLoc(), deviceAttr);
    } else {
      deviceVal = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
    }

    // Restore pin_memory from saved attribute, or None.
    mlir::Value pinMemoryVal;
    if (auto pinMemoryAttr = op.getPinMemoryAttr()) {
      pinMemoryVal = rewriter.create<TorchD::ConstantBoolOp>(op.getLoc(), pinMemoryAttr.getValue());
    } else {
      pinMemoryVal = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
    }

    mlir::Type torchResultType = getTypeConverter()->convertType(resultType);
    if (!torchResultType) return mlir::failure();

    rewriter.replaceOpWithNewOp<TorchD::AtenFullOp>(op, torchResultType, sizeList, fillVal, dtypeVal, layoutVal,
                                                    deviceVal, pinMemoryVal);
    return mlir::success();
  }
};

/// Converts mfuse.matmul -> torch.aten.mm (2D) or torch.aten.matmul (ND).
/// For kernel-generator dvm and 2D operands, trans_x1/trans_x2 are attached as dvm_trans_a/dvm_trans_b
/// on torch.aten.mm; otherwise transpose is expressed with torch.aten.permute (swap last two dims).
class ConvertMfuseMatmul : public mlir::OpConversionPattern<mlir::mfuse::MatmulOp> {
 public:
  ConvertMfuseMatmul(mlir::TypeConverter &converter, mlir::MLIRContext *ctx, llvm::StringRef kernelGenerator)
      : OpConversionPattern(converter, ctx), kernelGenerator_(kernelGenerator.str()) {}

  mlir::LogicalResult matchAndRewrite(mlir::mfuse::MatmulOp op, OpAdaptor adaptor,
                                      mlir::ConversionPatternRewriter &rewriter) const override {
    mlir::Value self = adaptor.getSelf();
    mlir::Value other = adaptor.getOther();
    bool trans1 = op.getTransX1();
    bool trans2 = op.getTransX2();

    auto resultType = getTypeConverter()->convertType(op.getResult().getType());
    if (!resultType) {
      return mlir::failure();
    }

    mlir::Location loc = op.getLoc();
    const bool twoD = isTwoDMatmulOperandTypes(self.getType(), other.getType());
    const bool dvm = isDvmKernelGenerator(kernelGenerator_);

    if (twoD && dvm) {
      auto newMm = rewriter.create<TorchD::AtenMmOp>(loc, resultType, self, other);
      newMm->setAttr("dvm_trans_a", rewriter.getBoolAttr(trans1));
      newMm->setAttr("dvm_trans_b", rewriter.getBoolAttr(trans2));
      rewriter.replaceOp(op, newMm.getResult());
      return mlir::success();
    }

    if (trans1) {
      auto permOr = buildSwapLastTwoDimsPermute(loc, self, rewriter);
      if (mlir::failed(permOr)) {
        return mlir::failure();
      }
      self = *permOr;
    }
    if (trans2) {
      auto permOr = buildSwapLastTwoDimsPermute(loc, other, rewriter);
      if (mlir::failed(permOr)) {
        return mlir::failure();
      }
      other = *permOr;
    }
    if (twoD) {
      rewriter.replaceOpWithNewOp<TorchD::AtenMmOp>(op, resultType, self, other);
    } else {
      rewriter.replaceOpWithNewOp<TorchD::AtenMatmulOp>(op, resultType, self, other);
    }
    return mlir::success();
  }

 private:
  std::string kernelGenerator_;
};

/// Converts mfuse.matmul_with_bias -> torch.aten.mm/matmul + torch.aten.add.Tensor.
/// Since torch.aten.mm/matmul don't support bias directly, we decompose it into
/// matmul followed by add. Transpose handling matches ConvertMfuseMatmul.
class ConvertMfuseMatmulWithBias : public mlir::OpConversionPattern<mlir::mfuse::MatmulWithBiasOp> {
 public:
  ConvertMfuseMatmulWithBias(mlir::TypeConverter &converter, mlir::MLIRContext *ctx, llvm::StringRef kernelGenerator)
      : OpConversionPattern(converter, ctx), kernelGenerator_(kernelGenerator.str()) {}

  mlir::LogicalResult matchAndRewrite(mlir::mfuse::MatmulWithBiasOp op, OpAdaptor adaptor,
                                      mlir::ConversionPatternRewriter &rewriter) const override {
    mlir::Value self = adaptor.getSelf();
    mlir::Value other = adaptor.getOther();
    mlir::Value bias = adaptor.getBias();
    bool trans1 = op.getTransX1();
    bool trans2 = op.getTransX2();

    auto resultType = getTypeConverter()->convertType(op.getResult().getType());
    if (!resultType) {
      return mlir::failure();
    }

    mlir::Location loc = op.getLoc();
    const bool twoD = isTwoDMatmulOperandTypes(self.getType(), other.getType());
    const bool dvm = isDvmKernelGenerator(kernelGenerator_);

    mlir::Value matmulResult;
    if (twoD && dvm) {
      auto newMm = rewriter.create<TorchD::AtenMmOp>(loc, resultType, self, other);
      newMm->setAttr("dvm_trans_a", rewriter.getBoolAttr(trans1));
      newMm->setAttr("dvm_trans_b", rewriter.getBoolAttr(trans2));
      matmulResult = newMm.getResult();
    } else {
      if (trans1) {
        auto permOr = buildSwapLastTwoDimsPermute(loc, self, rewriter);
        if (mlir::failed(permOr)) {
          return mlir::failure();
        }
        self = *permOr;
      }
      if (trans2) {
        auto permOr = buildSwapLastTwoDimsPermute(loc, other, rewriter);
        if (mlir::failed(permOr)) {
          return mlir::failure();
        }
        other = *permOr;
      }
      if (twoD) {
        matmulResult = rewriter.create<TorchD::AtenMmOp>(loc, resultType, self, other).getResult();
      } else {
        matmulResult = rewriter.create<TorchD::AtenMatmulOp>(loc, resultType, self, other).getResult();
      }
    }

    // Add bias: torch.aten.add.Tensor(matmul_result, bias, alpha=1)
    constexpr double kAlphaOne = 1.0;
    mlir::FloatAttr alphaAttr = rewriter.getFloatAttr(rewriter.getF64Type(), kAlphaOne);
    mlir::Value alphaOne = rewriter.create<TorchD::ConstantFloatOp>(op.getLoc(), alphaAttr);
    mlir::Value addResult =
      rewriter.create<TorchD::AtenAddTensorOp>(op.getLoc(), resultType, matmulResult, bias, alphaOne);

    rewriter.replaceOp(op, addResult);
    return mlir::success();
  }

 private:
  std::string kernelGenerator_;
};

struct ConvertMfuseCast : public mlir::OpConversionPattern<mlir::mfuse::CastOp> {
  using OpConversionPattern::OpConversionPattern;

  mlir::LogicalResult matchAndRewrite(mlir::mfuse::CastOp op, OpAdaptor adaptor,
                                      mlir::ConversionPatternRewriter &rewriter) const override {
    mlir::Value input = adaptor.getInput();
    auto resultType = getTypeConverter()->convertType(op.getResult().getType());

    mlir::Type dtypeType = mlir::cast<mlir::RankedTensorType>(op.getResult().getType()).getElementType();
    auto dtypeValOrFailure = buildTorchDtypeValue(dtypeType, op.getLoc(), rewriter);
    if (mlir::failed(dtypeValOrFailure)) {
      return rewriter.notifyMatchFailure(op, "unsupported dtype for torch scalar type");
    }
    mlir::Value dtypeVal = *dtypeValOrFailure;
    rewriter.replaceOpWithNewOp<TorchD::PrimsConvertElementTypeOp>(op, resultType, input, dtypeVal);
    return mlir::success();
  }
};

/// Converts mfuse.permute -> torch.aten.permute.
/// Performs minimal structural validation and relies on upstream passes for
/// semantic validity of perm values.
class ConvertMfusePermute : public mlir::OpConversionPattern<mlir::mfuse::PermuteOp> {
 public:
  using OpConversionPattern::OpConversionPattern;

  mlir::LogicalResult matchAndRewrite(mlir::mfuse::PermuteOp op, OpAdaptor adaptor,
                                      mlir::ConversionPatternRewriter &rewriter) const override {
    auto permAttr = op.getPermAttr();
    if (!permAttr) {
      return rewriter.notifyMatchFailure(op, "perm attribute must be present");
    }

    auto inputType = mlir::dyn_cast<mlir::RankedTensorType>(op.getInput().getType());
    if (!inputType) {
      return rewriter.notifyMatchFailure(op, "input must be ranked tensor");
    }

    int64_t rank = inputType.getRank();
    auto permValues = permAttr.getValue();
    if (permValues.size() != static_cast<size_t>(rank)) {
      return rewriter.notifyMatchFailure(op, "perm size must match input rank");
    }

    // Minimal structural validation: all elements must be integer attributes.
    llvm::SmallVector<mlir::Value> permDims;
    permDims.reserve(permValues.size());
    for (auto attr : permValues) {
      auto dimAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
      if (!dimAttr) {
        return rewriter.notifyMatchFailure(op, "perm values must be integers");
      }
      permDims.push_back(rewriter.create<TorchD::ConstantIntOp>(op.getLoc(), dimAttr));
    }

    mlir::Type resultType = getTypeConverter()->convertType(op.getResult().getType());
    if (!resultType) return mlir::failure();

    mlir::Value input = adaptor.getInput();
    auto listType = TorchD::ListType::get(op.getContext(), TorchD::IntType::get(op.getContext()));
    mlir::Value permList = rewriter.create<TorchD::PrimListConstructOp>(op.getLoc(), listType, permDims);
    rewriter.replaceOpWithNewOp<TorchD::AtenPermuteOp>(op, resultType, input, permList);
    return mlir::success();
  }
};

struct ConvertMfuseReduceMean : public mlir::OpConversionPattern<mlir::mfuse::ReduceMeanOp> {
  using OpConversionPattern::OpConversionPattern;

  mlir::LogicalResult matchAndRewrite(mlir::mfuse::ReduceMeanOp op, OpAdaptor adaptor,
                                      mlir::ConversionPatternRewriter &rewriter) const override {
    mlir::Value input = adaptor.getInput();
    auto resultType = getTypeConverter()->convertType(op.getResult().getType());

    mlir::Value dimList = buildTorchIntListFromI64ArrayAttr(op.getDimensions(), op.getLoc(), rewriter);

    bool keepdim = op.getKeepdim();
    mlir::Value keepdimVal = rewriter.create<TorchD::ConstantBoolOp>(op.getLoc(), keepdim);

    mlir::Type dtypeType = mlir::cast<mlir::RankedTensorType>(op.getResult().getType()).getElementType();
    mlir::Value dtypeVal;
    if (mlir::isa<mlir::NoneType, mlir::mfuse::NoneType, TorchD::NoneType>(dtypeType)) {
      dtypeVal = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
    } else {
      auto dtypeValOrFailure = buildTorchDtypeValue(dtypeType, op.getLoc(), rewriter);
      if (mlir::failed(dtypeValOrFailure)) {
        return rewriter.notifyMatchFailure(op, "unsupported dtype for torch scalar type");
      }
      dtypeVal = *dtypeValOrFailure;
    }

    rewriter.replaceOpWithNewOp<TorchD::AtenMeanDimOp>(op, resultType, input, dimList, keepdimVal, dtypeVal);
    return mlir::success();
  }
};

struct ConvertMfuseReduceSum : public mlir::OpConversionPattern<mlir::mfuse::ReduceSumOp> {
  using OpConversionPattern::OpConversionPattern;

  mlir::LogicalResult matchAndRewrite(mlir::mfuse::ReduceSumOp op, OpAdaptor adaptor,
                                      mlir::ConversionPatternRewriter &rewriter) const override {
    mlir::Value input = adaptor.getInput();
    auto resultType = getTypeConverter()->convertType(op.getResult().getType());

    mlir::Value dimList = buildTorchIntListFromI64ArrayAttr(op.getDimensions(), op.getLoc(), rewriter);

    bool keepdim = op.getKeepdim();
    mlir::Value keepdimVal = rewriter.create<TorchD::ConstantBoolOp>(op.getLoc(), keepdim);

    mlir::Type dtypeType = mlir::cast<mlir::RankedTensorType>(op.getResult().getType()).getElementType();
    mlir::Value dtypeVal;
    if (mlir::isa<mlir::NoneType, mlir::mfuse::NoneType, TorchD::NoneType>(dtypeType)) {
      dtypeVal = rewriter.create<TorchD::ConstantNoneOp>(op.getLoc());
    } else {
      auto dtypeValOrFailure = buildTorchDtypeValue(dtypeType, op.getLoc(), rewriter);
      if (mlir::failed(dtypeValOrFailure)) {
        return rewriter.notifyMatchFailure(op, "unsupported dtype for torch scalar type");
      }
      dtypeVal = *dtypeValOrFailure;
    }

    rewriter.replaceOpWithNewOp<TorchD::AtenSumDimIntListOp>(op, resultType, input, dimList, keepdimVal, dtypeVal);
    return mlir::success();
  }
};

/// Converts mfuse.reshape -> torch.aten.reshape.
/// Shape is derived from reshape result type. A dynamic dim is mapped to -1.
class ConvertMfuseReshape : public mlir::OpConversionPattern<mlir::mfuse::ReshapeOp> {
 public:
  using OpConversionPattern::OpConversionPattern;

  mlir::LogicalResult matchAndRewrite(mlir::mfuse::ReshapeOp op, OpAdaptor adaptor,
                                      mlir::ConversionPatternRewriter &rewriter) const override {
    llvm::SmallVector<mlir::Value> shapeValues;
    auto resultType = mlir::dyn_cast<mlir::RankedTensorType>(op.getResult().getType());
    if (!resultType) {
      return rewriter.notifyMatchFailure(op, "result must be ranked tensor");
    }
    shapeValues.reserve(resultType.getShape().size());
    std::transform(
      resultType.getShape().begin(), resultType.getShape().end(), std::back_inserter(shapeValues), [&](int64_t d) {
        return rewriter.create<TorchD::ConstantIntOp>(op.getLoc(), d == mlir::ShapedType::kDynamic ? -1 : d);
      });

    mlir::Type torchResultType = getTypeConverter()->convertType(resultType);
    if (!torchResultType) return mlir::failure();

    mlir::Value input = adaptor.getInput();
    auto listType = TorchD::ListType::get(op.getContext(), TorchD::IntType::get(op.getContext()));
    auto shapeList = rewriter.create<TorchD::PrimListConstructOp>(op.getLoc(), listType, shapeValues);
    rewriter.replaceOpWithNewOp<TorchD::AtenReshapeOp>(op, torchResultType, input, shapeList);
    return mlir::success();
  }
};

// Common logic for binary op conversion
// Returns a tuple of (lhs, processedRhs, isRhsScalar, torchResultType) if successful
// Returns failure otherwise
template <typename SourceOp>
static mlir::FailureOr<std::tuple<mlir::Value, mlir::Value, bool, mlir::Type>> convertBinaryOpCommon(
  SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter,
  const OpConversionPattern<SourceOp> *pattern) {
  auto operands = adaptor.getOperands();
  constexpr size_t kBinaryOpNumOperands = 2;
  if (operands.size() != kBinaryOpNumOperands) {
    return rewriter.notifyMatchFailure(op, "binary op must have 2 operands");
  }
  constexpr size_t kLhsIndex = 0;
  constexpr size_t kRhsIndex = 1;
  mlir::Value lhs = operands[kLhsIndex];
  mlir::Value rhs = operands[kRhsIndex];

  auto torchResultType = pattern->getTypeConverter()->convertType(op.getResult().getType());
  if (!torchResultType) {
    return mlir::failure();
  }
  bool isRhsScalar = mlir::isa<TorchD::FloatType>(rhs.getType()) || mlir::isa<TorchD::IntType>(rhs.getType());
  return std::make_tuple(lhs, rhs, isRhsScalar, torchResultType);
}

// Convert binary ops with tensor and scalar operands to corresponding torch ops (without alpha parameter)
template <typename SourceOp, typename TargetTensorOp, typename TargetScalarOp>
class ConvertBinaryOpPattern : public OpConversionPattern<SourceOp> {
 public:
  using OpConversionPattern<SourceOp>::OpConversionPattern;

  LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                                ConversionPatternRewriter &rewriter) const override {
    auto commonResult = convertBinaryOpCommon(op, adaptor, rewriter, this);
    if (failed(commonResult)) {
      return failure();
    }

    auto [lhs, processedRhs, isRhsScalar, torchResultType] = commonResult.value();

    if (!isRhsScalar) {
      rewriter.replaceOpWithNewOp<TargetTensorOp>(op, torchResultType, lhs, processedRhs);
    } else {
      rewriter.replaceOpWithNewOp<TargetScalarOp>(op, torchResultType, lhs, processedRhs);
    }
    return mlir::success();
  }
};

// Convert binary ops with tensor and scalar operands to corresponding torch ops (with alpha parameter)
template <typename SourceOp, typename TargetTensorOp, typename TargetScalarOp>
class ConvertBinaryOpWithAlphaPattern : public OpConversionPattern<SourceOp> {
 public:
  using OpConversionPattern<SourceOp>::OpConversionPattern;

  LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                                ConversionPatternRewriter &rewriter) const override {
    auto commonResult = convertBinaryOpCommon(op, adaptor, rewriter, this);
    if (failed(commonResult)) {
      return failure();
    }

    auto [lhs, processedRhs, isRhsScalar, torchResultType] = commonResult.value();

    // Add and sub have alpha input
    constexpr int64_t kAlphaOne = 1;
    auto alpha = rewriter.create<TorchD::ConstantIntOp>(op->getLoc(), rewriter.getI64IntegerAttr(kAlphaOne));
    if (!isRhsScalar) {
      rewriter.replaceOpWithNewOp<TargetTensorOp>(op, torchResultType, lhs, processedRhs, alpha);
    } else {
      rewriter.replaceOpWithNewOp<TargetScalarOp>(op, torchResultType, lhs, processedRhs, alpha);
    }
    return mlir::success();
  }
};

// ============================================================================
// =   Please keep the patterns in alphabetical order by operator name   =
// ============================================================================

static void populateMfuseMetaToTorchCustomPatterns(TypeConverter &converter, RewritePatternSet &patterns,
                                                   llvm::StringRef kernelGenerator) {
  MLIRContext *context = patterns.getContext();
  patterns.add<ConvertMfuseBroadcastTo>(converter, context);
  patterns.add<ConvertMfuseCast>(converter, context);
  patterns.add<ConvertMfuseAclnnConv2D>(converter, context);
  patterns.add<ConvertMfuseAclnnConv2DWithBias>(converter, context);
  patterns.add<ConvertMfuseFull>(converter, context);
  patterns.add<ConvertMfuseMatmul>(converter, context, kernelGenerator);
  patterns.add<ConvertMfuseMatmulWithBias>(converter, context, kernelGenerator);
  patterns.add<ConvertMfusePermute>(converter, context);
  patterns.add<ConvertMfuseReduceMean>(converter, context);
  patterns.add<ConvertMfuseReduceSum>(converter, context);
  patterns.add<ConvertMfuseReshape>(converter, context);
}

// Convert part of mfuse binary ops support both tensor and scalar operands to corresponding torch tensor scalar ops.
static void populateMfusePartBinaryOpToTorchTensorScalarPatterns(TypeConverter &converter,
                                                                 RewritePatternSet &patterns) {
  MLIRContext *ctx = patterns.getContext();
  // Use ConvertBinaryOpWithAlphaPattern for add and sub operations that require alpha parameter
  patterns.add<ConvertBinaryOpWithAlphaPattern<mfuse::AddOp, TorchD::AtenAddTensorOp, TorchD::AtenAddScalarOp>,
               ConvertBinaryOpWithAlphaPattern<mfuse::SubOp, TorchD::AtenSubTensorOp, TorchD::AtenSubScalarOp>>(
    converter, ctx);
  // Use ConvertBinaryOpPattern for other operations that don't require alpha parameter
  patterns.add<ConvertBinaryOpPattern<mfuse::DivOp, TorchD::AtenDivTensorOp, TorchD::AtenDivScalarOp>,
               ConvertBinaryOpPattern<mfuse::EqOp, TorchD::AtenEqTensorOp, TorchD::AtenEqScalarOp>,
               ConvertBinaryOpPattern<mfuse::GeOp, TorchD::AtenGeTensorOp, TorchD::AtenGeScalarOp>,
               ConvertBinaryOpPattern<mfuse::GtOp, TorchD::AtenGtTensorOp, TorchD::AtenGtScalarOp>,
               ConvertBinaryOpPattern<mfuse::LeOp, TorchD::AtenLeTensorOp, TorchD::AtenLeScalarOp>,
               ConvertBinaryOpPattern<mfuse::LtOp, TorchD::AtenLtTensorOp, TorchD::AtenLtScalarOp>,
               ConvertBinaryOpPattern<mfuse::MulOp, TorchD::AtenMulTensorOp, TorchD::AtenMulScalarOp>,
               ConvertBinaryOpPattern<mfuse::NeOp, TorchD::AtenNeTensorOp, TorchD::AtenNeScalarOp>,
               ConvertBinaryOpPattern<mfuse::PowOp, TorchD::AtenPowTensorTensorOp, TorchD::AtenPowTensorScalarOp>>(
    converter, ctx);
}
}  // namespace

void populateMfuseMetaToTorchConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns,
                                                llvm::StringRef kernelGenerator) {
  populateMfuseMetaToTorchCustomPatterns(converter, patterns, kernelGenerator);
  populateMfusePartBinaryOpToTorchTensorScalarPatterns(converter, patterns);
}

}  // namespace mlir