#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "LLVMInlining.h"
#include "TypeDetail.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/Mutex.h"
#include "llvm/Support/SourceMgr.h"
#include <numeric>
#include <optional>
using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::cconv::getMaxEnumValForCConv;
using mlir::LLVM::linkage::getMaxEnumValForLinkage;
using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
namespace mlir {
static Attribute convertToAttribute(MLIRContext *ctx,
IntegerOverflowFlags flags) {
return IntegerOverflowFlagsAttr::get(ctx, flags);
}
static LogicalResult
convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr,
function_ref<InFlightDiagnostic()> emitError) {
auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr);
if (!flagsAttr) {
return emitError() << "expected 'overflowFlags' attribute to be an "
"IntegerOverflowFlagsAttr, but got "
<< attr;
}
flags = flagsAttr.getValue();
return success();
}
}
static ParseResult parseOverflowFlags(AsmParser &p,
IntegerOverflowFlags &flags) {
if (failed(p.parseOptionalKeyword("overflow"))) {
flags = IntegerOverflowFlags::none;
return success();
}
if (p.parseLess())
return failure();
do {
StringRef kw;
SMLoc loc = p.getCurrentLocation();
if (p.parseKeyword(&kw))
return failure();
std::optional<IntegerOverflowFlags> flag =
symbolizeIntegerOverflowFlags(kw);
if (!flag)
return p.emitError(loc,
"invalid overflow flag: expected nsw, nuw, or none");
flags = flags | *flag;
} while (succeeded(p.parseOptionalComma()));
return p.parseGreater();
}
static void printOverflowFlags(AsmPrinter &p, Operation *op,
IntegerOverflowFlags flags) {
if (flags == IntegerOverflowFlags::none)
return;
p << " overflow<";
SmallVector<StringRef, 2> strs;
if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw))
strs.push_back("nsw");
if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw))
strs.push_back("nuw");
llvm::interleaveComma(strs, p);
p << ">";
}
static constexpr const char kElemTypeAttrName[] = "elem_type";
static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
SmallVector<NamedAttribute, 8> filteredAttrs(
llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
if (attr.getName() == "fastmathFlags") {
auto defAttr =
FastmathFlagsAttr::get(attr.getValue().getContext(), {});
return defAttr != attr.getValue();
}
return true;
}));
return filteredAttrs;
}
static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
NamedAttrList &result) {
return parser.parseOptionalAttrDict(result);
}
static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
DictionaryAttr attrs) {
auto filteredAttrs = processFMFAttr(attrs.getValue());
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) {
printer.printOptionalAttrDict(
filteredAttrs, {iface.getOverflowFlagsAttrName()});
} else {
printer.printOptionalAttrDict(filteredAttrs);
}
}
static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
Operation *op,
SymbolTableCollection &symbolTable) {
StringRef name = symbol.getValue();
auto func =
symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
if (!func)
return op->emitOpError("'")
<< name << "' does not reference a valid LLVM function";
if (func.isExternal())
return op->emitOpError("'") << name << "' does not have a definition";
return success();
}
static Type getI1SameShape(Type type) {
Type i1Type = IntegerType::get(type.getContext(), 1);
if (LLVM::isCompatibleVectorType(type))
return LLVM::getVectorType(i1Type, LLVM::getVectorNumElements(type));
return i1Type;
}
static int parseOptionalKeywordAlternative(OpAsmParser &parser,
ArrayRef<StringRef> keywords) {
for (const auto &en : llvm::enumerate(keywords)) {
if (succeeded(parser.parseOptionalKeyword(en.value())))
return en.index();
}
return -1;
}
namespace {
template <typename Ty>
struct EnumTraits {};
#define REGISTER_ENUM_TYPE(Ty) \
template <> \
struct EnumTraits<Ty> { \
static StringRef stringify(Ty value) { return stringify##Ty(value); } \
static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
}
REGISTER_ENUM_TYPE(Linkage);
REGISTER_ENUM_TYPE(UnnamedAddr);
REGISTER_ENUM_TYPE(CConv);
REGISTER_ENUM_TYPE(TailCallKind);
REGISTER_ENUM_TYPE(Visibility);
}
template <typename EnumTy, typename RetTy = EnumTy>
static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
OperationState &result,
EnumTy defaultValue) {
SmallVector<StringRef, 10> names;
for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
int index = parseOptionalKeywordAlternative(parser, names);
if (index == -1)
return static_cast<RetTy>(defaultValue);
return static_cast<RetTy>(index);
}
void ICmpOp::print(OpAsmPrinter &p) {
p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
<< ", " << getOperand(1);
p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"});
p << " : " << getLhs().getType();
}
void FCmpOp::print(OpAsmPrinter &p) {
p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0)
<< ", " << getOperand(1);
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"});
p << " : " << getLhs().getType();
}
template <typename CmpPredicateType>
static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
StringAttr predicateAttr;
OpAsmParser::UnresolvedOperand lhs, rhs;
Type type;
SMLoc predicateLoc, trailingTypeLoc;
if (parser.getCurrentLocation(&predicateLoc) ||
parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
parser.parseOperand(lhs) || parser.parseComma() ||
parser.parseOperand(rhs) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
parser.resolveOperand(lhs, type, result.operands) ||
parser.resolveOperand(rhs, type, result.operands))
return failure();
int64_t predicateValue = 0;
if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
std::optional<ICmpPredicate> predicate =
symbolizeICmpPredicate(predicateAttr.getValue());
if (!predicate)
return parser.emitError(predicateLoc)
<< "'" << predicateAttr.getValue()
<< "' is an incorrect value of the 'predicate' attribute";
predicateValue = static_cast<int64_t>(*predicate);
} else {
std::optional<FCmpPredicate> predicate =
symbolizeFCmpPredicate(predicateAttr.getValue());
if (!predicate)
return parser.emitError(predicateLoc)
<< "'" << predicateAttr.getValue()
<< "' is an incorrect value of the 'predicate' attribute";
predicateValue = static_cast<int64_t>(*predicate);
}
result.attributes.set("predicate",
parser.getBuilder().getI64IntegerAttr(predicateValue));
if (!isCompatibleType(type))
return parser.emitError(trailingTypeLoc,
"expected LLVM dialect-compatible type");
result.addTypes(getI1SameShape(type));
return success();
}
ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
return parseCmpOp<ICmpPredicate>(parser, result);
}
ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
return parseCmpOp<FCmpPredicate>(parser, result);
}
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
auto boolAttr = BoolAttr::get(ctx, value);
ShapedType shapedType = dyn_cast<ShapedType>(type);
if (!shapedType)
return boolAttr;
return DenseElementsAttr::get(shapedType, boolAttr);
}
OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
if (getPredicate() != ICmpPredicate::eq &&
getPredicate() != ICmpPredicate::ne)
return {};
if (getLhs() == getRhs())
return getBoolAttribute(getType(), getContext(),
getPredicate() == ICmpPredicate::eq);
if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<ZeroOp>())
return getBoolAttribute(getType(), getContext(),
getPredicate() == ICmpPredicate::ne);
if (getLhs().getDefiningOp<ZeroOp>() && getRhs().getDefiningOp<AllocaOp>()) {
Value lhs = getLhs();
Value rhs = getRhs();
getLhsMutable().assign(rhs);
getRhsMutable().assign(lhs);
return getResult();
}
return {};
}
void AllocaOp::print(OpAsmPrinter &p) {
auto funcTy =
FunctionType::get(getContext(), {getArraySize().getType()}, {getType()});
if (getInalloca())
p << " inalloca";
p << ' ' << getArraySize() << " x " << getElemType();
if (getAlignment() && *getAlignment() != 0)
p.printOptionalAttrDict((*this)->getAttrs(),
{kElemTypeAttrName, getInallocaAttrName()});
else
p.printOptionalAttrDict(
(*this)->getAttrs(),
{getAlignmentAttrName(), kElemTypeAttrName, getInallocaAttrName()});
p << " : " << funcTy;
}
ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand arraySize;
Type type, elemType;
SMLoc trailingTypeLoc;
if (succeeded(parser.parseOptionalKeyword("inalloca")))
result.addAttribute(getInallocaAttrName(result.name),
UnitAttr::get(parser.getContext()));
if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
parser.parseType(elemType) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
return failure();
std::optional<NamedAttribute> alignmentAttr =
result.attributes.getNamed("alignment");
if (alignmentAttr.has_value()) {
auto alignmentInt = llvm::dyn_cast<IntegerAttr>(alignmentAttr->getValue());
if (!alignmentInt)
return parser.emitError(parser.getNameLoc(),
"expected integer alignment");
if (alignmentInt.getValue().isZero())
result.attributes.erase("alignment");
}
auto funcType = llvm::dyn_cast<FunctionType>(type);
if (!funcType || funcType.getNumInputs() != 1 ||
funcType.getNumResults() != 1)
return parser.emitError(
trailingTypeLoc,
"expected trailing function type with one argument and one result");
if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
return failure();
Type resultType = funcType.getResult(0);
if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(resultType))
result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType));
result.addTypes({funcType.getResult(0)});
return success();
}
LogicalResult AllocaOp::verify() {
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getElemType());
targetExtType && !targetExtType.supportsMemOps())
return emitOpError()
<< "this target extension type cannot be used in alloca";
return success();
}
Type AllocaOp::getResultPtrElementType() { return getElemType(); }
SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return SuccessorOperands(getDestOperandsMutable());
}
SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
: getFalseDestOperandsMutable());
}
void CondBrOp::build(OpBuilder &builder, OperationState &result,
Value condition, Block *trueDest, ValueRange trueOperands,
Block *falseDest, ValueRange falseOperands,
std::optional<std::pair<uint32_t, uint32_t>> weights) {
DenseI32ArrayAttr weightsAttr;
if (weights)
weightsAttr =
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first),
static_cast<int32_t>(weights->second)});
build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
{}, trueDest, falseDest);
}
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
DenseIntElementsAttr caseValues,
BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands,
ArrayRef<int32_t> branchWeights) {
DenseI32ArrayAttr weightsAttr;
if (!branchWeights.empty())
weightsAttr = builder.getDenseI32ArrayAttr(branchWeights);
build(builder, result, value, defaultOperands, caseOperands, caseValues,
weightsAttr, defaultDestination, caseDestinations);
}
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
ArrayRef<APInt> caseValues, BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands,
ArrayRef<int32_t> branchWeights) {
DenseIntElementsAttr caseValuesAttr;
if (!caseValues.empty()) {
ShapedType caseValueType = VectorType::get(
static_cast<int64_t>(caseValues.size()), value.getType());
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
}
build(builder, result, value, defaultDestination, defaultOperands,
caseValuesAttr, caseDestinations, caseOperands, branchWeights);
}
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands,
ArrayRef<int32_t> branchWeights) {
DenseIntElementsAttr caseValuesAttr;
if (!caseValues.empty()) {
ShapedType caseValueType = VectorType::get(
static_cast<int64_t>(caseValues.size()), value.getType());
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
}
build(builder, result, value, defaultDestination, defaultOperands,
caseValuesAttr, caseDestinations, caseOperands, branchWeights);
}
static ParseResult parseSwitchOpCases(
OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues,
SmallVectorImpl<Block *> &caseDestinations,
SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
if (failed(parser.parseLSquare()))
return failure();
if (succeeded(parser.parseOptionalRSquare()))
return success();
SmallVector<APInt> values;
unsigned bitWidth = flagType.getIntOrFloatBitWidth();
auto parseCase = [&]() {
int64_t value = 0;
if (failed(parser.parseInteger(value)))
return failure();
values.push_back(APInt(bitWidth, value));
Block *destination;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
SmallVector<Type> operandTypes;
if (parser.parseColon() || parser.parseSuccessor(destination))
return failure();
if (!parser.parseOptionalLParen()) {
if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
false) ||
parser.parseColonTypeList(operandTypes) || parser.parseRParen())
return failure();
}
caseDestinations.push_back(destination);
caseOperands.emplace_back(operands);
caseOperandTypes.emplace_back(operandTypes);
return success();
};
if (failed(parser.parseCommaSeparatedList(parseCase)))
return failure();
ShapedType caseValueType =
VectorType::get(static_cast<int64_t>(values.size()), flagType);
caseValues = DenseIntElementsAttr::get(caseValueType, values);
return parser.parseRSquare();
}
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
DenseIntElementsAttr caseValues,
SuccessorRange caseDestinations,
OperandRangeRange caseOperands,
const TypeRangeRange &caseOperandTypes) {
p << '[';
p.printNewline();
if (!caseValues) {
p << ']';
return;
}
size_t index = 0;
llvm::interleave(
llvm::zip(caseValues, caseDestinations),
[&](auto i) {
p << " ";
p << std::get<0>(i).getLimitedValue();
p << ": ";
p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
},
[&] {
p << ',';
p.printNewline();
});
p.printNewline();
p << ']';
}
LogicalResult SwitchOp::verify() {
if ((!getCaseValues() && !getCaseDestinations().empty()) ||
(getCaseValues() &&
getCaseValues()->size() !=
static_cast<int64_t>(getCaseDestinations().size())))
return emitOpError("expects number of case values to match number of "
"case destinations");
if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
return emitError("expects number of branch weights to match number of "
"successors: ")
<< getBranchWeights()->size() << " vs " << getNumSuccessors();
if (getCaseValues() &&
getValue().getType() != getCaseValues()->getElementType())
return emitError("expects case value type to match condition value type");
return success();
}
SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
: getCaseOperandsMutable(index - 1));
}
constexpr int32_t GEPOp::kDynamicIndex;
GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(),
getDynamicIndices());
}
static Type extractVectorElementType(Type type) {
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return vectorType.getElementType();
if (auto scalableVectorType = llvm::dyn_cast<LLVMScalableVectorType>(type))
return scalableVectorType.getElementType();
if (auto fixedVectorType = llvm::dyn_cast<LLVMFixedVectorType>(type))
return fixedVectorType.getElementType();
return type;
}
static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
SmallVectorImpl<int32_t> &rawConstantIndices,
SmallVectorImpl<Value> &dynamicIndices) {
for (const GEPArg &iter : indices) {
bool requiresConst = !rawConstantIndices.empty() &&
isa_and_nonnull<LLVMStructType>(currType);
if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
APInt intC;
if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
intC.isSignedIntN(kGEPConstantBitWidth)) {
rawConstantIndices.push_back(intC.getSExtValue());
} else {
rawConstantIndices.push_back(GEPOp::kDynamicIndex);
dynamicIndices.push_back(val);
}
} else {
rawConstantIndices.push_back(iter.get<GEPConstantIndex>());
}
if (rawConstantIndices.size() == 1 || !currType)
continue;
currType =
TypeSwitch<Type, Type>(currType)
.Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
LLVMArrayType>([](auto containerType) {
return containerType.getElementType();
})
.Case([&](LLVMStructType structType) -> Type {
int64_t memberIndex = rawConstantIndices.back();
if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
structType.getBody().size())
return structType.getBody()[memberIndex];
return nullptr;
})
.Default(Type(nullptr));
}
}
void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
bool inbounds, ArrayRef<NamedAttribute> attributes) {
SmallVector<int32_t> rawConstantIndices;
SmallVector<Value> dynamicIndices;
destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices);
result.addTypes(resultType);
result.addAttributes(attributes);
result.addAttribute(getRawConstantIndicesAttrName(result.name),
builder.getDenseI32ArrayAttr(rawConstantIndices));
if (inbounds) {
result.addAttribute(getInboundsAttrName(result.name),
builder.getUnitAttr());
}
result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType));
result.addOperands(basePtr);
result.addOperands(dynamicIndices);
}
void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
Type elementType, Value basePtr, ValueRange indices,
bool inbounds, ArrayRef<NamedAttribute> attributes) {
build(builder, result, resultType, elementType, basePtr,
SmallVector<GEPArg>(indices), inbounds, attributes);
}
static ParseResult
parseGEPIndices(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices,
DenseI32ArrayAttr &rawConstantIndices) {
SmallVector<int32_t> constantIndices;
auto idxParser = [&]() -> ParseResult {
int32_t constantIndex;
OptionalParseResult parsedInteger =
parser.parseOptionalInteger(constantIndex);
if (parsedInteger.has_value()) {
if (failed(parsedInteger.value()))
return failure();
constantIndices.push_back(constantIndex);
return success();
}
constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
return parser.parseOperand(indices.emplace_back());
};
if (parser.parseCommaSeparatedList(idxParser))
return failure();
rawConstantIndices =
DenseI32ArrayAttr::get(parser.getContext(), constantIndices);
return success();
}
static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
OperandRange indices,
DenseI32ArrayAttr rawConstantIndices) {
llvm::interleaveComma(
GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
[&](PointerUnion<IntegerAttr, Value> cst) {
if (Value val = llvm::dyn_cast_if_present<Value>(cst))
printer.printOperand(val);
else
printer << cst.get<IntegerAttr>().getInt();
});
}
static LogicalResult
verifyStructIndices(Type baseGEPType, unsigned indexPos,
GEPIndicesAdaptor<ValueRange> indices,
function_ref<InFlightDiagnostic()> emitOpError) {
if (indexPos >= indices.size())
return success();
return TypeSwitch<Type, LogicalResult>(baseGEPType)
.Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
if (!indices[indexPos].is<IntegerAttr>())
return emitOpError() << "expected index " << indexPos
<< " indexing a struct to be constant";
int32_t gepIndex = indices[indexPos].get<IntegerAttr>().getInt();
ArrayRef<Type> elementTypes = structType.getBody();
if (gepIndex < 0 ||
static_cast<size_t>(gepIndex) >= elementTypes.size())
return emitOpError() << "index " << indexPos
<< " indexing a struct is out of bounds";
return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
indices, emitOpError);
})
.Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
LLVMArrayType>([&](auto containerType) -> LogicalResult {
return verifyStructIndices(containerType.getElementType(), indexPos + 1,
indices, emitOpError);
})
.Default([&](auto otherType) -> LogicalResult {
return emitOpError()
<< "type " << otherType << " cannot be indexed (index #"
<< indexPos << ")";
});
}
static LogicalResult
verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices,
function_ref<InFlightDiagnostic()> emitOpError) {
return verifyStructIndices(baseGEPType, 1, indices, emitOpError);
}
LogicalResult LLVM::GEPOp::verify() {
if (static_cast<size_t>(
llvm::count(getRawConstantIndices(), kDynamicIndex)) !=
getDynamicIndices().size())
return emitOpError("expected as many dynamic indices as specified in '")
<< getRawConstantIndicesAttrName().getValue() << "'";
return verifyStructIndices(getElemType(), getIndices(),
[&] { return emitOpError(); });
}
Type GEPOp::getResultPtrElementType() {
Type selectedType = getElemType();
auto indices = getIndices();
for (GEPIndicesAdaptor<ValueRange>::value_type index :
llvm::drop_begin(indices)) {
if (auto arrayType = dyn_cast<LLVMArrayType>(selectedType)) {
selectedType = arrayType.getElementType();
continue;
}
selectedType = cast<DestructurableTypeInterface>(selectedType)
.getTypeAtIndex(cast<IntegerAttr>(index));
}
return selectedType;
}
void LoadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable());
if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
getOrdering() != AtomicOrdering::unordered)) {
effects.emplace_back(MemoryEffects::Write::get());
effects.emplace_back(MemoryEffects::Read::get());
}
}
static bool isTypeCompatibleWithAtomicOp(Type type,
const DataLayout &dataLayout) {
if (!isa<IntegerType, LLVMPointerType>(type))
if (!isCompatibleFloatingPointType(type))
return false;
llvm::TypeSize bitWidth = dataLayout.getTypeSizeInBits(type);
if (bitWidth.isScalable())
return false;
return bitWidth >= 8 && (bitWidth & (bitWidth - 1)) == 0;
}
template <typename OpTy>
LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType,
ArrayRef<AtomicOrdering> unsupportedOrderings) {
if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
DataLayout dataLayout = DataLayout::closest(memOp);
if (!isTypeCompatibleWithAtomicOp(valueType, dataLayout))
return memOp.emitOpError("unsupported type ")
<< valueType << " for atomic access";
if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
return memOp.emitOpError("unsupported ordering '")
<< stringifyAtomicOrdering(memOp.getOrdering()) << "'";
if (!memOp.getAlignment())
return memOp.emitOpError("expected alignment for atomic access");
return success();
}
if (memOp.getSyncscope())
return memOp.emitOpError(
"expected syncscope to be null for non-atomic access");
return success();
}
LogicalResult LoadOp::verify() {
Type valueType = getResult().getType();
return verifyAtomicMemOp(*this, valueType,
{AtomicOrdering::release, AtomicOrdering::acq_rel});
}
void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
Value addr, unsigned alignment, bool isVolatile,
bool isNonTemporal, bool isInvariant,
AtomicOrdering ordering, StringRef syncscope) {
build(builder, state, type, addr,
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
isNonTemporal, isInvariant, ordering,
syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
nullptr,
nullptr, nullptr,
nullptr);
}
void StoreOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable());
if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
getOrdering() != AtomicOrdering::unordered)) {
effects.emplace_back(MemoryEffects::Write::get());
effects.emplace_back(MemoryEffects::Read::get());
}
}
LogicalResult StoreOp::verify() {
Type valueType = getValue().getType();
return verifyAtomicMemOp(*this, valueType,
{AtomicOrdering::acquire, AtomicOrdering::acq_rel});
}
void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
Value addr, unsigned alignment, bool isVolatile,
bool isNonTemporal, AtomicOrdering ordering,
StringRef syncscope) {
build(builder, state, value, addr,
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
isNonTemporal, ordering,
syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
nullptr,
nullptr, nullptr, nullptr);
}
static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
SmallVector<Type, 1> results;
Type resultType = calleeType.getReturnType();
if (!isa<LLVM::LLVMVoidType>(resultType))
results.push_back(resultType);
return results;
}
static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
}
static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
ValueRange args) {
Type resultType;
if (results.empty())
resultType = LLVMVoidType::get(context);
else
resultType = results.front();
return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()),
false);
}
void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
StringRef callee, ValueRange args) {
build(builder, state, results, builder.getStringAttr(callee), args);
}
void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
StringAttr callee, ValueRange args) {
build(builder, state, results, SymbolRefAttr::get(callee), args);
}
void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
FlatSymbolRefAttr callee, ValueRange args) {
assert(callee && "expected non-null callee in direct call builder");
build(builder, state, results,
nullptr, callee, args, nullptr,
nullptr,
nullptr, nullptr,
nullptr, nullptr,
nullptr, nullptr);
}
void CallOp::build(OpBuilder &builder, OperationState &state,
LLVMFunctionType calleeType, StringRef callee,
ValueRange args) {
build(builder, state, calleeType, builder.getStringAttr(callee), args);
}
void CallOp::build(OpBuilder &builder, OperationState &state,
LLVMFunctionType calleeType, StringAttr callee,
ValueRange args) {
build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
}
void CallOp::build(OpBuilder &builder, OperationState &state,
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
ValueRange args) {
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), callee, args,
nullptr,
nullptr, nullptr,
nullptr, nullptr,
nullptr, nullptr, nullptr);
}
void CallOp::build(OpBuilder &builder, OperationState &state,
LLVMFunctionType calleeType, ValueRange args) {
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType),
nullptr, args,
nullptr, nullptr,
nullptr, nullptr,
nullptr, nullptr,
nullptr, nullptr);
}
void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
ValueRange args) {
auto calleeType = func.getFunctionType();
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
nullptr, nullptr,
nullptr, nullptr,
nullptr, nullptr,
nullptr, nullptr);
}
CallInterfaceCallable CallOp::getCallableForCallee() {
if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
return calleeAttr;
return getOperand(0);
}
void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
auto symRef = callee.get<SymbolRefAttr>();
return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
}
return setOperand(0, callee.get<Value>());
}
Operation::operand_range CallOp::getArgOperands() {
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
}
MutableOperandRange CallOp::getArgOperandsMutable() {
return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
getCalleeOperands().size());
}
static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
if (callee.isExternal())
return success();
auto parentFunc = callOp->getParentOfType<FunctionOpInterface>();
if (!parentFunc)
return success();
auto hasSubprogram = [](Operation *op) {
return op->getLoc()
->findInstanceOf<FusedLocWith<LLVM::DISubprogramAttr>>() !=
nullptr;
};
if (!hasSubprogram(parentFunc) || !hasSubprogram(callee))
return success();
bool containsLoc = !isa<UnknownLoc>(callOp->getLoc());
if (!containsLoc)
return callOp.emitError()
<< "inlinable function call in a function with a DISubprogram "
"location must have a debug location";
return success();
}
template <typename OpTy>
LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
if (!varCalleeType)
return success();
if (!varCalleeType->isVarArg())
return callOp.emitOpError(
"expected var_callee_type to be a variadic function type");
if (varCalleeType->getNumParams() > callOp.getArgOperands().size())
return callOp.emitOpError("expected var_callee_type to have at most ")
<< callOp.getArgOperands().size() << " parameters";
for (auto [paramType, operand] :
llvm::zip(varCalleeType->getParams(), callOp.getArgOperands()))
if (paramType != operand.getType())
return callOp.emitOpError()
<< "var_callee_type parameter type mismatch: " << paramType
<< " != " << operand.getType();
if (!callOp.getNumResults()) {
if (!isa<LLVMVoidType>(varCalleeType->getReturnType()))
return callOp.emitOpError("expected var_callee_type to return void");
} else {
if (callOp.getResult().getType() != varCalleeType->getReturnType())
return callOp.emitOpError("var_callee_type return type mismatch: ")
<< varCalleeType->getReturnType()
<< " != " << callOp.getResult().getType();
}
return success();
}
LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(verifyCallOpVarCalleeType(*this)))
return failure();
Type fnType;
bool isIndirect = false;
FlatSymbolRefAttr calleeName = getCalleeAttr();
if (!calleeName) {
isIndirect = true;
if (!getNumOperands())
return emitOpError(
"must have either a `callee` attribute or at least an operand");
auto ptrType = llvm::dyn_cast<LLVMPointerType>(getOperand(0).getType());
if (!ptrType)
return emitOpError("indirect call expects a pointer as callee: ")
<< getOperand(0).getType();
return success();
} else {
Operation *callee =
symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr());
if (!callee)
return emitOpError()
<< "'" << calleeName.getValue()
<< "' does not reference a symbol in the current scope";
auto fn = dyn_cast<LLVMFuncOp>(callee);
if (!fn)
return emitOpError() << "'" << calleeName.getValue()
<< "' does not reference a valid LLVM function";
if (failed(verifyCallOpDebugInfo(*this, fn)))
return failure();
fnType = fn.getFunctionType();
}
LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);
if (!funcType)
return emitOpError("callee does not have a functional type: ") << fnType;
if (funcType.isVarArg() && !getVarCalleeType())
return emitOpError() << "missing var_callee_type attribute for vararg call";
if (!funcType.isVarArg() &&
funcType.getNumParams() != (getNumOperands() - isIndirect))
return emitOpError() << "incorrect number of operands ("
<< (getNumOperands() - isIndirect)
<< ") for callee (expecting: "
<< funcType.getNumParams() << ")";
if (funcType.getNumParams() > (getNumOperands() - isIndirect))
return emitOpError() << "incorrect number of operands ("
<< (getNumOperands() - isIndirect)
<< ") for varargs callee (expecting at least: "
<< funcType.getNumParams() << ")";
for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
return emitOpError() << "operand type mismatch for operand " << i << ": "
<< getOperand(i + isIndirect).getType()
<< " != " << funcType.getParamType(i);
if (getNumResults() == 0 &&
!llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
return emitOpError() << "expected function call to produce a value";
if (getNumResults() != 0 &&
llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
return emitOpError()
<< "calling function with void result must not produce values";
if (getNumResults() > 1)
return emitOpError()
<< "expected LLVM function call to produce 0 or 1 result";
if (getNumResults() && getResult().getType() != funcType.getReturnType())
return emitOpError() << "result type mismatch: " << getResult().getType()
<< " != " << funcType.getReturnType();
return success();
}
void CallOp::print(OpAsmPrinter &p) {
auto callee = getCallee();
bool isDirect = callee.has_value();
p << ' ';
if (getCConv() != LLVM::CConv::C)
p << stringifyCConv(getCConv()) << ' ';
if(getTailCallKind() != LLVM::TailCallKind::None)
p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
if (isDirect)
p.printSymbolName(callee.value());
else
p << getOperand(0);
auto args = getOperands().drop_front(isDirect ? 0 : 1);
p << '(' << args << ')';
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
p << " vararg(" << *varCalleeType << ")";
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
{getCalleeAttrName(), getTailCallKindAttrName(),
getVarCalleeTypeAttrName(), getCConvAttrName()});
p << " : ";
if (!isDirect)
p << getOperand(0).getType() << ", ";
p.printFunctionalType(args.getTypes(), getResultTypes());
}
static ParseResult parseCallTypeAndResolveOperands(
OpAsmParser &parser, OperationState &result, bool isDirect,
ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
SMLoc trailingTypesLoc = parser.getCurrentLocation();
SmallVector<Type> types;
if (parser.parseColonTypeList(types))
return failure();
if (isDirect && types.size() != 1)
return parser.emitError(trailingTypesLoc,
"expected direct call to have 1 trailing type");
if (!isDirect && types.size() != 2)
return parser.emitError(trailingTypesLoc,
"expected indirect call to have 2 trailing types");
auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
if (!funcType)
return parser.emitError(trailingTypesLoc,
"expected trailing function type");
if (funcType.getNumResults() > 1)
return parser.emitError(trailingTypesLoc,
"expected function with 0 or 1 result");
if (funcType.getNumResults() == 1 &&
llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0)))
return parser.emitError(trailingTypesLoc,
"expected a non-void result type");
llvm::append_range(types, funcType.getInputs());
if (parser.resolveOperands(operands, types, parser.getNameLoc(),
result.operands))
return failure();
if (funcType.getNumResults() != 0)
result.addTypes(funcType.getResults());
return success();
}
static ParseResult parseOptionalCallFuncPtr(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands) {
OpAsmParser::UnresolvedOperand funcPtrOperand;
OptionalParseResult parseResult = parser.parseOptionalOperand(funcPtrOperand);
if (parseResult.has_value()) {
if (failed(*parseResult))
return *parseResult;
operands.push_back(funcPtrOperand);
}
return success();
}
ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
SymbolRefAttr funcAttr;
TypeAttr varCalleeType;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
result.addAttribute(
getCConvAttrName(result.name),
CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
parser, result, LLVM::CConv::C)));
result.addAttribute(
getTailCallKindAttrName(result.name),
TailCallKindAttr::get(parser.getContext(),
parseOptionalLLVMKeyword<TailCallKind>(
parser, result, LLVM::TailCallKind::None)));
if (parseOptionalCallFuncPtr(parser, operands))
return failure();
bool isDirect = operands.empty();
if (isDirect)
if (parser.parseAttribute(funcAttr, "callee", result.attributes))
return failure();
if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren))
return failure();
bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
if (isVarArg) {
StringAttr varCalleeTypeAttrName =
CallOp::getVarCalleeTypeAttrName(result.name);
if (parser.parseLParen().failed() ||
parser
.parseAttribute(varCalleeType, varCalleeTypeAttrName,
result.attributes)
.failed() ||
parser.parseRParen().failed())
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
return parseCallTypeAndResolveOperands(parser, result, isDirect, operands);
}
LLVMFunctionType CallOp::getCalleeFunctionType() {
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
return *varCalleeType;
return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
}
void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
ValueRange ops, Block *normal, ValueRange normalOps,
Block *unwind, ValueRange unwindOps) {
auto calleeType = func.getFunctionType();
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
normalOps, unwindOps, nullptr, nullptr, normal, unwind);
}
void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
ValueRange normalOps, Block *unwind,
ValueRange unwindOps) {
build(builder, state, tys,
nullptr, callee, ops, normalOps, unwindOps, nullptr,
nullptr, normal, unwind);
}
void InvokeOp::build(OpBuilder &builder, OperationState &state,
LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
ValueRange ops, Block *normal, ValueRange normalOps,
Block *unwind, ValueRange unwindOps) {
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps,
nullptr, nullptr, normal, unwind);
}
SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
: getUnwindDestOperandsMutable());
}
CallInterfaceCallable InvokeOp::getCallableForCallee() {
if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
return calleeAttr;
return getOperand(0);
}
void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
auto symRef = callee.get<SymbolRefAttr>();
return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
}
return setOperand(0, callee.get<Value>());
}
Operation::operand_range InvokeOp::getArgOperands() {
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
}
MutableOperandRange InvokeOp::getArgOperandsMutable() {
return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
getCalleeOperands().size());
}
LogicalResult InvokeOp::verify() {
if (failed(verifyCallOpVarCalleeType(*this)))
return failure();
Block *unwindDest = getUnwindDest();
if (unwindDest->empty())
return emitError("must have at least one operation in unwind destination");
if (!isa<LandingpadOp>(unwindDest->front()))
return emitError("first operation in unwind destination should be a "
"llvm.landingpad operation");
return success();
}
void InvokeOp::print(OpAsmPrinter &p) {
auto callee = getCallee();
bool isDirect = callee.has_value();
p << ' ';
if (getCConv() != LLVM::CConv::C)
p << stringifyCConv(getCConv()) << ' ';
if (isDirect)
p.printSymbolName(callee.value());
else
p << getOperand(0);
p << '(' << getOperands().drop_front(isDirect ? 0 : 1) << ')';
p << " to ";
p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
p << " unwind ";
p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
p << " vararg(" << *varCalleeType << ")";
p.printOptionalAttrDict((*this)->getAttrs(),
{getCalleeAttrName(), getOperandSegmentSizeAttr(),
getCConvAttrName(), getVarCalleeTypeAttrName()});
p << " : ";
if (!isDirect)
p << getOperand(0).getType() << ", ";
p.printFunctionalType(llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1),
getResultTypes());
}
ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 8> operands;
SymbolRefAttr funcAttr;
TypeAttr varCalleeType;
Block *normalDest, *unwindDest;
SmallVector<Value, 4> normalOperands, unwindOperands;
Builder &builder = parser.getBuilder();
result.addAttribute(
getCConvAttrName(result.name),
CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
parser, result, LLVM::CConv::C)));
if (parseOptionalCallFuncPtr(parser, operands))
return failure();
bool isDirect = operands.empty();
if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
return failure();
if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
parser.parseKeyword("to") ||
parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
parser.parseKeyword("unwind") ||
parser.parseSuccessorAndUseList(unwindDest, unwindOperands))
return failure();
bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
if (isVarArg) {
StringAttr varCalleeTypeAttrName =
InvokeOp::getVarCalleeTypeAttrName(result.name);
if (parser.parseLParen().failed() ||
parser
.parseAttribute(varCalleeType, varCalleeTypeAttrName,
result.attributes)
.failed() ||
parser.parseRParen().failed())
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands))
return failure();
result.addSuccessors({normalDest, unwindDest});
result.addOperands(normalOperands);
result.addOperands(unwindOperands);
result.addAttribute(InvokeOp::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr(
{static_cast<int32_t>(operands.size()),
static_cast<int32_t>(normalOperands.size()),
static_cast<int32_t>(unwindOperands.size())}));
return success();
}
LLVMFunctionType InvokeOp::getCalleeFunctionType() {
if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
return *varCalleeType;
return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
}
LogicalResult LandingpadOp::verify() {
Value value;
if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
if (!func.getPersonality())
return emitError(
"llvm.landingpad needs to be in a function with a personality");
}
if (!getCleanup() && getOperands().empty())
return emitError("landingpad instruction expects at least one clause or "
"cleanup attribute");
for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
value = getOperand(idx);
bool isFilter = llvm::isa<LLVMArrayType>(value.getType());
if (isFilter) {
} else {
if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
continue;
return emitError("constant clauses expected").attachNote(bcOp.getLoc())
<< "global addresses expected as operand to "
"bitcast used in clauses for landingpad";
}
if (value.getDefiningOp<ZeroOp>())
continue;
if (value.getDefiningOp<AddressOfOp>())
continue;
return emitError("clause #")
<< idx << " is not a known constant - null, addressof, bitcast";
}
}
return success();
}
void LandingpadOp::print(OpAsmPrinter &p) {
p << (getCleanup() ? " cleanup " : " ");
for (auto value : getOperands()) {
bool isArrayTy = llvm::isa<LLVMArrayType>(value.getType());
p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
<< value.getType() << ") ";
}
p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"});
p << ": " << getType();
}
ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
if (succeeded(parser.parseOptionalKeyword("cleanup")))
result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
while (succeeded(parser.parseOptionalLParen()) &&
(succeeded(parser.parseOptionalKeyword("filter")) ||
succeeded(parser.parseOptionalKeyword("catch")))) {
OpAsmParser::UnresolvedOperand operand;
Type ty;
if (parser.parseOperand(operand) || parser.parseColon() ||
parser.parseType(ty) ||
parser.resolveOperand(operand, ty, result.operands) ||
parser.parseRParen())
return failure();
}
Type type;
if (parser.parseColon() || parser.parseType(type))
return failure();
result.addTypes(type);
return success();
}
static Type getInsertExtractValueElementType(
function_ref<InFlightDiagnostic(StringRef)> emitError, Type containerType,
ArrayRef<int64_t> position) {
Type llvmType = containerType;
if (!isCompatibleType(containerType)) {
emitError("expected LLVM IR Dialect type, got ") << containerType;
return {};
}
for (int64_t idx : position) {
if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(llvmType)) {
if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) {
emitError("position out of bounds: ") << idx;
return {};
}
llvmType = arrayType.getElementType();
} else if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) {
if (idx < 0 ||
static_cast<unsigned>(idx) >= structType.getBody().size()) {
emitError("position out of bounds: ") << idx;
return {};
}
llvmType = structType.getBody()[idx];
} else {
emitError("expected LLVM IR structure/array type, got: ") << llvmType;
return {};
}
}
return llvmType;
}
static Type getInsertExtractValueElementType(Type llvmType,
ArrayRef<int64_t> position) {
for (int64_t idx : position) {
if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType))
llvmType = structType.getBody()[idx];
else
llvmType = llvm::cast<LLVMArrayType>(llvmType).getElementType();
}
return llvmType;
}
OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
OpFoldResult result = {};
while (insertValueOp) {
if (getPosition() == insertValueOp.getPosition())
return insertValueOp.getValue();
unsigned min =
std::min(getPosition().size(), insertValueOp.getPosition().size());
if (getPosition().take_front(min) ==
insertValueOp.getPosition().take_front(min))
return result;
getContainerMutable().assign(insertValueOp.getContainer());
result = getResult();
insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
}
return result;
}
LogicalResult ExtractValueOp::verify() {
auto emitError = [this](StringRef msg) { return emitOpError(msg); };
Type valueType = getInsertExtractValueElementType(
emitError, getContainer().getType(), getPosition());
if (!valueType)
return failure();
if (getRes().getType() != valueType)
return emitOpError() << "Type mismatch: extracting from "
<< getContainer().getType() << " should produce "
<< valueType << " but this op returns "
<< getRes().getType();
return success();
}
void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
Value container, ArrayRef<int64_t> position) {
build(builder, state,
getInsertExtractValueElementType(container.getType(), position),
container, builder.getAttr<DenseI64ArrayAttr>(position));
}
static ParseResult
parseInsertExtractValueElementType(AsmParser &parser, Type &valueType,
Type containerType,
DenseI64ArrayAttr position) {
valueType = getInsertExtractValueElementType(
[&](StringRef msg) {
return parser.emitError(parser.getCurrentLocation(), msg);
},
containerType, position.asArrayRef());
return success(!!valueType);
}
static void printInsertExtractValueElementType(AsmPrinter &printer,
Operation *op, Type valueType,
Type containerType,
DenseI64ArrayAttr position) {}
LogicalResult InsertValueOp::verify() {
auto emitError = [this](StringRef msg) { return emitOpError(msg); };
Type valueType = getInsertExtractValueElementType(
emitError, getContainer().getType(), getPosition());
if (!valueType)
return failure();
if (getValue().getType() != valueType)
return emitOpError() << "Type mismatch: cannot insert "
<< getValue().getType() << " into "
<< getContainer().getType();
return success();
}
LogicalResult ReturnOp::verify() {
auto parent = (*this)->getParentOfType<LLVMFuncOp>();
if (!parent)
return success();
Type expectedType = parent.getFunctionType().getReturnType();
if (llvm::isa<LLVMVoidType>(expectedType)) {
if (!getArg())
return success();
InFlightDiagnostic diag = emitOpError("expected no operands");
diag.attachNote(parent->getLoc()) << "when returning from function";
return diag;
}
if (!getArg()) {
if (llvm::isa<LLVMVoidType>(expectedType))
return success();
InFlightDiagnostic diag = emitOpError("expected 1 operand");
diag.attachNote(parent->getLoc()) << "when returning from function";
return diag;
}
if (expectedType != getArg().getType()) {
InFlightDiagnostic diag = emitOpError("mismatching result types");
diag.attachNote(parent->getLoc()) << "when returning from function";
return diag;
}
return success();
}
static Operation *parentLLVMModule(Operation *op) {
Operation *module = op->getParentOp();
while (module && !satisfiesLLVMModule(module))
module = module->getParentOp();
assert(module && "unexpected operation outside of a module");
return module;
}
GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
return dyn_cast_or_null<GlobalOp>(
symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
}
LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) {
return dyn_cast_or_null<LLVMFuncOp>(
symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
}
LogicalResult
AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
Operation *symbol =
symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr());
auto global = dyn_cast_or_null<GlobalOp>(symbol);
auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
if (!global && !function)
return emitOpError(
"must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
LLVMPointerType type = getType();
if (global && global.getAddrSpace() != type.getAddressSpace())
return emitOpError("pointer address space must match address space of the "
"referenced global");
return success();
}
OpFoldResult LLVM::AddressOfOp::fold(FoldAdaptor) {
return getGlobalNameAttr();
}
void ComdatOp::build(OpBuilder &builder, OperationState &result,
StringRef symName) {
result.addAttribute(getSymNameAttrName(result.name),
builder.getStringAttr(symName));
Region *body = result.addRegion();
body->emplaceBlock();
}
LogicalResult ComdatOp::verifyRegions() {
Region &body = getBody();
for (Operation &op : body.getOps())
if (!isa<ComdatSelectorOp>(op))
return op.emitError(
"only comdat selector symbols can appear in a comdat region");
return success();
}
void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
bool isConstant, Linkage linkage, StringRef name,
Attribute value, uint64_t alignment, unsigned addrSpace,
bool dsoLocal, bool threadLocal, SymbolRefAttr comdat,
ArrayRef<NamedAttribute> attrs,
DIGlobalVariableExpressionAttr dbgExpr) {
result.addAttribute(getSymNameAttrName(result.name),
builder.getStringAttr(name));
result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type));
if (isConstant)
result.addAttribute(getConstantAttrName(result.name),
builder.getUnitAttr());
if (value)
result.addAttribute(getValueAttrName(result.name), value);
if (dsoLocal)
result.addAttribute(getDsoLocalAttrName(result.name),
builder.getUnitAttr());
if (threadLocal)
result.addAttribute(getThreadLocal_AttrName(result.name),
builder.getUnitAttr());
if (comdat)
result.addAttribute(getComdatAttrName(result.name), comdat);
if (alignment != 0)
result.addAttribute(getAlignmentAttrName(result.name),
builder.getI64IntegerAttr(alignment));
result.addAttribute(getLinkageAttrName(result.name),
LinkageAttr::get(builder.getContext(), linkage));
if (addrSpace != 0)
result.addAttribute(getAddrSpaceAttrName(result.name),
builder.getI32IntegerAttr(addrSpace));
result.attributes.append(attrs.begin(), attrs.end());
if (dbgExpr)
result.addAttribute(getDbgExprAttrName(result.name), dbgExpr);
result.addRegion();
}
void GlobalOp::print(OpAsmPrinter &p) {
p << ' ' << stringifyLinkage(getLinkage()) << ' ';
StringRef visibility = stringifyVisibility(getVisibility_());
if (!visibility.empty())
p << visibility << ' ';
if (getThreadLocal_())
p << "thread_local ";
if (auto unnamedAddr = getUnnamedAddr()) {
StringRef str = stringifyUnnamedAddr(*unnamedAddr);
if (!str.empty())
p << str << ' ';
}
if (getConstant())
p << "constant ";
p.printSymbolName(getSymName());
p << '(';
if (auto value = getValueOrNull())
p.printAttribute(value);
p << ')';
if (auto comdat = getComdat())
p << " comdat(" << *comdat << ')';
p.printOptionalAttrDict((*this)->getAttrs(),
{SymbolTable::getSymbolAttrName(),
getGlobalTypeAttrName(), getConstantAttrName(),
getValueAttrName(), getLinkageAttrName(),
getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
getVisibility_AttrName(), getComdatAttrName(),
getUnnamedAddrAttrName()});
if (llvm::dyn_cast_or_null<StringAttr>(getValueOrNull()))
return;
p << " : " << getType();
Region &initializer = getInitializerRegion();
if (!initializer.empty()) {
p << ' ';
p.printRegion(initializer, false);
}
}
static LogicalResult verifyComdat(Operation *op,
std::optional<SymbolRefAttr> attr) {
if (!attr)
return success();
auto *comdatSelector = SymbolTable::lookupNearestSymbolFrom(op, *attr);
if (!isa_and_nonnull<ComdatSelectorOp>(comdatSelector))
return op->emitError() << "expected comdat symbol";
return success();
}
ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
MLIRContext *ctx = parser.getContext();
result.addAttribute(getLinkageAttrName(result.name),
LLVM::LinkageAttr::get(
ctx, parseOptionalLLVMKeyword<Linkage>(
parser, result, LLVM::Linkage::External)));
result.addAttribute(getVisibility_AttrName(result.name),
parser.getBuilder().getI64IntegerAttr(
parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
parser, result, LLVM::Visibility::Default)));
result.addAttribute(getUnnamedAddrAttrName(result.name),
parser.getBuilder().getI64IntegerAttr(
parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
parser, result, LLVM::UnnamedAddr::None)));
if (succeeded(parser.parseOptionalKeyword("thread_local")))
result.addAttribute(getThreadLocal_AttrName(result.name),
parser.getBuilder().getUnitAttr());
if (succeeded(parser.parseOptionalKeyword("constant")))
result.addAttribute(getConstantAttrName(result.name),
parser.getBuilder().getUnitAttr());
StringAttr name;
if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
result.attributes) ||
parser.parseLParen())
return failure();
Attribute value;
if (parser.parseOptionalRParen()) {
if (parser.parseAttribute(value, getValueAttrName(result.name),
result.attributes) ||
parser.parseRParen())
return failure();
}
if (succeeded(parser.parseOptionalKeyword("comdat"))) {
SymbolRefAttr comdat;
if (parser.parseLParen() || parser.parseAttribute(comdat) ||
parser.parseRParen())
return failure();
result.addAttribute(getComdatAttrName(result.name), comdat);
}
SmallVector<Type, 1> types;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseOptionalColonTypeList(types))
return failure();
if (types.size() > 1)
return parser.emitError(parser.getNameLoc(), "expected zero or one type");
Region &initRegion = *result.addRegion();
if (types.empty()) {
if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(value)) {
MLIRContext *context = parser.getContext();
auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
strAttr.getValue().size());
types.push_back(arrayType);
} else {
return parser.emitError(parser.getNameLoc(),
"type can only be omitted for string globals");
}
} else {
OptionalParseResult parseResult =
parser.parseOptionalRegion(initRegion, {},
{});
if (parseResult.has_value() && failed(*parseResult))
return failure();
}
result.addAttribute(getGlobalTypeAttrName(result.name),
TypeAttr::get(types[0]));
return success();
}
static bool isZeroAttribute(Attribute value) {
if (auto intValue = llvm::dyn_cast<IntegerAttr>(value))
return intValue.getValue().isZero();
if (auto fpValue = llvm::dyn_cast<FloatAttr>(value))
return fpValue.getValue().isZero();
if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(value))
return isZeroAttribute(splatValue.getSplatValue<Attribute>());
if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(value))
return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(value))
return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
return false;
}
LogicalResult GlobalOp::verify() {
bool validType = isCompatibleOuterType(getType())
? !llvm::isa<LLVMVoidType, LLVMTokenType,
LLVMMetadataType, LLVMLabelType>(getType())
: llvm::isa<PointerElementTypeInterface>(getType());
if (!validType)
return emitOpError(
"expects type to be a valid element type for an LLVM global");
if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
return emitOpError("must appear at the module level");
if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) {
auto type = llvm::dyn_cast<LLVMArrayType>(getType());
IntegerType elementType =
type ? llvm::dyn_cast<IntegerType>(type.getElementType()) : nullptr;
if (!elementType || elementType.getWidth() != 8 ||
type.getNumElements() != strAttr.getValue().size())
return emitOpError(
"requires an i8 array type of the length equal to that of the string "
"attribute");
}
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal))
return emitOpError()
<< "this target extension type cannot be used in a global";
if (Attribute value = getValueOrNull())
return emitOpError() << "global with target extension type can only be "
"initialized with zero-initializer";
}
if (getLinkage() == Linkage::Common) {
if (Attribute value = getValueOrNull()) {
if (!isZeroAttribute(value)) {
return emitOpError()
<< "expected zero value for '"
<< stringifyLinkage(Linkage::Common) << "' linkage";
}
}
}
if (getLinkage() == Linkage::Appending) {
if (!llvm::isa<LLVMArrayType>(getType())) {
return emitOpError() << "expected array type for '"
<< stringifyLinkage(Linkage::Appending)
<< "' linkage";
}
}
if (failed(verifyComdat(*this, getComdat())))
return failure();
std::optional<uint64_t> alignAttr = getAlignment();
if (alignAttr.has_value()) {
uint64_t value = alignAttr.value();
if (!llvm::isPowerOf2_64(value))
return emitError() << "alignment attribute is not a power of 2";
}
return success();
}
LogicalResult GlobalOp::verifyRegions() {
if (Block *b = getInitializerBlock()) {
ReturnOp ret = cast<ReturnOp>(b->getTerminator());
if (ret.operand_type_begin() == ret.operand_type_end())
return emitOpError("initializer region cannot return void");
if (*ret.operand_type_begin() != getType())
return emitOpError("initializer region type ")
<< *ret.operand_type_begin() << " does not match global type "
<< getType();
for (Operation &op : *b) {
auto iface = dyn_cast<MemoryEffectOpInterface>(op);
if (!iface || !iface.hasNoEffect())
return op.emitError()
<< "ops with side effects not allowed in global initializers";
}
if (getValueOrNull())
return emitOpError("cannot have both initializer value and region");
}
return success();
}
LogicalResult
GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
for (Attribute ctor : getCtors()) {
if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(ctor), *this,
symbolTable)))
return failure();
}
return success();
}
LogicalResult GlobalCtorsOp::verify() {
if (getCtors().size() != getPriorities().size())
return emitError(
"mismatch between the number of ctors and the number of priorities");
return success();
}
LogicalResult
GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
for (Attribute dtor : getDtors()) {
if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(dtor), *this,
symbolTable)))
return failure();
}
return success();
}
LogicalResult GlobalDtorsOp::verify() {
if (getDtors().size() != getPriorities().size())
return emitError(
"mismatch between the number of dtors and the number of priorities");
return success();
}
void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
Value v2, DenseI32ArrayAttr mask,
ArrayRef<NamedAttribute> attrs) {
auto containerType = v1.getType();
auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType),
mask.size(),
LLVM::isScalableVectorType(containerType));
build(builder, state, vType, v1, v2, mask);
state.addAttributes(attrs);
}
void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
Value v2, ArrayRef<int32_t> mask) {
build(builder, state, v1, v2, builder.getDenseI32ArrayAttr(mask));
}
static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
Type &resType, DenseI32ArrayAttr mask) {
if (!LLVM::isCompatibleVectorType(v1Type))
return parser.emitError(parser.getCurrentLocation(),
"expected an LLVM compatible vector type");
resType = LLVM::getVectorType(LLVM::getVectorElementType(v1Type), mask.size(),
LLVM::isScalableVectorType(v1Type));
return success();
}
static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type,
Type resType, DenseI32ArrayAttr mask) {}
LogicalResult ShuffleVectorOp::verify() {
if (LLVM::isScalableVectorType(getV1().getType()) &&
llvm::any_of(getMask(), [](int32_t v) { return v != 0; }))
return emitOpError("expected a splat operation for scalable vectors");
return success();
}
Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) {
assert(empty() && "function already has an entry block");
OpBuilder::InsertionGuard g(builder);
Block *entry = builder.createBlock(&getBody());
LLVMFunctionType type = getFunctionType();
for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
entry->addArgument(type.getParamType(i), getLoc());
return entry;
}
void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
StringRef name, Type type, LLVM::Linkage linkage,
bool dsoLocal, CConv cconv, SymbolRefAttr comdat,
ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs,
std::optional<uint64_t> functionEntryCount) {
result.addRegion();
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(type));
result.addAttribute(getLinkageAttrName(result.name),
LinkageAttr::get(builder.getContext(), linkage));
result.addAttribute(getCConvAttrName(result.name),
CConvAttr::get(builder.getContext(), cconv));
result.attributes.append(attrs.begin(), attrs.end());
if (dsoLocal)
result.addAttribute(getDsoLocalAttrName(result.name),
builder.getUnitAttr());
if (comdat)
result.addAttribute(getComdatAttrName(result.name), comdat);
if (functionEntryCount)
result.addAttribute(getFunctionEntryCountAttrName(result.name),
builder.getI64IntegerAttr(functionEntryCount.value()));
if (argAttrs.empty())
return;
assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
"expected as many argument attribute lists as arguments");
function_interface_impl::addArgAndResultAttrs(
builder, result, argAttrs, std::nullopt,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
static Type
buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs,
ArrayRef<Type> outputs,
function_interface_impl::VariadicFlag variadicFlag) {
Builder &b = parser.getBuilder();
if (outputs.size() > 1) {
parser.emitError(loc, "failed to construct function type: expected zero or "
"one function result");
return {};
}
SmallVector<Type, 4> llvmInputs;
for (auto t : inputs) {
if (!isCompatibleType(t)) {
parser.emitError(loc, "failed to construct function type: expected LLVM "
"type for function arguments");
return {};
}
llvmInputs.push_back(t);
}
Type llvmOutput =
outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
if (!isCompatibleType(llvmOutput)) {
parser.emitError(loc, "failed to construct function type: expected LLVM "
"type for function results")
<< llvmOutput;
return {};
}
return LLVMFunctionType::get(llvmOutput, llvmInputs,
variadicFlag.isVariadic());
}
ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
result.addAttribute(
getLinkageAttrName(result.name),
LinkageAttr::get(parser.getContext(),
parseOptionalLLVMKeyword<Linkage>(
parser, result, LLVM::Linkage::External)));
result.addAttribute(getVisibility_AttrName(result.name),
parser.getBuilder().getI64IntegerAttr(
parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
parser, result, LLVM::Visibility::Default)));
result.addAttribute(getUnnamedAddrAttrName(result.name),
parser.getBuilder().getI64IntegerAttr(
parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
parser, result, LLVM::UnnamedAddr::None)));
result.addAttribute(
getCConvAttrName(result.name),
CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
parser, result, LLVM::CConv::C)));
StringAttr nameAttr;
SmallVector<OpAsmParser::Argument> entryArgs;
SmallVector<DictionaryAttr> resultAttrs;
SmallVector<Type> resultTypes;
bool isVariadic;
auto signatureLocation = parser.getCurrentLocation();
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes) ||
function_interface_impl::parseFunctionSignature(
parser, true, entryArgs, isVariadic, resultTypes,
resultAttrs))
return failure();
SmallVector<Type> argTypes;
for (auto &arg : entryArgs)
argTypes.push_back(arg.type);
auto type =
buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
function_interface_impl::VariadicFlag(isVariadic));
if (!type)
return failure();
result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(type));
if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
int64_t minRange, maxRange;
if (parser.parseLParen() || parser.parseInteger(minRange) ||
parser.parseComma() || parser.parseInteger(maxRange) ||
parser.parseRParen())
return failure();
auto intTy = IntegerType::get(parser.getContext(), 32);
result.addAttribute(
getVscaleRangeAttrName(result.name),
LLVM::VScaleRangeAttr::get(parser.getContext(),
IntegerAttr::get(intTy, minRange),
IntegerAttr::get(intTy, maxRange)));
}
if (succeeded(parser.parseOptionalKeyword("comdat"))) {
SymbolRefAttr comdat;
if (parser.parseLParen() || parser.parseAttribute(comdat) ||
parser.parseRParen())
return failure();
result.addAttribute(getComdatAttrName(result.name), comdat);
}
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
return failure();
function_interface_impl::addArgAndResultAttrs(
parser.getBuilder(), result, entryArgs, resultAttrs,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
auto *body = result.addRegion();
OptionalParseResult parseResult =
parser.parseOptionalRegion(*body, entryArgs);
return failure(parseResult.has_value() && failed(*parseResult));
}
void LLVMFuncOp::print(OpAsmPrinter &p) {
p << ' ';
if (getLinkage() != LLVM::Linkage::External)
p << stringifyLinkage(getLinkage()) << ' ';
StringRef visibility = stringifyVisibility(getVisibility_());
if (!visibility.empty())
p << visibility << ' ';
if (auto unnamedAddr = getUnnamedAddr()) {
StringRef str = stringifyUnnamedAddr(*unnamedAddr);
if (!str.empty())
p << str << ' ';
}
if (getCConv() != LLVM::CConv::C)
p << stringifyCConv(getCConv()) << ' ';
p.printSymbolName(getName());
LLVMFunctionType fnType = getFunctionType();
SmallVector<Type, 8> argTypes;
SmallVector<Type, 1> resTypes;
argTypes.reserve(fnType.getNumParams());
for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
argTypes.push_back(fnType.getParamType(i));
Type returnType = fnType.getReturnType();
if (!llvm::isa<LLVMVoidType>(returnType))
resTypes.push_back(returnType);
function_interface_impl::printFunctionSignature(p, *this, argTypes,
isVarArg(), resTypes);
if (std::optional<VScaleRangeAttr> vscale = getVscaleRange())
p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
<< vscale->getMaxRange().getInt() << ')';
if (auto comdat = getComdat())
p << " comdat(" << *comdat << ')';
function_interface_impl::printFunctionAttributes(
p, *this,
{getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
getComdatAttrName(), getUnnamedAddrAttrName(),
getVscaleRangeAttrName()});
Region &body = getBody();
if (!body.empty()) {
p << ' ';
p.printRegion(body, false,
true);
}
}
LogicalResult LLVMFuncOp::verify() {
if (getLinkage() == LLVM::Linkage::Common)
return emitOpError() << "functions cannot have '"
<< stringifyLinkage(LLVM::Linkage::Common)
<< "' linkage";
if (failed(verifyComdat(*this, getComdat())))
return failure();
if (isExternal()) {
if (getLinkage() != LLVM::Linkage::External &&
getLinkage() != LLVM::Linkage::ExternWeak)
return emitOpError() << "external functions must have '"
<< stringifyLinkage(LLVM::Linkage::External)
<< "' or '"
<< stringifyLinkage(LLVM::Linkage::ExternWeak)
<< "' linkage";
return success();
}
if (isNoInline() && isAlwaysInline())
return emitError("no_inline and always_inline attributes are incompatible");
if (isOptimizeNone() && !isNoInline())
return emitOpError("with optimize_none must also be no_inline");
Type landingpadResultTy;
StringRef diagnosticMessage;
bool isLandingpadTypeConsistent =
!walk([&](Operation *op) {
const auto checkType = [&](Type type, StringRef errorMessage) {
if (!landingpadResultTy) {
landingpadResultTy = type;
return WalkResult::advance();
}
if (landingpadResultTy != type) {
diagnosticMessage = errorMessage;
return WalkResult::interrupt();
}
return WalkResult::advance();
};
return TypeSwitch<Operation *, WalkResult>(op)
.Case<LandingpadOp>([&](auto landingpad) {
constexpr StringLiteral errorMessage =
"'llvm.landingpad' should have a consistent result type "
"inside a function";
return checkType(landingpad.getType(), errorMessage);
})
.Case<ResumeOp>([&](auto resume) {
constexpr StringLiteral errorMessage =
"'llvm.resume' should have a consistent input type inside a "
"function";
return checkType(resume.getValue().getType(), errorMessage);
})
.Default([](auto) { return WalkResult::skip(); });
}).wasInterrupted();
if (!isLandingpadTypeConsistent) {
assert(!diagnosticMessage.empty() &&
"Expecting a non-empty diagnostic message");
return emitError(diagnosticMessage);
}
return success();
}
LogicalResult LLVMFuncOp::verifyRegions() {
if (isExternal())
return success();
unsigned numArguments = getFunctionType().getNumParams();
Block &entryBlock = front();
for (unsigned i = 0; i < numArguments; ++i) {
Type argType = entryBlock.getArgument(i).getType();
if (!isCompatibleType(argType))
return emitOpError("entry block argument #")
<< i << " is not of LLVM type";
}
return success();
}
Region *LLVMFuncOp::getCallableRegion() {
if (isExternal())
return nullptr;
return &getBody();
}
OpFoldResult LLVM::UndefOp::fold(FoldAdaptor) {
return LLVM::UndefAttr::get(getContext());
}
OpFoldResult LLVM::PoisonOp::fold(FoldAdaptor) {
return LLVM::PoisonAttr::get(getContext());
}
LogicalResult LLVM::ZeroOp::verify() {
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit))
return emitOpError()
<< "target extension type does not support zero-initializer";
return success();
}
OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
OpFoldResult result = Builder(getContext()).getZeroAttr(getType());
if (result)
return result;
return LLVM::ZeroAttr::get(getContext());
}
LogicalResult LLVM::ConstantOp::verify() {
if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
!arrayType.getElementType().isInteger(8)) {
return emitOpError() << "expected array type of "
<< sAttr.getValue().size()
<< " i8 elements for the string constant";
}
return success();
}
if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) {
if (structType.getBody().size() != 2 ||
structType.getBody()[0] != structType.getBody()[1]) {
return emitError() << "expected struct type with two elements of the "
"same type, the type of a complex constant";
}
auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue());
if (!arrayAttr || arrayAttr.size() != 2) {
return emitOpError() << "expected array attribute with two elements, "
"representing a complex constant";
}
auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
if (!re || !im || re.getType() != im.getType()) {
return emitOpError()
<< "expected array attribute with two elements of the same type";
}
Type elementType = structType.getBody()[0];
if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>(
elementType)) {
return emitError()
<< "expected struct element types to be floating point type or "
"integer type";
}
return success();
}
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
return emitOpError() << "does not support target extension type.";
}
if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue()))
return emitOpError()
<< "only supports integer, float, string or elements attributes";
if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
if (!llvm::isa<IntegerType>(getType()))
return emitOpError() << "expected integer type";
}
if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
unsigned floatWidth = APFloat::getSizeInBits(sem);
if (auto floatTy = dyn_cast<FloatType>(getType())) {
if (floatTy.getWidth() != floatWidth) {
return emitOpError() << "expected float type of width " << floatWidth;
}
}
if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
return emitOpError() << "expected integer type of width " << floatWidth;
}
}
if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) {
if (!isa<VectorType>(getType()) && !isa<LLVM::LLVMArrayType>(getType()) &&
!isa<LLVM::LLVMFixedVectorType>(getType()) &&
!isa<LLVM::LLVMScalableVectorType>(getType()))
return emitOpError() << "expected vector or array type";
}
return success();
}
bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) {
auto typedAttr = dyn_cast<TypedAttr>(value);
if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type))
return false;
if (!isCompatibleType(type))
return false;
return isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
}
ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
Type type, Location loc) {
if (isBuildableWith(value, type))
return builder.create<LLVM::ConstantOp>(loc, cast<TypedAttr>(value));
return nullptr;
}
OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
AtomicBinOp binOp, Value ptr, Value val,
AtomicOrdering ordering, StringRef syncscope,
unsigned alignment, bool isVolatile) {
build(builder, state, val.getType(), binOp, ptr, val, ordering,
!syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
nullptr,
nullptr, nullptr, nullptr);
}
LogicalResult AtomicRMWOp::verify() {
auto valType = getVal().getType();
if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) {
if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
return emitOpError("expected LLVM IR floating point type");
} else if (getBinOp() == AtomicBinOp::xchg) {
DataLayout dataLayout = DataLayout::closest(*this);
if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
} else {
auto intType = llvm::dyn_cast<IntegerType>(valType);
unsigned intBitWidth = intType ? intType.getWidth() : 0;
if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
intBitWidth != 64)
return emitOpError("expected LLVM IR integer type");
}
if (static_cast<unsigned>(getOrdering()) <
static_cast<unsigned>(AtomicOrdering::monotonic))
return emitOpError() << "expected at least '"
<< stringifyAtomicOrdering(AtomicOrdering::monotonic)
<< "' ordering";
return success();
}
static LLVMStructType getValAndBoolStructType(Type valType) {
auto boolType = IntegerType::get(valType.getContext(), 1);
return LLVMStructType::getLiteral(valType.getContext(), {valType, boolType});
}
void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state,
Value ptr, Value cmp, Value val,
AtomicOrdering successOrdering,
AtomicOrdering failureOrdering, StringRef syncscope,
unsigned alignment, bool isWeak, bool isVolatile) {
build(builder, state, getValAndBoolStructType(val.getType()), ptr, cmp, val,
successOrdering, failureOrdering,
!syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isWeak,
isVolatile, nullptr,
nullptr, nullptr, nullptr);
}
LogicalResult AtomicCmpXchgOp::verify() {
auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType());
if (!ptrType)
return emitOpError("expected LLVM IR pointer type for operand #0");
auto valType = getVal().getType();
DataLayout dataLayout = DataLayout::closest(*this);
if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
return emitOpError("unexpected LLVM IR type");
if (getSuccessOrdering() < AtomicOrdering::monotonic ||
getFailureOrdering() < AtomicOrdering::monotonic)
return emitOpError("ordering must be at least 'monotonic'");
if (getFailureOrdering() == AtomicOrdering::release ||
getFailureOrdering() == AtomicOrdering::acq_rel)
return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
return success();
}
void FenceOp::build(OpBuilder &builder, OperationState &state,
AtomicOrdering ordering, StringRef syncscope) {
build(builder, state, ordering,
syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
}
LogicalResult FenceOp::verify() {
if (getOrdering() == AtomicOrdering::not_atomic ||
getOrdering() == AtomicOrdering::unordered ||
getOrdering() == AtomicOrdering::monotonic)
return emitOpError("can be given only acquire, release, acq_rel, "
"and seq_cst orderings");
return success();
}
template <class ExtOp>
static LogicalResult verifyExtOp(ExtOp op) {
IntegerType inputType, outputType;
if (isCompatibleVectorType(op.getArg().getType())) {
if (!isCompatibleVectorType(op.getResult().getType()))
return op.emitError(
"input type is a vector but output type is an integer");
if (getVectorNumElements(op.getArg().getType()) !=
getVectorNumElements(op.getResult().getType()))
return op.emitError("input and output vectors are of incompatible shape");
inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType()));
outputType =
cast<IntegerType>(getVectorElementType(op.getResult().getType()));
} else {
inputType = cast<IntegerType>(op.getArg().getType());
outputType = dyn_cast<IntegerType>(op.getResult().getType());
if (!outputType)
return op.emitError(
"input type is an integer but output type is a vector");
}
if (outputType.getWidth() <= inputType.getWidth())
return op.emitError("integer width of the output type is smaller or "
"equal to the integer width of the input type");
return success();
}
LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); }
OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) {
auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
if (!arg)
return {};
size_t targetSize = cast<IntegerType>(getType()).getWidth();
return IntegerAttr::get(getType(), arg.getValue().zext(targetSize));
}
LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
template <typename T>
static OpFoldResult foldChainableCast(T castOp,
typename T::FoldAdaptor adaptor) {
if (castOp.getArg().getType() == castOp.getType())
return castOp.getArg();
if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
if (prev.getArg().getType() == castOp.getType())
return prev.getArg();
castOp.getArgMutable().set(prev.getArg());
return Value{castOp};
}
return {};
}
OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
return foldChainableCast(*this, adaptor);
}
LogicalResult LLVM::BitcastOp::verify() {
auto resultType = llvm::dyn_cast<LLVMPointerType>(
extractVectorElementType(getResult().getType()));
auto sourceType = llvm::dyn_cast<LLVMPointerType>(
extractVectorElementType(getArg().getType()));
if (static_cast<bool>(resultType) != static_cast<bool>(sourceType))
return emitOpError("can only cast pointers from and to pointers");
if (!resultType)
return success();
auto isVector =
llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>;
if (isVector(getResult().getType()) && !isVector(getArg().getType()))
return emitOpError("cannot cast pointer to vector of pointers");
if (!isVector(getResult().getType()) && isVector(getArg().getType()))
return emitOpError("cannot cast vector of pointers to pointer");
if (resultType.getAddressSpace() != sourceType.getAddressSpace())
return emitOpError("cannot cast pointers of different address spaces, "
"use 'llvm.addrspacecast' instead");
return success();
}
OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
return foldChainableCast(*this, adaptor);
}
OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
adaptor.getDynamicIndices());
if (getBase().getType() == getType() && indices.size() == 1)
if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(indices[0]))
if (integer.getValue().isZero())
return getBase();
bool changed = false;
SmallVector<GEPArg> gepArgs;
for (auto iter : llvm::enumerate(indices)) {
auto integer = llvm::dyn_cast_or_null<IntegerAttr>(iter.value());
if (!indices.isDynamicIndex(iter.index()) || !integer ||
!integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
if (Value val = llvm::dyn_cast_if_present<Value>(existing))
gepArgs.emplace_back(val);
else
gepArgs.emplace_back(existing.get<IntegerAttr>().getInt());
continue;
}
changed = true;
gepArgs.emplace_back(integer.getInt());
}
if (changed) {
SmallVector<int32_t> rawConstantIndices;
SmallVector<Value> dynamicIndices;
destructureIndices(getElemType(), gepArgs, rawConstantIndices,
dynamicIndices);
getDynamicIndicesMutable().assign(dynamicIndices);
setRawConstantIndices(rawConstantIndices);
return Value{*this};
}
return {};
}
OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) {
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
if (!rhs)
return {};
if (rhs.getValue().getZExtValue() >=
getLhs().getType().getIntOrFloatBitWidth())
return {};
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
if (!lhs)
return {};
return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue()));
}
OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
if (!lhs)
return {};
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
if (!rhs)
return {};
return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
}
LogicalResult CallIntrinsicOp::verify() {
if (!getIntrin().starts_with("llvm."))
return emitOpError() << "intrinsic name must start with 'llvm.'";
return success();
}
namespace {
struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
return TypeSwitch<Attribute, AliasResult>(attr)
.Case<AccessGroupAttr, AliasScopeAttr, AliasScopeDomainAttr,
DIBasicTypeAttr, DICompileUnitAttr, DICompositeTypeAttr,
DIDerivedTypeAttr, DIFileAttr, DIGlobalVariableAttr,
DIGlobalVariableExpressionAttr, DILabelAttr, DILexicalBlockAttr,
DILexicalBlockFileAttr, DILocalVariableAttr, DIModuleAttr,
DINamespaceAttr, DINullTypeAttr, DIStringTypeAttr,
DISubprogramAttr, DISubroutineTypeAttr, LoopAnnotationAttr,
LoopVectorizeAttr, LoopInterleaveAttr, LoopUnrollAttr,
LoopUnrollAndJamAttr, LoopLICMAttr, LoopDistributeAttr,
LoopPipelineAttr, LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr,
TBAATagAttr, TBAATypeDescriptorAttr>([&](auto attr) {
os << decltype(attr)::getMnemonic();
return AliasResult::OverridableAlias;
})
.Default([](Attribute) { return AliasResult::NoAlias; });
}
};
}
LogicalResult LinkerOptionsOp::verify() {
if (mlir::Operation *parentOp = (*this)->getParentOp();
parentOp && !satisfiesLLVMModule(parentOp))
return emitOpError("must appear at the module level");
return success();
}
void InlineAsmOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (getHasSideEffects()) {
effects.emplace_back(MemoryEffects::Write::get());
effects.emplace_back(MemoryEffects::Read::get());
}
}
LogicalResult LLVM::masked_gather::verify() {
auto ptrsVectorType = getPtrs().getType();
Type expectedPtrsVectorType =
LLVM::getVectorType(extractVectorElementType(ptrsVectorType),
LLVM::getVectorNumElements(getRes().getType()));
if (ptrsVectorType != expectedPtrsVectorType)
return emitOpError("expected operand #1 type to be ")
<< expectedPtrsVectorType;
return success();
}
LogicalResult LLVM::masked_scatter::verify() {
auto ptrsVectorType = getPtrs().getType();
Type expectedPtrsVectorType =
LLVM::getVectorType(extractVectorElementType(ptrsVectorType),
LLVM::getVectorNumElements(getValue().getType()));
if (ptrsVectorType != expectedPtrsVectorType)
return emitOpError("expected operand #2 type to be ")
<< expectedPtrsVectorType;
return success();
}
void LLVMDialect::initialize() {
registerAttributes();
addTypes<LLVMVoidType,
LLVMPPCFP128Type,
LLVMX86MMXType,
LLVMTokenType,
LLVMLabelType,
LLVMMetadataType,
LLVMStructType>();
registerTypes();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
,
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
>();
allowUnknownOperations();
addInterfaces<LLVMOpAsmDialectInterface>();
detail::addLLVMInlinerInterface(this);
}
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
LogicalResult LLVMDialect::verifyDataLayoutString(
StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
llvm::Expected<llvm::DataLayout> maybeDataLayout =
llvm::DataLayout::parse(descr);
if (maybeDataLayout)
return success();
std::string message;
llvm::raw_string_ostream messageStream(message);
llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
reportError("invalid data layout descriptor: " + messageStream.str());
return failure();
}
LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
return success();
if (auto stringAttr = llvm::dyn_cast<StringAttr>(attr.getValue()))
return verifyDataLayoutString(
stringAttr.getValue(),
[op](const Twine &message) { op->emitOpError() << message.str(); });
return op->emitOpError() << "expected '"
<< LLVM::LLVMDialect::getDataLayoutAttrName()
<< "' to be a string attributes";
}
LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op,
Type paramType,
NamedAttribute paramAttr) {
bool verifyValueType = isCompatibleType(paramType);
StringAttr name = paramAttr.getName();
auto checkUnitAttrType = [&]() -> LogicalResult {
if (!llvm::isa<UnitAttr>(paramAttr.getValue()))
return op->emitError() << name << " should be a unit attribute";
return success();
};
auto checkTypeAttrType = [&]() -> LogicalResult {
if (!llvm::isa<TypeAttr>(paramAttr.getValue()))
return op->emitError() << name << " should be a type attribute";
return success();
};
auto checkIntegerAttrType = [&]() -> LogicalResult {
if (!llvm::isa<IntegerAttr>(paramAttr.getValue()))
return op->emitError() << name << " should be an integer attribute";
return success();
};
auto checkPointerType = [&]() -> LogicalResult {
if (!llvm::isa<LLVMPointerType>(paramType))
return op->emitError()
<< name << " attribute attached to non-pointer LLVM type";
return success();
};
auto checkIntegerType = [&]() -> LogicalResult {
if (!llvm::isa<IntegerType>(paramType))
return op->emitError()
<< name << " attribute attached to non-integer LLVM type";
return success();
};
auto checkPointerTypeMatches = [&]() -> LogicalResult {
if (failed(checkPointerType()))
return failure();
return success();
};
if (name == LLVMDialect::getNoAliasAttrName() ||
name == LLVMDialect::getReadonlyAttrName() ||
name == LLVMDialect::getReadnoneAttrName() ||
name == LLVMDialect::getWriteOnlyAttrName() ||
name == LLVMDialect::getNestAttrName() ||
name == LLVMDialect::getNoCaptureAttrName() ||
name == LLVMDialect::getNoFreeAttrName() ||
name == LLVMDialect::getNonNullAttrName()) {
if (failed(checkUnitAttrType()))
return failure();
if (verifyValueType && failed(checkPointerType()))
return failure();
return success();
}
if (name == LLVMDialect::getStructRetAttrName() ||
name == LLVMDialect::getByValAttrName() ||
name == LLVMDialect::getByRefAttrName() ||
name == LLVMDialect::getInAllocaAttrName() ||
name == LLVMDialect::getPreallocatedAttrName()) {
if (failed(checkTypeAttrType()))
return failure();
if (verifyValueType && failed(checkPointerTypeMatches()))
return failure();
return success();
}
if (name == LLVMDialect::getSExtAttrName() ||
name == LLVMDialect::getZExtAttrName()) {
if (failed(checkUnitAttrType()))
return failure();
if (verifyValueType && failed(checkIntegerType()))
return failure();
return success();
}
if (name == LLVMDialect::getAlignAttrName() ||
name == LLVMDialect::getDereferenceableAttrName() ||
name == LLVMDialect::getDereferenceableOrNullAttrName() ||
name == LLVMDialect::getStackAlignmentAttrName()) {
if (failed(checkIntegerAttrType()))
return failure();
if (verifyValueType && failed(checkPointerType()))
return failure();
return success();
}
if (name == LLVMDialect::getNoUndefAttrName() ||
name == LLVMDialect::getInRegAttrName() ||
name == LLVMDialect::getReturnedAttrName())
return checkUnitAttrType();
return success();
}
LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
unsigned regionIdx,
unsigned argIdx,
NamedAttribute argAttr) {
auto funcOp = dyn_cast<FunctionOpInterface>(op);
if (!funcOp)
return success();
Type argType = funcOp.getArgumentTypes()[argIdx];
return verifyParameterAttribute(op, argType, argAttr);
}
LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
unsigned regionIdx,
unsigned resIdx,
NamedAttribute resAttr) {
auto funcOp = dyn_cast<FunctionOpInterface>(op);
if (!funcOp)
return success();
Type resType = funcOp.getResultTypes()[resIdx];
if (llvm::isa<LLVMVoidType>(resType))
return op->emitError() << "cannot attach result attributes to functions "
"with a void return";
auto name = resAttr.getName();
if (name == LLVMDialect::getAllocAlignAttrName() ||
name == LLVMDialect::getAllocatedPointerAttrName() ||
name == LLVMDialect::getByValAttrName() ||
name == LLVMDialect::getByRefAttrName() ||
name == LLVMDialect::getInAllocaAttrName() ||
name == LLVMDialect::getNestAttrName() ||
name == LLVMDialect::getNoCaptureAttrName() ||
name == LLVMDialect::getNoFreeAttrName() ||
name == LLVMDialect::getPreallocatedAttrName() ||
name == LLVMDialect::getReadnoneAttrName() ||
name == LLVMDialect::getReadonlyAttrName() ||
name == LLVMDialect::getReturnedAttrName() ||
name == LLVMDialect::getStackAlignmentAttrName() ||
name == LLVMDialect::getStructRetAttrName() ||
name == LLVMDialect::getWriteOnlyAttrName())
return op->emitError() << name << " is not a valid result attribute";
return verifyParameterAttribute(op, resType, resAttr);
}
Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
if (auto symbol = dyn_cast<FlatSymbolRefAttr>(value))
if (isa<LLVM::LLVMPointerType>(type))
return builder.create<LLVM::AddressOfOp>(loc, type, symbol);
if (isa<LLVM::UndefAttr>(value))
return builder.create<LLVM::UndefOp>(loc, type);
if (isa<LLVM::PoisonAttr>(value))
return builder.create<LLVM::PoisonOp>(loc, type);
if (isa<LLVM::ZeroAttr>(value))
return builder.create<LLVM::ZeroOp>(loc, type);
return LLVM::ConstantOp::materialize(builder, value, type, loc);
}
Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
StringRef name, StringRef value,
LLVM::Linkage linkage) {
assert(builder.getInsertionBlock() &&
builder.getInsertionBlock()->getParentOp() &&
"expected builder to point to a block constrained in an op");
auto module =
builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
assert(module && "builder points to an op outside of a module");
OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
MLIRContext *ctx = builder.getContext();
auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
auto global = moduleBuilder.create<LLVM::GlobalOp>(
loc, type, true, linkage, name,
builder.getStringAttr(value), 0);
LLVMPointerType ptrType = LLVMPointerType::get(ctx);
Value globalPtr =
builder.create<LLVM::AddressOfOp>(loc, ptrType, global.getSymNameAttr());
return builder.create<LLVM::GEPOp>(loc, ptrType, type, globalPtr,
ArrayRef<GEPArg>{0, 0});
}
bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
return op->hasTrait<OpTrait::SymbolTable>() &&
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}