Skip to content

[Clang] Make enums trivially equality comparable #133587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 48 additions & 35 deletions clang/lib/Sema/SemaExprCXX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5174,6 +5174,43 @@ static bool HasNoThrowOperator(const RecordType *RT, OverloadedOperatorKind Op,
return false;
}

static bool EqualityComparisonIsDefaulted(Sema &S, const TypeDecl *Decl,
SourceLocation KeyLoc) {
EnterExpressionEvaluationContext UnevaluatedContext(
S, Sema::ExpressionEvaluationContext::Unevaluated);
Sema::SFINAETrap SFINAE(S, /*AccessCheckingSFINAE=*/true);
Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());

// const ClassT& obj;
OpaqueValueExpr Operand(
KeyLoc, Decl->getTypeForDecl()->getCanonicalTypeUnqualified().withConst(),
ExprValueKind::VK_LValue);
UnresolvedSet<16> Functions;
// obj == obj;
S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);

auto Result = S.CreateOverloadedBinOp(KeyLoc, BinaryOperatorKind::BO_EQ,
Functions, &Operand, &Operand);
if (Result.isInvalid() || SFINAE.hasErrorOccurred())
return false;

const auto *CallExpr = dyn_cast<CXXOperatorCallExpr>(Result.get());
if (!CallExpr)
return isa<EnumDecl>(Decl);
const auto *Callee = CallExpr->getDirectCallee();
auto ParamT = Callee->getParamDecl(0)->getType();
if (!Callee->isDefaulted())
return false;
if (!ParamT->isReferenceType()) {
if (const CXXRecordDecl * RD = dyn_cast<CXXRecordDecl>(Decl); !RD->isTriviallyCopyable())
return false;
}
if (ParamT.getNonReferenceType()->getUnqualifiedDesugaredType() !=
Decl->getTypeForDecl())
return false;
return true;
}

static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
const CXXRecordDecl *Decl,
SourceLocation KeyLoc) {
Expand All @@ -5182,39 +5219,8 @@ static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
if (Decl->isLambda())
return Decl->isCapturelessLambda();

{
EnterExpressionEvaluationContext UnevaluatedContext(
S, Sema::ExpressionEvaluationContext::Unevaluated);
Sema::SFINAETrap SFINAE(S, /*AccessCheckingSFINAE=*/true);
Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());

// const ClassT& obj;
OpaqueValueExpr Operand(
KeyLoc,
Decl->getTypeForDecl()->getCanonicalTypeUnqualified().withConst(),
ExprValueKind::VK_LValue);
UnresolvedSet<16> Functions;
// obj == obj;
S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);

auto Result = S.CreateOverloadedBinOp(KeyLoc, BinaryOperatorKind::BO_EQ,
Functions, &Operand, &Operand);
if (Result.isInvalid() || SFINAE.hasErrorOccurred())
return false;

const auto *CallExpr = dyn_cast<CXXOperatorCallExpr>(Result.get());
if (!CallExpr)
return false;
const auto *Callee = CallExpr->getDirectCallee();
auto ParamT = Callee->getParamDecl(0)->getType();
if (!Callee->isDefaulted())
return false;
if (!ParamT->isReferenceType() && !Decl->isTriviallyCopyable())
return false;
if (ParamT.getNonReferenceType()->getUnqualifiedDesugaredType() !=
Decl->getTypeForDecl())
return false;
}
if (!EqualityComparisonIsDefaulted(S, Decl, KeyLoc))
return false;

return llvm::all_of(Decl->bases(),
[&](const CXXBaseSpecifier &BS) {
Expand All @@ -5229,7 +5235,10 @@ static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
Type = Type->getBaseElementTypeUnsafe()
->getCanonicalTypeUnqualified();

if (Type->isReferenceType() || Type->isEnumeralType())
if (Type->isReferenceType() ||
(Type->isEnumeralType() &&
!EqualityComparisonIsDefaulted(
S, cast<EnumDecl>(Type->getAsTagDecl()), KeyLoc)))
return false;
if (const auto *RD = Type->getAsCXXRecordDecl())
return HasNonDeletedDefaultedEqualityComparison(S, RD, KeyLoc);
Expand All @@ -5240,9 +5249,13 @@ static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
static bool isTriviallyEqualityComparableType(Sema &S, QualType Type, SourceLocation KeyLoc) {
QualType CanonicalType = Type.getCanonicalType();
if (CanonicalType->isIncompleteType() || CanonicalType->isDependentType() ||
CanonicalType->isEnumeralType() || CanonicalType->isArrayType())
CanonicalType->isArrayType())
return false;

if (CanonicalType->isEnumeralType())
return EqualityComparisonIsDefaulted(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you are currently only testing if it is defaulted, we should be testing all code paths if possible to insure this does what we expect and to prevent future regressions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you're referring to. If you mean I should test a non-trivially equality comparable enum, that's already there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am probably being dense here but I thought for the both the test cases below EqualityComparisonIsDefaulted would return true so I was asking if there is a test case in which it returns false.

S, cast<EnumDecl>(CanonicalType->getAsTagDecl()), KeyLoc);

if (const auto *RD = CanonicalType->getAsCXXRecordDecl()) {
if (!HasNonDeletedDefaultedEqualityComparison(S, RD, KeyLoc))
return false;
Expand Down
12 changes: 12 additions & 0 deletions clang/test/SemaCXX/type-traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3873,6 +3873,11 @@ static_assert(!__is_trivially_equality_comparable(NonTriviallyEqualityComparable

#if __cplusplus >= 202002L

enum TriviallyEqualityComparableEnum {
x, y
};
static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableEnum));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice, thank you!

At line 3853 add:

static_assert(__is_trivially_equality_comparable(Enum));
static_assert(__is_trivially_equality_comparable(SignedEnum));
static_assert(__is_trivially_equality_comparable(UnsignedEnum));
static_assert(__is_trivially_equality_comparable(EnumClass));
static_assert(__is_trivially_equality_comparable(SignedEnumClass));
static_assert(__is_trivially_equality_comparable(UnsignedEnumClass));

It would be beneficial also to test that the compiler does not crash in this case:

enum E { e };
static_assert(__is_trivially_equality_comparable(E));
bool operator==(E, E);
static_assert(!__is_trivially_equality_comparable(E));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced that it's a good idea to test behaviour that's UB. AFAICT we don't have a test like this for any other trait.

struct TriviallyEqualityComparable {
int i;
int j;
Expand All @@ -3891,6 +3896,13 @@ struct TriviallyEqualityComparableContainsArray {
};
static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContainsArray));

struct TriviallyEqualityComparableContainsEnum {
TriviallyEqualityComparableEnum e;

bool operator==(const TriviallyEqualityComparableContainsEnum&) const = default;
};
static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContainsEnum));

struct TriviallyEqualityComparableContainsMultiDimensionArray {
int a[4][4];

Expand Down
Loading