From 0e7bdce1bf13b0146873f3fc6bac640b749a74eb Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 27 May 2025 17:05:33 +0200 Subject: [PATCH 1/2] Fix: make unsupported comparisons return `NotImplemented` Make all comparator magic methods return `NotImplemented` instead of `False` (or raising `TypeError` in some instances) if the other operand is not of a supported type. This means that when comparing a driver type with another type is doesn't support, the other type get the chance to handle the comparison. Affected types: * `neo4j.Record` * `neo4j.graph.Node`, `neo4j.graph.Relationship`, `neo4j.graph.Path` * `neo4j.time.Date`, `neo4j.time.Time`, `neo4j.time.DateTime` * `neo4j.spatial.Point` (and subclasses) --- CHANGELOG.md | 9 ++ .../_codec/packstream/_python/_common.py | 5 +- src/neo4j/_data.py | 5 +- src/neo4j/_io/__init__.py | 8 +- src/neo4j/api.py | 2 +- src/neo4j/graph/__init__.py | 12 +- src/neo4j/spatial/__init__.py | 5 +- src/neo4j/time/__init__.py | 152 ++++++++---------- 8 files changed, 84 insertions(+), 114 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 99a018564..76912d1cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -122,6 +122,15 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. should be treated as immutable. - Graph type sets (`neo4j.graph.EntitySetView`) can no longer by indexed by legacy `id` (`int`, e.g., `graph.nodes[0]`). Use the `element_id` instead (`str`, e.g., `graph.nodes["..."]`). +- Make all comparator magic methods return `NotImplemented` instead of `False` (or raising `TypeError` in some + instances) if the other operand is not of a supported type. + This means that when comparing a driver type with another type is doesn't support, the other type get the chance to + handle the comparison. + Affected types: + - `neo4j.Record` + - `neo4j.graph.Node`, `neo4j.graph.Relationship`, `neo4j.graph.Path` + - `neo4j.time.Date`, `neo4j.time.Time`, `neo4j.time.DateTime` + - `neo4j.spatial.Point` (and subclasses) ## Version 5.28 diff --git a/src/neo4j/_codec/packstream/_python/_common.py b/src/neo4j/_codec/packstream/_python/_common.py index 3cb230838..f2fa6bb01 100644 --- a/src/neo4j/_codec/packstream/_python/_common.py +++ b/src/neo4j/_codec/packstream/_python/_common.py @@ -28,10 +28,7 @@ def __eq__(self, other): try: return self.tag == other.tag and self.fields == other.fields except AttributeError: - return False - - def __ne__(self, other): - return not self.__eq__(other) + return NotImplementedError def __len__(self): return len(self.fields) diff --git a/src/neo4j/_data.py b/src/neo4j/_data.py index be961ddb5..2e1b622c5 100644 --- a/src/neo4j/_data.py +++ b/src/neo4j/_data.py @@ -115,10 +115,7 @@ def __eq__(self, other: object) -> bool: other = t.cast(t.Mapping, other) return dict(self) == dict(other) else: - return False - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) + return NotImplemented def __hash__(self): return reduce(xor_operator, map(hash, self.items())) diff --git a/src/neo4j/_io/__init__.py b/src/neo4j/_io/__init__.py index b5486a72e..f90680cea 100644 --- a/src/neo4j/_io/__init__.py +++ b/src/neo4j/_io/__init__.py @@ -61,28 +61,28 @@ def __ne__(self, other: object) -> bool: return self.version != other return NotImplemented - def __lt__(self, other: object) -> bool: + def __lt__(self, other: BoltProtocolVersion | tuple) -> bool: if isinstance(other, BoltProtocolVersion): return self.version < other.version if isinstance(other, tuple): return self.version < other return NotImplemented - def __le__(self, other: object) -> bool: + def __le__(self, other: BoltProtocolVersion | tuple) -> bool: if isinstance(other, BoltProtocolVersion): return self.version <= other.version if isinstance(other, tuple): return self.version <= other return NotImplemented - def __gt__(self, other: object) -> bool: + def __gt__(self, other: BoltProtocolVersion | tuple) -> bool: if isinstance(other, BoltProtocolVersion): return self.version > other.version if isinstance(other, tuple): return self.version > other return NotImplemented - def __ge__(self, other: object) -> bool: + def __ge__(self, other: BoltProtocolVersion | tuple) -> bool: if isinstance(other, BoltProtocolVersion): return self.version >= other.version if isinstance(other, tuple): diff --git a/src/neo4j/api.py b/src/neo4j/api.py index 10663f2b0..e2809054c 100644 --- a/src/neo4j/api.py +++ b/src/neo4j/api.py @@ -112,7 +112,7 @@ def __init__( if parameters: self.parameters = parameters - def __eq__(self, other: t.Any) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, Auth): return NotImplemented return vars(self) == vars(other) diff --git a/src/neo4j/graph/__init__.py b/src/neo4j/graph/__init__.py index 0de193275..279157325 100644 --- a/src/neo4j/graph/__init__.py +++ b/src/neo4j/graph/__init__.py @@ -116,7 +116,6 @@ def __init__( } def __eq__(self, other: t.Any) -> bool: - # TODO: 6.0 - return NotImplemented on type mismatch instead of False try: return ( type(self) is type(other) @@ -124,10 +123,7 @@ def __eq__(self, other: t.Any) -> bool: and self.element_id == other.element_id ) except AttributeError: - return False - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) + return NotImplemented def __hash__(self): return hash(self._element_id) @@ -324,17 +320,13 @@ def __repr__(self) -> str: ) def __eq__(self, other: t.Any) -> bool: - # TODO: 6.0 - return NotImplemented on type mismatch instead of False try: return ( self.start_node == other.start_node and self.relationships == other.relationships ) except AttributeError: - return False - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) + return NotImplemented def __hash__(self): value = hash(self._nodes[0]) diff --git a/src/neo4j/spatial/__init__.py b/src/neo4j/spatial/__init__.py index 0140528b7..0529d18bb 100644 --- a/src/neo4j/spatial/__init__.py +++ b/src/neo4j/spatial/__init__.py @@ -76,10 +76,7 @@ def __eq__(self, other: object) -> bool: _t.cast(Point, other) ) except (AttributeError, TypeError): - return False - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) + return NotImplemented def __hash__(self): return hash(type(self)) ^ hash(tuple(self)) diff --git a/src/neo4j/time/__init__.py b/src/neo4j/time/__init__.py index 362753a69..8b6e44e86 100644 --- a/src/neo4j/time/__init__.py +++ b/src/neo4j/time/__init__.py @@ -1188,60 +1188,31 @@ def __hash__(self): def __eq__(self, other: object) -> bool: """``==`` comparison with :class:`.Date` or :class:`datetime.date`.""" if not isinstance(other, (Date, date)): - # TODO: 6.0 - return NotImplemented for non-Date objects - # return NotImplemented - return False + return NotImplemented return self.toordinal() == other.toordinal() - def __ne__(self, other: object) -> bool: - """``!=`` comparison with :class:`.Date` or :class:`datetime.date`.""" - # TODO: 6.0 - return NotImplemented for non-Date objects - # if not isinstance(other, (Date, date)): - # return NotImplemented - return not self.__eq__(other) - def __lt__(self, other: Date | date) -> bool: """``<`` comparison with :class:`.Date` or :class:`datetime.date`.""" if not isinstance(other, (Date, date)): - # TODO: 6.0 - return NotImplemented for non-Date objects - # return NotImplemented - raise TypeError( - "'<' not supported between instances of 'Date' and " - f"{type(other).__name__!r}" - ) + return NotImplemented return self.toordinal() < other.toordinal() def __le__(self, other: Date | date) -> bool: """``<=`` comparison with :class:`.Date` or :class:`datetime.date`.""" if not isinstance(other, (Date, date)): - # TODO: 6.0 - return NotImplemented for non-Date objects - # return NotImplemented - raise TypeError( - "'<=' not supported between instances of 'Date' and " - f"{type(other).__name__!r}" - ) + return NotImplemented return self.toordinal() <= other.toordinal() def __ge__(self, other: Date | date) -> bool: """``>=`` comparison with :class:`.Date` or :class:`datetime.date`.""" if not isinstance(other, (Date, date)): - # TODO: 6.0 - return NotImplemented for non-Date objects - # return NotImplemented - raise TypeError( - "'>=' not supported between instances of 'Date' and " - f"{type(other).__name__!r}" - ) + return NotImplemented return self.toordinal() >= other.toordinal() def __gt__(self, other: Date | date) -> bool: """``>`` comparison with :class:`.Date` or :class:`datetime.date`.""" if not isinstance(other, (Date, date)): - # TODO: 6.0 - return NotImplemented for non-Date objects - # return NotImplemented - raise TypeError( - "'>' not supported between instances of 'Date' and " - f"{type(other).__name__!r}" - ) + return NotImplemented return self.toordinal() > other.toordinal() def __add__(self, other: Duration) -> Date: # type: ignore[override] @@ -1857,29 +1828,29 @@ def tzinfo(self) -> _tzinfo | None: # OPERATIONS # - def _get_both_normalized_ticks(self, other: object, strict=True): - if isinstance(other, (time, Time)) and ( - (self.utc_offset() is None) ^ (other.utcoffset() is None) - ): + def _get_both_normalized_ticks( + self, other: object, strict: bool = True + ) -> tuple[int, int] | None: + if not isinstance(other, (Time, time)): + return None + if (self.utc_offset() is None) ^ (other.utcoffset() is None): if strict: raise TypeError( "can't compare offset-naive and offset-aware times" ) else: - return None, None + return None other_ticks: int if isinstance(other, Time): other_ticks = other.__ticks - elif isinstance(other, time): + else: + assert isinstance(other, time) other_ticks = int( 3600000000000 * other.hour + 60000000000 * other.minute + NANO_SECONDS * other.second + 1000 * other.microsecond ) - else: - return None, None - assert isinstance(other, (Time, time)) utc_offset: timedelta | None = other.utcoffset() if utc_offset is not None: other_ticks -= int(utc_offset.total_seconds() * NANO_SECONDS) @@ -1899,43 +1870,52 @@ def __hash__(self): def __eq__(self, other: object) -> bool: """`==` comparison with :class:`.Time` or :class:`datetime.time`.""" - self_ticks, other_ticks = self._get_both_normalized_ticks( - other, strict=False - ) - if self_ticks is None: + if not isinstance(other, (Time, time)): + return NotImplemented + ticks = self._get_both_normalized_ticks(other, strict=False) + if ticks is None: return False + self_ticks, other_ticks = ticks return self_ticks == other_ticks - def __ne__(self, other: object) -> bool: - """`!=` comparison with :class:`.Time` or :class:`datetime.time`.""" - return not self.__eq__(other) - def __lt__(self, other: Time | time) -> bool: """`<` comparison with :class:`.Time` or :class:`datetime.time`.""" - self_ticks, other_ticks = self._get_both_normalized_ticks(other) - if self_ticks is None: + if not isinstance(other, (Time, time)): return NotImplemented + ticks = self._get_both_normalized_ticks(other) + if ticks is None: + return False + self_ticks, other_ticks = ticks return self_ticks < other_ticks def __le__(self, other: Time | time) -> bool: """`<=` comparison with :class:`.Time` or :class:`datetime.time`.""" - self_ticks, other_ticks = self._get_both_normalized_ticks(other) - if self_ticks is None: + if not isinstance(other, (Time, time)): return NotImplemented + ticks = self._get_both_normalized_ticks(other) + if ticks is None: + return False + self_ticks, other_ticks = ticks return self_ticks <= other_ticks def __ge__(self, other: Time | time) -> bool: """`>=` comparison with :class:`.Time` or :class:`datetime.time`.""" - self_ticks, other_ticks = self._get_both_normalized_ticks(other) - if self_ticks is None: + if not isinstance(other, (Time, time)): return NotImplemented + ticks = self._get_both_normalized_ticks(other) + if ticks is None: + return False + self_ticks, other_ticks = ticks return self_ticks >= other_ticks def __gt__(self, other: Time | time) -> bool: """`>` comparison with :class:`.Time` or :class:`datetime.time`.""" - self_ticks, other_ticks = self._get_both_normalized_ticks(other) - if self_ticks is None: + if not isinstance(other, (Time, time)): return NotImplemented + ticks = self._get_both_normalized_ticks(other) + if ticks is None: + return False + self_ticks, other_ticks = ticks return self_ticks > other_ticks # INSTANCE METHODS # @@ -2510,29 +2490,28 @@ def hour_minute_second_nanosecond(self) -> tuple[int, int, int, int]: # OPERATIONS # - def _get_both_normalized(self, other, strict=True): - if isinstance(other, (datetime, DateTime)) and ( - (self.utc_offset() is None) ^ (other.utcoffset() is None) - ): + def _get_both_normalized( + self, other: object, strict: bool = True + ) -> tuple[DateTime, DateTime | datetime] | None: + if not isinstance(other, (datetime, DateTime)): + return None + if (self.utc_offset() is None) ^ (other.utcoffset() is None): if strict: raise TypeError( "can't compare offset-naive and offset-aware datetimes" ) else: - return None, None + return None self_norm = self utc_offset = self.utc_offset() if utc_offset is not None: self_norm -= utc_offset self_norm = self_norm.replace(tzinfo=None) other_norm = other - if isinstance(other, (datetime, DateTime)): - utc_offset = other.utcoffset() - if utc_offset is not None: - other_norm -= utc_offset - other_norm = other_norm.replace(tzinfo=None) - else: - return None, None + utc_offset = other.utcoffset() + if utc_offset is not None: + other_norm -= utc_offset + other_norm = other_norm.replace(tzinfo=None) return self_norm, other_norm def __hash__(self): @@ -2554,21 +2533,12 @@ def __eq__(self, other: object) -> bool: return NotImplemented if self.utc_offset() == other.utcoffset(): return self.date() == other.date() and self.time() == other.time() - self_norm, other_norm = self._get_both_normalized(other, strict=False) - if self_norm is None: + normalized = self._get_both_normalized(other, strict=False) + if normalized is None: return False + self_norm, other_norm = normalized return self_norm == other_norm - def __ne__(self, other: object) -> bool: - """ - ``!=`` comparison with another datetime. - - Accepts :class:`.DateTime` and :class:`datetime.datetime`. - """ - if not isinstance(other, (DateTime, datetime)): - return NotImplemented - return not self.__eq__(other) - def __lt__( # type: ignore[override] self, other: datetime | DateTime ) -> bool: @@ -2583,7 +2553,9 @@ def __lt__( # type: ignore[override] if self.date() == other.date(): return self.time() < other.time() return self.date() < other.date() - self_norm, other_norm = self._get_both_normalized(other) + normalized = self._get_both_normalized(other) + assert normalized is not None, "checked for correct type above" + self_norm, other_norm = normalized return ( self_norm.date() < other_norm.date() or self_norm.time() < other_norm.time() @@ -2603,7 +2575,9 @@ def __le__( # type: ignore[override] if self.date() == other.date(): return self.time() <= other.time() return self.date() <= other.date() - self_norm, other_norm = self._get_both_normalized(other) + normalized = self._get_both_normalized(other) + assert normalized is not None, "checked for correct type above" + self_norm, other_norm = normalized return self_norm <= other_norm def __ge__( # type: ignore[override] @@ -2620,7 +2594,9 @@ def __ge__( # type: ignore[override] if self.date() == other.date(): return self.time() >= other.time() return self.date() >= other.date() - self_norm, other_norm = self._get_both_normalized(other) + normalized = self._get_both_normalized(other) + assert normalized is not None, "checked for correct type above" + self_norm, other_norm = normalized return self_norm >= other_norm def __gt__( # type: ignore[override] @@ -2637,7 +2613,9 @@ def __gt__( # type: ignore[override] if self.date() == other.date(): return self.time() > other.time() return self.date() > other.date() - self_norm, other_norm = self._get_both_normalized(other) + normalized = self._get_both_normalized(other) + assert normalized is not None, "checked for correct type above" + self_norm, other_norm = normalized return ( self_norm.date() > other_norm.date() or self_norm.time() > other_norm.time() From 9358a36177673999206e30c5dd44121471806a1c Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 30 May 2025 09:57:42 +0200 Subject: [PATCH 2/2] Simplify code --- src/neo4j/time/__init__.py | 64 ++++++++++++------------- tests/unit/common/time/test_datetime.py | 56 ++++++++++++++++++++++ tests/unit/common/time/test_time.py | 56 ++++++++++++++++++++++ 3 files changed, 142 insertions(+), 34 deletions(-) diff --git a/src/neo4j/time/__init__.py b/src/neo4j/time/__init__.py index 7e26bfec4..2c7afa93d 100644 --- a/src/neo4j/time/__init__.py +++ b/src/neo4j/time/__init__.py @@ -1888,11 +1888,19 @@ def tzinfo(self) -> _tzinfo | None: # OPERATIONS # + @_t.overload + def _get_both_normalized_ticks( + self, other: Time | _time, strict: _t.Literal[True] = True + ) -> tuple[int, int]: ... + + @_t.overload def _get_both_normalized_ticks( - self, other: object, strict: bool = True + self, other: Time | _time, strict: _t.Literal[False] + ) -> tuple[int, int] | None: ... + + def _get_both_normalized_ticks( + self, other: Time | _time, strict: bool = True ) -> tuple[int, int] | None: - if not isinstance(other, (Time, _time)): - return None if (self.utc_offset() is None) ^ (other.utcoffset() is None): if strict: raise TypeError( @@ -1942,40 +1950,28 @@ def __lt__(self, other: Time | _time) -> bool: """`<` comparison with :class:`.Time` or :class:`datetime.time`.""" if not isinstance(other, (Time, _time)): return NotImplemented - ticks = self._get_both_normalized_ticks(other) - if ticks is None: - return False - self_ticks, other_ticks = ticks + self_ticks, other_ticks = self._get_both_normalized_ticks(other) return self_ticks < other_ticks def __le__(self, other: Time | _time) -> bool: """`<=` comparison with :class:`.Time` or :class:`datetime.time`.""" if not isinstance(other, (Time, _time)): return NotImplemented - ticks = self._get_both_normalized_ticks(other) - if ticks is None: - return False - self_ticks, other_ticks = ticks + self_ticks, other_ticks = self._get_both_normalized_ticks(other) return self_ticks <= other_ticks def __ge__(self, other: Time | _time) -> bool: """`>=` comparison with :class:`.Time` or :class:`datetime.time`.""" if not isinstance(other, (Time, _time)): return NotImplemented - ticks = self._get_both_normalized_ticks(other) - if ticks is None: - return False - self_ticks, other_ticks = ticks + self_ticks, other_ticks = self._get_both_normalized_ticks(other) return self_ticks >= other_ticks def __gt__(self, other: Time | _time) -> bool: """`>` comparison with :class:`.Time` or :class:`datetime.time`.""" if not isinstance(other, (Time, _time)): return NotImplemented - ticks = self._get_both_normalized_ticks(other) - if ticks is None: - return False - self_ticks, other_ticks = ticks + self_ticks, other_ticks = self._get_both_normalized_ticks(other) return self_ticks > other_ticks # INSTANCE METHODS # @@ -2583,11 +2579,19 @@ def hour_minute_second_nanosecond(self) -> tuple[int, int, int, int]: # OPERATIONS # + @_t.overload + def _get_both_normalized( + self, other: _datetime | DateTime, strict: _t.Literal[True] = True + ) -> tuple[DateTime, DateTime | _datetime]: ... + + @_t.overload + def _get_both_normalized( + self, other: _datetime | DateTime, strict: _t.Literal[False] + ) -> tuple[DateTime, DateTime | _datetime] | None: ... + def _get_both_normalized( - self, other: object, strict: bool = True + self, other: _datetime | DateTime, strict: bool = True ) -> tuple[DateTime, DateTime | _datetime] | None: - if not isinstance(other, (_datetime, DateTime)): - return None if (self.utc_offset() is None) ^ (other.utcoffset() is None): if strict: raise TypeError( @@ -2646,9 +2650,7 @@ def __lt__( # type: ignore[override] if self.date() == other.date(): return self.time() < other.time() return self.date() < other.date() - normalized = self._get_both_normalized(other) - assert normalized is not None, "checked for correct type above" - self_norm, other_norm = normalized + self_norm, other_norm = self._get_both_normalized(other) return ( self_norm.date() < other_norm.date() or self_norm.time() < other_norm.time() @@ -2668,9 +2670,7 @@ def __le__( # type: ignore[override] if self.date() == other.date(): return self.time() <= other.time() return self.date() <= other.date() - normalized = self._get_both_normalized(other) - assert normalized is not None, "checked for correct type above" - self_norm, other_norm = normalized + self_norm, other_norm = self._get_both_normalized(other) return self_norm <= other_norm def __ge__( # type: ignore[override] @@ -2687,9 +2687,7 @@ def __ge__( # type: ignore[override] if self.date() == other.date(): return self.time() >= other.time() return self.date() >= other.date() - normalized = self._get_both_normalized(other) - assert normalized is not None, "checked for correct type above" - self_norm, other_norm = normalized + self_norm, other_norm = self._get_both_normalized(other) return self_norm >= other_norm def __gt__( # type: ignore[override] @@ -2706,9 +2704,7 @@ def __gt__( # type: ignore[override] if self.date() == other.date(): return self.time() > other.time() return self.date() > other.date() - normalized = self._get_both_normalized(other) - assert normalized is not None, "checked for correct type above" - self_norm, other_norm = normalized + self_norm, other_norm = self._get_both_normalized(other) return ( self_norm.date() > other_norm.date() or self_norm.time() > other_norm.time() diff --git a/tests/unit/common/time/test_datetime.py b/tests/unit/common/time/test_datetime.py index 3a7740ef2..5610f393b 100644 --- a/tests/unit/common/time/test_datetime.py +++ b/tests/unit/common/time/test_datetime.py @@ -1190,6 +1190,62 @@ def test_comparison(dt1, dt2) -> None: assert not dt1 >= dt2 +@pytest.mark.parametrize( + ("dt1_args", "dt2_args"), + ( + ( + (2022, 11, 25, 12, 34, 56, 789124), + (2022, 11, 25, 12, 34, 56, 789124), + ), + ( + (2022, 11, 25, 12, 33, 56, 789124), + (2022, 11, 25, 12, 34, 56, 789124), + ), + ( + (2022, 11, 25, 12, 34, 56, 789124), + (2022, 11, 25, 12, 35, 56, 789124), + ), + ( + (2022, 11, 25, 12, 32, 56, 789124), + (2022, 11, 25, 12, 34, 56, 789124), + ), + ( + (2022, 11, 25, 12, 34, 56, 789124), + (2022, 11, 25, 12, 36, 56, 789124), + ), + ), +) +@pytest.mark.parametrize("dt1_cls", (DateTime, datetime)) +@pytest.mark.parametrize("dt2_cls", (DateTime, datetime)) +@pytest.mark.parametrize( + "tz", + (FixedOffset(0), FixedOffset(1), FixedOffset(-1), utc, timezone_berlin), +) +def test_comparison_only_one_with_tzinfo( + dt1_args, dt1_cls, dt2_args, dt2_cls, tz +) -> None: + dt1 = dt1_cls(*dt1_args) + dt2 = dt2_cls(*dt2_args, tzinfo=None) + err_msg = "can't compare offset-naive and offset-aware" + dt2 = dt2.replace(tzinfo=tz) + with pytest.raises(TypeError, match=err_msg): + assert not dt1 < dt2 + with pytest.raises(TypeError, match=err_msg): + assert not dt2 < dt1 + with pytest.raises(TypeError, match=err_msg): + assert not dt1 <= dt2 + with pytest.raises(TypeError, match=err_msg): + assert not dt2 <= dt1 + with pytest.raises(TypeError, match=err_msg): + assert not dt1 > dt2 + with pytest.raises(TypeError, match=err_msg): + assert not dt2 > dt1 + with pytest.raises(TypeError, match=err_msg): + assert not dt1 <= dt2 + with pytest.raises(TypeError, match=err_msg): + assert not dt2 <= dt1 + + def test_str() -> None: dt = DateTime(2018, 4, 26, 23, 0, 17, 914390409) assert str(dt) == "2018-04-26T23:00:17.914390409" diff --git a/tests/unit/common/time/test_time.py b/tests/unit/common/time/test_time.py index 4fa414375..2dc7a342c 100644 --- a/tests/unit/common/time/test_time.py +++ b/tests/unit/common/time/test_time.py @@ -577,6 +577,62 @@ def test_pickle(self, expected): assert expected.foo is not actual.foo +@pytest.mark.parametrize( + ("t1_args", "t2_args"), + ( + ( + (12, 34, 56, 789124), + (12, 34, 56, 789124), + ), + ( + (12, 33, 56, 789124), + (12, 34, 56, 789124), + ), + ( + (12, 34, 56, 789124), + (12, 35, 56, 789124), + ), + ( + (12, 32, 56, 789124), + (12, 34, 56, 789124), + ), + ( + (12, 34, 56, 789124), + (12, 36, 56, 789124), + ), + ), +) +@pytest.mark.parametrize("t1_cls", (Time, time)) +@pytest.mark.parametrize("t2_cls", (Time, time)) +@pytest.mark.parametrize( + "tz", + (FixedOffset(0), FixedOffset(1), FixedOffset(-1), utc), +) +def test_comparison_only_one_with_tzinfo( + t1_args, t1_cls, t2_args, t2_cls, tz +) -> None: + t1 = t1_cls(*t1_args) + t2 = t2_cls(*t2_args, tzinfo=None) + err_msg = "can't compare offset-naive and offset-aware" + t2 = t2.replace(tzinfo=tz) + with pytest.raises(TypeError, match=err_msg): + assert not t1 < t2 + with pytest.raises(TypeError, match=err_msg): + assert not t2 < t1 + with pytest.raises(TypeError, match=err_msg): + assert not t1 <= t2 + with pytest.raises(TypeError, match=err_msg): + assert not t2 <= t1 + with pytest.raises(TypeError, match=err_msg): + assert not t1 > t2 + with pytest.raises(TypeError, match=err_msg): + assert not t2 > t1 + with pytest.raises(TypeError, match=err_msg): + assert not t1 <= t2 + with pytest.raises(TypeError, match=err_msg): + assert not t2 <= t1 + + def test_str() -> None: t = Time(12, 34, 56, 789123001) assert str(t) == "12:34:56.789123001"