#include "IncrementalParser.h"
#include "clang/AST/DeclContextInternals.h"
#include "clang/CodeGen/BackendUtil.h"
#include "clang/CodeGen/CodeGenAction.h"
#include "clang/CodeGen/ModuleBuilder.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendAction.h"
#include "clang/FrontendTool/Utils.h"
#include "clang/Interpreter/Interpreter.h"
#include "clang/Parse/Parser.h"
#include "clang/Sema/Sema.h"
#include "llvm/Option/ArgList.h"
#include "llvm/Support/CrashRecoveryContext.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/Timer.h"
#include <sstream>
namespace clang {
class IncrementalASTConsumer final : public ASTConsumer {
Interpreter &Interp;
std::unique_ptr<ASTConsumer> Consumer;
public:
IncrementalASTConsumer(Interpreter &InterpRef, std::unique_ptr<ASTConsumer> C)
: Interp(InterpRef), Consumer(std::move(C)) {}
bool HandleTopLevelDecl(DeclGroupRef DGR) override final {
if (DGR.isNull())
return true;
if (!Consumer)
return true;
for (Decl *D : DGR)
if (auto *TSD = llvm::dyn_cast<TopLevelStmtDecl>(D);
TSD && TSD->isSemiMissing())
TSD->setStmt(Interp.SynthesizeExpr(cast<Expr>(TSD->getStmt())));
return Consumer->HandleTopLevelDecl(DGR);
}
void HandleTranslationUnit(ASTContext &Ctx) override final {
Consumer->HandleTranslationUnit(Ctx);
}
void HandleInlineFunctionDefinition(FunctionDecl *D) override final {
Consumer->HandleInlineFunctionDefinition(D);
}
void HandleInterestingDecl(DeclGroupRef D) override final {
Consumer->HandleInterestingDecl(D);
}
void HandleTagDeclDefinition(TagDecl *D) override final {
Consumer->HandleTagDeclDefinition(D);
}
void HandleTagDeclRequiredDefinition(const TagDecl *D) override final {
Consumer->HandleTagDeclRequiredDefinition(D);
}
void HandleCXXImplicitFunctionInstantiation(FunctionDecl *D) override final {
Consumer->HandleCXXImplicitFunctionInstantiation(D);
}
void HandleTopLevelDeclInObjCContainer(DeclGroupRef D) override final {
Consumer->HandleTopLevelDeclInObjCContainer(D);
}
void HandleImplicitImportDecl(ImportDecl *D) override final {
Consumer->HandleImplicitImportDecl(D);
}
void CompleteTentativeDefinition(VarDecl *D) override final {
Consumer->CompleteTentativeDefinition(D);
}
void CompleteExternalDeclaration(DeclaratorDecl *D) override final {
Consumer->CompleteExternalDeclaration(D);
}
void AssignInheritanceModel(CXXRecordDecl *RD) override final {
Consumer->AssignInheritanceModel(RD);
}
void HandleCXXStaticMemberVarInstantiation(VarDecl *D) override final {
Consumer->HandleCXXStaticMemberVarInstantiation(D);
}
void HandleVTable(CXXRecordDecl *RD) override final {
Consumer->HandleVTable(RD);
}
ASTMutationListener *GetASTMutationListener() override final {
return Consumer->GetASTMutationListener();
}
ASTDeserializationListener *GetASTDeserializationListener() override final {
return Consumer->GetASTDeserializationListener();
}
void PrintStats() override final { Consumer->PrintStats(); }
bool shouldSkipFunctionBody(Decl *D) override final {
return Consumer->shouldSkipFunctionBody(D);
}
static bool classof(const clang::ASTConsumer *) { return true; }
};
class IncrementalAction : public WrapperFrontendAction {
private:
bool IsTerminating = false;
public:
IncrementalAction(CompilerInstance &CI, llvm::LLVMContext &LLVMCtx,
llvm::Error &Err)
: WrapperFrontendAction([&]() {
llvm::ErrorAsOutParameter EAO(&Err);
std::unique_ptr<FrontendAction> Act;
switch (CI.getFrontendOpts().ProgramAction) {
default:
Err = llvm::createStringError(
std::errc::state_not_recoverable,
"Driver initialization failed. "
"Incremental mode for action %d is not supported",
CI.getFrontendOpts().ProgramAction);
return Act;
case frontend::ASTDump:
[[fallthrough]];
case frontend::ASTPrint:
[[fallthrough]];
case frontend::ParseSyntaxOnly:
Act = CreateFrontendAction(CI);
break;
case frontend::PluginAction:
[[fallthrough]];
case frontend::EmitAssembly:
[[fallthrough]];
case frontend::EmitBC:
[[fallthrough]];
case frontend::EmitObj:
[[fallthrough]];
case frontend::PrintPreprocessedInput:
[[fallthrough]];
case frontend::EmitLLVMOnly:
Act.reset(new EmitLLVMOnlyAction(&LLVMCtx));
break;
}
return Act;
}()) {}
FrontendAction *getWrapped() const { return WrappedAction.get(); }
TranslationUnitKind getTranslationUnitKind() override {
return TU_Incremental;
}
void ExecuteAction() override {
CompilerInstance &CI = getCompilerInstance();
assert(CI.hasPreprocessor() && "No PP!");
CodeCompleteConsumer *CompletionConsumer = nullptr;
if (CI.hasCodeCompletionConsumer())
CompletionConsumer = &CI.getCodeCompletionConsumer();
Preprocessor &PP = CI.getPreprocessor();
PP.EnterMainSourceFile();
if (!CI.hasSema())
CI.createSema(getTranslationUnitKind(), CompletionConsumer);
}
void EndSourceFile() override {
if (IsTerminating && getWrapped())
WrapperFrontendAction::EndSourceFile();
}
void FinalizeAction() {
assert(!IsTerminating && "Already finalized!");
IsTerminating = true;
EndSourceFile();
}
};
CodeGenerator *IncrementalParser::getCodeGen() const {
FrontendAction *WrappedAct = Act->getWrapped();
if (!WrappedAct->hasIRSupport())
return nullptr;
return static_cast<CodeGenAction *>(WrappedAct)->getCodeGenerator();
}
IncrementalParser::IncrementalParser() {}
IncrementalParser::IncrementalParser(Interpreter &Interp,
std::unique_ptr<CompilerInstance> Instance,
llvm::LLVMContext &LLVMCtx,
llvm::Error &Err)
: CI(std::move(Instance)) {
llvm::ErrorAsOutParameter EAO(&Err);
Act = std::make_unique<IncrementalAction>(*CI, LLVMCtx, Err);
if (Err)
return;
CI->ExecuteAction(*Act);
if (getCodeGen())
CachedInCodeGenModule = GenModule();
std::unique_ptr<ASTConsumer> IncrConsumer =
std::make_unique<IncrementalASTConsumer>(Interp, CI->takeASTConsumer());
CI->setASTConsumer(std::move(IncrConsumer));
Consumer = &CI->getASTConsumer();
P.reset(
new Parser(CI->getPreprocessor(), CI->getSema(), false));
P->Initialize();
auto PTU = ParseOrWrapTopLevelDecl();
if (auto E = PTU.takeError()) {
consumeError(std::move(E));
return;
}
if (getCodeGen()) {
PTU->TheModule = GenModule();
assert(PTU->TheModule && "Failed to create initial PTU");
}
}
IncrementalParser::~IncrementalParser() {
P.reset();
Act->FinalizeAction();
}
llvm::Expected<PartialTranslationUnit &>
IncrementalParser::ParseOrWrapTopLevelDecl() {
Sema &S = CI->getSema();
llvm::CrashRecoveryContextCleanupRegistrar<Sema> CleanupSema(&S);
Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, true);
Sema::LocalEagerInstantiationScope LocalInstantiations(S);
PTUs.emplace_back(PartialTranslationUnit());
PartialTranslationUnit &LastPTU = PTUs.back();
ASTContext &C = S.getASTContext();
C.addTranslationUnitDecl();
LastPTU.TUPart = C.getTranslationUnitDecl();
if (P->getCurToken().is(tok::annot_repl_input_end)) {
P->ConsumeAnyToken();
P->ExitScope();
S.CurContext = nullptr;
P->EnterScope(Scope::DeclScope);
S.ActOnTranslationUnitScope(P->getCurScope());
}
Parser::DeclGroupPtrTy ADecl;
Sema::ModuleImportState ImportState;
for (bool AtEOF = P->ParseFirstTopLevelDecl(ADecl, ImportState); !AtEOF;
AtEOF = P->ParseTopLevelDecl(ADecl, ImportState)) {
if (ADecl && !Consumer->HandleTopLevelDecl(ADecl.get()))
return llvm::make_error<llvm::StringError>("Parsing failed. "
"The consumer rejected a decl",
std::error_code());
}
DiagnosticsEngine &Diags = getCI()->getDiagnostics();
if (Diags.hasErrorOccurred()) {
PartialTranslationUnit MostRecentPTU = {C.getTranslationUnitDecl(),
nullptr};
CleanUpPTU(MostRecentPTU);
Diags.Reset(true);
Diags.getClient()->clear();
return llvm::make_error<llvm::StringError>("Parsing failed.",
std::error_code());
}
for (Decl *D : S.WeakTopLevelDecls()) {
DeclGroupRef DGR(D);
Consumer->HandleTopLevelDecl(DGR);
}
LocalInstantiations.perform();
GlobalInstantiations.perform();
Consumer->HandleTranslationUnit(C);
return LastPTU;
}
llvm::Expected<PartialTranslationUnit &>
IncrementalParser::Parse(llvm::StringRef input) {
Preprocessor &PP = CI->getPreprocessor();
assert(PP.isIncrementalProcessingEnabled() && "Not in incremental mode!?");
std::ostringstream SourceName;
SourceName << "input_line_" << InputCount++;
size_t InputSize = input.size();
std::unique_ptr<llvm::MemoryBuffer> MB(
llvm::WritableMemoryBuffer::getNewUninitMemBuffer(InputSize + 1,
SourceName.str()));
char *MBStart = const_cast<char *>(MB->getBufferStart());
memcpy(MBStart, input.data(), InputSize);
MBStart[InputSize] = '\n';
SourceManager &SM = CI->getSourceManager();
SourceLocation NewLoc = SM.getLocForStartOfFile(SM.getMainFileID());
FileID FID = SM.createFileID(std::move(MB), SrcMgr::C_User, 0,
0, NewLoc);
if (PP.EnterSourceFile(FID, nullptr, NewLoc))
return llvm::make_error<llvm::StringError>("Parsing failed. "
"Cannot enter source file.",
std::error_code());
auto PTU = ParseOrWrapTopLevelDecl();
if (!PTU)
return PTU.takeError();
if (PP.getLangOpts().DelayedTemplateParsing) {
Token Tok;
do {
PP.Lex(Tok);
} while (Tok.isNot(tok::annot_repl_input_end));
} else {
Token AssertTok;
PP.Lex(AssertTok);
assert(AssertTok.is(tok::annot_repl_input_end) &&
"Lexer must be EOF when starting incremental parse!");
}
if (std::unique_ptr<llvm::Module> M = GenModule())
PTU->TheModule = std::move(M);
return PTU;
}
std::unique_ptr<llvm::Module> IncrementalParser::GenModule() {
static unsigned ID = 0;
if (CodeGenerator *CG = getCodeGen()) {
assert((!CachedInCodeGenModule ||
(CachedInCodeGenModule->empty() &&
CachedInCodeGenModule->global_empty() &&
CachedInCodeGenModule->alias_empty() &&
CachedInCodeGenModule->ifunc_empty())) &&
"CodeGen wrote to a readonly module");
std::unique_ptr<llvm::Module> M(CG->ReleaseModule());
CG->StartModule("incr_module_" + std::to_string(ID++), M->getContext());
return M;
}
return nullptr;
}
void IncrementalParser::CleanUpPTU(PartialTranslationUnit &PTU) {
TranslationUnitDecl *MostRecentTU = PTU.TUPart;
if (StoredDeclsMap *Map = MostRecentTU->getPrimaryContext()->getLookupPtr()) {
for (auto &&[Key, List] : *Map) {
DeclContextLookupResult R = List.getLookupResult();
std::vector<NamedDecl *> NamedDeclsToRemove;
bool RemoveAll = true;
for (NamedDecl *D : R) {
if (D->getTranslationUnitDecl() == MostRecentTU)
NamedDeclsToRemove.push_back(D);
else
RemoveAll = false;
}
if (LLVM_LIKELY(RemoveAll)) {
Map->erase(Key);
} else {
for (NamedDecl *D : NamedDeclsToRemove)
List.remove(D);
}
}
}
for (Decl *D : MostRecentTU->decls()) {
auto *ND = dyn_cast<NamedDecl>(D);
if (!ND)
continue;
if (ND->getDeclName().getFETokenInfo() && !D->getLangOpts().ObjC &&
!D->getLangOpts().CPlusPlus)
getCI()->getSema().IdResolver.RemoveDecl(ND);
}
}
llvm::StringRef IncrementalParser::GetMangledName(GlobalDecl GD) const {
CodeGenerator *CG = getCodeGen();
assert(CG);
return CG->GetMangledName(GD);
}
}