diff --git a/.gitignore b/.gitignore index af644d1e3..c22ef00fa 100644 --- a/.gitignore +++ b/.gitignore @@ -25,7 +25,7 @@ __pycache__ pip-log.txt # Unit test / coverage reports -.cache +.pytest_cache .coverage .tox .pytest_cache/ diff --git a/oauth2_provider/admin.py b/oauth2_provider/admin.py index 8b963d981..a8d69e623 100644 --- a/oauth2_provider/admin.py +++ b/oauth2_provider/admin.py @@ -2,7 +2,7 @@ from .models import ( get_access_token_model, get_application_model, - get_grant_model, get_refresh_token_model + get_grant_model, get_id_token_model, get_refresh_token_model ) @@ -26,6 +26,11 @@ class AccessTokenAdmin(admin.ModelAdmin): raw_id_fields = ("user", "source_refresh_token") +class IDTokenAdmin(admin.ModelAdmin): + list_display = ("token", "user", "application", "expires") + raw_id_fields = ("user", ) + + class RefreshTokenAdmin(admin.ModelAdmin): list_display = ("token", "user", "application") raw_id_fields = ("user", "access_token") @@ -34,9 +39,11 @@ class RefreshTokenAdmin(admin.ModelAdmin): Application = get_application_model() Grant = get_grant_model() AccessToken = get_access_token_model() +IDToken = get_id_token_model() RefreshToken = get_refresh_token_model() admin.site.register(Application, ApplicationAdmin) admin.site.register(Grant, GrantAdmin) admin.site.register(AccessToken, AccessTokenAdmin) +admin.site.register(IDToken, IDTokenAdmin) admin.site.register(RefreshToken, RefreshTokenAdmin) diff --git a/oauth2_provider/forms.py b/oauth2_provider/forms.py index 2e465959a..41129c449 100644 --- a/oauth2_provider/forms.py +++ b/oauth2_provider/forms.py @@ -5,6 +5,7 @@ class AllowForm(forms.Form): allow = forms.BooleanField(required=False) redirect_uri = forms.CharField(widget=forms.HiddenInput()) scope = forms.CharField(widget=forms.HiddenInput()) + nonce = forms.CharField(required=False, widget=forms.HiddenInput()) client_id = forms.CharField(widget=forms.HiddenInput()) state = forms.CharField(required=False, widget=forms.HiddenInput()) response_type = forms.CharField(widget=forms.HiddenInput()) diff --git a/oauth2_provider/migrations/0002_auto_20190406_1805.py b/oauth2_provider/migrations/0002_auto_20190406_1805.py index 8ca177abf..bcacc23ce 100644 --- a/oauth2_provider/migrations/0002_auto_20190406_1805.py +++ b/oauth2_provider/migrations/0002_auto_20190406_1805.py @@ -1,5 +1,3 @@ -# Generated by Django 2.2 on 2019-04-06 18:05 - from django.db import migrations, models diff --git a/oauth2_provider/migrations/0003_auto_20190413_2007.py b/oauth2_provider/migrations/0003_auto_20190413_2007.py new file mode 100644 index 000000000..b27bd4ebb --- /dev/null +++ b/oauth2_provider/migrations/0003_auto_20190413_2007.py @@ -0,0 +1,21 @@ +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('oauth2_provider', '0002_auto_20190406_1805'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='algorithm', + field=models.CharField(choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256', max_length=5), + ), + migrations.AlterField( + model_name='application', + name='authorization_grant_type', + field=models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32), + ), + ] diff --git a/oauth2_provider/migrations/0004_idtoken.py b/oauth2_provider/migrations/0004_idtoken.py new file mode 100644 index 000000000..853a7089f --- /dev/null +++ b/oauth2_provider/migrations/0004_idtoken.py @@ -0,0 +1,31 @@ +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('oauth2_provider', '0003_auto_20190413_2007'), + ] + + operations = [ + migrations.CreateModel( + name='IDToken', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('token', models.TextField(unique=True)), + ('expires', models.DateTimeField()), + ('scope', models.TextField(blank=True)), + ('created', models.DateTimeField(auto_now_add=True)), + ('updated', models.DateTimeField(auto_now=True)), + ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)), + ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_idtoken', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'abstract': False, + 'swappable': 'OAUTH2_PROVIDER_ID_TOKEN_MODEL', + }, + ), + ] diff --git a/oauth2_provider/migrations/0005_accesstoken_id_token.py b/oauth2_provider/migrations/0005_accesstoken_id_token.py new file mode 100644 index 000000000..0a14a058c --- /dev/null +++ b/oauth2_provider/migrations/0005_accesstoken_id_token.py @@ -0,0 +1,18 @@ +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('oauth2_provider', '0004_idtoken'), + ] + + operations = [ + migrations.AddField( + model_name='accesstoken', + name='id_token', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=settings.OAUTH2_PROVIDER_ID_TOKEN_MODEL), + ), + ] diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index f87a51691..1421c89eb 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -1,3 +1,4 @@ +import json import logging from datetime import timedelta from urllib.parse import parse_qsl, urlparse @@ -9,6 +10,7 @@ from django.urls import reverse from django.utils import timezone from django.utils.translation import gettext_lazy as _ +from jwcrypto import jwk, jwt from .generators import generate_client_id, generate_client_secret from .scopes import get_scopes_backend @@ -50,11 +52,20 @@ class AbstractApplication(models.Model): GRANT_IMPLICIT = "implicit" GRANT_PASSWORD = "password" GRANT_CLIENT_CREDENTIALS = "client-credentials" + GRANT_OPENID_HYBRID = "openid-hybrid" GRANT_TYPES = ( (GRANT_AUTHORIZATION_CODE, _("Authorization code")), (GRANT_IMPLICIT, _("Implicit")), (GRANT_PASSWORD, _("Resource owner password-based")), (GRANT_CLIENT_CREDENTIALS, _("Client credentials")), + (GRANT_OPENID_HYBRID, _("OpenID connect hybrid")), + ) + + RS256_ALGORITHM = "RS256" + HS256_ALGORITHM = "HS256" + ALGORITHM_TYPES = ( + (RS256_ALGORITHM, _("RSA with SHA-2 256")), + (HS256_ALGORITHM, _("HMAC with SHA-2 256")), ) id = models.BigAutoField(primary_key=True) @@ -82,6 +93,7 @@ class AbstractApplication(models.Model): created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) + algorithm = models.CharField(max_length=5, choices=ALGORITHM_TYPES, default=RS256_ALGORITHM) class Meta: abstract = True @@ -282,6 +294,10 @@ class AbstractAccessToken(models.Model): related_name="refreshed_access_token" ) token = models.CharField(max_length=255, unique=True, ) + id_token = models.OneToOneField( + oauth2_settings.ID_TOKEN_MODEL, on_delete=models.CASCADE, blank=True, null=True, + related_name="access_token" + ) application = models.ForeignKey( oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE, blank=True, null=True, ) @@ -415,6 +431,99 @@ class Meta(AbstractRefreshToken.Meta): swappable = "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL" +class AbstractIDToken(models.Model): + """ + An IDToken instance represents the actual token to + access user's resources, as in :openid:`2`. + + Fields: + + * :attr:`user` The Django user representing resources' owner + * :attr:`token` ID token + * :attr:`application` Application instance + * :attr:`expires` Date and time of token expiration, in DateTime format + * :attr:`scope` Allowed scopes + """ + id = models.BigAutoField(primary_key=True) + user = models.ForeignKey( + settings.AUTH_USER_MODEL, on_delete=models.CASCADE, blank=True, null=True, + related_name="%(app_label)s_%(class)s" + ) + token = models.TextField(unique=True) + application = models.ForeignKey( + oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE, blank=True, null=True, + ) + expires = models.DateTimeField() + scope = models.TextField(blank=True) + + created = models.DateTimeField(auto_now_add=True) + updated = models.DateTimeField(auto_now=True) + + def is_valid(self, scopes=None): + """ + Checks if the access token is valid. + + :param scopes: An iterable containing the scopes to check or None + """ + return not self.is_expired() and self.allow_scopes(scopes) + + def is_expired(self): + """ + Check token expiration with timezone awareness + """ + if not self.expires: + return True + + return timezone.now() >= self.expires + + def allow_scopes(self, scopes): + """ + Check if the token allows the provided scopes + + :param scopes: An iterable containing the scopes to check + """ + if not scopes: + return True + + provided_scopes = set(self.scope.split()) + resource_scopes = set(scopes) + + return resource_scopes.issubset(provided_scopes) + + def revoke(self): + """ + Convenience method to uniform tokens' interface, for now + simply remove this token from the database in order to revoke it. + """ + self.delete() + + @property + def scopes(self): + """ + Returns a dictionary of allowed scope names (as keys) with their descriptions (as values) + """ + all_scopes = get_scopes_backend().get_all_scopes() + token_scopes = self.scope.split() + return {name: desc for name, desc in all_scopes.items() if name in token_scopes} + + @property + def claims(self): + key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + jwt_token = jwt.JWT(key=key, jwt=self.token) + return json.loads(jwt_token.claims) + + def __str__(self): + return self.token + + class Meta: + abstract = True + + +class IDToken(AbstractIDToken): + class Meta(AbstractIDToken.Meta): + swappable = "OAUTH2_PROVIDER_ID_TOKEN_MODEL" + + def get_application_model(): """ Return the Application model that is active in this project. """ return apps.get_model(oauth2_settings.APPLICATION_MODEL) @@ -430,6 +539,11 @@ def get_access_token_model(): return apps.get_model(oauth2_settings.ACCESS_TOKEN_MODEL) +def get_id_token_model(): + """ Return the AccessToken model that is active in this project. """ + return apps.get_model(oauth2_settings.ID_TOKEN_MODEL) + + def get_refresh_token_model(): """ Return the RefreshToken model that is active in this project. """ return apps.get_model(oauth2_settings.REFRESH_TOKEN_MODEL) diff --git a/oauth2_provider/oauth2_backends.py b/oauth2_provider/oauth2_backends.py index f8710fdb0..5bcdd7db8 100644 --- a/oauth2_provider/oauth2_backends.py +++ b/oauth2_provider/oauth2_backends.py @@ -102,7 +102,7 @@ def validate_authorization_request(self, request): except oauth2.OAuth2Error as error: raise OAuthToolkitError(error=error) - def create_authorization_response(self, request, scopes, credentials, allow): + def create_authorization_response(self, uri, request, scopes, credentials, body, allow): """ A wrapper method that calls create_authorization_response on `server_class` instance. @@ -110,7 +110,8 @@ def create_authorization_response(self, request, scopes, credentials, allow): :param request: The current django.http.HttpRequest object :param scopes: A list of provided scopes :param credentials: Authorization credentials dictionary containing - `client_id`, `state`, `redirect_uri`, `response_type` + `client_id`, `state`, `redirect_uri` and `response_type` + :param body: Other body parameters not used in credentials dictionary :param allow: True if the user authorize the client, otherwise False """ try: @@ -122,10 +123,10 @@ def create_authorization_response(self, request, scopes, credentials, allow): credentials["user"] = request.user headers, body, status = self.server.create_authorization_response( - uri=credentials["redirect_uri"], scopes=scopes, credentials=credentials) - uri = headers.get("Location", None) + uri=uri, scopes=scopes, credentials=credentials, body=body) + redirect_uri = headers.get("Location", None) - return uri, headers, body, status + return redirect_uri, headers, body, status except oauth2.FatalClientError as error: raise FatalClientError(error=error, redirect_uri=credentials["redirect_uri"]) diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 162112d21..f6ede19ac 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -1,5 +1,7 @@ import base64 import binascii +import hashlib +import json import logging from collections import OrderedDict from datetime import datetime, timedelta @@ -11,15 +13,19 @@ from django.core.exceptions import ObjectDoesNotExist from django.db import transaction from django.db.models import Q -from django.utils import timezone +from django.utils import dateformat, timezone from django.utils.timezone import make_aware from django.utils.translation import gettext_lazy as _ +from jwcrypto import jwk, jwt +from jwcrypto.common import JWException +from jwcrypto.jwt import JWTExpired from oauthlib.oauth2 import RequestValidator +from oauthlib.oauth2.rfc6749 import utils from .exceptions import FatalClientError from .models import ( - AbstractApplication, get_access_token_model, - get_application_model, get_grant_model, get_refresh_token_model + AbstractApplication, get_access_token_model, get_application_model, + get_grant_model, get_id_token_model, get_refresh_token_model ) from .scopes import get_scopes_backend from .settings import oauth2_settings @@ -28,18 +34,22 @@ log = logging.getLogger("oauth2_provider") GRANT_TYPE_MAPPING = { - "authorization_code": (AbstractApplication.GRANT_AUTHORIZATION_CODE, ), - "password": (AbstractApplication.GRANT_PASSWORD, ), - "client_credentials": (AbstractApplication.GRANT_CLIENT_CREDENTIALS, ), + "authorization_code": ( + AbstractApplication.GRANT_AUTHORIZATION_CODE, + AbstractApplication.GRANT_OPENID_HYBRID, + ), + "password": (AbstractApplication.GRANT_PASSWORD,), + "client_credentials": (AbstractApplication.GRANT_CLIENT_CREDENTIALS,), "refresh_token": ( AbstractApplication.GRANT_AUTHORIZATION_CODE, AbstractApplication.GRANT_PASSWORD, AbstractApplication.GRANT_CLIENT_CREDENTIALS, - ) + ), } Application = get_application_model() AccessToken = get_access_token_model() +IDToken = get_id_token_model() Grant = get_grant_model() RefreshToken = get_refresh_token_model() UserModel = get_user_model() @@ -92,12 +102,15 @@ def _authenticate_basic_auth(self, request): except UnicodeDecodeError: log.debug( "Failed basic auth: %r can't be decoded as unicode by %r", - auth_string, encoding + auth_string, + encoding, ) return False try: - client_id, client_secret = map(unquote_plus, auth_string_decoded.split(":", 1)) + client_id, client_secret = map( + unquote_plus, auth_string_decoded.split(":", 1) + ) except ValueError: log.debug("Failed basic auth, Invalid base64 encoding.") return False @@ -146,40 +159,57 @@ def _load_application(self, client_id, request): """ # we want to be sure that request has the client attribute! - assert hasattr(request, "client"), '"request" instance has no "client" attribute' + assert hasattr( + request, "client" + ), '"request" instance has no "client" attribute' try: - request.client = request.client or Application.objects.get(client_id=client_id) + request.client = request.client or Application.objects.get( + client_id=client_id + ) # Check that the application can be used (defaults to always True) if not request.client.is_usable(request): - log.debug("Failed body authentication: Application %r is disabled" % (client_id)) + log.debug( + "Failed body authentication: Application %r is disabled" + % (client_id) + ) return None return request.client except Application.DoesNotExist: - log.debug("Failed body authentication: Application %r does not exist" % (client_id)) + log.debug( + "Failed body authentication: Application %r does not exist" + % (client_id) + ) return None def _set_oauth2_error_on_request(self, request, access_token, scopes): if access_token is None: - error = OrderedDict([ - ("error", "invalid_token", ), - ("error_description", _("The access token is invalid."), ), - ]) + error = OrderedDict( + [ + ("error", "invalid_token",), + ("error_description", _("The access token is invalid."),), + ] + ) elif access_token.is_expired(): - error = OrderedDict([ - ("error", "invalid_token", ), - ("error_description", _("The access token has expired."), ), - ]) + error = OrderedDict( + [ + ("error", "invalid_token",), + ("error_description", _("The access token has expired."),), + ] + ) elif not access_token.allow_scopes(scopes): - error = OrderedDict([ - ("error", "insufficient_scope", ), - ("error_description", _("The access token is valid but does not have enough scope."), ), - ]) + error = OrderedDict( + [ + ("error", "insufficient_scope",), + ( + "error_description", + _("The access token is valid but does not have enough scope."), + ), + ] + ) else: log.warning("OAuth2 access token is invalid for an unknown reason.") - error = OrderedDict([ - ("error", "invalid_token", ), - ]) + error = OrderedDict([("error", "invalid_token",), ]) request.oauth2_error = error return request @@ -240,11 +270,15 @@ def authenticate_client_id(self, client_id, request, *args, **kwargs): proceed only if the client exists and is not of type "Confidential". """ if self._load_application(client_id, request) is not None: - log.debug("Application %r has type %r" % (client_id, request.client.client_type)) + log.debug( + "Application %r has type %r" % (client_id, request.client.client_type) + ) return request.client.client_type != AbstractApplication.CLIENT_CONFIDENTIAL return False - def confirm_redirect_uri(self, client_id, code, redirect_uri, client, *args, **kwargs): + def confirm_redirect_uri( + self, client_id, code, redirect_uri, client, *args, **kwargs + ): """ Ensure the redirect_uri is listed in the Application instance redirect_uris field """ @@ -269,7 +303,7 @@ def get_default_redirect_uri(self, client_id, request, *args, **kwargs): return request.client.default_redirect_uri def _get_token_from_authentication_server( - self, token, introspection_url, introspection_token, introspection_credentials + self, token, introspection_url, introspection_token, introspection_credentials ): """Use external introspection endpoint to "crack open" the token. :param introspection_url: introspection endpoint URL @@ -297,11 +331,12 @@ def _get_token_from_authentication_server( try: response = requests.post( - introspection_url, - data={"token": token}, headers=headers + introspection_url, data={"token": token}, headers=headers ) except requests.exceptions.RequestException: - log.exception("Introspection: Failed POST to %r in token lookup", introspection_url) + log.exception( + "Introspection: Failed POST to %r in token lookup", introspection_url + ) return None try: @@ -339,7 +374,8 @@ def _get_token_from_authentication_server( "application": None, "scope": scope, "expires": expires, - }) + }, + ) return access_token @@ -352,10 +388,14 @@ def validate_bearer_token(self, token, scopes, request): introspection_url = oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL introspection_token = oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN - introspection_credentials = oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS + introspection_credentials = ( + oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS + ) try: - access_token = AccessToken.objects.select_related("application", "user").get(token=token) + access_token = AccessToken.objects.select_related( + "application", "user" + ).get(token=token) except AccessToken.DoesNotExist: access_token = None @@ -366,7 +406,7 @@ def validate_bearer_token(self, token, scopes, request): token, introspection_url, introspection_token, - introspection_credentials + introspection_credentials, ) if access_token and access_token.is_valid(scopes): @@ -393,22 +433,38 @@ def validate_code(self, client_id, code, client, request, *args, **kwargs): except Grant.DoesNotExist: return False - def validate_grant_type(self, client_id, grant_type, client, request, *args, **kwargs): + def validate_grant_type( + self, client_id, grant_type, client, request, *args, **kwargs + ): """ Validate both grant_type is a valid string and grant_type is allowed for current workflow """ - assert(grant_type in GRANT_TYPE_MAPPING) # mapping misconfiguration + assert grant_type in GRANT_TYPE_MAPPING # mapping misconfiguration return request.client.allows_grant_type(*GRANT_TYPE_MAPPING[grant_type]) - def validate_response_type(self, client_id, response_type, client, request, *args, **kwargs): + def validate_response_type( + self, client_id, response_type, client, request, *args, **kwargs + ): """ We currently do not support the Authorization Endpoint Response Types registry as in rfc:`8.4`, so validate the response_type only if it matches "code" or "token" """ if response_type == "code": - return client.allows_grant_type(AbstractApplication.GRANT_AUTHORIZATION_CODE) + return client.allows_grant_type( + AbstractApplication.GRANT_AUTHORIZATION_CODE + ) elif response_type == "token": return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) + elif response_type == "id_token": + return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) + elif response_type == "id_token token": + return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) + elif response_type == "code id_token": + return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) + elif response_type == "code token": + return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) + elif response_type == "code id_token token": + return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) else: return False @@ -416,11 +472,15 @@ def validate_scopes(self, client_id, scopes, client, request, *args, **kwargs): """ Ensure required scopes are permitted (as specified in the settings file) """ - available_scopes = get_scopes_backend().get_available_scopes(application=client, request=request) + available_scopes = get_scopes_backend().get_available_scopes( + application=client, request=request + ) return set(scopes).issubset(set(available_scopes)) def get_default_scopes(self, client_id, request, *args, **kwargs): - default_scopes = get_scopes_backend().get_default_scopes(application=request.client, request=request) + default_scopes = get_scopes_backend().get_default_scopes( + application=request.client, request=request + ) return default_scopes def validate_redirect_uri(self, client_id, redirect_uri, request, *args, **kwargs): @@ -447,7 +507,8 @@ def get_code_challenge_method(self, code, request): def save_authorization_code(self, client_id, code, request, *args, **kwargs): expires = timezone.now() + timedelta( - seconds=oauth2_settings.AUTHORIZATION_CODE_EXPIRE_SECONDS) + seconds=oauth2_settings.AUTHORIZATION_CODE_EXPIRE_SECONDS + ) Grant.objects.create( application=request.client, user=request.user, @@ -456,9 +517,27 @@ def save_authorization_code(self, client_id, code, request, *args, **kwargs): redirect_uri=request.redirect_uri, scope=" ".join(request.scopes), code_challenge=request.code_challenge or "", - code_challenge_method=request.code_challenge_method or "" + code_challenge_method=request.code_challenge_method or "", ) + def get_authorization_code_scopes(self, client_id, code, redirect_uri, request): + scopes = [] + fields = { + "code": code, + } + + if client_id: + fields["application__client_id"] = client_id + + if redirect_uri: + fields["redirect_uri"] = redirect_uri + + grant = Grant.objects.filter(**fields).values() + if grant.exists(): + grant_dict = dict(grant[0]) + scopes = utils.scope_to_list(grant_dict["scope"]) + return scopes + def rotate_refresh_token(self, request): """ Checks if rotate refresh token is enabled @@ -499,9 +578,11 @@ def save_bearer_token(self, token, request, *args, **kwargs): refresh_token_instance = getattr(request, "refresh_token_instance", None) # If we are to reuse tokens, and we can: do so - if not self.rotate_refresh_token(request) and \ - isinstance(refresh_token_instance, RefreshToken) and \ - refresh_token_instance.access_token: + if ( + not self.rotate_refresh_token(request) + and isinstance(refresh_token_instance, RefreshToken) + and refresh_token_instance.access_token + ): access_token = AccessToken.objects.select_for_update().get( pk=refresh_token_instance.access_token.pk @@ -548,14 +629,18 @@ def save_bearer_token(self, token, request, *args, **kwargs): source_refresh_token=refresh_token_instance, ) - self._create_refresh_token(request, refresh_token_code, access_token) + self._create_refresh_token( + request, refresh_token_code, access_token + ) else: # make sure that the token data we're returning matches # the existing token token["access_token"] = previous_access_token.token - token["refresh_token"] = RefreshToken.objects.filter( - access_token=previous_access_token - ).first().token + token["refresh_token"] = ( + RefreshToken.objects.filter(access_token=previous_access_token) + .first() + .token + ) token["scope"] = previous_access_token.scope # No refresh token should be created, just access token @@ -563,11 +648,15 @@ def save_bearer_token(self, token, request, *args, **kwargs): self._create_access_token(expires, request, token) def _create_access_token(self, expires, request, token, source_refresh_token=None): + id_token = token.get("id_token", None) + if id_token: + id_token = IDToken.objects.get(token=id_token) return AccessToken.objects.create( user=request.user, scope=token["scope"], expires=expires, token=token["access_token"], + id_token=id_token, application=request.client, source_refresh_token=source_refresh_token, ) @@ -577,7 +666,7 @@ def _create_refresh_token(self, request, refresh_token_code, access_token): user=request.user, token=refresh_token_code, application=request.client, - access_token=access_token + access_token=access_token, ) def revoke_token(self, token, token_type_hint, request, *args, **kwargs): @@ -630,9 +719,8 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs """ null_or_recent = Q(revoked__isnull=True) | Q( - revoked__gt=timezone.now() - timedelta( - seconds=oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS - ) + revoked__gt=timezone.now() + - timedelta(seconds=oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS) ) rt = RefreshToken.objects.filter(null_or_recent, token=refresh_token).select_related( "access_token" @@ -646,3 +734,141 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs # Temporary store RefreshToken instance to be reused by get_original_scopes and save_bearer_token. request.refresh_token_instance = rt return rt.application == client + + @transaction.atomic + def _save_id_token(self, token, request, expires, *args, **kwargs): + + scopes = request.scope or " ".join(request.scopes) + + if request.grant_type == "client_credentials": + request.user = None + + id_token = IDToken.objects.create( + user=request.user, + scope=scopes, + expires=expires, + token=token.serialize(), + application=request.client, + ) + return id_token + + def get_jwt_bearer_token(self, token, token_handler, request): + return self.get_id_token(token, token_handler, request) + + def get_id_token(self, token, token_handler, request): + + key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + + # TODO: http://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken2 + # Save the id_token on database bound to code when the request come to + # Authorization Endpoint and return the same one when request come to + # Token Endpoint + + # TODO: Check if at this point this request parameters are alredy validated + + expiration_time = timezone.now() + timedelta( + seconds=oauth2_settings.ID_TOKEN_EXPIRE_SECONDS + ) + # Required ID Token claims + claims = { + "iss": oauth2_settings.OIDC_ISS_ENDPOINT, + "sub": str(request.user.id), + "aud": request.client_id, + "exp": int(dateformat.format(expiration_time, "U")), + "iat": int(dateformat.format(datetime.utcnow(), "U")), + "auth_time": int(dateformat.format(request.user.last_login, "U")), + } + + nonce = getattr(request, "nonce", None) + if nonce: + claims["nonce"] = nonce + + # TODO: create a function to check if we should add at_hash + # http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken + # http://openid.net/specs/openid-connect-core-1_0.html#ImplicitIDToken + # if request.grant_type in 'authorization_code' and 'access_token' in token: + if ( + (request.grant_type == "authorization_code" and "access_token" in token) + or request.response_type == "code id_token token" + or (request.response_type == "id_token token" and "access_token" in token) + ): + acess_token = token["access_token"] + sha256 = hashlib.sha256(acess_token.encode("ascii")) + bits128 = sha256.hexdigest()[:16] + at_hash = base64.urlsafe_b64encode(bits128.encode("ascii")) + claims["at_hash"] = at_hash.decode("utf8") + + # TODO: create a function to check if we should include c_hash + # http://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken + if request.response_type in ("code id_token", "code id_token token"): + code = token["code"] + sha256 = hashlib.sha256(code.encode("ascii")) + bits256 = sha256.hexdigest()[:32] + c_hash = base64.urlsafe_b64encode(bits256.encode("ascii")) + claims["c_hash"] = c_hash.decode("utf8") + + jwt_token = jwt.JWT( + header=json.dumps({"alg": "RS256"}, default=str), + claims=json.dumps(claims, default=str), + ) + jwt_token.make_signed_token(key) + + id_token = self._save_id_token(jwt_token, request, expiration_time) + # this is needed by django rest framework + request.access_token = id_token + request.id_token = id_token + return jwt_token.serialize() + + def validate_jwt_bearer_token(self, token, scopes, request): + return self.validate_id_token(token, scopes, request) + + def validate_id_token(self, token, scopes, request): + """ + When users try to access resources, check that provided id_token is valid + """ + if not token: + return False + + key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + + try: + jwt_token = jwt.JWT(key=key, jwt=token) + id_token = IDToken.objects.get(token=jwt_token.serialize()) + request.client = id_token.application + request.user = id_token.user + request.scopes = scopes + # this is needed by django rest framework + request.access_token = id_token + return True + except (JWException, JWTExpired): + # TODO: This is the base exception of all jwcrypto + return False + + return False + + def validate_user_match(self, id_token_hint, scopes, claims, request): + # TODO: Fix to validate when necessary acording + # https://github.com/idan/oauthlib/blob/master/oauthlib/oauth2/rfc6749/request_validator.py#L556 + # http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest id_token_hint section + return True + + def get_authorization_code_nonce(self, client_id, code, redirect_uri, request): + """ Extracts nonce from saved authorization code. + If present in the Authentication Request, Authorization + Servers MUST include a nonce Claim in the ID Token with the + Claim Value being the nonce value sent in the Authentication + Request. Authorization Servers SHOULD perform no other + processing on nonce values used. The nonce value is a + case-sensitive string. + Only code param should be sufficient to retrieve grant code from + any storage you are using. However, `client_id` and `redirect_uri` + have been validated and can be used also. + :param client_id: Unicode client identifier + :param code: Unicode authorization code grant + :param redirect_uri: Unicode absolute URI + :return: Unicode nonce + Method is used by: + - Authorization Token Grant Dispatcher + """ + # TODO: Fix this ;) + return "" diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index 858efdbe7..d770cbd56 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -23,10 +23,19 @@ USER_SETTINGS = getattr(settings, "OAUTH2_PROVIDER", None) -APPLICATION_MODEL = getattr(settings, "OAUTH2_PROVIDER_APPLICATION_MODEL", "oauth2_provider.Application") -ACCESS_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL", "oauth2_provider.AccessToken") +APPLICATION_MODEL = getattr( + settings, "OAUTH2_PROVIDER_APPLICATION_MODEL", "oauth2_provider.Application" +) +ACCESS_TOKEN_MODEL = getattr( + settings, "OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL", "oauth2_provider.AccessToken" +) +ID_TOKEN_MODEL = getattr( + settings, "OAUTH2_PROVIDER_ID_TOKEN_MODEL", "oauth2_provider.IDToken" +) GRANT_MODEL = getattr(settings, "OAUTH2_PROVIDER_GRANT_MODEL", "oauth2_provider.Grant") -REFRESH_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL", "oauth2_provider.RefreshToken") +REFRESH_TOKEN_MODEL = getattr( + settings, "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL", "oauth2_provider.RefreshToken" +) DEFAULTS = { "CLIENT_ID_GENERATOR_CLASS": "oauth2_provider.generators.ClientIdGenerator", @@ -35,7 +44,7 @@ "ACCESS_TOKEN_GENERATOR": None, "REFRESH_TOKEN_GENERATOR": None, "EXTRA_SERVER_KWARGS": {}, - "OAUTH2_SERVER_CLASS": "oauthlib.oauth2.Server", + "OAUTH2_SERVER_CLASS": "oauthlib.openid.connect.core.endpoints.pre_configured.Server", "OAUTH2_VALIDATOR_CLASS": "oauth2_provider.oauth2_validators.OAuth2Validator", "OAUTH2_BACKEND_CLASS": "oauth2_provider.oauth2_backends.OAuthLibCore", "SCOPES": {"read": "Reading scope", "write": "Writing scope"}, @@ -45,29 +54,46 @@ "WRITE_SCOPE": "write", "AUTHORIZATION_CODE_EXPIRE_SECONDS": 60, "ACCESS_TOKEN_EXPIRE_SECONDS": 36000, + "ID_TOKEN_EXPIRE_SECONDS": 36000, "REFRESH_TOKEN_EXPIRE_SECONDS": None, "REFRESH_TOKEN_GRACE_PERIOD_SECONDS": 0, "ROTATE_REFRESH_TOKEN": True, "ERROR_RESPONSE_WITH_SCOPES": False, "APPLICATION_MODEL": APPLICATION_MODEL, "ACCESS_TOKEN_MODEL": ACCESS_TOKEN_MODEL, + "ID_TOKEN_MODEL": ID_TOKEN_MODEL, "GRANT_MODEL": GRANT_MODEL, "REFRESH_TOKEN_MODEL": REFRESH_TOKEN_MODEL, "REQUEST_APPROVAL_PROMPT": "force", "ALLOWED_REDIRECT_URI_SCHEMES": ["http", "https"], - + "OIDC_ISS_ENDPOINT": "", + "OIDC_USERINFO_ENDPOINT": "", + "OIDC_RSA_PRIVATE_KEY": "", + "OIDC_RESPONSE_TYPES_SUPPORTED": [ + "code", + "token", + "id_token", + "id_token token", + "code token", + "code id_token", + "code id_token token", + ], + "OIDC_SUBJECT_TYPES_SUPPORTED": ["public"], + "OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED": ["RS256", "HS256"], + "OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED": [ + "client_secret_post", + "client_secret_basic", + ], # Special settings that will be evaluated at runtime "_SCOPES": [], "_DEFAULT_SCOPES": [], - # Resource Server with Token Introspection "RESOURCE_SERVER_INTROSPECTION_URL": None, "RESOURCE_SERVER_AUTH_TOKEN": None, "RESOURCE_SERVER_INTROSPECTION_CREDENTIALS": None, "RESOURCE_SERVER_TOKEN_CACHING_SECONDS": 36000, - # Whether or not PKCE is required - "PKCE_REQUIRED": False + "PKCE_REQUIRED": False, } # List of settings that cannot be empty @@ -79,6 +105,13 @@ "OAUTH2_BACKEND_CLASS", "SCOPES", "ALLOWED_REDIRECT_URI_SCHEMES", + "OIDC_ISS_ENDPOINT", + "OIDC_USERINFO_ENDPOINT", + "OIDC_RSA_PRIVATE_KEY", + "OIDC_RESPONSE_TYPES_SUPPORTED", + "OIDC_SUBJECT_TYPES_SUPPORTED", + "OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED", + "OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED", ) # List of settings that may be in string import notation. @@ -117,7 +150,12 @@ def import_from_string(val, setting_name): module = importlib.import_module(module_path) return getattr(module, class_name) except ImportError as e: - msg = "Could not import %r for setting %r. %s: %s." % (val, setting_name, e.__class__.__name__, e) + msg = "Could not import %r for setting %r. %s: %s." % ( + val, + setting_name, + e.__class__.__name__, + e, + ) raise ImportError(msg) @@ -129,7 +167,9 @@ class OAuth2ProviderSettings(object): and return the class, rather than the string literal. """ - def __init__(self, user_settings=None, defaults=None, import_strings=None, mandatory=None): + def __init__( + self, user_settings=None, defaults=None, import_strings=None, mandatory=None + ): self.user_settings = user_settings or {} self.defaults = defaults or {} self.import_strings = import_strings or () @@ -164,7 +204,9 @@ def __getattr__(self, attr): if scope in self._SCOPES: val.append(scope) else: - raise ImproperlyConfigured("Defined DEFAULT_SCOPES not present in SCOPES") + raise ImproperlyConfigured( + "Defined DEFAULT_SCOPES not present in SCOPES" + ) self.validate_setting(attr, val) diff --git a/oauth2_provider/urls.py b/oauth2_provider/urls.py index 86d97d053..4baef4704 100644 --- a/oauth2_provider/urls.py +++ b/oauth2_provider/urls.py @@ -27,5 +27,12 @@ name="authorized-token-delete"), ] +oidc_urlpatterns = [ + url(r"^\.well-known/openid-configuration/$", views.ConnectDiscoveryInfoView.as_view(), + name="oidc-connect-discovery-info"), + url(r"^jwks/$", views.JwksInfoView.as_view(), name="jwks-info"), + url(r"^userinfo/$", views.UserInfoView.as_view(), name="user-info") +] + -urlpatterns = base_urlpatterns + management_urlpatterns +urlpatterns = base_urlpatterns + management_urlpatterns + oidc_urlpatterns diff --git a/oauth2_provider/views/__init__.py b/oauth2_provider/views/__init__.py index 7bf60cece..9f2ac4ff7 100644 --- a/oauth2_provider/views/__init__.py +++ b/oauth2_provider/views/__init__.py @@ -1,7 +1,13 @@ # flake8: noqa -from .base import AuthorizationView, TokenView, RevokeTokenView -from .application import ApplicationRegistration, ApplicationDetail, ApplicationList, \ - ApplicationDelete, ApplicationUpdate -from .generic import ProtectedResourceView, ScopedProtectedResourceView, ReadWriteScopedResourceView -from .token import AuthorizedTokensListView, AuthorizedTokenDeleteView +from .application import ( + ApplicationDelete, ApplicationDetail, ApplicationList, + ApplicationRegistration, ApplicationUpdate +) +from .base import AuthorizationView, RevokeTokenView, TokenView +from .generic import ( + ProtectedResourceView, ReadWriteScopedResourceView, + ScopedProtectedResourceView +) from .introspect import IntrospectTokenView +from .oidc import ConnectDiscoveryInfoView, JwksInfoView, UserInfoView +from .token import AuthorizedTokenDeleteView, AuthorizedTokensListView diff --git a/oauth2_provider/views/application.py b/oauth2_provider/views/application.py index c925493f5..b38c907ab 100644 --- a/oauth2_provider/views/application.py +++ b/oauth2_provider/views/application.py @@ -32,7 +32,7 @@ def get_form_class(self): get_application_model(), fields=( "name", "client_id", "client_secret", "client_type", - "authorization_grant_type", "redirect_uris" + "authorization_grant_type", "redirect_uris", "algorithm", ) ) @@ -81,6 +81,6 @@ def get_form_class(self): get_application_model(), fields=( "name", "client_id", "client_secret", "client_type", - "authorization_grant_type", "redirect_uris" + "authorization_grant_type", "redirect_uris", "algorithm", ) ) diff --git a/oauth2_provider/views/base.py b/oauth2_provider/views/base.py index 8a3a59c25..f1ad6e544 100644 --- a/oauth2_provider/views/base.py +++ b/oauth2_provider/views/base.py @@ -86,6 +86,7 @@ class AuthorizationView(BaseAuthorizationView, FormView): * Authorization code * Implicit grant """ + template_name = "oauth2_provider/authorize.html" form_class = AllowForm @@ -101,11 +102,14 @@ def get_initial(self): initial_data = { "redirect_uri": self.oauth2_data.get("redirect_uri", None), "scope": " ".join(scopes), + "nonce": self.oauth2_data.get("nonce", None), "client_id": self.oauth2_data.get("client_id", None), "state": self.oauth2_data.get("state", None), "response_type": self.oauth2_data.get("response_type", None), "code_challenge": self.oauth2_data.get("code_challenge", None), - "code_challenge_method": self.oauth2_data.get("code_challenge_method", None), + "code_challenge_method": self.oauth2_data.get( + "code_challenge_method", None + ), } return initial_data @@ -116,18 +120,27 @@ def form_valid(self, form): "client_id": form.cleaned_data.get("client_id"), "redirect_uri": form.cleaned_data.get("redirect_uri"), "response_type": form.cleaned_data.get("response_type", None), - "state": form.cleaned_data.get("state", None) + "state": form.cleaned_data.get("state", None), } if form.cleaned_data.get("code_challenge", False): credentials["code_challenge"] = form.cleaned_data.get("code_challenge") if form.cleaned_data.get("code_challenge_method", False): - credentials["code_challenge_method"] = form.cleaned_data.get("code_challenge_method") + credentials["code_challenge_method"] = form.cleaned_data.get( + "code_challenge_method" + ) + + body = {"nonce": form.cleaned_data.get("nonce")} scopes = form.cleaned_data.get("scope") allow = form.cleaned_data.get("allow") try: uri, headers, body, status = self.create_authorization_response( - request=self.request, scopes=scopes, credentials=credentials, allow=allow + self.request.get_raw_uri(), + request=self.request, + scopes=scopes, + credentials=credentials, + body=body, + allow=allow, ) except OAuthToolkitError as error: return self.error_response(error, application) @@ -142,12 +155,10 @@ def get(self, request, *args, **kwargs): # TODO: Remove the two following lines after oauthlib updates its implementation # https://github.com/jazzband/django-oauth-toolkit/pull/707#issuecomment-485011945 credentials["code_challenge"] = credentials.get( - "code_challenge", - request.GET.get("code_challenge", None) + "code_challenge", request.GET.get("code_challenge", None) ) credentials["code_challenge_method"] = credentials.get( - "code_challenge_method", - request.GET.get("code_challenge_method", None) + "code_challenge_method", request.GET.get("code_challenge_method", None) ) except OAuthToolkitError as error: # Application is not available at this time. @@ -159,7 +170,14 @@ def get(self, request, *args, **kwargs): # at this point we know an Application instance with such client_id exists in the database # TODO: Cache this! - application = get_application_model().objects.get(client_id=credentials["client_id"]) + application = get_application_model().objects.get( + client_id=credentials["client_id"] + ) + + uri_query = urllib.parse.urlparse(self.request.get_raw_uri()).query + uri_query_params = dict( + urllib.parse.parse_qsl(uri_query, keep_blank_values=True, strict_parsing=True) + ) kwargs["application"] = application kwargs["client_id"] = credentials["client_id"] @@ -168,6 +186,7 @@ def get(self, request, *args, **kwargs): kwargs["state"] = credentials["state"] kwargs["code_challenge"] = credentials["code_challenge"] kwargs["code_challenge_method"] = credentials["code_challenge_method"] + kwargs["nonce"] = uri_query_params.get("nonce", None) self.oauth2_data = kwargs # following two loc are here only because of https://code.djangoproject.com/ticket/17795 @@ -176,7 +195,9 @@ def get(self, request, *args, **kwargs): # Check to see if the user has already granted access and return # a successful response depending on "approval_prompt" url parameter - require_approval = request.GET.get("approval_prompt", oauth2_settings.REQUEST_APPROVAL_PROMPT) + require_approval = request.GET.get( + "approval_prompt", oauth2_settings.REQUEST_APPROVAL_PROMPT + ) try: # If skip_authorization field is True, skip the authorization screen even @@ -185,26 +206,36 @@ def get(self, request, *args, **kwargs): # are already approved. if application.skip_authorization: uri, headers, body, status = self.create_authorization_response( - request=self.request, scopes=" ".join(scopes), - credentials=credentials, allow=True + self.request.get_raw_uri(), + request=self.request, + scopes=" ".join(scopes), + credentials=credentials, + allow=True, ) return self.redirect(uri, application) elif require_approval == "auto": - tokens = get_access_token_model().objects.filter( - user=request.user, - application=kwargs["application"], - expires__gt=timezone.now() - ).all() + tokens = ( + get_access_token_model() + .objects.filter( + user=request.user, + application=kwargs["application"], + expires__gt=timezone.now(), + ) + .all() + ) # check past authorizations regarded the same scopes as the current one for token in tokens: if token.allow_scopes(scopes): uri, headers, body, status = self.create_authorization_response( - request=self.request, scopes=" ".join(scopes), - credentials=credentials, allow=True + self.request.get_raw_uri(), + request=self.request, + scopes=" ".join(scopes), + credentials=credentials, + allow=True, ) - return self.redirect(uri, application, token) + return self.redirect(uri, application) except OAuthToolkitError as error: return self.error_response(error, application) @@ -251,6 +282,7 @@ class TokenView(OAuthLibMixin, View): * Password * Client credentials """ + server_class = oauth2_settings.OAUTH2_SERVER_CLASS validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS @@ -261,11 +293,8 @@ def post(self, request, *args, **kwargs): if status == 200: access_token = json.loads(body).get("access_token") if access_token is not None: - token = get_access_token_model().objects.get( - token=access_token) - app_authorized.send( - sender=self, request=request, - token=token) + token = get_access_token_model().objects.get(token=access_token) + app_authorized.send(sender=self, request=request, token=token) response = HttpResponse(content=body, status=status) for k, v in headers.items(): @@ -278,6 +307,7 @@ class RevokeTokenView(OAuthLibMixin, View): """ Implements an endpoint to revoke access or refresh tokens """ + server_class = oauth2_settings.OAUTH2_SERVER_CLASS validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS diff --git a/oauth2_provider/views/mixins.py b/oauth2_provider/views/mixins.py index 0cc9bd589..5c596761f 100644 --- a/oauth2_provider/views/mixins.py +++ b/oauth2_provider/views/mixins.py @@ -97,7 +97,7 @@ def validate_authorization_request(self, request): core = self.get_oauthlib_core() return core.validate_authorization_request(request) - def create_authorization_response(self, request, scopes, credentials, allow): + def create_authorization_response(self, uri, request, scopes, credentials, allow, body=None): """ A wrapper method that calls create_authorization_response on `server_class` instance. @@ -105,14 +105,15 @@ def create_authorization_response(self, request, scopes, credentials, allow): :param request: The current django.http.HttpRequest object :param scopes: A space-separated string of provided scopes :param credentials: Authorization credentials dictionary containing - `client_id`, `state`, `redirect_uri`, `response_type` + `client_id`, `state`, `redirect_uri` and `response_type` :param allow: True if the user authorize the client, otherwise False + :param body: Other body parameters not used in credentials dictionary """ # TODO: move this scopes conversion from and to string into a utils function scopes = scopes.split(" ") if scopes else [] core = self.get_oauthlib_core() - return core.create_authorization_response(request, scopes, credentials, allow) + return core.create_authorization_response(uri, request, scopes, credentials, body, allow) def create_token_response(self, request): """ diff --git a/oauth2_provider/views/oidc.py b/oauth2_provider/views/oidc.py new file mode 100644 index 000000000..732965a5d --- /dev/null +++ b/oauth2_provider/views/oidc.py @@ -0,0 +1,64 @@ +from __future__ import absolute_import, unicode_literals + +import json + +from django.http import JsonResponse +from django.urls import reverse_lazy +from django.views.generic import View +from jwcrypto import jwk +from rest_framework.views import APIView + +from ..settings import oauth2_settings + + +class ConnectDiscoveryInfoView(View): + """ + View used to show oidc provider configuration information + """ + def get(self, request, *args, **kwargs): + issuer_url = oauth2_settings.OIDC_ISS_ENDPOINT + data = { + "issuer": issuer_url, + "authorization_endpoint": "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:authorize")), + "token_endpoint": "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:token")), + "userinfo_endpoint": oauth2_settings.OIDC_USERINFO_ENDPOINT, + "jwks_uri": "{}{}".format(issuer_url, reverse_lazy("oauth2_provider:jwks-info")), + "response_types_supported": oauth2_settings.OIDC_RESPONSE_TYPES_SUPPORTED, + "subject_types_supported": oauth2_settings.OIDC_SUBJECT_TYPES_SUPPORTED, + "id_token_signing_alg_values_supported": + oauth2_settings.OIDC_ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED, + "token_endpoint_auth_methods_supported": + oauth2_settings.OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED, + } + response = JsonResponse(data) + response["Access-Control-Allow-Origin"] = "*" + return response + + +class JwksInfoView(View): + """ + View used to show oidc json web key set document + """ + def get(self, request, *args, **kwargs): + key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + data = { + "keys": [{ + "alg": "RS256", + "use": "sig", + "kid": key.thumbprint() + }] + } + data["keys"][0].update(json.loads(key.export_public())) + response = JsonResponse(data) + response["Access-Control-Allow-Origin"] = "*" + return response + + +class UserInfoView(APIView): + """ + View used to show Claims about the authenticated End-User + """ + def get(self, request, *args, **kwargs): + response = JsonResponse(request.auth.id_token.claims) + response["Access-Control-Allow-Origin"] = "*" + return response diff --git a/setup.cfg b/setup.cfg index 71a69a99d..40a68012b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,7 @@ install_requires = django >= 2.1 requests >= 2.13.0 oauthlib >= 3.0.1 + jwcrypto >= 0.4.2 [options.packages.find] exclude = tests diff --git a/tests/migrations/0001_initial.py b/tests/migrations/0001_initial.py index 60b17f2ae..eef6dbab5 100644 --- a/tests/migrations/0001_initial.py +++ b/tests/migrations/0001_initial.py @@ -45,7 +45,7 @@ class Migration(migrations.Migration): ('client_id', models.CharField(db_index=True, default=oauth2_provider.generators.generate_client_id, max_length=100, unique=True)), ('redirect_uris', models.TextField(blank=True, help_text='Allowed URIs list, space separated')), ('client_type', models.CharField(choices=[('confidential', 'Confidential'), ('public', 'Public')], max_length=32)), - ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials')], max_length=32)), + ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32)), ('client_secret', models.CharField(blank=True, db_index=True, default=oauth2_provider.generators.generate_client_secret, max_length=255)), ('name', models.CharField(blank=True, max_length=255)), ('skip_authorization', models.BooleanField(default=False)), @@ -53,6 +53,7 @@ class Migration(migrations.Migration): ('updated', models.DateTimeField(auto_now=True)), ('custom_field', models.CharField(max_length=255)), ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_sampleapplication', to=settings.AUTH_USER_MODEL)), + ('algorithm', models.CharField(max_length=5, choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256')), ], options={ 'abstract': False, @@ -71,6 +72,7 @@ class Migration(migrations.Migration): ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)), ('source_refresh_token', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='s_refreshed_access_token', to=settings.OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL)), ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_sampleaccesstoken', to=settings.AUTH_USER_MODEL)), + ('id_token', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=settings.OAUTH2_PROVIDER_ID_TOKEN_MODEL)), ], options={ 'abstract': False, @@ -83,7 +85,7 @@ class Migration(migrations.Migration): ('client_id', models.CharField(db_index=True, default=oauth2_provider.generators.generate_client_id, max_length=100, unique=True)), ('redirect_uris', models.TextField(blank=True, help_text='Allowed URIs list, space separated')), ('client_type', models.CharField(choices=[('confidential', 'Confidential'), ('public', 'Public')], max_length=32)), - ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials')], max_length=32)), + ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32)), ('client_secret', models.CharField(blank=True, db_index=True, default=oauth2_provider.generators.generate_client_secret, max_length=255)), ('name', models.CharField(blank=True, max_length=255)), ('skip_authorization', models.BooleanField(default=False)), @@ -91,6 +93,7 @@ class Migration(migrations.Migration): ('updated', models.DateTimeField(auto_now=True)), ('allowed_schemes', models.TextField(blank=True)), ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_basetestapplication', to=settings.AUTH_USER_MODEL)), + ('algorithm', models.CharField(max_length=5, choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256')), ], options={ 'abstract': False, diff --git a/tests/settings.py b/tests/settings.py index 40eef5ebd..edd1ae679 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -130,3 +130,30 @@ }, } } + +OIDC_RSA_PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQCbCYh5h2NmQuBqVO6G+/CO+cHm9VBzsb0MeA6bbQfDnbhstVOT +j0hcnZJzDjYc6ajBZZf6gxVP9xrdm9Uh599VI3X5PFXLbMHrmzTAMzCGIyg+/fnP +0gocYxmCX2+XKyj/Zvt1pUX8VAN2AhrJSfxNDKUHERTVEV9bRBJg4F0C3wIDAQAB +AoGAP+i4nNw+Ec/8oWh8YSFm4xE6qKG0NdTtSMAOyWwy+KTB+vHuT1QPsLn1vj77 ++IQrX/moogg6F1oV9YdA3vat3U7rwt1sBGsRrLhA+Spp9WEQtglguNo4+QfVo2ju +YBa2rG+h75qjiA3xnU//F3rvwnAsOWv0NUVdVeguyR+u6okCQQDBUmgWeH2WHmUn +2nLNCz+9wj28rqhfOr9Ptem2gqk+ywJmuIr4Y5S1OdavOr2UZxOcEwncJ/MLVYQq +MH+x4V5HAkEAzU2GMR5OdVLcxfVTjzuIC76paoHVWnLibd1cdANpPmE6SM+pf5el +fVSwuH9Fmlizu8GiPCxbJUoXB/J1tGEKqQJBALhClEU+qOzpoZ6/voYi/6kdN3zc +uEy0EN6n09AKb8gS9QH1STgAqh+ltjMkeMe3C2DKYK5/QU9/Pc58lWl1FkcCQG67 +ZamQgxjcvJ85FvymS1aqW45KwNysIlzHjFo2jMlMf7dN6kobbPMQftDENLJvLWIT +qoFyGycdsxZiPAIyZSECQQCZFn3Dl6hnJxWZH8Fsa9hj79kZ/WVkIXGmtdgt0fNr +dTnvCVtA59ne4LEVie/PMH/odQWY0SxVm/76uBZv/1vY +-----END RSA PRIVATE KEY-----""" + +OAUTH2_PROVIDER = { + "OIDC_ISS_ENDPOINT": "http://localhost", + "OIDC_USERINFO_ENDPOINT": "http://localhost/userinfo/", + "OIDC_RSA_PRIVATE_KEY": OIDC_RSA_PRIVATE_KEY, +} + +OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL = "oauth2_provider.AccessToken" +OAUTH2_PROVIDER_APPLICATION_MODEL = "oauth2_provider.Application" +OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL = "oauth2_provider.RefreshToken" +OAUTH2_PROVIDER_ID_TOKEN_MODEL = "oauth2_provider.IDToken" diff --git a/tests/test_application_views.py b/tests/test_application_views.py index 6130876ce..64e112da3 100644 --- a/tests/test_application_views.py +++ b/tests/test_application_views.py @@ -50,6 +50,7 @@ def test_application_registration_user(self): "client_type": Application.CLIENT_CONFIDENTIAL, "redirect_uris": "http://example.com", "authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE, + "algorithm": "RS256", } response = self.client.post(reverse("oauth2_provider:register"), form_data) diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index 9a95bc269..0c6c71705 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -41,8 +41,12 @@ def get(self, request, *args, **kwargs): class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() - self.test_user = UserModel.objects.create_user("test_user", "test@example.com", "123456") - self.dev_user = UserModel.objects.create_user("dev_user", "dev@example.com", "123456") + self.test_user = UserModel.objects.create_user( + "test_user", "test@example.com", "123456" + ) + self.dev_user = UserModel.objects.create_user( + "dev_user", "dev@example.com", "123456" + ) oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] @@ -57,8 +61,13 @@ def setUp(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - oauth2_settings._SCOPES = ["read", "write"] + oauth2_settings._SCOPES = ["read", "write", "openid"] oauth2_settings._DEFAULT_SCOPES = ["read", "write"] + oauth2_settings.SCOPES = { + "read": "Reading scope", + "write": "Writing scope", + "openid": "OpenID connect", + } def tearDown(self): self.application.delete() @@ -74,14 +83,18 @@ class TestRegressionIssue315(BaseTest): def test_request_is_not_overwritten(self): self.client.login(username="test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 200) @@ -97,14 +110,42 @@ def test_skip_authorization_completely(self): self.application.skip_authorization = True self.application.save() - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + + def test_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 302) @@ -115,17 +156,18 @@ def test_pre_auth_invalid_client(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode({ - "client_id": "fakeclientid", - "response_type": "code", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + {"client_id": "fakeclientid", "response_type": "code", } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 400) self.assertEqual( response.context_data["url"], - "?error=invalid_request&error_description=Invalid+client_id+parameter+value." + "?error=invalid_request&error_description=Invalid+client_id+parameter+value.", ) def test_pre_auth_valid_client(self): @@ -134,14 +176,18 @@ def test_pre_auth_valid_client(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 200) @@ -155,6 +201,37 @@ def test_pre_auth_valid_client(self): self.assertEqual(form["scope"].value(), "read write") self.assertEqual(form["client_id"].value(), self.application.client_id) + def test_id_token_pre_auth_valid_client(self): + """ + Test response for a valid client_id with response_type: code + """ + self.client.login(username="test_user", password="123456") + + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "openid") + self.assertEqual(form["client_id"].value(), self.application.client_id) + def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): """ Test response for a valid client_id with response_type: code @@ -162,14 +239,18 @@ def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "custom-scheme://example.com", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "custom-scheme://example.com", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 200) @@ -185,21 +266,26 @@ def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): def test_pre_auth_approval_prompt(self): tok = AccessToken.objects.create( - user=self.test_user, token="1234567890", + user=self.test_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) self.client.login(username="test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "approval_prompt": "auto", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "approval_prompt": "auto", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 302) # user already authorized the application, but with different scopes: prompt them. @@ -212,20 +298,25 @@ def test_pre_auth_approval_prompt_default(self): self.assertEqual(oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") AccessToken.objects.create( - user=self.test_user, token="1234567890", + user=self.test_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) self.client.login(username="test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 200) @@ -233,20 +324,25 @@ def test_pre_auth_approval_prompt_default_override(self): oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" AccessToken.objects.create( - user=self.test_user, token="1234567890", + user=self.test_user, + token="1234567890", application=self.application, expires=timezone.now() + datetime.timedelta(days=1), - scope="read write" + scope="read write", ) self.client.login(username="test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 302) @@ -256,11 +352,12 @@ def test_pre_auth_default_redirect(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + {"client_id": self.application.client_id, "response_type": "code", } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 200) @@ -274,12 +371,16 @@ def test_pre_auth_forbibben_redirect(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "redirect_uri": "http://forbidden.it", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "redirect_uri": "http://forbidden.it", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 400) @@ -290,11 +391,12 @@ def test_pre_auth_wrong_response_type(self): """ self.client.login(username="test_user", password="123456") - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "WRONG", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + {"client_id": self.application.client_id, "response_type": "WRONG", } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 302) @@ -315,7 +417,32 @@ def test_code_post_auth_allow(self): "allow": True, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=form_data + ) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org?", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + + def test_id_token_code_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + } + + response = self.client.post( + reverse("oauth2_provider:authorize"), data=form_data + ) self.assertEqual(response.status_code, 302) self.assertIn("http://example.org?", response["Location"]) self.assertIn("state=random_state_string", response["Location"]) @@ -336,7 +463,9 @@ def test_code_post_auth_deny(self): "allow": False, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=form_data + ) self.assertEqual(response.status_code, 302) self.assertIn("error=access_denied", response["Location"]) self.assertIn("state=random_state_string", response["Location"]) @@ -355,7 +484,9 @@ def test_code_post_auth_deny_no_state(self): "allow": False, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=form_data + ) self.assertEqual(response.status_code, 302) self.assertIn("error=access_denied", response["Location"]) self.assertNotIn("state", response["Location"]) @@ -375,7 +506,9 @@ def test_code_post_auth_bad_responsetype(self): "allow": True, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=form_data + ) self.assertEqual(response.status_code, 302) self.assertIn("http://example.org?error", response["Location"]) @@ -394,7 +527,9 @@ def test_code_post_auth_forbidden_redirect_uri(self): "allow": True, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=form_data + ) self.assertEqual(response.status_code, 400) def test_code_post_auth_malicious_redirect_uri(self): @@ -412,7 +547,9 @@ def test_code_post_auth_malicious_redirect_uri(self): "allow": True, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=form_data + ) self.assertEqual(response.status_code, 400) def test_code_post_auth_allow_custom_redirect_uri_scheme(self): @@ -431,7 +568,9 @@ def test_code_post_auth_allow_custom_redirect_uri_scheme(self): "allow": True, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=form_data + ) self.assertEqual(response.status_code, 302) self.assertIn("custom-scheme://example.com?", response["Location"]) self.assertIn("state=random_state_string", response["Location"]) @@ -453,7 +592,9 @@ def test_code_post_auth_deny_custom_redirect_uri_scheme(self): "allow": False, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=form_data + ) self.assertEqual(response.status_code, 302) self.assertIn("custom-scheme://example.com?", response["Location"]) self.assertIn("error=access_denied", response["Location"]) @@ -476,7 +617,9 @@ def test_code_post_auth_redirection_uri_with_querystring(self): "allow": True, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=form_data + ) self.assertEqual(response.status_code, 302) self.assertIn("http://example.com?foo=bar", response["Location"]) self.assertIn("code=", response["Location"]) @@ -499,7 +642,9 @@ def test_code_post_auth_failing_redirection_uri_with_querystring(self): "allow": False, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=form_data + ) self.assertEqual(response.status_code, 302) self.assertIn("http://example.com?", response["Location"]) self.assertIn("error=access_denied", response["Location"]) @@ -521,25 +666,29 @@ def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): "allow": True, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=form_data + ) self.assertEqual(response.status_code, 400) class TestAuthorizationCodeTokenView(BaseTest): - def get_auth(self): + def get_auth(self, scope="read write"): """ Helper method to retrieve a valid authorization code """ authcode_data = { "client_id": self.application.client_id, "state": "random_state_string", - "scope": "read write", + "scope": scope, "redirect_uri": "http://example.org", "response_type": "code", "allow": True, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=authcode_data + ) query_dict = parse_qs(urlparse(response["Location"]).query) return query_dict["code"].pop() @@ -549,9 +698,13 @@ def generate_pkce_codes(self, algorithm, length=43): """ code_verifier = get_random_string(length) if algorithm == "S256": - code_challenge = base64.urlsafe_b64encode( - hashlib.sha256(code_verifier.encode()).digest() - ).decode().rstrip("=") + code_challenge = ( + base64.urlsafe_b64encode( + hashlib.sha256(code_verifier.encode()).digest() + ) + .decode() + .rstrip("=") + ) else: code_challenge = code_verifier return code_verifier, code_challenge @@ -572,7 +725,9 @@ def get_pkce_auth(self, code_challenge, code_challenge_method): "code_challenge_method": code_challenge_method, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=authcode_data + ) query_dict = parse_qs(urlparse(response["Location"]).query) oauth2_settings.PKCE_REQUIRED = False return query_dict["code"].pop() @@ -587,17 +742,23 @@ def test_basic_auth(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual( + content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS + ) def test_refresh(self): """ @@ -609,11 +770,15 @@ def test_refresh(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -622,23 +787,29 @@ def test_refresh(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) token_request_data = { "grant_type": "refresh_token", "refresh_token": content["refresh_token"], "scope": content["scope"], } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertTrue("access_token" in content) # check refresh token cannot be used twice - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 400) content = json.loads(response.content.decode("utf-8")) self.assertTrue("invalid_grant" in content.values()) @@ -654,11 +825,15 @@ def test_refresh_with_grace_period(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -667,9 +842,11 @@ def test_refresh_with_grace_period(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) token_request_data = { "grant_type": "refresh_token", @@ -677,7 +854,9 @@ def test_refresh_with_grace_period(self): "scope": content["scope"], } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) @@ -686,7 +865,9 @@ def test_refresh_with_grace_period(self): first_refresh_token = content["refresh_token"] # check access token returns same data if used twice, see #497 - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertTrue("access_token" in content) @@ -706,11 +887,15 @@ def test_refresh_invalidates_old_tokens(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) content = json.loads(response.content.decode("utf-8")) rt = content["refresh_token"] @@ -721,7 +906,9 @@ def test_refresh_invalidates_old_tokens(self): "refresh_token": rt, "scope": content["scope"], } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 200) refresh_token = RefreshToken.objects.filter(token=rt).first() @@ -738,11 +925,15 @@ def test_refresh_no_scopes(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -750,7 +941,9 @@ def test_refresh_no_scopes(self): "grant_type": "refresh_token", "refresh_token": content["refresh_token"], } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) @@ -766,11 +959,15 @@ def test_refresh_bad_scopes(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -779,7 +976,9 @@ def test_refresh_bad_scopes(self): "refresh_token": content["refresh_token"], "scope": "read write nuke", } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 400) def test_refresh_fail_repeating_requests(self): @@ -792,11 +991,15 @@ def test_refresh_fail_repeating_requests(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -805,9 +1008,13 @@ def test_refresh_fail_repeating_requests(self): "refresh_token": content["refresh_token"], "scope": content["scope"], } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 200) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 400) def test_refresh_repeating_requests(self): @@ -822,11 +1029,15 @@ def test_refresh_repeating_requests(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -835,18 +1046,26 @@ def test_refresh_repeating_requests(self): "refresh_token": content["refresh_token"], "scope": content["scope"], } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 200) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 200) # try refreshing outside the refresh window, see #497 rt = RefreshToken.objects.get(token=content["refresh_token"]) self.assertIsNotNone(rt.revoked) - rt.revoked = timezone.now() - datetime.timedelta(minutes=10) # instead of mocking out datetime + rt.revoked = timezone.now() - datetime.timedelta( + minutes=10 + ) # instead of mocking out datetime rt.save() - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 400) oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 0 @@ -860,11 +1079,15 @@ def test_refresh_repeating_requests_non_rotating_tokens(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) content = json.loads(response.content.decode("utf-8")) self.assertTrue("refresh_token" in content) @@ -875,9 +1098,13 @@ def test_refresh_repeating_requests_non_rotating_tokens(self): } oauth2_settings.ROTATE_REFRESH_TOKEN = False - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 200) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 200) oauth2_settings.ROTATE_REFRESH_TOKEN = True @@ -891,11 +1118,15 @@ def test_basic_auth_bad_authcode(self): token_request_data = { "grant_type": "authorization_code", "code": "BLAH", - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 400) def test_basic_auth_bad_granttype(self): @@ -907,11 +1138,15 @@ def test_basic_auth_bad_granttype(self): token_request_data = { "grant_type": "UNKNOWN", "code": "BLAH", - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 400) def test_basic_auth_grant_expired(self): @@ -920,18 +1155,27 @@ def test_basic_auth_grant_expired(self): """ self.client.login(username="test_user", password="123456") g = Grant( - application=self.application, user=self.test_user, code="BLAH", - expires=timezone.now(), redirect_uri="", scope="") + application=self.application, + user=self.test_user, + code="BLAH", + expires=timezone.now(), + redirect_uri="", + scope="", + ) g.save() token_request_data = { "grant_type": "authorization_code", "code": "BLAH", - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 400) def test_basic_auth_bad_secret(self): @@ -944,11 +1188,13 @@ def test_basic_auth_bad_secret(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } auth_headers = get_basic_auth_header(self.application.client_id, "BOOM!") - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 401) def test_basic_auth_wrong_auth_type(self): @@ -961,16 +1207,20 @@ def test_basic_auth_wrong_auth_type(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - user_pass = "{0}:{1}".format(self.application.client_id, self.application.client_secret) + user_pass = "{0}:{1}".format( + self.application.client_id, self.application.client_secret + ) auth_string = base64.b64encode(user_pass.encode("utf-8")) auth_headers = { "HTTP_AUTHORIZATION": "Wrong " + auth_string.decode("utf-8"), } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 401) def test_request_body_params(self): @@ -988,13 +1238,17 @@ def test_request_body_params(self): "client_secret": self.application.client_secret, } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data + ) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual( + content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS + ) def test_public(self): """ @@ -1010,16 +1264,52 @@ def test_public(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "http://example.org", - "client_id": self.application.client_id + "client_id": self.application.client_id, } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data + ) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual( + content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS + ) + + def test_id_token_public(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth(scope="openid") + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "scope": "openid", + } + + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data + ) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual( + content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS + ) def test_public_pkce_S256_authorize_get(self): """ @@ -1034,17 +1324,21 @@ def test_public_pkce_S256_authorize_get(self): code_verifier, code_challenge = self.generate_pkce_codes("S256") oauth2_settings.PKCE_REQUIRED = True - query_string = urlencode({ - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - "code_challenge": code_challenge, - "code_challenge_method": "S256" - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 200) @@ -1063,17 +1357,21 @@ def test_public_pkce_plain_authorize_get(self): code_verifier, code_challenge = self.generate_pkce_codes("plain") oauth2_settings.PKCE_REQUIRED = True - query_string = urlencode({ - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - "code_challenge": code_challenge, - "code_challenge_method": "plain" - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge": code_challenge, + "code_challenge_method": "plain", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 200) @@ -1097,16 +1395,20 @@ def test_public_pkce_S256(self): "code": authorization_code, "redirect_uri": "http://example.org", "client_id": self.application.client_id, - "code_verifier": code_verifier + "code_verifier": code_verifier, } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data + ) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual( + content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS + ) oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_plain(self): @@ -1127,16 +1429,20 @@ def test_public_pkce_plain(self): "code": authorization_code, "redirect_uri": "http://example.org", "client_id": self.application.client_id, - "code_verifier": code_verifier + "code_verifier": code_verifier, } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data + ) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual( + content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS + ) oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_invalid_algorithm(self): @@ -1151,17 +1457,21 @@ def test_public_pkce_invalid_algorithm(self): code_verifier, code_challenge = self.generate_pkce_codes("invalid") oauth2_settings.PKCE_REQUIRED = True - query_string = urlencode({ - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - "code_challenge": code_challenge, - "code_challenge_method": "invalid", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge": code_challenge, + "code_challenge_method": "invalid", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 302) @@ -1181,16 +1491,20 @@ def test_public_pkce_missing_code_challenge(self): code_verifier, code_challenge = self.generate_pkce_codes("S256") oauth2_settings.PKCE_REQUIRED = True - query_string = urlencode({ - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - "code_challenge_method": "S256" - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge_method": "S256", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 302) @@ -1209,16 +1523,20 @@ def test_public_pkce_missing_code_challenge_method(self): code_verifier, code_challenge = self.generate_pkce_codes("S256") oauth2_settings.PKCE_REQUIRED = True - query_string = urlencode({ - "client_id": self.application.client_id, - "state": "random_state_string", - "scope": "read write", - "redirect_uri": "http://example.org", - "response_type": "code", - "allow": True, - "code_challenge": code_challenge - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + "code_challenge": code_challenge, + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 200) @@ -1242,10 +1560,12 @@ def test_public_pkce_S256_invalid_code_verifier(self): "code": authorization_code, "redirect_uri": "http://example.org", "client_id": self.application.client_id, - "code_verifier": "invalid" + "code_verifier": "invalid", } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data + ) self.assertEqual(response.status_code, 400) oauth2_settings.PKCE_REQUIRED = False @@ -1267,10 +1587,12 @@ def test_public_pkce_plain_invalid_code_verifier(self): "code": authorization_code, "redirect_uri": "http://example.org", "client_id": self.application.client_id, - "code_verifier": "invalid" + "code_verifier": "invalid", } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data + ) self.assertEqual(response.status_code, 400) oauth2_settings.PKCE_REQUIRED = False @@ -1291,10 +1613,12 @@ def test_public_pkce_S256_missing_code_verifier(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "http://example.org", - "client_id": self.application.client_id + "client_id": self.application.client_id, } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data + ) self.assertEqual(response.status_code, 400) oauth2_settings.PKCE_REQUIRED = False @@ -1315,10 +1639,12 @@ def test_public_pkce_plain_missing_code_verifier(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "http://example.org", - "client_id": self.application.client_id + "client_id": self.application.client_id, } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data + ) self.assertEqual(response.status_code, 400) oauth2_settings.PKCE_REQUIRED = False @@ -1337,14 +1663,19 @@ def test_malicious_redirect_uri(self): "grant_type": "authorization_code", "code": authorization_code, "redirect_uri": "/../", - "client_id": self.application.client_id + "client_id": self.application.client_id, } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data + ) self.assertEqual(response.status_code, 400) data = response.json() self.assertEqual(data["error"], "invalid_request") - self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) + self.assertEqual( + data["error_description"], + oauthlib_errors.MismatchingRedirectURIError.description, + ) def test_code_exchange_succeed_when_redirect_uri_match(self): """ @@ -1361,7 +1692,9 @@ def test_code_exchange_succeed_when_redirect_uri_match(self): "response_type": "code", "allow": True, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=authcode_data + ) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1369,17 +1702,23 @@ def test_code_exchange_succeed_when_redirect_uri_match(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org?foo=bar" + "redirect_uri": "http://example.org?foo=bar", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual( + content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS + ) def test_code_exchange_fails_when_redirect_uri_does_not_match(self): """ @@ -1396,7 +1735,9 @@ def test_code_exchange_fails_when_redirect_uri_does_not_match(self): "response_type": "code", "allow": True, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=authcode_data + ) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1404,17 +1745,26 @@ def test_code_exchange_fails_when_redirect_uri_does_not_match(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org?foo=baraa" + "redirect_uri": "http://example.org?foo=baraa", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 400) data = response.json() self.assertEqual(data["error"], "invalid_request") - self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) + self.assertEqual( + data["error_description"], + oauthlib_errors.MismatchingRedirectURIError.description, + ) - def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): + def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params( + self, + ): """ Tests code exchange succeed when redirect uri matches the one used for code request """ @@ -1431,7 +1781,9 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param "response_type": "code", "allow": True, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=authcode_data + ) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1439,17 +1791,72 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.com?bar=baz&foo=bar" + "redirect_uri": "http://example.com?bar=baz&foo=bar", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual( + content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS + ) + + def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params( + self, + ): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="test_user", password="123456") + self.application.redirect_uris = "http://localhost http://example.com?foo=bar" + self.application.save() + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.com?bar=baz&foo=bar", + "response_type": "code", + "allow": True, + } + response = self.client.post( + reverse("oauth2_provider:authorize"), data=authcode_data + ) + query_dict = parse_qs(urlparse(response["Location"]).query) + authorization_code = query_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.com?bar=baz&foo=bar", + } + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) + + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual( + content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS + ) def test_oob_as_html(self): """ @@ -1512,7 +1919,9 @@ def test_oob_as_json(self): "allow": True, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=authcode_data + ) self.assertEqual(response.status_code, 200) self.assertRegex(response["Content-Type"], "^application/json") @@ -1529,13 +1938,17 @@ def test_oob_as_json(self): "client_secret": self.application.client_secret, } - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data + ) self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual( + content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS + ) class TestAuthorizationCodeProtectedResource(BaseTest): @@ -1551,7 +1964,9 @@ def test_resource_access_allowed(self): "response_type": "code", "allow": True, } - response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + response = self.client.post( + reverse("oauth2_provider:authorize"), data=authcode_data + ) query_dict = parse_qs(urlparse(response["Location"]).query) authorization_code = query_dict["code"].pop() @@ -1559,11 +1974,15 @@ def test_resource_access_allowed(self): token_request_data = { "grant_type": "authorization_code", "code": authorization_code, - "redirect_uri": "http://example.org" + "redirect_uri": "http://example.org", } - auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) - response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) content = json.loads(response.content.decode("utf-8")) access_token = content["access_token"] @@ -1578,6 +1997,63 @@ def test_resource_access_allowed(self): response = view(request) self.assertEqual(response, "This is a protected resource") + def test_id_token_resource_access_allowed(self): + self.client.login(username="test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + } + response = self.client.post( + reverse("oauth2_provider:authorize"), data=authcode_data + ) + query_dict = parse_qs(urlparse(response["Location"]).query) + authorization_code = query_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + } + auth_headers = get_basic_auth_header( + self.application.client_id, self.application.client_secret + ) + + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data, **auth_headers + ) + content = json.loads(response.content.decode("utf-8")) + access_token = content["access_token"] + id_token = content["id_token"] + + # use token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + access_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + + # use id_token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + id_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + def test_resource_access_deny(self): auth_headers = { "HTTP_AUTHORIZATION": "Bearer " + "faketoken", @@ -1591,7 +2067,6 @@ def test_resource_access_deny(self): class TestDefaultScopes(BaseTest): - def test_pre_auth_default_scopes(self): """ Test response for a valid client_id with response_type: code using default scopes @@ -1599,13 +2074,17 @@ def test_pre_auth_default_scopes(self): self.client.login(username="test_user", password="123456") oauth2_settings._DEFAULT_SCOPES = ["read"] - query_string = urlencode({ - "client_id": self.application.client_id, - "response_type": "code", - "state": "random_state_string", - "redirect_uri": "http://example.org", - }) - url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format( + url=reverse("oauth2_provider:authorize"), qs=query_string + ) response = self.client.get(url) self.assertEqual(response.status_code, 200) diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py new file mode 100644 index 000000000..1f45aeeec --- /dev/null +++ b/tests/test_hybrid.py @@ -0,0 +1,1264 @@ +import base64 +import datetime +import json +from urllib.parse import parse_qs, urlencode, urlparse + +from django.contrib.auth import get_user_model +from django.test import RequestFactory, TestCase +from django.urls import reverse +from django.utils import timezone +from oauthlib.oauth2.rfc6749 import errors as oauthlib_errors + +from oauth2_provider.models import ( + get_access_token_model, get_application_model, + get_grant_model, get_refresh_token_model +) +from oauth2_provider.settings import oauth2_settings +from oauth2_provider.views import ProtectedResourceView + +from .utils import get_basic_auth_header + + +Application = get_application_model() +AccessToken = get_access_token_model() +Grant = get_grant_model() +RefreshToken = get_refresh_token_model() +UserModel = get_user_model() + + +# mocking a protected resource view +class ResourceView(ProtectedResourceView): + def get(self, request, *args, **kwargs): + return "This is a protected resource" + + +class BaseTest(TestCase): + def setUp(self): + self.factory = RequestFactory() + self.hy_test_user = UserModel.objects.create_user("hy_test_user", "test_hy@example.com", "123456") + self.hy_dev_user = UserModel.objects.create_user("hy_dev_user", "dev_hy@example.com", "123456") + + oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] + + self.application = Application( + name="Hybrid Test Application", + redirect_uris=( + "http://localhost http://example.com http://example.org custom-scheme://example.com" + ), + user=self.hy_dev_user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_OPENID_HYBRID, + ) + self.application.save() + + oauth2_settings._SCOPES = ["read", "write", "openid"] + oauth2_settings._DEFAULT_SCOPES = ["read", "write"] + oauth2_settings.SCOPES = { + "read": "Reading scope", + "write": "Writing scope", + "openid": "OpenID connect" + } + + def tearDown(self): + self.application.delete() + self.hy_test_user.delete() + self.hy_dev_user.delete() + + +class TestRegressionIssue315Hybrid(BaseTest): + """ + Test to avoid regression for the issue 315: request object + was being reassigned when getting AuthorizationView + """ + + def test_request_is_not_overwritten_code_token(self): + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code token", + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + assert "request" not in response.context_data + + def test_request_is_not_overwritten_code_id_token(self): + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "nonce": "nonce", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + assert "request" not in response.context_data + + def test_request_is_not_overwritten_code_id_token_token(self): + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token token", + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "nonce": "nonce", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + assert "request" not in response.context_data + + +class TestHybridView(BaseTest): + def test_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="hy_test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + + def test_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="hy_test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + + def test_pre_auth_invalid_client(self): + """ + Test error for an invalid client_id with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": "fakeclientid", + "response_type": "code", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 400) + self.assertEqual( + response.context_data["url"], + "?error=invalid_request&error_description=Invalid+client_id+parameter+value." + ) + + def test_pre_auth_valid_client(self): + """ + Test response for a valid client_id with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "read write") + self.assertEqual(form["client_id"].value(), self.application.client_id) + + def test_id_token_pre_auth_valid_client(self): + """ + Test response for a valid client_id with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "nonce": "nonce", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "openid") + self.assertEqual(form["client_id"].value(), self.application.client_id) + + def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): + """ + Test response for a valid client_id with response_type: code + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "custom-scheme://example.com", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "custom-scheme://example.com") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "read write") + self.assertEqual(form["client_id"].value(), self.application.client_id) + + def test_pre_auth_approval_prompt(self): + tok = AccessToken.objects.create( + user=self.hy_test_user, token="1234567890", + application=self.application, + expires=timezone.now() + datetime.timedelta(days=1), + scope="read write" + ) + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "approval_prompt": "auto", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + # user already authorized the application, but with different scopes: prompt them. + tok.scope = "read" + tok.save() + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + def test_pre_auth_approval_prompt_default(self): + oauth2_settings.REQUEST_APPROVAL_PROMPT = "force" + self.assertEqual(oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") + + AccessToken.objects.create( + user=self.hy_test_user, token="1234567890", + application=self.application, + expires=timezone.now() + datetime.timedelta(days=1), + scope="read write" + ) + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + def test_pre_auth_approval_prompt_default_override(self): + oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" + + AccessToken.objects.create( + user=self.hy_test_user, token="1234567890", + application=self.application, + expires=timezone.now() + datetime.timedelta(days=1), + scope="read write" + ) + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + + def test_pre_auth_default_redirect(self): + """ + Test for default redirect uri if omitted from query string with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code id_token", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://localhost") + + def test_pre_auth_forbibben_redirect(self): + """ + Test error when passing a forbidden redirect_uri in query string with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code", + "redirect_uri": "http://forbidden.it", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 400) + + def test_pre_auth_wrong_response_type(self): + """ + Test error when passing a wrong response_type in query string + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "WRONG", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + self.assertIn("error=unsupported_response_type", response["Location"]) + + def test_code_post_auth_allow_code_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "response_type": "code token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_allow_code_id_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + + def test_code_post_auth_allow_code_id_token_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "response_type": "code id_token token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_id_token_code_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + + def test_code_post_auth_deny(self): + """ + Test error when resource owner deny access + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("error=access_denied", response["Location"]) + + def test_code_post_auth_bad_responsetype(self): + """ + Test authorization code is given for an allowed request with a response_type not supported + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "UNKNOWN", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org?error", response["Location"]) + + def test_code_post_auth_forbidden_redirect_uri(self): + """ + Test authorization code is given for an allowed request with a forbidden redirect_uri + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://forbidden.it", + "response_type": "code", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 400) + + def test_code_post_auth_malicious_redirect_uri(self): + """ + Test validation of a malicious redirect_uri + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "/../", + "response_type": "code", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 400) + + def test_code_post_auth_allow_custom_redirect_uri_scheme_code_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "custom-scheme://example.com", + "response_type": "code token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("custom-scheme://example.com", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_allow_custom_redirect_uri_scheme_code_id_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "custom-scheme://example.com", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("custom-scheme://example.com", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + + def test_code_post_auth_allow_custom_redirect_uri_scheme_code_id_token_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "custom-scheme://example.com", + "response_type": "code id_token token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("custom-scheme://example.com", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_deny_custom_redirect_uri_scheme(self): + """ + Test error when resource owner deny access + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "custom-scheme://example.com", + "response_type": "code", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("custom-scheme://example.com?", response["Location"]) + self.assertIn("error=access_denied", response["Location"]) + + def test_code_post_auth_redirection_uri_with_querystring_code_token(self): + """ + Tests that a redirection uri with query string is allowed + and query string is retained on redirection. + See http://tools.ietf.org/html/rfc6749#section-3.1.2 + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.com?foo=bar", + "response_type": "code token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.com?foo=bar", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_redirection_uri_with_querystring_code_id_token(self): + """ + Tests that a redirection uri with query string is allowed + and query string is retained on redirection. + See http://tools.ietf.org/html/rfc6749#section-3.1.2 + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.com?foo=bar", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.com?foo=bar", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + + def test_code_post_auth_redirection_uri_with_querystring_code_id_token_token(self): + """ + Tests that a redirection uri with query string is allowed + and query string is retained on redirection. + See http://tools.ietf.org/html/rfc6749#section-3.1.2 + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.com?foo=bar", + "response_type": "code id_token token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.com?foo=bar", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_failing_redirection_uri_with_querystring(self): + """ + Test that in case of error the querystring of the redirection uri is preserved + + See https://github.com/evonove/django-oauth-toolkit/issues/238 + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.com?foo=bar", + "response_type": "code", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertEqual( + "http://example.com?foo=bar&error=access_denied&state=random_state_string", response["Location"] + ) + + def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): + """ + Tests that a redirection uri is matched using scheme + netloc + path + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.com/a?foo=bar", + "response_type": "code", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 400) + + +class TestHybridTokenView(BaseTest): + def get_auth(self, scope="read write"): + """ + Helper method to retrieve a valid authorization code + """ + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": scope, + "redirect_uri": "http://example.org", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + return fragment_dict["code"].pop() + + def test_basic_auth(self): + """ + Request an access token using basic authentication for client authentication + """ + self.client.login(username="hy_test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "read write") + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_basic_auth_bad_authcode(self): + """ + Request an access token using a bad authorization code + """ + self.client.login(username="hy_test_user", password="123456") + + token_request_data = { + "grant_type": "authorization_code", + "code": "BLAH", + "redirect_uri": "http://example.org" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 400) + + def test_basic_auth_bad_granttype(self): + """ + Request an access token using a bad grant_type string + """ + self.client.login(username="hy_test_user", password="123456") + + token_request_data = { + "grant_type": "UNKNOWN", + "code": "BLAH", + "redirect_uri": "http://example.org" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 400) + + def test_basic_auth_grant_expired(self): + """ + Request an access token using an expired grant token + """ + self.client.login(username="hy_test_user", password="123456") + g = Grant( + application=self.application, user=self.hy_test_user, code="BLAH", + expires=timezone.now(), redirect_uri="", scope="") + g.save() + + token_request_data = { + "grant_type": "authorization_code", + "code": "BLAH", + "redirect_uri": "http://example.org" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 400) + + def test_basic_auth_bad_secret(self): + """ + Request an access token using basic authentication for client authentication + """ + self.client.login(username="hy_test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org" + } + auth_headers = get_basic_auth_header(self.application.client_id, "BOOM!") + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 401) + + def test_basic_auth_wrong_auth_type(self): + """ + Request an access token using basic authentication for client authentication + """ + self.client.login(username="hy_test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org" + } + + user_pass = "{0}:{1}".format(self.application.client_id, self.application.client_secret) + auth_string = base64.b64encode(user_pass.encode("utf-8")) + auth_headers = { + "HTTP_AUTHORIZATION": "Wrong " + auth_string.decode("utf-8"), + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 401) + + def test_request_body_params(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="hy_test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "client_secret": self.application.client_secret, + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "read write") + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_public(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="hy_test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "read write") + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_id_token_public(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="hy_test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth(scope="openid") + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "scope": "openid", + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_malicious_redirect_uri(self): + """ + Request an access token using client_type: public and ensure redirect_uri is + properly validated. + """ + self.client.login(username="hy_test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "/../", + "client_id": self.application.client_id + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 400) + data = response.json() + self.assertEqual(data["error"], "invalid_request") + self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) + + def test_code_exchange_succeed_when_redirect_uri_match(self): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="hy_test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org?foo=bar", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org?foo=bar" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid read write") + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_code_exchange_fails_when_redirect_uri_does_not_match(self): + """ + Tests code exchange fails when redirect uri does not match the one used for code request + """ + self.client.login(username="hy_test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org?foo=bar", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + query_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = query_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org?foo=baraa" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 400) + data = response.json() + self.assertEqual(data["error"], "invalid_request") + self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) + + def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="hy_test_user", password="123456") + self.application.redirect_uris = "http://localhost http://example.com?foo=bar" + self.application.save() + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.com?bar=baz&foo=bar", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.com?bar=baz&foo=bar" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid read write") + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="hy_test_user", password="123456") + self.application.redirect_uris = "http://localhost http://example.com?foo=bar" + self.application.save() + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.com?bar=baz&foo=bar", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.com?bar=baz&foo=bar", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + +class TestHybridProtectedResource(BaseTest): + def test_resource_access_allowed(self): + self.client.login(username="hy_test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org" + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + content = json.loads(response.content.decode("utf-8")) + access_token = content["access_token"] + + # use token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + access_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.hy_test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + + def test_id_token_resource_access_allowed(self): + self.client.login(username="hy_test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + content = json.loads(response.content.decode("utf-8")) + access_token = content["access_token"] + id_token = content["id_token"] + + # use token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + access_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.hy_test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + + # use id_token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + id_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.hy_test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + + def test_resource_access_deny(self): + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + "faketoken", + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.hy_test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response.status_code, 403) + + +class TestDefaultScopesHybrid(BaseTest): + + def test_pre_auth_default_scopes(self): + """ + Test response for a valid client_id with response_type: code using default scopes + """ + self.client.login(username="hy_test_user", password="123456") + oauth2_settings._DEFAULT_SCOPES = ["read"] + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "code token", + "state": "random_state_string", + "redirect_uri": "http://example.org", + }) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "read") + self.assertEqual(form["client_id"].value(), self.application.client_id) + oauth2_settings._DEFAULT_SCOPES = ["read", "write"] diff --git a/tests/test_implicit.py b/tests/test_implicit.py index c2fd83a5a..4e8879a9e 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -1,8 +1,10 @@ +import json from urllib.parse import parse_qs, urlencode, urlparse from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse +from jwcrypto import jwk, jwt from oauth2_provider.models import get_application_model from oauth2_provider.settings import oauth2_settings @@ -33,8 +35,14 @@ def setUp(self): authorization_grant_type=Application.GRANT_IMPLICIT, ) - oauth2_settings._SCOPES = ["read", "write"] + oauth2_settings._SCOPES = ["read", "write", "openid"] oauth2_settings._DEFAULT_SCOPES = ["read"] + oauth2_settings.SCOPES = { + "read": "Reading scope", + "write": "Writing scope", + "openid": "OpenID connect" + } + self.key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) def tearDown(self): self.application.delete() @@ -272,3 +280,197 @@ def test_resource_access_allowed(self): view = ResourceView.as_view() response = view(request) self.assertEqual(response, "This is a protected resource") + + +class TestOpenIDConnectImplicitFlow(BaseTest): + def test_id_token_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: id_token + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "id_token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org#", response["Location"]) + self.assertNotIn("access_token=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + + uri_query = urlparse(response["Location"]).fragment + uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) + id_token = uri_query_params["id_token"][0] + jwt_token = jwt.JWT(key=self.key, jwt=id_token) + claims = json.loads(jwt_token.claims) + self.assertIn("nonce", claims) + self.assertNotIn("at_hash", claims) + + def test_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "id_token", + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + }) + + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org#", response["Location"]) + self.assertNotIn("access_token=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + + uri_query = urlparse(response["Location"]).fragment + uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) + id_token = uri_query_params["id_token"][0] + jwt_token = jwt.JWT(key=self.key, jwt=id_token) + claims = json.loads(jwt_token.claims) + self.assertIn("nonce", claims) + self.assertNotIn("at_hash", claims) + + def test_id_token_skip_authorization_completely_missing_nonce(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "id_token", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + }) + + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + self.assertIn("error=invalid_request", response["Location"]) + self.assertIn("error_description=Request+is+missing+mandatory+nonce+paramete", response["Location"]) + + def test_id_token_post_auth_deny(self): + """ + Test error when resource owner deny access + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "id_token", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("error=access_denied", response["Location"]) + + def test_access_token_and_id_token_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: token + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "id_token token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org#", response["Location"]) + self.assertIn("access_token=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + + uri_query = urlparse(response["Location"]).fragment + uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) + id_token = uri_query_params["id_token"][0] + jwt_token = jwt.JWT(key=self.key, jwt=id_token) + claims = json.loads(jwt_token.claims) + self.assertIn("nonce", claims) + self.assertIn("at_hash", claims) + + def test_access_token_and_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode({ + "client_id": self.application.client_id, + "response_type": "id_token token", + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + }) + + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org#", response["Location"]) + self.assertIn("access_token=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + + uri_query = urlparse(response["Location"]).fragment + uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) + id_token = uri_query_params["id_token"][0] + jwt_token = jwt.JWT(key=self.key, jwt=id_token) + claims = json.loads(jwt_token.claims) + self.assertIn("nonce", claims) + self.assertIn("at_hash", claims) + + def test_access_token_and_id_token_post_auth_deny(self): + """ + Test error when resource owner deny access + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "id_token token", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("error=access_denied", response["Location"]) diff --git a/tests/test_oauth2_backends.py b/tests/test_oauth2_backends.py index d844da5f4..0d98dad8b 100644 --- a/tests/test_oauth2_backends.py +++ b/tests/test_oauth2_backends.py @@ -65,7 +65,9 @@ def test_create_token_response_gets_extra_credentials(self): payload = "grant_type=password&username=john&password=123456" request = self.factory.post("/o/token/", payload, content_type="application/x-www-form-urlencoded") - with mock.patch("oauthlib.oauth2.Server.create_token_response") as create_token_response: + with mock.patch( + "oauthlib.openid.connect.core.endpoints.pre_configured.Server.create_token_response" + ) as create_token_response: mocked = mock.MagicMock() create_token_response.return_value = mocked, mocked, mocked core = self.MyOAuthLibCore() diff --git a/tests/test_oidc_views.py b/tests/test_oidc_views.py new file mode 100644 index 000000000..43e46d297 --- /dev/null +++ b/tests/test_oidc_views.py @@ -0,0 +1,47 @@ +from __future__ import unicode_literals + +from django.test import TestCase +from django.urls import reverse + + +class TestConnectDiscoveryInfoView(TestCase): + def test_get_connect_discovery_info(self): + expected_response = { + "issuer": "http://localhost", + "authorization_endpoint": "http://localhost/o/authorize/", + "token_endpoint": "http://localhost/o/token/", + "userinfo_endpoint": "http://localhost/userinfo/", + "jwks_uri": "http://localhost/o/jwks/", + "response_types_supported": [ + "code", + "token", + "id_token", + "id_token token", + "code token", + "code id_token", + "code id_token token" + ], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256", "HS256"], + "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"] + } + response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) + self.assertEqual(response.status_code, 200) + assert response.json() == expected_response + + +class TestJwksInfoView(TestCase): + def test_get_jwks_info(self): + expected_response = { + "keys": [{ + "alg": "RS256", + "use": "sig", + "kid": "s4a1o8mFEd1tATAIH96caMlu4hOxzBUaI2QTqbYNBHs", + "e": "AQAB", + "kty": "RSA", + "n": "mwmIeYdjZkLgalTuhvvwjvnB5vVQc7G9DHgOm20Hw524bLVTk49IXJ2Scw42HOmowWWX-oMVT_ca3ZvVIeffVSN1-TxVy2zB65s0wDMwhiMoPv35z9IKHGMZgl9vlyso_2b7daVF_FQDdgIayUn8TQylBxEU1RFfW0QSYOBdAt8" # noqa + }] + } + response = self.client.get(reverse("oauth2_provider:jwks-info")) + self.assertEqual(response.status_code, 200) + assert response.json() == expected_response diff --git a/tox.ini b/tox.ini index 210106f57..7d2de237e 100644 --- a/tox.ini +++ b/tox.ini @@ -14,7 +14,8 @@ envlist = django_find_project = false [testenv] -commands = pytest --cov=oauth2_provider --cov-report= --cov-append {posargs} +commands = + pytest --cov=oauth2_provider --cov-report= --cov-append {posargs} -s setenv = DJANGO_SETTINGS_MODULE = tests.settings PYTHONPATH = {toxinidir} @@ -33,15 +34,17 @@ deps = pytest-xdist py27: mock requests + jwcrypto [testenv:py37-docs] basepython = python changedir = docs whitelist_externals = make commands = make html -deps = sphinx +deps = sphinx<3 oauthlib>=3.0.1 m2r>=0.2.1 + jwcrypto [testenv:py37-flake8] skip_install = True @@ -67,7 +70,9 @@ commands = [coverage:run] source = oauth2_provider -omit = */migrations/* +omit = + */migrations/* + oauth2_provider/settings.py [flake8] max-line-length = 110