|
21 | 21 | import random |
22 | 22 | import socket |
23 | 23 | import sys |
24 | | -import time |
| 24 | +import time as time # noqa: PLC0414 # needed in sync version |
25 | 25 | from typing import ( |
26 | 26 | Any, |
27 | 27 | Callable, |
28 | 28 | TypeVar, |
29 | 29 | cast, |
30 | 30 | ) |
31 | 31 |
|
| 32 | +from pymongo import _csot |
32 | 33 | from pymongo.errors import ( |
33 | 34 | OperationFailure, |
34 | 35 | PyMongoError, |
35 | 36 | ) |
36 | 37 | from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE |
| 38 | +from pymongo.lock import _async_create_lock |
37 | 39 |
|
38 | 40 | _IS_SYNC = False |
39 | 41 |
|
@@ -78,34 +80,115 @@ async def inner(*args: Any, **kwargs: Any) -> Any: |
78 | 80 | _MAX_RETRIES = 3 |
79 | 81 | _BACKOFF_INITIAL = 0.05 |
80 | 82 | _BACKOFF_MAX = 10 |
81 | | -_TIME = time |
| 83 | +# DRIVERS-3240 will determine these defaults. |
| 84 | +DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0 |
| 85 | +DEFAULT_RETRY_TOKEN_RETURN = 0.1 |
82 | 86 |
|
83 | 87 |
|
84 | | -async def _backoff( |
| 88 | +def _backoff( |
85 | 89 | attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX |
86 | | -) -> None: |
| 90 | +) -> float: |
87 | 91 | jitter = random.random() # noqa: S311 |
88 | | - backoff = jitter * min(initial_delay * (2**attempt), max_delay) |
89 | | - await asyncio.sleep(backoff) |
| 92 | + return jitter * min(initial_delay * (2**attempt), max_delay) |
| 93 | + |
| 94 | + |
| 95 | +class _TokenBucket: |
| 96 | + """A token bucket implementation for rate limiting.""" |
| 97 | + |
| 98 | + def __init__( |
| 99 | + self, |
| 100 | + capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY, |
| 101 | + return_rate: float = DEFAULT_RETRY_TOKEN_RETURN, |
| 102 | + ): |
| 103 | + self.lock = _async_create_lock() |
| 104 | + self.capacity = capacity |
| 105 | + # DRIVERS-3240 will determine how full the bucket should start. |
| 106 | + self.tokens = capacity |
| 107 | + self.return_rate = return_rate |
| 108 | + |
| 109 | + async def consume(self) -> bool: |
| 110 | + """Consume a token from the bucket if available.""" |
| 111 | + async with self.lock: |
| 112 | + if self.tokens >= 1: |
| 113 | + self.tokens -= 1 |
| 114 | + return True |
| 115 | + return False |
| 116 | + |
| 117 | + async def deposit(self, retry: bool = False) -> None: |
| 118 | + """Deposit a token back into the bucket.""" |
| 119 | + retry_token = 1 if retry else 0 |
| 120 | + async with self.lock: |
| 121 | + self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate) |
| 122 | + |
| 123 | + |
| 124 | +class _RetryPolicy: |
| 125 | + """A retry limiter that performs exponential backoff with jitter. |
| 126 | +
|
| 127 | + Retry attempts are limited by a token bucket to prevent overwhelming the server during |
| 128 | + a prolonged outage or high load. |
| 129 | + """ |
| 130 | + |
| 131 | + def __init__( |
| 132 | + self, |
| 133 | + token_bucket: _TokenBucket, |
| 134 | + attempts: int = _MAX_RETRIES, |
| 135 | + backoff_initial: float = _BACKOFF_INITIAL, |
| 136 | + backoff_max: float = _BACKOFF_MAX, |
| 137 | + ): |
| 138 | + self.token_bucket = token_bucket |
| 139 | + self.attempts = attempts |
| 140 | + self.backoff_initial = backoff_initial |
| 141 | + self.backoff_max = backoff_max |
| 142 | + |
| 143 | + async def record_success(self, retry: bool) -> None: |
| 144 | + """Record a successful operation.""" |
| 145 | + await self.token_bucket.deposit(retry) |
| 146 | + |
| 147 | + def backoff(self, attempt: int) -> float: |
| 148 | + """Return the backoff duration for the given .""" |
| 149 | + return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) |
| 150 | + |
| 151 | + async def should_retry(self, attempt: int, delay: float) -> bool: |
| 152 | + """Return if we have budget to retry and how long to backoff.""" |
| 153 | + if attempt > self.attempts: |
| 154 | + return False |
| 155 | + |
| 156 | + # If the delay would exceed the deadline, bail early before consuming a token. |
| 157 | + if _csot.get_timeout(): |
| 158 | + if time.monotonic() + delay > _csot.get_deadline(): |
| 159 | + return False |
| 160 | + |
| 161 | + # Check token bucket last since we only want to consume a token if we actually retry. |
| 162 | + if not await self.token_bucket.consume(): |
| 163 | + # DRIVERS-3246 Improve diagnostics when this case happens. |
| 164 | + # We could add info to the exception and log. |
| 165 | + return False |
| 166 | + return True |
90 | 167 |
|
91 | 168 |
|
92 | 169 | def _retry_overload(func: F) -> F: |
93 | 170 | @functools.wraps(func) |
94 | | - async def inner(*args: Any, **kwargs: Any) -> Any: |
| 171 | + async def inner(self: Any, *args: Any, **kwargs: Any) -> Any: |
| 172 | + retry_policy = self._retry_policy |
95 | 173 | attempt = 0 |
96 | 174 | while True: |
97 | 175 | try: |
98 | | - return await func(*args, **kwargs) |
| 176 | + res = await func(self, *args, **kwargs) |
| 177 | + await retry_policy.record_success(retry=attempt > 0) |
| 178 | + return res |
99 | 179 | except PyMongoError as exc: |
100 | 180 | if not exc.has_error_label("Retryable"): |
101 | 181 | raise |
102 | 182 | attempt += 1 |
103 | | - if attempt > _MAX_RETRIES: |
| 183 | + delay = 0 |
| 184 | + if exc.has_error_label("SystemOverloaded"): |
| 185 | + delay = retry_policy.backoff(attempt) |
| 186 | + if not await retry_policy.should_retry(attempt, delay): |
104 | 187 | raise |
105 | 188 |
|
106 | 189 | # Implement exponential backoff on retry. |
107 | | - if exc.has_error_label("SystemOverloaded"): |
108 | | - await _backoff(attempt) |
| 190 | + if delay: |
| 191 | + await asyncio.sleep(delay) |
109 | 192 | continue |
110 | 193 |
|
111 | 194 | return cast(F, inner) |
|
0 commit comments