#include <assert.h>
#include <algorithm>
#include <filesystem>
#include <limits>
#include <memory>
#include <optional>
#include <regex>
#include <string>
#include <vector>
#include "RawPtrHelpers.h"
#include "RawPtrManualPathsToIgnore.h"
#include "SeparateRepositoryPaths.h"
#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/ASTMatchers/ASTMatchers.h"
#include "clang/ASTMatchers/ASTMatchersMacros.h"
#include "clang/Basic/CharInfo.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Lex/Lexer.h"
#include "clang/Lex/MacroArgs.h"
#include "clang/Lex/PPCallbacks.h"
#include "clang/Lex/Preprocessor.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Refactoring.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/TargetSelect.h"
using namespace clang::ast_matchers;
namespace {
const char kRawPtrIncludePath[] = "base/memory/raw_ptr.h";
const char kRawRefIncludePath[] = "base/memory/raw_ref.h";
const char kRawSpanIncludePath[] = "base/memory/raw_span.h";
const char kExcludeFieldsParamName[] = "exclude-fields";
const char kOverrideExcludePathsParamName[] = "override-exclude-paths";
// | sed '/^==== BEGIN $DELIM ====$/,/^==== END $DELIM ====$/{//!b};d' \
// | sort | uniq > ~/scratch/some-out-of-band-output.txt
class OutputSectionHelper {
public:
explicit OutputSectionHelper(llvm::StringRef output_delimiter)
: output_delimiter_(output_delimiter.str()) {}
OutputSectionHelper(const OutputSectionHelper&) = delete;
OutputSectionHelper& operator=(const OutputSectionHelper&) = delete;
void Add(llvm::StringRef output_line,
llvm::StringRef tag = "",
llvm::StringRef loc = "") {
llvm::StringSet<>& tags = output_line_to_tags_[output_line];
if (!tag.empty()) {
tags.insert(tag);
}
llvm::StringSet<>& locs = output_line_to_locs_[output_line];
if (!loc.empty()) {
locs.insert(loc);
}
}
void Emit() {
if (output_line_to_tags_.empty())
return;
llvm::outs() << "==== BEGIN " << output_delimiter_ << " ====\n";
for (const llvm::StringRef& output_line :
GetSortedKeys(output_line_to_tags_)) {
llvm::outs() << output_line;
const llvm::StringSet<>& locs = output_line_to_locs_[output_line];
if (!locs.empty()) {
std::vector<llvm::StringRef> sorted_locs = GetSortedKeys(locs);
std::string locs_comment =
llvm::join(sorted_locs.begin(), sorted_locs.end(), ", ");
llvm::outs() << " @ " << locs_comment;
}
const llvm::StringSet<>& tags = output_line_to_tags_[output_line];
if (!tags.empty()) {
std::vector<llvm::StringRef> sorted_tags = GetSortedKeys(tags);
std::string tags_comment =
llvm::join(sorted_tags.begin(), sorted_tags.end(), ", ");
llvm::outs() << " # " << tags_comment;
}
llvm::outs() << "\n";
}
llvm::outs() << "==== END " << output_delimiter_ << " ====\n";
}
private:
template <typename TValue>
static std::vector<llvm::StringRef> GetSortedKeys(
const llvm::StringMap<TValue>& map) {
std::vector<llvm::StringRef> sorted(map.keys().begin(), map.keys().end());
std::sort(sorted.begin(), sorted.end());
return sorted;
}
std::string output_delimiter_;
llvm::StringMap<llvm::StringSet<>> output_line_to_tags_;
llvm::StringMap<llvm::StringSet<>> output_line_to_locs_;
};
class OutputHelper : public clang::tooling::SourceFileCallbacks {
public:
OutputHelper()
: edits_helper_("EDITS"), field_decl_filter_helper_("FIELD FILTERS") {}
~OutputHelper() = default;
OutputHelper(const OutputHelper&) = delete;
OutputHelper& operator=(const OutputHelper&) = delete;
void AddReplacement(const clang::SourceManager& source_manager,
const clang::SourceRange& replacement_range,
std::string replacement_text,
const char* include_path = nullptr) {
clang::tooling::Replacement replacement(
source_manager, clang::CharSourceRange::getCharRange(replacement_range),
replacement_text);
std::string file_path =
std::filesystem::proximate(replacement.getFilePath().str());
if (file_path.empty())
return;
std::replace(replacement_text.begin(), replacement_text.end(), '\n', '\0');
std::string replacement_directive = llvm::formatv(
"r:::{0}:::{1}:::{2}:::{3}", file_path, replacement.getOffset(),
replacement.getLength(), replacement_text);
edits_helper_.Add(replacement_directive);
if (include_path) {
std::string include_directive = llvm::formatv(
"include-user-header:::{0}:::-1:::-1:::{1}", file_path, include_path);
edits_helper_.Add(include_directive);
}
}
void AddFilteredField(const clang::SourceManager& source_manager,
const clang::FieldDecl& field_decl,
llvm::StringRef filter_tag) {
std::string qualified_name = field_decl.getQualifiedNameAsString();
clang::SourceLocation loc = field_decl.getBeginLoc();
std::string loc_str =
std::filesystem::proximate(source_manager.getFilename(loc).str());
if (!loc_str.empty()) {
loc_str +=
":" + std::to_string(source_manager.getSpellingLineNumber(loc));
loc_str +=
":" + std::to_string(source_manager.getSpellingColumnNumber(loc));
}
field_decl_filter_helper_.Add(qualified_name, filter_tag, loc_str);
}
private:
bool handleBeginSource(clang::CompilerInstance& compiler) override {
const clang::FrontendOptions& frontend_options = compiler.getFrontendOpts();
assert((frontend_options.Inputs.size() == 1) &&
"run_tool.py should invoke the rewriter one file at a time");
const clang::FrontendInputFile& input_file = frontend_options.Inputs[0];
assert(input_file.isFile() &&
"run_tool.py should invoke the rewriter on actual files");
current_language_ = input_file.getKind().getLanguage();
return true;
}
void handleEndSource() override {
if (ShouldSuppressOutput())
return;
edits_helper_.Emit();
field_decl_filter_helper_.Emit();
}
bool ShouldSuppressOutput() {
switch (current_language_) {
case clang::Language::Unknown:
case clang::Language::Asm:
case clang::Language::LLVM_IR:
case clang::Language::OpenCL:
case clang::Language::CUDA:
case clang::Language::RenderScript:
case clang::Language::HIP:
case clang::Language::HLSL:
return true;
case clang::Language::C:
case clang::Language::ObjC:
return true;
case clang::Language::CXX:
case clang::Language::OpenCLCXX:
case clang::Language::ObjCXX:
return false;
}
assert(false && "Unrecognized clang::Language");
return true;
}
OutputSectionHelper edits_helper_;
OutputSectionHelper field_decl_filter_helper_;
clang::Language current_language_ = clang::Language::Unknown;
};
AST_MATCHER(clang::CXXRecordDecl, isTrivial) {
return Node.isTrivial();
}
bool IsOverlapping(const clang::SourceManager& source_manager,
const clang::SourceRange& a,
const clang::SourceRange& b) {
clang::FullSourceLoc a1(a.getBegin(), source_manager);
clang::FullSourceLoc a2(a.getEnd(), source_manager);
clang::FullSourceLoc b1(b.getBegin(), source_manager);
clang::FullSourceLoc b2(b.getEnd(), source_manager);
if (!a1.isFileID() || !a2.isFileID() || !b1.isFileID() || !b2.isFileID())
return false;
if (a1.getFileID() != a2.getFileID() || a2.getFileID() != b1.getFileID() ||
b1.getFileID() != b2.getFileID()) {
return false;
}
bool b1_is_inside_a_range = a1.getFileOffset() <= b1.getFileOffset() &&
b1.getFileOffset() <= a2.getFileOffset();
bool a1_is_inside_b_range = b1.getFileOffset() <= a1.getFileOffset() &&
a1.getFileOffset() <= b2.getFileOffset();
return b1_is_inside_a_range || a1_is_inside_b_range;
}
AST_MATCHER(clang::FieldDecl, overlapsOtherDeclsWithinRecordDecl) {
const clang::FieldDecl& self = Node;
const clang::SourceManager& source_manager =
Finder->getASTContext().getSourceManager();
const clang::RecordDecl* record_decl = self.getParent();
if (!record_decl)
return false;
clang::SourceRange self_range(self.getBeginLoc(), self.getEndLoc());
auto is_overlapping_sibling = [&](const clang::Decl* other_decl) {
if (other_decl == &self)
return false;
clang::SourceRange other_range(other_decl->getBeginLoc(),
other_decl->getEndLoc());
return IsOverlapping(source_manager, self_range, other_range);
};
bool has_sibling_with_overlapping_location =
std::any_of(record_decl->decls_begin(), record_decl->decls_end(),
is_overlapping_sibling);
return has_sibling_with_overlapping_location;
}
AST_MATCHER_P(clang::QualType,
typeWithEmbeddedFieldDecl,
clang::ast_matchers::internal::Matcher<clang::FieldDecl>,
InnerMatcher) {
const clang::Type* type =
Node.getDesugaredType(Finder->getASTContext()).getTypePtrOrNull();
if (!type)
return false;
if (const clang::CXXRecordDecl* record_decl = type->getAsCXXRecordDecl()) {
auto matcher =
recordDecl(forEach(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(anyOf(
InnerMatcher, hasType(typeWithEmbeddedFieldDecl(InnerMatcher)))))));
return matcher.matches(*record_decl, Finder, Builder);
}
if (type->isArrayType()) {
const clang::ArrayType* array_type =
Finder->getASTContext().getAsArrayType(Node);
auto matcher = typeWithEmbeddedFieldDecl(InnerMatcher);
return matcher.matches(array_type->getElementType(), Finder, Builder);
}
return false;
}
class FieldDeclRewriter : public MatchFinder::MatchCallback {
public:
explicit FieldDeclRewriter(OutputHelper* output_helper,
const char* format_string,
const char* include_path)
: output_helper_(output_helper),
format_string_(format_string),
include_path_(include_path) {}
FieldDeclRewriter(const FieldDeclRewriter&) = delete;
FieldDeclRewriter& operator=(const FieldDeclRewriter&) = delete;
virtual bool earlyExit(const MatchFinder::MatchResult& result) const = 0;
void run(const MatchFinder::MatchResult& result) override {
if (earlyExit(result)) {
return;
}
const clang::ASTContext& ast_context = *result.Context;
const clang::SourceManager& source_manager = *result.SourceManager;
const clang::FieldDecl* field_decl =
result.Nodes.getNodeAs<clang::FieldDecl>("affectedFieldDecl");
assert(field_decl && "matcher should bind 'fieldDecl'");
const clang::TypeSourceInfo* type_source_info =
field_decl->getTypeSourceInfo();
if (auto* ivar_decl = clang::dyn_cast<clang::ObjCIvarDecl>(field_decl)) {
if (ivar_decl->getSynthesize()) {
assert(!type_source_info);
return;
}
}
assert(type_source_info && "assuming |type_source_info| is always present");
clang::QualType pointer_type = type_source_info->getType();
clang::SourceRange replacement_range(field_decl->getBeginLoc(),
field_decl->getLocation());
std::string replacement_text = GenerateNewText(ast_context, pointer_type);
if (field_decl->isMutable())
replacement_text.insert(0, "mutable ");
output_helper_->AddReplacement(source_manager, replacement_range,
replacement_text, include_path_);
}
private:
std::string GenerateNewText(const clang::ASTContext& ast_context,
const clang::QualType& pointer_type) {
std::string result;
clang::QualType pointee_type = pointer_type->getPointeeType();
assert(
!pointer_type.isRestrictQualified() &&
"|restrict| is a C-only qualifier and raw_ptr<T>/raw_ref<T> need C++");
if (pointer_type.isConstQualified())
result += "const ";
if (pointer_type.isVolatileQualified())
result += "volatile ";
clang::PrintingPolicy printing_policy(ast_context.getLangOpts());
printing_policy.SuppressScope = 1;
std::string pointee_type_as_string =
pointee_type.getAsString(printing_policy);
result += llvm::formatv(format_string_, pointee_type_as_string);
return result;
}
OutputHelper* const output_helper_;
const char* format_string_;
const char* include_path_;
};
class AffectedExprRewriter : public MatchFinder::MatchCallback {
public:
explicit AffectedExprRewriter(
OutputHelper* output_helper,
std::function<std::pair<clang::SourceRange, std::string>(
const MatchFinder::MatchResult&)> fct)
: output_helper_(output_helper), getRangeAndText_(fct) {}
AffectedExprRewriter(const AffectedExprRewriter&) = delete;
AffectedExprRewriter& operator=(const AffectedExprRewriter&) = delete;
void run(const MatchFinder::MatchResult& result) override {
const clang::SourceManager& source_manager = *result.SourceManager;
auto [replacement_range, text] = getRangeAndText_(result);
output_helper_->AddReplacement(source_manager, replacement_range,
text.c_str());
}
private:
OutputHelper* const output_helper_;
std::function<std::pair<clang::SourceRange, std::string>(
const MatchFinder::MatchResult&)>
getRangeAndText_;
};
class FilteredExprWriter : public MatchFinder::MatchCallback {
public:
FilteredExprWriter(OutputHelper* output_helper, llvm::StringRef filter_tag)
: output_helper_(output_helper), filter_tag_(filter_tag) {}
FilteredExprWriter(const FilteredExprWriter&) = delete;
FilteredExprWriter& operator=(const FilteredExprWriter&) = delete;
void run(const MatchFinder::MatchResult& result) override {
const clang::FieldDecl* field_decl =
result.Nodes.getNodeAs<clang::FieldDecl>("affectedFieldDecl");
assert(field_decl && "matcher should bind 'affectedFieldDecl'");
output_helper_->AddFilteredField(*result.SourceManager, *field_decl,
filter_tag_);
}
private:
OutputHelper* const output_helper_;
llvm::StringRef filter_tag_;
};
class RawPtrRewriter {
public:
RawPtrRewriter(
OutputHelper* output_helper,
MatchFinder& finder,
const raw_ptr_plugin::RawPtrAndRefExclusionsOptions& exclusion_options)
: match_finder(finder),
field_decl_rewriter(output_helper, "raw_ptr<{0}> ", kRawPtrIncludePath),
affected_expr_rewriter(output_helper, getRangeAndText_),
filtered_addr_of_expr_writer(output_helper, "addr-of"),
filtered_in_out_ref_arg_writer(output_helper, "in-out-param-ref"),
overlapping_field_decl_writer(output_helper, "overlapping"),
macro_field_decl_writer(output_helper, "macro"),
global_scope_rewriter(output_helper, "global-scope"),
union_field_decl_writer(output_helper, "union"),
reinterpret_cast_struct_writer(output_helper,
"reinterpret-cast-trivial-type"),
exclusion_options_(exclusion_options) {}
void addMatchers() {
auto field_decl_matcher = AffectedRawPtrFieldDecl(exclusion_options_);
match_finder.addMatcher(field_decl_matcher, &field_decl_rewriter);
auto affected_member_expr_matcher =
memberExpr(member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
field_decl_matcher))))
.bind("affectedMemberExpr");
auto affected_expr_matcher = ignoringImplicit(affected_member_expr_matcher);
auto affected_expr_that_needs_fixing_matcher = expr(allOf(
affected_expr_matcher,
hasParent(expr(anyOf(callExpr(callee(functionDecl(isVariadic()))),
cxxConstCastExpr(), cxxReinterpretCastExpr())))));
match_finder.addMatcher(affected_expr_that_needs_fixing_matcher,
&affected_expr_rewriter);
auto affected_ternary_operator_arg_matcher =
conditionalOperator(eachOf(hasTrueExpression(affected_expr_matcher),
hasFalseExpression(affected_expr_matcher)));
match_finder.addMatcher(affected_ternary_operator_arg_matcher,
&affected_expr_rewriter);
auto std_string_expr_matcher =
expr(hasType(cxxRecordDecl(hasName("::std::basic_string"))));
auto affected_string_binary_operator_arg_matcher = cxxOperatorCallExpr(
hasAnyOverloadedOperatorName("+", "==", "!=", "<", "<=", ">", ">="),
hasAnyArgument(std_string_expr_matcher),
forEachArgumentWithParam(affected_expr_matcher, parmVarDecl()));
match_finder.addMatcher(affected_string_binary_operator_arg_matcher,
&affected_expr_rewriter);
auto templated_function_arg_matcher = forEachArgumentWithParam(
affected_expr_matcher,
parmVarDecl(allOf(
hasType(
qualType(allOf(findAll(qualType(substTemplateTypeParmType())),
unless(referenceType())))),
unless(hasAncestor(functionDecl(hasName("Unretained")))))));
match_finder.addMatcher(callExpr(templated_function_arg_matcher),
&affected_expr_rewriter);
match_finder.addMatcher(
traverse(clang::TraversalKind::TK_AsIs,
cxxConstructExpr(templated_function_arg_matcher)),
&affected_expr_rewriter);
auto implicit_ctor_expr_matcher = cxxConstructExpr(
allOf(anyOf(hasParent(materializeTemporaryExpr()),
hasParent(implicitCastExpr())),
hasDeclaration(cxxConstructorDecl(
allOf(parameterCountIs(1), unless(isExplicit())))),
forEachArgumentWithParam(affected_expr_matcher, parmVarDecl())));
match_finder.addMatcher(implicit_ctor_expr_matcher,
&affected_expr_rewriter);
auto auto_var_decl_matcher = declStmt(forEach(
varDecl(allOf(hasType(pointerType(pointee(autoType()))),
hasInitializer(anyOf(
affected_expr_matcher,
initListExpr(hasInit(0, affected_expr_matcher))))))));
match_finder.addMatcher(auto_var_decl_matcher, &affected_expr_rewriter);
auto affected_addr_of_expr_matcher = expr(allOf(
affected_expr_matcher, hasParent(unaryOperator(hasOperatorName("&")))));
match_finder.addMatcher(affected_addr_of_expr_matcher,
&filtered_addr_of_expr_writer);
auto affected_in_out_ref_arg_matcher = callExpr(forEachArgumentWithParam(
affected_expr_matcher,
raw_ptr_plugin::hasExplicitParmVarDecl(
hasType(qualType(allOf(referenceType(pointee(pointerType())),
unless(rValueReferenceType())))))));
match_finder.addMatcher(affected_in_out_ref_arg_matcher,
&filtered_in_out_ref_arg_writer);
auto overlapping_field_decl_matcher = fieldDecl(
allOf(field_decl_matcher, overlapsOtherDeclsWithinRecordDecl()));
match_finder.addMatcher(overlapping_field_decl_matcher,
&overlapping_field_decl_writer);
auto macro_field_decl_matcher = fieldDecl(
allOf(field_decl_matcher, raw_ptr_plugin::isInMacroLocation()));
match_finder.addMatcher(macro_field_decl_matcher, ¯o_field_decl_writer);
auto global_scope_matcher =
varDecl(allOf(hasGlobalStorage(),
hasType(typeWithEmbeddedFieldDecl(field_decl_matcher))));
match_finder.addMatcher(global_scope_matcher, &global_scope_rewriter);
files_with_audited_unions =
std::make_unique<raw_ptr_plugin::FilterFile>(std::vector<std::string>{
"third_party/libc++/src/include/optional",
"third_party/abseil-cpp/absl/types/internal/variant.h",
});
auto union_field_decl_matcher = recordDecl(allOf(
isUnion(),
unless(isInLocationListedInFilterFile(files_with_audited_unions.get())),
forEach(fieldDecl(
anyOf(field_decl_matcher,
hasType(typeWithEmbeddedFieldDecl(field_decl_matcher)))))));
match_finder.addMatcher(union_field_decl_matcher, &union_field_decl_writer);
auto reinterpret_cast_struct_matcher =
cxxReinterpretCastExpr(hasDestinationType(pointerType(pointee(
hasUnqualifiedDesugaredType(recordType(hasDeclaration(cxxRecordDecl(
allOf(forEach(field_decl_matcher), isTrivial())))))))));
match_finder.addMatcher(reinterpret_cast_struct_matcher,
&reinterpret_cast_struct_writer);
}
private:
class RawPtrFieldDeclRewriter : public FieldDeclRewriter {
public:
explicit RawPtrFieldDeclRewriter(OutputHelper* output_helper,
const char* format_string,
const char* include_path)
: FieldDeclRewriter(output_helper, format_string, include_path) {}
bool earlyExit(const MatchFinder::MatchResult& result) const override {
return false;
}
};
std::function<std::pair<clang::SourceRange, std::string>(
const MatchFinder::MatchResult&)>
getRangeAndText_ = [](const MatchFinder::MatchResult& result)
-> std::pair<clang::SourceRange, std::string> {
const clang::MemberExpr* member_expr =
result.Nodes.getNodeAs<clang::MemberExpr>("affectedMemberExpr");
assert(member_expr && "matcher should bind 'affectedMemberExpr'");
clang::SourceLocation member_name_start = member_expr->getMemberLoc();
size_t member_name_length = member_expr->getMemberDecl()->getName().size();
clang::SourceLocation insertion_loc =
member_name_start.getLocWithOffset(member_name_length);
clang::SourceRange replacement_range(insertion_loc, insertion_loc);
return {replacement_range, ".get()"};
};
MatchFinder& match_finder;
RawPtrFieldDeclRewriter field_decl_rewriter;
AffectedExprRewriter affected_expr_rewriter;
FilteredExprWriter filtered_addr_of_expr_writer;
FilteredExprWriter filtered_in_out_ref_arg_writer;
FilteredExprWriter overlapping_field_decl_writer;
FilteredExprWriter macro_field_decl_writer;
FilteredExprWriter global_scope_rewriter;
FilteredExprWriter union_field_decl_writer;
FilteredExprWriter reinterpret_cast_struct_writer;
std::unique_ptr<raw_ptr_plugin::FilterFile> files_with_audited_unions;
const raw_ptr_plugin::RawPtrAndRefExclusionsOptions exclusion_options_;
};
class RawRefRewriter {
public:
RawRefRewriter(
OutputHelper* output_helper,
MatchFinder& finder,
const raw_ptr_plugin::RawPtrAndRefExclusionsOptions& exclusion_options)
: match_finder(finder),
field_decl_rewriter(output_helper,
"const raw_ref<{0}> ",
kRawRefIncludePath),
affected_expr_operator_rewriter(output_helper,
affectedMemberExprOperatorFct_),
affected_expr_rewriter(output_helper, affectedMemberExprFct_),
affected_expr_rewriter_with_parentheses(
output_helper,
affectedMemberExprWithParenFct_),
affected_initializer_expr_rewriter(output_helper,
affectedInitializerExprFct_),
global_scope_rewriter(output_helper, "global-scope"),
overlapping_field_decl_writer(output_helper, "overlapping"),
macro_field_decl_writer(output_helper, "macro"),
exclusion_options_(exclusion_options) {}
void addMatchers() {
auto field_decl_matcher = AffectedRawRefFieldDecl(exclusion_options_);
match_finder.addMatcher(field_decl_matcher, &field_decl_rewriter);
auto affected_member_expr_operator_matcher =
expr(anyOf(memberExpr(has(memberExpr(
member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
field_decl_matcher)))))),
memberExpr(has(implicitCastExpr(has(memberExpr(
member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
field_decl_matcher)))))))),
cxxDependentScopeMemberExpr(has(memberExpr(
member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
field_decl_matcher))))))))
.bind("affectedMemberExprOperator");
match_finder.addMatcher(affected_member_expr_operator_matcher,
&affected_expr_operator_rewriter);
auto affected_member_expr = memberExpr(
memberExpr(
member(fieldDecl(
raw_ptr_plugin::hasExplicitFieldDecl(field_decl_matcher))),
unless(
anyOf(hasParent(memberExpr()),
hasParent(implicitCastExpr(hasParent(memberExpr()))),
hasParent(cxxDependentScopeMemberExpr()),
hasParent(varDecl(unless(anyOf(
hasType(referenceType(pointee(autoType()))),
hasParent(declStmt(hasParent(cxxForRangeStmt()))))))),
hasAncestor(cxxConstructorDecl(isDefaulted())),
hasParent(cxxOperatorCallExpr()),
hasParent(unaryOperator(
anyOf(hasOperatorName("--"), hasOperatorName("++")))),
hasParent(arraySubscriptExpr()),
hasParent(callExpr(
callee(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
field_decl_matcher))))))))
.bind("affectedMemberExpr"),
unless(anyOf(
hasParent(cxxConstructorDecl(hasAnyConstructorInitializer(
allOf(withInitializer(
memberExpr(equalsBoundNode("affectedMemberExpr"))),
forField(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
field_decl_matcher))))))),
hasParent(initListExpr(raw_ptr_plugin::forEachInitExprWithFieldDecl(
memberExpr(equalsBoundNode("affectedMemberExpr")),
raw_ptr_plugin::hasExplicitFieldDecl(field_decl_matcher)))))));
match_finder.addMatcher(affected_member_expr, &affected_expr_rewriter);
auto affected_member_expr_matcher =
memberExpr(member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
field_decl_matcher))))
.bind("affectedMemberExpr");
auto implicit_ctor_expr_matcher = cxxConstructExpr(allOf(
anyOf(hasParent(materializeTemporaryExpr()),
hasParent(implicitCastExpr())),
hasDeclaration(cxxConstructorDecl(
allOf(parameterCountIs(1), unless(isExplicit())))),
forEachArgumentWithParam(affected_member_expr_matcher, parmVarDecl())));
match_finder.addMatcher(implicit_ctor_expr_matcher,
&affected_expr_rewriter);
auto auto_var_decl_matcher = declStmt(forEach(varDecl(
allOf(hasType(referenceType(pointee(autoType()))),
hasInitializer(anyOf(
affected_member_expr_matcher,
initListExpr(hasInit(0, affected_member_expr_matcher))))))));
match_finder.addMatcher(auto_var_decl_matcher, &affected_expr_rewriter);
auto affected_member_expr_with_parentheses =
memberExpr(member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
field_decl_matcher))),
anyOf(hasParent(cxxOperatorCallExpr()),
hasParent(unaryOperator(anyOf(hasOperatorName("--"),
hasOperatorName("++")))),
hasParent(arraySubscriptExpr()),
hasParent(callExpr(callee(
fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
field_decl_matcher)))))))
.bind("affectedMemberExprWithParentheses");
match_finder.addMatcher(affected_member_expr_with_parentheses,
&affected_expr_rewriter_with_parentheses);
auto init_list_expr_with_raw_ref = initListExpr(
raw_ptr_plugin::forEachInitExprWithFieldDecl(
expr(unless(anyOf(
materializeTemporaryExpr(),
memberExpr(
member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
field_decl_matcher)))))))
.bind("initializer_expr"),
raw_ptr_plugin::hasExplicitFieldDecl(field_decl_matcher)),
unless(hasParent(cxxConstructExpr())));
match_finder.addMatcher(init_list_expr_with_raw_ref,
&affected_initializer_expr_rewriter);
auto overlapping_field_decl_matcher = fieldDecl(
allOf(field_decl_matcher, overlapsOtherDeclsWithinRecordDecl()));
match_finder.addMatcher(overlapping_field_decl_matcher,
&overlapping_field_decl_writer);
auto macro_field_decl_matcher = fieldDecl(
allOf(field_decl_matcher, raw_ptr_plugin::isInMacroLocation()));
match_finder.addMatcher(macro_field_decl_matcher, ¯o_field_decl_writer);
auto global_scope_matcher =
varDecl(allOf(hasGlobalStorage(),
hasType(typeWithEmbeddedFieldDecl(field_decl_matcher))));
match_finder.addMatcher(global_scope_matcher, &global_scope_rewriter);
}
private:
class RawRefFieldDeclRewriter : public FieldDeclRewriter {
public:
explicit RawRefFieldDeclRewriter(OutputHelper* output_helper,
const char* format_string,
const char* include_path)
: FieldDeclRewriter(output_helper, format_string, include_path) {}
bool earlyExit(const MatchFinder::MatchResult& result) const override {
auto* type = result.Nodes.getNodeAs<clang::LValueReferenceTypeLoc>(
"affectedFieldDeclType");
return !type;
}
};
std::function<std::pair<clang::SourceRange, std::string>(
const MatchFinder::MatchResult&)>
affectedMemberExprFct_ = [](const MatchFinder::MatchResult& result)
-> std::pair<clang::SourceRange, std::string> {
const clang::MemberExpr* member_expr =
result.Nodes.getNodeAs<clang::MemberExpr>("affectedMemberExpr");
assert(member_expr && "matcher should bind 'affectedMemberExpr'");
clang::SourceRange replacement_range(member_expr->getBeginLoc(),
member_expr->getBeginLoc());
return {replacement_range, "*"};
};
std::function<std::pair<clang::SourceRange, std::string>(
const MatchFinder::MatchResult&)>
affectedMemberExprWithParenFct_ =
[](const MatchFinder::MatchResult& result)
-> std::pair<clang::SourceRange, std::string> {
const clang::SourceManager& source_manager = *result.SourceManager;
const clang::MemberExpr* member_expr =
result.Nodes.getNodeAs<clang::MemberExpr>(
"affectedMemberExprWithParentheses");
assert(member_expr &&
"matcher should bind 'affectedMemberExprWithParentheses'");
clang::SourceLocation member_name_start = member_expr->getMemberLoc();
clang::SourceLocation endLoc = member_name_start.getLocWithOffset(
member_expr->getMemberDecl()->getName().size());
clang::SourceRange replacement_range(member_expr->getBeginLoc(), endLoc);
auto source_text = clang::Lexer::getSourceText(
clang::CharSourceRange::getTokenRange(member_expr->getSourceRange()),
source_manager, result.Context->getLangOpts());
return {replacement_range,
llvm::formatv("(*{0})",
std::string(source_text.begin(), source_text.end()))};
};
std::function<std::pair<clang::SourceRange, std::string>(
const MatchFinder::MatchResult&)>
affectedMemberExprOperatorFct_ =
[](const MatchFinder::MatchResult& result)
-> std::pair<clang::SourceRange, std::string> {
const clang::MemberExpr* member_expr =
result.Nodes.getNodeAs<clang::MemberExpr>("affectedMemberExprOperator");
const clang::CXXDependentScopeMemberExpr* cxx_dependent_scope_member_expr =
result.Nodes.getNodeAs<clang::CXXDependentScopeMemberExpr>(
"affectedMemberExprOperator");
assert((member_expr || cxx_dependent_scope_member_expr) &&
"matcher should bind 'affectedMemberExprOperator'");
if (member_expr) {
clang::SourceRange replacement_range(member_expr->getOperatorLoc(),
member_expr->getMemberLoc());
return {replacement_range, "->"};
}
clang::SourceRange replacement_range(
cxx_dependent_scope_member_expr->getOperatorLoc(),
cxx_dependent_scope_member_expr->getMemberLoc());
return {replacement_range, "->"};
};
std::function<std::pair<clang::SourceRange, std::string>(
const MatchFinder::MatchResult&)>
affectedInitializerExprFct_ = [](const MatchFinder::MatchResult& result)
-> std::pair<clang::SourceRange, std::string> {
const clang::SourceManager& source_manager = *result.SourceManager;
const clang::Expr* initializer_expr =
result.Nodes.getNodeAs<clang::Expr>("initializer_expr");
auto source_text = clang::Lexer::getSourceText(
clang::CharSourceRange::getTokenRange(
initializer_expr->getSourceRange()),
source_manager, result.Context->getLangOpts());
clang::SourceLocation endLoc =
initializer_expr->getBeginLoc().getLocWithOffset(source_text.size());
clang::SourceRange replacement_range(initializer_expr->getBeginLoc(),
endLoc);
return {replacement_range,
llvm::formatv("raw_ref({0})",
std::string(source_text.begin(), source_text.end()))};
};
MatchFinder& match_finder;
RawRefFieldDeclRewriter field_decl_rewriter;
AffectedExprRewriter affected_expr_operator_rewriter;
AffectedExprRewriter affected_expr_rewriter;
AffectedExprRewriter affected_expr_rewriter_with_parentheses;
AffectedExprRewriter affected_initializer_expr_rewriter;
FilteredExprWriter global_scope_rewriter;
FilteredExprWriter overlapping_field_decl_writer;
FilteredExprWriter macro_field_decl_writer;
const raw_ptr_plugin::RawPtrAndRefExclusionsOptions exclusion_options_;
};
class SpanFieldDeclRewriter : public MatchFinder::MatchCallback {
public:
explicit SpanFieldDeclRewriter(OutputHelper* output_helper,
const char* include_path)
: output_helper_(output_helper), include_path_(include_path) {}
SpanFieldDeclRewriter(const SpanFieldDeclRewriter&) = delete;
SpanFieldDeclRewriter& operator=(const SpanFieldDeclRewriter&) = delete;
void run(const MatchFinder::MatchResult& result) override {
const clang::ASTContext& ast_context = *result.Context;
const clang::SourceManager& source_manager = *result.SourceManager;
const auto& lang_opts = ast_context.getLangOpts();
const clang::FieldDecl* field_decl =
result.Nodes.getNodeAs<clang::FieldDecl>("affectedFieldDecl");
assert(field_decl && "matcher should bind 'fieldDecl'");
const clang::TypeSourceInfo* type_source_info =
field_decl->getTypeSourceInfo();
if (auto* ivar_decl = clang::dyn_cast<clang::ObjCIvarDecl>(field_decl)) {
if (ivar_decl->getSynthesize()) {
assert(!type_source_info);
return;
}
}
assert(type_source_info && "assuming |type_source_info| is always present");
if (result.Nodes.getNodeAs<clang::QualType>("container_type")) {
HandleContainerArguments(field_decl, result);
return;
}
clang::SourceRange replacement_range(field_decl->getBeginLoc(),
field_decl->getLocation());
GenerateReplacement(replacement_range, source_manager, lang_opts);
}
private:
clang::SourceRange GetTemplateArgumentSourceRange(
const clang::TemplateSpecializationTypeLoc& tst_tl,
unsigned i) {
if (i == (tst_tl.getNumArgs() - 1)) {
return clang::SourceRange(tst_tl.getArgLoc(i).getLocation(),
tst_tl.getRAngleLoc());
}
return tst_tl.getArgLoc(i).getSourceRange();
}
std::optional<clang::TemplateSpecializationTypeLoc>
GetTemplateSpecializationTypeLoc(clang::TypeLoc loc) {
if (auto specialization =
loc.getAs<clang::TemplateSpecializationTypeLoc>()) {
return specialization;
}
if (auto elaborated = loc.getAs<clang::ElaboratedTypeLoc>()) {
if (auto specialization =
elaborated.getNamedTypeLoc()
.getAs<clang::TemplateSpecializationTypeLoc>()) {
return specialization;
}
}
return {};
}
void HandleContainerArguments(const clang::FieldDecl* decl,
const MatchFinder::MatchResult& result) {
const clang::ASTContext& ast_context = *result.Context;
const clang::SourceManager& source_manager = *result.SourceManager;
auto field_type_loc = decl->getTypeSourceInfo()->getTypeLoc();
const auto& lang_opts = ast_context.getLangOpts();
auto tstl = GetTemplateSpecializationTypeLoc(field_type_loc);
if (!tstl) {
return;
}
unsigned argument_index = 0;
if (result.Nodes.getNodeAs<clang::TemplateArgument>("template_arg0")) {
argument_index = 0;
auto source_range = GetTemplateArgumentSourceRange(*tstl, argument_index);
GenerateReplacement(source_range, source_manager, lang_opts);
}
if (result.Nodes.getNodeAs<clang::TemplateArgument>("template_arg1")) {
argument_index = 1;
auto source_range = GetTemplateArgumentSourceRange(*tstl, argument_index);
GenerateReplacement(source_range, source_manager, lang_opts);
}
}
void GenerateReplacement(const clang::SourceRange& source_range,
const clang::SourceManager& source_manager,
const clang::LangOptions& lang_opts) {
std::string initial_text =
clang::Lexer::getSourceText(
clang::CharSourceRange::getCharRange(source_range), source_manager,
lang_opts)
.str();
std::string replacement_text = std::regex_replace(
initial_text, std::regex("(<|base::)?(span<)"), "$1raw_$2");
if (replacement_text.empty() || (initial_text == replacement_text)) {
return;
}
output_helper_->AddReplacement(source_manager, source_range,
replacement_text, include_path_);
}
OutputHelper* const output_helper_;
const char* include_path_;
};
class SpanRewriter {
public:
SpanRewriter(
OutputHelper* output_helper,
MatchFinder& finder,
const raw_ptr_plugin::RawPtrAndRefExclusionsOptions& exclusion_options)
: match_finder(finder),
field_decl_rewriter(output_helper, kRawSpanIncludePath),
global_scope_rewriter(output_helper, "global-scope"),
overlapping_field_decl_writer(output_helper, "overlapping"),
macro_field_decl_writer(output_helper, "macro"),
exclusion_options_(exclusion_options) {}
void addMatchers() {
auto raw_span = hasTemplateArgument(
2, refersToType(qualType(hasCanonicalType(qualType(hasDeclaration(
mapAnyOf(classTemplateSpecializationDecl, classTemplateDecl)
.with(hasName("raw_ptr"))))))));
auto string_literals_span = hasTemplateArgument(
0, refersToType(qualType(hasCanonicalType(
anyOf(asString("const char"), asString("const wchar_t"),
asString("const char8_t"), asString("const char16_t"),
asString("const char32_t"))))));
auto excluded_spans = anyOf(raw_span, string_literals_span);
auto span_type = anyOf(
qualType(hasCanonicalType(
qualType(hasDeclaration(classTemplateSpecializationDecl(
hasName("base::span"), unless(excluded_spans)))))),
qualType(hasCanonicalType(qualType(type(templateSpecializationType(
hasDeclaration(classTemplateDecl(hasName("base::span"))),
unless(excluded_spans)))))));
auto optional_span_type = anyOf(
qualType(
hasCanonicalType(hasDeclaration(classTemplateSpecializationDecl(
hasName("optional"),
hasTemplateArgument(0, refersToType(span_type)))))),
qualType(hasCanonicalType(qualType(type(templateSpecializationType(
hasDeclaration(classTemplateDecl(hasName("optional"))),
hasTemplateArgument(0, refersToType(span_type))))))));
auto container_methods =
anyOf(allOf(hasMethod(hasName("push_back")),
hasMethod(hasName("pop_back")), hasMethod(hasName("size"))),
allOf(hasMethod(hasName("insert")), hasMethod(hasName("erase")),
hasMethod(hasName("size"))),
allOf(hasMethod(hasName("push")), hasMethod(hasName("pop")),
hasMethod(hasName("size"))));
auto template_arg0 = hasTemplateArgument(
0, templateArgument(refersToType(anyOf(span_type, optional_span_type)))
.bind("template_arg0"));
auto template_arg1 = hasTemplateArgument(
1, templateArgument(refersToType(anyOf(span_type, optional_span_type)))
.bind("template_arg1"));
auto template_arguments = anyOf(allOf(template_arg0, template_arg1),
template_arg0, template_arg1);
auto container_of_span_type =
qualType(hasCanonicalType(anyOf(
qualType(hasDeclaration(classTemplateSpecializationDecl(
container_methods, template_arguments))),
qualType(type(templateSpecializationType(
hasDeclaration(classTemplateDecl(
has(cxxRecordDecl(container_methods)))),
template_arguments))))))
.bind("container_type");
auto field_decl_matcher =
traverse(clang::TK_IgnoreUnlessSpelledInSource,
fieldDecl(hasType(qualType(anyOf(span_type, optional_span_type,
container_of_span_type))),
unless(PtrAndRefExclusions(exclusion_options_)))
.bind("affectedFieldDecl"));
match_finder.addMatcher(field_decl_matcher, &field_decl_rewriter);
auto global_scope_matcher =
varDecl(allOf(hasGlobalStorage(),
hasType(typeWithEmbeddedFieldDecl(field_decl_matcher))));
match_finder.addMatcher(global_scope_matcher, &global_scope_rewriter);
auto macro_field_decl_matcher = fieldDecl(
allOf(field_decl_matcher, raw_ptr_plugin::isInMacroLocation()));
match_finder.addMatcher(macro_field_decl_matcher, ¯o_field_decl_writer);
}
private:
MatchFinder& match_finder;
SpanFieldDeclRewriter field_decl_rewriter;
FilteredExprWriter global_scope_rewriter;
FilteredExprWriter overlapping_field_decl_writer;
FilteredExprWriter macro_field_decl_writer;
const raw_ptr_plugin::RawPtrAndRefExclusionsOptions exclusion_options_;
};
}
int main(int argc, const char* argv[]) {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmParser();
llvm::cl::OptionCategory category(
"rewrite_raw_ptr_fields: changes |T* field_| to |raw_ptr<T> field_|.");
llvm::cl::opt<std::string> exclude_fields_param(
kExcludeFieldsParamName, llvm::cl::value_desc("filepath"),
llvm::cl::desc("file listing fields to be blocked (not rewritten)"));
llvm::cl::opt<std::string> override_exclude_paths_param(
kOverrideExcludePathsParamName, llvm::cl::value_desc("filepath"),
llvm::cl::desc(
"override file listing paths to be blocked (not rewritten)"));
llvm::cl::opt<bool> enable_raw_ref_rewrite(
"enable_raw_ref_rewrite", llvm::cl::init(false),
llvm::cl::desc("Rewrite T& into const raw_ref<T>"));
llvm::cl::opt<bool> enable_raw_ptr_rewrite(
"enable_raw_ptr_rewrite", llvm::cl::init(false),
llvm::cl::desc("Rewrite T* into raw_ptr<T>"));
llvm::cl::opt<bool> exclude_stack_allocated(
"exclude_stack_allocated", llvm::cl::init(true),
llvm::cl::desc("Exclude pointers/references to `STACK_ALLOCATED` objects "
"from the rewrite"));
llvm::Expected<clang::tooling::CommonOptionsParser> options =
clang::tooling::CommonOptionsParser::create(argc, argv, category);
assert(static_cast<bool>(options));
clang::tooling::ClangTool tool(options->getCompilations(),
options->getSourcePathList());
bool rewrite_raw_ref_and_ptr =
!enable_raw_ref_rewrite && !enable_raw_ptr_rewrite;
MatchFinder match_finder;
OutputHelper output_helper;
raw_ptr_plugin::FilterFile fields_to_exclude(
exclude_fields_param, exclude_fields_param.ArgStr.str());
std::unique_ptr<raw_ptr_plugin::FilterFile> paths_to_exclude;
if (override_exclude_paths_param == "") {
std::vector<std::string> paths_to_exclude_lines;
for (auto* const line : kRawPtrManualPathsToIgnore) {
paths_to_exclude_lines.push_back(line);
}
for (auto* const line : kSeparateRepositoryPaths) {
paths_to_exclude_lines.push_back(line);
}
paths_to_exclude =
std::make_unique<raw_ptr_plugin::FilterFile>(paths_to_exclude_lines);
} else {
paths_to_exclude = std::make_unique<raw_ptr_plugin::FilterFile>(
override_exclude_paths_param,
override_exclude_paths_param.ArgStr.str());
}
raw_ptr_plugin::StackAllocatedPredicate stack_allocated_checker;
raw_ptr_plugin::RawPtrAndRefExclusionsOptions exclusion_options{
&fields_to_exclude, paths_to_exclude.get(), exclude_stack_allocated,
&stack_allocated_checker, true};
RawPtrRewriter raw_ptr_rewriter(&output_helper, match_finder,
exclusion_options);
if (rewrite_raw_ref_and_ptr || enable_raw_ptr_rewrite) {
raw_ptr_rewriter.addMatchers();
}
RawRefRewriter raw_ref_rewriter(&output_helper, match_finder,
exclusion_options);
if (rewrite_raw_ref_and_ptr || enable_raw_ref_rewrite) {
raw_ref_rewriter.addMatchers();
}
SpanRewriter span_rewriter(&output_helper, match_finder, exclusion_options);
span_rewriter.addMatchers();
std::unique_ptr<clang::tooling::FrontendActionFactory> factory =
clang::tooling::newFrontendActionFactory(&match_finder, &output_helper);
int result = tool.run(factory.get());
if (result != 0)
return result;
return 0;
}