#include "Parser.h"
#include "AsmParserImpl.h"
#include "mlir/AsmParser/AsmParserState.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"
#include <optional>
using namespace mlir;
using namespace mlir::detail;
Attribute Parser::parseAttribute(Type type) {
switch (getToken().getKind()) {
case Token::kw_affine_map: {
consumeToken(Token::kw_affine_map);
AffineMap map;
if (parseToken(Token::less, "expected '<' in affine map") ||
parseAffineMapReference(map) ||
parseToken(Token::greater, "expected '>' in affine map"))
return Attribute();
return AffineMapAttr::get(map);
}
case Token::kw_affine_set: {
consumeToken(Token::kw_affine_set);
IntegerSet set;
if (parseToken(Token::less, "expected '<' in integer set") ||
parseIntegerSetReference(set) ||
parseToken(Token::greater, "expected '>' in integer set"))
return Attribute();
return IntegerSetAttr::get(set);
}
case Token::l_square: {
consumeToken(Token::l_square);
SmallVector<Attribute, 4> elements;
auto parseElt = [&]() -> ParseResult {
elements.push_back(parseAttribute());
return elements.back() ? success() : failure();
};
if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
return nullptr;
return builder.getArrayAttr(elements);
}
case Token::kw_false:
consumeToken(Token::kw_false);
return builder.getBoolAttr(false);
case Token::kw_true:
consumeToken(Token::kw_true);
return builder.getBoolAttr(true);
case Token::kw_dense:
return parseDenseElementsAttr(type);
case Token::kw_dense_resource:
return parseDenseResourceElementsAttr(type);
case Token::kw_array:
return parseDenseArrayAttr(type);
case Token::l_brace: {
NamedAttrList elements;
if (parseAttributeDict(elements))
return nullptr;
return elements.getDictionary(getContext());
}
case Token::hash_identifier:
return parseExtendedAttr(type);
case Token::floatliteral:
return parseFloatAttr(type, false);
case Token::integer:
return parseDecOrHexAttr(type, false);
case Token::minus: {
consumeToken(Token::minus);
if (getToken().is(Token::integer))
return parseDecOrHexAttr(type, true);
if (getToken().is(Token::floatliteral))
return parseFloatAttr(type, true);
return (emitWrongTokenError(
"expected constant integer or floating point value"),
nullptr);
}
case Token::kw_loc: {
consumeToken(Token::kw_loc);
LocationAttr locAttr;
if (parseToken(Token::l_paren, "expected '(' in inline location") ||
parseLocationInstance(locAttr) ||
parseToken(Token::r_paren, "expected ')' in inline location"))
return Attribute();
return locAttr;
}
case Token::kw_sparse:
return parseSparseElementsAttr(type);
case Token::kw_strided:
return parseStridedLayoutAttr();
case Token::kw_distinct:
return parseDistinctAttr(type);
case Token::string: {
auto val = getToken().getStringValue();
consumeToken(Token::string);
if (!type && consumeIf(Token::colon) && !(type = parseType()))
return Attribute();
return type ? StringAttr::get(val, type)
: StringAttr::get(getContext(), val);
}
case Token::at_identifier: {
SmallVector<SMRange> referenceLocations;
if (state.asmState)
referenceLocations.push_back(getToken().getLocRange());
std::string nameStr = getToken().getSymbolReference();
consumeToken(Token::at_identifier);
std::vector<FlatSymbolRefAttr> nestedRefs;
while (getToken().is(Token::colon)) {
const char *curPointer = getToken().getLoc().getPointer();
consumeToken(Token::colon);
if (!consumeIf(Token::colon)) {
if (getToken().isNot(Token::eof, Token::error)) {
state.lex.resetPointer(curPointer);
consumeToken();
}
break;
}
auto curLoc = getToken().getLoc();
if (getToken().isNot(Token::at_identifier)) {
emitError(curLoc, "expected nested symbol reference identifier");
return Attribute();
}
if (state.asmState)
referenceLocations.push_back(getToken().getLocRange());
std::string nameStr = getToken().getSymbolReference();
consumeToken(Token::at_identifier);
nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
}
SymbolRefAttr symbolRefAttr =
SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
if (state.asmState)
state.asmState->addUses(symbolRefAttr, referenceLocations);
return symbolRefAttr;
}
case Token::kw_unit:
consumeToken(Token::kw_unit);
return builder.getUnitAttr();
case Token::code_complete:
if (getToken().isCodeCompletionFor(Token::hash_identifier))
return parseExtendedAttr(type);
return codeCompleteAttribute();
default:
Type type;
OptionalParseResult result = parseOptionalType(type);
if (!result.has_value())
return emitWrongTokenError("expected attribute value"), Attribute();
return failed(*result) ? Attribute() : TypeAttr::get(type);
}
}
OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
Type type) {
switch (getToken().getKind()) {
case Token::at_identifier:
case Token::floatliteral:
case Token::integer:
case Token::hash_identifier:
case Token::kw_affine_map:
case Token::kw_affine_set:
case Token::kw_dense:
case Token::kw_dense_resource:
case Token::kw_false:
case Token::kw_loc:
case Token::kw_sparse:
case Token::kw_true:
case Token::kw_unit:
case Token::l_brace:
case Token::l_square:
case Token::minus:
case Token::string:
attribute = parseAttribute(type);
return success(attribute != nullptr);
default:
Type type;
OptionalParseResult result = parseOptionalType(type);
if (result.has_value() && succeeded(*result))
attribute = TypeAttr::get(type);
return result;
}
}
OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
Type type) {
return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
}
OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
Type type) {
return parseOptionalAttributeWithToken(Token::string, attribute, type);
}
OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
Type type) {
return parseOptionalAttributeWithToken(Token::at_identifier, result, type);
}
ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
llvm::SmallDenseSet<StringAttr> seenKeys;
auto parseElt = [&]() -> ParseResult {
std::optional<StringAttr> nameId;
if (getToken().is(Token::string))
nameId = builder.getStringAttr(getToken().getStringValue());
else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
getToken().isKeyword())
nameId = builder.getStringAttr(getTokenSpelling());
else
return emitWrongTokenError("expected attribute name");
if (nameId->empty())
return emitError("expected valid attribute name");
if (!seenKeys.insert(*nameId).second)
return emitError("duplicate key '")
<< nameId->getValue() << "' in dictionary attribute";
consumeToken();
auto splitName = nameId->strref().split('.');
if (!splitName.second.empty())
getContext()->getOrLoadDialect(splitName.first);
if (!consumeIf(Token::equal)) {
attributes.push_back({*nameId, builder.getUnitAttr()});
return success();
}
auto attr = parseAttribute();
if (!attr)
return failure();
attributes.push_back({*nameId, attr});
return success();
};
return parseCommaSeparatedList(Delimiter::Braces, parseElt,
" in attribute dictionary");
}
Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
auto val = getToken().getFloatingPointValue();
if (!val)
return (emitError("floating point value too large for attribute"), nullptr);
consumeToken(Token::floatliteral);
if (!type) {
if (!consumeIf(Token::colon))
type = builder.getF64Type();
else if (!(type = parseType()))
return nullptr;
}
if (!isa<FloatType>(type))
return (emitError("floating point value not valid for specified type"),
nullptr);
return FloatAttr::get(type, isNegative ? -*val : *val);
}
static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
StringRef spelling) {
APInt result;
bool isHex = spelling.size() > 1 && spelling[1] == 'x';
if (spelling.getAsInteger(isHex ? 0 : 10, result))
return std::nullopt;
unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
: type.getIntOrFloatBitWidth();
if (width > result.getBitWidth()) {
result = result.zext(width);
} else if (width < result.getBitWidth()) {
if (result.countl_zero() < result.getBitWidth() - width)
return std::nullopt;
result = result.trunc(width);
}
if (width == 0) {
if (isNegative)
return std::nullopt;
} else if (isNegative) {
result.negate();
if (!result.isSignBitSet())
return std::nullopt;
} else if ((type.isSignedInteger() || type.isIndex()) &&
result.isSignBitSet()) {
return std::nullopt;
}
return result;
}
Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
Token tok = getToken();
StringRef spelling = tok.getSpelling();
SMLoc loc = tok.getLoc();
consumeToken(Token::integer);
if (!type) {
if (!consumeIf(Token::colon))
type = builder.getIntegerType(64);
else if (!(type = parseType()))
return nullptr;
}
if (auto floatType = dyn_cast<FloatType>(type)) {
std::optional<APFloat> result;
if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
floatType.getFloatSemantics(),
floatType.getWidth())))
return Attribute();
return FloatAttr::get(floatType, *result);
}
if (!isa<IntegerType, IndexType>(type))
return emitError(loc, "integer literal not valid for specified type"),
nullptr;
if (isNegative && type.isUnsignedInteger()) {
emitError(loc,
"negative integer literal not valid for unsigned integer type");
return nullptr;
}
std::optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
if (!apInt)
return emitError(loc, "integer constant out of range for attribute"),
nullptr;
return builder.getIntegerAttr(type, *apInt);
}
static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
std::string &result) {
if (std::optional<std::string> value = tok.getHexStringValue()) {
result = std::move(*value);
return success();
}
return parser.emitError(
tok.getLoc(), "expected string containing hex digits starting with `0x`");
}
namespace {
class TensorLiteralParser {
public:
TensorLiteralParser(Parser &p) : p(p) {}
ParseResult parse(bool allowHex);
DenseElementsAttr getAttr(SMLoc loc, ShapedType type);
ArrayRef<int64_t> getShape() const { return shape; }
private:
ParseResult getIntAttrElements(SMLoc loc, Type eltTy,
std::vector<APInt> &intValues);
ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy,
std::vector<APFloat> &floatValues);
DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type);
ParseResult parseElement();
ParseResult parseList(SmallVectorImpl<int64_t> &dims);
ParseResult parseHexElements();
Parser &p;
SmallVector<int64_t, 4> shape;
std::vector<std::pair<bool, Token>> storage;
std::optional<Token> hexStorage;
};
}
ParseResult TensorLiteralParser::parse(bool allowHex) {
if (allowHex && p.getToken().is(Token::string)) {
hexStorage = p.getToken();
p.consumeToken(Token::string);
return success();
}
if (p.getToken().is(Token::l_square))
return parseList(shape);
return parseElement();
}
DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
Type eltType = type.getElementType();
if (hexStorage &&
(eltType.isIntOrIndexOrFloat() || isa<ComplexType>(eltType)))
return getHexAttr(loc, type);
if (!shape.empty() && getShape() != type.getShape()) {
p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
<< "]) does not match type ([" << type.getShape() << "])";
return nullptr;
}
if (!hexStorage && storage.empty() && type.getNumElements()) {
p.emitError(loc) << "parsed zero elements, but type (" << type
<< ") expected at least 1";
return nullptr;
}
bool isComplex = false;
if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) {
eltType = complexTy.getElementType();
isComplex = true;
}
if (eltType.isIntOrIndex()) {
std::vector<APInt> intValues;
if (failed(getIntAttrElements(loc, eltType, intValues)))
return nullptr;
if (isComplex) {
auto complexData = llvm::ArrayRef(
reinterpret_cast<std::complex<APInt> *>(intValues.data()),
intValues.size() / 2);
return DenseElementsAttr::get(type, complexData);
}
return DenseElementsAttr::get(type, intValues);
}
if (FloatType floatTy = dyn_cast<FloatType>(eltType)) {
std::vector<APFloat> floatValues;
if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
return nullptr;
if (isComplex) {
auto complexData = llvm::ArrayRef(
reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
floatValues.size() / 2);
return DenseElementsAttr::get(type, complexData);
}
return DenseElementsAttr::get(type, floatValues);
}
return getStringAttr(loc, type, type.getElementType());
}
ParseResult
TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
std::vector<APInt> &intValues) {
intValues.reserve(storage.size());
bool isUintType = eltTy.isUnsignedInteger();
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
const Token &token = signAndToken.second;
auto tokenLoc = token.getLoc();
if (isNegative && isUintType) {
return p.emitError(tokenLoc)
<< "expected unsigned integer elements, but parsed negative value";
}
if (token.is(Token::floatliteral)) {
return p.emitError(tokenLoc)
<< "expected integer elements, but parsed floating-point";
}
assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
"unexpected token type");
if (token.isAny(Token::kw_true, Token::kw_false)) {
if (!eltTy.isInteger(1)) {
return p.emitError(tokenLoc)
<< "expected i1 type for 'true' or 'false' values";
}
APInt apInt(1, token.is(Token::kw_true), false);
intValues.push_back(apInt);
continue;
}
std::optional<APInt> apInt =
buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
if (!apInt)
return p.emitError(tokenLoc, "integer constant out of range for type");
intValues.push_back(*apInt);
}
return success();
}
ParseResult
TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
std::vector<APFloat> &floatValues) {
floatValues.reserve(storage.size());
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
const Token &token = signAndToken.second;
if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) {
std::optional<APFloat> result;
if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
eltTy.getFloatSemantics(),
eltTy.getWidth())))
return failure();
floatValues.push_back(*result);
continue;
}
if (!token.is(Token::floatliteral))
return p.emitError()
<< "expected floating-point elements, but parsed integer";
auto val = token.getFloatingPointValue();
if (!val)
return p.emitError("floating point value too large for attribute");
APFloat apVal(isNegative ? -*val : *val);
if (!eltTy.isF64()) {
bool unused;
apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&unused);
}
floatValues.push_back(apVal);
}
return success();
}
DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
Type eltTy) {
if (hexStorage.has_value()) {
auto stringValue = hexStorage->getStringValue();
return DenseStringElementsAttr::get(type, {stringValue});
}
std::vector<std::string> stringValues;
std::vector<StringRef> stringRefValues;
stringValues.reserve(storage.size());
stringRefValues.reserve(storage.size());
for (auto val : storage) {
stringValues.push_back(val.second.getStringValue());
stringRefValues.emplace_back(stringValues.back());
}
return DenseStringElementsAttr::get(type, stringRefValues);
}
DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) {
Type elementType = type.getElementType();
if (!elementType.isIntOrIndexOrFloat() && !isa<ComplexType>(elementType)) {
p.emitError(loc)
<< "expected floating-point, integer, or complex element type, got "
<< elementType;
return nullptr;
}
std::string data;
if (parseElementAttrHexValues(p, *hexStorage, data))
return nullptr;
ArrayRef<char> rawData(data.data(), data.size());
bool detectedSplat = false;
if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
p.emitError(loc) << "elements hex data size is invalid for provided type: "
<< type;
return nullptr;
}
if (llvm::endianness::native == llvm::endianness::big) {
SmallVector<char, 64> outDataVec(rawData.size());
MutableArrayRef<char> convRawData(outDataVec);
DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
rawData, convRawData, type);
return DenseElementsAttr::getFromRawBuffer(type, convRawData);
}
return DenseElementsAttr::getFromRawBuffer(type, rawData);
}
ParseResult TensorLiteralParser::parseElement() {
switch (p.getToken().getKind()) {
case Token::kw_true:
case Token::kw_false:
case Token::floatliteral:
case Token::integer:
storage.emplace_back(false, p.getToken());
p.consumeToken();
break;
case Token::minus:
p.consumeToken(Token::minus);
if (!p.getToken().isAny(Token::floatliteral, Token::integer))
return p.emitError("expected integer or floating point literal");
storage.emplace_back(true, p.getToken());
p.consumeToken();
break;
case Token::string:
storage.emplace_back(false, p.getToken());
p.consumeToken();
break;
case Token::l_paren:
p.consumeToken(Token::l_paren);
if (parseElement() ||
p.parseToken(Token::comma, "expected ',' between complex elements") ||
parseElement() ||
p.parseToken(Token::r_paren, "expected ')' after complex elements"))
return failure();
break;
default:
return p.emitError("expected element literal of primitive type");
}
return success();
}
ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
if (prevDims == newDims)
return success();
return p.emitError("tensor literal is invalid; ranks are not consistent "
"between elements");
};
bool first = true;
SmallVector<int64_t, 4> newDims;
unsigned size = 0;
auto parseOneElement = [&]() -> ParseResult {
SmallVector<int64_t, 4> thisDims;
if (p.getToken().getKind() == Token::l_square) {
if (parseList(thisDims))
return failure();
} else if (parseElement()) {
return failure();
}
++size;
if (!first)
return checkDims(newDims, thisDims);
newDims = thisDims;
first = false;
return success();
};
if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
return failure();
dims.clear();
dims.push_back(size);
dims.append(newDims.begin(), newDims.end());
return success();
}
namespace {
class DenseArrayElementParser {
public:
explicit DenseArrayElementParser(Type type) : type(type) {}
ParseResult parseIntegerElement(Parser &p);
ParseResult parseFloatElement(Parser &p);
DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); }
private:
void append(const APInt &data);
Type type;
std::vector<char> rawData;
int64_t size = 0;
};
}
void DenseArrayElementParser::append(const APInt &data) {
if (data.getBitWidth()) {
assert(data.getBitWidth() % 8 == 0);
unsigned byteSize = data.getBitWidth() / 8;
size_t offset = rawData.size();
rawData.insert(rawData.end(), byteSize, 0);
llvm::StoreIntToMemory(
data, reinterpret_cast<uint8_t *>(rawData.data() + offset), byteSize);
}
++size;
}
ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
bool isNegative = p.consumeIf(Token::minus);
std::optional<APInt> value;
StringRef spelling = p.getToken().getSpelling();
if (p.getToken().isAny(Token::kw_true, Token::kw_false)) {
if (!type.isInteger(1))
return p.emitError("expected i1 type for 'true' or 'false' values");
value = APInt(8, p.getToken().is(Token::kw_true),
!type.isUnsignedInteger());
p.consumeToken();
} else if (p.consumeIf(Token::integer)) {
value = buildAttributeAPInt(type, isNegative, spelling);
if (!value)
return p.emitError("integer constant out of range");
} else {
return p.emitError("expected integer literal");
}
append(*value);
return success();
}
ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
bool isNegative = p.consumeIf(Token::minus);
Token token = p.getToken();
std::optional<APFloat> result;
auto floatType = cast<FloatType>(type);
if (p.consumeIf(Token::integer)) {
if (p.parseFloatFromIntegerLiteral(result, token, isNegative,
floatType.getFloatSemantics(),
floatType.getWidth()))
return failure();
} else if (p.consumeIf(Token::floatliteral)) {
std::optional<double> val = token.getFloatingPointValue();
if (!val)
return failure();
result = APFloat(isNegative ? -*val : *val);
if (!type.isF64()) {
bool unused;
result->convert(floatType.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &unused);
}
} else {
return p.emitError("expected integer or floating point literal");
}
append(result->bitcastToAPInt());
return success();
}
Attribute Parser::parseDenseArrayAttr(Type attrType) {
consumeToken(Token::kw_array);
if (parseToken(Token::less, "expected '<' after 'array'"))
return {};
SMLoc typeLoc = getToken().getLoc();
Type eltType = parseType();
if (!eltType) {
emitError(typeLoc, "expected an integer or floating point type");
return {};
}
if (!eltType.isIntOrIndexOrFloat()) {
emitError(typeLoc, "expected integer or float type, got: ") << eltType;
return {};
}
if (!eltType.isInteger(1) && eltType.getIntOrFloatBitWidth() % 8 != 0) {
emitError(typeLoc, "element type bitwidth must be a multiple of 8");
return {};
}
if (consumeIf(Token::greater))
return DenseArrayAttr::get(eltType, 0, {});
if (parseToken(Token::colon, "expected ':' after dense array type"))
return {};
DenseArrayElementParser eltParser(eltType);
if (eltType.isIntOrIndex()) {
if (parseCommaSeparatedList(
[&] { return eltParser.parseIntegerElement(*this); }))
return {};
} else {
if (parseCommaSeparatedList(
[&] { return eltParser.parseFloatElement(*this); }))
return {};
}
if (parseToken(Token::greater, "expected '>' to close an array attribute"))
return {};
return eltParser.getAttr();
}
Attribute Parser::parseDenseElementsAttr(Type attrType) {
auto attribLoc = getToken().getLoc();
consumeToken(Token::kw_dense);
if (parseToken(Token::less, "expected '<' after 'dense'"))
return nullptr;
TensorLiteralParser literalParser(*this);
if (!consumeIf(Token::greater)) {
if (literalParser.parse(true) ||
parseToken(Token::greater, "expected '>'"))
return nullptr;
}
auto loc = attrType ? attribLoc : getToken().getLoc();
auto type = parseElementsLiteralType(attrType);
if (!type)
return nullptr;
return literalParser.getAttr(loc, type);
}
Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
auto loc = getToken().getLoc();
consumeToken(Token::kw_dense_resource);
if (parseToken(Token::less, "expected '<' after 'dense_resource'"))
return nullptr;
FailureOr<AsmDialectResourceHandle> rawHandle =
parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>());
if (failed(rawHandle) || parseToken(Token::greater, "expected '>'"))
return nullptr;
auto *handle = dyn_cast<DenseResourceElementsHandle>(&*rawHandle);
if (!handle)
return emitError(loc, "invalid `dense_resource` handle type"), nullptr;
SMLoc typeLoc = loc;
if (!attrType) {
typeLoc = getToken().getLoc();
if (parseToken(Token::colon, "expected ':'") || !(attrType = parseType()))
return nullptr;
}
ShapedType shapedType = dyn_cast<ShapedType>(attrType);
if (!shapedType) {
emitError(typeLoc, "`dense_resource` expected a shaped type");
return nullptr;
}
return DenseResourceElementsAttr::get(shapedType, *handle);
}
ShapedType Parser::parseElementsLiteralType(Type type) {
if (!type) {
if (parseToken(Token::colon, "expected ':'"))
return nullptr;
if (!(type = parseType()))
return nullptr;
}
auto sType = dyn_cast<ShapedType>(type);
if (!sType) {
emitError("elements literal must be a shaped type");
return nullptr;
}
if (!sType.hasStaticShape())
return (emitError("elements literal type must have static shape"), nullptr);
return sType;
}
Attribute Parser::parseSparseElementsAttr(Type attrType) {
SMLoc loc = getToken().getLoc();
consumeToken(Token::kw_sparse);
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
return nullptr;
Type indiceEltType = builder.getIntegerType(64);
if (consumeIf(Token::greater)) {
ShapedType type = parseElementsLiteralType(attrType);
if (!type)
return nullptr;
ShapedType indicesType =
RankedTensorType::get({0, type.getRank()}, indiceEltType);
ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
return getChecked<SparseElementsAttr>(
loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
}
auto indicesLoc = getToken().getLoc();
TensorLiteralParser indiceParser(*this);
if (indiceParser.parse(false))
return nullptr;
if (parseToken(Token::comma, "expected ','"))
return nullptr;
auto valuesLoc = getToken().getLoc();
TensorLiteralParser valuesParser(*this);
if (valuesParser.parse(true))
return nullptr;
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
auto type = parseElementsLiteralType(attrType);
if (!type)
return nullptr;
ShapedType indicesType;
if (indiceParser.getShape().empty()) {
indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
} else {
indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
}
auto indices = indiceParser.getAttr(indicesLoc, indicesType);
auto valuesEltType = type.getElementType();
ShapedType valuesType =
valuesParser.getShape().empty()
? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
: RankedTensorType::get(valuesParser.getShape(), valuesEltType);
auto values = valuesParser.getAttr(valuesLoc, valuesType);
return getChecked<SparseElementsAttr>(loc, type, indices, values);
}
Attribute Parser::parseStridedLayoutAttr() {
llvm::SMLoc loc = getToken().getLoc();
auto errorEmitter = [&] { return emitError(loc); };
consumeToken(Token::kw_strided);
if (failed(parseToken(Token::less, "expected '<' after 'strided'")) ||
failed(parseToken(Token::l_square, "expected '['")))
return nullptr;
auto parseStrideOrOffset = [&]() -> std::optional<int64_t> {
if (consumeIf(Token::question))
return ShapedType::kDynamic;
SMLoc loc = getToken().getLoc();
auto emitWrongTokenError = [&] {
emitError(loc, "expected a 64-bit signed integer or '?'");
return std::nullopt;
};
bool negative = consumeIf(Token::minus);
if (getToken().is(Token::integer)) {
std::optional<uint64_t> value = getToken().getUInt64IntegerValue();
if (!value ||
*value > static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
return emitWrongTokenError();
consumeToken();
auto result = static_cast<int64_t>(*value);
if (negative)
result = -result;
return result;
}
return emitWrongTokenError();
};
SmallVector<int64_t> strides;
if (!getToken().is(Token::r_square)) {
do {
std::optional<int64_t> stride = parseStrideOrOffset();
if (!stride)
return nullptr;
strides.push_back(*stride);
} while (consumeIf(Token::comma));
}
if (failed(parseToken(Token::r_square, "expected ']'")))
return nullptr;
if (consumeIf(Token::greater)) {
if (failed(StridedLayoutAttr::verify(errorEmitter,
0, strides)))
return nullptr;
return StridedLayoutAttr::get(getContext(), 0, strides);
}
if (failed(parseToken(Token::comma, "expected ','")) ||
failed(parseToken(Token::kw_offset, "expected 'offset' after comma")) ||
failed(parseToken(Token::colon, "expected ':' after 'offset'")))
return nullptr;
std::optional<int64_t> offset = parseStrideOrOffset();
if (!offset || failed(parseToken(Token::greater, "expected '>'")))
return nullptr;
if (failed(StridedLayoutAttr::verify(errorEmitter, *offset, strides)))
return nullptr;
return StridedLayoutAttr::get(getContext(), *offset, strides);
}
Attribute Parser::parseDistinctAttr(Type type) {
SMLoc loc = getToken().getLoc();
consumeToken(Token::kw_distinct);
if (parseToken(Token::l_square, "expected '[' after 'distinct'"))
return {};
Token token = getToken();
if (parseToken(Token::integer, "expected distinct ID"))
return {};
std::optional<uint64_t> value = token.getUInt64IntegerValue();
if (!value) {
emitError("expected an unsigned 64-bit integer");
return {};
}
if (parseToken(Token::r_square, "expected ']' to close distinct ID") ||
parseToken(Token::less, "expected '<' after distinct ID"))
return {};
Attribute referencedAttr;
if (getToken().is(Token::greater)) {
consumeToken();
referencedAttr = builder.getUnitAttr();
} else {
referencedAttr = parseAttribute(type);
if (!referencedAttr) {
emitError("expected attribute");
return {};
}
if (parseToken(Token::greater, "expected '>' to close distinct attribute"))
return {};
}
DenseMap<uint64_t, DistinctAttr> &distinctAttrs =
state.symbols.distinctAttributes;
auto it = distinctAttrs.find(*value);
if (it == distinctAttrs.end()) {
DistinctAttr distinctAttr = DistinctAttr::create(referencedAttr);
it = distinctAttrs.try_emplace(*value, distinctAttr).first;
} else if (it->getSecond().getReferencedAttr() != referencedAttr) {
emitError(loc, "referenced attribute does not match previous definition: ")
<< it->getSecond().getReferencedAttr();
return {};
}
return it->getSecond();
}