#include "CheckTraceVisitor.h"
#include <vector>
#include "Config.h"
#include "Edge.h"
#include "RecordInfo.h"
using namespace clang;
CheckTraceVisitor::CheckTraceVisitor(CXXMethodDecl* trace,
RecordInfo* info,
RecordCache* cache)
: trace_(trace), info_(info), cache_(cache) {}
bool CheckTraceVisitor::VisitMemberExpr(MemberExpr* member) {
if (IsWeakCallback()) {
if (FieldDecl* field = dyn_cast<FieldDecl>(member->getMemberDecl()))
FoundField(field, false);
}
return true;
}
bool CheckTraceVisitor::VisitCallExpr(CallExpr* call) {
if (IsWeakCallback())
return true;
Expr* callee = call->getCallee();
if (DependentScopeDeclRefExpr* expr =
dyn_cast<DependentScopeDeclRefExpr>(callee)) {
CheckDependentScopeDeclRefExpr(call, expr);
return true;
}
if (ImplicitCastExpr* expr = dyn_cast<ImplicitCastExpr>(callee)) {
if (CheckImplicitCastExpr(call, expr))
return true;
}
if ((call->getNumArgs() != 1) && (call->getNumArgs() != 2)) {
return true;
}
Expr* arg = call->getArg(0);
if (UnresolvedMemberExpr* expr = dyn_cast<UnresolvedMemberExpr>(callee)) {
if (CheckTraceBaseCall(call))
return true;
if (expr->getMemberName().getAsString() == kRegisterWeakMembersName)
MarkAllWeakMembersTraced();
QualType base = expr->getBaseType();
if (!base->isPointerType())
return true;
CXXRecordDecl* decl = base->getPointeeType()->getAsCXXRecordDecl();
if (decl)
CheckTraceFieldCall(expr->getMemberName().getAsString(), decl, arg,
call->getNumArgs() > 1 ? call->getArg(1) : nullptr);
return true;
}
if (CXXMemberCallExpr* expr = dyn_cast<CXXMemberCallExpr>(call)) {
if (CheckTraceFieldMemberCall(expr) || CheckRegisterWeakMembers(expr))
return true;
}
CheckTraceBaseCall(call);
return true;
}
bool CheckTraceVisitor::IsTraceCallName(const std::string& name) {
return name == trace_->getName();
}
CXXRecordDecl* CheckTraceVisitor::GetDependentTemplatedDecl(
DependentScopeDeclRefExpr* expr) {
NestedNameSpecifier qual = expr->getQualifier();
if (!qual)
return 0;
const Type* type = qual.getAsType();
if (!type)
return 0;
return RecordInfo::GetDependentTemplatedDecl(*type);
}
namespace {
class FindFieldVisitor : public RecursiveASTVisitor<FindFieldVisitor> {
public:
FindFieldVisitor();
FieldDecl* field() const;
bool TraverseMemberExpr(MemberExpr* member);
private:
FieldDecl* field_;
};
FindFieldVisitor::FindFieldVisitor() : field_(0) {}
FieldDecl* FindFieldVisitor::field() const {
return field_;
}
bool FindFieldVisitor::TraverseMemberExpr(MemberExpr* member) {
if (FieldDecl* field = dyn_cast<FieldDecl>(member->getMemberDecl())) {
field_ = field;
return false;
}
return true;
}
}
void CheckTraceVisitor::CheckDependentScopeDeclRefExpr(
CallExpr* call,
DependentScopeDeclRefExpr* expr) {
std::string fn_name = expr->getDeclName().getAsString();
if (NestedNameSpecifier qual = expr->getQualifier()) {
if (const Type* type = qual.getAsType()) {
if (const TemplateTypeParmType* tmpl_parm_type =
type->getAs<TemplateTypeParmType>()) {
const unsigned param_index = tmpl_parm_type->getIndex();
if (param_index >= info_->GetBases().size())
return;
info_->GetBases()[param_index].second.MarkTraced();
}
}
}
CXXRecordDecl* tmpl = GetDependentTemplatedDecl(expr);
if (!tmpl)
return;
if (call->getNumArgs() == 1 && IsTraceCallName(fn_name)) {
RecordInfo::Bases::iterator it = info_->GetBases().begin();
for (; it != info_->GetBases().end(); ++it) {
if (it->first->getName() == tmpl->getName())
it->second.MarkTraced();
}
}
if (call->getNumArgs() == 2 && fn_name == kTraceName &&
tmpl->getName() == kTraceIfNeededName) {
FindFieldVisitor finder;
finder.TraverseStmt(call->getArg(1));
if (finder.field())
FoundField(finder.field(), true);
}
}
bool CheckTraceVisitor::CheckTraceBaseCall(CallExpr* call) {
CXXRecordDecl* callee_record = nullptr;
std::string func_name;
if (MemberExpr* callee = dyn_cast<MemberExpr>(call->getCallee())) {
if (!callee->hasQualifier())
return false;
FunctionDecl* trace_decl =
dyn_cast<FunctionDecl>(callee->getMemberDecl());
if (!trace_decl || !Config::IsTraceMethod(trace_decl))
return false;
const Type* type = callee->getQualifier().getAsType();
if (!type)
return false;
callee_record = type->getAsCXXRecordDecl();
func_name = std::string(trace_decl->getName());
} else if (UnresolvedMemberExpr* callee =
dyn_cast<UnresolvedMemberExpr>(call->getCallee())) {
CXXMethodDecl* trace_decl = nullptr;
for (NamedDecl* named_decl : callee->decls()) {
if (CXXMethodDecl* method_decl = dyn_cast<CXXMethodDecl>(named_decl)) {
if (Config::IsTraceMethod(method_decl)) {
trace_decl = method_decl;
break;
}
}
}
if (!trace_decl)
return false;
if (call->getNumArgs() != 1)
return false;
DeclRefExpr* arg = dyn_cast<DeclRefExpr>(call->getArg(0));
if (!arg || arg->getNameInfo().getAsString() != kVisitorVarName)
return false;
callee_record = trace_decl->getParent();
func_name = callee->getMemberName().getAsString();
}
if (!callee_record)
return false;
if (!IsTraceCallName(func_name))
return false;
for (auto& base : info_->GetBases()) {
std::vector<CXXRecordDecl*> base_records;
base_records.push_back(base.first);
while (!base_records.empty()) {
CXXRecordDecl* base_record = base_records.back();
base_records.pop_back();
if (base_record == callee_record) {
base.second.MarkTraced();
return true;
}
if (RecordInfo* base_info = cache_->Lookup(base_record)) {
if (!base_info->RequiresTraceMethod()) {
for (auto& inner_base : base_info->GetBases())
base_records.push_back(inner_base.first);
}
}
}
}
return false;
}
bool CheckTraceVisitor::CheckTraceFieldMemberCall(CXXMemberCallExpr* call) {
return CheckTraceFieldCall(
call->getMethodDecl()->getNameAsString(), call->getRecordDecl(),
call->getArg(0), call->getNumArgs() > 1 ? call->getArg(1) : nullptr);
}
bool CheckTraceVisitor::CheckTraceFieldCall(const std::string& name,
CXXRecordDecl* callee,
Expr* arg1,
Expr* arg2) {
if (!Config::IsVisitor(callee->getName())) {
return false;
}
if (name == kTraceName || name == kTraceMultipleName) {
FindFieldVisitor finder;
finder.TraverseStmt(arg1);
if (finder.field()) {
FoundField(finder.field(), false);
}
return true;
}
if (name == kTraceEphemeronName) {
FindFieldVisitor finder1;
finder1.TraverseStmt(arg1);
if (finder1.field()) {
FoundField(finder1.field(), false);
}
assert(arg2);
FindFieldVisitor finder2;
finder2.TraverseStmt(arg2);
if (finder2.field()) {
FoundField(finder2.field(), false);
}
return true;
}
return false;
}
bool CheckTraceVisitor::CheckRegisterWeakMembers(CXXMemberCallExpr* call) {
CXXMethodDecl* fn = call->getMethodDecl();
if (fn->getName() != kRegisterWeakMembersName)
return false;
if (fn->isTemplateInstantiation()) {
const TemplateArgumentList& args =
*fn->getTemplateSpecializationInfo()->TemplateArguments;
if (args.size() > 1 &&
args[1].getKind() == TemplateArgument::Declaration) {
if (FunctionDecl* callback =
dyn_cast<FunctionDecl>(args[1].getAsDecl())) {
if (callback->hasBody()) {
CheckTraceVisitor nested_visitor(nullptr, info_, nullptr);
nested_visitor.TraverseStmt(callback->getBody());
}
}
}
}
return true;
}
bool CheckTraceVisitor::IsWeakCallback() const {
return !trace_;
}
void CheckTraceVisitor::MarkTraced(RecordInfo::Fields::iterator it) {
if (IsWeakCallback() && !it->second.edge()->IsWeakMember())
return;
it->second.MarkTraced();
}
void CheckTraceVisitor::MarkTracedIfNeeded(RecordInfo::Fields::iterator it) {
if (IsWeakCallback() && !it->second.edge()->IsWeakMember()) {
return;
}
it->second.MarkTracedIfNeeded();
}
namespace {
RecordInfo::Fields::iterator FindField(RecordInfo* info, FieldDecl* field) {
if (Config::IsTemplateInstantiation(info->record())) {
const std::string& name = field->getNameAsString();
for (RecordInfo::Fields::iterator it = info->GetFields().begin();
it != info->GetFields().end(); ++it) {
if (it->first->getNameAsString() == name) {
return it;
}
}
return info->GetFields().end();
} else {
return info->GetFields().find(field);
}
}
}
void CheckTraceVisitor::FoundField(FieldDecl* field, bool is_trace_if_needed) {
RecordInfo::Fields::iterator it = FindField(info_, field);
if (it != info_->GetFields().end()) {
if (is_trace_if_needed) {
MarkTracedIfNeeded(it);
} else {
MarkTraced(it);
}
}
}
void CheckTraceVisitor::MarkAllWeakMembersTraced() {
for (auto& field : info_->GetFields()) {
if (field.second.edge()->IsWeakMember())
field.second.MarkTraced();
}
}
bool CheckTraceVisitor::CheckImplicitCastExpr(CallExpr* call,
ImplicitCastExpr* expr) {
DeclRefExpr* sub_expr = dyn_cast<DeclRefExpr>(expr->getSubExpr());
if (!sub_expr)
return false;
NestedNameSpecifier qualifier = sub_expr->getQualifier();
if (!qualifier)
return false;
CXXRecordDecl* class_decl = qualifier.getAsRecordDecl();
if (!class_decl)
return false;
NamedDecl* found_decl = sub_expr->getFoundDecl();
std::string fn_name = found_decl->getNameAsString();
if (call->getNumArgs() == 2 && fn_name == kTraceName &&
class_decl->getName() == kTraceIfNeededName) {
FindFieldVisitor finder;
finder.TraverseStmt(call->getArg(1));
if (finder.field())
FoundField(finder.field(), true);
return true;
}
return false;
}
namespace {
FieldDecl* GetRangeField(CXXForRangeStmt* for_range_stmt) {
DeclStmt* decl_stmt = for_range_stmt->getRangeStmt();
if (!decl_stmt->isSingleDecl()) {
return nullptr;
}
VarDecl* var_decl = dyn_cast<VarDecl>(decl_stmt->getSingleDecl());
if (!var_decl) {
return nullptr;
}
MemberExpr* member_expr = dyn_cast<MemberExpr>(var_decl->getInit());
if (!member_expr) {
return nullptr;
}
FieldDecl* field_decl = dyn_cast<FieldDecl>(member_expr->getMemberDecl());
if (!field_decl) {
return nullptr;
}
return field_decl;
}
}
bool CheckTraceVisitor::VisitStmt(Stmt* stmt) {
CXXForRangeStmt* for_range = dyn_cast<CXXForRangeStmt>(stmt);
if (!for_range) {
return true;
}
FieldDecl* field_decl = GetRangeField(for_range);
if (!field_decl) {
return true;
}
RecordInfo::Fields::iterator it = FindField(info_, field_decl);
if (it == info_->GetFields().end()) {
return true;
}
Edge* field_edge = it->second.edge();
if (field_edge->IsArray()) {
MarkTraced(it);
}
if (field_edge->IsCollection()) {
Collection* collection = static_cast<Collection*>(field_edge);
if (collection->IsSTDCollection() &&
(collection->GetCollectionName() == "array")) {
MarkTraced(it);
}
}
return true;
}