Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions rest_framework/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,47 +288,47 @@ def request(self, **kwargs):
def get(self, path, data=None, follow=False, **extra):
response = super().get(path, data=data, **extra)
if follow:
response = self._handle_redirects(response, **extra)
response = self._handle_redirects(response, data=data, **extra)
return response

def post(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super().post(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response

def put(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super().put(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response

def patch(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super().patch(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response

def delete(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super().delete(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response

def options(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
response = super().options(
path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
response = self._handle_redirects(response, data=data, format=format, content_type=content_type, **extra)
return response

def logout(self):
Expand Down
72 changes: 36 additions & 36 deletions tests/test_testing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import itertools
from io import BytesIO
from unittest.mock import patch

import django
from django.contrib.auth.models import User
from django.http import HttpResponseRedirect
from django.shortcuts import redirect
from django.test import TestCase, override_settings
from django.urls import path
Expand All @@ -14,7 +17,7 @@
)


@api_view(['GET', 'POST'])
@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
def view(request):
return Response({
'auth': request.META.get('HTTP_AUTHORIZATION', b''),
Expand All @@ -36,6 +39,11 @@ def redirect_view(request):
return redirect('/view/')


@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
def redirect_307_308_view(request, code):
return HttpResponseRedirect('/view/', status=code)


class BasicSerializer(serializers.Serializer):
flag = fields.BooleanField(default=lambda: True)

Expand All @@ -51,6 +59,7 @@ def post_view(request):
path('view/', view),
path('session-view/', session_view),
path('redirect-view/', redirect_view),
path('redirect-view/<int:code>/', redirect_307_308_view),
path('post-view/', post_view)
]

Expand Down Expand Up @@ -146,41 +155,32 @@ def test_follow_redirect(self):
"""
Follow redirect by setting follow argument.
"""
response = self.client.get('/redirect-view/')
assert response.status_code == 302
response = self.client.get('/redirect-view/', follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200

response = self.client.post('/redirect-view/')
assert response.status_code == 302
response = self.client.post('/redirect-view/', follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200

response = self.client.put('/redirect-view/')
assert response.status_code == 302
response = self.client.put('/redirect-view/', follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200

response = self.client.patch('/redirect-view/')
assert response.status_code == 302
response = self.client.patch('/redirect-view/', follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200

response = self.client.delete('/redirect-view/')
assert response.status_code == 302
response = self.client.delete('/redirect-view/', follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200

response = self.client.options('/redirect-view/')
assert response.status_code == 302
response = self.client.options('/redirect-view/', follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200
for method in ('get', 'post', 'put', 'patch', 'delete', 'options'):
with self.subTest(method=method):
req_method = getattr(self.client, method)
response = req_method('/redirect-view/')
assert response.status_code == 302
response = req_method('/redirect-view/', follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200

def test_follow_307_308_preserve_kwargs(self, *mocked_methods):
"""
Follow redirect by setting follow argument, and make sure the following
method called with appropriate kwargs.
"""
methods = ('get', 'post', 'put', 'patch', 'delete', 'options')
codes = (307, 308)
for method, code in itertools.product(methods, codes):
subtest_ctx = self.subTest(method=method, code=code)
patch_ctx = patch.object(self.client, method, side_effect=getattr(self.client, method))
with subtest_ctx, patch_ctx as req_method:
kwargs = {'data': {'example': 'test'}, 'format': 'json'}
response = req_method('/redirect-view/%s/' % code, follow=True, **kwargs)
assert response.redirect_chain is not None
assert response.status_code == 200
for _, call_args, call_kwargs in req_method.mock_calls:
assert all(call_kwargs[k] == kwargs[k] for k in kwargs if k in call_kwargs)

def test_invalid_multipart_data(self):
"""
Expand Down