#include "AST.h"
#include "FindTarget.h"
#include "ParsedAST.h"
#include "Selection.h"
#include "SourceCode.h"
#include "refactor/Tweak.h"
#include "support/Logger.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
#include "clang/AST/DeclBase.h"
#include "clang/AST/NestedNameSpecifier.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/Stmt.h"
#include "clang/Basic/LangOptions.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Tooling/Core/Replacement.h"
#include "clang/Tooling/Refactoring/Extract/SourceExtraction.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/raw_os_ostream.h"
#include <optional>
namespace clang {
namespace clangd {
namespace {
using Node = SelectionTree::Node;
enum class ZoneRelative {
Before,
Inside,
After,
OutsideFunc
};
enum FunctionDeclKind {
InlineDefinition,
ForwardDeclaration,
OutOfLineDefinition
};
bool isRootStmt(const Node *N) {
if (!N->ASTNode.get<Stmt>())
return false;
if (N->Selected == SelectionTree::Partial)
return false;
if (N->Selected == SelectionTree::Unselected && !N->ASTNode.get<DeclStmt>())
return false;
return true;
}
const Node *getParentOfRootStmts(const Node *CommonAnc) {
if (!CommonAnc)
return nullptr;
const Node *Parent = nullptr;
switch (CommonAnc->Selected) {
case SelectionTree::Selection::Unselected:
Parent = CommonAnc;
break;
case SelectionTree::Selection::Partial:
return nullptr;
case SelectionTree::Selection::Complete:
Parent = CommonAnc->Parent;
if (Parent->ASTNode.get<DeclStmt>())
Parent = Parent->Parent;
break;
}
return llvm::all_of(Parent->Children, isRootStmt) ? Parent : nullptr;
}
struct ExtractionZone {
const Node *Parent = nullptr;
SourceRange ZoneRange;
const FunctionDecl *EnclosingFunction = nullptr;
SourceRange EnclosingFuncRange;
llvm::DenseSet<const Stmt *> RootStmts;
SourceLocation getInsertionPoint() const {
return EnclosingFuncRange.getBegin();
}
bool isRootStmt(const Stmt *S) const;
const Node *getLastRootStmt() const { return Parent->Children.back(); }
bool requiresHoisting(const SourceManager &SM,
const HeuristicResolver *Resolver) const {
llvm::SmallSet<const Decl *, 1> DeclsInExtZone;
for (auto *RootStmt : RootStmts) {
findExplicitReferences(
RootStmt,
[&DeclsInExtZone](const ReferenceLoc &Loc) {
if (!Loc.IsDecl)
return;
DeclsInExtZone.insert(Loc.Targets.front());
},
Resolver);
}
if (DeclsInExtZone.empty())
return false;
for (const auto *S : EnclosingFunction->getBody()->children()) {
if (SM.isBeforeInTranslationUnit(S->getSourceRange().getEnd(),
ZoneRange.getEnd()))
continue;
bool HasPostUse = false;
findExplicitReferences(
S,
[&](const ReferenceLoc &Loc) {
if (HasPostUse ||
SM.isBeforeInTranslationUnit(Loc.NameLoc, ZoneRange.getEnd()))
return;
HasPostUse = llvm::any_of(Loc.Targets,
[&DeclsInExtZone](const Decl *Target) {
return DeclsInExtZone.contains(Target);
});
},
Resolver);
if (HasPostUse)
return true;
}
return false;
}
};
bool alwaysReturns(const ExtractionZone &EZ) {
const Stmt *Last = EZ.getLastRootStmt()->ASTNode.get<Stmt>();
while (const auto *CS = llvm::dyn_cast<CompoundStmt>(Last)) {
if (CS->body_empty())
return false;
Last = CS->body_back();
}
return llvm::isa<ReturnStmt>(Last);
}
bool ExtractionZone::isRootStmt(const Stmt *S) const {
return RootStmts.contains(S);
}
const FunctionDecl *findEnclosingFunction(const Node *CommonAnc) {
for (const Node *CurNode = CommonAnc; CurNode; CurNode = CurNode->Parent) {
if (CurNode->ASTNode.get<LambdaExpr>())
return nullptr;
if (const FunctionDecl *Func = CurNode->ASTNode.get<FunctionDecl>()) {
if (Func->isTemplated())
return nullptr;
if (!Func->getBody())
return nullptr;
for (const auto *S : Func->getBody()->children()) {
if (!S)
return nullptr;
}
return Func;
}
}
return nullptr;
}
std::optional<SourceRange> findZoneRange(const Node *Parent,
const SourceManager &SM,
const LangOptions &LangOpts) {
SourceRange SR;
if (auto BeginFileRange = toHalfOpenFileRange(
SM, LangOpts, Parent->Children.front()->ASTNode.getSourceRange()))
SR.setBegin(BeginFileRange->getBegin());
else
return std::nullopt;
if (auto EndFileRange = toHalfOpenFileRange(
SM, LangOpts, Parent->Children.back()->ASTNode.getSourceRange()))
SR.setEnd(EndFileRange->getEnd());
else
return std::nullopt;
return SR;
}
std::optional<SourceRange>
computeEnclosingFuncRange(const FunctionDecl *EnclosingFunction,
const SourceManager &SM,
const LangOptions &LangOpts) {
return toHalfOpenFileRange(SM, LangOpts, EnclosingFunction->getSourceRange());
}
bool validSingleChild(const Node *Child, const FunctionDecl *EnclosingFunc) {
if (Child->ASTNode.get<Expr>())
return false;
assert(EnclosingFunc->hasBody() &&
"We should always be extracting from a function body.");
if (Child->ASTNode.get<Stmt>() == EnclosingFunc->getBody())
return false;
return true;
}
std::optional<ExtractionZone> findExtractionZone(const Node *CommonAnc,
const SourceManager &SM,
const LangOptions &LangOpts) {
ExtractionZone ExtZone;
ExtZone.Parent = getParentOfRootStmts(CommonAnc);
if (!ExtZone.Parent || ExtZone.Parent->Children.empty())
return std::nullopt;
ExtZone.EnclosingFunction = findEnclosingFunction(ExtZone.Parent);
if (!ExtZone.EnclosingFunction)
return std::nullopt;
if (ExtZone.Parent->Children.size() == 1 &&
!validSingleChild(ExtZone.getLastRootStmt(), ExtZone.EnclosingFunction))
return std::nullopt;
if (auto FuncRange =
computeEnclosingFuncRange(ExtZone.EnclosingFunction, SM, LangOpts))
ExtZone.EnclosingFuncRange = *FuncRange;
if (auto ZoneRange = findZoneRange(ExtZone.Parent, SM, LangOpts))
ExtZone.ZoneRange = *ZoneRange;
if (ExtZone.EnclosingFuncRange.isInvalid() || ExtZone.ZoneRange.isInvalid())
return std::nullopt;
for (const Node *Child : ExtZone.Parent->Children)
ExtZone.RootStmts.insert(Child->ASTNode.get<Stmt>());
return ExtZone;
}
struct NewFunction {
struct Parameter {
std::string Name;
QualType TypeInfo;
bool PassByReference;
unsigned OrderPriority;
std::string render(const DeclContext *Context) const;
bool operator<(const Parameter &Other) const {
return OrderPriority < Other.OrderPriority;
}
};
std::string Name = "extracted";
QualType ReturnType;
std::vector<Parameter> Parameters;
SourceRange BodyRange;
SourceLocation DefinitionPoint;
std::optional<SourceLocation> ForwardDeclarationPoint;
const CXXRecordDecl *EnclosingClass = nullptr;
const NestedNameSpecifier *DefinitionQualifier = nullptr;
const DeclContext *SemanticDC = nullptr;
const DeclContext *SyntacticDC = nullptr;
const DeclContext *ForwardDeclarationSyntacticDC = nullptr;
bool CallerReturnsValue = false;
bool Static = false;
ConstexprSpecKind Constexpr = ConstexprSpecKind::Unspecified;
bool Const = false;
tooling::ExtractionSemicolonPolicy SemicolonPolicy;
const LangOptions *LangOpts;
NewFunction(tooling::ExtractionSemicolonPolicy SemicolonPolicy,
const LangOptions *LangOpts)
: SemicolonPolicy(SemicolonPolicy), LangOpts(LangOpts) {}
std::string renderCall() const;
std::string renderDeclaration(FunctionDeclKind K,
const DeclContext &SemanticDC,
const DeclContext &SyntacticDC,
const SourceManager &SM) const;
private:
std::string
renderParametersForDeclaration(const DeclContext &Enclosing) const;
std::string renderParametersForCall() const;
std::string renderSpecifiers(FunctionDeclKind K) const;
std::string renderQualifiers() const;
std::string renderDeclarationName(FunctionDeclKind K) const;
std::string getFuncBody(const SourceManager &SM) const;
};
std::string NewFunction::renderParametersForDeclaration(
const DeclContext &Enclosing) const {
std::string Result;
bool NeedCommaBefore = false;
for (const Parameter &P : Parameters) {
if (NeedCommaBefore)
Result += ", ";
NeedCommaBefore = true;
Result += P.render(&Enclosing);
}
return Result;
}
std::string NewFunction::renderParametersForCall() const {
std::string Result;
bool NeedCommaBefore = false;
for (const Parameter &P : Parameters) {
if (NeedCommaBefore)
Result += ", ";
NeedCommaBefore = true;
Result += P.Name;
}
return Result;
}
std::string NewFunction::renderSpecifiers(FunctionDeclKind K) const {
std::string Attributes;
if (Static && K != FunctionDeclKind::OutOfLineDefinition) {
Attributes += "static ";
}
switch (Constexpr) {
case ConstexprSpecKind::Unspecified:
case ConstexprSpecKind::Constinit:
break;
case ConstexprSpecKind::Constexpr:
Attributes += "constexpr ";
break;
case ConstexprSpecKind::Consteval:
Attributes += "consteval ";
break;
}
return Attributes;
}
std::string NewFunction::renderQualifiers() const {
std::string Attributes;
if (Const) {
Attributes += " const";
}
return Attributes;
}
std::string NewFunction::renderDeclarationName(FunctionDeclKind K) const {
if (DefinitionQualifier == nullptr || K != OutOfLineDefinition) {
return Name;
}
std::string QualifierName;
llvm::raw_string_ostream Oss(QualifierName);
DefinitionQualifier->print(Oss, *LangOpts);
return llvm::formatv("{0}{1}", QualifierName, Name);
}
std::string NewFunction::renderCall() const {
return std::string(
llvm::formatv("{0}{1}({2}){3}", CallerReturnsValue ? "return " : "", Name,
renderParametersForCall(),
(SemicolonPolicy.isNeededInOriginalFunction() ? ";" : "")));
}
std::string NewFunction::renderDeclaration(FunctionDeclKind K,
const DeclContext &SemanticDC,
const DeclContext &SyntacticDC,
const SourceManager &SM) const {
std::string Declaration = std::string(llvm::formatv(
"{0}{1} {2}({3}){4}", renderSpecifiers(K),
printType(ReturnType, SyntacticDC), renderDeclarationName(K),
renderParametersForDeclaration(SemanticDC), renderQualifiers()));
switch (K) {
case ForwardDeclaration:
return std::string(llvm::formatv("{0};\n", Declaration));
case OutOfLineDefinition:
case InlineDefinition:
return std::string(
llvm::formatv("{0} {\n{1}\n}\n", Declaration, getFuncBody(SM)));
break;
}
llvm_unreachable("Unsupported FunctionDeclKind enum");
}
std::string NewFunction::getFuncBody(const SourceManager &SM) const {
return toSourceCode(SM, BodyRange).str() +
(SemicolonPolicy.isNeededInExtractedFunction() ? ";" : "");
}
std::string NewFunction::Parameter::render(const DeclContext *Context) const {
return printType(TypeInfo, *Context) + (PassByReference ? " &" : " ") + Name;
}
struct CapturedZoneInfo {
struct DeclInformation {
const Decl *TheDecl;
ZoneRelative DeclaredIn;
unsigned DeclIndex;
bool IsReferencedInZone = false;
bool IsReferencedInPostZone = false;
DeclInformation(const Decl *TheDecl, ZoneRelative DeclaredIn,
unsigned DeclIndex)
: TheDecl(TheDecl), DeclaredIn(DeclaredIn), DeclIndex(DeclIndex){};
void markOccurence(ZoneRelative ReferenceLoc);
};
llvm::DenseMap<const Decl *, DeclInformation> DeclInfoMap;
bool HasReturnStmt = false;
bool AlwaysReturns = false;
bool BrokenControlFlow = false;
DeclInformation *createDeclInfo(const Decl *D, ZoneRelative RelativeLoc);
DeclInformation *getDeclInfoFor(const Decl *D);
};
CapturedZoneInfo::DeclInformation *
CapturedZoneInfo::createDeclInfo(const Decl *D, ZoneRelative RelativeLoc) {
auto InsertionResult = DeclInfoMap.insert(
{D, DeclInformation(D, RelativeLoc, DeclInfoMap.size())});
return &InsertionResult.first->second;
}
CapturedZoneInfo::DeclInformation *
CapturedZoneInfo::getDeclInfoFor(const Decl *D) {
auto Iter = DeclInfoMap.find(D);
if (Iter == DeclInfoMap.end())
return nullptr;
return &Iter->second;
}
void CapturedZoneInfo::DeclInformation::markOccurence(
ZoneRelative ReferenceLoc) {
switch (ReferenceLoc) {
case ZoneRelative::Inside:
IsReferencedInZone = true;
break;
case ZoneRelative::After:
IsReferencedInPostZone = true;
break;
default:
break;
}
}
bool isLoop(const Stmt *S) {
return isa<ForStmt>(S) || isa<DoStmt>(S) || isa<WhileStmt>(S) ||
isa<CXXForRangeStmt>(S);
}
CapturedZoneInfo captureZoneInfo(const ExtractionZone &ExtZone) {
class ExtractionZoneVisitor
: public clang::RecursiveASTVisitor<ExtractionZoneVisitor> {
public:
ExtractionZoneVisitor(const ExtractionZone &ExtZone) : ExtZone(ExtZone) {
TraverseDecl(const_cast<FunctionDecl *>(ExtZone.EnclosingFunction));
}
bool TraverseStmt(Stmt *S) {
if (!S)
return true;
bool IsRootStmt = ExtZone.isRootStmt(const_cast<const Stmt *>(S));
if (IsRootStmt)
CurrentLocation = ZoneRelative::Inside;
addToLoopSwitchCounters(S, 1);
RecursiveASTVisitor::TraverseStmt(S);
addToLoopSwitchCounters(S, -1);
if (IsRootStmt)
CurrentLocation = ZoneRelative::After;
return true;
}
void addToLoopSwitchCounters(Stmt *S, int Increment) {
if (CurrentLocation != ZoneRelative::Inside)
return;
if (isLoop(S))
CurNumberOfNestedLoops += Increment;
else if (isa<SwitchStmt>(S))
CurNumberOfSwitch += Increment;
}
bool VisitDecl(Decl *D) {
Info.createDeclInfo(D, CurrentLocation);
return true;
}
bool VisitDeclRefExpr(DeclRefExpr *DRE) {
const Decl *D = DRE->getDecl();
auto *DeclInfo = Info.getDeclInfoFor(D);
if (!DeclInfo)
DeclInfo = Info.createDeclInfo(D, ZoneRelative::OutsideFunc);
DeclInfo->markOccurence(CurrentLocation);
return true;
}
bool VisitReturnStmt(ReturnStmt *Return) {
if (CurrentLocation == ZoneRelative::Inside)
Info.HasReturnStmt = true;
return true;
}
bool VisitBreakStmt(BreakStmt *Break) {
if (CurrentLocation == ZoneRelative::Inside &&
!(CurNumberOfNestedLoops || CurNumberOfSwitch))
Info.BrokenControlFlow = true;
return true;
}
bool VisitContinueStmt(ContinueStmt *Continue) {
if (CurrentLocation == ZoneRelative::Inside && !CurNumberOfNestedLoops)
Info.BrokenControlFlow = true;
return true;
}
CapturedZoneInfo Info;
const ExtractionZone &ExtZone;
ZoneRelative CurrentLocation = ZoneRelative::Before;
unsigned CurNumberOfNestedLoops = 0;
unsigned CurNumberOfSwitch = 0;
};
ExtractionZoneVisitor Visitor(ExtZone);
CapturedZoneInfo Result = std::move(Visitor.Info);
Result.AlwaysReturns = alwaysReturns(ExtZone);
return Result;
}
bool createParameters(NewFunction &ExtractedFunc,
const CapturedZoneInfo &CapturedInfo) {
for (const auto &KeyVal : CapturedInfo.DeclInfoMap) {
const auto &DeclInfo = KeyVal.second;
if (DeclInfo.DeclaredIn == ZoneRelative::Inside &&
DeclInfo.IsReferencedInPostZone)
return false;
if (!DeclInfo.IsReferencedInZone)
continue;
if (DeclInfo.DeclaredIn == ZoneRelative::Inside ||
DeclInfo.DeclaredIn == ZoneRelative::OutsideFunc)
continue;
const ValueDecl *VD = dyn_cast_or_null<ValueDecl>(DeclInfo.TheDecl);
if (!VD || isa<FunctionDecl>(DeclInfo.TheDecl))
return false;
QualType TypeInfo = VD->getType().getNonReferenceType();
bool IsPassedByReference = true;
ExtractedFunc.Parameters.push_back({std::string(VD->getName()), TypeInfo,
IsPassedByReference,
DeclInfo.DeclIndex});
}
llvm::sort(ExtractedFunc.Parameters);
return true;
}
tooling::ExtractionSemicolonPolicy
getSemicolonPolicy(ExtractionZone &ExtZone, const SourceManager &SM,
const LangOptions &LangOpts) {
SourceRange FuncBodyRange = {ExtZone.ZoneRange.getBegin(),
ExtZone.ZoneRange.getEnd().getLocWithOffset(-1)};
auto SemicolonPolicy = tooling::ExtractionSemicolonPolicy::compute(
ExtZone.getLastRootStmt()->ASTNode.get<Stmt>(), FuncBodyRange, SM,
LangOpts);
ExtZone.ZoneRange.setEnd(FuncBodyRange.getEnd().getLocWithOffset(1));
return SemicolonPolicy;
}
bool generateReturnProperties(NewFunction &ExtractedFunc,
const FunctionDecl &EnclosingFunc,
const CapturedZoneInfo &CapturedInfo) {
if (CapturedInfo.HasReturnStmt) {
if (!CapturedInfo.AlwaysReturns)
return false;
QualType Ret = EnclosingFunc.getReturnType();
if (Ret->isDependentType())
return false;
ExtractedFunc.ReturnType = Ret;
return true;
}
ExtractedFunc.ReturnType = EnclosingFunc.getParentASTContext().VoidTy;
return true;
}
void captureMethodInfo(NewFunction &ExtractedFunc,
const CXXMethodDecl *Method) {
ExtractedFunc.Static = Method->isStatic();
ExtractedFunc.Const = Method->isConst();
ExtractedFunc.EnclosingClass = Method->getParent();
}
llvm::Expected<NewFunction> getExtractedFunction(ExtractionZone &ExtZone,
const SourceManager &SM,
const LangOptions &LangOpts) {
CapturedZoneInfo CapturedInfo = captureZoneInfo(ExtZone);
if (CapturedInfo.BrokenControlFlow)
return error("Cannot extract break/continue without corresponding "
"loop/switch statement.");
NewFunction ExtractedFunc(getSemicolonPolicy(ExtZone, SM, LangOpts),
&LangOpts);
ExtractedFunc.SyntacticDC =
ExtZone.EnclosingFunction->getLexicalDeclContext();
ExtractedFunc.SemanticDC = ExtZone.EnclosingFunction->getDeclContext();
ExtractedFunc.DefinitionQualifier = ExtZone.EnclosingFunction->getQualifier();
ExtractedFunc.Constexpr = ExtZone.EnclosingFunction->getConstexprKind();
if (const auto *Method =
llvm::dyn_cast<CXXMethodDecl>(ExtZone.EnclosingFunction))
captureMethodInfo(ExtractedFunc, Method);
if (ExtZone.EnclosingFunction->isOutOfLine()) {
const auto *FirstOriginalDecl =
ExtZone.EnclosingFunction->getCanonicalDecl();
auto DeclPos =
toHalfOpenFileRange(SM, LangOpts, FirstOriginalDecl->getSourceRange());
if (!DeclPos)
return error("Declaration is inside a macro");
ExtractedFunc.ForwardDeclarationPoint = DeclPos->getBegin();
ExtractedFunc.ForwardDeclarationSyntacticDC = ExtractedFunc.SemanticDC;
}
ExtractedFunc.BodyRange = ExtZone.ZoneRange;
ExtractedFunc.DefinitionPoint = ExtZone.getInsertionPoint();
ExtractedFunc.CallerReturnsValue = CapturedInfo.AlwaysReturns;
if (!createParameters(ExtractedFunc, CapturedInfo) ||
!generateReturnProperties(ExtractedFunc, *ExtZone.EnclosingFunction,
CapturedInfo))
return error("Too complex to extract.");
return ExtractedFunc;
}
class ExtractFunction : public Tweak {
public:
const char *id() const final;
bool prepare(const Selection &Inputs) override;
Expected<Effect> apply(const Selection &Inputs) override;
std::string title() const override { return "Extract to function"; }
llvm::StringLiteral kind() const override {
return CodeAction::REFACTOR_KIND;
}
private:
ExtractionZone ExtZone;
};
REGISTER_TWEAK(ExtractFunction)
tooling::Replacement replaceWithFuncCall(const NewFunction &ExtractedFunc,
const SourceManager &SM,
const LangOptions &LangOpts) {
std::string FuncCall = ExtractedFunc.renderCall();
return tooling::Replacement(
SM, CharSourceRange(ExtractedFunc.BodyRange, false), FuncCall, LangOpts);
}
tooling::Replacement createFunctionDefinition(const NewFunction &ExtractedFunc,
const SourceManager &SM) {
FunctionDeclKind DeclKind = InlineDefinition;
if (ExtractedFunc.ForwardDeclarationPoint)
DeclKind = OutOfLineDefinition;
std::string FunctionDef = ExtractedFunc.renderDeclaration(
DeclKind, *ExtractedFunc.SemanticDC, *ExtractedFunc.SyntacticDC, SM);
return tooling::Replacement(SM, ExtractedFunc.DefinitionPoint, 0,
FunctionDef);
}
tooling::Replacement createForwardDeclaration(const NewFunction &ExtractedFunc,
const SourceManager &SM) {
std::string FunctionDecl = ExtractedFunc.renderDeclaration(
ForwardDeclaration, *ExtractedFunc.SemanticDC,
*ExtractedFunc.ForwardDeclarationSyntacticDC, SM);
SourceLocation DeclPoint = *ExtractedFunc.ForwardDeclarationPoint;
return tooling::Replacement(SM, DeclPoint, 0, FunctionDecl);
}
bool hasReturnStmt(const ExtractionZone &ExtZone) {
class ReturnStmtVisitor
: public clang::RecursiveASTVisitor<ReturnStmtVisitor> {
public:
bool VisitReturnStmt(ReturnStmt *Return) {
Found = true;
return false;
}
bool Found = false;
};
ReturnStmtVisitor V;
for (const Stmt *RootStmt : ExtZone.RootStmts) {
V.TraverseStmt(const_cast<Stmt *>(RootStmt));
if (V.Found)
break;
}
return V.Found;
}
bool ExtractFunction::prepare(const Selection &Inputs) {
const LangOptions &LangOpts = Inputs.AST->getLangOpts();
if (!LangOpts.CPlusPlus)
return false;
const Node *CommonAnc = Inputs.ASTSelection.commonAncestor();
const SourceManager &SM = Inputs.AST->getSourceManager();
auto MaybeExtZone = findExtractionZone(CommonAnc, SM, LangOpts);
if (!MaybeExtZone ||
(hasReturnStmt(*MaybeExtZone) && !alwaysReturns(*MaybeExtZone)))
return false;
if (MaybeExtZone->requiresHoisting(SM, Inputs.AST->getHeuristicResolver()))
return false;
ExtZone = std::move(*MaybeExtZone);
return true;
}
Expected<Tweak::Effect> ExtractFunction::apply(const Selection &Inputs) {
const SourceManager &SM = Inputs.AST->getSourceManager();
const LangOptions &LangOpts = Inputs.AST->getLangOpts();
auto ExtractedFunc = getExtractedFunction(ExtZone, SM, LangOpts);
if (!ExtractedFunc)
return ExtractedFunc.takeError();
tooling::Replacements Edit;
if (auto Err = Edit.add(createFunctionDefinition(*ExtractedFunc, SM)))
return std::move(Err);
if (auto Err = Edit.add(replaceWithFuncCall(*ExtractedFunc, SM, LangOpts)))
return std::move(Err);
if (auto FwdLoc = ExtractedFunc->ForwardDeclarationPoint) {
if (SM.isWrittenInSameFile(ExtractedFunc->DefinitionPoint, *FwdLoc)) {
if (auto Err = Edit.add(createForwardDeclaration(*ExtractedFunc, SM)))
return std::move(Err);
} else {
auto MultiFileEffect = Effect::mainFileEdit(SM, std::move(Edit));
if (!MultiFileEffect)
return MultiFileEffect.takeError();
tooling::Replacements OtherEdit(
createForwardDeclaration(*ExtractedFunc, SM));
if (auto PathAndEdit = Tweak::Effect::fileEdit(SM, SM.getFileID(*FwdLoc),
OtherEdit))
MultiFileEffect->ApplyEdits.try_emplace(PathAndEdit->first,
PathAndEdit->second);
else
return PathAndEdit.takeError();
return MultiFileEffect;
}
}
return Effect::mainFileEdit(SM, std::move(Edit));
}
}
}
}