/**
 * Copyright 2026 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "mfusion/Dialect/Mfuse/IR/Mfuse.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/TypeID.h"

namespace {

struct MfuseUtCreateReshapePass
    : public mlir::PassWrapper<MfuseUtCreateReshapePass, mlir::OperationPass<mlir::func::FuncOp>> {
  // cppcheck-suppress unknownMacro
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MfuseUtCreateReshapePass)

  mlir::StringRef getArgument() const final { return "mfuse-ut-create-reshape"; }
  mlir::StringRef getDescription() const final {
    return "UT-only pass: create mfuse.reshape via C++ builder to trigger symbolic inference";
  }

  void getDependentDialects(mlir::DialectRegistry &registry) const override {
    registry.insert<mlir::mfuse::MfuseDialect>();
    registry.insert<mlir::func::FuncDialect>();
  }

  void runOnOperation() override {
    mlir::func::FuncOp func = getOperation();
    auto marker = func->getAttrOfType<mlir::BoolAttr>("mfuse.ut_create_reshape");
    if (!marker || !marker.getValue()) {
      return;
    }
    if (func.getNumArguments() == 0) {
      return;
    }

    auto inputType = mlir::dyn_cast<mlir::RankedTensorType>(func.getArgument(0).getType());
    if (!inputType || inputType.getRank() != 2) {
      return;
    }
    int64_t d0 = inputType.getShape()[0];
    int64_t d1 = inputType.getShape()[1];
    if (d0 == mlir::ShapedType::kDynamic || d1 != mlir::ShapedType::kDynamic) {
      return;
    }

    auto outType = mlir::RankedTensorType::get({mlir::ShapedType::kDynamic, d0}, inputType.getElementType());
    mlir::Block &entry = func.getBody().front();
    mlir::OpBuilder builder(func.getContext());
    builder.setInsertionPoint(entry.getTerminator());
    (void)builder.create<mlir::mfuse::ReshapeOp>(func.getLoc(), outType, func.getArgument(0));
  }
};

static mlir::PassRegistration<MfuseUtCreateReshapePass> passReg;
}  // namespace