Skip to content

Commit c49bb59

Browse files
committed
Allow SchemaView to handle both CoreAPI & OpenAPI.
1 parent bb0db35 commit c49bb59

File tree

6 files changed

+52
-17
lines changed

6 files changed

+52
-17
lines changed

rest_framework/management/commands/generateschema.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from rest_framework import renderers
44
from rest_framework.schemas import coreapi
55
from rest_framework.schemas.openapi import SchemaGenerator
6-
from rest_framework.settings import api_settings
76

87
OPENAPI_MODE = 'openapi'
98
COREAPI_MODE = 'coreapi'
@@ -13,10 +12,7 @@ class Command(BaseCommand):
1312
help = "Generates configured API schema for project."
1413

1514
def get_mode(self):
16-
default_schema_class = api_settings.DEFAULT_SCHEMA_CLASS
17-
if issubclass(default_schema_class, coreapi.AutoSchema):
18-
return COREAPI_MODE
19-
return OPENAPI_MODE
15+
return COREAPI_MODE if coreapi.is_enabled() else OPENAPI_MODE
2016

2117
def add_arguments(self, parser):
2218
parser.add_argument('--title', dest="title", default='', type=str)

rest_framework/schemas/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,32 @@
2222
"""
2323
from rest_framework.settings import api_settings
2424

25+
from . import coreapi, openapi
2526
from .inspectors import DefaultSchema # noqa
2627
from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa
2728

2829

2930
def get_schema_view(
3031
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
31-
public=False, patterns=None, generator_class=SchemaGenerator,
32+
public=False, patterns=None, generator_class=None,
3233
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
3334
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
3435
"""
3536
Return a schema view.
3637
"""
37-
# Avoid import cycle on APIView
38-
from .views import SchemaView
38+
if generator_class is None:
39+
if coreapi.is_enabled():
40+
generator_class = coreapi.SchemaGenerator
41+
else:
42+
generator_class = openapi.SchemaGenerator
43+
3944
generator = generator_class(
4045
title=title, url=url, description=description,
4146
urlconf=urlconf, patterns=patterns,
4247
)
48+
49+
# Avoid import cycle on APIView
50+
from .views import SchemaView
4351
return SchemaView.as_view(
4452
renderer_classes=renderer_classes,
4553
schema_generator=generator,

rest_framework/schemas/coreapi.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,3 +609,8 @@ def get_link(self, path, method, base_url):
609609
fields=self._fields,
610610
description=self._description
611611
)
612+
613+
614+
def is_enabled():
615+
"""Is CoreAPI Mode enabled?"""
616+
return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema)

rest_framework/schemas/views.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
from rest_framework import exceptions, renderers
77
from rest_framework.response import Response
8+
from rest_framework.schemas import coreapi
89
from rest_framework.settings import api_settings
910
from rest_framework.views import APIView
1011

@@ -19,10 +20,16 @@ class SchemaView(APIView):
1920
def __init__(self, *args, **kwargs):
2021
super(SchemaView, self).__init__(*args, **kwargs)
2122
if self.renderer_classes is None:
22-
self.renderer_classes = [
23-
renderers.CoreAPIOpenAPIRenderer,
24-
renderers.CoreJSONRenderer
25-
]
23+
if coreapi.is_enabled():
24+
self.renderer_classes = [
25+
renderers.CoreAPIOpenAPIRenderer,
26+
renderers.CoreJSONRenderer
27+
]
28+
else:
29+
self.renderer_classes = [
30+
renderers.OpenAPIRenderer,
31+
renderers.JSONOpenAPIRenderer,
32+
]
2633
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
2734
self.renderer_classes += [renderers.BrowsableAPIRenderer]
2835

tests/schemas/test_coreapi.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,12 @@ def put_documented_custom_action(self, request, *args, **kwargs):
134134
pass
135135

136136

137-
if coreapi:
138-
schema_view = get_schema_view(title='Example API')
139-
else:
140-
def schema_view(request):
141-
pass
137+
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
138+
if coreapi:
139+
schema_view = get_schema_view(title='Example API')
140+
else:
141+
def schema_view(request):
142+
pass
142143

143144
router = DefaultRouter()
144145
router.register('example', ExampleViewSet, basename='example')
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from django.test import TestCase, override_settings
2+
3+
from rest_framework import renderers
4+
from rest_framework.schemas import coreapi, get_schema_view, openapi
5+
6+
7+
class GetSchemaViewTests(TestCase):
8+
"""For the get_schema_view() helper."""
9+
def test_openapi(self):
10+
schema_view = get_schema_view(title="With OpenAPI")
11+
assert isinstance(schema_view.initkwargs['schema_generator'], openapi.SchemaGenerator)
12+
assert renderers.OpenAPIRenderer in schema_view.cls().renderer_classes
13+
14+
def test_coreapi(self):
15+
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
16+
schema_view = get_schema_view(title="With CoreAPI")
17+
assert isinstance(schema_view.initkwargs['schema_generator'], coreapi.SchemaGenerator)
18+
assert renderers.CoreAPIOpenAPIRenderer in schema_view.cls().renderer_classes

0 commit comments

Comments
 (0)