2121# make sure TAuth is resolved in the docs, else they're pretty useless
2222
2323
24- import time
2524import typing as t
25+ import warnings
2626from logging import getLogger
2727
2828from .._async_compat .concurrency import AsyncLock
3131 expiring_auth_has_expired ,
3232 ExpiringAuth ,
3333)
34- from .._meta import preview
34+ from .._meta import (
35+ preview ,
36+ PreviewWarning ,
37+ )
3538
3639# work around for https://github.com/sphinx-doc/sphinx/pull/10880
3740# make sure TAuth is resolved in the docs, else they're pretty useless
3841# if t.TYPE_CHECKING:
3942from ..api import _TAuth
43+ from ..exceptions import Neo4jError
4044
4145
4246log = getLogger ("neo4j" )
@@ -51,21 +55,25 @@ def __init__(self, auth: _TAuth) -> None:
5155 async def get_auth (self ) -> _TAuth :
5256 return self ._auth
5357
54- async def on_auth_expired (self , auth : _TAuth ) -> None :
55- pass
58+ async def handle_security_exception (
59+ self , auth : _TAuth , error : Neo4jError
60+ ) -> bool :
61+ return False
5662
5763
58- class AsyncExpirationBasedAuthManager (AsyncAuthManager ):
64+ class Neo4jAuthTokenManager (AsyncAuthManager ):
5965 _current_auth : t .Optional [ExpiringAuth ]
6066 _provider : t .Callable [[], t .Awaitable [ExpiringAuth ]]
67+ _handled_codes : t .FrozenSet [str ]
6168 _lock : AsyncLock
6269
63-
6470 def __init__ (
6571 self ,
66- provider : t .Callable [[], t .Awaitable [ExpiringAuth ]]
72+ provider : t .Callable [[], t .Awaitable [ExpiringAuth ]],
73+ handled_codes : t .FrozenSet [str ]
6774 ) -> None :
6875 self ._provider = provider
76+ self ._handled_codes = handled_codes
6977 self ._current_auth = None
7078 self ._lock = AsyncLock ()
7179
@@ -81,18 +89,25 @@ async def get_auth(self) -> _TAuth:
8189 async with self ._lock :
8290 auth = self ._current_auth
8391 if auth is None or expiring_auth_has_expired (auth ):
84- log .debug ("[ ] _: <TEMPORAL AUTH> refreshing (time out)" )
92+ log .debug ("[ ] _: <AUTH MANAGER> refreshing (%s)" ,
93+ "init" if auth is None else "time out" )
8594 await self ._refresh_auth ()
8695 auth = self ._current_auth
8796 assert auth is not None
8897 return auth .auth
8998
90- async def on_auth_expired (self , auth : _TAuth ) -> None :
99+ async def handle_security_exception (
100+ self , auth : _TAuth , error : Neo4jError
101+ ) -> bool :
102+ if error .code not in self ._handled_codes :
103+ return False
91104 async with self ._lock :
92105 cur_auth = self ._current_auth
93106 if cur_auth is not None and cur_auth .auth == auth :
94- log .debug ("[ ] _: <TEMPORAL AUTH> refreshing (error)" )
107+ log .debug ("[ ] _: <AUTH MANAGER> refreshing (error %s)" ,
108+ error .code )
95109 await self ._refresh_auth ()
110+ return True
96111
97112
98113class AsyncAuthManagers :
@@ -103,6 +118,11 @@ class AsyncAuthManagers:
103118 See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features
104119
105120 .. versionadded:: 5.8
121+
122+ .. versionchanged:: 5.12
123+
124+ * Method ``expiration_based()`` was renamed to :meth:`bearer`.
125+ * Added :meth:`basic`.
106126 """
107127
108128 @staticmethod
@@ -139,10 +159,72 @@ def static(auth: _TAuth) -> AsyncAuthManager:
139159
140160 @staticmethod
141161 @preview ("Auth managers are a preview feature." )
142- def expiration_based (
162+ def basic (
163+ provider : t .Callable [[], t .Awaitable [_TAuth ]]
164+ ) -> AsyncAuthManager :
165+ """Create an auth manager handling basic auth password rotation.
166+
167+ .. warning::
168+
169+ The provider function **must not** interact with the driver in any
170+ way as this can cause deadlocks and undefined behaviour.
171+
172+ The provider function must only ever return auth information
173+ belonging to the same identity.
174+ Switching identities is undefined behavior.
175+ You may use session-level authentication for such use-cases
176+ :ref:`session-auth-ref`.
177+
178+ Example::
179+
180+ import neo4j
181+ from neo4j.auth_management import (
182+ AsyncAuthManagers,
183+ ExpiringAuth,
184+ )
185+
186+
187+ async def auth_provider():
188+ # some way of getting a token
189+ user, password = await get_current_auth()
190+ return (user, password)
191+
192+
193+ with neo4j.GraphDatabase.driver(
194+ "neo4j://example.com:7687",
195+ auth=AsyncAuthManagers.basic(auth_provider)
196+ ) as driver:
197+ ... # do stuff
198+
199+ :param provider:
200+ A callable that provides a :class:`.ExpiringAuth` instance.
201+
202+ :returns:
203+ An instance of an implementation of :class:`.AsyncAuthManager` that
204+ returns auth info from the given provider and refreshes it, calling
205+ the provider again, when the auth info expires (either because it's
206+ reached its expiry time or because the server flagged it as
207+ expired).
208+
209+ .. versionadded:: 5.12
210+ """
211+ handled_codes = frozenset (("Neo.ClientError.Security.Unauthorized" ,))
212+
213+ async def wrapped_provider () -> ExpiringAuth :
214+ with warnings .catch_warnings ():
215+ warnings .filterwarnings ("ignore" ,
216+ message = r"^Auth managers\b.*" ,
217+ category = PreviewWarning )
218+ return ExpiringAuth (await provider ())
219+
220+ return Neo4jAuthTokenManager (wrapped_provider , handled_codes )
221+
222+ @staticmethod
223+ @preview ("Auth managers are a preview feature." )
224+ def bearer (
143225 provider : t .Callable [[], t .Awaitable [ExpiringAuth ]]
144226 ) -> AsyncAuthManager :
145- """Create an auth manager for potentially expiring auth info .
227+ """Create an auth manager for potentially expiring bearer auth tokens .
146228
147229 .. warning::
148230
@@ -165,7 +247,7 @@ def expiration_based(
165247
166248
167249 async def auth_provider():
168- # some way to getting a token
250+ # some way of getting a token
169251 sso_token = await get_sso_token()
170252 # assume we know our tokens expire every 60 seconds
171253 expires_in = 60
@@ -180,7 +262,7 @@ async def auth_provider():
180262
181263 with neo4j.GraphDatabase.driver(
182264 "neo4j://example.com:7687",
183- auth=AsyncAuthManagers.temporal (auth_provider)
265+ auth=AsyncAuthManagers.bearer (auth_provider)
184266 ) as driver:
185267 ... # do stuff
186268
@@ -194,6 +276,10 @@ async def auth_provider():
194276 reached its expiry time or because the server flagged it as
195277 expired).
196278
197-
279+ .. versionadded:: 5.12
198280 """
199- return AsyncExpirationBasedAuthManager (provider )
281+ handled_codes = frozenset ((
282+ "Neo.ClientError.Security.TokenExpired" ,
283+ "Neo.ClientError.Security.Unauthorized" ,
284+ ))
285+ return Neo4jAuthTokenManager (provider , handled_codes )
0 commit comments