#include "clang/Tooling/Refactoring/Extract/SourceExtraction.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/StmtCXX.h"
#include "clang/AST/StmtObjC.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Lex/Lexer.h"
#include <optional>
using namespace clang;
namespace {
bool isSemicolonAtLocation(SourceLocation TokenLoc, const SourceManager &SM,
const LangOptions &LangOpts) {
return Lexer::getSourceText(
CharSourceRange::getTokenRange(TokenLoc, TokenLoc), SM,
LangOpts) == ";";
}
bool isSemicolonRequiredAfter(const Stmt *S) {
if (isa<CompoundStmt>(S))
return false;
if (const auto *If = dyn_cast<IfStmt>(S))
return isSemicolonRequiredAfter(If->getElse() ? If->getElse()
: If->getThen());
if (const auto *While = dyn_cast<WhileStmt>(S))
return isSemicolonRequiredAfter(While->getBody());
if (const auto *For = dyn_cast<ForStmt>(S))
return isSemicolonRequiredAfter(For->getBody());
if (const auto *CXXFor = dyn_cast<CXXForRangeStmt>(S))
return isSemicolonRequiredAfter(CXXFor->getBody());
if (const auto *ObjCFor = dyn_cast<ObjCForCollectionStmt>(S))
return isSemicolonRequiredAfter(ObjCFor->getBody());
if(const auto *Switch = dyn_cast<SwitchStmt>(S))
return isSemicolonRequiredAfter(Switch->getBody());
if(const auto *Case = dyn_cast<SwitchCase>(S))
return isSemicolonRequiredAfter(Case->getSubStmt());
switch (S->getStmtClass()) {
case Stmt::DeclStmtClass:
case Stmt::CXXTryStmtClass:
case Stmt::ObjCAtSynchronizedStmtClass:
case Stmt::ObjCAutoreleasePoolStmtClass:
case Stmt::ObjCAtTryStmtClass:
return false;
default:
return true;
}
}
bool areOnSameLine(SourceLocation Loc1, SourceLocation Loc2,
const SourceManager &SM) {
return !Loc1.isMacroID() && !Loc2.isMacroID() &&
SM.getSpellingLineNumber(Loc1) == SM.getSpellingLineNumber(Loc2);
}
}
namespace clang {
namespace tooling {
ExtractionSemicolonPolicy
ExtractionSemicolonPolicy::compute(const Stmt *S, SourceRange &ExtractedRange,
const SourceManager &SM,
const LangOptions &LangOpts) {
auto neededInExtractedFunction = []() {
return ExtractionSemicolonPolicy(true, false);
};
auto neededInOriginalFunction = []() {
return ExtractionSemicolonPolicy(false, true);
};
if (isa<Expr>(S))
return neededInExtractedFunction();
bool NeedsSemi = isSemicolonRequiredAfter(S);
if (!NeedsSemi)
return neededInOriginalFunction();
SourceLocation End = ExtractedRange.getEnd();
if (isSemicolonAtLocation(End, SM, LangOpts))
return neededInOriginalFunction();
std::optional<Token> NextToken = Lexer::findNextToken(End, SM, LangOpts);
if (NextToken && NextToken->is(tok::semi) &&
areOnSameLine(NextToken->getLocation(), End, SM)) {
ExtractedRange.setEnd(NextToken->getLocation());
return neededInOriginalFunction();
}
return ExtractionSemicolonPolicy(true, true);
}
}
}