#include "Parser.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TensorEncoding.h"
using namespace mlir;
using namespace mlir::detail;
OptionalParseResult Parser::parseOptionalType(Type &type) {
switch (getToken().getKind()) {
case Token::l_paren:
case Token::kw_memref:
case Token::kw_tensor:
case Token::kw_complex:
case Token::kw_tuple:
case Token::kw_vector:
case Token::inttype:
case Token::kw_bf16:
case Token::kw_f16:
case Token::kw_f32:
case Token::kw_f64:
case Token::kw_f80:
case Token::kw_f128:
case Token::kw_index:
case Token::kw_none:
case Token::exclamation_identifier:
return failure(!(type = parseType()));
default:
return llvm::None;
}
}
Type Parser::parseType() {
if (getToken().is(Token::l_paren))
return parseFunctionType();
return parseNonFunctionType();
}
ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
if (getToken().is(Token::l_paren))
return parseTypeListParens(elements);
Type t = parseNonFunctionType();
if (!t)
return failure();
elements.push_back(t);
return success();
}
ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
auto parseElt = [&]() -> ParseResult {
auto elt = parseType();
elements.push_back(elt);
return elt ? success() : failure();
};
return parseCommaSeparatedList(parseElt);
}
ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
if (parseToken(Token::l_paren, "expected '('"))
return failure();
if (getToken().is(Token::r_paren))
return consumeToken(), success();
if (parseTypeListNoParens(elements) ||
parseToken(Token::r_paren, "expected ')'"))
return failure();
return success();
}
Type Parser::parseComplexType() {
consumeToken(Token::kw_complex);
if (parseToken(Token::less, "expected '<' in complex type"))
return nullptr;
SMLoc elementTypeLoc = getToken().getLoc();
auto elementType = parseType();
if (!elementType ||
parseToken(Token::greater, "expected '>' in complex type"))
return nullptr;
if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
return emitError(elementTypeLoc, "invalid element type for complex"),
nullptr;
return ComplexType::get(elementType);
}
Type Parser::parseFunctionType() {
assert(getToken().is(Token::l_paren));
SmallVector<Type, 4> arguments, results;
if (parseTypeListParens(arguments) ||
parseToken(Token::arrow, "expected '->' in function type") ||
parseFunctionResultTypes(results))
return nullptr;
return builder.getFunctionType(arguments, results);
}
ParseResult Parser::parseStridedLayout(int64_t &offset,
SmallVectorImpl<int64_t> &strides) {
consumeToken(Token::kw_offset);
if (parseToken(Token::colon, "expected colon after `offset` keyword"))
return failure();
auto maybeOffset = getToken().getUnsignedIntegerValue();
bool question = getToken().is(Token::question);
if (!maybeOffset && !question)
return emitWrongTokenError("invalid offset");
offset = maybeOffset ? static_cast<int64_t>(*maybeOffset)
: MemRefType::getDynamicStrideOrOffset();
consumeToken();
if (parseToken(Token::comma, "expected comma after offset value") ||
parseToken(Token::kw_strides,
"expected `strides` keyword after offset specification") ||
parseToken(Token::colon, "expected colon after `strides` keyword") ||
parseStrideList(strides))
return failure();
return success();
}
Type Parser::parseMemRefType() {
SMLoc loc = getToken().getLoc();
consumeToken(Token::kw_memref);
if (parseToken(Token::less, "expected '<' in memref type"))
return nullptr;
bool isUnranked;
SmallVector<int64_t, 4> dimensions;
if (consumeIf(Token::star)) {
isUnranked = true;
if (parseXInDimensionList())
return nullptr;
} else {
isUnranked = false;
if (parseDimensionListRanked(dimensions))
return nullptr;
}
auto typeLoc = getToken().getLoc();
auto elementType = parseType();
if (!elementType)
return nullptr;
if (!BaseMemRefType::isValidElementType(elementType))
return emitError(typeLoc, "invalid memref element type"), nullptr;
MemRefLayoutAttrInterface layout;
Attribute memorySpace;
auto parseElt = [&]() -> ParseResult {
if (getToken().is(Token::kw_offset)) {
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(parseStridedLayout(offset, strides)))
return failure();
AffineMap map = makeStridedLinearLayoutMap(strides, offset, getContext());
layout = AffineMapAttr::get(map);
} else {
Attribute attr = parseAttribute();
if (!attr)
return failure();
if (attr.isa<MemRefLayoutAttrInterface>()) {
layout = attr.cast<MemRefLayoutAttrInterface>();
} else if (memorySpace) {
return emitError("multiple memory spaces specified in memref type");
} else {
memorySpace = attr;
return success();
}
}
if (isUnranked)
return emitError("cannot have affine map for unranked memref type");
if (memorySpace)
return emitError("expected memory space to be last in memref type");
return success();
};
if (!consumeIf(Token::greater)) {
if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
parseCommaSeparatedListUntil(Token::greater, parseElt,
false)) {
return nullptr;
}
}
if (isUnranked)
return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
return getChecked<MemRefType>(loc, dimensions, elementType, layout,
memorySpace);
}
Type Parser::parseNonFunctionType() {
switch (getToken().getKind()) {
default:
return (emitWrongTokenError("expected non-function type"), nullptr);
case Token::kw_memref:
return parseMemRefType();
case Token::kw_tensor:
return parseTensorType();
case Token::kw_complex:
return parseComplexType();
case Token::kw_tuple:
return parseTupleType();
case Token::kw_vector:
return parseVectorType();
case Token::inttype: {
auto width = getToken().getIntTypeBitwidth();
if (!width.has_value())
return (emitError("invalid integer width"), nullptr);
if (width.value() > IntegerType::kMaxWidth) {
emitError(getToken().getLoc(), "integer bitwidth is limited to ")
<< IntegerType::kMaxWidth << " bits";
return nullptr;
}
IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
if (Optional<bool> signedness = getToken().getIntTypeSignedness())
signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
consumeToken(Token::inttype);
return IntegerType::get(getContext(), *width, signSemantics);
}
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
return builder.getBF16Type();
case Token::kw_f16:
consumeToken(Token::kw_f16);
return builder.getF16Type();
case Token::kw_f32:
consumeToken(Token::kw_f32);
return builder.getF32Type();
case Token::kw_f64:
consumeToken(Token::kw_f64);
return builder.getF64Type();
case Token::kw_f80:
consumeToken(Token::kw_f80);
return builder.getF80Type();
case Token::kw_f128:
consumeToken(Token::kw_f128);
return builder.getF128Type();
case Token::kw_index:
consumeToken(Token::kw_index);
return builder.getIndexType();
case Token::kw_none:
consumeToken(Token::kw_none);
return builder.getNoneType();
case Token::exclamation_identifier:
return parseExtendedType();
case Token::code_complete:
if (getToken().isCodeCompletionFor(Token::exclamation_identifier))
return parseExtendedType();
return codeCompleteType();
}
}
Type Parser::parseTensorType() {
consumeToken(Token::kw_tensor);
if (parseToken(Token::less, "expected '<' in tensor type"))
return nullptr;
bool isUnranked;
SmallVector<int64_t, 4> dimensions;
if (consumeIf(Token::star)) {
isUnranked = true;
if (parseXInDimensionList())
return nullptr;
} else {
isUnranked = false;
if (parseDimensionListRanked(dimensions))
return nullptr;
}
auto elementTypeLoc = getToken().getLoc();
auto elementType = parseType();
Attribute encoding;
if (consumeIf(Token::comma)) {
encoding = parseAttribute();
if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
if (failed(v.verifyEncoding(dimensions, elementType,
[&] { return emitError(); })))
return nullptr;
}
}
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
return nullptr;
if (!TensorType::isValidElementType(elementType))
return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
if (isUnranked) {
if (encoding)
return emitError("cannot apply encoding to unranked tensor"), nullptr;
return UnrankedTensorType::get(elementType);
}
return RankedTensorType::get(dimensions, elementType, encoding);
}
Type Parser::parseTupleType() {
consumeToken(Token::kw_tuple);
if (parseToken(Token::less, "expected '<' in tuple type"))
return nullptr;
if (consumeIf(Token::greater))
return TupleType::get(getContext());
SmallVector<Type, 4> types;
if (parseTypeListNoParens(types) ||
parseToken(Token::greater, "expected '>' in tuple type"))
return nullptr;
return TupleType::get(getContext(), types);
}
VectorType Parser::parseVectorType() {
consumeToken(Token::kw_vector);
if (parseToken(Token::less, "expected '<' in vector type"))
return nullptr;
SmallVector<int64_t, 4> dimensions;
unsigned numScalableDims;
if (parseVectorDimensionList(dimensions, numScalableDims))
return nullptr;
if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
return emitError(getToken().getLoc(),
"vector types must have positive constant sizes"),
nullptr;
auto typeLoc = getToken().getLoc();
auto elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
if (!VectorType::isValidElementType(elementType))
return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;
return VectorType::get(dimensions, elementType, numScalableDims);
}
ParseResult
Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
unsigned &numScalableDims) {
numScalableDims = 0;
while (getToken().is(Token::integer)) {
int64_t value;
if (parseIntegerInDimensionList(value))
return failure();
dimensions.push_back(value);
if (parseXInDimensionList())
return failure();
}
if (consumeIf(Token::l_square)) {
while (getToken().is(Token::integer)) {
int64_t value;
if (parseIntegerInDimensionList(value))
return failure();
dimensions.push_back(value);
numScalableDims++;
if (consumeIf(Token::r_square)) {
return parseXInDimensionList();
}
if (parseXInDimensionList())
return failure();
}
return emitWrongTokenError(
"missing ']' closing set of scalable dimensions");
}
return success();
}
ParseResult
Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
bool allowDynamic, bool withTrailingX) {
auto parseDim = [&]() -> LogicalResult {
auto loc = getToken().getLoc();
if (consumeIf(Token::question)) {
if (!allowDynamic)
return emitError(loc, "expected static shape");
dimensions.push_back(-1);
} else {
int64_t value;
if (failed(parseIntegerInDimensionList(value)))
return failure();
dimensions.push_back(value);
}
return success();
};
if (withTrailingX) {
while (getToken().isAny(Token::integer, Token::question)) {
if (failed(parseDim()) || failed(parseXInDimensionList()))
return failure();
}
return success();
}
if (getToken().isAny(Token::integer, Token::question)) {
if (failed(parseDim()))
return failure();
while (getToken().is(Token::bare_identifier) &&
getTokenSpelling()[0] == 'x') {
if (failed(parseXInDimensionList()) || failed(parseDim()))
return failure();
}
}
return success();
}
ParseResult Parser::parseIntegerInDimensionList(int64_t &value) {
if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
value = 0;
state.lex.resetPointer(getTokenSpelling().data() + 1);
consumeToken();
} else {
Optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
if (!dimension ||
*dimension > (uint64_t)std::numeric_limits<int64_t>::max())
return emitError("invalid dimension");
value = (int64_t)*dimension;
consumeToken(Token::integer);
}
return success();
}
ParseResult Parser::parseXInDimensionList() {
if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
return emitWrongTokenError("expected 'x' in dimension list");
if (getTokenSpelling().size() != 1)
state.lex.resetPointer(getTokenSpelling().data() + 1);
consumeToken(Token::bare_identifier);
return success();
}
ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
return parseCommaSeparatedList(
Delimiter::Square,
[&]() -> ParseResult {
if (consumeIf(Token::question)) {
dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
} else {
int64_t val;
if (getToken().getSpelling().getAsInteger(10, val))
return emitError("invalid integer value: ")
<< getToken().getSpelling();
if (ShapedType::isDynamic(val))
return emitError("invalid integer value: ")
<< getToken().getSpelling()
<< ", use `?` to specify a dynamic dimension";
if (val == 0)
return emitError("invalid memref stride");
dimensions.push_back(val);
consumeToken(Token::integer);
}
return success();
},
" in stride list");
}