diff --git a/AUTHORS b/AUTHORS index 5345c4869..68960486a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -60,4 +60,6 @@ Jadiel Teófilo pySilver Łukasz Skarżyński Shaheed Haque +Andrea Greco Vinay Karanam + diff --git a/docs/oidc.rst b/docs/oidc.rst index ba69e984f..eae9a67d4 100644 --- a/docs/oidc.rst +++ b/docs/oidc.rst @@ -245,16 +245,17 @@ required claims, eg ``iss``, ``aud``, ``exp``, ``iat``, ``auth_time`` etc), and the ``sub`` claim will use the primary key of the user as the value. You'll probably want to customize this and add additional claims or change what is sent for the ``sub`` claim. To do so, you will need to add a method to -our custom validator:: - +our custom validator. +Standard claim ``sub`` is included by default, for remove it override ``get_claim_list``:: class CustomOAuth2Validator(OAuth2Validator): - - def get_additional_claims(self, request): - return { - "sub": request.user.email, - "first_name": request.user.first_name, - "last_name": request.user.last_name, - } + def get_additional_claims(self): + def get_user_email(request): + return request.user.get_full_name() + + # Element name, callback to obtain data + claims_list = [ ("email", get_sub_cod), + ("username", get_user_email) ] + return claims_list .. note:: This ``request`` object is not a ``django.http.Request`` object, but an diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index f3a24e258..461c40d53 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -728,15 +728,24 @@ def _save_id_token(self, jti, request, expires, *args, **kwargs): def get_jwt_bearer_token(self, token, token_handler, request): return self.get_id_token(token, token_handler, request) - def get_oidc_claims(self, token, token_handler, request): - # Required OIDC claims - claims = { - "sub": str(request.user.id), - } + def get_claim_list(self): + def get_sub_code(request): + return str(request.user.id) + + list = [("sub", get_sub_code)] # https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims - claims.update(**self.get_additional_claims(request)) + add = self.get_additional_claims() + list.extend(add) + + return list + def get_oidc_claims(self, token, token_handler, request): + data = self.get_claim_list() + claims = {} + + for k, call in data: + claims[k] = call(request) return claims def get_id_token_dictionary(self, token, token_handler, request): @@ -889,5 +898,5 @@ def get_userinfo_claims(self, request): """ return self.get_oidc_claims(None, None, request) - def get_additional_claims(self, request): - return {} + def get_additional_claims(self): + return [] diff --git a/oauth2_provider/views/oidc.py b/oauth2_provider/views/oidc.py index b4bb8869b..0cd24fc85 100644 --- a/oauth2_provider/views/oidc.py +++ b/oauth2_provider/views/oidc.py @@ -45,6 +45,13 @@ def get(self, request, *args, **kwargs): signing_algorithms = [Application.HS256_ALGORITHM] if oauth2_settings.OIDC_RSA_PRIVATE_KEY: signing_algorithms = [Application.RS256_ALGORITHM, Application.HS256_ALGORITHM] + + validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS + validator = validator_class() + oidc_claims = [] + for el, _ in validator.get_claim_list(): + oidc_claims.append(el) + data = { "issuer": issuer_url, "authorization_endpoint": authorization_endpoint, @@ -57,6 +64,7 @@ def get(self, request, *args, **kwargs): "token_endpoint_auth_methods_supported": ( oauth2_settings.OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED ), + "claims_supported": oidc_claims, } response = JsonResponse(data) response["Access-Control-Allow-Origin"] = "*" diff --git a/tests/test_oidc_views.py b/tests/test_oidc_views.py index 46040f86d..719d10e98 100644 --- a/tests/test_oidc_views.py +++ b/tests/test_oidc_views.py @@ -29,6 +29,7 @@ def test_get_connect_discovery_info(self): "subject_types_supported": ["public"], "id_token_signing_alg_values_supported": ["RS256", "HS256"], "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], + "claims_supported": ["sub"], } response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) self.assertEqual(response.status_code, 200) @@ -55,6 +56,7 @@ def test_get_connect_discovery_info_without_issuer_url(self): "subject_types_supported": ["public"], "id_token_signing_alg_values_supported": ["RS256", "HS256"], "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], + "claims_supported": ["sub"], } response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) self.assertEqual(response.status_code, 200) @@ -146,11 +148,21 @@ def test_userinfo_endpoint_bad_token(oidc_tokens, client): assert rsp.status_code == 401 +EXAMPLE_EMAIL = "example.email@example.com" + + +def claim_user_email(request): + return EXAMPLE_EMAIL + + @pytest.mark.django_db def test_userinfo_endpoint_custom_claims(oidc_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): - def get_additional_claims(self, request): - return {"state": "very nice"} + def get_additional_claims(self): + return [ + ("username", claim_user_email), + ("email", claim_user_email), + ] oidc_tokens.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CustomValidator auth_header = "Bearer %s" % oidc_tokens.access_token @@ -161,5 +173,9 @@ def get_additional_claims(self, request): data = rsp.json() assert "sub" in data assert data["sub"] == str(oidc_tokens.user.pk) - assert "state" in data - assert data["state"] == "very nice" + + assert "username" in data + assert data["username"] == EXAMPLE_EMAIL + + assert "email" in data + assert data["email"] == EXAMPLE_EMAIL