#include "RawPtrHelpers.h"
#include "StackAllocatedChecker.h"
#include "clang/ASTMatchers/ASTMatchers.h"
#include "llvm/Support/LineIterator.h"
#include "llvm/Support/MemoryBuffer.h"
using namespace clang::ast_matchers;
namespace raw_ptr_plugin {
FilterFile::FilterFile(const std::vector<std::string>& lines) {
for (const auto& line : lines) {
file_lines_.insert(line);
}
}
bool FilterFile::ContainsLine(llvm::StringRef line) const {
auto it = file_lines_.find(line);
return it != file_lines_.end();
}
bool FilterFile::ContainsSubstringOf(llvm::StringRef string_to_match) const {
if (!inclusion_substring_regex_.has_value()) {
std::vector<std::string> regex_escaped_inclusion_file_lines;
std::vector<std::string> regex_escaped_exclusion_file_lines;
regex_escaped_inclusion_file_lines.reserve(file_lines_.size());
for (const llvm::StringRef& file_line : file_lines_.keys()) {
if (file_line.starts_with("!")) {
regex_escaped_exclusion_file_lines.push_back(
llvm::Regex::escape(file_line.substr(1)));
} else {
regex_escaped_inclusion_file_lines.push_back(
llvm::Regex::escape(file_line));
}
}
std::string inclusion_substring_regex_pattern =
llvm::join(regex_escaped_inclusion_file_lines.begin(),
regex_escaped_inclusion_file_lines.end(), "|");
inclusion_substring_regex_.emplace(inclusion_substring_regex_pattern);
std::string exclusion_substring_regex_pattern =
llvm::join(regex_escaped_exclusion_file_lines.begin(),
regex_escaped_exclusion_file_lines.end(), "|");
exclusion_substring_regex_.emplace(exclusion_substring_regex_pattern);
}
return inclusion_substring_regex_->match(string_to_match) &&
!exclusion_substring_regex_->match(string_to_match);
}
void FilterFile::ParseInputFile(const std::string& filepath,
const std::string& arg_name) {
if (filepath.empty()) {
return;
}
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> file_or_err =
llvm::MemoryBuffer::getFile(filepath);
if (std::error_code err = file_or_err.getError()) {
llvm::errs() << "ERROR: Cannot open the file specified in --" << arg_name
<< " argument: " << filepath << ": " << err.message() << "\n";
assert(false);
return;
}
llvm::line_iterator it(**file_or_err, true , '#');
for (; !it.is_at_eof(); ++it) {
llvm::StringRef line = *it;
size_t loc_info_start_pos = line.find('@');
if (loc_info_start_pos != llvm::StringRef::npos) {
line = line.substr(0, loc_info_start_pos);
} else {
size_t comment_start_pos = line.find('#');
if (comment_start_pos != llvm::StringRef::npos) {
line = line.substr(0, comment_start_pos);
}
}
line = line.trim();
if (line.empty()) {
continue;
}
file_lines_.insert(line);
}
}
clang::ast_matchers::internal::Matcher<clang::Decl> ImplicitFieldDeclaration() {
auto implicit_class_specialization_matcher =
classTemplateSpecializationDecl(isImplicitClassTemplateSpecialization());
auto implicit_function_specialization_matcher =
functionDecl(isImplicitFunctionTemplateSpecialization());
auto implicit_field_decl_matcher = fieldDecl(hasParent(cxxRecordDecl(anyOf(
isLambda(), implicit_class_specialization_matcher,
hasAncestor(decl(anyOf(implicit_class_specialization_matcher,
implicit_function_specialization_matcher)))))));
return implicit_field_decl_matcher;
}
clang::ast_matchers::internal::Matcher<clang::QualType> StackAllocatedQualType(
const raw_ptr_plugin::StackAllocatedPredicate* checker) {
return qualType(recordType(hasDeclaration(
cxxRecordDecl(isStackAllocated(*checker)))))
.bind("pointeeQualType");
}
clang::ast_matchers::internal::Matcher<clang::NamedDecl> PtrAndRefExclusions(
const RawPtrAndRefExclusionsOptions& options) {
if (!options.should_exclude_stack_allocated_records) {
return anyOf(isSpellingInSystemHeader(), isInExternCContext(),
isRawPtrExclusionAnnotated(), isInThirdPartyLocation(),
isInGeneratedLocation(), isNotSpelledInSource(),
isInLocationListedInFilterFile(options.paths_to_exclude),
isFieldDeclListedInFilterFile(options.fields_to_exclude),
ImplicitFieldDeclaration(), isObjCSynthesize());
} else {
return anyOf(
isSpellingInSystemHeader(), isInExternCContext(),
isRawPtrExclusionAnnotated(), isInThirdPartyLocation(),
isInGeneratedLocation(), isNotSpelledInSource(),
isInLocationListedInFilterFile(options.paths_to_exclude),
isFieldDeclListedInFilterFile(options.fields_to_exclude),
ImplicitFieldDeclaration(), isObjCSynthesize(),
hasDescendant(
StackAllocatedQualType(options.stack_allocated_predicate)),
isDeclaredInStackAllocated(*options.stack_allocated_predicate));
}
}
clang::ast_matchers::internal::Matcher<clang::TypeLoc>
PtrAndRefTypeLocExclusions() {
return anyOf(isSpellingInSystemHeader(), isInThirdPartyLocation());
}
static const auto unsupported_pointee_types =
pointee(hasUnqualifiedDesugaredType(
anyOf(functionType(), memberPointerType(), arrayType())));
clang::ast_matchers::internal::Matcher<clang::Type> supported_pointer_type() {
return pointerType(unless(unsupported_pointee_types));
}
clang::ast_matchers::internal::Matcher<clang::Type> const_char_pointer_type(
bool should_rewrite_non_string_literals) {
if (should_rewrite_non_string_literals) {
return pointerType(pointee(qualType(hasCanonicalType(
anyOf(asString("const char"), asString("const wchar_t"),
asString("const char8_t"), asString("const char16_t"),
asString("const char32_t"))))));
}
return pointerType(pointee(qualType(
allOf(isConstQualified(), hasUnqualifiedDesugaredType(anyCharType())))));
}
clang::ast_matchers::internal::Matcher<clang::Decl> AffectedRawPtrFieldDecl(
const RawPtrAndRefExclusionsOptions& options) {
auto const_char_pointer_matcher = fieldDecl(hasType(
const_char_pointer_type(options.should_rewrite_non_string_literals)));
auto field_decl_matcher =
fieldDecl(allOf(hasType(supported_pointer_type()),
unless(anyOf(const_char_pointer_matcher,
PtrAndRefExclusions(options)))))
.bind("affectedFieldDecl");
return field_decl_matcher;
}
clang::ast_matchers::internal::Matcher<clang::Decl> AffectedRawRefFieldDecl(
const RawPtrAndRefExclusionsOptions& options) {
auto supported_ref_types_matcher =
referenceType(unless(unsupported_pointee_types));
auto field_decl_matcher =
fieldDecl(allOf(has(referenceTypeLoc().bind("affectedFieldDeclType")),
hasType(supported_ref_types_matcher),
unless(PtrAndRefExclusions(options))))
.bind("affectedFieldDecl");
return field_decl_matcher;
}
clang::ast_matchers::internal::Matcher<clang::TypeLoc>
RawPtrToStackAllocatedTypeLoc(
const raw_ptr_plugin::StackAllocatedPredicate* predicate) {
auto pointer_record =
cxxRecordDecl(hasAnyName("base::raw_ptr", "base::raw_ref"))
.bind("pointerRecordDecl");
auto pointee_type =
qualType(StackAllocatedQualType(predicate)).bind("pointeeQualType");
auto stack_allocated_rawptr_type_loc =
templateSpecializationTypeLoc(
allOf(unless(PtrAndRefTypeLocExclusions()),
loc(templateSpecializationType(hasDeclaration(
allOf(pointer_record,
classTemplateSpecializationDecl(hasTemplateArgument(
0, refersToType(pointee_type)))))))))
.bind("stackAllocatedRawPtrTypeLoc");
return stack_allocated_rawptr_type_loc;
}
clang::ast_matchers::internal::Matcher<clang::Stmt> BadRawPtrCastExpr(
const CastingUnsafePredicate& casting_unsafe_predicate,
const FilterFile& exclude_files,
const FilterFile& exclude_functions) {
auto src_type =
type(isCastingUnsafe(casting_unsafe_predicate)).bind("srcType");
auto dst_type =
type(isCastingUnsafe(casting_unsafe_predicate)).bind("dstType");
auto cast_kind = castExpr(anyOf(hasCastKind(clang::CK_BitCast),
hasCastKind(clang::CK_LValueBitCast),
hasCastKind(clang::CK_LValueToRValueBitCast),
hasCastKind(clang::CK_PointerToIntegral),
hasCastKind(clang::CK_IntegralToPointer)));
auto in_template_invocation_ctx = implicitCastExpr(
allOf(isInTemplateInstantiation(), hasParent(invocation())));
auto in_comparison_ctx =
implicitCastExpr(hasParent(binaryOperator(isComparisonOperator())));
auto in_allowlisted_invocation_ctx =
implicitCastExpr(hasParent(invocation(hasDeclaration(
namedDecl(isFieldDeclListedInFilterFile(&exclude_functions))))));
auto const_builtin_pointer_type =
type(hasUnqualifiedDesugaredType(pointerType(
pointee(qualType(allOf(isConstQualified(), builtinType()))))));
auto cast_expr_to_const_pointer = anyOf(
implicitCastExpr(hasImplicitDestinationType(const_builtin_pointer_type)),
explicitCastExpr(hasDestinationType(const_builtin_pointer_type)));
auto exclusions =
anyOf(isSpellingInSystemHeader(), isInThirdPartyLocation(),
isNotSpelledInSource(),
isInLocationListedInFilterFile(&exclude_files), in_comparison_ctx,
in_allowlisted_invocation_ctx, cast_expr_to_const_pointer,
isInRawPtrCastHeader(), in_template_invocation_ctx);
auto enclosingCastExpr = hasEnclosingExplicitCastExpr(
explicitCastExpr().bind("enclosingCastExpr"));
auto cast_matcher =
castExpr(
allOf(anyOf(hasSourceExpression(hasType(src_type)),
implicitCastExpr(hasImplicitDestinationType(dst_type)),
explicitCastExpr(hasDestinationType(dst_type))),
cast_kind, optionally(enclosingCastExpr),
anyOf(isInStdBitCastHeader(), unless(exclusions))))
.bind("castExpr");
return cast_matcher;
}
const clang::FieldDecl* GetExplicitDecl(const clang::FieldDecl* field_decl) {
if (field_decl->isAnonymousStructOrUnion()) {
return field_decl;
}
const clang::CXXRecordDecl* record_decl =
clang::dyn_cast<clang::CXXRecordDecl>(field_decl->getParent());
if (!record_decl) {
return field_decl;
}
const clang::CXXRecordDecl* pattern_decl =
record_decl->getTemplateInstantiationPattern();
if (!pattern_decl) {
return field_decl;
}
if (record_decl->getTemplateSpecializationKind() !=
clang::TemplateSpecializationKind::TSK_ImplicitInstantiation) {
return field_decl;
}
clang::DeclContextLookupResult lookup_result =
pattern_decl->lookup(field_decl->getDeclName());
assert(!lookup_result.empty());
const clang::NamedDecl* found_decl = lookup_result.front();
assert(found_decl);
field_decl = clang::dyn_cast<clang::FieldDecl>(found_decl);
assert(field_decl);
return field_decl;
}
const clang::ParmVarDecl* GetExplicitDecl(
const clang::ParmVarDecl* original_param) {
const clang::FunctionDecl* original_func =
clang::dyn_cast<clang::FunctionDecl>(original_param->getDeclContext());
if (!original_func) {
return nullptr;
}
const clang::FunctionDecl* pattern_func =
original_func->getTemplateInstantiationPattern();
if (!pattern_func) {
return original_param;
}
bool has_param_pack = false;
unsigned int index_of_param_pack = std::numeric_limits<unsigned int>::max();
for (unsigned int i = 0; i < pattern_func->getNumParams(); i++) {
const clang::ParmVarDecl* pattern_param = pattern_func->getParamDecl(i);
if (!pattern_param->isParameterPack()) {
continue;
}
if (has_param_pack) {
return nullptr;
}
has_param_pack = true;
index_of_param_pack = i;
}
unsigned int original_index = original_param->getFunctionScopeIndex();
unsigned int pattern_index = std::numeric_limits<unsigned int>::max();
if (!has_param_pack) {
pattern_index = original_index;
} else {
unsigned int leading_param_num = index_of_param_pack;
unsigned int pack_expansion_num =
original_func->getNumParams() - pattern_func->getNumParams() + 1;
if (original_index < leading_param_num) {
pattern_index = original_index;
} else if (leading_param_num <= original_index &&
original_index < (leading_param_num + pack_expansion_num)) {
pattern_index = index_of_param_pack;
} else if ((leading_param_num + pack_expansion_num) <= original_index) {
pattern_index = original_index - pack_expansion_num + 1;
}
}
assert(pattern_index < pattern_func->getNumParams());
return pattern_func->getParamDecl(pattern_index);
}
}