Skip to content
This repository was archived by the owner on Aug 29, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions oauth2_provider/oauth2_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def create_authorization_response(self, uri, request, scopes, credentials, body,

# add current user to credentials. this will be used by OAUTH2_VALIDATOR_CLASS
credentials["user"] = request.user

headers, body, status = self.server.create_authorization_response(
uri=uri, scopes=scopes, credentials=credentials, body=body)
uri=uri, scopes=scopes, credentials=credentials, body=body, headers=request.META
)
redirect_uri = headers.get("Location", None)

return redirect_uri, headers, body, status
Expand Down
40 changes: 40 additions & 0 deletions tests/test_authorization_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import re
from urllib.parse import parse_qs, urlparse
from jwcrypto import jwk, jwt

from django.contrib.auth import get_user_model
from django.test import RequestFactory, TestCase
Expand Down Expand Up @@ -1253,6 +1254,45 @@ def test_id_token_public(self):
content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS
)

def test_id_token_public_oidc_capable(self):
"""
Check that the id token includes our custom iss
"""
iss_entity = "http://testserver/o"
oauth2_settings.OIDC_ISS_ENDPOINT = None
oauth2_settings.OIDC_USERINFO_ENDPOINT = None

self.client.login(username="test_user", password="123456")

self.application.client_type = Application.CLIENT_PUBLIC
self.application.save()
authorization_code = self.get_auth(scope="openid")

token_request_data = {
"grant_type": "authorization_code",
"code": authorization_code,
"redirect_uri": "http://example.org",
"client_id": self.application.client_id,
"scope": "openid",
}

response = self.client.post(
reverse("oauth2_provider:token"), data=token_request_data
)
self.assertEqual(response.status_code, 200)

# Unload and decode the jwt and check the iss entity matches
content = json.loads(response.content.decode("utf-8"))
key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8"))
jwt_token = jwt.JWT(key=key, jwt=content["id_token"])

# Find our testserver iss entity
self.assertIn(iss_entity, jwt_token.claims)

# Turn back on the OIDC specific endpoints
oauth2_settings.OIDC_ISS_ENDPOINT = "http://localhost"
oauth2_settings.OIDC_USERINFO_ENDPOINT = "http://localhost/userinfo/"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to move this logic into before/after functions. If the test fails before this then these lines won’t be called and the settings won’t get restored.

The other thing I’d consider doing is storing the initial values in variables and then setting to those variables after instead of hardcoding the settings here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah I was pondering along similar lines but I wanted to keep with the same style that are in the other tests inside test_oidc_views.

class TestConnectDiscoveryInfoView(TestCase):
    def test_get_connect_discovery_info(self):
        expected_response = {
            "issuer": "http://localhost",
            "authorization_endpoint": "http://localhost/o/authorize/",
            "token_endpoint": "http://localhost/o/token/",
            "userinfo_endpoint": "http://localhost/userinfo/",
            "jwks_uri": "http://localhost/o/jwks/",
            "response_types_supported": [
                "code",
                "token",
                "id_token",
                "id_token token",
                "code token",
                "code id_token",
                "code id_token token"
            ],
            "subject_types_supported": ["public"],
            "id_token_signing_alg_values_supported": ["RS256", "HS256"],
            "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"]
        }
        response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info"))
        self.assertEqual(response.status_code, 200)
        assert response.json() == expected_response

    def test_get_connect_discovery_info_without_issuer_url(self):
        oauth2_settings.OIDC_ISS_ENDPOINT = None
        oauth2_settings.OIDC_USERINFO_ENDPOINT = None
        expected_response = {
            "issuer": "http://testserver/o",
            "authorization_endpoint": "http://testserver/o/authorize/",
            "token_endpoint": "http://testserver/o/token/",
            "userinfo_endpoint": "http://testserver/o/userinfo/",
            "jwks_uri": "http://testserver/o/jwks/",
            "response_types_supported": [
                "code",
                "token",
                "id_token",
                "id_token token",
                "code token",
                "code id_token",
                "code id_token token"
            ],
            "subject_types_supported": ["public"],
            "id_token_signing_alg_values_supported": ["RS256", "HS256"],
            "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"]
        }
        response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info"))
        self.assertEqual(response.status_code, 200)
        assert response.json() == expected_response
        oauth2_settings.OIDC_ISS_ENDPOINT = "http://localhost"
        oauth2_settings.OIDC_USERINFO_ENDPOINT = "http://localhost/userinfo/"


class TestJwksInfoView(TestCase):
    def test_get_jwks_info(self):
        expected_response = {
            "keys": [{
                "alg": "RS256",
                "use": "sig",
                "kid": "s4a1o8mFEd1tATAIH96caMlu4hOxzBUaI2QTqbYNBHs",
                "e": "AQAB",
                "kty": "RSA",
                "n": "mwmIeYdjZkLgalTuhvvwjvnB5vVQc7G9DHgOm20Hw524bLVTk49IXJ2Scw42HOmowWWX-oMVT_ca3ZvVIeffVSN1-TxVy2zB65s0wDMwhiMoPv35z9IKHGMZgl9vlyso_2b7daVF_FQDdgIayUn8TQylBxEU1RFfW0QSYOBdAt8"  # noqa
            }]
        }
        response = self.client.get(reverse("oauth2_provider:jwks-info"))
        self.assertEqual(response.status_code, 200)
        assert response.json() == expected_response


def test_public_pkce_S256_authorize_get(self):
"""
Request an access token using client_type: public
Expand Down