@@ -178,7 +178,7 @@ struct TMTensorBufferizePass
}
if (isa<TensorType>(inputs[0].getType())) {
// Tensor to MemRef cast.
- return builder.create<bufferization::ToBufferOp>(loc, type, inputs[0]);
+ return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
}
llvm_unreachable("only tensor/memref input types supported");
});
@@ -110,7 +110,8 @@ struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
RewritePatternSet patterns(context);
patterns.insert<ScalarLoopOpInterfaceLowerToLoopsPattern>(context);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
return signalPassFailure();
}
}
@@ -164,49 +164,45 @@ class AdjustCallingConventionForReturn
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
+ matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> newOperands;
- for (const auto &vals : adaptor.getOperands()) {
- if (vals.size() == 1) {
- if (isa<Torch::NoneType>(vals[0].getType()))
- continue;
- newOperands.push_back(vals[0]);
- } else if (vals.size() > 1) {
- // The dialect conversion framework inserts unrealized conversion casts
- // to materialize legal types from illegal types. For example, for input
- // IR like
- // %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor,
- // torch.tensor -> !torch.tuple<tensor, tensor>
- // return %1 : !torch.tuple<tensor, tensor>
- // at this stage in the conversion process we'll have something like
- // %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor,
- // !torch.tensor -> !torch.tuple<tensor, tensor>
- // %2 = builtin.unrealized_conversion_cast %1 :
- // !torch.tuple<tensor, tensor> to !torch.tensor
- // %3 = builtin.unrealized_conversion_cast %1 :
- // !torch.tuple<tensor, tensor> to !torch.tensor
- // return %2, %3 : !torch.tensor, !torch.tensor
- //
- // Given (%2, %3) as operands, here we map back to the original
- // torch.prim.TupleConstruct.
- if (vals[0].getDefiningOp() &&
- isa<mlir::UnrealizedConversionCastOp>(vals[0].getDefiningOp())) {
- Value operand = vals[0].getDefiningOp()->getOperand(0);
- if (auto tuple = dyn_cast<Torch::TupleType>(operand.getType())) {
- Location loc = op.getLoc();
- for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
- auto i = rewriter.create<ConstantIntOp>(
- loc, rewriter.getI64IntegerAttr(en.index()));
- newOperands.push_back(rewriter.create<PrimTupleIndexOp>(
- loc, en.value(), operand, i));
- }
- continue;
+ for (auto operand : adaptor.getOperands()) {
+ if (isa<Torch::NoneType>(operand.getType()))
+ continue;
+ // The dialect conversion framework inserts unrealized conversion casts
+ // to materialize legal types from illegal types. For example, for input
+ // IR like
+ // %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor,
+ // torch.tensor -> !torch.tuple<tensor, tensor>
+ // return %1 : !torch.tuple<tensor, tensor>
+ // at this stage in the conversion process we'll have something like
+ // %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor,
+ // !torch.tensor -> !torch.tuple<tensor, tensor>
+ // %2 = builtin.unrealized_conversion_cast %1 :
+ // !torch.tuple<tensor, tensor> to !torch.tensor
+ // %3 = builtin.unrealized_conversion_cast %1 :
+ // !torch.tuple<tensor, tensor> to !torch.tensor
+ // return %2, %3 : !torch.tensor, !torch.tensor
+ //
+ // Given (%2, %3) as operands, here we map back to the original
+ // torch.prim.TupleConstruct.
+ if (operand.getDefiningOp() &&
+ isa<mlir::UnrealizedConversionCastOp>(operand.getDefiningOp())) {
+ Value originalOperand = operand.getDefiningOp()->getOperand(0);
+ if (auto tuple =
+ dyn_cast<Torch::TupleType>(originalOperand.getType())) {
+ Location loc = op.getLoc();
+ for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
+ auto i = rewriter.create<ConstantIntOp>(
+ loc, rewriter.getI64IntegerAttr(en.index()));
+ newOperands.push_back(rewriter.create<PrimTupleIndexOp>(
+ loc, en.value(), originalOperand, i));
}
+ continue;
}
-
- llvm::append_range(newOperands, vals);
}
+ newOperands.push_back(operand);
}
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, newOperands);
@@ -13099,11 +13099,11 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenAsStridedOp>(patterns);
GreedyRewriteConfig config;
- config.setUseTopDownTraversal(true);
- config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
+ config.useTopDownTraversal = true;
+ config.maxIterations = GreedyRewriteConfig::kNoLimit;
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns),
- config))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config))) {
return signalPassFailure();
}
}
@@ -457,8 +457,8 @@ public:
context);
GreedyRewriteConfig config;
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns),
- config))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config))) {
return signalPassFailure();
}
}
@@ -564,8 +564,7 @@ static LogicalResult rewriteMonomorphizedFuncClone(
argsToErase.set(type.index());
}
}
- if (failed(func.eraseArguments(argsToErase)))
- return failure();
+ func.eraseArguments(argsToErase);
return success(!walkResult.wasInterrupted());
}
@@ -50,7 +50,7 @@ using namespace mlir::torch::Torch;
/// would probably want to go through the effort to indirect through the symbol
/// tables to make things clearer.
class FlatSymbolRefLatticeAnchor
- : public GenericLatticeAnchorBase<FlatSymbolRefLatticeAnchor, Operation *> {
+ : public GenericProgramPointBase<FlatSymbolRefLatticeAnchor, Operation *> {
public:
using Base::Base;
void print(raw_ostream &os) const override {
@@ -92,7 +92,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) {
/// unsafe
class InlineGlobalSlotsAnalysisState : public AnalysisState {
public:
- InlineGlobalSlotsAnalysisState(LatticeAnchor point) : AnalysisState(point) {
+ InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) {
(void)setSafe();
}
@@ -132,7 +132,7 @@ class InlineGlobalSlotsAnalysis : public DataFlowAnalysis {
public:
InlineGlobalSlotsAnalysis(DataFlowSolver &solver);
LogicalResult initialize(Operation *top) override;
- LogicalResult visit(ProgramPoint *point) override;
+ LogicalResult visit(ProgramPoint point) override;
private:
/// The local transfer function determining the safety of `value`.
@@ -146,14 +146,14 @@ private:
InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver)
: DataFlowAnalysis(solver) {
- registerAnchorKind<FlatSymbolRefLatticeAnchor>();
+ registerPointKind<FlatSymbolRefLatticeAnchor>();
}
LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
auto walkResult = top->walk([this](Operation *op) {
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
- getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
+ getProgramPoint<FlatSymbolRefLatticeAnchor>(globalSlot));
propagateIfChanged(state,
state->setSafe(globalSlot.getVisibility() !=
SymbolTable::Visibility::Public));
@@ -163,14 +163,14 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
globalSlotSet, globalSlotSet.getSlotAttr());
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
- getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
+ getProgramPoint<FlatSymbolRefLatticeAnchor>(globalSlot));
propagateIfChanged(state, state->setSafe(false));
}
// Save the InitializeGlobalSlotsOp for later referencee
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
initializeGlobalSlotsOp = initialize;
}
- if (failed(visit(getProgramPointAfter(op))))
+ if (failed(visit(ProgramPoint(op))))
return WalkResult::interrupt();
return WalkResult::advance();
@@ -180,11 +180,11 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
return success();
}
-LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint *point) {
- if (point->isBlockStart())
+LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
+ if (llvm::dyn_cast_if_present<Block *>(point))
return success();
- if (auto op = point->getPrevOp()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(point)) {
for (auto value : op->getResults()) {
bool isSafe = isValueSafeTransferFunction(value);
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(value);
@@ -197,9 +197,9 @@ LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint *point) {
auto globalSlot = SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(
globalSlotGet, globalSlotGet.getSlotAttr());
auto *flatSymbolRefPoint =
- getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
+ getProgramPoint<FlatSymbolRefLatticeAnchor>(globalSlot);
auto *valueState = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
- getProgramPointAfter(globalSlot), globalSlotGet.getResult());
+ point, globalSlotGet.getResult());
auto *globalState =
getOrCreate<InlineGlobalSlotsAnalysisState>(flatSymbolRefPoint);
propagateIfChanged(globalState,
@@ -226,7 +226,7 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
if ((op->hasTrait<Torch::OpTrait::ReadOnly>() || isMemoryEffectFree(op)) &&
llvm::all_of(op->getResults(), [&](Value result) {
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
- getProgramPointAfter(value.getDefiningOp()), result);
+ ProgramPoint(value.getDefiningOp()), result);
return state->isSafe;
}))
continue;
@@ -236,9 +236,9 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
auto globalSlot =
SymbolTable::lookupNearestSymbolFrom<GlobalSlotOp>(op, symName);
- auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
- getProgramPointAfter(value.getDefiningOp()),
- getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
+ auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
+ ProgramPoint(value.getDefiningOp()),
+ getProgramPoint<FlatSymbolRefLatticeAnchor>(globalSlot));
if (state->isSafe)
continue;
}
@@ -251,9 +251,7 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
SmallVector<Operation *> getBackwardSliceIncludingRoot(Value initialValue) {
SetVector<Operation *> sliceSet;
- [[maybe_unused]] LogicalResult result =
- getBackwardSlice(initialValue, &sliceSet);
- assert(result.succeeded() && "expected a backward slice");
+ getBackwardSlice(initialValue, &sliceSet);
SmallVector<Operation *> slice;
llvm::append_range(slice, sliceSet);
slice.push_back(initialValue.getDefiningOp());
@@ -289,7 +287,7 @@ class InlineGlobalSlotsPass
module->walk([&](Operation *op) {
if (auto globalSlot = dyn_cast<Torch::GlobalSlotOp>(op)) {
auto *state = solver.lookupState<InlineGlobalSlotsAnalysisState>(
- solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot));
+ solver.getProgramPoint<FlatSymbolRefLatticeAnchor>(globalSlot));
state->print(llvm::dbgs());
llvm::dbgs() << ": "
<< FlatSymbolRefAttr::get(globalSlot.getSymNameAttr())
@@ -327,7 +325,7 @@ class InlineGlobalSlotsPass
initialize, slotSymName);
auto symbolRefPoint =
- solver.getLatticeAnchor<FlatSymbolRefLatticeAnchor>(globalSlot);
+ solver.getProgramPoint<FlatSymbolRefLatticeAnchor>(globalSlot);
auto *state =
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
// We roll the analysis of whether a slot is set or public into the
@@ -122,8 +122,8 @@ public:
patterns.insert<MatchQuantizeOperator>(context);
GreedyRewriteConfig config;
- if (failed(
- applyPatternsGreedily(getOperation(), std::move(patterns), config)))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config)))
return signalPassFailure();
}
};
@@ -372,7 +372,7 @@ class MaximizeValueSemanticsPass
RewritePatternSet patterns(context);
patterns.insert<AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock,
RewriteViewLikeSubgraph>(context);
- (void)applyPatternsGreedily(func, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
};
@@ -75,13 +75,14 @@ class PrepareForGlobalizeObjectGraphPass
func::CallIndirectOp::getCanonicalizationPatterns(patterns, context);
patterns.add<EraseUnusedConstantOp>(context);
- // Use applyPatternsGreedily because the CallIndirectOp folding
+ // Use applyPatternsAndFoldGreedily because the CallIndirectOp folding
// makes the ConstantOp unused, which does not work with the visitation
// order of the dialect conversion infrastructure.
// TODO: Do this with the dialect conversion infrastructure to avoid doing
// folding as part of this. Or avoid folding during greedy pattern
// application. See: https://llvm.org/PR49502
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
return signalPassFailure();
}
@@ -820,11 +820,11 @@ public:
patterns.add<RecomposeMeshgridIndexingListUnpack>(context);
GreedyRewriteConfig config;
- config.setUseTopDownTraversal(true);
- config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
+ config.useTopDownTraversal = true;
+ config.maxIterations = GreedyRewriteConfig::kNoLimit;
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns),
- config))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config))) {
return signalPassFailure();
}
}
@@ -261,10 +261,10 @@ public:
// TODO: Debug visitation order to make this more efficient.
// A single linear scan should suffice.
GreedyRewriteConfig config;
- config.setUseTopDownTraversal(true);
- config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns),
- config))) {
+ config.useTopDownTraversal = true;
+ config.maxIterations = GreedyRewriteConfig::kNoLimit;
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config))) {
return signalPassFailure();
}
}
@@ -1601,9 +1601,10 @@ public:
// When propagating, we need to go back and clean up aten.Tensor ops that
// have been futher propagated. It is also necessary to add newly created
// ops for custom folding after scalarizing a where.self op.
- config.setStrictness(GreedyRewriteStrictness::ExistingAndNewOps);
- if (failed(applyOpPatternsGreedily(shapeCalculationOps.getArrayRef(),
- std::move(patterns), config))) {
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+ FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+ if (failed(applyOpPatternsAndFold(shapeCalculationOps.getArrayRef(),
+ frozenPatterns, config))) {
return signalPassFailure();
}
@@ -211,10 +211,10 @@ class SimplifyDtypeCalculationsPass
// TODO: Debug visitation order to make this more efficient.
// A single linear scan should suffice.
GreedyRewriteConfig config;
- config.setUseTopDownTraversal(true);
- config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns),
- config))) {
+ config.useTopDownTraversal = true;
+ config.maxIterations = GreedyRewriteConfig::kNoLimit;
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config))) {
return signalPassFailure();
}
}
@@ -207,10 +207,10 @@ class SimplifyShapeCalculationsPass
// TODO: Debug visitation order to make this more efficient.
// A single linear scan should suffice.
GreedyRewriteConfig config;
- config.setUseTopDownTraversal(true);
- config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns),
- config))) {
+ config.useTopDownTraversal = true;
+ config.maxIterations = GreedyRewriteConfig::kNoLimit;
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config))) {
return signalPassFailure();
}
}
@@ -232,7 +232,7 @@ struct FinalizingBackendTypeConversionPass
RewritePatternSet greedyPatterns(context);
greedyPatterns.insert<ExtFTruncFPattern>(context);
- if (failed(applyPatternsGreedily(func, std::move(greedyPatterns))))
+ if (failed(applyPatternsAndFoldGreedily(func, std::move(greedyPatterns))))
signalPassFailure();
// Drop attributes that are no longer used after conversion out of Torch.
@@ -100,12 +100,12 @@ public:
std::vector<APInt> newData(data.size() * packRatio,
APInt(unpackedBitWidth, 0));
for (int i = 0, e = data.size(); i < e; ++i) {
- auto el = data[i];
+ auto el = static_cast<uint8_t>(data[i]);
uint8_t mask = (1 << unpackedBitWidth) - 1;
for (int b = 0; b < packRatio; b++) {
- newData[i * packRatio + b] =
- APInt(unpackedBitWidth, (el & mask) >> (unpackedBitWidth * b),
- /*isSigned=*/false, /*implicitTrunc=*/true);
+ uint64_t value = (static_cast<uint64_t>(el & mask) >>
+ (unpackedBitWidth * b));
+ newData[i * packRatio + b] = APInt(unpackedBitWidth, value);
mask = mask << unpackedBitWidth;
}
}
@@ -131,7 +131,8 @@ class UnpackQuantTensorPass
RewritePatternSet patterns(context);
patterns.add<UnpackQuantizedMatmulWeights>(context);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
@@ -425,7 +425,8 @@ class MungeMemrefCopy : public MungeMemrefCopyBase<MungeMemrefCopy> {
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
patterns.insert<MemrefCopyOpToLinalg>(context);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
return signalPassFailure();
}
}
@@ -447,7 +448,8 @@ class GeneralizeTensorConcat
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
tensor::populateDecomposeTensorConcatPatterns(patterns);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
return signalPassFailure();
}
}
@@ -470,7 +472,8 @@ class GeneralizeTensorPad
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
patterns.insert<linalg::DecomposePadOpPattern>(context);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
return signalPassFailure();
}
}