#include "Lexer.h"
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
using namespace mlir::pdll;
std::string Token::getStringValue() const {
assert(getKind() == string || getKind() == string_block ||
getKind() == code_complete_string);
StringRef bytes = getSpelling();
if (is(string))
bytes = bytes.drop_front().drop_back();
else if (is(string_block))
bytes = bytes.drop_front(2).drop_back(2);
std::string result;
result.reserve(bytes.size());
for (unsigned i = 0, e = bytes.size(); i != e;) {
auto c = bytes[i++];
if (c != '\\') {
result.push_back(c);
continue;
}
assert(i + 1 <= e && "invalid string should be caught by lexer");
auto c1 = bytes[i++];
switch (c1) {
case '"':
case '\\':
result.push_back(c1);
continue;
case 'n':
result.push_back('\n');
continue;
case 't':
result.push_back('\t');
continue;
default:
break;
}
assert(i + 1 <= e && "invalid string should be caught by lexer");
auto c2 = bytes[i++];
assert(llvm::isHexDigit(c1) && llvm::isHexDigit(c2) && "invalid escape");
result.push_back((llvm::hexDigitValue(c1) << 4) | llvm::hexDigitValue(c2));
}
return result;
}
Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
CodeCompleteContext *codeCompleteContext)
: srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false),
codeCompletionLocation(nullptr) {
curBufferID = mgr.getMainFileID();
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
curPtr = curBuffer.begin();
if (codeCompleteContext) {
codeCompletionLocation =
codeCompleteContext->getCodeCompleteLoc().getPointer();
}
if (!diagEngine.getHandlerFn()) {
diagEngine.setHandlerFn([&](const ast::Diagnostic &diag) {
srcMgr.PrintMessage(diag.getLocation().Start, diag.getSeverity(),
diag.getMessage());
for (const ast::Diagnostic ¬e : diag.getNotes())
srcMgr.PrintMessage(note.getLocation().Start, note.getSeverity(),
note.getMessage());
});
addedHandlerToDiagEngine = true;
}
}
Lexer::~Lexer() {
if (addedHandlerToDiagEngine)
diagEngine.setHandlerFn(nullptr);
}
LogicalResult Lexer::pushInclude(StringRef filename, SMRange includeLoc) {
std::string includedFile;
int bufferID =
srcMgr.AddIncludeFile(filename.str(), includeLoc.End, includedFile);
if (!bufferID)
return failure();
curBufferID = bufferID;
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
curPtr = curBuffer.begin();
return success();
}
Token Lexer::emitError(SMRange loc, const Twine &msg) {
diagEngine.emitError(loc, msg);
return formToken(Token::error, loc.Start.getPointer());
}
Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
const Twine ¬e) {
diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
return formToken(Token::error, loc.Start.getPointer());
}
Token Lexer::emitError(const char *loc, const Twine &msg) {
return emitError(
SMRange(SMLoc::getFromPointer(loc), SMLoc::getFromPointer(loc + 1)), msg);
}
int Lexer::getNextChar() {
char curChar = *curPtr++;
switch (curChar) {
default:
return static_cast<unsigned char>(curChar);
case 0: {
if (curPtr - 1 != curBuffer.end())
return 0;
--curPtr;
return EOF;
}
case '\n':
case '\r':
if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
++curPtr;
return '\n';
}
}
Token Lexer::lexToken() {
while (true) {
const char *tokStart = curPtr;
if (tokStart == codeCompletionLocation)
return formToken(Token::code_complete, tokStart);
int curChar = getNextChar();
switch (curChar) {
default:
if (isalpha(curChar) || curChar == '_')
return lexIdentifier(tokStart);
return emitError(tokStart, "unexpected character");
case EOF: {
Token eof = formToken(Token::eof, tokStart);
SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
if (parentIncludeLoc.isValid()) {
curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
curPtr = parentIncludeLoc.getPointer();
}
return eof;
}
case '-':
if (*curPtr == '>') {
++curPtr;
return formToken(Token::arrow, tokStart);
}
return emitError(tokStart, "unexpected character");
case ':':
return formToken(Token::colon, tokStart);
case ',':
return formToken(Token::comma, tokStart);
case '.':
return formToken(Token::dot, tokStart);
case '=':
if (*curPtr == '>') {
++curPtr;
return formToken(Token::equal_arrow, tokStart);
}
return formToken(Token::equal, tokStart);
case ';':
return formToken(Token::semicolon, tokStart);
case '[':
if (*curPtr == '{') {
++curPtr;
return lexString(tokStart, true);
}
return formToken(Token::l_square, tokStart);
case ']':
return formToken(Token::r_square, tokStart);
case '<':
return formToken(Token::less, tokStart);
case '>':
return formToken(Token::greater, tokStart);
case '{':
return formToken(Token::l_brace, tokStart);
case '}':
return formToken(Token::r_brace, tokStart);
case '(':
return formToken(Token::l_paren, tokStart);
case ')':
return formToken(Token::r_paren, tokStart);
case '/':
if (*curPtr == '/') {
lexComment();
continue;
}
return emitError(tokStart, "unexpected character");
case 0:
case ' ':
case '\t':
case '\n':
return lexToken();
case '#':
return lexDirective(tokStart);
case '"':
return lexString(tokStart, false);
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
return lexNumber(tokStart);
}
}
}
void Lexer::lexComment() {
assert(*curPtr == '/');
++curPtr;
while (true) {
switch (*curPtr++) {
case '\n':
case '\r':
return;
case 0:
if (curPtr - 1 == curBuffer.end()) {
--curPtr;
return;
}
[[fallthrough]];
default:
break;
}
}
}
Token Lexer::lexDirective(const char *tokStart) {
while (isalnum(*curPtr) || *curPtr == '_')
++curPtr;
StringRef str(tokStart, curPtr - tokStart);
return Token(Token::directive, str);
}
Token Lexer::lexIdentifier(const char *tokStart) {
while (isalnum(*curPtr) || *curPtr == '_')
++curPtr;
StringRef str(tokStart, curPtr - tokStart);
Token::Kind kind = StringSwitch<Token::Kind>(str)
.Case("attr", Token::kw_attr)
.Case("Attr", Token::kw_Attr)
.Case("erase", Token::kw_erase)
.Case("let", Token::kw_let)
.Case("Constraint", Token::kw_Constraint)
.Case("not", Token::kw_not)
.Case("op", Token::kw_op)
.Case("Op", Token::kw_Op)
.Case("OpName", Token::kw_OpName)
.Case("Pattern", Token::kw_Pattern)
.Case("replace", Token::kw_replace)
.Case("return", Token::kw_return)
.Case("rewrite", Token::kw_rewrite)
.Case("Rewrite", Token::kw_Rewrite)
.Case("type", Token::kw_type)
.Case("Type", Token::kw_Type)
.Case("TypeRange", Token::kw_TypeRange)
.Case("Value", Token::kw_Value)
.Case("ValueRange", Token::kw_ValueRange)
.Case("with", Token::kw_with)
.Case("_", Token::underscore)
.Default(Token::identifier);
return Token(kind, str);
}
Token Lexer::lexNumber(const char *tokStart) {
assert(isdigit(curPtr[-1]));
while (isdigit(*curPtr))
++curPtr;
return formToken(Token::integer, tokStart);
}
Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
while (true) {
if (curPtr == codeCompletionLocation) {
return formToken(Token::code_complete_string,
tokStart + (isStringBlock ? 2 : 1));
}
switch (*curPtr++) {
case '"':
if (!isStringBlock)
return formToken(Token::string, tokStart);
continue;
case '}':
if (!isStringBlock || *curPtr != ']')
continue;
++curPtr;
return formToken(Token::string_block, tokStart);
case 0: {
if (curPtr - 1 != curBuffer.end())
continue;
--curPtr;
StringRef expectedEndStr = isStringBlock ? "}]" : "\"";
return emitError(curPtr - 1,
"expected '" + expectedEndStr + "' in string literal");
}
case '\n':
case '\v':
case '\f':
if (!isStringBlock)
return emitError(curPtr - 1, "expected '\"' in string literal");
continue;
case '\\':
if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
*curPtr == 't') {
++curPtr;
} else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
curPtr += 2;
} else {
return emitError(curPtr - 1, "unknown escape in string literal");
}
continue;
default:
continue;
}
}
}