@@ -137,6 +137,19 @@ static auto valueOperatorCall() {
137137 isStatusOrOperatorCallWithName (" ->" )));
138138}
139139
140+ static clang::ast_matchers::TypeMatcher statusType () {
141+ using namespace ::clang::ast_matchers; // NOLINT: Too many names
142+ return hasCanonicalType (qualType (hasDeclaration (statusClass ())));
143+ }
144+
145+ static auto isComparisonOperatorCall (llvm::StringRef operator_name) {
146+ using namespace ::clang::ast_matchers; // NOLINT: Too many names
147+ return cxxOperatorCallExpr (
148+ hasOverloadedOperatorName (operator_name), argumentCountIs (2 ),
149+ hasArgument (0 , anyOf (hasType (statusType ()), hasType (statusOrType ()))),
150+ hasArgument (1 , anyOf (hasType (statusType ()), hasType (statusOrType ()))));
151+ }
152+
140153static auto
141154buildDiagnoseMatchSwitch (const UncheckedStatusOrAccessModelOptions &Options) {
142155 return CFGMatchSwitchBuilder<const Environment,
@@ -312,6 +325,101 @@ static void transferStatusUpdateCall(const CXXMemberCallExpr *Expr,
312325 State.Env .setValue (locForOk (*ThisLoc), NewVal);
313326}
314327
328+ static BoolValue *evaluateStatusEquality (RecordStorageLocation &LhsStatusLoc,
329+ RecordStorageLocation &RhsStatusLoc,
330+ Environment &Env) {
331+ auto &A = Env.arena ();
332+ // Logically, a Status object is composed of an error code that could take one
333+ // of multiple possible values, including the "ok" value. We track whether a
334+ // Status object has an "ok" value and represent this as an `ok` bit. Equality
335+ // of Status objects compares their error codes. Therefore, merely comparing
336+ // the `ok` bits isn't sufficient: when two Status objects are assigned non-ok
337+ // error codes the equality of their respective error codes matters. Since we
338+ // only track the `ok` bits, we can't make any conclusions about equality when
339+ // we know that two Status objects have non-ok values.
340+
341+ auto &LhsOkVal = valForOk (LhsStatusLoc, Env);
342+ auto &RhsOkVal = valForOk (RhsStatusLoc, Env);
343+
344+ auto &Res = Env.makeAtomicBoolValue ();
345+
346+ // lhs && rhs => res (a.k.a. !res => !lhs || !rhs)
347+ Env.assume (A.makeImplies (A.makeAnd (LhsOkVal.formula (), RhsOkVal.formula ()),
348+ Res.formula ()));
349+ // res => (lhs == rhs)
350+ Env.assume (A.makeImplies (
351+ Res.formula (), A.makeEquals (LhsOkVal.formula (), RhsOkVal.formula ())));
352+
353+ return &Res;
354+ }
355+
356+ static BoolValue *
357+ evaluateStatusOrEquality (RecordStorageLocation &LhsStatusOrLoc,
358+ RecordStorageLocation &RhsStatusOrLoc,
359+ Environment &Env) {
360+ auto &A = Env.arena ();
361+ // Logically, a StatusOr<T> object is composed of two values - a Status and a
362+ // value of type T. Equality of StatusOr objects compares both values.
363+ // Therefore, merely comparing the `ok` bits of the Status values isn't
364+ // sufficient. When two StatusOr objects are engaged, the equality of their
365+ // respective values of type T matters. Similarly, when two StatusOr objects
366+ // have Status values that have non-ok error codes, the equality of the error
367+ // codes matters. Since we only track the `ok` bits of the Status values, we
368+ // can't make any conclusions about equality when we know that two StatusOr
369+ // objects are engaged or when their Status values contain non-ok error codes.
370+ auto &LhsOkVal = valForOk (locForStatus (LhsStatusOrLoc), Env);
371+ auto &RhsOkVal = valForOk (locForStatus (RhsStatusOrLoc), Env);
372+ auto &res = Env.makeAtomicBoolValue ();
373+
374+ // res => (lhs == rhs)
375+ Env.assume (A.makeImplies (
376+ res.formula (), A.makeEquals (LhsOkVal.formula (), RhsOkVal.formula ())));
377+ return &res;
378+ }
379+
380+ static BoolValue *evaluateEquality (const Expr *LhsExpr, const Expr *RhsExpr,
381+ Environment &Env) {
382+ // Check the type of both sides in case an operator== is added that admits
383+ // different types.
384+ if (isStatusOrType (LhsExpr->getType ()) &&
385+ isStatusOrType (RhsExpr->getType ())) {
386+ auto *LhsStatusOrLoc = Env.get <RecordStorageLocation>(*LhsExpr);
387+ if (LhsStatusOrLoc == nullptr )
388+ return nullptr ;
389+ auto *RhsStatusOrLoc = Env.get <RecordStorageLocation>(*RhsExpr);
390+ if (RhsStatusOrLoc == nullptr )
391+ return nullptr ;
392+
393+ return evaluateStatusOrEquality (*LhsStatusOrLoc, *RhsStatusOrLoc, Env);
394+ }
395+ if (isStatusType (LhsExpr->getType ()) && isStatusType (RhsExpr->getType ())) {
396+ auto *LhsStatusLoc = Env.get <RecordStorageLocation>(*LhsExpr);
397+ if (LhsStatusLoc == nullptr )
398+ return nullptr ;
399+
400+ auto *RhsStatusLoc = Env.get <RecordStorageLocation>(*RhsExpr);
401+ if (RhsStatusLoc == nullptr )
402+ return nullptr ;
403+
404+ return evaluateStatusEquality (*LhsStatusLoc, *RhsStatusLoc, Env);
405+ }
406+ return nullptr ;
407+ }
408+
409+ static void transferComparisonOperator (const CXXOperatorCallExpr *Expr,
410+ LatticeTransferState &State,
411+ bool IsNegative) {
412+ auto *LhsAndRhsVal =
413+ evaluateEquality (Expr->getArg (0 ), Expr->getArg (1 ), State.Env );
414+ if (LhsAndRhsVal == nullptr )
415+ return ;
416+
417+ if (IsNegative)
418+ State.Env .setValue (*Expr, State.Env .makeNot (*LhsAndRhsVal));
419+ else
420+ State.Env .setValue (*Expr, *LhsAndRhsVal);
421+ }
422+
315423CFGMatchSwitch<LatticeTransferState>
316424buildTransferMatchSwitch (ASTContext &Ctx,
317425 CFGMatchSwitchBuilder<LatticeTransferState> Builder) {
@@ -325,6 +433,20 @@ buildTransferMatchSwitch(ASTContext &Ctx,
325433 transferStatusOkCall)
326434 .CaseOfCFGStmt <CXXMemberCallExpr>(isStatusMemberCallWithName (" Update" ),
327435 transferStatusUpdateCall)
436+ .CaseOfCFGStmt <CXXOperatorCallExpr>(
437+ isComparisonOperatorCall (" ==" ),
438+ [](const CXXOperatorCallExpr *Expr, const MatchFinder::MatchResult &,
439+ LatticeTransferState &State) {
440+ transferComparisonOperator (Expr, State,
441+ /* IsNegative=*/ false );
442+ })
443+ .CaseOfCFGStmt <CXXOperatorCallExpr>(
444+ isComparisonOperatorCall (" !=" ),
445+ [](const CXXOperatorCallExpr *Expr, const MatchFinder::MatchResult &,
446+ LatticeTransferState &State) {
447+ transferComparisonOperator (Expr, State,
448+ /* IsNegative=*/ true );
449+ })
328450 .Build ();
329451}
330452
0 commit comments