11import warnings
2+ from enum import Enum
23from operator import attrgetter
34from urllib .parse import urljoin
45
@@ -39,12 +40,18 @@ def get_paths(self, request=None):
3940
4041 # Only generate the path prefix for paths that will be included
4142 if not paths :
42- return None
43+ return None , None
44+
45+ components_schemas = {}
4346
4447 for path , method , view in view_endpoints :
4548 if not self .has_view_permissions (path , method , view ):
4649 continue
47- operation = view .schema .get_operation (path , method )
50+ operation , operation_schema = view .schema .get_operation (path , method )
51+
52+ if operation_schema is not None :
53+ components_schemas .update (operation_schema )
54+
4855 # Normalise path for any provided mount url.
4956 if path .startswith ('/' ):
5057 path = path [1 :]
@@ -53,24 +60,29 @@ def get_paths(self, request=None):
5360 result .setdefault (path , {})
5461 result [path ][method .lower ()] = operation
5562
56- return result
63+ return result , components_schemas
5764
5865 def get_schema (self , request = None , public = False ):
5966 """
6067 Generate a OpenAPI schema.
6168 """
6269 self ._initialise_endpoints ()
6370
64- paths = self .get_paths (None if public else request )
71+ paths , components_schemas = self .get_paths (None if public else request )
6572 if not paths :
6673 return None
6774
6875 schema = {
6976 'openapi' : '3.0.2' ,
7077 'info' : self .get_info (),
71- 'paths' : paths ,
78+ 'paths' : paths
7279 }
7380
81+ if len (components_schemas ) > 0 :
82+ schema ['components' ] = {
83+ 'schemas' : components_schemas
84+ }
85+
7486 return schema
7587
7688# View Inspectors
@@ -106,7 +118,9 @@ def get_operation(self, path, method):
106118 operation ['requestBody' ] = request_body
107119 operation ['responses' ] = self ._get_responses (path , method )
108120
109- return operation
121+ component_schema = self ._get_component_schema (path , method )
122+
123+ return operation , component_schema
110124
111125 def _get_operation_id (self , path , method ):
112126 """
@@ -479,29 +493,67 @@ def _get_serializer(self, method, path):
479493 .format (view .__class__ .__name__ , method , path ))
480494 return None
481495
482- def _get_request_body (self , path , method ):
483- if method not in ('PUT' , 'PATCH' , 'POST' ):
496+ class SchemaMode (Enum ):
497+ RESPONSE = 1
498+ BODY = 2
499+
500+ def _get_item_schema (self , serializer , schema_mode , method ):
501+ if not isinstance (serializer , serializers .Serializer ):
484502 return {}
485503
486- self .request_media_types = self .map_parsers (path , method )
504+ # If the serializer uses a model, we should use a reference
505+ if hasattr (serializer , 'Meta' ) and hasattr (serializer .Meta , 'model' ):
506+ model_name = serializer .Meta .model .__name__
507+ return {'$ref' : '#/components/schemas/{}' .format (model_name )}
508+
509+ # There is no model, we'll map the serializer's fields
510+ item_schema = self ._map_serializer (serializer )
487511
512+ if schema_mode == self .SchemaMode .RESPONSE :
513+ # No write_only fields for response.
514+ for name , schema in item_schema ['properties' ].copy ().items ():
515+ if 'writeOnly' in schema :
516+ del item_schema ['properties' ][name ]
517+ if 'required' in item_schema :
518+ item_schema ['required' ] = [f for f in item_schema ['required' ] if f != name ]
519+
520+ elif schema_mode == self .SchemaMode .BODY :
521+ # No required fields for PATCH
522+ if method == 'PATCH' :
523+ item_schema .pop ('required' , None )
524+ # No read_only fields for request.
525+ for name , schema in item_schema ['properties' ].copy ().items ():
526+ if 'readOnly' in schema :
527+ del item_schema ['properties' ][name ]
528+
529+ return item_schema
530+
531+ def _get_component_schema (self , path , method ):
488532 serializer = self ._get_serializer (path , method )
489533
490534 if not isinstance (serializer , serializers .Serializer ):
491- return {}
535+ return None
536+
537+ # If the model has no model, then the serializer will be inlined
538+ if not hasattr (serializer , 'Meta' ) or not hasattr (serializer .Meta , 'model' ):
539+ return None
492540
541+ model_name = serializer .Meta .model .__name__
493542 content = self ._map_serializer (serializer )
494- # No required fields for PATCH
495- if method == 'PATCH' :
496- content .pop ('required' , None )
497- # No read_only fields for request.
498- for name , schema in content ['properties' ].copy ().items ():
499- if 'readOnly' in schema :
500- del content ['properties' ][name ]
543+
544+ return {model_name : content }
545+
546+ def _get_request_body (self , path , method ):
547+ if method not in ('PUT' , 'PATCH' , 'POST' ):
548+ return {}
549+
550+ self .request_media_types = self .map_parsers (path , method )
551+
552+ serializer = self ._get_serializer (path , method )
501553
502554 return {
503555 'content' : {
504- ct : {'schema' : content }
556+ ct : {'schema' : self . _get_item_schema ( serializer , self . SchemaMode . BODY , method ) }
505557 for ct in self .request_media_types
506558 }
507559 }
@@ -517,17 +569,8 @@ def _get_responses(self, path, method):
517569
518570 self .response_media_types = self .map_renderers (path , method )
519571
520- item_schema = {}
521572 serializer = self ._get_serializer (path , method )
522-
523- if isinstance (serializer , serializers .Serializer ):
524- item_schema = self ._map_serializer (serializer )
525- # No write_only fields for response.
526- for name , schema in item_schema ['properties' ].copy ().items ():
527- if 'writeOnly' in schema :
528- del item_schema ['properties' ][name ]
529- if 'required' in item_schema :
530- item_schema ['required' ] = [f for f in item_schema ['required' ] if f != name ]
573+ item_schema = self ._get_item_schema (serializer , self .SchemaMode .RESPONSE , method )
531574
532575 if is_list_view (path , method , self .view ):
533576 response_schema = {
0 commit comments