diff --git a/clang/include/clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def b/clang/include/clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def index 3bf40f29b33c6..6af6ea2c2c338 100644 --- a/clang/include/clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def +++ b/clang/include/clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def @@ -24,6 +24,12 @@ #define WARNING_OPTIONAL_GADGET(name) WARNING_GADGET(name) #endif +/// A `WARNING_GADGET` subset, each of which corresponds to an unsafe +/// interaction with bounds-attributed constructs +#ifndef WARNING_BOUNDS_SAFETY_GADGET +#define WARNING_BOUNDS_SAFETY_GADGET(name) WARNING_GADGET(name) +#endif + /// Safe gadgets correspond to code patterns that aren't unsafe but need to be /// properly recognized in order to emit correct warnings and fixes over unsafe /// gadgets. @@ -39,8 +45,8 @@ WARNING_GADGET(UnsafeBufferUsageAttr) WARNING_GADGET(UnsafeBufferUsageCtorAttr) WARNING_GADGET(DataInvocation) // TO_UPSTREAM(BoundsSafety) ON -WARNING_GADGET(CountAttributedPointerArgument) -WARNING_GADGET(SinglePointerArgument) +WARNING_BOUNDS_SAFETY_GADGET(CountAttributedPointerArgument) +WARNING_BOUNDS_SAFETY_GADGET(SinglePointerArgument) // TO_UPSTREAM(BoundsSafety) OFF WARNING_OPTIONAL_GADGET(UnsafeLibcFunctionCall) WARNING_OPTIONAL_GADGET(SpanTwoParamConstructor) // Uses of `std::span(arg0, arg1)` @@ -58,4 +64,5 @@ FIXABLE_GADGET(PointerInit) #undef FIXABLE_GADGET #undef WARNING_GADGET #undef WARNING_OPTIONAL_GADGET +#undef WARNING_BOUNDS_SAFETY_GADGET #undef GADGET diff --git a/clang/lib/Analysis/UnsafeBufferUsage.cpp b/clang/lib/Analysis/UnsafeBufferUsage.cpp index 660a1300fe4f8..403eed83aef20 100644 --- a/clang/lib/Analysis/UnsafeBufferUsage.cpp +++ b/clang/lib/Analysis/UnsafeBufferUsage.cpp @@ -9,34 +9,36 @@ #include "clang/Analysis/Analyses/UnsafeBufferUsage.h" #include "clang/AST/APValue.h" #include "clang/AST/ASTContext.h" +#include "clang/AST/ASTTypeTraits.h" #include "clang/AST/Attr.h" #include "clang/AST/Decl.h" +#include "clang/AST/DeclCXX.h" #include "clang/AST/DynamicRecursiveASTVisitor.h" #include "clang/AST/Expr.h" +#include "clang/AST/ExprCXX.h" #include "clang/AST/FormatString.h" #include "clang/AST/OperationKinds.h" -#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/AST/ParentMapContext.h" #include "clang/AST/Stmt.h" #include "clang/AST/StmtVisitor.h" #include "clang/AST/Type.h" -#include "clang/ASTMatchers/ASTMatchFinder.h" -#include "clang/ASTMatchers/ASTMatchers.h" #include "clang/Basic/SourceLocation.h" #include "clang/Lex/Lexer.h" #include "clang/Lex/Preprocessor.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include -#include +#include #include -#include #include +#include using namespace llvm; using namespace clang; -using namespace ast_matchers; #ifndef NDEBUG namespace { @@ -73,7 +75,7 @@ static std::string getDREAncestorString(const DeclRefExpr *DRE, if (StParents.size() > 1) return "unavailable due to multiple parents"; - if (StParents.size() == 0) + if (StParents.empty()) break; St = StParents.begin()->get(); if (St) @@ -81,10 +83,47 @@ static std::string getDREAncestorString(const DeclRefExpr *DRE, } while (St); return SS.str(); } + } // namespace #endif /* NDEBUG */ -namespace clang::ast_matchers { +namespace { +// Using a custom `FastMatcher` instead of ASTMatchers to achieve better +// performance. FastMatcher uses simple function `matches` to find if a node +// is a match, avoiding the dependency on the ASTMatchers framework which +// provide a nice abstraction, but incur big performance costs. +class FastMatcher { +public: + virtual bool matches(const DynTypedNode &DynNode, ASTContext &Ctx, + const UnsafeBufferUsageHandler &Handler) = 0; + virtual ~FastMatcher() = default; +}; + +class MatchResult { + +public: + MatchResult() = default; + + MatchResult(StringRef ID, const DynTypedNode &Node) { Nodes[ID] = Node; } + + template const T *getNodeAs(StringRef ID) const { + auto It = Nodes.find(ID); + if (It == Nodes.end()) { + return nullptr; + } + return It->second.get(); + } + + void addNode(StringRef ID, const DynTypedNode &Node) { Nodes[ID] = Node; } + +private: + llvm::StringMap Nodes; +}; + +using MatchResults = SmallVector; + +} // namespace + // A `RecursiveASTVisitor` that traverses all descendants of a given node "n" // except for those belonging to a different callable of "n". class MatchDescendantVisitor : public DynamicRecursiveASTVisitor { @@ -92,13 +131,12 @@ class MatchDescendantVisitor : public DynamicRecursiveASTVisitor { // Creates an AST visitor that matches `Matcher` on all // descendants of a given node "n" except for the ones // belonging to a different callable of "n". - MatchDescendantVisitor(const internal::DynTypedMatcher *Matcher, - internal::ASTMatchFinder *Finder, - internal::BoundNodesTreeBuilder *Builder, - internal::ASTMatchFinder::BindKind Bind, - const bool ignoreUnevaluatedContext) - : Matcher(Matcher), Finder(Finder), Builder(Builder), Bind(Bind), - Matches(false), ignoreUnevaluatedContext(ignoreUnevaluatedContext) { + MatchDescendantVisitor(ASTContext &Context, FastMatcher &Matcher, + bool FindAll, bool ignoreUnevaluatedContext, + const UnsafeBufferUsageHandler &NewHandler) + : Matcher(&Matcher), FindAll(FindAll), Matches(false), + ignoreUnevaluatedContext(ignoreUnevaluatedContext), + ActiveASTContext(&Context), Handler(&NewHandler) { ShouldVisitTemplateInstantiations = true; ShouldVisitImplicitCode = false; // TODO: let's ignore implicit code for now } @@ -109,7 +147,6 @@ class MatchDescendantVisitor : public DynamicRecursiveASTVisitor { Matches = false; if (const Stmt *StmtNode = DynNode.get()) { TraverseStmt(const_cast(StmtNode)); - *Builder = ResultBindings; return Matches; } return false; @@ -197,108 +234,174 @@ class MatchDescendantVisitor : public DynamicRecursiveASTVisitor { // Returns 'true' if traversal should continue after this function // returns, i.e. if no match is found or 'Bind' is 'BK_All'. template bool match(const T &Node) { - internal::BoundNodesTreeBuilder RecursiveBuilder(*Builder); - - if (Matcher->matches(DynTypedNode::create(Node), Finder, - &RecursiveBuilder)) { - ResultBindings.addMatch(RecursiveBuilder); + if (Matcher->matches(DynTypedNode::create(Node), *ActiveASTContext, + *Handler)) { Matches = true; - if (Bind != internal::ASTMatchFinder::BK_All) + if (!FindAll) return false; // Abort as soon as a match is found. } return true; } - const internal::DynTypedMatcher *const Matcher; - internal::ASTMatchFinder *const Finder; - internal::BoundNodesTreeBuilder *const Builder; - internal::BoundNodesTreeBuilder ResultBindings; - const internal::ASTMatchFinder::BindKind Bind; + FastMatcher *const Matcher; + // When true, finds all matches. When false, finds the first match and stops. + const bool FindAll; bool Matches; bool ignoreUnevaluatedContext; + ASTContext *ActiveASTContext; + const UnsafeBufferUsageHandler *Handler; }; // Because we're dealing with raw pointers, let's define what we mean by that. -static auto hasPointerType() { - return hasType(hasCanonicalType(pointerType())); +static bool hasPointerType(const Expr &E) { + return isa(E.getType().getCanonicalType()); } -static auto hasArrayType() { return hasType(hasCanonicalType(arrayType())); } - -AST_MATCHER(QualType, isCountAttributedType) { - return Node->isCountAttributedType(); +static bool isSinglePointerType(QualType Ty) { + return Ty->isSinglePointerType(); } -AST_MATCHER(QualType, isSinglePointerType) { - return Node->isSinglePointerType(); +static bool hasArrayType(const Expr &E) { + return isa(E.getType().getCanonicalType()); } -AST_MATCHER_P(Stmt, forEachDescendantEvaluatedStmt, internal::Matcher, - innerMatcher) { - const DynTypedMatcher &DTM = static_cast(innerMatcher); - - MatchDescendantVisitor Visitor(&DTM, Finder, Builder, ASTMatchFinder::BK_All, - true); - return Visitor.findMatch(DynTypedNode::create(Node)); +static void +forEachDescendantEvaluatedStmt(const Stmt *S, ASTContext &Ctx, + const UnsafeBufferUsageHandler &Handler, + FastMatcher &Matcher) { + MatchDescendantVisitor Visitor(Ctx, Matcher, /*FindAll=*/true, + /*ignoreUnevaluatedContext=*/true, Handler); + Visitor.findMatch(DynTypedNode::create(*S)); } -AST_MATCHER_P(Stmt, forEachDescendantStmt, internal::Matcher, - innerMatcher) { - const DynTypedMatcher &DTM = static_cast(innerMatcher); - - MatchDescendantVisitor Visitor(&DTM, Finder, Builder, ASTMatchFinder::BK_All, - false); - return Visitor.findMatch(DynTypedNode::create(Node)); +static void forEachDescendantStmt(const Stmt *S, ASTContext &Ctx, + const UnsafeBufferUsageHandler &Handler, + FastMatcher &Matcher) { + MatchDescendantVisitor Visitor(Ctx, Matcher, /*FindAll=*/true, + /*ignoreUnevaluatedContext=*/false, Handler); + Visitor.findMatch(DynTypedNode::create(*S)); } // Matches a `Stmt` node iff the node is in a safe-buffer opt-out region -AST_MATCHER_P(Stmt, notInSafeBufferOptOut, const UnsafeBufferUsageHandler *, - Handler) { +static bool notInSafeBufferOptOut(const Stmt &Node, + const UnsafeBufferUsageHandler *Handler) { return !Handler->isSafeBufferOptOut(Node.getBeginLoc()); } -AST_MATCHER_P(Stmt, ignoreUnsafeBufferInContainer, - const UnsafeBufferUsageHandler *, Handler) { +static bool +ignoreUnsafeBufferInContainer(const Stmt &Node, + const UnsafeBufferUsageHandler *Handler) { return Handler->ignoreUnsafeBufferInContainer(Node.getBeginLoc()); } -AST_MATCHER_P(Stmt, ignoreUnsafeLibcCall, const UnsafeBufferUsageHandler *, - Handler) { - if (Finder->getASTContext().getLangOpts().CPlusPlus) - return Handler->ignoreUnsafeBufferInLibcCall(Node.getBeginLoc()); - return true; /* Only warn about libc calls for C++ */ +// Finds any expression 'e' such that `OnResult` +// matches 'e' and 'e' is in an Unspecified Lvalue Context. +static void findStmtsInUnspecifiedLvalueContext( + const Stmt *S, const llvm::function_ref OnResult) { + if (const auto *CE = dyn_cast(S); + CE && CE->getCastKind() == CastKind::CK_LValueToRValue) + OnResult(CE->getSubExpr()); + if (const auto *BO = dyn_cast(S); + BO && BO->getOpcode() == BO_Assign) + OnResult(BO->getLHS()); } -AST_MATCHER_P(CastExpr, castSubExpr, internal::Matcher, innerMatcher) { - return innerMatcher.matches(*Node.getSubExpr(), Finder, Builder); -} +/// Note: Copied and modified from ASTMatchers. +/// Matches all arguments and their respective types for a \c CallExpr. +/// It is very similar to \c forEachArgumentWithParam but +/// it works on calls through function pointers as well. +/// +/// The difference is, that function pointers do not provide access to a +/// \c ParmVarDecl, but only the \c QualType for each argument. +/// +/// Given +/// \code +/// void f(int i); +/// int y; +/// f(y); +/// void (*f_ptr)(int) = f; +/// f_ptr(y); +/// \endcode +/// callExpr( +/// forEachArgumentWithParamType( +/// declRefExpr(to(varDecl(hasName("y")))), +/// qualType(isInteger()).bind("type) +/// )) +/// matches f(y) and f_ptr(y) +/// with declRefExpr(...) +/// matching int y +/// and qualType(...) +/// matching int +static void forEachArgumentWithParamType( + const CallExpr &Node, + const llvm::function_ref + OnParamAndArg) { + // The first argument of an overloaded member operator is the implicit object + // argument of the method which should not be matched against a parameter, so + // we skip over it here. + unsigned ArgIndex = 0; + if (const auto *CE = dyn_cast(&Node)) { + const auto *MD = dyn_cast_or_null(CE->getDirectCallee()); + if (MD && !MD->isExplicitObjectMemberFunction()) { + // This is an overloaded operator call. + // We need to skip the first argument, which is the implicit object + // argument of the method which should not be matched against a + // parameter. + ++ArgIndex; + } + } -// Matches a `UnaryOperator` whose operator is pre-increment: -AST_MATCHER(UnaryOperator, isPreInc) { - return Node.getOpcode() == UnaryOperator::Opcode::UO_PreInc; -} + const FunctionProtoType *FProto = nullptr; -// Returns a matcher that matches any expression 'e' such that `innerMatcher` -// matches 'e' and 'e' is in an Unspecified Lvalue Context. -static auto isInUnspecifiedLvalueContext(internal::Matcher innerMatcher) { - // clang-format off - return - expr(anyOf( - implicitCastExpr( - hasCastKind(CastKind::CK_LValueToRValue), - castSubExpr(innerMatcher)), - binaryOperator( - hasAnyOperatorName("="), - hasLHS(innerMatcher) - ) - )); - // clang-format on + if (const auto *Call = dyn_cast(&Node)) { + if (const auto *Value = + dyn_cast_or_null(Call->getCalleeDecl())) { + QualType QT = Value->getType().getCanonicalType(); + + // This does not necessarily lead to a `FunctionProtoType`, + // e.g. K&R functions do not have a function prototype. + if (QT->isFunctionPointerType()) + FProto = QT->getPointeeType()->getAs(); + + if (QT->isMemberFunctionPointerType()) { + const auto *MP = QT->getAs(); + assert(MP && "Must be member-pointer if its a memberfunctionpointer"); + FProto = MP->getPointeeType()->getAs(); + assert(FProto && + "The call must have happened through a member function " + "pointer"); + } + } + } + + unsigned ParamIndex = 0; + unsigned NumArgs = Node.getNumArgs(); + if (FProto && FProto->isVariadic()) + NumArgs = std::min(NumArgs, FProto->getNumParams()); + + const auto GetParamType = + [&FProto, &Node](unsigned int ParamIndex) -> std::optional { + if (FProto && FProto->getNumParams() > ParamIndex) { + return FProto->getParamType(ParamIndex); + } + const auto *FD = Node.getDirectCallee(); + if (FD && FD->getNumParams() > ParamIndex) { + return FD->getParamDecl(ParamIndex)->getType(); + } + return std::nullopt; + }; + + for (; ArgIndex < NumArgs; ++ArgIndex, ++ParamIndex) { + auto ParamType = GetParamType(ParamIndex); + if (ParamType) + OnParamAndArg(*ParamType, Node.getArg(ArgIndex)->IgnoreParenCasts()); + } } -// Returns a matcher that matches any expression `e` such that `InnerMatcher` -// matches `e` and `e` is in an Unspecified Pointer Context (UPC). -static internal::Matcher -isInUnspecifiedPointerContext(internal::Matcher InnerMatcher) { +// Finds any expression `e` such that `InnerMatcher` matches `e` and +// `e` is in an Unspecified Pointer Context (UPC). +static void findStmtsInUnspecifiedPointerContext( + const Stmt *S, llvm::function_ref InnerMatcher) { // A UPC can be // 1. an argument of a function call (except the callee has [[unsafe_...]] // attribute), or @@ -307,45 +410,58 @@ isInUnspecifiedPointerContext(internal::Matcher InnerMatcher) { // 4. the operand of a pointer subtraction operation // (i.e., computing the distance between two pointers); or ... - // clang-format off - auto CallArgMatcher = callExpr( + if (auto *CE = dyn_cast(S)) { + if (const auto *FnDecl = CE->getDirectCallee(); + FnDecl && FnDecl->hasAttr()) + return; forEachArgumentWithParamType( - InnerMatcher, - isAnyPointer() /* array also decays to pointer type*/), - unless(callee( - functionDecl(hasAttr(attr::UnsafeBufferUsage))))); - - auto CastOperandMatcher = - castExpr(anyOf(hasCastKind(CastKind::CK_PointerToIntegral), - hasCastKind(CastKind::CK_PointerToBoolean)), - castSubExpr(allOf(hasPointerType(), InnerMatcher))); - - auto CompOperandMatcher = - binaryOperator(hasAnyOperatorName("!=", "==", "<", "<=", ">", ">="), - eachOf(hasLHS(allOf(hasPointerType(), InnerMatcher)), - hasRHS(allOf(hasPointerType(), InnerMatcher)))); - - // A matcher that matches pointer subtractions: - auto PtrSubtractionMatcher = - binaryOperator(hasOperatorName("-"), - // Note that here we need both LHS and RHS to be - // pointer. Then the inner matcher can match any of - // them: - allOf(hasLHS(hasPointerType()), - hasRHS(hasPointerType())), - eachOf(hasLHS(InnerMatcher), - hasRHS(InnerMatcher))); - // clang-format on - - return stmt(anyOf(CallArgMatcher, CastOperandMatcher, CompOperandMatcher, - PtrSubtractionMatcher)); - // FIXME: any more cases? (UPC excludes the RHS of an assignment. For now we - // don't have to check that.) + *CE, [&InnerMatcher](QualType Type, const Expr *Arg) { + if (Type->isAnyPointerType()) + InnerMatcher(Arg); + }); + } + + if (auto *CE = dyn_cast(S)) { + if (CE->getCastKind() != CastKind::CK_PointerToIntegral && + CE->getCastKind() != CastKind::CK_PointerToBoolean) + return; + if (!hasPointerType(*CE->getSubExpr())) + return; + InnerMatcher(CE->getSubExpr()); + } + + // Pointer comparison operator. + if (const auto *BO = dyn_cast(S); + BO && (BO->getOpcode() == BO_EQ || BO->getOpcode() == BO_NE || + BO->getOpcode() == BO_LT || BO->getOpcode() == BO_LE || + BO->getOpcode() == BO_GT || BO->getOpcode() == BO_GE)) { + auto *LHS = BO->getLHS(); + if (hasPointerType(*LHS)) + InnerMatcher(LHS); + + auto *RHS = BO->getRHS(); + if (hasPointerType(*RHS)) + InnerMatcher(RHS); + } + + // Pointer subtractions. + if (const auto *BO = dyn_cast(S); + BO && BO->getOpcode() == BO_Sub && hasPointerType(*BO->getLHS()) && + hasPointerType(*BO->getRHS())) { + // Note that here we need both LHS and RHS to be + // pointer. Then the inner matcher can match any of + // them: + InnerMatcher(BO->getLHS()); + InnerMatcher(BO->getRHS()); + } + // FIXME: any more cases? (UPC excludes the RHS of an assignment. For now + // we don't have to check that.) } -// Returns a matcher that matches any expression 'e' such that `innerMatcher` -// matches 'e' and 'e' is in an unspecified untyped context (i.e the expression -// 'e' isn't evaluated to an RValue). For example, consider the following code: +// Finds statements in unspecified untyped context i.e. any expression 'e' such +// that `InnerMatcher` matches 'e' and 'e' is in an unspecified untyped context +// (i.e the expression 'e' isn't evaluated to an RValue). For example, consider +// the following code: // int *p = new int[4]; // int *q = new int[4]; // if ((p = q)) {} @@ -353,17 +469,23 @@ isInUnspecifiedPointerContext(internal::Matcher InnerMatcher) { // The expression `p = q` in the conditional of the `if` statement // `if ((p = q))` is evaluated as an RValue, whereas the expression `p = q;` // in the assignment statement is in an untyped context. -static internal::Matcher -isInUnspecifiedUntypedContext(internal::Matcher InnerMatcher) { +static void findStmtsInUnspecifiedUntypedContext( + const Stmt *S, llvm::function_ref InnerMatcher) { // An unspecified context can be // 1. A compound statement, // 2. The body of an if statement // 3. Body of a loop - auto CompStmt = compoundStmt(forEach(InnerMatcher)); - auto IfStmtThen = ifStmt(hasThen(InnerMatcher)); - auto IfStmtElse = ifStmt(hasElse(InnerMatcher)); + if (auto *CS = dyn_cast(S)) { + for (auto *Child : CS->body()) + InnerMatcher(Child); + } + if (auto *IfS = dyn_cast(S)) { + if (IfS->getThen()) + InnerMatcher(IfS->getThen()); + if (IfS->getElse()) + InnerMatcher(IfS->getElse()); + } // FIXME: Handle loop bodies. - return stmt(anyOf(CompStmt, IfStmtThen, IfStmtElse)); } namespace { @@ -695,46 +817,67 @@ bool isCompatibleWithCountExpr(const Expr *E, const Expr *ExpectedCountExpr, return Visitor.Visit(ExpectedCountExpr, E, /* hasBeenSubstituted*/ false); } +// Returns true iff `C` is a C++ nclass method call to the function +// `'ClassName'::'MethodName'` or `'ClassName'::Operator'MethodName'`: +static bool matchCXXMethodByName(const CallExpr *C, StringRef ClassName, + StringRef MethodName) { + const Decl *Callee = C->getCalleeDecl(); + + if (const auto *MethodDecl = dyn_cast(Callee)) { + if (!MethodDecl->getDeclName().isIdentifier() && + !MethodDecl->isOverloadedOperator()) + return false; + + if (MethodDecl->getDeclName().isIdentifier() && + MethodDecl->getName() != MethodName) + return false; + + if (MethodDecl->isOverloadedOperator()) { + StringRef Spelling = + getOperatorSpelling(MethodDecl->getOverloadedOperator()); + + if (Spelling != MethodName) + return false; + } + + if (const auto *RD = dyn_cast(MethodDecl->getParent())) { + if (!RD->getDeclName().isIdentifier()) + return false; + return RD->getQualifiedNameAsString() == ClassName; + } + } + return false; +} + // Returns if a pair of expressions contain method calls to .data()/.c_str() and // .size()/.size_bytes()/.length() that form a valid range. -bool isValidContainerRange(ASTContext &Context, const Expr *Data, - const Expr *Size, bool ArgInBytes, - bool ParamInBytes) { - auto MethodMatcher = [](StringRef ClassName, StringRef MethodName) { - return callee( - cxxMethodDecl(hasName(MethodName), ofClass(hasName(ClassName)))); - }; - - const auto *DataCall = selectFirst( - "e", - match(expr(ignoringParenImpCasts( - cxxMemberCallExpr( - anyOf(MethodMatcher("::std::array", "data"), - MethodMatcher("::std::basic_string", "c_str"), - MethodMatcher("::std::basic_string", "data"), - MethodMatcher("::std::basic_string_view", "data"), - MethodMatcher("::std::span", "data"), - MethodMatcher("::std::vector", "data"))) - .bind("e"))), - *Data, Context)); - if (!DataCall) +static bool isValidContainerRange(ASTContext &Context, const Expr *Data, + const Expr *Size, bool ArgInBytes, + bool ParamInBytes) { + const auto *DataCall = + dyn_cast(Data->IgnoreParenImpCasts()); + + if (!(DataCall && + (matchCXXMethodByName(DataCall, "std::array", "data") || + matchCXXMethodByName(DataCall, "std::basic_string", "c_str") || + matchCXXMethodByName(DataCall, "std::basic_string", "data") || + matchCXXMethodByName(DataCall, "std::basic_string_view", "data") || + matchCXXMethodByName(DataCall, "std::span", "data") || + matchCXXMethodByName(DataCall, "std::vector", "data")))) return false; - const auto *SizeCall = selectFirst( - "e", - match(expr(ignoringParenImpCasts( - cxxMemberCallExpr( - anyOf(MethodMatcher("::std::array", "size"), - MethodMatcher("::std::basic_string", "length"), - MethodMatcher("::std::basic_string", "size"), - MethodMatcher("::std::basic_string_view", "length"), - MethodMatcher("::std::basic_string_view", "size"), - MethodMatcher("::std::span", "size"), - MethodMatcher("::std::span", "size_bytes"), - MethodMatcher("::std::vector", "size"))) - .bind("e"))), - *Size, Context)); - if (!SizeCall) + const auto *SizeCall = + dyn_cast(Size->IgnoreParenImpCasts()); + + if (!(SizeCall && + (matchCXXMethodByName(SizeCall, "std::array", "size") || + matchCXXMethodByName(SizeCall, "std::basic_string", "length") || + matchCXXMethodByName(SizeCall, "std::basic_string", "size") || + matchCXXMethodByName(SizeCall, "std::basic_string_view", "length") || + matchCXXMethodByName(SizeCall, "std::basic_string_view", "size") || + matchCXXMethodByName(SizeCall, "std::span", "size") || + matchCXXMethodByName(SizeCall, "std::span", "size_bytes") || + matchCXXMethodByName(SizeCall, "std::vector", "size")))) return false; const Expr *DataObj = DataCall->getImplicitObjectArgument(); @@ -759,20 +902,28 @@ bool isValidContainerRange(ASTContext &Context, const Expr *Data, // Extract the extent `X` from `sp.first(X).data()` and friends. const Expr *extractExtentFromSubviewDataCall(ASTContext &Context, const Expr *E) { - auto ExtentMatcher = [](StringRef Name, unsigned N) { - return cxxMemberCallExpr( - callee(cxxMethodDecl(hasName(Name), ofClass(hasName("::std::span")))), - hasArgument(N, expr().bind("extent"))); + auto ExtentMatcher = [](const CXXMemberCallExpr *MCE, StringRef Name, + unsigned N) -> const Expr * { + if (matchCXXMethodByName(MCE, "std::span", Name) && MCE->getNumArgs() > N) + return MCE->getArg(N); + return nullptr; }; - auto SpanSubviewMatcher = - anyOf(ExtentMatcher("first", 0), ExtentMatcher("last", 0), - ExtentMatcher("subspan", 1)); - auto SpanDataMatcher = cxxMemberCallExpr( - callee(cxxMethodDecl(hasName("data"), ofClass(hasName("::std::span")))), - on(SpanSubviewMatcher)); - return selectFirst( - "extent", - match(expr(ignoringParenImpCasts(SpanDataMatcher)), *E, Context)); + + if (const auto *MCE = dyn_cast(E->IgnoreParenImpCasts())) { + if (!matchCXXMethodByName(MCE, "std::span", "data")) + return nullptr; + if (const auto *DataObj = MCE->getImplicitObjectArgument()) + if (const auto *DataObjMCE = + dyn_cast(DataObj->IgnoreParenImpCasts())) { + if (const auto *Extent = ExtentMatcher(DataObjMCE, "first", 0)) + return Extent; + if (const auto *Extent = ExtentMatcher(DataObjMCE, "last", 0)) + return Extent; + if (const auto *Extent = ExtentMatcher(DataObjMCE, "subspan", 1)) + return Extent; + } + } + return nullptr; } // Returns true iff `E` evaluates to `Val`. @@ -787,6 +938,16 @@ static bool hasIntegeralConstant(const Expr *E, uint64_t Val, ASTContext &Ctx) { return false; } +// Return `DRE` if `E` matches the form `&DRE`: +static const DeclRefExpr *tryGetAddressofDRE(const Expr *E) { + if (const auto *UO = dyn_cast(E->IgnoreParenImpCasts())) { + if (UO->getOpcode() != UnaryOperator::Opcode::UO_AddrOf) + return nullptr; + return dyn_cast(UO->getSubExpr()->IgnoreParenImpCasts()); + } + return nullptr; +} + // Checks if the argument passed to count-attributed pointer is one of the // following forms: // 0. `NULL/nullptr`, if the argument to dependent count/size is `0`. @@ -857,12 +1018,7 @@ static bool isCountAttributedPointerArgumentSafeImpl( } // check form 1-2: - auto AddressofDRE = expr(unaryOperator( - hasOperatorName("&"), - hasUnaryOperand(ignoringParenImpCasts(declRefExpr().bind("VarIdent"))))); - - if (auto *DRE = selectFirst( - "VarIdent", match(AddressofDRE, *PtrArgNoImp, Context))) { + if (auto *DRE = tryGetAddressofDRE(PtrArgNoImp)) { if (CountArg) { if (!isSizedBy) // form 1.a.: return hasIntegeralConstant(CountArg, 1, Context); @@ -996,29 +1152,25 @@ bool isSinglePointerArgumentSafe(ASTContext &Context, const Expr *Arg) { // Check form 1: { - auto AddrOfDREMatcher = expr( - unaryOperator(hasOperatorName("&"), - hasUnaryOperand(ignoringParenImpCasts(declRefExpr())))); - bool Matches = !match(AddrOfDREMatcher, *ArgNoImp, Context).empty(); - if (Matches) + if (tryGetAddressofDRE(ArgNoImp)) return true; } // Check form 2: { - // TODO: Add more classes. - auto HardenedClassNameMatcher = - anyOf(hasName("::std::array"), hasName("::std::basic_string"), - hasName("::std::basic_string_view"), hasName("::std::span"), - hasName("::std::vector")); - auto SubscriptOpMatcher = cxxOperatorCallExpr(callee(cxxMethodDecl( - hasName("operator[]"), ofClass(HardenedClassNameMatcher)))); - auto AddrOfMatcher = expr(unaryOperator( - hasOperatorName("&"), - hasUnaryOperand(ignoringParenImpCasts(SubscriptOpMatcher)))); - bool Matches = !match(AddrOfMatcher, *ArgNoImp, Context).empty(); - if (Matches) - return true; + if (const auto *UO = dyn_cast(ArgNoImp)) + if (UO->getOpcode() == UnaryOperator::Opcode::UO_AddrOf) { + const Expr *Operand = UO->getSubExpr()->IgnoreParenImpCasts(); + + if (const auto *OPCall = dyn_cast(Operand)) + // TODO: Add more classes. + if (matchCXXMethodByName(OPCall, "std::array", "[]") || + matchCXXMethodByName(OPCall, "std::basic_string", "[]") || + matchCXXMethodByName(OPCall, "std::basic_string_view", "[]") || + matchCXXMethodByName(OPCall, "std::span", "[]") || + matchCXXMethodByName(OPCall, "std::vector", "[]")) + return true; + } } // Check form 3: @@ -1074,8 +1226,7 @@ static bool areEqualIntegers(const Expr *E1, const Expr *E2, ASTContext &Ctx) { Expr::EvalResult ER1, ER2; // If both are constants: - if (E1->EvaluateAsInt(ER1, Ctx) && - E2->EvaluateAsInt(ER2, Ctx)) + if (E1->EvaluateAsInt(ER1, Ctx) && E2->EvaluateAsInt(ER2, Ctx)) return ER1.Val.getInt() == ER2.Val.getInt(); // Otherwise, they should have identical stmt kind: @@ -1116,14 +1267,13 @@ static bool areEqualIntegers(const Expr *E1, const Expr *E2, ASTContext &Ctx) { // pointer OR `std::span{(char*)p, n}`, where `p` is a // __sized_by(`n`) pointer. (This pattern is not in upstream, so try it // last to avoid possible conflicts.) -// // TO_UPSTREAM(BoundsSafetyInterop) OFF -AST_MATCHER(CXXConstructExpr, isSafeSpanTwoParamConstruct) { +// TO_UPSTREAM(BoundsSafetyInterop) OFF +static bool isSafeSpanTwoParamConstruct(const CXXConstructExpr &Node, + ASTContext &Ctx) { assert(Node.getNumArgs() == 2 && "expecting a two-parameter std::span constructor"); const Expr *Arg0 = Node.getArg(0)->IgnoreParenImpCasts(); const Expr *Arg1 = Node.getArg(1)->IgnoreParenImpCasts(); - ASTContext &Ctx = Finder->getASTContext(); - auto HaveEqualConstantValues = [&Ctx](const Expr *E0, const Expr *E1) { if (auto E0CV = E0->getIntegerConstantExpr(Ctx)) if (auto E1CV = E1->getIntegerConstantExpr(Ctx)) { @@ -1237,8 +1387,8 @@ AST_MATCHER(CXXConstructExpr, isSafeSpanTwoParamConstruct) { bool isArg0CastToBytePtrType = false; if (auto *CE = dyn_cast(Arg0)) { - if (auto DestTySize = Finder->getASTContext().getTypeSizeInCharsIfKnown( - Arg0Ty->getPointeeType())) { + if (auto DestTySize = + Ctx.getTypeSizeInCharsIfKnown(Arg0Ty->getPointeeType())) { if (!DestTySize->isOne()) return false; // If the destination pointee type is NOT of one byte // size, pattern match fails. @@ -1257,8 +1407,7 @@ AST_MATCHER(CXXConstructExpr, isSafeSpanTwoParamConstruct) { // must be one byte: if (CAT->isCountInBytes() && !isArg0CastToBytePtrType) { std::optional SizeOpt = - Finder->getASTContext().getTypeSizeInCharsIfKnown( - CAT->getPointeeType()); + Ctx.getTypeSizeInCharsIfKnown(CAT->getPointeeType()); if (!SizeOpt.has_value() || !SizeOpt->isOne()) return false; } @@ -1272,18 +1421,18 @@ AST_MATCHER(CXXConstructExpr, isSafeSpanTwoParamConstruct) { if (!ValuesOpt.has_value()) return false; return isCompatibleWithCountExpr(Arg1, CAT->getCountExpr(), MemberBase, - &*ValuesOpt, Finder->getASTContext()); + &*ValuesOpt, Ctx); } return isCompatibleWithCountExpr(Arg1, CAT->getCountExpr(), MemberBase, - /*DependentValues=*/nullptr, - Finder->getASTContext()); + /*DependentValues=*/nullptr, Ctx); } /* TO_UPSTREAM(BoundsSafetyInterop) OFF */ return false; } -AST_MATCHER(ArraySubscriptExpr, isSafeArraySubscript) { +static bool isSafeArraySubscript(const ArraySubscriptExpr &Node, + const ASTContext &Ctx) { // FIXME: Proper solution: // - refactor Sema::CheckArrayAccess // - split safe/OOB/unknown decision logic from diagnostics emitting code @@ -1298,7 +1447,7 @@ AST_MATCHER(ArraySubscriptExpr, isSafeArraySubscript) { ->getType() ->getUnqualifiedDesugaredType())) { limit = CATy->getLimitedSize(); - } else if (const auto *SLiteral = dyn_cast( + } else if (const auto *SLiteral = dyn_cast( Node.getBase()->IgnoreParenImpCasts())) { limit = SLiteral->getLength() + 1; } else { @@ -1308,7 +1457,7 @@ AST_MATCHER(ArraySubscriptExpr, isSafeArraySubscript) { Expr::EvalResult EVResult; const Expr *IndexExpr = Node.getIdx(); if (!IndexExpr->isValueDependent() && - IndexExpr->EvaluateAsInt(EVResult, Finder->getASTContext())) { + IndexExpr->EvaluateAsInt(EVResult, Ctx)) { llvm::APSInt ArrIdx = EVResult.Val.getInt(); // FIXME: ArrIdx.isNegative() we could immediately emit an error as that's a // bug @@ -1324,10 +1473,9 @@ AST_MATCHER(ArraySubscriptExpr, isSafeArraySubscript) { const Expr *RHS = BE->getRHS(); if ((!LHS->isValueDependent() && - LHS->EvaluateAsInt(EVResult, - Finder->getASTContext())) || // case: `n & e` + LHS->EvaluateAsInt(EVResult, Ctx)) || // case: `n & e` (!RHS->isValueDependent() && - RHS->EvaluateAsInt(EVResult, Finder->getASTContext()))) { // `e & n` + RHS->EvaluateAsInt(EVResult, Ctx))) { // `e & n` llvm::APSInt result = EVResult.Val.getInt(); if (result.isNonNegative() && result.getLimitedValue() < limit) return true; @@ -1337,22 +1485,43 @@ AST_MATCHER(ArraySubscriptExpr, isSafeArraySubscript) { return false; } -AST_MATCHER_P(CallExpr, hasNumArgs, unsigned, Num) { - return Node.getNumArgs() == Num; +static bool hasNumArgs(const CallExpr *CE, unsigned NumArgs) { + return CE->getNumArgs() == NumArgs; } -// Matches a function declaration if any parameter type or return type has -// bounds attributes. -AST_MATCHER(FunctionDecl, hasAnyBoundsAttributes) { - bool RetTyHasBoundsAttr = Node.getReturnType()->isBoundsAttributedType() || - Node.getReturnType()->isValueTerminatedType(); +static bool hasAnyBoundsAttributes(const FunctionDecl *FD) { + bool RetTyHasBoundsAttr = FD->getReturnType()->isBoundsAttributedType() || + FD->getReturnType()->isValueTerminatedType(); return RetTyHasBoundsAttr || - llvm::any_of(Node.parameters(), [](const ParmVarDecl *PVD) { + llvm::any_of(FD->parameters(), [](const ParmVarDecl *PVD) { return PVD->getType()->isBoundsAttributedType() || PVD->getType()->isValueTerminatedType(); }); } +// A pointer type expression is known to be null-terminated, if +// 1. it is a string literal or `PredefinedExpr` (e.g., `__func__`); +// 2. it has the form: E.c_str(), for any expression E of `std::string` type; +// 3. it has `__null_terminated` type +static bool isNullTermPointer(const Expr *Ptr, ASTContext &Ctx) { + if (isa(Ptr->IgnoreParenImpCasts())) + return true; + if (isa(Ptr->IgnoreParenImpCasts())) + return true; + if (auto *MCE = dyn_cast(Ptr->IgnoreParenImpCasts())) { + const CXXMethodDecl *MD = MCE->getMethodDecl(); + const CXXRecordDecl *RD = MCE->getRecordDecl()->getCanonicalDecl(); + + if (MD && RD && RD->isInStdNamespace()) + if (MD->getName() == "c_str" && RD->getName() == "basic_string") + return true; + } + if (auto *VTT = Ptr->getType().getTypePtr()->getAs()) { + return VTT->getTerminatorValue(Ctx).isZero(); + } + return false; +} + namespace libc_func_matchers { // Under `libc_func_matchers`, define a set of matchers that match unsafe // functions in libc and unsafe calls to them. @@ -1398,33 +1567,6 @@ struct LibcFunNamePrefixSuffixParser { } }; -// A pointer type expression is known to be null-terminated, if -// 1. it is a string literal or `PredefinedExpr` (e.g., `__func__`); -// 2. it has the form: E.c_str(), for any expression E of `std::string` type; -// 3. it has `__null_terminated` type -static bool isNullTermPointer(const Expr *Ptr, ASTContext &Ctx) { - if (isa(Ptr->IgnoreParenImpCasts())) - return true; - if (isa(Ptr->IgnoreParenImpCasts())) - return true; - if (auto *MCE = dyn_cast(Ptr->IgnoreParenImpCasts())) { - const CXXMethodDecl *MD = MCE->getMethodDecl(); - const CXXRecordDecl *RD = MCE->getRecordDecl()->getCanonicalDecl(); - - if (MD && RD && RD->isInStdNamespace()) - if (MD->getName() == "c_str" && RD->getName() == "basic_string") - return true; - } - if (auto *VTT = Ptr->getType().getTypePtr()->getAs()) { - return VTT->getTerminatorValue(Ctx).isZero(); - } - return false; -} - -AST_MATCHER(Expr, isNullTermPointer) { - return isNullTermPointer(&Node, Finder->getASTContext()); -} - // Return true iff at least one of following cases holds: // 1. Format string is a literal and there is an unsafe pointer argument // corresponding to an `s` specifier; @@ -1469,7 +1611,7 @@ static bool hasUnsafeFormatOrSArg(const CallExpr *Call, const Expr *&UnsafeArg, const Expr *Fmt = Call->getArg(FmtArgIdx); - if (auto *SL = dyn_cast(Fmt->IgnoreParenImpCasts())) { + if (auto *SL = dyn_cast(Fmt->IgnoreParenImpCasts())) { StringRef FmtStr; if (SL->getCharByteWidth() == 1) @@ -1509,7 +1651,7 @@ static bool hasUnsafeFormatOrSArg(const CallExpr *Call, const Expr *&UnsafeArg, // Note: For predefined prefix and suffix, see `LibcFunNamePrefixSuffixParser`. // The notation `CoreName[str/wcs]` means a new name obtained from replace // string "wcs" with "str" in `CoreName`. -AST_MATCHER(FunctionDecl, isPredefinedUnsafeLibcFunc) { +static bool isPredefinedUnsafeLibcFunc(const FunctionDecl &Node) { static std::unique_ptr> PredefinedNames = nullptr; if (!PredefinedNames) PredefinedNames = @@ -1616,7 +1758,7 @@ AST_MATCHER(FunctionDecl, isPredefinedUnsafeLibcFunc) { // Match a call to one of the `v*printf` functions taking `va_list`. We cannot // check safety for these functions so they should be changed to their // non-va_list versions. -AST_MATCHER(FunctionDecl, isUnsafeVaListPrintfFunc) { +static bool isUnsafeVaListPrintfFunc(const FunctionDecl &Node) { auto *II = Node.getIdentifier(); if (!II) @@ -1632,7 +1774,7 @@ AST_MATCHER(FunctionDecl, isUnsafeVaListPrintfFunc) { // Matches a call to one of the `sprintf` functions as they are always unsafe // and should be changed to `snprintf`. -AST_MATCHER(FunctionDecl, isUnsafeSprintfFunc) { +static bool isUnsafeSprintfFunc(const FunctionDecl &Node) { auto *II = Node.getIdentifier(); if (!II) @@ -1656,7 +1798,7 @@ AST_MATCHER(FunctionDecl, isUnsafeSprintfFunc) { // Match function declarations of `printf`, `fprintf`, `snprintf` and their wide // character versions. Calls to these functions can be safe if their arguments // are carefully made safe. -AST_MATCHER(FunctionDecl, isNormalPrintfFunc) { +static bool isNormalPrintfFunc(const FunctionDecl &Node) { auto *II = Node.getIdentifier(); if (!II) @@ -1680,9 +1822,8 @@ AST_MATCHER(FunctionDecl, isNormalPrintfFunc) { // Then if the format string is a string literal, this matcher matches when at // least one string argument is unsafe. If the format is not a string literal, // this matcher matches when at least one pointer type argument is unsafe. -AST_MATCHER_P(CallExpr, hasUnsafePrintfStringArg, - clang::ast_matchers::internal::Matcher, - UnsafeStringArgMatcher) { +static bool hasUnsafePrintfStringArg(const CallExpr &Node, ASTContext &Ctx, + MatchResult &Result, llvm::StringRef Tag) { // Determine what printf it is by examining formal parameters: const FunctionDecl *FD = Node.getDirectCallee(); @@ -1693,7 +1834,6 @@ AST_MATCHER_P(CallExpr, hasUnsafePrintfStringArg, if (NumParms < 1) return false; // possibly some user-defined printf function - ASTContext &Ctx = Finder->getASTContext(); QualType FirstParmTy = FD->getParamDecl(0)->getType(); if (!FirstParmTy->isPointerType()) @@ -1707,8 +1847,10 @@ AST_MATCHER_P(CallExpr, hasUnsafePrintfStringArg, // It is a fprintf: const Expr *UnsafeArg; - if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 1, Ctx, false)) - return UnsafeStringArgMatcher.matches(*UnsafeArg, Finder, Builder); + if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 1, Ctx, false)) { + Result.addNode(Tag, DynTypedNode::create(*UnsafeArg)); + return true; + } return false; } @@ -1719,8 +1861,10 @@ AST_MATCHER_P(CallExpr, hasUnsafePrintfStringArg, if (auto *II = FD->getIdentifier()) isKprintf = II->getName() == "kprintf"; - if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 0, Ctx, isKprintf)) - return UnsafeStringArgMatcher.matches(*UnsafeArg, Finder, Builder); + if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 0, Ctx, isKprintf)) { + Result.addNode(Tag, DynTypedNode::create(*UnsafeArg)); + return true; + } return false; } @@ -1732,17 +1876,20 @@ AST_MATCHER_P(CallExpr, hasUnsafePrintfStringArg, // second is an integer, it is a snprintf: const Expr *UnsafeArg; - if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 2, Ctx, false)) - return UnsafeStringArgMatcher.matches(*UnsafeArg, Finder, Builder); + if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 2, Ctx, false)) { + Result.addNode(Tag, DynTypedNode::create(*UnsafeArg)); + return true; + } return false; } } // We don't really recognize this "normal" printf, the only thing we // can do is to require all pointers to be null-terminated: - for (auto *Arg : Node.arguments()) - if (Arg->getType()->isPointerType() && !isNullTermPointer(Arg, Ctx)) - if (UnsafeStringArgMatcher.matches(*Arg, Finder, Builder)) - return true; + for (const auto *Arg : Node.arguments()) + if (Arg->getType()->isPointerType() && !isNullTermPointer(Arg, Ctx)) { + Result.addNode(Tag, DynTypedNode::create(*Arg)); + return true; + } return false; } @@ -1762,7 +1909,7 @@ AST_MATCHER_P(CallExpr, hasUnsafePrintfStringArg, // ptr := Constant-Array-DRE; // size:= any expression that has compile-time constant value equivalent to // sizeof (Constant-Array-DRE) -AST_MATCHER(CallExpr, hasUnsafeSnprintfBuffer) { +static bool hasUnsafeSnprintfBuffer(const CallExpr &Node, ASTContext &Ctx) { const FunctionDecl *FD = Node.getDirectCallee(); assert(FD && "It should have been checked that FD is non-null."); @@ -1783,12 +1930,11 @@ AST_MATCHER(CallExpr, hasUnsafeSnprintfBuffer) { return false; // not an snprintf call if (std::optional NumChars = - Finder->getASTContext().getTypeSizeInCharsIfKnown(FirstPteTy)) { + Ctx.getTypeSizeInCharsIfKnown(FirstPteTy)) { Buf = Buf->IgnoreParenImpCasts(); Size = Size->IgnoreParenImpCasts(); return !isHardcodedCountedByPointerArgumentSafe( - Finder->getASTContext(), Buf, Size, FirstParmTy.getTypePtr(), - NumChars->isOne(), false); + Ctx, Buf, Size, FirstParmTy.getTypePtr(), NumChars->isOne(), false); } return false; } @@ -1798,11 +1944,11 @@ AST_MATCHER(CallExpr, hasUnsafeSnprintfBuffer) { // unsafe, i.e. that is NOT in one of the following forms: // 1. `sp.data()` if the argument to dependent count is `sp.size()`. // 2. `sp.first(extent).data()` if `extent` is compatible with the `count`. -AST_MATCHER_P(CallExpr, forEachUnsafeCountAttributedPointerArgument, - internal::Matcher, ArgMatcher) { - ASTContext &Context = Finder->getASTContext(); - const CallExpr *Call = &Node; - +static bool forEachUnsafeCountAttributedPointerArgument(const CallExpr *Call, + MatchResults &Result, + StringRef ArgTag, + StringRef CallTag, + ASTContext &Ctx) { const FunctionDecl *FD = Call->getDirectCallee(); if (!FD) return false; @@ -1820,26 +1966,17 @@ AST_MATCHER_P(CallExpr, forEachUnsafeCountAttributedPointerArgument, continue; const Expr *Arg = Call->getArg(I); - if (!isCountAttributedPointerArgumentSafe(Context, CAT, Call, Arg)) { - BoundNodesTreeBuilder ArgMatches(*Builder); - if (ArgMatcher.matches(*Arg, Finder, &ArgMatches)) { - Builder->addMatch(ArgMatches); - Matched = true; - } + if (!isCountAttributedPointerArgumentSafe(Ctx, CAT, Call, Arg)) { + MatchResult &MR = Result.emplace_back(ArgTag, DynTypedNode::create(*Arg)); + + MR.addNode(CallTag, DynTypedNode::create(*Call)); + Matched = true; } } return Matched; } -// Matches iff the argument passed to __single pointer type is safe. -AST_MATCHER(Expr, isSinglePointerArgumentSafe) { - ASTContext &Context = Finder->getASTContext(); - return isSinglePointerArgumentSafe(Context, &Node); -} - -} // namespace clang::ast_matchers - namespace { // Because the analysis revolves around variables and their types, we'll need to // track uses of variables (aka DeclRefExprs). @@ -1865,11 +2002,6 @@ class Gadget { #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def" }; - /// Common type of ASTMatchers used for discovering gadgets. - /// Useful for implementing the static matcher() methods - /// that are expected from all non-abstract subclasses. - using Matcher = decltype(stmt()); - Gadget(Kind K) : K(K) {} Kind getKind() const { return K; } @@ -1946,7 +2078,10 @@ class FixableGadget : public Gadget { } }; -static auto toSupportedVariable() { return to(varDecl()); } +static bool isSupportedVariable(const DeclRefExpr &Node) { + const Decl *D = Node.getDecl(); + return D != nullptr && isa(D); +} using FixableGadgetList = std::vector>; using WarningGadgetList = std::vector>; @@ -1958,19 +2093,23 @@ class IncrementGadget : public WarningGadget { const UnaryOperator *Op; public: - IncrementGadget(const MatchFinder::MatchResult &Result) + IncrementGadget(const MatchResult &Result) : WarningGadget(Kind::Increment), - Op(Result.Nodes.getNodeAs(OpTag)) {} + Op(Result.getNodeAs(OpTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::Increment; } - static Matcher matcher() { - return stmt( - unaryOperator(hasOperatorName("++"), - hasUnaryOperand(ignoringParenImpCasts(hasPointerType()))) - .bind(OpTag)); + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + const auto *UO = dyn_cast(S); + if (!UO || !UO->isIncrementOp()) + return false; + if (!hasPointerType(*UO->getSubExpr()->IgnoreParenImpCasts())) + return false; + Result.addNode(OpTag, DynTypedNode::create(*UO)); + return true; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -1998,19 +2137,23 @@ class DecrementGadget : public WarningGadget { const UnaryOperator *Op; public: - DecrementGadget(const MatchFinder::MatchResult &Result) + DecrementGadget(const MatchResult &Result) : WarningGadget(Kind::Decrement), - Op(Result.Nodes.getNodeAs(OpTag)) {} + Op(Result.getNodeAs(OpTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::Decrement; } - static Matcher matcher() { - return stmt( - unaryOperator(hasOperatorName("--"), - hasUnaryOperand(ignoringParenImpCasts(hasPointerType()))) - .bind(OpTag)); + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + const auto *UO = dyn_cast(S); + if (!UO || !UO->isDecrementOp()) + return false; + if (!hasPointerType(*UO->getSubExpr()->IgnoreParenImpCasts())) + return false; + Result.addNode(OpTag, DynTypedNode::create(*UO)); + return true; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -2037,26 +2180,29 @@ class ArraySubscriptGadget : public WarningGadget { const ArraySubscriptExpr *ASE; public: - ArraySubscriptGadget(const MatchFinder::MatchResult &Result) + ArraySubscriptGadget(const MatchResult &Result) : WarningGadget(Kind::ArraySubscript), - ASE(Result.Nodes.getNodeAs(ArraySubscrTag)) {} + ASE(Result.getNodeAs(ArraySubscrTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::ArraySubscript; } - static Matcher matcher() { - // clang-format off - return stmt(arraySubscriptExpr( - hasBase(ignoringParenImpCasts( - anyOf(hasPointerType(), hasArrayType()))), - unless(anyOf( - isSafeArraySubscript(), - hasIndex( - anyOf(integerLiteral(equals(0)), arrayInitIndexExpr()) - ) - ))).bind(ArraySubscrTag)); - // clang-format on + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + const auto *ASE = dyn_cast(S); + if (!ASE) + return false; + const auto *const Base = ASE->getBase()->IgnoreParenImpCasts(); + if (!hasPointerType(*Base) && !hasArrayType(*Base)) + return false; + const auto *Idx = dyn_cast(ASE->getIdx()); + bool IsSafeIndex = (Idx && Idx->getValue().isZero()) || + isa(ASE->getIdx()); + if (IsSafeIndex || isSafeArraySubscript(*ASE, Ctx)) + return false; + Result.addNode(ArraySubscrTag, DynTypedNode::create(*ASE)); + return true; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -2087,29 +2233,40 @@ class PointerArithmeticGadget : public WarningGadget { const Expr *Ptr; // the pointer expression in `PA` public: - PointerArithmeticGadget(const MatchFinder::MatchResult &Result) + PointerArithmeticGadget(const MatchResult &Result) : WarningGadget(Kind::PointerArithmetic), - PA(Result.Nodes.getNodeAs(PointerArithmeticTag)), - Ptr(Result.Nodes.getNodeAs(PointerArithmeticPointerTag)) {} + PA(Result.getNodeAs(PointerArithmeticTag)), + Ptr(Result.getNodeAs(PointerArithmeticPointerTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::PointerArithmetic; } - static Matcher matcher() { - auto HasIntegerType = anyOf(hasType(isInteger()), hasType(enumType())); - auto PtrAtRight = - allOf(hasOperatorName("+"), - hasRHS(expr(hasPointerType()).bind(PointerArithmeticPointerTag)), - hasLHS(HasIntegerType)); - auto PtrAtLeft = - allOf(anyOf(hasOperatorName("+"), hasOperatorName("-"), - hasOperatorName("+="), hasOperatorName("-=")), - hasLHS(expr(hasPointerType()).bind(PointerArithmeticPointerTag)), - hasRHS(HasIntegerType)); - - return stmt(binaryOperator(anyOf(PtrAtLeft, PtrAtRight)) - .bind(PointerArithmeticTag)); + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + const auto *BO = dyn_cast(S); + if (!BO) + return false; + const auto *LHS = BO->getLHS(); + const auto *RHS = BO->getRHS(); + // ptr at left + if (BO->getOpcode() == BO_Add || BO->getOpcode() == BO_Sub || + BO->getOpcode() == BO_AddAssign || BO->getOpcode() == BO_SubAssign) { + if (hasPointerType(*LHS) && (RHS->getType()->isIntegerType() || + RHS->getType()->isEnumeralType())) { + Result.addNode(PointerArithmeticPointerTag, DynTypedNode::create(*LHS)); + Result.addNode(PointerArithmeticTag, DynTypedNode::create(*BO)); + return true; + } + } + // ptr at right + if (BO->getOpcode() == BO_Add && hasPointerType(*RHS) && + (LHS->getType()->isIntegerType() || LHS->getType()->isEnumeralType())) { + Result.addNode(PointerArithmeticPointerTag, DynTypedNode::create(*RHS)); + Result.addNode(PointerArithmeticTag, DynTypedNode::create(*BO)); + return true; + } + return false; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -2136,27 +2293,35 @@ class SpanTwoParamConstructorGadget : public WarningGadget { const CXXConstructExpr *Ctor; // the span constructor expression public: - SpanTwoParamConstructorGadget(const MatchFinder::MatchResult &Result) + SpanTwoParamConstructorGadget(const MatchResult &Result) : WarningGadget(Kind::SpanTwoParamConstructor), - Ctor(Result.Nodes.getNodeAs( - SpanTwoParamConstructorTag)) {} + Ctor(Result.getNodeAs(SpanTwoParamConstructorTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::SpanTwoParamConstructor; } - static Matcher matcher() { - auto HasTwoParamSpanCtorDecl = hasDeclaration( - cxxConstructorDecl(hasDeclContext(isInStdNamespace()), hasName("span"), - parameterCountIs(2))); - - return stmt(cxxConstructExpr(HasTwoParamSpanCtorDecl, - unless(isSafeSpanTwoParamConstruct())) - .bind(SpanTwoParamConstructorTag)); + static bool matches(const Stmt *S, ASTContext &Ctx, MatchResult &Result) { + const auto *CE = dyn_cast(S); + if (!CE) + return false; + const auto *CDecl = CE->getConstructor(); + const auto *CRecordDecl = CDecl->getParent(); + auto HasTwoParamSpanCtorDecl = + CRecordDecl->isInStdNamespace() && + CDecl->getDeclName().getAsString() == "span" && CE->getNumArgs() == 2; + if (!HasTwoParamSpanCtorDecl || isSafeSpanTwoParamConstruct(*CE, Ctx)) + return false; + Result.addNode(SpanTwoParamConstructorTag, DynTypedNode::create(*CE)); + return true; } - static Matcher matcher(const UnsafeBufferUsageHandler *Handler) { - return stmt(unless(ignoreUnsafeBufferInContainer(Handler)), matcher()); + static bool matches(const Stmt *S, ASTContext &Ctx, + const UnsafeBufferUsageHandler *Handler, + MatchResult &Result) { + if (ignoreUnsafeBufferInContainer(*S, Handler)) + return false; + return matches(S, Ctx, Result); } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -2189,23 +2354,35 @@ class PointerInitGadget : public FixableGadget { const DeclRefExpr *PtrInitRHS; // the RHS pointer expression in `PI` public: - PointerInitGadget(const MatchFinder::MatchResult &Result) + PointerInitGadget(const MatchResult &Result) : FixableGadget(Kind::PointerInit), - PtrInitLHS(Result.Nodes.getNodeAs(PointerInitLHSTag)), - PtrInitRHS(Result.Nodes.getNodeAs(PointerInitRHSTag)) {} + PtrInitLHS(Result.getNodeAs(PointerInitLHSTag)), + PtrInitRHS(Result.getNodeAs(PointerInitRHSTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::PointerInit; } - static Matcher matcher() { - auto PtrInitStmt = declStmt(hasSingleDecl( - varDecl(hasInitializer(ignoringImpCasts( - declRefExpr(hasPointerType(), toSupportedVariable()) - .bind(PointerInitRHSTag)))) - .bind(PointerInitLHSTag))); - - return stmt(PtrInitStmt); + static bool matches(const Stmt *S, + llvm::SmallVectorImpl &Results) { + const DeclStmt *DS = dyn_cast(S); + if (!DS || !DS->isSingleDecl()) + return false; + const VarDecl *VD = dyn_cast(DS->getSingleDecl()); + if (!VD) + return false; + const Expr *Init = VD->getAnyInitializer(); + if (!Init) + return false; + const auto *DRE = dyn_cast(Init->IgnoreImpCasts()); + if (!DRE || !hasPointerType(*DRE) || !isSupportedVariable(*DRE)) { + return false; + } + MatchResult R; + R.addNode(PointerInitLHSTag, DynTypedNode::create(*VD)); + R.addNode(PointerInitRHSTag, DynTypedNode::create(*DRE)); + Results.emplace_back(std::move(R)); + return true; } virtual std::optional @@ -2237,25 +2414,40 @@ class PtrToPtrAssignmentGadget : public FixableGadget { const DeclRefExpr *PtrRHS; // the RHS pointer expression in `PA` public: - PtrToPtrAssignmentGadget(const MatchFinder::MatchResult &Result) + PtrToPtrAssignmentGadget(const MatchResult &Result) : FixableGadget(Kind::PtrToPtrAssignment), - PtrLHS(Result.Nodes.getNodeAs(PointerAssignLHSTag)), - PtrRHS(Result.Nodes.getNodeAs(PointerAssignRHSTag)) {} + PtrLHS(Result.getNodeAs(PointerAssignLHSTag)), + PtrRHS(Result.getNodeAs(PointerAssignRHSTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::PtrToPtrAssignment; } - static Matcher matcher() { - auto PtrAssignExpr = binaryOperator( - allOf(hasOperatorName("="), - hasRHS(ignoringParenImpCasts( - declRefExpr(hasPointerType(), toSupportedVariable()) - .bind(PointerAssignRHSTag))), - hasLHS(declRefExpr(hasPointerType(), toSupportedVariable()) - .bind(PointerAssignLHSTag)))); - - return stmt(isInUnspecifiedUntypedContext(PtrAssignExpr)); + static bool matches(const Stmt *S, + llvm::SmallVectorImpl &Results) { + size_t SizeBefore = Results.size(); + findStmtsInUnspecifiedUntypedContext(S, [&Results](const Stmt *S) { + const auto *BO = dyn_cast(S); + if (!BO || BO->getOpcode() != BO_Assign) + return; + const auto *RHS = BO->getRHS()->IgnoreParenImpCasts(); + if (const auto *RHSRef = dyn_cast(RHS); + !RHSRef || !hasPointerType(*RHSRef) || + !isSupportedVariable(*RHSRef)) { + return; + } + const auto *LHS = BO->getLHS(); + if (const auto *LHSRef = dyn_cast(LHS); + !LHSRef || !hasPointerType(*LHSRef) || + !isSupportedVariable(*LHSRef)) { + return; + } + MatchResult R; + R.addNode(PointerAssignLHSTag, DynTypedNode::create(*LHS)); + R.addNode(PointerAssignRHSTag, DynTypedNode::create(*RHS)); + Results.emplace_back(std::move(R)); + }); + return SizeBefore != Results.size(); } virtual std::optional @@ -2286,26 +2478,41 @@ class CArrayToPtrAssignmentGadget : public FixableGadget { const DeclRefExpr *PtrRHS; // the RHS pointer expression in `PA` public: - CArrayToPtrAssignmentGadget(const MatchFinder::MatchResult &Result) + CArrayToPtrAssignmentGadget(const MatchResult &Result) : FixableGadget(Kind::CArrayToPtrAssignment), - PtrLHS(Result.Nodes.getNodeAs(PointerAssignLHSTag)), - PtrRHS(Result.Nodes.getNodeAs(PointerAssignRHSTag)) {} + PtrLHS(Result.getNodeAs(PointerAssignLHSTag)), + PtrRHS(Result.getNodeAs(PointerAssignRHSTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::CArrayToPtrAssignment; } - static Matcher matcher() { - auto PtrAssignExpr = binaryOperator( - allOf(hasOperatorName("="), - hasRHS(ignoringParenImpCasts( - declRefExpr(hasType(hasCanonicalType(constantArrayType())), - toSupportedVariable()) - .bind(PointerAssignRHSTag))), - hasLHS(declRefExpr(hasPointerType(), toSupportedVariable()) - .bind(PointerAssignLHSTag)))); - - return stmt(isInUnspecifiedUntypedContext(PtrAssignExpr)); + static bool matches(const Stmt *S, + llvm::SmallVectorImpl &Results) { + size_t SizeBefore = Results.size(); + findStmtsInUnspecifiedUntypedContext(S, [&Results](const Stmt *S) { + const auto *BO = dyn_cast(S); + if (!BO || BO->getOpcode() != BO_Assign) + return; + const auto *RHS = BO->getRHS()->IgnoreParenImpCasts(); + if (const auto *RHSRef = dyn_cast(RHS); + !RHSRef || + !isa(RHSRef->getType().getCanonicalType()) || + !isSupportedVariable(*RHSRef)) { + return; + } + const auto *LHS = BO->getLHS(); + if (const auto *LHSRef = dyn_cast(LHS); + !LHSRef || !hasPointerType(*LHSRef) || + !isSupportedVariable(*LHSRef)) { + return; + } + MatchResult R; + R.addNode(PointerAssignLHSTag, DynTypedNode::create(*LHS)); + R.addNode(PointerAssignRHSTag, DynTypedNode::create(*RHS)); + Results.emplace_back(std::move(R)); + }); + return SizeBefore != Results.size(); } virtual std::optional @@ -2329,23 +2536,32 @@ class UnsafeBufferUsageAttrGadget : public WarningGadget { const Expr *Op; public: - UnsafeBufferUsageAttrGadget(const MatchFinder::MatchResult &Result) + UnsafeBufferUsageAttrGadget(const MatchResult &Result) : WarningGadget(Kind::UnsafeBufferUsageAttr), - Op(Result.Nodes.getNodeAs(OpTag)) {} + Op(Result.getNodeAs(OpTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::UnsafeBufferUsageAttr; } - static Matcher matcher() { - auto HasUnsafeFieldDecl = - member(fieldDecl(hasAttr(attr::UnsafeBufferUsage))); - - auto HasUnsafeFnDecl = - callee(functionDecl(hasAttr(attr::UnsafeBufferUsage))); - - return stmt(anyOf(callExpr(HasUnsafeFnDecl).bind(OpTag), - memberExpr(HasUnsafeFieldDecl).bind(OpTag))); + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + if (auto *CE = dyn_cast(S)) { + if (CE->getDirectCallee() && + CE->getDirectCallee()->hasAttr()) { + Result.addNode(OpTag, DynTypedNode::create(*CE)); + return true; + } + } + if (auto *ME = dyn_cast(S)) { + if (!isa(ME->getMemberDecl())) + return false; + if (ME->getMemberDecl()->hasAttr()) { + Result.addNode(OpTag, DynTypedNode::create(*ME)); + return true; + } + } + return false; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -2366,22 +2582,24 @@ class UnsafeBufferUsageCtorAttrGadget : public WarningGadget { const CXXConstructExpr *Op; public: - UnsafeBufferUsageCtorAttrGadget(const MatchFinder::MatchResult &Result) + UnsafeBufferUsageCtorAttrGadget(const MatchResult &Result) : WarningGadget(Kind::UnsafeBufferUsageCtorAttr), - Op(Result.Nodes.getNodeAs(OpTag)) {} + Op(Result.getNodeAs(OpTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::UnsafeBufferUsageCtorAttr; } - static Matcher matcher() { - auto HasUnsafeCtorDecl = - hasDeclaration(cxxConstructorDecl(hasAttr(attr::UnsafeBufferUsage))); + static bool matches(const Stmt *S, ASTContext &Ctx, MatchResult &Result) { + const auto *CE = dyn_cast(S); + if (!CE || !CE->getConstructor()->hasAttr()) + return false; // std::span(ptr, size) ctor is handled by SpanTwoParamConstructorGadget. - auto HasTwoParamSpanCtorDecl = SpanTwoParamConstructorGadget::matcher(); - return stmt( - cxxConstructExpr(HasUnsafeCtorDecl, unless(HasTwoParamSpanCtorDecl)) - .bind(OpTag)); + MatchResult Tmp; + if (SpanTwoParamConstructorGadget::matches(CE, Ctx, Tmp)) + return false; + Result.addNode(OpTag, DynTypedNode::create(*CE)); + return true; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -2403,23 +2621,34 @@ class DataInvocationGadget : public WarningGadget { const ExplicitCastExpr *Op; public: - DataInvocationGadget(const MatchFinder::MatchResult &Result) + DataInvocationGadget(const MatchResult &Result) : WarningGadget(Kind::DataInvocation), - Op(Result.Nodes.getNodeAs(OpTag)) {} + Op(Result.getNodeAs(OpTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::DataInvocation; } - static Matcher matcher() { - - Matcher callExpr = cxxMemberCallExpr(callee( - cxxMethodDecl(hasName("data"), - ofClass(anyOf(hasName("std::span"), hasName("std::array"), - hasName("std::vector")))))); - return stmt( - explicitCastExpr(anyOf(has(callExpr), has(parenExpr(has(callExpr))))) - .bind(OpTag)); + static bool matches(const Stmt *S, const ASTContext &Ctx, + MatchResult &Result) { + auto *CE = dyn_cast(S); + if (!CE) + return false; + for (auto *Child : CE->children()) { + if (auto *MCE = dyn_cast(Child); + MCE && isDataFunction(MCE)) { + Result.addNode(OpTag, DynTypedNode::create(*CE)); + return true; + } + if (auto *Paren = dyn_cast(Child)) { + if (auto *MCE = dyn_cast(Paren->getSubExpr()); + MCE && isDataFunction(MCE)) { + Result.addNode(OpTag, DynTypedNode::create(*CE)); + return true; + } + } + } + return false; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -2430,6 +2659,23 @@ class DataInvocationGadget : public WarningGadget { SourceLocation getSourceLoc() const override { return Op->getBeginLoc(); } DeclUseList getClaimedVarUseSites() const override { return {}; } + +private: + static bool isDataFunction(const CXXMemberCallExpr *call) { + if (!call) + return false; + auto *callee = call->getDirectCallee(); + if (!callee || !isa(callee)) + return false; + auto *method = cast(callee); + if (method->getNameAsString() == "data" && + method->getParent()->isInStdNamespace() && + (method->getParent()->getName() == "span" || + method->getParent()->getName() == "array" || + method->getParent()->getName() == "vector")) + return true; + return false; + } }; class UnsafeLibcFunctionCallGadget : public WarningGadget { @@ -2459,22 +2705,24 @@ class UnsafeLibcFunctionCallGadget : public WarningGadget { } WarnedFunKind = OTHERS; public: - UnsafeLibcFunctionCallGadget(const MatchFinder::MatchResult &Result) + UnsafeLibcFunctionCallGadget(const MatchResult &Result) : WarningGadget(Kind::UnsafeLibcFunctionCall), - Call(Result.Nodes.getNodeAs(Tag)) { - if (Result.Nodes.getNodeAs(UnsafeSprintfTag)) + Call(Result.getNodeAs(Tag)) { + if (Result.getNodeAs(UnsafeSprintfTag)) WarnedFunKind = SPRINTF; - else if (auto *E = Result.Nodes.getNodeAs(UnsafeStringTag)) { + else if (auto *E = Result.getNodeAs(UnsafeStringTag)) { WarnedFunKind = STRING; UnsafeArg = E; - } else if (Result.Nodes.getNodeAs(UnsafeSizedByTag)) { + } else if (Result.getNodeAs(UnsafeSizedByTag)) { WarnedFunKind = SIZED_BY; UnsafeArg = Call->getArg(0); - } else if (Result.Nodes.getNodeAs(UnsafeVaListTag)) + } else if (Result.getNodeAs(UnsafeVaListTag)) WarnedFunKind = VA_LIST; } - static Matcher matcher(const UnsafeBufferUsageHandler *Handler) { + static bool matches(const Stmt *Stmt, ASTContext &Ctx, + const UnsafeBufferUsageHandler *Handler, + MatchResult &Result) { // When this warning interops with bounds attributes, we suppress the // warning for most of the libc functions except for // 1. "normal printf" (see `libc_func_matchers::isNormalPrintfFunc`), @@ -2482,46 +2730,66 @@ class UnsafeLibcFunctionCallGadget : public WarningGadget { // the '__null_terminated' attribute; // 2. `v*printf/sprintf` functions because these functions cannot be // completely safe even with bounds attributes - return stmt(unless(ignoreUnsafeLibcCall(Handler)), - anyOf( - callExpr( - callee(functionDecl(anyOf( - // Match a predefined unsafe libc - // function: - functionDecl(unless(hasAnyBoundsAttributes()), - libc_func_matchers::isPredefinedUnsafeLibcFunc()), - // Match a call to one of the `v*printf` functions - // taking va-list, which cannot be checked at - // compile-time: - functionDecl(libc_func_matchers::isUnsafeVaListPrintfFunc()) - .bind(UnsafeVaListTag), - // Match a call to a `sprintf` function, which is never - // safe: - functionDecl(libc_func_matchers::isUnsafeSprintfFunc()) - .bind(UnsafeSprintfTag)))), - // (unless the call has a sole null-terminated argument, e.g., strlen, printf, atoi): - unless( - allOf(hasArgument(0, - expr(libc_func_matchers::isNullTermPointer())), - hasNumArgs(1)))), - - // The following two cases require checking against actual - // arguments of the call: - - // Match a call to an `snprintf` function. And first two - // arguments of the call (that describe a buffer) are not in - // safe patterns: - callExpr(callee(functionDecl( - // we do not warn about the write buffer of snprintf if it has bounds attributes: - unless(hasAnyBoundsAttributes()), - libc_func_matchers::isNormalPrintfFunc())), - libc_func_matchers::hasUnsafeSnprintfBuffer()) - .bind(UnsafeSizedByTag), - // Match a call to a `printf` function, which can be safe if - // all arguments are null-terminated: - callExpr(callee(functionDecl(libc_func_matchers::isNormalPrintfFunc())), - libc_func_matchers::hasUnsafePrintfStringArg( - expr().bind(UnsafeStringTag))))); + if (!Ctx.getLangOpts().CPlusPlus /* Warn about libc ONLY in C++ */ || + Handler->ignoreUnsafeBufferInLibcCall(Stmt->getBeginLoc())) + return false; + + const CallExpr *Call = dyn_cast(Stmt); + + if (!Call) + return false; + + const FunctionDecl *CalleeDecl = + dyn_cast_or_null(Call->getCalleeDecl()); + + if (!CalleeDecl) + return false; + // If the call has a sole null-terminated argument, e.g., strlen, + // printf, atoi, we consider it safe: + if (hasNumArgs(Call, 1) && isNullTermPointer(Call->getArg(0), Ctx)) + return false; + if (!hasAnyBoundsAttributes(CalleeDecl)) { + // Match a predefined unsafe libc function: + if (libc_func_matchers::isPredefinedUnsafeLibcFunc(*CalleeDecl)) { + Result.addNode(Tag, DynTypedNode::create(*Call)); + return true; + } + } // v*printf and sprintf functions are always unsafe regardless of whether + // they have bounds annotations + + // Match a call to one of the `v*printf` functions taking va-list, which + // cannot be checked at compile-time: + if (libc_func_matchers::isUnsafeVaListPrintfFunc(*CalleeDecl)) { + Result.addNode(UnsafeVaListTag, DynTypedNode::create(*CalleeDecl)); + Result.addNode(Tag, DynTypedNode::create(*Call)); + return true; + } + + // Match a call to a `sprintf` function, which is never safe: + if (libc_func_matchers::isUnsafeSprintfFunc(*CalleeDecl)) { + Result.addNode(UnsafeSprintfTag, DynTypedNode::create(*CalleeDecl)); + Result.addNode(Tag, DynTypedNode::create(*Call)); + return true; + } + + if (libc_func_matchers::isNormalPrintfFunc(*CalleeDecl)) { + // Match a call to an `snprintf` function. And first two arguments of the + // call (that describe a buffer) are not in safe patterns: + if (!hasAnyBoundsAttributes(CalleeDecl) && + libc_func_matchers::hasUnsafeSnprintfBuffer(*Call, Ctx)) { + Result.addNode(UnsafeSizedByTag, DynTypedNode::create(*Call)); + Result.addNode(Tag, DynTypedNode::create(*Call)); + return true; + } + // Match a call to a `printf` function, which can be safe if + // all arguments are null-terminated: + if (libc_func_matchers::hasUnsafePrintfStringArg(*Call, Ctx, Result, + UnsafeStringTag)) { + Result.addNode(Tag, DynTypedNode::create(*Call)); + return true; + } + } + return false; } const Stmt *getBaseStmt() const { return Call; } @@ -2538,7 +2806,7 @@ class UnsafeLibcFunctionCallGadget : public WarningGadget { }; // Represents expressions of the form `DRE[*]` in the Unspecified Lvalue -// Context (see `isInUnspecifiedLvalueContext`). +// Context (see `findStmtsInUnspecifiedLvalueContext`). // Note here `[]` is the built-in subscript operator. class ULCArraySubscriptGadget : public FixableGadget { private: @@ -2547,9 +2815,9 @@ class ULCArraySubscriptGadget : public FixableGadget { const ArraySubscriptExpr *Node; public: - ULCArraySubscriptGadget(const MatchFinder::MatchResult &Result) + ULCArraySubscriptGadget(const MatchResult &Result) : FixableGadget(Kind::ULCArraySubscript), - Node(Result.Nodes.getNodeAs(ULCArraySubscriptTag)) { + Node(Result.getNodeAs(ULCArraySubscriptTag)) { assert(Node != nullptr && "Expecting a non-null matching result"); } @@ -2557,14 +2825,23 @@ class ULCArraySubscriptGadget : public FixableGadget { return G->getKind() == Kind::ULCArraySubscript; } - static Matcher matcher() { - auto ArrayOrPtr = anyOf(hasPointerType(), hasArrayType()); - auto BaseIsArrayOrPtrDRE = hasBase( - ignoringParenImpCasts(declRefExpr(ArrayOrPtr, toSupportedVariable()))); - auto Target = - arraySubscriptExpr(BaseIsArrayOrPtrDRE).bind(ULCArraySubscriptTag); - - return expr(isInUnspecifiedLvalueContext(Target)); + static bool matches(const Stmt *S, + llvm::SmallVectorImpl &Results) { + size_t SizeBefore = Results.size(); + findStmtsInUnspecifiedLvalueContext(S, [&Results](const Expr *E) { + const auto *ASE = dyn_cast(E); + if (!ASE) + return; + const auto *DRE = + dyn_cast(ASE->getBase()->IgnoreParenImpCasts()); + if (!DRE || !(hasPointerType(*DRE) || hasArrayType(*DRE)) || + !isSupportedVariable(*DRE)) + return; + MatchResult R; + R.addNode(ULCArraySubscriptTag, DynTypedNode::create(*ASE)); + Results.emplace_back(std::move(R)); + }); + return SizeBefore != Results.size(); } virtual std::optional @@ -2581,17 +2858,17 @@ class ULCArraySubscriptGadget : public FixableGadget { }; // Fixable gadget to handle stand alone pointers of the form `UPC(DRE)` in the -// unspecified pointer context (isInUnspecifiedPointerContext). The gadget emits -// fixit of the form `UPC(DRE.data())`. +// unspecified pointer context (findStmtsInUnspecifiedPointerContext). The +// gadget emits fixit of the form `UPC(DRE.data())`. class UPCStandalonePointerGadget : public FixableGadget { private: static constexpr const char *const DeclRefExprTag = "StandalonePointer"; const DeclRefExpr *Node; public: - UPCStandalonePointerGadget(const MatchFinder::MatchResult &Result) + UPCStandalonePointerGadget(const MatchResult &Result) : FixableGadget(Kind::UPCStandalonePointer), - Node(Result.Nodes.getNodeAs(DeclRefExprTag)) { + Node(Result.getNodeAs(DeclRefExprTag)) { assert(Node != nullptr && "Expecting a non-null matching result"); } @@ -2599,12 +2876,22 @@ class UPCStandalonePointerGadget : public FixableGadget { return G->getKind() == Kind::UPCStandalonePointer; } - static Matcher matcher() { - auto ArrayOrPtr = anyOf(hasPointerType(), hasArrayType()); - auto target = expr(ignoringParenImpCasts( - declRefExpr(allOf(ArrayOrPtr, toSupportedVariable())) - .bind(DeclRefExprTag))); - return stmt(isInUnspecifiedPointerContext(target)); + static bool matches(const Stmt *S, + llvm::SmallVectorImpl &Results) { + size_t SizeBefore = Results.size(); + findStmtsInUnspecifiedPointerContext(S, [&Results](const Stmt *S) { + auto *E = dyn_cast(S); + if (!E) + return; + const auto *DRE = dyn_cast(E->IgnoreParenImpCasts()); + if (!DRE || (!hasPointerType(*DRE) && !hasArrayType(*DRE)) || + !isSupportedVariable(*DRE)) + return; + MatchResult R; + R.addNode(DeclRefExprTag, DynTypedNode::create(*DRE)); + Results.emplace_back(std::move(R)); + }); + return SizeBefore != Results.size(); } virtual std::optional @@ -2622,25 +2909,35 @@ class PointerDereferenceGadget : public FixableGadget { const UnaryOperator *Op = nullptr; public: - PointerDereferenceGadget(const MatchFinder::MatchResult &Result) + PointerDereferenceGadget(const MatchResult &Result) : FixableGadget(Kind::PointerDereference), - BaseDeclRefExpr( - Result.Nodes.getNodeAs(BaseDeclRefExprTag)), - Op(Result.Nodes.getNodeAs(OperatorTag)) {} + BaseDeclRefExpr(Result.getNodeAs(BaseDeclRefExprTag)), + Op(Result.getNodeAs(OperatorTag)) {} static bool classof(const Gadget *G) { return G->getKind() == Kind::PointerDereference; } - static Matcher matcher() { - auto Target = - unaryOperator( - hasOperatorName("*"), - has(expr(ignoringParenImpCasts( - declRefExpr(toSupportedVariable()).bind(BaseDeclRefExprTag))))) - .bind(OperatorTag); - - return expr(isInUnspecifiedLvalueContext(Target)); + static bool matches(const Stmt *S, + llvm::SmallVectorImpl &Results) { + size_t SizeBefore = Results.size(); + findStmtsInUnspecifiedLvalueContext(S, [&Results](const Stmt *S) { + const auto *UO = dyn_cast(S); + if (!UO || UO->getOpcode() != UO_Deref) + return; + const auto *CE = dyn_cast(UO->getSubExpr()); + if (!CE) + return; + CE = CE->IgnoreParenImpCasts(); + const auto *DRE = dyn_cast(CE); + if (!DRE || !isSupportedVariable(*DRE)) + return; + MatchResult R; + R.addNode(BaseDeclRefExprTag, DynTypedNode::create(*DRE)); + R.addNode(OperatorTag, DynTypedNode::create(*UO)); + Results.emplace_back(std::move(R)); + }); + return SizeBefore != Results.size(); } DeclUseList getClaimedVarUseSites() const override { @@ -2653,7 +2950,7 @@ class PointerDereferenceGadget : public FixableGadget { }; // Represents expressions of the form `&DRE[any]` in the Unspecified Pointer -// Context (see `isInUnspecifiedPointerContext`). +// Context (see `findStmtsInUnspecifiedPointerContext`). // Note here `[]` is the built-in subscript operator. class UPCAddressofArraySubscriptGadget : public FixableGadget { private: @@ -2662,10 +2959,9 @@ class UPCAddressofArraySubscriptGadget : public FixableGadget { const UnaryOperator *Node; // the `&DRE[any]` node public: - UPCAddressofArraySubscriptGadget(const MatchFinder::MatchResult &Result) + UPCAddressofArraySubscriptGadget(const MatchResult &Result) : FixableGadget(Kind::ULCArraySubscript), - Node(Result.Nodes.getNodeAs( - UPCAddressofArraySubscriptTag)) { + Node(Result.getNodeAs(UPCAddressofArraySubscriptTag)) { assert(Node != nullptr && "Expecting a non-null matching result"); } @@ -2673,13 +2969,28 @@ class UPCAddressofArraySubscriptGadget : public FixableGadget { return G->getKind() == Kind::UPCAddressofArraySubscript; } - static Matcher matcher() { - return expr(isInUnspecifiedPointerContext(expr(ignoringImpCasts( - unaryOperator( - hasOperatorName("&"), - hasUnaryOperand(arraySubscriptExpr(hasBase( - ignoringParenImpCasts(declRefExpr(toSupportedVariable())))))) - .bind(UPCAddressofArraySubscriptTag))))); + static bool matches(const Stmt *S, + llvm::SmallVectorImpl &Results) { + size_t SizeBefore = Results.size(); + findStmtsInUnspecifiedPointerContext(S, [&Results](const Stmt *S) { + auto *E = dyn_cast(S); + if (!E) + return; + const auto *UO = dyn_cast(E->IgnoreImpCasts()); + if (!UO || UO->getOpcode() != UO_AddrOf) + return; + const auto *ASE = dyn_cast(UO->getSubExpr()); + if (!ASE) + return; + const auto *DRE = + dyn_cast(ASE->getBase()->IgnoreParenImpCasts()); + if (!DRE || !isSupportedVariable(*DRE)) + return; + MatchResult R; + R.addNode(UPCAddressofArraySubscriptTag, DynTypedNode::create(*UO)); + Results.emplace_back(std::move(R)); + }); + return SizeBefore != Results.size(); } virtual std::optional @@ -2770,9 +3081,9 @@ class UPCPreIncrementGadget : public FixableGadget { const UnaryOperator *Node; // the `++Ptr` node public: - UPCPreIncrementGadget(const MatchFinder::MatchResult &Result) + UPCPreIncrementGadget(const MatchResult &Result) : FixableGadget(Kind::UPCPreIncrement), - Node(Result.Nodes.getNodeAs(UPCPreIncrementTag)) { + Node(Result.getNodeAs(UPCPreIncrementTag)) { assert(Node != nullptr && "Expecting a non-null matching result"); } @@ -2780,15 +3091,28 @@ class UPCPreIncrementGadget : public FixableGadget { return G->getKind() == Kind::UPCPreIncrement; } - static Matcher matcher() { + static bool matches(const Stmt *S, + llvm::SmallVectorImpl &Results) { // Note here we match `++Ptr` for any expression `Ptr` of pointer type. // Although currently we can only provide fix-its when `Ptr` is a DRE, we // can have the matcher be general, so long as `getClaimedVarUseSites` does // things right. - return stmt(isInUnspecifiedPointerContext(expr(ignoringImpCasts( - unaryOperator(isPreInc(), - hasUnaryOperand(declRefExpr(toSupportedVariable()))) - .bind(UPCPreIncrementTag))))); + size_t SizeBefore = Results.size(); + findStmtsInUnspecifiedPointerContext(S, [&Results](const Stmt *S) { + auto *E = dyn_cast(S); + if (!E) + return; + const auto *UO = dyn_cast(E->IgnoreImpCasts()); + if (!UO || UO->getOpcode() != UO_PreInc) + return; + const auto *DRE = dyn_cast(UO->getSubExpr()); + if (!DRE || !isSupportedVariable(*DRE)) + return; + MatchResult R; + R.addNode(UPCPreIncrementTag, DynTypedNode::create(*UO)); + Results.emplace_back(std::move(R)); + }); + return SizeBefore != Results.size(); } virtual std::optional @@ -2812,10 +3136,10 @@ class UUCAddAssignGadget : public FixableGadget { const Expr *Offset = nullptr; public: - UUCAddAssignGadget(const MatchFinder::MatchResult &Result) + UUCAddAssignGadget(const MatchResult &Result) : FixableGadget(Kind::UUCAddAssign), - Node(Result.Nodes.getNodeAs(UUCAddAssignTag)), - Offset(Result.Nodes.getNodeAs(OffsetTag)) { + Node(Result.getNodeAs(UUCAddAssignTag)), + Offset(Result.getNodeAs(OffsetTag)) { assert(Node != nullptr && "Expecting a non-null matching result"); } @@ -2823,17 +3147,25 @@ class UUCAddAssignGadget : public FixableGadget { return G->getKind() == Kind::UUCAddAssign; } - static Matcher matcher() { - // clang-format off - return stmt(isInUnspecifiedUntypedContext(expr(ignoringImpCasts( - binaryOperator(hasOperatorName("+="), - hasLHS( - declRefExpr( - hasPointerType(), - toSupportedVariable())), - hasRHS(expr().bind(OffsetTag))) - .bind(UUCAddAssignTag))))); - // clang-format on + static bool matches(const Stmt *S, + llvm::SmallVectorImpl &Results) { + size_t SizeBefore = Results.size(); + findStmtsInUnspecifiedUntypedContext(S, [&Results](const Stmt *S) { + const auto *E = dyn_cast(S); + if (!E) + return; + const auto *BO = dyn_cast(E->IgnoreImpCasts()); + if (!BO || BO->getOpcode() != BO_AddAssign) + return; + const auto *DRE = dyn_cast(BO->getLHS()); + if (!DRE || !hasPointerType(*DRE) || !isSupportedVariable(*DRE)) + return; + MatchResult R; + R.addNode(UUCAddAssignTag, DynTypedNode::create(*BO)); + R.addNode(OffsetTag, DynTypedNode::create(*BO->getRHS())); + Results.emplace_back(std::move(R)); + }); + return SizeBefore != Results.size(); } virtual std::optional @@ -2859,31 +3191,60 @@ class DerefSimplePtrArithFixableGadget : public FixableGadget { const IntegerLiteral *Offset = nullptr; public: - DerefSimplePtrArithFixableGadget(const MatchFinder::MatchResult &Result) + DerefSimplePtrArithFixableGadget(const MatchResult &Result) : FixableGadget(Kind::DerefSimplePtrArithFixable), - BaseDeclRefExpr( - Result.Nodes.getNodeAs(BaseDeclRefExprTag)), - DerefOp(Result.Nodes.getNodeAs(DerefOpTag)), - AddOp(Result.Nodes.getNodeAs(AddOpTag)), - Offset(Result.Nodes.getNodeAs(OffsetTag)) {} - - static Matcher matcher() { - // clang-format off - auto ThePtr = expr(hasPointerType(), - ignoringImpCasts(declRefExpr(toSupportedVariable()). - bind(BaseDeclRefExprTag))); - auto PlusOverPtrAndInteger = expr(anyOf( - binaryOperator(hasOperatorName("+"), hasLHS(ThePtr), - hasRHS(integerLiteral().bind(OffsetTag))) - .bind(AddOpTag), - binaryOperator(hasOperatorName("+"), hasRHS(ThePtr), - hasLHS(integerLiteral().bind(OffsetTag))) - .bind(AddOpTag))); - return isInUnspecifiedLvalueContext(unaryOperator( - hasOperatorName("*"), - hasUnaryOperand(ignoringParens(PlusOverPtrAndInteger))) - .bind(DerefOpTag)); - // clang-format on + BaseDeclRefExpr(Result.getNodeAs(BaseDeclRefExprTag)), + DerefOp(Result.getNodeAs(DerefOpTag)), + AddOp(Result.getNodeAs(AddOpTag)), + Offset(Result.getNodeAs(OffsetTag)) {} + + static bool matches(const Stmt *S, + llvm::SmallVectorImpl &Results) { + auto IsPtr = [](const Expr *E, MatchResult &R) { + if (!E || !hasPointerType(*E)) + return false; + const auto *DRE = dyn_cast(E->IgnoreImpCasts()); + if (!DRE || !isSupportedVariable(*DRE)) + return false; + R.addNode(BaseDeclRefExprTag, DynTypedNode::create(*DRE)); + return true; + }; + const auto IsPlusOverPtrAndInteger = [&IsPtr](const Expr *E, + MatchResult &R) { + const auto *BO = dyn_cast(E); + if (!BO || BO->getOpcode() != BO_Add) + return false; + + const auto *LHS = BO->getLHS(); + const auto *RHS = BO->getRHS(); + if (isa(RHS) && IsPtr(LHS, R)) { + R.addNode(OffsetTag, DynTypedNode::create(*RHS)); + R.addNode(AddOpTag, DynTypedNode::create(*BO)); + return true; + } + if (isa(LHS) && IsPtr(RHS, R)) { + R.addNode(OffsetTag, DynTypedNode::create(*LHS)); + R.addNode(AddOpTag, DynTypedNode::create(*BO)); + return true; + } + return false; + }; + size_t SizeBefore = Results.size(); + const auto InnerMatcher = [&IsPlusOverPtrAndInteger, + &Results](const Expr *E) { + const auto *UO = dyn_cast(E); + if (!UO || UO->getOpcode() != UO_Deref) + return; + + const auto *Operand = UO->getSubExpr()->IgnoreParens(); + MatchResult R; + if (IsPlusOverPtrAndInteger(Operand, R)) { + R.addNode(DerefOpTag, DynTypedNode::create(*UO)); + Results.emplace_back(std::move(R)); + } + }; + findStmtsInUnspecifiedLvalueContext(S, InnerMatcher); + return SizeBefore != Results.size(); } virtual std::optional @@ -2908,11 +3269,10 @@ class CountAttributedPointerArgumentGadget : public WarningGadget { const Expr *Arg; public: - explicit CountAttributedPointerArgumentGadget( - const MatchFinder::MatchResult &Result) + explicit CountAttributedPointerArgumentGadget(const MatchResult &Result) : WarningGadget(Kind::CountAttributedPointerArgument), - Call(Result.Nodes.getNodeAs(CallTag)), - Arg(Result.Nodes.getNodeAs(ArgTag)) { + Call(Result.getNodeAs(CallTag)), + Arg(Result.getNodeAs(ArgTag)) { assert(Call != nullptr && "Expecting a non-null matching result"); assert(Arg != nullptr && "Expecting a non-null matching result"); } @@ -2921,10 +3281,13 @@ class CountAttributedPointerArgumentGadget : public WarningGadget { return G->getKind() == Kind::CountAttributedPointerArgument; } - static Matcher matcher() { - return stmt(callExpr(forEachUnsafeCountAttributedPointerArgument( - expr().bind(ArgTag))) - .bind(CallTag)); + static bool matches(const Stmt *Stmt, ASTContext &Ctx, + MatchResults &Results) { + if (const auto *Call = dyn_cast(Stmt)) + if (forEachUnsafeCountAttributedPointerArgument(Call, Results, ArgTag, + CallTag, Ctx)) + return true; + return false; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -2951,9 +3314,9 @@ class SinglePointerArgumentGadget : public WarningGadget { const Expr *Arg; public: - explicit SinglePointerArgumentGadget(const MatchFinder::MatchResult &Result) + explicit SinglePointerArgumentGadget(const MatchResult &Result) : WarningGadget(Kind::SinglePointerArgument), - Arg(Result.Nodes.getNodeAs(ArgTag)) { + Arg(Result.getNodeAs(ArgTag)) { assert(Arg != nullptr && "Expecting a non-null matching result"); } @@ -2961,10 +3324,20 @@ class SinglePointerArgumentGadget : public WarningGadget { return G->getKind() == Kind::SinglePointerArgument; } - static Matcher matcher() { - return stmt(callExpr(forEachArgumentWithParamType( - expr(unless(isSinglePointerArgumentSafe())).bind(ArgTag), - isSinglePointerType()))); + static bool matches(const Stmt *Stmt, ASTContext &Ctx, + MatchResults &Results) { + bool Found = false; + + if (const auto *Call = dyn_cast(Stmt)) { + forEachArgumentWithParamType(*Call, [&Results, &Ctx, &Found]( + QualType QT, const Expr *Arg) { + if (isSinglePointerType(QT) && !isSinglePointerArgumentSafe(Ctx, Arg)) { + Results.emplace_back(ArgTag, DynTypedNode::create(*Arg)); + Found = true; + } + }); + } + return Found; } void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler, @@ -2984,111 +3357,122 @@ class SinglePointerArgumentGadget : public WarningGadget { }; /// Scan the function and return a list of gadgets found with provided kits. -static void findGadgets(const Stmt *S, ASTContext &Ctx, - const UnsafeBufferUsageHandler &Handler, - bool EmitSuggestions, FixableGadgetList &FixableGadgets, - WarningGadgetList &WarningGadgets, - DeclUseTracker &Tracker) { - - struct GadgetFinderCallback : MatchFinder::MatchCallback { - GadgetFinderCallback(FixableGadgetList &FixableGadgets, - WarningGadgetList &WarningGadgets, - DeclUseTracker &Tracker) - : FixableGadgets(FixableGadgets), WarningGadgets(WarningGadgets), - Tracker(Tracker) {} - - void run(const MatchFinder::MatchResult &Result) override { - // In debug mode, assert that we've found exactly one gadget. - // This helps us avoid conflicts in .bind() tags. -#if NDEBUG -#define NEXT return -#else - [[maybe_unused]] int numFound = 0; -#define NEXT ++numFound -#endif +class WarningGadgetMatcher : public FastMatcher { - if (const auto *DRE = Result.Nodes.getNodeAs("any_dre")) { - Tracker.discoverUse(DRE); - NEXT; - } +public: + WarningGadgetMatcher(WarningGadgetList &WarningGadgets) + : WarningGadgets(WarningGadgets) {} - if (const auto *DS = Result.Nodes.getNodeAs("any_ds")) { - Tracker.discoverDecl(DS); - NEXT; - } + bool matches(const DynTypedNode &DynNode, ASTContext &Ctx, + const UnsafeBufferUsageHandler &Handler) override { + const Stmt *S = DynNode.get(); + if (!S) + return false; - // Figure out which matcher we've found, and call the appropriate - // subclass constructor. - // FIXME: Can we do this more logarithmically? -#define FIXABLE_GADGET(name) \ - if (Result.Nodes.getNodeAs(#name)) { \ - FixableGadgets.push_back(std::make_unique(Result)); \ - NEXT; \ - } -#include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def" + MatchResult Result; #define WARNING_GADGET(name) \ - if (Result.Nodes.getNodeAs(#name)) { \ + if (name##Gadget::matches(S, Ctx, Result) && \ + notInSafeBufferOptOut(*S, &Handler)) { \ + WarningGadgets.push_back(std::make_unique(Result)); \ + return true; \ + } +#define WARNING_BOUNDS_SAFETY_GADGET(name) \ + { \ + MatchResults MultiResults; \ + if (name##Gadget::matches(S, Ctx, MultiResults) && \ + notInSafeBufferOptOut(*S, &Handler)) { \ + for (auto &Result : MultiResults) \ + WarningGadgets.push_back(std::make_unique(Result)); \ + return true; \ + } \ + } +#define WARNING_OPTIONAL_GADGET(name) \ + if (name##Gadget::matches(S, Ctx, &Handler, Result) && \ + notInSafeBufferOptOut(*S, &Handler)) { \ WarningGadgets.push_back(std::make_unique(Result)); \ - NEXT; \ + return true; \ } #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def" + return false; + } - assert(numFound >= 1 && "Gadgets not found in match result!"); - assert(numFound <= 1 && "Conflicting bind tags in gadgets!"); - } +private: + WarningGadgetList &WarningGadgets; +}; - FixableGadgetList &FixableGadgets; - WarningGadgetList &WarningGadgets; - DeclUseTracker &Tracker; - }; +class FixableGadgetMatcher : public FastMatcher { - MatchFinder M; - GadgetFinderCallback CB{FixableGadgets, WarningGadgets, Tracker}; - - // clang-format off - M.addMatcher( - stmt( - forEachDescendantEvaluatedStmt(stmt(anyOf( - // Add Gadget::matcher() for every gadget in the registry. -#define WARNING_GADGET(x) \ - allOf(x ## Gadget::matcher().bind(#x), \ - notInSafeBufferOptOut(&Handler)), -#define WARNING_OPTIONAL_GADGET(x) \ - allOf(x ## Gadget::matcher(&Handler).bind(#x), \ - notInSafeBufferOptOut(&Handler)), +public: + FixableGadgetMatcher(FixableGadgetList &FixableGadgets, + DeclUseTracker &Tracker) + : FixableGadgets(FixableGadgets), Tracker(Tracker) {} + + bool matches(const DynTypedNode &DynNode, ASTContext &Ctx, + const UnsafeBufferUsageHandler &Handler) override { + bool matchFound = false; + const Stmt *S = DynNode.get(); + if (!S) { + return matchFound; + } + + llvm::SmallVector Results; +#define FIXABLE_GADGET(name) \ + if (name##Gadget::matches(S, Results)) { \ + for (const auto &R : Results) { \ + FixableGadgets.push_back(std::make_unique(R)); \ + matchFound = true; \ + } \ + Results = {}; \ + } #include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def" - // Avoid a hanging comma. - unless(stmt()) - ))) - ), - &CB - ); - // clang-format on + // In parallel, match all DeclRefExprs so that to find out + // whether there are any uncovered by gadgets. + if (auto *DRE = findDeclRefExpr(S); DRE) { + Tracker.discoverUse(DRE); + matchFound = true; + } + // Also match DeclStmts because we'll need them when fixing + // their underlying VarDecls that otherwise don't have + // any backreferences to DeclStmts. + if (auto *DS = findDeclStmt(S); DS) { + Tracker.discoverDecl(DS); + matchFound = true; + } + return matchFound; + } +private: + const DeclRefExpr *findDeclRefExpr(const Stmt *S) { + const auto *DRE = dyn_cast(S); + if (!DRE || (!hasPointerType(*DRE) && !hasArrayType(*DRE))) + return nullptr; + const Decl *D = DRE->getDecl(); + if (!D || (!isa(D) && !isa(D))) + return nullptr; + return DRE; + } + const DeclStmt *findDeclStmt(const Stmt *S) { + const auto *DS = dyn_cast(S); + if (!DS) + return nullptr; + return DS; + } + FixableGadgetList &FixableGadgets; + DeclUseTracker &Tracker; +}; + +// Scan the function and return a list of gadgets found with provided kits. +static void findGadgets(const Stmt *S, ASTContext &Ctx, + const UnsafeBufferUsageHandler &Handler, + bool EmitSuggestions, FixableGadgetList &FixableGadgets, + WarningGadgetList &WarningGadgets, + DeclUseTracker &Tracker) { + WarningGadgetMatcher WMatcher{WarningGadgets}; + forEachDescendantEvaluatedStmt(S, Ctx, Handler, WMatcher); if (EmitSuggestions) { - // clang-format off - M.addMatcher( - stmt( - forEachDescendantStmt(stmt(eachOf( -#define FIXABLE_GADGET(x) \ - x ## Gadget::matcher().bind(#x), -#include "clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def" - // In parallel, match all DeclRefExprs so that to find out - // whether there are any uncovered by gadgets. - declRefExpr(anyOf(hasPointerType(), hasArrayType()), - to(anyOf(varDecl(), bindingDecl()))).bind("any_dre"), - // Also match DeclStmts because we'll need them when fixing - // their underlying VarDecls that otherwise don't have - // any backreferences to DeclStmts. - declStmt().bind("any_ds") - ))) - ), - &CB - ); - // clang-format on - } - - M.match(*S, Ctx); + FixableGadgetMatcher FMatcher{FixableGadgets, Tracker}; + forEachDescendantStmt(S, Ctx, Handler, FMatcher); + } } // Compares AST nodes by source locations. @@ -4634,9 +5018,11 @@ class VariableGroupsManagerImpl : public VariableGroupsManager { } }; -void applyGadgets(const Decl *D, FixableGadgetList FixableGadgets, - WarningGadgetList WarningGadgets, DeclUseTracker Tracker, - UnsafeBufferUsageHandler &Handler, bool EmitSuggestions) { +static void applyGadgets(const Decl *D, FixableGadgetList FixableGadgets, + WarningGadgetList WarningGadgets, + DeclUseTracker Tracker, + UnsafeBufferUsageHandler &Handler, + bool EmitSuggestions) { if (!EmitSuggestions) { // Our job is very easy without suggestions. Just warn about // every problematic operation and consider it done. No need to deal @@ -4648,7 +5034,7 @@ void applyGadgets(const Decl *D, FixableGadgetList FixableGadgets, // This return guarantees that most of the machine doesn't run when // suggestions aren't requested. - assert(FixableGadgets.size() == 0 && + assert(FixableGadgets.empty() && "Fixable gadgets found but suggestions not requested!"); return; } @@ -4747,7 +5133,7 @@ void applyGadgets(const Decl *D, FixableGadgetList FixableGadgets, DepMapTy DependenciesMap{}; DepMapTy PtrAssignmentGraph{}; - for (auto it : FixablesForAllVars.byVar) { + for (const auto &it : FixablesForAllVars.byVar) { for (const FixableGadget *fixable : it.second) { std::optional> ImplPair = fixable->getStrategyImplications();