#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "PredicateTree.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP
#include "mlir/Conversion/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::pdl_to_pdl_interp;
namespace {
struct PatternLowering {
public:
PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
DenseMap<Operation *, PDLPatternConfigSet *> *configMap);
void lower(ModuleOp module);
private:
using ValueMap = llvm::ScopedHashTable<Position *, Value>;
using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
Block *generateMatcher(MatcherNode &node, Region ®ion,
Block *block = nullptr);
Value getValueAt(Block *¤tBlock, Position *pos);
void generate(BoolNode *boolNode, Block *¤tBlock, Value val);
void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
void generate(SuccessNode *successNode, Block *¤tBlock);
SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
SmallVectorImpl<Position *> &usedMatchValues);
void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::AttributeOp attrOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::EraseOp eraseOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::OperationOp operationOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::RangeOp rangeOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::ReplaceOp replaceOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::ResultOp resultOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::ResultsOp resultOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::TypeOp typeOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::TypesOp typeOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
void generateOperationResultTypeRewriter(
pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
bool &hasInferredResultTypes);
OpBuilder builder;
pdl_interp::FuncOp matcherFunc;
ModuleOp rewriterModule;
SymbolTable rewriterSymbolTable;
ValueMap values;
SmallVector<Block *, 8> failureBlockStack;
DenseMap<Value, Position *> valueToPosition;
SetVector<Value> locOps;
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
};
}
PatternLowering::PatternLowering(
pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
DenseMap<Operation *, PDLPatternConfigSet *> *configMap)
: builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
configMap(configMap) {}
void PatternLowering::lower(ModuleOp module) {
PredicateUniquer predicateUniquer;
PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
ValueMapScope topLevelValueScope(values);
Block *matcherEntryBlock = &matcherFunc.front();
values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
module, predicateBuilder, valueToPosition);
Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
assert(failureBlockStack.empty() && "failed to empty the stack");
matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
firstMatcherBlock->getOperations());
firstMatcherBlock->erase();
}
Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion,
Block *block) {
if (!block)
block = ®ion.emplaceBlock();
ValueMapScope scope(values);
if (isa<ExitNode>(node)) {
builder.setInsertionPointToEnd(block);
builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
return block;
}
std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
Block *failureBlock;
if (failureNode) {
failureBlock = generateMatcher(*failureNode, region);
failureBlockStack.push_back(failureBlock);
} else {
assert(!failureBlockStack.empty() && "expected valid failure block");
failureBlock = failureBlockStack.back();
}
Block *currentBlock = block;
Position *position = node.getPosition();
Value val = position ? getValueAt(currentBlock, position) : Value();
bool isOperationValue = val && isa<pdl::OperationType>(val.getType());
if (isOperationValue)
locOps.insert(val);
TypeSwitch<MatcherNode *>(&node)
.Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
this->generate(derivedNode, currentBlock, val);
})
.Case([&](SuccessNode *successNode) {
generate(successNode, currentBlock);
});
while (failureBlockStack.back() != failureBlock) {
failureBlockStack.pop_back();
assert(!failureBlockStack.empty() && "unable to locate failure block");
}
if (failureNode)
failureBlockStack.pop_back();
if (isOperationValue)
locOps.remove(val);
return block;
}
Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
if (Value val = values.lookup(pos))
return val;
Value parentVal;
if (Position *parent = pos->getParent())
parentVal = getValueAt(currentBlock, parent);
Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
builder.setInsertionPointToEnd(currentBlock);
Value value;
switch (pos->getKind()) {
case Predicates::OperationPos: {
auto *operationPos = cast<OperationPosition>(pos);
if (operationPos->isOperandDefiningOp())
value = builder.create<pdl_interp::GetDefiningOpOp>(
loc, builder.getType<pdl::OperationType>(), parentVal);
else
value = parentVal;
break;
}
case Predicates::UsersPos: {
auto *usersPos = cast<UsersPosition>(pos);
if (isa<pdl::RangeType>(parentVal.getType()) &&
usersPos->useRepresentative())
value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
else
value = parentVal;
value = builder.create<pdl_interp::GetUsersOp>(loc, value);
break;
}
case Predicates::ForEachPos: {
assert(!failureBlockStack.empty() && "expected valid failure block");
auto foreach = builder.create<pdl_interp::ForEachOp>(
loc, parentVal, failureBlockStack.back(), true);
value = foreach.getLoopVariable();
Block *continueBlock = builder.createBlock(&foreach.getRegion());
builder.create<pdl_interp::ContinueOp>(loc);
failureBlockStack.push_back(continueBlock);
currentBlock = &foreach.getRegion().front();
break;
}
case Predicates::OperandPos: {
auto *operandPos = cast<OperandPosition>(pos);
value = builder.create<pdl_interp::GetOperandOp>(
loc, builder.getType<pdl::ValueType>(), parentVal,
operandPos->getOperandNumber());
break;
}
case Predicates::OperandGroupPos: {
auto *operandPos = cast<OperandGroupPosition>(pos);
Type valueTy = builder.getType<pdl::ValueType>();
value = builder.create<pdl_interp::GetOperandsOp>(
loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
parentVal, operandPos->getOperandGroupNumber());
break;
}
case Predicates::AttributePos: {
auto *attrPos = cast<AttributePosition>(pos);
value = builder.create<pdl_interp::GetAttributeOp>(
loc, builder.getType<pdl::AttributeType>(), parentVal,
attrPos->getName().strref());
break;
}
case Predicates::TypePos: {
if (isa<pdl::AttributeType>(parentVal.getType()))
value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
else
value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
break;
}
case Predicates::ResultPos: {
auto *resPos = cast<ResultPosition>(pos);
value = builder.create<pdl_interp::GetResultOp>(
loc, builder.getType<pdl::ValueType>(), parentVal,
resPos->getResultNumber());
break;
}
case Predicates::ResultGroupPos: {
auto *resPos = cast<ResultGroupPosition>(pos);
Type valueTy = builder.getType<pdl::ValueType>();
value = builder.create<pdl_interp::GetResultsOp>(
loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
parentVal, resPos->getResultGroupNumber());
break;
}
case Predicates::AttributeLiteralPos: {
auto *attrPos = cast<AttributeLiteralPosition>(pos);
value =
builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
break;
}
case Predicates::TypeLiteralPos: {
auto *typePos = cast<TypeLiteralPosition>(pos);
Attribute rawTypeAttr = typePos->getValue();
if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
else
value = builder.create<pdl_interp::CreateTypesOp>(
loc, cast<ArrayAttr>(rawTypeAttr));
break;
}
case Predicates::ConstraintResultPos: {
auto *constrResPos = cast<ConstraintPosition>(pos);
auto i = constraintOpMap.find(constrResPos->getQuestion());
assert(i != constraintOpMap.end());
value = i->second->getResult(constrResPos->getIndex());
break;
}
default:
llvm_unreachable("Generating unknown Position getter");
break;
}
values.insert(pos, value);
return value;
}
void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
Value val) {
Location loc = val.getLoc();
Qualifier *question = boolNode->getQuestion();
Qualifier *answer = boolNode->getAnswer();
Region *region = currentBlock->getParent();
SmallVector<Value> args;
if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
args = {getValueAt(currentBlock, equalToQuestion->getValue())};
} else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
for (Position *position : cstQuestion->getArgs())
args.push_back(getValueAt(currentBlock, position));
}
Block *success = ®ion->emplaceBlock();
Block *failure = failureBlockStack.back();
builder.setInsertionPointToEnd(currentBlock);
Predicates::Kind kind = question->getKind();
switch (kind) {
case Predicates::IsNotNullQuestion:
builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
break;
case Predicates::OperationNameQuestion: {
auto *opNameAnswer = cast<OperationNameAnswer>(answer);
builder.create<pdl_interp::CheckOperationNameOp>(
loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
break;
}
case Predicates::TypeQuestion: {
auto *ans = cast<TypeAnswer>(answer);
if (isa<pdl::RangeType>(val.getType()))
builder.create<pdl_interp::CheckTypesOp>(
loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
else
builder.create<pdl_interp::CheckTypeOp>(
loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
break;
}
case Predicates::AttributeQuestion: {
auto *ans = cast<AttributeAnswer>(answer);
builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
success, failure);
break;
}
case Predicates::OperandCountAtLeastQuestion:
case Predicates::OperandCountQuestion:
builder.create<pdl_interp::CheckOperandCountOp>(
loc, val, cast<UnsignedAnswer>(answer)->getValue(),
kind == Predicates::OperandCountAtLeastQuestion,
success, failure);
break;
case Predicates::ResultCountAtLeastQuestion:
case Predicates::ResultCountQuestion:
builder.create<pdl_interp::CheckResultCountOp>(
loc, val, cast<UnsignedAnswer>(answer)->getValue(),
kind == Predicates::ResultCountAtLeastQuestion,
success, failure);
break;
case Predicates::EqualToQuestion: {
bool trueAnswer = isa<TrueAnswer>(answer);
builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
trueAnswer ? success : failure,
trueAnswer ? failure : success);
break;
}
case Predicates::ConstraintQuestion: {
auto *cstQuestion = cast<ConstraintQuestion>(question);
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
cstQuestion->getIsNegated(), success, failure);
constraintOpMap.insert({cstQuestion, applyConstraintOp});
break;
}
default:
llvm_unreachable("Generating unknown Predicate operation");
}
generateMatcher(*boolNode->getSuccessNode(), *region, success);
}
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
llvm::MapVector<Qualifier *, Block *> &dests) {
std::vector<ValT> values;
std::vector<Block *> blocks;
values.reserve(dests.size());
blocks.reserve(dests.size());
for (const auto &it : dests) {
blocks.push_back(it.second);
values.push_back(cast<PredT>(it.first)->getValue());
}
builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
}
void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
Value val) {
Qualifier *question = switchNode->getQuestion();
Region *region = currentBlock->getParent();
Block *defaultDest = failureBlockStack.back();
Predicates::Kind kind = question->getKind();
if (kind == Predicates::OperandCountAtLeastQuestion ||
kind == Predicates::ResultCountAtLeastQuestion) {
SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
llvm::seq<unsigned>(0, switchNode->getChildren().size()));
llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
});
failureBlockStack.push_back(defaultDest);
Location loc = val.getLoc();
for (unsigned idx : sortedChildren) {
auto &child = switchNode->getChild(idx);
Block *childBlock = generateMatcher(*child.second, *region);
Block *predicateBlock = builder.createBlock(childBlock);
builder.setInsertionPointToEnd(predicateBlock);
unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
switch (kind) {
case Predicates::OperandCountAtLeastQuestion:
builder.create<pdl_interp::CheckOperandCountOp>(
loc, val, ans, true, childBlock, defaultDest);
break;
case Predicates::ResultCountAtLeastQuestion:
builder.create<pdl_interp::CheckResultCountOp>(
loc, val, ans, true, childBlock, defaultDest);
break;
default:
llvm_unreachable("Generating invalid AtLeast operation");
}
failureBlockStack.back() = predicateBlock;
}
Block *firstPredicateBlock = failureBlockStack.pop_back_val();
currentBlock->getOperations().splice(currentBlock->end(),
firstPredicateBlock->getOperations());
firstPredicateBlock->erase();
return;
}
llvm::MapVector<Qualifier *, Block *> children;
for (auto &it : switchNode->getChildren())
children.insert({it.first, generateMatcher(*it.second, *region)});
builder.setInsertionPointToEnd(currentBlock);
switch (question->getKind()) {
case Predicates::OperandCountQuestion:
return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
int32_t>(val, defaultDest, builder, children);
case Predicates::ResultCountQuestion:
return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
int32_t>(val, defaultDest, builder, children);
case Predicates::OperationNameQuestion:
return createSwitchOp<pdl_interp::SwitchOperationNameOp,
OperationNameAnswer>(val, defaultDest, builder,
children);
case Predicates::TypeQuestion:
if (isa<pdl::RangeType>(val.getType())) {
return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
val, defaultDest, builder, children);
}
return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
val, defaultDest, builder, children);
case Predicates::AttributeQuestion:
return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
val, defaultDest, builder, children);
default:
llvm_unreachable("Generating unknown switch predicate.");
}
}
void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
pdl::PatternOp pattern = successNode->getPattern();
Value root = successNode->getRoot();
SmallVector<Position *, 8> usedMatchValues;
SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
std::vector<Value> mappedMatchValues;
mappedMatchValues.reserve(usedMatchValues.size());
for (Position *position : usedMatchValues)
mappedMatchValues.push_back(getValueAt(currentBlock, position));
SmallVector<StringRef, 4> generatedOps;
for (auto op :
pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
generatedOps.push_back(*op.getOpName());
ArrayAttr generatedOpsAttr;
if (!generatedOps.empty())
generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
StringAttr rootKindAttr;
if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
if (std::optional<StringRef> rootKind = rootOp.getOpName())
rootKindAttr = builder.getStringAttr(*rootKind);
builder.setInsertionPointToEnd(currentBlock);
auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
failureBlockStack.back());
if (configMap)
configMap->try_emplace(matchOp, configMap->lookup(pattern));
}
SymbolRefAttr PatternLowering::generateRewriter(
pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
builder.setInsertionPointToEnd(rewriterModule.getBody());
auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
pattern.getLoc(), "pdl_generated_rewriter",
builder.getFunctionType(std::nullopt, std::nullopt));
rewriterSymbolTable.insert(rewriterFunc);
builder.setInsertionPointToEnd(&rewriterFunc.front());
DenseMap<Value, Value> rewriteValues;
auto mapRewriteValue = [&](Value oldValue) {
Value &newValue = rewriteValues[oldValue];
if (newValue)
return newValue;
Operation *oldOp = oldValue.getDefiningOp();
if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
if (Attribute value = attrOp.getValueAttr()) {
return newValue = builder.create<pdl_interp::CreateAttributeOp>(
attrOp.getLoc(), value);
}
} else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
if (TypeAttr type = typeOp.getConstantTypeAttr()) {
return newValue = builder.create<pdl_interp::CreateTypeOp>(
typeOp.getLoc(), type);
}
} else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
return newValue = builder.create<pdl_interp::CreateTypesOp>(
typeOp.getLoc(), typeOp.getType(), type);
}
}
Position *inputPos = valueToPosition.lookup(oldValue);
assert(inputPos && "expected value to be a pattern input");
usedMatchValues.push_back(inputPos);
return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
oldValue.getLoc());
};
pdl::RewriteOp rewriter = pattern.getRewriter();
if (StringAttr rewriteName = rewriter.getNameAttr()) {
SmallVector<Value> args;
if (rewriter.getRoot())
args.push_back(mapRewriteValue(rewriter.getRoot()));
auto mappedArgs =
llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
args.append(mappedArgs.begin(), mappedArgs.end());
builder.create<pdl_interp::ApplyRewriteOp>(
rewriter.getLoc(), TypeRange(), rewriteName, args);
} else {
for (Operation &rewriteOp : *rewriter.getBody()) {
llvm::TypeSwitch<Operation *>(&rewriteOp)
.Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
this->generateRewriter(op, rewriteValues, mapRewriteValue);
});
}
}
rewriterFunc.setType(builder.getFunctionType(
llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
std::nullopt));
builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
return SymbolRefAttr::get(
builder.getContext(),
pdl_interp::PDLInterpDialect::getRewriterModuleName(),
SymbolRefAttr::get(rewriterFunc));
}
void PatternLowering::generateRewriter(
pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
SmallVector<Value, 2> arguments;
for (Value argument : rewriteOp.getArgs())
arguments.push_back(mapRewriteValue(argument));
auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
arguments);
for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
rewriteValues[std::get<0>(it)] = std::get<1>(it);
}
void PatternLowering::generateRewriter(
pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
attrOp.getLoc(), attrOp.getValueAttr());
rewriteValues[attrOp] = newAttr;
}
void PatternLowering::generateRewriter(
pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
mapRewriteValue(eraseOp.getOpValue()));
}
void PatternLowering::generateRewriter(
pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
SmallVector<Value, 4> operands;
for (Value operand : operationOp.getOperandValues())
operands.push_back(mapRewriteValue(operand));
SmallVector<Value, 4> attributes;
for (Value attr : operationOp.getAttributeValues())
attributes.push_back(mapRewriteValue(attr));
bool hasInferredResultTypes = false;
SmallVector<Value, 2> types;
generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
rewriteValues, hasInferredResultTypes);
Location loc = operationOp.getLoc();
Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
attributes, operationOp.getAttributeValueNames());
rewriteValues[operationOp.getOp()] = createdOp;
OperandRange resultTys = operationOp.getTypeValues();
if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
Value &type = rewriteValues[resultTys[0]];
if (!type) {
auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
}
return;
}
bool seenVariableLength = false;
Type valueTy = builder.getType<pdl::ValueType>();
Type valueRangeTy = pdl::RangeType::get(valueTy);
for (const auto &it : llvm::enumerate(resultTys)) {
Value &type = rewriteValues[it.value()];
if (type)
continue;
bool isVariadic = isa<pdl::RangeType>(it.value().getType());
seenVariableLength |= isVariadic;
Value resultVal;
if (seenVariableLength)
resultVal = builder.create<pdl_interp::GetResultsOp>(
loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
else
resultVal = builder.create<pdl_interp::GetResultOp>(
loc, valueTy, createdOp, it.index());
type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
}
}
void PatternLowering::generateRewriter(
pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
SmallVector<Value, 4> replOperands;
for (Value operand : rangeOp.getArguments())
replOperands.push_back(mapRewriteValue(operand));
rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
rangeOp.getLoc(), rangeOp.getType(), replOperands);
}
void PatternLowering::generateRewriter(
pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
SmallVector<Value, 4> replOperands;
if (Value replOp = replaceOp.getReplOperation()) {
auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
if (!opOp || !opOp.getTypeValues().empty()) {
replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
replOp.getLoc(), mapRewriteValue(replOp)));
}
} else {
for (Value operand : replaceOp.getReplValues())
replOperands.push_back(mapRewriteValue(operand));
}
if (replOperands.empty()) {
builder.create<pdl_interp::EraseOp>(
replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
return;
}
builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
mapRewriteValue(replaceOp.getOpValue()),
replOperands);
}
void PatternLowering::generateRewriter(
pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
resultOp.getLoc(), builder.getType<pdl::ValueType>(),
mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
}
void PatternLowering::generateRewriter(
pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
resultOp.getLoc(), resultOp.getType(),
mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
}
void PatternLowering::generateRewriter(
pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
rewriteValues[typeOp] =
builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
}
}
void PatternLowering::generateRewriter(
pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
typeOp.getLoc(), typeOp.getType(), typeAttr);
}
}
void PatternLowering::generateOperationResultTypeRewriter(
pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
bool &hasInferredResultTypes) {
Block *rewriterBlock = op->getBlock();
OperandRange resultTypeValues = op.getTypeValues();
auto tryResolveResultTypes = [&] {
types.reserve(resultTypeValues.size());
for (const auto &it : llvm::enumerate(resultTypeValues)) {
Value resultType = it.value();
if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
types.push_back(existingRewriteValue);
continue;
}
if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
types.push_back(mapRewriteValue(resultType));
continue;
}
types.clear();
return failure();
}
return success();
};
if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
return;
if (op.hasTypeInference()) {
hasInferredResultTypes = true;
return;
}
for (OpOperand &use : op.getOp().getUses()) {
pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
if (!replOpUser || use.getOperandNumber() == 0)
continue;
Value replOpVal = replOpUser.getOpValue();
Operation *replacedOp = replOpVal.getDefiningOp();
if (replacedOp->getBlock() == rewriterBlock &&
!replacedOp->isBeforeInBlock(op))
continue;
Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
replacedOp->getLoc(), mapRewriteValue(replOpVal));
types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
replacedOp->getLoc(), replacedOpResults));
return;
}
if (resultTypeValues.empty())
return;
op->emitOpError() << "unable to infer result type for operation";
llvm_unreachable("unable to infer result type for operation");
}
namespace {
struct PDLToPDLInterpPass
: public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
PDLToPDLInterpPass() = default;
PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
: configMap(&configMap) {}
void runOnOperation() final;
DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
};
}
void PDLToPDLInterpPass::runOnOperation() {
ModuleOp module = getOperation();
OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
auto matcherFunc = builder.create<pdl_interp::FuncOp>(
module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
builder.getFunctionType(builder.getType<pdl::OperationType>(),
std::nullopt),
std::nullopt);
ModuleOp rewriterModule = builder.create<ModuleOp>(
module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
PatternLowering generator(matcherFunc, rewriterModule, configMap);
generator.lower(module);
for (pdl::PatternOp pattern :
llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
if (configMap)
configMap->erase(pattern);
pattern.erase();
}
}
std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
return std::make_unique<PDLToPDLInterpPass>();
}
std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass(
DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
return std::make_unique<PDLToPDLInterpPass>(configMap);
}