Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 75 additions & 3 deletions redisvl/redis/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, overload
from urllib.parse import urlparse
from warnings import warn

from redis import Redis, RedisCluster
Expand All @@ -11,6 +12,7 @@
from redis.asyncio.connection import SSLConnection as AsyncSSLConnection
from redis.connection import SSLConnection
from redis.exceptions import ResponseError
from redis.sentinel import Sentinel

from redisvl import __version__
from redisvl.redis.constants import REDIS_URL_ENV_VAR
Expand Down Expand Up @@ -192,6 +194,9 @@ def parse_attrs(attrs):
}


T = TypeVar("T", Redis, AsyncRedis)


class RedisConnectionFactory:
"""Builds connections to a Redis database, supporting both synchronous and
asynchronous clients.
Expand Down Expand Up @@ -253,7 +258,9 @@ def get_redis_connection(
variable is not set.
"""
url = redis_url or get_address_from_env()
if is_cluster_url(url, **kwargs):
if url.startswith("redis+sentinel"):
client = RedisConnectionFactory._redis_sentinel_client(url, Redis, **kwargs)
elif is_cluster_url(url, **kwargs):
client = RedisCluster.from_url(url, **kwargs)
else:
client = Redis.from_url(url, **kwargs)
Expand Down Expand Up @@ -293,7 +300,11 @@ async def _get_aredis_connection(
"""
url = url or get_address_from_env()

if is_cluster_url(url, **kwargs):
if url.startswith("redis+sentinel"):
client = RedisConnectionFactory._redis_sentinel_client(
url, AsyncRedis, **kwargs
)
elif is_cluster_url(url, **kwargs):
client = AsyncRedisCluster.from_url(url, **kwargs)
else:
client = AsyncRedis.from_url(url, **kwargs)
Expand Down Expand Up @@ -334,6 +345,10 @@ def get_async_redis_connection(
DeprecationWarning,
)
url = url or get_address_from_env()
if url.startswith("redis+sentinel"):
return RedisConnectionFactory._redis_sentinel_client(
url, AsyncRedis, **kwargs
)
return AsyncRedis.from_url(url, **kwargs)

@staticmethod
Expand Down Expand Up @@ -440,3 +455,60 @@ async def validate_async_redis(
await redis_client.echo(_lib_name)

# Module validation removed - operations will fail naturally if modules are missing

@staticmethod
@overload
def _redis_sentinel_client(
redis_url: str, redis_class: type[Redis], **kwargs: Any
) -> Redis: ...

@staticmethod
@overload
def _redis_sentinel_client(
redis_url: str, redis_class: type[AsyncRedis], **kwargs: Any
) -> AsyncRedis: ...

@staticmethod
def _redis_sentinel_client(
redis_url: str, redis_class: Union[type[Redis], type[AsyncRedis]], **kwargs: Any
) -> Union[Redis, AsyncRedis]:
sentinel_list, service_name, db, username, password = (
RedisConnectionFactory._parse_sentinel_url(redis_url)
)

sentinel_kwargs = {}
if username:
sentinel_kwargs["username"] = username
kwargs["username"] = username
if password:
sentinel_kwargs["password"] = password
kwargs["password"] = password
if db:
kwargs["db"] = db

sentinel = Sentinel(sentinel_list, sentinel_kwargs=sentinel_kwargs, **kwargs)
return sentinel.master_for(service_name, redis_class=redis_class, **kwargs)

@staticmethod
def _parse_sentinel_url(url: str) -> tuple:
parsed_url = urlparse(url)
hosts_part = parsed_url.netloc.split("@")[-1]
sentinel_hosts = hosts_part.split(",")

sentinel_list = []
for host in sentinel_hosts:
host_parts = host.split(":")
if len(host_parts) == 2:
sentinel_list.append((host_parts[0], int(host_parts[1])))
else:
sentinel_list.append((host_parts[0], 26379))

service_name = "mymaster"
db = None
if parsed_url.path:
path_parts = parsed_url.path.split("/")
service_name = path_parts[1] or "mymaster"
if len(path_parts) > 2:
db = path_parts[2]

return sentinel_list, service_name, db, parsed_url.username, parsed_url.password
84 changes: 84 additions & 0 deletions tests/unit/test_sentinel_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from unittest.mock import MagicMock, patch

import pytest
from redis.exceptions import ConnectionError

from redisvl.redis.connection import RedisConnectionFactory


@pytest.mark.parametrize("use_async", [False, True])
def test_sentinel_url_connection(use_async):
sentinel_url = (
"redis+sentinel://username:password@host1:26379,host2:26380/mymaster/0"
)

with patch("redisvl.redis.connection.Sentinel") as mock_sentinel:
mock_master = MagicMock()
mock_sentinel.return_value.master_for.return_value = mock_master

if use_async:
client = RedisConnectionFactory.get_async_redis_connection(sentinel_url)
else:
client = RedisConnectionFactory.get_redis_connection(sentinel_url)

mock_sentinel.assert_called_once()
call_args = mock_sentinel.call_args
assert call_args[0][0] == [("host1", 26379), ("host2", 26380)]
assert call_args[1]["sentinel_kwargs"] == {
"username": "username",
"password": "password",
}

mock_sentinel.return_value.master_for.assert_called_once()
master_for_args = mock_sentinel.return_value.master_for.call_args
assert master_for_args[0][0] == "mymaster"
assert master_for_args[1]["db"] == "0"

assert client == mock_master


@pytest.mark.parametrize("use_async", [False, True])
def test_sentinel_url_connection_no_auth_no_db(use_async):
sentinel_url = "redis+sentinel://host1:26379,host2:26380/mymaster"

with patch("redisvl.redis.connection.Sentinel") as mock_sentinel:
mock_master = MagicMock()
mock_sentinel.return_value.master_for.return_value = mock_master

if use_async:
client = RedisConnectionFactory.get_async_redis_connection(sentinel_url)
else:
client = RedisConnectionFactory.get_redis_connection(sentinel_url)

mock_sentinel.assert_called_once()
call_args = mock_sentinel.call_args
assert call_args[0][0] == [("host1", 26379), ("host2", 26380)]
assert (
"sentinel_kwargs" not in call_args[1]
or call_args[1]["sentinel_kwargs"] == {}
)

mock_sentinel.return_value.master_for.assert_called_once()
master_for_args = mock_sentinel.return_value.master_for.call_args
assert master_for_args[0][0] == "mymaster"
assert "db" not in master_for_args[1]

assert client == mock_master


@pytest.mark.parametrize("use_async", [False, True])
def test_sentinel_url_connection_error(use_async):
sentinel_url = "redis+sentinel://host1:26379,host2:26380/mymaster"

with patch("redisvl.redis.connection.Sentinel") as mock_sentinel:
mock_sentinel.return_value.master_for.side_effect = ConnectionError(
"Test connection error"
)

with pytest.raises(ConnectionError):
if use_async:
RedisConnectionFactory.get_async_redis_connection(sentinel_url)
else:
RedisConnectionFactory.get_redis_connection(sentinel_url)

mock_sentinel.assert_called_once()
Loading