#include "AsmParserImpl.h"
#include "Parser.h"
#include "mlir/AsmParser/AsmParserState.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include <cassert>
#include <cstddef>
#include <utility>
using namespace mlir;
using namespace mlir::detail;
using llvm::MemoryBuffer;
using llvm::SourceMgr;
namespace {
class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
public:
CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
: AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
fullSpec(fullSpec) {}
~CustomDialectAsmParser() override = default;
StringRef getFullSymbolSpec() const override { return fullSpec; }
private:
StringRef fullSpec;
};
}
ParseResult Parser::parseDialectSymbolBody(StringRef &body,
bool &isCodeCompletion) {
const char *curPtr = getTokenSpelling().data();
assert(*curPtr == '<');
SmallVector<char, 8> nestedPunctuation;
const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
auto emitPunctError = [&] {
return emitError() << "unbalanced '" << nestedPunctuation.back()
<< "' character in pretty dialect name";
};
auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult {
if (nestedPunctuation.back() != expectedToken)
return emitPunctError();
nestedPunctuation.pop_back();
return success();
};
do {
if (curPtr == codeCompleteLoc) {
isCodeCompletion = true;
nestedPunctuation.clear();
break;
}
char c = *curPtr++;
switch (c) {
case '\0':
if (!nestedPunctuation.empty())
return emitPunctError();
return emitError("unexpected nul or EOF in pretty dialect name");
case '<':
case '[':
case '(':
case '{':
nestedPunctuation.push_back(c);
continue;
case '-':
if (*curPtr == '>')
++curPtr;
continue;
case '>':
if (failed(checkNestedPunctuation('<')))
return failure();
break;
case ']':
if (failed(checkNestedPunctuation('[')))
return failure();
break;
case ')':
if (failed(checkNestedPunctuation('(')))
return failure();
break;
case '}':
if (failed(checkNestedPunctuation('{')))
return failure();
break;
case '"': {
resetToken(curPtr - 1);
curPtr = state.curToken.getEndLoc().getPointer();
if (state.curToken.isCodeCompletion()) {
isCodeCompletion = true;
nestedPunctuation.clear();
break;
}
if (state.curToken.isNot(Token::string))
return failure();
break;
}
default:
continue;
}
} while (!nestedPunctuation.empty());
resetToken(curPtr);
unsigned length = curPtr - body.begin();
body = StringRef(body.data(), length);
return success();
}
template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
SymbolAliasMap &aliases,
CreateFn &&createSymbol) {
Token tok = p.getToken();
StringRef identifier = tok.getSpelling().drop_front();
if (tok.isCodeCompletion() && identifier.empty())
return p.codeCompleteDialectSymbol(aliases);
SMRange range = p.getToken().getLocRange();
SMLoc loc = p.getToken().getLoc();
p.consumeToken();
auto [dialectName, symbolData] = identifier.split('.');
bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
bool hasTrailingData =
p.getToken().is(Token::less) &&
identifier.bytes_end() == p.getTokenSpelling().bytes_begin();
if (!hasTrailingData && !isPrettyName) {
auto aliasIt = aliases.find(identifier);
if (aliasIt == aliases.end())
return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
"'"),
nullptr);
if (asmState) {
if constexpr (std::is_same_v<Symbol, Type>)
asmState->addTypeAliasUses(identifier, range);
else
asmState->addAttrAliasUses(identifier, range);
}
return aliasIt->second;
}
if (!isPrettyName) {
symbolData = StringRef(dialectName.end(), 0);
bool isCodeCompletion = false;
if (p.parseDialectSymbolBody(symbolData, isCodeCompletion))
return nullptr;
symbolData = symbolData.drop_front();
if (!isCodeCompletion)
symbolData = symbolData.drop_back();
} else {
loc = SMLoc::getFromPointer(symbolData.data());
if (hasTrailingData && p.parseDialectSymbolBody(symbolData))
return nullptr;
}
return createSymbol(dialectName, symbolData, loc);
}
Attribute Parser::parseExtendedAttr(Type type) {
MLIRContext *ctx = getContext();
Attribute attr = parseExtendedSymbol<Attribute>(
*this, state.asmState, state.symbols.attributeAliasDefinitions,
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
Type attrType = type;
if (consumeIf(Token::colon) && !(attrType = parseType()))
return Attribute();
if (Dialect *dialect =
builder.getContext()->getOrLoadDialect(dialectName)) {
const char *curLexerPos = getToken().getLoc().getPointer();
resetToken(symbolData.data());
CustomDialectAsmParser customParser(symbolData, *this);
Attribute attr = dialect->parseAttribute(customParser, attrType);
resetToken(curLexerPos);
return attr;
}
return OpaqueAttr::getChecked(
[&] { return emitError(loc); }, StringAttr::get(ctx, dialectName),
symbolData, attrType ? attrType : NoneType::get(ctx));
});
auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
if (type && typedAttr && typedAttr.getType() != type) {
emitError("attribute type different than expected: expected ")
<< type << ", but got " << typedAttr.getType();
return nullptr;
}
return attr;
}
Type Parser::parseExtendedType() {
MLIRContext *ctx = getContext();
return parseExtendedSymbol<Type>(
*this, state.asmState, state.symbols.typeAliasDefinitions,
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
const char *curLexerPos = getToken().getLoc().getPointer();
resetToken(symbolData.data());
CustomDialectAsmParser customParser(symbolData, *this);
Type type = dialect->parseType(customParser);
resetToken(curLexerPos);
return type;
}
return OpaqueType::getChecked([&] { return emitError(loc); },
StringAttr::get(ctx, dialectName),
symbolData);
});
}
template <typename T, typename ParserFn>
static T parseSymbol(StringRef inputStr, MLIRContext *context,
size_t *numReadOut, bool isKnownNullTerminated,
ParserFn &&parserFn) {
auto memBuffer =
isKnownNullTerminated
? MemoryBuffer::getMemBuffer(inputStr,
inputStr)
: MemoryBuffer::getMemBufferCopy(inputStr, inputStr);
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
SymbolState aliasState;
ParserConfig config(context);
ParserState state(sourceMgr, config, aliasState, nullptr,
nullptr);
Parser parser(state);
Token startTok = parser.getToken();
T symbol = parserFn(parser);
if (!symbol)
return T();
Token endTok = parser.getToken();
size_t numRead =
endTok.getLoc().getPointer() - startTok.getLoc().getPointer();
if (numReadOut) {
*numReadOut = numRead;
} else if (numRead != inputStr.size()) {
parser.emitError(endTok.getLoc()) << "found trailing characters: '"
<< inputStr.drop_front(numRead) << "'";
return T();
}
return symbol;
}
Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
Type type, size_t *numRead,
bool isKnownNullTerminated) {
return parseSymbol<Attribute>(
attrStr, context, numRead, isKnownNullTerminated,
[type](Parser &parser) { return parser.parseAttribute(type); });
}
Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
bool isKnownNullTerminated) {
return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
[](Parser &parser) { return parser.parseType(); });
}