diff --git a/docs/community/3.10-announcement.md b/docs/community/3.10-announcement.md new file mode 100644 index 0000000000..52c3c0628c --- /dev/null +++ b/docs/community/3.10-announcement.md @@ -0,0 +1,44 @@ + +# Django REST framework 3.10 + + +* Reworked OpenAPI schema generation. +* Python 3 only. + + +## OpenAPI Schema Generation. + +Since we first introduced schema support in Django REST Framework 3.5, OpenAPI has emerged as the widely adopted standard for modelling Web APIs. + +This release deprecates the old CoreAPI based schema generation, and introduces improved OpenAPI schema generation in its place. + +### Deprecating CoreAPI Schema Generation. + +The in-built docs that were introduced in Django REST Framework v3.5 were built on CoreAPI. These are now deprecated. You may continue to use them but they will be **removed in Django REST Framework v 3.12**. + +You should migrate to using the new OpenAPI based schema generation as soon as you can. + + +We have removed the old documentation for the CoreAPI based schema generation. +You may view them from here: + +TOOD: Insert links: + +* Tutorial 7: Schemas & client libraries +* Section from Documenting your API topic page. + * Includes API reference for `include_docs_urls()` and supporting functions. +* Schemas API Guide reference. + +---- + +**TODO**: +- Currently, `SchemaView` will keep working. + - Hardcoded generator class. +- `generateschema` won't work... Do we need to fix that? + - Only introduced in v3.9. Not really useful for `coreapi`? +- Note on (≈private) import changes. + - Documented utility imports from `schemas` should be maintained. + - Sub-imports can move for the code changes. + +---- + diff --git a/mkdocs.yml b/mkdocs.yml index 83b3c8141e..f3059a6b79 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -65,6 +65,7 @@ pages: - 'Contributing to REST framework': 'community/contributing.md' - 'Project management': 'community/project-management.md' - 'Release Notes': 'community/release-notes.md' + - '3.10 Announcement': 'community/3.10-announcement.md' - '3.9 Announcement': 'community/3.9-announcement.md' - '3.8 Announcement': 'community/3.8-announcement.md' - '3.7 Announcement': 'community/3.7-announcement.md' diff --git a/rest_framework/filters.py b/rest_framework/filters.py index bb1b86586c..53b77ff395 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -40,6 +40,9 @@ def get_schema_fields(self, view): assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' return [] + def get_schema_operation_parameters(self, view): + return [] + class SearchFilter(BaseFilterBackend): # The URL query parameter used for the search. @@ -159,6 +162,19 @@ def get_schema_fields(self, view): ) ] + def get_schema_operation_parameters(self, view): + return [ + { + 'name': self.search_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.search_description), + 'schema': { + 'type': 'string', + }, + }, + ] + class OrderingFilter(BaseFilterBackend): # The URL query parameter used for the ordering. @@ -290,6 +306,19 @@ def get_schema_fields(self, view): ) ] + def get_schema_operation_parameters(self, view): + return [ + { + 'name': self.ordering_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.ordering_description), + 'schema': { + 'type': 'string', + }, + }, + ] + class DjangoObjectPermissionsFilter(BaseFilterBackend): """ diff --git a/rest_framework/management/commands/generateschema.py b/rest_framework/management/commands/generateschema.py index 591073ba03..55e27ea8fb 100644 --- a/rest_framework/management/commands/generateschema.py +++ b/rest_framework/management/commands/generateschema.py @@ -1,24 +1,20 @@ from django.core.management.base import BaseCommand -from rest_framework.compat import coreapi -from rest_framework.renderers import ( - CoreJSONRenderer, JSONOpenAPIRenderer, OpenAPIRenderer -) -from rest_framework.schemas.generators import SchemaGenerator +from rest_framework.compat import yaml +from rest_framework.schemas.openapi import SchemaGenerator +from rest_framework.utils import json class Command(BaseCommand): help = "Generates configured API schema for project." def add_arguments(self, parser): - parser.add_argument('--title', dest="title", default=None, type=str) + parser.add_argument('--title', dest="title", default='', type=str) parser.add_argument('--url', dest="url", default=None, type=str) parser.add_argument('--description', dest="description", default=None, type=str) - parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str) + parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str) def handle(self, *args, **options): - assert coreapi is not None, 'coreapi must be installed.' - generator = SchemaGenerator( url=options['url'], title=options['title'], @@ -27,15 +23,10 @@ def handle(self, *args, **options): schema = generator.get_schema(request=None, public=True) - renderer = self.get_renderer(options['format']) - output = renderer.render(schema, renderer_context={}) - self.stdout.write(output.decode('utf-8')) - - def get_renderer(self, format): - renderer_cls = { - 'corejson': CoreJSONRenderer, - 'openapi': OpenAPIRenderer, - 'openapi-json': JSONOpenAPIRenderer, - }[format] + # TODO: Handle via renderer? More options? + if options['format'] == 'openapi': + output = yaml.dump(schema, default_flow_style=False) + else: + output = json.dumps(schema, indent=2) - return renderer_cls() + self.stdout.write(output) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index b11d7cdf3a..e93095d10e 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -152,6 +152,9 @@ def get_schema_fields(self, view): assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' return [] + def get_schema_operation_parameters(self, view): + return [] + class PageNumberPagination(BasePagination): """ @@ -305,6 +308,32 @@ def get_schema_fields(self, view): ) return fields + def get_schema_operation_parameters(self, view): + parameters = [ + { + 'name': self.page_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.page_query_description), + 'schema': { + 'type': 'integer', + }, + }, + ] + if self.page_size_query_param is not None: + parameters.append( + { + 'name': self.page_size_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.page_size_query_description), + 'schema': { + 'type': 'integer', + }, + }, + ) + return parameters + class LimitOffsetPagination(BasePagination): """ @@ -434,6 +463,15 @@ def to_html(self): context = self.get_html_context() return template.render(context) + def get_count(self, queryset): + """ + Determine an object count, supporting either querysets or regular lists. + """ + try: + return queryset.count() + except (AttributeError, TypeError): + return len(queryset) + def get_schema_fields(self, view): assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' @@ -458,14 +496,28 @@ def get_schema_fields(self, view): ) ] - def get_count(self, queryset): - """ - Determine an object count, supporting either querysets or regular lists. - """ - try: - return queryset.count() - except (AttributeError, TypeError): - return len(queryset) + def get_schema_operation_parameters(self, view): + parameters = [ + { + 'name': self.limit_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.limit_query_description), + 'schema': { + 'type': 'integer', + }, + }, + { + 'name': self.offset_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.offset_query_description), + 'schema': { + 'type': 'integer', + }, + }, + ] + return parameters class CursorPagination(BasePagination): @@ -820,3 +872,29 @@ def get_schema_fields(self, view): ) ) return fields + + def get_schema_operation_parameters(self, view): + parameters = [ + { + 'name': self.cursor_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.cursor_query_description), + 'schema': { + 'type': 'integer', + }, + } + ] + if self.page_size_query_param is not None: + parameters.append( + { + 'name': self.page_size_query_param, + 'required': False, + 'in': 'query', + 'description': force_text(self.page_size_query_description), + 'schema': { + 'type': 'integer', + }, + } + ) + return parameters diff --git a/rest_framework/schemas/__init__.py b/rest_framework/schemas/__init__.py index ba0ec65369..b1f37987a7 100644 --- a/rest_framework/schemas/__init__.py +++ b/rest_framework/schemas/__init__.py @@ -22,8 +22,8 @@ """ from rest_framework.settings import api_settings -from .generators import SchemaGenerator -from .inspectors import AutoSchema, DefaultSchema, ManualSchema # noqa +from .inspectors import DefaultSchema # noqa +from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa def get_schema_view( diff --git a/rest_framework/schemas/coreapi.py b/rest_framework/schemas/coreapi.py new file mode 100644 index 0000000000..895ea0efd1 --- /dev/null +++ b/rest_framework/schemas/coreapi.py @@ -0,0 +1,611 @@ +import re +import warnings +from collections import Counter, OrderedDict + +from django.db import models +from django.utils.encoding import force_text, smart_text +from django.utils.six.moves.urllib import parse as urlparse + +from rest_framework import exceptions, serializers +from rest_framework.compat import coreapi, coreschema, uritemplate +from rest_framework.settings import api_settings +from rest_framework.utils import formatting + +from .generators import BaseSchemaGenerator +from .inspectors import ViewInspector +from .utils import get_pk_description, is_list_view + +# Used in _get_description_section() +# TODO: ???: move up to base. +header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') + +# Generator # +# TODO: Pull some of this into base. + + +def is_custom_action(action): + return action not in { + 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' + } + + +def distribute_links(obj): + for key, value in obj.items(): + distribute_links(value) + + for preferred_key, link in obj.links: + key = obj.get_available_key(preferred_key) + obj[key] = link + + +INSERT_INTO_COLLISION_FMT = """ +Schema Naming Collision. + +coreapi.Link for URL path {value_url} cannot be inserted into schema. +Position conflicts with coreapi.Link for URL path {target_url}. + +Attempted to insert link with keys: {keys}. + +Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()` +to customise schema structure. +""" + + +class LinkNode(OrderedDict): + def __init__(self): + self.links = [] + self.methods_counter = Counter() + super(LinkNode, self).__init__() + + def get_available_key(self, preferred_key): + if preferred_key not in self: + return preferred_key + + while True: + current_val = self.methods_counter[preferred_key] + self.methods_counter[preferred_key] += 1 + + key = '{}_{}'.format(preferred_key, current_val) + if key not in self: + return key + + +def insert_into(target, keys, value): + """ + Nested dictionary insertion. + + >>> example = {} + >>> insert_into(example, ['a', 'b', 'c'], 123) + >>> example + LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}}))) + """ + for key in keys[:-1]: + if key not in target: + target[key] = LinkNode() + target = target[key] + + try: + target.links.append((keys[-1], value)) + except TypeError: + msg = INSERT_INTO_COLLISION_FMT.format( + value_url=value.url, + target_url=target.url, + keys=keys + ) + raise ValueError(msg) + + +class SchemaGenerator(BaseSchemaGenerator): + """ + Original CoreAPI version. + """ + # Map HTTP methods onto actions. + default_mapping = { + 'get': 'retrieve', + 'post': 'create', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy', + } + + # Map the method names we use for viewset actions onto external schema names. + # These give us names that are more suitable for the external representation. + # Set by 'SCHEMA_COERCE_METHOD_NAMES'. + coerce_method_names = None + + def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): + assert coreapi, '`coreapi` must be installed for schema support.' + assert coreschema, '`coreschema` must be installed for schema support.' + + super(SchemaGenerator, self).__init__(title, url, description, patterns, urlconf) + self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES + + def get_links(self, request=None): + """ + Return a dictionary containing all the links that should be + included in the API schema. + """ + links = LinkNode() + + paths, view_endpoints = self._get_paths_and_endpoints(request) + + # Only generate the path prefix for paths that will be included + if not paths: + return None + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + link = view.schema.get_link(path, method, base_url=self.url) + subpath = path[len(prefix):] + keys = self.get_keys(subpath, method, view) + insert_into(links, keys, link) + + return links + + def get_schema(self, request=None, public=False): + """ + Generate a `coreapi.Document` representing the API schema. + """ + self._initialise_endpoints() + + links = self.get_links(None if public else request) + if not links: + return None + + url = self.url + if not url and request is not None: + url = request.build_absolute_uri() + + distribute_links(links) + return coreapi.Document( + title=self.title, description=self.description, + url=url, content=links + ) + + # Method for generating the link layout.... + def get_keys(self, subpath, method, view): + """ + Return a list of keys that should be used to layout a link within + the schema document. + + /users/ ("users", "list"), ("users", "create") + /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") + /users/enabled/ ("users", "enabled") # custom viewset list action + /users/{pk}/star/ ("users", "star") # custom viewset detail action + /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") + /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") + """ + if hasattr(view, 'action'): + # Viewsets have explicitly named actions. + action = view.action + else: + # Views have no associated action, so we determine one from the method. + if is_list_view(subpath, method, view): + action = 'list' + else: + action = self.default_mapping[method.lower()] + + named_path_components = [ + component for component + in subpath.strip('/').split('/') + if '{' not in component + ] + + if is_custom_action(action): + # Custom action, eg "/users/{pk}/activate/", "/users/active/" + if len(view.action_map) > 1: + action = self.default_mapping[method.lower()] + if action in self.coerce_method_names: + action = self.coerce_method_names[action] + return named_path_components + [action] + else: + return named_path_components[:-1] + [action] + + if action in self.coerce_method_names: + action = self.coerce_method_names[action] + + # Default action, eg "/users/", "/users/{pk}/" + return named_path_components + [action] + +# View Inspectors # + + +def field_to_schema(field): + title = force_text(field.label) if field.label else '' + description = force_text(field.help_text) if field.help_text else '' + + if isinstance(field, (serializers.ListSerializer, serializers.ListField)): + child_schema = field_to_schema(field.child) + return coreschema.Array( + items=child_schema, + title=title, + description=description + ) + elif isinstance(field, serializers.DictField): + return coreschema.Object( + title=title, + description=description + ) + elif isinstance(field, serializers.Serializer): + return coreschema.Object( + properties=OrderedDict([ + (key, field_to_schema(value)) + for key, value + in field.fields.items() + ]), + title=title, + description=description + ) + elif isinstance(field, serializers.ManyRelatedField): + related_field_schema = field_to_schema(field.child_relation) + + return coreschema.Array( + items=related_field_schema, + title=title, + description=description + ) + elif isinstance(field, serializers.PrimaryKeyRelatedField): + schema_cls = coreschema.String + model = getattr(field.queryset, 'model', None) + if model is not None: + model_field = model._meta.pk + if isinstance(model_field, models.AutoField): + schema_cls = coreschema.Integer + return schema_cls(title=title, description=description) + elif isinstance(field, serializers.RelatedField): + return coreschema.String(title=title, description=description) + elif isinstance(field, serializers.MultipleChoiceField): + return coreschema.Array( + items=coreschema.Enum(enum=list(field.choices)), + title=title, + description=description + ) + elif isinstance(field, serializers.ChoiceField): + return coreschema.Enum( + enum=list(field.choices), + title=title, + description=description + ) + elif isinstance(field, serializers.BooleanField): + return coreschema.Boolean(title=title, description=description) + elif isinstance(field, (serializers.DecimalField, serializers.FloatField)): + return coreschema.Number(title=title, description=description) + elif isinstance(field, serializers.IntegerField): + return coreschema.Integer(title=title, description=description) + elif isinstance(field, serializers.DateField): + return coreschema.String( + title=title, + description=description, + format='date' + ) + elif isinstance(field, serializers.DateTimeField): + return coreschema.String( + title=title, + description=description, + format='date-time' + ) + elif isinstance(field, serializers.JSONField): + return coreschema.Object(title=title, description=description) + + if field.style.get('base_template') == 'textarea.html': + return coreschema.String( + title=title, + description=description, + format='textarea' + ) + + return coreschema.String(title=title, description=description) + + +class AutoSchema(ViewInspector): + """ + Default inspector for APIView + + Responsible for per-view introspection and schema generation. + """ + def __init__(self, manual_fields=None): + """ + Parameters: + + * `manual_fields`: list of `coreapi.Field` instances that + will be added to auto-generated fields, overwriting on `Field.name` + """ + super(AutoSchema, self).__init__() + if manual_fields is None: + manual_fields = [] + self._manual_fields = manual_fields + + def get_link(self, path, method, base_url): + """ + Generate `coreapi.Link` for self.view, path and method. + + This is the main _public_ access point. + + Parameters: + + * path: Route path for view from URLConf. + * method: The HTTP request method. + * base_url: The project "mount point" as given to SchemaGenerator + """ + fields = self.get_path_fields(path, method) + fields += self.get_serializer_fields(path, method) + fields += self.get_pagination_fields(path, method) + fields += self.get_filter_fields(path, method) + + manual_fields = self.get_manual_fields(path, method) + fields = self.update_fields(fields, manual_fields) + + if fields and any([field.location in ('form', 'body') for field in fields]): + encoding = self.get_encoding(path, method) + else: + encoding = None + + description = self.get_description(path, method) + + if base_url and path.startswith('/'): + path = path[1:] + + return coreapi.Link( + url=urlparse.urljoin(base_url, path), + action=method.lower(), + encoding=encoding, + fields=fields, + description=description + ) + + def get_description(self, path, method): + """ + Determine a link description. + + This will be based on the method docstring if one exists, + or else the class docstring. + """ + view = self.view + + method_name = getattr(view, 'action', method.lower()) + method_docstring = getattr(view, method_name, None).__doc__ + if method_docstring: + # An explicit docstring on the method or action. + return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring))) + else: + return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description()) + + def _get_description_section(self, view, header, description): + lines = [line for line in description.splitlines()] + current_section = '' + sections = {'': ''} + + for line in lines: + if header_regex.match(line): + current_section, seperator, lead = line.partition(':') + sections[current_section] = lead.strip() + else: + sections[current_section] += '\n' + line + + # TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys` + coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES + if header in sections: + return sections[header].strip() + if header in coerce_method_names: + if coerce_method_names[header] in sections: + return sections[coerce_method_names[header]].strip() + return sections[''].strip() + + def get_path_fields(self, path, method): + """ + Return a list of `coreapi.Field` instances corresponding to any + templated path variables. + """ + view = self.view + model = getattr(getattr(view, 'queryset', None), 'model', None) + fields = [] + + for variable in uritemplate.variables(path): + title = '' + description = '' + schema_cls = coreschema.String + kwargs = {} + if model is not None: + # Attempt to infer a field description if possible. + try: + model_field = model._meta.get_field(variable) + except Exception: + model_field = None + + if model_field is not None and model_field.verbose_name: + title = force_text(model_field.verbose_name) + + if model_field is not None and model_field.help_text: + description = force_text(model_field.help_text) + elif model_field is not None and model_field.primary_key: + description = get_pk_description(model, model_field) + + if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable: + kwargs['pattern'] = view.lookup_value_regex + elif isinstance(model_field, models.AutoField): + schema_cls = coreschema.Integer + + field = coreapi.Field( + name=variable, + location='path', + required=True, + schema=schema_cls(title=title, description=description, **kwargs) + ) + fields.append(field) + + return fields + + def get_serializer_fields(self, path, method): + """ + Return a list of `coreapi.Field` instances corresponding to any + request body input, as determined by the serializer class. + """ + view = self.view + + if method not in ('PUT', 'PATCH', 'POST'): + return [] + + if not hasattr(view, 'get_serializer'): + return [] + + try: + serializer = view.get_serializer() + except exceptions.APIException: + serializer = None + warnings.warn('{}.get_serializer() raised an exception during ' + 'schema generation. Serializer fields will not be ' + 'generated for {} {}.' + .format(view.__class__.__name__, method, path)) + + if isinstance(serializer, serializers.ListSerializer): + return [ + coreapi.Field( + name='data', + location='body', + required=True, + schema=coreschema.Array() + ) + ] + + if not isinstance(serializer, serializers.Serializer): + return [] + + fields = [] + for field in serializer.fields.values(): + if field.read_only or isinstance(field, serializers.HiddenField): + continue + + required = field.required and method != 'PATCH' + field = coreapi.Field( + name=field.field_name, + location='form', + required=required, + schema=field_to_schema(field) + ) + fields.append(field) + + return fields + + def get_pagination_fields(self, path, method): + view = self.view + + if not is_list_view(path, method, view): + return [] + + pagination = getattr(view, 'pagination_class', None) + if not pagination: + return [] + + paginator = view.pagination_class() + return paginator.get_schema_fields(view) + + def _allows_filters(self, path, method): + """ + Determine whether to include filter Fields in schema. + + Default implementation looks for ModelViewSet or GenericAPIView + actions/methods that cause filtering on the default implementation. + + Override to adjust behaviour for your view. + + Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore) + to allow changes based on user experience. + """ + if getattr(self.view, 'filter_backends', None) is None: + return False + + if hasattr(self.view, 'action'): + return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] + + return method.lower() in ["get", "put", "patch", "delete"] + + def get_filter_fields(self, path, method): + if not self._allows_filters(path, method): + return [] + + fields = [] + for filter_backend in self.view.filter_backends: + fields += filter_backend().get_schema_fields(self.view) + return fields + + def get_manual_fields(self, path, method): + return self._manual_fields + + @staticmethod + def update_fields(fields, update_with): + """ + Update list of coreapi.Field instances, overwriting on `Field.name`. + + Utility function to handle replacing coreapi.Field fields + from a list by name. Used to handle `manual_fields`. + + Parameters: + + * `fields`: list of `coreapi.Field` instances to update + * `update_with: list of `coreapi.Field` instances to add or replace. + """ + if not update_with: + return fields + + by_name = OrderedDict((f.name, f) for f in fields) + for f in update_with: + by_name[f.name] = f + fields = list(by_name.values()) + return fields + + def get_encoding(self, path, method): + """ + Return the 'encoding' parameter to use for a given endpoint. + """ + view = self.view + + # Core API supports the following request encodings over HTTP... + supported_media_types = { + 'application/json', + 'application/x-www-form-urlencoded', + 'multipart/form-data', + } + parser_classes = getattr(view, 'parser_classes', []) + for parser_class in parser_classes: + media_type = getattr(parser_class, 'media_type', None) + if media_type in supported_media_types: + return media_type + # Raw binary uploads are supported with "application/octet-stream" + if media_type == '*/*': + return 'application/octet-stream' + + return None + + +class ManualSchema(ViewInspector): + """ + Allows providing a list of coreapi.Fields, + plus an optional description. + """ + def __init__(self, fields, description='', encoding=None): + """ + Parameters: + + * `fields`: list of `coreapi.Field` instances. + * `description`: String description for view. Optional. + """ + super(ManualSchema, self).__init__() + assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" + self._fields = fields + self._description = description + self._encoding = encoding + + def get_link(self, path, method, base_url): + + if base_url and path.startswith('/'): + path = path[1:] + + return coreapi.Link( + url=urlparse.urljoin(base_url, path), + action=method.lower(), + encoding=self._encoding, + fields=self._fields, + description=self._description + ) diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index db226a6c16..3455ad1e09 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -4,7 +4,6 @@ See schemas.__init__.py for package overview. """ import re -from collections import Counter, OrderedDict from importlib import import_module from django.conf import settings @@ -14,15 +13,11 @@ from django.utils import six from rest_framework import exceptions -from rest_framework.compat import ( - URLPattern, URLResolver, coreapi, coreschema, get_original_route -) +from rest_framework.compat import URLPattern, URLResolver, get_original_route from rest_framework.request import clone_request from rest_framework.settings import api_settings from rest_framework.utils.model_meta import _get_pk -from .utils import is_list_view - def common_path(paths): split_paths = [path.strip('/').split('/') for path in paths] @@ -51,78 +46,6 @@ def is_api_view(callback): return (cls is not None) and issubclass(cls, APIView) -INSERT_INTO_COLLISION_FMT = """ -Schema Naming Collision. - -coreapi.Link for URL path {value_url} cannot be inserted into schema. -Position conflicts with coreapi.Link for URL path {target_url}. - -Attempted to insert link with keys: {keys}. - -Adjust URLs to avoid naming collision or override `SchemaGenerator.get_keys()` -to customise schema structure. -""" - - -class LinkNode(OrderedDict): - def __init__(self): - self.links = [] - self.methods_counter = Counter() - super(LinkNode, self).__init__() - - def get_available_key(self, preferred_key): - if preferred_key not in self: - return preferred_key - - while True: - current_val = self.methods_counter[preferred_key] - self.methods_counter[preferred_key] += 1 - - key = '{}_{}'.format(preferred_key, current_val) - if key not in self: - return key - - -def insert_into(target, keys, value): - """ - Nested dictionary insertion. - - >>> example = {} - >>> insert_into(example, ['a', 'b', 'c'], 123) - >>> example - LinkNode({'a': LinkNode({'b': LinkNode({'c': LinkNode(links=[123])}}}))) - """ - for key in keys[:-1]: - if key not in target: - target[key] = LinkNode() - target = target[key] - - try: - target.links.append((keys[-1], value)) - except TypeError: - msg = INSERT_INTO_COLLISION_FMT.format( - value_url=value.url, - target_url=target.url, - keys=keys - ) - raise ValueError(msg) - - -def distribute_links(obj): - for key, value in obj.items(): - distribute_links(value) - - for preferred_key, link in obj.links: - key = obj.get_available_key(preferred_key) - obj[key] = link - - -def is_custom_action(action): - return action not in { - 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' - } - - def endpoint_ordering(endpoint): path, method, callback = endpoint method_priority = { @@ -193,6 +116,10 @@ def get_path_from_regex(self, path_regex): """ Given a URL conf regex, return a URI template string. """ + # ???: Would it be feasible to adjust this such that we generate the + # path, plus the kwargs, plus the type from the convertor, such that we + # could feed that straight into the parameter schema object? + path = simplify_regex(path_regex) # Strip Django 2.0 convertors as they are incompatible with uritemplate format @@ -232,35 +159,18 @@ def get_allowed_methods(self, callback): return [method for method in methods if method not in ('OPTIONS', 'HEAD')] -class SchemaGenerator(object): - # Map HTTP methods onto actions. - default_mapping = { - 'get': 'retrieve', - 'post': 'create', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy', - } +class BaseSchemaGenerator(object): endpoint_inspector_cls = EndpointEnumerator - # Map the method names we use for viewset actions onto external schema names. - # These give us names that are more suitable for the external representation. - # Set by 'SCHEMA_COERCE_METHOD_NAMES'. - coerce_method_names = None - # 'pk' isn't great as an externally exposed name for an identifier, # so by default we prefer to use the actual model field name for schemas. # Set by 'SCHEMA_COERCE_PATH_PK'. coerce_path_pk = None def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): - assert coreapi, '`coreapi` must be installed for schema support.' - assert coreschema, '`coreschema` must be installed for schema support.' - if url and not url.endswith('/'): url += '/' - self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK self.patterns = patterns @@ -270,36 +180,15 @@ def __init__(self, title=None, url=None, description=None, patterns=None, urlcon self.url = url self.endpoints = None - def get_schema(self, request=None, public=False): - """ - Generate a `coreapi.Document` representing the API schema. - """ + def _initialise_endpoints(self): if self.endpoints is None: inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) self.endpoints = inspector.get_api_endpoints() - links = self.get_links(None if public else request) - if not links: - return None - - url = self.url - if not url and request is not None: - url = request.build_absolute_uri() - - distribute_links(links) - return coreapi.Document( - title=self.title, description=self.description, - url=url, content=links - ) - - def get_links(self, request=None): + def _get_paths_and_endpoints(self, request): """ - Return a dictionary containing all the links that should be - included in the API schema. + Generate (path, method, view) given (path, method, callback) for paths. """ - links = LinkNode() - - # Generate (path, method, view) given (path, method, callback). paths = [] view_endpoints = [] for path, method, callback in self.endpoints: @@ -308,22 +197,48 @@ def get_links(self, request=None): paths.append(path) view_endpoints.append((path, method, view)) - # Only generate the path prefix for paths that will be included - if not paths: - return None - prefix = self.determine_path_prefix(paths) + return paths, view_endpoints - for path, method, view in view_endpoints: - if not self.has_view_permissions(path, method, view): - continue - link = view.schema.get_link(path, method, base_url=self.url) - subpath = path[len(prefix):] - keys = self.get_keys(subpath, method, view) - insert_into(links, keys, link) + def create_view(self, callback, method, request=None): + """ + Given a callback, return an actual view instance. + """ + view = callback.cls(**getattr(callback, 'initkwargs', {})) + view.args = () + view.kwargs = {} + view.format_kwarg = None + view.request = None + view.action_map = getattr(callback, 'actions', None) - return links + actions = getattr(callback, 'actions', None) + if actions is not None: + if method == 'OPTIONS': + view.action = 'metadata' + else: + view.action = actions.get(method.lower()) + + if request is not None: + view.request = clone_request(request, method) + + return view + + def coerce_path(self, path, method, view): + """ + Coerce {pk} path arguments into the name of the model field, + where possible. This is cleaner for an external representation. + (Ie. "this is an identifier", not "this is a database primary key") + """ + if not self.coerce_path_pk or '{pk}' not in path: + return path + model = getattr(getattr(view, 'queryset', None), 'model', None) + if model: + field_name = get_pk_name(model) + else: + field_name = 'id' + return path.replace('{pk}', '{%s}' % field_name) - # Methods used when we generate a view instance from the raw callback... + def get_schema(self, request=None, public=False): + raise NotImplementedError(".get_schema() must be implemented in subclasses.") def determine_path_prefix(self, paths): """ @@ -356,29 +271,6 @@ def determine_path_prefix(self, paths): prefixes.append('/' + prefix + '/') return common_path(prefixes) - def create_view(self, callback, method, request=None): - """ - Given a callback, return an actual view instance. - """ - view = callback.cls(**getattr(callback, 'initkwargs', {})) - view.args = () - view.kwargs = {} - view.format_kwarg = None - view.request = None - view.action_map = getattr(callback, 'actions', None) - - actions = getattr(callback, 'actions', None) - if actions is not None: - if method == 'OPTIONS': - view.action = 'metadata' - else: - view.action = actions.get(method.lower()) - - if request is not None: - view.request = clone_request(request, method) - - return view - def has_view_permissions(self, path, method, view): """ Return `True` if the incoming request has the correct view permissions. @@ -391,64 +283,3 @@ def has_view_permissions(self, path, method, view): except (exceptions.APIException, Http404, PermissionDenied): return False return True - - def coerce_path(self, path, method, view): - """ - Coerce {pk} path arguments into the name of the model field, - where possible. This is cleaner for an external representation. - (Ie. "this is an identifier", not "this is a database primary key") - """ - if not self.coerce_path_pk or '{pk}' not in path: - return path - model = getattr(getattr(view, 'queryset', None), 'model', None) - if model: - field_name = get_pk_name(model) - else: - field_name = 'id' - return path.replace('{pk}', '{%s}' % field_name) - - # Method for generating the link layout.... - - def get_keys(self, subpath, method, view): - """ - Return a list of keys that should be used to layout a link within - the schema document. - - /users/ ("users", "list"), ("users", "create") - /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") - /users/enabled/ ("users", "enabled") # custom viewset list action - /users/{pk}/star/ ("users", "star") # custom viewset detail action - /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") - /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") - """ - if hasattr(view, 'action'): - # Viewsets have explicitly named actions. - action = view.action - else: - # Views have no associated action, so we determine one from the method. - if is_list_view(subpath, method, view): - action = 'list' - else: - action = self.default_mapping[method.lower()] - - named_path_components = [ - component for component - in subpath.strip('/').split('/') - if '{' not in component - ] - - if is_custom_action(action): - # Custom action, eg "/users/{pk}/activate/", "/users/active/" - if len(view.action_map) > 1: - action = self.default_mapping[method.lower()] - if action in self.coerce_method_names: - action = self.coerce_method_names[action] - return named_path_components + [action] - else: - return named_path_components[:-1] + [action] - - if action in self.coerce_method_names: - action = self.coerce_method_names[action] - - # Default action, eg "/users/", "/users/{pk}/" - return named_path_components + [action] diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index 85142edce4..e1f955f930 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -4,125 +4,9 @@ See schemas.__init__.py for package overview. """ -import re -import warnings -from collections import OrderedDict from weakref import WeakKeyDictionary -from django.db import models -from django.utils.encoding import force_text, smart_text -from django.utils.six.moves.urllib import parse as urlparse -from django.utils.translation import ugettext_lazy as _ - -from rest_framework import exceptions, serializers -from rest_framework.compat import coreapi, coreschema, uritemplate from rest_framework.settings import api_settings -from rest_framework.utils import formatting - -from .utils import is_list_view - -header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') - - -def field_to_schema(field): - title = force_text(field.label) if field.label else '' - description = force_text(field.help_text) if field.help_text else '' - - if isinstance(field, (serializers.ListSerializer, serializers.ListField)): - child_schema = field_to_schema(field.child) - return coreschema.Array( - items=child_schema, - title=title, - description=description - ) - elif isinstance(field, serializers.DictField): - return coreschema.Object( - title=title, - description=description - ) - elif isinstance(field, serializers.Serializer): - return coreschema.Object( - properties=OrderedDict([ - (key, field_to_schema(value)) - for key, value - in field.fields.items() - ]), - title=title, - description=description - ) - elif isinstance(field, serializers.ManyRelatedField): - related_field_schema = field_to_schema(field.child_relation) - - return coreschema.Array( - items=related_field_schema, - title=title, - description=description - ) - elif isinstance(field, serializers.PrimaryKeyRelatedField): - schema_cls = coreschema.String - model = getattr(field.queryset, 'model', None) - if model is not None: - model_field = model._meta.pk - if isinstance(model_field, models.AutoField): - schema_cls = coreschema.Integer - return schema_cls(title=title, description=description) - elif isinstance(field, serializers.RelatedField): - return coreschema.String(title=title, description=description) - elif isinstance(field, serializers.MultipleChoiceField): - return coreschema.Array( - items=coreschema.Enum(enum=list(field.choices)), - title=title, - description=description - ) - elif isinstance(field, serializers.ChoiceField): - return coreschema.Enum( - enum=list(field.choices), - title=title, - description=description - ) - elif isinstance(field, serializers.BooleanField): - return coreschema.Boolean(title=title, description=description) - elif isinstance(field, (serializers.DecimalField, serializers.FloatField)): - return coreschema.Number(title=title, description=description) - elif isinstance(field, serializers.IntegerField): - return coreschema.Integer(title=title, description=description) - elif isinstance(field, serializers.DateField): - return coreschema.String( - title=title, - description=description, - format='date' - ) - elif isinstance(field, serializers.DateTimeField): - return coreschema.String( - title=title, - description=description, - format='date-time' - ) - elif isinstance(field, serializers.JSONField): - return coreschema.Object(title=title, description=description) - - if field.style.get('base_template') == 'textarea.html': - return coreschema.String( - title=title, - description=description, - format='textarea' - ) - - return coreschema.String(title=title, description=description) - - -def get_pk_description(model, model_field): - if isinstance(model_field, models.AutoField): - value_type = _('unique integer value') - elif isinstance(model_field, models.UUIDField): - value_type = _('UUID string') - else: - value_type = _('unique value') - - return _('A {value_type} identifying this {name}.').format( - value_type=value_type, - name=model._meta.verbose_name, - ) class ViewInspector(object): @@ -179,321 +63,6 @@ def view(self, value): def view(self): self._view = None - def get_link(self, path, method, base_url): - """ - Generate `coreapi.Link` for self.view, path and method. - - This is the main _public_ access point. - - Parameters: - - * path: Route path for view from URLConf. - * method: The HTTP request method. - * base_url: The project "mount point" as given to SchemaGenerator - """ - raise NotImplementedError(".get_link() must be overridden.") - - -class AutoSchema(ViewInspector): - """ - Default inspector for APIView - - Responsible for per-view introspection and schema generation. - """ - def __init__(self, manual_fields=None): - """ - Parameters: - - * `manual_fields`: list of `coreapi.Field` instances that - will be added to auto-generated fields, overwriting on `Field.name` - """ - super(AutoSchema, self).__init__() - if manual_fields is None: - manual_fields = [] - self._manual_fields = manual_fields - - def get_link(self, path, method, base_url): - fields = self.get_path_fields(path, method) - fields += self.get_serializer_fields(path, method) - fields += self.get_pagination_fields(path, method) - fields += self.get_filter_fields(path, method) - - manual_fields = self.get_manual_fields(path, method) - fields = self.update_fields(fields, manual_fields) - - if fields and any([field.location in ('form', 'body') for field in fields]): - encoding = self.get_encoding(path, method) - else: - encoding = None - - description = self.get_description(path, method) - - if base_url and path.startswith('/'): - path = path[1:] - - return coreapi.Link( - url=urlparse.urljoin(base_url, path), - action=method.lower(), - encoding=encoding, - fields=fields, - description=description - ) - - def get_description(self, path, method): - """ - Determine a link description. - - This will be based on the method docstring if one exists, - or else the class docstring. - """ - view = self.view - - method_name = getattr(view, 'action', method.lower()) - method_docstring = getattr(view, method_name, None).__doc__ - if method_docstring: - # An explicit docstring on the method or action. - return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring))) - else: - return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description()) - - def _get_description_section(self, view, header, description): - lines = [line for line in description.splitlines()] - current_section = '' - sections = {'': ''} - - for line in lines: - if header_regex.match(line): - current_section, seperator, lead = line.partition(':') - sections[current_section] = lead.strip() - else: - sections[current_section] += '\n' + line - - # TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys` - coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES - if header in sections: - return sections[header].strip() - if header in coerce_method_names: - if coerce_method_names[header] in sections: - return sections[coerce_method_names[header]].strip() - return sections[''].strip() - - def get_path_fields(self, path, method): - """ - Return a list of `coreapi.Field` instances corresponding to any - templated path variables. - """ - view = self.view - model = getattr(getattr(view, 'queryset', None), 'model', None) - fields = [] - - for variable in uritemplate.variables(path): - title = '' - description = '' - schema_cls = coreschema.String - kwargs = {} - if model is not None: - # Attempt to infer a field description if possible. - try: - model_field = model._meta.get_field(variable) - except Exception: - model_field = None - - if model_field is not None and model_field.verbose_name: - title = force_text(model_field.verbose_name) - - if model_field is not None and model_field.help_text: - description = force_text(model_field.help_text) - elif model_field is not None and model_field.primary_key: - description = get_pk_description(model, model_field) - - if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable: - kwargs['pattern'] = view.lookup_value_regex - elif isinstance(model_field, models.AutoField): - schema_cls = coreschema.Integer - - field = coreapi.Field( - name=variable, - location='path', - required=True, - schema=schema_cls(title=title, description=description, **kwargs) - ) - fields.append(field) - - return fields - - def get_serializer_fields(self, path, method): - """ - Return a list of `coreapi.Field` instances corresponding to any - request body input, as determined by the serializer class. - """ - view = self.view - - if method not in ('PUT', 'PATCH', 'POST'): - return [] - - if not hasattr(view, 'get_serializer'): - return [] - - try: - serializer = view.get_serializer() - except exceptions.APIException: - serializer = None - warnings.warn('{}.get_serializer() raised an exception during ' - 'schema generation. Serializer fields will not be ' - 'generated for {} {}.' - .format(view.__class__.__name__, method, path)) - - if isinstance(serializer, serializers.ListSerializer): - return [ - coreapi.Field( - name='data', - location='body', - required=True, - schema=coreschema.Array() - ) - ] - - if not isinstance(serializer, serializers.Serializer): - return [] - - fields = [] - for field in serializer.fields.values(): - if field.read_only or isinstance(field, serializers.HiddenField): - continue - - required = field.required and method != 'PATCH' - field = coreapi.Field( - name=field.field_name, - location='form', - required=required, - schema=field_to_schema(field) - ) - fields.append(field) - - return fields - - def get_pagination_fields(self, path, method): - view = self.view - - if not is_list_view(path, method, view): - return [] - - pagination = getattr(view, 'pagination_class', None) - if not pagination: - return [] - - paginator = view.pagination_class() - return paginator.get_schema_fields(view) - - def _allows_filters(self, path, method): - """ - Determine whether to include filter Fields in schema. - - Default implementation looks for ModelViewSet or GenericAPIView - actions/methods that cause filtering on the default implementation. - - Override to adjust behaviour for your view. - - Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore) - to allow changes based on user experience. - """ - if getattr(self.view, 'filter_backends', None) is None: - return False - - if hasattr(self.view, 'action'): - return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] - - return method.lower() in ["get", "put", "patch", "delete"] - - def get_filter_fields(self, path, method): - if not self._allows_filters(path, method): - return [] - - fields = [] - for filter_backend in self.view.filter_backends: - fields += filter_backend().get_schema_fields(self.view) - return fields - - def get_manual_fields(self, path, method): - return self._manual_fields - - @staticmethod - def update_fields(fields, update_with): - """ - Update list of coreapi.Field instances, overwriting on `Field.name`. - - Utility function to handle replacing coreapi.Field fields - from a list by name. Used to handle `manual_fields`. - - Parameters: - - * `fields`: list of `coreapi.Field` instances to update - * `update_with: list of `coreapi.Field` instances to add or replace. - """ - if not update_with: - return fields - - by_name = OrderedDict((f.name, f) for f in fields) - for f in update_with: - by_name[f.name] = f - fields = list(by_name.values()) - return fields - - def get_encoding(self, path, method): - """ - Return the 'encoding' parameter to use for a given endpoint. - """ - view = self.view - - # Core API supports the following request encodings over HTTP... - supported_media_types = { - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data', - } - parser_classes = getattr(view, 'parser_classes', []) - for parser_class in parser_classes: - media_type = getattr(parser_class, 'media_type', None) - if media_type in supported_media_types: - return media_type - # Raw binary uploads are supported with "application/octet-stream" - if media_type == '*/*': - return 'application/octet-stream' - - return None - - -class ManualSchema(ViewInspector): - """ - Allows providing a list of coreapi.Fields, - plus an optional description. - """ - def __init__(self, fields, description='', encoding=None): - """ - Parameters: - - * `fields`: list of `coreapi.Field` instances. - * `description`: String description for view. Optional. - """ - super(ManualSchema, self).__init__() - assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" - self._fields = fields - self._description = description - self._encoding = encoding - - def get_link(self, path, method, base_url): - - if base_url and path.startswith('/'): - path = path[1:] - - return coreapi.Link( - url=urlparse.urljoin(base_url, path), - action=method.lower(), - encoding=self._encoding, - fields=self._fields, - description=self._description - ) - class DefaultSchema(ViewInspector): """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py new file mode 100644 index 0000000000..85d296c5a3 --- /dev/null +++ b/rest_framework/schemas/openapi.py @@ -0,0 +1,366 @@ +import warnings + +from django.db import models +from django.utils.encoding import force_text + +from rest_framework import exceptions, serializers +from rest_framework.compat import uritemplate + +from .generators import BaseSchemaGenerator +from .inspectors import ViewInspector +from .utils import get_pk_description, is_list_view + +# Generator + + +class SchemaGenerator(BaseSchemaGenerator): + + def get_info(self): + info = { + 'title': self.title, + 'version': 'TODO', + } + + if self.description is not None: + info['description'] = self.description + + return info + + def get_paths(self, request=None): + result = {} + + paths, view_endpoints = self._get_paths_and_endpoints(request) + + # Only generate the path prefix for paths that will be included + if not paths: + return None + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + operation = view.schema.get_operation(path, method) + subpath = '/' + path[len(prefix):] + result.setdefault(subpath, {}) + result[subpath][method.lower()] = operation + + return result + + def get_schema(self, request=None, public=False): + """ + Generate a OpenAPI schema. + """ + self._initialise_endpoints() + + paths = self.get_paths(None if public else request) + if not paths: + return None + + schema = { + 'openapi': '3.0.2', + 'info': self.get_info(), + 'paths': paths, + } + + return schema + +# View Inspectors + + +class AutoSchema(ViewInspector): + + content_types = ['application/json'] + method_mapping = { + 'get': 'Retrieve', + 'post': 'Create', + 'put': 'Update', + 'patch': 'PartialUpdate', + 'delete': 'Destroy', + } + + def get_operation(self, path, method): + operation = {} + + operation['operationId'] = self._get_operation_id(path, method) + + parameters = [] + parameters += self._get_path_parameters(path, method) + parameters += self._get_pagination_parameters(path, method) + parameters += self._get_filter_parameters(path, method) + operation['parameters'] = parameters + + request_body = self._get_request_body(path, method) + if request_body: + operation['requestBody'] = request_body + operation['responses'] = self._get_responses(path, method) + + return operation + + def _get_operation_id(self, path, method): + """ + Compute an operation ID from the model, serializer or view name. + """ + # TODO: Allow an attribute/method on the view to change that ID? + # Avoid cyclic imports + from rest_framework.generics import GenericAPIView + + if is_list_view(path, method, self.view): + action = 'List' + else: + action = self.method_mapping[method.lower()] + + # Try to deduce the ID from the view's model + model = getattr(getattr(self.view, 'queryset', None), 'model', None) + if model is not None: + name = model.__name__ + + # Try with the serializer class name + elif isinstance(self.view, GenericAPIView): + name = self.view.get_serializer_class().__name__ + if name.endswith('Serializer'): + name = name[:-10] + + # Fallback to the view name + else: + name = self.view.__class__.__name__ + if name.endswith('APIView'): + name = name[:-7] + elif name.endswith('View'): + name = name[:-4] + if name.endswith(action): # ListView, UpdateAPIView, ThingDelete ... + name = name[:-len(action)] + + if action == 'List' and not name.endswith('s'): # ListThings instead of ListThing + name += 's' + + return action + name + + def _get_path_parameters(self, path, method): + """ + Return a list of parameters from templated path variables. + """ + assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.' + + model = getattr(getattr(self.view, 'queryset', None), 'model', None) + parameters = [] + + for variable in uritemplate.variables(path): + description = '' + if model is not None: # TODO: test this. + # Attempt to infer a field description if possible. + try: + model_field = model._meta.get_field(variable) + except Exception: + model_field = None + + if model_field is not None and model_field.help_text: + description = force_text(model_field.help_text) + elif model_field is not None and model_field.primary_key: + description = get_pk_description(model, model_field) + + parameter = { + "name": variable, + "in": "path", + "required": True, + "description": description, + 'schema': { + 'type': 'string', # TODO: integer, pattern, ... + }, + } + parameters.append(parameter) + + return parameters + + def _get_filter_parameters(self, path, method): + if not self._allows_filters(path, method): + return [] + parameters = [] + for filter_backend in self.view.filter_backends: + parameters += filter_backend().get_schema_operation_parameters(self.view) + return parameters + + def _allows_filters(self, path, method): + """ + Determine whether to include filter Fields in schema. + + Default implementation looks for ModelViewSet or GenericAPIView + actions/methods that cause filtering on the default implementation. + """ + if getattr(self.view, 'filter_backends', None) is None: + return False + if hasattr(self.view, 'action'): + return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] + return method.lower() in ["get", "put", "patch", "delete"] + + def _get_pagination_parameters(self, path, method): + view = self.view + + if not is_list_view(path, method, view): + return [] + + pagination = getattr(view, 'pagination_class', None) + if not pagination: + return [] + + paginator = view.pagination_class() + return paginator.get_schema_operation_parameters(view) + + def _map_field(self, field): + + # Nested Serializers, `many` or not. + if isinstance(field, serializers.ListSerializer): + return { + 'type': 'array', + 'items': self._map_serializer(field.child) + } + if isinstance(field, serializers.Serializer): + data = self._map_serializer(field) + data['type'] = 'object' + return data + + # Related fields. + if isinstance(field, serializers.ManyRelatedField): + return { + 'type': 'array', + 'items': self._map_field(field.child_relation) + } + if isinstance(field, serializers.PrimaryKeyRelatedField): + model = getattr(field.queryset, 'model', None) + if model is not None: + model_field = model._meta.pk + if isinstance(model_field, models.AutoField): + return {'type': 'integer'} + + # ChoiceFields (single and multiple). + # Q: + # - Is 'type' required? + # - can we determine the TYPE of a choicefield? + if isinstance(field, serializers.MultipleChoiceField): + return { + 'type': 'array', + 'items': { + 'enum': list(field.choices) + }, + } + + if isinstance(field, serializers.ChoiceField): + return { + 'enum': list(field.choices), + } + + # ListField. + if isinstance(field, serializers.ListField): + return { + 'type': 'array', + } + + # Simplest cases, default to 'string' type: + FIELD_CLASS_SCHEMA_TYPE = { + serializers.BooleanField: 'boolean', + serializers.DecimalField: 'number', + serializers.FloatField: 'number', + serializers.IntegerField: 'integer', + serializers.DateField: 'date', + serializers.DateTimeField: 'date-time', + serializers.JSONField: 'object', + serializers.DictField: 'object', + } + return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')} + + def _map_serializer(self, serializer): + # Assuming we have a valid serializer instance. + # TODO: + # - field is Nested or List serializer. + # - Handle read_only/write_only for request/response differences. + # - could do this with readOnly/writeOnly and then filter dict. + required = [] + properties = {} + + for field in serializer.fields.values(): + if isinstance(field, serializers.HiddenField): + continue + + if field.required: + required.append(field.field_name) + + schema = self._map_field(field) + if field.read_only: + schema['readOnly'] = True + if field.write_only: + schema['writeOnly'] = True + if field.allow_null: + schema['nullable'] = True + + properties[field.field_name] = schema + return { + 'required': required, + 'properties': properties, + } + + def _get_request_body(self, path, method): + view = self.view + + if method not in ('PUT', 'PATCH', 'POST'): + return {} + + if not hasattr(view, 'get_serializer'): + return {} + + try: + serializer = view.get_serializer() + except exceptions.APIException: + serializer = None + warnings.warn('{}.get_serializer() raised an exception during ' + 'schema generation. Serializer fields will not be ' + 'generated for {} {}.' + .format(view.__class__.__name__, method, path)) + + if not isinstance(serializer, serializers.Serializer): + return {} + + content = self._map_serializer(serializer) + # No required fields for PATCH + if method == 'PATCH': + del content['required'] + # No read_only fields for request. + for name, schema in content['properties'].copy().items(): + if 'readOnly' in schema: + del content['properties'][name] + + return { + 'content': { + ct: {'schema': content} + for ct in self.content_types + } + } + + def _get_responses(self, path, method): + # TODO: Handle multiple codes. + content = {} + view = self.view + if hasattr(view, 'get_serializer'): + try: + serializer = view.get_serializer() + except exceptions.APIException: + serializer = None + warnings.warn('{}.get_serializer() raised an exception during ' + 'schema generation. Serializer fields will not be ' + 'generated for {} {}.' + .format(view.__class__.__name__, method, path)) + + if isinstance(serializer, serializers.Serializer): + content = self._map_serializer(serializer) + # No write_only fields for response. + for name, schema in content['properties'].copy().items(): + if 'writeOnly' in schema: + del content['properties'][name] + content['required'] = [f for f in content['required'] if f != name] + + return { + '200': { + 'content': { + ct: {'schema': content} + for ct in self.content_types + } + } + } diff --git a/rest_framework/schemas/utils.py b/rest_framework/schemas/utils.py index 76437a20a6..6724eb4289 100644 --- a/rest_framework/schemas/utils.py +++ b/rest_framework/schemas/utils.py @@ -3,6 +3,9 @@ See schemas.__init__.py for package overview. """ +from django.db import models +from django.utils.translation import ugettext_lazy as _ + from rest_framework.mixins import RetrieveModelMixin @@ -22,3 +25,17 @@ def is_list_view(path, method, view): if path_components and '{' in path_components[-1]: return False return True + + +def get_pk_description(model, model_field): + if isinstance(model_field, models.AutoField): + value_type = _('unique integer value') + elif isinstance(model_field, models.UUIDField): + value_type = _('UUID string') + else: + value_type = _('unique value') + + return _('A {value_type} identifying this {name}.').format( + value_type=value_type, + name=model._meta.verbose_name, + ) diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 8db9c81eda..b18dc1457f 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -56,7 +56,7 @@ 'DEFAULT_FILTER_BACKENDS': (), # Schema - 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema', + 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema', # Throttling 'DEFAULT_THROTTLE_RATES': { diff --git a/tests/schemas/__init__.py b/tests/schemas/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_schemas.py b/tests/schemas/test_coreapi.py similarity index 94% rename from tests/test_schemas.py rename to tests/schemas/test_coreapi.py index 3cb9e0cda8..db0d5e4c8f 100644 --- a/tests/test_schemas.py +++ b/tests/schemas/test_coreapi.py @@ -16,15 +16,16 @@ from rest_framework.schemas import ( AutoSchema, ManualSchema, SchemaGenerator, get_schema_view ) +from rest_framework.schemas.coreapi import field_to_schema from rest_framework.schemas.generators import EndpointEnumerator -from rest_framework.schemas.inspectors import field_to_schema from rest_framework.schemas.utils import is_list_view from rest_framework.test import APIClient, APIRequestFactory from rest_framework.utils import formatting from rest_framework.views import APIView from rest_framework.viewsets import GenericViewSet, ModelViewSet -from .models import BasicModel, ForeignKeySource, ManyToManySource +from . import views +from ..models import BasicModel, ForeignKeySource, ManyToManySource factory = APIRequestFactory() @@ -148,7 +149,7 @@ def schema_view(request): @unittest.skipUnless(coreapi, 'coreapi is not installed') -@override_settings(ROOT_URLCONF='tests.test_schemas') +@override_settings(ROOT_URLCONF=__name__, REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestRouterGeneratedSchema(TestCase): def test_anonymous_request(self): client = APIClient() @@ -382,30 +383,14 @@ class MethodLimitedViewSet(ExampleViewSet): http_method_names = ['get', 'head', 'options'] -class ExampleListView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - pass - - def post(self, request, *args, **kwargs): - pass - - -class ExampleDetailView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, *args, **kwargs): - pass - - @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGenerator(TestCase): def setUp(self): self.patterns = [ - url(r'^example/?$', ExampleListView.as_view()), - url(r'^example/(?P\d+)/?$', ExampleDetailView.as_view()), - url(r'^example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), + url(r'^example/?$', views.ExampleListView.as_view()), + url(r'^example/(?P\d+)/?$', views.ExampleDetailView.as_view()), + url(r'^example/(?P\d+)/sub/?$', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -453,12 +438,13 @@ def test_schema_for_regular_views(self): @unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(path, 'needs Django 2') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorDjango2(TestCase): def setUp(self): self.patterns = [ - path('example/', ExampleListView.as_view()), - path('example//', ExampleDetailView.as_view()), - path('example//sub/', ExampleDetailView.as_view()), + path('example/', views.ExampleListView.as_view()), + path('example//', views.ExampleDetailView.as_view()), + path('example//sub/', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -505,12 +491,13 @@ def test_schema_for_regular_views(self): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorNotAtRoot(TestCase): def setUp(self): self.patterns = [ - url(r'^api/v1/example/?$', ExampleListView.as_view()), - url(r'^api/v1/example/(?P\d+)/?$', ExampleDetailView.as_view()), - url(r'^api/v1/example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), + url(r'^api/v1/example/?$', views.ExampleListView.as_view()), + url(r'^api/v1/example/(?P\d+)/?$', views.ExampleDetailView.as_view()), + url(r'^api/v1/example/(?P\d+)/sub/?$', views.ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -558,6 +545,7 @@ def test_schema_for_regular_views(self): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase): def setUp(self): router = DefaultRouter() @@ -622,13 +610,14 @@ def test_schema_for_regular_views(self): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorWithRestrictedViewSets(TestCase): def setUp(self): router = DefaultRouter() router.register('example1', Http404ExampleViewSet, basename='example1') router.register('example2', PermissionDeniedExampleViewSet, basename='example2') self.patterns = [ - url('^example/?$', ExampleListView.as_view()), + url('^example/?$', views.ExampleListView.as_view()), url(r'^', include(router.urls)) ] @@ -668,6 +657,7 @@ class ForeignKeySourceView(generics.CreateAPIView): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorWithForeignKey(TestCase): def setUp(self): self.patterns = [ @@ -713,6 +703,7 @@ class ManyToManySourceView(generics.CreateAPIView): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestSchemaGeneratorWithManyToMany(TestCase): def setUp(self): self.patterns = [ @@ -747,6 +738,7 @@ def test_schema_for_regular_views(self): @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class Test4605Regression(TestCase): def test_4605_regression(self): generator = SchemaGenerator() @@ -762,6 +754,7 @@ class CustomViewInspector(AutoSchema): pass +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestAutoSchema(TestCase): def test_apiview_schema_descriptor(self): @@ -777,7 +770,7 @@ class CustomView(APIView): assert isinstance(view.schema, CustomViewInspector) def test_set_custom_inspector_class_via_settings(self): - with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}): + with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.schemas.test_coreapi.CustomViewInspector'}): view = APIView() assert isinstance(view.schema, CustomViewInspector) @@ -971,6 +964,7 @@ def test_field_to_schema(self): self.assertEqual(field_to_schema(case[0]), case[1]) +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) def test_docstring_is_not_stripped_by_get_description(): class ExampleDocstringAPIView(APIView): """ @@ -1014,20 +1008,19 @@ def get(self, request, *args, **kwargs): pass -@api_view(['GET']) -@schema(None) -def excluded_fbv(request): - pass - - -@api_view(['GET']) -def included_fbv(request): - pass - - @unittest.skipUnless(coreapi, 'coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class SchemaGenerationExclusionTests(TestCase): def setUp(self): + @api_view(['GET']) + @schema(None) + def excluded_fbv(request): + pass + + @api_view(['GET']) + def included_fbv(request): + pass + self.patterns = [ url('^excluded-cbv/$', ExcludedAPIView.as_view()), url('^excluded-fbv/$', excluded_fbv), @@ -1078,11 +1071,6 @@ def test_should_include_endpoint_excludes_correctly(self): assert should_include == expected -@api_view(["GET"]) -def simple_fbv(request): - pass - - class BasicModelSerializer(serializers.ModelSerializer): class Meta: model = BasicModel @@ -1118,11 +1106,16 @@ def detail_export(self, request): @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) class TestURLNamingCollisions(TestCase): """ Ref: https://github.com/encode/django-rest-framework/issues/4704 """ def test_manually_routing_nested_routes(self): + @api_view(["GET"]) + def simple_fbv(request): + pass + patterns = [ url(r'^test', simple_fbv), url(r'^test/list/', simple_fbv), @@ -1228,6 +1221,10 @@ def test_url_under_same_key_not_replaced(self): def test_url_under_same_key_not_replaced_another(self): + @api_view(["GET"]) + def simple_fbv(request): + pass + patterns = [ url(r'^test/list/', simple_fbv), url(r'^test/(?P\d+)/list/', simple_fbv), @@ -1303,7 +1300,8 @@ def custom_action(self, request, pk): @pytest.mark.skipif(not coreapi, reason='coreapi is not installed') -class TestAutoSchemaAllowsFilters(object): +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) +class TestAutoSchemaAllowsFilters(TestCase): class MockAPIView(APIView): filter_backends = [filters.OrderingFilter] diff --git a/tests/schemas/test_managementcommand.py b/tests/schemas/test_managementcommand.py new file mode 100644 index 0000000000..a4d038e55a --- /dev/null +++ b/tests/schemas/test_managementcommand.py @@ -0,0 +1,51 @@ +from __future__ import unicode_literals + +import pytest +from django.conf.urls import url +from django.core.management import call_command +from django.test import TestCase +from django.test.utils import override_settings +from django.utils import six + +from rest_framework.compat import uritemplate, yaml +from rest_framework.utils import json +from rest_framework.views import APIView + + +class FooView(APIView): + def get(self, request): + pass + + +urlpatterns = [ + url(r'^$', FooView.as_view()) +] + + +@override_settings(ROOT_URLCONF=__name__) +@pytest.mark.skipif(not uritemplate, reason='uritemplate is not installed') +class GenerateSchemaTests(TestCase): + """Tests for management command generateschema.""" + + def setUp(self): + self.out = six.StringIO() + + @pytest.mark.skipif(six.PY2, reason='PyYAML unicode output is malformed on PY2.') + @pytest.mark.skipif(yaml is None, reason='PyYAML is required.') + def test_renders_default_schema_with_custom_title_url_and_description(self): + call_command('generateschema', + '--title=SampleAPI', + '--url=http://api.sample.com', + '--description=Sample description', + stdout=self.out) + # Check valid YAML was output. + schema = yaml.load(self.out.getvalue()) + assert schema['openapi'] == '3.0.2' + + def test_renders_openapi_json_schema(self): + call_command('generateschema', + '--format=openapi-json', + stdout=self.out) + # Check valid JSON was output. + out_json = json.loads(self.out.getvalue()) + assert out_json['openapi'] == '3.0.2' diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py new file mode 100644 index 0000000000..479b6405f5 --- /dev/null +++ b/tests/schemas/test_openapi.py @@ -0,0 +1,198 @@ +import pytest +from django.conf.urls import url +from django.test import RequestFactory, TestCase, override_settings + +from rest_framework import filters, generics, pagination, serializers +from rest_framework.compat import uritemplate +from rest_framework.request import Request +from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator + +from . import views + + +def create_request(path): + factory = RequestFactory() + request = Request(factory.get(path)) + return request + + +def create_view(view_cls, method, request): + generator = SchemaGenerator() + view = generator.create_view(view_cls.as_view(), method, request) + return view + + +class TestBasics(TestCase): + def dummy_view(request): + pass + + def test_filters(self): + classes = [filters.SearchFilter, filters.OrderingFilter] + for c in classes: + f = c() + assert f.get_schema_operation_parameters(self.dummy_view) + + def test_pagination(self): + classes = [pagination.PageNumberPagination, pagination.LimitOffsetPagination, pagination.CursorPagination] + for c in classes: + f = c() + assert f.get_schema_operation_parameters(self.dummy_view) + + +@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.') +class TestOperationIntrospection(TestCase): + + def test_path_without_parameters(self): + path = '/example/' + method = 'GET' + + view = create_view( + views.ExampleListView, + method, + create_request(path) + ) + inspector = AutoSchema() + inspector.view = view + + operation = inspector.get_operation(path, method) + assert operation == { + 'operationId': 'ListExamples', + 'parameters': [], + 'responses': {'200': {'content': {'application/json': {'schema': {}}}}}, + } + + def test_path_with_id_parameter(self): + path = '/example/{id}/' + method = 'GET' + + view = create_view( + views.ExampleDetailView, + method, + create_request(path) + ) + inspector = AutoSchema() + inspector.view = view + + parameters = inspector._get_path_parameters(path, method) + assert parameters == [{ + 'description': '', + 'in': 'path', + 'name': 'id', + 'required': True, + 'schema': { + 'type': 'string', + }, + }] + + def test_request_body(self): + path = '/' + method = 'POST' + + class Serializer(serializers.Serializer): + text = serializers.CharField() + read_only = serializers.CharField(read_only=True) + + class View(generics.GenericAPIView): + serializer_class = Serializer + + view = create_view( + View, + method, + create_request(path) + ) + inspector = AutoSchema() + inspector.view = view + + request_body = inspector._get_request_body(path, method) + assert request_body['content']['application/json']['schema']['required'] == ['text'] + assert list(request_body['content']['application/json']['schema']['properties'].keys()) == ['text'] + + def test_response_body_generation(self): + path = '/' + method = 'POST' + + class Serializer(serializers.Serializer): + text = serializers.CharField() + write_only = serializers.CharField(write_only=True) + + class View(generics.GenericAPIView): + serializer_class = Serializer + + view = create_view( + View, + method, + create_request(path) + ) + inspector = AutoSchema() + inspector.view = view + + responses = inspector._get_responses(path, method) + assert responses['200']['content']['application/json']['schema']['required'] == ['text'] + assert list(responses['200']['content']['application/json']['schema']['properties'].keys()) == ['text'] + + def test_response_body_nested_serializer(self): + path = '/' + method = 'POST' + + class NestedSerializer(serializers.Serializer): + number = serializers.IntegerField() + + class Serializer(serializers.Serializer): + text = serializers.CharField() + nested = NestedSerializer() + + class View(generics.GenericAPIView): + serializer_class = Serializer + + view = create_view( + View, + method, + create_request(path), + ) + inspector = AutoSchema() + inspector.view = view + + responses = inspector._get_responses(path, method) + schema = responses['200']['content']['application/json']['schema'] + assert sorted(schema['required']) == ['nested', 'text'] + assert sorted(list(schema['properties'].keys())) == ['nested', 'text'] + assert schema['properties']['nested']['type'] == 'object' + assert list(schema['properties']['nested']['properties'].keys()) == ['number'] + assert schema['properties']['nested']['required'] == ['number'] + + +@pytest.mark.skipif(uritemplate is None, reason='uritemplate not installed.') +@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema'}) +class TestGenerator(TestCase): + + def test_override_settings(self): + assert isinstance(views.ExampleListView.schema, AutoSchema) + + def test_paths_construction(self): + """Construction of the `paths` key.""" + patterns = [ + url(r'^example/?$', views.ExampleListView.as_view()), + ] + generator = SchemaGenerator(patterns=patterns) + generator._initialise_endpoints() + + paths = generator.get_paths() + + assert '/example/' in paths + example_operations = paths['/example/'] + assert len(example_operations) == 2 + assert 'get' in example_operations + assert 'post' in example_operations + + def test_schema_construction(self): + """Construction of the top level dictionary.""" + patterns = [ + url(r'^example/?$', views.ExampleListView.as_view()), + ] + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + + assert 'openapi' in schema + assert 'paths' in schema diff --git a/tests/schemas/views.py b/tests/schemas/views.py new file mode 100644 index 0000000000..c368ba7e56 --- /dev/null +++ b/tests/schemas/views.py @@ -0,0 +1,19 @@ +from rest_framework import permissions +from rest_framework.views import APIView + + +class ExampleListView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass + + def post(self, request, *args, **kwargs): + pass + + +class ExampleDetailView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass diff --git a/tests/test_generateschema.py b/tests/test_generateschema.py deleted file mode 100644 index 915c6ea059..0000000000 --- a/tests/test_generateschema.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import unicode_literals - -import pytest -from django.conf.urls import url -from django.core.management import call_command -from django.test import TestCase -from django.test.utils import override_settings -from django.utils import six - -from rest_framework.compat import coreapi -from rest_framework.utils import formatting, json -from rest_framework.views import APIView - - -class FooView(APIView): - def get(self, request): - pass - - -urlpatterns = [ - url(r'^$', FooView.as_view()) -] - - -@override_settings(ROOT_URLCONF='tests.test_generateschema') -@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') -class GenerateSchemaTests(TestCase): - """Tests for management command generateschema.""" - - def setUp(self): - self.out = six.StringIO() - - @pytest.mark.skipif(six.PY2, reason='PyYAML unicode output is malformed on PY2.') - def test_renders_default_schema_with_custom_title_url_and_description(self): - expected_out = """info: - description: Sample description - title: SampleAPI - version: '' - openapi: 3.0.0 - paths: - /: - get: - operationId: list - servers: - - url: http://api.sample.com/ - """ - call_command('generateschema', - '--title=SampleAPI', - '--url=http://api.sample.com', - '--description=Sample description', - stdout=self.out) - - self.assertIn(formatting.dedent(expected_out), self.out.getvalue()) - - def test_renders_openapi_json_schema(self): - expected_out = { - "openapi": "3.0.0", - "info": { - "version": "", - "title": "", - "description": "" - }, - "servers": [ - { - "url": "" - } - ], - "paths": { - "/": { - "get": { - "operationId": "list" - } - } - } - } - call_command('generateschema', - '--format=openapi-json', - stdout=self.out) - out_json = json.loads(self.out.getvalue()) - - self.assertDictEqual(out_json, expected_out) - - def test_renders_corejson_schema(self): - expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}""" - call_command('generateschema', - '--format=corejson', - stdout=self.out) - self.assertIn(expected_out, self.out.getvalue())