Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Stéphane Raimbault
Emanuele Palazzetti
David Fischer
Ash Christopher
Rodney Richardson
2 changes: 1 addition & 1 deletion oauth2_provider/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.7.2'
__version__ = '0.7.3'

__author__ = "Massimiliano Pippi & Federico Frenguelli"

Expand Down
9 changes: 9 additions & 0 deletions oauth2_provider/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from django.http import HttpResponseRedirect

from .settings import oauth2_settings


class HttpResponseUriRedirect(HttpResponseRedirect):
def __init__(self, redirect_to, *args, **kwargs):
self.allowed_schemes = oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES
super(HttpResponseUriRedirect, self).__init__(redirect_to, *args, **kwargs)
2 changes: 2 additions & 0 deletions oauth2_provider/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
'ACCESS_TOKEN_EXPIRE_SECONDS': 36000,
'APPLICATION_MODEL': getattr(settings, 'OAUTH2_PROVIDER_APPLICATION_MODEL', 'oauth2_provider.Application'),
'REQUEST_APPROVAL_PROMPT': 'force',
'ALLOWED_REDIRECT_URI_SCHEMES': ['http', 'https'],

# Special settings that will be evaluated at runtime
'_SCOPES': [],
Expand All @@ -52,6 +53,7 @@
'CLIENT_SECRET_GENERATOR_CLASS',
'OAUTH2_VALIDATOR_CLASS',
'SCOPES',
'ALLOWED_REDIRECT_URI_SCHEMES',
)

# List of settings that may be in string import notation.
Expand Down
75 changes: 74 additions & 1 deletion oauth2_provider/tests/test_authorization_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ def setUp(self):
self.test_user = UserModel.objects.create_user("test_user", "[email protected]", "123456")
self.dev_user = UserModel.objects.create_user("dev_user", "[email protected]", "123456")

oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ['http', 'custom-scheme']

self.application = Application(
name="Test Application",
redirect_uris="http://localhost http://example.com http://example.it",
redirect_uris="http://localhost http://example.com http://example.it custom-scheme://example.com",
user=self.dev_user,
client_type=Application.CLIENT_CONFIDENTIAL,
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
Expand Down Expand Up @@ -92,6 +94,34 @@ def test_pre_auth_valid_client(self):
self.assertEqual(form['scope'].value(), "read write")
self.assertEqual(form['client_id'].value(), self.application.client_id)

def test_pre_auth_valid_client_custom_redirect_uri_scheme(self):
"""
Test response for a valid client_id with response_type: code
using a non-standard, but allowed, redirect_uri scheme.
"""
self.client.login(username="test_user", password="123456")

query_string = urlencode({
'client_id': self.application.client_id,
'response_type': 'code',
'state': 'random_state_string',
'scope': 'read write',
'redirect_uri': 'custom-scheme://example.com',
})
url = "{url}?{qs}".format(url=reverse('oauth2_provider:authorize'), qs=query_string)

response = self.client.get(url)
self.assertEqual(response.status_code, 200)

# check form is in context and form params are valid
self.assertIn("form", response.context)

form = response.context["form"]
self.assertEqual(form['redirect_uri'].value(), "custom-scheme://example.com")
self.assertEqual(form['state'].value(), "random_state_string")
self.assertEqual(form['scope'].value(), "read write")
self.assertEqual(form['client_id'].value(), self.application.client_id)

def test_pre_auth_approval_prompt(self):
"""

Expand Down Expand Up @@ -307,6 +337,49 @@ def test_code_post_auth_malicious_redirect_uri(self):
response = self.client.post(reverse('oauth2_provider:authorize'), data=form_data)
self.assertEqual(response.status_code, 400)

def test_code_post_auth_allow_custom_redirect_uri_scheme(self):
"""
Test authorization code is given for an allowed request with response_type: code
using a non-standard, but allowed, redirect_uri scheme.
"""
self.client.login(username="test_user", password="123456")

form_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scope': 'read write',
'redirect_uri': 'custom-scheme://example.com',
'response_type': 'code',
'allow': True,
}

response = self.client.post(reverse('oauth2_provider:authorize'), data=form_data)
self.assertEqual(response.status_code, 302)
self.assertIn('custom-scheme://example.com?', response['Location'])
self.assertIn('state=random_state_string', response['Location'])
self.assertIn('code=', response['Location'])

def test_code_post_auth_deny_custom_redirect_uri_scheme(self):
"""
Test error when resource owner deny access
using a non-standard, but allowed, redirect_uri scheme.
"""
self.client.login(username="test_user", password="123456")

form_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scope': 'read write',
'redirect_uri': 'custom-scheme://example.com',
'response_type': 'code',
'allow': False,
}

response = self.client.post(reverse('oauth2_provider:authorize'), data=form_data)
self.assertEqual(response.status_code, 302)
self.assertIn('custom-scheme://example.com?', response['Location'])
self.assertIn("error=access_denied", response['Location'])


class TestAuthorizationCodeTokenView(BaseTest):
def get_auth(self):
Expand Down
30 changes: 24 additions & 6 deletions oauth2_provider/tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,35 @@
from django.test import TestCase
from django.core.validators import ValidationError

from ..settings import oauth2_settings
from ..validators import validate_uris


class TestValidators(TestCase):
def test_validate_good_uris(self):
good_urls = 'http://example.com/ http://example.it/?key=val http://example'
good_uris = 'http://example.com/ http://example.it/?key=val http://example'
# Check ValidationError not thrown
validate_uris(good_urls)
validate_uris(good_uris)

def test_validate_custom_uri_scheme(self):
oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ['my-scheme', 'http']
good_uris = 'my-scheme://example.com http://example.com'
# Check ValidationError not thrown
validate_uris(good_uris)

def test_validate_whitespace_separators(self):
# Check that whitespace can be used as a separator
good_uris = 'http://example\r\nhttp://example\thttp://example'
# Check ValidationError not thrown
validate_uris(good_uris)

def test_validate_bad_uris(self):
bad_url = 'http://example.com/#fragment'
self.assertRaises(ValidationError, validate_uris, bad_url)
bad_url = 'http:/example.com'
self.assertRaises(ValidationError, validate_uris, bad_url)
bad_uri = 'http://example.com/#fragment'
self.assertRaises(ValidationError, validate_uris, bad_uri)
bad_uri = 'http:/example.com'
self.assertRaises(ValidationError, validate_uris, bad_uri)
bad_uri = 'my-scheme://example.com'
self.assertRaises(ValidationError, validate_uris, bad_uri)
bad_uri = 'sdklfsjlfjljdflksjlkfjsdkl'
self.assertRaises(ValidationError, validate_uris, bad_uri)

14 changes: 11 additions & 3 deletions oauth2_provider/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from django.utils.six.moves.urllib.parse import urlsplit, urlunsplit
from django.core.validators import RegexValidator

from .settings import oauth2_settings

class URIValidator(RegexValidator):
regex = re.compile(
r'^(?:[a-z0-9\.\-]*)s?://' # http:// or https://
r'^(?:[a-z][a-z0-9\.\-\+]*)://' # scheme...
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain...
r'(?!-)[A-Z\d-]{1,63}(?<!-)|' # also cover non-dotted domain
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|' # ...or ipv4

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not intimately familiar with this function, but is there a reason for using a custom regex instead of a custom validator based on urlparse.urlparse?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not my code - that's already in the master (I just fixed the scheme to match the RFC standard).

The call seems to handle the case where the network location has been internationalized (although there are no tests for that), so that may have something to do with it.

We DO need a RedirectURIValidator (which inherits from this), because the standard says that Redirect URIs cannot contain fragments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at urlparse documentation, I can't see why this (or the safe equivalent for Python 2 and 3 compatibility) couldn't be used instead of a RegexValidator. Perhaps you might add this as a separate issue?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, urlparse is what I used in my implementation (https://github.com/Locu/djoauth2/blob/9df7c3661e0a4c3585d3a333a63d5ed74472083c/djoauth2/authorization.py#L185):

if urlparse(redirect_uri).fragment:
  raise InvalidRequest('"redirect_uri" must not contain a fragment')

I'll take a look at this later today / tomorrow when I have the time.

Expand Down Expand Up @@ -41,16 +42,23 @@ def __call__(self, value):


class RedirectURIValidator(URIValidator):
def __init__(self, allowed_schemes):
self.allowed_schemes = allowed_schemes

def __call__(self, value):
super(RedirectURIValidator, self).__call__(value)
value = force_text(value)
if len(value.split('#')) > 1:
raise ValidationError('Redirect URIs must not contain fragments')
scheme, netloc, path, query, fragment = urlsplit(value)
if scheme.lower() not in self.allowed_schemes:
raise ValidationError('Redirect URI scheme is not allowed.')


def validate_uris(value):
"""
This validator ensures that `value` contains valid blank-separated urls"
This validator ensures that `value` contains valid blank-separated URIs"
"""
v = RedirectURIValidator()
v = RedirectURIValidator(oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES)
for uri in value.split():
v(uri)
9 changes: 5 additions & 4 deletions oauth2_provider/views/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from django.http import HttpResponse, HttpResponseRedirect
from django.http import HttpResponse
from django.views.generic import View, FormView
from django.utils import timezone

Expand All @@ -11,6 +11,7 @@
from ..settings import oauth2_settings
from ..exceptions import OAuthToolkitError
from ..forms import AllowForm
from ..http import HttpResponseUriRedirect
from ..models import get_application_model
from .mixins import OAuthLibMixin

Expand Down Expand Up @@ -41,7 +42,7 @@ def error_response(self, error, **kwargs):
redirect, error_response = super(BaseAuthorizationView, self).error_response(error, **kwargs)

if redirect:
return HttpResponseRedirect(error_response['url'])
return HttpResponseUriRedirect(error_response['url'])

status = error_response['error'].status_code
return self.render_to_response(error_response, status=status)
Expand Down Expand Up @@ -100,7 +101,7 @@ def form_valid(self, form):
request=self.request, scopes=scopes, credentials=credentials, allow=allow)
self.success_url = uri
log.debug("Success url for the request: {0}".format(self.success_url))
return super(AuthorizationView, self).form_valid(form)
return HttpResponseUriRedirect(self.success_url)

except OAuthToolkitError as error:
return self.error_response(error)
Expand Down Expand Up @@ -130,7 +131,7 @@ def get(self, request, *args, **kwargs):
uri, headers, body, status = self.create_authorization_response(
request=self.request, scopes=" ".join(scopes),
credentials=credentials, allow=True)
return HttpResponseRedirect(uri)
return HttpResponseUriRedirect(uri)

return self.render_to_response(self.get_context_data(**kwargs))

Expand Down