diff --git a/.editorconfig b/.editorconfig index 2ca598bbd..5a7ffef02 100644 --- a/.editorconfig +++ b/.editorconfig @@ -8,7 +8,7 @@ indent_size = 4 insert_final_newline = true trim_trailing_whitespace = true -[{Makefile,tox.ini,setup.cfg}] +[{Makefile,setup.cfg}] indent_style = tab [*.{yml,yaml}] diff --git a/.gitignore b/.gitignore index af644d1e3..3643335d4 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ pip-log.txt # Unit test / coverage reports .cache +.pytest_cache .coverage .tox .pytest_cache/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 01e45bb33..58f279398 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +### Added +* #915 Add optional OpenID Connect support. + ## [1.4.1] ### Changed diff --git a/docs/index.rst b/docs/index.rst index f4add1bdd..4f83249f0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -40,6 +40,7 @@ Index views/details models advanced_topics + oidc signals settings resource_server diff --git a/docs/oidc.rst b/docs/oidc.rst new file mode 100644 index 000000000..29c9406bd --- /dev/null +++ b/docs/oidc.rst @@ -0,0 +1,308 @@ +OpenID Connect +++++++++++++++ + +OpenID Connect support +====================== + +``django-oauth-toolkit`` supports OpenID Connect (OIDC), which standardizes +authentication flows and provides a plug and play integration with other +systems. OIDC is built on top of OAuth 2.0 to provide: + +* Generating ID tokens as part of the login process. These are JWT that + describe the user, and can be used to authenticate them to your application. +* Metadata based auto-configuration for providers +* A user info endpoint, which applications can query to get more information + about a user. + +Enabling OIDC doesn't affect your existing OAuth 2.0 flows, these will +continue to work alongside OIDC. + +We support: + +* OpenID Connect Authorization Code Flow +* OpenID Connect Implicit Flow +* OpenID Connect Hybrid Flow + + +Configuration +============= + +OIDC is not enabled by default because it requires additional configuration +that must be provided. ``django-oauth-toolkit`` supports two different +algorithms for signing JWT tokens, ``RS256``, which uses asymmetric RSA keys (a +public key and a private key), and ``HS256``, which uses a symmetric key. + +It is preferrable to use ``RS256``, because this produces a token that can be +verified by anyone using the public key (which is made available and +discoverable by OIDC service auto-discovery, included with +``django-oauth-toolkit``). ``HS256`` on the other hand uses the +``client_secret`` in order to verify keys. This is simpler to implement, but +makes it harder to safely verify tokens. + +Using ``HS256`` also means that you cannot use the Implicit or Hybrid flows, +or verify the tokens in public clients, because you cannot disclose the +``client_secret`` to a public client. If you are using a public client, you +must use ``RS256``. + + +Creating RSA private key +~~~~~~~~~~~~~~~~~~~~~~~~ + +To use ``RS256`` requires an RSA private key, which is used for signing JWT. You +can generate this using the `openssl`_ tool:: + + openssl genrsa -out oidc.key 4096 + +This will generate a 4096-bit RSA key, which will be sufficient for our needs. + +.. _openssl: https://www.openssl.org + +.. warning:: + The contents of this key *must* be kept a secret. Don't put it in your + settings and commit it to version control! + + If the key is ever accidentally disclosed, an attacker could use it to + forge JWT tokens that verify as issued by your OAuth provider, which is + very bad! + + If it is ever disclosed, you should immediately replace the key. + + Safe ways to handle it would be: + + * Store it in a secure system like `Hashicorp Vault`_, and inject it in to + your environment when running your server. + * Store it in a secure file on your server, and use your initialization + scripts to inject it in to your environment. + +.. _Hashicorp Vault: https://www.hashicorp.com/products/vault + +Now we need to add this key to our settings and allow the ``openid`` scope to +be used. Assuming we have set an environment variable called +``OIDC_RSA_PRIVATE_KEY``, we can make changes to our ``settings.py``:: + + import os.environ + + OAUTH2_PROVIDER = { + "OIDC_ENABLED": True, + "OIDC_RSA_PRIVATE_KEY": os.environ.get("OIDC_RSA_PRIVATE_KEY"), + "SCOPES": { + "openid": "OpenID Connect scope", + # ... any other scopes that you use + }, + # ... any other settings you want + } + +If you are adding OIDC support to an existing OAuth 2.0 provider site, and you +are currently using a custom class for ``OAUTH2_SERVER_CLASS``, you must +change this class to derive from ``oauthlib.openid.Server`` instead of +``oauthlib.oauth2.Server``. + +With ``RSA`` key-pairs, the public key can be generated from the private key, +so there is no need to add a setting for the public key. + +Using ``HS256`` keys +~~~~~~~~~~~~~~~~~~~~ + +If you would prefer to use just ``HS256`` keys, you don't need to create any +additional keys, ``django-oauth-toolkit`` will just use the application's +``client_secret`` to sign the JWT token. + +In this case, you just need to enable OIDC and add ``openid`` to your list of +scopes in your ``settings.py``:: + + OAUTH2_PROVIDER = { + "OIDC_ENABLED": True, + "SCOPES": { + "openid": "OpenID Connect scope", + # ... any other scopes that you use + }, + # ... any other settings you want + } + +.. info:: + If you want to enable ``RS256`` at a later date, you can do so - just add + the private key as described above. + +Setting up OIDC enabled clients +=============================== + +Setting up an OIDC client in ``django-oauth-toolkit`` is simple - in fact, all +existing OAuth 2.0 Authorization Code Flow and Implicit Flow applications that +are already configured can be easily updated to use OIDC by setting the +appropriate algorithm for them to use. + +You can also switch existing apps to use OIDC Hybrid Flow by changing their +Authorization Grant Type and selecting a signing algorithm to use. + +You can read about the pros and cons of the different flows in `this excellent +article`_ from Robert Broeckelmann. + +.. _this excellent article: https://medium.com/@robert.broeckelmann/when-to-use-which-oauth2-grants-and-oidc-flows-ec6a5c00d864 + +OIDC Authorization Code Flow +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To create an OIDC Authorization Code Flow client, create an ``Application`` +with the grant type ``Authorization code`` and select your desired signing +algorithm. + +When making an authorization request, be sure to include ``openid`` as a +scope. When the code is exchanged for the access token, the response will +also contain an ID token JWT. + +If the ``openid`` scope is not requested, authorization requests will be +treated as standard OAuth 2.0 Authorization Code Grant requests. + +With ``PKCE`` enabled, even public clients can use this flow, and it is the most +secure and recommended flow. + +OIDC Implicit Flow +~~~~~~~~~~~~~~~~~~ + +OIDC Implicit Flow is very similar to OAuth 2.0 Implicit Grant, except that +the client can request a ``response_type`` of ``id_token`` or ``id_token +token``. Requesting just ``token`` is also possible, but it would make it not +an OIDC flow and would fall back to being the same as OAuth 2.0 Implicit +Grant. + +To setup an OIDC Implicit Flow client, simply create an ``Application`` with +the a grant type of ``Implicit`` and select your desired signing algorithm, +and configure the client to request the ``openid`` scope and an OIDC +``response_type`` (``id_token`` or ``id_token token``). + + +OIDC Hybrid Flow +~~~~~~~~~~~~~~~~ + +OIDC Hybrid Flow is a mixture of the previous two flows. It allows the ID +token and an access token to be returned to the frontend, whilst also +allowing the backend to retrieve the ID token and an access token (not +necessarily the same access token) on the backend. + +To setup an OIDC Hybrid Flow application, create an ``Application`` with a +grant type of ``OpenID connect hybrid`` and select your desired signing +algorithm. + + +Customizing the OIDC responses +============================== + +This basic configuration will give you a basic working OIDC setup, but your +ID tokens will have very few claims in them, and the ``UserInfo`` service will +just return the same claims as the ID token. + +To configure all of these things we need to customize the +``OAUTH2_VALIDATOR_CLASS`` in ``django-oauth-toolkit``. Create a new file in +our project, eg ``my_project/oauth_validator.py``:: + + from oauth2_provider.oauth2_validators import OAuth2Validator + + + class CustomOAuth2Validator(OAuth2Validator): + pass + + +and then configure our site to use this in our ``settings.py``:: + + OAUTH2_PROVIDER = { + "OAUTH2_VALIDATOR_CLASS": "my_project.oauth_validators.CustomOAuth2Validator", + # ... other settings + } + +Now we can customize the tokens and the responses that are produced by adding +methods to our custom validator. + + +Adding claims to the ID token +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +By default the ID token will just have a ``sub`` claim (in addition to the +required claims, eg ``iss``, ``aud``, ``exp``, ``iat``, ``auth_time`` etc), +and the ``sub`` claim will use the primary key of the user as the value. +You'll probably want to customize this and add additional claims or change +what is sent for the ``sub`` claim. To do so, you will need to add a method to +our custom validator:: + + class CustomOAuth2Validator(OAuth2Validator): + + def get_additional_claims(self, request): + return { + "sub": request.user.email, + "first_name": request.user.first_name, + "last_name": request.user.last_name, + } + +.. note:: + This ``request`` object is not a ``django.http.Request`` object, but an + ``oauthlib.common.Request`` object. This has a number of attributes that + you can use to decide what claims to put in to the ID token: + + * ``request.scopes`` - a list of the scopes requested by the client when + making an authorization request. + * ``request.claims`` - a dictionary of the requested claims, using the + `OIDC claims requesting system`_. These must be requested by the client + when making an authorization request. + * ``request.user`` - the django user object. + +.. _OIDC claims requesting system: https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter + +What claims you decide to put in to the token is up to you to determine based +upon what the scopes and / or claims means to your provider. + + +Adding information to the ``UserInfo`` service +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``UserInfo`` service is supplied as part of the OIDC service, and is used +to retrieve more information about the user than was supplied in the ID token +when the user logged in to the OIDC client application. It is optional to use +the service. The service is accessed by making a request to the +``UserInfo`` endpoint, eg ``/o/userinfo/`` and supplying the access token +retrieved at login as a ``Bearer`` token. + +Again, to modify the content delivered, we need to add a function to our +custom validator. The default implementation adds the claims from the ID +token, so you will probably want to re-use that:: + + class CustomOAuth2Validator(OAuth2Validator): + + def get_userinfo_claims(self, request): + claims = super().get_userinfo_claims() + claims["color_scheme"] = get_color_scheme(request.user) + return claims + + +OIDC Views +========== + +Enabling OIDC support adds three views to ``django-oauth-toolkit``. When OIDC +is not enabled, these views will log that OIDC support is not enabled, and +return a ``404`` response, or if ``DEBUG`` is enabled, raise an +``ImproperlyConfigured`` exception. + +In the docs below, it assumes that you have mounted the +``django-oauth-toolkit`` at ``/o/``. If you have mounted it elsewhere, adjust +the URLs accordingly. + + +ConnectDiscoveryInfoView +~~~~~~~~~~~~~~~~~~~~~~~~ + +Available at ``/o/.well-known/openid-configuration/``, this view provides auto +discovery information to OIDC clients, telling them the JWT issuer to use, the +location of the JWKs to verify JWTs with, the token and userinfo endpoints to +query, and other details. + + +JwksInfoView +~~~~~~~~~~~~ + +Available at ``/o/.well-known/jwks.json``, this view provides details of the key used to sign +the JWTs generated for ID tokens, so that clients are able to verify them. + + +UserInfoView +~~~~~~~~~~~~ + +Available at ``/o/userinfo/``, this view provides extra user details. You can +customize the details included in the response as described above. diff --git a/docs/settings.rst b/docs/settings.rst index be06e83ca..afca76e01 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -124,7 +124,9 @@ Overwrite this value if you wrote your own implementation (subclass of OAUTH2_SERVER_CLASS ~~~~~~~~~~~~~~~~~~~ The import string for the ``server_class`` (or ``oauthlib.oauth2.Server`` subclass) -used in the ``OAuthLibMixin`` that implements OAuth2 grant types. +used in the ``OAuthLibMixin`` that implements OAuth2 grant types. It defaults +to ``oauthlib.oauth2.Server``, except when OIDC support is enabled, when the +default is ``oauthlib.openid.Server``. OAUTH2_VALIDATOR_CLASS ~~~~~~~~~~~~~~~~~~~~~~ @@ -247,3 +249,64 @@ PKCE_REQUIRED Default: ``False`` Whether or not PKCE is required. Can be either a bool or a callable that takes a client id and returns a bool. + + +OIDC_RSA_PRIVATE_KEY +~~~~~~~~~~~~~~~~~~~~ +Default: ``""`` + +The RSA private key used to sign OIDC ID tokens. If not set, OIDC is disabled. + + +OIDC_USERINFO_ENDPOINT +~~~~~~~~~~~~~~~~~~~~~~ +Default: ``""`` + +The url of the userinfo endpoint. Used to advertise the location of the +endpoint in the OIDC discovery metadata. Changing this does not change the URL +that ``django-oauth-toolkit`` adds for the userinfo endpoint, so if you change +this you must also provide the service at that endpoint. + +If unset, the default location is used, eg if ``django-oauth-toolkit`` is +mounted at ``/o/``, it will be ``/o/userinfo/``. + +OIDC_ISS_ENDPOINT +~~~~~~~~~~~~~~~~~ +Default: ``""`` + +The URL of the issuer that is used in the ID token JWT and advertised in the +OIDC discovery metadata. Clients use this location to retrieve the OIDC +discovery metadata from ``OIDC_ISS_ENDPOINT`` + +``/.well-known/openid-configuration/``. + +If unset, the default location is used, eg if ``django-oauth-toolkit`` is +mounted at ``/o``, it will be ``/o``. + +OIDC_RESPONSE_TYPES_SUPPORTED +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Default:: + + [ + "code", + "token", + "id_token", + "id_token token", + "code token", + "code id_token", + "code id_token token", + ] + + +The response types that are advertised to be supported by this server. + +OIDC_SUBJECT_TYPES_SUPPORTED +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Default: ``["public"]`` + +The subject types that are advertised to be supported by this server. + +OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Default: ``["client_secret_post", "client_secret_basic"]`` + +The authentication methods that are advertised to be supported by this server. diff --git a/oauth2_provider/admin.py b/oauth2_provider/admin.py index ed835cd16..79bcf7702 100644 --- a/oauth2_provider/admin.py +++ b/oauth2_provider/admin.py @@ -7,6 +7,8 @@ get_application_model, get_grant_admin_class, get_grant_model, + get_id_token_admin_class, + get_id_token_model, get_refresh_token_admin_class, get_refresh_token_model, ) @@ -32,6 +34,11 @@ class GrantAdmin(admin.ModelAdmin): raw_id_fields = ("user",) +class IDTokenAdmin(admin.ModelAdmin): + list_display = ("jti", "user", "application", "expires") + raw_id_fields = ("user",) + + class RefreshTokenAdmin(admin.ModelAdmin): list_display = ("token", "user", "application") raw_id_fields = ("user", "access_token") @@ -40,14 +47,17 @@ class RefreshTokenAdmin(admin.ModelAdmin): application_model = get_application_model() access_token_model = get_access_token_model() grant_model = get_grant_model() +id_token_model = get_id_token_model() refresh_token_model = get_refresh_token_model() application_admin_class = get_application_admin_class() access_token_admin_class = get_access_token_admin_class() grant_admin_class = get_grant_admin_class() +id_token_admin_class = get_id_token_admin_class() refresh_token_admin_class = get_refresh_token_admin_class() admin.site.register(application_model, application_admin_class) admin.site.register(access_token_model, access_token_admin_class) admin.site.register(grant_model, grant_admin_class) +admin.site.register(id_token_model, id_token_admin_class) admin.site.register(refresh_token_model, refresh_token_admin_class) diff --git a/oauth2_provider/forms.py b/oauth2_provider/forms.py index 2e465959a..876213626 100644 --- a/oauth2_provider/forms.py +++ b/oauth2_provider/forms.py @@ -5,8 +5,10 @@ class AllowForm(forms.Form): allow = forms.BooleanField(required=False) redirect_uri = forms.CharField(widget=forms.HiddenInput()) scope = forms.CharField(widget=forms.HiddenInput()) + nonce = forms.CharField(required=False, widget=forms.HiddenInput()) client_id = forms.CharField(widget=forms.HiddenInput()) state = forms.CharField(required=False, widget=forms.HiddenInput()) response_type = forms.CharField(widget=forms.HiddenInput()) code_challenge = forms.CharField(required=False, widget=forms.HiddenInput()) code_challenge_method = forms.CharField(required=False, widget=forms.HiddenInput()) + claims = forms.CharField(required=False, widget=forms.HiddenInput()) diff --git a/oauth2_provider/migrations/0004_auto_20200902_2022.py b/oauth2_provider/migrations/0004_auto_20200902_2022.py new file mode 100644 index 000000000..81dd20d04 --- /dev/null +++ b/oauth2_provider/migrations/0004_auto_20200902_2022.py @@ -0,0 +1,60 @@ +import uuid + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + +from oauth2_provider.settings import oauth2_settings + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('oauth2_provider', '0003_auto_20201211_1314'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='algorithm', + field=models.CharField(blank=True, choices=[("", "No OIDC support"), ('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='', max_length=5), + ), + migrations.AlterField( + model_name='application', + name='authorization_grant_type', + field=models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32), + ), + migrations.CreateModel( + name='IDToken', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ("jti", models.UUIDField(unique=True, default=uuid.uuid4, editable=False, verbose_name="JWT Token ID")), + ('expires', models.DateTimeField()), + ('scope', models.TextField(blank=True)), + ('created', models.DateTimeField(auto_now_add=True)), + ('updated', models.DateTimeField(auto_now=True)), + ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=oauth2_settings.APPLICATION_MODEL)), + ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='oauth2_provider_idtoken', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'abstract': False, + 'swappable': 'OAUTH2_PROVIDER_ID_TOKEN_MODEL', + }, + ), + migrations.AddField( + model_name='accesstoken', + name='id_token', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=oauth2_settings.ID_TOKEN_MODEL), + ), + migrations.AddField( + model_name="grant", + name="nonce", + field=models.CharField(blank=True, max_length=255, default=""), + ), + migrations.AddField( + model_name="grant", + name="claims", + field=models.TextField(blank=True), + ), + ] diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 835fe24b2..a21cb868b 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -1,4 +1,5 @@ import logging +import uuid from datetime import timedelta from urllib.parse import parse_qsl, urlparse @@ -9,6 +10,8 @@ from django.urls import reverse from django.utils import timezone from django.utils.translation import gettext_lazy as _ +from jwcrypto import jwk +from jwcrypto.common import base64url_encode from .generators import generate_client_id, generate_client_secret from .scopes import get_scopes_backend @@ -51,11 +54,22 @@ class AbstractApplication(models.Model): GRANT_IMPLICIT = "implicit" GRANT_PASSWORD = "password" GRANT_CLIENT_CREDENTIALS = "client-credentials" + GRANT_OPENID_HYBRID = "openid-hybrid" GRANT_TYPES = ( (GRANT_AUTHORIZATION_CODE, _("Authorization code")), (GRANT_IMPLICIT, _("Implicit")), (GRANT_PASSWORD, _("Resource owner password-based")), (GRANT_CLIENT_CREDENTIALS, _("Client credentials")), + (GRANT_OPENID_HYBRID, _("OpenID connect hybrid")), + ) + + NO_ALGORITHM = "" + RS256_ALGORITHM = "RS256" + HS256_ALGORITHM = "HS256" + ALGORITHM_TYPES = ( + (NO_ALGORITHM, _("No OIDC support")), + (RS256_ALGORITHM, _("RSA with SHA-2 256")), + (HS256_ALGORITHM, _("HMAC with SHA-2 256")), ) id = models.BigAutoField(primary_key=True) @@ -82,6 +96,7 @@ class AbstractApplication(models.Model): created = models.DateTimeField(auto_now_add=True) updated = models.DateTimeField(auto_now=True) + algorithm = models.CharField(max_length=5, choices=ALGORITHM_TYPES, default=NO_ALGORITHM, blank=True) class Meta: abstract = True @@ -134,6 +149,11 @@ def clean(self): grant_types = ( AbstractApplication.GRANT_AUTHORIZATION_CODE, AbstractApplication.GRANT_IMPLICIT, + AbstractApplication.GRANT_OPENID_HYBRID, + ) + hs_forbidden_grant_types = ( + AbstractApplication.GRANT_IMPLICIT, + AbstractApplication.GRANT_OPENID_HYBRID, ) redirect_uris = self.redirect_uris.strip().split() @@ -153,6 +173,18 @@ def clean(self): grant_type=self.authorization_grant_type ) ) + if self.algorithm == AbstractApplication.RS256_ALGORITHM: + if not oauth2_settings.OIDC_RSA_PRIVATE_KEY: + raise ValidationError(_("You must set OIDC_RSA_PRIVATE_KEY to use RSA algorithm")) + + if self.algorithm == AbstractApplication.HS256_ALGORITHM: + if any( + ( + self.authorization_grant_type in hs_forbidden_grant_types, + self.client_type == Application.CLIENT_PUBLIC, + ) + ): + raise ValidationError(_("You cannot use HS256 with public grants or clients")) def get_absolute_url(self): return reverse("oauth2_provider:detail", args=[str(self.id)]) @@ -175,6 +207,16 @@ def is_usable(self, request): """ return True + @property + def jwk_key(self): + if self.algorithm == AbstractApplication.RS256_ALGORITHM: + if not oauth2_settings.OIDC_RSA_PRIVATE_KEY: + raise ImproperlyConfigured("You must set OIDC_RSA_PRIVATE_KEY to use RSA algorithm") + return jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + elif self.algorithm == AbstractApplication.HS256_ALGORITHM: + return jwk.JWK(kty="oct", k=base64url_encode(self.client_secret)) + raise ImproperlyConfigured("This application does not support signed tokens") + class ApplicationManager(models.Manager): def get_by_natural_key(self, client_id): @@ -231,6 +273,9 @@ class AbstractGrant(models.Model): max_length=10, blank=True, default="", choices=CODE_CHALLENGE_METHODS ) + nonce = models.CharField(max_length=255, blank=True, default="") + claims = models.TextField(blank=True) + def is_expired(self): """ Check token expiration with timezone awareness @@ -290,6 +335,13 @@ class AbstractAccessToken(models.Model): max_length=255, unique=True, ) + id_token = models.OneToOneField( + oauth2_settings.ID_TOKEN_MODEL, + on_delete=models.CASCADE, + blank=True, + null=True, + related_name="access_token", + ) application = models.ForeignKey( oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE, @@ -430,6 +482,102 @@ class Meta(AbstractRefreshToken.Meta): swappable = "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL" +class AbstractIDToken(models.Model): + """ + An IDToken instance represents the actual token to + access user's resources, as in :openid:`2`. + + Fields: + + * :attr:`user` The Django user representing resources' owner + * :attr:`jti` ID token JWT Token ID, to identify an individual token + * :attr:`application` Application instance + * :attr:`expires` Date and time of token expiration, in DateTime format + * :attr:`scope` Allowed scopes + * :attr:`created` Date and time of token creation, in DateTime format + * :attr:`updated` Date and time of token update, in DateTime format + """ + + id = models.BigAutoField(primary_key=True) + user = models.ForeignKey( + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + blank=True, + null=True, + related_name="%(app_label)s_%(class)s", + ) + jti = models.UUIDField(unique=True, default=uuid.uuid4, editable=False, verbose_name="JWT Token ID") + application = models.ForeignKey( + oauth2_settings.APPLICATION_MODEL, + on_delete=models.CASCADE, + blank=True, + null=True, + ) + expires = models.DateTimeField() + scope = models.TextField(blank=True) + + created = models.DateTimeField(auto_now_add=True) + updated = models.DateTimeField(auto_now=True) + + def is_valid(self, scopes=None): + """ + Checks if the access token is valid. + + :param scopes: An iterable containing the scopes to check or None + """ + return not self.is_expired() and self.allow_scopes(scopes) + + def is_expired(self): + """ + Check token expiration with timezone awareness + """ + if not self.expires: + return True + + return timezone.now() >= self.expires + + def allow_scopes(self, scopes): + """ + Check if the token allows the provided scopes + + :param scopes: An iterable containing the scopes to check + """ + if not scopes: + return True + + provided_scopes = set(self.scope.split()) + resource_scopes = set(scopes) + + return resource_scopes.issubset(provided_scopes) + + def revoke(self): + """ + Convenience method to uniform tokens' interface, for now + simply remove this token from the database in order to revoke it. + """ + self.delete() + + @property + def scopes(self): + """ + Returns a dictionary of allowed scope names (as keys) with their descriptions (as values) + """ + all_scopes = get_scopes_backend().get_all_scopes() + token_scopes = self.scope.split() + return {name: desc for name, desc in all_scopes.items() if name in token_scopes} + + def __str__(self): + return "JTI: {self.jti} User: {self.user_id}".format(self=self) + + class Meta: + abstract = True + + +class IDToken(AbstractIDToken): + class Meta(AbstractIDToken.Meta): + swappable = "OAUTH2_PROVIDER_ID_TOKEN_MODEL" + + def get_application_model(): """ Return the Application model that is active in this project. """ return apps.get_model(oauth2_settings.APPLICATION_MODEL) @@ -445,6 +593,11 @@ def get_access_token_model(): return apps.get_model(oauth2_settings.ACCESS_TOKEN_MODEL) +def get_id_token_model(): + """ Return the AccessToken model that is active in this project. """ + return apps.get_model(oauth2_settings.ID_TOKEN_MODEL) + + def get_refresh_token_model(): """ Return the RefreshToken model that is active in this project. """ return apps.get_model(oauth2_settings.REFRESH_TOKEN_MODEL) @@ -468,6 +621,12 @@ def get_grant_admin_class(): return grant_admin_class +def get_id_token_admin_class(): + """ Return the IDToken admin class that is active in this project. """ + id_token_admin_class = oauth2_settings.ID_TOKEN_ADMIN_CLASS + return id_token_admin_class + + def get_refresh_token_admin_class(): """ Return the RefreshToken admin class that is active in this project. """ refresh_token_admin_class = oauth2_settings.REFRESH_TOKEN_ADMIN_CLASS diff --git a/oauth2_provider/oauth2_backends.py b/oauth2_provider/oauth2_backends.py index 34b1c62cd..dbebd3a8e 100644 --- a/oauth2_provider/oauth2_backends.py +++ b/oauth2_provider/oauth2_backends.py @@ -4,6 +4,7 @@ from oauthlib import oauth2 from oauthlib.common import Request as OauthlibRequest from oauthlib.common import quote, urlencode, urlencoded +from oauthlib.oauth2 import OAuth2Error from .exceptions import FatalClientError, OAuthToolkitError from .settings import oauth2_settings @@ -74,6 +75,10 @@ def extract_headers(self, request): del headers["wsgi.errors"] if "HTTP_AUTHORIZATION" in headers: headers["Authorization"] = headers["HTTP_AUTHORIZATION"] + if request.is_secure(): + headers["X_DJANGO_OAUTH_TOOLKIT_SECURE"] = "1" + elif "X_DJANGO_OAUTH_TOOLKIT_SECURE" in headers: + del headers["X_DJANGO_OAUTH_TOOLKIT_SECURE"] return headers @@ -120,9 +125,14 @@ def create_authorization_response(self, request, scopes, credentials, allow): # add current user to credentials. this will be used by OAUTH2_VALIDATOR_CLASS credentials["user"] = request.user + request_uri, http_method, _, request_headers = self._extract_params(request) headers, body, status = self.server.create_authorization_response( - uri=credentials["redirect_uri"], scopes=scopes, credentials=credentials + uri=request_uri, + http_method=http_method, + headers=request_headers, + scopes=scopes, + credentials=credentials, ) uri = headers.get("Location", None) @@ -163,6 +173,21 @@ def create_revocation_response(self, request): return uri, headers, body, status + def create_userinfo_response(self, request): + """ + A wrapper method that calls create_userinfo_response on a + `server_class` instance. + + :param request: The current django.http.HttpRequest object + """ + uri, http_method, body, headers = self._extract_params(request) + try: + headers, body, status = self.server.create_userinfo_response(uri, http_method, body, headers) + uri = headers.get("Location", None) + return uri, headers, body, status + except OAuth2Error as exc: + return None, exc.headers, exc.json, exc.status_code + def verify_request(self, request, scopes): """ A wrapper method that calls verify_request on `server_class` instance. diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index de707bb21..f91c06011 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -1,7 +1,9 @@ import base64 import binascii import http.client +import json import logging +import uuid from collections import OrderedDict from datetime import datetime, timedelta from urllib.parse import unquote_plus @@ -12,10 +14,14 @@ from django.core.exceptions import ObjectDoesNotExist from django.db import transaction from django.db.models import Q -from django.utils import timezone +from django.utils import dateformat, timezone from django.utils.timezone import make_aware from django.utils.translation import gettext_lazy as _ -from oauthlib.oauth2 import RequestValidator +from jwcrypto import jws, jwt +from jwcrypto.common import JWException +from jwcrypto.jwt import JWTExpired +from oauthlib.oauth2.rfc6749 import utils +from oauthlib.openid import RequestValidator from .exceptions import FatalClientError from .models import ( @@ -23,6 +29,7 @@ get_access_token_model, get_application_model, get_grant_model, + get_id_token_model, get_refresh_token_model, ) from .scopes import get_scopes_backend @@ -32,18 +39,23 @@ log = logging.getLogger("oauth2_provider") GRANT_TYPE_MAPPING = { - "authorization_code": (AbstractApplication.GRANT_AUTHORIZATION_CODE,), + "authorization_code": ( + AbstractApplication.GRANT_AUTHORIZATION_CODE, + AbstractApplication.GRANT_OPENID_HYBRID, + ), "password": (AbstractApplication.GRANT_PASSWORD,), "client_credentials": (AbstractApplication.GRANT_CLIENT_CREDENTIALS,), "refresh_token": ( AbstractApplication.GRANT_AUTHORIZATION_CODE, AbstractApplication.GRANT_PASSWORD, AbstractApplication.GRANT_CLIENT_CREDENTIALS, + AbstractApplication.GRANT_OPENID_HYBRID, ), } Application = get_application_model() AccessToken = get_access_token_model() +IDToken = get_id_token_model() Grant = get_grant_model() RefreshToken = get_refresh_token_model() UserModel = get_user_model() @@ -370,10 +382,7 @@ def validate_bearer_token(self, token, scopes, request): introspection_token = oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN introspection_credentials = oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS - try: - access_token = AccessToken.objects.select_related("application", "user").get(token=token) - except AccessToken.DoesNotExist: - access_token = None + access_token = self._load_access_token(token) # if there is no token or it's invalid then introspect the token if there's an external OAuth server if not access_token or not access_token.is_valid(scopes): @@ -394,12 +403,19 @@ def validate_bearer_token(self, token, scopes, request): self._set_oauth2_error_on_request(request, access_token, scopes) return False + def _load_access_token(self, token): + return AccessToken.objects.select_related("application", "user").filter(token=token).first() + def validate_code(self, client_id, code, client, request, *args, **kwargs): try: grant = Grant.objects.get(code=code, application=client) if not grant.is_expired(): request.scopes = grant.scope.split(" ") request.user = grant.user + if grant.nonce: + request.nonce = grant.nonce + if grant.claims: + request.claims = json.loads(grant.claims) return True return False @@ -422,6 +438,16 @@ def validate_response_type(self, client_id, response_type, client, request, *arg return client.allows_grant_type(AbstractApplication.GRANT_AUTHORIZATION_CODE) elif response_type == "token": return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) + elif response_type == "id_token": + return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) + elif response_type == "id_token token": + return client.allows_grant_type(AbstractApplication.GRANT_IMPLICIT) + elif response_type == "code id_token": + return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) + elif response_type == "code token": + return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) + elif response_type == "code id_token token": + return client.allows_grant_type(AbstractApplication.GRANT_OPENID_HYBRID) else: return False @@ -461,6 +487,12 @@ def get_code_challenge_method(self, code, request): def save_authorization_code(self, client_id, code, request, *args, **kwargs): self._create_authorization_code(request, code) + def get_authorization_code_scopes(self, client_id, code, redirect_uri, request): + scopes = Grant.objects.filter(code=code).values_list("scope", flat=True).first() + if scopes: + return utils.scope_to_list(scopes) + return [] + def rotate_refresh_token(self, request): """ Checks if rotate refresh token is enabled @@ -570,11 +602,15 @@ def save_bearer_token(self, token, request, *args, **kwargs): self._create_access_token(expires, request, token) def _create_access_token(self, expires, request, token, source_refresh_token=None): + id_token = token.get("id_token", None) + if id_token: + id_token = self._load_id_token(id_token) return AccessToken.objects.create( user=request.user, scope=token["scope"], expires=expires, token=token["access_token"], + id_token=id_token, application=request.client, source_refresh_token=source_refresh_token, ) @@ -582,7 +618,6 @@ def _create_access_token(self, expires, request, token, source_refresh_token=Non def _create_authorization_code(self, request, code, expires=None): if not expires: expires = timezone.now() + timedelta(seconds=oauth2_settings.AUTHORIZATION_CODE_EXPIRE_SECONDS) - return Grant.objects.create( application=request.client, user=request.user, @@ -592,6 +627,8 @@ def _create_authorization_code(self, request, code, expires=None): scope=" ".join(request.scopes), code_challenge=request.code_challenge or "", code_challenge_method=request.code_challenge_method or "", + nonce=request.nonce or "", + claims=json.dumps(request.claims or {}), ) def _create_refresh_token(self, request, refresh_token_code, access_token): @@ -665,3 +702,183 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs # Temporary store RefreshToken instance to be reused by get_original_scopes and save_bearer_token. request.refresh_token_instance = rt return rt.application == client + + @transaction.atomic + def _save_id_token(self, jti, request, expires, *args, **kwargs): + scopes = request.scope or " ".join(request.scopes) + + id_token = IDToken.objects.create( + user=request.user, + scope=scopes, + expires=expires, + jti=jti, + application=request.client, + ) + return id_token + + def get_jwt_bearer_token(self, token, token_handler, request): + return self.get_id_token(token, token_handler, request) + + def get_oidc_claims(self, token, token_handler, request): + # Required OIDC claims + claims = { + "sub": str(request.user.id), + } + + # https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims + claims.update(**self.get_additional_claims(request)) + + return claims + + def get_id_token_dictionary(self, token, token_handler, request): + """ + Get the claims to put in the ID Token. + + These claims are in addition to the claims automatically added by + ``oauthlib`` - aud, iat, nonce, at_hash, c_hash. + + This function adds in iss, exp and auth_time, plus any claims added from + calling ``get_oidc_claims()`` + """ + claims = self.get_oidc_claims(token, token_handler, request) + + expiration_time = timezone.now() + timedelta(seconds=oauth2_settings.ID_TOKEN_EXPIRE_SECONDS) + # Required ID Token claims + claims.update( + **{ + "iss": self.get_oidc_issuer_endpoint(request), + "exp": int(dateformat.format(expiration_time, "U")), + "auth_time": int(dateformat.format(request.user.last_login, "U")), + "jti": str(uuid.uuid4()), + } + ) + + return claims, expiration_time + + def get_oidc_issuer_endpoint(self, request): + return oauth2_settings.oidc_issuer(request) + + def finalize_id_token(self, id_token, token, token_handler, request): + claims, expiration_time = self.get_id_token_dictionary(token, token_handler, request) + id_token.update(**claims) + # Workaround for oauthlib bug #746 + # https://github.com/oauthlib/oauthlib/issues/746 + if "nonce" not in id_token and request.nonce: + id_token["nonce"] = request.nonce + + header = { + "typ": "JWT", + "alg": request.client.algorithm, + } + # RS256 consumers expect a kid in the header for verifying the token + if request.client.algorithm == AbstractApplication.RS256_ALGORITHM: + header["kid"] = request.client.jwk_key.thumbprint() + + jwt_token = jwt.JWT( + header=json.dumps(header, default=str), + claims=json.dumps(id_token, default=str), + ) + jwt_token.make_signed_token(request.client.jwk_key) + id_token = self._save_id_token(id_token["jti"], request, expiration_time) + # this is needed by django rest framework + request.access_token = id_token + request.id_token = id_token + return jwt_token.serialize() + + def validate_jwt_bearer_token(self, token, scopes, request): + return self.validate_id_token(token, scopes, request) + + def validate_id_token(self, token, scopes, request): + """ + When users try to access resources, check that provided id_token is valid + """ + if not token: + return False + + id_token = self._load_id_token(token) + if not id_token: + return False + + if not id_token.allow_scopes(scopes): + return False + + request.client = id_token.application + request.user = id_token.user + request.scopes = scopes + # this is needed by django rest framework + request.access_token = id_token + return True + + def _load_id_token(self, token): + key = self._get_key_for_token(token) + if not key: + return None + try: + jwt_token = jwt.JWT(key=key, jwt=token) + claims = json.loads(jwt_token.claims) + return IDToken.objects.get(jti=claims["jti"]) + except (JWException, JWTExpired, IDToken.DoesNotExist): + return None + + def _get_key_for_token(self, token): + """ + Peek at the unvalidated token to discover who it was issued for + and then use that to load that application and its key. + """ + unverified_token = jws.JWS() + unverified_token.deserialize(token) + claims = json.loads(unverified_token.objects["payload"].decode("utf-8")) + if "aud" not in claims: + return None + application = self._get_client_by_audience(claims["aud"]) + if application: + return application.jwk_key + + def _get_client_by_audience(self, audience): + """ + Load a client by the aud claim in a JWT. + aud may be multi-valued, if your provider makes it so. + This function is separate to allow further customization. + """ + if isinstance(audience, str): + audience = [audience] + return Application.objects.filter(client_id__in=audience).first() + + def validate_user_match(self, id_token_hint, scopes, claims, request): + # TODO: Fix to validate when necessary acording + # https://github.com/idan/oauthlib/blob/master/oauthlib/oauth2/rfc6749/request_validator.py#L556 + # http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest id_token_hint section + return True + + def get_authorization_code_nonce(self, client_id, code, redirect_uri, request): + """Extracts nonce from saved authorization code. + If present in the Authentication Request, Authorization + Servers MUST include a nonce Claim in the ID Token with the + Claim Value being the nonce value sent in the Authentication + Request. Authorization Servers SHOULD perform no other + processing on nonce values used. The nonce value is a + case-sensitive string. + Only code param should be sufficient to retrieve grant code from + any storage you are using. However, `client_id` and `redirect_uri` + have been validated and can be used also. + :param client_id: Unicode client identifier + :param code: Unicode authorization code grant + :param redirect_uri: Unicode absolute URI + :return: Unicode nonce + Method is used by: + - Authorization Token Grant Dispatcher + """ + nonce = Grant.objects.filter(code=code).values_list("nonce", flat=True).first() + if nonce: + return nonce + + def get_userinfo_claims(self, request): + """ + Generates and saves a new JWT for this request, and returns it as the + current user's claims. + + """ + return self.get_oidc_claims(None, None, request) + + def get_additional_claims(self, request): + return {} diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index 5d81a05ef..b862fca7a 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -18,14 +18,18 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured +from django.http import HttpRequest from django.test.signals import setting_changed +from django.urls import reverse from django.utils.module_loading import import_string +from oauthlib.common import Request USER_SETTINGS = getattr(settings, "OAUTH2_PROVIDER", None) APPLICATION_MODEL = getattr(settings, "OAUTH2_PROVIDER_APPLICATION_MODEL", "oauth2_provider.Application") ACCESS_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL", "oauth2_provider.AccessToken") +ID_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_ID_TOKEN_MODEL", "oauth2_provider.IDToken") GRANT_MODEL = getattr(settings, "OAUTH2_PROVIDER_GRANT_MODEL", "oauth2_provider.Grant") REFRESH_TOKEN_MODEL = getattr(settings, "OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL", "oauth2_provider.RefreshToken") @@ -37,6 +41,7 @@ "REFRESH_TOKEN_GENERATOR": None, "EXTRA_SERVER_KWARGS": {}, "OAUTH2_SERVER_CLASS": "oauthlib.oauth2.Server", + "OIDC_SERVER_CLASS": "oauthlib.openid.Server", "OAUTH2_VALIDATOR_CLASS": "oauth2_provider.oauth2_validators.OAuth2Validator", "OAUTH2_BACKEND_CLASS": "oauth2_provider.oauth2_backends.OAuthLibCore", "SCOPES": {"read": "Reading scope", "write": "Writing scope"}, @@ -46,20 +51,41 @@ "WRITE_SCOPE": "write", "AUTHORIZATION_CODE_EXPIRE_SECONDS": 60, "ACCESS_TOKEN_EXPIRE_SECONDS": 36000, + "ID_TOKEN_EXPIRE_SECONDS": 36000, "REFRESH_TOKEN_EXPIRE_SECONDS": None, "REFRESH_TOKEN_GRACE_PERIOD_SECONDS": 0, "ROTATE_REFRESH_TOKEN": True, "ERROR_RESPONSE_WITH_SCOPES": False, "APPLICATION_MODEL": APPLICATION_MODEL, "ACCESS_TOKEN_MODEL": ACCESS_TOKEN_MODEL, + "ID_TOKEN_MODEL": ID_TOKEN_MODEL, "GRANT_MODEL": GRANT_MODEL, "REFRESH_TOKEN_MODEL": REFRESH_TOKEN_MODEL, "APPLICATION_ADMIN_CLASS": "oauth2_provider.admin.ApplicationAdmin", "ACCESS_TOKEN_ADMIN_CLASS": "oauth2_provider.admin.AccessTokenAdmin", "GRANT_ADMIN_CLASS": "oauth2_provider.admin.GrantAdmin", + "ID_TOKEN_ADMIN_CLASS": "oauth2_provider.admin.IDTokenAdmin", "REFRESH_TOKEN_ADMIN_CLASS": "oauth2_provider.admin.RefreshTokenAdmin", "REQUEST_APPROVAL_PROMPT": "force", "ALLOWED_REDIRECT_URI_SCHEMES": ["http", "https"], + "OIDC_ENABLED": False, + "OIDC_ISS_ENDPOINT": "", + "OIDC_USERINFO_ENDPOINT": "", + "OIDC_RSA_PRIVATE_KEY": "", + "OIDC_RESPONSE_TYPES_SUPPORTED": [ + "code", + "token", + "id_token", + "id_token token", + "code token", + "code id_token", + "code id_token token", + ], + "OIDC_SUBJECT_TYPES_SUPPORTED": ["public"], + "OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED": [ + "client_secret_post", + "client_secret_basic", + ], # Special settings that will be evaluated at runtime "_SCOPES": [], "_DEFAULT_SCOPES": [], @@ -70,6 +96,9 @@ "RESOURCE_SERVER_TOKEN_CACHING_SECONDS": 36000, # Whether or not PKCE is required "PKCE_REQUIRED": False, + # Whether to re-create OAuthlibCore on every request. + # Should only be required in testing. + "ALWAYS_RELOAD_OAUTHLIB_CORE": False, } # List of settings that cannot be empty @@ -81,6 +110,9 @@ "OAUTH2_BACKEND_CLASS", "SCOPES", "ALLOWED_REDIRECT_URI_SCHEMES", + "OIDC_RESPONSE_TYPES_SUPPORTED", + "OIDC_SUBJECT_TYPES_SUPPORTED", + "OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED", ) # List of settings that may be in string import notation. @@ -96,6 +128,7 @@ "APPLICATION_ADMIN_CLASS", "ACCESS_TOKEN_ADMIN_CLASS", "GRANT_ADMIN_CLASS", + "ID_TOKEN_ADMIN_CLASS", "REFRESH_TOKEN_ADMIN_CLASS", ) @@ -125,6 +158,13 @@ def import_from_string(val, setting_name): raise ImportError(msg) +class _PhonyHttpRequest(HttpRequest): + _scheme = "http" + + def _get_scheme(self): + return self._scheme + + class OAuth2ProviderSettings: """ A settings object, that allows OAuth2 Provider settings to be accessed as properties. @@ -149,13 +189,17 @@ def user_settings(self): def __getattr__(self, attr): if attr not in self.defaults: raise AttributeError("Invalid OAuth2Provider setting: %s" % attr) - try: # Check if present in user settings val = self.user_settings[attr] except KeyError: # Fall back to defaults - val = self.defaults[attr] + # Special case OAUTH2_SERVER_CLASS - if not specified, and OIDC is + # enabled, use the OIDC_SERVER_CLASS setting instead + if attr == "OAUTH2_SERVER_CLASS" and self.OIDC_ENABLED: + val = self.defaults["OIDC_SERVER_CLASS"] + else: + val = self.defaults[attr] # Coerce import strings into classes if val and attr in self.import_strings: @@ -221,6 +265,28 @@ def reload(self): if hasattr(self, "_user_settings"): delattr(self, "_user_settings") + def oidc_issuer(self, request): + """ + Helper function to get the OIDC issuer URL, either from the settings + or constructing it from the passed request. + + If only an oauthlib request is available, a dummy django request is + built from that and used to generate the URL. + """ + if self.OIDC_ISS_ENDPOINT: + return self.OIDC_ISS_ENDPOINT + if isinstance(request, HttpRequest): + django_request = request + elif isinstance(request, Request): + django_request = _PhonyHttpRequest() + django_request.META = request.headers + if request.headers.get("X_DJANGO_OAUTH_TOOLKIT_SECURE", False): + django_request._scheme = "https" + else: + raise TypeError("request must be a django or oauthlib request: got %r" % request) + abs_url = django_request.build_absolute_uri(reverse("oauth2_provider:oidc-connect-discovery-info")) + return abs_url[: -len("/.well-known/openid-configuration/")] + oauth2_settings = OAuth2ProviderSettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS, MANDATORY) diff --git a/oauth2_provider/urls.py b/oauth2_provider/urls.py index c7ae526f0..508f97c96 100644 --- a/oauth2_provider/urls.py +++ b/oauth2_provider/urls.py @@ -30,5 +30,15 @@ ), ] +oidc_urlpatterns = [ + re_path( + r"^\.well-known/openid-configuration/$", + views.ConnectDiscoveryInfoView.as_view(), + name="oidc-connect-discovery-info", + ), + re_path(r"^\.well-known/jwks.json$", views.JwksInfoView.as_view(), name="jwks-info"), + re_path(r"^userinfo/$", views.UserInfoView.as_view(), name="user-info"), +] + -urlpatterns = base_urlpatterns + management_urlpatterns +urlpatterns = base_urlpatterns + management_urlpatterns + oidc_urlpatterns diff --git a/oauth2_provider/views/__init__.py b/oauth2_provider/views/__init__.py index 6d5d74c67..0720c1aa2 100644 --- a/oauth2_provider/views/__init__.py +++ b/oauth2_provider/views/__init__.py @@ -15,4 +15,5 @@ ScopedProtectedResourceView, ) from .introspect import IntrospectTokenView +from .oidc import ConnectDiscoveryInfoView, JwksInfoView, UserInfoView from .token import AuthorizedTokenDeleteView, AuthorizedTokensListView diff --git a/oauth2_provider/views/application.py b/oauth2_provider/views/application.py index 186097ae4..e9a21a99f 100644 --- a/oauth2_provider/views/application.py +++ b/oauth2_provider/views/application.py @@ -37,6 +37,7 @@ def get_form_class(self): "client_type", "authorization_grant_type", "redirect_uris", + "algorithm", ), ) @@ -94,5 +95,6 @@ def get_form_class(self): "client_type", "authorization_grant_type", "redirect_uris", + "algorithm", ), ) diff --git a/oauth2_provider/views/base.py b/oauth2_provider/views/base.py index 104413787..e46a49d10 100644 --- a/oauth2_provider/views/base.py +++ b/oauth2_provider/views/base.py @@ -90,10 +90,6 @@ class AuthorizationView(BaseAuthorizationView, FormView): template_name = "oauth2_provider/authorize.html" form_class = AllowForm - server_class = oauth2_settings.OAUTH2_SERVER_CLASS - validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS - oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS - skip_authorization_completely = False def get_initial(self): @@ -102,11 +98,13 @@ def get_initial(self): initial_data = { "redirect_uri": self.oauth2_data.get("redirect_uri", None), "scope": " ".join(scopes), + "nonce": self.oauth2_data.get("nonce", None), "client_id": self.oauth2_data.get("client_id", None), "state": self.oauth2_data.get("state", None), "response_type": self.oauth2_data.get("response_type", None), "code_challenge": self.oauth2_data.get("code_challenge", None), "code_challenge_method": self.oauth2_data.get("code_challenge_method", None), + "claims": self.oauth2_data.get("claims", None), } return initial_data @@ -123,6 +121,11 @@ def form_valid(self, form): credentials["code_challenge"] = form.cleaned_data.get("code_challenge") if form.cleaned_data.get("code_challenge_method", False): credentials["code_challenge_method"] = form.cleaned_data.get("code_challenge_method") + if form.cleaned_data.get("nonce", False): + credentials["nonce"] = form.cleaned_data.get("nonce") + if form.cleaned_data.get("claims", False): + credentials["claims"] = form.cleaned_data.get("claims") + scopes = form.cleaned_data.get("scope") allow = form.cleaned_data.get("allow") @@ -161,6 +164,10 @@ def get(self, request, *args, **kwargs): kwargs["code_challenge"] = credentials["code_challenge"] if "code_challenge_method" in credentials: kwargs["code_challenge_method"] = credentials["code_challenge_method"] + if "nonce" in credentials: + kwargs["nonce"] = credentials["nonce"] + if "claims" in credentials: + kwargs["claims"] = json.dumps(credentials["claims"]) self.oauth2_data = kwargs # following two loc are here only because of https://code.djangoproject.com/ticket/17795 @@ -195,7 +202,10 @@ def get(self, request, *args, **kwargs): for token in tokens: if token.allow_scopes(scopes): uri, headers, body, status = self.create_authorization_response( - request=self.request, scopes=" ".join(scopes), credentials=credentials, allow=True + request=self.request, + scopes=" ".join(scopes), + credentials=credentials, + allow=True, ) return self.redirect(uri, application, token) @@ -245,10 +255,6 @@ class TokenView(OAuthLibMixin, View): * Client credentials """ - server_class = oauth2_settings.OAUTH2_SERVER_CLASS - validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS - oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS - @method_decorator(sensitive_post_parameters("password")) def post(self, request, *args, **kwargs): url, headers, body, status = self.create_token_response(request) @@ -270,10 +276,6 @@ class RevokeTokenView(OAuthLibMixin, View): Implements an endpoint to revoke access or refresh tokens """ - server_class = oauth2_settings.OAUTH2_SERVER_CLASS - validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS - oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS - def post(self, request, *args, **kwargs): url, headers, body, status = self.create_revocation_response(request) response = HttpResponse(content=body or "", status=status) diff --git a/oauth2_provider/views/generic.py b/oauth2_provider/views/generic.py index 10e84d59f..da675eac4 100644 --- a/oauth2_provider/views/generic.py +++ b/oauth2_provider/views/generic.py @@ -1,6 +1,5 @@ from django.views.generic import View -from ..settings import oauth2_settings from .mixins import ( ClientProtectedResourceMixin, OAuthLibMixin, @@ -10,16 +9,7 @@ ) -class InitializationMixin(OAuthLibMixin): - - """Initializer for OauthLibMixin""" - - server_class = oauth2_settings.OAUTH2_SERVER_CLASS - validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS - oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS - - -class ProtectedResourceView(ProtectedResourceMixin, InitializationMixin, View): +class ProtectedResourceView(ProtectedResourceMixin, OAuthLibMixin, View): """ Generic view protecting resources by providing OAuth2 authentication out of the box """ @@ -45,7 +35,7 @@ class ReadWriteScopedResourceView(ReadWriteScopedResourceMixin, ProtectedResourc pass -class ClientProtectedResourceView(ClientProtectedResourceMixin, InitializationMixin, View): +class ClientProtectedResourceView(ClientProtectedResourceMixin, OAuthLibMixin, View): """View for protecting a resource with client-credentials method. This involves allowing access tokens, Basic Auth and plain credentials in request body. diff --git a/oauth2_provider/views/introspect.py b/oauth2_provider/views/introspect.py index d29605097..afb8ac627 100644 --- a/oauth2_provider/views/introspect.py +++ b/oauth2_provider/views/introspect.py @@ -7,7 +7,7 @@ from django.views.decorators.csrf import csrf_exempt from oauth2_provider.models import get_access_token_model -from oauth2_provider.views import ClientProtectedScopedResourceView +from oauth2_provider.views.generic import ClientProtectedScopedResourceView @method_decorator(csrf_exempt, name="dispatch") diff --git a/oauth2_provider/views/mixins.py b/oauth2_provider/views/mixins.py index 0a0c66ea9..477d24e24 100644 --- a/oauth2_provider/views/mixins.py +++ b/oauth2_provider/views/mixins.py @@ -1,7 +1,8 @@ import logging +from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from django.http import HttpResponseForbidden +from django.http import HttpResponseForbidden, HttpResponseNotFound from ..exceptions import FatalClientError from ..scopes import get_scopes_backend @@ -25,6 +26,9 @@ class OAuthLibMixin: * validator_class * oauthlib_backend_class + If these class variables are not set, it will fall back to using the classes + specified in oauth2_settings (OAUTH2_SERVER_CLASS, OAUTH2_VALIDATOR_CLASS + and OAUTH2_BACKEND_CLASS). """ server_class = None @@ -37,10 +41,7 @@ def get_server_class(cls): Return the OAuthlib server class to use """ if cls.server_class is None: - raise ImproperlyConfigured( - "OAuthLibMixin requires either a definition of 'server_class'" - " or an implementation of 'get_server_class()'" - ) + return oauth2_settings.OAUTH2_SERVER_CLASS else: return cls.server_class @@ -50,10 +51,7 @@ def get_validator_class(cls): Return the RequestValidator implementation class to use """ if cls.validator_class is None: - raise ImproperlyConfigured( - "OAuthLibMixin requires either a definition of 'validator_class'" - " or an implementation of 'get_validator_class()'" - ) + return oauth2_settings.OAUTH2_VALIDATOR_CLASS else: return cls.validator_class @@ -63,10 +61,7 @@ def get_oauthlib_backend_class(cls): Return the OAuthLibCore implementation class to use """ if cls.oauthlib_backend_class is None: - raise ImproperlyConfigured( - "OAuthLibMixin requires either a definition of 'oauthlib_backend_class'" - " or an implementation of 'get_oauthlib_backend_class()'" - ) + return oauth2_settings.OAUTH2_BACKEND_CLASS else: return cls.oauthlib_backend_class @@ -85,8 +80,9 @@ def get_server(cls): def get_oauthlib_core(cls): """ Cache and return `OAuthlibCore` instance so it will be created only on first request + unless ALWAYS_RELOAD_OAUTHLIB_CORE is True. """ - if not hasattr(cls, "_oauthlib_core"): + if not hasattr(cls, "_oauthlib_core") or oauth2_settings.ALWAYS_RELOAD_OAUTHLIB_CORE: server = cls.get_server() core_class = cls.get_oauthlib_backend_class() cls._oauthlib_core = core_class(server) @@ -109,7 +105,7 @@ def create_authorization_response(self, request, scopes, credentials, allow): :param request: The current django.http.HttpRequest object :param scopes: A space-separated string of provided scopes :param credentials: Authorization credentials dictionary containing - `client_id`, `state`, `redirect_uri`, `response_type` + `client_id`, `state`, `redirect_uri` and `response_type` :param allow: True if the user authorize the client, otherwise False """ # TODO: move this scopes conversion from and to string into a utils function @@ -137,6 +133,16 @@ def create_revocation_response(self, request): core = self.get_oauthlib_core() return core.create_revocation_response(request) + def create_userinfo_response(self, request): + """ + A wrapper method that calls create_userinfo_response on the + `server_class` instance. + + :param request: The current django.http.HttpRequest object + """ + core = self.get_oauthlib_core() + return core.create_userinfo_response(request) + def verify_request(self, request): """ A wrapper method that calls verify_request on `server_class` instance. @@ -286,7 +292,30 @@ def dispatch(self, request, *args, **kwargs): if valid: request.resource_owner = r.user return super().dispatch(request, *args, **kwargs) - else: - return HttpResponseForbidden() + return HttpResponseForbidden() else: return super().dispatch(request, *args, **kwargs) + + +class OIDCOnlyMixin: + """ + Mixin for views that should only be accessible when OIDC is enabled. + + If OIDC is not enabled: + + * if DEBUG is True, raises an ImproperlyConfigured exception explaining why + * otherwise, returns a 404 response, logging the same warning + """ + + debug_error_message = ( + "django-oauth-toolkit OIDC views are not enabled unless you " + "have configured OIDC_ENABLED in the settings" + ) + + def dispatch(self, *args, **kwargs): + if not oauth2_settings.OIDC_ENABLED: + if settings.DEBUG: + raise ImproperlyConfigured(self.debug_error_message) + log.warning(self.debug_error_message) + return HttpResponseNotFound() + return super().dispatch(*args, **kwargs) diff --git a/oauth2_provider/views/oidc.py b/oauth2_provider/views/oidc.py new file mode 100644 index 000000000..ac3a2a172 --- /dev/null +++ b/oauth2_provider/views/oidc.py @@ -0,0 +1,97 @@ +import json + +from django.http import HttpResponse, JsonResponse +from django.urls import reverse +from django.utils.decorators import method_decorator +from django.views.decorators.csrf import csrf_exempt +from django.views.generic import View +from jwcrypto import jwk + +from ..models import get_application_model +from ..settings import oauth2_settings +from .mixins import OAuthLibMixin, OIDCOnlyMixin + + +Application = get_application_model() + + +class ConnectDiscoveryInfoView(OIDCOnlyMixin, View): + """ + View used to show oidc provider configuration information + """ + + def get(self, request, *args, **kwargs): + issuer_url = oauth2_settings.OIDC_ISS_ENDPOINT + + if not issuer_url: + issuer_url = oauth2_settings.oidc_issuer(request) + authorization_endpoint = request.build_absolute_uri(reverse("oauth2_provider:authorize")) + token_endpoint = request.build_absolute_uri(reverse("oauth2_provider:token")) + userinfo_endpoint = oauth2_settings.OIDC_USERINFO_ENDPOINT or request.build_absolute_uri( + reverse("oauth2_provider:user-info") + ) + jwks_uri = request.build_absolute_uri(reverse("oauth2_provider:jwks-info")) + else: + authorization_endpoint = "{}{}".format(issuer_url, reverse("oauth2_provider:authorize")) + token_endpoint = "{}{}".format(issuer_url, reverse("oauth2_provider:token")) + userinfo_endpoint = oauth2_settings.OIDC_USERINFO_ENDPOINT or "{}{}".format( + issuer_url, reverse("oauth2_provider:user-info") + ) + jwks_uri = "{}{}".format(issuer_url, reverse("oauth2_provider:jwks-info")) + signing_algorithms = [Application.HS256_ALGORITHM] + if oauth2_settings.OIDC_RSA_PRIVATE_KEY: + signing_algorithms = [Application.RS256_ALGORITHM, Application.HS256_ALGORITHM] + data = { + "issuer": issuer_url, + "authorization_endpoint": authorization_endpoint, + "token_endpoint": token_endpoint, + "userinfo_endpoint": userinfo_endpoint, + "jwks_uri": jwks_uri, + "response_types_supported": oauth2_settings.OIDC_RESPONSE_TYPES_SUPPORTED, + "subject_types_supported": oauth2_settings.OIDC_SUBJECT_TYPES_SUPPORTED, + "id_token_signing_alg_values_supported": signing_algorithms, + "token_endpoint_auth_methods_supported": ( + oauth2_settings.OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED + ), + } + response = JsonResponse(data) + response["Access-Control-Allow-Origin"] = "*" + return response + + +class JwksInfoView(OIDCOnlyMixin, View): + """ + View used to show oidc json web key set document + """ + + def get(self, request, *args, **kwargs): + keys = [] + if oauth2_settings.OIDC_RSA_PRIVATE_KEY: + key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + data = {"alg": "RS256", "use": "sig", "kid": key.thumbprint()} + data.update(json.loads(key.export_public())) + keys.append(data) + response = JsonResponse({"keys": keys}) + response["Access-Control-Allow-Origin"] = "*" + return response + + +@method_decorator(csrf_exempt, name="dispatch") +class UserInfoView(OIDCOnlyMixin, OAuthLibMixin, View): + """ + View used to show Claims about the authenticated End-User + """ + + def get(self, request, *args, **kwargs): + return self._create_userinfo_response(request) + + def post(self, request, *args, **kwargs): + return self._create_userinfo_response(request) + + def _create_userinfo_response(self, request): + url, headers, body, status = self.create_userinfo_response(request) + response = HttpResponse(content=body or "", status=status) + + for k, v in headers.items(): + response[k] = v + return response diff --git a/setup.cfg b/setup.cfg index 22e81675e..03d614a7f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,10 +29,14 @@ classifiers = packages = find: include_package_data = True zip_safe = False +# jwcrypto has a direct dependency on six, but does not list it yet in a release +# Previously, cryptography also depended on six, so this was unnoticed install_requires = django >= 2.2 requests >= 2.13.0 oauthlib >= 3.1.0 + jwcrypto >= 0.8.0 + six [options.packages.find] exclude = tests diff --git a/tests/admin.py b/tests/admin.py index 557434250..f071769ee 100644 --- a/tests/admin.py +++ b/tests/admin.py @@ -13,5 +13,9 @@ class CustomGrantAdmin(admin.ModelAdmin): list_display = ("id",) +class CustomIDTokenAdmin(admin.ModelAdmin): + list_display = ("id",) + + class CustomRefreshTokenAdmin(admin.ModelAdmin): list_display = ("id",) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..a3274aa33 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,156 @@ +from types import SimpleNamespace +from urllib.parse import parse_qs, urlparse + +import pytest +from django.conf import settings as test_settings +from django.contrib.auth import get_user_model +from django.urls import reverse +from jwcrypto import jwk + +from oauth2_provider.models import get_application_model +from oauth2_provider.settings import oauth2_settings as _oauth2_settings + +from . import presets + + +Application = get_application_model() +UserModel = get_user_model() + + +class OAuthSettingsWrapper: + """ + A wrapper around oauth2_settings to ensure that when an overridden value is + set, it also records it in _cached_attrs, so that the settings can be reset. + """ + + def __init__(self, settings, user_settings): + self.settings = settings + if not user_settings: + user_settings = {} + self.update(user_settings) + + def update(self, user_settings): + self.settings.OAUTH2_PROVIDER = user_settings + _oauth2_settings.reload() + # Reload OAuthlibCore for every view request during tests + self.ALWAYS_RELOAD_OAUTHLIB_CORE = True + + def __setattr__(self, attr, value): + if attr == "settings": + super().__setattr__(attr, value) + else: + setattr(_oauth2_settings, attr, value) + _oauth2_settings._cached_attrs.add(attr) + + def __delattr__(self, attr): + delattr(_oauth2_settings, attr) + if attr in _oauth2_settings._cached_attrs: + _oauth2_settings._cached_attrs.remove(attr) + + def __getattr__(self, attr): + return getattr(_oauth2_settings, attr) + + def finalize(self): + self.settings.finalize() + _oauth2_settings.reload() + + +@pytest.fixture +def oauth2_settings(request, settings): + """ + A fixture that provides a simple way to override OAUTH2_PROVIDER settings. + + It can be used two ways - either setting things on the fly, or by reading + configuration data from the pytest marker oauth2_settings. + + If used on a standard pytest function, you can use argument dependency + injection to get the wrapper. If used on a unittest.TestCase, the wrapper + is made available on the class instance, as `oauth2_settings`. + + Anything overridden will be restored at the end of the test case, ensuring + that there is no configuration leakage between test cases. + """ + marker = request.node.get_closest_marker("oauth2_settings") + user_settings = {} + if marker is not None: + user_settings = marker.args[0] + wrapper = OAuthSettingsWrapper(settings, user_settings) + if request.instance is not None: + request.instance.oauth2_settings = wrapper + yield wrapper + wrapper.finalize() + + +@pytest.fixture(scope="session") +def oidc_key_(): + return jwk.JWK.from_pem(test_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + + +@pytest.fixture +def oidc_key(request, oidc_key_): + if request.instance is not None: + request.instance.key = oidc_key_ + return oidc_key_ + + +@pytest.fixture +def application(): + return Application.objects.create( + name="Test Application", + redirect_uris="http://example.org", + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + algorithm=Application.RS256_ALGORITHM, + ) + + +@pytest.fixture +def hybrid_application(application): + application.authorization_grant_type = application.GRANT_OPENID_HYBRID + application.save() + return application + + +@pytest.fixture +def test_user(): + return UserModel.objects.create_user("test_user", "test@example.com", "123456") + + +@pytest.fixture +def oidc_tokens(oauth2_settings, application, test_user, client): + oauth2_settings.update(presets.OIDC_SETTINGS_RW) + client.force_login(test_user) + auth_rsp = client.post( + reverse("oauth2_provider:authorize"), + data={ + "client_id": application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + }, + ) + assert auth_rsp.status_code == 302 + code = parse_qs(urlparse(auth_rsp["Location"]).query)["code"] + client.logout() + token_rsp = client.post( + reverse("oauth2_provider:token"), + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": "http://example.org", + "client_id": application.client_id, + "client_secret": application.client_secret, + "scope": "openid", + }, + ) + assert token_rsp.status_code == 200 + token_data = token_rsp.json() + return SimpleNamespace( + user=test_user, + application=application, + access_token=token_data["access_token"], + id_token=token_data["id_token"], + oauth2_settings=oauth2_settings, + ) diff --git a/tests/migrations/0001_initial.py b/tests/migrations/0001_initial.py index 60b17f2ae..8903a5a96 100644 --- a/tests/migrations/0001_initial.py +++ b/tests/migrations/0001_initial.py @@ -33,6 +33,8 @@ class Migration(migrations.Migration): ('custom_field', models.CharField(max_length=255)), ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)), ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='tests_samplegrant', to=settings.AUTH_USER_MODEL)), + ("nonce", models.CharField(blank=True, max_length=255, default="")), + ("claims", models.TextField(blank=True)), ], options={ 'abstract': False, @@ -45,7 +47,7 @@ class Migration(migrations.Migration): ('client_id', models.CharField(db_index=True, default=oauth2_provider.generators.generate_client_id, max_length=100, unique=True)), ('redirect_uris', models.TextField(blank=True, help_text='Allowed URIs list, space separated')), ('client_type', models.CharField(choices=[('confidential', 'Confidential'), ('public', 'Public')], max_length=32)), - ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials')], max_length=32)), + ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32)), ('client_secret', models.CharField(blank=True, db_index=True, default=oauth2_provider.generators.generate_client_secret, max_length=255)), ('name', models.CharField(blank=True, max_length=255)), ('skip_authorization', models.BooleanField(default=False)), @@ -53,6 +55,7 @@ class Migration(migrations.Migration): ('updated', models.DateTimeField(auto_now=True)), ('custom_field', models.CharField(max_length=255)), ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_sampleapplication', to=settings.AUTH_USER_MODEL)), + ('algorithm', models.CharField(max_length=5, choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256')), ], options={ 'abstract': False, @@ -71,6 +74,7 @@ class Migration(migrations.Migration): ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)), ('source_refresh_token', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='s_refreshed_access_token', to=settings.OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL)), ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_sampleaccesstoken', to=settings.AUTH_USER_MODEL)), + ('id_token', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='access_token', to=settings.OAUTH2_PROVIDER_ID_TOKEN_MODEL)), ], options={ 'abstract': False, @@ -83,7 +87,7 @@ class Migration(migrations.Migration): ('client_id', models.CharField(db_index=True, default=oauth2_provider.generators.generate_client_id, max_length=100, unique=True)), ('redirect_uris', models.TextField(blank=True, help_text='Allowed URIs list, space separated')), ('client_type', models.CharField(choices=[('confidential', 'Confidential'), ('public', 'Public')], max_length=32)), - ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials')], max_length=32)), + ('authorization_grant_type', models.CharField(choices=[('authorization-code', 'Authorization code'), ('implicit', 'Implicit'), ('password', 'Resource owner password-based'), ('client-credentials', 'Client credentials'), ('openid-hybrid', 'OpenID connect hybrid')], max_length=32)), ('client_secret', models.CharField(blank=True, db_index=True, default=oauth2_provider.generators.generate_client_secret, max_length=255)), ('name', models.CharField(blank=True, max_length=255)), ('skip_authorization', models.BooleanField(default=False)), @@ -91,6 +95,7 @@ class Migration(migrations.Migration): ('updated', models.DateTimeField(auto_now=True)), ('allowed_schemes', models.TextField(blank=True)), ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_basetestapplication', to=settings.AUTH_USER_MODEL)), + ('algorithm', models.CharField(max_length=5, choices=[('RS256', 'RSA with SHA-2 256'), ('HS256', 'HMAC with SHA-2 256')], default='RS256')), ], options={ 'abstract': False, diff --git a/tests/presets.py b/tests/presets.py new file mode 100644 index 000000000..da1577bf4 --- /dev/null +++ b/tests/presets.py @@ -0,0 +1,45 @@ +from copy import deepcopy + +from django.conf import settings + + +# A set of OAUTH2_PROVIDER settings dicts that can be used in tests + +DEFAULT_SCOPES_RW = {"DEFAULT_SCOPES": ["read", "write"]} +DEFAULT_SCOPES_RO = {"DEFAULT_SCOPES": ["read"]} +OIDC_SETTINGS_RW = { + "OIDC_ENABLED": True, + "OIDC_ISS_ENDPOINT": "http://localhost", + "OIDC_USERINFO_ENDPOINT": "http://localhost/userinfo/", + "OIDC_RSA_PRIVATE_KEY": settings.OIDC_RSA_PRIVATE_KEY, + "SCOPES": { + "read": "Reading scope", + "write": "Writing scope", + "openid": "OpenID connect", + }, + "DEFAULT_SCOPES": ["read", "write"], +} +OIDC_SETTINGS_RO = deepcopy(OIDC_SETTINGS_RW) +OIDC_SETTINGS_RO["DEFAULT_SCOPES"] = ["read"] +OIDC_SETTINGS_HS256_ONLY = deepcopy(OIDC_SETTINGS_RW) +del OIDC_SETTINGS_HS256_ONLY["OIDC_RSA_PRIVATE_KEY"] +REST_FRAMEWORK_SCOPES = { + "SCOPES": { + "read": "Read scope", + "write": "Write scope", + "scope1": "Scope 1", + "scope2": "Scope 2", + "resource1": "Resource 1", + }, +} +INTROSPECTION_SETTINGS = { + "SCOPES": { + "read": "Read scope", + "write": "Write scope", + "introspection": "Introspection scope", + "dolphin": "eek eek eek scope", + }, + "RESOURCE_SERVER_INTROSPECTION_URL": "http://example.org/introspection", + "READ_SCOPE": "read", + "WRITE_SCOPE": "write", +} diff --git a/tests/settings.py b/tests/settings.py index 536762c43..1d295982e 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -117,3 +117,24 @@ }, }, } + +OIDC_RSA_PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQCbCYh5h2NmQuBqVO6G+/CO+cHm9VBzsb0MeA6bbQfDnbhstVOT +j0hcnZJzDjYc6ajBZZf6gxVP9xrdm9Uh599VI3X5PFXLbMHrmzTAMzCGIyg+/fnP +0gocYxmCX2+XKyj/Zvt1pUX8VAN2AhrJSfxNDKUHERTVEV9bRBJg4F0C3wIDAQAB +AoGAP+i4nNw+Ec/8oWh8YSFm4xE6qKG0NdTtSMAOyWwy+KTB+vHuT1QPsLn1vj77 ++IQrX/moogg6F1oV9YdA3vat3U7rwt1sBGsRrLhA+Spp9WEQtglguNo4+QfVo2ju +YBa2rG+h75qjiA3xnU//F3rvwnAsOWv0NUVdVeguyR+u6okCQQDBUmgWeH2WHmUn +2nLNCz+9wj28rqhfOr9Ptem2gqk+ywJmuIr4Y5S1OdavOr2UZxOcEwncJ/MLVYQq +MH+x4V5HAkEAzU2GMR5OdVLcxfVTjzuIC76paoHVWnLibd1cdANpPmE6SM+pf5el +fVSwuH9Fmlizu8GiPCxbJUoXB/J1tGEKqQJBALhClEU+qOzpoZ6/voYi/6kdN3zc +uEy0EN6n09AKb8gS9QH1STgAqh+ltjMkeMe3C2DKYK5/QU9/Pc58lWl1FkcCQG67 +ZamQgxjcvJ85FvymS1aqW45KwNysIlzHjFo2jMlMf7dN6kobbPMQftDENLJvLWIT +qoFyGycdsxZiPAIyZSECQQCZFn3Dl6hnJxWZH8Fsa9hj79kZ/WVkIXGmtdgt0fNr +dTnvCVtA59ne4LEVie/PMH/odQWY0SxVm/76uBZv/1vY +-----END RSA PRIVATE KEY-----""" + +OAUTH2_PROVIDER_ACCESS_TOKEN_MODEL = "oauth2_provider.AccessToken" +OAUTH2_PROVIDER_APPLICATION_MODEL = "oauth2_provider.Application" +OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL = "oauth2_provider.RefreshToken" +OAUTH2_PROVIDER_ID_TOKEN_MODEL = "oauth2_provider.IDToken" diff --git a/tests/test_application_views.py b/tests/test_application_views.py index 0e476054a..42eb17fd0 100644 --- a/tests/test_application_views.py +++ b/tests/test_application_views.py @@ -1,9 +1,9 @@ +import pytest from django.contrib.auth import get_user_model from django.test import TestCase from django.urls import reverse from oauth2_provider.models import get_application_model -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views.application import ApplicationRegistration from .models import SampleApplication @@ -23,21 +23,19 @@ def tearDown(self): self.bar_user.delete() +@pytest.mark.usefixtures("oauth2_settings") class TestApplicationRegistrationView(BaseTest): + @pytest.mark.oauth2_settings({"APPLICATION_MODEL": "tests.SampleApplication"}) def test_get_form_class(self): """ Tests that the form class returned by the "get_form_class" method is bound to custom application model defined in the "OAUTH2_PROVIDER_APPLICATION_MODEL" setting. """ - # Patch oauth2 settings to use a custom Application model - oauth2_settings.APPLICATION_MODEL = "tests.SampleApplication" # Create a registration view and tests that the model form is bound # to the custom Application model application_form_class = ApplicationRegistration().get_form_class() self.assertEqual(SampleApplication, application_form_class._meta.model) - # Revert oauth2 settings - oauth2_settings.APPLICATION_MODEL = "oauth2_provider.Application" def test_application_registration_user(self): self.client.login(username="foo_user", password="123456") @@ -49,6 +47,7 @@ def test_application_registration_user(self): "client_type": Application.CLIENT_CONFIDENTIAL, "redirect_uris": "http://example.com", "authorization_grant_type": Application.GRANT_AUTHORIZATION_CODE, + "algorithm": "", } response = self.client.post(reverse("oauth2_provider:register"), form_data) diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index 44c474380..ea1bee86d 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -5,11 +5,13 @@ import re from urllib.parse import parse_qs, urlparse +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse from django.utils import timezone from django.utils.crypto import get_random_string +from jwcrypto import jwt from oauthlib.oauth2.rfc6749 import errors as oauthlib_errors from oauth2_provider.models import ( @@ -18,9 +20,9 @@ get_grant_model, get_refresh_token_model, ) -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView +from . import presets from .utils import get_basic_auth_header @@ -40,13 +42,14 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() self.test_user = UserModel.objects.create_user("test_user", "test@example.com", "123456") self.dev_user = UserModel.objects.create_user("dev_user", "dev@example.com", "123456") - oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] + self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] self.application = Application.objects.create( name="Test Application", @@ -59,9 +62,6 @@ def setUp(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - oauth2_settings._SCOPES = ["read", "write"] - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - def tearDown(self): self.application.delete() self.test_user.delete() @@ -90,6 +90,7 @@ def test_request_is_not_overwritten(self): assert "request" not in response.context_data +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) class TestAuthorizationCodeView(BaseTest): def test_skip_authorization_completely(self): """ @@ -210,7 +211,7 @@ def test_pre_auth_approval_prompt(self): self.assertEqual(response.status_code, 200) def test_pre_auth_approval_prompt_default(self): - self.assertEqual(oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") + self.assertEqual(self.oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") AccessToken.objects.create( user=self.test_user, @@ -231,7 +232,7 @@ def test_pre_auth_approval_prompt_default(self): self.assertEqual(response.status_code, 200) def test_pre_auth_approval_prompt_default_override(self): - oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" + self.oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" AccessToken.objects.create( user=self.test_user, @@ -523,15 +524,84 @@ def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): self.assertEqual(response.status_code, 400) -class TestAuthorizationCodeTokenView(BaseTest): - def get_auth(self): +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestOIDCAuthorizationCodeView(BaseTest): + def test_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + } + + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) + self.assertEqual(response.status_code, 302) + + def test_id_token_pre_auth_valid_client(self): + """ + Test response for a valid client_id with response_type: code + """ + self.client.login(username="test_user", password="123456") + + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + } + + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "openid") + self.assertEqual(form["client_id"].value(), self.application.client_id) + + def test_id_token_code_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org?", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + + +class BaseAuthorizationCodeTokenView(BaseTest): + def get_auth(self, scope="read write"): """ Helper method to retrieve a valid authorization code """ authcode_data = { "client_id": self.application.client_id, "state": "random_state_string", - "scope": "read write", + "scope": scope, "redirect_uri": "http://example.org", "response_type": "code", "allow": True, @@ -558,7 +628,7 @@ def get_pkce_auth(self, code_challenge, code_challenge_method): """ Helper method to retrieve a valid authorization code using pkce """ - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True authcode_data = { "client_id": self.application.client_id, "state": "random_state_string", @@ -572,9 +642,11 @@ def get_pkce_auth(self, code_challenge, code_challenge_method): response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) query_dict = parse_qs(urlparse(response["Location"]).query) - oauth2_settings.PKCE_REQUIRED = False return query_dict["code"].pop() + +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) +class TestAuthorizationCodeTokenView(BaseAuthorizationCodeTokenView): def test_basic_auth(self): """ Request an access token using basic authentication for client authentication @@ -595,7 +667,7 @@ def test_basic_auth(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_refresh(self): """ @@ -645,7 +717,7 @@ def test_refresh_with_grace_period(self): """ Request an access token using a refresh token """ - oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 120 + self.oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 120 self.client.login(username="test_user", password="123456") authorization_code = self.get_auth() @@ -692,7 +764,6 @@ def test_refresh_with_grace_period(self): # refresh token should be the same as well self.assertTrue("refresh_token" in content) self.assertEqual(content["refresh_token"], first_refresh_token) - oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 0 def test_refresh_invalidates_old_tokens(self): """ @@ -813,7 +884,7 @@ def test_refresh_repeating_requests(self): Trying to refresh an access token with the same refresh token more than once succeeds in the grace period and fails outside """ - oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 120 + self.oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 120 self.client.login(username="test_user", password="123456") authorization_code = self.get_auth() @@ -846,7 +917,6 @@ def test_refresh_repeating_requests(self): response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 400) - oauth2_settings.REFRESH_TOKEN_GRACE_PERIOD_SECONDS = 0 def test_refresh_repeating_requests_non_rotating_tokens(self): """ @@ -871,15 +941,13 @@ def test_refresh_repeating_requests_non_rotating_tokens(self): "refresh_token": content["refresh_token"], "scope": content["scope"], } - oauth2_settings.ROTATE_REFRESH_TOKEN = False + self.oauth2_settings.ROTATE_REFRESH_TOKEN = False response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) self.assertEqual(response.status_code, 200) - oauth2_settings.ROTATE_REFRESH_TOKEN = True - def test_basic_auth_bad_authcode(self): """ Request an access token using a bad authorization code @@ -993,7 +1061,7 @@ def test_request_body_params(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public(self): """ @@ -1018,7 +1086,7 @@ def test_public(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public_pkce_S256_authorize_get(self): """ @@ -1031,7 +1099,7 @@ def test_public_pkce_S256_authorize_get(self): self.application.client_type = Application.CLIENT_PUBLIC self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True query_data = { "client_id": self.application.client_id, @@ -1047,7 +1115,6 @@ def test_public_pkce_S256_authorize_get(self): response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertContains(response, 'value="S256"', count=1, status_code=200) self.assertContains(response, 'value="{0}"'.format(code_challenge), count=1, status_code=200) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_plain_authorize_get(self): """ @@ -1060,7 +1127,7 @@ def test_public_pkce_plain_authorize_get(self): self.application.client_type = Application.CLIENT_PUBLIC self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("plain") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True query_data = { "client_id": self.application.client_id, @@ -1076,7 +1143,6 @@ def test_public_pkce_plain_authorize_get(self): response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertContains(response, 'value="plain"', count=1, status_code=200) self.assertContains(response, 'value="{0}"'.format(code_challenge), count=1, status_code=200) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_S256(self): """ @@ -1089,7 +1155,7 @@ def test_public_pkce_S256(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") authorization_code = self.get_pkce_auth(code_challenge, "S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1105,8 +1171,7 @@ def test_public_pkce_S256(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - oauth2_settings.PKCE_REQUIRED = False + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public_pkce_plain(self): """ @@ -1119,7 +1184,7 @@ def test_public_pkce_plain(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("plain") authorization_code = self.get_pkce_auth(code_challenge, "plain") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1135,8 +1200,7 @@ def test_public_pkce_plain(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - oauth2_settings.PKCE_REQUIRED = False + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_public_pkce_invalid_algorithm(self): """ @@ -1148,7 +1212,7 @@ def test_public_pkce_invalid_algorithm(self): self.application.client_type = Application.CLIENT_PUBLIC self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("invalid") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True query_data = { "client_id": self.application.client_id, @@ -1164,7 +1228,6 @@ def test_public_pkce_invalid_algorithm(self): response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 302) self.assertIn("error=invalid_request", response["Location"]) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_missing_code_challenge(self): """ @@ -1177,7 +1240,7 @@ def test_public_pkce_missing_code_challenge(self): self.application.skip_authorization = True self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True query_data = { "client_id": self.application.client_id, @@ -1192,7 +1255,6 @@ def test_public_pkce_missing_code_challenge(self): response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 302) self.assertIn("error=invalid_request", response["Location"]) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_missing_code_challenge_method(self): """ @@ -1204,7 +1266,7 @@ def test_public_pkce_missing_code_challenge_method(self): self.application.client_type = Application.CLIENT_PUBLIC self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True query_data = { "client_id": self.application.client_id, @@ -1218,7 +1280,6 @@ def test_public_pkce_missing_code_challenge_method(self): response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) self.assertEqual(response.status_code, 200) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_S256_invalid_code_verifier(self): """ @@ -1231,7 +1292,7 @@ def test_public_pkce_S256_invalid_code_verifier(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") authorization_code = self.get_pkce_auth(code_challenge, "S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1243,7 +1304,6 @@ def test_public_pkce_S256_invalid_code_verifier(self): response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_plain_invalid_code_verifier(self): """ @@ -1256,7 +1316,7 @@ def test_public_pkce_plain_invalid_code_verifier(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("plain") authorization_code = self.get_pkce_auth(code_challenge, "plain") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1268,7 +1328,6 @@ def test_public_pkce_plain_invalid_code_verifier(self): response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_S256_missing_code_verifier(self): """ @@ -1281,7 +1340,7 @@ def test_public_pkce_S256_missing_code_verifier(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("S256") authorization_code = self.get_pkce_auth(code_challenge, "S256") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1292,7 +1351,6 @@ def test_public_pkce_S256_missing_code_verifier(self): response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) - oauth2_settings.PKCE_REQUIRED = False def test_public_pkce_plain_missing_code_verifier(self): """ @@ -1305,7 +1363,7 @@ def test_public_pkce_plain_missing_code_verifier(self): self.application.save() code_verifier, code_challenge = self.generate_pkce_codes("plain") authorization_code = self.get_pkce_auth(code_challenge, "plain") - oauth2_settings.PKCE_REQUIRED = True + self.oauth2_settings.PKCE_REQUIRED = True token_request_data = { "grant_type": "authorization_code", @@ -1316,7 +1374,6 @@ def test_public_pkce_plain_missing_code_verifier(self): response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) self.assertEqual(response.status_code, 400) - oauth2_settings.PKCE_REQUIRED = False def test_malicious_redirect_uri(self): """ @@ -1340,7 +1397,10 @@ def test_malicious_redirect_uri(self): self.assertEqual(response.status_code, 400) data = response.json() self.assertEqual(data["error"], "invalid_request") - self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) + self.assertEqual( + data["error_description"], + oauthlib_errors.MismatchingRedirectURIError.description, + ) def test_code_exchange_succeed_when_redirect_uri_match(self): """ @@ -1375,7 +1435,7 @@ def test_code_exchange_succeed_when_redirect_uri_match(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_code_exchange_fails_when_redirect_uri_does_not_match(self): """ @@ -1408,9 +1468,14 @@ def test_code_exchange_fails_when_redirect_uri_does_not_match(self): self.assertEqual(response.status_code, 400) data = response.json() self.assertEqual(data["error"], "invalid_request") - self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) + self.assertEqual( + data["error_description"], + oauthlib_errors.MismatchingRedirectURIError.description, + ) - def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): + def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params( + self, + ): """ Tests code exchange succeed when redirect uri matches the one used for code request """ @@ -1445,7 +1510,7 @@ def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_param content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_oob_as_html(self): """ @@ -1491,7 +1556,7 @@ def test_oob_as_html(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_oob_as_json(self): """ @@ -1531,9 +1596,130 @@ def test_oob_as_json(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestOIDCAuthorizationCodeTokenView(BaseAuthorizationCodeTokenView): + def setUp(self): + super().setUp() + self.application.algorithm = Application.RS256_ALGORITHM + self.application.save() + + def test_id_token_public(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="test_user", password="123456") + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth(scope="openid") + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "scope": "openid", + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params( + self, + ): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="test_user", password="123456") + self.application.redirect_uris = "http://localhost http://example.com?foo=bar" + self.application.save() + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.com?bar=baz&foo=bar", + "response_type": "code", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + query_dict = parse_qs(urlparse(response["Location"]).query) + authorization_code = query_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.com?bar=baz&foo=bar", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestOIDCAuthorizationCodeHSAlgorithm(BaseAuthorizationCodeTokenView): + def setUp(self): + super().setUp() + self.oauth2_settings.OIDC_RSA_PRIVATE_KEY = None + self.application.algorithm = Application.HS256_ALGORITHM + self.application.save() + + def test_id_token(self): + """ + Request an access token using an HS256 application + """ + self.client.login(username="test_user", password="123456") + + authorization_code = self.get_auth(scope="openid") + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "client_secret": self.application.client_secret, + "scope": "openid", + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = response.json() + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + # Check decoding JWT using HS256 + key = self.application.jwk_key + assert key.key_type == "oct" + jwt_token = jwt.JWT(key=key, jwt=content["id_token"]) + claims = json.loads(jwt_token.claims) + assert claims["sub"] == "1" + + +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) class TestAuthorizationCodeProtectedResource(BaseTest): def test_resource_access_allowed(self): self.client.login(username="test_user", password="123456") @@ -1586,13 +1772,72 @@ def test_resource_access_deny(self): self.assertEqual(response.status_code, 403) +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestOIDCAuthorizationCodeProtectedResource(BaseTest): + def setUp(self): + super().setUp() + self.application.algorithm = Application.RS256_ALGORITHM + self.application.save() + + def test_id_token_resource_access_allowed(self): + self.client.login(username="test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + query_dict = parse_qs(urlparse(response["Location"]).query) + authorization_code = query_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + content = json.loads(response.content.decode("utf-8")) + access_token = content["access_token"] + id_token = content["id_token"] + + # use token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + access_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + + # use id_token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + id_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + + +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RO) class TestDefaultScopes(BaseTest): def test_pre_auth_default_scopes(self): """ Test response for a valid client_id with response_type: code using default scopes """ self.client.login(username="test_user", password="123456") - oauth2_settings._DEFAULT_SCOPES = ["read"] query_data = { "client_id": self.application.client_id, @@ -1612,4 +1857,3 @@ def test_pre_auth_default_scopes(self): self.assertEqual(form["state"].value(), "random_state_string") self.assertEqual(form["scope"].value(), "read") self.assertEqual(form["client_id"].value(), self.application.client_id) - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] diff --git a/tests/test_client_credential.py b/tests/test_client_credential.py index 966eb826b..8b9aa3bc2 100644 --- a/tests/test_client_credential.py +++ b/tests/test_client_credential.py @@ -1,6 +1,7 @@ import json from urllib.parse import quote_plus +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse @@ -10,10 +11,10 @@ from oauth2_provider.models import get_access_token_model, get_application_model from oauth2_provider.oauth2_backends import OAuthLibCore from oauth2_provider.oauth2_validators import OAuth2Validator -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView from oauth2_provider.views.mixins import OAuthLibMixin +from . import presets from .utils import get_basic_auth_header @@ -28,6 +29,8 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW) class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() @@ -41,9 +44,6 @@ def setUp(self): authorization_grant_type=Application.GRANT_CLIENT_CREDENTIALS, ) - oauth2_settings._SCOPES = ["read", "write"] - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - def tearDown(self): self.application.delete() self.test_user.delete() diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 22ce48e76..ce17a891a 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -6,7 +6,6 @@ from oauth2_provider.decorators import protected_resource, rw_protected_resource from oauth2_provider.models import get_access_token_model, get_application_model -from oauth2_provider.settings import oauth2_settings Application = get_application_model() @@ -37,8 +36,6 @@ def setUp(self): application=self.application, ) - oauth2_settings._SCOPES = ["read", "write"] - def test_access_denied(self): @protected_resource() def view(request, *args, **kwargs): diff --git a/tests/test_generator.py b/tests/test_generator.py index 670ac9ea1..cc7928017 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,13 +1,7 @@ +import pytest from django.test import TestCase -from oauth2_provider.generators import ( - BaseHashGenerator, - ClientIdGenerator, - ClientSecretGenerator, - generate_client_id, - generate_client_secret, -) -from oauth2_provider.settings import oauth2_settings +from oauth2_provider.generators import BaseHashGenerator, generate_client_id, generate_client_secret class MockHashGenerator(BaseHashGenerator): @@ -15,23 +9,20 @@ def hash(self): return 42 +@pytest.mark.usefixtures("oauth2_settings") class TestGenerators(TestCase): - def tearDown(self): - oauth2_settings.CLIENT_ID_GENERATOR_CLASS = ClientIdGenerator - oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS = ClientSecretGenerator - def test_generate_client_id(self): - g = oauth2_settings.CLIENT_ID_GENERATOR_CLASS() + g = self.oauth2_settings.CLIENT_ID_GENERATOR_CLASS() self.assertEqual(len(g.hash()), 40) - oauth2_settings.CLIENT_ID_GENERATOR_CLASS = MockHashGenerator + self.oauth2_settings.CLIENT_ID_GENERATOR_CLASS = MockHashGenerator self.assertEqual(generate_client_id(), 42) def test_generate_secret_id(self): - g = oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS() + g = self.oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS() self.assertEqual(len(g.hash()), 128) - oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS = MockHashGenerator + self.oauth2_settings.CLIENT_SECRET_GENERATOR_CLASS = MockHashGenerator self.assertEqual(generate_client_secret(), 42) def test_basegen_misuse(self): diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py new file mode 100644 index 000000000..d198988f6 --- /dev/null +++ b/tests/test_hybrid.py @@ -0,0 +1,1431 @@ +import base64 +import datetime +import json +from urllib.parse import parse_qs, urlencode, urlparse + +import pytest +from django.contrib.auth import get_user_model +from django.test import RequestFactory, TestCase +from django.urls import reverse +from django.utils import timezone +from jwcrypto import jwt +from oauthlib.oauth2.rfc6749 import errors as oauthlib_errors + +from oauth2_provider.models import ( + get_access_token_model, + get_application_model, + get_grant_model, + get_refresh_token_model, +) +from oauth2_provider.oauth2_validators import OAuth2Validator +from oauth2_provider.views import ProtectedResourceView, ScopedProtectedResourceView + +from . import presets +from .utils import get_basic_auth_header, spy_on + + +Application = get_application_model() +AccessToken = get_access_token_model() +Grant = get_grant_model() +RefreshToken = get_refresh_token_model() +UserModel = get_user_model() + + +# mocking a protected resource view +class ResourceView(ProtectedResourceView): + def get(self, request, *args, **kwargs): + return "This is a protected resource" + + +class ScopedResourceView(ScopedProtectedResourceView): + required_scopes = ["read"] + + def get(self, request, *args, **kwargs): + return "This is a protected resource" + + +@pytest.mark.usefixtures("oauth2_settings") +class BaseTest(TestCase): + def setUp(self): + self.factory = RequestFactory() + self.hy_test_user = UserModel.objects.create_user("hy_test_user", "test_hy@example.com", "123456") + self.hy_dev_user = UserModel.objects.create_user("hy_dev_user", "dev_hy@example.com", "123456") + + self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["http", "custom-scheme"] + + self.application = Application( + name="Hybrid Test Application", + redirect_uris=( + "http://localhost http://example.com http://example.org custom-scheme://example.com" + ), + user=self.hy_dev_user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_OPENID_HYBRID, + algorithm=Application.RS256_ALGORITHM, + ) + self.application.save() + + def tearDown(self): + self.application.delete() + self.hy_test_user.delete() + self.hy_dev_user.delete() + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestRegressionIssue315Hybrid(BaseTest): + """ + Test to avoid regression for the issue 315: request object + was being reassigned when getting AuthorizationView + """ + + def test_request_is_not_overwritten_code_token(self): + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code token", + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + assert "request" not in response.context_data + + def test_request_is_not_overwritten_code_id_token(self): + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "nonce": "nonce", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + assert "request" not in response.context_data + + def test_request_is_not_overwritten_code_id_token_token(self): + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token token", + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "nonce": "nonce", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + assert "request" not in response.context_data + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestHybridView(BaseTest): + def test_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="hy_test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + + def test_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="hy_test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + + def test_pre_auth_invalid_client(self): + """ + Test error for an invalid client_id with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode( + { + "client_id": "fakeclientid", + "response_type": "code", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 400) + self.assertEqual( + response.context_data["url"], + "?error=invalid_request&error_description=Invalid+client_id+parameter+value.", + ) + + def test_pre_auth_valid_client(self): + """ + Test response for a valid client_id with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "read write") + self.assertEqual(form["client_id"].value(), self.application.client_id) + + def test_id_token_pre_auth_valid_client(self): + """ + Test response for a valid client_id with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "nonce": "nonce", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "openid") + self.assertEqual(form["client_id"].value(), self.application.client_id) + + def test_pre_auth_valid_client_custom_redirect_uri_scheme(self): + """ + Test response for a valid client_id with response_type: code + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "custom-scheme://example.com", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "custom-scheme://example.com") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "read write") + self.assertEqual(form["client_id"].value(), self.application.client_id) + + def test_pre_auth_approval_prompt(self): + tok = AccessToken.objects.create( + user=self.hy_test_user, + token="1234567890", + application=self.application, + expires=timezone.now() + datetime.timedelta(days=1), + scope="read write", + ) + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "approval_prompt": "auto", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + # user already authorized the application, but with different scopes: prompt them. + tok.scope = "read" + tok.save() + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + def test_pre_auth_approval_prompt_default(self): + self.oauth2_settings.REQUEST_APPROVAL_PROMPT = "force" + self.assertEqual(self.oauth2_settings.REQUEST_APPROVAL_PROMPT, "force") + + AccessToken.objects.create( + user=self.hy_test_user, + token="1234567890", + application=self.application, + expires=timezone.now() + datetime.timedelta(days=1), + scope="read write", + ) + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + def test_pre_auth_approval_prompt_default_override(self): + self.oauth2_settings.REQUEST_APPROVAL_PROMPT = "auto" + + AccessToken.objects.create( + user=self.hy_test_user, + token="1234567890", + application=self.application, + expires=timezone.now() + datetime.timedelta(days=1), + scope="read write", + ) + self.client.login(username="hy_test_user", password="123456") + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + + def test_pre_auth_default_redirect(self): + """ + Test for default redirect uri if omitted from query string with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code id_token", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://localhost") + + def test_pre_auth_forbibben_redirect(self): + """ + Test error when passing a forbidden redirect_uri in query string with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code", + "redirect_uri": "http://forbidden.it", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 400) + + def test_pre_auth_wrong_response_type(self): + """ + Test error when passing a wrong response_type in query string + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "WRONG", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 302) + self.assertIn("error=unsupported_response_type", response["Location"]) + + def test_code_post_auth_allow_code_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "response_type": "code token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_allow_code_id_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + + def test_code_post_auth_allow_code_id_token_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "response_type": "code id_token token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_id_token_code_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: code + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + + def test_code_post_auth_deny(self): + """ + Test error when resource owner deny access + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "code", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("error=access_denied", response["Location"]) + + def test_code_post_auth_bad_responsetype(self): + """ + Test authorization code is given for an allowed request with a response_type not supported + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "response_type": "UNKNOWN", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org?error", response["Location"]) + + def test_code_post_auth_forbidden_redirect_uri(self): + """ + Test authorization code is given for an allowed request with a forbidden redirect_uri + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://forbidden.it", + "response_type": "code", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 400) + + def test_code_post_auth_malicious_redirect_uri(self): + """ + Test validation of a malicious redirect_uri + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "/../", + "response_type": "code", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 400) + + def test_code_post_auth_allow_custom_redirect_uri_scheme_code_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "custom-scheme://example.com", + "response_type": "code token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("custom-scheme://example.com", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_allow_custom_redirect_uri_scheme_code_id_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "custom-scheme://example.com", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("custom-scheme://example.com", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + + def test_code_post_auth_allow_custom_redirect_uri_scheme_code_id_token_token(self): + """ + Test authorization code is given for an allowed request with response_type: code + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "custom-scheme://example.com", + "response_type": "code id_token token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("custom-scheme://example.com", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_deny_custom_redirect_uri_scheme(self): + """ + Test error when resource owner deny access + using a non-standard, but allowed, redirect_uri scheme. + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "custom-scheme://example.com", + "response_type": "code", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("custom-scheme://example.com?", response["Location"]) + self.assertIn("error=access_denied", response["Location"]) + + def test_code_post_auth_redirection_uri_with_querystring_code_token(self): + """ + Tests that a redirection uri with query string is allowed + and query string is retained on redirection. + See http://tools.ietf.org/html/rfc6749#section-3.1.2 + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.com?foo=bar", + "response_type": "code token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.com?foo=bar", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_redirection_uri_with_querystring_code_id_token(self): + """ + Tests that a redirection uri with query string is allowed + and query string is retained on redirection. + See http://tools.ietf.org/html/rfc6749#section-3.1.2 + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.com?foo=bar", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.com?foo=bar", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + + def test_code_post_auth_redirection_uri_with_querystring_code_id_token_token(self): + """ + Tests that a redirection uri with query string is allowed + and query string is retained on redirection. + See http://tools.ietf.org/html/rfc6749#section-3.1.2 + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.com?foo=bar", + "response_type": "code id_token token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.com?foo=bar", response["Location"]) + self.assertIn("code=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("access_token=", response["Location"]) + + def test_code_post_auth_failing_redirection_uri_with_querystring(self): + """ + Test that in case of error the querystring of the redirection uri is preserved + + See https://github.com/evonove/django-oauth-toolkit/issues/238 + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.com?foo=bar", + "response_type": "code", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertEqual( + "http://example.com?foo=bar&error=access_denied&state=random_state_string", response["Location"] + ) + + def test_code_post_auth_fails_when_redirect_uri_path_is_invalid(self): + """ + Tests that a redirection uri is matched using scheme + netloc + path + """ + self.client.login(username="hy_test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.com/a?foo=bar", + "response_type": "code", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 400) + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestHybridTokenView(BaseTest): + def get_auth(self, scope="read write"): + """ + Helper method to retrieve a valid authorization code + """ + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": scope, + "redirect_uri": "http://example.org", + "response_type": "code id_token", + "allow": True, + "nonce": "nonce", + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + return fragment_dict["code"].pop() + + def test_basic_auth(self): + """ + Request an access token using basic authentication for client authentication + """ + self.client.login(username="hy_test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "read write") + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_basic_auth_bad_authcode(self): + """ + Request an access token using a bad authorization code + """ + self.client.login(username="hy_test_user", password="123456") + + token_request_data = { + "grant_type": "authorization_code", + "code": "BLAH", + "redirect_uri": "http://example.org", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 400) + + def test_basic_auth_bad_granttype(self): + """ + Request an access token using a bad grant_type string + """ + self.client.login(username="hy_test_user", password="123456") + + token_request_data = {"grant_type": "UNKNOWN", "code": "BLAH", "redirect_uri": "http://example.org"} + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 400) + + def test_basic_auth_grant_expired(self): + """ + Request an access token using an expired grant token + """ + self.client.login(username="hy_test_user", password="123456") + g = Grant( + application=self.application, + user=self.hy_test_user, + code="BLAH", + expires=timezone.now(), + redirect_uri="", + scope="", + ) + g.save() + + token_request_data = { + "grant_type": "authorization_code", + "code": "BLAH", + "redirect_uri": "http://example.org", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 400) + + def test_basic_auth_bad_secret(self): + """ + Request an access token using basic authentication for client authentication + """ + self.client.login(username="hy_test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + } + auth_headers = get_basic_auth_header(self.application.client_id, "BOOM!") + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 401) + + def test_basic_auth_wrong_auth_type(self): + """ + Request an access token using basic authentication for client authentication + """ + self.client.login(username="hy_test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + } + + user_pass = "{0}:{1}".format(self.application.client_id, self.application.client_secret) + auth_string = base64.b64encode(user_pass.encode("utf-8")) + auth_headers = { + "HTTP_AUTHORIZATION": "Wrong " + auth_string.decode("utf-8"), + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 401) + + def test_request_body_params(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="hy_test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "client_secret": self.application.client_secret, + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "read write") + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_public(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="hy_test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "read write") + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_id_token_public(self): + """ + Request an access token using client_type: public + """ + self.client.login(username="hy_test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth(scope="openid") + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "scope": "openid", + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_malicious_redirect_uri(self): + """ + Request an access token using client_type: public and ensure redirect_uri is + properly validated. + """ + self.client.login(username="hy_test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "/../", + "client_id": self.application.client_id, + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 400) + data = response.json() + self.assertEqual(data["error"], "invalid_request") + self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) + + def test_code_exchange_succeed_when_redirect_uri_match(self): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="hy_test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org?foo=bar", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org?foo=bar", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid read write") + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_code_exchange_fails_when_redirect_uri_does_not_match(self): + """ + Tests code exchange fails when redirect uri does not match the one used for code request + """ + self.client.login(username="hy_test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org?foo=bar", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + query_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = query_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org?foo=baraa", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 400) + data = response.json() + self.assertEqual(data["error"], "invalid_request") + self.assertEqual(data["error_description"], oauthlib_errors.MismatchingRedirectURIError.description) + + def test_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="hy_test_user", password="123456") + self.application.redirect_uris = "http://localhost http://example.com?foo=bar" + self.application.save() + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.com?bar=baz&foo=bar", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.com?bar=baz&foo=bar", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid read write") + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + def test_id_token_code_exchange_succeed_when_redirect_uri_match_with_multiple_query_params(self): + """ + Tests code exchange succeed when redirect uri matches the one used for code request + """ + self.client.login(username="hy_test_user", password="123456") + self.application.redirect_uris = "http://localhost http://example.com?foo=bar" + self.application.save() + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.com?bar=baz&foo=bar", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.com?bar=baz&foo=bar", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + self.assertEqual(response.status_code, 200) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["token_type"], "Bearer") + self.assertEqual(content["scope"], "openid") + self.assertIn("access_token", content) + self.assertIn("id_token", content) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestHybridProtectedResource(BaseTest): + def test_resource_access_allowed(self): + self.client.login(username="hy_test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid read write", + "redirect_uri": "http://example.org", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + content = json.loads(response.content.decode("utf-8")) + access_token = content["access_token"] + + # use token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + access_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.hy_test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + + def test_id_token_resource_access_allowed(self): + self.client.login(username="hy_test_user", password="123456") + + # retrieve a valid authorization code + authcode_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code token", + "allow": True, + } + response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data) + fragment_dict = parse_qs(urlparse(response["Location"]).fragment) + authorization_code = fragment_dict["code"].pop() + + # exchange authorization code for a valid access token + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + } + auth_headers = get_basic_auth_header(self.application.client_id, self.application.client_secret) + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers) + content = json.loads(response.content.decode("utf-8")) + access_token = content["access_token"] + id_token = content["id_token"] + + # use token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + access_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.hy_test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + + # use id_token to access the resource + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + id_token, + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.hy_test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response, "This is a protected resource") + + # If the resource requires more scopes than we requested, we should get an error + view = ScopedResourceView.as_view() + response = view(request) + self.assertEqual(response.status_code, 403) + + def test_resource_access_deny(self): + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + "faketoken", + } + request = self.factory.get("/fake-resource", **auth_headers) + request.user = self.hy_test_user + + view = ResourceView.as_view() + response = view(request) + self.assertEqual(response.status_code, 403) + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RO) +class TestDefaultScopesHybrid(BaseTest): + def test_pre_auth_default_scopes(self): + """ + Test response for a valid client_id with response_type: code using default scopes + """ + self.client.login(username="hy_test_user", password="123456") + + query_string = urlencode( + { + "client_id": self.application.client_id, + "response_type": "code token", + "state": "random_state_string", + "redirect_uri": "http://example.org", + } + ) + url = "{url}?{qs}".format(url=reverse("oauth2_provider:authorize"), qs=query_string) + + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + # check form is in context and form params are valid + self.assertIn("form", response.context) + + form = response.context["form"] + self.assertEqual(form["redirect_uri"].value(), "http://example.org") + self.assertEqual(form["state"].value(), "random_state_string") + self.assertEqual(form["scope"].value(), "read") + self.assertEqual(form["client_id"].value(), self.application.client_id) + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_application, client, oidc_key): + client.force_login(test_user) + auth_rsp = client.post( + reverse("oauth2_provider:authorize"), + data={ + "client_id": hybrid_application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code id_token", + "nonce": "random_nonce_string", + "allow": True, + }, + ) + assert auth_rsp.status_code == 302 + auth_data = parse_qs(urlparse(auth_rsp["Location"]).fragment) + assert "code" in auth_data + assert "id_token" in auth_data + # Decode the id token - is the nonce correct + jwt_token = jwt.JWT(key=oidc_key, jwt=auth_data["id_token"][0]) + claims = json.loads(jwt_token.claims) + assert "nonce" in claims + assert claims["nonce"] == "random_nonce_string" + code = auth_data["code"][0] + client.logout() + # Get the token response using the code + token_rsp = client.post( + reverse("oauth2_provider:token"), + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": "http://example.org", + "client_id": hybrid_application.client_id, + "client_secret": hybrid_application.client_secret, + "scope": "openid", + }, + ) + assert token_rsp.status_code == 200 + token_data = token_rsp.json() + assert "id_token" in token_data + # The nonce should be present in this id token also + jwt_token = jwt.JWT(key=oidc_key, jwt=token_data["id_token"]) + claims = json.loads(jwt_token.claims) + assert "nonce" in claims + assert claims["nonce"] == "random_nonce_string" + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_claims_passed_to_code_generation( + oauth2_settings, test_user, hybrid_application, client, mocker, oidc_key +): + # Add a spy on to OAuth2Validator.finalize_id_token + mocker.patch.object( + OAuth2Validator, + "finalize_id_token", + spy_on(OAuth2Validator.finalize_id_token), + ) + claims = {"id_token": {"email": {"essential": True}}} + client.force_login(test_user) + auth_form_rsp = client.get( + reverse("oauth2_provider:authorize"), + data={ + "client_id": hybrid_application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "code id_token", + "nonce": "random_nonce_string", + "claims": json.dumps(claims), + }, + ) + # Check that claims has made it in to the form to be submitted + assert auth_form_rsp.status_code == 200 + form_initial_data = auth_form_rsp.context_data["form"].initial + assert "claims" in form_initial_data + assert json.loads(form_initial_data["claims"]) == claims + # Filter out not specified values + form_data = {key: value for key, value in form_initial_data.items() if value is not None} + # Now submitting the form (with allow=True) should persist requested claims + auth_rsp = client.post( + reverse("oauth2_provider:authorize"), + data={"allow": True, **form_data}, + ) + assert auth_rsp.status_code == 302 + auth_data = parse_qs(urlparse(auth_rsp["Location"]).fragment) + assert "code" in auth_data + assert "id_token" in auth_data + assert OAuth2Validator.finalize_id_token.spy.call_count == 1 + oauthlib_request = OAuth2Validator.finalize_id_token.spy.call_args[0][4] + assert oauthlib_request.claims == claims + assert Grant.objects.get().claims == json.dumps(claims) + OAuth2Validator.finalize_id_token.spy.reset_mock() + + # Get the token response using the code + client.logout() + code = auth_data["code"][0] + token_rsp = client.post( + reverse("oauth2_provider:token"), + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": "http://example.org", + "client_id": hybrid_application.client_id, + "client_secret": hybrid_application.client_secret, + "scope": "openid", + }, + ) + assert token_rsp.status_code == 200 + token_data = token_rsp.json() + assert "id_token" in token_data + assert OAuth2Validator.finalize_id_token.spy.call_count == 1 + oauthlib_request = OAuth2Validator.finalize_id_token.spy.call_args[0][4] + assert oauthlib_request.claims == claims diff --git a/tests/test_implicit.py b/tests/test_implicit.py index b51d0e1da..a5863401c 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -1,13 +1,17 @@ +import json from urllib.parse import parse_qs, urlparse +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse +from jwcrypto import jwt from oauth2_provider.models import get_application_model -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView +from . import presets + Application = get_application_model() UserModel = get_user_model() @@ -19,6 +23,7 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() @@ -33,15 +38,13 @@ def setUp(self): authorization_grant_type=Application.GRANT_IMPLICIT, ) - oauth2_settings._SCOPES = ["read", "write"] - oauth2_settings._DEFAULT_SCOPES = ["read"] - def tearDown(self): self.application.delete() self.test_user.delete() self.dev_user.delete() +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RO) class TestImplicitAuthorizationCodeView(BaseTest): def test_pre_auth_valid_client_default_scopes(self): """ @@ -237,6 +240,7 @@ def test_implicit_fails_when_redirect_uri_path_is_invalid(self): self.assertEqual(response.status_code, 400) +@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RO) class TestImplicitTokenView(BaseTest): def test_resource_access_allowed(self): self.client.login(username="test_user", password="123456") @@ -265,3 +269,198 @@ def test_resource_access_allowed(self): view = ResourceView.as_view() response = view(request) self.assertEqual(response, "This is a protected resource") + + +@pytest.mark.usefixtures("oidc_key") +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestOpenIDConnectImplicitFlow(BaseTest): + def setUp(self): + super().setUp() + self.application.algorithm = Application.RS256_ALGORITHM + self.application.save() + + def test_id_token_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: id_token + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "id_token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org#", response["Location"]) + self.assertNotIn("access_token=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + + uri_query = urlparse(response["Location"]).fragment + uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) + id_token = uri_query_params["id_token"][0] + jwt_token = jwt.JWT(key=self.key, jwt=id_token) + claims = json.loads(jwt_token.claims) + self.assertIn("nonce", claims) + self.assertNotIn("at_hash", claims) + + def test_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_data = { + "client_id": self.application.client_id, + "response_type": "id_token", + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + } + + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org#", response["Location"]) + self.assertNotIn("access_token=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + + uri_query = urlparse(response["Location"]).fragment + uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) + id_token = uri_query_params["id_token"][0] + jwt_token = jwt.JWT(key=self.key, jwt=id_token) + claims = json.loads(jwt_token.claims) + self.assertIn("nonce", claims) + self.assertNotIn("at_hash", claims) + + def test_id_token_skip_authorization_completely_missing_nonce(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_data = { + "client_id": self.application.client_id, + "response_type": "id_token", + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + } + + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) + self.assertEqual(response.status_code, 302) + self.assertIn("error=invalid_request", response["Location"]) + self.assertIn("error_description=Request+is+missing+mandatory+nonce+paramete", response["Location"]) + + def test_id_token_post_auth_deny(self): + """ + Test error when resource owner deny access + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "id_token", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("error=access_denied", response["Location"]) + + def test_access_token_and_id_token_post_auth_allow(self): + """ + Test authorization code is given for an allowed request with response_type: token + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "id_token token", + "allow": True, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org#", response["Location"]) + self.assertIn("access_token=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + + uri_query = urlparse(response["Location"]).fragment + uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) + id_token = uri_query_params["id_token"][0] + jwt_token = jwt.JWT(key=self.key, jwt=id_token) + claims = json.loads(jwt_token.claims) + self.assertIn("nonce", claims) + self.assertIn("at_hash", claims) + + def test_access_token_and_id_token_skip_authorization_completely(self): + """ + If application.skip_authorization = True, should skip the authorization page. + """ + self.client.login(username="test_user", password="123456") + self.application.skip_authorization = True + self.application.save() + + query_data = { + "client_id": self.application.client_id, + "response_type": "id_token token", + "state": "random_state_string", + "nonce": "random_nonce_string", + "scope": "openid", + "redirect_uri": "http://example.org", + } + + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) + self.assertEqual(response.status_code, 302) + self.assertIn("http://example.org#", response["Location"]) + self.assertIn("access_token=", response["Location"]) + self.assertIn("id_token=", response["Location"]) + self.assertIn("state=random_state_string", response["Location"]) + + uri_query = urlparse(response["Location"]).fragment + uri_query_params = dict(parse_qs(uri_query, keep_blank_values=True, strict_parsing=True)) + id_token = uri_query_params["id_token"][0] + jwt_token = jwt.JWT(key=self.key, jwt=id_token) + claims = json.loads(jwt_token.claims) + self.assertIn("nonce", claims) + self.assertIn("at_hash", claims) + + def test_access_token_and_id_token_post_auth_deny(self): + """ + Test error when resource owner deny access + """ + self.client.login(username="test_user", password="123456") + + form_data = { + "client_id": self.application.client_id, + "state": "random_state_string", + "scope": "openid", + "redirect_uri": "http://example.org", + "response_type": "id_token token", + "allow": False, + } + + response = self.client.post(reverse("oauth2_provider:authorize"), data=form_data) + self.assertEqual(response.status_code, 302) + self.assertIn("error=access_denied", response["Location"]) diff --git a/tests/test_introspection_auth.py b/tests/test_introspection_auth.py index 5fc12b6b1..9f871cdea 100644 --- a/tests/test_introspection_auth.py +++ b/tests/test_introspection_auth.py @@ -1,6 +1,7 @@ import calendar import datetime +import pytest from django.conf.urls import include from django.contrib.auth import get_user_model from django.http import HttpResponse @@ -11,9 +12,10 @@ from oauth2_provider.models import get_access_token_model, get_application_model from oauth2_provider.oauth2_validators import OAuth2Validator -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ScopedProtectedResourceView +from . import presets + try: from unittest import mock @@ -78,6 +80,8 @@ def json(self): @override_settings(ROOT_URLCONF=__name__) +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.INTROSPECTION_SETTINGS) class TestTokenIntrospectionAuth(TestCase): """ Tests for Authorization through token introspection @@ -114,16 +118,9 @@ def setUp(self): scope="read write dolphin", ) - oauth2_settings._SCOPES = ["read", "write", "introspection", "dolphin"] - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL = "http://example.org/introspection" - oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN = self.resource_server_token.token - oauth2_settings.READ_SCOPE = "read" - oauth2_settings.WRITE_SCOPE = "write" + self.oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN = self.resource_server_token.token def tearDown(self): - oauth2_settings._SCOPES = ["read", "write"] - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL = None - oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN = None self.resource_server_token.delete() self.application.delete() AccessToken.objects.all().delete() @@ -136,9 +133,9 @@ def test_get_token_from_authentication_server_not_existing_token(self, mock_get) """ token = self.validator._get_token_from_authentication_server( self.resource_server_token.token, - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL, - oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN, - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS, + self.oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL, + self.oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN, + self.oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS, ) self.assertIsNone(token) @@ -149,9 +146,9 @@ def test_get_token_from_authentication_server_existing_token(self, mock_get): """ token = self.validator._get_token_from_authentication_server( "foo", - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL, - oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN, - oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS, + self.oauth2_settings.RESOURCE_SERVER_INTROSPECTION_URL, + self.oauth2_settings.RESOURCE_SERVER_AUTH_TOKEN, + self.oauth2_settings.RESOURCE_SERVER_INTROSPECTION_CREDENTIALS, ) self.assertIsInstance(token, AccessToken) self.assertEqual(token.user.username, "foo_user") diff --git a/tests/test_introspection_view.py b/tests/test_introspection_view.py index 5b3fc58f8..0f68320ca 100644 --- a/tests/test_introspection_view.py +++ b/tests/test_introspection_view.py @@ -1,14 +1,15 @@ import calendar import datetime +import pytest from django.contrib.auth import get_user_model from django.test import TestCase from django.urls import reverse from django.utils import timezone from oauth2_provider.models import get_access_token_model, get_application_model -from oauth2_provider.settings import oauth2_settings +from . import presets from .utils import get_basic_auth_header @@ -17,6 +18,8 @@ UserModel = get_user_model() +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.INTROSPECTION_SETTINGS) class TestTokenIntrospectionViews(TestCase): """ Tests for Authorized Token Introspection Views @@ -74,12 +77,7 @@ def setUp(self): scope="read write dolphin", ) - oauth2_settings._SCOPES = ["read", "write", "introspection", "dolphin"] - oauth2_settings.READ_SCOPE = "read" - oauth2_settings.WRITE_SCOPE = "write" - def tearDown(self): - oauth2_settings._SCOPES = ["read", "write"] AccessToken.objects.all().delete() Application.objects.all().delete() UserModel.objects.all().delete() diff --git a/tests/test_mixins.py b/tests/test_mixins.py index 793a5b4b4..1294b75cb 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -1,13 +1,25 @@ +import logging + +import pytest from django.core.exceptions import ImproperlyConfigured +from django.http import HttpResponse from django.test import RequestFactory, TestCase from django.views.generic import View from oauthlib.oauth2 import Server from oauth2_provider.oauth2_backends import OAuthLibCore from oauth2_provider.oauth2_validators import OAuth2Validator -from oauth2_provider.views.mixins import OAuthLibMixin, ProtectedResourceMixin, ScopedResourceMixin +from oauth2_provider.views.mixins import ( + OAuthLibMixin, + OIDCOnlyMixin, + ProtectedResourceMixin, + ScopedResourceMixin, +) + +from . import presets +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): @classmethod def setUpClass(cls): @@ -16,32 +28,55 @@ def setUpClass(cls): class TestOAuthLibMixin(BaseTest): - def test_missing_oauthlib_backend_class(self): + def test_missing_oauthlib_backend_class_uses_fallback(self): + class CustomOauthLibBackend: + def __init__(self, *args, **kwargs): + pass + + self.oauth2_settings.OAUTH2_BACKEND_CLASS = CustomOauthLibBackend + class TestView(OAuthLibMixin, View): server_class = Server validator_class = OAuth2Validator test_view = TestView() - self.assertRaises(ImproperlyConfigured, test_view.get_oauthlib_backend_class) + self.assertEqual(CustomOauthLibBackend, test_view.get_oauthlib_backend_class()) + core = test_view.get_oauthlib_core() + self.assertTrue(isinstance(core, CustomOauthLibBackend)) + + def test_missing_server_class_uses_fallback(self): + class CustomServer: + def __init__(self, *args, **kwargs): + pass + + self.oauth2_settings.OAUTH2_SERVER_CLASS = CustomServer - def test_missing_server_class(self): class TestView(OAuthLibMixin, View): validator_class = OAuth2Validator oauthlib_backend_class = OAuthLibCore test_view = TestView() - self.assertRaises(ImproperlyConfigured, test_view.get_server) + self.assertEqual(CustomServer, test_view.get_server_class()) + core = test_view.get_oauthlib_core() + self.assertTrue(isinstance(core.server, CustomServer)) + + def test_missing_validator_class_uses_fallback(self): + class CustomValidator: + pass + + self.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CustomValidator - def test_missing_validator_class(self): class TestView(OAuthLibMixin, View): server_class = Server oauthlib_backend_class = OAuthLibCore test_view = TestView() - self.assertRaises(ImproperlyConfigured, test_view.get_server) + self.assertEqual(CustomValidator, test_view.get_validator_class()) + core = test_view.get_oauthlib_core() + self.assertTrue(isinstance(core.server.request_validator, CustomValidator)) def test_correct_server(self): class TestView(OAuthLibMixin, View): @@ -99,3 +134,38 @@ class TestView(ProtectedResourceMixin, View): view = TestView.as_view() response = view(request) self.assertEqual(response.status_code, 200) + + +@pytest.fixture +def oidc_only_view(): + class TView(OIDCOnlyMixin, View): + def get(self, *args, **kwargs): + return HttpResponse("OK") + + return TView.as_view() + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_oidc_only_mixin_oidc_enabled(oauth2_settings, rf, oidc_only_view): + assert oauth2_settings.OIDC_ENABLED + rsp = oidc_only_view(rf.get("/")) + assert rsp.status_code == 200 + assert rsp.content.decode("utf-8") == "OK" + + +def test_oidc_only_mixin_oidc_disabled_debug(oauth2_settings, rf, settings, oidc_only_view): + assert oauth2_settings.OIDC_ENABLED is False + settings.DEBUG = True + with pytest.raises(ImproperlyConfigured) as exc: + oidc_only_view(rf.get("/")) + assert "OIDC views are not enabled" in str(exc.value) + + +def test_oidc_only_mixin_oidc_disabled_no_debug(oauth2_settings, rf, settings, oidc_only_view, caplog): + assert oauth2_settings.OIDC_ENABLED is False + settings.DEBUG = False + with caplog.at_level(logging.WARNING, logger="oauth2_provider"): + rsp = oidc_only_view(rf.get("/")) + assert rsp.status_code == 404 + assert len(caplog.records) == 1 + assert "OIDC views are not enabled" in caplog.records[0].message diff --git a/tests/test_models.py b/tests/test_models.py index afcd6b419..7b37486ca 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -10,9 +10,11 @@ get_access_token_model, get_application_model, get_grant_model, + get_id_token_model, get_refresh_token_model, ) -from oauth2_provider.settings import oauth2_settings + +from . import presets Application = get_application_model() @@ -20,6 +22,7 @@ AccessToken = get_access_token_model() RefreshToken = get_refresh_token_model() UserModel = get_user_model() +IDToken = get_id_token_model() class BaseTestModels(TestCase): @@ -108,6 +111,7 @@ def test_scopes_property(self): OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL="tests.SampleRefreshToken", OAUTH2_PROVIDER_GRANT_MODEL="tests.SampleGrant", ) +@pytest.mark.usefixtures("oauth2_settings") class TestCustomModels(BaseTestModels): def test_custom_application_model(self): """ @@ -126,22 +130,16 @@ def test_custom_application_model(self): def test_custom_application_model_incorrect_format(self): # Patch oauth2 settings to use a custom Application model - oauth2_settings.APPLICATION_MODEL = "IncorrectApplicationFormat" + self.oauth2_settings.APPLICATION_MODEL = "IncorrectApplicationFormat" self.assertRaises(ValueError, get_application_model) - # Revert oauth2 settings - oauth2_settings.APPLICATION_MODEL = "oauth2_provider.Application" - def test_custom_application_model_not_installed(self): # Patch oauth2 settings to use a custom Application model - oauth2_settings.APPLICATION_MODEL = "tests.ApplicationNotInstalled" + self.oauth2_settings.APPLICATION_MODEL = "tests.ApplicationNotInstalled" self.assertRaises(LookupError, get_application_model) - # Revert oauth2 settings - oauth2_settings.APPLICATION_MODEL = "oauth2_provider.Application" - def test_custom_access_token_model(self): """ If a custom access token model is installed, it should be present in @@ -158,22 +156,16 @@ def test_custom_access_token_model(self): def test_custom_access_token_model_incorrect_format(self): # Patch oauth2 settings to use a custom AccessToken model - oauth2_settings.ACCESS_TOKEN_MODEL = "IncorrectAccessTokenFormat" + self.oauth2_settings.ACCESS_TOKEN_MODEL = "IncorrectAccessTokenFormat" self.assertRaises(ValueError, get_access_token_model) - # Revert oauth2 settings - oauth2_settings.ACCESS_TOKEN_MODEL = "oauth2_provider.AccessToken" - def test_custom_access_token_model_not_installed(self): # Patch oauth2 settings to use a custom AccessToken model - oauth2_settings.ACCESS_TOKEN_MODEL = "tests.AccessTokenNotInstalled" + self.oauth2_settings.ACCESS_TOKEN_MODEL = "tests.AccessTokenNotInstalled" self.assertRaises(LookupError, get_access_token_model) - # Revert oauth2 settings - oauth2_settings.ACCESS_TOKEN_MODEL = "oauth2_provider.AccessToken" - def test_custom_refresh_token_model(self): """ If a custom refresh token model is installed, it should be present in @@ -190,22 +182,16 @@ def test_custom_refresh_token_model(self): def test_custom_refresh_token_model_incorrect_format(self): # Patch oauth2 settings to use a custom RefreshToken model - oauth2_settings.REFRESH_TOKEN_MODEL = "IncorrectRefreshTokenFormat" + self.oauth2_settings.REFRESH_TOKEN_MODEL = "IncorrectRefreshTokenFormat" self.assertRaises(ValueError, get_refresh_token_model) - # Revert oauth2 settings - oauth2_settings.REFRESH_TOKEN_MODEL = "oauth2_provider.RefreshToken" - def test_custom_refresh_token_model_not_installed(self): # Patch oauth2 settings to use a custom AccessToken model - oauth2_settings.REFRESH_TOKEN_MODEL = "tests.RefreshTokenNotInstalled" + self.oauth2_settings.REFRESH_TOKEN_MODEL = "tests.RefreshTokenNotInstalled" self.assertRaises(LookupError, get_refresh_token_model) - # Revert oauth2 settings - oauth2_settings.REFRESH_TOKEN_MODEL = "oauth2_provider.RefreshToken" - def test_custom_grant_model(self): """ If a custom grant model is installed, it should be present in @@ -222,22 +208,16 @@ def test_custom_grant_model(self): def test_custom_grant_model_incorrect_format(self): # Patch oauth2 settings to use a custom Grant model - oauth2_settings.GRANT_MODEL = "IncorrectGrantFormat" + self.oauth2_settings.GRANT_MODEL = "IncorrectGrantFormat" self.assertRaises(ValueError, get_grant_model) - # Revert oauth2 settings - oauth2_settings.GRANT_MODEL = "oauth2_provider.Grant" - def test_custom_grant_model_not_installed(self): # Patch oauth2 settings to use a custom AccessToken model - oauth2_settings.GRANT_MODEL = "tests.GrantNotInstalled" + self.oauth2_settings.GRANT_MODEL = "tests.GrantNotInstalled" self.assertRaises(LookupError, get_grant_model) - # Revert oauth2 settings - oauth2_settings.GRANT_MODEL = "oauth2_provider.Grant" - class TestGrantModel(BaseTestModels): def setUp(self): @@ -310,6 +290,7 @@ def test_str(self): self.assertEqual("%s" % refresh_token, refresh_token.token) +@pytest.mark.usefixtures("oauth2_settings") class TestClearExpired(BaseTestModels): def setUp(self): super().setUp() @@ -341,11 +322,11 @@ def setUp(self): ) def test_clear_expired_tokens(self): - oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 60 + self.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 60 assert clear_expired() is None def test_clear_expired_tokens_incorect_timetype(self): - oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = "A" + self.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = "A" with pytest.raises(ImproperlyConfigured) as excinfo: clear_expired() result = excinfo.value.__class__.__name__ @@ -353,7 +334,7 @@ def test_clear_expired_tokens_incorect_timetype(self): def test_clear_expired_tokens_with_tokens(self): self.client.login(username="test_user", password="123456") - oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 0 + self.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 0 ttokens = AccessToken.objects.count() expiredt = AccessToken.objects.filter(expires__lte=timezone.now()).count() assert ttokens == 2 @@ -361,3 +342,93 @@ def test_clear_expired_tokens_with_tokens(self): clear_expired() expiredt = AccessToken.objects.filter(expires__lte=timezone.now()).count() assert expiredt == 0 + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_id_token_methods(oidc_tokens, rf): + id_token = IDToken.objects.get() + + # Token was just created, so should be valid + assert id_token.is_valid() + + # if expires is None, it should always be expired + # the column is NOT NULL, but could be NULL in sub-classes + id_token.expires = None + assert id_token.is_expired() + + # if no scopes are passed, they should be valid + assert id_token.allow_scopes(None) + + # if the requested scopes are in the token, they should be valid + assert id_token.allow_scopes(["openid"]) + + # if the requested scopes are not in the token, they should not be valid + assert id_token.allow_scopes(["fizzbuzz"]) is False + + # we should be able to get a list of the scopes on the token + assert id_token.scopes == {"openid": "OpenID connect"} + + # the id token should stringify as the JWT token + id_token_str = str(id_token) + assert str(id_token.jti) in id_token_str + assert id_token_str.endswith(str(id_token.user_id)) + + # revoking the token should delete it + id_token.revoke() + assert IDToken.objects.filter(jti=id_token.jti).count() == 0 + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_application_key(oauth2_settings, application): + # RS256 key + key = application.jwk_key + assert key.key_type == "RSA" + + # RS256 key, but not configured + oauth2_settings.OIDC_RSA_PRIVATE_KEY = None + with pytest.raises(ImproperlyConfigured) as exc: + application.jwk_key + assert "You must set OIDC_RSA_PRIVATE_KEY" in str(exc.value) + + # HS256 key + application.algorithm = Application.HS256_ALGORITHM + key = application.jwk_key + assert key.key_type == "oct" + + # No algorithm + application.algorithm = Application.NO_ALGORITHM + with pytest.raises(ImproperlyConfigured) as exc: + application.jwk_key + assert "This application does not support signed tokens" == str(exc.value) + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_application_clean(oauth2_settings, application): + # RS256, RSA key is configured + application.clean() + + # RS256, RSA key is not configured + oauth2_settings.OIDC_RSA_PRIVATE_KEY = None + with pytest.raises(ValidationError) as exc: + application.clean() + assert "You must set OIDC_RSA_PRIVATE_KEY" in str(exc.value) + + # HS256 algorithm, auth code + confidential -> allowed + application.algorithm = Application.HS256_ALGORITHM + application.clean() + + # HS256, auth code + public -> forbidden + application.client_type = Application.CLIENT_PUBLIC + with pytest.raises(ValidationError) as exc: + application.clean() + assert "You cannot use HS256" in str(exc.value) + + # HS256, hybrid + confidential -> forbidden + application.client_type = Application.CLIENT_CONFIDENTIAL + application.authorization_grant_type = Application.GRANT_OPENID_HYBRID + with pytest.raises(ValidationError) as exc: + application.clean() + assert "You cannot use HS256" in str(exc.value) diff --git a/tests/test_oauth2_backends.py b/tests/test_oauth2_backends.py index f318ccde1..860cbb461 100644 --- a/tests/test_oauth2_backends.py +++ b/tests/test_oauth2_backends.py @@ -1,5 +1,6 @@ import json +import pytest from django.test import RequestFactory, TestCase from oauth2_provider.backends import get_oauthlib_core @@ -12,15 +13,16 @@ import mock +@pytest.mark.usefixtures("oauth2_settings") class TestOAuthLibCoreBackend(TestCase): def setUp(self): self.factory = RequestFactory() self.oauthlib_core = OAuthLibCore() def test_swappable_server_class(self): - with mock.patch("oauth2_provider.oauth2_backends.oauth2_settings.OAUTH2_SERVER_CLASS"): - oauthlib_core = OAuthLibCore() - self.assertTrue(isinstance(oauthlib_core.server, mock.MagicMock)) + self.oauth2_settings.OAUTH2_SERVER_CLASS = mock.MagicMock + oauthlib_core = OAuthLibCore() + self.assertTrue(isinstance(oauthlib_core.server, mock.MagicMock)) def test_form_urlencoded_extract_params(self): payload = "grant_type=password&username=john&password=123456" diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index 21b0fcfa2..7997d3bca 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -1,15 +1,21 @@ import contextlib import datetime +import json +import pytest from django.contrib.auth import get_user_model from django.test import TestCase, TransactionTestCase from django.utils import timezone +from jwcrypto import jwt from oauthlib.common import Request from oauth2_provider.exceptions import FatalClientError from oauth2_provider.models import get_access_token_model, get_application_model, get_refresh_token_model +from oauth2_provider.oauth2_backends import get_oauthlib_core from oauth2_provider.oauth2_validators import OAuth2Validator +from . import presets + try: from unittest import mock @@ -440,3 +446,77 @@ def test_response_when_auth_server_response_return_404(self): "Not Found.\nNoneType: None", mock_log.output, ) + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_oidc_endpoint_generation(oauth2_settings, rf): + oauth2_settings.OIDC_ISS_ENDPOINT = "" + django_request = rf.get("/") + request = Request("/", headers=django_request.META) + validator = OAuth2Validator() + oidc_issuer_endpoint = validator.get_oidc_issuer_endpoint(request) + assert oidc_issuer_endpoint == "http://testserver/o" + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_oidc_endpoint_generation_ssl(oauth2_settings, rf, settings): + oauth2_settings.OIDC_ISS_ENDPOINT = "" + django_request = rf.get("/", secure=True) + # Calling the settings method with a django https request should generate a https url + oidc_issuer_endpoint = oauth2_settings.oidc_issuer(django_request) + assert oidc_issuer_endpoint == "https://testserver/o" + + # Should also work with an oauthlib request (via validator) + core = get_oauthlib_core() + uri, http_method, body, headers = core._extract_params(django_request) + request = Request(uri=uri, http_method=http_method, body=body, headers=headers) + validator = OAuth2Validator() + oidc_issuer_endpoint = validator.get_oidc_issuer_endpoint(request) + assert oidc_issuer_endpoint == "https://testserver/o" + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_get_jwt_bearer_token(oauth2_settings, mocker): + # oauthlib instructs us to make get_jwt_bearer_token call get_id_token + request = mocker.MagicMock(wraps=Request) + validator = OAuth2Validator() + mock_get_id_token = mocker.patch.object(validator, "get_id_token") + validator.get_jwt_bearer_token(None, None, request) + assert mock_get_id_token.call_count == 1 + assert mock_get_id_token.call_args[0] == (None, None, request) + assert mock_get_id_token.call_args[1] == {} + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_validate_id_token_expired_jwt(oauth2_settings, mocker, oidc_tokens): + mocker.patch("oauth2_provider.oauth2_validators.jwt.JWT", side_effect=jwt.JWTExpired) + validator = OAuth2Validator() + status = validator.validate_id_token(oidc_tokens.id_token, ["openid"], mocker.sentinel.request) + assert status is False + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_validate_id_token_no_token(oauth2_settings, mocker): + validator = OAuth2Validator() + status = validator.validate_id_token("", ["openid"], mocker.sentinel.request) + assert status is False + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens): + oidc_tokens.application.delete() + validator = OAuth2Validator() + status = validator.validate_id_token(oidc_tokens.id_token, ["openid"], mocker.sentinel.request) + assert status is False + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_validate_id_token_bad_token_no_aud(oauth2_settings, mocker, oidc_key): + token = jwt.JWT(header=json.dumps({"alg": "RS256"}), claims=json.dumps({"bad": "token"})) + token.make_signed_token(oidc_key) + validator = OAuth2Validator() + status = validator.validate_id_token(token.serialize(), ["openid"], mocker.sentinel.request) + assert status is False diff --git a/tests/test_oidc_views.py b/tests/test_oidc_views.py new file mode 100644 index 000000000..3e3a5538c --- /dev/null +++ b/tests/test_oidc_views.py @@ -0,0 +1,139 @@ +import pytest +from django.test import TestCase +from django.urls import reverse + +from oauth2_provider.oauth2_validators import OAuth2Validator + +from . import presets + + +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestConnectDiscoveryInfoView(TestCase): + def test_get_connect_discovery_info(self): + expected_response = { + "issuer": "http://localhost", + "authorization_endpoint": "http://localhost/o/authorize/", + "token_endpoint": "http://localhost/o/token/", + "userinfo_endpoint": "http://localhost/userinfo/", + "jwks_uri": "http://localhost/o/.well-known/jwks.json", + "response_types_supported": [ + "code", + "token", + "id_token", + "id_token token", + "code token", + "code id_token", + "code id_token token", + ], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256", "HS256"], + "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], + } + response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) + self.assertEqual(response.status_code, 200) + assert response.json() == expected_response + + def test_get_connect_discovery_info_without_issuer_url(self): + self.oauth2_settings.OIDC_ISS_ENDPOINT = None + self.oauth2_settings.OIDC_USERINFO_ENDPOINT = None + expected_response = { + "issuer": "http://testserver/o", + "authorization_endpoint": "http://testserver/o/authorize/", + "token_endpoint": "http://testserver/o/token/", + "userinfo_endpoint": "http://testserver/o/userinfo/", + "jwks_uri": "http://testserver/o/.well-known/jwks.json", + "response_types_supported": [ + "code", + "token", + "id_token", + "id_token token", + "code token", + "code id_token", + "code id_token token", + ], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256", "HS256"], + "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], + } + response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) + self.assertEqual(response.status_code, 200) + assert response.json() == expected_response + + def test_get_connect_discovery_info_without_rsa_key(self): + self.oauth2_settings.OIDC_RSA_PRIVATE_KEY = None + response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) + self.assertEqual(response.status_code, 200) + assert response.json()["id_token_signing_alg_values_supported"] == ["HS256"] + + +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +class TestJwksInfoView(TestCase): + def test_get_jwks_info(self): + expected_response = { + "keys": [ + { + "alg": "RS256", + "use": "sig", + "kid": "s4a1o8mFEd1tATAIH96caMlu4hOxzBUaI2QTqbYNBHs", + "e": "AQAB", + "kty": "RSA", + "n": "mwmIeYdjZkLgalTuhvvwjvnB5vVQc7G9DHgOm20Hw524bLVTk49IXJ2Scw42HOmowWWX-oMVT_ca3ZvVIeffVSN1-TxVy2zB65s0wDMwhiMoPv35z9IKHGMZgl9vlyso_2b7daVF_FQDdgIayUn8TQylBxEU1RFfW0QSYOBdAt8", # noqa + } + ] + } + response = self.client.get(reverse("oauth2_provider:jwks-info")) + self.assertEqual(response.status_code, 200) + assert response.json() == expected_response + + def test_get_jwks_info_no_rsa_key(self): + self.oauth2_settings.OIDC_RSA_PRIVATE_KEY = None + response = self.client.get(reverse("oauth2_provider:jwks-info")) + self.assertEqual(response.status_code, 200) + assert response.json() == {"keys": []} + + +@pytest.mark.django_db +@pytest.mark.parametrize("method", ["get", "post"]) +def test_userinfo_endpoint(oidc_tokens, client, method): + auth_header = "Bearer %s" % oidc_tokens.access_token + rsp = getattr(client, method)( + reverse("oauth2_provider:user-info"), + HTTP_AUTHORIZATION=auth_header, + ) + data = rsp.json() + assert "sub" in data + assert data["sub"] == str(oidc_tokens.user.pk) + + +@pytest.mark.django_db +def test_userinfo_endpoint_bad_token(oidc_tokens, client): + # No access token + rsp = client.get(reverse("oauth2_provider:user-info")) + assert rsp.status_code == 401 + # Bad access token + rsp = client.get( + reverse("oauth2_provider:user-info"), + HTTP_AUTHORIZATION="Bearer not-a-real-token", + ) + assert rsp.status_code == 401 + + +@pytest.mark.django_db +def test_userinfo_endpoint_custom_claims(oidc_tokens, client, oauth2_settings): + class CustomValidator(OAuth2Validator): + def get_additional_claims(self, request): + return {"state": "very nice"} + + oidc_tokens.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CustomValidator + auth_header = "Bearer %s" % oidc_tokens.access_token + rsp = client.get( + reverse("oauth2_provider:user-info"), + HTTP_AUTHORIZATION=auth_header, + ) + data = rsp.json() + assert "sub" in data + assert data["sub"] == str(oidc_tokens.user.pk) + assert "state" in data + assert data["state"] == "very nice" diff --git a/tests/test_password.py b/tests/test_password.py index f50404f9f..953b076e2 100644 --- a/tests/test_password.py +++ b/tests/test_password.py @@ -1,11 +1,11 @@ import json +import pytest from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse from oauth2_provider.models import get_application_model -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ProtectedResourceView from .utils import get_basic_auth_header @@ -21,6 +21,7 @@ def get(self, request, *args, **kwargs): return "This is a protected resource" +@pytest.mark.usefixtures("oauth2_settings") class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() @@ -34,9 +35,6 @@ def setUp(self): authorization_grant_type=Application.GRANT_PASSWORD, ) - oauth2_settings._SCOPES = ["read", "write"] - oauth2_settings._DEFAULT_SCOPES = ["read", "write"] - def tearDown(self): self.application.delete() self.test_user.delete() @@ -60,8 +58,8 @@ def test_get_token(self): content = json.loads(response.content.decode("utf-8")) self.assertEqual(content["token_type"], "Bearer") - self.assertEqual(content["scope"], "read write") - self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertEqual(set(content["scope"].split()), {"read", "write"}) + self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) def test_bad_credentials(self): """ diff --git a/tests/test_rest_framework.py b/tests/test_rest_framework.py index f23891dca..a25611b93 100644 --- a/tests/test_rest_framework.py +++ b/tests/test_rest_framework.py @@ -1,5 +1,6 @@ from datetime import timedelta +import pytest from django.conf.urls import include from django.contrib.auth import get_user_model from django.core.exceptions import ImproperlyConfigured @@ -22,13 +23,8 @@ TokenMatchesOASRequirements, ) from oauth2_provider.models import get_access_token_model, get_application_model -from oauth2_provider.settings import oauth2_settings - -try: - from unittest import mock -except ImportError: - import mock +from . import presets Application = get_application_model() @@ -131,10 +127,10 @@ class AuthenticationNoneOAuth2View(MockView): @override_settings(ROOT_URLCONF=__name__) +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(presets.REST_FRAMEWORK_SCOPES) class TestOAuth2Authentication(TestCase): def setUp(self): - oauth2_settings._SCOPES = ["read", "write", "scope1", "scope2", "resource1"] - self.test_user = UserModel.objects.create_user("test_user", "test@example.com", "123456") self.dev_user = UserModel.objects.create_user("dev_user", "dev@example.com", "123456") @@ -154,9 +150,6 @@ def setUp(self): application=self.application, ) - def tearDown(self): - oauth2_settings._SCOPES = ["read", "write"] - def _create_authorization_header(self, token): return "Bearer {0}".format(token) @@ -311,8 +304,8 @@ def test_resource_scoped_permission_post_denied(self): response = self.client.post("/oauth2-resource-scoped-test/", HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 403) - @mock.patch.object(oauth2_settings, "ERROR_RESPONSE_WITH_SCOPES", new=True) def test_required_scope_in_response(self): + self.oauth2_settings.ERROR_RESPONSE_WITH_SCOPES = True self.access_token.scope = "scope2" self.access_token.save() diff --git a/tests/test_scopes.py b/tests/test_scopes.py index d2efa5856..a310e223a 100644 --- a/tests/test_scopes.py +++ b/tests/test_scopes.py @@ -1,13 +1,13 @@ import json from urllib.parse import parse_qs, urlparse +import pytest from django.contrib.auth import get_user_model from django.core.exceptions import ImproperlyConfigured from django.test import RequestFactory, TestCase from django.urls import reverse from oauth2_provider.models import get_access_token_model, get_application_model, get_grant_model -from oauth2_provider.settings import oauth2_settings from oauth2_provider.views import ReadWriteScopedResourceView, ScopedProtectedResourceView from .utils import get_basic_auth_header @@ -42,6 +42,19 @@ def post(self, request, *args, **kwargs): return "This is a write protected resource" +SCOPE_SETTINGS = { + "SCOPES": { + "read": "Read scope", + "write": "Write scope", + "scope1": "Custom scope 1", + "scope2": "Custom scope 2", + "scope3": "Custom scope 3", + }, +} + + +@pytest.mark.usefixtures("oauth2_settings") +@pytest.mark.oauth2_settings(SCOPE_SETTINGS) class BaseTest(TestCase): def setUp(self): self.factory = RequestFactory() @@ -56,12 +69,7 @@ def setUp(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - oauth2_settings._SCOPES = ["read", "write", "scope1", "scope2", "scope3"] - oauth2_settings.READ_SCOPE = "read" - oauth2_settings.WRITE_SCOPE = "write" - def tearDown(self): - oauth2_settings._SCOPES = ["read", "write"] self.application.delete() self.test_user.delete() self.dev_user.delete() @@ -325,27 +333,27 @@ def get_access_token(self, scopes): return content["access_token"] def test_improperly_configured(self): - oauth2_settings.SCOPES = {"scope1": "Scope 1"} + self.oauth2_settings.SCOPES = {"scope1": "Scope 1"} request = self.factory.get("/fake") view = ReadWriteResourceView.as_view() self.assertRaises(ImproperlyConfigured, view, request) - oauth2_settings.SCOPES = {"read": "Read Scope", "write": "Write Scope"} - oauth2_settings.READ_SCOPE = "ciccia" + self.oauth2_settings.SCOPES = {"read": "Read Scope", "write": "Write Scope"} + self.oauth2_settings.READ_SCOPE = "ciccia" view = ReadWriteResourceView.as_view() self.assertRaises(ImproperlyConfigured, view, request) def test_properly_configured(self): - oauth2_settings.SCOPES = {"scope1": "Scope 1"} + self.oauth2_settings.SCOPES = {"scope1": "Scope 1"} request = self.factory.get("/fake") view = ReadWriteResourceView.as_view() self.assertRaises(ImproperlyConfigured, view, request) - oauth2_settings.SCOPES = {"read": "Read Scope", "write": "Write Scope"} - oauth2_settings.READ_SCOPE = "ciccia" + self.oauth2_settings.SCOPES = {"read": "Read Scope", "write": "Write Scope"} + self.oauth2_settings.READ_SCOPE = "ciccia" view = ReadWriteResourceView.as_view() self.assertRaises(ImproperlyConfigured, view, request) diff --git a/tests/test_scopes_backend.py b/tests/test_scopes_backend.py index 5f629613e..925a4e3c5 100644 --- a/tests/test_scopes_backend.py +++ b/tests/test_scopes_backend.py @@ -3,9 +3,9 @@ def test_settings_scopes_get_available_scopes(): scopes = SettingsScopes() - assert scopes.get_available_scopes() == ["read", "write"] + assert set(scopes.get_available_scopes()) == {"read", "write"} def test_settings_scopes_get_default_scopes(): scopes = SettingsScopes() - assert scopes.get_default_scopes() == ["read", "write"] + assert set(scopes.get_default_scopes()) == {"read", "write"} diff --git a/tests/test_settings.py b/tests/test_settings.py index 379d12c2e..52bdafe03 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,20 +1,27 @@ +import pytest +from django.core.exceptions import ImproperlyConfigured from django.test import TestCase from django.test.utils import override_settings +from oauthlib.common import Request from oauth2_provider.admin import ( get_access_token_admin_class, get_application_admin_class, get_grant_admin_class, + get_id_token_admin_class, get_refresh_token_admin_class, ) -from oauth2_provider.settings import OAuth2ProviderSettings, oauth2_settings +from oauth2_provider.settings import OAuth2ProviderSettings, oauth2_settings, perform_import from tests.admin import ( CustomAccessTokenAdmin, CustomApplicationAdmin, CustomGrantAdmin, + CustomIDTokenAdmin, CustomRefreshTokenAdmin, ) +from . import presets + class TestAdminClass(TestCase): def test_import_error_message_maintained(self): @@ -47,7 +54,15 @@ def test_get_grant_admin_class(self): """ grant_admin_class = get_grant_admin_class() default_grant_admin_class = oauth2_settings.GRANT_ADMIN_CLASS - assert grant_admin_class, default_grant_admin_class + assert grant_admin_class == default_grant_admin_class + + def test_get_id_token_admin_class(self): + """ + Test for getting class for ID token admin. + """ + id_token_admin_class = get_id_token_admin_class() + default_id_token_admin_class = oauth2_settings.ID_TOKEN_ADMIN_CLASS + assert id_token_admin_class == default_id_token_admin_class def test_get_refresh_token_admin_class(self): """ @@ -81,6 +96,14 @@ def test_get_custom_grant_admin_class(self): grant_admin_class = get_grant_admin_class() assert grant_admin_class == CustomGrantAdmin + @override_settings(OAUTH2_PROVIDER={"ID_TOKEN_ADMIN_CLASS": "tests.admin.CustomIDTokenAdmin"}) + def test_get_custom_id_token_admin_class(self): + """ + Test for getting custom class for ID token admin. + """ + id_token_admin_class = get_id_token_admin_class() + assert id_token_admin_class == CustomIDTokenAdmin + @override_settings(OAUTH2_PROVIDER={"REFRESH_TOKEN_ADMIN_CLASS": "tests.admin.CustomRefreshTokenAdmin"}) def test_get_custom_refresh_token_admin_class(self): """ @@ -88,3 +111,59 @@ def test_get_custom_refresh_token_admin_class(self): """ refresh_token_admin_class = get_refresh_token_admin_class() assert refresh_token_admin_class == CustomRefreshTokenAdmin + + +def test_perform_import_when_none(): + assert perform_import(None, "REFRESH_TOKEN_ADMIN_CLASS") is None + + +def test_perform_import_list(): + imports = ["tests.admin.CustomIDTokenAdmin", "tests.admin.CustomGrantAdmin"] + assert perform_import(imports, "SOME_CLASSES") == [CustomIDTokenAdmin, CustomGrantAdmin] + + +def test_perform_import_already_imported(): + cls = perform_import(CustomRefreshTokenAdmin, "REFRESH_TOKEN_ADMIN_CLASS") + assert cls == CustomRefreshTokenAdmin + + +def test_invalid_scopes_raises_error(): + settings = OAuth2ProviderSettings( + { + "SCOPES": {"foo": "foo scope"}, + "DEFAULT_SCOPES": ["bar"], + } + ) + with pytest.raises(ImproperlyConfigured) as exc: + settings._DEFAULT_SCOPES + assert str(exc.value) == "Defined DEFAULT_SCOPES not present in SCOPES" + + +def test_missing_mandatory_setting_raises_error(): + settings = OAuth2ProviderSettings( + user_settings={}, defaults={"very_important": None}, mandatory=["very_important"] + ) + with pytest.raises(AttributeError) as exc: + settings.very_important + assert str(exc.value) == "OAuth2Provider setting: very_important is mandatory" + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +@pytest.mark.parametrize("issuer_setting", ["http://foo.com/", None]) +@pytest.mark.parametrize("request_type", ["django", "oauthlib"]) +def test_generating_iss_endpoint(oauth2_settings, issuer_setting, request_type, rf): + oauth2_settings.OIDC_ISS_ENDPOINT = issuer_setting + if request_type == "django": + request = rf.get("/") + elif request_type == "oauthlib": + request = Request("/", headers=rf.get("/").META) + expected = issuer_setting or "http://testserver/o" + assert oauth2_settings.oidc_issuer(request) == expected + + +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) +def test_generating_iss_endpoint_type_error(oauth2_settings): + oauth2_settings.OIDC_ISS_ENDPOINT = None + with pytest.raises(TypeError) as exc: + oauth2_settings.oidc_issuer(None) + assert str(exc.value) == "request must be a django or oauthlib request: got None" diff --git a/tests/test_token_revocation.py b/tests/test_token_revocation.py index 5274ee13e..1ed1c9119 100644 --- a/tests/test_token_revocation.py +++ b/tests/test_token_revocation.py @@ -6,7 +6,6 @@ from django.utils import timezone from oauth2_provider.models import get_access_token_model, get_application_model, get_refresh_token_model -from oauth2_provider.settings import oauth2_settings Application = get_application_model() @@ -29,8 +28,6 @@ def setUp(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) - oauth2_settings._SCOPES = ["read", "write"] - def tearDown(self): self.application.delete() self.test_user.delete() diff --git a/tests/test_validators.py b/tests/test_validators.py index 82930a9d7..0760e0290 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,10 +1,11 @@ +import pytest from django.core.validators import ValidationError from django.test import TestCase -from oauth2_provider.settings import oauth2_settings from oauth2_provider.validators import RedirectURIValidator +@pytest.mark.usefixtures("oauth2_settings") class TestValidators(TestCase): def test_validate_good_uris(self): validator = RedirectURIValidator(allowed_schemes=["https"]) @@ -37,7 +38,7 @@ def test_validate_custom_uri_scheme(self): def test_validate_bad_uris(self): validator = RedirectURIValidator(allowed_schemes=["https"]) - oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"] + self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"] bad_uris = [ "http:/example.com", "HTTP://localhost", diff --git a/tests/urls.py b/tests/urls.py index f4b22a4d4..0661a9336 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -7,7 +7,5 @@ urlpatterns = [ path("o/", include("oauth2_provider.urls", namespace="oauth2_provider")), + path("admin/", admin.site.urls), ] - - -urlpatterns += [path("admin/", admin.site.urls)] diff --git a/tests/utils.py b/tests/utils.py index ec2590512..b7dc2001a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,5 @@ import base64 +from unittest import mock def get_basic_auth_header(user, password): @@ -13,3 +14,19 @@ def get_basic_auth_header(user, password): } return auth_headers + + +def spy_on(meth): + """ + Util function to add a spy onto a method of a class. + """ + spy = mock.MagicMock() + + def wrapper(self, *args, **kwargs): + spy(self, *args, **kwargs) + return_value = meth(self, *args, **kwargs) + spy.returned = return_value + return return_value + + wrapper.spy = spy + return wrapper diff --git a/tox.ini b/tox.ini index 8d0611633..3016d024c 100644 --- a/tox.ini +++ b/tox.ini @@ -14,10 +14,17 @@ python = [pytest] django_find_project = false +addopts = + --cov=oauth2_provider + --cov-report= + --cov-append + -s +markers = + oauth2_settings: Custom OAuth2 settings to use - use with oauth2_settings fixture [testenv] commands = - pytest --cov=oauth2_provider --cov-report= --cov-append {posargs} + pytest {posargs} coverage report coverage xml setenv = @@ -31,14 +38,16 @@ deps = djmain: https://github.com/django/django/archive/main.tar.gz djangorestframework oauthlib>=3.1.0 + jwcrypto coverage pytest pytest-cov pytest-django pytest-xdist + pytest-mock requests passenv = - PYTEST_ADDOPTS + PYTEST_ADDOPTS [testenv:py{38,39}-djmain] ignore_errors = true @@ -57,6 +66,7 @@ deps = m2r>=0.2.1 sphinx-rtd-theme livedocs: sphinx-autobuild + jwcrypto [testenv:flake8] basepython = python3.8 @@ -84,6 +94,9 @@ commands = source = oauth2_provider omit = */migrations/* +[coverage:report] +show_missing = True + [flake8] max-line-length = 110 exclude = docs/, oauth2_provider/migrations/, tests/migrations/, .tox/