|
15 | 15 | """Test Client Backpressure spec.""" |
16 | 16 | from __future__ import annotations |
17 | 17 |
|
| 18 | +import asyncio |
18 | 19 | import sys |
19 | 20 |
|
| 21 | +import pymongo |
| 22 | + |
20 | 23 | sys.path[0:0] = [""] |
21 | 24 |
|
22 | 25 | from test.asynchronous import ( |
@@ -187,31 +190,41 @@ async def test_retry_policy(self): |
187 | 190 | self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL) |
188 | 191 | self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX) |
189 | 192 | for i in range(1, helpers._MAX_RETRIES + 1): |
190 | | - self.assertTrue(await retry_policy.should_retry(i)) |
191 | | - self.assertFalse(await retry_policy.should_retry(helpers._MAX_RETRIES + 1)) |
| 193 | + self.assertTrue(await retry_policy.should_retry(i, 0)) |
| 194 | + self.assertFalse(await retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0)) |
192 | 195 | for i in range(capacity - helpers._MAX_RETRIES): |
193 | | - self.assertTrue(await retry_policy.should_retry(1)) |
| 196 | + self.assertTrue(await retry_policy.should_retry(1, 0)) |
194 | 197 | # No tokens left, should not retry. |
195 | | - self.assertFalse(await retry_policy.should_retry(1)) |
| 198 | + self.assertFalse(await retry_policy.should_retry(1, 0)) |
196 | 199 | self.assertEqual(retry_policy.token_bucket.tokens, 0) |
197 | 200 |
|
198 | 201 | # record_success should generate tokens. |
199 | 202 | for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)): |
200 | 203 | await retry_policy.record_success(retry=False) |
201 | 204 | self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2) |
202 | 205 | for i in range(2): |
203 | | - self.assertTrue(await retry_policy.should_retry(1)) |
204 | | - self.assertFalse(await retry_policy.should_retry(1)) |
| 206 | + self.assertTrue(await retry_policy.should_retry(1, 0)) |
| 207 | + self.assertFalse(await retry_policy.should_retry(1, 0)) |
205 | 208 |
|
206 | 209 | # Recording a successful retry should return 1 additional token. |
207 | 210 | await retry_policy.record_success(retry=True) |
208 | 211 | self.assertAlmostEqual( |
209 | 212 | retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN |
210 | 213 | ) |
211 | | - self.assertTrue(await retry_policy.should_retry(1)) |
212 | | - self.assertFalse(await retry_policy.should_retry(1)) |
| 214 | + self.assertTrue(await retry_policy.should_retry(1, 0)) |
| 215 | + self.assertFalse(await retry_policy.should_retry(1, 0)) |
213 | 216 | self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN) |
214 | 217 |
|
| 218 | + async def test_retry_policy_csot(self): |
| 219 | + retry_policy = _RetryPolicy(_TokenBucket()) |
| 220 | + self.assertTrue(await retry_policy.should_retry(1, 0.5)) |
| 221 | + with pymongo.timeout(0.5): |
| 222 | + self.assertTrue(await retry_policy.should_retry(1, 0)) |
| 223 | + self.assertTrue(await retry_policy.should_retry(1, 0.1)) |
| 224 | + # Would exceed the timeout, should not retry. |
| 225 | + self.assertFalse(await retry_policy.should_retry(1, 1.0)) |
| 226 | + self.assertTrue(await retry_policy.should_retry(1, 1.0)) |
| 227 | + |
215 | 228 |
|
216 | 229 | if __name__ == "__main__": |
217 | 230 | unittest.main() |
0 commit comments