From 6832e665102d31e4c209b80b9caf4ceb42572293 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 1 Aug 2024 14:03:32 +0200 Subject: [PATCH] Refine TLS certificate handling * Improve API docs. * Document missing classes from the mTLS feature in preview. * Improve wording and references. * Made `neo4j.auth_management.RotatingClientCertificateProvider` an abstract class. This means it can no longer be instantiated directly. Please use the provided factory method `neo4j.auth_management.RotatingClientCertificateProvider.rotating` instead. * Analogously for the async APIs. * Fix missing type hint for parameter of `TrustCustomCAs.__init__`. --- CHANGELOG.md | 7 +++- docs/source/api.rst | 9 +++++ docs/source/async_api.rst | 8 +++++ src/neo4j/_async/auth_management.py | 44 ++++++++++++++++------- src/neo4j/_auth_management.py | 33 ++++++++++++----- src/neo4j/_conf.py | 4 +-- src/neo4j/_sync/auth_management.py | 44 ++++++++++++++++------- tests/unit/async_/test_auth_management.py | 27 ++------------ tests/unit/sync/test_auth_management.py | 27 ++------------ 9 files changed, 117 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 943a1529c..b02bb8599 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,12 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. ## NEXT RELEASE -- No breaking or major changes. +- Made `neo4j.auth_management.RotatingClientCertificateProvider` and + `...AsyncRotatingClientCertificateProvider` (in preview) + abstract classes, meaning they can no longer be instantiated directly. + Please use the provided factory methods instead: + `neo4j.auth_management.RotatingClientCertificateProvider.rotating` and + `....AsyncRotatingClientCertificateProvider.rotating()` respectively. ## Version 5.23 diff --git a/docs/source/api.rst b/docs/source/api.rst index e44448f09..28be087a6 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -664,8 +664,17 @@ https://github.com/neo4j/neo4j-python-driver/wiki/preview-features .. versionadded:: 5.19 .. autoclass:: neo4j.auth_management.ClientCertificate + :members: .. autoclass:: neo4j.auth_management.ClientCertificateProvider + :members: + +.. autoclass:: neo4j.auth_management.ClientCertificateProviders + :members: + +.. autoclass:: neo4j.auth_management.RotatingClientCertificateProvider + :show-inheritance: + :members: .. _user-agent-ref: diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index 8c666caae..d84938b99 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -457,6 +457,14 @@ https://github.com/neo4j/neo4j-python-driver/wiki/preview-features .. versionadded:: 5.19 .. autoclass:: neo4j.auth_management.AsyncClientCertificateProvider + :members: + +.. autoclass:: neo4j.auth_management.AsyncClientCertificateProviders + :members: + +.. autoclass:: neo4j.auth_management.AsyncRotatingClientCertificateProvider + :show-inheritance: + :members: diff --git a/src/neo4j/_async/auth_management.py b/src/neo4j/_async/auth_management.py index be4d309ff..20ceed681 100644 --- a/src/neo4j/_async/auth_management.py +++ b/src/neo4j/_async/auth_management.py @@ -16,6 +16,7 @@ from __future__ import annotations +import abc import typing as t from logging import getLogger @@ -110,7 +111,8 @@ async def handle_security_exception( class AsyncAuthManagers: - """A collection of :class:`.AsyncAuthManager` factories. + """ + A collection of :class:`.AsyncAuthManager` factories. .. versionadded:: 5.8 @@ -124,7 +126,8 @@ class AsyncAuthManagers: @staticmethod def static(auth: _TAuth) -> AsyncAuthManager: - """Create a static auth manager. + """ + Create a static auth manager. The manager will always return the auth info provided at its creation. @@ -163,7 +166,8 @@ def static(auth: _TAuth) -> AsyncAuthManager: def basic( provider: t.Callable[[], t.Awaitable[_TAuth]] ) -> AsyncAuthManager: - """Create an auth manager handling basic auth password rotation. + """ + Create an auth manager handling basic auth password rotation. This factory wraps the provider function in an auth manager implementation that caches the provided auth info until the server @@ -230,7 +234,8 @@ async def wrapped_provider() -> ExpiringAuth: def bearer( provider: t.Callable[[], t.Awaitable[ExpiringAuth]] ) -> AsyncAuthManager: - """Create an auth manager for potentially expiring bearer auth tokens. + """ + Create an auth manager for potentially expiring bearer auth tokens. This factory wraps the provider function in an auth manager implementation that caches the provided auth info until either the @@ -310,10 +315,9 @@ async def get_certificate(self) -> t.Optional[ClientCertificate]: return cert -@preview("Mutual TLS is a preview feature.") class AsyncRotatingClientCertificateProvider(AsyncClientCertificateProvider): """ - Implementation of a certificate provider that can rotate certificates. + Abstract base class for certificate providers that can rotate certificates. The provider will make the driver use the initial certificate for all connections until the certificate is updated using the @@ -367,10 +371,26 @@ class AsyncRotatingClientCertificateProvider(AsyncClientCertificateProvider): # rotated again ... - :param initial_cert: The certificate to use initially. - .. versionadded:: 5.19 + + .. versionchanged:: 5.24 + + Turned this class into an abstract class to make the actual + implementation internal. This entails removing the possibility to + directly instantiate this class. Please use the factory method + :meth:`.AsyncClientCertificateProviders.rotating` instead. """ + + @abc.abstractmethod + async def update_certificate(self, cert: ClientCertificate) -> None: + """ + Update the certificate to use for new connections. + """ + + +class _AsyncNeo4jRotatingClientCertificateProvider( + AsyncRotatingClientCertificateProvider +): def __init__(self, initial_cert: ClientCertificate) -> None: self._cert: t.Optional[ClientCertificate] = initial_cert self._lock = AsyncCooperativeLock() @@ -381,15 +401,13 @@ async def get_certificate(self) -> t.Optional[ClientCertificate]: return cert async def update_certificate(self, cert: ClientCertificate) -> None: - """ - Update the certificate to use for new connections. - """ async with self._lock: self._cert = cert class AsyncClientCertificateProviders: - """A collection of :class:`.AsyncClientCertificateProvider` factories. + """ + A collection of :class:`.AsyncClientCertificateProvider` factories. **This is a preview** (see :ref:`filter-warnings-ref`). It might be changed without following the deprecation policy. @@ -419,4 +437,4 @@ def rotating( .. seealso:: :class:`.AsyncRotatingClientCertificateProvider` """ - return AsyncRotatingClientCertificateProvider(initial_cert) + return _AsyncNeo4jRotatingClientCertificateProvider(initial_cert) diff --git a/src/neo4j/_auth_management.py b/src/neo4j/_auth_management.py index b4411a042..e9708e986 100644 --- a/src/neo4j/_auth_management.py +++ b/src/neo4j/_auth_management.py @@ -37,7 +37,8 @@ @dataclass class ExpiringAuth: - """Represents potentially expiring authentication information. + """ + Represents potentially expiring authentication information. This class is used with :meth:`.AuthManagers.bearer` and :meth:`.AsyncAuthManagers.bearer`. @@ -68,7 +69,8 @@ class ExpiringAuth: expires_at: t.Optional[float] = None def expires_in(self, seconds: float) -> ExpiringAuth: - """Return a (flat) copy of this object with a new expiration time. + """ + Return a (flat) copy of this object with a new expiration time. This is a convenience method for creating an :class:`.ExpiringAuth` for a relative expiration time ("expires in" instead of "expires at"). @@ -96,7 +98,8 @@ def expiring_auth_has_expired(auth: ExpiringAuth) -> bool: class AuthManager(metaclass=abc.ABCMeta): - """Baseclass for authentication information managers. + """ + Abstract base class for authentication information managers. The driver provides some default implementations of this class in :class:`.AuthManagers` for convenience. @@ -132,7 +135,8 @@ class AuthManager(metaclass=abc.ABCMeta): @abc.abstractmethod def get_auth(self) -> _TAuth: - """Return the current authentication information. + """ + Return the current authentication information. The driver will call this method very frequently. It is recommended to implement some form of caching to avoid unnecessary overhead. @@ -151,7 +155,8 @@ def get_auth(self) -> _TAuth: def handle_security_exception( self, auth: _TAuth, error: Neo4jError ) -> bool: - """Handle the server indicating authentication failure. + """ + Handle the server indicating authentication failure. The driver will call this method when the server returns any `Neo.ClientError.Security.*` error. The error will then be processed @@ -174,7 +179,8 @@ def handle_security_exception( class AsyncAuthManager(_Protocol, metaclass=abc.ABCMeta): - """Async version of :class:`.AuthManager`. + """ + Async version of :class:`.AuthManager`. .. seealso:: :class:`.AuthManager` @@ -234,10 +240,14 @@ class ClientCertificate: class ClientCertificateProvider(_Protocol, metaclass=abc.ABCMeta): """ - Provides a client certificate to the driver for mutual TLS. + Interface for providing a client certificate to the driver for mutual TLS. + + This is an abstract base class (:class:`abc.ABC`) as well as a protocol + (:class:`typing.Protocol`). Meaning you can either inherit from it or just + implement all required method on a class to satisfy the type constraints. The package provides some default implementations of this class in - :class:`.AsyncClientCertificateProviders` for convenience. + :class:`.ClientCertificateProviders` for convenience. The driver will call :meth:`.get_certificate` to check if the client wants the driver to use as new certificate for mutual TLS. @@ -286,12 +296,17 @@ class AsyncClientCertificateProvider(_Protocol, metaclass=abc.ABCMeta): """ Async version of :class:`.ClientCertificateProvider`. + The package provides some default implementations of this class in + :class:`.AsyncClientCertificateProviders` for convenience. + **This is a preview** (see :ref:`filter-warnings-ref`). It might be changed without following the deprecation policy. See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features - .. seealso:: :class:`.ClientCertificateProvider` + .. seealso:: + :class:`.ClientCertificateProvider`, + :class:`.AsyncClientCertificateProviders` .. versionadded:: 5.19 """ diff --git a/src/neo4j/_conf.py b/src/neo4j/_conf.py index dde0e0afa..1818211a2 100644 --- a/src/neo4j/_conf.py +++ b/src/neo4j/_conf.py @@ -102,7 +102,7 @@ class TrustCustomCAs(TrustStore): authority at the specified paths. This option is primarily intended for self-signed and custom certificates. - :param certificates (str): paths to the certificates to trust. + :param certificates: paths to the certificates to trust. Those are not the certificates you expect to see from the server but the CA certificates you expect to be used to sign the server's certificate. @@ -118,7 +118,7 @@ class TrustCustomCAs(TrustStore): ) ) """ - def __init__(self, *certificates): + def __init__(self, *certificates: str): self.certs = certificates diff --git a/src/neo4j/_sync/auth_management.py b/src/neo4j/_sync/auth_management.py index 665fbdb07..534c42337 100644 --- a/src/neo4j/_sync/auth_management.py +++ b/src/neo4j/_sync/auth_management.py @@ -16,6 +16,7 @@ from __future__ import annotations +import abc import typing as t from logging import getLogger @@ -110,7 +111,8 @@ def handle_security_exception( class AuthManagers: - """A collection of :class:`.AuthManager` factories. + """ + A collection of :class:`.AuthManager` factories. .. versionadded:: 5.8 @@ -124,7 +126,8 @@ class AuthManagers: @staticmethod def static(auth: _TAuth) -> AuthManager: - """Create a static auth manager. + """ + Create a static auth manager. The manager will always return the auth info provided at its creation. @@ -163,7 +166,8 @@ def static(auth: _TAuth) -> AuthManager: def basic( provider: t.Callable[[], t.Union[_TAuth]] ) -> AuthManager: - """Create an auth manager handling basic auth password rotation. + """ + Create an auth manager handling basic auth password rotation. This factory wraps the provider function in an auth manager implementation that caches the provided auth info until the server @@ -230,7 +234,8 @@ def wrapped_provider() -> ExpiringAuth: def bearer( provider: t.Callable[[], t.Union[ExpiringAuth]] ) -> AuthManager: - """Create an auth manager for potentially expiring bearer auth tokens. + """ + Create an auth manager for potentially expiring bearer auth tokens. This factory wraps the provider function in an auth manager implementation that caches the provided auth info until either the @@ -310,10 +315,9 @@ def get_certificate(self) -> t.Optional[ClientCertificate]: return cert -@preview("Mutual TLS is a preview feature.") class RotatingClientCertificateProvider(ClientCertificateProvider): """ - Implementation of a certificate provider that can rotate certificates. + Abstract base class for certificate providers that can rotate certificates. The provider will make the driver use the initial certificate for all connections until the certificate is updated using the @@ -367,10 +371,26 @@ class RotatingClientCertificateProvider(ClientCertificateProvider): # rotated again ... - :param initial_cert: The certificate to use initially. - .. versionadded:: 5.19 + + .. versionchanged:: 5.24 + + Turned this class into an abstract class to make the actual + implementation internal. This entails removing the possibility to + directly instantiate this class. Please use the factory method + :meth:`.ClientCertificateProviders.rotating` instead. """ + + @abc.abstractmethod + def update_certificate(self, cert: ClientCertificate) -> None: + """ + Update the certificate to use for new connections. + """ + + +class _Neo4jRotatingClientCertificateProvider( + RotatingClientCertificateProvider +): def __init__(self, initial_cert: ClientCertificate) -> None: self._cert: t.Optional[ClientCertificate] = initial_cert self._lock = CooperativeLock() @@ -381,15 +401,13 @@ def get_certificate(self) -> t.Optional[ClientCertificate]: return cert def update_certificate(self, cert: ClientCertificate) -> None: - """ - Update the certificate to use for new connections. - """ with self._lock: self._cert = cert class ClientCertificateProviders: - """A collection of :class:`.ClientCertificateProvider` factories. + """ + A collection of :class:`.ClientCertificateProvider` factories. **This is a preview** (see :ref:`filter-warnings-ref`). It might be changed without following the deprecation policy. @@ -419,4 +437,4 @@ def rotating( .. seealso:: :class:`.RotatingClientCertificateProvider` """ - return RotatingClientCertificateProvider(initial_cert) + return _Neo4jRotatingClientCertificateProvider(initial_cert) diff --git a/tests/unit/async_/test_auth_management.py b/tests/unit/async_/test_auth_management.py index ac68e58b1..3e36b8259 100644 --- a/tests/unit/async_/test_auth_management.py +++ b/tests/unit/async_/test_auth_management.py @@ -260,12 +260,6 @@ def static_cert_provider(*args, **kwargs): return AsyncClientCertificateProviders.static(*args, **kwargs) -@copy_signature(AsyncRotatingClientCertificateProvider) -def rotating_cert_provider_direct(*args, **kwargs): - with pytest.warns(PreviewWarning, match="Mutual TLS"): - return AsyncRotatingClientCertificateProvider(*args, **kwargs) - - @copy_signature(AsyncClientCertificateProviders.rotating) def rotating_cert_provider(*args, **kwargs): with pytest.warns(PreviewWarning, match="Mutual TLS"): @@ -285,15 +279,6 @@ async def test_static_client_cert_provider(client_cert_factory) -> None: if t.TYPE_CHECKING: # Tests for type checker only. No need to run the test. - async def test_rotating_client_cert_provider_type_init( - client_cert_factory - ) -> None: - cert1: ClientCertificate = client_cert_factory() - provider: AsyncRotatingClientCertificateProvider = \ - rotating_cert_provider_direct(cert1) - _: AsyncClientCertificateProvider = provider - - async def test_rotating_client_cert_provider_type_factory( client_cert_factory ) -> None: @@ -303,19 +288,13 @@ async def test_rotating_client_cert_provider_type_factory( _: AsyncClientCertificateProvider = provider -@pytest.mark.parametrize( - "factory", (rotating_cert_provider, rotating_cert_provider_direct) -) @mark_async_test -async def test_rotating_client_cert_provider( - factory: t.Callable[[ClientCertificate], - AsyncRotatingClientCertificateProvider], - client_cert_factory -) -> None: +async def test_rotating_client_cert_provider(client_cert_factory) -> None: cert1: ClientCertificate = client_cert_factory() cert2: ClientCertificate = client_cert_factory() assert cert1 is not cert2 # sanity check - provider: AsyncRotatingClientCertificateProvider = factory(cert1) + provider: AsyncRotatingClientCertificateProvider = \ + rotating_cert_provider(cert1) assert await provider.get_certificate() is cert1 for _ in range(10): diff --git a/tests/unit/sync/test_auth_management.py b/tests/unit/sync/test_auth_management.py index 598877daf..6d2d36752 100644 --- a/tests/unit/sync/test_auth_management.py +++ b/tests/unit/sync/test_auth_management.py @@ -260,12 +260,6 @@ def static_cert_provider(*args, **kwargs): return ClientCertificateProviders.static(*args, **kwargs) -@copy_signature(RotatingClientCertificateProvider) -def rotating_cert_provider_direct(*args, **kwargs): - with pytest.warns(PreviewWarning, match="Mutual TLS"): - return RotatingClientCertificateProvider(*args, **kwargs) - - @copy_signature(ClientCertificateProviders.rotating) def rotating_cert_provider(*args, **kwargs): with pytest.warns(PreviewWarning, match="Mutual TLS"): @@ -285,15 +279,6 @@ def test_static_client_cert_provider(client_cert_factory) -> None: if t.TYPE_CHECKING: # Tests for type checker only. No need to run the test. - def test_rotating_client_cert_provider_type_init( - client_cert_factory - ) -> None: - cert1: ClientCertificate = client_cert_factory() - provider: RotatingClientCertificateProvider = \ - rotating_cert_provider_direct(cert1) - _: ClientCertificateProvider = provider - - def test_rotating_client_cert_provider_type_factory( client_cert_factory ) -> None: @@ -303,19 +288,13 @@ def test_rotating_client_cert_provider_type_factory( _: ClientCertificateProvider = provider -@pytest.mark.parametrize( - "factory", (rotating_cert_provider, rotating_cert_provider_direct) -) @mark_sync_test -def test_rotating_client_cert_provider( - factory: t.Callable[[ClientCertificate], - RotatingClientCertificateProvider], - client_cert_factory -) -> None: +def test_rotating_client_cert_provider(client_cert_factory) -> None: cert1: ClientCertificate = client_cert_factory() cert2: ClientCertificate = client_cert_factory() assert cert1 is not cert2 # sanity check - provider: RotatingClientCertificateProvider = factory(cert1) + provider: RotatingClientCertificateProvider = \ + rotating_cert_provider(cert1) assert provider.get_certificate() is cert1 for _ in range(10):