diff --git a/.gitignore b/.gitignore index 6df9633f..57805bec 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ dump.rdb .coverage coverage.xml cobertura.xml +CLAUDE.md diff --git a/.ruff.toml b/.ruff.toml index 0eb858db..a253c1c3 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -178,3 +178,6 @@ ban-relative-imports = "all" # pickle is used on purpose and its use is discouraged "django_redis/serializers/pickle.py" = ["S301"] + +# min/max are official Redis parameter names matching redis-py API +"django_redis/client/mixins/sorted_sets.py" = ["A002"] diff --git a/changelog.d/797.feature b/changelog.d/797.feature new file mode 100644 index 00000000..e42e4eff --- /dev/null +++ b/changelog.d/797.feature @@ -0,0 +1 @@ +Add sorted set operations (zadd, zrange, zrem, etc.) and mixins for RedisCache diff --git a/django_redis/cache.py b/django_redis/cache.py index 94881ac5..4e218be9 100644 --- a/django_redis/cache.py +++ b/django_redis/cache.py @@ -278,3 +278,60 @@ def hkeys(self, *args, **kwargs): @omit_exception def hexists(self, *args, **kwargs): return self.client.hexists(*args, **kwargs) + + # Sorted Set Operations + @omit_exception + def zadd(self, *args, **kwargs): + return self.client.zadd(*args, **kwargs) + + @omit_exception + def zcard(self, *args, **kwargs): + return self.client.zcard(*args, **kwargs) + + @omit_exception + def zcount(self, *args, **kwargs): + return self.client.zcount(*args, **kwargs) + + @omit_exception + def zincrby(self, *args, **kwargs): + return self.client.zincrby(*args, **kwargs) + + @omit_exception + def zpopmax(self, *args, **kwargs): + return self.client.zpopmax(*args, **kwargs) + + @omit_exception + def zpopmin(self, *args, **kwargs): + return self.client.zpopmin(*args, **kwargs) + + @omit_exception + def zrange(self, *args, **kwargs): + return self.client.zrange(*args, **kwargs) + + @omit_exception + def zrangebyscore(self, *args, **kwargs): + return self.client.zrangebyscore(*args, **kwargs) + + @omit_exception + def zrank(self, *args, **kwargs): + return self.client.zrank(*args, **kwargs) + + @omit_exception + def zrem(self, *args, **kwargs): + return self.client.zrem(*args, **kwargs) + + @omit_exception + def zremrangebyscore(self, *args, **kwargs): + return self.client.zremrangebyscore(*args, **kwargs) + + @omit_exception + def zrevrange(self, *args, **kwargs): + return self.client.zrevrange(*args, **kwargs) + + @omit_exception + def zrevrangebyscore(self, *args, **kwargs): + return self.client.zrevrangebyscore(*args, **kwargs) + + @omit_exception + def zscore(self, *args, **kwargs): + return self.client.zscore(*args, **kwargs) diff --git a/django_redis/client/default.py b/django_redis/client/default.py index a2833125..258b6b2c 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -23,6 +23,7 @@ from redis.typing import AbsExpiryT, EncodableT, ExpiryT, KeyT, PatternT from django_redis import pool +from django_redis.client.mixins import SortedSetMixin from django_redis.exceptions import CompressorError, ConnectionInterrupted from django_redis.util import CacheKey @@ -40,7 +41,7 @@ def glob_escape(s: str) -> str: return special_re.sub(r"[\1]", s) -class DefaultClient: +class DefaultClient(SortedSetMixin): def __init__(self, server, params: dict[str, Any], backend: BaseCache) -> None: self._backend = backend self._server = server diff --git a/django_redis/client/mixins/__init__.py b/django_redis/client/mixins/__init__.py new file mode 100644 index 00000000..4da0aea9 --- /dev/null +++ b/django_redis/client/mixins/__init__.py @@ -0,0 +1,4 @@ +from django_redis.client.mixins.protocols import ClientProtocol +from django_redis.client.mixins.sorted_sets import SortedSetMixin + +__all__ = ["ClientProtocol", "SortedSetMixin"] diff --git a/django_redis/client/mixins/protocols.py b/django_redis/client/mixins/protocols.py new file mode 100644 index 00000000..bcfcf9af --- /dev/null +++ b/django_redis/client/mixins/protocols.py @@ -0,0 +1,33 @@ +from typing import Any, Optional, Protocol, Union + +from redis import Redis +from redis.typing import KeyT + + +class ClientProtocol(Protocol): + """ + Protocol for client methods required by mixins. + + Any class using django-redis mixins must implement these methods. + """ + + def make_key( + self, + key: KeyT, + version: Optional[int] = None, + prefix: Optional[str] = None, + ) -> KeyT: + """Create a cache key with optional version and prefix.""" + ... + + def encode(self, value: Any) -> Union[bytes, int]: + """Encode a value for storage in Redis.""" + ... + + def decode(self, value: Union[bytes, int]) -> Any: + """Decode a value retrieved from Redis.""" + ... + + def get_client(self, write: bool = False) -> Redis: + """Get a Redis client instance for read or write operations.""" + ... diff --git a/django_redis/client/mixins/sorted_sets.py b/django_redis/client/mixins/sorted_sets.py new file mode 100644 index 00000000..928ae4b3 --- /dev/null +++ b/django_redis/client/mixins/sorted_sets.py @@ -0,0 +1,324 @@ +from typing import Any, Optional, Union + +from redis import Redis +from redis.typing import KeyT + +from django_redis.client.mixins.protocols import ClientProtocol + + +class SortedSetMixin(ClientProtocol): + """Mixin providing Redis sorted set (ZSET) operations.""" + + def zadd( + self, + name: KeyT, + mapping: dict[Any, float], + nx: bool = False, + xx: bool = False, + ch: bool = False, + incr: bool = False, + gt: bool = False, + lt: bool = False, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + """Add members with scores to sorted set.""" + if client is None: + client = self.get_client(write=True) + + name = self.make_key(name, version=version) + # Encode members but NOT scores (scores must remain as floats) + encoded_mapping = { + self.encode(member): score for member, score in mapping.items() + } + + return int( + client.zadd( + name, + encoded_mapping, # type: ignore[arg-type] + nx=nx, + xx=xx, + ch=ch, + incr=incr, + gt=gt, + lt=lt, + ), + ) + + def zcard( + self, + name: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + """Get the number of members in sorted set.""" + if client is None: + client = self.get_client(write=False) + + name = self.make_key(name, version=version) + return int(client.zcard(name)) + + def zcount( + self, + name: KeyT, + min: Union[float, str], + max: Union[float, str], + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + """Count members in sorted set with scores between min and max.""" + if client is None: + client = self.get_client(write=False) + + name = self.make_key(name, version=version) + return int(client.zcount(name, min, max)) + + def zincrby( + self, + name: KeyT, + amount: float, + value: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> float: + """Increment the score of member in sorted set by amount.""" + if client is None: + client = self.get_client(write=True) + + name = self.make_key(name, version=version) + value = self.encode(value) + return float(client.zincrby(name, amount, value)) + + def zpopmax( + self, + name: KeyT, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[list[tuple[Any, float]], tuple[Any, float], None]: + """Remove and return members with highest scores.""" + if client is None: + client = self.get_client(write=True) + + name = self.make_key(name, version=version) + result = client.zpopmax(name, count) + + if not result: + return None if count is None else [] + + decoded = [(self.decode(member), score) for member, score in result] + + if count is None: + return decoded[0] if decoded else None + + return decoded + + def zpopmin( + self, + name: KeyT, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[list[tuple[Any, float]], tuple[Any, float], None]: + """Remove and return members with lowest scores.""" + if client is None: + client = self.get_client(write=True) + + name = self.make_key(name, version=version) + result = client.zpopmin(name, count) + + if not result: + return None if count is None else [] + + decoded = [(self.decode(member), score) for member, score in result] + + if count is None: + return decoded[0] if decoded else None + + return decoded + + def zrange( + self, + name: KeyT, + start: int, + end: int, + desc: bool = False, + withscores: bool = False, + score_cast_func: type = float, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[list[Any], list[tuple[Any, float]]]: + """Return members in sorted set by index range.""" + if client is None: + client = self.get_client(write=False) + + name = self.make_key(name, version=version) + result = client.zrange( + name, + start, + end, + desc=desc, + withscores=withscores, + score_cast_func=score_cast_func, + ) + + if withscores: + return [(self.decode(member), score) for member, score in result] + + return [self.decode(member) for member in result] + + def zrangebyscore( + self, + name: KeyT, + min: Union[float, str], + max: Union[float, str], + start: Optional[int] = None, + num: Optional[int] = None, + withscores: bool = False, + score_cast_func: type = float, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[list[Any], list[tuple[Any, float]]]: + """Return members in sorted set by score range.""" + if client is None: + client = self.get_client(write=False) + + name = self.make_key(name, version=version) + result = client.zrangebyscore( + name, + min, + max, + start=start, + num=num, + withscores=withscores, + score_cast_func=score_cast_func, + ) + + if withscores: + return [(self.decode(member), score) for member, score in result] + + return [self.decode(member) for member in result] + + def zrank( + self, + name: KeyT, + value: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Optional[int]: + """Get the rank (index) of member in sorted set, ordered low to high.""" + if client is None: + client = self.get_client(write=False) + + name = self.make_key(name, version=version) + value = self.encode(value) + rank = client.zrank(name, value) + + return int(rank) if rank is not None else None + + def zrem( + self, + name: KeyT, + *values: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + """Remove members from sorted set.""" + if client is None: + client = self.get_client(write=True) + + name = self.make_key(name, version=version) + encoded_values = [self.encode(value) for value in values] + return int(client.zrem(name, *encoded_values)) + + def zremrangebyscore( + self, + name: KeyT, + min: Union[float, str], + max: Union[float, str], + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + """Remove members from sorted set with scores between min and max.""" + if client is None: + client = self.get_client(write=True) + + name = self.make_key(name, version=version) + return int(client.zremrangebyscore(name, min, max)) + + def zrevrange( + self, + name: KeyT, + start: int, + end: int, + withscores: bool = False, + score_cast_func: type = float, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[list[Any], list[tuple[Any, float]]]: + """Return members in sorted set by index range, ordered high to low.""" + if client is None: + client = self.get_client(write=False) + + name = self.make_key(name, version=version) + result = client.zrevrange( + name, + start, + end, + withscores=withscores, + score_cast_func=score_cast_func, + ) + + if withscores: + return [(self.decode(member), score) for member, score in result] + + return [self.decode(member) for member in result] + + def zrevrangebyscore( + self, + name: KeyT, + max: Union[float, str], + min: Union[float, str], + start: Optional[int] = None, + num: Optional[int] = None, + withscores: bool = False, + score_cast_func: type = float, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[list[Any], list[tuple[Any, float]]]: + """Return members in sorted set by score range, ordered high to low.""" + if client is None: + client = self.get_client(write=False) + + name = self.make_key(name, version=version) + result = client.zrevrangebyscore( + name, + max, + min, + start=start, + num=num, + withscores=withscores, + score_cast_func=score_cast_func, + ) + + if withscores: + return [(self.decode(member), score) for member, score in result] + + return [self.decode(member) for member in result] + + def zscore( + self, + name: KeyT, + value: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Optional[float]: + """Get the score of member in sorted set.""" + if client is None: + client = self.get_client(write=False) + + name = self.make_key(name, version=version) + value = self.encode(value) + score = client.zscore(name, value) + + return float(score) if score is not None else None diff --git a/setup.cfg b/setup.cfg index 88262d8e..fa8af4fa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,7 @@ python_requires = >=3.9 packages = django_redis django_redis.client + django_redis.client.mixins django_redis.serializers django_redis.compressors install_requires = diff --git a/tests/test_backend_sorted_sets.py b/tests/test_backend_sorted_sets.py new file mode 100644 index 00000000..e0d9590f --- /dev/null +++ b/tests/test_backend_sorted_sets.py @@ -0,0 +1,217 @@ +from django_redis.cache import RedisCache + + +class TestSortedSetOperations: + """Tests for sorted set (ZSET) operations.""" + + def test_zadd_basic(self, cache: RedisCache): + """Test adding members to sorted set.""" + result = cache.zadd("scores", {"player1": 100.0, "player2": 200.0}) + assert result == 2 + assert cache.zcard("scores") == 2 + + def test_zadd_with_nx(self, cache: RedisCache): + """Test zadd with nx flag (only add new).""" + cache.zadd("scores", {"alice": 10.0}) + result = cache.zadd("scores", {"alice": 20.0}, nx=True) + assert result == 0 + assert cache.zscore("scores", "alice") == 10.0 + + def test_zadd_with_xx(self, cache: RedisCache): + """Test zadd with xx flag (only update existing).""" + cache.zadd("scores", {"bob": 15.0}) + result = cache.zadd("scores", {"bob": 25.0}, xx=True) + assert result == 0 # No new members added + assert cache.zscore("scores", "bob") == 25.0 + result = cache.zadd("scores", {"charlie": 30.0}, xx=True) + assert result == 0 + assert cache.zscore("scores", "charlie") is None + + def test_zadd_with_ch(self, cache: RedisCache): + """Test zadd with ch flag (return changed count).""" + cache.zadd("scores", {"player1": 100.0}) + result = cache.zadd("scores", {"player1": 150.0, "player2": 200.0}, ch=True) + assert result == 2 # 1 changed + 1 added + + def test_zcard(self, cache: RedisCache): + """Test getting sorted set cardinality.""" + cache.zadd("scores", {"a": 1.0, "b": 2.0, "c": 3.0}) + assert cache.zcard("scores") == 3 + assert cache.zcard("nonexistent") == 0 + + def test_zcount(self, cache: RedisCache): + """Test counting members in score range.""" + cache.zadd("scores", {"a": 1.0, "b": 2.0, "c": 3.0, "d": 4.0, "e": 5.0}) + assert cache.zcount("scores", 2.0, 4.0) == 3 # b, c, d + assert cache.zcount("scores", "-inf", "+inf") == 5 + assert cache.zcount("scores", 10.0, 20.0) == 0 + + def test_zincrby(self, cache: RedisCache): + """Test incrementing member score.""" + cache.zadd("scores", {"player1": 100.0}) + new_score = cache.zincrby("scores", 50.0, "player1") + assert new_score == 150.0 + assert cache.zscore("scores", "player1") == 150.0 + new_score = cache.zincrby("scores", 25.0, "player2") + assert new_score == 25.0 + + def test_zpopmax(self, cache: RedisCache): + """Test popping highest scored members.""" + cache.zadd("scores", {"a": 1.0, "b": 2.0, "c": 3.0}) + result = cache.zpopmax("scores") + assert result == ("c", 3.0) + assert cache.zcard("scores") == 2 + cache.zadd("scores", {"d": 4.0, "e": 5.0}) + result = cache.zpopmax("scores", count=2) + assert len(result) == 2 + assert result[0][0] == "e" and result[0][1] == 5.0 + assert result[1][0] == "d" and result[1][1] == 4.0 + + def test_zpopmin(self, cache: RedisCache): + """Test popping lowest scored members.""" + cache.zadd("scores", {"a": 1.0, "b": 2.0, "c": 3.0}) + result = cache.zpopmin("scores") + assert result == ("a", 1.0) + assert cache.zcard("scores") == 2 + cache.zadd("scores", {"d": 0.5, "e": 0.1}) + result = cache.zpopmin("scores", count=2) + assert len(result) == 2 + assert result[0][0] == "e" and result[0][1] == 0.1 + assert result[1][0] == "d" and result[1][1] == 0.5 + + def test_zrange_basic(self, cache: RedisCache): + """Test getting range of members by index.""" + cache.zadd("scores", {"alice": 10.0, "bob": 20.0, "charlie": 15.0}) + result = cache.zrange("scores", 0, -1) + assert result == ["alice", "charlie", "bob"] + result = cache.zrange("scores", 0, 1) + assert result == ["alice", "charlie"] + + def test_zrange_withscores(self, cache: RedisCache): + """Test zrange with scores.""" + cache.zadd("scores", {"alice": 10.5, "bob": 20.0, "charlie": 15.5}) + result = cache.zrange("scores", 0, -1, withscores=True) + assert result == [("alice", 10.5), ("charlie", 15.5), ("bob", 20.0)] + + def test_zrange_desc(self, cache: RedisCache): + """Test zrange in descending order.""" + cache.zadd("scores", {"a": 1.0, "b": 2.0, "c": 3.0}) + result = cache.zrange("scores", 0, -1, desc=True) + assert result == ["c", "b", "a"] + + def test_zrangebyscore(self, cache: RedisCache): + """Test getting members by score range.""" + cache.zadd("scores", {"a": 1.0, "b": 2.0, "c": 3.0, "d": 4.0, "e": 5.0}) + result = cache.zrangebyscore("scores", 2.0, 4.0) + assert result == ["b", "c", "d"] + result = cache.zrangebyscore("scores", "-inf", 2.0) + assert result == ["a", "b"] + + def test_zrangebyscore_withscores(self, cache: RedisCache): + """Test zrangebyscore with scores.""" + cache.zadd("scores", {"a": 1.0, "b": 2.0, "c": 3.0}) + result = cache.zrangebyscore("scores", 1.0, 2.0, withscores=True) + assert result == [("a", 1.0), ("b", 2.0)] + + def test_zrangebyscore_pagination(self, cache: RedisCache): + """Test zrangebyscore with pagination.""" + cache.zadd("scores", {"a": 1.0, "b": 2.0, "c": 3.0, "d": 4.0, "e": 5.0}) + result = cache.zrangebyscore("scores", "-inf", "+inf", start=1, num=2) + assert len(result) == 2 + assert result == ["b", "c"] + + def test_zrank(self, cache: RedisCache): + """Test getting member rank.""" + cache.zadd("scores", {"alice": 10.0, "bob": 20.0, "charlie": 15.0}) + assert cache.zrank("scores", "alice") == 0 # Lowest score + assert cache.zrank("scores", "charlie") == 1 + assert cache.zrank("scores", "bob") == 2 + assert cache.zrank("scores", "nonexistent") is None + + def test_zrem(self, cache: RedisCache): + """Test removing members from sorted set.""" + cache.zadd("scores", {"a": 1.0, "b": 2.0, "c": 3.0}) + result = cache.zrem("scores", "b") + assert result == 1 + assert cache.zcard("scores") == 2 + result = cache.zrem("scores", "a", "c") + assert result == 2 + assert cache.zcard("scores") == 0 + + def test_zremrangebyscore(self, cache: RedisCache): + """Test removing members by score range.""" + cache.zadd("scores", {"a": 1.0, "b": 2.0, "c": 3.0, "d": 4.0, "e": 5.0}) + result = cache.zremrangebyscore("scores", 2.0, 4.0) + assert result == 3 # b, c, d removed + assert cache.zcard("scores") == 2 + assert cache.zrange("scores", 0, -1) == ["a", "e"] + + def test_zrevrange(self, cache: RedisCache): + """Test getting reverse range (highest to lowest).""" + cache.zadd("scores", {"a": 1.0, "b": 2.0, "c": 3.0}) + result = cache.zrevrange("scores", 0, -1) + assert result == ["c", "b", "a"] + + def test_zrevrange_withscores(self, cache: RedisCache): + """Test zrevrange with scores.""" + cache.zadd("scores", {"a": 1.0, "b": 2.0, "c": 3.0}) + result = cache.zrevrange("scores", 0, -1, withscores=True) + assert result == [("c", 3.0), ("b", 2.0), ("a", 1.0)] + + def test_zrevrangebyscore(self, cache: RedisCache): + """Test getting reverse range by score.""" + cache.zadd("scores", {"a": 1.0, "b": 2.0, "c": 3.0, "d": 4.0, "e": 5.0}) + result = cache.zrevrangebyscore("scores", 4.0, 2.0) + assert result == ["d", "c", "b"] + + def test_zscore(self, cache: RedisCache): + """Test getting member score.""" + cache.zadd("scores", {"alice": 42.5, "bob": 100.0}) + assert cache.zscore("scores", "alice") == 42.5 + assert cache.zscore("scores", "bob") == 100.0 + assert cache.zscore("scores", "nonexistent") is None + + def test_sorted_set_serialization(self, cache: RedisCache): + """Test that complex objects serialize correctly as members.""" + cache.zadd("complex", {("tuple", "key"): 1.0, "string": 2.0}) + result = cache.zrange("complex", 0, -1) + assert ("tuple", "key") in result or ["tuple", "key"] in result + assert "string" in result + + def test_sorted_set_version_support(self, cache: RedisCache): + """Test version parameter works correctly.""" + cache.zadd("data", {"v1": 1.0}, version=1) + cache.zadd("data", {"v2": 2.0}, version=2) + + assert cache.zcard("data", version=1) == 1 + assert cache.zcard("data", version=2) == 1 + assert cache.zrange("data", 0, -1, version=1) == ["v1"] + assert cache.zrange("data", 0, -1, version=2) == ["v2"] + + def test_sorted_set_float_scores(self, cache: RedisCache): + """Test that float scores work correctly.""" + cache.zadd("precise", {"a": 1.1, "b": 1.2, "c": 1.15}) + result = cache.zrange("precise", 0, -1, withscores=True) + assert result[0] == ("a", 1.1) + assert result[1] == ("c", 1.15) + assert result[2] == ("b", 1.2) + + def test_sorted_set_negative_scores(self, cache: RedisCache): + """Test that negative scores work correctly.""" + cache.zadd("temps", {"freezing": -10.0, "cold": 0.0, "warm": 20.0}) + result = cache.zrange("temps", 0, -1) + assert result == ["freezing", "cold", "warm"] + + def test_zpopmin_empty_set(self, cache: RedisCache): + """Test zpopmin on empty sorted set.""" + result = cache.zpopmin("nonexistent") + assert result is None + result = cache.zpopmin("nonexistent", count=5) + assert result == [] + + def test_zpopmax_empty_set(self, cache: RedisCache): + """Test zpopmax on empty sorted set.""" + result = cache.zpopmax("nonexistent") + assert result is None + result = cache.zpopmax("nonexistent", count=5) + assert result == []