Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@
)

if TYPE_CHECKING and SSL_AVAILABLE:
from ssl import TLSVersion, VerifyMode
from ssl import TLSVersion, VerifyFlags, VerifyMode
else:
TLSVersion = None
VerifyMode = None
VerifyFlags = None

PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
_KeyT = TypeVar("_KeyT", bound=KeyT)
Expand Down Expand Up @@ -238,6 +239,8 @@ def __init__(
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_cert_reqs: Union[str, VerifyMode] = "required",
ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
ssl_ca_certs: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = True,
Expand Down Expand Up @@ -347,6 +350,8 @@ def __init__(
"ssl_keyfile": ssl_keyfile,
"ssl_certfile": ssl_certfile,
"ssl_cert_reqs": ssl_cert_reqs,
"ssl_include_verify_flags": ssl_include_verify_flags,
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
"ssl_ca_certs": ssl_ca_certs,
"ssl_ca_data": ssl_ca_data,
"ssl_check_hostname": ssl_check_hostname,
Expand Down
7 changes: 6 additions & 1 deletion redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@
)

if SSL_AVAILABLE:
from ssl import TLSVersion, VerifyMode
from ssl import TLSVersion, VerifyFlags, VerifyMode
else:
TLSVersion = None
VerifyMode = None
VerifyFlags = None

TargetNodesT = TypeVar(
"TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
Expand Down Expand Up @@ -299,6 +300,8 @@ def __init__(
ssl_ca_certs: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_cert_reqs: Union[str, VerifyMode] = "required",
ssl_include_verify_flags: Optional[List[VerifyFlags]] = None,
ssl_exclude_verify_flags: Optional[List[VerifyFlags]] = None,
ssl_certfile: Optional[str] = None,
ssl_check_hostname: bool = True,
ssl_keyfile: Optional[str] = None,
Expand Down Expand Up @@ -358,6 +361,8 @@ def __init__(
"ssl_ca_certs": ssl_ca_certs,
"ssl_ca_data": ssl_ca_data,
"ssl_cert_reqs": ssl_cert_reqs,
"ssl_include_verify_flags": ssl_include_verify_flags,
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
"ssl_certfile": ssl_certfile,
"ssl_check_hostname": ssl_check_hostname,
"ssl_keyfile": ssl_keyfile,
Expand Down
44 changes: 43 additions & 1 deletion redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@

if SSL_AVAILABLE:
import ssl
from ssl import SSLContext, TLSVersion
from ssl import SSLContext, TLSVersion, VerifyFlags
else:
ssl = None
TLSVersion = None
SSLContext = None
VerifyFlags = None

from ..auth.token import TokenInterface
from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher
Expand Down Expand Up @@ -793,6 +794,8 @@ def __init__(
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required",
ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
ssl_ca_certs: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
ssl_check_hostname: bool = True,
Expand All @@ -807,6 +810,8 @@ def __init__(
keyfile=ssl_keyfile,
certfile=ssl_certfile,
cert_reqs=ssl_cert_reqs,
include_verify_flags=ssl_include_verify_flags,
exclude_verify_flags=ssl_exclude_verify_flags,
ca_certs=ssl_ca_certs,
ca_data=ssl_ca_data,
check_hostname=ssl_check_hostname,
Expand All @@ -832,6 +837,14 @@ def certfile(self):
def cert_reqs(self):
return self.ssl_context.cert_reqs

@property
def include_verify_flags(self):
return self.ssl_context.include_verify_flags

@property
def exclude_verify_flags(self):
return self.ssl_context.exclude_verify_flags

@property
def ca_certs(self):
return self.ssl_context.ca_certs
Expand All @@ -854,6 +867,8 @@ class RedisSSLContext:
"keyfile",
"certfile",
"cert_reqs",
"include_verify_flags",
"exclude_verify_flags",
"ca_certs",
"ca_data",
"context",
Expand All @@ -867,6 +882,8 @@ def __init__(
keyfile: Optional[str] = None,
certfile: Optional[str] = None,
cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None,
include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
ca_certs: Optional[str] = None,
ca_data: Optional[str] = None,
check_hostname: bool = False,
Expand All @@ -892,6 +909,8 @@ def __init__(
)
cert_reqs = CERT_REQS[cert_reqs]
self.cert_reqs = cert_reqs
self.include_verify_flags = include_verify_flags
self.exclude_verify_flags = exclude_verify_flags
self.ca_certs = ca_certs
self.ca_data = ca_data
self.check_hostname = (
Expand All @@ -906,6 +925,12 @@ def get(self) -> SSLContext:
context = ssl.create_default_context()
context.check_hostname = self.check_hostname
context.verify_mode = self.cert_reqs
if self.include_verify_flags:
for flag in self.include_verify_flags:
context.verify_flags |= flag
if self.exclude_verify_flags:
for flag in self.exclude_verify_flags:
context.verify_flags &= ~flag
if self.certfile and self.keyfile:
context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
if self.ca_certs or self.ca_data:
Expand Down Expand Up @@ -953,6 +978,20 @@ def to_bool(value) -> Optional[bool]:
return bool(value)


def parse_ssl_verify_flags(value):
# flags are passed in as a string representation of a list,
# e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
verify_flags_str = value.replace("[", "").replace("]", "")

verify_flags = []
for flag in verify_flags_str.split(","):
flag = flag.strip()
if not hasattr(VerifyFlags, flag):
raise ValueError(f"Invalid ssl verify flag: {flag}")
verify_flags.append(getattr(VerifyFlags, flag))
return verify_flags


URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
{
"db": int,
Expand All @@ -963,6 +1002,8 @@ def to_bool(value) -> Optional[bool]:
"max_connections": int,
"health_check_interval": int,
"ssl_check_hostname": to_bool,
"ssl_include_verify_flags": parse_ssl_verify_flags,
"ssl_exclude_verify_flags": parse_ssl_verify_flags,
"timeout": float,
}
)
Expand Down Expand Up @@ -1021,6 +1062,7 @@ def parse_url(url: str) -> ConnectKwargs:

if parsed.scheme == "rediss":
kwargs["connection_class"] = SSLConnection

else:
valid_schemes = "redis://, rediss://, unix://"
raise ValueError(
Expand Down
4 changes: 4 additions & 0 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def __init__(
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required",
ssl_include_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
ssl_exclude_verify_flags: Optional[List["ssl.VerifyFlags"]] = None,
ssl_ca_certs: Optional[str] = None,
ssl_ca_path: Optional[str] = None,
ssl_ca_data: Optional[str] = None,
Expand Down Expand Up @@ -330,6 +332,8 @@ def __init__(
"ssl_keyfile": ssl_keyfile,
"ssl_certfile": ssl_certfile,
"ssl_cert_reqs": ssl_cert_reqs,
"ssl_include_verify_flags": ssl_include_verify_flags,
"ssl_exclude_verify_flags": ssl_exclude_verify_flags,
"ssl_ca_certs": ssl_ca_certs,
"ssl_ca_data": ssl_ca_data,
"ssl_check_hostname": ssl_check_hostname,
Expand Down
2 changes: 2 additions & 0 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def parse_cluster_myshardid(resp, **options):
"ssl_ca_data",
"ssl_certfile",
"ssl_cert_reqs",
"ssl_include_verify_flags",
"ssl_exclude_verify_flags",
"ssl_keyfile",
"ssl_password",
"ssl_check_hostname",
Expand Down
33 changes: 32 additions & 1 deletion redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@

if SSL_AVAILABLE:
import ssl
from ssl import VerifyFlags
else:
ssl = None
VerifyFlags = None

if HIREDIS_AVAILABLE:
import hiredis
Expand Down Expand Up @@ -1360,6 +1362,8 @@ def __init__(
ssl_keyfile=None,
ssl_certfile=None,
ssl_cert_reqs="required",
ssl_include_verify_flags: Optional[List["VerifyFlags"]] = None,
ssl_exclude_verify_flags: Optional[List["VerifyFlags"]] = None,
ssl_ca_certs=None,
ssl_ca_data=None,
ssl_check_hostname=True,
Expand All @@ -1378,7 +1382,10 @@ def __init__(
Args:
ssl_keyfile: Path to an ssl private key. Defaults to None.
ssl_certfile: Path to an ssl certificate. Defaults to None.
ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required".
ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required),
or an ssl.VerifyMode. Defaults to "required".
ssl_include_verify_flags: A list of flags to be included in the SSLContext.verify_flags. Defaults to None.
ssl_exclude_verify_flags: A list of flags to be excluded from the SSLContext.verify_flags. Defaults to None.
ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None.
ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates.
ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to True.
Expand Down Expand Up @@ -1414,6 +1421,8 @@ def __init__(
)
ssl_cert_reqs = CERT_REQS[ssl_cert_reqs]
self.cert_reqs = ssl_cert_reqs
self.ssl_include_verify_flags = ssl_include_verify_flags
self.ssl_exclude_verify_flags = ssl_exclude_verify_flags
self.ca_certs = ssl_ca_certs
self.ca_data = ssl_ca_data
self.ca_path = ssl_ca_path
Expand Down Expand Up @@ -1453,6 +1462,12 @@ def _wrap_socket_with_ssl(self, sock):
context = ssl.create_default_context()
context.check_hostname = self.check_hostname
context.verify_mode = self.cert_reqs
if self.ssl_include_verify_flags:
for flag in self.ssl_include_verify_flags:
context.verify_flags |= flag
if self.ssl_exclude_verify_flags:
for flag in self.ssl_exclude_verify_flags:
context.verify_flags &= ~flag
if self.certfile or self.keyfile:
context.load_cert_chain(
certfile=self.certfile,
Expand Down Expand Up @@ -1566,6 +1581,20 @@ def to_bool(value):
return bool(value)


def parse_ssl_verify_flags(value):
# flags are passed in as a string representation of a list,
# e.g. VERIFY_X509_STRICT, VERIFY_X509_PARTIAL_CHAIN
verify_flags_str = value.replace("[", "").replace("]", "")

verify_flags = []
for flag in verify_flags_str.split(","):
flag = flag.strip()
if not hasattr(VerifyFlags, flag):
raise ValueError(f"Invalid ssl verify flag: {flag}")
verify_flags.append(getattr(VerifyFlags, flag))
return verify_flags


URL_QUERY_ARGUMENT_PARSERS = {
"db": int,
"socket_timeout": float,
Expand All @@ -1576,6 +1605,8 @@ def to_bool(value):
"max_connections": int,
"health_check_interval": int,
"ssl_check_hostname": to_bool,
"ssl_include_verify_flags": parse_ssl_verify_flags,
"ssl_exclude_verify_flags": parse_ssl_verify_flags,
"timeout": float,
}

Expand Down
87 changes: 87 additions & 0 deletions tests/test_asyncio/test_ssl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ssl
import unittest.mock
from urllib.parse import urlparse
import pytest
import pytest_asyncio
Expand Down Expand Up @@ -54,3 +56,88 @@ async def test_cert_reqs_none_with_check_hostname(self, request):
assert conn.check_hostname is False
finally:
await r.aclose()

async def test_ssl_flags_applied_to_context(self, request):
"""
Test that ssl_include_verify_flags and ssl_exclude_verify_flags
are properly applied to the SSL context
"""
ssl_url = request.config.option.redis_ssl_url
parsed_url = urlparse(ssl_url)

# Test with specific SSL verify flags
ssl_include_verify_flags = [
ssl.VerifyFlags.VERIFY_CRL_CHECK_LEAF, # Disable strict verification
ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN, # Enable partial chain
]

ssl_exclude_verify_flags = [
ssl.VerifyFlags.VERIFY_X509_STRICT, # Disable trusted first
]

r = redis.Redis(
host=parsed_url.hostname,
port=parsed_url.port,
ssl=True,
ssl_cert_reqs="none",
ssl_include_verify_flags=ssl_include_verify_flags,
ssl_exclude_verify_flags=ssl_exclude_verify_flags,
)

try:
# Get the connection to trigger SSL context creation
conn = r.connection_pool.make_connection()
assert isinstance(conn, redis.SSLConnection)

# Verify the flags were processed by checking they're stored in connection
assert conn.include_verify_flags is not None
assert len(conn.include_verify_flags) == 2

assert conn.exclude_verify_flags is not None
assert len(conn.exclude_verify_flags) == 1

# Check each flag individually
for flag in ssl_include_verify_flags:
assert flag in conn.include_verify_flags, (
f"Flag {flag} not found in stored ssl_include_verify_flags"
)
for flag in ssl_exclude_verify_flags:
assert flag in conn.exclude_verify_flags, (
f"Flag {flag} not found in stored ssl_exclude_verify_flags"
)

# Test the actual SSL context created by the connection's RedisSSLContext
# We need to mock the ssl.create_default_context to capture the context
captured_context = None
original_create_default_context = ssl.create_default_context

def capture_context_create_default():
nonlocal captured_context
captured_context = original_create_default_context()
return captured_context

with unittest.mock.patch(
"ssl.create_default_context", capture_context_create_default
):
# Trigger SSL context creation by calling get() on the RedisSSLContext
ssl_context = conn.ssl_context.get()

# Validate that we captured a context and it has the correct flags applied
assert captured_context is not None, "SSL context was not captured"
assert ssl_context is captured_context, (
"Returned context should be the captured one"
)

# Verify that VERIFY_X509_STRICT was disabled (bit cleared)
assert not (
captured_context.verify_flags & ssl.VerifyFlags.VERIFY_X509_STRICT
), "VERIFY_X509_STRICT should be disabled but is enabled"

# Verify that VERIFY_CRL_CHECK_CHAIN was enabled (bit set)
assert (
captured_context.verify_flags
& ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN
), "VERIFY_CRL_CHECK_CHAIN should be enabled but is disabled"

finally:
await r.aclose()
Loading