From 3bde632d5722f1f85ffcd8277504955321f00fff Mon Sep 17 00:00:00 2001 From: Alan Crosswell Date: Tue, 6 Oct 2020 15:07:37 -0400 Subject: [PATCH 01/53] Revert "Openid Connect Core support - Round 2 (#859)" (#877) This reverts commit 4655c030be15616ba6e0872253a2c15a897d9701. --- .gitignore | 2 +- oauth2_provider/admin.py | 9 +- oauth2_provider/forms.py | 1 - .../migrations/0002_auto_20190406_1805.py | 2 + .../migrations/0003_auto_20200902_2022.py | 48 - oauth2_provider/models.py | 114 -- oauth2_provider/oauth2_backends.py | 26 +- oauth2_provider/oauth2_validators.py | 376 +---- oauth2_provider/settings.py | 62 +- oauth2_provider/urls.py | 9 +- oauth2_provider/views/__init__.py | 16 +- oauth2_provider/views/application.py | 4 +- oauth2_provider/views/base.py | 74 +- oauth2_provider/views/introspect.py | 2 +- oauth2_provider/views/mixins.py | 31 +- oauth2_provider/views/oidc.py | 95 -- setup.cfg | 1 - tests/migrations/0001_initial.py | 7 +- tests/settings.py | 27 - tests/test_application_views.py | 1 - tests/test_authorization_code.py | 682 ++------- tests/test_hybrid.py | 1264 ----------------- tests/test_implicit.py | 198 +-- tests/test_oauth2_backends.py | 4 +- tests/test_oauth2_validators.py | 7 - tests/test_oidc_views.py | 77 - tests/urls.py | 8 +- tox.ini | 9 +- 28 files changed, 259 insertions(+), 2897 deletions(-) delete mode 100644 oauth2_provider/migrations/0003_auto_20200902_2022.py delete mode 100644 oauth2_provider/views/oidc.py delete mode 100644 tests/test_hybrid.py delete mode 100644 tests/test_oidc_views.py diff --git a/.gitignore b/.gitignore index c22ef00fa..af644d1e3 100644 --- a/.gitignore +++ b/.gitignore @@ -25,7 +25,7 @@ __pycache__ pip-log.txt # Unit test / coverage reports -.pytest_cache +.cache .coverage .tox .pytest_cache/ diff --git a/oauth2_provider/admin.py b/oauth2_provider/admin.py index a8d69e623..8b963d981 100644 --- a/oauth2_provider/admin.py +++ b/oauth2_provider/admin.py @@ -2,7 +2,7 @@ from .models import ( get_access_token_model, get_application_model, - get_grant_model, get_id_token_model, get_refresh_token_model + get_grant_model, get_refresh_token_model ) @@ -26,11 +26,6 @@ class AccessTokenAdmin(admin.ModelAdmin): raw_id_fields = ("user", "source_refresh_token") -class IDTokenAdmin(admin.ModelAdmin): - list_display = ("token", "user", "application", "expires") - raw_id_fields = ("user", ) - - class RefreshTokenAdmin(admin.ModelAdmin): list_display = ("token", "user", "application") raw_id_fields = ("user", "access_token") @@ -39,11 +34,9 @@ class RefreshTokenAdmin(admin.ModelAdmin): Application = get_application_model() Grant = get_grant_model() AccessToken = get_access_token_model() -IDToken = get_id_token_model() RefreshToken = get_refresh_token_model() admin.site.register(Application, ApplicationAdmin) admin.site.register(Grant, GrantAdmin) admin.site.register(AccessToken, AccessTokenAdmin) -admin.site.register(IDToken, IDTokenAdmin) admin.site.register(RefreshToken, RefreshTokenAdmin) diff --git a/oauth2_provider/forms.py b/oauth2_provider/forms.py index 41129c449..2e465959a 100644 --- a/oauth2_provider/forms.py +++ b/oauth2_provider/forms.py @@ -5,7 +5,6 @@ class AllowForm(forms.Form): allow = forms.BooleanField(required=False) redirect_uri = forms.CharField(widget=forms.HiddenInput()) scope = forms.CharField(widget=forms.HiddenInput()) - nonce = forms.CharField(required=False, widget=forms.HiddenInput()) client_id = forms.CharField(widget=forms.HiddenInput()) state = forms.CharField(required=False, widget=forms.HiddenInput()) response_type = forms.CharField(widget=forms.HiddenInput()) diff --git a/oauth2_provider/migrations/0002_auto_20190406_1805.py b/oauth2_provider/migrations/0002_auto_20190406_1805.py index bcacc23ce..8ca177abf 100644 --- a/oauth2_provider/migrations/0002_auto_20190406_1805.py +++ b/oauth2_provider/migrations/0002_auto_20190406_1805.py @@ -1,3 +1,5 @@ +# Generated by Django 2.2 on 2019-04-06 18:05 + from django.db import migrations, models diff --git a/oauth2_provider/migrations/0003_auto_20200902_2022.py b/oauth2_provider/migrations/0003_auto_20200902_2022.py deleted file mode 100644 index 684949c9d..000000000 --- a/oauth2_provider/migrations/0003_auto_20200902_2022.py +++ /dev/null @@ -1,48 +0,0 @@ -from django.conf import settings -from django.db import migrations, models -import django.db.models.deletion - -from oauth2_provider.settings import oauth2_settings - - -class Migration(migrations.Migration): - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ('oauth2_provider', '0002_auto_20190406_1805'), - ] - - operations = [ - migrations.AddField( - model_name='application', - name='algorithm', - field=models.CharField(choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256', max_length=5), - ), - migrations.AlterField( - model_name='application', - name='authorization_grant_type', - field=models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32), - ), - migrations.CreateModel( - name='IDToken', - fields=[ - ('id', models.BigAutoField(primary_key=True, serialize=False)), - ('token', models.TextField(unique=True)), - ('expires', models.DateTimeField()), - ('scope', models.TextField(blank=True)), - ('created', models.DateTimeField(auto_now_add=True)), - ('updated', models.DateTimeField(auto_now=True)), - ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=oauth2_settings.APPLICATION_MODEL)), - ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_idtoken', to=settings.AUTH_USER_MODEL)), - ], - options={ - 'abstract': False, - 'swappable': 'OAUTH2_PROVIDER_ID_TOKEN_MODEL', - }, - ), - migrations.AddField( - model_name='accesstoken', - name='id_token', - field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=oauth2_settings.ID_TOKEN_MODEL), - ), - ] diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 7135192db..5676bc0c5 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -1,4 +1,3 @@ -import json import logging from datetime import timedelta from urllib.parse import parse_qsl, urlparse @@ -10,7 +9,6 @@ from django.urls import reverse from django.utils import timezone from django.utils.translation import gettext_lazy as _ -from jwcrypto import jwk, jwt from .generators import generate_client_id, generate_client_secret from .scopes import get_scopes_backend @@ -52,20 +50,11 @@ class AbstractApplication(models.Model): GRANT_IMPLICIT = "implicit" GRANT_PASSWORD = "password" GRANT_CLIENT_CREDENTIALS = "client-credentials" - GRANT_OPENID_HYBRID = "openid-hybrid" GRANT_TYPES = ( (GRANT_AUTHORIZATION_CODE, _("Authorization code")), (GRANT_IMPLICIT, _("Implicit")), (GRANT_PASSWORD, _("Resource owner password-based")), (GRANT_CLIENT_CREDENTIALS, _("Client credentials")), - (GRANT_OPENID_HYBRID, _("OpenID connect hybrid")), - ) - - RS256_ALGORITHM = "RS256" - HS256_ALGORITHM = "HS256" - ALGORITHM_TYPES = ( - (RS256_ALGORITHM, _("RSA with SHA-2 256")), - (HS256_ALGORITHM, _("HMAC with SHA-2 256")), ) id = models.BigAutoField(primary_key=True) @@ -93,7 +82,6 @@ class AbstractApplication(models.Model): created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) - algorithm = models.CharField(max_length=5, choices=ALGORITHM_TYPES, default=RS256_ALGORITHM) class Meta: abstract = True @@ -294,10 +282,6 @@ class AbstractAccessToken(models.Model): related_name="refreshed_access_token" ) token = models.CharField(max_length=255, unique=True, ) - id_token = models.OneToOneField( - oauth2_settings.ID_TOKEN_MODEL, on_delete=models.CASCADE, blank=True, null=True, - related_name="access_token" - ) application = models.ForeignKey( oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE, blank=True, null=True, ) @@ -431,99 +415,6 @@ class Meta(AbstractRefreshToken.Meta): swappable = "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL" -class AbstractIDToken(models.Model): - """ - An IDToken instance represents the actual token to - access user's resources, as in :openid:`2`. - - Fields: - - * :attr:`user` The Django user representing resources' owner - * :attr:`token` ID token - * :attr:`application` Application instance - * :attr:`expires` Date and time of token expiration, in DateTime format - * :attr:`scope` Allowed scopes - """ - id = models.BigAutoField(primary_key=True) - user = models.ForeignKey( - settings.AUTH_USER_MODEL, on_delete=models.CASCADE, blank=True, null=True, - related_name="%(app_label)s_%(class)s" - ) - token = models.TextField(unique=True) - application = models.ForeignKey( - oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE, blank=True, null=True, - ) - expires = models.DateTimeField() - scope = models.TextField(blank=True) - - created = models.DateTimeField(auto_now_add=True) - updated = models.DateTimeField(auto_now=True) - - def is_valid(self, scopes=None): - """ - Checks if the access token is valid. - - :param scopes: An iterable containing the scopes to check or None - """ - return not self.is_expired() and self.allow_scopes(scopes) - - def is_expired(self): - """ - Check token expiration with timezone awareness - """ - if not self.expires: - return True - - return timezone.now() >= self.expires - - def allow_scopes(self, scopes): - """ - Check if the token allows the provided scopes - - :param scopes: An iterable containing the scopes to check - """ - if not scopes: - return True - - provided_scopes = set(self.scope.split()) - resource_scopes = set(scopes) - - return resource_scopes.issubset(provided_scopes) - - def revoke(self): - """ - Convenience method to uniform tokens' interface, for now - simply remove this token from the database in order to revoke it. - """ - self.delete() - - @property - def scopes(self): - """ - Returns a dictionary of allowed scope names (as keys) with their descriptions (as values) - """ - all_scopes = get_scopes_backend().get_all_scopes() - token_scopes = self.scope.split() - return {name: desc for name, desc in all_scopes.items() if name in token_scopes} - - @property - def claims(self): - key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) - jwt_token = jwt.JWT(key=key, jwt=self.token) - return json.loads(jwt_token.claims) - - def __str__(self): - return self.token - - class Meta: - abstract = True - - -class IDToken(AbstractIDToken): - class Meta(AbstractIDToken.Meta): - swappable = "OAUTH2_PROVIDER_ID_TOKEN_MODEL" - - def get_application_model(): """ Return the Application model that is active in this project. """ return apps.get_model(oauth2_settings.APPLICATION_MODEL) @@ -539,11 +430,6 @@ def get_access_token_model(): return apps.get_model(oauth2_settings.ACCESS_TOKEN_MODEL) -def get_id_token_model(): - """ Return the AccessToken model that is active in this project. """ - return apps.get_model(oauth2_settings.ID_TOKEN_MODEL) - - def get_refresh_token_model(): """ Return the RefreshToken model that is active in this project. """ return apps.get_model(oauth2_settings.REFRESH_TOKEN_MODEL) diff --git a/oauth2_provider/oauth2_backends.py b/oauth2_provider/oauth2_backends.py index 404add70e..6d8e68a2c 100644 --- a/oauth2_provider/oauth2_backends.py +++ b/oauth2_provider/oauth2_backends.py @@ -104,7 +104,7 @@ def validate_authorization_request(self, request): except oauth2.OAuth2Error as error: raise OAuthToolkitError(error=error) - def create_authorization_response(self, uri, request, scopes, credentials, body, allow): + def create_authorization_response(self, request, scopes, credentials, allow): """ A wrapper method that calls create_authorization_response on `server_class` instance. @@ -112,8 +112,7 @@ def create_authorization_response(self, uri, request, scopes, credentials, body, :param request: The current django.http.HttpRequest object :param scopes: A list of provided scopes :param credentials: Authorization credentials dictionary containing - `client_id`, `state`, `redirect_uri` and `response_type` - :param body: Other body parameters not used in credentials dictionary + `client_id`, `state`, `redirect_uri`, `response_type` :param allow: True if the user authorize the client, otherwise False """ try: @@ -125,10 +124,10 @@ def create_authorization_response(self, uri, request, scopes, credentials, body, credentials["user"] = request.user headers, body, status = self.server.create_authorization_response( - uri=uri, scopes=scopes, credentials=credentials, body=body) - redirect_uri = headers.get("Location", None) + uri=credentials["redirect_uri"], scopes=scopes, credentials=credentials) + uri = headers.get("Location", None) - return redirect_uri, headers, body, status + return uri, headers, body, status except oauth2.FatalClientError as error: raise FatalClientError( @@ -167,21 +166,6 @@ def create_revocation_response(self, request): return uri, headers, body, status - def create_userinfo_response(self, request): - """ - A wrapper method that calls create_userinfo_response on a - `server_class` instance. - - :param request: The current django.http.HttpRequest object - """ - uri, http_method, body, headers = self._extract_params(request) - headers, body, status = self.server.create_userinfo_response( - uri, http_method, body, headers - ) - uri = headers.get("Location", None) - - return uri, headers, body, status - def verify_request(self, request, scopes): """ A wrapper method that calls verify_request on `server_class` instance. diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index e7fb860b3..515353d6f 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -1,8 +1,6 @@ import base64 import binascii -import hashlib import http.client -import json import logging from collections import OrderedDict from datetime import datetime, timedelta @@ -14,21 +12,15 @@ from django.core.exceptions import ObjectDoesNotExist from django.db import transaction from django.db.models import Q -from django.http import HttpRequest -from django.urls import reverse -from django.utils import dateformat, timezone +from django.utils import timezone from django.utils.timezone import make_aware from django.utils.translation import gettext_lazy as _ -from jwcrypto import jwk, jwt -from jwcrypto.common import JWException -from jwcrypto.jwt import JWTExpired from oauthlib.oauth2 import RequestValidator -from oauthlib.oauth2.rfc6749 import utils from .exceptions import FatalClientError from .models import ( - AbstractApplication, get_access_token_model, get_application_model, - get_grant_model, get_id_token_model, get_refresh_token_model + AbstractApplication, get_access_token_model, + get_application_model, get_grant_model, get_refresh_token_model ) from .scopes import get_scopes_backend from .settings import oauth2_settings @@ -37,23 +29,18 @@ log = logging.getLogger("oauth2_provider") GRANT_TYPE_MAPPING = { - "authorization_code": ( - AbstractApplication.GRANT_AUTHORIZATION_CODE, - AbstractApplication.GRANT_OPENID_HYBRID, - ), - "password": (AbstractApplication.GRANT_PASSWORD,), - "client_credentials": (AbstractApplication.GRANT_CLIENT_CREDENTIALS,), + "authorization_code": (AbstractApplication.GRANT_AUTHORIZATION_CODE, ), + "password": (AbstractApplication.GRANT_PASSWORD, ), + "client_credentials": (AbstractApplication.GRANT_CLIENT_CREDENTIALS, ), "refresh_token": ( AbstractApplication.GRANT_AUTHORIZATION_CODE, AbstractApplication.GRANT_PASSWORD, AbstractApplication.GRANT_CLIENT_CREDENTIALS, - AbstractApplication.GRANT_OPENID_HYBRID, - ), + ) } Application = get_application_model() AccessToken = get_access_token_model() -IDToken = get_id_token_model() Grant = get_grant_model() RefreshToken = get_refresh_token_model() UserModel = get_user_model() @@ -106,15 +93,12 @@ def _authenticate_basic_auth(self, request): except UnicodeDecodeError: log.debug( "Failed basic auth: %r can't be decoded as unicode by %r", - auth_string, - encoding, + auth_string, encoding ) return False try: - client_id, client_secret = map( - unquote_plus, auth_string_decoded.split(":", 1) - ) + client_id, client_secret = map(unquote_plus, auth_string_decoded.split(":", 1)) except ValueError: log.debug("Failed basic auth, Invalid base64 encoding.") return False @@ -163,54 +147,35 @@ def _load_application(self, client_id, request): """ # we want to be sure that request has the client attribute! - assert hasattr( - request, "client" - ), '"request" instance has no "client" attribute' + assert hasattr(request, "client"), '"request" instance has no "client" attribute' try: - request.client = request.client or Application.objects.get( - client_id=client_id - ) + request.client = request.client or Application.objects.get(client_id=client_id) # Check that the application can be used (defaults to always True) if not request.client.is_usable(request): - log.debug( - "Failed body authentication: Application %r is disabled" - % (client_id) - ) + log.debug("Failed body authentication: Application %r is disabled" % (client_id)) return None return request.client except Application.DoesNotExist: - log.debug( - "Failed body authentication: Application %r does not exist" - % (client_id) - ) + log.debug("Failed body authentication: Application %r does not exist" % (client_id)) return None def _set_oauth2_error_on_request(self, request, access_token, scopes): if access_token is None: - error = OrderedDict( - [ - ("error", "invalid_token",), - ("error_description", _("The access token is invalid."),), - ] - ) + error = OrderedDict([ + ("error", "invalid_token", ), + ("error_description", _("The access token is invalid."), ), + ]) elif access_token.is_expired(): - error = OrderedDict( - [ - ("error", "invalid_token",), - ("error_description", _("The access token has expired."),), - ] - ) + error = OrderedDict([ + ("error", "invalid_token", ), + ("error_description", _("The access token has expired."), ), + ]) elif not access_token.allow_scopes(scopes): - error = OrderedDict( - [ - ("error", "insufficient_scope",), - ( - "error_description", - _("The access token is valid but does not have enough scope."), - ), - ] - ) + error = OrderedDict([ + ("error", "insufficient_scope", ), + ("error_description", _("The access token is valid but does not have enough scope."), ), + ]) else: log.warning("OAuth2 access token is invalid for an unknown reason.") error = OrderedDict([ @@ -276,15 +241,11 @@ def authenticate_client_id(self, client_id, request, *args, **kwargs): proceed only if the client exists and is not of type "Confidential". """ if self._load_application(client_id, request) is not None: - log.debug( - "Application %r has type %r" % (client_id, request.client.client_type) - ) + log.debug("Application %r has type %r" % (client_id, request.client.client_type)) return request.client.client_type != AbstractApplication.CLIENT_CONFIDENTIAL return False - def confirm_redirect_uri( - self, client_id, code, redirect_uri, client, *args, **kwargs - ): + def confirm_redirect_uri(self, client_id, code, redirect_uri, client, *args, **kwargs): """ Ensure the redirect_uri is listed in the Application instance redirect_uris field """ @@ -309,7 +270,7 @@ def get_default_redirect_uri(self, client_id, request, *args, **kwargs): return request.client.default_redirect_uri def _get_token_from_authentication_server( - self, token, introspection_url, introspection_token, introspection_credentials + self, token, introspection_url, introspection_token, introspection_credentials ): """Use external introspection endpoint to "crack open" the token. :param introspection_url: introspection endpoint URL @@ -337,12 +298,11 @@ def _get_token_from_authentication_server( try: response = requests.post( - introspection_url, data={"token": token}, headers=headers + introspection_url, + data={"token": token}, headers=headers ) except requests.exceptions.RequestException: - log.exception( - "Introspection: Failed POST to %r in token lookup", introspection_url - ) + log.exception("Introspection: Failed POST to %r in token lookup", introspection_url) return None # Log an exception when response from auth server is not successful @@ -388,8 +348,7 @@ def _get_token_from_authentication_server( "application": None, "scope": scope, "expires": expires, - }, - ) + }) return access_token @@ -402,14 +361,10 @@ def validate_bearer_token(self, token, scopes, request): introspection_url = oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL introspection_token = oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN - introspection_credentials = ( - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS - ) + introspection_credentials = oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS try: - access_token = AccessToken.objects.select_related( - "application", "user" - ).get(token=token) + access_token = AccessToken.objects.select_related("application", "user").get(token=token) except AccessToken.DoesNotExist: access_token = None @@ -420,7 +375,7 @@ def validate_bearer_token(self, token, scopes, request): token, introspection_url, introspection_token, - introspection_credentials, + introspection_credentials ) if access_token and access_token.is_valid(scopes): @@ -447,38 +402,22 @@ def validate_code(self, client_id, code, client, request, *args, **kwargs): except Grant.DoesNotExist: return False - def validate_grant_type( - self, client_id, grant_type, client, request, *args, **kwargs - ): + def validate_grant_type(self, client_id, grant_type, client, request, *args, **kwargs): """ Validate both grant_type is a valid string and grant_type is allowed for current workflow """ - assert grant_type in GRANT_TYPE_MAPPING # mapping misconfiguration + assert(grant_type in GRANT_TYPE_MAPPING) # mapping misconfiguration return request.client.allows_grant_type(*GRANT_TYPE_MAPPING[grant_type]) - def validate_response_type( - self, client_id, response_type, client, request, *args, **kwargs - ): + def validate_response_type(self, client_id, response_type, client, request, *args, **kwargs): """ We currently do not support the Authorization Endpoint Response Types registry as in rfc:`8.4`, so validate the response_type only if it matches "code" or "token" """ if response_type == "code": - return client.allows_grant_type( - AbstractApplication.GRANT_AUTHORIZATION_CODE - ) + return client.allows_grant_type(AbstractApplication.GRANT_AUTHORIZATION_CODE) elif response_type == "token": return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) - elif response_type == "id_token": - return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) - elif response_type == "id_token token": - return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) - elif response_type == "code id_token": - return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) - elif response_type == "code token": - return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) - elif response_type == "code id_token token": - return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) else: return False @@ -486,15 +425,11 @@ def validate_scopes(self, client_id, scopes, client, request, *args, **kwargs): """ Ensure required scopes are permitted (as specified in the settings file) """ - available_scopes = get_scopes_backend().get_available_scopes( - application=client, request=request - ) + available_scopes = get_scopes_backend().get_available_scopes(application=client, request=request) return set(scopes).issubset(set(available_scopes)) def get_default_scopes(self, client_id, request, *args, **kwargs): - default_scopes = get_scopes_backend().get_default_scopes( - application=request.client, request=request - ) + default_scopes = get_scopes_backend().get_default_scopes(application=request.client, request=request) return default_scopes def validate_redirect_uri(self, client_id, redirect_uri, request, *args, **kwargs): @@ -522,24 +457,6 @@ def get_code_challenge_method(self, code, request): def save_authorization_code(self, client_id, code, request, *args, **kwargs): self._create_authorization_code(request, code) - def get_authorization_code_scopes(self, client_id, code, redirect_uri, request): - scopes = [] - fields = { - "code": code, - } - - if client_id: - fields["application__client_id"] = client_id - - if redirect_uri: - fields["redirect_uri"] = redirect_uri - - grant = Grant.objects.filter(**fields).values() - if grant.exists(): - grant_dict = dict(grant[0]) - scopes = utils.scope_to_list(grant_dict["scope"]) - return scopes - def rotate_refresh_token(self, request): """ Checks if rotate refresh token is enabled @@ -580,11 +497,9 @@ def save_bearer_token(self, token, request, *args, **kwargs): refresh_token_instance = getattr(request, "refresh_token_instance", None) # If we are to reuse tokens, and we can: do so - if ( - not self.rotate_refresh_token(request) - and isinstance(refresh_token_instance, RefreshToken) - and refresh_token_instance.access_token - ): + if not self.rotate_refresh_token(request) and \ + isinstance(refresh_token_instance, RefreshToken) and \ + refresh_token_instance.access_token: access_token = AccessToken.objects.select_for_update().get( pk=refresh_token_instance.access_token.pk @@ -631,18 +546,14 @@ def save_bearer_token(self, token, request, *args, **kwargs): source_refresh_token=refresh_token_instance, ) - self._create_refresh_token( - request, refresh_token_code, access_token - ) + self._create_refresh_token(request, refresh_token_code, access_token) else: # make sure that the token data we're returning matches # the existing token token["access_token"] = previous_access_token.token - token["refresh_token"] = ( - RefreshToken.objects.filter(access_token=previous_access_token) - .first() - .token - ) + token["refresh_token"] = RefreshToken.objects.filter( + access_token=previous_access_token + ).first().token token["scope"] = previous_access_token.scope # No refresh token should be created, just access token @@ -650,15 +561,11 @@ def save_bearer_token(self, token, request, *args, **kwargs): self._create_access_token(expires, request, token) def _create_access_token(self, expires, request, token, source_refresh_token=None): - id_token = token.get("id_token", None) - if id_token: - id_token = IDToken.objects.get(token=id_token) return AccessToken.objects.create( user=request.user, scope=token["scope"], expires=expires, token=token["access_token"], - id_token=id_token, application=request.client, source_refresh_token=source_refresh_token, ) @@ -683,7 +590,7 @@ def _create_refresh_token(self, request, refresh_token_code, access_token): user=request.user, token=refresh_token_code, application=request.client, - access_token=access_token, + access_token=access_token ) def revoke_token(self, token, token_type_hint, request, *args, **kwargs): @@ -736,8 +643,9 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs """ null_or_recent = Q(revoked__isnull=True) | Q( - revoked__gt=timezone.now() - - timedelta(seconds=oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS) + revoked__gt=timezone.now() - timedelta( + seconds=oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS + ) ) rt = RefreshToken.objects.filter(null_or_recent, token=refresh_token).select_related( "access_token" @@ -751,183 +659,3 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs # Temporary store RefreshToken instance to be reused by get_original_scopes and save_bearer_token. request.refresh_token_instance = rt return rt.application == client - - @transaction.atomic - def _save_id_token(self, token, request, expires, *args, **kwargs): - - scopes = request.scope or " ".join(request.scopes) - - if request.grant_type == "client_credentials": - request.user = None - - id_token = IDToken.objects.create( - user=request.user, - scope=scopes, - expires=expires, - token=token.serialize(), - application=request.client, - ) - return id_token - - def get_jwt_bearer_token(self, token, token_handler, request): - return self.get_id_token(token, token_handler, request) - - def get_oidc_claims(self, token, token_handler, request): - # Required OIDC claims - claims = { - "sub": str(request.user.id), - } - - # https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims - claims.update(**self.get_additional_claims(request)) - - return claims - - def get_id_token_dictionary(self, token, token_handler, request): - # TODO: http://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken2 - # Save the id_token on database bound to code when the request come to - # Authorization Endpoint and return the same one when request come to - # Token Endpoint - - # TODO: Check if at this point this request parameters are alredy validated - claims = self.get_oidc_claims(token, token_handler, request) - - expiration_time = timezone.now() + timedelta( - seconds=oauth2_settings.ID_TOKEN_EXPIRE_SECONDS - ) - # Required ID Token claims - claims.update(**{ - "iss": self.get_oidc_issuer_endpoint(request), - "aud": request.client_id, - "exp": int(dateformat.format(expiration_time, "U")), - "iat": int(dateformat.format(datetime.utcnow(), "U")), - "auth_time": int(dateformat.format(request.user.last_login, "U")), - }) - - nonce = getattr(request, "nonce", None) - if nonce: - claims["nonce"] = nonce - - # TODO: create a function to check if we should add at_hash - # http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken - # http://openid.net/specs/openid-connect-core-1_0.html#ImplicitIDToken - # if request.grant_type in 'authorization_code' and 'access_token' in token: - if ( - (request.grant_type == "authorization_code" and "access_token" in token) - or request.response_type == "code id_token token" - or (request.response_type == "id_token token" and "access_token" in token) - ): - acess_token = token["access_token"] - at_hash = self.generate_at_hash(acess_token) - claims["at_hash"] = at_hash - - # TODO: create a function to check if we should include c_hash - # http://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken - if request.response_type in ("code id_token", "code id_token token"): - code = token["code"] - sha256 = hashlib.sha256(code.encode("ascii")) - bits256 = sha256.hexdigest()[:32] - c_hash = base64.urlsafe_b64encode(bits256.encode("ascii")) - claims["c_hash"] = c_hash.decode("utf8") - - return claims, expiration_time - - def get_oidc_issuer_endpoint(self, request): - if oauth2_settings.OIDC_ISS_ENDPOINT: - return oauth2_settings.OIDC_ISS_ENDPOINT - - # generate it based on known URL - django_request = HttpRequest() - django_request.META = request.headers - - abs_url = django_request.build_absolute_uri(reverse("oauth2_provider:oidc-connect-discovery-info")) - base_url = abs_url[:-len("/.well-known/openid-configuration/")] - return base_url - - def generate_at_hash(self, access_token): - sha256 = hashlib.sha256(access_token.encode("ascii")) - bits128 = sha256.digest()[:16] - at_hash = base64.urlsafe_b64encode(bits128).decode("utf8").rstrip("=") - return at_hash - - def get_id_token(self, token, token_handler, request): - key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) - - claims, expiration_time = self.get_id_token_dictionary(token, token_handler, request) - - jwt_token = jwt.JWT( - header=json.dumps({"alg": "RS256"}, default=str), - claims=json.dumps(claims, default=str), - ) - jwt_token.make_signed_token(key) - - id_token = self._save_id_token(jwt_token, request, expiration_time) - # this is needed by django rest framework - request.access_token = id_token - request.id_token = id_token - return jwt_token.serialize() - - def validate_jwt_bearer_token(self, token, scopes, request): - return self.validate_id_token(token, scopes, request) - - def validate_id_token(self, token, scopes, request): - """ - When users try to access resources, check that provided id_token is valid - """ - if not token: - return False - - key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) - - try: - jwt_token = jwt.JWT(key=key, jwt=token) - id_token = IDToken.objects.get(token=jwt_token.serialize()) - request.client = id_token.application - request.user = id_token.user - request.scopes = scopes - # this is needed by django rest framework - request.access_token = id_token - return True - except (JWException, JWTExpired): - # TODO: This is the base exception of all jwcrypto - return False - - return False - - def validate_user_match(self, id_token_hint, scopes, claims, request): - # TODO: Fix to validate when necessary acording - # https://github.com/idan/oauthlib/blob/master/oauthlib/oauth2/rfc6749/request_validator.py#L556 - # http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest id_token_hint section - return True - - def get_authorization_code_nonce(self, client_id, code, redirect_uri, request): - """ Extracts nonce from saved authorization code. - If present in the Authentication Request, Authorization - Servers MUST include a nonce Claim in the ID Token with the - Claim Value being the nonce value sent in the Authentication - Request. Authorization Servers SHOULD perform no other - processing on nonce values used. The nonce value is a - case-sensitive string. - Only code param should be sufficient to retrieve grant code from - any storage you are using. However, `client_id` and `redirect_uri` - have been validated and can be used also. - :param client_id: Unicode client identifier - :param code: Unicode authorization code grant - :param redirect_uri: Unicode absolute URI - :return: Unicode nonce - Method is used by: - - Authorization Token Grant Dispatcher - """ - # TODO: Fix this ;) - return "" - - def get_userinfo_claims(self, request): - """ - Generates and saves a new JWT for this request, and returns it as the - current user's claims. - - """ - return self.get_oidc_claims(None, None, request) - - def get_additional_claims(self, request): - return {} diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index d3d60801e..0135da8b7 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -23,19 +23,10 @@ USER_SETTINGS = getattr(settings, "OAUTH2_PROVIDER", None) -APPLICATION_MODEL = getattr( - settings, "OAUTH2_PROVIDER_APPLICATION_MODEL", "oauth2_provider.Application" -) -ACCESS_TOKEN_MODEL = getattr( - settings, "OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL", "oauth2_provider.AccessToken" -) -ID_TOKEN_MODEL = getattr( - settings, "OAUTH2_PROVIDER_ID_TOKEN_MODEL", "oauth2_provider.IDToken" -) +APPLICATION_MODEL = getattr(settings, "OAUTH2_PROVIDER_APPLICATION_MODEL", "oauth2_provider.Application") +ACCESS_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL", "oauth2_provider.AccessToken") GRANT_MODEL = getattr(settings, "OAUTH2_PROVIDER_GRANT_MODEL", "oauth2_provider.Grant") -REFRESH_TOKEN_MODEL = getattr( - settings, "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL", "oauth2_provider.RefreshToken" -) +REFRESH_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL", "oauth2_provider.RefreshToken") DEFAULTS = { "CLIENT_ID_GENERATOR_CLASS": "oauth2_provider.generators.ClientIdGenerator", @@ -44,7 +35,7 @@ "ACCESS_TOKEN_GENERATOR": None, "REFRESH_TOKEN_GENERATOR": None, "EXTRA_SERVER_KWARGS": {}, - "OAUTH2_SERVER_CLASS": "oauthlib.openid.connect.core.endpoints.pre_configured.Server", + "OAUTH2_SERVER_CLASS": "oauthlib.oauth2.Server", "OAUTH2_VALIDATOR_CLASS": "oauth2_provider.oauth2_validators.OAuth2Validator", "OAUTH2_BACKEND_CLASS": "oauth2_provider.oauth2_backends.OAuthLibCore", "SCOPES": {"read": "Reading scope", "write": "Writing scope"}, @@ -54,46 +45,29 @@ "WRITE_SCOPE": "write", "AUTHORIZATION_CODE_EXPIRE_SECONDS": 60, "ACCESS_TOKEN_EXPIRE_SECONDS": 36000, - "ID_TOKEN_EXPIRE_SECONDS": 36000, "REFRESH_TOKEN_EXPIRE_SECONDS": None, "REFRESH_TOKEN_GRACE_PERIOD_SECONDS": 0, "ROTATE_REFRESH_TOKEN": True, "ERROR_RESPONSE_WITH_SCOPES": False, "APPLICATION_MODEL": APPLICATION_MODEL, "ACCESS_TOKEN_MODEL": ACCESS_TOKEN_MODEL, - "ID_TOKEN_MODEL": ID_TOKEN_MODEL, "GRANT_MODEL": GRANT_MODEL, "REFRESH_TOKEN_MODEL": REFRESH_TOKEN_MODEL, "REQUEST_APPROVAL_PROMPT": "force", "ALLOWED_REDIRECT_URI_SCHEMES": ["http", "https"], - "OIDC_ISS_ENDPOINT": "", - "OIDC_USERINFO_ENDPOINT": "", - "OIDC_RSA_PRIVATE_KEY": "", - "OIDC_RESPONSE_TYPES_SUPPORTED": [ - "code", - "token", - "id_token", - "id_token token", - "code token", - "code id_token", - "code id_token token", - ], - "OIDC_SUBJECT_TYPES_SUPPORTED": ["public"], - "OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED": ["RS256", "HS256"], - "OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED": [ - "client_secret_post", - "client_secret_basic", - ], + # Special settings that will be evaluated at runtime "_SCOPES": [], "_DEFAULT_SCOPES": [], + # Resource Server with Token Introspection "RESOURCE_SERVER_INTROSPECTION_URL": None, "RESOURCE_SERVER_AUTH_TOKEN": None, "RESOURCE_SERVER_INTROSPECTION_CREDENTIALS": None, "RESOURCE_SERVER_TOKEN_CACHING_SECONDS": 36000, + # Whether or not PKCE is required - "PKCE_REQUIRED": False, + "PKCE_REQUIRED": False } # List of settings that cannot be empty @@ -105,11 +79,6 @@ "OAUTH2_BACKEND_CLASS", "SCOPES", "ALLOWED_REDIRECT_URI_SCHEMES", - "OIDC_RSA_PRIVATE_KEY", - "OIDC_RESPONSE_TYPES_SUPPORTED", - "OIDC_SUBJECT_TYPES_SUPPORTED", - "OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED", - "OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED", ) # List of settings that may be in string import notation. @@ -148,12 +117,7 @@ def import_from_string(val, setting_name): module = importlib.import_module(module_path) return getattr(module, class_name) except ImportError as e: - msg = "Could not import %r for setting %r. %s: %s." % ( - val, - setting_name, - e.__class__.__name__, - e, - ) + msg = "Could not import %r for setting %r. %s: %s." % (val, setting_name, e.__class__.__name__, e) raise ImportError(msg) @@ -165,9 +129,7 @@ class OAuth2ProviderSettings: and return the class, rather than the string literal. """ - def __init__( - self, user_settings=None, defaults=None, import_strings=None, mandatory=None - ): + def __init__(self, user_settings=None, defaults=None, import_strings=None, mandatory=None): self.user_settings = user_settings or {} self.defaults = defaults or {} self.import_strings = import_strings or () @@ -202,9 +164,7 @@ def __getattr__(self, attr): if scope in self._SCOPES: val.append(scope) else: - raise ImproperlyConfigured( - "Defined DEFAULT_SCOPES not present in SCOPES" - ) + raise ImproperlyConfigured("Defined DEFAULT_SCOPES not present in SCOPES") self.validate_setting(attr, val) diff --git a/oauth2_provider/urls.py b/oauth2_provider/urls.py index f2f04d853..4cf6d4c6d 100644 --- a/oauth2_provider/urls.py +++ b/oauth2_provider/urls.py @@ -27,12 +27,5 @@ name="authorized-token-delete"), ] -oidc_urlpatterns = [ - re_path(r"^\.well-known/openid-configuration/$", views.ConnectDiscoveryInfoView.as_view(), - name="oidc-connect-discovery-info"), - re_path(r"^jwks/$", views.JwksInfoView.as_view(), name="jwks-info"), - re_path(r"^userinfo/$", views.UserInfoView.as_view(), name="user-info") -] - -urlpatterns = base_urlpatterns + management_urlpatterns + oidc_urlpatterns +urlpatterns = base_urlpatterns + management_urlpatterns diff --git a/oauth2_provider/views/__init__.py b/oauth2_provider/views/__init__.py index 9f2ac4ff7..7636bd9c7 100644 --- a/oauth2_provider/views/__init__.py +++ b/oauth2_provider/views/__init__.py @@ -1,13 +1,9 @@ # flake8: noqa -from .application import ( - ApplicationDelete, ApplicationDetail, ApplicationList, - ApplicationRegistration, ApplicationUpdate -) -from .base import AuthorizationView, RevokeTokenView, TokenView +from .base import AuthorizationView, TokenView, RevokeTokenView +from .application import ApplicationRegistration, ApplicationDetail, ApplicationList, \ + ApplicationDelete, ApplicationUpdate from .generic import ( - ProtectedResourceView, ReadWriteScopedResourceView, - ScopedProtectedResourceView -) + ProtectedResourceView, ScopedProtectedResourceView, ReadWriteScopedResourceView, + ClientProtectedResourceView, ClientProtectedScopedResourceView) +from .token import AuthorizedTokensListView, AuthorizedTokenDeleteView from .introspect import IntrospectTokenView -from .oidc import ConnectDiscoveryInfoView, JwksInfoView, UserInfoView -from .token import AuthorizedTokenDeleteView, AuthorizedTokensListView diff --git a/oauth2_provider/views/application.py b/oauth2_provider/views/application.py index b38c907ab..c925493f5 100644 --- a/oauth2_provider/views/application.py +++ b/oauth2_provider/views/application.py @@ -32,7 +32,7 @@ def get_form_class(self): get_application_model(), fields=( "name", "client_id", "client_secret", "client_type", - "authorization_grant_type", "redirect_uris", "algorithm", + "authorization_grant_type", "redirect_uris" ) ) @@ -81,6 +81,6 @@ def get_form_class(self): get_application_model(), fields=( "name", "client_id", "client_secret", "client_type", - "authorization_grant_type", "redirect_uris", "algorithm", + "authorization_grant_type", "redirect_uris" ) ) diff --git a/oauth2_provider/views/base.py b/oauth2_provider/views/base.py index eb825c307..b9b6ed7f9 100644 --- a/oauth2_provider/views/base.py +++ b/oauth2_provider/views/base.py @@ -86,7 +86,6 @@ class AuthorizationView(BaseAuthorizationView, FormView): * Authorization code * Implicit grant """ - template_name = "oauth2_provider/authorize.html" form_class = AllowForm @@ -102,14 +101,11 @@ def get_initial(self): initial_data = { "redirect_uri": self.oauth2_data.get("redirect_uri", None), "scope": " ".join(scopes), - "nonce": self.oauth2_data.get("nonce", None), "client_id": self.oauth2_data.get("client_id", None), "state": self.oauth2_data.get("state", None), "response_type": self.oauth2_data.get("response_type", None), "code_challenge": self.oauth2_data.get("code_challenge", None), - "code_challenge_method": self.oauth2_data.get( - "code_challenge_method", None - ), + "code_challenge_method": self.oauth2_data.get("code_challenge_method", None), } return initial_data @@ -120,27 +116,18 @@ def form_valid(self, form): "client_id": form.cleaned_data.get("client_id"), "redirect_uri": form.cleaned_data.get("redirect_uri"), "response_type": form.cleaned_data.get("response_type", None), - "state": form.cleaned_data.get("state", None), + "state": form.cleaned_data.get("state", None) } if form.cleaned_data.get("code_challenge", False): credentials["code_challenge"] = form.cleaned_data.get("code_challenge") if form.cleaned_data.get("code_challenge_method", False): - credentials["code_challenge_method"] = form.cleaned_data.get( - "code_challenge_method" - ) - - body = {"nonce": form.cleaned_data.get("nonce")} + credentials["code_challenge_method"] = form.cleaned_data.get("code_challenge_method") scopes = form.cleaned_data.get("scope") allow = form.cleaned_data.get("allow") try: uri, headers, body, status = self.create_authorization_response( - self.request.get_raw_uri(), - request=self.request, - scopes=scopes, - credentials=credentials, - body=body, - allow=allow, + request=self.request, scopes=scopes, credentials=credentials, allow=allow ) except OAuthToolkitError as error: return self.error_response(error, application) @@ -162,21 +149,13 @@ def get(self, request, *args, **kwargs): # at this point we know an Application instance with such client_id exists in the database # TODO: Cache this! - application = get_application_model().objects.get( - client_id=credentials["client_id"] - ) - - uri_query = urllib.parse.urlparse(self.request.get_raw_uri()).query - uri_query_params = dict( - urllib.parse.parse_qsl(uri_query, keep_blank_values=True, strict_parsing=True) - ) + application = get_application_model().objects.get(client_id=credentials["client_id"]) kwargs["application"] = application kwargs["client_id"] = credentials["client_id"] kwargs["redirect_uri"] = credentials["redirect_uri"] kwargs["response_type"] = credentials["response_type"] kwargs["state"] = credentials["state"] - kwargs["nonce"] = uri_query_params.get("nonce", None) self.oauth2_data = kwargs # following two loc are here only because of https://code.djangoproject.com/ticket/17795 @@ -185,9 +164,7 @@ def get(self, request, *args, **kwargs): # Check to see if the user has already granted access and return # a successful response depending on "approval_prompt" url parameter - require_approval = request.GET.get( - "approval_prompt", oauth2_settings.REQUEST_APPROVAL_PROMPT - ) + require_approval = request.GET.get("approval_prompt", oauth2_settings.REQUEST_APPROVAL_PROMPT) try: # If skip_authorization field is True, skip the authorization screen even @@ -196,36 +173,26 @@ def get(self, request, *args, **kwargs): # are already approved. if application.skip_authorization: uri, headers, body, status = self.create_authorization_response( - self.request.get_raw_uri(), - request=self.request, - scopes=" ".join(scopes), - credentials=credentials, - allow=True, + request=self.request, scopes=" ".join(scopes), + credentials=credentials, allow=True ) return self.redirect(uri, application) elif require_approval == "auto": - tokens = ( - get_access_token_model() - .objects.filter( - user=request.user, - application=kwargs["application"], - expires__gt=timezone.now(), - ) - .all() - ) + tokens = get_access_token_model().objects.filter( + user=request.user, + application=kwargs["application"], + expires__gt=timezone.now() + ).all() # check past authorizations regarded the same scopes as the current one for token in tokens: if token.allow_scopes(scopes): uri, headers, body, status = self.create_authorization_response( - self.request.get_raw_uri(), - request=self.request, - scopes=" ".join(scopes), - credentials=credentials, - allow=True, + request=self.request, scopes=" ".join(scopes), + credentials=credentials, allow=True ) - return self.redirect(uri, application) + return self.redirect(uri, application, token) except OAuthToolkitError as error: return self.error_response(error, application) @@ -272,7 +239,6 @@ class TokenView(OAuthLibMixin, View): * Password * Client credentials """ - server_class = oauth2_settings.OAUTH2_SERVER_CLASS validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS @@ -283,8 +249,11 @@ def post(self, request, *args, **kwargs): if status == 200: access_token = json.loads(body).get("access_token") if access_token is not None: - token = get_access_token_model().objects.get(token=access_token) - app_authorized.send(sender=self, request=request, token=token) + token = get_access_token_model().objects.get( + token=access_token) + app_authorized.send( + sender=self, request=request, + token=token) response = HttpResponse(content=body, status=status) for k, v in headers.items(): @@ -297,7 +266,6 @@ class RevokeTokenView(OAuthLibMixin, View): """ Implements an endpoint to revoke access or refresh tokens """ - server_class = oauth2_settings.OAUTH2_SERVER_CLASS validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS diff --git a/oauth2_provider/views/introspect.py b/oauth2_provider/views/introspect.py index 460a1395d..7d4381179 100644 --- a/oauth2_provider/views/introspect.py +++ b/oauth2_provider/views/introspect.py @@ -7,7 +7,7 @@ from django.views.decorators.csrf import csrf_exempt from oauth2_provider.models import get_access_token_model -from oauth2_provider.views.generic import ClientProtectedScopedResourceView +from oauth2_provider.views import ClientProtectedScopedResourceView @method_decorator(csrf_exempt, name="dispatch") diff --git a/oauth2_provider/views/mixins.py b/oauth2_provider/views/mixins.py index 0b7e02c7a..b5d0d4145 100644 --- a/oauth2_provider/views/mixins.py +++ b/oauth2_provider/views/mixins.py @@ -97,7 +97,7 @@ def validate_authorization_request(self, request): core = self.get_oauthlib_core() return core.validate_authorization_request(request) - def create_authorization_response(self, uri, request, scopes, credentials, allow, body=None): + def create_authorization_response(self, request, scopes, credentials, allow): """ A wrapper method that calls create_authorization_response on `server_class` instance. @@ -105,15 +105,14 @@ def create_authorization_response(self, uri, request, scopes, credentials, allow :param request: The current django.http.HttpRequest object :param scopes: A space-separated string of provided scopes :param credentials: Authorization credentials dictionary containing - `client_id`, `state`, `redirect_uri` and `response_type` + `client_id`, `state`, `redirect_uri`, `response_type` :param allow: True if the user authorize the client, otherwise False - :param body: Other body parameters not used in credentials dictionary """ # TODO: move this scopes conversion from and to string into a utils function scopes = scopes.split(" ") if scopes else [] core = self.get_oauthlib_core() - return core.create_authorization_response(uri, request, scopes, credentials, body, allow) + return core.create_authorization_response(request, scopes, credentials, allow) def create_token_response(self, request): """ @@ -134,16 +133,6 @@ def create_revocation_response(self, request): core = self.get_oauthlib_core() return core.create_revocation_response(request) - def create_userinfo_response(self, request): - """ - A wrapper method that calls create_userinfo_response on the - `server_class` instance. - - :param request: The current django.http.HttpRequest object - """ - core = self.get_oauthlib_core() - return core.create_userinfo_response(request) - def verify_request(self, request): """ A wrapper method that calls verify_request on `server_class` instance. @@ -288,13 +277,11 @@ def dispatch(self, request, *args, **kwargs): if not valid: # Alternatively allow access tokens # check if the request is valid and the protected resource may be accessed - try: - valid, r = self.verify_request(request) - if valid: - request.resource_owner = r.user - return super().dispatch(request, *args, **kwargs) - except ValueError: - pass - return HttpResponseForbidden() + valid, r = self.verify_request(request) + if valid: + request.resource_owner = r.user + return super().dispatch(request, *args, **kwargs) + else: + return HttpResponseForbidden() else: return super().dispatch(request, *args, **kwargs) diff --git a/oauth2_provider/views/oidc.py b/oauth2_provider/views/oidc.py deleted file mode 100644 index d7ffe4670..000000000 --- a/oauth2_provider/views/oidc.py +++ /dev/null @@ -1,95 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -import json - -from django.http import HttpResponse, JsonResponse -from django.urls import reverse, reverse_lazy -from django.utils.decorators import method_decorator -from django.views.decorators.csrf import csrf_exempt -from django.views.generic import View -from jwcrypto import jwk - -from ..settings import oauth2_settings -from .mixins import OAuthLibMixin - - -class ConnectDiscoveryInfoView(View): - """ - View used to show oidc provider configuration information - """ - def get(self, request, *args, **kwargs): - issuer_url = oauth2_settings.OIDC_ISS_ENDPOINT - - if not issuer_url: - abs_url = request.build_absolute_uri(reverse("oauth2_provider:oidc-connect-discovery-info")) - issuer_url = abs_url[:-len("/.well-known/openid-configuration/")] - - authorization_endpoint = request.build_absolute_uri(reverse("oauth2_provider:authorize")) - token_endpoint = request.build_absolute_uri(reverse("oauth2_provider:token")) - userinfo_endpoint = ( - oauth2_settings.OIDC_USERINFO_ENDPOINT or - request.build_absolute_uri(reverse("oauth2_provider:user-info")) - ) - jwks_uri = request.build_absolute_uri(reverse("oauth2_provider:jwks-info")) - else: - authorization_endpoint = "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:authorize")) - token_endpoint = "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:token")) - userinfo_endpoint = ( - oauth2_settings.OIDC_USERINFO_ENDPOINT or - "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:user-info")) - ) - jwks_uri = "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:jwks-info")) - - data = { - "issuer": issuer_url, - "authorization_endpoint": authorization_endpoint, - "token_endpoint": token_endpoint, - "userinfo_endpoint": userinfo_endpoint, - "jwks_uri": jwks_uri, - "response_types_supported": oauth2_settings.OIDC_RESPONSE_TYPES_SUPPORTED, - "subject_types_supported": oauth2_settings.OIDC_SUBJECT_TYPES_SUPPORTED, - "id_token_signing_alg_values_supported": - oauth2_settings.OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED, - "token_endpoint_auth_methods_supported": - oauth2_settings.OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED, - } - response = JsonResponse(data) - response["Access-Control-Allow-Origin"] = "*" - return response - - -class JwksInfoView(View): - """ - View used to show oidc json web key set document - """ - def get(self, request, *args, **kwargs): - key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) - data = { - "keys": [{ - "alg": "RS256", - "use": "sig", - "kid": key.thumbprint() - }] - } - data["keys"][0].update(json.loads(key.export_public())) - response = JsonResponse(data) - response["Access-Control-Allow-Origin"] = "*" - return response - - -@method_decorator(csrf_exempt, name="dispatch") -class UserInfoView(OAuthLibMixin, View): - """ - View used to show Claims about the authenticated End-User - """ - server_class = oauth2_settings.OAUTH2_SERVER_CLASS - validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS - oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS - - def get(self, request, *args, **kwargs): - url, headers, body, status = self.create_userinfo_response(request) - response = HttpResponse(content=body or "", status=status) - - for k, v in headers.items(): - response[k] = v - return response diff --git a/setup.cfg b/setup.cfg index fb060f88e..3c4e0badc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,7 +34,6 @@ install_requires = django >= 2.1 requests >= 2.13.0 oauthlib >= 3.1.0 - jwcrypto >= 0.4.2 [options.packages.find] exclude = tests diff --git a/tests/migrations/0001_initial.py b/tests/migrations/0001_initial.py index eef6dbab5..60b17f2ae 100644 --- a/tests/migrations/0001_initial.py +++ b/tests/migrations/0001_initial.py @@ -45,7 +45,7 @@ class Migration(migrations.Migration): ('client_id', models.CharField(db_index=True, default=oauth2_provider.generators.generate_client_id, max_length=100, unique=True)), ('redirect_uris', models.TextField(blank=True, help_text='Allowed URIs list, space separated')), ('client_type', models.CharField(choices=[('confidential', 'Confidential'), ('public', 'Public')], max_length=32)), - ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32)), + ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials')], max_length=32)), ('client_secret', models.CharField(blank=True, db_index=True, default=oauth2_provider.generators.generate_client_secret, max_length=255)), ('name', models.CharField(blank=True, max_length=255)), ('skip_authorization', models.BooleanField(default=False)), @@ -53,7 +53,6 @@ class Migration(migrations.Migration): ('updated', models.DateTimeField(auto_now=True)), ('custom_field', models.CharField(max_length=255)), ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_sampleapplication', to=settings.AUTH_USER_MODEL)), - ('algorithm', models.CharField(max_length=5, choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256')), ], options={ 'abstract': False, @@ -72,7 +71,6 @@ class Migration(migrations.Migration): ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)), ('source_refresh_token', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='s_refreshed_access_token', to=settings.OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL)), ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_sampleaccesstoken', to=settings.AUTH_USER_MODEL)), - ('id_token', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=settings.OAUTH2_PROVIDER_ID_TOKEN_MODEL)), ], options={ 'abstract': False, @@ -85,7 +83,7 @@ class Migration(migrations.Migration): ('client_id', models.CharField(db_index=True, default=oauth2_provider.generators.generate_client_id, max_length=100, unique=True)), ('redirect_uris', models.TextField(blank=True, help_text='Allowed URIs list, space separated')), ('client_type', models.CharField(choices=[('confidential', 'Confidential'), ('public', 'Public')], max_length=32)), - ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32)), + ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials')], max_length=32)), ('client_secret', models.CharField(blank=True, db_index=True, default=oauth2_provider.generators.generate_client_secret, max_length=255)), ('name', models.CharField(blank=True, max_length=255)), ('skip_authorization', models.BooleanField(default=False)), @@ -93,7 +91,6 @@ class Migration(migrations.Migration): ('updated', models.DateTimeField(auto_now=True)), ('allowed_schemes', models.TextField(blank=True)), ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_basetestapplication', to=settings.AUTH_USER_MODEL)), - ('algorithm', models.CharField(max_length=5, choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256')), ], options={ 'abstract': False, diff --git a/tests/settings.py b/tests/settings.py index edd1ae679..40eef5ebd 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -130,30 +130,3 @@ }, } } - -OIDC_RSA_PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY----- -MIICXQIBAAKBgQCbCYh5h2NmQuBqVO6G+/CO+cHm9VBzsb0MeA6bbQfDnbhstVOT -j0hcnZJzDjYc6ajBZZf6gxVP9xrdm9Uh599VI3X5PFXLbMHrmzTAMzCGIyg+/fnP -0gocYxmCX2+XKyj/Zvt1pUX8VAN2AhrJSfxNDKUHERTVEV9bRBJg4F0C3wIDAQAB -AoGAP+i4nNw+Ec/8oWh8YSFm4xE6qKG0NdTtSMAOyWwy+KTB+vHuT1QPsLn1vj77 -+IQrX/moogg6F1oV9YdA3vat3U7rwt1sBGsRrLhA+Spp9WEQtglguNo4+QfVo2ju -YBa2rG+h75qjiA3xnU//F3rvwnAsOWv0NUVdVeguyR+u6okCQQDBUmgWeH2WHmUn -2nLNCz+9wj28rqhfOr9Ptem2gqk+ywJmuIr4Y5S1OdavOr2UZxOcEwncJ/MLVYQq -MH+x4V5HAkEAzU2GMR5OdVLcxfVTjzuIC76paoHVWnLibd1cdANpPmE6SM+pf5el -fVSwuH9Fmlizu8GiPCxbJUoXB/J1tGEKqQJBALhClEU+qOzpoZ6/voYi/6kdN3zc -uEy0EN6n09AKb8gS9QH1STgAqh+ltjMkeMe3C2DKYK5/QU9/Pc58lWl1FkcCQG67 -ZamQgxjcvJ85FvymS1aqW45KwNysIlzHjFo2jMlMf7dN6kobbPMQftDENLJvLWIT -qoFyGycdsxZiPAIyZSECQQCZFn3Dl6hnJxWZH8Fsa9hj79kZ/WVkIXGmtdgt0fNr -dTnvCVtA59ne4LEVie/PMH/odQWY0SxVm/76uBZv/1vY ------END RSA PRIVATE KEY-----""" - -OAUTH2_PROVIDER = { - "OIDC_ISS_ENDPOINT": "http://localhost", - "OIDC_USERINFO_ENDPOINT": "http://localhost/userinfo/", - "OIDC_RSA_PRIVATE_KEY": OIDC_RSA_PRIVATE_KEY, -} - -OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL = "oauth2_provider.AccessToken" -OAUTH2_PROVIDER_APPLICATION_MODEL = "oauth2_provider.Application" -OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL = "oauth2_provider.RefreshToken" -OAUTH2_PROVIDER_ID_TOKEN_MODEL = "oauth2_provider.IDToken" diff --git a/tests/test_application_views.py b/tests/test_application_views.py index 64e112da3..6130876ce 100644 --- a/tests/test_application_views.py +++ b/tests/test_application_views.py @@ -50,7 +50,6 @@ def test_application_registration_user(self): "client_type": Application.CLIENT_CONFIDENTIAL, "redirect_uris": "http://example.com", "authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE, - "algorithm": "RS256", } response = self.client.post(reverse("oauth2_provider:register"), form_data) diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index e4eb8ae81..e98f5b041 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -41,12 +41,8 @@ def get(self, request, *args, **kwargs): class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() - self.test_user = UserModel.objects.create_user( - "test_user", "test@example.com", "123456" - ) - self.dev_user = UserModel.objects.create_user( - "dev_user", "dev@example.com", "123456" - ) + self.test_user = UserModel.objects.create_user("test_user", "test@example.com", "123456") + self.dev_user = UserModel.objects.create_user("dev_user", "dev@example.com", "123456") oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] @@ -61,13 +57,8 @@ def setUp(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - oauth2_settings._SCOPES = ["read", "write", "openid"] + oauth2_settings._SCOPES = ["read", "write"] oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - oauth2_settings.SCOPES = { - "read": "Reading scope", - "write": "Writing scope", - "openid": "OpenID connect", - } def tearDown(self): self.application.delete() @@ -112,25 +103,6 @@ def test_skip_authorization_completely(self): }) self.assertEqual(response.status_code, 302) - def test_id_token_skip_authorization_completely(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_data = { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } - - response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 302) - def test_pre_auth_invalid_client(self): """ Test error for an invalid client_id with response_type: code @@ -175,32 +147,6 @@ def test_pre_auth_valid_client(self): self.assertEqual(form["scope"].value(), "read write") self.assertEqual(form["client_id"].value(), self.application.client_id) - def test_id_token_pre_auth_valid_client(self): - """ - Test response for a valid client_id with response_type: code - """ - self.client.login(username="test_user", password="123456") - - query_data = { - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } - - response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 200) - - # check form is in context and form params are valid - self.assertIn("form", response.context) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "http://example.org") - self.assertEqual(form["state"].value(), "random_state_string") - self.assertEqual(form["scope"].value(), "openid") - self.assertEqual(form["client_id"].value(), self.application.client_id) - def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): """ Test response for a valid client_id with response_type: code @@ -230,11 +176,10 @@ def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): def test_pre_auth_approval_prompt(self): tok = AccessToken.objects.create( - user=self.test_user, - token="1234567890", + user=self.test_user, token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write", + scope="read write" ) self.client.login(username="test_user", password="123456") @@ -259,11 +204,10 @@ def test_pre_auth_approval_prompt_default(self): self.assertEqual(oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") AccessToken.objects.create( - user=self.test_user, - token="1234567890", + user=self.test_user, token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write", + scope="read write" ) self.client.login(username="test_user", password="123456") query_data = { @@ -280,11 +224,10 @@ def test_pre_auth_approval_prompt_default_override(self): oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" AccessToken.objects.create( - user=self.test_user, - token="1234567890", + user=self.test_user, token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write", + scope="read write" ) self.client.login(username="test_user", password="123456") query_data = { @@ -359,32 +302,7 @@ def test_code_post_auth_allow(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org?", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - - def test_id_token_code_post_auth_allow(self): - """ - Test authorization code is given for an allowed request with response_type: code - """ - self.client.login(username="test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - } - - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.org?", response["Location"]) self.assertIn("state=random_state_string", response["Location"]) @@ -405,9 +323,7 @@ def test_code_post_auth_deny(self): "allow": False, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("error=access_denied", response["Location"]) self.assertIn("state=random_state_string", response["Location"]) @@ -426,9 +342,7 @@ def test_code_post_auth_deny_no_state(self): "allow": False, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("error=access_denied", response["Location"]) self.assertNotIn("state", response["Location"]) @@ -448,9 +362,7 @@ def test_code_post_auth_bad_responsetype(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.org?error", response["Location"]) @@ -469,9 +381,7 @@ def test_code_post_auth_forbidden_redirect_uri(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 400) def test_code_post_auth_malicious_redirect_uri(self): @@ -489,9 +399,7 @@ def test_code_post_auth_malicious_redirect_uri(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 400) def test_code_post_auth_allow_custom_redirect_uri_scheme(self): @@ -510,9 +418,7 @@ def test_code_post_auth_allow_custom_redirect_uri_scheme(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("custom-scheme://example.com?", response["Location"]) self.assertIn("state=random_state_string", response["Location"]) @@ -534,9 +440,7 @@ def test_code_post_auth_deny_custom_redirect_uri_scheme(self): "allow": False, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("custom-scheme://example.com?", response["Location"]) self.assertIn("error=access_denied", response["Location"]) @@ -559,9 +463,7 @@ def test_code_post_auth_redirection_uri_with_querystring(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.com?foo=bar", response["Location"]) self.assertIn("code=", response["Location"]) @@ -584,9 +486,7 @@ def test_code_post_auth_failing_redirection_uri_with_querystring(self): "allow": False, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 302) self.assertIn("http://example.com?", response["Location"]) self.assertIn("error=access_denied", response["Location"]) @@ -608,29 +508,25 @@ def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=form_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) self.assertEqual(response.status_code, 400) class TestAuthorizationCodeTokenView(BaseTest): - def get_auth(self, scope="read write"): + def get_auth(self): """ Helper method to retrieve a valid authorization code """ authcode_data = { "client_id": self.application.client_id, "state": "random_state_string", - "scope": scope, + "scope": "read write", "redirect_uri": "http://example.org", "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) return query_dict["code"].pop() @@ -640,13 +536,9 @@ def generate_pkce_codes(self, algorithm, length=43): """ code_verifier = get_random_string(length) if algorithm == "S256": - code_challenge = ( - base64.urlsafe_b64encode( - hashlib.sha256(code_verifier.encode()).digest() - ) - .decode() - .rstrip("=") - ) + code_challenge = base64.urlsafe_b64encode( + hashlib.sha256(code_verifier.encode()).digest() + ).decode().rstrip("=") else: code_challenge = code_verifier return code_verifier, code_challenge @@ -667,9 +559,7 @@ def get_pkce_auth(self, code_challenge, code_challenge_method): "code_challenge_method": code_challenge_method, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) oauth2_settings.PKCE_REQUIRED = False return query_dict["code"].pop() @@ -684,23 +574,17 @@ def test_basic_auth(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_refresh(self): """ @@ -712,15 +596,11 @@ def test_refresh(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -729,29 +609,23 @@ def test_refresh(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) token_request_data = { "grant_type": "refresh_token", "refresh_token": content["refresh_token"], "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertTrue("access_token" in content) # check refresh token cannot be used twice - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) content = json.loads(response.content.decode("utf-8")) self.assertTrue("invalid_grant" in content.values()) @@ -767,15 +641,11 @@ def test_refresh_with_grace_period(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -784,11 +654,9 @@ def test_refresh_with_grace_period(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) token_request_data = { "grant_type": "refresh_token", @@ -796,9 +664,7 @@ def test_refresh_with_grace_period(self): "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) @@ -807,9 +673,7 @@ def test_refresh_with_grace_period(self): first_refresh_token = content["refresh_token"] # check access token returns same data if used twice, see #497 - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertTrue("access_token" in content) @@ -829,15 +693,11 @@ def test_refresh_invalidates_old_tokens(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) rt = content["refresh_token"] @@ -848,9 +708,7 @@ def test_refresh_invalidates_old_tokens(self): "refresh_token": rt, "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) refresh_token = RefreshToken.objects.filter(token=rt).first() @@ -867,15 +725,11 @@ def test_refresh_no_scopes(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -883,9 +737,7 @@ def test_refresh_no_scopes(self): "grant_type": "refresh_token", "refresh_token": content["refresh_token"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) @@ -901,15 +753,11 @@ def test_refresh_bad_scopes(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -918,9 +766,7 @@ def test_refresh_bad_scopes(self): "refresh_token": content["refresh_token"], "scope": "read write nuke", } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_refresh_fail_repeating_requests(self): @@ -933,15 +779,11 @@ def test_refresh_fail_repeating_requests(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -950,13 +792,9 @@ def test_refresh_fail_repeating_requests(self): "refresh_token": content["refresh_token"], "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_refresh_repeating_requests(self): @@ -971,15 +809,11 @@ def test_refresh_repeating_requests(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -988,26 +822,18 @@ def test_refresh_repeating_requests(self): "refresh_token": content["refresh_token"], "scope": content["scope"], } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) # try refreshing outside the refresh window, see #497 rt = RefreshToken.objects.get(token=content["refresh_token"]) self.assertIsNotNone(rt.revoked) - rt.revoked = timezone.now() - datetime.timedelta( - minutes=10 - ) # instead of mocking out datetime + rt.revoked = timezone.now() - datetime.timedelta(minutes=10) # instead of mocking out datetime rt.save() - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 0 @@ -1021,15 +847,11 @@ def test_refresh_repeating_requests_non_rotating_tokens(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -1040,13 +862,9 @@ def test_refresh_repeating_requests_non_rotating_tokens(self): } oauth2_settings.ROTATE_REFRESH_TOKEN = False - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) oauth2_settings.ROTATE_REFRESH_TOKEN = True @@ -1060,15 +878,11 @@ def test_basic_auth_bad_authcode(self): token_request_data = { "grant_type": "authorization_code", "code": "BLAH", - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_basic_auth_bad_granttype(self): @@ -1080,15 +894,11 @@ def test_basic_auth_bad_granttype(self): token_request_data = { "grant_type": "UNKNOWN", "code": "BLAH", - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_basic_auth_grant_expired(self): @@ -1097,27 +907,18 @@ def test_basic_auth_grant_expired(self): """ self.client.login(username="test_user", password="123456") g = Grant( - application=self.application, - user=self.test_user, - code="BLAH", - expires=timezone.now(), - redirect_uri="", - scope="", - ) + application=self.application, user=self.test_user, code="BLAH", + expires=timezone.now(), redirect_uri="", scope="") g.save() token_request_data = { "grant_type": "authorization_code", "code": "BLAH", - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) def test_basic_auth_bad_secret(self): @@ -1130,13 +931,11 @@ def test_basic_auth_bad_secret(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } auth_headers = get_basic_auth_header(self.application.client_id, "BOOM!") - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 401) def test_basic_auth_wrong_auth_type(self): @@ -1149,20 +948,16 @@ def test_basic_auth_wrong_auth_type(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - user_pass = "{0}:{1}".format( - self.application.client_id, self.application.client_secret - ) + user_pass = "{0}:{1}".format(self.application.client_id, self.application.client_secret) auth_string = base64.b64encode(user_pass.encode("utf-8")) auth_headers = { "HTTP_AUTHORIZATION": "Wrong " + auth_string.decode("utf-8"), } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 401) def test_request_body_params(self): @@ -1180,17 +975,13 @@ def test_request_body_params(self): "client_secret": self.application.client_secret, } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public(self): """ @@ -1206,52 +997,16 @@ def test_public(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "http://example.org", - "client_id": self.application.client_id, + "client_id": self.application.client_id } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) - - def test_id_token_public(self): - """ - Request an access token using client_type: public - """ - self.client.login(username="test_user", password="123456") - - self.application.client_type = Application.CLIENT_PUBLIC - self.application.save() - authorization_code = self.get_auth(scope="openid") - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - "client_id": self.application.client_id, - "scope": "openid", - } - - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid") - self.assertIn("access_token", content) - self.assertIn("id_token", content) - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public_pkce_S256_authorize_get(self): """ @@ -1327,20 +1082,16 @@ def test_public_pkce_S256(self): "code": authorization_code, "redirect_uri": "http://example.org", "client_id": self.application.client_id, - "code_verifier": code_verifier, + "code_verifier": code_verifier } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_plain(self): @@ -1361,20 +1112,16 @@ def test_public_pkce_plain(self): "code": authorization_code, "redirect_uri": "http://example.org", "client_id": self.application.client_id, - "code_verifier": code_verifier, + "code_verifier": code_verifier } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_invalid_algorithm(self): @@ -1477,12 +1224,10 @@ def test_public_pkce_S256_invalid_code_verifier(self): "code": authorization_code, "redirect_uri": "http://example.org", "client_id": self.application.client_id, - "code_verifier": "invalid", + "code_verifier": "invalid" } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) oauth2_settings.PKCE_REQUIRED = False @@ -1504,12 +1249,10 @@ def test_public_pkce_plain_invalid_code_verifier(self): "code": authorization_code, "redirect_uri": "http://example.org", "client_id": self.application.client_id, - "code_verifier": "invalid", + "code_verifier": "invalid" } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) oauth2_settings.PKCE_REQUIRED = False @@ -1530,12 +1273,10 @@ def test_public_pkce_S256_missing_code_verifier(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "http://example.org", - "client_id": self.application.client_id, + "client_id": self.application.client_id } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) oauth2_settings.PKCE_REQUIRED = False @@ -1556,12 +1297,10 @@ def test_public_pkce_plain_missing_code_verifier(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "http://example.org", - "client_id": self.application.client_id, + "client_id": self.application.client_id } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) oauth2_settings.PKCE_REQUIRED = False @@ -1580,19 +1319,14 @@ def test_malicious_redirect_uri(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "/../", - "client_id": self.application.client_id, + "client_id": self.application.client_id } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) data = response.json() self.assertEqual(data["error"], "invalid_request") - self.assertEqual( - data["error_description"], - oauthlib_errors.MismatchingRedirectURIError.description, - ) + self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) def test_code_exchange_succeed_when_redirect_uri_match(self): """ @@ -1609,9 +1343,7 @@ def test_code_exchange_succeed_when_redirect_uri_match(self): "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1619,23 +1351,17 @@ def test_code_exchange_succeed_when_redirect_uri_match(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org?foo=bar", + "redirect_uri": "http://example.org?foo=bar" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_code_exchange_fails_when_redirect_uri_does_not_match(self): """ @@ -1652,9 +1378,7 @@ def test_code_exchange_fails_when_redirect_uri_does_not_match(self): "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1662,26 +1386,17 @@ def test_code_exchange_fails_when_redirect_uri_does_not_match(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org?foo=baraa", + "redirect_uri": "http://example.org?foo=baraa" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) data = response.json() self.assertEqual(data["error"], "invalid_request") - self.assertEqual( - data["error_description"], - oauthlib_errors.MismatchingRedirectURIError.description, - ) + self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) - def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params( - self, - ): + def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): """ Tests code exchange succeed when redirect uri matches the one used for code request """ @@ -1698,9 +1413,7 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1708,72 +1421,17 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.com?bar=baz&foo=bar", + "redirect_uri": "http://example.com?bar=baz&foo=bar" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) - - def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params( - self, - ): - """ - Tests code exchange succeed when redirect uri matches the one used for code request - """ - self.client.login(username="test_user", password="123456") - self.application.redirect_uris = "http://localhost http://example.com?foo=bar" - self.application.save() - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.com?bar=baz&foo=bar", - "response_type": "code", - "allow": True, - } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) - query_dict = parse_qs(urlparse(response["Location"]).query) - authorization_code = query_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.com?bar=baz&foo=bar", - } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) - - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid") - self.assertIn("access_token", content) - self.assertIn("id_token", content) - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_oob_as_html(self): """ @@ -1836,9 +1494,7 @@ def test_oob_as_json(self): "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) self.assertEqual(response.status_code, 200) self.assertRegex(response["Content-Type"], "^application/json") @@ -1855,17 +1511,13 @@ def test_oob_as_json(self): "client_secret": self.application.client_secret, } - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual( - content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS - ) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) class TestAuthorizationCodeProtectedResource(BaseTest): @@ -1881,54 +1533,7 @@ def test_resource_access_allowed(self): "response_type": "code", "allow": True, } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) - query_dict = parse_qs(urlparse(response["Location"]).query) - authorization_code = query_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) - - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) - content = json.loads(response.content.decode("utf-8")) - access_token = content["access_token"] - - # use token to access the resource - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + access_token, - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response, "This is a protected resource") - - def test_id_token_resource_access_allowed(self): - self.client.login(username="test_user", password="123456") - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - } - response = self.client.post( - reverse("oauth2_provider:authorize"), data=authcode_data - ) + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1936,18 +1541,13 @@ def test_id_token_resource_access_allowed(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org", + "redirect_uri": "http://example.org" } - auth_headers = get_basic_auth_header( - self.application.client_id, self.application.client_secret - ) + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data, **auth_headers - ) + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) content = json.loads(response.content.decode("utf-8")) access_token = content["access_token"] - id_token = content["id_token"] # use token to access the resource auth_headers = { @@ -1960,17 +1560,6 @@ def test_id_token_resource_access_allowed(self): response = view(request) self.assertEqual(response, "This is a protected resource") - # use id_token to access the resource - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + id_token, - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response, "This is a protected resource") - def test_resource_access_deny(self): auth_headers = { "HTTP_AUTHORIZATION": "Bearer " + "faketoken", @@ -1984,6 +1573,7 @@ def test_resource_access_deny(self): class TestDefaultScopes(BaseTest): + def test_pre_auth_default_scopes(self): """ Test response for a valid client_id with response_type: code using default scopes diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py deleted file mode 100644 index 1f45aeeec..000000000 --- a/tests/test_hybrid.py +++ /dev/null @@ -1,1264 +0,0 @@ -import base64 -import datetime -import json -from urllib.parse import parse_qs, urlencode, urlparse - -from django.contrib.auth import get_user_model -from django.test import RequestFactory, TestCase -from django.urls import reverse -from django.utils import timezone -from oauthlib.oauth2.rfc6749 import errors as oauthlib_errors - -from oauth2_provider.models import ( - get_access_token_model, get_application_model, - get_grant_model, get_refresh_token_model -) -from oauth2_provider.settings import oauth2_settings -from oauth2_provider.views import ProtectedResourceView - -from .utils import get_basic_auth_header - - -Application = get_application_model() -AccessToken = get_access_token_model() -Grant = get_grant_model() -RefreshToken = get_refresh_token_model() -UserModel = get_user_model() - - -# mocking a protected resource view -class ResourceView(ProtectedResourceView): - def get(self, request, *args, **kwargs): - return "This is a protected resource" - - -class BaseTest(TestCase): - def setUp(self): - self.factory = RequestFactory() - self.hy_test_user = UserModel.objects.create_user("hy_test_user", "test_hy@example.com", "123456") - self.hy_dev_user = UserModel.objects.create_user("hy_dev_user", "dev_hy@example.com", "123456") - - oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] - - self.application = Application( - name="Hybrid Test Application", - redirect_uris=( - "http://localhost http://example.com http://example.org custom-scheme://example.com" - ), - user=self.hy_dev_user, - client_type=Application.CLIENT_CONFIDENTIAL, - authorization_grant_type=Application.GRANT_OPENID_HYBRID, - ) - self.application.save() - - oauth2_settings._SCOPES = ["read", "write", "openid"] - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - oauth2_settings.SCOPES = { - "read": "Reading scope", - "write": "Writing scope", - "openid": "OpenID connect" - } - - def tearDown(self): - self.application.delete() - self.hy_test_user.delete() - self.hy_dev_user.delete() - - -class TestRegressionIssue315Hybrid(BaseTest): - """ - Test to avoid regression for the issue 315: request object - was being reassigned when getting AuthorizationView - """ - - def test_request_is_not_overwritten_code_token(self): - self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code token", - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - assert "request" not in response.context_data - - def test_request_is_not_overwritten_code_id_token(self): - self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "nonce": "nonce", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - assert "request" not in response.context_data - - def test_request_is_not_overwritten_code_id_token_token(self): - self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token token", - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "nonce": "nonce", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - assert "request" not in response.context_data - - -class TestHybridView(BaseTest): - def test_skip_authorization_completely(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="hy_test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 302) - - def test_id_token_skip_authorization_completely(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="hy_test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 302) - - def test_pre_auth_invalid_client(self): - """ - Test error for an invalid client_id with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": "fakeclientid", - "response_type": "code", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 400) - self.assertEqual( - response.context_data["url"], - "?error=invalid_request&error_description=Invalid+client_id+parameter+value." - ) - - def test_pre_auth_valid_client(self): - """ - Test response for a valid client_id with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - # check form is in context and form params are valid - self.assertIn("form", response.context) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "http://example.org") - self.assertEqual(form["state"].value(), "random_state_string") - self.assertEqual(form["scope"].value(), "read write") - self.assertEqual(form["client_id"].value(), self.application.client_id) - - def test_id_token_pre_auth_valid_client(self): - """ - Test response for a valid client_id with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "nonce": "nonce", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - # check form is in context and form params are valid - self.assertIn("form", response.context) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "http://example.org") - self.assertEqual(form["state"].value(), "random_state_string") - self.assertEqual(form["scope"].value(), "openid") - self.assertEqual(form["client_id"].value(), self.application.client_id) - - def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): - """ - Test response for a valid client_id with response_type: code - using a non-standard, but allowed, redirect_uri scheme. - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "custom-scheme://example.com", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - # check form is in context and form params are valid - self.assertIn("form", response.context) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "custom-scheme://example.com") - self.assertEqual(form["state"].value(), "random_state_string") - self.assertEqual(form["scope"].value(), "read write") - self.assertEqual(form["client_id"].value(), self.application.client_id) - - def test_pre_auth_approval_prompt(self): - tok = AccessToken.objects.create( - user=self.hy_test_user, token="1234567890", - application=self.application, - expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" - ) - self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "approval_prompt": "auto", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - response = self.client.get(url) - self.assertEqual(response.status_code, 302) - # user already authorized the application, but with different scopes: prompt them. - tok.scope = "read" - tok.save() - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - def test_pre_auth_approval_prompt_default(self): - oauth2_settings.REQUEST_APPROVAL_PROMPT = "force" - self.assertEqual(oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") - - AccessToken.objects.create( - user=self.hy_test_user, token="1234567890", - application=self.application, - expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" - ) - self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - def test_pre_auth_approval_prompt_default_override(self): - oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" - - AccessToken.objects.create( - user=self.hy_test_user, token="1234567890", - application=self.application, - expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" - ) - self.client.login(username="hy_test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - response = self.client.get(url) - self.assertEqual(response.status_code, 302) - - def test_pre_auth_default_redirect(self): - """ - Test for default redirect uri if omitted from query string with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code id_token", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "http://localhost") - - def test_pre_auth_forbibben_redirect(self): - """ - Test error when passing a forbidden redirect_uri in query string with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "redirect_uri": "http://forbidden.it", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 400) - - def test_pre_auth_wrong_response_type(self): - """ - Test error when passing a wrong response_type in query string - """ - self.client.login(username="hy_test_user", password="123456") - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "WRONG", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 302) - self.assertIn("error=unsupported_response_type", response["Location"]) - - def test_code_post_auth_allow_code_token(self): - """ - Test authorization code is given for an allowed request with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "response_type": "code token", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("access_token=", response["Location"]) - - def test_code_post_auth_allow_code_id_token(self): - """ - Test authorization code is given for an allowed request with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "response_type": "code id_token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - - def test_code_post_auth_allow_code_id_token_token(self): - """ - Test authorization code is given for an allowed request with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "response_type": "code id_token token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("access_token=", response["Location"]) - - def test_id_token_code_post_auth_allow(self): - """ - Test authorization code is given for an allowed request with response_type: code - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "code id_token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - - def test_code_post_auth_deny(self): - """ - Test error when resource owner deny access - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": False, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("error=access_denied", response["Location"]) - - def test_code_post_auth_bad_responsetype(self): - """ - Test authorization code is given for an allowed request with a response_type not supported - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "UNKNOWN", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org?error", response["Location"]) - - def test_code_post_auth_forbidden_redirect_uri(self): - """ - Test authorization code is given for an allowed request with a forbidden redirect_uri - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://forbidden.it", - "response_type": "code", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 400) - - def test_code_post_auth_malicious_redirect_uri(self): - """ - Test validation of a malicious redirect_uri - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "/../", - "response_type": "code", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 400) - - def test_code_post_auth_allow_custom_redirect_uri_scheme_code_token(self): - """ - Test authorization code is given for an allowed request with response_type: code - using a non-standard, but allowed, redirect_uri scheme. - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "custom-scheme://example.com", - "response_type": "code token", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("custom-scheme://example.com", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("access_token=", response["Location"]) - - def test_code_post_auth_allow_custom_redirect_uri_scheme_code_id_token(self): - """ - Test authorization code is given for an allowed request with response_type: code - using a non-standard, but allowed, redirect_uri scheme. - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "custom-scheme://example.com", - "response_type": "code id_token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("custom-scheme://example.com", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - - def test_code_post_auth_allow_custom_redirect_uri_scheme_code_id_token_token(self): - """ - Test authorization code is given for an allowed request with response_type: code - using a non-standard, but allowed, redirect_uri scheme. - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "custom-scheme://example.com", - "response_type": "code id_token token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("custom-scheme://example.com", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("access_token=", response["Location"]) - - def test_code_post_auth_deny_custom_redirect_uri_scheme(self): - """ - Test error when resource owner deny access - using a non-standard, but allowed, redirect_uri scheme. - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "custom-scheme://example.com", - "response_type": "code", - "allow": False, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("custom-scheme://example.com?", response["Location"]) - self.assertIn("error=access_denied", response["Location"]) - - def test_code_post_auth_redirection_uri_with_querystring_code_token(self): - """ - Tests that a redirection uri with query string is allowed - and query string is retained on redirection. - See http://tools.ietf.org/html/rfc6749#section-3.1.2 - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.com?foo=bar", - "response_type": "code token", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.com?foo=bar", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("access_token=", response["Location"]) - - def test_code_post_auth_redirection_uri_with_querystring_code_id_token(self): - """ - Tests that a redirection uri with query string is allowed - and query string is retained on redirection. - See http://tools.ietf.org/html/rfc6749#section-3.1.2 - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.com?foo=bar", - "response_type": "code id_token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.com?foo=bar", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - - def test_code_post_auth_redirection_uri_with_querystring_code_id_token_token(self): - """ - Tests that a redirection uri with query string is allowed - and query string is retained on redirection. - See http://tools.ietf.org/html/rfc6749#section-3.1.2 - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.com?foo=bar", - "response_type": "code id_token token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.com?foo=bar", response["Location"]) - self.assertIn("code=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("access_token=", response["Location"]) - - def test_code_post_auth_failing_redirection_uri_with_querystring(self): - """ - Test that in case of error the querystring of the redirection uri is preserved - - See https://github.com/evonove/django-oauth-toolkit/issues/238 - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.com?foo=bar", - "response_type": "code", - "allow": False, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertEqual( - "http://example.com?foo=bar&error=access_denied&state=random_state_string", response["Location"] - ) - - def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): - """ - Tests that a redirection uri is matched using scheme + netloc + path - """ - self.client.login(username="hy_test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.com/a?foo=bar", - "response_type": "code", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 400) - - -class TestHybridTokenView(BaseTest): - def get_auth(self, scope="read write"): - """ - Helper method to retrieve a valid authorization code - """ - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": scope, - "redirect_uri": "http://example.org", - "response_type": "code id_token", - "allow": True, - "nonce": "nonce", - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - fragment_dict = parse_qs(urlparse(response["Location"]).fragment) - return fragment_dict["code"].pop() - - def test_basic_auth(self): - """ - Request an access token using basic authentication for client authentication - """ - self.client.login(username="hy_test_user", password="123456") - authorization_code = self.get_auth() - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_basic_auth_bad_authcode(self): - """ - Request an access token using a bad authorization code - """ - self.client.login(username="hy_test_user", password="123456") - - token_request_data = { - "grant_type": "authorization_code", - "code": "BLAH", - "redirect_uri": "http://example.org" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 400) - - def test_basic_auth_bad_granttype(self): - """ - Request an access token using a bad grant_type string - """ - self.client.login(username="hy_test_user", password="123456") - - token_request_data = { - "grant_type": "UNKNOWN", - "code": "BLAH", - "redirect_uri": "http://example.org" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 400) - - def test_basic_auth_grant_expired(self): - """ - Request an access token using an expired grant token - """ - self.client.login(username="hy_test_user", password="123456") - g = Grant( - application=self.application, user=self.hy_test_user, code="BLAH", - expires=timezone.now(), redirect_uri="", scope="") - g.save() - - token_request_data = { - "grant_type": "authorization_code", - "code": "BLAH", - "redirect_uri": "http://example.org" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 400) - - def test_basic_auth_bad_secret(self): - """ - Request an access token using basic authentication for client authentication - """ - self.client.login(username="hy_test_user", password="123456") - authorization_code = self.get_auth() - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org" - } - auth_headers = get_basic_auth_header(self.application.client_id, "BOOM!") - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 401) - - def test_basic_auth_wrong_auth_type(self): - """ - Request an access token using basic authentication for client authentication - """ - self.client.login(username="hy_test_user", password="123456") - authorization_code = self.get_auth() - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org" - } - - user_pass = "{0}:{1}".format(self.application.client_id, self.application.client_secret) - auth_string = base64.b64encode(user_pass.encode("utf-8")) - auth_headers = { - "HTTP_AUTHORIZATION": "Wrong " + auth_string.decode("utf-8"), - } - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 401) - - def test_request_body_params(self): - """ - Request an access token using client_type: public - """ - self.client.login(username="hy_test_user", password="123456") - authorization_code = self.get_auth() - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - "client_id": self.application.client_id, - "client_secret": self.application.client_secret, - } - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_public(self): - """ - Request an access token using client_type: public - """ - self.client.login(username="hy_test_user", password="123456") - - self.application.client_type = Application.CLIENT_PUBLIC - self.application.save() - authorization_code = self.get_auth() - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - "client_id": self.application.client_id - } - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_id_token_public(self): - """ - Request an access token using client_type: public - """ - self.client.login(username="hy_test_user", password="123456") - - self.application.client_type = Application.CLIENT_PUBLIC - self.application.save() - authorization_code = self.get_auth(scope="openid") - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - "client_id": self.application.client_id, - "scope": "openid", - } - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid") - self.assertIn("access_token", content) - self.assertIn("id_token", content) - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_malicious_redirect_uri(self): - """ - Request an access token using client_type: public and ensure redirect_uri is - properly validated. - """ - self.client.login(username="hy_test_user", password="123456") - - self.application.client_type = Application.CLIENT_PUBLIC - self.application.save() - authorization_code = self.get_auth() - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "/../", - "client_id": self.application.client_id - } - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) - self.assertEqual(response.status_code, 400) - data = response.json() - self.assertEqual(data["error"], "invalid_request") - self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) - - def test_code_exchange_succeed_when_redirect_uri_match(self): - """ - Tests code exchange succeed when redirect uri matches the one used for code request - """ - self.client.login(username="hy_test_user", password="123456") - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org?foo=bar", - "response_type": "code token", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - fragment_dict = parse_qs(urlparse(response["Location"]).fragment) - authorization_code = fragment_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org?foo=bar" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_code_exchange_fails_when_redirect_uri_does_not_match(self): - """ - Tests code exchange fails when redirect uri does not match the one used for code request - """ - self.client.login(username="hy_test_user", password="123456") - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org?foo=bar", - "response_type": "code token", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - query_dict = parse_qs(urlparse(response["Location"]).fragment) - authorization_code = query_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org?foo=baraa" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 400) - data = response.json() - self.assertEqual(data["error"], "invalid_request") - self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) - - def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): - """ - Tests code exchange succeed when redirect uri matches the one used for code request - """ - self.client.login(username="hy_test_user", password="123456") - self.application.redirect_uris = "http://localhost http://example.com?foo=bar" - self.application.save() - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.com?bar=baz&foo=bar", - "response_type": "code token", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - fragment_dict = parse_qs(urlparse(response["Location"]).fragment) - authorization_code = fragment_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.com?bar=baz&foo=bar" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): - """ - Tests code exchange succeed when redirect uri matches the one used for code request - """ - self.client.login(username="hy_test_user", password="123456") - self.application.redirect_uris = "http://localhost http://example.com?foo=bar" - self.application.save() - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.com?bar=baz&foo=bar", - "response_type": "code token", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - fragment_dict = parse_qs(urlparse(response["Location"]).fragment) - authorization_code = fragment_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.com?bar=baz&foo=bar", - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - self.assertEqual(response.status_code, 200) - - content = json.loads(response.content.decode("utf-8")) - self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "openid") - self.assertIn("access_token", content) - self.assertIn("id_token", content) - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - - -class TestHybridProtectedResource(BaseTest): - def test_resource_access_allowed(self): - self.client.login(username="hy_test_user", password="123456") - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid read write", - "redirect_uri": "http://example.org", - "response_type": "code token", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - fragment_dict = parse_qs(urlparse(response["Location"]).fragment) - authorization_code = fragment_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org" - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - content = json.loads(response.content.decode("utf-8")) - access_token = content["access_token"] - - # use token to access the resource - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + access_token, - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.hy_test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response, "This is a protected resource") - - def test_id_token_resource_access_allowed(self): - self.client.login(username="hy_test_user", password="123456") - - # retrieve a valid authorization code - authcode_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "code token", - "allow": True, - } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) - fragment_dict = parse_qs(urlparse(response["Location"]).fragment) - authorization_code = fragment_dict["code"].pop() - - # exchange authorization code for a valid access token - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) - - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) - content = json.loads(response.content.decode("utf-8")) - access_token = content["access_token"] - id_token = content["id_token"] - - # use token to access the resource - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + access_token, - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.hy_test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response, "This is a protected resource") - - # use id_token to access the resource - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + id_token, - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.hy_test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response, "This is a protected resource") - - def test_resource_access_deny(self): - auth_headers = { - "HTTP_AUTHORIZATION": "Bearer " + "faketoken", - } - request = self.factory.get("/fake-resource", **auth_headers) - request.user = self.hy_test_user - - view = ResourceView.as_view() - response = view(request) - self.assertEqual(response.status_code, 403) - - -class TestDefaultScopesHybrid(BaseTest): - - def test_pre_auth_default_scopes(self): - """ - Test response for a valid client_id with response_type: code using default scopes - """ - self.client.login(username="hy_test_user", password="123456") - oauth2_settings._DEFAULT_SCOPES = ["read"] - - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code token", - "state": "random_state_string", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) - - response = self.client.get(url) - self.assertEqual(response.status_code, 200) - - # check form is in context and form params are valid - self.assertIn("form", response.context) - - form = response.context["form"] - self.assertEqual(form["redirect_uri"].value(), "http://example.org") - self.assertEqual(form["state"].value(), "random_state_string") - self.assertEqual(form["scope"].value(), "read") - self.assertEqual(form["client_id"].value(), self.application.client_id) - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 15ac7469d..b51d0e1da 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -1,10 +1,8 @@ -import json from urllib.parse import parse_qs, urlparse from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse -from jwcrypto import jwk, jwt from oauth2_provider.models import get_application_model from oauth2_provider.settings import oauth2_settings @@ -35,14 +33,8 @@ def setUp(self): authorization_grant_type=Application.GRANT_IMPLICIT, ) - oauth2_settings._SCOPES = ["read", "write", "openid"] + oauth2_settings._SCOPES = ["read", "write"] oauth2_settings._DEFAULT_SCOPES = ["read"] - oauth2_settings.SCOPES = { - "read": "Reading scope", - "write": "Writing scope", - "openid": "OpenID connect" - } - self.key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) def tearDown(self): self.application.delete() @@ -273,191 +265,3 @@ def test_resource_access_allowed(self): view = ResourceView.as_view() response = view(request) self.assertEqual(response, "This is a protected resource") - - -class TestOpenIDConnectImplicitFlow(BaseTest): - def test_id_token_post_auth_allow(self): - """ - Test authorization code is given for an allowed request with response_type: id_token - """ - self.client.login(username="test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "nonce": "random_nonce_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "id_token", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org#", response["Location"]) - self.assertNotIn("access_token=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - - uri_query = urlparse(response["Location"]).fragment - uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) - id_token = uri_query_params["id_token"][0] - jwt_token = jwt.JWT(key=self.key, jwt=id_token) - claims = json.loads(jwt_token.claims) - self.assertIn("nonce", claims) - self.assertNotIn("at_hash", claims) - - def test_id_token_skip_authorization_completely(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_data = { - "client_id": self.application.client_id, - "response_type": "id_token", - "state": "random_state_string", - "nonce": "random_nonce_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } - - response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org#", response["Location"]) - self.assertNotIn("access_token=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - - uri_query = urlparse(response["Location"]).fragment - uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) - id_token = uri_query_params["id_token"][0] - jwt_token = jwt.JWT(key=self.key, jwt=id_token) - claims = json.loads(jwt_token.claims) - self.assertIn("nonce", claims) - self.assertNotIn("at_hash", claims) - - def test_id_token_skip_authorization_completely_missing_nonce(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_data = { - "client_id": self.application.client_id, - "response_type": "id_token", - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } - - response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 302) - self.assertIn("error=invalid_request", response["Location"]) - self.assertIn("error_description=Request+is+missing+mandatory+nonce+paramete", response["Location"]) - - def test_id_token_post_auth_deny(self): - """ - Test error when resource owner deny access - """ - self.client.login(username="test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "nonce": "random_nonce_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "id_token", - "allow": False, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("error=access_denied", response["Location"]) - - def test_access_token_and_id_token_post_auth_allow(self): - """ - Test authorization code is given for an allowed request with response_type: token - """ - self.client.login(username="test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "nonce": "random_nonce_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "id_token token", - "allow": True, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org#", response["Location"]) - self.assertIn("access_token=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - - uri_query = urlparse(response["Location"]).fragment - uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) - id_token = uri_query_params["id_token"][0] - jwt_token = jwt.JWT(key=self.key, jwt=id_token) - claims = json.loads(jwt_token.claims) - self.assertIn("nonce", claims) - self.assertIn("at_hash", claims) - - def test_access_token_and_id_token_skip_authorization_completely(self): - """ - If application.skip_authorization = True, should skip the authorization page. - """ - self.client.login(username="test_user", password="123456") - self.application.skip_authorization = True - self.application.save() - - query_data = { - "client_id": self.application.client_id, - "response_type": "id_token token", - "state": "random_state_string", - "nonce": "random_nonce_string", - "scope": "openid", - "redirect_uri": "http://example.org", - } - - response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 302) - self.assertIn("http://example.org#", response["Location"]) - self.assertIn("access_token=", response["Location"]) - self.assertIn("id_token=", response["Location"]) - self.assertIn("state=random_state_string", response["Location"]) - - uri_query = urlparse(response["Location"]).fragment - uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) - id_token = uri_query_params["id_token"][0] - jwt_token = jwt.JWT(key=self.key, jwt=id_token) - claims = json.loads(jwt_token.claims) - self.assertIn("nonce", claims) - self.assertIn("at_hash", claims) - - def test_access_token_and_id_token_post_auth_deny(self): - """ - Test error when resource owner deny access - """ - self.client.login(username="test_user", password="123456") - - form_data = { - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "openid", - "redirect_uri": "http://example.org", - "response_type": "id_token token", - "allow": False, - } - - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) - self.assertEqual(response.status_code, 302) - self.assertIn("error=access_denied", response["Location"]) diff --git a/tests/test_oauth2_backends.py b/tests/test_oauth2_backends.py index 0d98dad8b..d844da5f4 100644 --- a/tests/test_oauth2_backends.py +++ b/tests/test_oauth2_backends.py @@ -65,9 +65,7 @@ def test_create_token_response_gets_extra_credentials(self): payload = "grant_type=password&username=john&password=123456" request = self.factory.post("/o/token/", payload, content_type="application/x-www-form-urlencoded") - with mock.patch( - "oauthlib.openid.connect.core.endpoints.pre_configured.Server.create_token_response" - ) as create_token_response: + with mock.patch("oauthlib.oauth2.Server.create_token_response") as create_token_response: mocked = mock.MagicMock() create_token_response.return_value = mocked, mocked, mocked core = self.MyOAuthLibCore() diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index 1a0926988..7821148d5 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -287,13 +287,6 @@ def test_save_bearer_token__with_new_token__calls_methods_to_create_access_and_r assert create_access_token_mock.call_count == 1 assert create_refresh_token_mock.call_count == 1 - def test_generate_at_hash(self): - # Values taken from spec, https://openid.net/specs/openid-connect-core-1_0.html#id_token-tokenExample - access_token = "jHkWEdUXMU1BwAsC4vtUsZwnNvTIxEl0z9K3vx5KF0Y" - at_hash = self.validator.generate_at_hash(access_token) - - assert at_hash == "77QmUPtjPfzWtF2AnpK9RQ" - class TestOAuth2ValidatorProvidesErrorData(TransactionTestCase): """These test cases check that the recommended error codes are returned diff --git a/tests/test_oidc_views.py b/tests/test_oidc_views.py deleted file mode 100644 index 71f41d7eb..000000000 --- a/tests/test_oidc_views.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import unicode_literals - -from django.test import TestCase -from django.urls import reverse - -from oauth2_provider.settings import oauth2_settings - - -class TestConnectDiscoveryInfoView(TestCase): - def test_get_connect_discovery_info(self): - expected_response = { - "issuer": "http://localhost", - "authorization_endpoint": "http://localhost/o/authorize/", - "token_endpoint": "http://localhost/o/token/", - "userinfo_endpoint": "http://localhost/userinfo/", - "jwks_uri": "http://localhost/o/jwks/", - "response_types_supported": [ - "code", - "token", - "id_token", - "id_token token", - "code token", - "code id_token", - "code id_token token" - ], - "subject_types_supported": ["public"], - "id_token_signing_alg_values_supported": ["RS256", "HS256"], - "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"] - } - response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) - self.assertEqual(response.status_code, 200) - assert response.json() == expected_response - - def test_get_connect_discovery_info_without_issuer_url(self): - oauth2_settings.OIDC_ISS_ENDPOINT = None - oauth2_settings.OIDC_USERINFO_ENDPOINT = None - expected_response = { - "issuer": "http://testserver/o", - "authorization_endpoint": "http://testserver/o/authorize/", - "token_endpoint": "http://testserver/o/token/", - "userinfo_endpoint": "http://testserver/o/userinfo/", - "jwks_uri": "http://testserver/o/jwks/", - "response_types_supported": [ - "code", - "token", - "id_token", - "id_token token", - "code token", - "code id_token", - "code id_token token" - ], - "subject_types_supported": ["public"], - "id_token_signing_alg_values_supported": ["RS256", "HS256"], - "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"] - } - response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) - self.assertEqual(response.status_code, 200) - assert response.json() == expected_response - oauth2_settings.OIDC_ISS_ENDPOINT = "http://localhost" - oauth2_settings.OIDC_USERINFO_ENDPOINT = "http://localhost/userinfo/" - - -class TestJwksInfoView(TestCase): - def test_get_jwks_info(self): - expected_response = { - "keys": [{ - "alg": "RS256", - "use": "sig", - "kid": "s4a1o8mFEd1tATAIH96caMlu4hOxzBUaI2QTqbYNBHs", - "e": "AQAB", - "kty": "RSA", - "n": "mwmIeYdjZkLgalTuhvvwjvnB5vVQc7G9DHgOm20Hw524bLVTk49IXJ2Scw42HOmowWWX-oMVT_ca3ZvVIeffVSN1-TxVy2zB65s0wDMwhiMoPv35z9IKHGMZgl9vlyso_2b7daVF_FQDdgIayUn8TQylBxEU1RFfW0QSYOBdAt8" # noqa - }] - } - response = self.client.get(reverse("oauth2_provider:jwks-info")) - self.assertEqual(response.status_code, 200) - assert response.json() == expected_response diff --git a/tests/urls.py b/tests/urls.py index c7fa9a101..16dcf6ded 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,11 +1,13 @@ +from django.conf.urls import include, url from django.contrib import admin -from django.urls import include, re_path admin.autodiscover() urlpatterns = [ - re_path(r"^o/", include("oauth2_provider.urls", namespace="oauth2_provider")), - re_path(r"^admin/", admin.site.urls), + url(r"^o/", include("oauth2_provider.urls", namespace="oauth2_provider")), ] + + +urlpatterns += [url(r"^admin/", admin.site.urls)] diff --git a/tox.ini b/tox.ini index 686bf366a..c984f8b99 100644 --- a/tox.ini +++ b/tox.ini @@ -14,8 +14,7 @@ envlist = django_find_project = false [testenv] -commands = - pytest --cov=oauth2_provider --cov-report= --cov-append {posargs} -s +commands = pytest --cov=oauth2_provider --cov-report= --cov-append {posargs} setenv = DJANGO_SETTINGS_MODULE = tests.settings PYTHONPATH = {toxinidir} @@ -27,7 +26,6 @@ deps = djangomaster: https://github.com/django/django/archive/master.tar.gz djangorestframework oauthlib>=3.1.0 - jwcrypto coverage pytest pytest-cov @@ -44,7 +42,6 @@ commands = make html deps = sphinx<3 oauthlib>=3.1.0 m2r>=0.2.1 - jwcrypto [testenv:py37-flake8] skip_install = True @@ -70,9 +67,7 @@ commands = [coverage:run] source = oauth2_provider -omit = - */migrations/* - oauth2_provider/settings.py +omit = */migrations/* [flake8] max-line-length = 110 From 1b2d73da1e3d904142bcf6b4eb773c00a280d980 Mon Sep 17 00:00:00 2001 From: David Smith Date: Sat, 10 Oct 2020 06:28:53 +0100 Subject: [PATCH 02/53] Updated url() to path() url() is deprecated in Django 3.1. Path is available in all supported versions of Django. --- tests/urls.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/urls.py b/tests/urls.py index 16dcf6ded..f4b22a4d4 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,13 +1,13 @@ -from django.conf.urls import include, url from django.contrib import admin +from django.urls import include, path admin.autodiscover() urlpatterns = [ - url(r"^o/", include("oauth2_provider.urls", namespace="oauth2_provider")), + path("o/", include("oauth2_provider.urls", namespace="oauth2_provider")), ] -urlpatterns += [url(r"^admin/", admin.site.urls)] +urlpatterns += [path("admin/", admin.site.urls)] From 342a63488fee02c86b1b3e5f399ce00a4f6765d5 Mon Sep 17 00:00:00 2001 From: Mattia Procopio Date: Fri, 16 Oct 2020 19:07:55 +0200 Subject: [PATCH 03/53] Update changelog for 1.3.3 release (#889) --- CHANGELOG.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f32d2eb3e..2f48ba0c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,14 +14,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Security --> -## [unreleased] +## [unreleased] + +## [1.3.3] 2020-10-16 -### added +### Added * added `select_related` in intospect view for better query performance +* #831 Authorization token creation now can receive an expire date +* #831 Added a method to override Grant creation +* #825 Bump oauthlib to 3.1.0 to introduce PKCE ### Fixed * #847: Fix inappropriate message when response from authentication server is not OK. +### Changed +* few smaller improvements to remove older django version compatibility #830, #861, #862, #863 + ## [1.3.2] 2020-03-24 ### Fixed From 02a872c20d641e6731d3df2095c30a620eb3d8a9 Mon Sep 17 00:00:00 2001 From: Alan Crosswell Date: Tue, 20 Oct 2020 01:39:33 -0400 Subject: [PATCH 04/53] release 1.3.3 (#890) --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 3c4e0badc..696e45ff7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = django-oauth-toolkit -version = 1.3.2 +version = 1.3.3 description = OAuth2 Provider for Django long_description = file: README.rst long_description_content_type = text/x-rst From 0a62a9767f31b42aff00ebbe62ff31ffed8fcad2 Mon Sep 17 00:00:00 2001 From: Alan Crosswell Date: Tue, 20 Oct 2020 10:45:46 -0400 Subject: [PATCH 05/53] improve contributing docs (#891) --- docs/contributing.rst | 72 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/docs/contributing.rst b/docs/contributing.rst index 021895e38..5d36149b0 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -2,6 +2,13 @@ Contributing ============ +.. image:: https://jazzband.co/static/img/jazzband.svg + :target: https://jazzband.co/ + :alt: Jazzband + +This is a `Jazzband `_ project. By contributing you agree to abide by the `Contributor Code of Conduct `_ and follow the `guidelines `_. + + Setup ===== @@ -70,7 +77,7 @@ When you begin your PR, you'll be asked to provide the following: JazzBand security team ``. Do not file an issue on the tracker or submit a PR until directed to do so.) -* Make sure your name is in `AUTHORS`. +* Make sure your name is in `AUTHORS`. We want to give credit to all contrbutors! If your PR is not yet ready to be merged mark it as a Work-in-Progress By prepending `WIP:` to the PR title so that it doesn't get inadvertently approved and merged. @@ -106,6 +113,29 @@ How to get your pull request accepted We really want your code, so please follow these simple guidelines to make the process as smooth as possible. +The Checklist +------------- + +A checklist template is automatically added to your PR when you create it. Make sure you've done all the +applicable steps and check them off to indicate you have done so. This is +what you'll see when creating your PR: + + Fixes # + + ## Description of the Change + + ## Checklist + + - [ ] PR only contains one change (considered splitting up PR) + - [ ] unit-test added + - [ ] documentation updated + - [ ] `CHANGELOG.md` updated (only for user relevant changes) + - [ ] author name in `AUTHORS` + +Any PRs that are missing checklist items will not be merged and may be reverted if they are merged by +mistake. + + Run the tests! -------------- @@ -142,5 +172,45 @@ Try reading our code and grasp the overall philosophy regarding method and varia the sake of readability, keep in mind that *simple is better than complex*. If you feel the code is not straightforward, add a comment. If you think a function is not trivial, add a docstrings. +To see if your code formatting will pass muster use: `tox -e py37-flake8` + The contents of this page are heavily based on the docs from `django-admin2 `_ + +Maintainer Checklist +==================== +The following notes are to remind the project maintainers and leads of the steps required to +review and merge PRs and to publish a new release. + +Reviewing and Merging PRs +------------------------ + +- Make sure the PR description includes the `pull request template + `_ +- Confirm that all required checklist items from the PR template are both indicated as done in the + PR description and are actually done. +- Perform a careful review and ask for any needed changes. +- Make sure any PRs only ever improve code coverage percentage. +- All PRs should be be reviewed by one individual (not the submitter) and merged by another. + +PRs that are incorrectly merged may (reluctantly) be reverted by the Project Leads. + + +Publishing a Release +-------------------- + +Only Project Leads can publish a release to pypi.org and rtfd.io. This checklist is a reminder +of steps. + +- When planning a new release, create a `milestone + `_ + and assign issues, PRs, etc. to that milestone. +- Review all commits since the last release and confirm that they are properly + documented in the CHANGELOG. (Unfortunately, this has not always been the case + so you may be stuck documenting things that should have been documented as part of their PRs.) +- Make a final PR for the release that updates: + + - CHANGELOG to show the release date. + - setup.cfg to set `version = ...` + +- Once the final PR is committed push the new release to pypi and rtfd.io. From 6f08e3bfa5e87fed0025ccea1f6befee5dc90bb2 Mon Sep 17 00:00:00 2001 From: David Smith <39445562+smithdc1@users.noreply.github.com> Date: Fri, 23 Oct 2020 07:42:10 +0100 Subject: [PATCH 06/53] Make calls to super() more python3 (#881) --- tests/models.py | 2 +- tests/test_application_views.py | 4 ++-- tests/test_auth_backends.py | 2 +- tests/test_client_credential.py | 2 +- tests/test_decorators.py | 2 +- tests/test_mixins.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/models.py b/tests/models.py index 7ca0c57c5..ad3575844 100644 --- a/tests/models.py +++ b/tests/models.py @@ -13,7 +13,7 @@ class BaseTestApplication(AbstractApplication): def get_allowed_schemes(self): if self.allowed_schemes: return self.allowed_schemes.split() - return super(BaseTestApplication, self).get_allowed_schemes() + return super().get_allowed_schemes() class SampleApplication(AbstractApplication): diff --git a/tests/test_application_views.py b/tests/test_application_views.py index 6130876ce..8f281611b 100644 --- a/tests/test_application_views.py +++ b/tests/test_application_views.py @@ -70,7 +70,7 @@ def _create_application(self, name, user): return app def setUp(self): - super(TestApplicationViews, self).setUp() + super().setUp() self.app_foo_1 = self._create_application("app foo_user 1", self.foo_user) self.app_foo_2 = self._create_application("app foo_user 2", self.foo_user) self.app_foo_3 = self._create_application("app foo_user 3", self.foo_user) @@ -79,7 +79,7 @@ def setUp(self): self.app_bar_2 = self._create_application("app bar_user 2", self.bar_user) def tearDown(self): - super(TestApplicationViews, self).tearDown() + super().tearDown() get_application_model().objects.all().delete() def test_application_list(self): diff --git a/tests/test_auth_backends.py b/tests/test_auth_backends.py index baf82169c..1e1cbb544 100644 --- a/tests/test_auth_backends.py +++ b/tests/test_auth_backends.py @@ -85,7 +85,7 @@ def test_get_user(self): class TestOAuth2Middleware(BaseTest): def setUp(self): - super(TestOAuth2Middleware, self).setUp() + super().setUp() self.anon_user = AnonymousUser() def dummy_get_response(request): diff --git a/tests/test_client_credential.py b/tests/test_client_credential.py index 09401cf0e..0f3756358 100644 --- a/tests/test_client_credential.py +++ b/tests/test_client_credential.py @@ -105,7 +105,7 @@ class TestExtendedRequest(BaseTest): @classmethod def setUpClass(cls): cls.request_factory = RequestFactory() - super(TestExtendedRequest, cls).setUpClass() + super().setUpClass() def test_extended_request(self): class TestView(OAuthLibMixin, View): diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 0732b2920..80d2ae1a2 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -18,7 +18,7 @@ class TestProtectedResourceDecorator(TestCase): @classmethod def setUpClass(cls): cls.request_factory = RequestFactory() - super(TestProtectedResourceDecorator, cls).setUpClass() + super().setUpClass() def setUp(self): self.user = UserModel.objects.create_user("test_user", "test@example.com", "123456") diff --git a/tests/test_mixins.py b/tests/test_mixins.py index b8aa9ac4d..5a4531596 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -14,7 +14,7 @@ class BaseTest(TestCase): @classmethod def setUpClass(cls): cls.request_factory = RequestFactory() - super(BaseTest, cls).setUpClass() + super().setUpClass() class TestOAuthLibMixin(BaseTest): From a3c085e60da809dfd3ec1e16e049899a6c8f6922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Skar=C5=BCy=C5=84ski?= Date: Thu, 12 Nov 2020 10:52:32 +0100 Subject: [PATCH 07/53] pass PKCE fields to AuthorizationView form (#896) * add tests for issue of PKCE authorization code GET request * pass PKCE fields to AuthorizationView form Pass code_challenge and code_challenge_method from query string to AuthorizationView form in get(). Without this, it was impossible to use authorization code grant flow with GET, because code_challenge and code_challenge_method data were never passed to form, so they weren't in form.cleaned_data, which causes creating Grant with always empty code_challenge and code_challenge_method. This issue was quite hard bug to discover because there are already few tests for authorization code flow pkce, however, they weren't checking form rendering in GET request, but only response.status_code, I have added asserts for these 2 values, please look at the changes in test_public_pkce_plain_authorize_get and test_public_pkce_S256_authorize_get tests in test_authorization_code.py. --- AUTHORS | 3 ++- oauth2_provider/views/base.py | 4 ++++ tests/test_authorization_code.py | 10 ++++++---- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/AUTHORS b/AUTHORS index 611a0e62b..ef1708d5c 100644 --- a/AUTHORS +++ b/AUTHORS @@ -30,4 +30,5 @@ Rodney Richardson Silvano Cerza Stéphane Raimbault Jun Zhou -David Smith \ No newline at end of file +David Smith +Łukasz Skarżyński diff --git a/oauth2_provider/views/base.py b/oauth2_provider/views/base.py index b9b6ed7f9..f9a28cfaa 100644 --- a/oauth2_provider/views/base.py +++ b/oauth2_provider/views/base.py @@ -156,6 +156,10 @@ def get(self, request, *args, **kwargs): kwargs["redirect_uri"] = credentials["redirect_uri"] kwargs["response_type"] = credentials["response_type"] kwargs["state"] = credentials["state"] + if "code_challenge" in credentials: + kwargs["code_challenge"] = credentials["code_challenge"] + if "code_challenge_method" in credentials: + kwargs["code_challenge_method"] = credentials["code_challenge_method"] self.oauth2_data = kwargs # following two loc are here only because of https://code.djangoproject.com/ticket/17795 diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index e98f5b041..a80a54490 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -1012,7 +1012,7 @@ def test_public_pkce_S256_authorize_get(self): """ Request an access token using client_type: public and PKCE enabled. Tests if the authorize get is successfull - for the S256 algorithm + for the S256 algorithm and form data are properly passed. """ self.client.login(username="test_user", password="123456") @@ -1033,14 +1033,15 @@ def test_public_pkce_S256_authorize_get(self): } response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 200) + self.assertContains(response, 'value="S256"', count=1, status_code=200) + self.assertContains(response, 'value="{0}"'.format(code_challenge), count=1, status_code=200) oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_plain_authorize_get(self): """ Request an access token using client_type: public and PKCE enabled. Tests if the authorize get is successfull - for the plain algorithm + for the plain algorithm and form data are properly passed. """ self.client.login(username="test_user", password="123456") @@ -1061,7 +1062,8 @@ def test_public_pkce_plain_authorize_get(self): } response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) - self.assertEqual(response.status_code, 200) + self.assertContains(response, 'value="plain"', count=1, status_code=200) + self.assertContains(response, 'value="{0}"'.format(code_challenge), count=1, status_code=200) oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_S256(self): From 6dc1ca20b78569fdfdaa66bb53558e0ed818257c Mon Sep 17 00:00:00 2001 From: Timm Simpkins Date: Fri, 13 Nov 2020 09:01:28 -0500 Subject: [PATCH 08/53] Fixed some grammar and spelling mistakes in the docs. (#895) --- docs/getting_started.rst | 26 +++++++++++++------------- docs/index.rst | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/getting_started.rst b/docs/getting_started.rst index c5b5ec51c..427195ae9 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -82,7 +82,7 @@ That’ll create a directory :file:`users`, which is laid out like this:: If you’re starting a new project, it’s highly recommended to set up a custom user model, even if the default `User`_ model is sufficient for you. This model behaves identically to the default user model, but you’ll be able to customize it in the future if the need arises. -- `Django documentation`_ -Edit :file:`users/models.py` adding the code bellow: +Edit :file:`users/models.py` adding the code below: .. code-block:: python @@ -105,7 +105,7 @@ Change :file:`iam/settings.py` to add ``users`` application to ``INSTALLED_APPS` 'users', ] -Configure ``users.User`` to be the model used for the ``auth`` application adding ``AUTH_USER_MODEL`` to :file:`iam/settings.py`: +Configure ``users.User`` to be the model used for the ``auth`` application by adding ``AUTH_USER_MODEL`` to :file:`iam/settings.py`: .. code-block:: python @@ -152,7 +152,7 @@ The ``migrate`` output:: Django OAuth Toolkit -------------------- -Django OAuth Toolkit can help you providing out of the box all the endpoints, data and logic needed to add OAuth2 capabilities to your Django projects. +Django OAuth Toolkit can help you by providing, out of the box, all the endpoints, data, and logic needed to add OAuth2 capabilities to your Django projects. Install Django OAuth Toolkit:: @@ -231,12 +231,12 @@ We will start by given a try to the grant types listed below: * Authorization code * Client credential -This two grant types cover the most initially used uses cases. +These two grant types cover the most initially used use cases. Authorization Code ------------------ -The Authorization Code flow is best used in web and mobile apps. This is the flow used for third party integration, the user authorize your partner to access its products in your APIs. +The Authorization Code flow is best used in web and mobile apps. This is the flow used for third party integration, the user authorizes your partner to access its products in your APIs. Start the development server:: @@ -256,7 +256,7 @@ Export ``Client id`` and ``Client secret`` values as environment variable: export ID=vW1RcAl7Mb0d5gyHNQIAcH110lWoOW2BmWJIero8 export SECRET=DZFpuNjRdt5xUEzxXovAp40bU3lQvoMvF3awEStn61RXWE0Ses4RgzHWKJKTvUCHfRkhcBi3ebsEfSjfEO96vo2Sh6pZlxJ6f7KcUbhvqMMPoVxRwv4vfdWEoWMGPeIO -To start the Authorization code flow got to this `URL`_ with is the same as show bellow:: +To start the Authorization code flow go to this `URL`_ which is the same as shown below:: http://127.0.0.1:8000/o/authorize/?response_type=code&client_id=vW1RcAl7Mb0d5gyHNQIAcH110lWoOW2BmWJIero8&redirect_uri=http://127.0.0.1:8000/noexist/callback @@ -273,13 +273,13 @@ Go ahead and authorize the ``web-app`` .. image:: _images/application-authorize-web-app.png :alt: Authorization code authorize web-app -Remenber we used ``http://127.0.0.1:8000/noexist/callback`` as ``redirect_uri`` you will get a **Page not found (404)** but it worked if you get a url like:: +Remember we used ``http://127.0.0.1:8000/noexist/callback`` as ``redirect_uri`` you will get a **Page not found (404)** but it worked if you get a url like:: http://127.0.0.1:8000/noexist/callback?code=uVqLxiHDKIirldDZQfSnDsmYW1Abj2 -This is the OAuth2 provider trying to give you a ``code`` in this case ``uVqLxiHDKIirldDZQfSnDsmYW1Abj2``. +This is the OAuth2 provider trying to give you a ``code``. in this case ``uVqLxiHDKIirldDZQfSnDsmYW1Abj2``. -Export it as environment variable: +Export it as an environment variable: .. code-block:: sh @@ -326,7 +326,7 @@ The Client Credential grant is suitable for machine-to-machine authentication. Y Point your browser to http://127.0.0.1:8000/o/applications/register/ lets create an application. -Fill the form as show in the screenshot bellow and before save take note of ``Client id`` and ``Client secret`` we will use it in a minute. +Fill the form as show in the screenshot below, and before saving take note of ``Client id`` and ``Client secret`` we will use it in a minute. .. image:: _images/application-register-client-credential.png :alt: Client credential application registration @@ -352,7 +352,7 @@ We need to encode ``client_id`` and ``client_secret`` as HTTP base authenticatio b'YXhYU1NCVnV2T3lHVnpoNFB1cnZLYXE1TUhYTW03RnRySGdETWk0dToxZnV2NVdWZlI3QTVCbEYwbzE1NUg3czViTGdYbHdXTGhpM1k3cGRKOWFKdUNkbDBYVjVDeGdkMHRyaTduU3pDODBxeXJvdmg4cUZYRkhnRkFBYzBsZFBObjVaWUxhbnhTbTFTSTFyeGxScldVUDU5MXdwSERHYTNwU3BCNmRDWg==' >>> -Export the credential as environment variable +Export the credential as an environment variable .. code-block:: sh @@ -362,7 +362,7 @@ To start the Client Credential flow you call ``/token/`` endpoint direct:: curl -X POST -H "Authorization: Basic ${CREDENTIAL}" -H "Cache-Control: no-cache" -H "Content-Type: application/x-www-form-urlencoded" "http://127.0.0.1:8000/o/token/" -d "grant_type=client_credentials" -To be more easy to visualize:: +To be easier to visualize:: curl -X POST \ -H "Authorization: Basic ${CREDENTIAL}" \ @@ -371,7 +371,7 @@ To be more easy to visualize:: "http://127.0.0.1:8000/o/token/" \ -d "grant_type=client_credentials" -The OAuth2 provider will return the follow response: +The OAuth2 provider will return the following response: .. code-block:: javascript diff --git a/docs/index.rst b/docs/index.rst index 635837832..51696a6f4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,7 +6,7 @@ Welcome to Django OAuth Toolkit Documentation ============================================= -Django OAuth Toolkit can help you providing out of the box all the endpoints, data and logic needed to add OAuth2 +Django OAuth Toolkit can help you by providing, out of the box, all the endpoints, data, and logic needed to add OAuth2 capabilities to your Django projects. Django OAuth Toolkit makes extensive use of the excellent `OAuthLib `_, so that everything is `rfc-compliant `_. From 29bed25680ba863524e467129b21ce6f45631499 Mon Sep 17 00:00:00 2001 From: David Smith <39445562+smithdc1@users.noreply.github.com> Date: Sat, 14 Nov 2020 17:16:13 +0000 Subject: [PATCH 09/53] Added Python 3.9 to test matrix for djangomaster (#884) * Added Python 3.9 to test matrix for djangomaster * Added Python3.9 support for Django 2.2 and 3.0 --- .travis.yml | 8 ++++++++ CHANGELOG.md | 3 +++ setup.cfg | 1 + tox.ini | 2 ++ 4 files changed, 14 insertions(+) diff --git a/.travis.yml b/.travis.yml index 2aef56d6f..65284a65f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,6 +14,7 @@ matrix: - env: TOXENV=py36-djangomaster - env: TOXENV=py37-djangomaster - env: TOXENV=py38-djangomaster + - env: TOXENV=py39-djangomaster include: - python: 3.7 @@ -21,6 +22,13 @@ matrix: - python: 3.7 env: TOXENV=py37-docs + - python: 3.9 + env: TOXENV=py39-djangomaster + - python: 3.9 + env: TOXENV=py39-django30 + - python: 3.9 + env: TOXENV=py39-django22 + - python: 3.8 env: TOXENV=py38-django30 - python: 3.8 diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f48ba0c9..3fce3f882 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +### Added +* #884 Added support for Python 3.9 + ## [1.3.3] 2020-10-16 ### Added diff --git a/setup.cfg b/setup.cfg index 696e45ff7..6c2012991 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ classifiers = Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 Topic :: Internet :: WWW/HTTP [options] diff --git a/tox.ini b/tox.ini index c984f8b99..d3218b19f 100644 --- a/tox.ini +++ b/tox.ini @@ -6,6 +6,8 @@ envlist = py37-django{30,22,21}, py36-django{22,21}, py35-django{22,21}, + py39-django{22,30} + py39-djangomaster, py38-djangomaster, py37-djangomaster, py36-djangomaster, From afd651c8f1e160b608af807acf7c60029d465b32 Mon Sep 17 00:00:00 2001 From: David Smith Date: Sat, 14 Nov 2020 19:57:07 +0000 Subject: [PATCH 10/53] Updated supported Django versions added Support for 3.1 Removed support for 2.1 --- .travis.yml | 16 ++++++++-------- CHANGELOG.md | 1 + docs/index.rst | 2 +- setup.cfg | 4 ++-- tox.ini | 12 ++++++------ 5 files changed, 18 insertions(+), 17 deletions(-) diff --git a/.travis.yml b/.travis.yml index 65284a65f..1505d8cf3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -29,33 +29,33 @@ matrix: - python: 3.9 env: TOXENV=py39-django22 + - python: 3.8 + env: TOXENV=py38-django31 - python: 3.8 env: TOXENV=py38-django30 - python: 3.8 env: TOXENV=py38-django22 - - python: 3.8 - env: TOXENV=py38-django21 - python: 3.8 env: TOXENV=py38-djangomaster + - python: 3.7 + env: TOXENV=py37-django31 - python: 3.7 env: TOXENV=py37-django30 - python: 3.7 env: TOXENV=py37-django22 - - python: 3.7 - env: TOXENV=py37-django21 - python: 3.7 env: TOXENV=py37-djangomaster - python: 3.6 - env: TOXENV=py36-django22 + env: TOXENV=py36-django31 + - python: 3.6 + env: TOXENV=py36-django30 - python: 3.6 - env: TOXENV=py36-django21 + env: TOXENV=py36-django22 - python: 3.5 env: TOXENV=py35-django22 - - python: 3.5 - env: TOXENV=py35-django21 install: - pip install coveralls tox tox-travis diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fce3f882..b10ebfe30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * #831 Authorization token creation now can receive an expire date * #831 Added a method to override Grant creation * #825 Bump oauthlib to 3.1.0 to introduce PKCE +* Support for Django 3.1 ### Fixed * #847: Fix inappropriate message when response from authentication server is not OK. diff --git a/docs/index.rst b/docs/index.rst index 51696a6f4..75ed1afcf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -22,7 +22,7 @@ Requirements ------------ * Python 3.5+ -* Django 2.1+ +* Django 2.2+ * oauthlib 3.1+ Index diff --git a/setup.cfg b/setup.cfg index 6c2012991..df6db19d4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,9 +13,9 @@ classifiers = Development Status :: 5 - Production/Stable Environment :: Web Environment Framework :: Django - Framework :: Django :: 2.1 Framework :: Django :: 2.2 Framework :: Django :: 3.0 + Framework :: Django :: 3.1 Intended Audience :: Developers License :: OSI Approved :: BSD License Operating System :: OS Independent @@ -32,7 +32,7 @@ packages = find: include_package_data = True zip_safe = False install_requires = - django >= 2.1 + django >= 2.2 requests >= 2.13.0 oauthlib >= 3.1.0 diff --git a/tox.ini b/tox.ini index d3218b19f..ab677a738 100644 --- a/tox.ini +++ b/tox.ini @@ -2,11 +2,11 @@ envlist = py37-flake8, py37-docs, - py38-django{30,22,21}, - py37-django{30,22,21}, - py36-django{22,21}, - py35-django{22,21}, - py39-django{22,30} + py39-django{31,30,22}, + py38-django{31,30,22}, + py37-django{31,30,22}, + py36-django{31,30,22}, + py35-django{22}, py39-djangomaster, py38-djangomaster, py37-djangomaster, @@ -22,9 +22,9 @@ setenv = PYTHONPATH = {toxinidir} PYTHONWARNINGS = all deps = - django21: Django>=2.1,<2.2 django22: Django>=2.2,<3 django30: Django>=3.0,<3.1 + django31: Django>=3.1,<3.2 djangomaster: https://github.com/django/django/archive/master.tar.gz djangorestframework oauthlib>=3.1.0 From c2f379d103624b8eb0474524b6a656ea5870d0b9 Mon Sep 17 00:00:00 2001 From: David Smith <39445562+smithdc1@users.noreply.github.com> Date: Sat, 14 Nov 2020 19:59:36 +0000 Subject: [PATCH 11/53] Removed universal wheels (py2) --- setup.cfg | 3 --- 1 file changed, 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index df6db19d4..98ef302b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,6 +38,3 @@ install_requires = [options.packages.find] exclude = tests - -[bdist_wheel] -universal = 1 From 5cb5398905e5ed6bf071e3aca14b9e05aafc4026 Mon Sep 17 00:00:00 2001 From: Tom Evans Date: Tue, 17 Nov 2020 08:42:03 +0000 Subject: [PATCH 12/53] Use black for formatting the code (#887) * Add black, isort and pre-commit hooks * Add black configuration * Run black as part of flake8 testenv in tox * Add editorconfig to ensure indent style in tox.ini/setup.cfg * Add pre-commit hooks to check flake8, black, isort and common errors * Update isort configuration to be black-compatible * Update contributing documentation * Add myself to AUTHORS * Skip migrations in black/isort/pre-commit * Run black over the source tree This is the result of running `black .` over the repository. By-hand improvements of the blackened code will be in follow up commits, to make it easier to reapply this commit to future updates, if necessary - IE to remove this commit and re-run black over a fresh tree, rather than trying to merge new changes in to this commit. * Hand tweak some of black's autoformatting Some minor hand tweaks: oauth2_provider/contrib/rest_framework/authentication.py oauth2_provider/oauth2_validators.py Construct OrderedDict in a clearer, still black compliant way (one line per dict entry) tests/test_token_revocation.py Remove empty method docstrings * Apply isort over codebase Co-authored-by: Tom Evans --- .editorconfig | 15 ++ .pre-commit-config.yaml | 27 ++ AUTHORS | 1 + docs/conf.py | 151 ++++++----- docs/contributing.rst | 28 ++ oauth2_provider/admin.py | 9 +- .../contrib/rest_framework/__init__.py | 7 +- .../contrib/rest_framework/authentication.py | 14 +- .../contrib/rest_framework/permissions.py | 43 +-- oauth2_provider/decorators.py | 7 +- oauth2_provider/exceptions.py | 2 + oauth2_provider/generators.py | 1 + oauth2_provider/http.py | 5 +- .../management/commands/createapplication.py | 13 +- oauth2_provider/models.py | 112 ++++---- oauth2_provider/oauth2_backends.py | 30 +-- oauth2_provider/oauth2_validators.py | 126 ++++----- oauth2_provider/settings.py | 5 +- oauth2_provider/urls.py | 7 +- oauth2_provider/validators.py | 11 +- oauth2_provider/views/__init__.py | 21 +- oauth2_provider/views/application.py | 30 ++- oauth2_provider/views/base.py | 54 ++-- oauth2_provider/views/generic.py | 16 +- oauth2_provider/views/introspect.py | 23 +- oauth2_provider/views/mixins.py | 21 +- oauth2_provider/views/token.py | 6 +- pyproject.toml | 10 + tests/models.py | 20 +- tests/settings.py | 25 +- tests/test_application_views.py | 6 +- tests/test_auth_backends.py | 8 +- tests/test_authorization_code.py | 141 +++++----- tests/test_client_credential.py | 6 +- tests/test_commands.py | 1 - tests/test_decorators.py | 2 +- tests/test_generator.py | 7 +- tests/test_introspection_auth.py | 42 +-- tests/test_introspection_view.py | 251 ++++++++++-------- tests/test_mixins.py | 8 +- tests/test_models.py | 55 ++-- tests/test_oauth2_backends.py | 26 +- tests/test_oauth2_validators.py | 135 ++++++---- tests/test_rest_framework.py | 16 +- tests/test_scopes.py | 20 +- tests/test_token_revocation.py | 58 ++-- tests/test_token_view.py | 48 ++-- tox.ini | 36 ++- 48 files changed, 957 insertions(+), 749 deletions(-) create mode 100644 .editorconfig create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..2ca598bbd --- /dev/null +++ b/.editorconfig @@ -0,0 +1,15 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_style = space +indent_size = 4 +insert_final_newline = true +trim_trailing_whitespace = true + +[{Makefile,tox.ini,setup.cfg}] +indent_style = tab + +[*.{yml,yaml}] +indent_size = 2 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..323a7fcff --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,27 @@ +repos: + - repo: https://github.com/ambv/black + rev: 20.8b1 + hooks: + - id: black + exclude: ^(oauth2_provider/migrations/|tests/migrations/) + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: check-ast + - id: trailing-whitespace + - id: check-merge-conflict + - id: check-json + - id: check-xml + - id: check-yaml + - id: mixed-line-ending + args: ['--fix=lf'] + - repo: https://github.com/PyCQA/isort + rev: 5.6.3 + hooks: + - id: isort + exclude: ^(oauth2_provider/migrations/|tests/migrations/) + - repo: https://gitlab.com/pycqa/flake8 + rev: 3.8.4 + hooks: + - id: flake8 + exclude: ^(oauth2_provider/migrations/|tests/migrations/) diff --git a/AUTHORS b/AUTHORS index ef1708d5c..4f9cd850b 100644 --- a/AUTHORS +++ b/AUTHORS @@ -32,3 +32,4 @@ Stéphane Raimbault Jun Zhou David Smith Łukasz Skarżyński +Tom Evans diff --git a/docs/conf.py b/docs/conf.py index 628fb4bed..fefcff4dc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -32,27 +32,33 @@ # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.todo', 'sphinx.ext.coverage', 'rfc', 'm2r',] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "rfc", + "m2r", +] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Django OAuth Toolkit' -copyright = u'2013, Evonove' +project = "Django OAuth Toolkit" +copyright = "2013, Evonove" # The version info for the project you're documenting, acts as replacement for @@ -66,181 +72,176 @@ # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # http://www.sphinx-doc.org/en/1.5.1/ext/intersphinx.html -extensions.append('sphinx.ext.intersphinx') -intersphinx_mapping = {'python3': ('https://docs.python.org/3.6', None), - 'django': ('http://django.readthedocs.org/en/latest/', None)} - +extensions.append("sphinx.ext.intersphinx") +intersphinx_mapping = { + "python3": ("https://docs.python.org/3.6", None), + "django": ("http://django.readthedocs.org/en/latest/", None), +} # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -#html_theme = 'classic' +# html_theme = 'classic' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -#html_static_path = ['_static'] +# html_static_path = ['_static'] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'DjangoOAuthToolkitdoc' +htmlhelp_basename = "DjangoOAuthToolkitdoc" # -- Options for LaTeX output -------------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'DjangoOAuthToolkit.tex', u'Django OAuth Toolkit Documentation', - u'Evonove', 'manual'), + ("index", "DjangoOAuthToolkit.tex", "Django OAuth Toolkit Documentation", "Evonove", "manual"), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output -------------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'djangooauthtoolkit', u'Django OAuth Toolkit Documentation', - [u'Evonove'], 1) -] +man_pages = [("index", "djangooauthtoolkit", "Django OAuth Toolkit Documentation", ["Evonove"], 1)] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------------ @@ -249,19 +250,25 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'DjangoOAuthToolkit', u'Django OAuth Toolkit Documentation', - u'Evonove', 'DjangoOAuthToolkit', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "DjangoOAuthToolkit", + "Django OAuth Toolkit Documentation", + "Evonove", + "DjangoOAuthToolkit", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False diff --git a/docs/contributing.rst b/docs/contributing.rst index 5d36149b0..39ed1a427 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -24,6 +24,34 @@ You can find the list of bugs, enhancements and feature requests on the `issue tracker `_. If you want to fix an issue, pick up one and add a comment stating you're working on it. +Code Style +========== + +The project uses `flake8 `_ for linting, +`black `_ for formatting the code, +`isort `_ for formatting and sorting imports, +and `pre-commit `_ for checking/fixing commits for +correctness before they are made. + +You will need to install ``pre-commit`` yourself, and then ``pre-commit`` will +take care of installing ``flake8``, ``black`` and ``isort``. + +After cloning your repository, go into it and run:: + + pre-commit install + +to install the hooks. On the next commit that you make, ``pre-commit`` will +download and install the necessary hooks (a one off task). If anything in the +commit would fail the hooks, the commit will be abandoned. For ``black`` and +``isort``, any necessary changes will be made automatically, but not staged. +Review the changes, and then re-stage and commit again. + +Using ``pre-commit`` ensures that code that would fail in QA does not make it +into a commit in the first place, and will save you time in the long run. You +can also (largely) stop worrying about code style, although you should always +check how the code looks after ``black`` has formatted it, and think if there +is a better way to structure the code so that it is more readable. + Pull requests ============= diff --git a/oauth2_provider/admin.py b/oauth2_provider/admin.py index 8b963d981..a2ec8501a 100644 --- a/oauth2_provider/admin.py +++ b/oauth2_provider/admin.py @@ -1,9 +1,6 @@ from django.contrib import admin -from .models import ( - get_access_token_model, get_application_model, - get_grant_model, get_refresh_token_model -) +from .models import get_access_token_model, get_application_model, get_grant_model, get_refresh_token_model class ApplicationAdmin(admin.ModelAdmin): @@ -13,12 +10,12 @@ class ApplicationAdmin(admin.ModelAdmin): "client_type": admin.HORIZONTAL, "authorization_grant_type": admin.VERTICAL, } - raw_id_fields = ("user", ) + raw_id_fields = ("user",) class GrantAdmin(admin.ModelAdmin): list_display = ("code", "application", "user", "expires") - raw_id_fields = ("user", ) + raw_id_fields = ("user",) class AccessTokenAdmin(admin.ModelAdmin): diff --git a/oauth2_provider/contrib/rest_framework/__init__.py b/oauth2_provider/contrib/rest_framework/__init__.py index a004c1872..b54f42220 100644 --- a/oauth2_provider/contrib/rest_framework/__init__.py +++ b/oauth2_provider/contrib/rest_framework/__init__.py @@ -1,6 +1,9 @@ # flake8: noqa from .authentication import OAuth2Authentication from .permissions import ( - TokenHasScope, TokenHasReadWriteScope, TokenMatchesOASRequirements, - TokenHasResourceScope, IsAuthenticatedOrTokenHasScope + IsAuthenticatedOrTokenHasScope, + TokenHasReadWriteScope, + TokenHasResourceScope, + TokenHasScope, + TokenMatchesOASRequirements, ) diff --git a/oauth2_provider/contrib/rest_framework/authentication.py b/oauth2_provider/contrib/rest_framework/authentication.py index 228361967..53087f756 100644 --- a/oauth2_provider/contrib/rest_framework/authentication.py +++ b/oauth2_provider/contrib/rest_framework/authentication.py @@ -9,16 +9,14 @@ class OAuth2Authentication(BaseAuthentication): """ OAuth 2 authentication backend using `django-oauth-toolkit` """ + www_authenticate_realm = "api" def _dict_to_string(self, my_dict): """ Return a string of comma-separated key-value pairs (e.g. k="v",k2="v2"). """ - return ",".join([ - '{k}="{v}"'.format(k=k, v=v) - for k, v in my_dict.items() - ]) + return ",".join(['{k}="{v}"'.format(k=k, v=v) for k, v in my_dict.items()]) def authenticate(self, request): """ @@ -36,9 +34,11 @@ def authenticate_header(self, request): """ Bearer is the only finalized type currently """ - www_authenticate_attributes = OrderedDict([ - ("realm", self.www_authenticate_realm,), - ]) + www_authenticate_attributes = OrderedDict( + [ + ("realm", self.www_authenticate_realm), + ] + ) oauth2_error = getattr(request, "oauth2_error", {}) www_authenticate_attributes.update(oauth2_error) return "Bearer {attributes}".format( diff --git a/oauth2_provider/contrib/rest_framework/permissions.py b/oauth2_provider/contrib/rest_framework/permissions.py index 7ba1c5c71..1050bf751 100644 --- a/oauth2_provider/contrib/rest_framework/permissions.py +++ b/oauth2_provider/contrib/rest_framework/permissions.py @@ -2,9 +2,7 @@ from django.core.exceptions import ImproperlyConfigured from rest_framework.exceptions import PermissionDenied -from rest_framework.permissions import ( - SAFE_METHODS, BasePermission, IsAuthenticated -) +from rest_framework.permissions import SAFE_METHODS, BasePermission, IsAuthenticated from ...settings import oauth2_settings from .authentication import OAuth2Authentication @@ -33,10 +31,10 @@ def has_permission(self, request, view): # Provide information about required scope? include_required_scope = ( - oauth2_settings.ERROR_RESPONSE_WITH_SCOPES and - required_scopes and - not token.is_expired() and - not token.allow_scopes(required_scopes) + oauth2_settings.ERROR_RESPONSE_WITH_SCOPES + and required_scopes + and not token.is_expired() + and not token.allow_scopes(required_scopes) ) if include_required_scope: @@ -47,9 +45,11 @@ def has_permission(self, request, view): return False - assert False, ("TokenHasScope requires the" - "`oauth2_provider.rest_framework.OAuth2Authentication` authentication " - "class to be used.") + assert False, ( + "TokenHasScope requires the" + "`oauth2_provider.rest_framework.OAuth2Authentication` authentication " + "class to be used." + ) def get_scopes(self, request, view): try: @@ -96,9 +96,7 @@ def get_scopes(self, request, view): else: scope_type = oauth2_settings.WRITE_SCOPE - required_scopes = [ - "{}:{}".format(scope, scope_type) for scope in view_scopes - ] + required_scopes = ["{}:{}".format(scope, scope_type) for scope in view_scopes] return required_scopes @@ -113,6 +111,7 @@ class IsAuthenticatedOrTokenHasScope(BasePermission): the browsable api's if they log in using the a non token bassed middleware, and let them access the api's using a rest client with a token """ + def has_permission(self, request, view): is_authenticated = IsAuthenticated().has_permission(request, view) oauth2authenticated = False @@ -155,8 +154,11 @@ def has_permission(self, request, view): m = request.method.upper() if m in required_alternate_scopes: - log.debug("Required scopes alternatives to access resource: {0}" - .format(required_alternate_scopes[m])) + log.debug( + "Required scopes alternatives to access resource: {0}".format( + required_alternate_scopes[m] + ) + ) for alt in required_alternate_scopes[m]: if token.is_valid(alt): return True @@ -165,9 +167,11 @@ def has_permission(self, request, view): log.warning("no scope alternates defined for method {0}".format(m)) return False - assert False, ("TokenMatchesOASRequirements requires the" - "`oauth2_provider.rest_framework.OAuth2Authentication` authentication " - "class to be used.") + assert False, ( + "TokenMatchesOASRequirements requires the" + "`oauth2_provider.rest_framework.OAuth2Authentication` authentication " + "class to be used." + ) def get_required_alternate_scopes(self, request, view): try: @@ -175,4 +179,5 @@ def get_required_alternate_scopes(self, request, view): except AttributeError: raise ImproperlyConfigured( "TokenMatchesOASRequirements requires the view to" - " define the required_alternate_scopes attribute") + " define the required_alternate_scopes attribute" + ) diff --git a/oauth2_provider/decorators.py b/oauth2_provider/decorators.py index d4b7085aa..0ab26ddb4 100644 --- a/oauth2_provider/decorators.py +++ b/oauth2_provider/decorators.py @@ -33,7 +33,9 @@ def _validate(request, *args, **kwargs): request.resource_owner = oauthlib_req.user return view_func(request, *args, **kwargs) return HttpResponseForbidden() + return _validate + return decorator @@ -62,8 +64,7 @@ def _validate(request, *args, **kwargs): if not set(read_write_scopes).issubset(set(provided_scopes)): raise ImproperlyConfigured( "rw_protected_resource decorator requires following scopes {0}" - " to be in OAUTH2_PROVIDER['SCOPES'] list in settings".format( - read_write_scopes) + " to be in OAUTH2_PROVIDER['SCOPES'] list in settings".format(read_write_scopes) ) # Check if method is safe @@ -80,5 +81,7 @@ def _validate(request, *args, **kwargs): request.resource_owner = oauthlib_req.user return view_func(request, *args, **kwargs) return HttpResponseForbidden() + return _validate + return decorator diff --git a/oauth2_provider/exceptions.py b/oauth2_provider/exceptions.py index 215515500..c4208488d 100644 --- a/oauth2_provider/exceptions.py +++ b/oauth2_provider/exceptions.py @@ -2,6 +2,7 @@ class OAuthToolkitError(Exception): """ Base class for exceptions """ + def __init__(self, error=None, redirect_uri=None, *args, **kwargs): super().__init__(*args, **kwargs) self.oauthlib_error = error @@ -14,4 +15,5 @@ class FatalClientError(OAuthToolkitError): """ Class for critical errors """ + pass diff --git a/oauth2_provider/generators.py b/oauth2_provider/generators.py index ab5d25a7a..f72bc6e7a 100644 --- a/oauth2_provider/generators.py +++ b/oauth2_provider/generators.py @@ -8,6 +8,7 @@ class BaseHashGenerator: """ All generators should extend this class overriding `.hash()` method. """ + def hash(self): raise NotImplementedError() diff --git a/oauth2_provider/http.py b/oauth2_provider/http.py index 980cb7bd4..274ed81af 100644 --- a/oauth2_provider/http.py +++ b/oauth2_provider/http.py @@ -11,6 +11,7 @@ class OAuth2ResponseRedirect(HttpResponse): Works like django.http.HttpResponseRedirect but we customize it to give us more flexibility on allowed scheme validation. """ + status_code = 302 def __init__(self, redirect_to, allowed_schemes, *args, **kwargs): @@ -28,6 +29,4 @@ def validate_redirect(self, redirect_to): if not parsed.scheme: raise DisallowedRedirect("OAuth2 redirects require a URI scheme.") if parsed.scheme not in self.allowed_schemes: - raise DisallowedRedirect( - "Redirect to scheme {!r} is not permitted".format(parsed.scheme) - ) + raise DisallowedRedirect("Redirect to scheme {!r} is not permitted".format(parsed.scheme)) diff --git a/oauth2_provider/management/commands/createapplication.py b/oauth2_provider/management/commands/createapplication.py index 95cb2d865..92c4ae46b 100644 --- a/oauth2_provider/management/commands/createapplication.py +++ b/oauth2_provider/management/commands/createapplication.py @@ -72,15 +72,10 @@ def handle(self, *args, **options): try: new_application.full_clean() except ValidationError as exc: - errors = "\n ".join(["- " + err_key + ": " + str(err_value) for err_key, - err_value in exc.message_dict.items()]) - self.stdout.write( - self.style.ERROR( - "Please correct the following errors:\n %s" % errors - ) + errors = "\n ".join( + ["- " + err_key + ": " + str(err_value) for err_key, err_value in exc.message_dict.items()] ) + self.stdout.write(self.style.ERROR("Please correct the following errors:\n %s" % errors)) else: new_application.save() - self.stdout.write( - self.style.SUCCESS("New application created successfully") - ) + self.stdout.write(self.style.SUCCESS("New application created successfully")) diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 5676bc0c5..77542d35f 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -39,6 +39,7 @@ class AbstractApplication(models.Model): the registration process as described in :rfc:`2.2` * :attr:`name` Friendly name for the Application """ + CLIENT_CONFIDENTIAL = "confidential" CLIENT_PUBLIC = "public" CLIENT_TYPES = ( @@ -58,22 +59,21 @@ class AbstractApplication(models.Model): ) id = models.BigAutoField(primary_key=True) - client_id = models.CharField( - max_length=100, unique=True, default=generate_client_id, db_index=True - ) + client_id = models.CharField(max_length=100, unique=True, default=generate_client_id, db_index=True) user = models.ForeignKey( settings.AUTH_USER_MODEL, related_name="%(app_label)s_%(class)s", - null=True, blank=True, on_delete=models.CASCADE + null=True, + blank=True, + on_delete=models.CASCADE, ) redirect_uris = models.TextField( - blank=True, help_text=_("Allowed URIs list, space separated"), + blank=True, + help_text=_("Allowed URIs list, space separated"), ) client_type = models.CharField(max_length=32, choices=CLIENT_TYPES) - authorization_grant_type = models.CharField( - max_length=32, choices=GRANT_TYPES - ) + authorization_grant_type = models.CharField(max_length=32, choices=GRANT_TYPES) client_secret = models.CharField( max_length=255, blank=True, default=generate_client_secret, db_index=True ) @@ -115,9 +115,11 @@ def redirect_uri_allowed(self, uri): for allowed_uri in self.redirect_uris.split(): parsed_allowed_uri = urlparse(allowed_uri) - if (parsed_allowed_uri.scheme == parsed_uri.scheme and - parsed_allowed_uri.netloc == parsed_uri.netloc and - parsed_allowed_uri.path == parsed_uri.path): + if ( + parsed_allowed_uri.scheme == parsed_uri.scheme + and parsed_allowed_uri.netloc == parsed_uri.netloc + and parsed_allowed_uri.path == parsed_uri.path + ): aqs_set = set(parse_qsl(parsed_allowed_uri.query)) @@ -143,14 +145,14 @@ def clean(self): validator(uri) scheme = urlparse(uri).scheme if scheme not in allowed_schemes: - raise ValidationError(_( - "Unauthorized redirect scheme: {scheme}" - ).format(scheme=scheme)) + raise ValidationError(_("Unauthorized redirect scheme: {scheme}").format(scheme=scheme)) elif self.authorization_grant_type in grant_types: - raise ValidationError(_( - "redirect_uris cannot be empty with grant_type {grant_type}" - ).format(grant_type=self.authorization_grant_type)) + raise ValidationError( + _("redirect_uris cannot be empty with grant_type {grant_type}").format( + grant_type=self.authorization_grant_type + ) + ) def get_absolute_url(self): return reverse("oauth2_provider:detail", args=[str(self.id)]) @@ -206,22 +208,17 @@ class AbstractGrant(models.Model): * :attr:`code_challenge` PKCE code challenge * :attr:`code_challenge_method` PKCE code challenge transform algorithm """ + CODE_CHALLENGE_PLAIN = "plain" CODE_CHALLENGE_S256 = "S256" - CODE_CHALLENGE_METHODS = ( - (CODE_CHALLENGE_PLAIN, "plain"), - (CODE_CHALLENGE_S256, "S256") - ) + CODE_CHALLENGE_METHODS = ((CODE_CHALLENGE_PLAIN, "plain"), (CODE_CHALLENGE_S256, "S256")) id = models.BigAutoField(primary_key=True) user = models.ForeignKey( - settings.AUTH_USER_MODEL, on_delete=models.CASCADE, - related_name="%(app_label)s_%(class)s" + settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="%(app_label)s_%(class)s" ) code = models.CharField(max_length=255, unique=True) # code comes from oauthlib - application = models.ForeignKey( - oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE - ) + application = models.ForeignKey(oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE) expires = models.DateTimeField() redirect_uri = models.CharField(max_length=255) scope = models.TextField(blank=True) @@ -231,7 +228,8 @@ class AbstractGrant(models.Model): code_challenge = models.CharField(max_length=128, blank=True, default="") code_challenge_method = models.CharField( - max_length=10, blank=True, default="", choices=CODE_CHALLENGE_METHODS) + max_length=10, blank=True, default="", choices=CODE_CHALLENGE_METHODS + ) def is_expired(self): """ @@ -271,19 +269,32 @@ class AbstractAccessToken(models.Model): * :attr:`expires` Date and time of token expiration, in DateTime format * :attr:`scope` Allowed scopes """ + id = models.BigAutoField(primary_key=True) user = models.ForeignKey( - settings.AUTH_USER_MODEL, on_delete=models.CASCADE, blank=True, null=True, - related_name="%(app_label)s_%(class)s" + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + blank=True, + null=True, + related_name="%(app_label)s_%(class)s", ) source_refresh_token = models.OneToOneField( # unique=True implied by the OneToOneField - oauth2_settings.REFRESH_TOKEN_MODEL, on_delete=models.SET_NULL, blank=True, null=True, - related_name="refreshed_access_token" + oauth2_settings.REFRESH_TOKEN_MODEL, + on_delete=models.SET_NULL, + blank=True, + null=True, + related_name="refreshed_access_token", + ) + token = models.CharField( + max_length=255, + unique=True, ) - token = models.CharField(max_length=255, unique=True, ) application = models.ForeignKey( - oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE, blank=True, null=True, + oauth2_settings.APPLICATION_MODEL, + on_delete=models.CASCADE, + blank=True, + null=True, ) expires = models.DateTimeField() scope = models.TextField(blank=True) @@ -364,17 +375,19 @@ class AbstractRefreshToken(models.Model): bounded to * :attr:`revoked` Timestamp of when this refresh token was revoked """ + id = models.BigAutoField(primary_key=True) user = models.ForeignKey( - settings.AUTH_USER_MODEL, on_delete=models.CASCADE, - related_name="%(app_label)s_%(class)s" + settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name="%(app_label)s_%(class)s" ) token = models.CharField(max_length=255) - application = models.ForeignKey( - oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE) + application = models.ForeignKey(oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE) access_token = models.OneToOneField( - oauth2_settings.ACCESS_TOKEN_MODEL, on_delete=models.SET_NULL, blank=True, null=True, - related_name="refresh_token" + oauth2_settings.ACCESS_TOKEN_MODEL, + on_delete=models.SET_NULL, + blank=True, + null=True, + related_name="refresh_token", ) created = models.DateTimeField(auto_now_add=True) @@ -388,9 +401,11 @@ def revoke(self): access_token_model = get_access_token_model() refresh_token_model = get_refresh_token_model() with transaction.atomic(): - self = refresh_token_model.objects.filter( - pk=self.pk, revoked__isnull=True - ).select_for_update().first() + self = ( + refresh_token_model.objects.filter(pk=self.pk, revoked__isnull=True) + .select_for_update() + .first() + ) if not self: return @@ -407,7 +422,10 @@ def __str__(self): class Meta: abstract = True - unique_together = ("token", "revoked",) + unique_together = ( + "token", + "revoked", + ) class RefreshToken(AbstractRefreshToken): @@ -466,13 +484,9 @@ def clear_expired(): revoked.delete() expired.delete() else: - logger.info("refresh_expire_at is %s. No refresh tokens deleted.", - refresh_expire_at) + logger.info("refresh_expire_at is %s. No refresh tokens deleted.", refresh_expire_at) - access_tokens = access_token_model.objects.filter( - refresh_token__isnull=True, - expires__lt=now - ) + access_tokens = access_token_model.objects.filter(refresh_token__isnull=True, expires__lt=now) grants = grant_model.objects.filter(expires__lt=now) logger.info("%s Expired access tokens to be deleted", access_tokens.count()) diff --git a/oauth2_provider/oauth2_backends.py b/oauth2_provider/oauth2_backends.py index 6d8e68a2c..34b1c62cd 100644 --- a/oauth2_provider/oauth2_backends.py +++ b/oauth2_provider/oauth2_backends.py @@ -24,9 +24,7 @@ def __init__(self, server=None): validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS validator = validator_class() server_kwargs = oauth2_settings.server_kwargs - self.server = server or oauth2_settings.OAUTH2_SERVER_CLASS( - validator, **server_kwargs - ) + self.server = server or oauth2_settings.OAUTH2_SERVER_CLASS(validator, **server_kwargs) def _get_escaped_full_path(self, request): """ @@ -96,7 +94,8 @@ def validate_authorization_request(self, request): try: uri, http_method, body, headers = self._extract_params(request) scopes, credentials = self.server.validate_authorization_request( - uri, http_method=http_method, body=body, headers=headers) + uri, http_method=http_method, body=body, headers=headers + ) return scopes, credentials except oauth2.FatalClientError as error: @@ -117,24 +116,22 @@ def create_authorization_response(self, request, scopes, credentials, allow): """ try: if not allow: - raise oauth2.AccessDeniedError( - state=credentials.get("state", None)) + raise oauth2.AccessDeniedError(state=credentials.get("state", None)) # add current user to credentials. this will be used by OAUTH2_VALIDATOR_CLASS credentials["user"] = request.user headers, body, status = self.server.create_authorization_response( - uri=credentials["redirect_uri"], scopes=scopes, credentials=credentials) + uri=credentials["redirect_uri"], scopes=scopes, credentials=credentials + ) uri = headers.get("Location", None) return uri, headers, body, status except oauth2.FatalClientError as error: - raise FatalClientError( - error=error, redirect_uri=credentials["redirect_uri"]) + raise FatalClientError(error=error, redirect_uri=credentials["redirect_uri"]) except oauth2.OAuth2Error as error: - raise OAuthToolkitError( - error=error, redirect_uri=credentials["redirect_uri"]) + raise OAuthToolkitError(error=error, redirect_uri=credentials["redirect_uri"]) def create_token_response(self, request): """ @@ -145,8 +142,9 @@ def create_token_response(self, request): uri, http_method, body, headers = self._extract_params(request) extra_credentials = self._get_extra_credentials(request) - headers, body, status = self.server.create_token_response(uri, http_method, body, - headers, extra_credentials) + headers, body, status = self.server.create_token_response( + uri, http_method, body, headers, extra_credentials + ) uri = headers.get("Location", None) return uri, headers, body, status @@ -160,8 +158,7 @@ def create_revocation_response(self, request): """ uri, http_method, body, headers = self._extract_params(request) - headers, body, status = self.server.create_revocation_response( - uri, http_method, body, headers) + headers, body, status = self.server.create_revocation_response(uri, http_method, body, headers) uri = headers.get("Location", None) return uri, headers, body, status @@ -175,8 +172,7 @@ def verify_request(self, request, scopes): """ uri, http_method, body, headers = self._extract_params(request) - valid, r = self.server.verify_request( - uri, http_method, body, headers, scopes=scopes) + valid, r = self.server.verify_request(uri, http_method, body, headers, scopes=scopes) return valid, r def authenticate_client(self, request): diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 515353d6f..de707bb21 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -19,8 +19,11 @@ from .exceptions import FatalClientError from .models import ( - AbstractApplication, get_access_token_model, - get_application_model, get_grant_model, get_refresh_token_model + AbstractApplication, + get_access_token_model, + get_application_model, + get_grant_model, + get_refresh_token_model, ) from .scopes import get_scopes_backend from .settings import oauth2_settings @@ -29,14 +32,14 @@ log = logging.getLogger("oauth2_provider") GRANT_TYPE_MAPPING = { - "authorization_code": (AbstractApplication.GRANT_AUTHORIZATION_CODE, ), - "password": (AbstractApplication.GRANT_PASSWORD, ), - "client_credentials": (AbstractApplication.GRANT_CLIENT_CREDENTIALS, ), + "authorization_code": (AbstractApplication.GRANT_AUTHORIZATION_CODE,), + "password": (AbstractApplication.GRANT_PASSWORD,), + "client_credentials": (AbstractApplication.GRANT_CLIENT_CREDENTIALS,), "refresh_token": ( AbstractApplication.GRANT_AUTHORIZATION_CODE, AbstractApplication.GRANT_PASSWORD, AbstractApplication.GRANT_CLIENT_CREDENTIALS, - ) + ), } Application = get_application_model() @@ -91,10 +94,7 @@ def _authenticate_basic_auth(self, request): try: auth_string_decoded = b64_decoded.decode(encoding) except UnicodeDecodeError: - log.debug( - "Failed basic auth: %r can't be decoded as unicode by %r", - auth_string, encoding - ) + log.debug("Failed basic auth: %r can't be decoded as unicode by %r", auth_string, encoding) return False try: @@ -162,25 +162,33 @@ def _load_application(self, client_id, request): def _set_oauth2_error_on_request(self, request, access_token, scopes): if access_token is None: - error = OrderedDict([ - ("error", "invalid_token", ), - ("error_description", _("The access token is invalid."), ), - ]) + error = OrderedDict( + [ + ("error", "invalid_token"), + ("error_description", _("The access token is invalid.")), + ] + ) elif access_token.is_expired(): - error = OrderedDict([ - ("error", "invalid_token", ), - ("error_description", _("The access token has expired."), ), - ]) + error = OrderedDict( + [ + ("error", "invalid_token"), + ("error_description", _("The access token has expired.")), + ] + ) elif not access_token.allow_scopes(scopes): - error = OrderedDict([ - ("error", "insufficient_scope", ), - ("error_description", _("The access token is valid but does not have enough scope."), ), - ]) + error = OrderedDict( + [ + ("error", "insufficient_scope"), + ("error_description", _("The access token is valid but does not have enough scope.")), + ] + ) else: log.warning("OAuth2 access token is invalid for an unknown reason.") - error = OrderedDict([ - ("error", "invalid_token", ), - ]) + error = OrderedDict( + [ + ("error", "invalid_token"), + ] + ) request.oauth2_error = error return request @@ -270,7 +278,7 @@ def get_default_redirect_uri(self, client_id, request, *args, **kwargs): return request.client.default_redirect_uri def _get_token_from_authentication_server( - self, token, introspection_url, introspection_token, introspection_credentials + self, token, introspection_url, introspection_token, introspection_credentials ): """Use external introspection endpoint to "crack open" the token. :param introspection_url: introspection endpoint URL @@ -297,20 +305,18 @@ def _get_token_from_authentication_server( headers = {"Authorization": "Basic {}".format(basic_auth.decode("utf-8"))} try: - response = requests.post( - introspection_url, - data={"token": token}, headers=headers - ) + response = requests.post(introspection_url, data={"token": token}, headers=headers) except requests.exceptions.RequestException: log.exception("Introspection: Failed POST to %r in token lookup", introspection_url) return None # Log an exception when response from auth server is not successful if response.status_code != http.client.OK: - log.exception("Introspection: Failed to get a valid response " - "from authentication server. Status code: {}, " - "Reason: {}.".format(response.status_code, - response.reason)) + log.exception( + "Introspection: Failed to get a valid response " + "from authentication server. Status code: {}, " + "Reason: {}.".format(response.status_code, response.reason) + ) return None try: @@ -348,7 +354,8 @@ def _get_token_from_authentication_server( "application": None, "scope": scope, "expires": expires, - }) + }, + ) return access_token @@ -372,10 +379,7 @@ def validate_bearer_token(self, token, scopes, request): if not access_token or not access_token.is_valid(scopes): if introspection_url and (introspection_token or introspection_credentials): access_token = self._get_token_from_authentication_server( - token, - introspection_url, - introspection_token, - introspection_credentials + token, introspection_url, introspection_token, introspection_credentials ) if access_token and access_token.is_valid(scopes): @@ -406,7 +410,7 @@ def validate_grant_type(self, client_id, grant_type, client, request, *args, **k """ Validate both grant_type is a valid string and grant_type is allowed for current workflow """ - assert(grant_type in GRANT_TYPE_MAPPING) # mapping misconfiguration + assert grant_type in GRANT_TYPE_MAPPING # mapping misconfiguration return request.client.allows_grant_type(*GRANT_TYPE_MAPPING[grant_type]) def validate_response_type(self, client_id, response_type, client, request, *args, **kwargs): @@ -477,9 +481,12 @@ def save_bearer_token(self, token, request, *args, **kwargs): # expires_in is passed to Server on initialization # custom server class can have logic to override this - expires = timezone.now() + timedelta(seconds=token.get( - "expires_in", oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS, - )) + expires = timezone.now() + timedelta( + seconds=token.get( + "expires_in", + oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS, + ) + ) if request.grant_type == "client_credentials": request.user = None @@ -497,9 +504,11 @@ def save_bearer_token(self, token, request, *args, **kwargs): refresh_token_instance = getattr(request, "refresh_token_instance", None) # If we are to reuse tokens, and we can: do so - if not self.rotate_refresh_token(request) and \ - isinstance(refresh_token_instance, RefreshToken) and \ - refresh_token_instance.access_token: + if ( + not self.rotate_refresh_token(request) + and isinstance(refresh_token_instance, RefreshToken) + and refresh_token_instance.access_token + ): access_token = AccessToken.objects.select_for_update().get( pk=refresh_token_instance.access_token.pk @@ -551,9 +560,9 @@ def save_bearer_token(self, token, request, *args, **kwargs): # make sure that the token data we're returning matches # the existing token token["access_token"] = previous_access_token.token - token["refresh_token"] = RefreshToken.objects.filter( - access_token=previous_access_token - ).first().token + token["refresh_token"] = ( + RefreshToken.objects.filter(access_token=previous_access_token).first().token + ) token["scope"] = previous_access_token.scope # No refresh token should be created, just access token @@ -582,15 +591,12 @@ def _create_authorization_code(self, request, code, expires=None): redirect_uri=request.redirect_uri, scope=" ".join(request.scopes), code_challenge=request.code_challenge or "", - code_challenge_method=request.code_challenge_method or "" + code_challenge_method=request.code_challenge_method or "", ) def _create_refresh_token(self, request, refresh_token_code, access_token): return RefreshToken.objects.create( - user=request.user, - token=refresh_token_code, - application=request.client, - access_token=access_token + user=request.user, token=refresh_token_code, application=request.client, access_token=access_token ) def revoke_token(self, token, token_type_hint, request, *args, **kwargs): @@ -643,13 +649,13 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs """ null_or_recent = Q(revoked__isnull=True) | Q( - revoked__gt=timezone.now() - timedelta( - seconds=oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS - ) + revoked__gt=timezone.now() - timedelta(seconds=oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS) + ) + rt = ( + RefreshToken.objects.filter(null_or_recent, token=refresh_token) + .select_related("access_token") + .first() ) - rt = RefreshToken.objects.filter(null_or_recent, token=refresh_token).select_related( - "access_token" - ).first() if not rt: return False diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index 0135da8b7..42c08b676 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -55,19 +55,16 @@ "REFRESH_TOKEN_MODEL": REFRESH_TOKEN_MODEL, "REQUEST_APPROVAL_PROMPT": "force", "ALLOWED_REDIRECT_URI_SCHEMES": ["http", "https"], - # Special settings that will be evaluated at runtime "_SCOPES": [], "_DEFAULT_SCOPES": [], - # Resource Server with Token Introspection "RESOURCE_SERVER_INTROSPECTION_URL": None, "RESOURCE_SERVER_AUTH_TOKEN": None, "RESOURCE_SERVER_INTROSPECTION_CREDENTIALS": None, "RESOURCE_SERVER_TOKEN_CACHING_SECONDS": 36000, - # Whether or not PKCE is required - "PKCE_REQUIRED": False + "PKCE_REQUIRED": False, } # List of settings that cannot be empty diff --git a/oauth2_provider/urls.py b/oauth2_provider/urls.py index 4cf6d4c6d..c7ae526f0 100644 --- a/oauth2_provider/urls.py +++ b/oauth2_provider/urls.py @@ -23,8 +23,11 @@ re_path(r"^applications/(?P[\w-]+)/update/$", views.ApplicationUpdate.as_view(), name="update"), # Token management views re_path(r"^authorized_tokens/$", views.AuthorizedTokensListView.as_view(), name="authorized-token-list"), - re_path(r"^authorized_tokens/(?P[\w-]+)/delete/$", views.AuthorizedTokenDeleteView.as_view(), - name="authorized-token-delete"), + re_path( + r"^authorized_tokens/(?P[\w-]+)/delete/$", + views.AuthorizedTokenDeleteView.as_view(), + name="authorized-token-delete", + ), ] diff --git a/oauth2_provider/validators.py b/oauth2_provider/validators.py index f3f82102c..6c8fa3839 100644 --- a/oauth2_provider/validators.py +++ b/oauth2_provider/validators.py @@ -10,12 +10,9 @@ class URIValidator(URLValidator): scheme_re = r"^(?:[a-z][a-z0-9\.\-\+]*)://" dotless_domain_re = r"(?!-)[A-Z\d-]{1,63}(?=3.1.0 - m2r>=0.2.1 +deps = + sphinx<3 + oauthlib>=3.1.0 + m2r>=0.2.1 [testenv:py37-flake8] skip_install = True @@ -53,18 +56,19 @@ deps = flake8 flake8-isort flake8-quotes + flake8-black [testenv:install] deps = - twine - setuptools>=39.0 - wheel + twine + setuptools>=39.0 + wheel whitelist_externals= - rm + rm commands = - rm -rf dist - python setup.py sdist bdist_wheel - twine upload dist/* + rm -rf dist + python setup.py sdist bdist_wheel + twine upload dist/* [coverage:run] @@ -76,12 +80,16 @@ max-line-length = 110 exclude = docs/, oauth2_provider/migrations/, tests/migrations/, .tox/ application-import-names = oauth2_provider inline-quotes = double +extend-ignore = E203, W503 [isort] -balanced_wrapping = True default_section = THIRDPARTY known_first_party = oauth2_provider -line_length = 80 +line_length = 110 lines_after_imports = 2 -multi_line_output = 5 -skip = oauth2_provider/migrations/, .tox/ +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 0 +use_parentheses = True +ensure_newline_before_comments = True +skip = oauth2_provider/migrations/, .tox/, tests/migrations/ From 86e78b93a86f37c6731c23bfd21f340c91954136 Mon Sep 17 00:00:00 2001 From: Vaskevich Aleksander Date: Wed, 16 Dec 2020 12:01:07 +0300 Subject: [PATCH 13/53] #898 Added the ability to customize classes for django admin (#904) --- AUTHORS | 1 + CHANGELOG.md | 4 +- docs/settings.rst | 24 ++++++++++ oauth2_provider/admin.py | 42 +++++++++++------ oauth2_provider/models.py | 24 ++++++++++ oauth2_provider/settings.py | 63 +++++++++++++++++++------- tests/admin.py | 17 +++++++ tests/test_settings.py | 90 +++++++++++++++++++++++++++++++++++++ 8 files changed, 234 insertions(+), 31 deletions(-) create mode 100644 tests/admin.py create mode 100644 tests/test_settings.py diff --git a/AUTHORS b/AUTHORS index 4f9cd850b..7e03b37ed 100644 --- a/AUTHORS +++ b/AUTHORS @@ -9,6 +9,7 @@ Contributors Abhishek Patel Alessandro De Angelis +Aleksander Vaskevich Alan Crosswell Anvesh Agarwal Asif Saif Uddin diff --git a/CHANGELOG.md b/CHANGELOG.md index b10ebfe30..1cb02280a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [unreleased] +## [1.5.0] 2021-03-18 ### Added * #915 Add optional OpenID Connect support. -## [1.4.1] +### Changed +* #942 Help via defunct Google group replaced with using GitHub issues + +## [1.4.1] 2021-03-12 ### Changed * #925 OAuth2TokenMiddleware converted to new style middleware, and no longer extends MiddlewareMixin. diff --git a/setup.cfg b/setup.cfg index 03d614a7f..13d6cd0f9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = django-oauth-toolkit -version = 1.4.1 +version = 1.5.0 description = OAuth2 Provider for Django long_description = file: README.rst long_description_content_type = text/x-rst From 4bd9edcc5579b568d3465cce4e6d33ba304c8107 Mon Sep 17 00:00:00 2001 From: Ryan Luker Date: Thu, 30 Mar 2023 10:28:42 -0700 Subject: [PATCH 37/53] Fixup Readme for the short term local dev --- README.rst | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/README.rst b/README.rst index dc6d79ad0..1fc8a275b 100644 --- a/README.rst +++ b/README.rst @@ -88,16 +88,27 @@ Developer Setup --------------- To get started with running or building the module you will need to install a virtual environment:: - - python36 -m venv ~/venv/python36-django-oauth-toolkit - source ~/venv/python36-django-oauth-toolkit +``` + python39 -m venv ~/venv/python39-django-oauth-toolkit + source ~/venv/python39-django-oauth-toolkit +``` +or +``` + pyenv install 3.9.16 + pyenv virtualenv 3.9.16-toolkit + pyenv activate 3.9.16-toolkit +``` +then +``` python setup.py install python setup.py build +``` To run the tox tests for our specifically supported build use the following:: - +``` pip install tox - tox -e py36-django22 + tox -e py39-django32 +``` License ------- From 28506db8e4cdeb11a5669baa1c92f2dc9023cc6b Mon Sep 17 00:00:00 2001 From: Ryan Luker Date: Thu, 30 Mar 2023 10:31:50 -0700 Subject: [PATCH 38/53] Only test python 3.9 --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7d257b465..cf4b545e2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,7 +9,7 @@ jobs: fail-fast: false max-parallel: 5 matrix: - python-version: ['3.6', '3.7', '3.8', '3.9'] + python-version: ['3.9'] steps: - uses: actions/checkout@v2 From 77bf66d3dd183766ef9c24e459c87052dfb1531a Mon Sep 17 00:00:00 2001 From: Mackenzie Salloum Date: Thu, 30 Mar 2023 11:12:30 -0700 Subject: [PATCH 39/53] Use self with settings to match rest of tests --- tests/test_authorization_code.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index 9469b4248..e5e333dff 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -1094,8 +1094,8 @@ def test_id_token_public_oidc_capable(self): Check that the id token includes our custom iss """ iss_entity = "http://testserver/o" - oauth2_settings.OIDC_ISS_ENDPOINT = None - oauth2_settings.OIDC_USERINFO_ENDPOINT = None + self.oauth2_settings.OIDC_ISS_ENDPOINT = None + self.oauth2_settings.OIDC_USERINFO_ENDPOINT = None self.client.login(username="test_user", password="123456") @@ -1118,15 +1118,15 @@ def test_id_token_public_oidc_capable(self): # Unload and decode the jwt and check the iss entity matches content = json.loads(response.content.decode("utf-8")) - key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + key = jwk.JWK.from_pem(self.oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) jwt_token = jwt.JWT(key=key, jwt=content["id_token"]) # Find our testserver iss entity self.assertIn(iss_entity, jwt_token.claims) # Turn back on the OIDC specific endpoints - oauth2_settings.OIDC_ISS_ENDPOINT = "http://localhost" - oauth2_settings.OIDC_USERINFO_ENDPOINT = "http://localhost/userinfo/" + self.oauth2_settings.OIDC_ISS_ENDPOINT = "http://localhost" + self.oauth2_settings.OIDC_USERINFO_ENDPOINT = "http://localhost/userinfo/" def test_public_pkce_S256_authorize_get(self): """ From 347724a9af36adb656738e7191901658f15078c3 Mon Sep 17 00:00:00 2001 From: Mackenzie Salloum Date: Thu, 30 Mar 2023 11:13:00 -0700 Subject: [PATCH 40/53] Remove test since generate_at_hash no longer exists --- tests/test_oauth2_validators.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index 52a784c01..b2631411a 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -287,13 +287,6 @@ def test_save_bearer_token__with_new_token__calls_methods_to_create_access_and_r assert create_access_token_mock.call_count == 1 assert create_refresh_token_mock.call_count == 1 - def test_generate_at_hash(self): - # Values taken from spec, https://openid.net/specs/openid-connect-core-1_0.html#id_token-tokenExample - access_token = "jHkWEdUXMU1BwAsC4vtUsZwnNvTIxEl0z9K3vx5KF0Y" - at_hash = self.validator.generate_at_hash(access_token) - - assert at_hash == "77QmUPtjPfzWtF2AnpK9RQ" - @override_settings(TIME_ZONE="US/Eastern") def test_iat_timezone_awareness(self): # get_id_token_dictionary requires these fields to be set From ee9054a8c78d74f1e3420d12dfd52c045ffd73e9 Mon Sep 17 00:00:00 2001 From: Mackenzie Salloum Date: Thu, 30 Mar 2023 11:57:26 -0700 Subject: [PATCH 41/53] Move out test so we can use the request factory --- tests/test_oauth2_validators.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index b2631411a..57ed7c9e2 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -287,19 +287,7 @@ def test_save_bearer_token__with_new_token__calls_methods_to_create_access_and_r assert create_access_token_mock.call_count == 1 assert create_refresh_token_mock.call_count == 1 - @override_settings(TIME_ZONE="US/Eastern") - def test_iat_timezone_awareness(self): - # get_id_token_dictionary requires these fields to be set - self.request.client_id = self.application.client_id - self.request.user.last_login = timezone.now() - self.request.response_type = None - claims, __ = self.validator.get_id_token_dictionary(None, None, self.request) - # Remove several sig figs to improve test resilience - expected_time = int(calendar.timegm(timezone.now().timetuple()) / 1000) * 1000 - actual_time = int(claims["iat"] / 1000) * 1000 - - assert actual_time == expected_time class TestOAuth2ValidatorProvidesErrorData(TransactionTestCase): @@ -535,3 +523,23 @@ def test_validate_id_token_bad_token_no_aud(oauth2_settings, mocker, oidc_key): validator = OAuth2Validator() status = validator.validate_id_token(token.serialize(), ["openid"], mocker.sentinel.request) assert status is False + +@override_settings(TIME_ZONE="US/Eastern") +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_iat_timezone_awareness(oauth2_settings, rf): + # oauth2_settings.OIDC_ISS_ENDPOINT = "" + django_request = rf.get("/") + request = Request("/", headers=django_request.META) + user = mock.MagicMock() + user.last_login = timezone.now() + request.user = user + request.grant_type = "not client" + validator = OAuth2Validator() + request.client_id = "client_id" + # get_id_token_dictionary requires these fields to be set + claims, __ = validator.get_id_token_dictionary(None, None, request) + # Remove several sig figs to improve test resilience + expected_time = int(calendar.timegm(timezone.now().timetuple()) / 1000) * 1000 + actual_time = int(claims["iat"] / 1000) * 1000 + + assert actual_time == expected_time From 87bb44b71d2551f85f8f872c38afc8b2328229ef Mon Sep 17 00:00:00 2001 From: Ryan Luker Date: Thu, 30 Mar 2023 12:47:02 -0700 Subject: [PATCH 42/53] Fix test --- .gitignore | 1 + tests/test_authorization_code.py | 78 ++++++++++++++++---------------- 2 files changed, 40 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index 3643335d4..e1090f48d 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,7 @@ pip-log.txt .cache .pytest_cache .coverage +coverage.xml .tox .pytest_cache/ nosetests.xml diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index e5e333dff..9619c580a 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -1089,45 +1089,6 @@ def test_public(self): self.assertEqual(content["scope"], "read write") self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - def test_id_token_public_oidc_capable(self): - """ - Check that the id token includes our custom iss - """ - iss_entity = "http://testserver/o" - self.oauth2_settings.OIDC_ISS_ENDPOINT = None - self.oauth2_settings.OIDC_USERINFO_ENDPOINT = None - - self.client.login(username="test_user", password="123456") - - self.application.client_type = Application.CLIENT_PUBLIC - self.application.save() - authorization_code = self.get_auth(scope="openid") - - token_request_data = { - "grant_type": "authorization_code", - "code": authorization_code, - "redirect_uri": "http://example.org", - "client_id": self.application.client_id, - "scope": "openid", - } - - response = self.client.post( - reverse("oauth2_provider:token"), data=token_request_data - ) - self.assertEqual(response.status_code, 200) - - # Unload and decode the jwt and check the iss entity matches - content = json.loads(response.content.decode("utf-8")) - key = jwk.JWK.from_pem(self.oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) - jwt_token = jwt.JWT(key=key, jwt=content["id_token"]) - - # Find our testserver iss entity - self.assertIn(iss_entity, jwt_token.claims) - - # Turn back on the OIDC specific endpoints - self.oauth2_settings.OIDC_ISS_ENDPOINT = "http://localhost" - self.oauth2_settings.OIDC_USERINFO_ENDPOINT = "http://localhost/userinfo/" - def test_public_pkce_S256_authorize_get(self): """ Request an access token using client_type: public @@ -1646,6 +1607,45 @@ def setUp(self): self.application.algorithm = Application.RS256_ALGORITHM self.application.save() + def test_id_token_public_oidc_capable(self): + """ + Check that the id token includes our custom iss + """ + iss_entity = "http://testserver/o" + self.oauth2_settings.OIDC_ISS_ENDPOINT = None + self.oauth2_settings.OIDC_USERINFO_ENDPOINT = None + + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth(scope="openid") + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "scope": "openid", + } + + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data + ) + self.assertEqual(response.status_code, 200) + + # Unload and decode the jwt and check the iss entity matches + content = json.loads(response.content.decode("utf-8")) + key = jwk.JWK.from_pem(self.oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + jwt_token = jwt.JWT(key=key, jwt=content["id_token"]) + + # Find our testserver iss entity + self.assertIn(iss_entity, jwt_token.claims) + + # Turn back on the OIDC specific endpoints + self.oauth2_settings.OIDC_ISS_ENDPOINT = "http://localhost" + self.oauth2_settings.OIDC_USERINFO_ENDPOINT = "http://localhost/userinfo/" + def test_id_token_public(self): """ Request an access token using client_type: public From 3622206e49b42f91aeb8de91d44f4c4f46a13afa Mon Sep 17 00:00:00 2001 From: Ryan Luker Date: Thu, 30 Mar 2023 14:58:48 -0700 Subject: [PATCH 43/53] Add back old migration auto3 --- .../migrations/0003_auto_20200902_2022.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 oauth2_provider/migrations/0003_auto_20200902_2022.py diff --git a/oauth2_provider/migrations/0003_auto_20200902_2022.py b/oauth2_provider/migrations/0003_auto_20200902_2022.py new file mode 100644 index 000000000..684949c9d --- /dev/null +++ b/oauth2_provider/migrations/0003_auto_20200902_2022.py @@ -0,0 +1,48 @@ +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + +from oauth2_provider.settings import oauth2_settings + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('oauth2_provider', '0002_auto_20190406_1805'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='algorithm', + field=models.CharField(choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256', max_length=5), + ), + migrations.AlterField( + model_name='application', + name='authorization_grant_type', + field=models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32), + ), + migrations.CreateModel( + name='IDToken', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('token', models.TextField(unique=True)), + ('expires', models.DateTimeField()), + ('scope', models.TextField(blank=True)), + ('created', models.DateTimeField(auto_now_add=True)), + ('updated', models.DateTimeField(auto_now=True)), + ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=oauth2_settings.APPLICATION_MODEL)), + ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_idtoken', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'abstract': False, + 'swappable': 'OAUTH2_PROVIDER_ID_TOKEN_MODEL', + }, + ), + migrations.AddField( + model_name='accesstoken', + name='id_token', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=oauth2_settings.ID_TOKEN_MODEL), + ), + ] From e4d5e1b76af0d3a5fc6ad4bb71a264329ab57e2f Mon Sep 17 00:00:00 2001 From: Ryan Luker Date: Thu, 30 Mar 2023 15:30:08 -0700 Subject: [PATCH 44/53] Add merge migration --- ...3_auto_20200902_2022_0004_auto_20200902_2022.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 oauth2_provider/migrations/0005_merge_0003_auto_20200902_2022_0004_auto_20200902_2022.py diff --git a/oauth2_provider/migrations/0005_merge_0003_auto_20200902_2022_0004_auto_20200902_2022.py b/oauth2_provider/migrations/0005_merge_0003_auto_20200902_2022_0004_auto_20200902_2022.py new file mode 100644 index 000000000..aa9ddce40 --- /dev/null +++ b/oauth2_provider/migrations/0005_merge_0003_auto_20200902_2022_0004_auto_20200902_2022.py @@ -0,0 +1,14 @@ +# Generated by Django 3.2.15 on 2023-03-30 22:18 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('oauth2_provider', '0003_auto_20200902_2022'), + ('oauth2_provider', '0004_auto_20200902_2022'), + ] + + operations = [ + ] From 32f00516bfd316c9ee7a52879486483fe9fbcb9a Mon Sep 17 00:00:00 2001 From: Ryan Luker Date: Thu, 30 Mar 2023 15:36:06 -0700 Subject: [PATCH 45/53] Add back token on id token --- oauth2_provider/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index a21cb868b..ee1463569 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -506,6 +506,7 @@ class AbstractIDToken(models.Model): null=True, related_name="%(app_label)s_%(class)s", ) + token = models.TextField(unique=True) jti = models.UUIDField(unique=True, default=uuid.uuid4, editable=False, verbose_name="JWT Token ID") application = models.ForeignKey( oauth2_settings.APPLICATION_MODEL, From dc0dd81c45ef748d0aa0ba95ccb2b6d99aa71303 Mon Sep 17 00:00:00 2001 From: Ryan Luker Date: Thu, 30 Mar 2023 15:39:05 -0700 Subject: [PATCH 46/53] Add migration for new values --- .../migrations/0006_auto_20230330_1837.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 oauth2_provider/migrations/0006_auto_20230330_1837.py diff --git a/oauth2_provider/migrations/0006_auto_20230330_1837.py b/oauth2_provider/migrations/0006_auto_20230330_1837.py new file mode 100644 index 000000000..ce6aa367e --- /dev/null +++ b/oauth2_provider/migrations/0006_auto_20230330_1837.py @@ -0,0 +1,24 @@ +# Generated by Django 3.2.15 on 2023-03-30 22:37 + +from django.db import migrations, models +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('oauth2_provider', '0005_merge_0003_auto_20200902_2022_0004_auto_20200902_2022'), + ] + + operations = [ + migrations.AddField( + model_name='idtoken', + name='jti', + field=models.UUIDField(default=uuid.uuid4, editable=False, unique=True), + ), + migrations.AlterField( + model_name='application', + name='algorithm', + field=models.CharField(blank=True, choices=[('', 'No OIDC support'), ('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='', max_length=5), + ), + ] From 40ace7198618b3cee6bdc54af1e0f5bb71011023 Mon Sep 17 00:00:00 2001 From: Ryan Luker Date: Thu, 30 Mar 2023 15:46:03 -0700 Subject: [PATCH 47/53] Rewrite migration 0004 to remove already applied --- .../migrations/0004_auto_20200902_2022.py | 66 ++++++++++--------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/oauth2_provider/migrations/0004_auto_20200902_2022.py b/oauth2_provider/migrations/0004_auto_20200902_2022.py index 81dd20d04..00f68fe8b 100644 --- a/oauth2_provider/migrations/0004_auto_20200902_2022.py +++ b/oauth2_provider/migrations/0004_auto_20200902_2022.py @@ -15,37 +15,41 @@ class Migration(migrations.Migration): ] operations = [ - migrations.AddField( - model_name='application', - name='algorithm', - field=models.CharField(blank=True, choices=[("", "No OIDC support"), ('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='', max_length=5), - ), - migrations.AlterField( - model_name='application', - name='authorization_grant_type', - field=models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32), - ), - migrations.CreateModel( - name='IDToken', - fields=[ - ('id', models.BigAutoField(primary_key=True, serialize=False)), - ("jti", models.UUIDField(unique=True, default=uuid.uuid4, editable=False, verbose_name="JWT Token ID")), - ('expires', models.DateTimeField()), - ('scope', models.TextField(blank=True)), - ('created', models.DateTimeField(auto_now_add=True)), - ('updated', models.DateTimeField(auto_now=True)), - ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=oauth2_settings.APPLICATION_MODEL)), - ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_idtoken', to=settings.AUTH_USER_MODEL)), - ], - options={ - 'abstract': False, - 'swappable': 'OAUTH2_PROVIDER_ID_TOKEN_MODEL', - }, - ), - migrations.AddField( - model_name='accesstoken', - name='id_token', - field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=oauth2_settings.ID_TOKEN_MODEL), + migrations.SeparateDatabaseAndState( + state_operations=[ + migrations.AddField( + model_name='application', + name='algorithm', + field=models.CharField(blank=True, choices=[("", "No OIDC support"), ('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='', max_length=5), + ), + migrations.AlterField( + model_name='application', + name='authorization_grant_type', + field=models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32), + ), + migrations.CreateModel( + name='IDToken', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ("jti", models.UUIDField(unique=True, default=uuid.uuid4, editable=False, verbose_name="JWT Token ID")), + ('expires', models.DateTimeField()), + ('scope', models.TextField(blank=True)), + ('created', models.DateTimeField(auto_now_add=True)), + ('updated', models.DateTimeField(auto_now=True)), + ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=oauth2_settings.APPLICATION_MODEL)), + ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_idtoken', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'abstract': False, + 'swappable': 'OAUTH2_PROVIDER_ID_TOKEN_MODEL', + }, + ), + migrations.AddField( + model_name='accesstoken', + name='id_token', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=oauth2_settings.ID_TOKEN_MODEL), + ), + ] ), migrations.AddField( model_name="grant", From af476e94972bb4bffbf9be7116741d0951d7d431 Mon Sep 17 00:00:00 2001 From: Mackenzie Salloum Date: Thu, 30 Mar 2023 16:50:45 -0700 Subject: [PATCH 48/53] Can remove custom iat because oauthlib uses time.time() --- oauth2_provider/oauth2_validators.py | 2 -- tests/test_oauth2_validators.py | 20 -------------------- 2 files changed, 22 deletions(-) diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 639993452..9739c52ca 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -744,13 +744,11 @@ def get_id_token_dictionary(self, token, token_handler, request): expiration_time = timezone.now() + timedelta(seconds=oauth2_settings.ID_TOKEN_EXPIRE_SECONDS) # Required ID Token claims - # TODO sc-77176: Should we keep our aud + iat even though they are added by authlib now? claims.update( **{ "iss": self.get_oidc_issuer_endpoint(request), "aud": request.client_id, "exp": int(dateformat.format(expiration_time, "U")), - "iat": int(dateformat.format(timezone.now(), "U")), "auth_time": int(dateformat.format(request.user.last_login, "U")), "jti": str(uuid.uuid4()), } diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index 57ed7c9e2..539ca9178 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -523,23 +523,3 @@ def test_validate_id_token_bad_token_no_aud(oauth2_settings, mocker, oidc_key): validator = OAuth2Validator() status = validator.validate_id_token(token.serialize(), ["openid"], mocker.sentinel.request) assert status is False - -@override_settings(TIME_ZONE="US/Eastern") -@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) -def test_iat_timezone_awareness(oauth2_settings, rf): - # oauth2_settings.OIDC_ISS_ENDPOINT = "" - django_request = rf.get("/") - request = Request("/", headers=django_request.META) - user = mock.MagicMock() - user.last_login = timezone.now() - request.user = user - request.grant_type = "not client" - validator = OAuth2Validator() - request.client_id = "client_id" - # get_id_token_dictionary requires these fields to be set - claims, __ = validator.get_id_token_dictionary(None, None, request) - # Remove several sig figs to improve test resilience - expected_time = int(calendar.timegm(timezone.now().timetuple()) / 1000) * 1000 - actual_time = int(claims["iat"] / 1000) * 1000 - - assert actual_time == expected_time From 3ffb051bc2d558c575f3e9cb8c9f76dcfa099db2 Mon Sep 17 00:00:00 2001 From: Mackenzie Salloum Date: Thu, 30 Mar 2023 17:11:11 -0700 Subject: [PATCH 49/53] Remove unneeded aud claim, its provided by oauthlib --- oauth2_provider/oauth2_validators.py | 1 - 1 file changed, 1 deletion(-) diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 9739c52ca..f91c06011 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -747,7 +747,6 @@ def get_id_token_dictionary(self, token, token_handler, request): claims.update( **{ "iss": self.get_oidc_issuer_endpoint(request), - "aud": request.client_id, "exp": int(dateformat.format(expiration_time, "U")), "auth_time": int(dateformat.format(request.user.last_login, "U")), "jti": str(uuid.uuid4()), From f87d071dbe8126179e04926c0317a7f609e0eda6 Mon Sep 17 00:00:00 2001 From: Mackenzie Salloum Date: Thu, 30 Mar 2023 17:11:29 -0700 Subject: [PATCH 50/53] Double write token and jti fields --- oauth2_provider/oauth2_validators.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index f91c06011..9fab8c07d 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -704,14 +704,15 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs return rt.application == client @transaction.atomic - def _save_id_token(self, jti, request, expires, *args, **kwargs): + def _save_id_token(self, id_token, request, expires, *args, **kwargs): scopes = request.scope or " ".join(request.scopes) id_token = IDToken.objects.create( + token=id_token, user=request.user, scope=scopes, expires=expires, - jti=jti, + jti=id_token["jti"], application=request.client, ) return id_token @@ -779,7 +780,7 @@ def finalize_id_token(self, id_token, token, token_handler, request): claims=json.dumps(id_token, default=str), ) jwt_token.make_signed_token(request.client.jwk_key) - id_token = self._save_id_token(id_token["jti"], request, expiration_time) + id_token = self._save_id_token(id_token, request, expiration_time) # this is needed by django rest framework request.access_token = id_token request.id_token = id_token From 2e734e58201c75850aeb141869ad6f175b957de3 Mon Sep 17 00:00:00 2001 From: Mackenzie Salloum Date: Fri, 31 Mar 2023 11:43:58 -0700 Subject: [PATCH 51/53] Lookup using token for now --- oauth2_provider/oauth2_validators.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 9fab8c07d..8aac3d0af 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -816,8 +816,10 @@ def _load_id_token(self, token): return None try: jwt_token = jwt.JWT(key=key, jwt=token) - claims = json.loads(jwt_token.claims) - return IDToken.objects.get(jti=claims["jti"]) + # TODO: Once double write is live, we can read from jti + # claims = json.loads(jwt_token.claims) + # return IDToken.objects.get(jti=claims["jti"]) + return IDToken.objects.get(token=jwt_token.serialize()) except (JWException, JWTExpired, IDToken.DoesNotExist): return None From 38974367daaf59aafdd8c7c599453c42d2b42907 Mon Sep 17 00:00:00 2001 From: Mackenzie Salloum Date: Fri, 31 Mar 2023 11:56:38 -0700 Subject: [PATCH 52/53] Use correct token --- oauth2_provider/oauth2_validators.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 8aac3d0af..58d56760c 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -704,11 +704,12 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs return rt.application == client @transaction.atomic - def _save_id_token(self, id_token, request, expires, *args, **kwargs): + def _save_id_token(self, id_token, jwt_token, request, expires, *args, **kwargs): scopes = request.scope or " ".join(request.scopes) id_token = IDToken.objects.create( - token=id_token, + # TODO 2: Once reading from jti is live, we can write to token + token=jwt_token.serialize(), user=request.user, scope=scopes, expires=expires, @@ -780,7 +781,7 @@ def finalize_id_token(self, id_token, token, token_handler, request): claims=json.dumps(id_token, default=str), ) jwt_token.make_signed_token(request.client.jwk_key) - id_token = self._save_id_token(id_token, request, expiration_time) + id_token = self._save_id_token(id_token, jwt_token, request, expiration_time) # this is needed by django rest framework request.access_token = id_token request.id_token = id_token @@ -816,7 +817,7 @@ def _load_id_token(self, token): return None try: jwt_token = jwt.JWT(key=key, jwt=token) - # TODO: Once double write is live, we can read from jti + # TODO 1: Once double write is live, we can read from jti # claims = json.loads(jwt_token.claims) # return IDToken.objects.get(jti=claims["jti"]) return IDToken.objects.get(token=jwt_token.serialize()) From 2cf5391cd4971261689ae3ad76068e7e2f531cf7 Mon Sep 17 00:00:00 2001 From: Mackenzie Salloum Date: Fri, 31 Mar 2023 12:17:58 -0700 Subject: [PATCH 53/53] Update todos to refer to tickets --- oauth2_provider/oauth2_validators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 58d56760c..f42cdf57d 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -708,7 +708,7 @@ def _save_id_token(self, id_token, jwt_token, request, expires, *args, **kwargs) scopes = request.scope or " ".join(request.scopes) id_token = IDToken.objects.create( - # TODO 2: Once reading from jti is live, we can write to token + # TODO sc-77179: Once reading from jti is live, we can stop writing to token. token=jwt_token.serialize(), user=request.user, scope=scopes, @@ -817,7 +817,7 @@ def _load_id_token(self, token): return None try: jwt_token = jwt.JWT(key=key, jwt=token) - # TODO 1: Once double write is live, we can read from jti + # TODO sc-77179: Once double write is live, we can read from jti # claims = json.loads(jwt_token.claims) # return IDToken.objects.get(jti=claims["jti"]) return IDToken.objects.get(token=jwt_token.serialize())