Skip to content

Commit d99550e

Browse files
committed
Add Open API get_schema().
1 parent 740424e commit d99550e

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

rest_framework/schemas/generators.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,7 @@ def get_schema(self, request=None, public=False):
283283
"""
284284
Generate a `coreapi.Document` representing the API schema.
285285
"""
286-
if self.endpoints is None:
287-
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
288-
self.endpoints = inspector.get_api_endpoints()
286+
self._initialise_endpoints()
289287

290288
links = self.get_links(None if public else request)
291289
if not links:
@@ -301,6 +299,11 @@ def get_schema(self, request=None, public=False):
301299
url=url, content=links
302300
)
303301

302+
def _initialise_endpoints(self):
303+
if self.endpoints is None:
304+
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
305+
self.endpoints = inspector.get_api_endpoints()
306+
304307
def get_links(self, request=None):
305308
"""
306309
Return a dictionary containing all the links that should be
@@ -491,3 +494,20 @@ def get_paths(self, request=None):
491494
result[subpath][method.lower()] = operation
492495

493496
return result
497+
498+
def get_schema(self, request=None, public=False):
499+
"""
500+
Generate a `coreapi.Document` representing the API schema.
501+
"""
502+
self._initialise_endpoints()
503+
504+
paths = self.get_paths(None if public else request)
505+
if not paths:
506+
return None
507+
508+
schema = {
509+
'basePath': self.url,
510+
'paths': paths,
511+
}
512+
513+
return schema

tests/schemas/test_openapi.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,12 @@ def test_override_settings(self):
4747
assert isinstance(views.ExampleListView.schema, OpenAPIAutoSchema)
4848

4949
def test_paths_construction(self):
50+
"""Construction of the `paths` key."""
5051
patterns = [
5152
url(r'^example/?$', views.ExampleListView.as_view()),
5253
]
5354
generator = OpenAPISchemaGenerator(patterns=patterns)
54-
55-
# This happens in get_schema()
56-
inspector = generator.endpoint_inspector_cls(generator.patterns, generator.urlconf)
57-
generator.endpoints = inspector.get_api_endpoints()
55+
generator._initialise_endpoints()
5856

5957
paths = generator.get_paths()
6058

@@ -63,3 +61,16 @@ def test_paths_construction(self):
6361
assert len(example_operations) == 2
6462
assert 'get' in example_operations
6563
assert 'post' in example_operations
64+
65+
def test_schema_construction(self):
66+
"""Construction of the top level dictionary."""
67+
patterns = [
68+
url(r'^example/?$', views.ExampleListView.as_view()),
69+
]
70+
generator = OpenAPISchemaGenerator(patterns=patterns)
71+
72+
request = create_request('/')
73+
schema = generator.get_schema(request=request)
74+
75+
assert 'basePath' in schema
76+
assert 'paths' in schema

0 commit comments

Comments
 (0)