//===- TransformOps.cpp - Transform dialect operations --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformOps.h"

#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/CSE.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include <optional>

#define DEBUG_TYPE "transform-dialect"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")

#define DEBUG_TYPE_MATCHER "transform-matcher"
#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)

using namespace mlir;

static ParseResult parseSequenceOpOperands(
    OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
    Type &rootType,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
    SmallVectorImpl<Type> &extraBindingTypes);
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
                                    Value root, Type rootType,
                                    ValueRange extraBindings,
                                    TypeRange extraBindingTypes);
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
                                     ArrayAttr matchers, ArrayAttr actions);
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
                                            ArrayAttr &matchers,
                                            ArrayAttr &actions);

/// Helper function to check if the given transform op is contained in (or
/// equal to) the given payload target op. In that case, an error is returned.
/// Transforming transform IR that is currently executing is generally unsafe.
static DiagnosedSilenceableFailure
ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
                                     Operation *payload) {
  Operation *transformAncestor = transform.getOperation();
  while (transformAncestor) {
    if (transformAncestor == payload) {
      DiagnosedDefiniteFailure diag =
          transform.emitDefiniteFailure()
          << "cannot apply transform to itself (or one of its ancestors)";
      diag.attachNote(payload->getLoc()) << "target payload op";
      return diag;
    }
    transformAncestor = transformAncestor->getParentOp();
  }
  return DiagnosedSilenceableFailure::success();
}

#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"

//===----------------------------------------------------------------------===//
// AlternativesOp
//===----------------------------------------------------------------------===//

OperandRange
transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
  if (!point.isParent() && getOperation()->getNumOperands() == 1)
    return getOperation()->getOperands();
  return OperandRange(getOperation()->operand_end(),
                      getOperation()->operand_end());
}

void transform::AlternativesOp::getSuccessorRegions(
    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
  for (Region &alternative : llvm::drop_begin(
           getAlternatives(),
           point.isParent() ? 0
                            : point.getRegionOrNull()->getRegionNumber() + 1)) {
    regions.emplace_back(&alternative, !getOperands().empty()
                                           ? alternative.getArguments()
                                           : Block::BlockArgListType());
  }
  if (!point.isParent())
    regions.emplace_back(getOperation()->getResults());
}

void transform::AlternativesOp::getRegionInvocationBounds(
    ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
  (void)operands;
  // The region corresponding to the first alternative is always executed, the
  // remaining may or may not be executed.
  bounds.reserve(getNumRegions());
  bounds.emplace_back(1, 1);
  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
}

static void forwardEmptyOperands(Block *block, transform::TransformState &state,
                                 transform::TransformResults &results) {
  for (const auto &res : block->getParentOp()->getOpResults())
    results.set(res, {});
}

DiagnosedSilenceableFailure
transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
                                 transform::TransformResults &results,
                                 transform::TransformState &state) {
  SmallVector<Operation *> originals;
  if (Value scopeHandle = getScope())
    llvm::append_range(originals, state.getPayloadOps(scopeHandle));
  else
    originals.push_back(state.getTopLevel());

  for (Operation *original : originals) {
    if (original->isAncestor(getOperation())) {
      auto diag = emitDefiniteFailure()
                  << "scope must not contain the transforms being applied";
      diag.attachNote(original->getLoc()) << "scope";
      return diag;
    }
    if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
      auto diag = emitDefiniteFailure()
                  << "only isolated-from-above ops can be alternative scopes";
      diag.attachNote(original->getLoc()) << "scope";
      return diag;
    }
  }

  for (Region &reg : getAlternatives()) {
    // Clone the scope operations and make the transforms in this alternative
    // region apply to them by virtue of mapping the block argument (the only
    // visible handle) to the cloned scope operations. This effectively prevents
    // the transformation from accessing any IR outside the scope.
    auto scope = state.make_region_scope(reg);
    auto clones = llvm::to_vector(
        llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
    auto deleteClones = llvm::make_scope_exit([&] {
      for (Operation *clone : clones)
        clone->erase();
    });
    if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
      return DiagnosedSilenceableFailure::definiteFailure();

    bool failed = false;
    for (Operation &transform : reg.front().without_terminator()) {
      DiagnosedSilenceableFailure result =
          state.applyTransform(cast<TransformOpInterface>(transform));
      if (result.isSilenceableFailure()) {
        LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
                          << "\n");
        failed = true;
        break;
      }

      if (::mlir::failed(result.silence()))
        return DiagnosedSilenceableFailure::definiteFailure();
    }

    // If all operations in the given alternative succeeded, no need to consider
    // the rest. Replace the original scoping operation with the clone on which
    // the transformations were performed.
    if (!failed) {
      // We will be using the clones, so cancel their scheduled deletion.
      deleteClones.release();
      TrackingListener listener(state, *this);
      IRRewriter rewriter(getContext(), &listener);
      for (const auto &kvp : llvm::zip(originals, clones)) {
        Operation *original = std::get<0>(kvp);
        Operation *clone = std::get<1>(kvp);
        original->getBlock()->getOperations().insert(original->getIterator(),
                                                     clone);
        rewriter.replaceOp(original, clone->getResults());
      }
      detail::forwardTerminatorOperands(&reg.front(), state, results);
      return DiagnosedSilenceableFailure::success();
    }
  }
  return emitSilenceableError() << "all alternatives failed";
}

void transform::AlternativesOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  consumesHandle(getOperation()->getOpOperands(), effects);
  producesHandle(getOperation()->getOpResults(), effects);
  for (Region *region : getRegions()) {
    if (!region->empty())
      producesHandle(region->front().getArguments(), effects);
  }
  modifiesPayload(effects);
}

LogicalResult transform::AlternativesOp::verify() {
  for (Region &alternative : getAlternatives()) {
    Block &block = alternative.front();
    Operation *terminator = block.getTerminator();
    if (terminator->getOperands().getTypes() != getResults().getTypes()) {
      InFlightDiagnostic diag = emitOpError()
                                << "expects terminator operands to have the "
                                   "same type as results of the operation";
      diag.attachNote(terminator->getLoc()) << "terminator";
      return diag;
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
// AnnotateOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
  SmallVector<Operation *> targets =
      llvm::to_vector(state.getPayloadOps(getTarget()));

  Attribute attr = UnitAttr::get(getContext());
  if (auto paramH = getParam()) {
    ArrayRef<Attribute> params = state.getParams(paramH);
    if (params.size() != 1) {
      if (targets.size() != params.size()) {
        return emitSilenceableError()
               << "parameter and target have different payload lengths ("
               << params.size() << " vs " << targets.size() << ")";
      }
      for (auto &&[target, attr] : llvm::zip_equal(targets, params))
        target->setAttr(getName(), attr);
      return DiagnosedSilenceableFailure::success();
    }
    attr = params[0];
  }
  for (auto *target : targets)
    target->setAttr(getName(), attr);
  return DiagnosedSilenceableFailure::success();
}

void transform::AnnotateOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  onlyReadsHandle(getTargetMutable(), effects);
  onlyReadsHandle(getParamMutable(), effects);
  modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// ApplyCommonSubexpressionEliminationOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
    transform::TransformRewriter &rewriter, Operation *target,
    ApplyToEachResultList &results, transform::TransformState &state) {
  // Make sure that this transform is not applied to itself. Modifying the
  // transform IR while it is being interpreted is generally dangerous.
  DiagnosedSilenceableFailure payloadCheck =
      ensurePayloadIsSeparateFromTransform(*this, target);
  if (!payloadCheck.succeeded())
    return payloadCheck;

  DominanceInfo domInfo;
  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
  return DiagnosedSilenceableFailure::success();
}

void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  transform::onlyReadsHandle(getTargetMutable(), effects);
  transform::modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// ApplyDeadCodeEliminationOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
    transform::TransformRewriter &rewriter, Operation *target,
    ApplyToEachResultList &results, transform::TransformState &state) {
  // Make sure that this transform is not applied to itself. Modifying the
  // transform IR while it is being interpreted is generally dangerous.
  DiagnosedSilenceableFailure payloadCheck =
      ensurePayloadIsSeparateFromTransform(*this, target);
  if (!payloadCheck.succeeded())
    return payloadCheck;

  // Maintain a worklist of potentially dead ops.
  SetVector<Operation *> worklist;

  // Helper function that adds all defining ops of used values (operands and
  // operands of nested ops).
  auto addDefiningOpsToWorklist = [&](Operation *op) {
    op->walk([&](Operation *op) {
      for (Value v : op->getOperands())
        if (Operation *defOp = v.getDefiningOp())
          if (target->isProperAncestor(defOp))
            worklist.insert(defOp);
    });
  };

  // Helper function that erases an op.
  auto eraseOp = [&](Operation *op) {
    // Remove op and nested ops from the worklist.
    op->walk([&](Operation *op) {
      const auto *it = llvm::find(worklist, op);
      if (it != worklist.end())
        worklist.erase(it);
    });
    rewriter.eraseOp(op);
  };

  // Initial walk over the IR.
  target->walk<WalkOrder::PostOrder>([&](Operation *op) {
    if (op != target && isOpTriviallyDead(op)) {
      addDefiningOpsToWorklist(op);
      eraseOp(op);
    }
  });

  // Erase all ops that have become dead.
  while (!worklist.empty()) {
    Operation *op = worklist.pop_back_val();
    if (!isOpTriviallyDead(op))
      continue;
    addDefiningOpsToWorklist(op);
    eraseOp(op);
  }

  return DiagnosedSilenceableFailure::success();
}

void transform::ApplyDeadCodeEliminationOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  transform::onlyReadsHandle(getTargetMutable(), effects);
  transform::modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// ApplyPatternsOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
    transform::TransformRewriter &rewriter, Operation *target,
    ApplyToEachResultList &results, transform::TransformState &state) {
  // Make sure that this transform is not applied to itself. Modifying the
  // transform IR while it is being interpreted is generally dangerous. Even
  // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
  // performs many additional simplifications such as dead code elimination.
  DiagnosedSilenceableFailure payloadCheck =
      ensurePayloadIsSeparateFromTransform(*this, target);
  if (!payloadCheck.succeeded())
    return payloadCheck;

  // Gather all specified patterns.
  MLIRContext *ctx = target->getContext();
  RewritePatternSet patterns(ctx);
  if (!getRegion().empty()) {
    for (Operation &op : getRegion().front()) {
      cast<transform::PatternDescriptorOpInterface>(&op)
          .populatePatternsWithState(patterns, state);
    }
  }

  // Configure the GreedyPatternRewriteDriver.
  GreedyRewriteConfig config;
  config.listener =
      static_cast<RewriterBase::Listener *>(rewriter.getListener());
  FrozenRewritePatternSet frozenPatterns(std::move(patterns));

  config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
                             ? GreedyRewriteConfig::kNoLimit
                             : getMaxIterations();
  config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
                              ? GreedyRewriteConfig::kNoLimit
                              : getMaxNumRewrites();

  // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
  // was requested, apply the greedy pattern rewrite only once. (The greedy
  // pattern rewrite driver already iterates to a fixpoint internally.)
  bool cseChanged = false;
  // One or two iterations should be sufficient. Stop iterating after a certain
  // threshold to make debugging easier.
  static const int64_t kNumMaxIterations = 50;
  int64_t iteration = 0;
  do {
    LogicalResult result = failure();
    if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
      // Op is isolated from above. Apply patterns and also perform region
      // simplification.
      result = applyPatternsAndFoldGreedily(target, frozenPatterns, config);
    } else {
      // Manually gather list of ops because the other
      // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
      // from above. This way, patterns can be applied to ops that are not
      // isolated from above. Regions are not being simplified. Furthermore,
      // only a single greedy rewrite iteration is performed.
      SmallVector<Operation *> ops;
      target->walk([&](Operation *nestedOp) {
        if (target != nestedOp)
          ops.push_back(nestedOp);
      });
      result = applyOpPatternsAndFold(ops, frozenPatterns, config);
    }

    // A failure typically indicates that the pattern application did not
    // converge.
    if (failed(result)) {
      return emitSilenceableFailure(target)
             << "greedy pattern application failed";
    }

    if (getApplyCse()) {
      DominanceInfo domInfo;
      mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
                                          &cseChanged);
    }
  } while (cseChanged && ++iteration < kNumMaxIterations);

  if (iteration == kNumMaxIterations)
    return emitDefiniteFailure() << "fixpoint iteration did not converge";

  return DiagnosedSilenceableFailure::success();
}

LogicalResult transform::ApplyPatternsOp::verify() {
  if (!getRegion().empty()) {
    for (Operation &op : getRegion().front()) {
      if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
        InFlightDiagnostic diag = emitOpError()
                                  << "expected children ops to implement "
                                     "PatternDescriptorOpInterface";
        diag.attachNote(op.getLoc()) << "op without interface";
        return diag;
      }
    }
  }
  return success();
}

void transform::ApplyPatternsOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  transform::onlyReadsHandle(getTargetMutable(), effects);
  transform::modifiesPayload(effects);
}

void transform::ApplyPatternsOp::build(
    OpBuilder &builder, OperationState &result, Value target,
    function_ref<void(OpBuilder &, Location)> bodyBuilder) {
  result.addOperands(target);

  OpBuilder::InsertionGuard g(builder);
  Region *region = result.addRegion();
  builder.createBlock(region);
  if (bodyBuilder)
    bodyBuilder(builder, result.location);
}

//===----------------------------------------------------------------------===//
// ApplyCanonicalizationPatternsOp
//===----------------------------------------------------------------------===//

void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
    RewritePatternSet &patterns) {
  MLIRContext *ctx = patterns.getContext();
  for (Dialect *dialect : ctx->getLoadedDialects())
    dialect->getCanonicalizationPatterns(patterns);
  for (RegisteredOperationName op : ctx->getRegisteredOperations())
    op.getCanonicalizationPatterns(patterns, ctx);
}

//===----------------------------------------------------------------------===//
// ApplyConversionPatternsOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
    transform::TransformRewriter &rewriter,
    transform::TransformResults &results, transform::TransformState &state) {
  MLIRContext *ctx = getContext();

  // Instantiate the default type converter if a type converter builder is
  // specified.
  std::unique_ptr<TypeConverter> defaultTypeConverter;
  transform::TypeConverterBuilderOpInterface typeConverterBuilder =
      getDefaultTypeConverter();
  if (typeConverterBuilder)
    defaultTypeConverter = typeConverterBuilder.getTypeConverter();

  // Configure conversion target.
  ConversionTarget conversionTarget(*getContext());
  if (getLegalOps())
    for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
      conversionTarget.addLegalOp(
          OperationName(cast<StringAttr>(attr).getValue(), ctx));
  if (getIllegalOps())
    for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
      conversionTarget.addIllegalOp(
          OperationName(cast<StringAttr>(attr).getValue(), ctx));
  if (getLegalDialects())
    for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
      conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
  if (getIllegalDialects())
    for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
      conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());

  // Gather all specified patterns.
  RewritePatternSet patterns(ctx);
  // Need to keep the converters alive until after pattern application because
  // the patterns take a reference to an object that would otherwise get out of
  // scope.
  SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
  if (!getPatterns().empty()) {
    for (Operation &op : getPatterns().front()) {
      auto descriptor =
          cast<transform::ConversionPatternDescriptorOpInterface>(&op);

      // Check if this pattern set specifies a type converter.
      std::unique_ptr<TypeConverter> typeConverter =
          descriptor.getTypeConverter();
      TypeConverter *converter = nullptr;
      if (typeConverter) {
        keepAliveConverters.emplace_back(std::move(typeConverter));
        converter = keepAliveConverters.back().get();
      } else {
        // No type converter specified: Use the default type converter.
        if (!defaultTypeConverter) {
          auto diag = emitDefiniteFailure()
                      << "pattern descriptor does not specify type "
                         "converter and apply_conversion_patterns op has "
                         "no default type converter";
          diag.attachNote(op.getLoc()) << "pattern descriptor op";
          return diag;
        }
        converter = defaultTypeConverter.get();
      }

      // Add descriptor-specific updates to the conversion target, which may
      // depend on the final type converter. In structural converters, the
      // legality of types dictates the dynamic legality of an operation.
      descriptor.populateConversionTargetRules(*converter, conversionTarget);

      descriptor.populatePatterns(*converter, patterns);
    }
  }

  // Attach a tracking listener if handles should be preserved. We configure the
  // listener to allow op replacements with different names, as conversion
  // patterns typically replace ops with replacement ops that have a different
  // name.
  TrackingListenerConfig trackingConfig;
  trackingConfig.requireMatchingReplacementOpName = false;
  ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
  ConversionConfig conversionConfig;
  if (getPreserveHandles())
    conversionConfig.listener = &trackingListener;

  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
  for (Operation *target : state.getPayloadOps(getTarget())) {
    // Make sure that this transform is not applied to itself. Modifying the
    // transform IR while it is being interpreted is generally dangerous.
    DiagnosedSilenceableFailure payloadCheck =
        ensurePayloadIsSeparateFromTransform(*this, target);
    if (!payloadCheck.succeeded())
      return payloadCheck;

    LogicalResult status = failure();
    if (getPartialConversion()) {
      status = applyPartialConversion(target, conversionTarget, frozenPatterns,
                                      conversionConfig);
    } else {
      status = applyFullConversion(target, conversionTarget, frozenPatterns,
                                   conversionConfig);
    }

    // Check dialect conversion state.
    DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
    if (failed(status)) {
      diag = emitSilenceableError() << "dialect conversion failed";
      diag.attachNote(target->getLoc()) << "target op";
    }

    // Check tracking listener error state.
    DiagnosedSilenceableFailure trackingFailure =
        trackingListener.checkAndResetError();
    if (!trackingFailure.succeeded()) {
      if (diag.succeeded()) {
        // Tracking failure is the only failure.
        return trackingFailure;
      } else {
        diag.attachNote() << "tracking listener also failed: "
                          << trackingFailure.getMessage();
        (void)trackingFailure.silence();
      }
    }

    if (!diag.succeeded())
      return diag;
  }

  return DiagnosedSilenceableFailure::success();
}

LogicalResult transform::ApplyConversionPatternsOp::verify() {
  if (getNumRegions() != 1 && getNumRegions() != 2)
    return emitOpError() << "expected 1 or 2 regions";
  if (!getPatterns().empty()) {
    for (Operation &op : getPatterns().front()) {
      if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
        InFlightDiagnostic diag =
            emitOpError() << "expected pattern children ops to implement "
                             "ConversionPatternDescriptorOpInterface";
        diag.attachNote(op.getLoc()) << "op without interface";
        return diag;
      }
    }
  }
  if (getNumRegions() == 2) {
    Region &typeConverterRegion = getRegion(1);
    if (!llvm::hasSingleElement(typeConverterRegion.front()))
      return emitOpError()
             << "expected exactly one op in default type converter region";
    Operation *maybeTypeConverter = &typeConverterRegion.front().front();
    auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
        maybeTypeConverter);
    if (!typeConverterOp) {
      InFlightDiagnostic diag = emitOpError()
                                << "expected default converter child op to "
                                   "implement TypeConverterBuilderOpInterface";
      diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
      return diag;
    }
    // Check default type converter type.
    if (!getPatterns().empty()) {
      for (Operation &op : getPatterns().front()) {
        auto descriptor =
            cast<transform::ConversionPatternDescriptorOpInterface>(&op);
        if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
          return failure();
      }
    }
  }
  return success();
}

void transform::ApplyConversionPatternsOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  if (!getPreserveHandles()) {
    transform::consumesHandle(getTargetMutable(), effects);
  } else {
    transform::onlyReadsHandle(getTargetMutable(), effects);
  }
  transform::modifiesPayload(effects);
}

void transform::ApplyConversionPatternsOp::build(
    OpBuilder &builder, OperationState &result, Value target,
    function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
    function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
  result.addOperands(target);

  {
    OpBuilder::InsertionGuard g(builder);
    Region *region1 = result.addRegion();
    builder.createBlock(region1);
    if (patternsBodyBuilder)
      patternsBodyBuilder(builder, result.location);
  }
  {
    OpBuilder::InsertionGuard g(builder);
    Region *region2 = result.addRegion();
    builder.createBlock(region2);
    if (typeConverterBodyBuilder)
      typeConverterBodyBuilder(builder, result.location);
  }
}

//===----------------------------------------------------------------------===//
// ApplyToLLVMConversionPatternsOp
//===----------------------------------------------------------------------===//

void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
    TypeConverter &typeConverter, RewritePatternSet &patterns) {
  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
  assert(dialect && "expected that dialect is loaded");
  auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
  // ConversionTarget is currently ignored because the enclosing
  // apply_conversion_patterns op sets up its own ConversionTarget.
  ConversionTarget target(*getContext());
  iface->populateConvertToLLVMConversionPatterns(
      target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
}

LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
    transform::TypeConverterBuilderOpInterface builder) {
  if (builder.getTypeConverterType() != "LLVMTypeConverter")
    return emitOpError("expected LLVMTypeConverter");
  return success();
}

LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
  if (!dialect)
    return emitOpError("unknown dialect or dialect not loaded: ")
           << getDialectName();
  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
  if (!iface)
    return emitOpError(
               "dialect does not implement ConvertToLLVMPatternInterface or "
               "extension was not loaded: ")
           << getDialectName();
  return success();
}

//===----------------------------------------------------------------------===//
// ApplyLoopInvariantCodeMotionOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
    transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
    transform::ApplyToEachResultList &results,
    transform::TransformState &state) {
  // Currently, LICM does not remove operations, so we don't need tracking.
  // If this ever changes, add a LICM entry point that takes a rewriter.
  moveLoopInvariantCode(target);
  return DiagnosedSilenceableFailure::success();
}

void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  transform::onlyReadsHandle(getTargetMutable(), effects);
  transform::modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// ApplyRegisteredPassOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
    transform::TransformRewriter &rewriter, Operation *target,
    ApplyToEachResultList &results, transform::TransformState &state) {
  // Make sure that this transform is not applied to itself. Modifying the
  // transform IR while it is being interpreted is generally dangerous. Even
  // more so when applying passes because they may perform a wide range of IR
  // modifications.
  DiagnosedSilenceableFailure payloadCheck =
      ensurePayloadIsSeparateFromTransform(*this, target);
  if (!payloadCheck.succeeded())
    return payloadCheck;

  // Get pass or pass pipeline from registry.
  const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
  if (!info)
    info = PassInfo::lookup(getPassName());
  if (!info)
    return emitDefiniteFailure()
           << "unknown pass or pass pipeline: " << getPassName();

  // Create pass manager and run the pass or pass pipeline.
  PassManager pm(getContext());
  if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
        emitError(msg);
        return failure();
      }))) {
    return emitDefiniteFailure()
           << "failed to add pass or pass pipeline to pipeline: "
           << getPassName();
  }
  if (failed(pm.run(target))) {
    auto diag = emitSilenceableError() << "pass pipeline failed";
    diag.attachNote(target->getLoc()) << "target op";
    return diag;
  }

  results.push_back(target);
  return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
                              Operation *target, ApplyToEachResultList &results,
                              transform::TransformState &state) {
  results.push_back(target);
  return DiagnosedSilenceableFailure::success();
}

void transform::CastOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  onlyReadsPayload(effects);
  onlyReadsHandle(getInputMutable(), effects);
  producesHandle(getOperation()->getOpResults(), effects);
}

bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
  assert(inputs.size() == 1 && "expected one input");
  assert(outputs.size() == 1 && "expected one output");
  return llvm::all_of(
      std::initializer_list<Type>{inputs.front(), outputs.front()},
      llvm::IsaPred<transform::TransformHandleTypeInterface>);
}

//===----------------------------------------------------------------------===//
// CollectMatchingOp
//===----------------------------------------------------------------------===//

/// Applies matcher operations from the given `block` using
/// `blockArgumentMapping` to initialize block arguments. Updates `state`
/// accordingly. If any of the matcher produces a silenceable failure, discards
/// it (printing the content to the debug output stream) and returns failure. If
/// any of the matchers produces a definite failure, reports it and returns
/// failure. If all matchers in the block succeed, populates `mappings` with the
/// payload entities associated with the block terminator operands. Note that
/// `mappings` will be cleared before that.
static DiagnosedSilenceableFailure
matchBlock(Block &block,
           ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
           transform::TransformState &state,
           SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
  assert(block.getParent() && "cannot match using a detached block");
  auto matchScope = state.make_region_scope(*block.getParent());
  if (failed(
          state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
    return DiagnosedSilenceableFailure::definiteFailure();

  for (Operation &match : block.without_terminator()) {
    if (!isa<transform::MatchOpInterface>(match)) {
      return emitDefiniteFailure(match.getLoc())
             << "expected operations in the match part to "
                "implement MatchOpInterface";
    }
    DiagnosedSilenceableFailure diag =
        state.applyTransform(cast<transform::TransformOpInterface>(match));
    if (diag.succeeded())
      continue;

    return diag;
  }

  // Remember the values mapped to the terminator operands so we can
  // forward them to the action.
  ValueRange yieldedValues = block.getTerminator()->getOperands();
  // Our contract with the caller is that the mappings will contain only the
  // newly mapped values, clear the rest.
  mappings.clear();
  transform::detail::prepareValueMappings(mappings, yieldedValues, state);
  return DiagnosedSilenceableFailure::success();
}

/// Returns `true` if both types implement one of the interfaces provided as
/// template parameters.
template <typename... Tys>
static bool implementSameInterface(Type t1, Type t2) {
  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
}

/// Returns `true` if both types implement one of the transform dialect
/// interfaces.
static bool implementSameTransformInterface(Type t1, Type t2) {
  return implementSameInterface<transform::TransformHandleTypeInterface,
                                transform::TransformParamTypeInterface,
                                transform::TransformValueHandleTypeInterface>(
      t1, t2);
}

//===----------------------------------------------------------------------===//
// CollectMatchingOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
                                    transform::TransformResults &results,
                                    transform::TransformState &state) {
  auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
      getOperation(), getMatcher());
  if (matcher.isExternal()) {
    return emitDefiniteFailure()
           << "unresolved external symbol " << getMatcher();
  }

  SmallVector<SmallVector<MappedValue>, 2> rawResults;
  rawResults.resize(getOperation()->getNumResults());
  std::optional<DiagnosedSilenceableFailure> maybeFailure;
  for (Operation *root : state.getPayloadOps(getRoot())) {
    WalkResult walkResult = root->walk([&](Operation *op) {
      DEBUG_MATCHER({
        DBGS_MATCHER() << "matching ";
        op->print(llvm::dbgs(),
                  OpPrintingFlags().assumeVerified().skipRegions());
        llvm::dbgs() << " @" << op << "\n";
      });

      // Try matching.
      SmallVector<SmallVector<MappedValue>> mappings;
      SmallVector<transform::MappedValue> inputMapping({op});
      DiagnosedSilenceableFailure diag = matchBlock(
          matcher.getFunctionBody().front(),
          ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
          mappings);
      if (diag.isDefiniteFailure())
        return WalkResult::interrupt();
      if (diag.isSilenceableFailure()) {
        DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
                                     << " failed: " << diag.getMessage());
        return WalkResult::advance();
      }

      // If succeeded, collect results.
      for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
        if (mapping.size() != 1) {
          maybeFailure.emplace(emitSilenceableError()
                               << "result #" << i << ", associated with "
                               << mapping.size()
                               << " payload objects, expected 1");
          return WalkResult::interrupt();
        }
        rawResults[i].push_back(mapping[0]);
      }
      return WalkResult::advance();
    });
    if (walkResult.wasInterrupted())
      return std::move(*maybeFailure);
    assert(!maybeFailure && "failure set but the walk was not interrupted");

    for (auto &&[opResult, rawResult] :
         llvm::zip_equal(getOperation()->getResults(), rawResults)) {
      results.setMappedValues(opResult, rawResult);
    }
  }
  return DiagnosedSilenceableFailure::success();
}

void transform::CollectMatchingOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  onlyReadsHandle(getRootMutable(), effects);
  producesHandle(getOperation()->getOpResults(), effects);
  onlyReadsPayload(effects);
}

LogicalResult transform::CollectMatchingOp::verifySymbolUses(
    SymbolTableCollection &symbolTable) {
  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
      symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
  if (!matcherSymbol ||
      !isa<TransformOpInterface>(matcherSymbol.getOperation()))
    return emitError() << "unresolved matcher symbol " << getMatcher();

  ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
  if (argumentTypes.size() != 1 ||
      !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
    return emitError()
           << "expected the matcher to take one operation handle argument";
  }
  if (!matcherSymbol.getArgAttr(
          0, transform::TransformDialect::kArgReadOnlyAttrName)) {
    return emitError() << "expected the matcher argument to be marked readonly";
  }

  ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
  if (resultTypes.size() != getOperation()->getNumResults()) {
    return emitError()
           << "expected the matcher to yield as many values as op has results ("
           << getOperation()->getNumResults() << "), got "
           << resultTypes.size();
  }

  for (auto &&[i, matcherType, resultType] :
       llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
    if (implementSameTransformInterface(matcherType, resultType))
      continue;

    return emitError()
           << "mismatching type interfaces for matcher result and op result #"
           << i;
  }

  return success();
}

//===----------------------------------------------------------------------===//
// ForeachMatchOp
//===----------------------------------------------------------------------===//

// This is fine because nothing is actually consumed by this op.
bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }

DiagnosedSilenceableFailure
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
                                 transform::TransformResults &results,
                                 transform::TransformState &state) {
  SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>>
      matchActionPairs;
  matchActionPairs.reserve(getMatchers().size());
  SymbolTableCollection symbolTable;
  for (auto &&[matcher, action] :
       llvm::zip_equal(getMatchers(), getActions())) {
    auto matcherSymbol =
        symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
            getOperation(), cast<SymbolRefAttr>(matcher));
    auto actionSymbol =
        symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
            getOperation(), cast<SymbolRefAttr>(action));
    assert(matcherSymbol && actionSymbol &&
           "unresolved symbols not caught by the verifier");

    if (matcherSymbol.isExternal())
      return emitDefiniteFailure() << "unresolved external symbol " << matcher;
    if (actionSymbol.isExternal())
      return emitDefiniteFailure() << "unresolved external symbol " << action;

    matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
  }

  DiagnosedSilenceableFailure overallDiag =
      DiagnosedSilenceableFailure::success();

  SmallVector<SmallVector<MappedValue>> matchInputMapping;
  SmallVector<SmallVector<MappedValue>> matchOutputMapping;
  SmallVector<SmallVector<MappedValue>> actionResultMapping;
  // Explicitly add the mapping for the first block argument (the op being
  // matched).
  matchInputMapping.emplace_back();
  transform::detail::prepareValueMappings(matchInputMapping,
                                          getForwardedInputs(), state);
  SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
  actionResultMapping.resize(getForwardedOutputs().size());

  for (Operation *root : state.getPayloadOps(getRoot())) {
    WalkResult walkResult = root->walk([&](Operation *op) {
      // If getRestrictRoot is not present, skip over the root op itself so we
      // don't invalidate it.
      if (!getRestrictRoot() && op == root)
        return WalkResult::advance();

      DEBUG_MATCHER({
        DBGS_MATCHER() << "matching ";
        op->print(llvm::dbgs(),
                  OpPrintingFlags().assumeVerified().skipRegions());
        llvm::dbgs() << " @" << op << "\n";
      });

      firstMatchArgument.clear();
      firstMatchArgument.push_back(op);

      // Try all the match/action pairs until the first successful match.
      for (auto [matcher, action] : matchActionPairs) {
        DiagnosedSilenceableFailure diag =
            matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
                       state, matchOutputMapping);
        if (diag.isDefiniteFailure())
          return WalkResult::interrupt();
        if (diag.isSilenceableFailure()) {
          DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
                                       << " failed: " << diag.getMessage());
          continue;
        }

        auto scope = state.make_region_scope(action.getFunctionBody());
        if (failed(state.mapBlockArguments(
                action.getFunctionBody().front().getArguments(),
                matchOutputMapping))) {
          return WalkResult::interrupt();
        }

        for (Operation &transform :
             action.getFunctionBody().front().without_terminator()) {
          DiagnosedSilenceableFailure result =
              state.applyTransform(cast<TransformOpInterface>(transform));
          if (result.isDefiniteFailure())
            return WalkResult::interrupt();
          if (result.isSilenceableFailure()) {
            if (overallDiag.succeeded()) {
              overallDiag = emitSilenceableError() << "actions failed";
            }
            overallDiag.attachNote(action->getLoc())
                << "failed action: " << result.getMessage();
            overallDiag.attachNote(op->getLoc())
                << "when applied to this matching payload";
            (void)result.silence();
            continue;
          }
        }
        if (failed(detail::appendValueMappings(
                MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
                action.getFunctionBody().front().getTerminator()->getOperands(),
                state, getFlattenResults()))) {
          emitDefiniteFailure()
              << "action @" << action.getName()
              << " has results associated with multiple payload entities, "
                 "but flattening was not requested";
          return WalkResult::interrupt();
        }
        break;
      }
      return WalkResult::advance();
    });
    if (walkResult.wasInterrupted())
      return DiagnosedSilenceableFailure::definiteFailure();
  }

  // The root operation should not have been affected, so we can just reassign
  // the payload to the result. Note that we need to consume the root handle to
  // make sure any handles to operations inside, that could have been affected
  // by actions, are invalidated.
  results.set(llvm::cast<OpResult>(getUpdated()),
              state.getPayloadOps(getRoot()));
  for (auto &&[result, mapping] :
       llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
    results.setMappedValues(result, mapping);
  }
  return overallDiag;
}

void transform::ForeachMatchOp::getAsmResultNames(
    OpAsmSetValueNameFn setNameFn) {
  setNameFn(getUpdated(), "updated_root");
  for (Value v : getForwardedOutputs()) {
    setNameFn(v, "yielded");
  }
}

void transform::ForeachMatchOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  // Bail if invalid.
  if (getOperation()->getNumOperands() < 1 ||
      getOperation()->getNumResults() < 1) {
    return modifiesPayload(effects);
  }

  consumesHandle(getRootMutable(), effects);
  onlyReadsHandle(getForwardedInputsMutable(), effects);
  producesHandle(getOperation()->getOpResults(), effects);
  modifiesPayload(effects);
}

/// Parses the comma-separated list of symbol reference pairs of the format
/// `@matcher -> @action`.
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
                                            ArrayAttr &matchers,
                                            ArrayAttr &actions) {
  StringAttr matcher;
  StringAttr action;
  SmallVector<Attribute> matcherList;
  SmallVector<Attribute> actionList;
  do {
    if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
        parser.parseSymbolName(action)) {
      return failure();
    }
    matcherList.push_back(SymbolRefAttr::get(matcher));
    actionList.push_back(SymbolRefAttr::get(action));
  } while (parser.parseOptionalComma().succeeded());

  matchers = parser.getBuilder().getArrayAttr(matcherList);
  actions = parser.getBuilder().getArrayAttr(actionList);
  return success();
}

/// Prints the comma-separated list of symbol reference pairs of the format
/// `@matcher -> @action`.
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
                                     ArrayAttr matchers, ArrayAttr actions) {
  printer.increaseIndent();
  printer.increaseIndent();
  for (auto &&[matcher, action, idx] : llvm::zip_equal(
           matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
    printer.printNewline();
    printer << cast<SymbolRefAttr>(matcher) << " -> "
            << cast<SymbolRefAttr>(action);
    if (idx != matchers.size() - 1)
      printer << ", ";
  }
  printer.decreaseIndent();
  printer.decreaseIndent();
}

LogicalResult transform::ForeachMatchOp::verify() {
  if (getMatchers().size() != getActions().size())
    return emitOpError() << "expected the same number of matchers and actions";
  if (getMatchers().empty())
    return emitOpError() << "expected at least one match/action pair";

  llvm::SmallPtrSet<Attribute, 8> matcherNames;
  for (Attribute name : getMatchers()) {
    if (matcherNames.insert(name).second)
      continue;
    emitWarning() << "matcher " << name
                  << " is used more than once, only the first match will apply";
  }

  return success();
}

/// Checks that the attributes of the function-like operation have correct
/// consumption effect annotations. If `alsoVerifyInternal`, checks for
/// annotations being present even if they can be inferred from the body.
static DiagnosedSilenceableFailure
verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
                                     bool alsoVerifyInternal = false) {
  auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
  llvm::SmallDenseSet<unsigned> consumedArguments;
  if (!op.isExternal()) {
    transform::getConsumedBlockArguments(op.getFunctionBody().front(),
                                         consumedArguments);
  }
  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
    bool isConsumed =
        op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
        nullptr;
    bool isReadOnly =
        op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
        nullptr;
    if (isConsumed && isReadOnly) {
      return transformOp.emitSilenceableError()
             << "argument #" << i << " cannot be both readonly and consumed";
    }
    if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
      return transformOp.emitSilenceableError()
             << "must provide consumed/readonly status for arguments of "
                "external or called ops";
    }
    if (op.isExternal())
      continue;

    if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
      return transformOp.emitSilenceableError()
             << "argument #" << i
             << " is consumed in the body but is not marked as such";
    }
    if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
      // Cannot use op.emitWarning() here as it would attempt to verify the op
      // before printing, resulting in infinite recursion.
      emitWarning(op->getLoc())
          << "op argument #" << i
          << " is not consumed in the body but is marked as consumed";
    }
  }
  return DiagnosedSilenceableFailure::success();
}

LogicalResult transform::ForeachMatchOp::verifySymbolUses(
    SymbolTableCollection &symbolTable) {
  assert(getMatchers().size() == getActions().size());
  auto consumedAttr =
      StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
  for (auto &&[matcher, action] :
       llvm::zip_equal(getMatchers(), getActions())) {
    // Presence and typing.
    auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
        symbolTable.lookupNearestSymbolFrom(getOperation(),
                                            cast<SymbolRefAttr>(matcher)));
    auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
        symbolTable.lookupNearestSymbolFrom(getOperation(),
                                            cast<SymbolRefAttr>(action)));
    if (!matcherSymbol ||
        !isa<TransformOpInterface>(matcherSymbol.getOperation()))
      return emitError() << "unresolved matcher symbol " << matcher;
    if (!actionSymbol ||
        !isa<TransformOpInterface>(actionSymbol.getOperation()))
      return emitError() << "unresolved action symbol " << action;

    if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
                                                    /*emitWarnings=*/false,
                                                    /*alsoVerifyInternal=*/true)
                   .checkAndReport())) {
      return failure();
    }
    if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol,
                                                    /*emitWarnings=*/false,
                                                    /*alsoVerifyInternal=*/true)
                   .checkAndReport())) {
      return failure();
    }

    // Input -> matcher forwarding.
    TypeRange operandTypes = getOperandTypes();
    TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
    if (operandTypes.size() != matcherArguments.size()) {
      InFlightDiagnostic diag =
          emitError() << "the number of operands (" << operandTypes.size()
                      << ") doesn't match the number of matcher arguments ("
                      << matcherArguments.size() << ") for " << matcher;
      diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
      return diag;
    }
    for (auto &&[i, operand, argument] :
         llvm::enumerate(operandTypes, matcherArguments)) {
      if (matcherSymbol.getArgAttr(i, consumedAttr)) {
        InFlightDiagnostic diag =
            emitOpError()
            << "does not expect matcher symbol to consume its operand #" << i;
        diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
        return diag;
      }

      if (implementSameTransformInterface(operand, argument))
        continue;

      InFlightDiagnostic diag =
          emitError()
          << "mismatching type interfaces for operand and matcher argument #"
          << i << " of matcher " << matcher;
      diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
      return diag;
    }

    // Matcher -> action forwarding.
    TypeRange matcherResults = matcherSymbol.getResultTypes();
    TypeRange actionArguments = actionSymbol.getArgumentTypes();
    if (matcherResults.size() != actionArguments.size()) {
      return emitError() << "mismatching number of matcher results and "
                            "action arguments between "
                         << matcher << " (" << matcherResults.size() << ") and "
                         << action << " (" << actionArguments.size() << ")";
    }
    for (auto &&[i, matcherType, actionType] :
         llvm::enumerate(matcherResults, actionArguments)) {
      if (implementSameTransformInterface(matcherType, actionType))
        continue;

      return emitError() << "mismatching type interfaces for matcher result "
                            "and action argument #"
                         << i << "of matcher " << matcher << " and action "
                         << action;
    }

    // Action -> result forwarding.
    TypeRange actionResults = actionSymbol.getResultTypes();
    auto resultTypes = TypeRange(getResultTypes()).drop_front();
    if (actionResults.size() != resultTypes.size()) {
      InFlightDiagnostic diag =
          emitError() << "the number of action results ("
                      << actionResults.size() << ") for " << action
                      << " doesn't match the number of extra op results ("
                      << resultTypes.size() << ")";
      diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
      return diag;
    }
    for (auto &&[i, resultType, actionType] :
         llvm::enumerate(resultTypes, actionResults)) {
      if (implementSameTransformInterface(resultType, actionType))
        continue;

      InFlightDiagnostic diag =
          emitError() << "mismatching type interfaces for action result #" << i
                      << " of action " << action << " and op result";
      diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
      return diag;
    }
  }
  return success();
}

//===----------------------------------------------------------------------===//
// ForeachOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
                            transform::TransformResults &results,
                            transform::TransformState &state) {
  // We store the payloads before executing the body as ops may be removed from
  // the mapping by the TrackingRewriter while iteration is in progress.
  SmallVector<SmallVector<MappedValue>> payloads;
  detail::prepareValueMappings(payloads, getTargets(), state);
  size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
  bool withZipShortest = getWithZipShortest();

  // In case of `zip_shortest`, set the number of iterations to the
  // smallest payload in the targets.
  if (withZipShortest) {
    numIterations =
        llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
                                        const SmallVector<MappedValue> &B) {
          return A.size() < B.size();
        })->size();

    for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
      payloads[argIdx].resize(numIterations);
  }

  // As we will be "zipping" over them, check all payloads have the same size.
  // `zip_shortest` adjusts all payloads to the same size, so skip this check
  // when true.
  for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
       argIdx++) {
    if (payloads[argIdx].size() != numIterations) {
      return emitSilenceableError()
             << "prior targets' payload size (" << numIterations
             << ") differs from payload size (" << payloads[argIdx].size()
             << ") of target " << getTargets()[argIdx];
    }
  }

  // Start iterating, indexing into payloads to obtain the right arguments to
  // call the body with - each slice of payloads at the same argument index
  // corresponding to a tuple to use as the body's block arguments.
  ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
  SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
  for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
    auto scope = state.make_region_scope(getBody());
    // Set up arguments to the region's block.
    for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
      MappedValue argument = payloads[argIdx][iterIdx];
      // Note that each blockArg's handle gets associated with just a single
      // element from the corresponding target's payload.
      if (failed(state.mapBlockArgument(blockArg, {argument})))
        return DiagnosedSilenceableFailure::definiteFailure();
    }

    // Execute loop body.
    for (Operation &transform : getBody().front().without_terminator()) {
      DiagnosedSilenceableFailure result = state.applyTransform(
          llvm::cast<transform::TransformOpInterface>(transform));
      if (!result.succeeded())
        return result;
    }

    // Append yielded payloads to corresponding results from prior iterations.
    OperandRange yieldOperands = getYieldOp().getOperands();
    for (auto &&[result, yieldOperand, resTuple] :
         llvm::zip_equal(getResults(), yieldOperands, zippedResults))
      // NB: each iteration we add any number of ops/vals/params to a result.
      if (isa<TransformHandleTypeInterface>(result.getType()))
        llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
      else if (isa<TransformValueHandleTypeInterface>(result.getType()))
        llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
      else if (isa<TransformParamTypeInterface>(result.getType()))
        llvm::append_range(resTuple, state.getParams(yieldOperand));
      else
        assert(false && "unhandled handle type");
  }

  // Associate the accumulated result payloads to the op's actual results.
  for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
    results.setMappedValues(llvm::cast<OpResult>(result), resPayload);

  return DiagnosedSilenceableFailure::success();
}

void transform::ForeachOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  // NB: this `zip` should be `zip_equal` - while this op's verifier catches
  // arity errors, this method might get called before/in absence of `verify()`.
  for (auto &&[target, blockArg] :
       llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
    BlockArgument blockArgument = blockArg;
    if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
          return isHandleConsumed(blockArgument,
                                  cast<TransformOpInterface>(&op));
        })) {
      consumesHandle(target, effects);
    } else {
      onlyReadsHandle(target, effects);
    }
  }

  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
        return doesModifyPayload(cast<TransformOpInterface>(&op));
      })) {
    modifiesPayload(effects);
  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
               return doesReadPayload(cast<TransformOpInterface>(&op));
             })) {
    onlyReadsPayload(effects);
  }

  producesHandle(getOperation()->getOpResults(), effects);
}

void transform::ForeachOp::getSuccessorRegions(
    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
  Region *bodyRegion = &getBody();
  if (point.isParent()) {
    regions.emplace_back(bodyRegion, bodyRegion->getArguments());
    return;
  }

  // Branch back to the region or the parent.
  assert(point == getBody() && "unexpected region index");
  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
  regions.emplace_back();
}

OperandRange
transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
  // Each block argument handle is mapped to a subset (one op to be precise)
  // of the payload of the corresponding `targets` operand of ForeachOp.
  assert(point == getBody() && "unexpected region index");
  return getOperation()->getOperands();
}

transform::YieldOp transform::ForeachOp::getYieldOp() {
  return cast<transform::YieldOp>(getBody().front().getTerminator());
}

LogicalResult transform::ForeachOp::verify() {
  for (auto [targetOpt, bodyArgOpt] :
       llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
    if (!targetOpt || !bodyArgOpt)
      return emitOpError() << "expects the same number of targets as the body "
                              "has block arguments";
    if (targetOpt.value().getType() != bodyArgOpt.value().getType())
      return emitOpError(
          "expects co-indexed targets and the body's "
          "block arguments to have the same op/value/param type");
  }

  for (auto [resultOpt, yieldOperandOpt] :
       llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
    if (!resultOpt || !yieldOperandOpt)
      return emitOpError() << "expects the same number of results as the "
                              "yield terminator has operands";
    if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
      return emitOpError("expects co-indexed results and yield "
                         "operands to have the same op/value/param type");
  }

  return success();
}

//===----------------------------------------------------------------------===//
// GetParentOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
                              transform::TransformResults &results,
                              transform::TransformState &state) {
  SmallVector<Operation *> parents;
  DenseSet<Operation *> resultSet;
  for (Operation *target : state.getPayloadOps(getTarget())) {
    Operation *parent = target;
    for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
      parent = parent->getParentOp();
      while (parent) {
        bool checkIsolatedFromAbove =
            !getIsolatedFromAbove() ||
            parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
        bool checkOpName = !getOpName().has_value() ||
                           parent->getName().getStringRef() == *getOpName();
        if (checkIsolatedFromAbove && checkOpName)
          break;
        parent = parent->getParentOp();
      }
      if (!parent) {
        if (getAllowEmptyResults()) {
          results.set(llvm::cast<OpResult>(getResult()), parents);
          return DiagnosedSilenceableFailure::success();
        }
        DiagnosedSilenceableFailure diag =
            emitSilenceableError()
            << "could not find a parent op that matches all requirements";
        diag.attachNote(target->getLoc()) << "target op";
        return diag;
      }
    }
    if (getDeduplicate()) {
      if (!resultSet.contains(parent)) {
        parents.push_back(parent);
        resultSet.insert(parent);
      }
    } else {
      parents.push_back(parent);
    }
  }
  results.set(llvm::cast<OpResult>(getResult()), parents);
  return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// GetConsumersOfResult
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
                                       transform::TransformResults &results,
                                       transform::TransformState &state) {
  int64_t resultNumber = getResultNumber();
  auto payloadOps = state.getPayloadOps(getTarget());
  if (std::empty(payloadOps)) {
    results.set(cast<OpResult>(getResult()), {});
    return DiagnosedSilenceableFailure::success();
  }
  if (!llvm::hasSingleElement(payloadOps))
    return emitDefiniteFailure()
           << "handle must be mapped to exactly one payload op";

  Operation *target = *payloadOps.begin();
  if (target->getNumResults() <= resultNumber)
    return emitDefiniteFailure() << "result number overflow";
  results.set(llvm::cast<OpResult>(getResult()),
              llvm::to_vector(target->getResult(resultNumber).getUsers()));
  return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// GetDefiningOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
                                transform::TransformResults &results,
                                transform::TransformState &state) {
  SmallVector<Operation *> definingOps;
  for (Value v : state.getPayloadValues(getTarget())) {
    if (llvm::isa<BlockArgument>(v)) {
      DiagnosedSilenceableFailure diag =
          emitSilenceableError() << "cannot get defining op of block argument";
      diag.attachNote(v.getLoc()) << "target value";
      return diag;
    }
    definingOps.push_back(v.getDefiningOp());
  }
  results.set(llvm::cast<OpResult>(getResult()), definingOps);
  return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// GetProducerOfOperand
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
                                       transform::TransformResults &results,
                                       transform::TransformState &state) {
  int64_t operandNumber = getOperandNumber();
  SmallVector<Operation *> producers;
  for (Operation *target : state.getPayloadOps(getTarget())) {
    Operation *producer =
        target->getNumOperands() <= operandNumber
            ? nullptr
            : target->getOperand(operandNumber).getDefiningOp();
    if (!producer) {
      DiagnosedSilenceableFailure diag =
          emitSilenceableError()
          << "could not find a producer for operand number: " << operandNumber
          << " of " << *target;
      diag.attachNote(target->getLoc()) << "target op";
      return diag;
    }
    producers.push_back(producer);
  }
  results.set(llvm::cast<OpResult>(getResult()), producers);
  return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// GetOperandOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
                               transform::TransformResults &results,
                               transform::TransformState &state) {
  SmallVector<Value> operands;
  for (Operation *target : state.getPayloadOps(getTarget())) {
    SmallVector<int64_t> operandPositions;
    DiagnosedSilenceableFailure diag = expandTargetSpecification(
        getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
        target->getNumOperands(), operandPositions);
    if (diag.isSilenceableFailure()) {
      diag.attachNote(target->getLoc())
          << "while considering positions of this payload operation";
      return diag;
    }
    llvm::append_range(operands,
                       llvm::map_range(operandPositions, [&](int64_t pos) {
                         return target->getOperand(pos);
                       }));
  }
  results.setValues(cast<OpResult>(getResult()), operands);
  return DiagnosedSilenceableFailure::success();
}

LogicalResult transform::GetOperandOp::verify() {
  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
                                    getIsInverted(), getIsAll());
}

//===----------------------------------------------------------------------===//
// GetResultOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
                              transform::TransformResults &results,
                              transform::TransformState &state) {
  SmallVector<Value> opResults;
  for (Operation *target : state.getPayloadOps(getTarget())) {
    SmallVector<int64_t> resultPositions;
    DiagnosedSilenceableFailure diag = expandTargetSpecification(
        getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
        target->getNumResults(), resultPositions);
    if (diag.isSilenceableFailure()) {
      diag.attachNote(target->getLoc())
          << "while considering positions of this payload operation";
      return diag;
    }
    llvm::append_range(opResults,
                       llvm::map_range(resultPositions, [&](int64_t pos) {
                         return target->getResult(pos);
                       }));
  }
  results.setValues(cast<OpResult>(getResult()), opResults);
  return DiagnosedSilenceableFailure::success();
}

LogicalResult transform::GetResultOp::verify() {
  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
                                    getIsInverted(), getIsAll());
}

//===----------------------------------------------------------------------===//
// GetTypeOp
//===----------------------------------------------------------------------===//

void transform::GetTypeOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  onlyReadsHandle(getValueMutable(), effects);
  producesHandle(getOperation()->getOpResults(), effects);
  onlyReadsPayload(effects);
}

DiagnosedSilenceableFailure
transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
                            transform::TransformResults &results,
                            transform::TransformState &state) {
  SmallVector<Attribute> params;
  for (Value value : state.getPayloadValues(getValue())) {
    Type type = value.getType();
    if (getElemental()) {
      if (auto shaped = dyn_cast<ShapedType>(type)) {
        type = shaped.getElementType();
      }
    }
    params.push_back(TypeAttr::get(type));
  }
  results.setParams(cast<OpResult>(getResult()), params);
  return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// IncludeOp
//===----------------------------------------------------------------------===//

/// Applies the transform ops contained in `block`. Maps `results` to the same
/// values as the operands of the block terminator.
static DiagnosedSilenceableFailure
applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
                   transform::TransformState &state,
                   transform::TransformResults &results) {
  // Apply the sequenced ops one by one.
  for (Operation &transform : block.without_terminator()) {
    DiagnosedSilenceableFailure result =
        state.applyTransform(cast<transform::TransformOpInterface>(transform));
    if (result.isDefiniteFailure())
      return result;

    if (result.isSilenceableFailure()) {
      if (mode == transform::FailurePropagationMode::Propagate) {
        // Propagate empty results in case of early exit.
        forwardEmptyOperands(&block, state, results);
        return result;
      }
      (void)result.silence();
    }
  }

  // Forward the operation mapping for values yielded from the sequence to the
  // values produced by the sequence op.
  transform::detail::forwardTerminatorOperands(&block, state, results);
  return DiagnosedSilenceableFailure::success();
}

DiagnosedSilenceableFailure
transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
                            transform::TransformResults &results,
                            transform::TransformState &state) {
  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
      getOperation(), getTarget());
  assert(callee && "unverified reference to unknown symbol");

  if (callee.isExternal())
    return emitDefiniteFailure() << "unresolved external named sequence";

  // Map operands to block arguments.
  SmallVector<SmallVector<MappedValue>> mappings;
  detail::prepareValueMappings(mappings, getOperands(), state);
  auto scope = state.make_region_scope(callee.getBody());
  for (auto &&[arg, map] :
       llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
    if (failed(state.mapBlockArgument(arg, map)))
      return DiagnosedSilenceableFailure::definiteFailure();
  }

  DiagnosedSilenceableFailure result = applySequenceBlock(
      callee.getBody().front(), getFailurePropagationMode(), state, results);
  mappings.clear();
  detail::prepareValueMappings(
      mappings, callee.getBody().front().getTerminator()->getOperands(), state);
  for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
    results.setMappedValues(result, mapping);
  return result;
}

static DiagnosedSilenceableFailure
verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);

void transform::IncludeOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  // Always mark as modifying the payload.
  // TODO: a mechanism to annotate effects on payload. Even when all handles are
  // only read, the payload may still be modified, so we currently stay on the
  // conservative side and always indicate modification. This may prevent some
  // code reordering.
  modifiesPayload(effects);

  // Results are always produced.
  producesHandle(getOperation()->getOpResults(), effects);

  // Adds default effects to operands and results. This will be added if
  // preconditions fail so the trait verifier doesn't complain about missing
  // effects and the real precondition failure is reported later on.
  auto defaultEffects = [&] {
    onlyReadsHandle(getOperation()->getOpOperands(), effects);
  };

  // Bail if the callee is unknown. This may run as part of the verification
  // process before we verified the validity of the callee or of this op.
  auto target =
      getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
  if (!target)
    return defaultEffects();
  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
      getOperation(), getTarget());
  if (!callee)
    return defaultEffects();
  DiagnosedSilenceableFailure earlyVerifierResult =
      verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
  if (!earlyVerifierResult.succeeded()) {
    (void)earlyVerifierResult.silence();
    return defaultEffects();
  }

  for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
    if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
      consumesHandle(getOperation()->getOpOperand(i), effects);
    else
      onlyReadsHandle(getOperation()->getOpOperand(i), effects);
  }
}

LogicalResult
transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
  // Access through indirection and do additional checking because this may be
  // running before the main op verifier.
  auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
  if (!targetAttr)
    return emitOpError() << "expects a 'target' symbol reference attribute";

  auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
      *this, targetAttr);
  if (!target)
    return emitOpError() << "does not reference a named transform sequence";

  FunctionType fnType = target.getFunctionType();
  if (fnType.getNumInputs() != getNumOperands())
    return emitError("incorrect number of operands for callee");

  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
    if (getOperand(i).getType() != fnType.getInput(i)) {
      return emitOpError("operand type mismatch: expected operand type ")
             << fnType.getInput(i) << ", but provided "
             << getOperand(i).getType() << " for operand number " << i;
    }
  }

  if (fnType.getNumResults() != getNumResults())
    return emitError("incorrect number of results for callee");

  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
    Type resultType = getResult(i).getType();
    Type funcType = fnType.getResult(i);
    if (!implementSameTransformInterface(resultType, funcType)) {
      return emitOpError() << "type of result #" << i
                           << " must implement the same transform dialect "
                              "interface as the corresponding callee result";
    }
  }

  return verifyFunctionLikeConsumeAnnotations(
             cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
             /*alsoVerifyInternal=*/true)
      .checkAndReport();
}

//===----------------------------------------------------------------------===//
// MatchOperationEmptyOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
    ::std::optional<::mlir::Operation *> maybeCurrent,
    transform::TransformResults &results, transform::TransformState &state) {
  if (!maybeCurrent.has_value()) {
    DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
    return DiagnosedSilenceableFailure::success();
  }
  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
  return emitSilenceableError() << "operation is not empty";
}

//===----------------------------------------------------------------------===//
// MatchOperationNameOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
    Operation *current, transform::TransformResults &results,
    transform::TransformState &state) {
  StringRef currentOpName = current->getName().getStringRef();
  for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
    if (acceptedAttr.getValue() == currentOpName)
      return DiagnosedSilenceableFailure::success();
  }
  return emitSilenceableError() << "wrong operation name";
}

//===----------------------------------------------------------------------===//
// MatchParamCmpIOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
                                   transform::TransformResults &results,
                                   transform::TransformState &state) {
  auto signedAPIntAsString = [&](const APInt &value) {
    std::string str;
    llvm::raw_string_ostream os(str);
    value.print(os, /*isSigned=*/true);
    return os.str();
  };

  ArrayRef<Attribute> params = state.getParams(getParam());
  ArrayRef<Attribute> references = state.getParams(getReference());

  if (params.size() != references.size()) {
    return emitSilenceableError()
           << "parameters have different payload lengths (" << params.size()
           << " vs " << references.size() << ")";
  }

  for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
    auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
    auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
    if (!intAttr || !refAttr) {
      return emitDefiniteFailure()
             << "non-integer parameter value not expected";
    }
    if (intAttr.getType() != refAttr.getType()) {
      return emitDefiniteFailure()
             << "mismatching integer attribute types in parameter #" << i;
    }
    APInt value = intAttr.getValue();
    APInt refValue = refAttr.getValue();

    // TODO: this copy will not be necessary in C++20.
    int64_t position = i;
    auto reportError = [&](StringRef direction) {
      DiagnosedSilenceableFailure diag =
          emitSilenceableError() << "expected parameter to be " << direction
                                 << " " << signedAPIntAsString(refValue)
                                 << ", got " << signedAPIntAsString(value);
      diag.attachNote(getParam().getLoc())
          << "value # " << position
          << " associated with the parameter defined here";
      return diag;
    };

    switch (getPredicate()) {
    case MatchCmpIPredicate::eq:
      if (value.eq(refValue))
        break;
      return reportError("equal to");
    case MatchCmpIPredicate::ne:
      if (value.ne(refValue))
        break;
      return reportError("not equal to");
    case MatchCmpIPredicate::lt:
      if (value.slt(refValue))
        break;
      return reportError("less than");
    case MatchCmpIPredicate::le:
      if (value.sle(refValue))
        break;
      return reportError("less than or equal to");
    case MatchCmpIPredicate::gt:
      if (value.sgt(refValue))
        break;
      return reportError("greater than");
    case MatchCmpIPredicate::ge:
      if (value.sge(refValue))
        break;
      return reportError("greater than or equal to");
    }
  }
  return DiagnosedSilenceableFailure::success();
}

void transform::MatchParamCmpIOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  onlyReadsHandle(getParamMutable(), effects);
  onlyReadsHandle(getReferenceMutable(), effects);
}

//===----------------------------------------------------------------------===//
// ParamConstantOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
                                  transform::TransformResults &results,
                                  transform::TransformState &state) {
  results.setParams(cast<OpResult>(getParam()), {getValue()});
  return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// MergeHandlesOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
                                 transform::TransformResults &results,
                                 transform::TransformState &state) {
  ValueRange handles = getHandles();
  if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
    SmallVector<Operation *> operations;
    for (Value operand : handles)
      llvm::append_range(operations, state.getPayloadOps(operand));
    if (!getDeduplicate()) {
      results.set(llvm::cast<OpResult>(getResult()), operations);
      return DiagnosedSilenceableFailure::success();
    }

    SetVector<Operation *> uniqued(operations.begin(), operations.end());
    results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
    return DiagnosedSilenceableFailure::success();
  }

  if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
    SmallVector<Attribute> attrs;
    for (Value attribute : handles)
      llvm::append_range(attrs, state.getParams(attribute));
    if (!getDeduplicate()) {
      results.setParams(cast<OpResult>(getResult()), attrs);
      return DiagnosedSilenceableFailure::success();
    }

    SetVector<Attribute> uniqued(attrs.begin(), attrs.end());
    results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
    return DiagnosedSilenceableFailure::success();
  }

  assert(
      llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
      "expected value handle type");
  SmallVector<Value> payloadValues;
  for (Value value : handles)
    llvm::append_range(payloadValues, state.getPayloadValues(value));
  if (!getDeduplicate()) {
    results.setValues(cast<OpResult>(getResult()), payloadValues);
    return DiagnosedSilenceableFailure::success();
  }

  SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end());
  results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
  return DiagnosedSilenceableFailure::success();
}

bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
  // Handles may be the same if deduplicating is enabled.
  return getDeduplicate();
}

void transform::MergeHandlesOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  onlyReadsHandle(getHandlesMutable(), effects);
  producesHandle(getOperation()->getOpResults(), effects);

  // There are no effects on the Payload IR as this is only a handle
  // manipulation.
}

OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
  if (getDeduplicate() || getHandles().size() != 1)
    return {};

  // If deduplication is not required and there is only one operand, it can be
  // used directly instead of merging.
  return getHandles().front();
}

//===----------------------------------------------------------------------===//
// NamedSequenceOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
                                  transform::TransformResults &results,
                                  transform::TransformState &state) {
  if (isExternal())
    return emitDefiniteFailure() << "unresolved external named sequence";

  // Map the entry block argument to the list of operations.
  // Note: this is the same implementation as PossibleTopLevelTransformOp but
  // without attaching the interface / trait since that is tailored to a
  // dangling top-level op that does not get "called".
  auto scope = state.make_region_scope(getBody());
  if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
          state, this->getOperation(), getBody())))
    return DiagnosedSilenceableFailure::definiteFailure();

  return applySequenceBlock(getBody().front(),
                            FailurePropagationMode::Propagate, state, results);
}

void transform::NamedSequenceOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}

ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
                                              OperationState &result) {
  return function_interface_impl::parseFunctionOp(
      parser, result, /*allowVariadic=*/false,
      getFunctionTypeAttrName(result.name),
      [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
         function_interface_impl::VariadicFlag,
         std::string &) { return builder.getFunctionType(inputs, results); },
      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}

void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
  function_interface_impl::printFunctionOp(
      printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
      getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
      getResAttrsAttrName());
}

/// Verifies that a symbol function-like transform dialect operation has the
/// signature and the terminator that have conforming types, i.e., types
/// implementing the same transform dialect type interface. If `allowExternal`
/// is set, allow external symbols (declarations) and don't check the terminator
/// as it may not exist.
static DiagnosedSilenceableFailure
verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
    DiagnosedSilenceableFailure diag =
        emitSilenceableFailure(op)
        << "cannot be defined inside another transform op";
    diag.attachNote(parent.getLoc()) << "ancestor transform op";
    return diag;
  }

  if (op.isExternal() || op.getFunctionBody().empty()) {
    if (allowExternal)
      return DiagnosedSilenceableFailure::success();

    return emitSilenceableFailure(op) << "cannot be external";
  }

  if (op.getFunctionBody().front().empty())
    return emitSilenceableFailure(op) << "expected a non-empty body block";

  Operation *terminator = &op.getFunctionBody().front().back();
  if (!isa<transform::YieldOp>(terminator)) {
    DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
                                       << "expected '"
                                       << transform::YieldOp::getOperationName()
                                       << "' as terminator";
    diag.attachNote(terminator->getLoc()) << "terminator";
    return diag;
  }

  if (terminator->getNumOperands() != op.getResultTypes().size()) {
    return emitSilenceableFailure(terminator)
           << "expected terminator to have as many operands as the parent op "
              "has results";
  }
  for (auto [i, operandType, resultType] : llvm::zip_equal(
           llvm::seq<unsigned>(0, terminator->getNumOperands()),
           terminator->getOperands().getType(), op.getResultTypes())) {
    if (operandType == resultType)
      continue;
    return emitSilenceableFailure(terminator)
           << "the type of the terminator operand #" << i
           << " must match the type of the corresponding parent op result ("
           << operandType << " vs " << resultType << ")";
  }

  return DiagnosedSilenceableFailure::success();
}

/// Verification of a NamedSequenceOp. This does not report the error
/// immediately, so it can be used to check for op's well-formedness before the
/// verifier runs, e.g., during trait verification.
static DiagnosedSilenceableFailure
verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
  if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
    if (!parent->getAttr(
            transform::TransformDialect::kWithNamedSequenceAttrName)) {
      DiagnosedSilenceableFailure diag =
          emitSilenceableFailure(op)
          << "expects the parent symbol table to have the '"
          << transform::TransformDialect::kWithNamedSequenceAttrName
          << "' attribute";
      diag.attachNote(parent->getLoc()) << "symbol table operation";
      return diag;
    }
  }

  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
    DiagnosedSilenceableFailure diag =
        emitSilenceableFailure(op)
        << "cannot be defined inside another transform op";
    diag.attachNote(parent.getLoc()) << "ancestor transform op";
    return diag;
  }

  if (op.isExternal() || op.getBody().empty())
    return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
                                                emitWarnings);

  if (op.getBody().front().empty())
    return emitSilenceableFailure(op) << "expected a non-empty body block";

  Operation *terminator = &op.getBody().front().back();
  if (!isa<transform::YieldOp>(terminator)) {
    DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
                                       << "expected '"
                                       << transform::YieldOp::getOperationName()
                                       << "' as terminator";
    diag.attachNote(terminator->getLoc()) << "terminator";
    return diag;
  }

  if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
    return emitSilenceableFailure(terminator)
           << "expected terminator to have as many operands as the parent op "
              "has results";
  }
  for (auto [i, operandType, resultType] :
       llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
                       terminator->getOperands().getType(),
                       op.getFunctionType().getResults())) {
    if (operandType == resultType)
      continue;
    return emitSilenceableFailure(terminator)
           << "the type of the terminator operand #" << i
           << " must match the type of the corresponding parent op result ("
           << operandType << " vs " << resultType << ")";
  }

  auto funcOp = cast<FunctionOpInterface>(*op);
  DiagnosedSilenceableFailure diag =
      verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
  if (!diag.succeeded())
    return diag;

  return verifyYieldingSingleBlockOp(funcOp,
                                     /*allowExternal=*/true);
}

LogicalResult transform::NamedSequenceOp::verify() {
  // Actual verification happens in a separate function for reusability.
  return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
}

template <typename FnTy>
static void buildSequenceBody(OpBuilder &builder, OperationState &state,
                              Type bbArgType, TypeRange extraBindingTypes,
                              FnTy bodyBuilder) {
  SmallVector<Type> types;
  types.reserve(1 + extraBindingTypes.size());
  types.push_back(bbArgType);
  llvm::append_range(types, extraBindingTypes);

  OpBuilder::InsertionGuard guard(builder);
  Region *region = state.regions.back().get();
  Block *bodyBlock =
      builder.createBlock(region, region->begin(), types,
                          SmallVector<Location>(types.size(), state.location));

  // Populate body.
  builder.setInsertionPointToStart(bodyBlock);
  if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
    bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
  } else {
    bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
                bodyBlock->getArguments().drop_front());
  }
}

void transform::NamedSequenceOp::build(OpBuilder &builder,
                                       OperationState &state, StringRef symName,
                                       Type rootType, TypeRange resultTypes,
                                       SequenceBodyBuilderFn bodyBuilder,
                                       ArrayRef<NamedAttribute> attrs,
                                       ArrayRef<DictionaryAttr> argAttrs) {
  state.addAttribute(SymbolTable::getSymbolAttrName(),
                     builder.getStringAttr(symName));
  state.addAttribute(getFunctionTypeAttrName(state.name),
                     TypeAttr::get(FunctionType::get(builder.getContext(),
                                                     rootType, resultTypes)));
  state.attributes.append(attrs.begin(), attrs.end());
  state.addRegion();

  buildSequenceBody(builder, state, rootType,
                    /*extraBindingTypes=*/TypeRange(), bodyBuilder);
}

//===----------------------------------------------------------------------===//
// NumAssociationsOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
                                    transform::TransformResults &results,
                                    transform::TransformState &state) {
  size_t numAssociations =
      llvm::TypeSwitch<Type, size_t>(getHandle().getType())
          .Case([&](TransformHandleTypeInterface opHandle) {
            return llvm::range_size(state.getPayloadOps(getHandle()));
          })
          .Case([&](TransformValueHandleTypeInterface valueHandle) {
            return llvm::range_size(state.getPayloadValues(getHandle()));
          })
          .Case([&](TransformParamTypeInterface param) {
            return llvm::range_size(state.getParams(getHandle()));
          })
          .Default([](Type) {
            llvm_unreachable("unknown kind of transform dialect type");
            return 0;
          });
  results.setParams(cast<OpResult>(getNum()),
                    rewriter.getI64IntegerAttr(numAssociations));
  return DiagnosedSilenceableFailure::success();
}

LogicalResult transform::NumAssociationsOp::verify() {
  // Verify that the result type accepts an i64 attribute as payload.
  auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
  return resultType
      .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
      .checkAndReport();
}

//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::SelectOp::apply(transform::TransformRewriter &rewriter,
                           transform::TransformResults &results,
                           transform::TransformState &state) {
  SmallVector<Operation *> result;
  auto payloadOps = state.getPayloadOps(getTarget());
  for (Operation *op : payloadOps) {
    if (op->getName().getStringRef() == getOpName())
      result.push_back(op);
  }
  results.set(cast<OpResult>(getResult()), result);
  return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// SplitHandleOp
//===----------------------------------------------------------------------===//

void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
                                     Value target, int64_t numResultHandles) {
  result.addOperands(target);
  result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
}

DiagnosedSilenceableFailure
transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
                                transform::TransformResults &results,
                                transform::TransformState &state) {
  int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
  auto produceNumOpsError = [&]() {
    return emitSilenceableError()
           << getHandle() << " expected to contain " << this->getNumResults()
           << " payload ops but it contains " << numPayloadOps
           << " payload ops";
  };

  // Fail if there are more payload ops than results and no overflow result was
  // specified.
  if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
    return produceNumOpsError();

  // Fail if there are more results than payload ops. Unless:
  // - "fail_on_payload_too_small" is set to "false", or
  // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
  if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
      (numPayloadOps != 0 || !getPassThroughEmptyHandle()))
    return produceNumOpsError();

  // Distribute payload ops.
  SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
  if (getOverflowResult())
    resultHandles[*getOverflowResult()].reserve(numPayloadOps -
                                                getNumResults());
  for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
    int64_t resultNum = en.index();
    if (resultNum >= getNumResults())
      resultNum = *getOverflowResult();
    resultHandles[resultNum].push_back(en.value());
  }

  // Set transform op results.
  for (auto &&it : llvm::enumerate(resultHandles))
    results.set(llvm::cast<OpResult>(getResult(it.index())), it.value());

  return DiagnosedSilenceableFailure::success();
}

void transform::SplitHandleOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  onlyReadsHandle(getHandleMutable(), effects);
  producesHandle(getOperation()->getOpResults(), effects);
  // There are no effects on the Payload IR as this is only a handle
  // manipulation.
}

LogicalResult transform::SplitHandleOp::verify() {
  if (getOverflowResult().has_value() &&
      !(*getOverflowResult() < getNumResults()))
    return emitOpError("overflow_result is not a valid result index");
  return success();
}

//===----------------------------------------------------------------------===//
// ReplicateOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
                              transform::TransformResults &results,
                              transform::TransformState &state) {
  unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
  for (const auto &en : llvm::enumerate(getHandles())) {
    Value handle = en.value();
    if (isa<TransformHandleTypeInterface>(handle.getType())) {
      SmallVector<Operation *> current =
          llvm::to_vector(state.getPayloadOps(handle));
      SmallVector<Operation *> payload;
      payload.reserve(numRepetitions * current.size());
      for (unsigned i = 0; i < numRepetitions; ++i)
        llvm::append_range(payload, current);
      results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
    } else {
      assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
             "expected param type");
      ArrayRef<Attribute> current = state.getParams(handle);
      SmallVector<Attribute> params;
      params.reserve(numRepetitions * current.size());
      for (unsigned i = 0; i < numRepetitions; ++i)
        llvm::append_range(params, current);
      results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
                        params);
    }
  }
  return DiagnosedSilenceableFailure::success();
}

void transform::ReplicateOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  onlyReadsHandle(getPatternMutable(), effects);
  onlyReadsHandle(getHandlesMutable(), effects);
  producesHandle(getOperation()->getOpResults(), effects);
}

//===----------------------------------------------------------------------===//
// SequenceOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
                             transform::TransformResults &results,
                             transform::TransformState &state) {
  // Map the entry block argument to the list of operations.
  auto scope = state.make_region_scope(*getBodyBlock()->getParent());
  if (failed(mapBlockArguments(state)))
    return DiagnosedSilenceableFailure::definiteFailure();

  return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
                            results);
}

static ParseResult parseSequenceOpOperands(
    OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
    Type &rootType,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
    SmallVectorImpl<Type> &extraBindingTypes) {
  OpAsmParser::UnresolvedOperand rootOperand;
  OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
  if (!hasRoot.has_value()) {
    root = std::nullopt;
    return success();
  }
  if (failed(hasRoot.value()))
    return failure();
  root = rootOperand;

  if (succeeded(parser.parseOptionalComma())) {
    if (failed(parser.parseOperandList(extraBindings)))
      return failure();
  }
  if (failed(parser.parseColon()))
    return failure();

  // The paren is truly optional.
  (void)parser.parseOptionalLParen();

  if (failed(parser.parseType(rootType))) {
    return failure();
  }

  if (!extraBindings.empty()) {
    if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
      return failure();
  }

  if (extraBindingTypes.size() != extraBindings.size()) {
    return parser.emitError(parser.getNameLoc(),
                            "expected types to be provided for all operands");
  }

  // The paren is truly optional.
  (void)parser.parseOptionalRParen();
  return success();
}

static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
                                    Value root, Type rootType,
                                    ValueRange extraBindings,
                                    TypeRange extraBindingTypes) {
  if (!root)
    return;

  printer << root;
  bool hasExtras = !extraBindings.empty();
  if (hasExtras) {
    printer << ", ";
    printer.printOperands(extraBindings);
  }

  printer << " : ";
  if (hasExtras)
    printer << "(";

  printer << rootType;
  if (hasExtras) {
    printer << ", ";
    llvm::interleaveComma(extraBindingTypes, printer.getStream());
    printer << ")";
  }
}

/// Returns `true` if the given op operand may be consuming the handle value in
/// the Transform IR. That is, if it may have a Free effect on it.
static bool isValueUsePotentialConsumer(OpOperand &use) {
  // Conservatively assume the effect being present in absence of the interface.
  auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
  if (!iface)
    return true;

  return isHandleConsumed(use.get(), iface);
}

LogicalResult
checkDoubleConsume(Value value,
                   function_ref<InFlightDiagnostic()> reportError) {
  OpOperand *potentialConsumer = nullptr;
  for (OpOperand &use : value.getUses()) {
    if (!isValueUsePotentialConsumer(use))
      continue;

    if (!potentialConsumer) {
      potentialConsumer = &use;
      continue;
    }

    InFlightDiagnostic diag = reportError()
                              << " has more than one potential consumer";
    diag.attachNote(potentialConsumer->getOwner()->getLoc())
        << "used here as operand #" << potentialConsumer->getOperandNumber();
    diag.attachNote(use.getOwner()->getLoc())
        << "used here as operand #" << use.getOperandNumber();
    return diag;
  }

  return success();
}

LogicalResult transform::SequenceOp::verify() {
  assert(getBodyBlock()->getNumArguments() >= 1 &&
         "the number of arguments must have been verified to be more than 1 by "
         "PossibleTopLevelTransformOpTrait");

  if (!getRoot() && !getExtraBindings().empty()) {
    return emitOpError()
           << "does not expect extra operands when used as top-level";
  }

  // Check if a block argument has more than one consuming use.
  for (BlockArgument arg : getBodyBlock()->getArguments()) {
    if (failed(checkDoubleConsume(arg, [this, arg]() {
          return (emitOpError() << "block argument #" << arg.getArgNumber());
        }))) {
      return failure();
    }
  }

  // Check properties of the nested operations they cannot check themselves.
  for (Operation &child : *getBodyBlock()) {
    if (!isa<TransformOpInterface>(child) &&
        &child != &getBodyBlock()->back()) {
      InFlightDiagnostic diag =
          emitOpError()
          << "expected children ops to implement TransformOpInterface";
      diag.attachNote(child.getLoc()) << "op without interface";
      return diag;
    }

    for (OpResult result : child.getResults()) {
      auto report = [&]() {
        return (child.emitError() << "result #" << result.getResultNumber());
      };
      if (failed(checkDoubleConsume(result, report)))
        return failure();
    }
  }

  if (!getBodyBlock()->mightHaveTerminator())
    return emitOpError() << "expects to have a terminator in the body";

  if (getBodyBlock()->getTerminator()->getOperandTypes() !=
      getOperation()->getResultTypes()) {
    InFlightDiagnostic diag = emitOpError()
                              << "expects the types of the terminator operands "
                                 "to match the types of the result";
    diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
    return diag;
  }
  return success();
}

void transform::SequenceOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  getPotentialTopLevelEffects(effects);
}

OperandRange
transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
  assert(point == getBody() && "unexpected region index");
  if (getOperation()->getNumOperands() > 0)
    return getOperation()->getOperands();
  return OperandRange(getOperation()->operand_end(),
                      getOperation()->operand_end());
}

void transform::SequenceOp::getSuccessorRegions(
    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
  if (point.isParent()) {
    Region *bodyRegion = &getBody();
    regions.emplace_back(bodyRegion, getNumOperands() != 0
                                         ? bodyRegion->getArguments()
                                         : Block::BlockArgListType());
    return;
  }

  assert(point == getBody() && "unexpected region index");
  regions.emplace_back(getOperation()->getResults());
}

void transform::SequenceOp::getRegionInvocationBounds(
    ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
  (void)operands;
  bounds.emplace_back(1, 1);
}

void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
                                  TypeRange resultTypes,
                                  FailurePropagationMode failurePropagationMode,
                                  Value root,
                                  SequenceBodyBuilderFn bodyBuilder) {
  build(builder, state, resultTypes, failurePropagationMode, root,
        /*extra_bindings=*/ValueRange());
  Type bbArgType = root.getType();
  buildSequenceBody(builder, state, bbArgType,
                    /*extraBindingTypes=*/TypeRange(), bodyBuilder);
}

void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
                                  TypeRange resultTypes,
                                  FailurePropagationMode failurePropagationMode,
                                  Value root, ValueRange extraBindings,
                                  SequenceBodyBuilderArgsFn bodyBuilder) {
  build(builder, state, resultTypes, failurePropagationMode, root,
        extraBindings);
  buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
                    bodyBuilder);
}

void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
                                  TypeRange resultTypes,
                                  FailurePropagationMode failurePropagationMode,
                                  Type bbArgType,
                                  SequenceBodyBuilderFn bodyBuilder) {
  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
        /*extra_bindings=*/ValueRange());
  buildSequenceBody(builder, state, bbArgType,
                    /*extraBindingTypes=*/TypeRange(), bodyBuilder);
}

void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
                                  TypeRange resultTypes,
                                  FailurePropagationMode failurePropagationMode,
                                  Type bbArgType, TypeRange extraBindingTypes,
                                  SequenceBodyBuilderArgsFn bodyBuilder) {
  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
        /*extra_bindings=*/ValueRange());
  buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
}

//===----------------------------------------------------------------------===//
// PrintOp
//===----------------------------------------------------------------------===//

void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
                               StringRef name) {
  if (!name.empty())
    result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
}

void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
                               Value target, StringRef name) {
  result.addOperands({target});
  build(builder, result, name);
}

DiagnosedSilenceableFailure
transform::PrintOp::apply(transform::TransformRewriter &rewriter,
                          transform::TransformResults &results,
                          transform::TransformState &state) {
  llvm::outs() << "[[[ IR printer: ";
  if (getName().has_value())
    llvm::outs() << *getName() << " ";

  OpPrintingFlags printFlags;
  if (getAssumeVerified().value_or(false))
    printFlags.assumeVerified();
  if (getUseLocalScope().value_or(false))
    printFlags.useLocalScope();
  if (getSkipRegions().value_or(false))
    printFlags.skipRegions();

  if (!getTarget()) {
    llvm::outs() << "top-level ]]]\n";
    state.getTopLevel()->print(llvm::outs(), printFlags);
    llvm::outs() << "\n";
    return DiagnosedSilenceableFailure::success();
  }

  llvm::outs() << "]]]\n";
  for (Operation *target : state.getPayloadOps(getTarget())) {
    target->print(llvm::outs(), printFlags);
    llvm::outs() << "\n";
  }

  return DiagnosedSilenceableFailure::success();
}

void transform::PrintOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  // We don't really care about mutability here, but `getTarget` now
  // unconditionally casts to a specific type before verification could run
  // here.
  if (!getTargetMutable().empty())
    onlyReadsHandle(getTargetMutable()[0], effects);
  onlyReadsPayload(effects);

  // There is no resource for stderr file descriptor, so just declare print
  // writes into the default resource.
  effects.emplace_back(MemoryEffects::Write::get());
}

//===----------------------------------------------------------------------===//
// VerifyOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
                                Operation *target,
                                transform::ApplyToEachResultList &results,
                                transform::TransformState &state) {
  if (failed(::mlir::verify(target))) {
    DiagnosedDefiniteFailure diag = emitDefiniteFailure()
                                    << "failed to verify payload op";
    diag.attachNote(target->getLoc()) << "payload op";
    return diag;
  }
  return DiagnosedSilenceableFailure::success();
}

void transform::VerifyOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  transform::onlyReadsHandle(getTargetMutable(), effects);
}

//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//

void transform::YieldOp::getEffects(
    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
  onlyReadsHandle(getOperandsMutable(), effects);
}