diff --git a/oauth2_provider/migrations/0005_auto_20161221_1906.py b/oauth2_provider/migrations/0005_auto_20161221_1906.py new file mode 100644 index 000000000..e594228c8 --- /dev/null +++ b/oauth2_provider/migrations/0005_auto_20161221_1906.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.10.4 on 2016-12-21 19:06 +from __future__ import unicode_literals + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +from oauth2_provider.settings import oauth2_settings + + +class Migration(migrations.Migration): + + dependencies = [ + ('oauth2_provider', '0004_auto_20160525_1623'), + migrations.swappable_dependency(oauth2_settings.RESOURCE_OWNER_MODEL), + ] + + operations = [ + migrations.AlterField( + model_name='accesstoken', + name='user', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=oauth2_settings.RESOURCE_OWNER_MODEL), + ), + migrations.AlterField( + model_name='grant', + name='user', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=oauth2_settings.RESOURCE_OWNER_MODEL), + ), + migrations.AlterField( + model_name='refreshtoken', + name='user', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=oauth2_settings.RESOURCE_OWNER_MODEL), + ), + ] diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 6a0103cdd..eb554ec9a 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -142,7 +142,6 @@ class Application(AbstractApplication): class Meta(AbstractApplication.Meta): swappable = 'OAUTH2_PROVIDER_APPLICATION_MODEL' - @python_2_unicode_compatible class Grant(models.Model): """ @@ -159,7 +158,7 @@ class Grant(models.Model): * :attr:`redirect_uri` Self explained * :attr:`scope` Required scopes, optional """ - user = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE) + user = models.ForeignKey(oauth2_settings.RESOURCE_OWNER_MODEL, on_delete=models.CASCADE) code = models.CharField(max_length=255, unique=True) # code comes from oauthlib application = models.ForeignKey(oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE) @@ -197,7 +196,7 @@ class AccessToken(models.Model): * :attr:`expires` Date and time of token expiration, in DateTime format * :attr:`scope` Allowed scopes """ - user = models.ForeignKey(settings.AUTH_USER_MODEL, blank=True, null=True, + user = models.ForeignKey(oauth2_settings.RESOURCE_OWNER_MODEL, blank=True, null=True, on_delete=models.CASCADE) token = models.CharField(max_length=255, unique=True) application = models.ForeignKey(oauth2_settings.APPLICATION_MODEL, @@ -270,7 +269,7 @@ class RefreshToken(models.Model): * :attr:`access_token` AccessToken instance this refresh token is bounded to """ - user = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE) + user = models.ForeignKey(oauth2_settings.RESOURCE_OWNER_MODEL, on_delete=models.CASCADE) token = models.CharField(max_length=255, unique=True) application = models.ForeignKey(oauth2_settings.APPLICATION_MODEL, on_delete=models.CASCADE) @@ -289,6 +288,20 @@ def __str__(self): return self.token +def get_resource_owner_model(): + """ Return the Resource Owner model that is active in this project. """ + try: + app_label, model_name = oauth2_settings.RESOURCE_OWNER_MODEL.split('.') + except ValueError: + e = "RESOURCE_OWNER_MODEL must be of the form 'app_label.model_name'" + raise ImproperlyConfigured(e) + app_model = apps.get_model(app_label, model_name) + if app_model is None: + e = "RESOURCE_OWNER_MODEL refers to model {0} that has not been installed" + raise ImproperlyConfigured(e.format(oauth2_settings.RESOURCE_OWNER_MODEL)) + return app_model + + def get_application_model(): """ Return the Application model that is active in this project. """ try: diff --git a/oauth2_provider/oauth2_backends.py b/oauth2_provider/oauth2_backends.py index 3a829b550..921e0490b 100644 --- a/oauth2_provider/oauth2_backends.py +++ b/oauth2_provider/oauth2_backends.py @@ -32,6 +32,15 @@ def _get_escaped_full_path(self, request): return urlunparse(parsed) + def _extract_resource_owner(self, request): + """ + Extracts the resource owner object from the Django request object + :param request: The current django.http.HttpRequest object + :return: the Resource Owner object + """ + return request.user + + def _get_extra_credentials(self, request): """ Produce extra credentials for token response. This dictionary will be @@ -106,14 +115,14 @@ def create_authorization_response(self, request, scopes, credentials, allow): :param scopes: A list of provided scopes :param credentials: Authorization credentials dictionary containing `client_id`, `state`, `redirect_uri`, `response_type` - :param allow: True if the user authorize the client, otherwise False + :param allow: True if the resource owner authorize the client, otherwise False """ try: if not allow: raise oauth2.AccessDeniedError() - # add current user to credentials. this will be used by OAUTH2_VALIDATOR_CLASS - credentials['user'] = request.user + # add current resource owner to credentials. this will be used by OAUTH2_VALIDATOR_CLASS + credentials['user'] = self._extract_resource_owner(request) headers, body, status = self.server.create_authorization_response( uri=credentials['redirect_uri'], scopes=scopes, credentials=credentials) diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 5ed033a20..e4c303a52 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -31,6 +31,13 @@ class OAuth2Validator(RequestValidator): + + def _extract_token_user(self, token): + return token.user + + def _extract_resource_owner(self, request): + return request.user + def _extract_basic_auth(self, request): """ Return authentication string if request contains basic auth credentials, @@ -238,7 +245,7 @@ def validate_bearer_token(self, token, scopes, request): token=token) if access_token.is_valid(scopes): request.client = access_token.application - request.user = access_token.user + request.user = self._extract_token_user(access_token) request.scopes = scopes # this is needed by django rest framework @@ -253,7 +260,7 @@ def validate_code(self, client_id, code, client, request, *args, **kwargs): grant = Grant.objects.get(code=code, application=client) if not grant.is_expired(): request.scopes = grant.scope.split(' ') - request.user = grant.user + request.user = self._extract_token_user(grant) return True return False @@ -296,7 +303,7 @@ def validate_redirect_uri(self, client_id, redirect_uri, request, *args, **kwarg def save_authorization_code(self, client_id, code, request, *args, **kwargs): expires = timezone.now() + timedelta( seconds=oauth2_settings.AUTHORIZATION_CODE_EXPIRE_SECONDS) - g = Grant(application=request.client, user=request.user, code=code['code'], + g = Grant(application=request.client, user=self._extract_resource_owner(request), code=code['code'], expires=expires, redirect_uri=request.redirect_uri, scope=' '.join(request.scopes)) g.save() @@ -344,7 +351,7 @@ def save_bearer_token(self, token, request, *args, **kwargs): access_token = AccessToken.objects.select_for_update().get( pk=refresh_token_instance.access_token.pk ) - access_token.user = request.user + access_token.user = self._extract_resource_owner(request) access_token.scope = token['scope'] access_token.expires = expires access_token.token = token['access_token'] @@ -365,7 +372,7 @@ def save_bearer_token(self, token, request, *args, **kwargs): access_token = self._create_access_token(expires, request, token) refresh_token = RefreshToken( - user=request.user, + user=self._extract_resource_owner(request), token=refresh_token_code, application=request.client, access_token=access_token @@ -381,7 +388,7 @@ def save_bearer_token(self, token, request, *args, **kwargs): def _create_access_token(self, expires, request, token): access_token = AccessToken( - user=request.user, + user=self._extract_resource_owner(request), scope=token['scope'], expires=expires, token=token['access_token'], @@ -437,7 +444,7 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs """ try: rt = RefreshToken.objects.get(token=refresh_token) - request.user = rt.user + request.user = self._extract_token_user(rt) request.refresh_token = rt.token # Temporary store RefreshToken instance to be reused by get_original_scopes. request.refresh_token_instance = rt diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index bab3626c8..ad08eeb97 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -43,6 +43,7 @@ 'REFRESH_TOKEN_EXPIRE_SECONDS': None, 'ROTATE_REFRESH_TOKEN': True, 'APPLICATION_MODEL': getattr(settings, 'OAUTH2_PROVIDER_APPLICATION_MODEL', 'oauth2_provider.Application'), + 'RESOURCE_OWNER_MODEL': getattr(settings, 'OAUTH2_PROVIDER_RESOURCE_OWNER_MODEL', settings.AUTH_USER_MODEL), 'REQUEST_APPROVAL_PROMPT': 'force', 'ALLOWED_REDIRECT_URI_SCHEMES': ['http', 'https'], diff --git a/oauth2_provider/tests/models.py b/oauth2_provider/tests/models.py index 27b01d66f..492b39546 100644 --- a/oauth2_provider/tests/models.py +++ b/oauth2_provider/tests/models.py @@ -1,6 +1,10 @@ +from django.conf import settings from django.db import models from oauth2_provider.models import AbstractApplication class TestApplication(AbstractApplication): custom_field = models.CharField(max_length=255) + +class TestResourceOwner(models.Model): + user = models.ForeignKey(settings.AUTH_USER_MODEL, related_name="resource_owners") diff --git a/oauth2_provider/tests/settings.py b/oauth2_provider/tests/settings.py index a9aa0b4e1..c95fc9312 100644 --- a/oauth2_provider/tests/settings.py +++ b/oauth2_provider/tests/settings.py @@ -126,3 +126,5 @@ }, } } + +OAUTH2_PROVIDER_RESOURCE_OWNER_MODEL = 'auth.User' diff --git a/oauth2_provider/tests/test_models.py b/oauth2_provider/tests/test_models.py index 022beefa7..22669ba1e 100644 --- a/oauth2_provider/tests/test_models.py +++ b/oauth2_provider/tests/test_models.py @@ -7,16 +7,20 @@ from django.test.utils import override_settings from django.utils import timezone -from ..models import get_application_model, Grant, AccessToken, RefreshToken +from ..models import get_application_model, get_resource_owner_model +from ..models import Grant, AccessToken, RefreshToken Application = get_application_model() +ResourceOwnerModel = get_resource_owner_model() UserModel = get_user_model() + class TestModels(TestCase): def setUp(self): self.user = UserModel.objects.create_user("test_user", "test@user.com", "123456") + self.resource_owner = self.user def test_allow_scopes(self): self.client.login(username="test_user", password="123456") @@ -29,7 +33,7 @@ def test_allow_scopes(self): ) access_token = AccessToken( - user=self.user, + user=self.resource_owner, scope='read write', expires=0, token='', @@ -88,7 +92,7 @@ def test_scopes_property(self): ) access_token = AccessToken( - user=self.user, + user=self.resource_owner, scope='read write', expires=0, token='', @@ -96,7 +100,7 @@ def test_scopes_property(self): ) access_token2 = AccessToken( - user=self.user, + user=self.resource_owner, scope='write', expires=0, token='', @@ -150,6 +154,7 @@ def test_expires_can_be_none(self): class TestAccessTokenModel(TestCase): def setUp(self): self.user = UserModel.objects.create_user("test_user", "test@user.com", "123456") + self.resource_owner = self.user def test_str(self): access_token = AccessToken(token="test_token") @@ -177,3 +182,12 @@ class TestRefreshTokenModel(TestCase): def test_str(self): refresh_token = RefreshToken(token="test_token") self.assertEqual("%s" % refresh_token, refresh_token.token) + + +@override_settings(OAUTH2_PROVIDER={'RESOURCE_OWNER_MODEL':'tests.TestResourceOwner'}) +class TestCustomResourceOwnerModel(TestCase): + def setUp(self): + self.user = UserModel.objects.create_user("test_user", "test@user.com", "123456") + + def test_model(self): + self.user.resource_owners.all() diff --git a/oauth2_provider/tests/test_oauth2_backends.py b/oauth2_provider/tests/test_oauth2_backends.py index d0b8a766b..5997d667b 100644 --- a/oauth2_provider/tests/test_oauth2_backends.py +++ b/oauth2_provider/tests/test_oauth2_backends.py @@ -50,6 +50,9 @@ class MyOAuthLibCore(OAuthLibCore): def _get_extra_credentials(self, request): return 1 + def extract_resource_owner(self, request): + return request.organization_user + def setUp(self): self.factory = RequestFactory() diff --git a/oauth2_provider/views/token.py b/oauth2_provider/views/token.py index ef8b9799f..9a248cf3a 100644 --- a/oauth2_provider/views/token.py +++ b/oauth2_provider/views/token.py @@ -14,13 +14,15 @@ class AuthorizedTokensListView(LoginRequiredMixin, ListView): context_object_name = 'authorized_tokens' template_name = 'oauth2_provider/authorized-tokens.html' model = AccessToken + user_lookup_attr = 'user' def get_queryset(self): """ Show only user's tokens """ return super(AuthorizedTokensListView, self).get_queryset()\ - .select_related('application').filter(user=self.request.user) + .select_related('application')\ + .filter(**{self.user_lookup_attr:self.request.user}) class AuthorizedTokenDeleteView(LoginRequiredMixin, DeleteView): @@ -30,6 +32,8 @@ class AuthorizedTokenDeleteView(LoginRequiredMixin, DeleteView): template_name = 'oauth2_provider/authorized-token-delete.html' success_url = reverse_lazy('oauth2_provider:authorized-token-list') model = AccessToken + user_lookup_attr = 'user' def get_queryset(self): - return super(AuthorizedTokenDeleteView, self).get_queryset().filter(user=self.request.user) + return super(AuthorizedTokenDeleteView, self).get_queryset()\ + .filter(**{self.user_lookup_attr:self.request.user})