#include <assert.h>
#include <algorithm>
#include <array>
#include <map>
#include <optional>
#include <set>
#include <sstream>
#include <string>
#include <string_view>
#include <variant>
#include <vector>
#include "RawPtrHelpers.h"
#include "SeparateRepositoryPaths.h"
#include "SpanifyManualPathsToIgnore.h"
#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Rewrite/Core/Rewriter.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Refactoring.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/TargetSelect.h"
using namespace clang::ast_matchers;
namespace {
enum class ContainerPointerRewritesMode {
kWrapWithBaseSpan,
kDontWrapWithBaseSpan,
};
std::string GetArraySize(const clang::ArrayTypeLoc& array_type_loc,
const clang::SourceManager& source_manager,
const clang::ASTContext& ast_context);
clang::SourceLocation EmitContainerPointerRewrites(
const MatchFinder::MatchResult& result,
std::string_view key,
ContainerPointerRewritesMode mode);
void DumpMatchResult(const MatchFinder::MatchResult& result) {
llvm::errs() << "Matched nodes:\n";
for (const auto& node : result.Nodes.getMap()) {
llvm::errs() << " - " << node.first << ":\n";
}
for (const auto& node : result.Nodes.getMap()) {
llvm::errs() << "\nDump for node " << node.first << ":\n";
node.second.dump(llvm::errs(), *result.Context);
}
}
const char kBaseSpanIncludePath[] = "base/containers/span.h";
const char kBaseRawSpanIncludePath[] = "base/memory/raw_span.h";
const char kBaseAutoSpanificationHelperIncludePath[] =
"base/containers/auto_spanification_helper.h";
const char kArrayIncludePath[] = "array";
const char kStringViewIncludePath[] = "string_view";
enum Precedence {
kNeutralPrecedence = 0,
kAppendDataCallPrecedence,
kDecaySpanToPointerPrecedence,
kAdaptBinaryOperationPrecedence,
kEmitSingleVariableSpanPrecedence,
kAdaptBinaryPlusEqOperationPrecedence,
kRewriteUnaryOperationPrecedence,
};
bool IsInExcludedMacro(clang::SourceLocation loc,
const clang::ASTContext& ast_context,
const clang::SourceManager& source_manager) {
if (!loc.isMacroID()) [[likely]] {
return false;
}
std::string outermost_macro_name;
while (source_manager.isMacroArgExpansion(loc)) {
outermost_macro_name = std::string(clang::Lexer::getImmediateMacroName(
loc, source_manager, ast_context.getLangOpts()));
loc = source_manager.getImmediateSpellingLoc(loc);
}
if (loc.isMacroID()) {
return true;
}
if (outermost_macro_name.starts_with("ASSERT_") ||
outermost_macro_name == "CHECK" ||
outermost_macro_name.starts_with("CHECK_") ||
outermost_macro_name == "DCHECK" ||
outermost_macro_name.starts_with("DCHECK_") ||
outermost_macro_name.starts_with("EXPECT_")) {
return false;
}
return true;
}
AST_POLYMORPHIC_MATCHER(isInExcludedMacroLocation,
AST_POLYMORPHIC_SUPPORTED_TYPES(clang::Decl,
clang::Stmt,
clang::TypeLoc)) {
auto loc = Node.getBeginLoc();
const clang::ASTContext& ast_context = Finder->getASTContext();
const clang::SourceManager& source_manager = ast_context.getSourceManager();
return IsInExcludedMacro(std::move(loc), ast_context, source_manager);
}
AST_MATCHER_P(clang::FunctionDecl,
forEachParmVarDecl,
clang::ast_matchers::internal::Matcher<clang::ParmVarDecl>,
parm_var_decl_matcher) {
const clang::FunctionDecl& function_decl = Node;
const unsigned num_params = function_decl.getNumParams();
bool is_matching = false;
clang::ast_matchers::internal::BoundNodesTreeBuilder result;
for (unsigned i = 0; i < num_params; i++) {
const clang::ParmVarDecl* param = function_decl.getParamDecl(i);
clang::ast_matchers::internal::BoundNodesTreeBuilder param_matches(
*Builder);
if (parm_var_decl_matcher.matches(*param, Finder, ¶m_matches)) {
is_matching = true;
result.addMatch(param_matches);
}
}
*Builder = std::move(result);
return is_matching;
}
AST_MATCHER(clang::VarDecl, hasExternalStorage) {
return Node.hasExternalStorage();
}
bool ArraySize(const clang::Expr* expr, uint64_t* output_size) {
if (const auto* constant_array = clang::dyn_cast<clang::ConstantArrayType>(
expr->getType()->getUnqualifiedDesugaredType())) {
*output_size = constant_array->getSize().getLimitedValue();
return true;
}
if (const auto* string_literal =
clang::dyn_cast<clang::StringLiteral>(expr)) {
*output_size = string_literal->getLength() + 1;
return true;
}
return false;
}
AST_MATCHER(clang::ArraySubscriptExpr, isSafeArraySubscript) {
uint64_t size = 0;
if (!ArraySize(Node.getBase()->IgnoreParenImpCasts(), &size)) {
return false;
}
clang::Expr::EvalResult eval_index;
const clang::Expr* index_expr = Node.getIdx();
if (index_expr->isValueDependent()) {
return false;
}
if (!index_expr->EvaluateAsInt(eval_index, Finder->getASTContext())) {
return false;
}
llvm::APInt index_value = eval_index.Val.getInt();
if (index_value.isNegative()) {
clang::SourceManager& source_manager =
Finder->getASTContext().getSourceManager();
llvm::errs() << llvm::formatv(
"{0}:{1}: Warning: array subscript out of bounds: {0} < 0\n",
source_manager.getFilename(Node.getExprLoc()),
source_manager.getSpellingLineNumber(Node.getExprLoc()),
index_value.getSExtValue());
return false;
}
if (index_value.uge(size)) {
clang::SourceManager& source_manager =
Finder->getASTContext().getSourceManager();
llvm::errs() << llvm::formatv(
"{0}:{1}: Warning: array subscript out of bounds: {2} >= {3}\n",
source_manager.getFilename(Node.getExprLoc()),
source_manager.getSpellingLineNumber(Node.getExprLoc()),
index_value.getSExtValue(), size);
return false;
}
return true;
}
struct UnsafeFreeFuncToMacro {
const std::string_view function_name;
const std::string_view macro_name;
};
std::optional<UnsafeFreeFuncToMacro> FindUnsafeFreeFuncToBeRewrittenToMacro(
const clang::FunctionDecl* function_decl) {
static constexpr UnsafeFreeFuncToMacro unsafe_free_func_table[] = {
{"CRYPTO_BUFFER_data", "UNSAFE_CRYPTO_BUFFER_DATA"},
{"hb_buffer_get_glyph_infos", "UNSAFE_HB_BUFFER_GET_GLYPH_INFOS"},
{"hb_buffer_get_glyph_positions", "UNSAFE_HB_BUFFER_GET_GLYPH_POSITIONS"},
{"g_get_system_data_dirs", "UNSAFE_G_GET_SYSTEM_DATA_DIRS"},
};
const std::string& function_name = function_decl->getQualifiedNameAsString();
for (const auto& entry : unsafe_free_func_table) {
if (function_name == entry.function_name) {
return entry;
}
}
return std::nullopt;
}
struct UnsafeCxxMethodToMacro {
const std::string_view class_name;
const std::string_view method_name;
const std::string_view macro_name;
};
std::optional<UnsafeCxxMethodToMacro> FindUnsafeCxxMethodToBeRewrittenToMacro(
const clang::CXXMethodDecl* method_decl) {
static constexpr UnsafeCxxMethodToMacro unsafe_cxx_method_table[] = {
{"SkBitmap", "NoArgForTesting", "UNSAFE_SKBITMAP_NOARGFORTESTING"},
{"SkBitmap", "getAddr32", "UNSAFE_SKBITMAP_GETADDR32"},
};
const clang::CXXRecordDecl* class_decl = method_decl->getParent();
const std::string& method_name = method_decl->getNameAsString();
const std::string& class_name = class_decl->getQualifiedNameAsString();
for (const auto& entry : unsafe_cxx_method_table) {
if (method_name == entry.method_name && class_name == entry.class_name) {
return entry;
}
}
return std::nullopt;
}
AST_MATCHER(clang::FunctionDecl, unsafeFunctionToBeRewrittenToMacro) {
const clang::FunctionDecl* function_decl = &Node;
if (const clang::CXXMethodDecl* method_decl =
clang::dyn_cast<clang::CXXMethodDecl>(function_decl)) {
return bool(FindUnsafeCxxMethodToBeRewrittenToMacro(method_decl));
}
return bool(FindUnsafeFreeFuncToBeRewrittenToMacro(function_decl));
}
std::string ToStringWithPadding(size_t value, size_t padding) {
std::string str = std::to_string(value);
assert(str.size() <= padding);
return std::string(padding - str.size(), '0') + str;
}
std::string HashBase64(const std::string& input, size_t output_size = 4) {
std::hash<std::string> hasher;
size_t hash = hasher(input);
constexpr std::array<char, 64> charset = {
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c',
'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',
'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '-', '_',
};
std::string output(output_size, '0');
for (size_t i = 0; i < output.size(); i++) {
output[i] = charset[hash % charset.size()];
hash /= charset.size();
}
return output;
}
template <bool human_readable = false >
std::string NodeKeyFromRange(const clang::SourceRange& range,
const clang::SourceManager& source_manager,
const std::string& optional_seed = "") {
clang::tooling::Replacement replacement(
source_manager, clang::CharSourceRange::getCharRange(range), "");
llvm::StringRef path = replacement.getFilePath();
llvm::StringRef file_name = llvm::sys::path::filename(path);
if constexpr (!human_readable) {
return llvm::formatv(
"{0}:{1}", ToStringWithPadding(replacement.getOffset(), 7),
HashBase64(NodeKeyFromRange<true>(range, source_manager, optional_seed),
8));
}
return llvm::formatv("{0}:{1}:{2}:{3}:{4}:{5}",
ToStringWithPadding(replacement.getOffset(), 7),
HashBase64(path.str() + optional_seed), file_name,
source_manager.getSpellingLineNumber(range.getBegin()),
source_manager.getSpellingColumnNumber(range.getBegin()),
replacement.getLength());
}
template <typename T>
std::string NodeKey(const T* node,
const clang::SourceManager& source_manager,
const std::string& optional_seed = "") {
return NodeKeyFromRange(node->getSourceRange(), source_manager,
optional_seed);
}
std::string GetRHS(const MatchFinder::MatchResult& result);
std::string GetLHS(const MatchFinder::MatchResult& result);
void Emit(const std::string& line) {
static std::set<std::string> emitted;
if (emitted.count(line) == 0) {
emitted.insert(line);
llvm::outs() << line;
}
}
void EmitReplacement(std::string_view node, std::string_view replacement) {
Emit(llvm::formatv("r {0} {1}\n", node, replacement));
}
void EmitEdge(const std::string& lhs, const std::string& rhs) {
Emit(llvm::formatv("e {0} {1}\n", lhs, rhs));
}
void EmitSource(const std::string& node) {
Emit(llvm::formatv("s {0}\n", node));
}
void EmitSink(const std::string& node) {
Emit(llvm::formatv("i {0}\n", node));
}
void EmitFrontier(const std::string& lhs_key,
const std::string& rhs_key,
const std::string& replacement) {
Emit(llvm::formatv("f {0} {1} {2}\n", lhs_key, rhs_key, replacement));
}
std::string GetReplacementDirective(const clang::SourceRange& replacement_range,
std::string replacement_text,
const clang::SourceManager& source_manager,
int precedence = kNeutralPrecedence) {
clang::tooling::Replacement replacement(
source_manager, clang::CharSourceRange::getCharRange(replacement_range),
replacement_text);
llvm::StringRef file_path = replacement.getFilePath();
assert(!file_path.empty() && "Replacement file path is empty.");
std::replace(replacement_text.begin(), replacement_text.end(), '\n', '\0');
return llvm::formatv("r:::{0}:::{1}:::{2}:::{3}:::{4}", file_path,
replacement.getOffset(), replacement.getLength(),
precedence, replacement_text);
}
std::string GetIncludeDirective(const clang::SourceRange replacement_range,
const clang::SourceManager& source_manager,
const char* include_path = kBaseSpanIncludePath,
bool is_system_include_path = false) {
return llvm::formatv(
"{0}:::{1}:::-1:::-1:::{2}",
is_system_include_path ? "include-system-header" : "include-user-header",
GetFilename(source_manager, replacement_range.getBegin(),
raw_ptr_plugin::FilenameLocationType::kSpellingLoc),
include_path);
}
template <typename T>
const T* GetNodeOrCrash(const MatchFinder::MatchResult& result,
std::string_view id,
std::string_view assert_message) {
const T* node = result.Nodes.getNodeAs<T>(id);
if (!node) {
llvm::errs() << "\nError: no node for `" << id << "` (" << assert_message
<< ")\n";
DumpMatchResult(result);
assert(false && "`GetNodeOrCrash()`");
}
return node;
}
std::function<clang::SourceLocation(clang::SourceLocation)> GetSpellingLocFunc(
const clang::SourceManager& source_manager [[clang::lifetimebound]],
const clang::LangOptions& lang_opts [[clang::lifetimebound]]) {
return [&](clang::SourceLocation loc) -> clang::SourceLocation {
if (!loc.isMacroID()) [[likely]] {
return loc;
}
clang::SourceLocation original_loc = loc;
while (source_manager.isMacroArgExpansion(loc)) {
loc = source_manager.getImmediateSpellingLoc(loc);
}
return loc.isValid() && loc.isFileID() ? loc : original_loc;
};
}
clang::SourceRange GetExprRange(const clang::Expr& expr,
const clang::SourceManager& source_manager,
const clang::LangOptions& lang_opts) {
auto ToSpellingLoc = GetSpellingLocFunc(source_manager, lang_opts);
if (const auto* member_expr = clang::dyn_cast<clang::MemberExpr>(&expr)) {
clang::SourceLocation member_loc =
ToSpellingLoc(member_expr->getMemberLoc());
size_t member_name_length = member_expr->getMemberDecl()->getName().size();
return {member_loc, member_loc.getLocWithOffset(member_name_length)};
}
if (const auto* decl_ref = clang::dyn_cast<clang::DeclRefExpr>(&expr)) {
assert(decl_ref->getBeginLoc() == decl_ref->getEndLoc() &&
"DeclRefExpr doesn't have the expected end loc.");
clang::SourceLocation begin_loc = ToSpellingLoc(decl_ref->getBeginLoc());
auto name = decl_ref->getNameInfo().getName().getAsString();
return {begin_loc, begin_loc.getLocWithOffset(name.size())};
}
if (const auto* call_expr = clang::dyn_cast<clang::CallExpr>(&expr)) {
return {ToSpellingLoc(call_expr->getBeginLoc()),
ToSpellingLoc(call_expr->getRParenLoc()).getLocWithOffset(1)};
}
if (auto* binary_op = clang::dyn_cast<clang::BinaryOperator>(&expr)) {
return {
ToSpellingLoc(expr.getBeginLoc()),
GetExprRange(*binary_op->getRHS(), source_manager, lang_opts).getEnd()};
}
if (auto* uett_expr =
clang::dyn_cast<clang::UnaryExprOrTypeTraitExpr>(&expr)) {
if (uett_expr->getKind() == clang::UETT_SizeOf) {
assert(expr.getBeginLoc() != expr.getEndLoc());
clang::SourceLocation begin_loc = ToSpellingLoc(expr.getBeginLoc());
clang::SourceLocation end_loc = ToSpellingLoc(expr.getEndLoc());
size_t token_length =
clang::Lexer::MeasureTokenLength(end_loc, source_manager, lang_opts);
return {begin_loc, end_loc.getLocWithOffset(token_length)};
}
}
const clang::SourceLocation begin_location = expr.getBeginLoc();
const clang::SourceLocation end_location = expr.getEndLoc();
if (begin_location != end_location) {
llvm::errs() << "Error: expected token with unhelpful `SourceLocation`s, "
"but got:\n "
<< begin_location.printToString(source_manager) << "\nand\n "
<< end_location.printToString(source_manager) << "\n";
assert(false && "Defaults to a single token expr.");
}
clang::SourceLocation begin_loc = ToSpellingLoc(expr.getBeginLoc());
size_t token_length =
clang::Lexer::MeasureTokenLength(begin_loc, source_manager, lang_opts);
return {begin_loc, begin_loc.getLocWithOffset(token_length)};
}
std::string GetTypeAsString(const clang::QualType& qual_type,
const clang::ASTContext& ast_context) {
clang::PrintingPolicy printing_policy(ast_context.getLangOpts());
printing_policy.SuppressScope = 0;
printing_policy.SuppressTagKeyword = 0;
printing_policy.SuppressUnwrittenScope = 1;
printing_policy.SuppressInlineNamespace = 1;
printing_policy.SuppressDefaultTemplateArgs = 1;
printing_policy.PrintAsCanonical = 0;
return qual_type.getAsString(printing_policy);
}
clang::SourceRange getSourceRange(const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const clang::LangOptions& lang_opts = result.Context->getLangOpts();
auto ToSpellingLoc = GetSpellingLocFunc(source_manager, lang_opts);
if (auto* op =
result.Nodes.getNodeAs<clang::UnaryOperator>("unaryOperator")) {
if (op->isPostfix()) {
return {ToSpellingLoc(op->getBeginLoc()),
ToSpellingLoc(op->getEndLoc()).getLocWithOffset(2)};
}
auto* expr = result.Nodes.getNodeAs<clang::Expr>("rhs_expr");
return {ToSpellingLoc(op->getBeginLoc()),
GetExprRange(*expr, source_manager, lang_opts).getEnd()};
}
if (auto* op = result.Nodes.getNodeAs<clang::Expr>("binaryOperator")) {
auto* sub_expr = result.Nodes.getNodeAs<clang::Expr>("binary_op_rhs");
auto end_loc = GetExprRange(*sub_expr, source_manager, lang_opts).getEnd();
return {ToSpellingLoc(op->getBeginLoc()), end_loc};
}
if (auto* op = result.Nodes.getNodeAs<clang::CXXOperatorCallExpr>(
"raw_ptr_operator++")) {
auto* callee = op->getDirectCallee();
if (callee->getNumParams() == 0) {
auto* expr = result.Nodes.getNodeAs<clang::Expr>("rhs_expr");
return clang::SourceRange(
GetExprRange(*expr, source_manager, lang_opts).getEnd());
}
return clang::SourceRange(
ToSpellingLoc(op->getEndLoc()).getLocWithOffset(2));
}
if (auto* expr = result.Nodes.getNodeAs<clang::Expr>("rhs_expr")) {
return clang::SourceRange(
GetExprRange(*expr, source_manager, lang_opts).getEnd());
}
if (auto* size_expr = result.Nodes.getNodeAs<clang::Expr>("size_node")) {
return clang::SourceRange(
GetExprRange(*size_expr, source_manager, lang_opts).getEnd());
}
llvm::errs() << "\n"
"Error: getSourceRange() encountered an unexpected match.\n"
"Expected one of : \n"
" - unaryOperator\n"
" - binaryOperator\n"
" - raw_ptr_operator++\n"
" - rhs_expr\n"
"\n";
DumpMatchResult(result);
assert(false && "Unexpected match in getSourceRange()");
}
clang::TypeLoc UnwrapTypedefTypeLoc(clang::TypeLoc type_loc) {
while (const clang::TypedefTypeLoc typedef_type_loc =
type_loc.getAs<clang::TypedefTypeLoc>()) {
const clang::TypedefNameDecl* typedef_name_decl =
typedef_type_loc.getDecl();
type_loc = typedef_name_decl->getTypeSourceInfo()->getTypeLoc();
}
return type_loc;
}
std::string getNodeFromPointerTypeLoc(const clang::PointerTypeLoc* type_loc,
const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const clang::ASTContext& ast_context = *result.Context;
const auto& lang_opts = ast_context.getLangOpts();
clang::SourceRange replacement_range = [type_loc, &result, &source_manager,
&lang_opts]() {
const auto* qualified_type_loc =
result.Nodes.getNodeAs<clang::QualifiedTypeLoc>("qualified_type_loc");
clang::SourceRange result = {type_loc->getBeginLoc(),
type_loc->getEndLoc().getLocWithOffset(1)};
if (!qualified_type_loc ||
!qualified_type_loc->getType().isConstQualified()) {
return result;
}
std::optional<clang::Token> previous_token =
clang::Lexer::findPreviousToken(type_loc->getBeginLoc(), source_manager,
lang_opts, false);
if (!previous_token.has_value()) {
return result;
}
std::string_view hopefully_const_qualifier = clang::Lexer::getSourceText(
clang::CharSourceRange::getCharRange(
{previous_token->getLocation(), previous_token->getEndLoc()}),
source_manager, lang_opts);
if (hopefully_const_qualifier != "const") {
llvm::errs() << "WARNING: `getNodeFromPointerTypeLoc()` expected "
"`const`, but got: "
<< hopefully_const_qualifier << " instead.\n";
return result;
}
result.setBegin(previous_token->getLocation());
return result;
}();
std::string initial_text =
clang::Lexer::getSourceText(
clang::CharSourceRange::getCharRange(replacement_range),
source_manager, lang_opts)
.str();
initial_text.pop_back();
std::string replacement_text = "base::span<" + initial_text + ">";
const std::string key = NodeKey(type_loc, source_manager);
EmitReplacement(key,
GetReplacementDirective(replacement_range, replacement_text,
source_manager));
EmitReplacement(key, GetIncludeDirective(replacement_range, source_manager));
return key;
}
std::string getNodeFromRawPtrTypeLoc(
const clang::TemplateSpecializationTypeLoc* raw_ptr_type_loc,
const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
auto replacement_range = clang::SourceRange(raw_ptr_type_loc->getBeginLoc(),
raw_ptr_type_loc->getLAngleLoc());
const std::string key = NodeKey(raw_ptr_type_loc, source_manager);
EmitReplacement(key,
GetReplacementDirective(replacement_range, "base::raw_span",
source_manager));
EmitReplacement(key, GetIncludeDirective(replacement_range, source_manager,
kBaseRawSpanIncludePath));
return key;
}
std::string getNodeFromFunctionArrayParameter(
const clang::TypeLoc* type_loc,
const clang::ParmVarDecl* param_decl,
const MatchFinder::MatchResult& result) {
clang::SourceManager& source_manager = *result.SourceManager;
const clang::ASTContext& ast_context = *result.Context;
const clang::QualType& qual_type = param_decl->getType();
std::ostringstream qualifiers;
qualifiers << (qual_type.isConstQualified() ? "const " : "")
<< (qual_type.isVolatileQualified() ? "volatile " : "");
std::string type = GetTypeAsString(qual_type->getPointeeType(), ast_context);
const clang::ArrayTypeLoc& array_type_loc =
type_loc->getUnqualifiedLoc().getAs<clang::ArrayTypeLoc>();
assert(!array_type_loc.isNull());
const std::string& array_size_as_string =
GetArraySize(array_type_loc, source_manager, ast_context);
std::string span_type;
if (array_size_as_string.empty()) {
span_type = llvm::formatv("base::span<{0}> ", type).str();
} else {
span_type =
llvm::formatv("base::span<{0}, {1}> ", type, array_size_as_string)
.str();
}
clang::SourceRange replacement_range{
param_decl->getBeginLoc(),
array_type_loc.getRBracketLoc().getLocWithOffset(1)};
std::string replacement_text =
qualifiers.str() + span_type + param_decl->getNameAsString();
const std::string key =
NodeKeyFromRange(replacement_range, source_manager, type);
EmitReplacement(key,
GetReplacementDirective(replacement_range, replacement_text,
source_manager));
EmitReplacement(key, GetIncludeDirective(replacement_range, source_manager));
return key;
}
std::string getNodeFromDecl(const clang::DeclaratorDecl* decl,
const MatchFinder::MatchResult& result) {
clang::SourceManager& source_manager = *result.SourceManager;
const clang::ASTContext& ast_context = *result.Context;
clang::SourceRange replacement_range{decl->getBeginLoc(),
decl->getLocation()};
const clang::QualType& qual_type = decl->getType();
std::ostringstream qualifiers;
qualifiers << (qual_type.isConstQualified() ? "const " : "")
<< (qual_type.isVolatileQualified() ? "volatile " : "");
std::string type = GetTypeAsString(qual_type->getPointeeType(), ast_context);
std::string replacement_text =
qualifiers.str() + llvm::formatv("base::span<{0}>", type).str();
const std::string key =
NodeKeyFromRange(replacement_range, source_manager, type);
EmitReplacement(key,
GetReplacementDirective(replacement_range, replacement_text,
source_manager));
EmitReplacement(key, GetIncludeDirective(replacement_range, source_manager));
return key;
}
void DecaySpanToPointer(const MatchFinder::MatchResult& result) {
const clang::Expr* deref_expr =
result.Nodes.getNodeAs<clang::Expr>("deref_expr");
const clang::SourceManager& source_manager = *result.SourceManager;
auto begin_range = clang::SourceRange(
deref_expr->getBeginLoc(), deref_expr->getBeginLoc().getLocWithOffset(1));
auto end_range = clang::SourceRange(getSourceRange(result).getEnd());
std::string begin_replacement_text = " ";
std::string end_replacement_text = "[0]";
if (result.Nodes.getNodeAs<clang::Expr>("unaryOperator")) {
begin_replacement_text = "(";
end_replacement_text = ")[0]";
}
EmitReplacement(
GetRHS(result),
GetReplacementDirective(begin_range, begin_replacement_text,
source_manager, -kDecaySpanToPointerPrecedence));
EmitReplacement(
GetRHS(result),
GetReplacementDirective(end_range, end_replacement_text, source_manager,
kDecaySpanToPointerPrecedence));
}
clang::SourceLocation GetBinaryOperationOperatorLoc(
const clang::Expr* expr,
const MatchFinder::MatchResult& result) {
if (auto* binary_op = clang::dyn_cast_or_null<clang::BinaryOperator>(expr)) {
return binary_op->getOperatorLoc();
}
if (auto* binary_op =
clang::dyn_cast_or_null<clang::CXXOperatorCallExpr>(expr)) {
return binary_op->getOperatorLoc();
}
if (auto* binary_op =
clang::dyn_cast_or_null<clang::CXXRewrittenBinaryOperator>(expr)) {
return binary_op->getOperatorLoc();
}
llvm::errs()
<< "\n"
"Error: GetBinaryOperationOperatorLoc() encountered an unexpected "
"expression.\n"
"Expected on of clang::BinaryOperator, clang::CXXOperatorCallExpr, "
"clang::CXXRewrittenBinaryOperator \n";
DumpMatchResult(result);
assert(false && "Unexpected binaryOperation Node");
}
struct RangedReplacement {
clang::SourceRange range;
std::string text;
};
struct CheckedCastReplacement {
RangedReplacement opener;
RangedReplacement closer;
};
using SubspanExprReplacement =
std::variant<std::monostate, RangedReplacement, CheckedCastReplacement>;
SubspanExprReplacement GetSubspanExprReplacement(
const clang::Expr* expr,
const MatchFinder::MatchResult& result,
std::string_view key) {
clang::QualType type = expr->getType();
const clang::ASTContext& ast_context = *result.Context;
const uint64_t size_t_bits =
ast_context.getTypeSize(ast_context.getSizeType());
const bool is_unsigned_type =
type == ast_context.getCorrespondingUnsignedType(type);
if (is_unsigned_type && ast_context.getTypeSize(type) <= size_t_bits) {
return {};
}
const clang::SourceManager& source_manager = *result.SourceManager;
const clang::SourceRange range =
GetExprRange(*expr, source_manager, result.Context->getLangOpts());
if (const auto* integer_literal =
clang::dyn_cast<clang::IntegerLiteral>(expr)) {
assert(integer_literal->getValue().isNonNegative());
return RangedReplacement{.range = range.getEnd(), .text = "u"};
}
EmitReplacement(key, GetIncludeDirective(range, source_manager,
"base/numerics/safe_conversions.h"));
EmitReplacement(key, GetIncludeDirective(range, source_manager, "cstdint",
true));
return CheckedCastReplacement{
.opener = {.range = range.getBegin(),
.text = "base::checked_cast<size_t>("},
.closer = {.range = range.getEnd(), .text = ")"}};
}
void AdaptBinaryOpInMacro(const MatchFinder::MatchResult& result,
const std::string& key) {
const clang::SourceManager& source_manager = *result.SourceManager;
const clang::ASTContext& ast_context = *result.Context;
const auto& lang_opts = ast_context.getLangOpts();
const auto* decl_ref =
result.Nodes.getNodeAs<clang::DeclRefExpr>("declRefExpr");
if (!decl_ref) {
llvm::errs()
<< "\n"
"Error: In case of a binary operation in a macro expansion, "
"only `declRefExpr` is supported for now.\n";
DumpMatchResult(result);
return;
}
EmitReplacement(
key, GetReplacementDirective(
GetExprRange(*decl_ref, source_manager, lang_opts).getEnd(),
".data()", source_manager));
clang::CharSourceRange macro_range =
source_manager.getExpansionRange(decl_ref->getBeginLoc());
EmitReplacement(key, GetReplacementDirective(macro_range.getBegin(),
"UNSAFE_TODO(", source_manager));
EmitReplacement(
key, GetReplacementDirective(macro_range.getEnd().getLocWithOffset(1),
")", source_manager));
}
std::string CreateSubspanOpener(
std::string_view prefix,
const SubspanExprReplacement* subspan_expr_replacement) {
std::string_view maybe_checked_cast_opener = "";
if (const auto* replacement =
std::get_if<CheckedCastReplacement>(subspan_expr_replacement)) {
maybe_checked_cast_opener = replacement->opener.text;
}
return llvm::formatv("{0}.subspan({1}", prefix, maybe_checked_cast_opener);
}
std::string CreateSubspanCloser(
const SubspanExprReplacement* subspan_expr_replacement) {
std::string_view maybe_closer = "";
if (const auto* replacement =
std::get_if<RangedReplacement>(subspan_expr_replacement)) {
maybe_closer = replacement->text;
} else if (const auto* replacement = std::get_if<CheckedCastReplacement>(
subspan_expr_replacement)) {
maybe_closer = replacement->closer.text;
}
return llvm::formatv("{0})", maybe_closer);
}
void AdaptBinaryOperation(const MatchFinder::MatchResult& result) {
const clang::ASTContext& ast_context = *result.Context;
const clang::SourceManager& source_manager = *result.SourceManager;
const auto* binary_operation =
GetNodeOrCrash<clang::Expr>(result, "binary_operation", __FUNCTION__);
const auto* rhs_expr =
GetNodeOrCrash<clang::Expr>(result, "rhs_expr", __FUNCTION__);
const std::string key = GetRHS(result);
if (IsInExcludedMacro(binary_operation->getBeginLoc(), ast_context,
source_manager) &&
IsInExcludedMacro(rhs_expr->getBeginLoc(), ast_context, source_manager)) {
AdaptBinaryOpInMacro(result, key);
return;
}
const auto* rhs_array_type =
result.Nodes.getNodeAs<clang::ArrayTypeLoc>("rhs_array_type_loc");
if (rhs_array_type) {
const auto* concrete_binary_operation =
GetNodeOrCrash<clang::BinaryOperator>(
result, "binary_operation",
"C-style array should not involve `CXXOperatorCallExpr` or "
"`CXXRewrittenBinaryOperator`");
EmitReplacement(
key, GetReplacementDirective(
concrete_binary_operation->getLHS()->getBeginLoc(),
llvm::formatv("base::span<{0}>(",
GetTypeAsString(rhs_array_type->getInnerType(),
*result.Context)),
source_manager, kAdaptBinaryOperationPrecedence));
}
const auto* binary_op_RHS =
GetNodeOrCrash<clang::Expr>(result, "binary_op_rhs", __FUNCTION__);
const auto subspan_expr_replacement =
GetSubspanExprReplacement(binary_op_RHS, result, key);
std::string_view prefix = rhs_array_type ? ")" : "";
std::string subspan_opener =
CreateSubspanOpener(prefix, &subspan_expr_replacement);
const clang::SourceLocation binary_operator_begin =
GetBinaryOperationOperatorLoc(binary_operation, result);
EmitReplacement(
key,
GetReplacementDirective(
{binary_operator_begin, binary_operator_begin.getLocWithOffset(1)},
subspan_opener, source_manager, -kAdaptBinaryOperationPrecedence));
const clang::SourceRange operator_rhs_range = GetExprRange(
*binary_op_RHS, source_manager, result.Context->getLangOpts());
std::string subspan_closer = CreateSubspanCloser(&subspan_expr_replacement);
EmitReplacement(key, GetReplacementDirective(
operator_rhs_range.getEnd(), subspan_closer,
source_manager, -kAdaptBinaryOperationPrecedence));
EmitReplacement(key, GetIncludeDirective(binary_operation->getBeginLoc(),
source_manager));
}
void AdaptBinaryPlusEqOperation(const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const clang::ASTContext& ast_context = *result.Context;
const auto& lang_opts = ast_context.getLangOpts();
auto* lhs_expr = result.Nodes.getNodeAs<clang::Expr>("rhs_expr");
auto* binary_op_RHS = result.Nodes.getNodeAs<clang::Expr>("binary_op_RHS");
auto lhs_expr_range = GetExprRange(*lhs_expr, source_manager, lang_opts);
auto binary_op_rhs_range =
GetExprRange(*binary_op_RHS, source_manager, lang_opts);
auto source_range = clang::SourceRange(lhs_expr_range.getEnd(),
binary_op_rhs_range.getBegin());
const std::string& key = GetRHS(result);
auto subspan_arg_fixup =
GetSubspanExprReplacement(binary_op_RHS, result, key);
std::string lhs_expr_text =
clang::Lexer::getSourceText(
clang::CharSourceRange::getCharRange(lhs_expr_range), source_manager,
lang_opts)
.str();
EmitReplacement(key,
GetReplacementDirective(
source_range,
CreateSubspanOpener(
std::string(llvm::formatv("= {0}", lhs_expr_text)),
&subspan_arg_fixup),
source_manager, kAdaptBinaryPlusEqOperationPrecedence));
std::string subspan_closer = CreateSubspanCloser(&subspan_arg_fixup);
EmitReplacement(
key, GetReplacementDirective(
clang::SourceRange(binary_op_rhs_range.getEnd()), subspan_closer,
source_manager, -kAdaptBinaryPlusEqOperationPrecedence));
}
void DecaySpanToBooleanOp(const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const std::string& key = GetRHS(result);
if (const auto* logical_not_op =
result.Nodes.getNodeAs<clang::UnaryOperator>("logical_not_op")) {
const clang::SourceRange logical_not_range{
logical_not_op->getBeginLoc(),
logical_not_op->getBeginLoc().getLocWithOffset(1)};
EmitReplacement(
key, GetReplacementDirective(logical_not_range, "", source_manager));
} else {
const auto* operand =
result.Nodes.getNodeAs<clang::Expr>("boolean_op_operand");
EmitReplacement(key, GetReplacementDirective(operand->getBeginLoc(), "!",
source_manager));
}
EmitReplacement(key, GetReplacementDirective(getSourceRange(result).getEnd(),
".empty()", source_manager));
}
void EraseMemberCall(const std::string& node,
const clang::MemberExpr* member_expr,
const clang::SourceManager& source_manager) {
if (member_expr->isArrow()) {
clang::SourceRange replacement_range(member_expr->getBase()->getBeginLoc(),
member_expr->getBeginLoc());
EmitReplacement(
node, GetReplacementDirective(replacement_range, "*", source_manager));
}
{
clang::SourceRange replacement_range(
member_expr->getMemberLoc().getLocWithOffset(
member_expr->isArrow() ? -2 : -1),
member_expr->getMemberLoc().getLocWithOffset(
member_expr->getMemberDecl()->getName().size() + 2));
EmitReplacement(
node, GetReplacementDirective(replacement_range, "", source_manager));
}
}
void AppendDataCall(const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const std::string key = GetRHS(result);
auto rep_range = clang::SourceRange(getSourceRange(result).getEnd());
std::string replacement_text = ".data()";
if (result.Nodes.getNodeAs<clang::Expr>("unaryOperator")) {
if (result.Nodes.getNodeAs<clang::Expr>("container_buff_address")) {
rep_range = EmitContainerPointerRewrites(
result, key, ContainerPointerRewritesMode::kDontWrapWithBaseSpan);
} else {
auto begin_range = clang::SourceRange(getSourceRange(result).getBegin());
EmitReplacement(key,
GetReplacementDirective(begin_range, "(", source_manager,
kAppendDataCallPrecedence));
replacement_text = ").data()";
}
}
EmitReplacement(
key, GetReplacementDirective(rep_range, replacement_text, source_manager,
-kAppendDataCallPrecedence));
}
void RewriteExprForSubspan(const clang::Expr* expr,
const MatchFinder::MatchResult& result,
std::string_view key) {
const auto replacement = GetSubspanExprReplacement(expr, result, key);
if (const auto* u_suffix = std::get_if<RangedReplacement>(&replacement)) {
EmitReplacement(key,
GetReplacementDirective(u_suffix->range, u_suffix->text,
*result.SourceManager));
return;
}
if (const auto* checked_cast_replacement =
std::get_if<CheckedCastReplacement>(&replacement)) {
const auto& [opener, closer] = *checked_cast_replacement;
EmitReplacement(key, GetReplacementDirective(opener.range, opener.text,
*result.SourceManager));
EmitReplacement(key, GetReplacementDirective(closer.range, closer.text,
*result.SourceManager));
return;
}
if (!std::get_if<std::monostate>(&replacement)) {
llvm::errs() << "Unexpected variant in `RewriteExprForSubspan()`.";
DumpMatchResult(result);
return;
}
}
clang::SourceLocation FindRightBracket(const MatchFinder::MatchResult& result,
const clang::Expr* subscript_expr) {
if (const auto* array_subscript_expr =
clang::dyn_cast<clang::ArraySubscriptExpr>(subscript_expr)) {
return array_subscript_expr->getRBracketLoc();
} else if (const auto* operator_subscript_expr =
clang::dyn_cast<clang::CXXOperatorCallExpr>(subscript_expr)) {
return operator_subscript_expr->getRParenLoc();
}
llvm::errs() << "Error: no matching cast for `subscript_expr` in "
<< __FUNCTION__ << "\n";
DumpMatchResult(result);
assert(false);
}
const clang::Expr* GetIndexExprForSubspan(
const MatchFinder::MatchResult& result,
const clang::Expr* subscript_expr) {
if (const auto* array_subscript_expr =
clang::dyn_cast<clang::ArraySubscriptExpr>(subscript_expr)) {
return array_subscript_expr->getIdx();
} else if (const auto* operator_subscript_expr =
clang::dyn_cast<clang::CXXOperatorCallExpr>(subscript_expr)) {
assert(operator_subscript_expr->getNumArgs() == 2u);
return operator_subscript_expr->getArg(1u)->IgnoreImpCasts();
}
llvm::errs() << "Error: no matching cast for `subscript_expr` in "
<< __FUNCTION__ << "\n";
DumpMatchResult(result);
assert(false);
}
clang::SourceLocation EmitContainerPointerRewrites(
const MatchFinder::MatchResult& result,
std::string_view key,
ContainerPointerRewritesMode mode) {
const clang::SourceManager& source_manager = *result.SourceManager;
const clang::LangOptions& lang_opts = result.Context->getLangOpts();
auto replacement_range = GetNodeOrCrash<clang::UnaryOperator>(
result, "unaryOperator", __FUNCTION__)
->getSourceRange();
replacement_range.setEnd(replacement_range.getBegin().getLocWithOffset(1));
const auto* subscript_expr =
GetNodeOrCrash<clang::Expr>(result, "subscript_expr", __FUNCTION__);
std::string_view declref_bind_name = "container_decl_ref";
std::string_view subspan_opener = ").subspan(";
if (mode == ContainerPointerRewritesMode::kDontWrapWithBaseSpan) {
declref_bind_name = "rhs_expr";
subspan_opener = ".subspan(";
}
const auto& container_decl_ref =
*GetNodeOrCrash<clang::Expr>(result, declref_bind_name, __FUNCTION__);
const clang::SourceLocation left_bracket =
GetExprRange(container_decl_ref, source_manager, lang_opts).getEnd();
clang::SourceLocation right_bracket =
FindRightBracket(result, subscript_expr);
if (result.Nodes.getNodeAs<clang::IntegerLiteral>("zero_container_offset")) {
EmitReplacement(key, GetReplacementDirective(replacement_range, "",
*result.SourceManager));
replacement_range = {left_bracket, right_bracket.getLocWithOffset(1)};
EmitReplacement(key, GetReplacementDirective(replacement_range, "",
*result.SourceManager));
return right_bracket.getLocWithOffset(1);
}
if (mode == ContainerPointerRewritesMode::kWrapWithBaseSpan) {
const auto& contained_type = *GetNodeOrCrash<clang::QualType>(
result, "contained_type", __FUNCTION__);
EmitReplacement(
key,
GetReplacementDirective(
replacement_range,
llvm::formatv("base::span<{0}>(",
GetTypeAsString(contained_type, *result.Context)),
source_manager));
} else {
EmitReplacement(key,
GetReplacementDirective(
replacement_range,
" ", source_manager));
}
EmitReplacement(key, GetReplacementDirective(
{left_bracket, left_bracket.getLocWithOffset(1)},
std::string(subspan_opener), source_manager));
const clang::Expr* index = GetIndexExprForSubspan(result, subscript_expr);
assert(index);
RewriteExprForSubspan(index, result, key);
EmitReplacement(key, GetReplacementDirective(
{right_bracket, right_bracket.getLocWithOffset(1)},
")", source_manager));
return right_bracket.getLocWithOffset(1);
}
void EmitSingleVariableSpan(const std::string& key,
const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const auto& lang_opts = result.Context->getLangOpts();
const auto* expr =
result.Nodes.getNodeAs<clang::UnaryOperator>("address_expr");
const auto* operand_expr =
result.Nodes.getNodeAs<clang::Expr>("address_expr_operand");
if (!expr || !operand_expr) {
llvm::errs()
<< "\n"
"Error: EmitSingleVariableSpan() encountered an unexpected match.\n";
DumpMatchResult(result);
assert(false && "Unexpected match in EmitSingleVariableSpan()");
}
clang::SourceLocation ampersand_loc = expr->getOperatorLoc();
clang::SourceRange ampersand_range = {
ampersand_loc, clang::Lexer::getLocForEndOfToken(
ampersand_loc, 0u, source_manager, lang_opts)};
EmitReplacement(key, GetReplacementDirective(
ampersand_range, "base::span_from_ref(",
source_manager, kEmitSingleVariableSpanPrecedence));
EmitReplacement(
key, GetReplacementDirective(
GetExprRange(*operand_expr, source_manager, lang_opts).getEnd(),
")", source_manager, -kEmitSingleVariableSpanPrecedence));
}
void EmitUnsafeCxxMethodCall(const std::string& key,
const clang::CXXMemberCallExpr* member_call_expr,
const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const auto* method_decl = GetNodeOrCrash<clang::CXXMethodDecl>(
result, "unsafe_function_decl",
"`unsafe_function_call_expr` in clang::CXXMemberCallExpr implies "
"`unsafe_function_decl` in clang::CXXMethodDecl");
const UnsafeCxxMethodToMacro entry =
FindUnsafeCxxMethodToBeRewrittenToMacro(method_decl).value();
const clang::MemberExpr* member_expr =
clang::dyn_cast<clang::MemberExpr>(member_call_expr->getCallee());
assert(member_expr);
EmitReplacement(
key, GetReplacementDirective(
member_call_expr->getImplicitObjectArgument()->getBeginLoc(),
llvm::formatv("{0}(", entry.macro_name), source_manager));
const bool has_arg = member_call_expr->getNumArgs() > 0;
EmitReplacement(
key,
GetReplacementDirective(
clang::SourceRange(member_expr->getOperatorLoc(),
has_arg
? member_call_expr->getArg(0)->getBeginLoc()
: member_call_expr->getRParenLoc()),
has_arg ? ", " : "", source_manager));
EmitReplacement(key, GetIncludeDirective(
member_call_expr->getSourceRange(), source_manager,
kBaseAutoSpanificationHelperIncludePath));
}
void EmitUnsafeFreeFuncCall(const std::string& key,
const clang::CallExpr* call_expr,
const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const auto* function_decl = GetNodeOrCrash<clang::FunctionDecl>(
result, "unsafe_function_decl",
"`unsafe_function_call_expr` implies `unsafe_function_decl`");
const UnsafeFreeFuncToMacro entry =
FindUnsafeFreeFuncToBeRewrittenToMacro(function_decl).value();
const clang::SourceLocation& func_loc = call_expr->getCallee()->getBeginLoc();
EmitReplacement(
key, GetReplacementDirective(
clang::SourceRange(func_loc, func_loc.getLocWithOffset(
entry.function_name.length())),
std::string(entry.macro_name), source_manager));
EmitReplacement(
key, GetIncludeDirective(call_expr->getSourceRange(), source_manager,
kBaseAutoSpanificationHelperIncludePath));
}
void EmitUnsafeFunctionCall(const std::string& key,
const clang::CallExpr* call_expr,
const MatchFinder::MatchResult& result) {
if (const clang::CXXMemberCallExpr* member_call_expr =
clang::dyn_cast<clang::CXXMemberCallExpr>(call_expr)) {
EmitUnsafeCxxMethodCall(key, member_call_expr, result);
return;
}
EmitUnsafeFreeFuncCall(key, call_expr, result);
}
void EmitCArrayIterCallExpr(const std::string& key,
const clang::CallExpr* call_expr,
const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const clang::LangOptions& lang_opts = result.Context->getLangOpts();
const auto* func_decl =
clang::dyn_cast<clang::FunctionDecl>(call_expr->getCalleeDecl());
assert(func_decl);
const std::string& function_name = func_decl->getQualifiedNameAsString();
struct FuncMapping {
const std::string_view function_name;
const std::string_view replacement_function_name;
};
static constexpr FuncMapping func_mapping_table[] = {
{"std::begin", "base::SpanificationArrayBegin"},
{"std::end", "base::SpanificationArrayEnd"},
{"std::cbegin", "base::SpanificationArrayCBegin"},
{"std::cend", "base::SpanificationArrayCEnd"},
};
std::string replacement_function_name;
for (const auto& entry : func_mapping_table) {
if (function_name == entry.function_name) {
replacement_function_name = entry.replacement_function_name;
break;
}
}
assert(!replacement_function_name.empty());
const clang::SourceRange replacement_range(
call_expr->getCallee()->getBeginLoc(),
clang::Lexer::getLocForEndOfToken(call_expr->getCallee()->getEndLoc(), 0u,
source_manager, lang_opts));
EmitReplacement(
key, GetReplacementDirective(replacement_range, replacement_function_name,
source_manager));
EmitReplacement(key,
GetIncludeDirective(replacement_range, source_manager,
kBaseAutoSpanificationHelperIncludePath));
}
std::string GetNodeFromSizeExpr(const clang::Expr* size_expr,
const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const std::string key = NodeKey(size_expr, source_manager);
if (const auto* unsafe_call_expr = result.Nodes.getNodeAs<clang::CallExpr>(
"unsafe_function_call_expr")) {
EmitUnsafeFunctionCall(key, unsafe_call_expr, result);
}
if (const auto* c_array_iter_call_expr =
result.Nodes.getNodeAs<clang::CallExpr>("c_array_iter_call_expr")) {
EmitCArrayIterCallExpr(key, c_array_iter_call_expr, result);
}
auto replacement_range =
clang::SourceRange(size_expr->getSourceRange().getBegin(),
size_expr->getSourceRange().getBegin());
if (const auto* nullptr_expr =
result.Nodes.getNodeAs<clang::CXXNullPtrLiteralExpr>(
"nullptr_expr")) {
clang::SourceRange nullptr_range = {
nullptr_expr->getBeginLoc(),
nullptr_expr->getBeginLoc().getLocWithOffset(7)};
EmitReplacement(
key, GetReplacementDirective(nullptr_range, "{}", source_manager));
} else if (result.Nodes.getNodeAs<clang::Expr>("address_expr")) {
EmitSingleVariableSpan(key, result);
}
if (result.Nodes.getNodeAs<clang::UnaryOperator>("container_buff_address")) {
EmitContainerPointerRewrites(
result, key, ContainerPointerRewritesMode::kWrapWithBaseSpan);
}
EmitReplacement(key, GetIncludeDirective(replacement_range, source_manager));
EmitSink(key);
return key;
}
void RewriteUnaryOperation(const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const auto& lang_opts = result.Context->getLangOpts();
const clang::Expr* operand = nullptr;
bool is_prefix = false;
clang::SourceLocation operator_loc;
if (const auto* unary_op =
result.Nodes.getNodeAs<clang::UnaryOperator>("unaryOperator")) {
operand = unary_op->getSubExpr();
is_prefix = unary_op->isPrefix();
operator_loc = unary_op->getOperatorLoc();
} else if (const auto* cxx_op_call =
result.Nodes.getNodeAs<clang::CXXOperatorCallExpr>(
"raw_ptr_operator++")) {
operand = cxx_op_call->getArg(0);
const auto* method_decl =
clang::dyn_cast<clang::CXXMethodDecl>(cxx_op_call->getCalleeDecl());
assert(method_decl);
is_prefix = (method_decl->getNumParams() == 0);
operator_loc = cxx_op_call->getOperatorLoc();
}
if (!operand) {
llvm::errs()
<< "\n"
<< "Error: RewriteUnaryOperation() encountered an unexpected match.\n"
<< "Expected a unaryOperator or raw_ptr_operator++ to be bound.\n";
DumpMatchResult(result);
assert(false && "Unexpected match in RewriteUnaryOperation()");
return;
}
assert(operator_loc.isValid());
clang::SourceRange operand_range =
GetExprRange(*operand->IgnoreParenImpCasts(), source_manager, lang_opts);
assert(operand_range.isValid());
clang::SourceLocation operator_end_loc = clang::Lexer::getLocForEndOfToken(
operator_loc, 0, source_manager, lang_opts);
assert(operator_end_loc.isValid());
clang::SourceRange op_token_range(operator_loc, operator_end_loc);
std::string begin_insert_text;
clang::SourceRange begin_replacement_range;
clang::SourceRange end_replacement_range;
if (is_prefix) {
begin_insert_text = "base::PreIncrementSpan(";
begin_replacement_range = op_token_range;
end_replacement_range =
clang::SourceRange(operand_range.getEnd(), operand_range.getEnd());
} else {
begin_insert_text = "base::PostIncrementSpan(";
begin_replacement_range =
clang::SourceRange(operand_range.getBegin(), operand_range.getBegin());
end_replacement_range = op_token_range;
}
assert(begin_replacement_range.isValid());
assert(end_replacement_range.isValid());
const std::string key = GetRHS(result);
EmitReplacement(key, GetReplacementDirective(
begin_replacement_range, begin_insert_text,
source_manager, kRewriteUnaryOperationPrecedence));
EmitReplacement(
key, GetReplacementDirective(end_replacement_range, ")", source_manager,
-kRewriteUnaryOperationPrecedence));
EmitReplacement(key,
GetIncludeDirective(operand_range, source_manager,
kBaseAutoSpanificationHelperIncludePath));
}
void RewriteArraySizeof(const MatchFinder::MatchResult& result) {
clang::SourceManager& source_manager = *result.SourceManager;
const auto* sizeof_expr =
result.Nodes.getNodeAs<clang::UnaryExprOrTypeTraitExpr>("sizeof_expr");
const std::string& array_decl_as_string =
result.Nodes.getNodeAs<clang::DeclaratorDecl>("rhs_begin")
->getNameAsString();
int end_offset = 1;
if (const auto* decl_ref = clang::dyn_cast_or_null<clang::DeclRefExpr>(
sizeof_expr->getArgumentExpr())) {
const clang::DeclarationNameInfo& name_info = decl_ref->getNameInfo();
const clang::DeclarationName& name = name_info.getName();
end_offset = name.getAsString().length();
}
const std::string& key = GetRHS(result);
const clang::SourceRange replacement_range = {
sizeof_expr->getBeginLoc(),
sizeof_expr->getEndLoc().getLocWithOffset(end_offset)};
EmitReplacement(key,
GetReplacementDirective(
replacement_range,
llvm::formatv("base::SpanificationSizeofForStdArray({0})",
array_decl_as_string),
source_manager));
EmitReplacement(key,
GetIncludeDirective(replacement_range, source_manager,
kBaseAutoSpanificationHelperIncludePath));
}
void AddSpanFrontierChange(const std::string& lhs_key,
const std::string& rhs_key,
const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const clang::ASTContext& ast_context = *result.Context;
const auto& lang_opts = ast_context.getLangOpts();
auto rep_range = clang::SourceRange(getSourceRange(result).getEnd());
std::string initial_text =
clang::Lexer::getSourceText(
clang::CharSourceRange::getCharRange(rep_range), source_manager,
lang_opts)
.str();
std::string replacement_text = ".data()";
if (result.Nodes.getNodeAs<clang::Expr>("unaryOperator")) {
auto begin_range = clang::SourceRange(getSourceRange(result).getBegin());
EmitFrontier(lhs_key, rhs_key,
GetReplacementDirective(begin_range, "(", source_manager,
kAppendDataCallPrecedence));
replacement_text = ").data()";
}
EmitFrontier(
lhs_key, rhs_key,
GetReplacementDirective(rep_range, replacement_text, source_manager,
-kAppendDataCallPrecedence));
}
std::string GenerateClassName(std::string var_name) {
const bool is_constant =
var_name.size() > 2 && var_name[0] == 'k' &&
std::isupper(static_cast<unsigned char>(var_name[1])) &&
var_name.find('_') == std::string::npos;
if (is_constant) {
var_name = var_name.substr(1);
}
char prev = '_';
for (char& c : var_name) {
if (prev == '_') {
c = llvm::toUpper(c);
}
prev = c;
}
llvm::erase(var_name, '_');
return var_name;
}
std::pair<std::string, std::string> maybeGetUnnamedAndDefinition(
const clang::QualType element_type,
const clang::DeclaratorDecl* array_decl,
const std::string& array_variable_as_string,
const clang::ASTContext& ast_context) {
std::string new_class_name_string;
std::string class_definition;
if (auto record_decl = element_type->getAsRecordDecl()) {
bool has_definition = array_decl->getSourceRange().fullyContains(
record_decl->getBraceRange());
bool is_unnamed = record_decl->getDeclName().isEmpty();
clang::DeclarationName original_name = record_decl->getDeclName();
clang::DeclarationName temporal_class_name;
if (is_unnamed) {
new_class_name_string = GenerateClassName(array_variable_as_string);
clang::StringRef new_class_name(new_class_name_string);
clang::IdentifierInfo& new_class_name_identifier =
ast_context.Idents.get(new_class_name);
temporal_class_name = ast_context.DeclarationNames.getIdentifier(
&new_class_name_identifier);
record_decl->setDeclName(temporal_class_name);
}
if (has_definition) {
const clang::SourceManager& source_manager =
ast_context.getSourceManager();
llvm::StringRef struct_body_with_braces = clang::Lexer::getSourceText(
clang::CharSourceRange::getTokenRange(record_decl->getBraceRange()),
source_manager, ast_context.getLangOpts());
if (is_unnamed) {
std::string type_keyword;
if (record_decl->isClass()) {
type_keyword = "class";
} else if (record_decl->isUnion()) {
type_keyword = "union";
} else if (record_decl->isEnum()) {
type_keyword = "enum";
} else {
assert(record_decl->isStruct());
type_keyword = "struct";
}
class_definition = type_keyword + " " + new_class_name_string + " " +
struct_body_with_braces.str() + ";\n";
} else {
clang::QualType unqualified_type = element_type.getUnqualifiedType();
std::string unqualified_type_str = unqualified_type.getAsString();
class_definition =
unqualified_type_str + " " + struct_body_with_braces.str() + ";\n";
}
}
if (is_unnamed) {
record_decl->setDeclName(original_name);
}
}
return std::make_pair(new_class_name_string, class_definition);
}
std::string GetArraySize(const clang::ArrayTypeLoc& array_type_loc,
const clang::SourceManager& source_manager,
const clang::ASTContext& ast_context) {
assert(!array_type_loc.isNull());
clang::SourceRange source_range(
array_type_loc.getLBracketLoc().getLocWithOffset(1),
array_type_loc.getRBracketLoc());
return clang::Lexer::getSourceText(
clang::CharSourceRange::getCharRange(source_range), source_manager,
ast_context.getLangOpts())
.str();
}
std::string RewriteCArrayToStdArray(const clang::QualType& type,
const clang::TypeLoc& type_loc,
const clang::SourceManager& source_manager,
const clang::ASTContext& ast_context) {
const clang::ArrayType* array_type = ast_context.getAsArrayType(type);
if (!array_type) {
return GetTypeAsString(type, ast_context);
}
const clang::ArrayTypeLoc& array_type_loc =
type_loc.getUnqualifiedLoc().getAs<clang::ArrayTypeLoc>();
assert(!array_type_loc.isNull());
const clang::QualType& element_type = array_type->getElementType();
const clang::TypeLoc& element_type_loc = array_type_loc.getElementLoc();
const std::string& element_type_as_string = RewriteCArrayToStdArray(
element_type, element_type_loc, source_manager, ast_context);
const std::string& size_as_string =
GetArraySize(array_type_loc, source_manager, ast_context);
std::ostringstream result;
result << "std::array<" << element_type_as_string << ", " << size_as_string
<< ">";
return result.str();
}
const clang::Expr* GetInitExpr(const clang::DeclaratorDecl* decl) {
const clang::Expr* init_expr = nullptr;
if (auto* var_decl = clang::dyn_cast_or_null<clang::VarDecl>(decl)) {
init_expr = var_decl->getInit();
} else if (auto* field_decl =
clang::dyn_cast_or_null<clang::FieldDecl>(decl)) {
init_expr = field_decl->getInClassInitializer();
}
return init_expr;
}
const clang::InitListExpr* GetArrayInitList(const clang::DeclaratorDecl* decl) {
const clang::Expr* init_expr = GetInitExpr(decl);
if (!init_expr) {
return nullptr;
}
const clang::InitListExpr* init_list_expr =
clang::dyn_cast_or_null<clang::InitListExpr>(init_expr);
if (init_list_expr) {
return init_list_expr;
}
const clang::ExprWithCleanups* expr_with_cleanups =
clang::dyn_cast_or_null<clang::ExprWithCleanups>(init_expr);
if (!expr_with_cleanups) {
return nullptr;
}
auto first_child = expr_with_cleanups->child_begin();
if (first_child == expr_with_cleanups->child_end()) {
return nullptr;
}
return clang::dyn_cast_or_null<clang::InitListExpr>(*first_child);
}
std::string GetStringViewType(const clang::QualType element_type,
const clang::ASTContext& ast_context) {
if (element_type->isCharType()) {
return "std::string_view";
}
if (element_type->isWideCharType()) {
return "std::wstring_view";
}
if (element_type->isChar8Type()) {
return "std::u8string_view";
}
if (element_type->isChar16Type()) {
return "std::u16string_view";
}
if (element_type->isChar32Type()) {
return "std::u32string_view";
}
clang::QualType element_type_without_qualifiers(element_type.getTypePtr(), 0);
return llvm::formatv(
"std::basic_string_view<{0}>",
GetTypeAsString(element_type_without_qualifiers, ast_context))
.str();
}
bool ShouldInsertTrailingComma(const clang::InitListExpr* init_list_expr,
const clang::SourceManager& source_manager) {
const int length =
source_manager.getFileOffset(init_list_expr->getRBraceLoc()) -
source_manager.getFileOffset(init_list_expr->getLBraceLoc());
if (init_list_expr->getNumInits() < 3 || length < 40) {
return false;
}
const clang::Expr* last_element =
init_list_expr->getInit(init_list_expr->getNumInits() - 1);
for (auto loc = last_element->getEndLoc().getLocWithOffset(1);
loc != init_list_expr->getRBraceLoc(); loc = loc.getLocWithOffset(1)) {
if (source_manager.getCharacterData(loc)[0] == ',') {
return false;
}
}
return true;
}
bool CanElideBracesForStdArrayInitialization(
const clang::InitListExpr* init_list_expr,
const clang::SourceManager& source_manager) {
for (const clang::Expr* expr : init_list_expr->inits()) {
const clang::SourceLocation& begin_loc = expr->getBeginLoc();
if (source_manager.getCharacterData(begin_loc)[0] == '{') {
return false;
}
}
return true;
}
std::pair<std::string, std::string> RewriteStdArrayWithInitList(
const clang::ArrayType* array_type,
const std::string& type,
const std::string& var,
const std::string& size,
const clang::InitListExpr* init_list_expr,
const clang::SourceManager& source_manager,
const clang::ASTContext& ast_context) {
bool needs_trailing_comma =
ShouldInsertTrailingComma(init_list_expr, source_manager);
clang::SourceRange init_list_closing_brackets_range = {
init_list_expr->getSourceRange().getEnd(),
init_list_expr->getSourceRange().getEnd().getLocWithOffset(1)};
if (size.empty()) {
auto closing_brackets_replacement_directive = GetReplacementDirective(
init_list_closing_brackets_range, needs_trailing_comma ? ",})" : "})",
source_manager);
return std::make_pair(
llvm::formatv("auto {0} = std::to_array<{1}>(", var, type),
closing_brackets_replacement_directive);
}
if (const auto* constant_array_type =
llvm::dyn_cast<clang::ConstantArrayType>(array_type)) {
if (init_list_expr->getNumInits() != 0 &&
constant_array_type->getSize().getZExtValue() !=
init_list_expr->getNumInits()) {
const clang::SourceLocation& location = init_list_expr->getBeginLoc();
llvm::errs() << "Array and initializer list size mismatch in file "
<< source_manager.getFilename(location) << ":"
<< source_manager.getSpellingLineNumber(location) << "\n";
}
}
const bool elide_braces =
CanElideBracesForStdArrayInitialization(init_list_expr, source_manager);
if (elide_braces) {
return std::make_pair(
llvm::formatv("std::array<{0}, {1}> {2} = ", type, size, var), "");
}
auto closing_brackets_replacement_directive = GetReplacementDirective(
init_list_closing_brackets_range, needs_trailing_comma ? ",}}" : "}}",
source_manager);
return std::make_pair(
llvm::formatv("std::array<{0}, {1}> {2} = {{", type, size, var),
closing_brackets_replacement_directive);
}
bool IsMutable(const clang::DeclaratorDecl* decl) {
if (const auto* field_decl =
clang::dyn_cast_or_null<clang::FieldDecl>(decl)) {
return field_decl->isMutable();
}
return false;
}
bool IsConstexpr(const clang::DeclaratorDecl* decl) {
if (const auto* var_decl = clang::dyn_cast_or_null<clang::VarDecl>(decl)) {
return var_decl->isConstexpr();
}
return false;
}
bool IsInlineVarDecl(const clang::DeclaratorDecl* decl) {
if (const auto* var_decl = clang::dyn_cast_or_null<clang::VarDecl>(decl)) {
return var_decl->isInlineSpecified();
}
return false;
}
bool IsStaticLocalOrStaticStorageClass(const clang::DeclaratorDecl* decl) {
if (const auto* var_decl = clang::dyn_cast_or_null<clang::VarDecl>(decl)) {
return var_decl->isStaticLocal() ||
var_decl->getStorageClass() == clang::SC_Static;
}
return false;
}
std::string getNodeFromArrayDecl(const clang::TypeLoc* type_loc,
const clang::DeclaratorDecl* array_decl,
const clang::ArrayType* array_type,
const MatchFinder::MatchResult& result) {
clang::SourceManager& source_manager = *result.SourceManager;
const clang::ASTContext& ast_context = *result.Context;
auto proxy_node = NodeKeyFromRange(
clang::SourceRange(array_decl->getBeginLoc(), array_decl->getBeginLoc()),
*result.SourceManager);
EmitSink(proxy_node);
if (!result.Nodes.getNodeAs<clang::Expr>("unsafe_buffer_access")) {
return proxy_node;
}
std::string key = NodeKey(array_decl, *result.SourceManager);
EmitEdge(key, proxy_node);
EmitSource(key);
const clang::ArrayTypeLoc& array_type_loc =
type_loc->getUnqualifiedLoc().getAs<clang::ArrayTypeLoc>();
assert(!array_type_loc.isNull());
const std::string& array_variable_as_string = array_decl->getNameAsString();
const std::string& array_size_as_string =
GetArraySize(array_type_loc, source_manager, ast_context);
const clang::QualType& original_element_type = array_type->getElementType();
std::stringstream qualifier_string;
if (IsInlineVarDecl(array_decl)) {
qualifier_string << "inline ";
}
if (IsMutable(array_decl)) {
qualifier_string << "mutable ";
}
if (IsStaticLocalOrStaticStorageClass(array_decl)) {
qualifier_string << "static ";
}
if (IsConstexpr(array_decl)) {
qualifier_string << "constexpr ";
}
clang::QualType new_element_type = original_element_type;
new_element_type.removeLocalConst();
if (original_element_type.isConstant(ast_context) &&
!IsConstexpr(array_decl)) {
qualifier_string << "const ";
}
std::string element_type_as_string;
const auto& [unnamed_class, class_definition] = maybeGetUnnamedAndDefinition(
new_element_type, array_decl, array_variable_as_string, ast_context);
if (!unnamed_class.empty()) {
element_type_as_string = unnamed_class;
} else if (original_element_type->isElaboratedTypeSpecifier()) {
clang::PrintingPolicy printing_policy(ast_context.getLangOpts());
printing_policy.SuppressTagKeyword = 1;
printing_policy.SuppressUnwrittenScope = 1;
printing_policy.SuppressInlineNamespace = 1;
printing_policy.SuppressDefaultTemplateArgs = 1;
printing_policy.PrintAsCanonical = 1;
element_type_as_string = new_element_type.getAsString(printing_policy);
} else {
element_type_as_string = RewriteCArrayToStdArray(
new_element_type, array_type_loc.getElementLoc(), source_manager,
ast_context);
}
const clang::InitListExpr* init_list_expr = GetArrayInitList(array_decl);
const clang::StringLiteral* init_string_literal =
clang::dyn_cast_or_null<clang::StringLiteral>(GetInitExpr(array_decl));
clang::SourceRange replacement_range = {
array_decl->getSourceRange().getBegin(),
init_list_expr ? init_list_expr->getBeginLoc()
: type_loc->getSourceRange().getEnd().getLocWithOffset(1)};
const char* include_path = kArrayIncludePath;
std::string replacement_text;
std::string additional_replacement;
if (init_string_literal) {
assert(original_element_type->isAnyCharacterType());
if (original_element_type.isConstant(ast_context) ||
IsConstexpr(array_decl)) {
replacement_text = llvm::formatv(
"{0} {1}", GetStringViewType(new_element_type, ast_context),
array_variable_as_string);
include_path = kStringViewIncludePath;
} else {
replacement_range.setEnd(init_string_literal->getBeginLoc());
replacement_text = llvm::formatv(
"std::array<{0}, {1}> {2}{{", element_type_as_string,
!array_size_as_string.empty()
? array_size_as_string
: llvm::formatv("{0}", init_string_literal->getLength() +
1 ),
array_variable_as_string);
const clang::SourceLocation& end_of_string_literal =
init_string_literal
->getLocationOfByte(init_string_literal->getByteLength(),
source_manager, ast_context.getLangOpts(),
ast_context.getTargetInfo())
.getLocWithOffset(1);
EmitReplacement(key, GetReplacementDirective(
clang::SourceRange(end_of_string_literal), "}",
source_manager));
}
} else if (init_list_expr) {
auto replacements = RewriteStdArrayWithInitList(
array_type, element_type_as_string, array_variable_as_string,
array_size_as_string, init_list_expr, source_manager, ast_context);
replacement_text = replacements.first;
if (!replacements.second.empty()) {
EmitReplacement(key, replacements.second);
}
} else {
replacement_text =
llvm::formatv("std::array<{0}, {1}> {2}", element_type_as_string,
array_size_as_string, array_variable_as_string);
}
replacement_text =
class_definition + qualifier_string.str() + replacement_text;
EmitReplacement(key,
GetReplacementDirective(replacement_range, replacement_text,
source_manager));
EmitReplacement(
key, GetIncludeDirective(replacement_range, source_manager, include_path,
true));
return proxy_node;
}
std::string getArrayNode(bool is_lhs, const MatchFinder::MatchResult& result) {
std::string array_type_loc_id =
(is_lhs) ? "lhs_array_type_loc" : "rhs_array_type_loc";
std::string begin_id = (is_lhs) ? "lhs_begin" : "rhs_begin";
std::string array_type_id = (is_lhs) ? "lhs_array_type" : "rhs_array_type";
auto* type_loc = result.Nodes.getNodeAs<clang::TypeLoc>(array_type_loc_id);
if (auto* array_param =
result.Nodes.getNodeAs<clang::ParmVarDecl>(begin_id)) {
return getNodeFromFunctionArrayParameter(type_loc, array_param, result);
}
auto* array_decl = result.Nodes.getNodeAs<clang::DeclaratorDecl>(begin_id);
auto* array_type = result.Nodes.getNodeAs<clang::ArrayType>(array_type_id);
if (array_decl) {
return getNodeFromArrayDecl(type_loc, array_decl, array_type, result);
}
llvm::errs() << "\n"
"Error: getArrayNode() encountered an unexpected match.\n"
"Expected a clang::DeclaratorDecl \n";
DumpMatchResult(result);
assert(false && "Unexpected match in getArrayNode()");
}
void RewriteComparisonWithCArrayIter(const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const clang::CallExpr* call_expr = GetNodeOrCrash<clang::CallExpr>(
result, "c_array_iter_call_expr",
"std::c?{begin,end} for a C array is expected");
const std::string& lhs = GetLHS(result);
const std::string& rhs = NodeKey(call_expr, source_manager);
EmitCArrayIterCallExpr(rhs, call_expr, result);
EmitEdge(lhs, rhs);
EmitEdge(rhs, lhs);
}
void RewriteFunctionPointerType(const MatchFinder::MatchResult& result) {
const clang::VarDecl* lhs_var_decl = GetNodeOrCrash<clang::VarDecl>(
result, "lhs_funcptrvardecl",
"The rewriting target variable of function pointer type must be bound.");
clang::FunctionProtoTypeLoc lhs_func_proto_type_loc;
{
const clang::TypeLoc var_type_loc =
UnwrapTypedefTypeLoc(lhs_var_decl->getTypeSourceInfo()->getTypeLoc());
if (var_type_loc.getAs<clang::AutoTypeLoc>() ||
var_type_loc.getAs<clang::DecltypeTypeLoc>()) {
return;
}
const clang::PointerTypeLoc pointer_type_loc =
var_type_loc.getAs<clang::PointerTypeLoc>();
assert(pointer_type_loc && "Failed to get a PointerTypeLoc.");
clang::TypeLoc pointee_type_loc = pointer_type_loc.getPointeeLoc();
while (clang::ParenTypeLoc paren_type_loc =
pointee_type_loc.getAs<clang::ParenTypeLoc>()) {
pointee_type_loc = paren_type_loc.getInnerLoc();
}
lhs_func_proto_type_loc =
pointee_type_loc.getAs<clang::FunctionProtoTypeLoc>();
}
assert(lhs_func_proto_type_loc && "Failed to get a FunctionProtoTypeLoc.");
const std::string& rhs_key = GetRHS(result);
std::string lhs_key;
if (const clang::ParmVarDecl* rhs_parm_var_decl =
result.Nodes.getNodeAs<clang::ParmVarDecl>("rhs_begin")) {
const unsigned parm_index = rhs_parm_var_decl->getFunctionScopeIndex();
const clang::ParmVarDecl* lhs_parm_var_decl =
lhs_func_proto_type_loc.getParam(parm_index);
const clang::TypeLoc lhs_parm_type_loc = UnwrapTypedefTypeLoc(
lhs_parm_var_decl->getTypeSourceInfo()->getTypeLoc());
if (lhs_parm_type_loc.getAs<clang::ArrayTypeLoc>()) {
lhs_key = getNodeFromFunctionArrayParameter(&lhs_parm_type_loc,
lhs_parm_var_decl, result);
} else if (lhs_parm_type_loc.getAs<clang::PointerTypeLoc>()) {
lhs_key = getNodeFromDecl(lhs_parm_var_decl, result);
} else if (const clang::TemplateSpecializationTypeLoc lhs_raw_ptr_type_loc =
lhs_parm_type_loc
.getAs<clang::TemplateSpecializationTypeLoc>()) {
lhs_key = getNodeFromRawPtrTypeLoc(&lhs_raw_ptr_type_loc, result);
} else {
assert(false && "Unknown kind of clang::TypeLoc at `lhs_parm_type_loc`");
}
} else {
const clang::PointerTypeLoc lhs_return_type_loc =
lhs_func_proto_type_loc.getReturnLoc().getAs<clang::PointerTypeLoc>();
assert(lhs_return_type_loc);
lhs_key = getNodeFromPointerTypeLoc(&lhs_return_type_loc, result);
}
EmitEdge(lhs_key, rhs_key);
EmitEdge(rhs_key, lhs_key);
}
void RewriteFunctionParamAndReturnType(const MatchFinder::MatchResult& result) {
const clang::SourceManager& source_manager = *result.SourceManager;
const clang::FunctionDecl* fct_decl =
result.Nodes.getNodeAs<clang::FunctionDecl>("fct_decl");
const std::string& replacement_key = GetRHS(result);
std::string parm_or_return_id;
if (const clang::ParmVarDecl* parm_var_decl =
result.Nodes.getNodeAs<clang::ParmVarDecl>("rhs_begin")) {
parm_or_return_id = llvm::formatv("{0}-th parm type",
parm_var_decl->getFunctionScopeIndex());
} else {
parm_or_return_id = "return type";
}
const std::string& current_key =
NodeKey(fct_decl, source_manager, parm_or_return_id);
EmitEdge(current_key, replacement_key);
EmitEdge(replacement_key, current_key);
if (const clang::Decl* previous_decl = fct_decl->getPreviousDecl()) {
const std::string& previous_key =
NodeKey(previous_decl, source_manager, parm_or_return_id);
if (raw_ptr_plugin::isNodeInThirdPartyLocation(*previous_decl,
source_manager)) {
EmitEdge(current_key, previous_key);
} else {
EmitEdge(current_key, previous_key);
EmitEdge(previous_key, current_key);
}
}
if (const clang::CXXMethodDecl* method_decl =
clang::dyn_cast<clang::CXXMethodDecl>(fct_decl)) {
for (auto* overridden_method_decl : method_decl->overridden_methods()) {
const std::string& overridden_method_key =
NodeKey(overridden_method_decl, source_manager, parm_or_return_id);
if (raw_ptr_plugin::isNodeInThirdPartyLocation(*overridden_method_decl,
source_manager)) {
EmitEdge(current_key, overridden_method_key);
} else {
EmitEdge(current_key, overridden_method_key);
EmitEdge(overridden_method_key, current_key);
}
}
}
}
std::string GetLHS(const MatchFinder::MatchResult& result) {
if (auto* type_loc =
result.Nodes.getNodeAs<clang::PointerTypeLoc>("lhs_type_loc")) {
return getNodeFromPointerTypeLoc(type_loc, result);
}
if (auto* raw_ptr_type_loc =
result.Nodes.getNodeAs<clang::TemplateSpecializationTypeLoc>(
"lhs_raw_ptr_type_loc")) {
return getNodeFromRawPtrTypeLoc(raw_ptr_type_loc, result);
}
if (result.Nodes.getNodeAs<clang::TypeLoc>("lhs_array_type_loc")) {
return getArrayNode(true, result);
}
if (auto* lhs_begin =
result.Nodes.getNodeAs<clang::DeclaratorDecl>("lhs_begin")) {
return getNodeFromDecl(lhs_begin, result);
}
llvm::errs() << "\n"
"Error: getLHS() encountered an unexpected match.\n"
"Expected one of : \n"
" - lhs_type_loc\n"
" - lhs_raw_ptr_type_loc\n"
" - lhs_array_type_loc\n"
" - lhs_begin\n"
"\n";
DumpMatchResult(result);
assert(false && "Unexpected match in getLHS()");
}
void RemoveReinterpretCastExpr(const MatchFinder::MatchResult& result,
std::string_view node_key) {
auto* cast_expr =
result.Nodes.getNodeAs<clang::CXXReinterpretCastExpr>("reinterpret_cast");
if (!cast_expr) {
return;
}
const clang::SourceRange replacement_range = {
cast_expr->getBeginLoc(),
cast_expr->getAngleBrackets().getEnd().getLocWithOffset(1u)};
if (result.Nodes.getNodeAs<clang::QualType>("reinterpret_cast_to_bytes")) {
const bool target_type_is_const =
GetNodeOrCrash<clang::QualType>(
result, "target_type", "`reinterpret_cast` implies `target_type`")
->isConstQualified();
std::string replacement = target_type_is_const
? "base::as_byte_span"
: "base::as_writable_byte_span";
return EmitReplacement(
node_key, GetReplacementDirective(replacement_range, replacement,
*result.SourceManager));
}
}
std::string GetRHSImpl(const MatchFinder::MatchResult& result) {
if (auto* type_loc =
result.Nodes.getNodeAs<clang::PointerTypeLoc>("rhs_type_loc")) {
return getNodeFromPointerTypeLoc(type_loc, result);
}
if (auto* raw_ptr_type_loc =
result.Nodes.getNodeAs<clang::TemplateSpecializationTypeLoc>(
"rhs_raw_ptr_type_loc")) {
return getNodeFromRawPtrTypeLoc(raw_ptr_type_loc, result);
}
if (result.Nodes.getNodeAs<clang::TypeLoc>("rhs_array_type_loc")) {
return getArrayNode(false, result);
}
if (auto* rhs_begin =
result.Nodes.getNodeAs<clang::DeclaratorDecl>("rhs_begin")) {
return getNodeFromDecl(rhs_begin, result);
}
if (result.Nodes.getNodeAs<clang::CXXMemberCallExpr>("member_data_call")) {
clang::SourceManager& source_manager = *result.SourceManager;
const clang::MemberExpr* data_member_expr =
result.Nodes.getNodeAs<clang::MemberExpr>("data_member_expr");
const std::string key = NodeKey(data_member_expr, source_manager);
EmitSink(key);
EraseMemberCall(key, data_member_expr, source_manager);
return key;
}
if (const clang::Expr* size_expr =
result.Nodes.getNodeAs<clang::Expr>("size_node")) {
return GetNodeFromSizeExpr(size_expr, result);
}
llvm::errs() << "\n"
"Error: "
<< __FUNCTION__
<< " encountered an unexpected match.\n"
"Expected one of : \n"
" - rhs_type_loc\n"
" - rhs_raw_ptr_type_loc\n"
" - rhs_array_type_loc\n"
" - rhs_begin\n"
" - member_data_call\n"
" - size_node\n"
"\n";
DumpMatchResult(result);
assert(false);
}
std::string GetRHS(const MatchFinder::MatchResult& result) {
std::string node_key = GetRHSImpl(result);
RemoveReinterpretCastExpr(result, node_key);
return node_key;
}
void MatchAdjacency(const MatchFinder::MatchResult& result) {
std::string lhs = GetLHS(result);
std::string rhs = GetRHS(result);
if (result.Nodes.getNodeAs<clang::Expr>("span_frontier")) {
AddSpanFrontierChange(lhs, rhs, result);
}
EmitEdge(lhs, rhs);
}
raw_ptr_plugin::FilterFile PathsToExclude() {
std::vector<std::string> paths_to_exclude_lines;
paths_to_exclude_lines.insert(paths_to_exclude_lines.end(),
kSpanifyManualPathsToIgnore.begin(),
kSpanifyManualPathsToIgnore.end());
paths_to_exclude_lines.insert(paths_to_exclude_lines.end(),
kSeparateRepositoryPaths.begin(),
kSeparateRepositoryPaths.end());
return raw_ptr_plugin::FilterFile(paths_to_exclude_lines);
}
class ExprVisitor
: public clang::ast_matchers::internal::BoundNodesTreeBuilder::Visitor {
public:
void visitMatch(
const clang::ast_matchers::BoundNodes& BoundNodesView) override {
assert(expr_ == nullptr &&
"Encountered more than one expression with match id 'LHS'.");
expr_ = BoundNodesView.getNodeAs<clang::Expr>("LHS");
}
const clang::Expr* expr_ = nullptr;
};
const clang::Expr* FindLHSExpr(
clang::ast_matchers::internal::BoundNodesTreeBuilder& matches) {
ExprVisitor v;
matches.visitMatches(&v);
return v.expr_;
}
AST_MATCHER_P(clang::Expr,
binary_plus_or_minus_operation,
clang::ast_matchers::internal::Matcher<clang::Expr>,
InnerMatcher) {
auto bin_op_matcher = expr(ignoringParenCasts(
binaryOperation(anyOf(hasOperatorName("+"), hasOperatorName("-")),
hasLHS(expr(binaryOperation(anyOf(hasOperatorName("+"),
hasOperatorName("-"))))
.bind("LHS")))));
clang::ast_matchers::internal::BoundNodesTreeBuilder matches;
if (bin_op_matcher.matches(Node, Finder, &matches)) {
const clang::Expr* n = FindLHSExpr(matches);
auto matcher = binary_plus_or_minus_operation(InnerMatcher);
return matcher.matches(*n, Finder, Builder);
}
return InnerMatcher.matches(Node, Finder, Builder);
}
class Spanifier {
public:
explicit Spanifier(MatchFinder& finder) : match_finder_(finder) {
auto frontier_exclusions = anyOf(
isExpansionInSystemHeader(), raw_ptr_plugin::isInExternCContext(),
raw_ptr_plugin::isInThirdPartyLocation(),
raw_ptr_plugin::isInGeneratedLocation(),
raw_ptr_plugin::ImplicitFieldDeclaration(), isInExcludedMacroLocation(),
raw_ptr_plugin::isInLocationListedInFilterFile(&paths_to_exclude_));
auto exclusions = anyOf(
frontier_exclusions,
hasAncestor(cxxRecordDecl(anyOf(hasName("raw_ptr"), hasName("span")))));
auto non_auto_pointer_type = pointerType(pointee(qualType(unless(
anyOf(qualType(hasDeclaration(
cxxRecordDecl(raw_ptr_plugin::isAnonymousStructOrUnion()))),
hasUnqualifiedDesugaredType(
anyOf(functionType(), memberPointerType(), voidType())),
hasCanonicalType(
anyOf(asString("const char"), asString("const wchar_t"),
asString("const char8_t"), asString("const char16_t"),
asString("const char32_t"))))))));
auto pointer_type = type(anyOf(
non_auto_pointer_type,
autoType(hasDeducedType(anyOf(
qualType(non_auto_pointer_type),
decltypeType(hasUnderlyingType(qualType(non_auto_pointer_type)))))),
decltypeType(hasUnderlyingType(qualType(non_auto_pointer_type)))));
auto pointer_type_loc = pointerTypeLoc(optionally(
hasPointeeLoc(qualifiedTypeLoc().bind("qualified_type_loc"))));
auto raw_ptr_type = qualType(
hasDeclaration(classTemplateSpecializationDecl(hasName("raw_ptr"))));
auto raw_ptr_type_loc = templateSpecializationTypeLoc(loc(raw_ptr_type));
auto lhs_type_loc = anyOf(
hasType(pointer_type),
allOf(hasType(raw_ptr_type),
hasDescendant(raw_ptr_type_loc.bind("lhs_raw_ptr_type_loc"))),
hasTypeLoc(loc(qualType(arrayType().bind("lhs_array_type")))
.bind("lhs_array_type_loc")));
auto rhs_type_loc = anyOf(
hasType(pointer_type),
allOf(hasType(raw_ptr_type),
hasDescendant(raw_ptr_type_loc.bind("rhs_raw_ptr_type_loc"))),
hasTypeLoc(loc(qualType(arrayType())).bind("rhs_array_type_loc")));
auto lhs_field =
fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(lhs_type_loc),
unless(exclusions),
unless(hasParent(cxxRecordDecl(hasName("raw_ptr")))))
.bind("lhs_begin");
auto rhs_field =
fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(rhs_type_loc),
unless(exclusions),
unless(hasParent(cxxRecordDecl(hasName("raw_ptr")))))
.bind("rhs_begin");
auto lhs_var =
varDecl(lhs_type_loc, unless(anyOf(exclusions, hasExternalStorage())))
.bind("lhs_begin");
auto rhs_var =
varDecl(rhs_type_loc, unless(anyOf(exclusions, hasExternalStorage())))
.bind("rhs_begin");
auto lhs_param =
parmVarDecl(lhs_type_loc, unless(exclusions)).bind("lhs_begin");
auto rhs_param =
parmVarDecl(rhs_type_loc, unless(exclusions)).bind("rhs_begin");
auto exclude_literal_strings =
unless(returns(qualType(pointsTo(qualType(hasCanonicalType(
anyOf(asString("const char"), asString("const wchar_t"),
asString("const char8_t"), asString("const char16_t"),
asString("const char32_t"))))))));
auto rhs_call_expr = callExpr(callee(
functionDecl(hasReturnTypeLoc(pointer_type_loc.bind("rhs_type_loc")),
exclude_literal_strings, unless(exclusions))));
auto lhs_call_expr = callExpr(callee(
functionDecl(hasReturnTypeLoc(pointer_type_loc.bind("lhs_type_loc")),
exclude_literal_strings, unless(exclusions))));
auto lhs_expr = expr(anyOf(declRefExpr(to(anyOf(lhs_var, lhs_param))),
memberExpr(member(lhs_field)), lhs_call_expr));
auto buff_address_from_container =
unaryOperator(
hasOperatorName("&"),
hasUnaryOperand(anyOf(
cxxOperatorCallExpr(
callee(functionDecl(
hasName("operator[]"),
hasParent(cxxRecordDecl(hasMethod(hasName("size")))))),
hasDescendant(
declRefExpr(
to(varDecl(hasType(classTemplateSpecializationDecl(
hasTemplateArgument(
0, refersToType(qualType().bind(
"contained_type"))))))))
.bind("container_decl_ref")),
optionally(
hasDescendant(integerLiteral(equals(0u))
.bind("zero_container_offset"))))
.bind("subscript_expr"),
arraySubscriptExpr(
hasBase(
declRefExpr(to(varDecl(hasType(arrayType(hasElementType(
qualType().bind("contained_type")))))))
.bind("container_decl_ref")),
optionally(hasIndex(integerLiteral(equals(0u))
.bind("zero_container_offset"))))
.bind("subscript_expr"))))
.bind("unaryOperator");
auto member_data_call =
cxxMemberCallExpr(
callee(functionDecl(
hasName("data"),
hasParent(cxxRecordDecl(hasMethod(hasName("size")))))),
has(memberExpr().bind("data_member_expr")))
.bind("member_data_call");
auto has_std_array_type = hasType(hasCanonicalType(hasDeclaration(
classTemplateSpecializationDecl(hasName("::std::array")))));
auto single_var_span_exclusions =
unless(anyOf(exclusions, hasType(arrayType()), hasType(functionType()),
has_std_array_type));
auto buff_address_from_single_var =
unaryOperator(
hasOperatorName("&"),
hasUnaryOperand(anyOf(
declRefExpr(to(anyOf(varDecl(single_var_span_exclusions),
parmVarDecl(single_var_span_exclusions))))
.bind("address_expr_operand"),
memberExpr(member(fieldDecl(single_var_span_exclusions)))
.bind("address_expr_operand"))))
.bind("address_expr");
auto c_array_iter_call_expr =
callExpr(callee(functionDecl(matchesName("std::c?(begin|end)"))),
hasArgument(0, hasType(arrayType())))
.bind("c_array_iter_call_expr");
const auto reinterpret_cast_wrapper = optionally(hasParent(
cxxReinterpretCastExpr(
hasDestinationType(qualType(pointsTo(
qualType(anyOf(qualType(asString("uint8_t"))
.bind("reinterpret_cast_to_bytes"),
qualType(isAnyCharacter())
.bind("reinterpret_cast_to_bytes"),
qualType(isInteger())
.bind("reinterpret_cast_to_integral_type")))
.bind("target_type")))),
unless(isInExcludedMacroLocation()))
.bind("reinterpret_cast")));
auto size_node_matcher = expr(
anyOf(
member_data_call,
expr(anyOf(callExpr(callee(functionDecl(
unsafeFunctionToBeRewrittenToMacro())
.bind("unsafe_function_decl")))
.bind("unsafe_function_call_expr"),
c_array_iter_call_expr,
callExpr(callee(functionDecl(
hasReturnTypeLoc(pointer_type_loc),
anyOf(raw_ptr_plugin::isInThirdPartyLocation(),
isExpansionInSystemHeader(),
raw_ptr_plugin::isInExternCContext())))),
cxxNullPtrLiteralExpr().bind("nullptr_expr"),
cxxNewExpr(),
expr(buff_address_from_container)
.bind("container_buff_address"),
buff_address_from_single_var))
.bind("size_node")),
reinterpret_cast_wrapper);
auto rhs_expr =
expr(ignoringParenCasts(anyOf(
declRefExpr(to(anyOf(rhs_var, rhs_param))).bind("declRefExpr"),
memberExpr(member(rhs_field)).bind("memberExpr"),
rhs_call_expr.bind("callExpr"))))
.bind("rhs_expr");
auto get_calls_on_raw_ptr = cxxMemberCallExpr(
callee(cxxMethodDecl(hasName("get"), ofClass(hasName("raw_ptr")))),
has(memberExpr(has(rhs_expr))));
auto index_into_pointer =
unaryOperator(
hasOperatorName("&"),
hasUnaryOperand(anyOf(
arraySubscriptExpr(
hasBase(declRefExpr(to(rhs_var)).bind("rhs_expr")),
optionally(hasIndex(integerLiteral(equals(0u))
.bind("zero_container_offset"))))
.bind("subscript_expr"),
cxxOperatorCallExpr(
callee(functionDecl(hasName("operator[]"))),
hasArgument(0, hasType(raw_ptr_type)),
hasDescendant(declRefExpr(to(rhs_var)).bind("rhs_expr")),
optionally(
hasDescendant(integerLiteral(equals(0u))
.bind("zero_container_offset"))))
.bind("subscript_expr"))))
.bind("unaryOperator");
auto rhs_exprs_without_size_nodes =
expr(ignoringParenCasts(anyOf(
rhs_expr,
binaryOperation(
binary_plus_or_minus_operation(
binaryOperation(hasLHS(rhs_expr), hasOperatorName("+"),
unless(isInExcludedMacroLocation()))),
hasRHS(expr(hasType(isInteger())).bind("binary_op_rhs")),
unless(hasParent(binaryOperation(
anyOf(hasOperatorName("+"), hasOperatorName("-"))))))
.bind("binaryOperator"),
unaryOperator(hasOperatorName("++"), hasUnaryOperand(rhs_expr))
.bind("unaryOperator"),
cxxOperatorCallExpr(
callee(cxxMethodDecl(ofClass(hasName("raw_ptr")))),
hasOperatorName("++"), hasArgument(0, rhs_expr))
.bind("raw_ptr_operator++"),
get_calls_on_raw_ptr,
expr(index_into_pointer).bind("container_buff_address"))),
reinterpret_cast_wrapper)
.bind("span_frontier");
auto rhs_expr_variations = expr(ignoringParenCasts(
anyOf(size_node_matcher, rhs_exprs_without_size_nodes)));
auto lhs_expr_variations = expr(ignoringParenCasts(lhs_expr));
auto unsafe_buffer_access = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
expr(ignoringParenCasts(anyOf(
arraySubscriptExpr(hasLHS(lhs_expr_variations),
unless(isSafeArraySubscript())),
binaryOperation(
anyOf(hasOperatorName("+="), hasOperatorName("+")),
hasLHS(lhs_expr_variations),
hasRHS(expr(hasType(isInteger())))),
unaryOperator(hasOperatorName("++"),
hasUnaryOperand(lhs_expr_variations)),
cxxOperatorCallExpr(anyOf(hasOverloadedOperatorName("[]"),
hasOperatorName("++")),
hasArgument(0, lhs_expr_variations)))))
.bind("unsafe_buffer_access"));
Match(unsafe_buffer_access, [](const auto& result) {
EmitSource(GetLHS(result));
});
auto sizeof_array_expr = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
sizeOfExpr(has(rhs_exprs_without_size_nodes)).bind("sizeof_expr"));
Match(sizeof_array_expr, RewriteArraySizeof);
auto deref_expression = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
expr(anyOf(unaryOperator(hasOperatorName("*"),
hasUnaryOperand(rhs_exprs_without_size_nodes)),
cxxOperatorCallExpr(
hasOverloadedOperatorName("*"),
hasArgument(0, rhs_exprs_without_size_nodes))),
unless(isInExcludedMacroLocation()))
.bind("deref_expr"));
Match(deref_expression, DecaySpanToPointer);
auto boolean_op_operand =
traverse(clang::TK_IgnoreUnlessSpelledInSource,
expr(rhs_exprs_without_size_nodes).bind("boolean_op_operand"));
auto raw_ptr_op_bool_call_expr =
cxxMemberCallExpr(on(boolean_op_operand),
callee(cxxMethodDecl(hasName("operator bool"),
ofClass(hasName("raw_ptr")))));
auto boolean_op = traverse(
clang::TK_AsIs,
expr(anyOf(implicitCastExpr(
hasCastKind(clang::CastKind::CK_PointerToBoolean),
hasSourceExpression(boolean_op_operand)),
implicitCastExpr(has(raw_ptr_op_bool_call_expr))),
optionally(hasParent(
unaryOperator(hasOperatorName("!")).bind("logical_not_op")))));
Match(boolean_op, DecaySpanToBooleanOp);
auto raw_ptr_get_call = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
cxxMemberCallExpr(
callee(cxxMethodDecl(hasName("get"), ofClass(hasName("raw_ptr")))),
has(memberExpr(has(rhs_expr)).bind("get_member_expr"))));
Match(raw_ptr_get_call, [](const MatchFinder::MatchResult& result) {
clang::SourceManager& source_manager = *result.SourceManager;
EraseMemberCall(
GetRHS(result),
result.Nodes.getNodeAs<clang::MemberExpr>("get_member_expr"),
source_manager);
});
auto buffer_to_external_func = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
expr(anyOf(
callExpr(
callee(functionDecl(
frontier_exclusions,
unless(matchesName(
"std::(size|c?r?begin|c?r?end|empty|swap|ranges::)")))),
forEachArgumentWithParam(expr(rhs_exprs_without_size_nodes),
parmVarDecl())),
cxxConstructExpr(
hasDeclaration(cxxConstructorDecl(frontier_exclusions)),
forEachArgumentWithParam(expr(rhs_exprs_without_size_nodes),
parmVarDecl())))));
Match(buffer_to_external_func, AppendDataCall);
auto unary_op = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
expr(ignoringParenCasts(anyOf(
unaryOperator(hasOperatorName("++"), hasUnaryOperand(rhs_expr))
.bind("unaryOperator"),
cxxOperatorCallExpr(
callee(cxxMethodDecl(ofClass(hasName("raw_ptr")))),
hasOperatorName("++"), hasArgument(0, rhs_expr))
.bind("raw_ptr_operator++"))))
.bind("unary_op"));
Match(unary_op, RewriteUnaryOperation);
auto binary_op =
traverse(clang::TK_IgnoreUnlessSpelledInSource,
expr(ignoringParenCasts(binaryOperation(
binary_plus_or_minus_operation(
binaryOperation(hasLHS(rhs_expr), hasOperatorName("+"),
hasRHS(expr(hasType(isInteger()))))
.bind("binary_operation")),
hasRHS(expr().bind("binary_op_rhs")),
unless(hasParent(binaryOperation(anyOf(
hasOperatorName("+"), hasOperatorName("-")))))))));
Match(binary_op, AdaptBinaryOperation);
auto binary_plus_eq_op = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
expr(ignoringParenCasts(binaryOperation(
hasLHS(rhs_expr), hasOperatorName("+="),
hasRHS(expr(hasType(isInteger())).bind("binary_op_RHS")))))
.bind("binary_plus_eq_op"));
Match(binary_plus_eq_op, AdaptBinaryPlusEqOperation);
auto assignment_relationship = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
binaryOperation(hasOperatorName("="),
hasOperands(lhs_expr_variations,
anyOf(rhs_expr_variations,
conditionalOperator(hasTrueExpression(
rhs_expr_variations)))),
unless(isExpansionInSystemHeader())));
Match(assignment_relationship, MatchAdjacency);
auto assignment_relationship2 = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
binaryOperation(hasOperatorName("="),
hasOperands(lhs_expr_variations,
conditionalOperator(hasFalseExpression(
rhs_expr_variations))),
unless(isExpansionInSystemHeader())));
Match(assignment_relationship2, MatchAdjacency);
auto var_construction = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
varDecl(
lhs_var,
has(expr(anyOf(
rhs_expr_variations,
conditionalOperator(hasTrueExpression(rhs_expr_variations)),
cxxConstructExpr(has(expr(anyOf(
rhs_expr_variations, conditionalOperator(hasTrueExpression(
rhs_expr_variations))))))))),
unless(isExpansionInSystemHeader())));
Match(var_construction, MatchAdjacency);
auto var_construction2 = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
varDecl(
lhs_var,
has(expr(anyOf(
conditionalOperator(hasFalseExpression(rhs_expr_variations)),
cxxConstructExpr(has(expr(conditionalOperator(
hasFalseExpression(rhs_expr_variations)))))))),
unless(isExpansionInSystemHeader())));
Match(var_construction2, MatchAdjacency);
auto equality_op =
traverse(clang::TK_IgnoreUnlessSpelledInSource,
binaryOperation(
anyOf(hasOperatorName("=="), hasOperatorName("!=")),
hasOperands(ignoringParenCasts(lhs_expr_variations),
ignoringParenCasts(c_array_iter_call_expr))));
Match(equality_op, RewriteComparisonWithCArrayIter);
auto returned_var_or_member = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
returnStmt(
hasReturnValue(expr(anyOf(
rhs_expr_variations,
conditionalOperator(hasTrueExpression(rhs_expr_variations))))),
unless(isExpansionInSystemHeader()),
forFunction(functionDecl(
hasReturnTypeLoc(pointer_type_loc.bind("lhs_type_loc")),
unless(exclusions))))
.bind("lhs_stmt"));
Match(returned_var_or_member, MatchAdjacency);
auto returned_var_or_member2 = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
returnStmt(hasReturnValue(conditionalOperator(
hasFalseExpression(rhs_expr_variations))),
unless(isExpansionInSystemHeader()),
forFunction(functionDecl(
hasReturnTypeLoc(pointer_type_loc.bind("lhs_type_loc")),
unless(exclusions))))
.bind("lhs_stmt"));
Match(returned_var_or_member2, MatchAdjacency);
auto ctor_initilizer = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
cxxCtorInitializer(withInitializer(anyOf(
cxxConstructExpr(has(expr(rhs_expr_variations))),
rhs_expr_variations)),
forField(lhs_field)));
Match(ctor_initilizer, MatchAdjacency);
auto var_passed_in_constructor = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
cxxConstructExpr(forEachArgumentWithParam(
expr(anyOf(
rhs_expr_variations,
conditionalOperator(hasTrueExpression(rhs_expr_variations)))),
lhs_param)));
Match(var_passed_in_constructor, MatchAdjacency);
auto var_passed_in_constructor2 = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
cxxConstructExpr(forEachArgumentWithParam(
expr(conditionalOperator(hasFalseExpression(rhs_expr_variations))),
lhs_param)));
Match(var_passed_in_constructor2, MatchAdjacency);
auto field_init = fieldDecl(lhs_field, has(rhs_expr_variations),
unless(isExpansionInSystemHeader()));
Match(field_init, MatchAdjacency);
auto var_passed_in_initlistExpr = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
initListExpr(raw_ptr_plugin::forEachInitExprWithFieldDecl(
expr(anyOf(
rhs_expr_variations,
conditionalOperator(hasTrueExpression(rhs_expr_variations)))),
lhs_field)));
Match(var_passed_in_initlistExpr, MatchAdjacency);
auto var_passed_in_initlistExpr2 = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
initListExpr(raw_ptr_plugin::forEachInitExprWithFieldDecl(
expr(conditionalOperator(hasFalseExpression(rhs_expr_variations))),
lhs_field)));
Match(var_passed_in_initlistExpr2, MatchAdjacency);
auto call_expr = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
callExpr(forEachArgumentWithParam(
expr(anyOf(rhs_expr_variations,
conditionalOperator(
hasTrueExpression(rhs_expr_variations)))),
lhs_param),
unless(isExpansionInSystemHeader()),
unless(cxxOperatorCallExpr(hasOperatorName("=")))));
Match(call_expr, MatchAdjacency);
auto fct_ptr_type = type(hasUnqualifiedDesugaredType(
pointerType(pointee(ignoringParens(functionProtoType())))));
auto fct_decl =
functionDecl(
eachOf(forEachParmVarDecl(rhs_param),
hasReturnTypeLoc(pointer_type_loc.bind("rhs_type_loc"))),
unless(exclusions))
.bind("fct_decl");
auto fct_decl_expr = expr(ignoringParenCasts(declRefExpr(to(fct_decl))));
auto fct_ptr_var_construction = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
varDecl(hasType(fct_ptr_type), has(fct_decl_expr), unless(exclusions))
.bind("lhs_funcptrvardecl"));
Match(fct_ptr_var_construction, RewriteFunctionPointerType);
auto fct_ptr_var_assignment = traverse(
clang::TK_IgnoreUnlessSpelledInSource,
binaryOperator(hasOperatorName("="),
hasLHS(declRefExpr(
to(varDecl(hasType(fct_ptr_type), unless(exclusions))
.bind("lhs_funcptrvardecl")))),
hasRHS(fct_decl_expr)));
Match(fct_ptr_var_assignment, RewriteFunctionPointerType);
auto fct_decls = traverse(clang::TK_IgnoreUnlessSpelledInSource, fct_decl);
Match(fct_decls, RewriteFunctionParamAndReturnType);
}
private:
class MatchCallback : public MatchFinder::MatchCallback {
public:
explicit MatchCallback(
std::function<void(const MatchFinder::MatchResult&)> callback)
: callback_(callback) {}
void run(const MatchFinder::MatchResult& result) override {
callback_(result);
}
private:
std::function<void(const MatchFinder::MatchResult&)> callback_;
};
template <typename Matcher>
void Match(const Matcher& matcher,
std::function<void(const MatchFinder::MatchResult&)> fn) {
auto match_callback = std::make_unique<MatchCallback>(std::move(fn));
match_finder_.addMatcher(matcher, match_callback.get());
match_callbacks_.push_back(std::move(match_callback));
}
raw_ptr_plugin::FilterFile paths_to_exclude_ = PathsToExclude();
MatchFinder& match_finder_;
std::vector<std::unique_ptr<MatchCallback>> match_callbacks_;
};
}
int main(int argc, const char* argv[]) {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmParser();
llvm::cl::OptionCategory category(
"spanifier: changes"
" 1- |T* var| to |base::span<T> var|."
" 2- |raw_ptr<T> var| to |base::raw_span<T> var|");
llvm::Expected<clang::tooling::CommonOptionsParser> options =
clang::tooling::CommonOptionsParser::create(argc, argv, category);
assert(static_cast<bool>(options));
clang::tooling::ClangTool tool(options->getCompilations(),
options->getSourcePathList());
MatchFinder match_finder;
Spanifier rewriter(match_finder);
std::unique_ptr<clang::tooling::FrontendActionFactory> factory =
clang::tooling::newFrontendActionFactory(&match_finder);
int result = tool.run(factory.get());
return result;
}