|
15 | 15 | from django.contrib.auth import authenticate, get_user_model
|
16 | 16 | from django.contrib.auth.hashers import check_password, identify_hasher
|
17 | 17 | from django.core.exceptions import ObjectDoesNotExist
|
18 |
| -from django.db import transaction |
| 18 | +from django.db import router, transaction |
19 | 19 | from django.http import HttpRequest
|
20 | 20 | from django.utils import dateformat, timezone
|
21 | 21 | from django.utils.crypto import constant_time_compare
|
@@ -567,11 +567,23 @@ def rotate_refresh_token(self, request):
|
567 | 567 | """
|
568 | 568 | return oauth2_settings.ROTATE_REFRESH_TOKEN
|
569 | 569 |
|
570 |
| - @transaction.atomic |
571 | 570 | def save_bearer_token(self, token, request, *args, **kwargs):
|
572 | 571 | """
|
573 |
| - Save access and refresh token, If refresh token is issued, remove or |
574 |
| - reuse old refresh token as in rfc:`6` |
| 572 | + Save access and refresh token. |
| 573 | +
|
| 574 | + Override _save_bearer_token and not this function when adding custom logic |
| 575 | + for the storing of these token. This allows the transaction logic to be |
| 576 | + separate from the token handling. |
| 577 | + """ |
| 578 | + # Use the AccessToken's database instead of making the assumption it is in 'default'. |
| 579 | + with transaction.atomic(using=router.db_for_write(AccessToken)): |
| 580 | + return self._save_bearer_token(token, request, *args, **kwargs) |
| 581 | + |
| 582 | + def _save_bearer_token(self, token, request, *args, **kwargs): |
| 583 | + """ |
| 584 | + Save access and refresh token. |
| 585 | +
|
| 586 | + If refresh token is issued, remove or reuse old refresh token as in rfc:`6`. |
575 | 587 |
|
576 | 588 | @see: https://rfc-editor.org/rfc/rfc6749.html#section-6
|
577 | 589 | """
|
@@ -793,7 +805,6 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs
|
793 | 805 |
|
794 | 806 | return rt.application == client
|
795 | 807 |
|
796 |
| - @transaction.atomic |
797 | 808 | def _save_id_token(self, jti, request, expires, *args, **kwargs):
|
798 | 809 | scopes = request.scope or " ".join(request.scopes)
|
799 | 810 |
|
@@ -894,7 +905,9 @@ def finalize_id_token(self, id_token, token, token_handler, request):
|
894 | 905 | claims=json.dumps(id_token, default=str),
|
895 | 906 | )
|
896 | 907 | jwt_token.make_signed_token(request.client.jwk_key)
|
897 |
| - id_token = self._save_id_token(id_token["jti"], request, expiration_time) |
| 908 | + # Use the IDToken's database instead of making the assumption it is in 'default'. |
| 909 | + with transaction.atomic(using=router.db_for_write(IDToken)): |
| 910 | + id_token = self._save_id_token(id_token["jti"], request, expiration_time) |
898 | 911 | # this is needed by django rest framework
|
899 | 912 | request.access_token = id_token
|
900 | 913 | request.id_token = id_token
|
|
0 commit comments