#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Interfaces/FunctionImplementation.h"
using namespace mlir;
using namespace mlir::pdl_interp;
#include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
void PDLInterpDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
>();
}
template <typename OpT>
static LogicalResult verifySwitchOp(OpT op) {
size_t numDests = op.getCases().size();
size_t numValues = op.getCaseValues().size();
if (numDests != numValues) {
return op.emitOpError(
"expected number of cases to match the number of case "
"values, got ")
<< numDests << " but expected " << numValues;
}
return success();
}
LogicalResult CreateOperationOp::verify() {
if (!getInferredResultTypes())
return success();
if (!getInputResultTypes().empty()) {
return emitOpError("with inferred results cannot also have "
"explicit result types");
}
OperationName opName(getName(), getContext());
if (!opName.hasInterface<InferTypeOpInterface>()) {
return emitOpError()
<< "has inferred results, but the created operation '" << opName
<< "' does not support result type inference (or is not "
"registered)";
}
return success();
}
static ParseResult parseCreateOperationOpAttributes(
OpAsmParser &p,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
ArrayAttr &attrNamesAttr) {
Builder &builder = p.getBuilder();
SmallVector<Attribute, 4> attrNames;
if (succeeded(p.parseOptionalLBrace())) {
auto parseOperands = [&]() {
StringAttr nameAttr;
OpAsmParser::UnresolvedOperand operand;
if (p.parseAttribute(nameAttr) || p.parseEqual() ||
p.parseOperand(operand))
return failure();
attrNames.push_back(nameAttr);
attrOperands.push_back(operand);
return success();
};
if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
return failure();
}
attrNamesAttr = builder.getArrayAttr(attrNames);
return success();
}
static void printCreateOperationOpAttributes(OpAsmPrinter &p,
CreateOperationOp op,
OperandRange attrArgs,
ArrayAttr attrNames) {
if (attrNames.empty())
return;
p << " {";
interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
[&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
p << '}';
}
static ParseResult parseCreateOperationOpResults(
OpAsmParser &p,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultOperands,
SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) {
if (failed(p.parseOptionalArrow()))
return success();
if (succeeded(p.parseOptionalLess())) {
if (p.parseKeyword("inferred") || p.parseGreater())
return failure();
inferredResultTypes = p.getBuilder().getUnitAttr();
return success();
}
return failure(p.parseLParen() || p.parseOperandList(resultOperands) ||
p.parseColonTypeList(resultTypes) || p.parseRParen());
}
static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op,
OperandRange resultOperands,
TypeRange resultTypes,
UnitAttr inferredResultTypes) {
if (inferredResultTypes) {
p << " -> <inferred>";
return;
}
if (!resultTypes.empty())
p << " -> (" << resultOperands << " : " << resultTypes << ")";
}
void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
Value range, Block *successor, bool initLoop) {
build(builder, state, range, successor);
if (initLoop) {
auto rangeType = llvm::cast<pdl::RangeType>(range.getType());
state.regions.front()->emplaceBlock();
state.regions.front()->addArgument(rangeType.getElementType(),
state.location);
}
}
ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::Argument loopVariable;
OpAsmParser::UnresolvedOperand operandInfo;
if (parser.parseArgument(loopVariable, true) ||
parser.parseKeyword("in", " after loop variable") ||
parser.parseOperand(operandInfo))
return failure();
Type rangeType = pdl::RangeType::get(loopVariable.type);
if (parser.resolveOperand(operandInfo, rangeType, result.operands))
return failure();
Region *body = result.addRegion();
Block *successor;
if (parser.parseRegion(*body, loopVariable) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseArrow() || parser.parseSuccessor(successor))
return failure();
result.addSuccessors(successor);
return success();
}
void ForEachOp::print(OpAsmPrinter &p) {
BlockArgument arg = getLoopVariable();
p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' ';
p.printRegion(getRegion(), false);
p.printOptionalAttrDict((*this)->getAttrs());
p << " -> ";
p.printSuccessor(getSuccessor());
}
LogicalResult ForEachOp::verify() {
if (getRegion().getNumArguments() != 1)
return emitOpError("requires exactly one argument");
BlockArgument arg = getLoopVariable();
Type rangeType = pdl::RangeType::get(arg.getType());
if (rangeType != getValues().getType())
return emitOpError("operand must be a range of loop variable type");
return success();
}
void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs) {
buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
}
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
}
static Type getGetValueTypeOpValueType(Type type) {
Type valueTy = pdl::ValueType::get(type.getContext());
return llvm::isa<pdl::RangeType>(type) ? pdl::RangeType::get(valueTy)
: valueTy;
}
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
Type &resultType) {
if (!argumentTypes.empty()) {
resultType =
pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes[0]));
return success();
}
return p.parseColonType(resultType);
}
static void printRangeType(OpAsmPrinter &p, CreateRangeOp op,
TypeRange argumentTypes, Type resultType) {
if (argumentTypes.empty())
p << ": " << resultType;
}
LogicalResult CreateRangeOp::verify() {
Type elementType = getType().getElementType();
for (Type operandType : getOperandTypes()) {
Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType);
if (operandElementType != elementType) {
return emitOpError("expected operand to have element type ")
<< elementType << ", but got " << operandElementType;
}
}
return success();
}
LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); }
LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); }
LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); }
LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); }
LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); }
LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); }
#define GET_OP_CLASSES
#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"