#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_PARSER_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_PARSER_H
#include "Diagnostics.h"
#include "RegistryManager.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
#include <vector>
namespace mlir::query::matcher::internal {
class Parser {
public:
enum class TokenKind {
Eof,
NewLine,
OpenParen,
CloseParen,
Comma,
Period,
Literal,
Ident,
InvalidChar,
CodeCompletion,
Error
};
class Sema {
public:
virtual ~Sema();
virtual VariantMatcher actOnMatcherExpression(
MatcherCtor ctor, SourceRange nameRange, llvm::StringRef functionName,
llvm::ArrayRef<ParserValue> args, Diagnostics *error) = 0;
virtual std::optional<MatcherCtor>
lookupMatcherCtor(llvm::StringRef matcherName) = 0;
virtual std::vector<ArgKind> getAcceptedCompletionTypes(
llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> Context);
virtual std::vector<MatcherCompletion>
getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes);
};
class RegistrySema : public Parser::Sema {
public:
RegistrySema(const Registry &matcherRegistry)
: matcherRegistry(matcherRegistry) {}
~RegistrySema() override;
std::optional<MatcherCtor>
lookupMatcherCtor(llvm::StringRef matcherName) override;
VariantMatcher actOnMatcherExpression(MatcherCtor Ctor,
SourceRange NameRange,
StringRef functionName,
ArrayRef<ParserValue> Args,
Diagnostics *Error) override;
std::vector<ArgKind> getAcceptedCompletionTypes(
llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) override;
std::vector<MatcherCompletion>
getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) override;
private:
const Registry &matcherRegistry;
};
using NamedValueMap = llvm::StringMap<VariantValue>;
static std::optional<DynMatcher>
parseMatcherExpression(llvm::StringRef &matcherCode,
const Registry &matcherRegistry,
const NamedValueMap *namedValues, Diagnostics *error);
static std::optional<DynMatcher>
parseMatcherExpression(llvm::StringRef &matcherCode,
const Registry &matcherRegistry, Diagnostics *error) {
return parseMatcherExpression(matcherCode, matcherRegistry, nullptr, error);
}
static bool parseExpression(llvm::StringRef &code,
const Registry &matcherRegistry,
const NamedValueMap *namedValues,
VariantValue *value, Diagnostics *error);
static bool parseExpression(llvm::StringRef &code,
const Registry &matcherRegistry,
VariantValue *value, Diagnostics *error) {
return parseExpression(code, matcherRegistry, nullptr, value, error);
}
static std::vector<MatcherCompletion>
completeExpression(llvm::StringRef &code, unsigned completionOffset,
const Registry &matcherRegistry,
const NamedValueMap *namedValues);
static std::vector<MatcherCompletion>
completeExpression(llvm::StringRef &code, unsigned completionOffset,
const Registry &matcherRegistry) {
return completeExpression(code, completionOffset, matcherRegistry, nullptr);
}
private:
class CodeTokenizer;
struct ScopedContextEntry;
struct TokenInfo;
Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
const NamedValueMap *namedValues, Diagnostics *error);
bool parseChainedExpression(std::string &argument);
bool parseExpressionImpl(VariantValue *value);
bool parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
const TokenInfo &nameToken, TokenInfo &endToken);
bool parseMatcherExpressionImpl(const TokenInfo &nameToken,
const TokenInfo &openToken,
std::optional<MatcherCtor> ctor,
VariantValue *value);
bool parseIdentifierPrefixImpl(VariantValue *value);
void addCompletion(const TokenInfo &compToken,
const MatcherCompletion &completion);
void addExpressionCompletions();
std::vector<MatcherCompletion>
getNamedValueCompletions(llvm::ArrayRef<ArgKind> acceptedTypes);
CodeTokenizer *const tokenizer;
std::unique_ptr<RegistrySema> sema;
const NamedValueMap *const namedValues;
Diagnostics *const error;
using ContextStackTy = std::vector<std::pair<MatcherCtor, unsigned>>;
ContextStackTy contextStack;
std::vector<MatcherCompletion> completions;
};
}
#endif