diff --git a/oauth2_provider/oauth2_backends.py b/oauth2_provider/oauth2_backends.py index 404add70e..d4d564661 100644 --- a/oauth2_provider/oauth2_backends.py +++ b/oauth2_provider/oauth2_backends.py @@ -123,9 +123,9 @@ def create_authorization_response(self, uri, request, scopes, credentials, body, # add current user to credentials. this will be used by OAUTH2_VALIDATOR_CLASS credentials["user"] = request.user - headers, body, status = self.server.create_authorization_response( - uri=uri, scopes=scopes, credentials=credentials, body=body) + uri=uri, scopes=scopes, credentials=credentials, body=body, headers=request.META + ) redirect_uri = headers.get("Location", None) return redirect_uri, headers, body, status diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index e4eb8ae81..b4cf04caa 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -4,6 +4,7 @@ import json import re from urllib.parse import parse_qs, urlparse +from jwcrypto import jwk, jwt from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase @@ -1253,6 +1254,45 @@ def test_id_token_public(self): content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS ) + def test_id_token_public_oidc_capable(self): + """ + Check that the id token includes our custom iss + """ + iss_entity = "http://testserver/o" + oauth2_settings.OIDC_ISS_ENDPOINT = None + oauth2_settings.OIDC_USERINFO_ENDPOINT = None + + self.client.login(username="test_user", password="123456") + + self.application.client_type = Application.CLIENT_PUBLIC + self.application.save() + authorization_code = self.get_auth(scope="openid") + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client_id": self.application.client_id, + "scope": "openid", + } + + response = self.client.post( + reverse("oauth2_provider:token"), data=token_request_data + ) + self.assertEqual(response.status_code, 200) + + # Unload and decode the jwt and check the iss entity matches + content = json.loads(response.content.decode("utf-8")) + key = jwk.JWK.from_pem(oauth2_settings.OIDC_RSA_PRIVATE_KEY.encode("utf8")) + jwt_token = jwt.JWT(key=key, jwt=content["id_token"]) + + # Find our testserver iss entity + self.assertIn(iss_entity, jwt_token.claims) + + # Turn back on the OIDC specific endpoints + oauth2_settings.OIDC_ISS_ENDPOINT = "http://localhost" + oauth2_settings.OIDC_USERINFO_ENDPOINT = "http://localhost/userinfo/" + def test_public_pkce_S256_authorize_get(self): """ Request an access token using client_type: public