Skip to content

Commit 57535b4

Browse files
committed
OpenID: Transform get_additional_claims into two forms
See documentation change for rationale.
1 parent 1f85553 commit 57535b4

File tree

4 files changed

+48
-14
lines changed

4 files changed

+48
-14
lines changed

docs/oidc.rst

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -245,23 +245,41 @@ required claims, eg ``iss``, ``aud``, ``exp``, ``iat``, ``auth_time`` etc),
245245
and the ``sub`` claim will use the primary key of the user as the value.
246246
You'll probably want to customize this and add additional claims or change
247247
what is sent for the ``sub`` claim. To do so, you will need to add a method to
248-
our custom validator. It should return a dictionary mapping a claim name to
249-
either the claim data, or a callable that will be called with the request to
250-
produce the claim data.
251-
Standard claim ``sub`` is included by default, for remove it override ``get_claim_list``::
248+
our custom validator. It takes one of two forms:
249+
250+
The first form gets passed a request object, and should return a dictionary
251+
mapping a claim name to claim data::
252252
class CustomOAuth2Validator(OAuth2Validator):
253253
def get_additional_claims(self, request):
254+
claims = {}
255+
claims["email"] = request.user.get_user_email()
256+
claims["username"] = request.user.get_full_name()
257+
258+
return claims
259+
260+
The second form gets no request object, and should return a dictionary
261+
mapping a claim name to a callable, accepting a request and producing
262+
the claim data::
263+
class CustomOAuth2Validator(OAuth2Validator):
264+
def get_additional_claims(self):
254265
def get_user_email(request):
255266
return request.user.get_user_email()
256267

257268
claims = {}
258-
# Element name, callback to obtain data
259269
claims["email"] = get_user_email
260-
# Element name, plain data returned
261-
claims["username"] = request.user.get_full_name()
270+
claims["username"] = lambda r: r.user.get_full_name()
262271

263272
return claims
264273

274+
Standard claim ``sub`` is included by default, for remove it override ``get_claim_list``.
275+
276+
In order to help lcients discover claims early, they can be advertised in the discovery
277+
info, under the ``claims_supported`` key. In order for the discovery info view to automatically
278+
add all claims your validator returns, you need to use the second form (producing callables),
279+
because the discovery info views are requested with an unauthenticated request, so directly
280+
producing claim data would fail. If you use the first form, producing claim data directly,
281+
your claims will not be added to discovery info.
282+
265283
.. note::
266284
This ``request`` object is not a ``django.http.Request`` object, but an
267285
``oauthlib.common.Request`` object. This has a number of attributes that

oauth2_provider/oauth2_validators.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import base64
22
import binascii
33
import http.client
4+
import inspect
45
import json
56
import logging
67
import uuid
@@ -737,21 +738,34 @@ def _save_id_token(self, jti, request, expires, *args, **kwargs):
737738
)
738739
return id_token
739740

741+
@classmethod
742+
def _get_additional_claims_is_request_agnostic(cls):
743+
return len(inspect.signature(cls.get_additional_claims).parameters) == 1
744+
740745
def get_jwt_bearer_token(self, token, token_handler, request):
741746
return self.get_id_token(token, token_handler, request)
742747

743748
def get_claim_dict(self, request):
744-
def get_sub_code(inner_request):
745-
return str(inner_request.user.id)
746-
747-
claims = {"sub": get_sub_code}
749+
if self._get_additional_claims_is_request_agnostic():
750+
claims = {"sub": lambda r: str(r.user.id)}
751+
else:
752+
claims = {"sub": str(request.user.id)}
748753

749754
# https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims
750-
add = self.get_additional_claims(request)
755+
if self._get_additional_claims_is_request_agnostic():
756+
add = self.get_additional_claims()
757+
else:
758+
add = self.get_additional_claims(request)
751759
claims.update(add)
752760

753761
return claims
754762

763+
def get_discovery_claims(self, request):
764+
claims = ["sub"]
765+
if self._get_additional_claims_is_request_agnostic():
766+
claims += list(self.get_claim_dict(request).keys())
767+
return claims
768+
755769
def get_oidc_claims(self, token, token_handler, request):
756770
data = self.get_claim_dict(request)
757771
claims = {}

oauth2_provider/views/oidc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def get(self, request, *args, **kwargs):
4848

4949
validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS
5050
validator = validator_class()
51-
oidc_claims = list(validator.get_claim_dict(request).keys())
51+
oidc_claims = validator.get_discovery_claims(request)
52+
if "sub" not in oidc_claims:
53+
oidc_claims.append("sub")
5254

5355
data = {
5456
"issuer": issuer_url,

tests/test_oidc_views.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def claim_user_email(request):
155155
@pytest.mark.django_db
156156
def test_userinfo_endpoint_custom_claims_callable(oidc_tokens, client, oauth2_settings):
157157
class CustomValidator(OAuth2Validator):
158-
def get_additional_claims(self, request):
158+
def get_additional_claims(self):
159159
return {
160160
"username": claim_user_email,
161161
"email": claim_user_email,

0 commit comments

Comments
 (0)