diff --git a/django_redis/client/default.py b/django_redis/client/default.py index 940e1a0e..f86578db 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -317,6 +317,10 @@ def expire( timeout: ExpiryT, version: Optional[int] = None, client: Optional[Redis] = None, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, ) -> bool: if timeout is DEFAULT_TIMEOUT: timeout = self._backend.default_timeout # type: ignore @@ -326,7 +330,7 @@ def expire( key = self.make_key(key, version=version) - return client.expire(key, timeout) + return client.expire(key, timeout, nx, xx, gt, lt) def pexpire( self, @@ -334,6 +338,10 @@ def pexpire( timeout: ExpiryT, version: Optional[int] = None, client: Optional[Redis] = None, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, ) -> bool: if timeout is DEFAULT_TIMEOUT: timeout = self._backend.default_timeout # type: ignore @@ -343,7 +351,7 @@ def pexpire( key = self.make_key(key, version=version) - return bool(client.pexpire(key, timeout)) + return bool(client.pexpire(key, timeout, nx, xx, gt, lt)) def pexpire_at( self, diff --git a/django_redis/client/sharded.py b/django_redis/client/sharded.py index 5e2eec90..a8858ad6 100644 --- a/django_redis/client/sharded.py +++ b/django_redis/client/sharded.py @@ -171,19 +171,57 @@ def persist(self, key, version=None, client=None): return super().persist(key=key, version=version, client=client) - def expire(self, key, timeout, version=None, client=None): + def expire( + self, + key, + timeout, + version=None, + client=None, + nx=False, + xx=False, + gt=False, + lt=False, + ): if client is None: key = self.make_key(key, version=version) client = self.get_server(key) - return super().expire(key=key, timeout=timeout, version=version, client=client) + return super().expire( + key=key, + timeout=timeout, + version=version, + client=client, + nx=nx, + xx=xx, + gt=gt, + lt=lt, + ) - def pexpire(self, key, timeout, version=None, client=None): + def pexpire( + self, + key, + timeout, + version=None, + client=None, + nx=False, + xx=False, + gt=False, + lt=False, + ): if client is None: key = self.make_key(key, version=version) client = self.get_server(key) - return super().pexpire(key=key, timeout=timeout, version=version, client=client) + return super().pexpire( + key=key, + timeout=timeout, + version=version, + client=client, + nx=nx, + xx=xx, + gt=gt, + lt=lt, + ) def pexpire_at(self, key, when: Union[datetime, int], version=None, client=None): """ diff --git a/tests/test_backend.py b/tests/test_backend.py index e5c54e18..a6a27985 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -11,6 +11,7 @@ from django.test import override_settings from pytest_django.fixtures import SettingsWrapper from pytest_mock import MockerFixture +from redis.exceptions import ResponseError from django_redis.cache import RedisCache from django_redis.client import ShardClient, herd @@ -600,30 +601,165 @@ def test_persist(self, cache: RedisCache): assert ttl is None assert cache.persist("not-existent-key") is False - def test_expire(self, cache: RedisCache): - cache.set("foo", "bar", timeout=None) - assert cache.expire("foo", 20) is True - ttl = cache.ttl("foo") - assert pytest.approx(ttl) == 20 - assert cache.expire("not-existent-key", 20) is False + @pytest.mark.parametrize( + "initial_timeout, new_timeout, params, expected_result, expected_ttl", + [ + # Basic expire functionality (existing test case) + (None, 20, {}, True, 20), + # NX tests (only set if key has no expiry) + # Should work - key has no expiry + (None, 30, {"nx": True}, True, 30), + # Should fail - key already has expiry + (20, 30, {"nx": True}, False, 20), + # XX tests (only set if key has existing expiry) + # Should work - key has expiry + (20, 30, {"xx": True}, True, 30), + # Should fail - key has no expiry + (None, 30, {"xx": True}, False, None), + # GT tests (only set if new expiry is greater than current) + # Should work - new timeout > current + (20, 30, {"gt": True}, True, 30), + # Should fail - new timeout < current + (30, 20, {"gt": True}, False, 30), + # LT tests (only set if new expiry is less than current) + # Should work - new timeout < current + (30, 20, {"lt": True}, True, 20), + # Should fail - new timeout > current + (20, 30, {"lt": True}, False, 20), + ], + ) + def test_expire_with_conditions( + self, + cache: RedisCache, + initial_timeout, + new_timeout, + params, + expected_result, + expected_ttl, + ): + cache.set("foo", "bar", timeout=initial_timeout) + result = cache.expire("foo", new_timeout, **params) + + assert result is expected_result + if expected_ttl is not None: + assert pytest.approx(cache.ttl("foo")) == expected_ttl - def test_expire_with_default_timeout(self, cache: RedisCache): + def test_expire_combinations(self, cache: RedisCache): + # Test that incompatible combinations raise redis.exceptions.ResponseError + cache.set("foo", "bar", timeout=20) + + # NX and XX are mutually exclusive + with pytest.raises(ResponseError): + cache.expire("foo", 30, nx=True, xx=True) + + # GT and LT are mutually exclusive) + with pytest.raises(ResponseError): + cache.expire("foo", 30, gt=True, lt=True) + + def test_expire_on_non_existent_key(self, cache: RedisCache): + # Test expire with conditions on non-existent key + assert cache.expire("non-existent", 20) is False + assert cache.expire("non-existent", 20, nx=True) is False + assert cache.expire("non-existent", 20, xx=True) is False + assert cache.expire("non-existent", 20, gt=True) is False + assert cache.expire("non-existent", 20, lt=True) is False + + def test_expire_with_default_timeout_and_conditions(self, cache: RedisCache): cache.set("foo", "bar", timeout=None) - assert cache.expire("foo", DEFAULT_TIMEOUT) is True - assert cache.expire("not-existent-key", DEFAULT_TIMEOUT) is False + assert cache.expire("foo", DEFAULT_TIMEOUT, nx=True) is True + + cache.set("foo2", "bar", timeout=20) + assert cache.expire("foo2", DEFAULT_TIMEOUT, xx=True) is True + + @pytest.mark.parametrize( + "initial_timeout, new_timeout, params, expected_result, expected_pttl", + [ + # Basic pexpire functionality (existing test case) + (None, 20500, {}, True, 20500), + # NX tests (only set if key has no expiry) + # Should work - key has no expiry + (None, 30500, {"nx": True}, True, 30500), + # Should fail - key already has expiry + (20500, 30500, {"nx": True}, False, 20500), + # XX tests (only set if key has existing expiry) + # Should work - key has expiry + (20500, 30500, {"xx": True}, True, 30500), + # Should fail - key has no expiry + (None, 30500, {"xx": True}, False, None), + # GT tests (only set if new expiry is greater than current) + # Should work - new timeout > current + (20500, 30500, {"gt": True}, True, 30500), + # Should fail - new timeout < current + (30500, 20500, {"gt": True}, False, 30500), + # LT tests (only set if new expiry is less than current) + # Should work - new timeout < current + (30500, 20500, {"lt": True}, True, 20500), + # Should fail - new timeout > current + (20500, 30500, {"lt": True}, False, 20500), + ], + ) + def test_pexpire_with_conditions( + self, + cache: RedisCache, + initial_timeout, + new_timeout, + params, + expected_result, + expected_pttl, + ): + cache.set( + "foo", "bar", timeout=initial_timeout / 1000 if initial_timeout else None + ) + result = cache.pexpire("foo", new_timeout, **params) + + assert result is expected_result + if expected_pttl is not None: + # Using a delta of 10ms for approximate comparison due to timing precision + assert pytest.approx(cache.pttl("foo"), abs=10) == expected_pttl + + def test_pexpire_combinations(self, cache: RedisCache): + # Test that incompatible combinations raise redis.exceptions.ResponseError + cache.set("foo", "bar", timeout=20) + + # NX and XX are mutually exclusive + with pytest.raises(ResponseError): + cache.pexpire("foo", 30500, nx=True, xx=True) + + # GT and LT are mutually exclusive + with pytest.raises(ResponseError): + cache.pexpire("foo", 30500, gt=True, lt=True) - def test_pexpire(self, cache: RedisCache): + def test_pexpire_on_non_existent_key(self, cache: RedisCache): + # Test pexpire with conditions on non-existent key + assert cache.pexpire("non-existent", 20500) is False + assert cache.pexpire("non-existent", 20500, nx=True) is False + assert cache.pexpire("non-existent", 20500, xx=True) is False + assert cache.pexpire("non-existent", 20500, gt=True) is False + assert cache.pexpire("non-existent", 20500, lt=True) is False + + def test_pexpire_with_default_timeout_and_conditions(self, cache: RedisCache): + # Test with DEFAULT_TIMEOUT cache.set("foo", "bar", timeout=None) - assert cache.pexpire("foo", 20500) is True - ttl = cache.pttl("foo") - # delta is set to 10 as precision error causes tests to fail - assert pytest.approx(ttl, 10) == 20500 - assert cache.pexpire("not-existent-key", 20500) is False + assert cache.pexpire("foo", DEFAULT_TIMEOUT, nx=True) is True - def test_pexpire_with_default_timeout(self, cache: RedisCache): + # Set a specific timeout and test with XX condition + cache.set("foo2", "bar", timeout=20) + assert cache.pexpire("foo2", DEFAULT_TIMEOUT, xx=True) is True + + def test_pexpire_precision(self, cache: RedisCache): + # Test precision with very small and large millisecond values cache.set("foo", "bar", timeout=None) - assert cache.pexpire("foo", DEFAULT_TIMEOUT) is True - assert cache.pexpire("not-existent-key", DEFAULT_TIMEOUT) is False + + # Test with small millisecond value + assert cache.pexpire("foo", 100) is True # 100ms + pttl = cache.pttl("foo") + assert pytest.approx(pttl, abs=10) == 100 + + # Test with large millisecond value + cache.set("foo2", "bar", timeout=None) + assert cache.pexpire("foo2", 3600000) is True # 1 hour in ms + pttl = cache.pttl("foo2") + assert pytest.approx(pttl, abs=10) == 3600000 def test_pexpire_at(self, cache: RedisCache): # Test settings expiration time 1 hour ahead by datetime.