#include "AST.h"
#include "FindTarget.h"
#include "Selection.h"
#include "SourceCode.h"
#include "XRefs.h"
#include "refactor/Tweak.h"
#include "support/Logger.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTTypeTraits.h"
#include "clang/AST/Decl.h"
#include "clang/AST/DeclBase.h"
#include "clang/AST/DeclCXX.h"
#include "clang/AST/DeclTemplate.h"
#include "clang/AST/NestedNameSpecifier.h"
#include "clang/AST/Stmt.h"
#include "clang/Basic/LangOptions.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Basic/TokenKinds.h"
#include "clang/Lex/Lexer.h"
#include "clang/Lex/Token.h"
#include "clang/Sema/Lookup.h"
#include "clang/Sema/Sema.h"
#include "clang/Tooling/Core/Replacement.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/raw_ostream.h"
#include <cstddef>
#include <optional>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace clang {
namespace clangd {
namespace {
std::optional<SourceLocation> getSemicolonForDecl(const FunctionDecl *FD) {
const SourceManager &SM = FD->getASTContext().getSourceManager();
const LangOptions &LangOpts = FD->getASTContext().getLangOpts();
SourceLocation CurLoc = FD->getEndLoc();
auto NextTok = Lexer::findNextToken(CurLoc, SM, LangOpts);
if (!NextTok || !NextTok->is(tok::semi))
return std::nullopt;
return NextTok->getLocation();
}
const FunctionDecl *getSelectedFunction(const SelectionTree::Node *SelNode) {
const DynTypedNode &AstNode = SelNode->ASTNode;
if (const FunctionDecl *FD = AstNode.get<FunctionDecl>())
return FD;
if (AstNode.get<CompoundStmt>() &&
SelNode->Selected == SelectionTree::Complete) {
if (const SelectionTree::Node *P = SelNode->Parent)
return P->ASTNode.get<FunctionDecl>();
}
return nullptr;
}
bool checkDeclsAreVisible(const llvm::DenseSet<const Decl *> &DeclRefs,
const FunctionDecl *Target, const SourceManager &SM) {
SourceLocation TargetLoc = Target->getLocation();
const RecordDecl *Class = nullptr;
if (const auto *MD = llvm::dyn_cast<CXXMethodDecl>(Target))
Class = MD->getParent();
for (const auto *DR : DeclRefs) {
const Decl *D = DR->getCanonicalDecl();
if (D == Target)
continue;
SourceLocation DeclLoc = D->getLocation();
if (!SM.isWrittenInSameFile(DeclLoc, TargetLoc))
return false;
if (SM.isBeforeInTranslationUnit(DeclLoc, TargetLoc))
continue;
if (!Class)
return false;
const RecordDecl *Parent = nullptr;
if (const auto *MD = llvm::dyn_cast<CXXMethodDecl>(D))
Parent = MD->getParent();
else if (const auto *FD = llvm::dyn_cast<FieldDecl>(D))
Parent = FD->getParent();
if (Parent != Class)
return false;
}
return true;
}
llvm::Expected<std::string> qualifyAllDecls(const FunctionDecl *FD,
const FunctionDecl *Target,
const HeuristicResolver *Resolver) {
auto *TargetContext = Target->getLexicalDeclContext();
const SourceManager &SM = FD->getASTContext().getSourceManager();
tooling::Replacements Replacements;
bool HadErrors = false;
findExplicitReferences(
FD->getBody(),
[&](ReferenceLoc Ref) {
if (Ref.Qualifier)
return;
if (Ref.Targets.empty())
return;
if (Ref.NameLoc.isMacroID())
return;
for (const NamedDecl *ND : Ref.Targets) {
if (ND->getDeclContext() != Ref.Targets.front()->getDeclContext()) {
elog("define inline: Targets from multiple contexts: {0}, {1}",
printQualifiedName(*Ref.Targets.front()),
printQualifiedName(*ND));
HadErrors = true;
return;
}
}
const NamedDecl *ND = Ref.Targets.front();
if (!ND->getDeclContext()->isNamespace())
return;
const std::string Qualifier = getQualification(
FD->getASTContext(), TargetContext, Target->getBeginLoc(), ND);
if (auto Err = Replacements.add(
tooling::Replacement(SM, Ref.NameLoc, 0, Qualifier))) {
HadErrors = true;
elog("define inline: Failed to add quals: {0}", std::move(Err));
}
},
Resolver);
if (HadErrors)
return error(
"define inline: Failed to compute qualifiers. See logs for details.");
auto OrigBodyRange = toHalfOpenFileRange(
SM, FD->getASTContext().getLangOpts(), FD->getBody()->getSourceRange());
if (!OrigBodyRange)
return error("Couldn't get range func body.");
unsigned BodyBegin = SM.getFileOffset(OrigBodyRange->getBegin());
unsigned BodyEnd = Replacements.getShiftedCodePosition(
SM.getFileOffset(OrigBodyRange->getEnd()));
auto QualifiedFunc = tooling::applyAllReplacements(
SM.getBufferData(SM.getFileID(OrigBodyRange->getBegin())), Replacements);
if (!QualifiedFunc)
return QualifiedFunc.takeError();
return QualifiedFunc->substr(BodyBegin, BodyEnd - BodyBegin + 1);
}
llvm::Expected<tooling::Replacements>
renameParameters(const FunctionDecl *Dest, const FunctionDecl *Source,
const HeuristicResolver *Resolver) {
llvm::DenseMap<const Decl *, std::string> ParamToNewName;
llvm::DenseMap<const NamedDecl *, std::vector<SourceLocation>> RefLocs;
auto HandleParam = [&](const NamedDecl *DestParam,
const NamedDecl *SourceParam) {
if (DestParam->getName() == SourceParam->getName())
return;
std::string NewName;
if (DestParam->getName().empty()) {
RefLocs[DestParam].push_back(DestParam->getLocation());
NewName = " ";
}
NewName.append(std::string(SourceParam->getName()));
ParamToNewName[DestParam->getCanonicalDecl()] = std::move(NewName);
};
auto *DestTempl = Dest->getDescribedFunctionTemplate();
auto *SourceTempl = Source->getDescribedFunctionTemplate();
assert(bool(DestTempl) == bool(SourceTempl));
if (DestTempl) {
const auto *DestTPL = DestTempl->getTemplateParameters();
const auto *SourceTPL = SourceTempl->getTemplateParameters();
assert(DestTPL->size() == SourceTPL->size());
for (size_t I = 0, EP = DestTPL->size(); I != EP; ++I)
HandleParam(DestTPL->getParam(I), SourceTPL->getParam(I));
}
assert(Dest->param_size() == Source->param_size());
for (size_t I = 0, E = Dest->param_size(); I != E; ++I)
HandleParam(Dest->getParamDecl(I), Source->getParamDecl(I));
const SourceManager &SM = Dest->getASTContext().getSourceManager();
const LangOptions &LangOpts = Dest->getASTContext().getLangOpts();
findExplicitReferences(
DestTempl ? llvm::dyn_cast<Decl>(DestTempl) : llvm::dyn_cast<Decl>(Dest),
[&](ReferenceLoc Ref) {
if (Ref.Targets.size() != 1)
return;
const auto *Target =
llvm::cast<NamedDecl>(Ref.Targets.front()->getCanonicalDecl());
auto It = ParamToNewName.find(Target);
if (It == ParamToNewName.end())
return;
RefLocs[Target].push_back(Ref.NameLoc);
},
Resolver);
tooling::Replacements Replacements;
for (auto &Entry : RefLocs) {
const auto *OldDecl = Entry.first;
llvm::StringRef OldName = OldDecl->getName();
llvm::StringRef NewName = ParamToNewName[OldDecl];
for (SourceLocation RefLoc : Entry.second) {
CharSourceRange ReplaceRange;
if (OldName.empty())
ReplaceRange = CharSourceRange::getCharRange(RefLoc, RefLoc);
else
ReplaceRange = CharSourceRange::getTokenRange(RefLoc, RefLoc);
if (RefLoc.isMacroID()) {
ReplaceRange = Lexer::makeFileCharRange(ReplaceRange, SM, LangOpts);
if (ReplaceRange.isInvalid()) {
auto Err = error("Cant rename parameter inside macro body.");
elog("define inline: {0}", Err);
return std::move(Err);
}
}
if (auto Err = Replacements.add(
tooling::Replacement(SM, ReplaceRange, NewName))) {
elog("define inline: Couldn't replace parameter name for {0} to {1}: "
"{2}",
OldName, NewName, Err);
return std::move(Err);
}
}
}
return Replacements;
}
const FunctionDecl *findTarget(const FunctionDecl *FD) {
auto *CanonDecl = FD->getCanonicalDecl();
if (!FD->isFunctionTemplateSpecialization() || CanonDecl == FD)
return CanonDecl;
auto *PrevDecl = FD;
while (PrevDecl->getPreviousDecl() != CanonDecl) {
PrevDecl = PrevDecl->getPreviousDecl();
assert(PrevDecl && "Found specialization without template decl");
}
return PrevDecl;
}
const SourceLocation getBeginLoc(const FunctionDecl *FD) {
if (auto *FTD = FD->getDescribedFunctionTemplate())
return FTD->getBeginLoc();
return FD->getBeginLoc();
}
std::optional<tooling::Replacement>
addInlineIfInHeader(const FunctionDecl *FD) {
if (FD->isInlined() || llvm::isa<CXXMethodDecl>(FD))
return std::nullopt;
if (FD->isTemplated() && !FD->isFunctionTemplateSpecialization())
return std::nullopt;
const SourceManager &SM = FD->getASTContext().getSourceManager();
llvm::StringRef FileName = SM.getFilename(FD->getLocation());
if (!isHeaderFile(FileName, FD->getASTContext().getLangOpts()))
return std::nullopt;
return tooling::Replacement(SM, FD->getInnerLocStart(), 0, "inline ");
}
class DefineInline : public Tweak {
public:
const char *id() const final;
llvm::StringLiteral kind() const override {
return CodeAction::REFACTOR_KIND;
}
std::string title() const override {
return "Move function body to declaration";
}
bool prepare(const Selection &Sel) override {
const SelectionTree::Node *SelNode = Sel.ASTSelection.commonAncestor();
if (!SelNode)
return false;
Source = getSelectedFunction(SelNode);
if (!Source || !Source->hasBody())
return false;
if (auto *MD = llvm::dyn_cast<CXXMethodDecl>(Source)) {
if (MD->getParent()->isTemplated())
return false;
}
if (Source->getBody()->getBeginLoc().isMacroID() ||
Source->getBody()->getEndLoc().isMacroID())
return false;
Target = findTarget(Source);
if (Target == Source) {
return false;
}
if (!checkDeclsAreVisible(getNonLocalDeclRefs(*Sel.AST, Source), Target,
Sel.AST->getSourceManager()))
return false;
return true;
}
Expected<Effect> apply(const Selection &Sel) override {
const auto &AST = Sel.AST->getASTContext();
const auto &SM = AST.getSourceManager();
auto Semicolon = getSemicolonForDecl(Target);
if (!Semicolon)
return error("Couldn't find semicolon for target declaration.");
auto AddInlineIfNecessary = addInlineIfInHeader(Target);
auto ParamReplacements =
renameParameters(Target, Source, Sel.AST->getHeuristicResolver());
if (!ParamReplacements)
return ParamReplacements.takeError();
auto QualifiedBody =
qualifyAllDecls(Source, Target, Sel.AST->getHeuristicResolver());
if (!QualifiedBody)
return QualifiedBody.takeError();
const tooling::Replacement SemicolonToFuncBody(SM, *Semicolon, 1,
*QualifiedBody);
tooling::Replacements TargetFileReplacements(SemicolonToFuncBody);
TargetFileReplacements = TargetFileReplacements.merge(*ParamReplacements);
if (AddInlineIfNecessary) {
if (auto Err = TargetFileReplacements.add(*AddInlineIfNecessary))
return std::move(Err);
}
auto DefRange = toHalfOpenFileRange(
SM, AST.getLangOpts(),
SM.getExpansionRange(CharSourceRange::getCharRange(getBeginLoc(Source),
Source->getEndLoc()))
.getAsRange());
if (!DefRange)
return error("Couldn't get range for the source.");
unsigned int SourceLen = SM.getFileOffset(DefRange->getEnd()) -
SM.getFileOffset(DefRange->getBegin());
const tooling::Replacement DeleteFuncBody(SM, DefRange->getBegin(),
SourceLen, "");
llvm::SmallVector<std::pair<std::string, Edit>> Edits;
auto FE = Effect::fileEdit(SM, SM.getFileID(*Semicolon),
std::move(TargetFileReplacements));
if (!FE)
return FE.takeError();
Edits.push_back(std::move(*FE));
if (!SM.isWrittenInSameFile(DefRange->getBegin(),
SM.getExpansionLoc(Target->getBeginLoc()))) {
auto FE = Effect::fileEdit(SM, SM.getFileID(Sel.Cursor),
tooling::Replacements(DeleteFuncBody));
if (!FE)
return FE.takeError();
Edits.push_back(std::move(*FE));
} else {
if (auto Err = Edits.front().second.Replacements.add(DeleteFuncBody))
return std::move(Err);
}
Effect E;
for (auto &Pair : Edits)
E.ApplyEdits.try_emplace(std::move(Pair.first), std::move(Pair.second));
return E;
}
private:
const FunctionDecl *Source = nullptr;
const FunctionDecl *Target = nullptr;
};
REGISTER_TWEAK(DefineInline)
}
}
}