Skip to content

Commit cfde0f2

Browse files
committed
Made sync lock consistent and added types to it
1 parent e5e265d commit cfde0f2

File tree

4 files changed

+76
-27
lines changed

4 files changed

+76
-27
lines changed

redis/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,7 @@ def lock(
10891089
name,
10901090
timeout=None,
10911091
sleep=0.1,
1092+
blocking=True,
10921093
blocking_timeout=None,
10931094
lock_class=None,
10941095
thread_local=True,
@@ -1104,6 +1105,13 @@ def lock(
11041105
when the lock is in blocking mode and another client is currently
11051106
holding the lock.
11061107
1108+
``blocking`` indicates whether calling ``acquire`` should block until
1109+
the lock has been acquired or to fail immediately, causing ``acquire``
1110+
to return False and the lock not being acquired. Defaults to True.
1111+
Note this value can be overridden by passing a ``blocking``
1112+
argument to ``acquire``.
1113+
1114+
11071115
``blocking_timeout`` indicates the maximum amount of time in seconds to
11081116
spend trying to acquire the lock. A value of ``None`` indicates
11091117
continue trying forever. ``blocking_timeout`` can be specified as a
@@ -1146,6 +1154,7 @@ def lock(
11461154
name,
11471155
timeout=timeout,
11481156
sleep=sleep,
1157+
blocking=blocking,
11491158
blocking_timeout=blocking_timeout,
11501159
thread_local=thread_local,
11511160
)

redis/lock.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from __future__ import annotations
2+
13
import threading
24
import time as mod_time
35
import uuid
4-
from types import SimpleNamespace
6+
from types import SimpleNamespace, TracebackType
7+
from typing import Optional, Type
58

69
from redis.exceptions import LockError, LockNotOwnedError
10+
from redis.typing import Number
711

812

913
class Lock:
@@ -73,13 +77,14 @@ class Lock:
7377

7478
def __init__(
7579
self,
76-
redis,
77-
name,
78-
timeout=None,
79-
sleep=0.1,
80-
blocking=True,
81-
blocking_timeout=None,
82-
thread_local=True,
80+
redis: "Redis",
81+
name: str,
82+
*,
83+
timeout: Optional[Number] = None,
84+
sleep: Number = 0.1,
85+
blocking: bool = True,
86+
blocking_timeout: Optional[Number] = None,
87+
thread_local: bool = True,
8388
):
8489
"""
8590
Create a new Lock instance named ``name`` using the Redis client
@@ -142,7 +147,7 @@ def __init__(
142147
self.local.token = None
143148
self.register_scripts()
144149

145-
def register_scripts(self):
150+
def register_scripts(self) -> None:
146151
cls = self.__class__
147152
client = self.redis
148153
if cls.lua_release is None:
@@ -152,15 +157,27 @@ def register_scripts(self):
152157
if cls.lua_reacquire is None:
153158
cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
154159

155-
def __enter__(self):
160+
def __enter__(self) -> "Lock":
156161
if self.acquire():
157162
return self
158163
raise LockError("Unable to acquire lock within the time specified")
159164

160-
def __exit__(self, exc_type, exc_value, traceback):
165+
def __exit__(
166+
self,
167+
exc_type: Optional[Type[BaseException]],
168+
exc_value: Optional[BaseException],
169+
traceback: Optional[TracebackType]
170+
) -> None:
161171
self.release()
162172

163-
def acquire(self, blocking=None, blocking_timeout=None, token=None):
173+
def acquire(
174+
self,
175+
*,
176+
sleep: Optional[Number] = None,
177+
blocking: Optional[bool] = None,
178+
blocking_timeout: Optional[Number] = None,
179+
token: Optional[str] = None
180+
):
164181
"""
165182
Use Redis to hold a shared, distributed lock named ``name``.
166183
Returns True once the lock is acquired.
@@ -176,7 +193,8 @@ def acquire(self, blocking=None, blocking_timeout=None, token=None):
176193
object with the default encoding. If a token isn't specified, a UUID
177194
will be generated.
178195
"""
179-
sleep = self.sleep
196+
if sleep is None:
197+
sleep = self.sleep
180198
if token is None:
181199
token = uuid.uuid1().hex.encode()
182200
else:
@@ -200,7 +218,7 @@ def acquire(self, blocking=None, blocking_timeout=None, token=None):
200218
return False
201219
mod_time.sleep(sleep)
202220

203-
def do_acquire(self, token):
221+
def do_acquire(self, token: str) -> bool:
204222
if self.timeout:
205223
# convert to milliseconds
206224
timeout = int(self.timeout * 1000)
@@ -210,13 +228,13 @@ def do_acquire(self, token):
210228
return True
211229
return False
212230

213-
def locked(self):
231+
def locked(self) -> bool:
214232
"""
215233
Returns True if this key is locked by any process, otherwise False.
216234
"""
217235
return self.redis.get(self.name) is not None
218236

219-
def owned(self):
237+
def owned(self) -> bool:
220238
"""
221239
Returns True if this key is locked by this lock, otherwise False.
222240
"""
@@ -228,21 +246,23 @@ def owned(self):
228246
stored_token = encoder.encode(stored_token)
229247
return self.local.token is not None and stored_token == self.local.token
230248

231-
def release(self):
232-
"Releases the already acquired lock"
249+
def release(self) -> None:
250+
"""
251+
Releases the already acquired lock
252+
"""
233253
expected_token = self.local.token
234254
if expected_token is None:
235255
raise LockError("Cannot release an unlocked lock")
236256
self.local.token = None
237257
self.do_release(expected_token)
238258

239-
def do_release(self, expected_token):
259+
def do_release(self, expected_token: str) -> None:
240260
if not bool(
241261
self.lua_release(keys=[self.name], args=[expected_token], client=self.redis)
242262
):
243263
raise LockNotOwnedError("Cannot release a lock" " that's no longer owned")
244264

245-
def extend(self, additional_time, replace_ttl=False):
265+
def extend(self, additional_time: int, replace_ttl: bool = False) -> bool:
246266
"""
247267
Adds more time to an already acquired lock.
248268
@@ -259,19 +279,19 @@ def extend(self, additional_time, replace_ttl=False):
259279
raise LockError("Cannot extend a lock with no timeout")
260280
return self.do_extend(additional_time, replace_ttl)
261281

262-
def do_extend(self, additional_time, replace_ttl):
282+
def do_extend(self, additional_time: int, replace_ttl: bool) -> bool:
263283
additional_time = int(additional_time * 1000)
264284
if not bool(
265285
self.lua_extend(
266286
keys=[self.name],
267-
args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
287+
args=[self.local.token, additional_time, "1" if replace_ttl else "0"],
268288
client=self.redis,
269289
)
270290
):
271-
raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned")
291+
raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
272292
return True
273293

274-
def reacquire(self):
294+
def reacquire(self) -> bool:
275295
"""
276296
Resets a TTL of an already acquired lock back to a timeout value.
277297
"""
@@ -281,12 +301,12 @@ def reacquire(self):
281301
raise LockError("Cannot reacquire a lock with no timeout")
282302
return self.do_reacquire()
283303

284-
def do_reacquire(self):
304+
def do_reacquire(self) -> bool:
285305
timeout = int(self.timeout * 1000)
286306
if not bool(
287307
self.lua_reacquire(
288308
keys=[self.name], args=[self.local.token, timeout], client=self.redis
289309
)
290310
):
291-
raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned")
311+
raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
292312
return True

redis/typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from redis.connection import ConnectionPool
1111

1212

13+
Number = Union[int, float]
1314
EncodedT = Union[bytes, memoryview]
1415
DecodedT = Union[str, int, float]
1516
EncodableT = Union[EncodedT, DecodedT]
@@ -37,7 +38,6 @@
3738
AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview)
3839
AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview)
3940

40-
4141
class CommandsProtocol(Protocol):
4242
connection_pool: Union["AsyncConnectionPool", "ConnectionPool"]
4343

tests/test_lock.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,16 @@ def test_context_manager(self, r):
116116
assert r.get("foo") == lock.local.token
117117
assert r.get("foo") is None
118118

119+
def test_context_manager_blocking_timeout(self, r):
120+
with self.get_lock(r, "foo", blocking=False) as lock1:
121+
bt = 0.4
122+
sleep = 0.05
123+
lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt)
124+
start = time.monotonic()
125+
assert not lock2.acquire()
126+
# The elapsed duration should be less than the total blocking_timeout
127+
assert bt > (time.monotonic() - start) > bt - sleep
128+
119129
def test_context_manager_raises_when_locked_not_acquired(self, r):
120130
r.set("foo", "bar")
121131
with pytest.raises(LockError):
@@ -221,6 +231,16 @@ def test_reacquiring_lock_no_longer_owned_raises_error(self, r):
221231
with pytest.raises(LockNotOwnedError):
222232
lock.reacquire()
223233

234+
def test_context_manager_reacquiring_lock_with_no_timeout_raises_error(self, r):
235+
with self.get_lock(r, "foo", timeout=None, blocking=False) as lock:
236+
with pytest.raises(LockError):
237+
lock.reacquire()
238+
239+
def test_context_manager_reacquiring_lock_no_longer_owned_raises_error(self, r):
240+
with pytest.raises(LockNotOwnedError):
241+
with self.get_lock(r, "foo", timeout=10, blocking=False):
242+
r.set("foo", "a")
243+
224244

225245
class TestLockClassSelection:
226246
def test_lock_class_argument(self, r):

0 commit comments

Comments
 (0)