From 0d0b8f6ce8afd0e64fa8a2530e850999b086b179 Mon Sep 17 00:00:00 2001 From: Sergey Vilgelm Date: Sat, 15 Feb 2020 19:09:42 -0600 Subject: [PATCH 1/2] Remove flake8 validation and add black instead Add a `checks` job in workflow to do a preliminary checks --- .github/workflows/test.yml | 30 +++++++++++++++++++++++++++--- CONTRIBUTING.rst | 2 +- README.rst | 7 +++++-- requirements/develop.pip | 5 +---- requirements/test.pip | 4 ++++ setup.cfg | 5 ----- tasks.py | 10 ++-------- 7 files changed, 40 insertions(+), 23 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d5605758..3a8611c5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,7 +9,31 @@ on: schedule: - cron: "0 1 * * *" jobs: - test: + checks: + runs-on: ubuntu-latest + steps: + - name: Set up Python 3.8 + uses: actions/setup-python@v1 + with: + python-version: 3.8 + - name: Checkout code + uses: actions/checkout@v2 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ".[dev]" + - name: Format with black + run: | + if ! RES=$(black --check $(git diff --name-only --diff-filter=AM master -- "*.py") 2>&1); then + RES="${RES//'%'/'%25'}" + RES="${RES//$'\n'/'%0A'}" + RES="${RES//$'\r'/'%0D'}" + echo "::error ::${RES}" + exit 1 + fi + echo ${RES} + unit-tests: + needs: checks runs-on: ubuntu-latest strategy: matrix: @@ -24,7 +48,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ".[dev]" + pip install ".[test]" - name: Test with inv run: inv cover qa - name: Coveralls @@ -35,7 +59,7 @@ jobs: env: COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} bench: - needs: test + needs: unit-tests runs-on: ubuntu-latest if: github.event_name == 'pull_request' steps: diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index ddcdc1e9..3bd682c4 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -41,7 +41,7 @@ There are some rules to follow: - your contribution should be documented (if needed) - your contribution should be tested and the test suite should pass successfully -- your code should be mostly PEP8 compatible with a 120 characters line length +- your code should be properly formatted (use ``black .`` to format) - your contribution should support both Python 2 and 3 (use ``tox`` to test) You need to install some dependencies to develop on flask-restx: diff --git a/README.rst b/README.rst index ae73c7f9..67e8348e 100644 --- a/README.rst +++ b/README.rst @@ -18,8 +18,11 @@ Flask RESTX :target: https://pypi.org/project/flask-restx :alt: Supported Python versions .. image:: https://badges.gitter.im/Join%20Chat.svg - :alt: Join the chat at https://gitter.im/python-restx - :target: https://gitter.im/python-restx?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge + :target: https://gitter.im/python-restx?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge + :alt: Join the chat at https://gitter.im/python-restx +.. image:: https://img.shields.io/badge/code%20style-black-000000.svg + :target: https://github.com/psf/black + :alt: Code style: black Flask-RESTX is a community driven fork of `Flask-RESTPlus `_. diff --git a/requirements/develop.pip b/requirements/develop.pip index d9a65bab..eafc9a28 100644 --- a/requirements/develop.pip +++ b/requirements/develop.pip @@ -1,5 +1,2 @@ -invoke==1.3.0 -flake8==3.7.8 -readme-renderer==24.0 tox==3.13.2 -twine==1.15.0 +black==19.10b0; python_version >= '3.6' diff --git a/requirements/test.pip b/requirements/test.pip index 2fe326ad..07caade6 100644 --- a/requirements/test.pip +++ b/requirements/test.pip @@ -10,3 +10,7 @@ pytest-mock==1.10.4 pytest-profiling==1.7.0 pytest-sugar==0.9.2 tzlocal + +invoke==1.3.0 +readme-renderer==24.0 +twine==1.15.0 diff --git a/setup.cfg b/setup.cfg index 7e0aca7c..682d1e35 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,8 +1,3 @@ -[flake8] -ignore = E124,E128 -max-line-length = 120 -exclude = doc,.git - [bdist_wheel] universal = 1 diff --git a/tasks.py b/tasks.py index 2bd8a794..24efaa7b 100644 --- a/tasks.py +++ b/tasks.py @@ -150,12 +150,6 @@ def qa(ctx): '''Run a quality report''' header(qa.__doc__) with ctx.cd(ROOT): - info('Python Static Analysis') - flake8_results = ctx.run('flake8 flask_restx tests', pty=True, warn=True) - if flake8_results.failed: - error('There is some lints to fix') - else: - success('No linter errors') info('Ensure PyPI can render README and CHANGELOG') info('Building dist package') dist = ctx.run('python setup.py sdist', pty=True, warn=False, hide=True) @@ -168,8 +162,8 @@ def qa(ctx): error('README and/or CHANGELOG is not renderable by PyPI') else: success('README and CHANGELOG are renderable by PyPI') - if flake8_results.failed or readme_results.failed: - exit('Quality check failed', flake8_results.return_code or readme_results.return_code) + if readme_results.failed: + exit('Quality check failed', readme_results.return_code) success('Quality check OK') From b6b92d55da1ae7bbe828d353d2875dd9c0ab8c96 Mon Sep 17 00:00:00 2001 From: Sergey Vilgelm Date: Sat, 15 Feb 2020 15:29:33 -0600 Subject: [PATCH 2/2] Using black formatting Run `black .` to format the current code with black --- doc/conf.py | 258 +- examples/todo.py | 72 +- examples/todo_blueprint.py | 74 +- examples/todo_simple.py | 9 +- examples/todomvc.py | 65 +- examples/xml_representation.py | 20 +- examples/zoo/__init__.py | 6 +- examples/zoo/cat.py | 33 +- examples/zoo/dog.py | 33 +- flask_restx/__about__.py | 6 +- flask_restx/__init__.py | 44 +- flask_restx/_http.py | 245 +- flask_restx/api.py | 420 +-- flask_restx/apidoc.py | 22 +- flask_restx/cors.py | 40 +- flask_restx/errors.py | 23 +- flask_restx/fields.py | 399 ++- flask_restx/inputs.py | 266 +- flask_restx/marshalling.py | 44 +- flask_restx/mask.py | 82 +- flask_restx/model.py | 123 +- flask_restx/namespace.py | 182 +- flask_restx/postman.py | 169 +- flask_restx/representations.py | 6 +- flask_restx/reqparse.py | 221 +- flask_restx/resource.py | 24 +- flask_restx/schemas/__init__.py | 45 +- flask_restx/swagger.py | 522 ++-- flask_restx/utils.py | 43 +- setup.py | 105 +- tasks.py | 174 +- tests/benchmarks/bench_marshalling.py | 28 +- tests/benchmarks/bench_swagger.py | 58 +- tests/conftest.py | 40 +- tests/legacy/test_api_legacy.py | 201 +- tests/legacy/test_api_with_blueprint.py | 127 +- tests/test_accept.py | 102 +- tests/test_api.py | 227 +- tests/test_apidoc.py | 120 +- tests/test_cors.py | 39 +- tests/test_errors.py | 373 ++- tests/test_fields.py | 1162 ++++---- tests/test_fields_mask.py | 1058 +++---- tests/test_inputs.py | 1251 ++++---- tests/test_logging.py | 32 +- tests/test_marshalling.py | 571 ++-- tests/test_model.py | 757 ++--- tests/test_namespace.py | 142 +- tests/test_payload.py | 285 +- tests/test_postman.py | 313 +- tests/test_reqparse.py | 915 +++--- tests/test_schemas.py | 32 +- tests/test_swagger.py | 3622 +++++++++++------------ tests/test_swagger_utils.py | 173 +- tests/test_utils.py | 114 +- 55 files changed, 7819 insertions(+), 7698 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 2e8c49d5..d243374b 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -19,269 +19,274 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.viewcode', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx_issues', - 'alabaster', + "sphinx.ext.autodoc", + "sphinx.ext.viewcode", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx_issues", + "alabaster", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Flask-RESTX' -copyright = u'2014, Axel Haustant' +project = u"Flask-RESTX" +copyright = u"2014, Axel Haustant" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The full version, including alpha/beta/rc tags. -release = __import__('flask_restx').__version__ +release = __import__("flask_restx").__version__ # The short X.Y version. -version = '.'.join(release.split('.')[:1]) +version = ".".join(release.split(".")[:1]) # Github repo -issues_github_path = 'python-restx/flask-restx' +issues_github_path = "python-restx/flask-restx" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'restx' +html_theme = "restx" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. html_theme_options = { - 'logo': 'logo-512.png', - 'logo_name': True, - 'touch_icon': 'apple-180.png', - 'github_user': 'python-restx', - 'github_repo': 'flask-restx', - 'github_banner': True, - 'show_related': True, - 'page_width': '1000px', - 'sidebar_width': '260px', - 'favicons': { - 64: 'favicon-64.png', - 128: 'favicon-128.png', - 196: 'favicon-196.png', - }, - 'badges': [( - # Gitter.im - 'https://badges.gitter.im/Join%20Chat.svg', - 'https://gitter.im/python-restx', - 'Join the chat at https://gitter.im/python-restx' - ), ( - # Github Fork - 'https://img.shields.io/github/forks/python-restx/flask-restx.svg?style=social&label=Fork', - 'https://github.com/python-restx/flask-restx', - 'Github repository', - ), ( - # Github issues - 'https://img.shields.io/github/issues-raw/python-restx/flask-restx.svg', - 'https://github.com/python-restx/flask-restx/issues', - 'Github repository', - ), ( - # License - 'https://img.shields.io/github/license/python-restx/flask-restx.svg', - 'https://github.com/python-restx/flask-restx', - 'License', - ), ( - # PyPI - 'https://img.shields.io/pypi/v/flask-restx.svg', - 'https://pypi.python.org/pypi/flask-restx', - 'Latest version on PyPI' - )] + "logo": "logo-512.png", + "logo_name": True, + "touch_icon": "apple-180.png", + "github_user": "python-restx", + "github_repo": "flask-restx", + "github_banner": True, + "show_related": True, + "page_width": "1000px", + "sidebar_width": "260px", + "favicons": {64: "favicon-64.png", 128: "favicon-128.png", 196: "favicon-196.png",}, + "badges": [ + ( + # Gitter.im + "https://badges.gitter.im/Join%20Chat.svg", + "https://gitter.im/python-restx", + "Join the chat at https://gitter.im/python-restx", + ), + ( + # Github Fork + "https://img.shields.io/github/forks/python-restx/flask-restx.svg?style=social&label=Fork", + "https://github.com/python-restx/flask-restx", + "Github repository", + ), + ( + # Github issues + "https://img.shields.io/github/issues-raw/python-restx/flask-restx.svg", + "https://github.com/python-restx/flask-restx/issues", + "Github repository", + ), + ( + # License + "https://img.shields.io/github/license/python-restx/flask-restx.svg", + "https://github.com/python-restx/flask-restx", + "License", + ), + ( + # PyPI + "https://img.shields.io/pypi/v/flask-restx.svg", + "https://pypi.python.org/pypi/flask-restx", + "Latest version on PyPI", + ), + ], } # Add any paths that contain custom themes here, relative to this directory. -html_theme_path = [alabaster.get_path(), '_themes'] +html_theme_path = [alabaster.get_path(), "_themes"] html_context = {} # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -html_favicon = '_static/favicon.ico' +html_favicon = "_static/favicon.ico" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. html_sidebars = { - '**': [ - 'about.html', - 'navigation.html', - 'relations.html', - 'searchbox.html', - 'donate.html', - 'badges.html', + "**": [ + "about.html", + "navigation.html", + "relations.html", + "searchbox.html", + "donate.html", + "badges.html", ] } # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'Flask-RESTXdoc' +htmlhelp_basename = "Flask-RESTXdoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - ('index', 'Flask-RESTX.tex', u'Flask-RESTX Documentation', - u'Axel Haustant', 'manual'), + ( + "index", + "Flask-RESTX.tex", + u"Flask-RESTX Documentation", + u"Axel Haustant", + "manual", + ), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- @@ -289,12 +294,11 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - ('index', 'flask-restx', u'Flask-RESTX Documentation', - [u'Axel Haustant'], 1) + ("index", "flask-restx", u"Flask-RESTX Documentation", [u"Axel Haustant"], 1) ] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -303,26 +307,32 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'Flask-RESTX', u'Flask-RESTX Documentation', - u'Axel Haustant', 'Flask-RESTX', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "Flask-RESTX", + u"Flask-RESTX Documentation", + u"Axel Haustant", + "Flask-RESTX", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False intersphinx_mapping = { - 'flask': ('http://flask.pocoo.org/docs/', None), - 'python': ('http://docs.python.org/', None), - 'werkzeug': ('http://werkzeug.pocoo.org/docs/', None), + "flask": ("http://flask.pocoo.org/docs/", None), + "python": ("http://docs.python.org/", None), + "werkzeug": ("http://werkzeug.pocoo.org/docs/", None), } diff --git a/examples/todo.py b/examples/todo.py index f85f0ddf..d0edbeb7 100644 --- a/examples/todo.py +++ b/examples/todo.py @@ -4,81 +4,87 @@ app = Flask(__name__) app.wsgi_app = ProxyFix(app.wsgi_app) -api = Api(app, version='1.0', title='Todo API', - description='A simple TODO API', -) +api = Api(app, version="1.0", title="Todo API", description="A simple TODO API",) -ns = api.namespace('todos', description='TODO operations') +ns = api.namespace("todos", description="TODO operations") TODOS = { - 'todo1': {'task': 'build an API'}, - 'todo2': {'task': '?????'}, - 'todo3': {'task': 'profit!'}, + "todo1": {"task": "build an API"}, + "todo2": {"task": "?????"}, + "todo3": {"task": "profit!"}, } -todo = api.model('Todo', { - 'task': fields.String(required=True, description='The task details') -}) +todo = api.model( + "Todo", {"task": fields.String(required=True, description="The task details")} +) -listed_todo = api.model('ListedTodo', { - 'id': fields.String(required=True, description='The todo ID'), - 'todo': fields.Nested(todo, description='The Todo') -}) +listed_todo = api.model( + "ListedTodo", + { + "id": fields.String(required=True, description="The todo ID"), + "todo": fields.Nested(todo, description="The Todo"), + }, +) def abort_if_todo_doesnt_exist(todo_id): if todo_id not in TODOS: api.abort(404, "Todo {} doesn't exist".format(todo_id)) + parser = api.parser() -parser.add_argument('task', type=str, required=True, help='The task details', location='form') +parser.add_argument( + "task", type=str, required=True, help="The task details", location="form" +) -@ns.route('/') -@api.doc(responses={404: 'Todo not found'}, params={'todo_id': 'The Todo ID'}) +@ns.route("/") +@api.doc(responses={404: "Todo not found"}, params={"todo_id": "The Todo ID"}) class Todo(Resource): - '''Show a single todo item and lets you delete them''' - @api.doc(description='todo_id should be in {0}'.format(', '.join(TODOS.keys()))) + """Show a single todo item and lets you delete them""" + + @api.doc(description="todo_id should be in {0}".format(", ".join(TODOS.keys()))) @api.marshal_with(todo) def get(self, todo_id): - '''Fetch a given resource''' + """Fetch a given resource""" abort_if_todo_doesnt_exist(todo_id) return TODOS[todo_id] - @api.doc(responses={204: 'Todo deleted'}) + @api.doc(responses={204: "Todo deleted"}) def delete(self, todo_id): - '''Delete a given resource''' + """Delete a given resource""" abort_if_todo_doesnt_exist(todo_id) del TODOS[todo_id] - return '', 204 + return "", 204 @api.doc(parser=parser) @api.marshal_with(todo) def put(self, todo_id): - '''Update a given resource''' + """Update a given resource""" args = parser.parse_args() - task = {'task': args['task']} + task = {"task": args["task"]} TODOS[todo_id] = task return task -@ns.route('/') +@ns.route("/") class TodoList(Resource): - '''Shows a list of all todos, and lets you POST to add new tasks''' + """Shows a list of all todos, and lets you POST to add new tasks""" + @api.marshal_list_with(listed_todo) def get(self): - '''List all todos''' - return [{'id': id, 'todo': todo} for id, todo in TODOS.items()] + """List all todos""" + return [{"id": id, "todo": todo} for id, todo in TODOS.items()] @api.doc(parser=parser) @api.marshal_with(todo, code=201) def post(self): - '''Create a todo''' + """Create a todo""" args = parser.parse_args() - todo_id = 'todo%d' % (len(TODOS) + 1) - TODOS[todo_id] = {'task': args['task']} + todo_id = "todo%d" % (len(TODOS) + 1) + TODOS[todo_id] = {"task": args["task"]} return TODOS[todo_id], 201 -if __name__ == '__main__': +if __name__ == "__main__": app.run(debug=True) diff --git a/examples/todo_blueprint.py b/examples/todo_blueprint.py index 1096346f..ce860366 100644 --- a/examples/todo_blueprint.py +++ b/examples/todo_blueprint.py @@ -1,85 +1,91 @@ from flask import Flask, Blueprint from flask_restx import Api, Resource, fields -api_v1 = Blueprint('api', __name__, url_prefix='/api/1') +api_v1 = Blueprint("api", __name__, url_prefix="/api/1") -api = Api(api_v1, version='1.0', title='Todo API', - description='A simple TODO API', -) +api = Api(api_v1, version="1.0", title="Todo API", description="A simple TODO API",) -ns = api.namespace('todos', description='TODO operations') +ns = api.namespace("todos", description="TODO operations") TODOS = { - 'todo1': {'task': 'build an API'}, - 'todo2': {'task': '?????'}, - 'todo3': {'task': 'profit!'}, + "todo1": {"task": "build an API"}, + "todo2": {"task": "?????"}, + "todo3": {"task": "profit!"}, } -todo = api.model('Todo', { - 'task': fields.String(required=True, description='The task details') -}) +todo = api.model( + "Todo", {"task": fields.String(required=True, description="The task details")} +) -listed_todo = api.model('ListedTodo', { - 'id': fields.String(required=True, description='The todo ID'), - 'todo': fields.Nested(todo, description='The Todo') -}) +listed_todo = api.model( + "ListedTodo", + { + "id": fields.String(required=True, description="The todo ID"), + "todo": fields.Nested(todo, description="The Todo"), + }, +) def abort_if_todo_doesnt_exist(todo_id): if todo_id not in TODOS: api.abort(404, "Todo {} doesn't exist".format(todo_id)) + parser = api.parser() -parser.add_argument('task', type=str, required=True, help='The task details', location='form') +parser.add_argument( + "task", type=str, required=True, help="The task details", location="form" +) -@ns.route('/') -@api.doc(responses={404: 'Todo not found'}, params={'todo_id': 'The Todo ID'}) +@ns.route("/") +@api.doc(responses={404: "Todo not found"}, params={"todo_id": "The Todo ID"}) class Todo(Resource): - '''Show a single todo item and lets you delete them''' - @api.doc(description='todo_id should be in {0}'.format(', '.join(TODOS.keys()))) + """Show a single todo item and lets you delete them""" + + @api.doc(description="todo_id should be in {0}".format(", ".join(TODOS.keys()))) @api.marshal_with(todo) def get(self, todo_id): - '''Fetch a given resource''' + """Fetch a given resource""" abort_if_todo_doesnt_exist(todo_id) return TODOS[todo_id] - @api.doc(responses={204: 'Todo deleted'}) + @api.doc(responses={204: "Todo deleted"}) def delete(self, todo_id): - '''Delete a given resource''' + """Delete a given resource""" abort_if_todo_doesnt_exist(todo_id) del TODOS[todo_id] - return '', 204 + return "", 204 @api.doc(parser=parser) @api.marshal_with(todo) def put(self, todo_id): - '''Update a given resource''' + """Update a given resource""" args = parser.parse_args() - task = {'task': args['task']} + task = {"task": args["task"]} TODOS[todo_id] = task return task -@ns.route('/') +@ns.route("/") class TodoList(Resource): - '''Shows a list of all todos, and lets you POST to add new tasks''' + """Shows a list of all todos, and lets you POST to add new tasks""" + @api.marshal_list_with(listed_todo) def get(self): - '''List all todos''' - return [{'id': id, 'todo': todo} for id, todo in TODOS.items()] + """List all todos""" + return [{"id": id, "todo": todo} for id, todo in TODOS.items()] @api.doc(parser=parser) @api.marshal_with(todo, code=201) def post(self): - '''Create a todo''' + """Create a todo""" args = parser.parse_args() - todo_id = 'todo%d' % (len(TODOS) + 1) - TODOS[todo_id] = {'task': args['task']} + todo_id = "todo%d" % (len(TODOS) + 1) + TODOS[todo_id] = {"task": args["task"]} return TODOS[todo_id], 201 -if __name__ == '__main__': +if __name__ == "__main__": app = Flask(__name__) app.register_blueprint(api_v1) app.run(debug=True) diff --git a/examples/todo_simple.py b/examples/todo_simple.py index 9bb217ee..07e673e9 100644 --- a/examples/todo_simple.py +++ b/examples/todo_simple.py @@ -7,7 +7,7 @@ todos = {} -@api.route('/') +@api.route("/") class TodoSimple(Resource): """ You can try this example as follow: @@ -30,15 +30,14 @@ class TodoSimple(Resource): {u'todo2': u'Change my breakpads'} """ + def get(self, todo_id): return {todo_id: todos[todo_id]} def put(self, todo_id): - todos[todo_id] = request.form['data'] + todos[todo_id] = request.form["data"] return {todo_id: todos[todo_id]} -if __name__ == '__main__': +if __name__ == "__main__": app.run(debug=False) - - diff --git a/examples/todomvc.py b/examples/todomvc.py index 077a6f55..c9d94856 100644 --- a/examples/todomvc.py +++ b/examples/todomvc.py @@ -4,16 +4,17 @@ app = Flask(__name__) app.wsgi_app = ProxyFix(app.wsgi_app) -api = Api(app, version='1.0', title='TodoMVC API', - description='A simple TodoMVC API', -) +api = Api(app, version="1.0", title="TodoMVC API", description="A simple TodoMVC API",) -ns = api.namespace('todos', description='TODO operations') +ns = api.namespace("todos", description="TODO operations") -todo = api.model('Todo', { - 'id': fields.Integer(readonly=True, description='The task unique identifier'), - 'task': fields.String(required=True, description='The task details') -}) +todo = api.model( + "Todo", + { + "id": fields.Integer(readonly=True, description="The task unique identifier"), + "task": fields.String(required=True, description="The task details"), + }, +) class TodoDAO(object): @@ -23,13 +24,13 @@ def __init__(self): def get(self, id): for todo in self.todos: - if todo['id'] == id: + if todo["id"] == id: return todo api.abort(404, "Todo {} doesn't exist".format(id)) def create(self, data): todo = data - todo['id'] = self.counter = self.counter + 1 + todo["id"] = self.counter = self.counter + 1 self.todos.append(todo) return todo @@ -44,52 +45,54 @@ def delete(self, id): DAO = TodoDAO() -DAO.create({'task': 'Build an API'}) -DAO.create({'task': '?????'}) -DAO.create({'task': 'profit!'}) +DAO.create({"task": "Build an API"}) +DAO.create({"task": "?????"}) +DAO.create({"task": "profit!"}) -@ns.route('/') +@ns.route("/") class TodoList(Resource): - '''Shows a list of all todos, and lets you POST to add new tasks''' - @ns.doc('list_todos') + """Shows a list of all todos, and lets you POST to add new tasks""" + + @ns.doc("list_todos") @ns.marshal_list_with(todo) def get(self): - '''List all tasks''' + """List all tasks""" return DAO.todos - @ns.doc('create_todo') + @ns.doc("create_todo") @ns.expect(todo) @ns.marshal_with(todo, code=201) def post(self): - '''Create a new task''' + """Create a new task""" return DAO.create(api.payload), 201 -@ns.route('/') -@ns.response(404, 'Todo not found') -@ns.param('id', 'The task identifier') +@ns.route("/") +@ns.response(404, "Todo not found") +@ns.param("id", "The task identifier") class Todo(Resource): - '''Show a single todo item and lets you delete them''' - @ns.doc('get_todo') + """Show a single todo item and lets you delete them""" + + @ns.doc("get_todo") @ns.marshal_with(todo) def get(self, id): - '''Fetch a given resource''' + """Fetch a given resource""" return DAO.get(id) - @ns.doc('delete_todo') - @ns.response(204, 'Todo deleted') + @ns.doc("delete_todo") + @ns.response(204, "Todo deleted") def delete(self, id): - '''Delete a task given its identifier''' + """Delete a task given its identifier""" DAO.delete(id) - return '', 204 + return "", 204 @ns.expect(todo) @ns.marshal_with(todo) def put(self, id): - '''Update a task given its identifier''' + """Update a task given its identifier""" return DAO.update(id, api.payload) -if __name__ == '__main__': +if __name__ == "__main__": app.run(debug=True) diff --git a/examples/xml_representation.py b/examples/xml_representation.py index 25b27ac9..19999b4e 100644 --- a/examples/xml_representation.py +++ b/examples/xml_representation.py @@ -6,18 +6,19 @@ def output_xml(data, code, headers=None): """Makes a Flask response with a XML encoded body""" - resp = make_response(dumps({'response': data}), code) + resp = make_response(dumps({"response": data}), code) resp.headers.extend(headers or {}) return resp + app = Flask(__name__) -api = Api(app, default_mediatype='application/xml') -api.representations['application/xml'] = output_xml +api = Api(app, default_mediatype="application/xml") +api.representations["application/xml"] = output_xml -hello_fields = api.model('Hello', {'entry': fields.String}) +hello_fields = api.model("Hello", {"entry": fields.String}) -@api.route('/') +@api.route("/") class Hello(Resource): """ # you need requests @@ -29,11 +30,12 @@ class Hello(Resource): >>> get('http://localhost:5000/me', headers={"accept":"application/xml"}).content 'me' """ - @api.doc(model=hello_fields, params={'entry': 'The entry to wrap'}) + + @api.doc(model=hello_fields, params={"entry": "The entry to wrap"}) def get(self, entry): - '''Get a wrapped entry''' - return {'hello': entry} + """Get a wrapped entry""" + return {"hello": entry} -if __name__ == '__main__': +if __name__ == "__main__": app.run(debug=True) diff --git a/examples/zoo/__init__.py b/examples/zoo/__init__.py index d066087a..b38600d2 100644 --- a/examples/zoo/__init__.py +++ b/examples/zoo/__init__.py @@ -3,11 +3,7 @@ from .cat import api as cat_api from .dog import api as dog_api -api = Api( - title='Zoo API', - version='1.0', - description='A simple demo API', -) +api = Api(title="Zoo API", version="1.0", description="A simple demo API",) api.add_namespace(cat_api) api.add_namespace(dog_api) diff --git a/examples/zoo/cat.py b/examples/zoo/cat.py index a6e39c5a..be1fc624 100644 --- a/examples/zoo/cat.py +++ b/examples/zoo/cat.py @@ -1,35 +1,38 @@ from flask_restx import Namespace, Resource, fields -api = Namespace('cats', description='Cats related operations') +api = Namespace("cats", description="Cats related operations") -cat = api.model('Cat', { - 'id': fields.String(required=True, description='The cat identifier'), - 'name': fields.String(required=True, description='The cat name'), -}) +cat = api.model( + "Cat", + { + "id": fields.String(required=True, description="The cat identifier"), + "name": fields.String(required=True, description="The cat name"), + }, +) CATS = [ - {'id': 'felix', 'name': 'Felix'}, + {"id": "felix", "name": "Felix"}, ] -@api.route('/') +@api.route("/") class CatList(Resource): - @api.doc('list_cats') + @api.doc("list_cats") @api.marshal_list_with(cat) def get(self): - '''List all cats''' + """List all cats""" return CATS -@api.route('/') -@api.param('id', 'The cat identifier') -@api.response(404, 'Cat not found') +@api.route("/") +@api.param("id", "The cat identifier") +@api.response(404, "Cat not found") class Cat(Resource): - @api.doc('get_cat') + @api.doc("get_cat") @api.marshal_with(cat) def get(self, id): - '''Fetch a cat given its identifier''' + """Fetch a cat given its identifier""" for cat in CATS: - if cat['id'] == id: + if cat["id"] == id: return cat api.abort(404) diff --git a/examples/zoo/dog.py b/examples/zoo/dog.py index d4c6bb4c..4666cbf7 100644 --- a/examples/zoo/dog.py +++ b/examples/zoo/dog.py @@ -1,35 +1,38 @@ from flask_restx import Namespace, Resource, fields -api = Namespace('dogs', description='Dogs related operations') +api = Namespace("dogs", description="Dogs related operations") -dog = api.model('Dog', { - 'id': fields.String(required=True, description='The dog identifier'), - 'name': fields.String(required=True, description='The dog name'), -}) +dog = api.model( + "Dog", + { + "id": fields.String(required=True, description="The dog identifier"), + "name": fields.String(required=True, description="The dog name"), + }, +) DOGS = [ - {'id': 'medor', 'name': 'Medor'}, + {"id": "medor", "name": "Medor"}, ] -@api.route('/') +@api.route("/") class DogList(Resource): - @api.doc('list_dogs') + @api.doc("list_dogs") @api.marshal_list_with(dog) def get(self): - '''List all dogs''' + """List all dogs""" return DOGS -@api.route('/') -@api.param('id', 'The dog identifier') -@api.response(404, 'Dog not found') +@api.route("/") +@api.param("id", "The dog identifier") +@api.response(404, "Dog not found") class Dog(Resource): - @api.doc('get_dog') + @api.doc("get_dog") @api.marshal_with(dog) def get(self, id): - '''Fetch a dog given its identifier''' + """Fetch a dog given its identifier""" for dog in DOGS: - if dog['id'] == id: + if dog["id"] == id: return dog api.abort(404) diff --git a/flask_restx/__about__.py b/flask_restx/__about__.py index e554d4ba..58e64ec3 100644 --- a/flask_restx/__about__.py +++ b/flask_restx/__about__.py @@ -1,3 +1,5 @@ # -*- coding: utf-8 -*- -__version__ = '0.1.1' -__description__ = 'Fully featured framework for fast, easy and documented API development with Flask' +__version__ = "0.1.1" +__description__ = ( + "Fully featured framework for fast, easy and documented API development with Flask" +) diff --git a/flask_restx/__init__.py b/flask_restx/__init__.py index 4929bb1a..c02b76a3 100644 --- a/flask_restx/__init__.py +++ b/flask_restx/__init__.py @@ -13,26 +13,26 @@ from .__about__ import __version__, __description__ __all__ = ( - '__version__', - '__description__', - 'Api', - 'Resource', - 'apidoc', - 'marshal', - 'marshal_with', - 'marshal_with_field', - 'Mask', - 'Model', - 'Namespace', - 'OrderedModel', - 'SchemaModel', - 'abort', - 'cors', - 'fields', - 'inputs', - 'reqparse', - 'RestError', - 'SpecsError', - 'Swagger', - 'ValidationError', + "__version__", + "__description__", + "Api", + "Resource", + "apidoc", + "marshal", + "marshal_with", + "marshal_with_field", + "Mask", + "Model", + "Namespace", + "OrderedModel", + "SchemaModel", + "abort", + "cors", + "fields", + "inputs", + "reqparse", + "RestError", + "SpecsError", + "Swagger", + "ValidationError", ) diff --git a/flask_restx/_http.py b/flask_restx/_http.py index b86714c1..800b7660 100644 --- a/flask_restx/_http.py +++ b/flask_restx/_http.py @@ -20,7 +20,8 @@ class HTTPStatus(IntEnum): * RFC 2295: Transparent Content Negotiation in HTTP * RFC 2774: An HTTP Extension Framework """ - def __new__(cls, value, phrase, description=''): + + def __new__(cls, value, phrase, description=""): obj = int.__new__(cls, value) obj._value_ = value @@ -32,110 +33,154 @@ def __str__(self): return str(self.value) # informational - CONTINUE = 100, 'Continue', 'Request received, please continue' - SWITCHING_PROTOCOLS = (101, 'Switching Protocols', - 'Switching to new protocol; obey Upgrade header') - PROCESSING = 102, 'Processing' + CONTINUE = 100, "Continue", "Request received, please continue" + SWITCHING_PROTOCOLS = ( + 101, + "Switching Protocols", + "Switching to new protocol; obey Upgrade header", + ) + PROCESSING = 102, "Processing" # success - OK = 200, 'OK', 'Request fulfilled, document follows' - CREATED = 201, 'Created', 'Document created, URL follows' - ACCEPTED = (202, 'Accepted', - 'Request accepted, processing continues off-line') - NON_AUTHORITATIVE_INFORMATION = (203, - 'Non-Authoritative Information', 'Request fulfilled from cache') - NO_CONTENT = 204, 'No Content', 'Request fulfilled, nothing follows' - RESET_CONTENT = 205, 'Reset Content', 'Clear input form for further input' - PARTIAL_CONTENT = 206, 'Partial Content', 'Partial content follows' - MULTI_STATUS = 207, 'Multi-Status' - ALREADY_REPORTED = 208, 'Already Reported' - IM_USED = 226, 'IM Used' + OK = 200, "OK", "Request fulfilled, document follows" + CREATED = 201, "Created", "Document created, URL follows" + ACCEPTED = (202, "Accepted", "Request accepted, processing continues off-line") + NON_AUTHORITATIVE_INFORMATION = ( + 203, + "Non-Authoritative Information", + "Request fulfilled from cache", + ) + NO_CONTENT = 204, "No Content", "Request fulfilled, nothing follows" + RESET_CONTENT = 205, "Reset Content", "Clear input form for further input" + PARTIAL_CONTENT = 206, "Partial Content", "Partial content follows" + MULTI_STATUS = 207, "Multi-Status" + ALREADY_REPORTED = 208, "Already Reported" + IM_USED = 226, "IM Used" # redirection - MULTIPLE_CHOICES = (300, 'Multiple Choices', - 'Object has several resources -- see URI list') - MOVED_PERMANENTLY = (301, 'Moved Permanently', - 'Object moved permanently -- see URI list') - FOUND = 302, 'Found', 'Object moved temporarily -- see URI list' - SEE_OTHER = 303, 'See Other', 'Object moved -- see Method and URL list' - NOT_MODIFIED = (304, 'Not Modified', - 'Document has not changed since given time') - USE_PROXY = (305, 'Use Proxy', - 'You must use proxy specified in Location to access this resource') - TEMPORARY_REDIRECT = (307, 'Temporary Redirect', - 'Object moved temporarily -- see URI list') - PERMANENT_REDIRECT = (308, 'Permanent Redirect', - 'Object moved temporarily -- see URI list') + MULTIPLE_CHOICES = ( + 300, + "Multiple Choices", + "Object has several resources -- see URI list", + ) + MOVED_PERMANENTLY = ( + 301, + "Moved Permanently", + "Object moved permanently -- see URI list", + ) + FOUND = 302, "Found", "Object moved temporarily -- see URI list" + SEE_OTHER = 303, "See Other", "Object moved -- see Method and URL list" + NOT_MODIFIED = (304, "Not Modified", "Document has not changed since given time") + USE_PROXY = ( + 305, + "Use Proxy", + "You must use proxy specified in Location to access this resource", + ) + TEMPORARY_REDIRECT = ( + 307, + "Temporary Redirect", + "Object moved temporarily -- see URI list", + ) + PERMANENT_REDIRECT = ( + 308, + "Permanent Redirect", + "Object moved temporarily -- see URI list", + ) # client error - BAD_REQUEST = (400, 'Bad Request', - 'Bad request syntax or unsupported method') - UNAUTHORIZED = (401, 'Unauthorized', - 'No permission -- see authorization schemes') - PAYMENT_REQUIRED = (402, 'Payment Required', - 'No payment -- see charging schemes') - FORBIDDEN = (403, 'Forbidden', - 'Request forbidden -- authorization will not help') - NOT_FOUND = (404, 'Not Found', - 'Nothing matches the given URI') - METHOD_NOT_ALLOWED = (405, 'Method Not Allowed', - 'Specified method is invalid for this resource') - NOT_ACCEPTABLE = (406, 'Not Acceptable', - 'URI not available in preferred format') - PROXY_AUTHENTICATION_REQUIRED = (407, - 'Proxy Authentication Required', - 'You must authenticate with this proxy before proceeding') - REQUEST_TIMEOUT = (408, 'Request Timeout', - 'Request timed out; try again later') - CONFLICT = 409, 'Conflict', 'Request conflict' - GONE = (410, 'Gone', - 'URI no longer exists and has been permanently removed') - LENGTH_REQUIRED = (411, 'Length Required', - 'Client must specify Content-Length') - PRECONDITION_FAILED = (412, 'Precondition Failed', - 'Precondition in headers is false') - REQUEST_ENTITY_TOO_LARGE = (413, 'Request Entity Too Large', - 'Entity is too large') - REQUEST_URI_TOO_LONG = (414, 'Request-URI Too Long', - 'URI is too long') - UNSUPPORTED_MEDIA_TYPE = (415, 'Unsupported Media Type', - 'Entity body in unsupported format') - REQUESTED_RANGE_NOT_SATISFIABLE = (416, - 'Requested Range Not Satisfiable', - 'Cannot satisfy request range') - EXPECTATION_FAILED = (417, 'Expectation Failed', - 'Expect condition could not be satisfied') - UNPROCESSABLE_ENTITY = 422, 'Unprocessable Entity' - LOCKED = 423, 'Locked' - FAILED_DEPENDENCY = 424, 'Failed Dependency' - UPGRADE_REQUIRED = 426, 'Upgrade Required' - PRECONDITION_REQUIRED = (428, 'Precondition Required', - 'The origin server requires the request to be conditional') - TOO_MANY_REQUESTS = (429, 'Too Many Requests', - 'The user has sent too many requests in ' - 'a given amount of time ("rate limiting")') - REQUEST_HEADER_FIELDS_TOO_LARGE = (431, - 'Request Header Fields Too Large', - 'The server is unwilling to process the request because its header ' - 'fields are too large') + BAD_REQUEST = (400, "Bad Request", "Bad request syntax or unsupported method") + UNAUTHORIZED = (401, "Unauthorized", "No permission -- see authorization schemes") + PAYMENT_REQUIRED = (402, "Payment Required", "No payment -- see charging schemes") + FORBIDDEN = (403, "Forbidden", "Request forbidden -- authorization will not help") + NOT_FOUND = (404, "Not Found", "Nothing matches the given URI") + METHOD_NOT_ALLOWED = ( + 405, + "Method Not Allowed", + "Specified method is invalid for this resource", + ) + NOT_ACCEPTABLE = (406, "Not Acceptable", "URI not available in preferred format") + PROXY_AUTHENTICATION_REQUIRED = ( + 407, + "Proxy Authentication Required", + "You must authenticate with this proxy before proceeding", + ) + REQUEST_TIMEOUT = (408, "Request Timeout", "Request timed out; try again later") + CONFLICT = 409, "Conflict", "Request conflict" + GONE = (410, "Gone", "URI no longer exists and has been permanently removed") + LENGTH_REQUIRED = (411, "Length Required", "Client must specify Content-Length") + PRECONDITION_FAILED = ( + 412, + "Precondition Failed", + "Precondition in headers is false", + ) + REQUEST_ENTITY_TOO_LARGE = (413, "Request Entity Too Large", "Entity is too large") + REQUEST_URI_TOO_LONG = (414, "Request-URI Too Long", "URI is too long") + UNSUPPORTED_MEDIA_TYPE = ( + 415, + "Unsupported Media Type", + "Entity body in unsupported format", + ) + REQUESTED_RANGE_NOT_SATISFIABLE = ( + 416, + "Requested Range Not Satisfiable", + "Cannot satisfy request range", + ) + EXPECTATION_FAILED = ( + 417, + "Expectation Failed", + "Expect condition could not be satisfied", + ) + UNPROCESSABLE_ENTITY = 422, "Unprocessable Entity" + LOCKED = 423, "Locked" + FAILED_DEPENDENCY = 424, "Failed Dependency" + UPGRADE_REQUIRED = 426, "Upgrade Required" + PRECONDITION_REQUIRED = ( + 428, + "Precondition Required", + "The origin server requires the request to be conditional", + ) + TOO_MANY_REQUESTS = ( + 429, + "Too Many Requests", + "The user has sent too many requests in " + 'a given amount of time ("rate limiting")', + ) + REQUEST_HEADER_FIELDS_TOO_LARGE = ( + 431, + "Request Header Fields Too Large", + "The server is unwilling to process the request because its header " + "fields are too large", + ) # server errors - INTERNAL_SERVER_ERROR = (500, 'Internal Server Error', - 'Server got itself in trouble') - NOT_IMPLEMENTED = (501, 'Not Implemented', - 'Server does not support this operation') - BAD_GATEWAY = (502, 'Bad Gateway', - 'Invalid responses from another server/proxy') - SERVICE_UNAVAILABLE = (503, 'Service Unavailable', - 'The server cannot process the request due to a high load') - GATEWAY_TIMEOUT = (504, 'Gateway Timeout', - 'The gateway server did not receive a timely response') - HTTP_VERSION_NOT_SUPPORTED = (505, 'HTTP Version Not Supported', - 'Cannot fulfill request') - VARIANT_ALSO_NEGOTIATES = 506, 'Variant Also Negotiates' - INSUFFICIENT_STORAGE = 507, 'Insufficient Storage' - LOOP_DETECTED = 508, 'Loop Detected' - NOT_EXTENDED = 510, 'Not Extended' - NETWORK_AUTHENTICATION_REQUIRED = (511, - 'Network Authentication Required', - 'The client needs to authenticate to gain network access') + INTERNAL_SERVER_ERROR = ( + 500, + "Internal Server Error", + "Server got itself in trouble", + ) + NOT_IMPLEMENTED = (501, "Not Implemented", "Server does not support this operation") + BAD_GATEWAY = (502, "Bad Gateway", "Invalid responses from another server/proxy") + SERVICE_UNAVAILABLE = ( + 503, + "Service Unavailable", + "The server cannot process the request due to a high load", + ) + GATEWAY_TIMEOUT = ( + 504, + "Gateway Timeout", + "The gateway server did not receive a timely response", + ) + HTTP_VERSION_NOT_SUPPORTED = ( + 505, + "HTTP Version Not Supported", + "Cannot fulfill request", + ) + VARIANT_ALSO_NEGOTIATES = 506, "Variant Also Negotiates" + INSUFFICIENT_STORAGE = 507, "Insufficient Storage" + LOOP_DETECTED = 508, "Loop Detected" + NOT_EXTENDED = 510, "Not Extended" + NETWORK_AUTHENTICATION_REQUIRED = ( + 511, + "Network Authentication Required", + "The client needs to authenticate to gain network access", + ) diff --git a/flask_restx/api.py b/flask_restx/api.py index 89a566b5..f9b098bb 100644 --- a/flask_restx/api.py +++ b/flask_restx/api.py @@ -23,7 +23,13 @@ from werkzeug.utils import cached_property from werkzeug.datastructures import Headers -from werkzeug.exceptions import HTTPException, MethodNotAllowed, NotFound, NotAcceptable, InternalServerError +from werkzeug.exceptions import ( + HTTPException, + MethodNotAllowed, + NotFound, + NotAcceptable, + InternalServerError, +) from werkzeug.wrappers import BaseResponse from . import apidoc @@ -36,18 +42,18 @@ from .representations import output_json from ._http import HTTPStatus -RE_RULES = re.compile('(<.*>)') +RE_RULES = re.compile("(<.*>)") # List headers that should never be handled by Flask-RESTX -HEADERS_BLACKLIST = ('Content-Length',) +HEADERS_BLACKLIST = ("Content-Length",) -DEFAULT_REPRESENTATIONS = [('application/json', output_json)] +DEFAULT_REPRESENTATIONS = [("application/json", output_json)] log = logging.getLogger(__name__) class Api(object): - ''' + """ The main entry point for the application. You need to initialize it with a Flask Application: :: @@ -87,19 +93,39 @@ class Api(object): :param FormatChecker format_checker: A jsonschema.FormatChecker object that is hooked into the Model validator. A default or a custom FormatChecker can be provided (e.g., with custom checkers), otherwise the default action is to not enforce any format validation. - ''' - - def __init__(self, app=None, version='1.0', title=None, description=None, - terms_url=None, license=None, license_url=None, - contact=None, contact_url=None, contact_email=None, - authorizations=None, security=None, doc='/', default_id=default_id, - default='default', default_label='Default namespace', validate=None, - tags=None, prefix='', ordered=False, - default_mediatype='application/json', decorators=None, - catch_all_404s=False, serve_challenge_on_401=False, format_checker=None, - **kwargs): + """ + + def __init__( + self, + app=None, + version="1.0", + title=None, + description=None, + terms_url=None, + license=None, + license_url=None, + contact=None, + contact_url=None, + contact_email=None, + authorizations=None, + security=None, + doc="/", + default_id=default_id, + default="default", + default_label="Default namespace", + validate=None, + tags=None, + prefix="", + ordered=False, + default_mediatype="application/json", + decorators=None, + catch_all_404s=False, + serve_challenge_on_401=False, + format_checker=None, + **kwargs + ): self.version = version - self.title = title or 'API' + self.title = title or "API" self.description = description self.terms_url = terms_url self.contact = contact @@ -143,11 +169,13 @@ def __init__(self, app=None, version='1.0', title=None, description=None, self.blueprint = None # must come after self.app initialisation to prevent __getattr__ recursion # in self._configure_namespace_logger - self.default_namespace = self.namespace(default, default_label, - endpoint='{0}-declaration'.format(default), + self.default_namespace = self.namespace( + default, + default_label, + endpoint="{0}-declaration".format(default), validate=validate, api=self, - path='/', + path="/", ) if app is not None: self.app = app @@ -155,7 +183,7 @@ def __init__(self, app=None, version='1.0', title=None, description=None, # super(Api, self).__init__(app, **kwargs) def init_app(self, app, **kwargs): - ''' + """ Allow to lazy register the API on a Flask application:: >>> app = Flask(__name__) @@ -170,17 +198,17 @@ def init_app(self, app, **kwargs): :param str license: The license associated to the API (used in Swagger documentation) :param str license_url: The license page URL (used in Swagger documentation) - ''' + """ self.app = app - self.title = kwargs.get('title', self.title) - self.description = kwargs.get('description', self.description) - self.terms_url = kwargs.get('terms_url', self.terms_url) - self.contact = kwargs.get('contact', self.contact) - self.contact_url = kwargs.get('contact_url', self.contact_url) - self.contact_email = kwargs.get('contact_email', self.contact_email) - self.license = kwargs.get('license', self.license) - self.license_url = kwargs.get('license_url', self.license_url) - self._add_specs = kwargs.get('add_specs', True) + self.title = kwargs.get("title", self.title) + self.description = kwargs.get("description", self.description) + self.terms_url = kwargs.get("terms_url", self.terms_url) + self.contact = kwargs.get("contact", self.contact) + self.contact_url = kwargs.get("contact_url", self.contact_url) + self.contact_email = kwargs.get("contact_email", self.contact_email) + self.license = kwargs.get("license", self.license) + self.license_url = kwargs.get("license_url", self.license_url) + self._add_specs = kwargs.get("add_specs", True) # If app is a blueprint, defer the initialization try: @@ -192,16 +220,18 @@ def init_app(self, app, **kwargs): self.blueprint = app def _init_app(self, app): - ''' + """ Perform initialization actions with the given :class:`flask.Flask` object. :param flask.Flask app: The flask application object - ''' + """ self._register_specs(self.blueprint or app) self._register_doc(self.blueprint or app) app.handle_exception = partial(self.error_router, app.handle_exception) - app.handle_user_exception = partial(self.error_router, app.handle_user_exception) + app.handle_user_exception = partial( + self.error_router, app.handle_user_exception + ) if len(self.resources) > 0: for resource, namespace, urls, kwargs in self.resources: @@ -211,58 +241,62 @@ def _init_app(self, app): self._configure_namespace_logger(app, ns) self._register_apidoc(app) - self._validate = self._validate if self._validate is not None else app.config.get('RESTX_VALIDATE', False) - app.config.setdefault('RESTX_MASK_HEADER', 'X-Fields') - app.config.setdefault('RESTX_MASK_SWAGGER', True) + self._validate = ( + self._validate + if self._validate is not None + else app.config.get("RESTX_VALIDATE", False) + ) + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + app.config.setdefault("RESTX_MASK_SWAGGER", True) def __getattr__(self, name): try: return getattr(self.default_namespace, name) except AttributeError: - raise AttributeError('Api does not have {0} attribute'.format(name)) + raise AttributeError("Api does not have {0} attribute".format(name)) def _complete_url(self, url_part, registration_prefix): - ''' + """ This method is used to defer the construction of the final url in the case that the Api is created with a Blueprint. :param url_part: The part of the url the endpoint is registered with :param registration_prefix: The part of the url contributed by the blueprint. Generally speaking, BlueprintSetupState.url_prefix - ''' + """ parts = (registration_prefix, self.prefix, url_part) - return ''.join(part for part in parts if part) + return "".join(part for part in parts if part) def _register_apidoc(self, app): - conf = app.extensions.setdefault('restx', {}) - if not conf.get('apidoc_registered', False): + conf = app.extensions.setdefault("restx", {}) + if not conf.get("apidoc_registered", False): app.register_blueprint(apidoc.apidoc) - conf['apidoc_registered'] = True + conf["apidoc_registered"] = True def _register_specs(self, app_or_blueprint): if self._add_specs: - endpoint = str('specs') + endpoint = str("specs") self._register_view( app_or_blueprint, SwaggerView, self.default_namespace, - '/swagger.json', + "/swagger.json", endpoint=endpoint, - resource_class_args=(self, ) + resource_class_args=(self,), ) self.endpoints.add(endpoint) def _register_doc(self, app_or_blueprint): if self._add_specs and self._doc: # Register documentation before root if enabled - app_or_blueprint.add_url_rule(self._doc, 'doc', self.render_doc) - app_or_blueprint.add_url_rule(self.prefix or '/', 'root', self.render_root) + app_or_blueprint.add_url_rule(self._doc, "doc", self.render_doc) + app_or_blueprint.add_url_rule(self.prefix or "/", "root", self.render_root) def register_resource(self, namespace, resource, *urls, **kwargs): - endpoint = kwargs.pop('endpoint', None) + endpoint = kwargs.pop("endpoint", None) endpoint = str(endpoint or self.default_endpoint(resource, namespace)) - kwargs['endpoint'] = endpoint + kwargs["endpoint"] = endpoint self.endpoints.add(endpoint) if self.app is not None: @@ -277,25 +311,28 @@ def _configure_namespace_logger(self, app, namespace): namespace.logger.setLevel(app.logger.level) def _register_view(self, app, resource, namespace, *urls, **kwargs): - endpoint = kwargs.pop('endpoint', None) or camel_to_dash(resource.__name__) - resource_class_args = kwargs.pop('resource_class_args', ()) - resource_class_kwargs = kwargs.pop('resource_class_kwargs', {}) + endpoint = kwargs.pop("endpoint", None) or camel_to_dash(resource.__name__) + resource_class_args = kwargs.pop("resource_class_args", ()) + resource_class_kwargs = kwargs.pop("resource_class_kwargs", {}) # NOTE: 'view_functions' is cleaned up from Blueprint class in Flask 1.0 - if endpoint in getattr(app, 'view_functions', {}): - previous_view_class = app.view_functions[endpoint].__dict__['view_class'] + if endpoint in getattr(app, "view_functions", {}): + previous_view_class = app.view_functions[endpoint].__dict__["view_class"] # if you override the endpoint with a different class, avoid the # collision by raising an exception if previous_view_class != resource: - msg = 'This endpoint (%s) is already set to the class %s.' + msg = "This endpoint (%s) is already set to the class %s." raise ValueError(msg % (endpoint, previous_view_class.__name__)) resource.mediatypes = self.mediatypes_method() # Hacky resource.endpoint = endpoint - resource_func = self.output(resource.as_view(endpoint, self, *resource_class_args, - **resource_class_kwargs)) + resource_func = self.output( + resource.as_view( + endpoint, self, *resource_class_args, **resource_class_kwargs + ) + ) # Apply Namespace and Api decorators to a resource for decorator in chain(namespace.decorators, self.decorators): @@ -308,7 +345,9 @@ def _register_view(self, app, resource, namespace, *urls, **kwargs): if self.blueprint_setup: # Set the rule to a string directly, as the blueprint is already # set up. - self.blueprint_setup.add_url_rule(url, view_func=resource_func, **kwargs) + self.blueprint_setup.add_url_rule( + url, view_func=resource_func, **kwargs + ) continue else: # Set the rule to a function that expects the blueprint prefix @@ -319,17 +358,18 @@ def _register_view(self, app, resource, namespace, *urls, **kwargs): rule = partial(self._complete_url, url) else: # If we've got no Blueprint, just build a url with no prefix - rule = self._complete_url(url, '') + rule = self._complete_url(url, "") # Add the url to the application or blueprint app.add_url_rule(rule, view_func=resource_func, **kwargs) def output(self, resource): - ''' + """ Wraps a resource (as a flask view function), for cases where the resource does not directly return a response object :param resource: The resource as a flask view function - ''' + """ + @wraps(resource) def wrapper(*args, **kwargs): resp = resource(*args, **kwargs) @@ -337,10 +377,11 @@ def wrapper(*args, **kwargs): return resp data, code, headers = unpack(resp) return self.make_response(data, code, headers=headers) + return wrapper def make_response(self, data, *args, **kwargs): - ''' + """ Looks up the representation transformer for the requested media type, invoking the transformer to create a response object. This defaults to default_mediatype if no transformer is found for the @@ -348,27 +389,28 @@ def make_response(self, data, *args, **kwargs): Acceptable response will be sent as per RFC 2616 section 14.1 :param data: Python object containing response data to be transformed - ''' - default_mediatype = kwargs.pop('fallback_mediatype', None) or self.default_mediatype + """ + default_mediatype = ( + kwargs.pop("fallback_mediatype", None) or self.default_mediatype + ) mediatype = request.accept_mimetypes.best_match( - self.representations, - default=default_mediatype, + self.representations, default=default_mediatype, ) if mediatype is None: raise NotAcceptable() if mediatype in self.representations: resp = self.representations[mediatype](data, *args, **kwargs) - resp.headers['Content-Type'] = mediatype + resp.headers["Content-Type"] = mediatype return resp - elif mediatype == 'text/plain': + elif mediatype == "text/plain": resp = original_flask_make_response(str(data), *args, **kwargs) - resp.headers['Content-Type'] = 'text/plain' + resp.headers["Content-Type"] = "text/plain" return resp else: raise InternalServerError() def documentation(self, func): - '''A decorator to specify a view function for the documentation''' + """A decorator to specify a view function for the documentation""" self._doc_view = func return func @@ -376,7 +418,7 @@ def render_root(self): self.abort(HTTPStatus.NOT_FOUND) def render_doc(self): - '''Override this method to customize the documentation page''' + """Override this method to customize the documentation page""" if self._doc_view: return self._doc_view() elif not self._doc: @@ -384,7 +426,7 @@ def render_doc(self): return apidoc.ui_for(self) def default_endpoint(self, resource, namespace): - ''' + """ Provide a default endpoint for a resource on a given namespace. Endpoints are ensured not to collide. @@ -394,14 +436,14 @@ def default_endpoint(self, resource, namespace): :param Resource resource: the resource for which we want an endpoint :param Namespace namespace: the namespace holding the resource :returns str: An endpoint name - ''' + """ endpoint = camel_to_dash(resource.__name__) if namespace is not self.default_namespace: - endpoint = '{ns.name}_{endpoint}'.format(ns=namespace, endpoint=endpoint) + endpoint = "{ns.name}_{endpoint}".format(ns=namespace, endpoint=endpoint) if endpoint in self.endpoints: suffix = 2 while True: - new_endpoint = '{base}_{suffix}'.format(base=endpoint, suffix=suffix) + new_endpoint = "{base}_{suffix}".format(base=endpoint, suffix=suffix) if new_endpoint not in self.endpoints: endpoint = new_endpoint break @@ -416,13 +458,13 @@ def ns_urls(self, ns, urls): return [path + url for url in urls] def add_namespace(self, ns, path=None): - ''' + """ This method registers resources from namespace for current instance of api. You can use argument path for definition custom prefix url for namespace. :param Namespace ns: the namespace :param path: registration prefix of namespace - ''' + """ if ns not in self.namespaces: self.namespaces.append(ns) if self not in ns.apis: @@ -441,65 +483,65 @@ def add_namespace(self, ns, path=None): self._configure_namespace_logger(self.app, ns) def namespace(self, *args, **kwargs): - ''' + """ A namespace factory. :returns Namespace: a new namespace instance - ''' - kwargs['ordered'] = kwargs.get('ordered', self.ordered) + """ + kwargs["ordered"] = kwargs.get("ordered", self.ordered) ns = Namespace(*args, **kwargs) self.add_namespace(ns) return ns def endpoint(self, name): if self.blueprint: - return '{0}.{1}'.format(self.blueprint.name, name) + return "{0}.{1}".format(self.blueprint.name, name) else: return name @property def specs_url(self): - ''' + """ The Swagger specifications absolute url (ie. `swagger.json`) :rtype: str - ''' - return url_for(self.endpoint('specs'), _external=True) + """ + return url_for(self.endpoint("specs"), _external=True) @property def base_url(self): - ''' + """ The API base absolute url :rtype: str - ''' - return url_for(self.endpoint('root'), _external=True) + """ + return url_for(self.endpoint("root"), _external=True) @property def base_path(self): - ''' + """ The API path :rtype: str - ''' - return url_for(self.endpoint('root'), _external=False) + """ + return url_for(self.endpoint("root"), _external=False) @cached_property def __schema__(self): - ''' + """ The Swagger specifications/schema for this API :returns dict: the schema as a serializable dict - ''' + """ if not self._schema: try: self._schema = Swagger(self).as_dict() except Exception: # Log the source exception for debugging purpose # and return an error message - msg = 'Unable to render schema' + msg = "Unable to render schema" log.exception(msg) # This will provide a full traceback - return {'error': msg} + return {"error": msg} return self._schema @property @@ -512,12 +554,13 @@ def _own_and_child_error_handlers(self): return rv def errorhandler(self, exception): - '''A decorator to register an error handler for a given exception''' + """A decorator to register an error handler for a given exception""" if inspect.isclass(exception) and issubclass(exception, Exception): # Register an error handler for a given exception def wrapper(func): self.error_handlers[exception] = func return func + return wrapper else: # Register the default error handler @@ -525,23 +568,23 @@ def wrapper(func): return exception def owns_endpoint(self, endpoint): - ''' + """ Tests if an endpoint name (not path) belongs to this Api. Takes into account the Blueprint name part of the endpoint name. :param str endpoint: The name of the endpoint being checked :return: bool - ''' + """ if self.blueprint: if endpoint.startswith(self.blueprint.name): - endpoint = endpoint.split(self.blueprint.name + '.', 1)[-1] + endpoint = endpoint.split(self.blueprint.name + ".", 1)[-1] else: return False return endpoint in self.endpoints def _should_use_fr_error_handler(self): - ''' + """ Determine if error should be handled with FR or default Flask The goal is to return Flask error handlers for non-FR-related routes, @@ -549,7 +592,7 @@ def _should_use_fr_error_handler(self): method currently handles 404 and 405 errors. :return: bool - ''' + """ adapter = current_app.create_url_adapter(request) try: @@ -566,7 +609,7 @@ def _should_use_fr_error_handler(self): pass def _has_fr_route(self): - '''Encapsulating the rules for whether the request was to a Flask endpoint''' + """Encapsulating the rules for whether the request was to a Flask endpoint""" # 404's, 405's, which might not have a url_rule if self._should_use_fr_error_handler(): return True @@ -576,7 +619,7 @@ def _has_fr_route(self): return self.owns_endpoint(request.url_rule.endpoint) def error_router(self, original_handler, e): - ''' + """ This function decides whether the error occurred in a flask-restx endpoint or not. If it happened in a flask-restx endpoint, our handler will be dispatched. If it happened in an unrelated view, the @@ -587,7 +630,7 @@ def error_router(self, original_handler, e): :param function original_handler: the original Flask error handler for the app :param Exception e: the exception raised while handling the request - ''' + """ if self._has_fr_route(): try: return self.handle_error(e) @@ -596,20 +639,22 @@ def error_router(self, original_handler, e): return original_handler(e) def handle_error(self, e): - ''' + """ Error handler for the API transforms a raised exception into a Flask response, with the appropriate HTTP status code and body. :param Exception e: the raised Exception object - ''' + """ got_request_exception.send(current_app._get_current_object(), exception=e) # When propagate_exceptions is set, do not return the exception to the # client if a handler is configured for the exception. - if not isinstance(e, HTTPException) and \ - current_app.propagate_exceptions and \ - not isinstance(e, tuple(self.error_handlers.keys())): + if ( + not isinstance(e, HTTPException) + and current_app.propagate_exceptions + and not isinstance(e, tuple(self.error_handlers.keys())) + ): exc_type, exc_value, tb = sys.exc_info() if exc_value is e: @@ -617,7 +662,9 @@ def handle_error(self, e): else: raise e - include_message_in_response = current_app.config.get("ERROR_INCLUDE_MESSAGE", True) + include_message_in_response = current_app.config.get( + "ERROR_INCLUDE_MESSAGE", True + ) default_data = {} headers = Headers() @@ -625,30 +672,32 @@ def handle_error(self, e): for typecheck, handler in six.iteritems(self._own_and_child_error_handlers): if isinstance(e, typecheck): result = handler(e) - default_data, code, headers = unpack(result, HTTPStatus.INTERNAL_SERVER_ERROR) + default_data, code, headers = unpack( + result, HTTPStatus.INTERNAL_SERVER_ERROR + ) break else: if isinstance(e, HTTPException): code = HTTPStatus(e.code) if include_message_in_response: - default_data = { - 'message': getattr(e, 'description', code.phrase) - } + default_data = {"message": getattr(e, "description", code.phrase)} headers = e.get_response().headers elif self._default_error_handler: result = self._default_error_handler(e) - default_data, code, headers = unpack(result, HTTPStatus.INTERNAL_SERVER_ERROR) + default_data, code, headers = unpack( + result, HTTPStatus.INTERNAL_SERVER_ERROR + ) else: code = HTTPStatus.INTERNAL_SERVER_ERROR if include_message_in_response: default_data = { - 'message': code.phrase, + "message": code.phrase, } if include_message_in_response: - default_data['message'] = default_data.get('message', str(e)) + default_data["message"] = default_data.get("message", str(e)) - data = getattr(e, 'data', default_data) + data = getattr(e, "data", default_data) fallback_mediatype = None if code >= HTTPStatus.INTERNAL_SERVER_ERROR: @@ -657,9 +706,12 @@ def handle_error(self, e): exc_info = None current_app.log_exception(exc_info) - elif code == HTTPStatus.NOT_FOUND and current_app.config.get("ERROR_404_HELP", True) \ - and include_message_in_response: - data['message'] = self._help_on_404(data.get('message', None)) + elif ( + code == HTTPStatus.NOT_FOUND + and current_app.config.get("ERROR_404_HELP", True) + and include_message_in_response + ): + data["message"] = self._help_on_404(data.get("message", None)) elif code == HTTPStatus.NOT_ACCEPTABLE and self.default_mediatype is None: # if we are handling NotAcceptable (406), make sure that @@ -667,47 +719,57 @@ def handle_error(self, e): # default mediatype (so that make_response doesn't throw # another NotAcceptable error). supported_mediatypes = list(self.representations.keys()) - fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain" + fallback_mediatype = ( + supported_mediatypes[0] if supported_mediatypes else "text/plain" + ) # Remove blacklisted headers for header in HEADERS_BLACKLIST: headers.pop(header, None) - resp = self.make_response(data, code, headers, fallback_mediatype=fallback_mediatype) + resp = self.make_response( + data, code, headers, fallback_mediatype=fallback_mediatype + ) if code == HTTPStatus.UNAUTHORIZED: resp = self.unauthorized(resp) return resp def _help_on_404(self, message=None): - rules = dict([(RE_RULES.sub('', rule.rule), rule.rule) - for rule in current_app.url_map.iter_rules()]) + rules = dict( + [ + (RE_RULES.sub("", rule.rule), rule.rule) + for rule in current_app.url_map.iter_rules() + ] + ) close_matches = difflib.get_close_matches(request.path, rules.keys()) if close_matches: # If we already have a message, add punctuation and continue it. - message = ''.join(( - (message.rstrip('.') + '. ') if message else '', - 'You have requested this URI [', - request.path, - '] but did you mean ', - ' or '.join((rules[match] for match in close_matches)), - ' ?', - )) + message = "".join( + ( + (message.rstrip(".") + ". ") if message else "", + "You have requested this URI [", + request.path, + "] but did you mean ", + " or ".join((rules[match] for match in close_matches)), + " ?", + ) + ) return message def as_postman(self, urlvars=False, swagger=False): - ''' + """ Serialize the API as Postman collection (v1) :param bool urlvars: whether to include or not placeholders for query strings :param bool swagger: whether to include or not the swagger.json specifications - ''' + """ return PostmanCollectionV1(self, swagger=swagger).as_dict(urlvars=urlvars) @property def payload(self): - '''Store the input payload in the current request context''' + """Store the input payload in the current request context""" return request.get_json() @property @@ -717,8 +779,10 @@ def refresolver(self): return self._refresolver @staticmethod - def _blueprint_setup_add_url_rule_patch(blueprint_setup, rule, endpoint=None, view_func=None, **options): - ''' + def _blueprint_setup_add_url_rule_patch( + blueprint_setup, rule, endpoint=None, view_func=None, **options + ): + """ Method used to patch BlueprintSetupState.add_url_rule for setup state instance corresponding to this Api instance. Exists primarily to enable _complete_url's function. @@ -730,23 +794,28 @@ def _blueprint_setup_add_url_rule_patch(blueprint_setup, rule, endpoint=None, vi :param endpoint: See BlueprintSetupState.add_url_rule :param view_func: See BlueprintSetupState.add_url_rule :param **options: See BlueprintSetupState.add_url_rule - ''' + """ if callable(rule): rule = rule(blueprint_setup.url_prefix) elif blueprint_setup.url_prefix: rule = blueprint_setup.url_prefix + rule - options.setdefault('subdomain', blueprint_setup.subdomain) + options.setdefault("subdomain", blueprint_setup.subdomain) if endpoint is None: endpoint = _endpoint_from_view_func(view_func) defaults = blueprint_setup.url_defaults - if 'defaults' in options: - defaults = dict(defaults, **options.pop('defaults')) - blueprint_setup.app.add_url_rule(rule, '%s.%s' % (blueprint_setup.blueprint.name, endpoint), - view_func, defaults=defaults, **options) + if "defaults" in options: + defaults = dict(defaults, **options.pop("defaults")) + blueprint_setup.app.add_url_rule( + rule, + "%s.%s" % (blueprint_setup.blueprint.name, endpoint), + view_func, + defaults=defaults, + **options + ) def _deferred_blueprint_init(self, setup_state): - ''' + """ Synchronize prefix between blueprint/api and registration options, then perform initialization with setup_state.app :class:`flask.Flask` object. When a :class:`flask_restx.Api` object is initialized with a blueprint, @@ -758,28 +827,33 @@ def _deferred_blueprint_init(self, setup_state): during blueprint registration :type setup_state: flask.blueprints.BlueprintSetupState - ''' + """ self.blueprint_setup = setup_state - if setup_state.add_url_rule.__name__ != '_blueprint_setup_add_url_rule_patch': + if setup_state.add_url_rule.__name__ != "_blueprint_setup_add_url_rule_patch": setup_state._original_add_url_rule = setup_state.add_url_rule - setup_state.add_url_rule = MethodType(Api._blueprint_setup_add_url_rule_patch, - setup_state) + setup_state.add_url_rule = MethodType( + Api._blueprint_setup_add_url_rule_patch, setup_state + ) if not setup_state.first_registration: - raise ValueError('flask-restx blueprints can only be registered once.') + raise ValueError("flask-restx blueprints can only be registered once.") self._init_app(setup_state.app) def mediatypes_method(self): - '''Return a method that returns a list of mediatypes''' + """Return a method that returns a list of mediatypes""" return lambda resource_cls: self.mediatypes() + [self.default_mediatype] def mediatypes(self): - '''Returns a list of requested mediatypes sent in the Accept header''' - return [h for h, q in sorted(request.accept_mimetypes, - key=operator.itemgetter(1), reverse=True)] + """Returns a list of requested mediatypes sent in the Accept header""" + return [ + h + for h, q in sorted( + request.accept_mimetypes, key=operator.itemgetter(1), reverse=True + ) + ] def representation(self, mediatype): - ''' + """ Allows additional representation transformers to be declared for the api. Transformers are functions that must be decorated with this method, passing the mediatype the transformer represents. Three @@ -799,49 +873,55 @@ def xml(data, code, headers): resp = make_response(convert_data_to_xml(data), code) resp.headers.extend(headers) return resp - ''' + """ + def wrapper(func): self.representations[mediatype] = func return func + return wrapper def unauthorized(self, response): - '''Given a response, change it to ask for credentials''' + """Given a response, change it to ask for credentials""" if self.serve_challenge_on_401: realm = current_app.config.get("HTTP_BASIC_AUTH_REALM", "flask-restx") - challenge = u"{0} realm=\"{1}\"".format("Basic", realm) + challenge = '{0} realm="{1}"'.format("Basic", realm) - response.headers['WWW-Authenticate'] = challenge + response.headers["WWW-Authenticate"] = challenge return response def url_for(self, resource, **values): - ''' + """ Generates a URL to the given resource. Works like :func:`flask.url_for`. - ''' + """ endpoint = resource.endpoint if self.blueprint: - endpoint = '{0}.{1}'.format(self.blueprint.name, endpoint) + endpoint = "{0}.{1}".format(self.blueprint.name, endpoint) return url_for(endpoint, **values) class SwaggerView(Resource): - '''Render the Swagger specifications as JSON''' + """Render the Swagger specifications as JSON""" + def get(self): schema = self.api.__schema__ - return schema, HTTPStatus.INTERNAL_SERVER_ERROR if 'error' in schema else HTTPStatus.OK + return ( + schema, + HTTPStatus.INTERNAL_SERVER_ERROR if "error" in schema else HTTPStatus.OK, + ) def mediatypes(self): - return ['application/json'] + return ["application/json"] def mask_parse_error_handler(error): - '''When a mask can't be parsed''' - return {'message': 'Mask parse error: {0}'.format(error)}, HTTPStatus.BAD_REQUEST + """When a mask can't be parsed""" + return {"message": "Mask parse error: {0}".format(error)}, HTTPStatus.BAD_REQUEST def mask_error_handler(error): - '''When any error occurs on mask''' - return {'message': 'Mask error: {0}'.format(error)}, HTTPStatus.BAD_REQUEST + """When any error occurs on mask""" + return {"message": "Mask error: {0}".format(error)}, HTTPStatus.BAD_REQUEST diff --git a/flask_restx/apidoc.py b/flask_restx/apidoc.py index 8b31585d..23c753ec 100644 --- a/flask_restx/apidoc.py +++ b/flask_restx/apidoc.py @@ -5,10 +5,11 @@ class Apidoc(Blueprint): - ''' + """ Allow to know if the blueprint has already been registered until https://github.com/mitsuhiko/flask/pull/1301 is merged - ''' + """ + def __init__(self, *args, **kwargs): self.registered = False super(Apidoc, self).__init__(*args, **kwargs) @@ -18,19 +19,20 @@ def register(self, *args, **kwargs): self.registered = True -apidoc = Apidoc('restx_doc', __name__, - template_folder='templates', - static_folder='static', - static_url_path='/swaggerui', +apidoc = Apidoc( + "restx_doc", + __name__, + template_folder="templates", + static_folder="static", + static_url_path="/swaggerui", ) @apidoc.add_app_template_global def swagger_static(filename): - return url_for('restx_doc.static', filename=filename) + return url_for("restx_doc.static", filename=filename) def ui_for(api): - '''Render a SwaggerUI for a given API''' - return render_template('swagger-ui.html', title=api.title, - specs_url=api.specs_url) + """Render a SwaggerUI for a given API""" + return render_template("swagger-ui.html", title=api.title, specs_url=api.specs_url) diff --git a/flask_restx/cors.py b/flask_restx/cors.py index d2d0a2e8..1b0faa57 100644 --- a/flask_restx/cors.py +++ b/flask_restx/cors.py @@ -6,20 +6,27 @@ from functools import update_wrapper -def crossdomain(origin=None, methods=None, headers=None, expose_headers=None, - max_age=21600, attach_to_all=True, - automatic_options=True, credentials=False): +def crossdomain( + origin=None, + methods=None, + headers=None, + expose_headers=None, + max_age=21600, + attach_to_all=True, + automatic_options=True, + credentials=False, +): """ http://flask.pocoo.org/snippets/56/ """ if methods is not None: - methods = ', '.join(sorted(x.upper() for x in methods)) + methods = ", ".join(sorted(x.upper() for x in methods)) if headers is not None and not isinstance(headers, str): - headers = ', '.join(x.upper() for x in headers) + headers = ", ".join(x.upper() for x in headers) if expose_headers is not None and not isinstance(expose_headers, str): - expose_headers = ', '.join(x.upper() for x in expose_headers) + expose_headers = ", ".join(x.upper() for x in expose_headers) if not isinstance(origin, str): - origin = ', '.join(origin) + origin = ", ".join(origin) if isinstance(max_age, timedelta): max_age = max_age.total_seconds() @@ -28,30 +35,31 @@ def get_methods(): return methods options_resp = current_app.make_default_options_response() - return options_resp.headers['allow'] + return options_resp.headers["allow"] def decorator(f): def wrapped_function(*args, **kwargs): - if automatic_options and request.method == 'OPTIONS': + if automatic_options and request.method == "OPTIONS": resp = current_app.make_default_options_response() else: resp = make_response(f(*args, **kwargs)) - if not attach_to_all and request.method != 'OPTIONS': + if not attach_to_all and request.method != "OPTIONS": return resp h = resp.headers - h['Access-Control-Allow-Origin'] = origin - h['Access-Control-Allow-Methods'] = get_methods() - h['Access-Control-Max-Age'] = str(max_age) + h["Access-Control-Allow-Origin"] = origin + h["Access-Control-Allow-Methods"] = get_methods() + h["Access-Control-Max-Age"] = str(max_age) if credentials: - h['Access-Control-Allow-Credentials'] = 'true' + h["Access-Control-Allow-Credentials"] = "true" if headers is not None: - h['Access-Control-Allow-Headers'] = headers + h["Access-Control-Allow-Headers"] = headers if expose_headers is not None: - h['Access-Control-Expose-Headers'] = expose_headers + h["Access-Control-Expose-Headers"] = expose_headers return resp f.provide_automatic_options = False return update_wrapper(wrapped_function, f) + return decorator diff --git a/flask_restx/errors.py b/flask_restx/errors.py index c544b7e1..fbe6e21a 100644 --- a/flask_restx/errors.py +++ b/flask_restx/errors.py @@ -8,15 +8,15 @@ from ._http import HTTPStatus __all__ = ( - 'abort', - 'RestError', - 'ValidationError', - 'SpecsError', + "abort", + "RestError", + "ValidationError", + "SpecsError", ) def abort(code=HTTPStatus.INTERNAL_SERVER_ERROR, message=None, **kwargs): - ''' + """ Properly abort the current request. Raise a `HTTPException` for the given status `code`. @@ -26,19 +26,20 @@ def abort(code=HTTPStatus.INTERNAL_SERVER_ERROR, message=None, **kwargs): :param str message: An optional details message :param kwargs: Any additional data to pass to the error payload :raise HTTPException: - ''' + """ try: flask.abort(code) except HTTPException as e: if message: - kwargs['message'] = str(message) + kwargs["message"] = str(message) if kwargs: e.data = kwargs raise class RestError(Exception): - '''Base class for all Flask-RESTX Errors''' + """Base class for all Flask-RESTX Errors""" + def __init__(self, msg): self.msg = msg @@ -47,10 +48,12 @@ def __str__(self): class ValidationError(RestError): - '''A helper class for validation errors.''' + """A helper class for validation errors.""" + pass class SpecsError(RestError): - '''A helper class for incoherent specifications.''' + """A helper class for incoherent specifications.""" + pass diff --git a/flask_restx/fields.py b/flask_restx/fields.py index 6b7fdc36..a9da9b93 100644 --- a/flask_restx/fields.py +++ b/flask_restx/fields.py @@ -16,22 +16,46 @@ from flask import url_for, request from werkzeug.utils import cached_property -from .inputs import date_from_iso8601, datetime_from_iso8601, datetime_from_rfc822, boolean +from .inputs import ( + date_from_iso8601, + datetime_from_iso8601, + datetime_from_rfc822, + boolean, +) from .errors import RestError from .marshalling import marshal from .utils import camel_to_dash, not_none -__all__ = ('Raw', 'String', 'FormattedString', 'Url', 'DateTime', 'Date', - 'Boolean', 'Integer', 'Float', 'Arbitrary', 'Fixed', - 'Nested', 'List', 'ClassName', 'Polymorph', 'Wildcard', - 'StringMixin', 'MinMaxMixin', 'NumberMixin', 'MarshallingError') +__all__ = ( + "Raw", + "String", + "FormattedString", + "Url", + "DateTime", + "Date", + "Boolean", + "Integer", + "Float", + "Arbitrary", + "Fixed", + "Nested", + "List", + "ClassName", + "Polymorph", + "Wildcard", + "StringMixin", + "MinMaxMixin", + "NumberMixin", + "MarshallingError", +) class MarshallingError(RestError): - ''' + """ This is an encapsulating Exception in case of marshalling error. - ''' + """ + def __init__(self, underlying_exception): # just put the contextual representation of the error to hint on what # went wrong without exposing internals @@ -43,13 +67,13 @@ def is_indexable_but_not_string(obj): def get_value(key, obj, default=None): - '''Helper for pulling a keyed value off various types of objects''' + """Helper for pulling a keyed value off various types of objects""" if isinstance(key, int): return _get_value_for_key(key, obj, default) elif callable(key): return key(obj) else: - return _get_value_for_keys(key.split('.'), obj, default) + return _get_value_for_keys(key.split("."), obj, default) def _get_value_for_keys(keys, obj, default): @@ -57,7 +81,8 @@ def _get_value_for_keys(keys, obj, default): return _get_value_for_key(keys[0], obj, default) else: return _get_value_for_keys( - keys[1:], _get_value_for_key(keys[0], obj, default), default) + keys[1:], _get_value_for_key(keys[0], obj, default), default + ) def _get_value_for_key(key, obj, default): @@ -70,24 +95,24 @@ def _get_value_for_key(key, obj, default): def to_marshallable_type(obj): - ''' + """ Helper for converting an object to a dictionary only if it is not dictionary already or an indexable object nor a simple type - ''' + """ if obj is None: return None # make it idempotent for None - if hasattr(obj, '__marshallable__'): + if hasattr(obj, "__marshallable__"): return obj.__marshallable__() - if hasattr(obj, '__getitem__'): + if hasattr(obj, "__getitem__"): return obj # it is indexable it is ok return dict(obj.__dict__) class Raw(object): - ''' + """ Raw provides a base field class from which others should extend. It applies no formatting by default, and should only be used in cases where data does not need to be formatted before being serialized. Fields should @@ -104,16 +129,27 @@ class Raw(object): :param bool readonly: Is the field read only ? (for documentation purpose) :param example: An optional data example (for documentation purpose) :param callable mask: An optional mask function to be applied to output - ''' + """ + #: The JSON/Swagger schema type - __schema_type__ = 'object' + __schema_type__ = "object" #: The JSON/Swagger schema format __schema_format__ = None #: An optional JSON/Swagger schema example __schema_example__ = None - def __init__(self, default=None, attribute=None, title=None, description=None, - required=None, readonly=None, example=None, mask=None, **kwargs): + def __init__( + self, + default=None, + attribute=None, + title=None, + description=None, + required=None, + readonly=None, + example=None, + mask=None, + **kwargs + ): self.attribute = attribute self.default = default self.title = title @@ -124,7 +160,7 @@ def __init__(self, default=None, attribute=None, title=None, description=None, self.mask = mask def format(self, value): - ''' + """ Formats a field's value. No-op by default - field classes that modify how the value of existing object keys should be presented should override this and apply the appropriate formatting. @@ -137,11 +173,11 @@ def format(self, value): class TitleCase(Raw): def format(self, value): return unicode(value).title() - ''' + """ return value def output(self, key, obj, **kwargs): - ''' + """ Pulls the value for the given key from the object, applies the field's formatting and returns the result. If the key is not found in the object, returns the default value. Field classes that create @@ -149,23 +185,25 @@ def output(self, key, obj, **kwargs): should override this and return the desired value. :raises MarshallingError: In case of formatting problem - ''' + """ value = get_value(key if self.attribute is None else self.attribute, obj) if value is None: - default = self._v('default') + default = self._v("default") return self.format(default) if default else default try: data = self.format(value) except MarshallingError as e: - msg = 'Unable to marshal field "{0}" value "{1}": {2}'.format(key, value, str(e)) + msg = 'Unable to marshal field "{0}" value "{1}": {2}'.format( + key, value, str(e) + ) raise MarshallingError(msg) return self.mask.apply(data) if self.mask else data def _v(self, key): - '''Helper for getting a value from attribute allowing callable''' + """Helper for getting a value from attribute allowing callable""" value = getattr(self, key) return value() if callable(value) else value @@ -175,18 +213,18 @@ def __schema__(self): def schema(self): return { - 'type': self.__schema_type__, - 'format': self.__schema_format__, - 'title': self.title, - 'description': self.description, - 'readOnly': self.readonly, - 'default': self._v('default'), - 'example': self.example, + "type": self.__schema_type__, + "format": self.__schema_format__, + "title": self.title, + "description": self.description, + "readOnly": self.readonly, + "default": self._v("default"), + "example": self.example, } class Nested(Raw): - ''' + """ Allows you to nest one set of fields inside another. See :ref:`nested-field` for more information @@ -200,10 +238,13 @@ class Nested(Raw): dictionary will be marshaled as its value if nested dictionary is all-null keys (e.g. lets you return an empty JSON object instead of null) - ''' + """ + __schema_type__ = None - def __init__(self, model, allow_null=False, skip_none=False, as_list=False, **kwargs): + def __init__( + self, model, allow_null=False, skip_none=False, as_list=False, **kwargs + ): self.model = model self.as_list = as_list self.allow_null = allow_null @@ -212,7 +253,7 @@ def __init__(self, model, allow_null=False, skip_none=False, as_list=False, **kw @property def nested(self): - return getattr(self.model, 'resolved', self.model) + return getattr(self.model, "resolved", self.model) def output(self, key, obj, ordered=False, **kwargs): value = get_value(key if self.attribute is None else self.attribute, obj) @@ -226,42 +267,43 @@ def output(self, key, obj, ordered=False, **kwargs): def schema(self): schema = super(Nested, self).schema() - ref = '#/definitions/{0}'.format(self.nested.name) + ref = "#/definitions/{0}".format(self.nested.name) if self.as_list: - schema['type'] = 'array' - schema['items'] = {'$ref': ref} + schema["type"] = "array" + schema["items"] = {"$ref": ref} elif any(schema.values()): # There is already some properties in the schema - allOf = schema.get('allOf', []) - allOf.append({'$ref': ref}) - schema['allOf'] = allOf + allOf = schema.get("allOf", []) + allOf.append({"$ref": ref}) + schema["allOf"] = allOf else: - schema['$ref'] = ref + schema["$ref"] = ref return schema def clone(self, mask=None): kwargs = self.__dict__.copy() - model = kwargs.pop('model') + model = kwargs.pop("model") if mask: - model = mask.apply(model.resolved if hasattr(model, 'resolved') else model) + model = mask.apply(model.resolved if hasattr(model, "resolved") else model) return self.__class__(model, **kwargs) class List(Raw): - ''' + """ Field for marshalling lists of other fields. See :ref:`list-field` for more information. :param cls_or_instance: The field type the list will contain. - ''' + """ + def __init__(self, cls_or_instance, **kwargs): - self.min_items = kwargs.pop('min_items', None) - self.max_items = kwargs.pop('max_items', None) - self.unique = kwargs.pop('unique', None) + self.min_items = kwargs.pop("min_items", None) + self.max_items = kwargs.pop("max_items", None) + self.unique = kwargs.pop("unique", None) super(List, self).__init__(**kwargs) - error_msg = 'The type of the list elements must be a subclass of fields.Raw' + error_msg = "The type of the list elements must be a subclass of fields.Raw" if isinstance(cls_or_instance, type): if not issubclass(cls_or_instance, Raw): raise MarshallingError(error_msg) @@ -284,8 +326,12 @@ def is_attr(val): if value is None: return [] return [ - self.container.output(idx, - val if (isinstance(val, dict) or is_attr(val)) and not is_nested else value) + self.container.output( + idx, + val + if (isinstance(val, dict) or is_attr(val)) and not is_nested + else value, + ) for idx, val in enumerate(value) ] @@ -296,83 +342,90 @@ def output(self, key, data, ordered=False, **kwargs): return self.format(value) if value is None: - return self._v('default') + return self._v("default") return [marshal(value, self.container.nested)] def schema(self): schema = super(List, self).schema() - schema.update(minItems=self._v('min_items'), - maxItems=self._v('max_items'), - uniqueItems=self._v('unique')) - schema['type'] = 'array' - schema['items'] = self.container.__schema__ + schema.update( + minItems=self._v("min_items"), + maxItems=self._v("max_items"), + uniqueItems=self._v("unique"), + ) + schema["type"] = "array" + schema["items"] = self.container.__schema__ return schema def clone(self, mask=None): kwargs = self.__dict__.copy() - model = kwargs.pop('container') + model = kwargs.pop("container") if mask: model = mask.apply(model) return self.__class__(model, **kwargs) class StringMixin(object): - __schema_type__ = 'string' + __schema_type__ = "string" def __init__(self, *args, **kwargs): - self.min_length = kwargs.pop('min_length', None) - self.max_length = kwargs.pop('max_length', None) - self.pattern = kwargs.pop('pattern', None) + self.min_length = kwargs.pop("min_length", None) + self.max_length = kwargs.pop("max_length", None) + self.pattern = kwargs.pop("pattern", None) super(StringMixin, self).__init__(*args, **kwargs) def schema(self): schema = super(StringMixin, self).schema() - schema.update(minLength=self._v('min_length'), - maxLength=self._v('max_length'), - pattern=self._v('pattern')) + schema.update( + minLength=self._v("min_length"), + maxLength=self._v("max_length"), + pattern=self._v("pattern"), + ) return schema class MinMaxMixin(object): def __init__(self, *args, **kwargs): - self.minimum = kwargs.pop('min', None) - self.exclusiveMinimum = kwargs.pop('exclusiveMin', None) - self.maximum = kwargs.pop('max', None) - self.exclusiveMaximum = kwargs.pop('exclusiveMax', None) + self.minimum = kwargs.pop("min", None) + self.exclusiveMinimum = kwargs.pop("exclusiveMin", None) + self.maximum = kwargs.pop("max", None) + self.exclusiveMaximum = kwargs.pop("exclusiveMax", None) super(MinMaxMixin, self).__init__(*args, **kwargs) def schema(self): schema = super(MinMaxMixin, self).schema() - schema.update(minimum=self._v('minimum'), - exclusiveMinimum=self._v('exclusiveMinimum'), - maximum=self._v('maximum'), - exclusiveMaximum=self._v('exclusiveMaximum')) + schema.update( + minimum=self._v("minimum"), + exclusiveMinimum=self._v("exclusiveMinimum"), + maximum=self._v("maximum"), + exclusiveMaximum=self._v("exclusiveMaximum"), + ) return schema class NumberMixin(MinMaxMixin): - __schema_type__ = 'number' + __schema_type__ = "number" def __init__(self, *args, **kwargs): - self.multiple = kwargs.pop('multiple', None) + self.multiple = kwargs.pop("multiple", None) super(NumberMixin, self).__init__(*args, **kwargs) def schema(self): schema = super(NumberMixin, self).schema() - schema.update(multipleOf=self._v('multiple')) + schema.update(multipleOf=self._v("multiple")) return schema class String(StringMixin, Raw): - ''' + """ Marshal a value as a string. Uses ``six.text_type`` so values will be converted to :class:`unicode` in python2 and :class:`str` in python3. - ''' + """ + def __init__(self, *args, **kwargs): - self.enum = kwargs.pop('enum', None) - self.discriminator = kwargs.pop('discriminator', None) + self.enum = kwargs.pop("enum", None) + self.discriminator = kwargs.pop("discriminator", None) super(String, self).__init__(*args, **kwargs) self.required = self.discriminator or self.required @@ -383,22 +436,23 @@ def format(self, value): raise MarshallingError(ve) def schema(self): - enum = self._v('enum') + enum = self._v("enum") schema = super(String, self).schema() if enum: schema.update(enum=enum) - if enum and schema['example'] is None: - schema['example'] = enum[0] + if enum and schema["example"] is None: + schema["example"] = enum[0] return schema class Integer(NumberMixin, Raw): - ''' + """ Field for outputting an integer value. :param int default: The default value for the field, if no value is specified. - ''' - __schema_type__ = 'integer' + """ + + __schema_type__ = "integer" def format(self, value): try: @@ -410,11 +464,11 @@ def format(self, value): class Float(NumberMixin, Raw): - ''' + """ A double as IEEE-754 double precision. ex : 3.141592653589793 3.1415926535897933e-06 3.141592653589793e+24 nan inf -inf - ''' + """ def format(self, value): try: @@ -424,11 +478,11 @@ def format(self, value): class Arbitrary(NumberMixin, Raw): - ''' + """ A floating point number with an arbitrary precision. ex: 634271127864378216478362784632784678324.23432 - ''' + """ def format(self, value): return text_type(Decimal(value)) @@ -438,34 +492,36 @@ def format(self, value): class Fixed(NumberMixin, Raw): - ''' + """ A decimal number with a fixed precision. - ''' + """ + def __init__(self, decimals=5, **kwargs): super(Fixed, self).__init__(**kwargs) - self.precision = Decimal('0.' + '0' * (decimals - 1) + '1') + self.precision = Decimal("0." + "0" * (decimals - 1) + "1") def format(self, value): dvalue = Decimal(value) if not dvalue.is_normal() and dvalue != ZERO: - raise MarshallingError('Invalid Fixed precision number.') + raise MarshallingError("Invalid Fixed precision number.") return text_type(dvalue.quantize(self.precision, rounding=ROUND_HALF_EVEN)) class Boolean(Raw): - ''' + """ Field for outputting a boolean value. Empty collections such as ``""``, ``{}``, ``[]``, etc. will be converted to ``False``. - ''' - __schema_type__ = 'boolean' + """ + + __schema_type__ = "boolean" def format(self, value): return boolean(value) class DateTime(MinMaxMixin, Raw): - ''' + """ Return a formatted datetime string in UTC. Supported formats are RFC 822 and ISO 8601. See :func:`email.utils.formatdate` for more info on the RFC 822 format. @@ -473,11 +529,12 @@ class DateTime(MinMaxMixin, Raw): See :meth:`datetime.datetime.isoformat` for more info on the ISO 8601 format. :param str dt_format: ``rfc822`` or ``iso8601`` - ''' - __schema_type__ = 'string' - __schema_format__ = 'date-time' + """ - def __init__(self, dt_format='iso8601', **kwargs): + __schema_type__ = "string" + __schema_format__ = "date-time" + + def __init__(self, dt_format="iso8601", **kwargs): super(DateTime, self).__init__(**kwargs) self.dt_format = dt_format @@ -485,45 +542,47 @@ def parse(self, value): if value is None: return None elif isinstance(value, string_types): - parser = datetime_from_iso8601 if self.dt_format == 'iso8601' else datetime_from_rfc822 + parser = ( + datetime_from_iso8601 + if self.dt_format == "iso8601" + else datetime_from_rfc822 + ) return parser(value) elif isinstance(value, datetime): return value elif isinstance(value, date): return datetime(value.year, value.month, value.day) else: - raise ValueError('Unsupported DateTime format') + raise ValueError("Unsupported DateTime format") def format(self, value): try: value = self.parse(value) - if self.dt_format == 'iso8601': + if self.dt_format == "iso8601": return self.format_iso8601(value) - elif self.dt_format == 'rfc822': + elif self.dt_format == "rfc822": return self.format_rfc822(value) else: - raise MarshallingError( - 'Unsupported date format %s' % self.dt_format - ) + raise MarshallingError("Unsupported date format %s" % self.dt_format) except (AttributeError, ValueError) as e: raise MarshallingError(e) def format_rfc822(self, dt): - ''' + """ Turn a datetime object into a formatted date. :param datetime dt: The datetime to transform :return: A RFC 822 formatted date string - ''' + """ return formatdate(timegm(dt.utctimetuple())) def format_iso8601(self, dt): - ''' + """ Turn a datetime object into an ISO8601 formatted date. :param datetime dt: The datetime to transform :return: A ISO 8601 formatted date string - ''' + """ return dt.isoformat() def _for_schema(self, name): @@ -532,23 +591,24 @@ def _for_schema(self, name): def schema(self): schema = super(DateTime, self).schema() - schema['default'] = self._for_schema('default') - schema['minimum'] = self._for_schema('minimum') - schema['maximum'] = self._for_schema('maximum') + schema["default"] = self._for_schema("default") + schema["minimum"] = self._for_schema("minimum") + schema["maximum"] = self._for_schema("maximum") return schema class Date(DateTime): - ''' + """ Return a formatted date string in UTC in ISO 8601. See :meth:`datetime.date.isoformat` for more info on the ISO 8601 format. - ''' - __schema_format__ = 'date' + """ + + __schema_format__ = "date" def __init__(self, **kwargs): - kwargs.pop('dt_format', None) - super(Date, self).__init__(dt_format='iso8601', **kwargs) + kwargs.pop("dt_format", None) + super(Date, self).__init__(dt_format="iso8601", **kwargs) def parse(self, value): if value is None: @@ -560,17 +620,18 @@ def parse(self, value): elif isinstance(value, date): return value else: - raise ValueError('Unsupported Date format') + raise ValueError("Unsupported Date format") class Url(StringMixin, Raw): - ''' + """ A string representation of a Url :param str endpoint: Endpoint name. If endpoint is ``None``, ``request.endpoint`` is used instead :param bool absolute: If ``True``, ensures that the generated urls will have the hostname included :param str scheme: URL scheme specifier (e.g. ``http``, ``https``) - ''' + """ + def __init__(self, endpoint=None, absolute=False, scheme=None, **kwargs): super(Url, self).__init__(**kwargs) self.endpoint = endpoint @@ -591,7 +652,7 @@ def output(self, key, obj, **kwargs): class FormattedString(StringMixin, Raw): - ''' + """ FormattedString is used to interpolate other values from the response into this field. The syntax for the source string is the same as the string :meth:`~str.format` method from the python @@ -609,7 +670,8 @@ class FormattedString(StringMixin, Raw): marshal(data, fields) :param str src_str: the string to format with the other values from the response. - ''' + """ + def __init__(self, src_str, **kwargs): super(FormattedString, self).__init__(**kwargs) self.src_str = text_type(src_str) @@ -623,24 +685,25 @@ def output(self, key, obj, **kwargs): class ClassName(String): - ''' + """ Return the serialized object class name as string. :param bool dash: If `True`, transform CamelCase to kebab_case. - ''' + """ + def __init__(self, dash=False, **kwargs): super(ClassName, self).__init__(**kwargs) self.dash = dash def output(self, key, obj, **kwargs): classname = obj.__class__.__name__ - if classname == 'dict': - return 'object' + if classname == "dict": + return "object" return camel_to_dash(classname) if self.dash else classname class Polymorph(Nested): - ''' + """ A Nested field handling inheritance. Allows you to specify a mapping between Python classes and fields specifications. @@ -657,7 +720,8 @@ class Polymorph(Nested): }) :param dict mapping: Maps classes to their model/fields representation - ''' + """ + def __init__(self, mapping, required=False, **kwargs): self.mapping = mapping parent = self.resolve_ancestor(list(itervalues(mapping))) @@ -673,49 +737,58 @@ def output(self, key, obj, ordered=False, **kwargs): return self.default # Handle mappings - if not hasattr(value, '__class__'): - raise ValueError('Polymorph field only accept class instances') + if not hasattr(value, "__class__"): + raise ValueError("Polymorph field only accept class instances") - candidates = [fields for cls, fields in iteritems(self.mapping) if type(value) == cls] + candidates = [ + fields for cls, fields in iteritems(self.mapping) if type(value) == cls + ] if len(candidates) <= 0: - raise ValueError('Unknown class: ' + value.__class__.__name__) + raise ValueError("Unknown class: " + value.__class__.__name__) elif len(candidates) > 1: - raise ValueError('Unable to determine a candidate for: ' + value.__class__.__name__) + raise ValueError( + "Unable to determine a candidate for: " + value.__class__.__name__ + ) else: - return marshal(value, candidates[0].resolved, mask=self.mask, ordered=ordered) + return marshal( + value, candidates[0].resolved, mask=self.mask, ordered=ordered + ) def resolve_ancestor(self, models): - ''' + """ Resolve the common ancestor for all models. Assume there is only one common ancestor. - ''' + """ ancestors = [m.ancestors for m in models] candidates = set.intersection(*ancestors) if len(candidates) != 1: field_names = [f.name for f in models] - raise ValueError('Unable to determine the common ancestor for: ' + ', '.join(field_names)) + raise ValueError( + "Unable to determine the common ancestor for: " + ", ".join(field_names) + ) parent_name = candidates.pop() return models[0].get_parent(parent_name) def clone(self, mask=None): data = self.__dict__.copy() - mapping = data.pop('mapping') - for field in ('allow_null', 'model'): + mapping = data.pop("mapping") + for field in ("allow_null", "model"): data.pop(field, None) - data['mask'] = mask + data["mask"] = mask return Polymorph(mapping, **data) class Wildcard(Raw): - ''' + """ Field for marshalling list of "unkown" fields. :param cls_or_instance: The field type the list will contain. - ''' + """ + exclude = set() # cache the flat object _flat = None @@ -725,7 +798,7 @@ class Wildcard(Raw): def __init__(self, cls_or_instance, **kwargs): super(Wildcard, self).__init__(**kwargs) - error_msg = 'The type of the wildcard elements must be a subclass of fields.Raw' + error_msg = "The type of the wildcard elements must be a subclass of fields.Raw" if isinstance(cls_or_instance, type): if not issubclass(cls_or_instance, Raw): raise MarshallingError(error_msg) @@ -746,8 +819,9 @@ def _flatten(self, obj): def __match_attributes(attribute): attr_name, attr_obj = attribute - if inspect.isroutine(attr_obj) or \ - (attr_name.startswith('__') and attr_name.endswith('__')): + if inspect.isroutine(attr_obj) or ( + attr_name.startswith("__") and attr_name.endswith("__") + ): return False return True @@ -780,9 +854,11 @@ def output(self, key, obj, ordered=False): # loop over the whole object every time dropping the # complexity to O(n) (objkey, val) = self._flat.pop() - if objkey not in self._cache and \ - objkey not in self.exclude and \ - re.match(reg, objkey, re.IGNORECASE): + if ( + objkey not in self._cache + and objkey not in self.exclude + and re.match(reg, objkey, re.IGNORECASE) + ): value = val self._cache.add(objkey) self._last = objkey @@ -796,16 +872,21 @@ def output(self, key, obj, ordered=False): return None if isinstance(self.container, Nested): - return marshal(value, self.container.nested, skip_none=self.container.skip_none, ordered=ordered) + return marshal( + value, + self.container.nested, + skip_none=self.container.skip_none, + ordered=ordered, + ) return self.container.format(value) def schema(self): schema = super(Wildcard, self).schema() - schema['type'] = 'object' - schema['additionalProperties'] = self.container.__schema__ + schema["type"] = "object" + schema["additionalProperties"] = self.container.__schema__ return schema def clone(self): kwargs = self.__dict__.copy() - model = kwargs.pop('container') + model = kwargs.pop("container") return self.__class__(model, **kwargs) diff --git a/flask_restx/inputs.py b/flask_restx/inputs.py index 5a9fa6ca..b05532f3 100644 --- a/flask_restx/inputs.py +++ b/flask_restx/inputs.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -''' +""" This module provide some helpers for advanced types parsing. You can define you own parser using the same pattern: @@ -15,7 +15,7 @@ def my_type(value): my_type.__schema__ = {'type': 'string', 'format': 'my-custom-format'} The last line allows you to document properly the type in the Swagger documentation. -''' +""" from __future__ import unicode_literals import re @@ -34,55 +34,55 @@ def my_type(value): netloc_regex = re.compile( - r'(?:(?P[^:@]+?(?::[^:@]*?)?)@)?' # basic auth - r'(?:' - r'(?Plocalhost)|' # localhost... - r'(?P\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})|' # ...or ipv4 - r'(?:\[?(?P[A-F0-9]*:[A-F0-9:]+)\]?)|' # ...or ipv6 - r'(?P(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?))' # domain... - r')' - r'(?::(?P\d+))?' # optional port - r'$', re.IGNORECASE) + r"(?:(?P[^:@]+?(?::[^:@]*?)?)@)?" # basic auth + r"(?:" + r"(?Plocalhost)|" # localhost... + r"(?P\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})|" # ...or ipv4 + r"(?:\[?(?P[A-F0-9]*:[A-F0-9:]+)\]?)|" # ...or ipv6 + r"(?P(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?))" # domain... + r")" + r"(?::(?P\d+))?" # optional port + r"$", + re.IGNORECASE, +) email_regex = re.compile( - r'^' - '(?P[^@]*[^@.])' - r'@' - r'(?P[^@]+(?:\.[^@]+)*)' - r'$', re.IGNORECASE) + r"^" "(?P[^@]*[^@.])" r"@" r"(?P[^@]+(?:\.[^@]+)*)" r"$", + re.IGNORECASE, +) -time_regex = re.compile(r'\d{2}:\d{2}') +time_regex = re.compile(r"\d{2}:\d{2}") def ipv4(value): - '''Validate an IPv4 address''' + """Validate an IPv4 address""" try: socket.inet_aton(value) - if value.count('.') == 3: + if value.count(".") == 3: return value except socket.error: pass - raise ValueError('{0} is not a valid ipv4 address'.format(value)) + raise ValueError("{0} is not a valid ipv4 address".format(value)) -ipv4.__schema__ = {'type': 'string', 'format': 'ipv4'} +ipv4.__schema__ = {"type": "string", "format": "ipv4"} def ipv6(value): - '''Validate an IPv6 address''' + """Validate an IPv6 address""" try: socket.inet_pton(socket.AF_INET6, value) return value except socket.error: - raise ValueError('{0} is not a valid ipv4 address'.format(value)) + raise ValueError("{0} is not a valid ipv4 address".format(value)) -ipv6.__schema__ = {'type': 'string', 'format': 'ipv6'} +ipv6.__schema__ = {"type": "string", "format": "ipv6"} def ip(value): - '''Validate an IP address (both IPv4 and IPv6)''' + """Validate an IP address (both IPv4 and IPv6)""" try: return ipv4(value) except ValueError: @@ -90,14 +90,14 @@ def ip(value): try: return ipv6(value) except ValueError: - raise ValueError('{0} is not a valid ip'.format(value)) + raise ValueError("{0} is not a valid ip".format(value)) -ip.__schema__ = {'type': 'string', 'format': 'ip'} +ip.__schema__ = {"type": "string", "format": "ip"} class URL(object): - ''' + """ Validate an URL. Example:: @@ -117,9 +117,19 @@ class URL(object): :param list|tuple schemes: Restrict valid schemes to this list :param list|tuple domains: Restrict valid domains to this list :param list|tuple exclude: Exclude some domains - ''' - def __init__(self, check=False, ip=False, local=False, port=False, auth=False, - schemes=None, domains=None, exclude=None): + """ + + def __init__( + self, + check=False, + ip=False, + local=False, + port=False, + auth=False, + schemes=None, + domains=None, + exclude=None, + ): self.check = check self.ip = ip self.local = local @@ -130,66 +140,68 @@ def __init__(self, check=False, ip=False, local=False, port=False, auth=False, self.exclude = exclude def error(self, value, details=None): - msg = '{0} is not a valid URL' + msg = "{0} is not a valid URL" if details: - msg = '. '.join((msg, details)) + msg = ". ".join((msg, details)) raise ValueError(msg.format(value)) def __call__(self, value): parsed = urlparse(value) netloc_match = netloc_regex.match(parsed.netloc) if not all((parsed.scheme, parsed.netloc)): - if netloc_regex.match(parsed.netloc or parsed.path.split('/', 1)[0].split('?', 1)[0]): - self.error(value, 'Did you mean: http://{0}') + if netloc_regex.match( + parsed.netloc or parsed.path.split("/", 1)[0].split("?", 1)[0] + ): + self.error(value, "Did you mean: http://{0}") self.error(value) if parsed.scheme and self.schemes and parsed.scheme not in self.schemes: - self.error(value, 'Protocol is not allowed') + self.error(value, "Protocol is not allowed") if not netloc_match: self.error(value) data = netloc_match.groupdict() - if data['ipv4'] or data['ipv6']: + if data["ipv4"] or data["ipv6"]: if not self.ip: - self.error(value, 'IP is not allowed') + self.error(value, "IP is not allowed") else: try: - ip(data['ipv4'] or data['ipv6']) + ip(data["ipv4"] or data["ipv6"]) except ValueError as e: self.error(value, str(e)) if not self.local: - if data['ipv4'] and data['ipv4'].startswith('127.'): - self.error(value, 'Localhost is not allowed') - elif data['ipv6'] == '::1': - self.error(value, 'Localhost is not allowed') + if data["ipv4"] and data["ipv4"].startswith("127."): + self.error(value, "Localhost is not allowed") + elif data["ipv6"] == "::1": + self.error(value, "Localhost is not allowed") if self.check: pass - if data['auth'] and not self.auth: - self.error(value, 'Authentication is not allowed') - if data['localhost'] and not self.local: - self.error(value, 'Localhost is not allowed') - if data['port']: + if data["auth"] and not self.auth: + self.error(value, "Authentication is not allowed") + if data["localhost"] and not self.local: + self.error(value, "Localhost is not allowed") + if data["port"]: if not self.port: - self.error(value, 'Custom port is not allowed') + self.error(value, "Custom port is not allowed") else: - port = int(data['port']) + port = int(data["port"]) if not 0 < port < 65535: - self.error(value, 'Port is out of range') - if data['domain']: - if self.domains and data['domain'] not in self.domains: - self.error(value, 'Domain is not allowed') - elif self.exclude and data['domain'] in self.exclude: - self.error(value, 'Domain is not allowed') + self.error(value, "Port is out of range") + if data["domain"]: + if self.domains and data["domain"] not in self.domains: + self.error(value, "Domain is not allowed") + elif self.exclude and data["domain"] in self.exclude: + self.error(value, "Domain is not allowed") if self.check: try: - socket.getaddrinfo(data['domain'], None) + socket.getaddrinfo(data["domain"], None) except socket.error: - self.error(value, 'Domain does not exists') + self.error(value, "Domain does not exists") return value @property def __schema__(self): return { - 'type': 'string', - 'format': 'url', + "type": "string", + "format": "url", } @@ -197,11 +209,13 @@ def __schema__(self): #: #: Legacy validator, allows, auth, port, ip and local #: Only allows schemes 'http', 'https', 'ftp' and 'ftps' -url = URL(ip=True, auth=True, port=True, local=True, schemes=('http', 'https', 'ftp', 'ftps')) +url = URL( + ip=True, auth=True, port=True, local=True, schemes=("http", "https", "ftp", "ftps") +) class email(object): - ''' + """ Validate an email. Example:: @@ -217,7 +231,8 @@ class email(object): :param bool local: Allow localhost (both string or ip) as domain :param list|tuple domains: Restrict valid domains to this list :param list|tuple exclude: Exclude some domains - ''' + """ + def __init__(self, check=False, ip=False, local=False, domains=None, exclude=None): self.check = check self.ip = ip @@ -226,7 +241,7 @@ def __init__(self, check=False, ip=False, local=False, domains=None, exclude=Non self.exclude = exclude def error(self, value, msg=None): - msg = msg or '{0} is not a valid email' + msg = msg or "{0} is not a valid email" raise ValueError(msg.format(value)) def is_ip(self, value): @@ -238,19 +253,21 @@ def is_ip(self, value): def __call__(self, value): match = email_regex.match(value) - if not match or '..' in value: + if not match or ".." in value: self.error(value) - server = match.group('server') + server = match.group("server") if self.check: try: socket.getaddrinfo(server, None) except socket.error: self.error(value) if self.domains and server not in self.domains: - self.error(value, '{0} does not belong to the authorized domains') + self.error(value, "{0} does not belong to the authorized domains") if self.exclude and server in self.exclude: - self.error(value, '{0} belongs to a forbidden domain') - if not self.local and (server in ('localhost', '::1') or server.startswith('127.')): + self.error(value, "{0} belongs to a forbidden domain") + if not self.local and ( + server in ("localhost", "::1") or server.startswith("127.") + ): self.error(value) if self.is_ip(server) and not self.ip: self.error(value) @@ -259,13 +276,13 @@ def __call__(self, value): @property def __schema__(self): return { - 'type': 'string', - 'format': 'email', + "type": "string", + "format": "email", } class regex(object): - ''' + """ Validate a string based on a regular expression. Example:: @@ -277,7 +294,7 @@ class regex(object): but numbers. :param str pattern: The regular expression the input must match - ''' + """ def __init__(self, pattern): self.pattern = pattern @@ -295,13 +312,13 @@ def __deepcopy__(self, memo): @property def __schema__(self): return { - 'type': 'string', - 'pattern': self.pattern, + "type": "string", + "pattern": self.pattern, } def _normalize_interval(start, end, value): - ''' + """ Normalize datetime intervals. Given a pair of datetime.date or datetime.datetime objects, @@ -317,7 +334,7 @@ def _normalize_interval(start, end, value): Params: - start: A date or datetime - end: A date or datetime - ''' + """ if not isinstance(start, datetime): start = datetime.combine(start, START_OF_DAY) end = datetime.combine(end, START_OF_DAY) @@ -340,9 +357,9 @@ def _expand_datetime(start, value): else: # Expand a datetime based on the finest resolution provided # in the original input string. - time = value.split('T')[1] - time_without_offset = re.sub('[+-].+', '', time) - num_separators = time_without_offset.count(':') + time = value.split("T")[1] + time_without_offset = re.sub("[+-].+", "", time) + num_separators = time_without_offset.count(":") if num_separators == 0: # Hour resolution end = start + timedelta(hours=1) @@ -357,10 +374,10 @@ def _expand_datetime(start, value): def _parse_interval(value): - ''' + """ Do some nasty try/except voodoo to get some sort of datetime object(s) out of the string. - ''' + """ try: return sorted(aniso8601.parse_interval(value)) except ValueError: @@ -370,8 +387,8 @@ def _parse_interval(value): return aniso8601.parse_date(value), None -def iso8601interval(value, argument='argument'): - ''' +def iso8601interval(value, argument="argument"): + """ Parses ISO 8601-formatted datetime intervals into tuples of datetimes. Accepts both a single date(time) or a full interval using either start/end @@ -397,9 +414,9 @@ def iso8601interval(value, argument='argument'): :return: Two UTC datetimes, the start and the end of the specified interval :rtype: A tuple (datetime, datetime) :raises ValueError: if the interval is invalid. - ''' + """ if not value: - raise ValueError('Expected a valid ISO8601 date/time interval.') + raise ValueError("Expected a valid ISO8601 date/time interval.") try: start, end = _parse_interval(value) @@ -410,58 +427,61 @@ def iso8601interval(value, argument='argument'): start, end = _normalize_interval(start, end, value) except ValueError: - msg = 'Invalid {arg}: {value}. {arg} must be a valid ISO8601 date/time interval.' + msg = ( + "Invalid {arg}: {value}. {arg} must be a valid ISO8601 date/time interval." + ) raise ValueError(msg.format(arg=argument, value=value)) return start, end -iso8601interval.__schema__ = {'type': 'string', 'format': 'iso8601-interval'} +iso8601interval.__schema__ = {"type": "string", "format": "iso8601-interval"} def date(value): - '''Parse a valid looking date in the format YYYY-mm-dd''' + """Parse a valid looking date in the format YYYY-mm-dd""" date = datetime.strptime(value, "%Y-%m-%d") return date -date.__schema__ = {'type': 'string', 'format': 'date'} +date.__schema__ = {"type": "string", "format": "date"} def _get_integer(value): try: return int(value) except (TypeError, ValueError): - raise ValueError('{0} is not a valid integer'.format(value)) + raise ValueError("{0} is not a valid integer".format(value)) -def natural(value, argument='argument'): - '''Restrict input type to the natural numbers (0, 1, 2, 3...)''' +def natural(value, argument="argument"): + """Restrict input type to the natural numbers (0, 1, 2, 3...)""" value = _get_integer(value) if value < 0: - msg = 'Invalid {arg}: {value}. {arg} must be a non-negative integer' + msg = "Invalid {arg}: {value}. {arg} must be a non-negative integer" raise ValueError(msg.format(arg=argument, value=value)) return value -natural.__schema__ = {'type': 'integer', 'minimum': 0} +natural.__schema__ = {"type": "integer", "minimum": 0} -def positive(value, argument='argument'): - '''Restrict input type to the positive integers (1, 2, 3...)''' +def positive(value, argument="argument"): + """Restrict input type to the positive integers (1, 2, 3...)""" value = _get_integer(value) if value < 1: - msg = 'Invalid {arg}: {value}. {arg} must be a positive integer' + msg = "Invalid {arg}: {value}. {arg} must be a positive integer" raise ValueError(msg.format(arg=argument, value=value)) return value -positive.__schema__ = {'type': 'integer', 'minimum': 0, 'exclusiveMinimum': True} +positive.__schema__ = {"type": "integer", "minimum": 0, "exclusiveMinimum": True} class int_range(object): - '''Restrict input to an integer in a range (inclusive)''' - def __init__(self, low, high, argument='argument'): + """Restrict input to an integer in a range (inclusive)""" + + def __init__(self, low, high, argument="argument"): self.low = low self.high = high self.argument = argument @@ -469,21 +489,23 @@ def __init__(self, low, high, argument='argument'): def __call__(self, value): value = _get_integer(value) if value < self.low or value > self.high: - msg = 'Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}' - raise ValueError(msg.format(arg=self.argument, val=value, lo=self.low, hi=self.high)) + msg = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}" + raise ValueError( + msg.format(arg=self.argument, val=value, lo=self.low, hi=self.high) + ) return value @property def __schema__(self): return { - 'type': 'integer', - 'minimum': self.low, - 'maximum': self.high, + "type": "integer", + "minimum": self.low, + "maximum": self.high, } def boolean(value): - ''' + """ Parse the string ``"true"`` or ``"false"`` as a boolean (case insensitive). Also accepts ``"1"`` and ``"0"`` as ``True``/``False`` (respectively). @@ -492,27 +514,27 @@ def boolean(value): and will be passed through without further parsing. :raises ValueError: if the boolean value is invalid - ''' + """ if isinstance(value, bool): return value if value is None: - raise ValueError('boolean type must be non-null') + raise ValueError("boolean type must be non-null") elif not value: return False value = str(value).lower() - if value in ('true', '1', 'on',): + if value in ("true", "1", "on",): return True - if value in ('false', '0',): + if value in ("false", "0",): return False - raise ValueError('Invalid literal for boolean(): {0}'.format(value)) + raise ValueError("Invalid literal for boolean(): {0}".format(value)) -boolean.__schema__ = {'type': 'boolean'} +boolean.__schema__ = {"type": "boolean"} def datetime_from_rfc822(value): - ''' + """ Turns an RFC822 formatted date into a datetime object. Example:: @@ -524,10 +546,10 @@ def datetime_from_rfc822(value): :rtype: datetime :raises ValueError: if value is an invalid date literal - ''' + """ raw = value if not time_regex.search(value): - value = ' '.join((value, '00:00:00')) + value = " ".join((value, "00:00:00")) try: timetuple = parsedate_tz(value) timestamp = mktime_tz(timetuple) @@ -540,7 +562,7 @@ def datetime_from_rfc822(value): def datetime_from_iso8601(value): - ''' + """ Turns an ISO8601 formatted date into a datetime object. Example:: @@ -552,7 +574,7 @@ def datetime_from_iso8601(value): :rtype: datetime :raises ValueError: if value is an invalid date literal - ''' + """ try: try: return aniso8601.parse_datetime(value) @@ -563,11 +585,11 @@ def datetime_from_iso8601(value): raise ValueError('Invalid date literal "{0}"'.format(value)) -datetime_from_iso8601.__schema__ = {'type': 'string', 'format': 'date-time'} +datetime_from_iso8601.__schema__ = {"type": "string", "format": "date-time"} def date_from_iso8601(value): - ''' + """ Turns an ISO8601 formatted date into a date object. Example:: @@ -581,8 +603,8 @@ def date_from_iso8601(value): :rtype: date :raises ValueError: if value is an invalid date literal - ''' + """ return datetime_from_iso8601(value).date() -date_from_iso8601.__schema__ = {'type': 'string', 'format': 'date'} +date_from_iso8601.__schema__ = {"type": "string", "format": "date"} diff --git a/flask_restx/marshalling.py b/flask_restx/marshalling.py index 3d435d0b..1c1e9c7c 100644 --- a/flask_restx/marshalling.py +++ b/flask_restx/marshalling.py @@ -88,8 +88,9 @@ def _append(k, v): _append(key, value) while True: value = field.output(dkey, data, ordered=ordered) - if value is None or \ - value == field.container.format(field.default): + if value is None or value == field.container.format( + field.default + ): break key = field.key _append(key, value) @@ -153,8 +154,8 @@ def _marshal(data, fields, envelope=None, skip_none=False, mask=None, ordered=Fa # ugly local import to avoid dependency loop from .fields import Wildcard - mask = mask or getattr(fields, '__mask__', None) - fields = getattr(fields, 'resolved', fields) + mask = mask or getattr(fields, "__mask__", None) + fields = getattr(fields, "resolved", fields) if mask: fields = apply_mask(fields, mask, skip=True) @@ -164,12 +165,12 @@ def _marshal(data, fields, envelope=None, skip_none=False, mask=None, ordered=Fa out = OrderedDict([(envelope, out)]) if ordered else {envelope: out} return out, False - has_wildcards = {'present': False} + has_wildcards = {"present": False} def __format_field(key, val): field = make(val) if isinstance(field, Wildcard): - has_wildcards['present'] = True + has_wildcards["present"] = True value = field.output(key, data, ordered=ordered) return (key, value) @@ -181,15 +182,16 @@ def __format_field(key, val): ) if skip_none: - items = ((k, v) for k, v in items - if v is not None and v != OrderedDict() and v != {}) + items = ( + (k, v) for k, v in items if v is not None and v != OrderedDict() and v != {} + ) out = OrderedDict(items) if ordered else dict(items) if envelope: out = OrderedDict([(envelope, out)]) if ordered else {envelope: out} - return out, has_wildcards['present'] + return out, has_wildcards["present"] class marshal_with(object): @@ -224,7 +226,10 @@ class marshal_with(object): see :meth:`flask_restx.marshal` """ - def __init__(self, fields, envelope=None, skip_none=False, mask=None, ordered=False): + + def __init__( + self, fields, envelope=None, skip_none=False, mask=None, ordered=False + ): """ :param fields: a dict of whose keys will make up the final serialized response output @@ -243,17 +248,27 @@ def wrapper(*args, **kwargs): resp = f(*args, **kwargs) mask = self.mask if has_app_context(): - mask_header = current_app.config['RESTX_MASK_HEADER'] + mask_header = current_app.config["RESTX_MASK_HEADER"] mask = request.headers.get(mask_header) or mask if isinstance(resp, tuple): data, code, headers = unpack(resp) return ( - marshal(data, self.fields, self.envelope, self.skip_none, mask, self.ordered), + marshal( + data, + self.fields, + self.envelope, + self.skip_none, + mask, + self.ordered, + ), code, - headers + headers, ) else: - return marshal(resp, self.fields, self.envelope, self.skip_none, mask, self.ordered) + return marshal( + resp, self.fields, self.envelope, self.skip_none, mask, self.ordered + ) + return wrapper @@ -271,6 +286,7 @@ class marshal_with_field(object): see :meth:`flask_restx.marshal_with` """ + def __init__(self, field): """ :param field: a single field with which to marshal the output. diff --git a/flask_restx/mask.py b/flask_restx/mask.py index e4c1670c..1784d4ec 100644 --- a/flask_restx/mask.py +++ b/flask_restx/mask.py @@ -12,26 +12,29 @@ log = logging.getLogger(__name__) -LEXER = re.compile(r'\{|\}|\,|[\w_:\-\*]+') +LEXER = re.compile(r"\{|\}|\,|[\w_:\-\*]+") class MaskError(RestError): - '''Raised when an error occurs on mask''' + """Raised when an error occurs on mask""" + pass class ParseError(MaskError): - '''Raised when the mask parsing failed''' + """Raised when the mask parsing failed""" + pass class Mask(OrderedDict): - ''' + """ Hold a parsed mask. :param str|dict|Mask mask: A mask, parsed or not :param bool skip: If ``True``, missing fields won't appear in result - ''' + """ + def __init__(self, mask=None, skip=False, **kwargs): self.skip = skip if isinstance(mask, six.string_types): @@ -44,7 +47,7 @@ def __init__(self, mask=None, skip=False, **kwargs): super(Mask, self).__init__(**kwargs) def parse(self, mask): - ''' + """ Parse a fields mask. Expect something in the form:: @@ -59,7 +62,7 @@ def parse(self, mask): :param str mask: the mask string to parse :raises ParseError: when a mask is unparseable/invalid - ''' + """ if not mask: return @@ -69,46 +72,47 @@ def parse(self, mask): stack = [] for token in LEXER.findall(mask): - if token == '{': + if token == "{": if previous not in fields: - raise ParseError('Unexpected opening bracket') + raise ParseError("Unexpected opening bracket") fields[previous] = Mask(skip=self.skip) stack.append(fields) fields = fields[previous] - elif token == '}': + elif token == "}": if not stack: - raise ParseError('Unexpected closing bracket') + raise ParseError("Unexpected closing bracket") fields = stack.pop() - elif token == ',': - if previous in (',', '{', None): - raise ParseError('Unexpected comma') + elif token == ",": + if previous in (",", "{", None): + raise ParseError("Unexpected comma") else: fields[token] = True previous = token if stack: - raise ParseError('Missing closing bracket') + raise ParseError("Missing closing bracket") def clean(self, mask): - '''Remove unnecessary characters''' - mask = mask.replace('\n', '').strip() + """Remove unnecessary characters""" + mask = mask.replace("\n", "").strip() # External brackets are optional - if mask[0] == '{': - if mask[-1] != '}': - raise ParseError('Missing closing bracket') + if mask[0] == "{": + if mask[-1] != "}": + raise ParseError("Missing closing bracket") mask = mask[1:-1] return mask def apply(self, data): - ''' + """ Apply a fields mask to the data. :param data: The data or model to apply mask on :raises MaskError: when unable to apply the mask - ''' + """ from . import fields + # Should handle lists if isinstance(data, (list, tuple, set)): return [self.apply(d) for d in data] @@ -118,27 +122,31 @@ def apply(self, data): return fields.Raw(default=data.default, attribute=data.attribute, mask=self) elif data == fields.Raw: return fields.Raw(mask=self) - elif isinstance(data, fields.Raw) or isclass(data) and issubclass(data, fields.Raw): + elif ( + isinstance(data, fields.Raw) + or isclass(data) + and issubclass(data, fields.Raw) + ): # Not possible to apply a mask on these remaining fields types - raise MaskError('Mask is inconsistent with model') + raise MaskError("Mask is inconsistent with model") # Should handle objects - elif (not isinstance(data, (dict, OrderedDict)) and hasattr(data, '__dict__')): + elif not isinstance(data, (dict, OrderedDict)) and hasattr(data, "__dict__"): data = data.__dict__ return self.filter_data(data) def filter_data(self, data): - ''' + """ Handle the data filtering given a parsed mask :param dict data: the raw data to filter :param list mask: a parsed mask to filter against :param bool skip: whether or not to skip missing fields - ''' + """ out = {} for field, content in six.iteritems(self): - if field == '*': + if field == "*": continue elif isinstance(content, Mask): nested = data.get(field, None) @@ -153,21 +161,25 @@ def filter_data(self, data): else: out[field] = data.get(field, None) - if '*' in self.keys(): + if "*" in self.keys(): for key, value in six.iteritems(data): if key not in out: out[key] = value return out def __str__(self): - return '{{{0}}}'.format(','.join([ - ''.join((k, str(v))) if isinstance(v, Mask) else k - for k, v in six.iteritems(self) - ])) + return "{{{0}}}".format( + ",".join( + [ + "".join((k, str(v))) if isinstance(v, Mask) else k + for k, v in six.iteritems(self) + ] + ) + ) def apply(data, mask, skip=False): - ''' + """ Apply a fields mask to the data. :param data: The data or model to apply mask on @@ -175,5 +187,5 @@ def apply(data, mask, skip=False): :param bool skip: If rue, missing field won't appear in result :raises MaskError: when unable to apply the mask - ''' + """ return Mask(mask, skip).apply(data) diff --git a/flask_restx/model.py b/flask_restx/model.py index 6da9dbcd..b298c2c5 100644 --- a/flask_restx/model.py +++ b/flask_restx/model.py @@ -6,6 +6,7 @@ import warnings from collections import OrderedDict + try: from collections.abc import MutableMapping except ImportError: @@ -24,7 +25,7 @@ from ._http import HTTPStatus -RE_REQUIRED = re.compile(r'u?\'(?P.*)\' is a required property', re.I | re.U) +RE_REQUIRED = re.compile(r"u?\'(?P.*)\' is a required property", re.I | re.U) def instance(cls): @@ -34,18 +35,16 @@ def instance(cls): class ModelBase(object): - ''' + """ Handles validation and swagger style inheritance for both subclasses. Subclass must define `schema` attribute. :param str name: The model public name - ''' + """ def __init__(self, name, *args, **kwargs): super(ModelBase, self).__init__(*args, **kwargs) - self.__apidoc__ = { - 'name': name - } + self.__apidoc__ = {"name": name} self.name = name self.__parents__ = [] @@ -56,9 +55,9 @@ def instance_inherit(name, *parents): @property def ancestors(self): - ''' + """ Return the ancestors tree - ''' + """ ancestors = [p.ancestors for p in self.__parents__] return set.union(set([self.name]), *ancestors) @@ -70,7 +69,7 @@ def get_parent(self, name): found = parent.get_parent(name) if found: return found - raise ValueError('Parent ' + name + ' not found') + raise ValueError("Parent " + name + " not found") @property def __schema__(self): @@ -78,68 +77,74 @@ def __schema__(self): if self.__parents__: refs = [ - {'$ref': '#/definitions/{0}'.format(parent.name)} + {"$ref": "#/definitions/{0}".format(parent.name)} for parent in self.__parents__ ] - return { - 'allOf': refs + [schema] - } + return {"allOf": refs + [schema]} else: return schema @classmethod def inherit(cls, name, *parents): - ''' + """ Inherit this model (use the Swagger composition pattern aka. allOf) :param str name: The new model name :param dict fields: The new model extra fields - ''' + """ model = cls(name, parents[-1]) model.__parents__ = parents[:-1] return model def validate(self, data, resolver=None, format_checker=None): - validator = Draft4Validator(self.__schema__, resolver=resolver, format_checker=format_checker) + validator = Draft4Validator( + self.__schema__, resolver=resolver, format_checker=format_checker + ) try: validator.validate(data) except ValidationError: - abort(HTTPStatus.BAD_REQUEST, message='Input payload validation failed', - errors=dict(self.format_error(e) for e in validator.iter_errors(data))) + abort( + HTTPStatus.BAD_REQUEST, + message="Input payload validation failed", + errors=dict(self.format_error(e) for e in validator.iter_errors(data)), + ) def format_error(self, error): path = list(error.path) - if error.validator == 'required': - name = RE_REQUIRED.match(error.message).group('name') + if error.validator == "required": + name = RE_REQUIRED.match(error.message).group("name") path.append(name) - key = '.'.join(str(p) for p in path) + key = ".".join(str(p) for p in path) return key, error.message def __unicode__(self): - return 'Model({name},{{{fields}}})'.format(name=self.name, fields=','.join(self.keys())) + return "Model({name},{{{fields}}})".format( + name=self.name, fields=",".join(self.keys()) + ) __str__ = __unicode__ class RawModel(ModelBase): - ''' + """ A thin wrapper on ordered fields dict to store API doc metadata. Can also be used for response marshalling. :param str name: The model public name :param str mask: an optional default model mask - ''' + """ wrapper = dict def __init__(self, name, *args, **kwargs): - self.__mask__ = kwargs.pop('mask', None) + self.__mask__ = kwargs.pop("mask", None) if self.__mask__ and not isinstance(self.__mask__, Mask): self.__mask__ = Mask(self.__mask__) super(RawModel, self).__init__(name, *args, **kwargs) def instance_clone(name, *parents): return self.__class__.clone(name, self, *parents) + self.clone = instance_clone @property @@ -152,22 +157,24 @@ def _schema(self): properties[name] = field.__schema__ if field.required: required.add(name) - if getattr(field, 'discriminator', False): + if getattr(field, "discriminator", False): discriminator = name - return not_none({ - 'required': sorted(list(required)) or None, - 'properties': properties, - 'discriminator': discriminator, - 'x-mask': str(self.__mask__) if self.__mask__ else None, - 'type': 'object', - }) + return not_none( + { + "required": sorted(list(required)) or None, + "properties": properties, + "discriminator": discriminator, + "x-mask": str(self.__mask__) if self.__mask__ else None, + "type": "object", + } + ) @cached_property def resolved(self): - ''' + """ Resolve real fields before submitting them to marshal - ''' + """ # Duplicate fields resolved = copy.deepcopy(self) @@ -176,10 +183,12 @@ def resolved(self): resolved.update(parent.resolved) # Handle discriminator - candidates = [f for f in itervalues(resolved) if getattr(f, 'discriminator', None)] + candidates = [ + f for f in itervalues(resolved) if getattr(f, "discriminator", None) + ] # Ensure the is only one discriminator if len(candidates) > 1: - raise ValueError('There can only be one discriminator by schema') + raise ValueError("There can only be one discriminator by schema") # Ensure discriminator always output the model name elif len(candidates) == 1: candidates[0].default = self.name @@ -187,15 +196,19 @@ def resolved(self): return resolved def extend(self, name, fields): - ''' + """ Extend this model (Duplicate all fields) :param str name: The new model name :param dict fields: The new model extra fields :deprecated: since 0.9. Use :meth:`clone` instead. - ''' - warnings.warn('extend is is deprecated, use clone instead', DeprecationWarning, stacklevel=2) + """ + warnings.warn( + "extend is is deprecated, use clone instead", + DeprecationWarning, + stacklevel=2, + ) if isinstance(fields, (list, tuple)): return self.clone(name, *fields) else: @@ -203,7 +216,7 @@ def extend(self, name, fields): @classmethod def clone(cls, name, *parents): - ''' + """ Clone these models (Duplicate all fields) It can be used from the class @@ -216,55 +229,61 @@ def clone(cls, name, *parents): :param str name: The new model name :param dict parents: The new model extra fields - ''' + """ fields = cls.wrapper() for parent in parents: fields.update(copy.deepcopy(parent)) return cls(name, fields) def __deepcopy__(self, memo): - obj = self.__class__(self.name, - [(key, copy.deepcopy(value, memo)) for key, value in iteritems(self)], - mask=self.__mask__) + obj = self.__class__( + self.name, + [(key, copy.deepcopy(value, memo)) for key, value in iteritems(self)], + mask=self.__mask__, + ) obj.__parents__ = self.__parents__ return obj class Model(RawModel, dict, MutableMapping): - ''' + """ A thin wrapper on fields dict to store API doc metadata. Can also be used for response marshalling. :param str name: The model public name :param str mask: an optional default model mask - ''' + """ + pass class OrderedModel(RawModel, OrderedDict, MutableMapping): - ''' + """ A thin wrapper on ordered fields dict to store API doc metadata. Can also be used for response marshalling. :param str name: The model public name :param str mask: an optional default model mask - ''' + """ + wrapper = OrderedDict class SchemaModel(ModelBase): - ''' + """ Stores API doc metadata based on a json schema. :param str name: The model public name :param dict schema: The json schema we are documenting - ''' + """ def __init__(self, name, schema=None): super(SchemaModel, self).__init__(name) self._schema = schema or {} def __unicode__(self): - return 'SchemaModel({name},{schema})'.format(name=self.name, schema=self._schema) + return "SchemaModel({name},{schema})".format( + name=self.name, schema=self._schema + ) __str__ = __unicode__ diff --git a/flask_restx/namespace.py b/flask_restx/namespace.py index 57b1528e..9e316476 100644 --- a/flask_restx/namespace.py +++ b/flask_restx/namespace.py @@ -22,7 +22,7 @@ class Namespace(object): - ''' + """ Group resources together. Namespace is to API what :class:`flask:flask.Blueprint` is for :class:`flask:flask.Flask`. @@ -34,9 +34,19 @@ class Namespace(object): :param bool validate: Whether or not to perform validation on this namespace :param bool ordered: Whether or not to preserve order on models and marshalling :param Api api: an optional API to attache to the namespace - ''' - def __init__(self, name, description=None, path=None, decorators=None, validate=None, - authorizations=None, ordered=False, **kwargs): + """ + + def __init__( + self, + name, + description=None, + path=None, + decorators=None, + validate=None, + authorizations=None, + ordered=False, + **kwargs + ): self.name = name self.description = description self._path = path @@ -52,16 +62,16 @@ def __init__(self, name, description=None, path=None, decorators=None, validate= self.authorizations = authorizations self.ordered = ordered self.apis = [] - if 'api' in kwargs: - self.apis.append(kwargs['api']) + if "api" in kwargs: + self.apis.append(kwargs["api"]) self.logger = logging.getLogger(__name__ + "." + self.name) @property def path(self): - return (self._path or ('/' + self.name)).rstrip('/') + return (self._path or ("/" + self.name)).rstrip("/") def add_resource(self, resource, *urls, **kwargs): - ''' + """ Register a Resource for a given API Namespace :param Resource resource: the resource ro register @@ -81,24 +91,26 @@ def add_resource(self, resource, *urls, **kwargs): namespace.add_resource(HelloWorld, '/', '/hello') namespace.add_resource(Foo, '/foo', endpoint="foo") namespace.add_resource(FooSpecial, '/special/foo', endpoint="foo") - ''' - route_doc = kwargs.pop('route_doc', {}) + """ + route_doc = kwargs.pop("route_doc", {}) self.resources.append(ResourceRoute(resource, urls, route_doc, kwargs)) for api in self.apis: ns_urls = api.ns_urls(self, urls) api.register_resource(self, resource, *ns_urls, **kwargs) def route(self, *urls, **kwargs): - ''' + """ A decorator to route resources. - ''' + """ + def wrapper(cls): - doc = kwargs.pop('doc', None) + doc = kwargs.pop("doc", None) if doc is not None: # build api doc intended only for this route - kwargs['route_doc'] = self._build_doc(cls, doc) + kwargs["route_doc"] = self._build_doc(cls, doc) self.add_resource(cls, *urls, **kwargs) return cls + return wrapper def _build_doc(self, cls, doc): @@ -112,34 +124,36 @@ def _build_doc(self, cls, doc): continue unshortcut_params_description(doc[http_method]) handle_deprecations(doc[http_method]) - if 'expect' in doc[http_method] and not isinstance(doc[http_method]['expect'], (list, tuple)): - doc[http_method]['expect'] = [doc[http_method]['expect']] - return merge(getattr(cls, '__apidoc__', {}), doc) + if "expect" in doc[http_method] and not isinstance( + doc[http_method]["expect"], (list, tuple) + ): + doc[http_method]["expect"] = [doc[http_method]["expect"]] + return merge(getattr(cls, "__apidoc__", {}), doc) def doc(self, shortcut=None, **kwargs): - '''A decorator to add some api documentation to the decorated object''' + """A decorator to add some api documentation to the decorated object""" if isinstance(shortcut, six.text_type): - kwargs['id'] = shortcut + kwargs["id"] = shortcut show = shortcut if isinstance(shortcut, bool) else True def wrapper(documented): documented.__apidoc__ = self._build_doc( - documented, - kwargs if show else False + documented, kwargs if show else False ) return documented + return wrapper def hide(self, func): - '''A decorator to hide a resource or a method from specifications''' + """A decorator to hide a resource or a method from specifications""" return self.doc(False)(func) def abort(self, *args, **kwargs): - ''' + """ Properly abort the current request See: :func:`~flask_restx.errors.abort` - ''' + """ abort(*args, **kwargs) def add_model(self, name, definition): @@ -149,31 +163,31 @@ def add_model(self, name, definition): return definition def model(self, name=None, model=None, mask=None, **kwargs): - ''' + """ Register a model .. seealso:: :class:`Model` - ''' + """ cls = OrderedModel if self.ordered else Model model = cls(name, model, mask=mask) model.__apidoc__.update(kwargs) return self.add_model(name, model) def schema_model(self, name=None, schema=None): - ''' + """ Register a model .. seealso:: :class:`Model` - ''' + """ model = SchemaModel(name, schema) return self.add_model(name, model) def extend(self, name, parent, fields): - ''' + """ Extend a model (Duplicate all fields) :deprecated: since 0.9. Use :meth:`clone` instead - ''' + """ if isinstance(parent, list): parents = parent + [fields] model = Model.extend(name, *parents) @@ -182,7 +196,7 @@ def extend(self, name, parent, fields): return self.add_model(name, model) def clone(self, name, *specs): - ''' + """ Clone a model (Duplicate all fields) :param str name: the resulting model name @@ -190,159 +204,173 @@ def clone(self, name, *specs): .. seealso:: :meth:`Model.clone` - ''' + """ model = Model.clone(name, *specs) return self.add_model(name, model) def inherit(self, name, *specs): - ''' + """ Inherit a model (use the Swagger composition pattern aka. allOf) .. seealso:: :meth:`Model.inherit` - ''' + """ model = Model.inherit(name, *specs) return self.add_model(name, model) def expect(self, *inputs, **kwargs): - ''' + """ A decorator to Specify the expected input model :param ModelBase|Parse inputs: An expect model or request parser :param bool validate: whether to perform validation or not - ''' + """ expect = [] - params = { - 'validate': kwargs.get('validate', self._validate), - 'expect': expect - } + params = {"validate": kwargs.get("validate", self._validate), "expect": expect} for param in inputs: expect.append(param) return self.doc(**params) def parser(self): - '''Instanciate a :class:`~RequestParser`''' + """Instanciate a :class:`~RequestParser`""" return RequestParser() def as_list(self, field): - '''Allow to specify nested lists for documentation''' - field.__apidoc__ = merge(getattr(field, '__apidoc__', {}), {'as_list': True}) + """Allow to specify nested lists for documentation""" + field.__apidoc__ = merge(getattr(field, "__apidoc__", {}), {"as_list": True}) return field - def marshal_with(self, fields, as_list=False, code=HTTPStatus.OK, description=None, **kwargs): - ''' + def marshal_with( + self, fields, as_list=False, code=HTTPStatus.OK, description=None, **kwargs + ): + """ A decorator specifying the fields to use for serialization. :param bool as_list: Indicate that the return type is a list (for the documentation) :param int code: Optionally give the expected HTTP response code if its different from 200 - ''' + """ + def wrapper(func): doc = { - 'responses': { - str(code): (description, [fields], kwargs) if as_list else (description, fields, kwargs) + "responses": { + str(code): (description, [fields], kwargs) + if as_list + else (description, fields, kwargs) }, - '__mask__': kwargs.get('mask', True), # Mask values can't be determined outside app context + "__mask__": kwargs.get( + "mask", True + ), # Mask values can't be determined outside app context } - func.__apidoc__ = merge(getattr(func, '__apidoc__', {}), doc) + func.__apidoc__ = merge(getattr(func, "__apidoc__", {}), doc) return marshal_with(fields, ordered=self.ordered, **kwargs)(func) + return wrapper def marshal_list_with(self, fields, **kwargs): - '''A shortcut decorator for :meth:`~Api.marshal_with` with ``as_list=True``''' + """A shortcut decorator for :meth:`~Api.marshal_with` with ``as_list=True``""" return self.marshal_with(fields, True, **kwargs) def marshal(self, *args, **kwargs): - '''A shortcut to the :func:`marshal` helper''' + """A shortcut to the :func:`marshal` helper""" return marshal(*args, **kwargs) def errorhandler(self, exception): - '''A decorator to register an error handler for a given exception''' + """A decorator to register an error handler for a given exception""" if inspect.isclass(exception) and issubclass(exception, Exception): # Register an error handler for a given exception def wrapper(func): self.error_handlers[exception] = func return func + return wrapper else: # Register the default error handler self.default_error_handler = exception return exception - def param(self, name, description=None, _in='query', **kwargs): - ''' + def param(self, name, description=None, _in="query", **kwargs): + """ A decorator to specify one of the expected parameters :param str name: the parameter name :param str description: a small description :param str _in: the parameter location `(query|header|formData|body|cookie)` - ''' + """ param = kwargs - param['in'] = _in - param['description'] = description + param["in"] = _in + param["description"] = description return self.doc(params={name: param}) def response(self, code, description, model=None, **kwargs): - ''' + """ A decorator to specify one of the expected responses :param int code: the HTTP status code :param str description: a small description about the response :param ModelBase model: an optional response model - ''' + """ return self.doc(responses={str(code): (description, model, kwargs)}) def header(self, name, description=None, **kwargs): - ''' + """ A decorator to specify one of the expected headers :param str name: the HTTP header name :param str description: a description about the header - ''' - header = {'description': description} + """ + header = {"description": description} header.update(kwargs) return self.doc(headers={name: header}) def produces(self, mimetypes): - '''A decorator to specify the MIME types the API can produce''' + """A decorator to specify the MIME types the API can produce""" return self.doc(produces=mimetypes) def deprecated(self, func): - '''A decorator to mark a resource or a method as deprecated''' + """A decorator to mark a resource or a method as deprecated""" return self.doc(deprecated=True)(func) def vendor(self, *args, **kwargs): - ''' + """ A decorator to expose vendor extensions. Extensions can be submitted as dict or kwargs. The ``x-`` prefix is optionnal and will be added if missing. See: http://swagger.io/specification/#specification-extensions-128 - ''' + """ for arg in args: kwargs.update(arg) return self.doc(vendor=kwargs) @property def payload(self): - '''Store the input payload in the current request context''' + """Store the input payload in the current request context""" return request.get_json() def unshortcut_params_description(data): - if 'params' in data: - for name, description in six.iteritems(data['params']): + if "params" in data: + for name, description in six.iteritems(data["params"]): if isinstance(description, six.string_types): - data['params'][name] = {'description': description} + data["params"][name] = {"description": description} def handle_deprecations(doc): - if 'parser' in doc: - warnings.warn('The parser attribute is deprecated, use expect instead', DeprecationWarning, stacklevel=2) - doc['expect'] = doc.get('expect', []) + [doc.pop('parser')] - if 'body' in doc: - warnings.warn('The body attribute is deprecated, use expect instead', DeprecationWarning, stacklevel=2) - doc['expect'] = doc.get('expect', []) + [doc.pop('body')] + if "parser" in doc: + warnings.warn( + "The parser attribute is deprecated, use expect instead", + DeprecationWarning, + stacklevel=2, + ) + doc["expect"] = doc.get("expect", []) + [doc.pop("parser")] + if "body" in doc: + warnings.warn( + "The body attribute is deprecated, use expect instead", + DeprecationWarning, + stacklevel=2, + ) + doc["expect"] = doc.get("expect", []) + [doc.pop("body")] diff --git a/flask_restx/postman.py b/flask_restx/postman.py index 0af8ee2d..644922af 100644 --- a/flask_restx/postman.py +++ b/flask_restx/postman.py @@ -9,19 +9,20 @@ def clean(data): - '''Remove all keys where value is None''' + """Remove all keys where value is None""" return dict((k, v) for k, v in iteritems(data) if v is not None) DEFAULT_VARS = { - 'string': '', - 'integer': 0, - 'number': 0, + "string": "", + "integer": 0, + "number": 0, } class Request(object): - '''Wraps a Swagger operation into a Postman Request''' + """Wraps a Swagger operation into a Postman Request""" + def __init__(self, collection, path, params, method, operation): self.collection = collection self.path = path @@ -31,90 +32,94 @@ def __init__(self, collection, path, params, method, operation): @property def id(self): - seed = str(' '.join((self.method, self.url))) + seed = str(" ".join((self.method, self.url))) return str(uuid5(self.collection.uuid, seed)) @property def url(self): - return self.collection.api.base_url.rstrip('/') + self.path + return self.collection.api.base_url.rstrip("/") + self.path @property def headers(self): headers = {} # Handle content-type - if self.method != 'GET': - consumes = self.collection.api.__schema__.get('consumes', []) - consumes = self.operation.get('consumes', consumes) + if self.method != "GET": + consumes = self.collection.api.__schema__.get("consumes", []) + consumes = self.operation.get("consumes", consumes) if len(consumes): - headers['Content-Type'] = consumes[-1] + headers["Content-Type"] = consumes[-1] # Add all parameters headers - for param in self.operation.get('parameters', []): - if param['in'] == 'header': - headers[param['name']] = param.get('default', '') + for param in self.operation.get("parameters", []): + if param["in"] == "header": + headers[param["name"]] = param.get("default", "") # Add security headers if needed (global then local) - for security in self.collection.api.__schema__.get('security', []): + for security in self.collection.api.__schema__.get("security", []): for key, header in iteritems(self.collection.apikeys): if key in security: - headers[header] = '' - for security in self.operation.get('security', []): + headers[header] = "" + for security in self.operation.get("security", []): for key, header in iteritems(self.collection.apikeys): if key in security: - headers[header] = '' + headers[header] = "" - lines = [':'.join(line) for line in iteritems(headers)] - return '\n'.join(lines) + lines = [":".join(line) for line in iteritems(headers)] + return "\n".join(lines) @property def folder(self): - if 'tags' not in self.operation or len(self.operation['tags']) == 0: + if "tags" not in self.operation or len(self.operation["tags"]) == 0: return - tag = self.operation['tags'][0] + tag = self.operation["tags"][0] for folder in self.collection.folders: if folder.tag == tag: return folder.id def as_dict(self, urlvars=False): url, variables = self.process_url(urlvars) - return clean({ - 'id': self.id, - 'method': self.method, - 'name': self.operation['operationId'], - 'description': self.operation.get('summary'), - 'url': url, - 'headers': self.headers, - 'collectionId': self.collection.id, - 'folder': self.folder, - 'pathVariables': variables, - 'time': int(time()), - }) + return clean( + { + "id": self.id, + "method": self.method, + "name": self.operation["operationId"], + "description": self.operation.get("summary"), + "url": url, + "headers": self.headers, + "collectionId": self.collection.id, + "folder": self.folder, + "pathVariables": variables, + "time": int(time()), + } + ) def process_url(self, urlvars=False): url = self.url path_vars = {} url_vars = {} - params = dict((p['name'], p) for p in self.params) - params.update(dict((p['name'], p) for p in self.operation.get('parameters', []))) + params = dict((p["name"], p) for p in self.params) + params.update( + dict((p["name"], p) for p in self.operation.get("parameters", [])) + ) if not params: return url, None for name, param in iteritems(params): - if param['in'] == 'path': - url = url.replace('{%s}' % name, ':%s' % name) - path_vars[name] = DEFAULT_VARS.get(param['type'], '') - elif param['in'] == 'query' and urlvars: - default = DEFAULT_VARS.get(param['type'], '') - url_vars[name] = param.get('default', default) + if param["in"] == "path": + url = url.replace("{%s}" % name, ":%s" % name) + path_vars[name] = DEFAULT_VARS.get(param["type"], "") + elif param["in"] == "query" and urlvars: + default = DEFAULT_VARS.get(param["type"], "") + url_vars[name] = param.get("default", default) if url_vars: - url = '?'.join((url, urlencode(url_vars))) + url = "?".join((url, urlencode(url_vars))) return url, path_vars class Folder(object): def __init__(self, collection, tag): self.collection = collection - self.tag = tag['name'] - self.description = tag['description'] + self.tag = tag["name"] + self.description = tag["description"] @property def id(self): @@ -122,23 +127,23 @@ def id(self): @property def order(self): - return [ - r.id for r in self.collection.requests - if r.folder == self.id - ] + return [r.id for r in self.collection.requests if r.folder == self.id] def as_dict(self): - return clean({ - 'id': self.id, - 'name': self.tag, - 'description': self.description, - 'order': self.order, - 'collectionId': self.collection.id - }) + return clean( + { + "id": self.id, + "name": self.tag, + "description": self.description, + "order": self.order, + "collectionId": self.collection.id, + } + ) class PostmanCollectionV1(object): - '''Postman Collection (V1 format) serializer''' + """Postman Collection (V1 format) serializer""" + def __init__(self, api, swagger=False): self.api = api self.swagger = swagger @@ -155,38 +160,48 @@ def id(self): def requests(self): if self.swagger: # First request is Swagger specifications - yield Request(self, '/swagger.json', {}, 'get', { - 'operationId': 'Swagger specifications', - 'summary': 'The API Swagger specifications as JSON', - }) + yield Request( + self, + "/swagger.json", + {}, + "get", + { + "operationId": "Swagger specifications", + "summary": "The API Swagger specifications as JSON", + }, + ) # Then iter over API paths and methods - for path, operations in iteritems(self.api.__schema__['paths']): - path_params = operations.get('parameters', []) + for path, operations in iteritems(self.api.__schema__["paths"]): + path_params = operations.get("parameters", []) for method, operation in iteritems(operations): - if method != 'parameters': + if method != "parameters": yield Request(self, path, path_params, method, operation) @property def folders(self): - for tag in self.api.__schema__['tags']: + for tag in self.api.__schema__["tags"]: yield Folder(self, tag) @property def apikeys(self): return dict( - (name, secdef['name']) - for name, secdef in iteritems(self.api.__schema__.get('securityDefinitions')) - if secdef.get('in') == 'header' and secdef.get('type') == 'apiKey' + (name, secdef["name"]) + for name, secdef in iteritems( + self.api.__schema__.get("securityDefinitions") + ) + if secdef.get("in") == "header" and secdef.get("type") == "apiKey" ) def as_dict(self, urlvars=False): - return clean({ - 'id': self.id, - 'name': ' '.join((self.api.title, self.api.version)), - 'description': self.api.description, - 'order': [r.id for r in self.requests if not r.folder], - 'requests': [r.as_dict(urlvars=urlvars) for r in self.requests], - 'folders': [f.as_dict() for f in self.folders], - 'timestamp': int(time()), - }) + return clean( + { + "id": self.id, + "name": " ".join((self.api.title, self.api.version)), + "description": self.api.description, + "order": [r.id for r in self.requests if not r.folder], + "requests": [r.as_dict(urlvars=urlvars) for r in self.requests], + "folders": [f.as_dict() for f in self.folders], + "timestamp": int(time()), + } + ) diff --git a/flask_restx/representations.py b/flask_restx/representations.py index 0aaece0c..f250d0f6 100644 --- a/flask_restx/representations.py +++ b/flask_restx/representations.py @@ -10,15 +10,15 @@ def output_json(data, code, headers=None): - '''Makes a Flask response with a JSON encoded body''' + """Makes a Flask response with a JSON encoded body""" - settings = current_app.config.get('RESTX_JSON', {}) + settings = current_app.config.get("RESTX_JSON", {}) # If we're in debug mode, and the indent is not set, we set it to a # reasonable value here. Note that this won't override any existing value # that was set. if current_app.debug: - settings.setdefault('indent', 4) + settings.setdefault("indent", 4) # always end the json dumps with a new line # see https://github.com/mitsuhiko/flask/pull/1262 diff --git a/flask_restx/reqparse.py b/flask_restx/reqparse.py index c840aab3..b36118d5 100644 --- a/flask_restx/reqparse.py +++ b/flask_restx/reqparse.py @@ -21,9 +21,10 @@ class ParseResult(dict): - ''' + """ The default result container as an Object dict. - ''' + """ + def __getattr__(self, name): try: return self[name] @@ -35,41 +36,41 @@ def __setattr__(self, name, value): _friendly_location = { - 'json': 'the JSON body', - 'form': 'the post body', - 'args': 'the query string', - 'values': 'the post body or the query string', - 'headers': 'the HTTP headers', - 'cookies': 'the request\'s cookies', - 'files': 'an uploaded file', + "json": "the JSON body", + "form": "the post body", + "args": "the query string", + "values": "the post body or the query string", + "headers": "the HTTP headers", + "cookies": "the request's cookies", + "files": "an uploaded file", } #: Maps Flask-RESTX RequestParser locations to Swagger ones LOCATIONS = { - 'args': 'query', - 'form': 'formData', - 'headers': 'header', - 'json': 'body', - 'values': 'query', - 'files': 'formData', + "args": "query", + "form": "formData", + "headers": "header", + "json": "body", + "values": "query", + "files": "formData", } #: Maps Python primitives types to Swagger ones PY_TYPES = { - int: 'integer', - str: 'string', - bool: 'boolean', - float: 'number', - None: 'void' + int: "integer", + str: "string", + bool: "boolean", + float: "number", + None: "void", } -SPLIT_CHAR = ',' +SPLIT_CHAR = "," text_type = lambda x: six.text_type(x) # noqa class Argument(object): - ''' + """ :param name: Either a name or a list of option strings, e.g. foo or -f, --foo. :param default: The value produced if the argument is absent from the request. :param dest: The name of the attribute to be added to the object @@ -95,13 +96,26 @@ class Argument(object): be stored if the argument is missing from the request. :param bool trim: If enabled, trims whitespace around the argument. :param bool nullable: If enabled, allows null value in argument. - ''' - - def __init__(self, name, default=None, dest=None, required=False, - ignore=False, type=text_type, location=('json', 'values',), - choices=(), action='store', help=None, operators=('=',), - case_sensitive=True, store_missing=True, trim=False, - nullable=True): + """ + + def __init__( + self, + name, + default=None, + dest=None, + required=False, + ignore=False, + type=text_type, + location=("json", "values",), + choices=(), + action="store", + help=None, + operators=("=",), + case_sensitive=True, + store_missing=True, + trim=False, + nullable=True, + ): self.name = name self.default = default self.dest = dest @@ -119,10 +133,10 @@ def __init__(self, name, default=None, dest=None, required=False, self.nullable = nullable def source(self, request): - ''' + """ Pulls values off the request in the provided location :param request: The flask request object to parse arguments from - ''' + """ if isinstance(self.location, six.string_types): value = getattr(request, self.location, MultiDict()) if callable(value): @@ -145,7 +159,7 @@ def convert(self, value, op): # Don't cast None if value is None: if not self.nullable: - raise ValueError('Must not be null!') + raise ValueError("Must not be null!") return None elif isinstance(self.type, Model) and isinstance(value, dict): @@ -168,7 +182,7 @@ def convert(self, value, op): return self.type(value) def handle_validation_error(self, error, bundle_errors): - ''' + """ Called when an error is raised while parsing. Aborts the request with a 400 status and an error message @@ -176,17 +190,19 @@ def handle_validation_error(self, error, bundle_errors): :param bool bundle_errors: do not abort when first error occurs, return a dict with the name of the argument and the error message to be bundled - ''' + """ error_str = six.text_type(error) - error_msg = ' '.join([six.text_type(self.help), error_str]) if self.help else error_str + error_msg = ( + " ".join([six.text_type(self.help), error_str]) if self.help else error_str + ) errors = {self.name: error_msg} if bundle_errors: return ValueError(error), errors - abort(HTTPStatus.BAD_REQUEST, 'Input payload validation failed', errors=errors) + abort(HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors) def parse(self, request, bundle_errors=False): - ''' + """ Parses argument value(s) from the request, converting according to the argument's type. @@ -194,8 +210,8 @@ def parse(self, request, bundle_errors=False): :param bool bundle_errors: do not abort when first error occurs, return a dict with the name of the argument and the error message to be bundled - ''' - bundle_errors = current_app.config.get('BUNDLE_ERRORS', False) or bundle_errors + """ + bundle_errors = current_app.config.get("BUNDLE_ERRORS", False) or bundle_errors source = self.source(request) results = [] @@ -205,26 +221,29 @@ def parse(self, request, bundle_errors=False): _found = True for operator in self.operators: - name = self.name + operator.replace('=', '', 1) + name = self.name + operator.replace("=", "", 1) if name in source: # Account for MultiDict and regular dict - if hasattr(source, 'getlist'): + if hasattr(source, "getlist"): values = source.getlist(name) else: values = [source.get(name)] for value in values: - if hasattr(value, 'strip') and self.trim: + if hasattr(value, "strip") and self.trim: value = value.strip() - if hasattr(value, 'lower') and not self.case_sensitive: + if hasattr(value, "lower") and not self.case_sensitive: value = value.lower() - if hasattr(self.choices, '__iter__'): + if hasattr(self.choices, "__iter__"): self.choices = [choice.lower() for choice in self.choices] try: - if self.action == 'split': - value = [self.convert(v, operator) for v in value.split(SPLIT_CHAR)] + if self.action == "split": + value = [ + self.convert(v, operator) + for v in value.split(SPLIT_CHAR) + ] else: value = self.convert(value, operator) except Exception as error: @@ -233,7 +252,9 @@ def parse(self, request, bundle_errors=False): return self.handle_validation_error(error, bundle_errors) if self.choices and value not in self.choices: - msg = 'The value \'{0}\' is not a valid choice for \'{1}\'.'.format(value, name) + msg = "The value '{0}' is not a valid choice for '{1}'.".format( + value, name + ) return self.handle_validation_error(msg, bundle_errors) if name in request.unparsed_arguments: @@ -245,8 +266,8 @@ def parse(self, request, bundle_errors=False): location = _friendly_location.get(self.location, self.location) else: locations = [_friendly_location.get(loc, loc) for loc in self.location] - location = ' or '.join(locations) - error_msg = 'Missing required parameter in {0}'.format(location) + location = " or ".join(locations) + error_msg = "Missing required parameter in {0}".format(location) return self.handle_validation_error(error_msg, bundle_errors) if not results: @@ -255,44 +276,43 @@ def parse(self, request, bundle_errors=False): else: return self.default, _not_found - if self.action == 'append': + if self.action == "append": return results, _found - if self.action == 'store' or len(results) == 1: + if self.action == "store" or len(results) == 1: return results[0], _found return results, _found @property def __schema__(self): - if self.location == 'cookie': + if self.location == "cookie": return - param = { - 'name': self.name, - 'in': LOCATIONS.get(self.location, 'query') - } + param = {"name": self.name, "in": LOCATIONS.get(self.location, "query")} _handle_arg_type(self, param) if self.required: - param['required'] = True + param["required"] = True if self.help: - param['description'] = self.help + param["description"] = self.help if self.default is not None: - param['default'] = self.default() if callable(self.default) else self.default - if self.action == 'append': - param['items'] = {'type': param['type']} - param['type'] = 'array' - param['collectionFormat'] = 'multi' - if self.action == 'split': - param['items'] = {'type': param['type']} - param['type'] = 'array' - param['collectionFormat'] = 'csv' + param["default"] = ( + self.default() if callable(self.default) else self.default + ) + if self.action == "append": + param["items"] = {"type": param["type"]} + param["type"] = "array" + param["collectionFormat"] = "multi" + if self.action == "split": + param["items"] = {"type": param["type"]} + param["type"] = "array" + param["collectionFormat"] = "csv" if self.choices: - param['enum'] = self.choices - param['collectionFormat'] = 'multi' + param["enum"] = self.choices + param["collectionFormat"] = "multi" return param class RequestParser(object): - ''' + """ Enables adding and parsing of multiple arguments in the context of a single request. Ex:: @@ -307,10 +327,15 @@ class RequestParser(object): :param bool bundle_errors: If enabled, do not abort when first error occurs, return a dict with the name of the argument and the error message to be bundled and return all validation errors - ''' - - def __init__(self, argument_class=Argument, result_class=ParseResult, - trim=False, bundle_errors=False): + """ + + def __init__( + self, + argument_class=Argument, + result_class=ParseResult, + trim=False, + bundle_errors=False, + ): self.args = [] self.argument_class = argument_class self.result_class = result_class @@ -318,14 +343,14 @@ def __init__(self, argument_class=Argument, result_class=ParseResult, self.bundle_errors = bundle_errors def add_argument(self, *args, **kwargs): - ''' + """ Adds an argument to be parsed. Accepts either a single instance of Argument or arguments to be passed into :class:`Argument`'s constructor. See :class:`Argument`'s constructor for documentation on the available options. - ''' + """ if len(args) == 1 and isinstance(args[0], self.argument_class): self.args.append(args[0]) @@ -335,18 +360,18 @@ def add_argument(self, *args, **kwargs): # Do not know what other argument classes are out there if self.trim and self.argument_class is Argument: # enable trim for appended element - self.args[-1].trim = kwargs.get('trim', self.trim) + self.args[-1].trim = kwargs.get("trim", self.trim) return self def parse_args(self, req=None, strict=False): - ''' + """ Parse all arguments from the provided request and return the results as a ParseResult :param bool strict: if req includes args not in parser, throw 400 BadRequest exception :return: the parsed results as :class:`ParseResult` (or any class defined as :attr:`result_class`) :rtype: ParseResult - ''' + """ if req is None: req = request @@ -354,7 +379,9 @@ def parse_args(self, req=None, strict=False): # A record of arguments not yet parsed; as each is found # among self.args, it will be popped out - req.unparsed_arguments = dict(self.argument_class('').source(req)) if strict else {} + req.unparsed_arguments = ( + dict(self.argument_class("").source(req)) if strict else {} + ) errors = {} for arg in self.args: value, found = arg.parse(req, self.bundle_errors) @@ -364,17 +391,19 @@ def parse_args(self, req=None, strict=False): if found or arg.store_missing: result[arg.dest or arg.name] = value if errors: - abort(HTTPStatus.BAD_REQUEST, 'Input payload validation failed', errors=errors) + abort( + HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors + ) if strict and req.unparsed_arguments: - arguments = ', '.join(req.unparsed_arguments.keys()) - msg = 'Unknown arguments: {0}'.format(arguments) + arguments = ", ".join(req.unparsed_arguments.keys()) + msg = "Unknown arguments: {0}".format(arguments) raise exceptions.BadRequest(msg) return result def copy(self): - '''Creates a copy of this RequestParser with the same set of arguments''' + """Creates a copy of this RequestParser with the same set of arguments""" parser_copy = self.__class__(self.argument_class, self.result_class) parser_copy.args = deepcopy(self.args) parser_copy.trim = self.trim @@ -382,7 +411,7 @@ def copy(self): return parser_copy def replace_argument(self, name, *args, **kwargs): - '''Replace the argument matching the given name with a new version.''' + """Replace the argument matching the given name with a new version.""" new_arg = self.argument_class(name, *args, **kwargs) for index, arg in enumerate(self.args[:]): if new_arg.name == arg.name: @@ -392,7 +421,7 @@ def replace_argument(self, name, *args, **kwargs): return self def remove_argument(self, name): - '''Remove the argument matching the given name.''' + """Remove the argument matching the given name.""" for index, arg in enumerate(self.args[:]): if name == arg.name: del self.args[index] @@ -407,21 +436,21 @@ def __schema__(self): param = arg.__schema__ if param: params.append(param) - locations.add(param['in']) - if 'body' in locations and 'formData' in locations: + locations.add(param["in"]) + if "body" in locations and "formData" in locations: raise SpecsError("Can't use formData and body at the same time") return params def _handle_arg_type(arg, param): if isinstance(arg.type, Hashable) and arg.type in PY_TYPES: - param['type'] = PY_TYPES[arg.type] - elif hasattr(arg.type, '__apidoc__'): - param['type'] = arg.type.__apidoc__['name'] - param['in'] = 'body' - elif hasattr(arg.type, '__schema__'): + param["type"] = PY_TYPES[arg.type] + elif hasattr(arg.type, "__apidoc__"): + param["type"] = arg.type.__apidoc__["name"] + param["in"] = "body" + elif hasattr(arg.type, "__schema__"): param.update(arg.type.__schema__) - elif arg.location == 'files': - param['type'] = 'file' + elif arg.location == "files": + param["type"] = "file" else: - param['type'] = 'string' + param["type"] = "string" diff --git a/flask_restx/resource.py b/flask_restx/resource.py index 776d2abe..991b7a68 100644 --- a/flask_restx/resource.py +++ b/flask_restx/resource.py @@ -11,7 +11,7 @@ class Resource(MethodView): - ''' + """ Represents an abstract RESTX resource. Concrete resources should extend from this class @@ -21,7 +21,7 @@ class Resource(MethodView): Otherwise the appropriate method is called and passed all arguments from the url rule used when adding the resource to an Api instance. See :meth:`~flask_restx.Api.add_resource` for details. - ''' + """ representations = None method_decorators = [] @@ -32,9 +32,9 @@ def __init__(self, api=None, *args, **kwargs): def dispatch_request(self, *args, **kwargs): # Taken from flask meth = getattr(self, request.method.lower(), None) - if meth is None and request.method == 'HEAD': - meth = getattr(self, 'get', None) - assert meth is not None, 'Unimplemented method %r' % request.method + if meth is None and request.method == "HEAD": + meth = getattr(self, "get", None) + assert meth is not None, "Unimplemented method %r" % request.method for decorator in self.method_decorators: meth = decorator(meth) @@ -52,17 +52,17 @@ def dispatch_request(self, *args, **kwargs): if mediatype in representations: data, code, headers = unpack(resp) resp = representations[mediatype](data, code, headers) - resp.headers['Content-Type'] = mediatype + resp.headers["Content-Type"] = mediatype return resp return resp def __validate_payload(self, expect, collection=False): - ''' + """ :param ModelBase expect: the expected model for the input payload :param bool collection: False if a single object of a resource is expected, True if a collection of objects of a resource is expected. - ''' + """ # TODO: proper content negotiation data = request.get_json() if collection: @@ -73,13 +73,13 @@ def __validate_payload(self, expect, collection=False): expect.validate(data, self.api.refresolver, self.api.format_checker) def validate_payload(self, func): - '''Perform a payload validation on expected model if necessary''' - if getattr(func, '__apidoc__', False) is not False: + """Perform a payload validation on expected model if necessary""" + if getattr(func, "__apidoc__", False) is not False: doc = func.__apidoc__ - validate = doc.get('validate', None) + validate = doc.get("validate", None) validate = validate if validate is not None else self.api._validate if validate: - for expect in doc.get('expect', []): + for expect in doc.get("expect", []): # TODO: handle third party handlers if isinstance(expect, list) and len(expect) == 1: if isinstance(expect[0], ModelBase): diff --git a/flask_restx/schemas/__init__.py b/flask_restx/schemas/__init__.py index 93e92506..d6dc2ac0 100644 --- a/flask_restx/schemas/__init__.py +++ b/flask_restx/schemas/__init__.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- -''' +""" This module give access to OpenAPI specifications schemas and allows to validate specs against them. .. versionadded:: 0.12.1 -''' +""" from __future__ import unicode_literals import io @@ -22,11 +22,12 @@ class SchemaValidationError(errors.ValidationError): - ''' + """ Raised when specification is not valid .. versionadded:: 0.12.1 - ''' + """ + def __init__(self, msg, errors=None): super(SchemaValidationError, self).__init__(msg) self.errors = errors @@ -34,25 +35,26 @@ def __init__(self, msg, errors=None): def __str__(self): msg = [self.msg] for error in sorted(self.errors, key=lambda e: e.path): - path = '.'.join(error.path) - msg.append('- {}: {}'.format(path, error.message)) + path = ".".join(error.path) + msg.append("- {}: {}".format(path, error.message)) for suberror in sorted(error.context, key=lambda e: e.schema_path): - path = '.'.join(suberror.schema_path) - msg.append(' - {}: {}'.format(path, suberror.message)) - return '\n'.join(msg) + path = ".".join(suberror.schema_path) + msg.append(" - {}: {}".format(path, suberror.message)) + return "\n".join(msg) __unicode__ = __str__ class LazySchema(Mapping): - ''' + """ A thin wrapper around schema file lazy loading the data on first access :param filename str: The package relative json schema filename :param validator: The jsonschema validator class version .. versionadded:: 0.12.1 - ''' + """ + def __init__(self, filename, validator=Draft4Validator): super(LazySchema, self).__init__() self.filename = filename @@ -79,21 +81,21 @@ def __len__(self): @property def validator(self): - '''The jsonschema validator to validate against''' + """The jsonschema validator to validate against""" return self._validator(self) #: OpenAPI 2.0 specification schema -OAS_20 = LazySchema('oas-2.0.json') +OAS_20 = LazySchema("oas-2.0.json") #: Map supported OpenAPI versions to their JSON schema VERSIONS = { - '2.0': OAS_20, + "2.0": OAS_20, } def validate(data): - ''' + """ Validate an OpenAPI specification. Supported OpenAPI versions: 2.0 @@ -105,11 +107,11 @@ def validate(data): the schema to validate against .. versionadded:: 0.12.1 - ''' - if 'swagger' not in data: - raise errors.SpecsError('Unable to determinate OpenAPI schema version') + """ + if "swagger" not in data: + raise errors.SpecsError("Unable to determinate OpenAPI schema version") - version = data['swagger'] + version = data["swagger"] if version not in VERSIONS: raise errors.SpecsError('Unknown OpenAPI schema version "{}"'.format(version)) @@ -117,6 +119,7 @@ def validate(data): validation_errors = list(validator.iter_errors(data)) if validation_errors: - raise SchemaValidationError('OpenAPI {} validation failed'.format(version), - errors=validation_errors) + raise SchemaValidationError( + "OpenAPI {} validation failed".format(version), errors=validation_errors + ) return True diff --git a/flask_restx/swagger.py b/flask_restx/swagger.py index 4f9fb93f..439d6380 100644 --- a/flask_restx/swagger.py +++ b/flask_restx/swagger.py @@ -6,6 +6,7 @@ from inspect import isclass, getdoc from collections import OrderedDict + try: from collections.abc import Hashable except ImportError: @@ -29,119 +30,121 @@ #: Maps Flask/Werkzeug rooting types to Swagger ones PATH_TYPES = { - 'int': 'integer', - 'float': 'number', - 'string': 'string', - 'default': 'string', + "int": "integer", + "float": "number", + "string": "string", + "default": "string", } #: Maps Python primitives types to Swagger ones PY_TYPES = { - int: 'integer', - float: 'number', - str: 'string', - bool: 'boolean', - None: 'void' + int: "integer", + float: "number", + str: "string", + bool: "boolean", + None: "void", } -RE_URL = re.compile(r'<(?:[^:<>]+:)?([^<>]+)>') +RE_URL = re.compile(r"<(?:[^:<>]+:)?([^<>]+)>") -DEFAULT_RESPONSE_DESCRIPTION = 'Success' -DEFAULT_RESPONSE = {'description': DEFAULT_RESPONSE_DESCRIPTION} +DEFAULT_RESPONSE_DESCRIPTION = "Success" +DEFAULT_RESPONSE = {"description": DEFAULT_RESPONSE_DESCRIPTION} -RE_RAISES = re.compile(r'^:raises\s+(?P[\w\d_]+)\s*:\s*(?P.*)$', re.MULTILINE) +RE_RAISES = re.compile( + r"^:raises\s+(?P[\w\d_]+)\s*:\s*(?P.*)$", re.MULTILINE +) def ref(model): - '''Return a reference to model in definitions''' + """Return a reference to model in definitions""" name = model.name if isinstance(model, ModelBase) else model - return {'$ref': '#/definitions/{0}'.format(quote(name, safe=''))} + return {"$ref": "#/definitions/{0}".format(quote(name, safe=""))} def _v(value): - '''Dereference values (callable)''' + """Dereference values (callable)""" return value() if callable(value) else value def extract_path(path): - ''' + """ Transform a Flask/Werkzeug URL pattern in a Swagger one. - ''' - return RE_URL.sub(r'{\1}', path) + """ + return RE_URL.sub(r"{\1}", path) def extract_path_params(path): - ''' + """ Extract Flask-style parameters from an URL pattern as Swagger ones. - ''' + """ params = OrderedDict() for converter, arguments, variable in parse_rule(path): if not converter: continue - param = { - 'name': variable, - 'in': 'path', - 'required': True - } + param = {"name": variable, "in": "path", "required": True} if converter in PATH_TYPES: - param['type'] = PATH_TYPES[converter] + param["type"] = PATH_TYPES[converter] elif converter in current_app.url_map.converters: - param['type'] = 'string' + param["type"] = "string" else: - raise ValueError('Unsupported type converter: %s' % converter) + raise ValueError("Unsupported type converter: %s" % converter) params[variable] = param return params def _param_to_header(param): - param.pop('in', None) - param.pop('name', None) + param.pop("in", None) + param.pop("name", None) return _clean_header(param) def _clean_header(header): if isinstance(header, string_types): - header = {'description': header} - typedef = header.get('type', 'string') + header = {"description": header} + typedef = header.get("type", "string") if isinstance(typedef, Hashable) and typedef in PY_TYPES: - header['type'] = PY_TYPES[typedef] - elif isinstance(typedef, (list, tuple)) and len(typedef) == 1 and typedef[0] in PY_TYPES: - header['type'] = 'array' - header['items'] = {'type': PY_TYPES[typedef[0]]} - elif hasattr(typedef, '__schema__'): + header["type"] = PY_TYPES[typedef] + elif ( + isinstance(typedef, (list, tuple)) + and len(typedef) == 1 + and typedef[0] in PY_TYPES + ): + header["type"] = "array" + header["items"] = {"type": PY_TYPES[typedef[0]]} + elif hasattr(typedef, "__schema__"): header.update(typedef.__schema__) else: - header['type'] = typedef + header["type"] = typedef return not_none(header) def parse_docstring(obj): raw = getdoc(obj) - summary = raw.strip(' \n').split('\n')[0].split('.')[0] if raw else None + summary = raw.strip(" \n").split("\n")[0].split(".")[0] if raw else None raises = {} - details = raw.replace(summary, '').lstrip('. \n').strip(' \n') if raw else None - for match in RE_RAISES.finditer(raw or ''): - raises[match.group('name')] = match.group('description') + details = raw.replace(summary, "").lstrip(". \n").strip(" \n") if raw else None + for match in RE_RAISES.finditer(raw or ""): + raises[match.group("name")] = match.group("description") if details: - details = details.replace(match.group(0), '') + details = details.replace(match.group(0), "") parsed = { - 'raw': raw, - 'summary': summary or None, - 'details': details or None, - 'returns': None, - 'params': [], - 'raises': raises, + "raw": raw, + "summary": summary or None, + "details": details or None, + "returns": None, + "params": [], + "raises": raises, } return parsed def is_hidden(resource, route_doc=None): - ''' + """ Determine whether a Resource has been hidden from Swagger documentation i.e. by using Api.doc(False) decorator - ''' + """ if route_doc is False: return True else: @@ -175,57 +178,53 @@ def build_request_body_parameters_schema(body_params): properties = {} for param in body_params: - properties[param['name']] = { - 'type': param.get('type', 'string') - } + properties[param["name"]] = {"type": param.get("type", "string")} return { - 'name': 'payload', - 'required': True, - 'in': 'body', - 'schema': { - 'type': 'object', - 'properties': properties - } + "name": "payload", + "required": True, + "in": "body", + "schema": {"type": "object", "properties": properties}, } class Swagger(object): - ''' + """ A Swagger documentation wrapper for an API instance. - ''' + """ + def __init__(self, api): self.api = api self._registered_models = {} def as_dict(self): - ''' + """ Output the specification as a serializable ``dict``. :returns: the full Swagger specification in a serializable format :rtype: dict - ''' + """ basepath = self.api.base_path - if len(basepath) > 1 and basepath.endswith('/'): + if len(basepath) > 1 and basepath.endswith("/"): basepath = basepath[:-1] infos = { - 'title': _v(self.api.title), - 'version': _v(self.api.version), + "title": _v(self.api.title), + "version": _v(self.api.version), } if self.api.description: - infos['description'] = _v(self.api.description) + infos["description"] = _v(self.api.description) if self.api.terms_url: - infos['termsOfService'] = _v(self.api.terms_url) + infos["termsOfService"] = _v(self.api.terms_url) if self.api.contact and (self.api.contact_email or self.api.contact_url): - infos['contact'] = { - 'name': _v(self.api.contact), - 'email': _v(self.api.contact_email), - 'url': _v(self.api.contact_url), + infos["contact"] = { + "name": _v(self.api.contact), + "email": _v(self.api.contact_email), + "url": _v(self.api.contact_url), } if self.api.license: - infos['license'] = {'name': _v(self.api.license)} + infos["license"] = {"name": _v(self.api.license)} if self.api.license_url: - infos['license']['url'] = _v(self.api.license_url) + infos["license"]["url"] = _v(self.api.license_url) paths = {} tags = self.extract_tags(self.api) @@ -238,11 +237,7 @@ def as_dict(self): for url in self.api.ns_urls(ns, urls): path = extract_path(url) serialized = self.serialize_resource( - ns, - resource, - url, - route_doc=route_doc, - **kwargs + ns, resource, url, route_doc=route_doc, **kwargs ) paths[path] = serialized @@ -251,28 +246,30 @@ def as_dict(self): if ns.authorizations: if self.api.authorizations is None: self.api.authorizations = {} - self.api.authorizations = merge(self.api.authorizations, ns.authorizations) + self.api.authorizations = merge( + self.api.authorizations, ns.authorizations + ) specs = { - 'swagger': '2.0', - 'basePath': basepath, - 'paths': not_none_sorted(paths), - 'info': infos, - 'produces': list(iterkeys(self.api.representations)), - 'consumes': ['application/json'], - 'securityDefinitions': self.api.authorizations or None, - 'security': self.security_requirements(self.api.security) or None, - 'tags': tags, - 'definitions': self.serialize_definitions() or None, - 'responses': responses or None, - 'host': self.get_host(), + "swagger": "2.0", + "basePath": basepath, + "paths": not_none_sorted(paths), + "info": infos, + "produces": list(iterkeys(self.api.representations)), + "consumes": ["application/json"], + "securityDefinitions": self.api.authorizations or None, + "security": self.security_requirements(self.api.security) or None, + "tags": tags, + "definitions": self.serialize_definitions() or None, + "responses": responses or None, + "host": self.get_host(), } return not_none(specs) def get_host(self): - hostname = current_app.config.get('SERVER_NAME', None) or None + hostname = current_app.config.get("SERVER_NAME", None) or None if hostname and self.api.blueprint and self.api.blueprint.subdomain: - hostname = '.'.join((self.api.blueprint.subdomain, hostname)) + hostname = ".".join((self.api.blueprint.subdomain, hostname)) return hostname def extract_tags(self, api): @@ -280,72 +277,72 @@ def extract_tags(self, api): by_name = {} for tag in api.tags: if isinstance(tag, string_types): - tag = {'name': tag} + tag = {"name": tag} elif isinstance(tag, (list, tuple)): - tag = {'name': tag[0], 'description': tag[1]} - elif isinstance(tag, dict) and 'name' in tag: + tag = {"name": tag[0], "description": tag[1]} + elif isinstance(tag, dict) and "name" in tag: pass else: - raise ValueError('Unsupported tag format for {0}'.format(tag)) + raise ValueError("Unsupported tag format for {0}".format(tag)) tags.append(tag) - by_name[tag['name']] = tag + by_name[tag["name"]] = tag for ns in api.namespaces: # hide namespaces without any Resources if not ns.resources: continue # hide namespaces with all Resources hidden from Swagger documentation - if all( - is_hidden(r.resource, route_doc=r.route_doc) - for r in ns.resources - ): + if all(is_hidden(r.resource, route_doc=r.route_doc) for r in ns.resources): continue if ns.name not in by_name: - tags.append({ - 'name': ns.name, - 'description': ns.description - } if ns.description else {'name': ns.name}) + tags.append( + {"name": ns.name, "description": ns.description} + if ns.description + else {"name": ns.name} + ) elif ns.description: - by_name[ns.name]['description'] = ns.description + by_name[ns.name]["description"] = ns.description return tags def extract_resource_doc(self, resource, url, route_doc=None): route_doc = {} if route_doc is None else route_doc if route_doc is False: return False - doc = merge(getattr(resource, '__apidoc__', {}), route_doc) + doc = merge(getattr(resource, "__apidoc__", {}), route_doc) if doc is False: return False # ensure unique names for multiple routes to the same resource # provides different Swagger operationId's doc["name"] = ( - "{}_{}".format(resource.__name__, url) - if route_doc - else resource.__name__ + "{}_{}".format(resource.__name__, url) if route_doc else resource.__name__ ) - params = merge(self.expected_params(doc), doc.get('params', OrderedDict())) + params = merge(self.expected_params(doc), doc.get("params", OrderedDict())) params = merge(params, extract_path_params(url)) # Track parameters for late deduplication - up_params = {(n, p.get('in', 'query')): p for n, p in params.items()} + up_params = {(n, p.get("in", "query")): p for n, p in params.items()} need_to_go_down = set() methods = [m.lower() for m in resource.methods or []] for method in methods: method_doc = doc.get(method, OrderedDict()) method_impl = getattr(resource, method) - if hasattr(method_impl, 'im_func'): + if hasattr(method_impl, "im_func"): method_impl = method_impl.im_func - elif hasattr(method_impl, '__func__'): + elif hasattr(method_impl, "__func__"): method_impl = method_impl.__func__ - method_doc = merge(method_doc, getattr(method_impl, '__apidoc__', OrderedDict())) + method_doc = merge( + method_doc, getattr(method_impl, "__apidoc__", OrderedDict()) + ) if method_doc is not False: - method_doc['docstring'] = parse_docstring(method_impl) + method_doc["docstring"] = parse_docstring(method_impl) method_params = self.expected_params(method_doc) - method_params = merge(method_params, method_doc.get('params', {})) - inherited_params = OrderedDict((k, v) for k, v in iteritems(params) if k in method_params) - method_doc['params'] = merge(inherited_params, method_params) - for name, param in method_doc['params'].items(): - key = (name, param.get('in', 'query')) + method_params = merge(method_params, method_doc.get("params", {})) + inherited_params = OrderedDict( + (k, v) for k, v in iteritems(params) if k in method_params + ) + method_doc["params"] = merge(inherited_params, method_params) + for name, param in method_doc["params"].items(): + key = (name, param.get("in", "query")) if key in up_params: need_to_go_down.add(key) doc[method] = method_doc @@ -358,69 +355,77 @@ def extract_resource_doc(self, resource, url, route_doc=None): if not method_doc: continue params = { - (n, p.get('in', 'query')): p - for n, p in (method_doc['params'] or {}).items() + (n, p.get("in", "query")): p + for n, p in (method_doc["params"] or {}).items() } for key in need_to_go_down: if key not in params: - method_doc['params'][key[0]] = up_params[key] - doc['params'] = OrderedDict( + method_doc["params"][key[0]] = up_params[key] + doc["params"] = OrderedDict( (k[0], p) for k, p in up_params.items() if k not in need_to_go_down ) return doc def expected_params(self, doc): params = OrderedDict() - if 'expect' not in doc: + if "expect" not in doc: return params - for expect in doc.get('expect', []): + for expect in doc.get("expect", []): if isinstance(expect, RequestParser): - parser_params = OrderedDict((p['name'], p) for p in expect.__schema__ if p['in'] != 'body') + parser_params = OrderedDict( + (p["name"], p) for p in expect.__schema__ if p["in"] != "body" + ) params.update(parser_params) - body_params = [p for p in expect.__schema__ if p['in'] == 'body'] + body_params = [p for p in expect.__schema__ if p["in"] == "body"] if body_params: - params['payload'] = build_request_body_parameters_schema(body_params) + params["payload"] = build_request_body_parameters_schema( + body_params + ) elif isinstance(expect, ModelBase): - params['payload'] = not_none({ - 'name': 'payload', - 'required': True, - 'in': 'body', - 'schema': self.serialize_schema(expect), - }) + params["payload"] = not_none( + { + "name": "payload", + "required": True, + "in": "body", + "schema": self.serialize_schema(expect), + } + ) elif isinstance(expect, (list, tuple)): if len(expect) == 2: # this is (payload, description) shortcut model, description = expect - params['payload'] = not_none({ - 'name': 'payload', - 'required': True, - 'in': 'body', - 'schema': self.serialize_schema(model), - 'description': description - }) + params["payload"] = not_none( + { + "name": "payload", + "required": True, + "in": "body", + "schema": self.serialize_schema(model), + "description": description, + } + ) else: - params['payload'] = not_none({ - 'name': 'payload', - 'required': True, - 'in': 'body', - 'schema': self.serialize_schema(expect), - }) + params["payload"] = not_none( + { + "name": "payload", + "required": True, + "in": "body", + "schema": self.serialize_schema(expect), + } + ) return params def register_errors(self): responses = {} for exception, handler in iteritems(self.api.error_handlers): doc = parse_docstring(handler) - response = { - 'description': doc['summary'] - } - apidoc = getattr(handler, '__apidoc__', {}) + response = {"description": doc["summary"]} + apidoc = getattr(handler, "__apidoc__", {}) self.process_headers(response, apidoc) - if 'responses' in apidoc: - _, model, _ = list(apidoc['responses'].values())[0] - response['schema'] = self.serialize_schema(model) + if "responses" in apidoc: + _, model, _ = list(apidoc["responses"].values())[0] + response["schema"] = self.serialize_schema(model) responses[exception.__name__] = not_none(response) return responses @@ -428,103 +433,108 @@ def serialize_resource(self, ns, resource, url, route_doc=None, **kwargs): doc = self.extract_resource_doc(resource, url, route_doc=route_doc) if doc is False: return - path = { - 'parameters': self.parameters_for(doc) or None - } + path = {"parameters": self.parameters_for(doc) or None} for method in [m.lower() for m in resource.methods or []]: - methods = [m.lower() for m in kwargs.get('methods', [])] + methods = [m.lower() for m in kwargs.get("methods", [])] if doc[method] is False or methods and method not in methods: continue path[method] = self.serialize_operation(doc, method) - path[method]['tags'] = [ns.name] + path[method]["tags"] = [ns.name] return not_none(path) def serialize_operation(self, doc, method): operation = { - 'responses': self.responses_for(doc, method) or None, - 'summary': doc[method]['docstring']['summary'], - 'description': self.description_for(doc, method) or None, - 'operationId': self.operation_id_for(doc, method), - 'parameters': self.parameters_for(doc[method]) or None, - 'security': self.security_for(doc, method), + "responses": self.responses_for(doc, method) or None, + "summary": doc[method]["docstring"]["summary"], + "description": self.description_for(doc, method) or None, + "operationId": self.operation_id_for(doc, method), + "parameters": self.parameters_for(doc[method]) or None, + "security": self.security_for(doc, method), } # Handle 'produces' mimetypes documentation - if 'produces' in doc[method]: - operation['produces'] = doc[method]['produces'] + if "produces" in doc[method]: + operation["produces"] = doc[method]["produces"] # Handle deprecated annotation - if doc.get('deprecated') or doc[method].get('deprecated'): - operation['deprecated'] = True + if doc.get("deprecated") or doc[method].get("deprecated"): + operation["deprecated"] = True # Handle form exceptions: - doc_params = list(doc.get('params', {}).values()) - all_params = doc_params + (operation['parameters'] or []) - if all_params and any(p['in'] == 'formData' for p in all_params): - if any(p['type'] == 'file' for p in all_params): - operation['consumes'] = ['multipart/form-data'] + doc_params = list(doc.get("params", {}).values()) + all_params = doc_params + (operation["parameters"] or []) + if all_params and any(p["in"] == "formData" for p in all_params): + if any(p["type"] == "file" for p in all_params): + operation["consumes"] = ["multipart/form-data"] else: - operation['consumes'] = ['application/x-www-form-urlencoded', 'multipart/form-data'] + operation["consumes"] = [ + "application/x-www-form-urlencoded", + "multipart/form-data", + ] operation.update(self.vendor_fields(doc, method)) return not_none(operation) def vendor_fields(self, doc, method): - ''' + """ Extract custom 3rd party Vendor fields prefixed with ``x-`` See: http://swagger.io/specification/#specification-extensions-128 - ''' + """ return dict( - (k if k.startswith('x-') else 'x-{0}'.format(k), v) - for k, v in iteritems(doc[method].get('vendor', {})) + (k if k.startswith("x-") else "x-{0}".format(k), v) + for k, v in iteritems(doc[method].get("vendor", {})) ) def description_for(self, doc, method): - '''Extract the description metadata and fallback on the whole docstring''' + """Extract the description metadata and fallback on the whole docstring""" parts = [] - if 'description' in doc: - parts.append(doc['description'] or "") - if method in doc and 'description' in doc[method]: - parts.append(doc[method]['description']) - if doc[method]['docstring']['details']: - parts.append(doc[method]['docstring']['details']) + if "description" in doc: + parts.append(doc["description"] or "") + if method in doc and "description" in doc[method]: + parts.append(doc[method]["description"]) + if doc[method]["docstring"]["details"]: + parts.append(doc[method]["docstring"]["details"]) - return '\n'.join(parts).strip() + return "\n".join(parts).strip() def operation_id_for(self, doc, method): - '''Extract the operation id''' - return doc[method]['id'] if 'id' in doc[method] else self.api.default_id(doc['name'], method) + """Extract the operation id""" + return ( + doc[method]["id"] + if "id" in doc[method] + else self.api.default_id(doc["name"], method) + ) def parameters_for(self, doc): params = [] - for name, param in iteritems(doc['params']): - param['name'] = name - if 'type' not in param and 'schema' not in param: - param['type'] = 'string' - if 'in' not in param: - param['in'] = 'query' - - if 'type' in param and 'schema' not in param: - ptype = param.get('type', None) + for name, param in iteritems(doc["params"]): + param["name"] = name + if "type" not in param and "schema" not in param: + param["type"] = "string" + if "in" not in param: + param["in"] = "query" + + if "type" in param and "schema" not in param: + ptype = param.get("type", None) if isinstance(ptype, (list, tuple)): typ = ptype[0] - param['type'] = 'array' - param['items'] = {'type': PY_TYPES.get(typ, typ)} + param["type"] = "array" + param["items"] = {"type": PY_TYPES.get(typ, typ)} elif isinstance(ptype, (type, type(None))) and ptype in PY_TYPES: - param['type'] = PY_TYPES[ptype] + param["type"] = PY_TYPES[ptype] params.append(param) # Handle fields mask - mask = doc.get('__mask__') - if (mask and current_app.config['RESTX_MASK_SWAGGER']): + mask = doc.get("__mask__") + if mask and current_app.config["RESTX_MASK_SWAGGER"]: param = { - 'name': current_app.config['RESTX_MASK_HEADER'], - 'in': 'header', - 'type': 'string', - 'format': 'mask', - 'description': 'An optional fields mask', + "name": current_app.config["RESTX_MASK_HEADER"], + "in": "header", + "type": "string", + "format": "mask", + "description": "An optional fields mask", } if isinstance(mask, string_types): - param['default'] = mask + param["default"] = mask params.append(param) return params @@ -534,8 +544,8 @@ def responses_for(self, doc, method): responses = {} for d in doc, doc[method]: - if 'responses' in d: - for code, response in iteritems(d['responses']): + if "responses" in d: + for code, response in iteritems(d["responses"]): code = str(code) if isinstance(response, string_types): description = response @@ -547,47 +557,59 @@ def responses_for(self, doc, method): description, model = response kwargs = {} else: - raise ValueError('Unsupported response specification') + raise ValueError("Unsupported response specification") description = description or DEFAULT_RESPONSE_DESCRIPTION if code in responses: responses[code].update(description=description) else: - responses[code] = {'description': description} + responses[code] = {"description": description} if model: schema = self.serialize_schema(model) - envelope = kwargs.get('envelope') + envelope = kwargs.get("envelope") if envelope: - schema = {'properties': {envelope: schema}} - responses[code]['schema'] = schema - self.process_headers(responses[code], doc, method, kwargs.get('headers')) - if 'model' in d: - code = str(d.get('default_code', HTTPStatus.OK)) + schema = {"properties": {envelope: schema}} + responses[code]["schema"] = schema + self.process_headers( + responses[code], doc, method, kwargs.get("headers") + ) + if "model" in d: + code = str(d.get("default_code", HTTPStatus.OK)) if code not in responses: - responses[code] = self.process_headers(DEFAULT_RESPONSE.copy(), doc, method) - responses[code]['schema'] = self.serialize_schema(d['model']) + responses[code] = self.process_headers( + DEFAULT_RESPONSE.copy(), doc, method + ) + responses[code]["schema"] = self.serialize_schema(d["model"]) - if 'docstring' in d: - for name, description in iteritems(d['docstring']['raises']): + if "docstring" in d: + for name, description in iteritems(d["docstring"]["raises"]): for exception, handler in iteritems(self.api.error_handlers): - error_responses = getattr(handler, '__apidoc__', {}).get('responses', {}) - code = str(list(error_responses.keys())[0]) if error_responses else None + error_responses = getattr(handler, "__apidoc__", {}).get( + "responses", {} + ) + code = ( + str(list(error_responses.keys())[0]) + if error_responses + else None + ) if code and exception.__name__ == name: - responses[code] = {'$ref': '#/responses/{0}'.format(name)} + responses[code] = {"$ref": "#/responses/{0}".format(name)} break if not responses: - responses[str(HTTPStatus.OK.value)] = self.process_headers(DEFAULT_RESPONSE.copy(), doc, method) + responses[str(HTTPStatus.OK.value)] = self.process_headers( + DEFAULT_RESPONSE.copy(), doc, method + ) return responses def process_headers(self, response, doc, method=None, headers=None): method_doc = doc.get(method, {}) - if 'headers' in doc or 'headers' in method_doc or headers: - response['headers'] = dict( - (k, _clean_header(v)) for k, v - in itertools.chain( - iteritems(doc.get('headers', {})), - iteritems(method_doc.get('headers', {})), - iteritems(headers or {}) + if "headers" in doc or "headers" in method_doc or headers: + response["headers"] = dict( + (k, _clean_header(v)) + for k, v in itertools.chain( + iteritems(doc.get("headers", {})), + iteritems(method_doc.get("headers", {})), + iteritems(headers or {}), ) ) return response @@ -602,8 +624,8 @@ def serialize_schema(self, model): if isinstance(model, (list, tuple)): model = model[0] return { - 'type': 'array', - 'items': self.serialize_schema(model), + "type": "array", + "items": self.serialize_schema(model), } elif isinstance(model, ModelBase): @@ -621,14 +643,14 @@ def serialize_schema(self, model): return model.__schema__ elif isinstance(model, (type, type(None))) and model in PY_TYPES: - return {'type': PY_TYPES[model]} + return {"type": PY_TYPES[model]} - raise ValueError('Model {0} not registered'.format(model)) + raise ValueError("Model {0} not registered".format(model)) def register_model(self, model): name = model.name if isinstance(model, ModelBase) else model if name not in self.api.models: - raise ValueError('Model {0} not registered'.format(name)) + raise ValueError("Model {0} not registered".format(name)) specs = self.api.models[name] self._registered_models[name] = specs if isinstance(specs, ModelBase): @@ -650,12 +672,12 @@ def register_field(self, field): def security_for(self, doc, method): security = None - if 'security' in doc: - auth = doc['security'] + if "security" in doc: + auth = doc["security"] security = self.security_requirements(auth) - if 'security' in doc[method]: - auth = doc[method]['security'] + if "security" in doc[method]: + auth = doc[method]["security"] security = self.security_requirements(auth) return security diff --git a/flask_restx/utils.py b/flask_restx/utils.py index 9ec17721..5ba79f7b 100644 --- a/flask_restx/utils.py +++ b/flask_restx/utils.py @@ -10,15 +10,22 @@ from ._http import HTTPStatus -FIRST_CAP_RE = re.compile('(.)([A-Z][a-z]+)') -ALL_CAP_RE = re.compile('([a-z0-9])([A-Z])') +FIRST_CAP_RE = re.compile("(.)([A-Z][a-z]+)") +ALL_CAP_RE = re.compile("([a-z0-9])([A-Z])") -__all__ = ('merge', 'camel_to_dash', 'default_id', 'not_none', 'not_none_sorted', 'unpack') +__all__ = ( + "merge", + "camel_to_dash", + "default_id", + "not_none", + "not_none_sorted", + "unpack", +) def merge(first, second): - ''' + """ Recursively merges two dictionaries. Second dictionary values will take precedence over those from the first one. @@ -28,7 +35,7 @@ def merge(first, second): :param dict second: The second dictionary :return: the resulting merged dictionary :rtype: dict - ''' + """ if not isinstance(second, dict): return second result = deepcopy(first) @@ -41,46 +48,46 @@ def merge(first, second): def camel_to_dash(value): - ''' + """ Transform a CamelCase string into a low_dashed one :param str value: a CamelCase string to transform :return: the low_dashed string :rtype: str - ''' - first_cap = FIRST_CAP_RE.sub(r'\1_\2', value) - return ALL_CAP_RE.sub(r'\1_\2', first_cap).lower() + """ + first_cap = FIRST_CAP_RE.sub(r"\1_\2", value) + return ALL_CAP_RE.sub(r"\1_\2", first_cap).lower() def default_id(resource, method): - '''Default operation ID generator''' - return '{0}_{1}'.format(method, camel_to_dash(resource)) + """Default operation ID generator""" + return "{0}_{1}".format(method, camel_to_dash(resource)) def not_none(data): - ''' + """ Remove all keys where value is None :param dict data: A dictionary with potentially some values set to None :return: The same dictionary without the keys with values to ``None`` :rtype: dict - ''' + """ return dict((k, v) for k, v in iteritems(data) if v is not None) def not_none_sorted(data): - ''' + """ Remove all keys where value is None :param OrderedDict data: A dictionary with potentially some values set to None :return: The same dictionary without the keys with values to ``None`` :rtype: OrderedDict - ''' + """ return OrderedDict((k, v) for k, v in sorted(iteritems(data)) if v is not None) def unpack(response, default_code=HTTPStatus.OK): - ''' + """ Unpack a Flask standard response. Flask response can be: @@ -98,7 +105,7 @@ def unpack(response, default_code=HTTPStatus.OK): :return: a 3-tuple ``(data, code, headers)`` :rtype: tuple :raise ValueError: if the response does not have one of the expected format - ''' + """ if not isinstance(response, tuple): # data only return response, default_code, {} @@ -114,4 +121,4 @@ def unpack(response, default_code=HTTPStatus.OK): data, code, headers = response return data, code or default_code, headers else: - raise ValueError('Too many response values') + raise ValueError("Too many response values") diff --git a/setup.py b/setup.py index 33ab6234..0039c9d3 100644 --- a/setup.py +++ b/setup.py @@ -9,100 +9,103 @@ from setuptools import setup, find_packages -RE_REQUIREMENT = re.compile(r'^\s*-r\s*(?P.*)$') +RE_REQUIREMENT = re.compile(r"^\s*-r\s*(?P.*)$") PYPI_RST_FILTERS = ( # Replace Python crossreferences by simple monospace - (r':(?:class|func|meth|mod|attr|obj|exc|data|const):`~(?:\w+\.)*(\w+)`', r'``\1``'), - (r':(?:class|func|meth|mod|attr|obj|exc|data|const):`([^`]+)`', r'``\1``'), + (r":(?:class|func|meth|mod|attr|obj|exc|data|const):`~(?:\w+\.)*(\w+)`", r"``\1``"), + (r":(?:class|func|meth|mod|attr|obj|exc|data|const):`([^`]+)`", r"``\1``"), # replace doc references - (r':doc:`(.+) <(.*)>`', r'`\1 `_'), + ( + r":doc:`(.+) <(.*)>`", + r"`\1 `_", + ), # replace issues references - (r':issue:`(.+?)`', r'`#\1 `_'), + ( + r":issue:`(.+?)`", + r"`#\1 `_", + ), # replace pr references - (r':pr:`(.+?)`', r'`#\1 `_'), + (r":pr:`(.+?)`", r"`#\1 `_"), # replace commit references - (r':commit:`(.+?)`', r'`#\1 `_'), + ( + r":commit:`(.+?)`", + r"`#\1 `_", + ), # Drop unrecognized currentmodule - (r'\.\. currentmodule:: .*', ''), + (r"\.\. currentmodule:: .*", ""), ) def rst(filename): - ''' + """ Load rst file and sanitize it for PyPI. Remove unsupported github tags: - code-block directive - all badges - ''' + """ content = io.open(filename).read() for regex, replacement in PYPI_RST_FILTERS: content = re.sub(regex, replacement, content) return content - def pip(filename): - '''Parse pip reqs file and transform it to setuptools requirements.''' + """Parse pip reqs file and transform it to setuptools requirements.""" requirements = [] - for line in io.open(os.path.join('requirements', '{0}.pip'.format(filename))): + for line in io.open(os.path.join("requirements", "{0}.pip".format(filename))): line = line.strip() - if not line or '://' in line or line.startswith('#'): + if not line or "://" in line or line.startswith("#"): continue requirements.append(line) return requirements -long_description = '\n'.join(( - rst('README.rst'), - '' -)) +long_description = "\n".join((rst("README.rst"), "")) -exec(compile(open('flask_restx/__about__.py').read(), 'flask_restx/__about__.py', 'exec')) +exec( + compile(open("flask_restx/__about__.py").read(), "flask_restx/__about__.py", "exec") +) -install_requires = pip('install') -doc_require = pip('doc') -tests_require = pip('test') -dev_require = tests_require + pip('develop') +install_requires = pip("install") +doc_require = pip("doc") +tests_require = pip("test") +dev_require = tests_require + pip("develop") setup( - name='flask-restx', + name="flask-restx", version=__version__, description=__description__, long_description=long_description, - url='https://github.com/python-restx/flask-restx', - author='python-restx Authors', - packages=find_packages(exclude=['tests', 'tests.*']), + url="https://github.com/python-restx/flask-restx", + author="python-restx Authors", + packages=find_packages(exclude=["tests", "tests.*"]), include_package_data=True, install_requires=install_requires, tests_require=tests_require, dev_require=dev_require, - extras_require={ - 'test': tests_require, - 'doc': doc_require, - 'dev': dev_require, - }, - license='BSD-3-Clause', + extras_require={"test": tests_require, "doc": doc_require, "dev": dev_require,}, + license="BSD-3-Clause", zip_safe=False, - keywords='flask restx rest api swagger openapi', + keywords="flask restx rest api swagger openapi", classifiers=[ - 'Development Status :: 3 - Alpha', - 'Programming Language :: Python', - 'Environment :: Web Environment', - 'Operating System :: OS Independent', - 'Intended Audience :: Developers', - 'Topic :: System :: Software Distribution', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: Implementation :: PyPy', - 'Topic :: Software Development :: Libraries :: Python Modules', - 'License :: OSI Approved :: BSD License', + "Development Status :: 3 - Alpha", + "Programming Language :: Python", + "Environment :: Web Environment", + "Operating System :: OS Independent", + "Intended Audience :: Developers", + "Topic :: System :: Software Distribution", + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: Implementation :: PyPy", + "Topic :: Software Development :: Libraries :: Python Modules", + "License :: OSI Approved :: BSD License", ], ) diff --git a/tasks.py b/tasks.py index 24efaa7b..ad89522e 100644 --- a/tasks.py +++ b/tasks.py @@ -11,53 +11,53 @@ ROOT = os.path.dirname(__file__) CLEAN_PATTERNS = [ - 'build', - 'dist', - 'cover', - 'docs/_build', - '**/*.pyc', - '.tox', - '**/__pycache__', - 'reports', - '*.egg-info', + "build", + "dist", + "cover", + "docs/_build", + "**/*.pyc", + ".tox", + "**/__pycache__", + "reports", + "*.egg-info", ] def color(code): - '''A simple ANSI color wrapper factory''' - return lambda t: '\033[{0}{1}\033[0;m'.format(code, t) + """A simple ANSI color wrapper factory""" + return lambda t: "\033[{0}{1}\033[0;m".format(code, t) -green = color('1;32m') -red = color('1;31m') -blue = color('1;30m') -cyan = color('1;36m') -purple = color('1;35m') -white = color('1;39m') +green = color("1;32m") +red = color("1;31m") +blue = color("1;30m") +cyan = color("1;36m") +purple = color("1;35m") +white = color("1;39m") def header(text): - '''Display an header''' - print(' '.join((blue('>>'), cyan(text)))) + """Display an header""" + print(" ".join((blue(">>"), cyan(text)))) sys.stdout.flush() def info(text, *args, **kwargs): - '''Display informations''' + """Display informations""" text = text.format(*args, **kwargs) - print(' '.join((purple('>>>'), text))) + print(" ".join((purple(">>>"), text))) sys.stdout.flush() def success(text): - '''Display a success message''' - print(' '.join((green('>>'), white(text)))) + """Display a success message""" + print(" ".join((green(">>"), white(text)))) sys.stdout.flush() def error(text): - '''Display an error message''' - print(red('✘ {0}'.format(text))) + """Display an error message""" + print(red("✘ {0}".format(text))) sys.stdout.flush() @@ -68,135 +68,153 @@ def exit(text=None, code=-1): def build_args(*args): - return ' '.join(a for a in args if a) + return " ".join(a for a in args if a) @task def clean(ctx): - '''Cleanup all build artifacts''' + """Cleanup all build artifacts""" header(clean.__doc__) with ctx.cd(ROOT): for pattern in CLEAN_PATTERNS: - info('Removing {0}', pattern) - ctx.run('rm -rf {0}'.format(pattern)) + info("Removing {0}", pattern) + ctx.run("rm -rf {0}".format(pattern)) @task def deps(ctx): - '''Install or update development dependencies''' + """Install or update development dependencies""" header(deps.__doc__) with ctx.cd(ROOT): - ctx.run('pip install -r requirements/develop.pip -r requirements/doc.pip', pty=True) + ctx.run( + "pip install -r requirements/develop.pip -r requirements/doc.pip", pty=True + ) @task def demo(ctx): - '''Run the demo''' + """Run the demo""" header(demo.__doc__) with ctx.cd(ROOT): - ctx.run('python examples/todo.py') + ctx.run("python examples/todo.py") @task def test(ctx, profile=False): - '''Run tests suite''' + """Run tests suite""" header(test.__doc__) - kwargs = build_args( - '--benchmark-skip', - '--profile' if profile else None, - ) + kwargs = build_args("--benchmark-skip", "--profile" if profile else None,) with ctx.cd(ROOT): - ctx.run('pytest {0}'.format(kwargs), pty=True) + ctx.run("pytest {0}".format(kwargs), pty=True) @task -def benchmark(ctx, max_time=2, save=False, compare=False, histogram=False, profile=False, tox=False): - '''Run benchmarks''' +def benchmark( + ctx, + max_time=2, + save=False, + compare=False, + histogram=False, + profile=False, + tox=False, +): + """Run benchmarks""" header(benchmark.__doc__) ts = datetime.now() kwargs = build_args( - '--benchmark-max-time={0}'.format(max_time), - '--benchmark-autosave' if save else None, - '--benchmark-compare' if compare else None, - '--benchmark-histogram=histograms/{0:%Y%m%d-%H%M%S}'.format(ts) if histogram else None, - '--benchmark-cprofile=tottime' if profile else None, + "--benchmark-max-time={0}".format(max_time), + "--benchmark-autosave" if save else None, + "--benchmark-compare" if compare else None, + "--benchmark-histogram=histograms/{0:%Y%m%d-%H%M%S}".format(ts) + if histogram + else None, + "--benchmark-cprofile=tottime" if profile else None, ) - cmd = 'pytest tests/benchmarks {0}'.format(kwargs) + cmd = "pytest tests/benchmarks {0}".format(kwargs) if tox: - envs = ctx.run('tox -l', hide=True).stdout.splitlines() - envs = ','.join(e for e in envs if e != 'doc') - cmd = 'tox -e {envs} -- {cmd}'.format(envs=envs, cmd=cmd) + envs = ctx.run("tox -l", hide=True).stdout.splitlines() + envs = ",".join(e for e in envs if e != "doc") + cmd = "tox -e {envs} -- {cmd}".format(envs=envs, cmd=cmd) ctx.run(cmd, pty=True) @task def cover(ctx, html=False): - '''Run tests suite with coverage''' + """Run tests suite with coverage""" header(cover.__doc__) - extra = '--cov-report html' if html else '' + extra = "--cov-report html" if html else "" with ctx.cd(ROOT): - ctx.run('pytest --benchmark-skip --cov flask_restx --cov-report term {0}'.format(extra), pty=True) + ctx.run( + "pytest --benchmark-skip --cov flask_restx --cov-report term {0}".format( + extra + ), + pty=True, + ) @task def tox(ctx): - '''Run tests against Python versions''' + """Run tests against Python versions""" header(tox.__doc__) - ctx.run('tox', pty=True) + ctx.run("tox", pty=True) @task def qa(ctx): - '''Run a quality report''' + """Run a quality report""" header(qa.__doc__) with ctx.cd(ROOT): - info('Ensure PyPI can render README and CHANGELOG') - info('Building dist package') - dist = ctx.run('python setup.py sdist', pty=True, warn=False, hide=True) + info("Ensure PyPI can render README and CHANGELOG") + info("Building dist package") + dist = ctx.run("python setup.py sdist", pty=True, warn=False, hide=True) if dist.failed: - error('Unable to build sdist package') - exit('Quality check failed', dist.return_code) - readme_results = ctx.run('twine check dist/*', pty=True, warn=True, hide=True) + error("Unable to build sdist package") + exit("Quality check failed", dist.return_code) + readme_results = ctx.run("twine check dist/*", pty=True, warn=True, hide=True) if readme_results.failed: print(readme_results.stdout) - error('README and/or CHANGELOG is not renderable by PyPI') + error("README and/or CHANGELOG is not renderable by PyPI") else: - success('README and CHANGELOG are renderable by PyPI') + success("README and CHANGELOG are renderable by PyPI") if readme_results.failed: - exit('Quality check failed', readme_results.return_code) - success('Quality check OK') + exit("Quality check failed", readme_results.return_code) + success("Quality check OK") @task def doc(ctx): - '''Build the documentation''' + """Build the documentation""" header(doc.__doc__) - with ctx.cd(os.path.join(ROOT, 'doc')): - ctx.run('make html', pty=True) + with ctx.cd(os.path.join(ROOT, "doc")): + ctx.run("make html", pty=True) @task def assets(ctx): - '''Fetch web assets''' + """Fetch web assets""" header(assets.__doc__) with ctx.cd(ROOT): - ctx.run('npm install') - ctx.run('mkdir -p flask_restx/static') - ctx.run('cp node_modules/swagger-ui-dist/{swagger-ui*.{css,js}{,.map},favicon*.png,oauth2-redirect.html} flask_restx/static') + ctx.run("npm install") + ctx.run("mkdir -p flask_restx/static") + ctx.run( + "cp node_modules/swagger-ui-dist/{swagger-ui*.{css,js}{,.map},favicon*.png,oauth2-redirect.html} flask_restx/static" + ) # Until next release we need to install droid sans separately - ctx.run('cp node_modules/typeface-droid-sans/index.css flask_restx/static/droid-sans.css') - ctx.run('cp -R node_modules/typeface-droid-sans/files flask_restx/static/') + ctx.run( + "cp node_modules/typeface-droid-sans/index.css flask_restx/static/droid-sans.css" + ) + ctx.run("cp -R node_modules/typeface-droid-sans/files flask_restx/static/") @task def dist(ctx): - '''Package for distribution''' + """Package for distribution""" header(dist.__doc__) with ctx.cd(ROOT): - ctx.run('python setup.py bdist_wheel', pty=True) + ctx.run("python setup.py bdist_wheel", pty=True) @task(clean, deps, test, doc, qa, assets, dist, default=True) def all(ctx): - '''Run tests, reports and packaging''' + """Run tests, reports and packaging""" pass diff --git a/tests/benchmarks/bench_marshalling.py b/tests/benchmarks/bench_marshalling.py index 9b505e04..9cf979b1 100644 --- a/tests/benchmarks/bench_marshalling.py +++ b/tests/benchmarks/bench_marshalling.py @@ -6,31 +6,21 @@ fake = Faker() -person_fields = { - 'name': fields.String, - 'age': fields.Integer -} +person_fields = {"name": fields.String, "age": fields.Integer} family_fields = { - 'father': fields.Nested(person_fields), - 'mother': fields.Nested(person_fields), - 'children': fields.List(fields.Nested(person_fields)) + "father": fields.Nested(person_fields), + "mother": fields.Nested(person_fields), + "children": fields.List(fields.Nested(person_fields)), } def person(): - return { - 'name': fake.name(), - 'age': fake.pyint() - } + return {"name": fake.name(), "age": fake.pyint()} def family(): - return { - 'father': person(), - 'mother': person(), - 'children': [person(), person()] - } + return {"father": person(), "mother": person(), "children": [person(), person()]} def marshal_simple(): @@ -42,16 +32,16 @@ def marshal_nested(): def marshal_simple_with_mask(app): - with app.test_request_context('/', headers={'X-Fields': 'name'}): + with app.test_request_context("/", headers={"X-Fields": "name"}): return marshal(person(), person_fields) def marshal_nested_with_mask(app): - with app.test_request_context('/', headers={'X-Fields': 'father,children{name}'}): + with app.test_request_context("/", headers={"X-Fields": "father,children{name}"}): return marshal(family(), family_fields) -@pytest.mark.benchmark(group='marshalling') +@pytest.mark.benchmark(group="marshalling") class MarshallingBenchmark(object): def bench_marshal_simple(self, benchmark): benchmark(marshal_simple) diff --git a/tests/benchmarks/bench_swagger.py b/tests/benchmarks/bench_swagger.py index 9d3509e6..da52b04c 100644 --- a/tests/benchmarks/bench_swagger.py +++ b/tests/benchmarks/bench_swagger.py @@ -5,86 +5,86 @@ api = Api() -person = api.model('Person', { - 'name': fields.String, - 'age': fields.Integer -}) +person = api.model("Person", {"name": fields.String, "age": fields.Integer}) -family = api.model('Family', { - 'name': fields.String, - 'father': fields.Nested(person), - 'mother': fields.Nested(person), - 'children': fields.List(fields.Nested(person)) -}) +family = api.model( + "Family", + { + "name": fields.String, + "father": fields.Nested(person), + "mother": fields.Nested(person), + "children": fields.List(fields.Nested(person)), + }, +) -@api.route('/families', endpoint='families') +@api.route("/families", endpoint="families") class Families(Resource): @api.marshal_with(family) def get(self): - '''List all families''' + """List all families""" pass @api.marshal_with(family) - @api.response(201, 'Family created') + @api.response(201, "Family created") def post(self): - '''Create a new family''' + """Create a new family""" pass -@api.route('/families//', endpoint='family') -@api.response(404, 'Family not found') +@api.route("/families//", endpoint="family") +@api.response(404, "Family not found") class Family(Resource): @api.marshal_with(family) def get(self): - '''Get a family given its name''' + """Get a family given its name""" pass @api.marshal_with(family) def put(self): - '''Update a family given its name''' + """Update a family given its name""" pass -@api.route('/persons', endpoint='persons') +@api.route("/persons", endpoint="persons") class Persons(Resource): @api.marshal_with(person) def get(self): - '''List all persons''' + """List all persons""" pass @api.marshal_with(person) - @api.response(201, 'Person created') + @api.response(201, "Person created") def post(self): - '''Create a new person''' + """Create a new person""" pass -@api.route('/persons//', endpoint='person') -@api.response(404, 'Person not found') +@api.route("/persons//", endpoint="person") +@api.response(404, "Person not found") class Person(Resource): @api.marshal_with(person) def get(self): - '''Get a person given its name''' + """Get a person given its name""" pass @api.marshal_with(person) def put(self): - '''Update a person given its name''' + """Update a person given its name""" pass def swagger_specs(app): - with app.test_request_context('/'): + with app.test_request_context("/"): return Swagger(api).as_dict() def swagger_specs_cached(app): - with app.test_request_context('/'): + with app.test_request_context("/"): return api.__schema__ -@pytest.mark.benchmark(group='swagger') +@pytest.mark.benchmark(group="swagger") class SwaggerBenchmark(object): @pytest.fixture(autouse=True) def register(self, app): diff --git a/tests/conftest.py b/tests/conftest.py index 7b3f4172..220c7b96 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,19 +26,23 @@ def dict_raise_on_duplicates(self, ordered_pairs): def get_json(self, url, status=200, **kwargs): response = self.get(url, **kwargs) assert response.status_code == status - assert response.content_type == 'application/json' - return json.loads(response.data.decode('utf8'), object_pairs_hook=self.dict_raise_on_duplicates) + assert response.content_type == "application/json" + return json.loads( + response.data.decode("utf8"), + object_pairs_hook=self.dict_raise_on_duplicates, + ) def post_json(self, url, data, status=200, **kwargs): - response = self.post(url, data=json.dumps(data), - headers={'content-type': 'application/json'}) + response = self.post( + url, data=json.dumps(data), headers={"content-type": "application/json"} + ) assert response.status_code == status - assert response.content_type == 'application/json' - return json.loads(response.data.decode('utf8')) + assert response.content_type == "application/json" + return json.loads(response.data.decode("utf8")) - def get_specs(self, prefix='', status=200, **kwargs): - '''Get a Swagger specification for a RESTX API''' - return self.get_json('{0}/swagger.json'.format(prefix), status=status, **kwargs) + def get_specs(self, prefix="", status=200, **kwargs): + """Get a Swagger specification for a RESTX API""" + return self.get_json("{0}/swagger.json".format(prefix), status=status, **kwargs) @pytest.fixture @@ -50,16 +54,16 @@ def app(): @pytest.fixture def api(request, app): - marker = request.node.get_closest_marker('api') + marker = request.node.get_closest_marker("api") bpkwargs = {} kwargs = {} if marker: - if 'prefix' in marker.kwargs: - bpkwargs['url_prefix'] = marker.kwargs.pop('prefix') - if 'subdomain' in marker.kwargs: - bpkwargs['subdomain'] = marker.kwargs.pop('subdomain') + if "prefix" in marker.kwargs: + bpkwargs["url_prefix"] = marker.kwargs.pop("prefix") + if "subdomain" in marker.kwargs: + bpkwargs["subdomain"] = marker.kwargs.pop("subdomain") kwargs = marker.kwargs - blueprint = Blueprint('api', __name__, **bpkwargs) + blueprint = Blueprint("api", __name__, **bpkwargs) api = restx.Api(blueprint, **kwargs) app.register_blueprint(blueprint) yield api @@ -70,7 +74,7 @@ def mock_app(mocker): app = mocker.Mock(Flask) # mock Flask app object doesn't have any real loggers -> mock logging # set up on Api object - mocker.patch.object(restx.Api, '_configure_namespace_logger') + mocker.patch.object(restx.Api, "_configure_namespace_logger") app.view_functions = {} app.extensions = {} app.config = {} @@ -79,8 +83,8 @@ def mock_app(mocker): @pytest.fixture(autouse=True) def _push_custom_request_context(request): - app = request.getfixturevalue('app') - options = request.node.get_closest_marker('request_context') + app = request.getfixturevalue("app") + options = request.node.get_closest_marker("request_context") if options is None: return diff --git a/tests/legacy/test_api_legacy.py b/tests/legacy/test_api_legacy.py index 8b43c22a..5d6649c8 100644 --- a/tests/legacy/test_api_legacy.py +++ b/tests/legacy/test_api_legacy.py @@ -24,107 +24,108 @@ def test_unauthorized_no_challenge_by_default(self, api, mocker): response = mocker.Mock() response.headers = {} response = api.unauthorized(response) - assert 'WWW-Authenticate' not in response.headers + assert "WWW-Authenticate" not in response.headers @pytest.mark.api(serve_challenge_on_401=True) def test_unauthorized(self, api, mocker): response = mocker.Mock() response.headers = {} response = api.unauthorized(response) - assert response.headers['WWW-Authenticate'] == 'Basic realm="flask-restx"' + assert response.headers["WWW-Authenticate"] == 'Basic realm="flask-restx"' - @pytest.mark.options(HTTP_BASIC_AUTH_REALM='Foo') + @pytest.mark.options(HTTP_BASIC_AUTH_REALM="Foo") @pytest.mark.api(serve_challenge_on_401=True) def test_unauthorized_custom_realm(self, api, mocker): response = mocker.Mock() response.headers = {} response = api.unauthorized(response) - assert response.headers['WWW-Authenticate'] == 'Basic realm="Foo"' + assert response.headers["WWW-Authenticate"] == 'Basic realm="Foo"' def test_handle_error_401_no_challenge_by_default(self, api): resp = api.handle_error(Unauthorized()) assert resp.status_code == 401 - assert 'WWW-Autheneticate' not in resp.headers + assert "WWW-Autheneticate" not in resp.headers @pytest.mark.api(serve_challenge_on_401=True) def test_handle_error_401_sends_challege_default_realm(self, api): exception = HTTPException() exception.code = 401 - exception.data = {'foo': 'bar'} + exception.data = {"foo": "bar"} resp = api.handle_error(exception) assert resp.status_code == 401 - assert resp.headers['WWW-Authenticate'] == 'Basic realm="flask-restx"' + assert resp.headers["WWW-Authenticate"] == 'Basic realm="flask-restx"' @pytest.mark.api(serve_challenge_on_401=True) - @pytest.mark.options(HTTP_BASIC_AUTH_REALM='test-realm') + @pytest.mark.options(HTTP_BASIC_AUTH_REALM="test-realm") def test_handle_error_401_sends_challege_configured_realm(self, api): resp = api.handle_error(Unauthorized()) assert resp.status_code == 401 - assert resp.headers['WWW-Authenticate'] == 'Basic realm="test-realm"' + assert resp.headers["WWW-Authenticate"] == 'Basic realm="test-realm"' def test_handle_error_does_not_swallow_exceptions(self, api): - exception = BadRequest('x') + exception = BadRequest("x") resp = api.handle_error(exception) assert resp.status_code == 400 assert resp.get_data() == b'{"message": "x"}\n' def test_api_representation(self, api): - @api.representation('foo') + @api.representation("foo") def foo(): pass - assert api.representations['foo'] == foo + assert api.representations["foo"] == foo def test_api_base(self, app): api = restx.Api(app) assert api.urls == {} - assert api.prefix == '' - assert api.default_mediatype == 'application/json' + assert api.prefix == "" + assert api.default_mediatype == "application/json" def test_api_delayed_initialization(self, app, client): api = restx.Api() - api.add_resource(HelloWorld, '/', endpoint="hello") + api.add_resource(HelloWorld, "/", endpoint="hello") api.init_app(app) - assert client.get('/').status_code == 200 + assert client.get("/").status_code == 200 def test_api_prefix(self, app): - api = restx.Api(app, prefix='/foo') - assert api.prefix == '/foo' + api = restx.Api(app, prefix="/foo") + assert api.prefix == "/foo" @pytest.mark.api(serve_challenge_on_401=True) def test_handle_auth(self, api): resp = api.handle_error(Unauthorized()) assert resp.status_code == 401 - expected_data = dumps({'message': Unauthorized.description}) + "\n" + expected_data = dumps({"message": Unauthorized.description}) + "\n" assert resp.data.decode() == expected_data - assert 'WWW-Authenticate' in resp.headers + assert "WWW-Authenticate" in resp.headers def test_media_types(self, app): api = restx.Api(app) - with app.test_request_context("/foo", headers={ - 'Accept': 'application/json' - }): - assert api.mediatypes() == ['application/json'] + with app.test_request_context("/foo", headers={"Accept": "application/json"}): + assert api.mediatypes() == ["application/json"] def test_media_types_method(self, app, mocker): api = restx.Api(app) - with app.test_request_context("/foo", headers={ - 'Accept': 'application/xml; q=.5' - }): - assert api.mediatypes_method()(mocker.Mock()) == ['application/xml', 'application/json'] + with app.test_request_context( + "/foo", headers={"Accept": "application/xml; q=.5"} + ): + assert api.mediatypes_method()(mocker.Mock()) == [ + "application/xml", + "application/json", + ] def test_media_types_q(self, app): api = restx.Api(app) - with app.test_request_context("/foo", headers={ - 'Accept': 'application/json; q=1, application/xml; q=.5' - }): - assert api.mediatypes() == ['application/json', 'application/xml'] + with app.test_request_context( + "/foo", headers={"Accept": "application/json; q=1, application/xml; q=.5"} + ): + assert api.mediatypes() == ["application/json", "application/xml"] def test_decorator(self, mocker, mock_app): def return_zero(func): @@ -134,65 +135,64 @@ def return_zero(func): api = restx.Api(mock_app) api.decorators.append(return_zero) api.output = mocker.Mock() - api.add_resource(view, '/foo', endpoint='bar') + api.add_resource(view, "/foo", endpoint="bar") - mock_app.add_url_rule.assert_called_with('/foo', view_func=0) + mock_app.add_url_rule.assert_called_with("/foo", view_func=0) def test_add_resource_endpoint(self, app, mocker): - view = mocker.Mock(**{'as_view.return_value.__name__': str('test_view')}) + view = mocker.Mock(**{"as_view.return_value.__name__": str("test_view")}) api = restx.Api(app) - api.add_resource(view, '/foo', endpoint='bar') + api.add_resource(view, "/foo", endpoint="bar") - view.as_view.assert_called_with('bar', api) + view.as_view.assert_called_with("bar", api) def test_add_two_conflicting_resources_on_same_endpoint(self, app): api = restx.Api(app) class Foo1(restx.Resource): def get(self): - return 'foo1' + return "foo1" class Foo2(restx.Resource): def get(self): - return 'foo2' + return "foo2" - api.add_resource(Foo1, '/foo', endpoint='bar') + api.add_resource(Foo1, "/foo", endpoint="bar") with pytest.raises(ValueError): - api.add_resource(Foo2, '/foo/toto', endpoint='bar') + api.add_resource(Foo2, "/foo/toto", endpoint="bar") def test_add_the_same_resource_on_same_endpoint(self, app): api = restx.Api(app) class Foo1(restx.Resource): def get(self): - return 'foo1' + return "foo1" - api.add_resource(Foo1, '/foo', endpoint='bar') - api.add_resource(Foo1, '/foo/toto', endpoint='blah') + api.add_resource(Foo1, "/foo", endpoint="bar") + api.add_resource(Foo1, "/foo/toto", endpoint="blah") with app.test_client() as client: - foo1 = client.get('/foo') + foo1 = client.get("/foo") assert foo1.data == b'"foo1"\n' - foo2 = client.get('/foo/toto') + foo2 = client.get("/foo/toto") assert foo2.data == b'"foo1"\n' def test_add_resource(self, mocker, mock_app): api = restx.Api(mock_app) api.output = mocker.Mock() - api.add_resource(views.MethodView, '/foo') + api.add_resource(views.MethodView, "/foo") - mock_app.add_url_rule.assert_called_with('/foo', - view_func=api.output()) + mock_app.add_url_rule.assert_called_with("/foo", view_func=api.output()) def test_add_resource_kwargs(self, mocker, mock_app): api = restx.Api(mock_app) api.output = mocker.Mock() - api.add_resource(views.MethodView, '/foo', defaults={"bar": "baz"}) + api.add_resource(views.MethodView, "/foo", defaults={"bar": "baz"}) - mock_app.add_url_rule.assert_called_with('/foo', - view_func=api.output(), - defaults={"bar": "baz"}) + mock_app.add_url_rule.assert_called_with( + "/foo", view_func=api.output(), defaults={"bar": "baz"} + ) def test_add_resource_forward_resource_class_parameters(self, app, client): api = restx.Api(app) @@ -200,23 +200,25 @@ def test_add_resource_forward_resource_class_parameters(self, app, client): class Foo(restx.Resource): def __init__(self, api, *args, **kwargs): self.one = args[0] - self.two = kwargs['secret_state'] + self.two = kwargs["secret_state"] super(Foo, self).__init__(api, *args, **kwargs) def get(self): return "{0} {1}".format(self.one, self.two) - api.add_resource(Foo, '/foo', - resource_class_args=('wonderful',), - resource_class_kwargs={'secret_state': 'slurm'}) + api.add_resource( + Foo, + "/foo", + resource_class_args=("wonderful",), + resource_class_kwargs={"secret_state": "slurm"}, + ) - foo = client.get('/foo') + foo = client.get("/foo") assert foo.data == b'"wonderful slurm"\n' def test_output_unpack(self, app): - def make_empty_response(): - return {'foo': 'bar'} + return {"foo": "bar"} api = restx.Api(app) @@ -227,9 +229,8 @@ def make_empty_response(): assert resp.data.decode() == '{"foo": "bar"}\n' def test_output_func(self, app): - def make_empty_resposne(): - return flask.make_response('') + return flask.make_response("") api = restx.Api(app) @@ -237,7 +238,7 @@ def make_empty_resposne(): wrapper = api.output(make_empty_resposne) resp = wrapper() assert resp.status_code == 200 - assert resp.data.decode() == '' + assert resp.data.decode() == "" def test_resource(self, app, mocker): resource = restx.Resource() @@ -249,7 +250,7 @@ def test_resource_resp(self, app, mocker): resource = restx.Resource() resource.get = mocker.Mock() with app.test_request_context("/foo"): - resource.get.return_value = flask.make_response('') + resource.get.return_value = flask.make_response("") resource.dispatch_request() def test_resource_text_plain(self, app): @@ -258,24 +259,24 @@ def text(data, code, headers=None): class Foo(restx.Resource): representations = { - 'text/plain': text, + "text/plain": text, } def get(self): - return 'hello' + return "hello" - with app.test_request_context("/foo", headers={'Accept': 'text/plain'}): + with app.test_request_context("/foo", headers={"Accept": "text/plain"}): resource = Foo(None) resp = resource.dispatch_request() - assert resp.data.decode() == 'hello' + assert resp.data.decode() == "hello" - @pytest.mark.request_context('/foo') + @pytest.mark.request_context("/foo") def test_resource_error(self, app): resource = restx.Resource() with pytest.raises(AssertionError): resource.dispatch_request() - @pytest.mark.request_context('/foo', method='HEAD') + @pytest.mark.request_context("/foo", method="HEAD") def test_resource_head(self, app): resource = restx.Resource() with pytest.raises(AssertionError): @@ -283,29 +284,29 @@ def test_resource_head(self, app): def test_endpoints(self, app): api = restx.Api(app) - api.add_resource(HelloWorld, '/ids/', endpoint="hello") - with app.test_request_context('/foo'): + api.add_resource(HelloWorld, "/ids/", endpoint="hello") + with app.test_request_context("/foo"): assert api._has_fr_route() is False - with app.test_request_context('/ids/3'): + with app.test_request_context("/ids/3"): assert api._has_fr_route() is True def test_url_for(self, app): api = restx.Api(app) - api.add_resource(HelloWorld, '/ids/') - with app.test_request_context('/foo'): - assert api.url_for(HelloWorld, id=123) == '/ids/123' + api.add_resource(HelloWorld, "/ids/") + with app.test_request_context("/foo"): + assert api.url_for(HelloWorld, id=123) == "/ids/123" def test_url_for_with_blueprint(self, app): """Verify that url_for works when an Api object is mounted on a Blueprint. """ - api_bp = Blueprint('api', __name__) + api_bp = Blueprint("api", __name__) api = restx.Api(api_bp) - api.add_resource(HelloWorld, '/foo/') + api.add_resource(HelloWorld, "/foo/") app.register_blueprint(api_bp) - with app.test_request_context('/foo'): - assert api.url_for(HelloWorld, bar='baz') == '/foo/baz' + with app.test_request_context("/foo"): + assert api.url_for(HelloWorld, bar="baz") == "/foo/baz" def test_exception_header_forwarding_doesnt_duplicate_headers(self, api): """Test that HTTPException's headers do not add a duplicate @@ -314,26 +315,22 @@ def test_exception_header_forwarding_doesnt_duplicate_headers(self, api): https://github.com/flask-restful/flask-restful/issues/534 """ r = api.handle_error(BadRequest()) - assert len(r.headers.getlist('Content-Length')) == 1 + assert len(r.headers.getlist("Content-Length")) == 1 def test_read_json_settings_from_config(self, app, client): class TestConfig(object): - RESTX_JSON = { - 'indent': 2, - 'sort_keys': True, - 'separators': (', ', ': ') - } + RESTX_JSON = {"indent": 2, "sort_keys": True, "separators": (", ", ": ")} app.config.from_object(TestConfig) api = restx.Api(app) class Foo(restx.Resource): def get(self): - return {'foo': 'bar', 'baz': 'qux'} + return {"foo": "bar", "baz": "qux"} - api.add_resource(Foo, '/foo') + api.add_resource(Foo, "/foo") - data = client.get('/foo').data + data = client.get("/foo").data expected = b'{\n "baz": "qux", \n "foo": "bar"\n}\n' @@ -342,21 +339,21 @@ def get(self): def test_use_custom_jsonencoder(self, app, client): class CabageEncoder(JSONEncoder): def default(self, obj): - return 'cabbage' + return "cabbage" class TestConfig(object): - RESTX_JSON = {'cls': CabageEncoder} + RESTX_JSON = {"cls": CabageEncoder} app.config.from_object(TestConfig) api = restx.Api(app) class Cabbage(restx.Resource): def get(self): - return {'frob': object()} + return {"frob": object()} - api.add_resource(Cabbage, '/cabbage') + api.add_resource(Cabbage, "/cabbage") - data = client.get('/cabbage').data + data = client.get("/cabbage").data expected = b'{"frob": "cabbage"}\n' assert data == expected @@ -364,11 +361,11 @@ def get(self): def test_json_with_no_settings(self, api, client): class Foo(restx.Resource): def get(self): - return {'foo': 'bar'} + return {"foo": "bar"} - api.add_resource(Foo, '/foo') + api.add_resource(Foo, "/foo") - data = client.get('/foo').data + data = client.get("/foo").data expected = b'{"foo": "bar"}\n' assert data == expected @@ -376,17 +373,17 @@ def get(self): def test_redirect(self, api, client): class FooResource(restx.Resource): def get(self): - return redirect('/') + return redirect("/") - api.add_resource(FooResource, '/api') + api.add_resource(FooResource, "/api") - resp = client.get('/api') + resp = client.get("/api") assert resp.status_code == 302 - assert resp.headers['Location'] == 'http://localhost/' + assert resp.headers["Location"] == "http://localhost/" def test_calling_owns_endpoint_before_api_init(self): api = restx.Api() - api.owns_endpoint('endpoint') + api.owns_endpoint("endpoint") # with pytest.raises(AttributeError): # try: # except AttributeError as ae: diff --git a/tests/legacy/test_api_with_blueprint.py b/tests/legacy/test_api_with_blueprint.py index 7916a6be..a59ab54e 100644 --- a/tests/legacy/test_api_with_blueprint.py +++ b/tests/legacy/test_api_with_blueprint.py @@ -24,139 +24,140 @@ def get(self): class APIWithBlueprintTest(object): def test_api_base(self, app): - blueprint = Blueprint('test', __name__) + blueprint = Blueprint("test", __name__) api = restx.Api(blueprint) app.register_blueprint(blueprint) assert api.urls == {} - assert api.prefix == '' - assert api.default_mediatype == 'application/json' + assert api.prefix == "" + assert api.default_mediatype == "application/json" def test_api_delayed_initialization(self, app): - blueprint = Blueprint('test', __name__) + blueprint = Blueprint("test", __name__) api = restx.Api() api.init_app(blueprint) app.register_blueprint(blueprint) - api.add_resource(HelloWorld, '/', endpoint="hello") + api.add_resource(HelloWorld, "/", endpoint="hello") def test_add_resource_endpoint(self, app, mocker): - blueprint = Blueprint('test', __name__) + blueprint = Blueprint("test", __name__) api = restx.Api(blueprint) - view = mocker.Mock(**{'as_view.return_value.__name__': str('test_view')}) - api.add_resource(view, '/foo', endpoint='bar') + view = mocker.Mock(**{"as_view.return_value.__name__": str("test_view")}) + api.add_resource(view, "/foo", endpoint="bar") app.register_blueprint(blueprint) - view.as_view.assert_called_with('bar', api) + view.as_view.assert_called_with("bar", api) def test_add_resource_endpoint_after_registration(self, app, mocker): - blueprint = Blueprint('test', __name__) + blueprint = Blueprint("test", __name__) api = restx.Api(blueprint) app.register_blueprint(blueprint) - view = mocker.Mock(**{'as_view.return_value.__name__': str('test_view')}) - api.add_resource(view, '/foo', endpoint='bar') - view.as_view.assert_called_with('bar', api) + view = mocker.Mock(**{"as_view.return_value.__name__": str("test_view")}) + api.add_resource(view, "/foo", endpoint="bar") + view.as_view.assert_called_with("bar", api) def test_url_with_api_prefix(self, app): - blueprint = Blueprint('test', __name__) - api = restx.Api(blueprint, prefix='/api') - api.add_resource(HelloWorld, '/hi', endpoint='hello') + blueprint = Blueprint("test", __name__) + api = restx.Api(blueprint, prefix="/api") + api.add_resource(HelloWorld, "/hi", endpoint="hello") app.register_blueprint(blueprint) - with app.test_request_context('/api/hi'): - assert request.endpoint == 'test.hello' + with app.test_request_context("/api/hi"): + assert request.endpoint == "test.hello" def test_url_with_blueprint_prefix(self, app): - blueprint = Blueprint('test', __name__, url_prefix='/bp') + blueprint = Blueprint("test", __name__, url_prefix="/bp") api = restx.Api(blueprint) - api.add_resource(HelloWorld, '/hi', endpoint='hello') + api.add_resource(HelloWorld, "/hi", endpoint="hello") app.register_blueprint(blueprint) - with app.test_request_context('/bp/hi'): - assert request.endpoint == 'test.hello' + with app.test_request_context("/bp/hi"): + assert request.endpoint == "test.hello" def test_url_with_registration_prefix(self, app): - blueprint = Blueprint('test', __name__) + blueprint = Blueprint("test", __name__) api = restx.Api(blueprint) - api.add_resource(HelloWorld, '/hi', endpoint='hello') - app.register_blueprint(blueprint, url_prefix='/reg') - with app.test_request_context('/reg/hi'): - assert request.endpoint == 'test.hello' + api.add_resource(HelloWorld, "/hi", endpoint="hello") + app.register_blueprint(blueprint, url_prefix="/reg") + with app.test_request_context("/reg/hi"): + assert request.endpoint == "test.hello" def test_registration_prefix_overrides_blueprint_prefix(self, app): - blueprint = Blueprint('test', __name__, url_prefix='/bp') + blueprint = Blueprint("test", __name__, url_prefix="/bp") api = restx.Api(blueprint) - api.add_resource(HelloWorld, '/hi', endpoint='hello') - app.register_blueprint(blueprint, url_prefix='/reg') - with app.test_request_context('/reg/hi'): - assert request.endpoint == 'test.hello' + api.add_resource(HelloWorld, "/hi", endpoint="hello") + app.register_blueprint(blueprint, url_prefix="/reg") + with app.test_request_context("/reg/hi"): + assert request.endpoint == "test.hello" def test_url_with_api_and_blueprint_prefix(self, app): - blueprint = Blueprint('test', __name__, url_prefix='/bp') - api = restx.Api(blueprint, prefix='/api') - api.add_resource(HelloWorld, '/hi', endpoint='hello') + blueprint = Blueprint("test", __name__, url_prefix="/bp") + api = restx.Api(blueprint, prefix="/api") + api.add_resource(HelloWorld, "/hi", endpoint="hello") app.register_blueprint(blueprint) - with app.test_request_context('/bp/api/hi'): - assert request.endpoint == 'test.hello' + with app.test_request_context("/bp/api/hi"): + assert request.endpoint == "test.hello" def test_error_routing(self, app, mocker): - blueprint = Blueprint('test', __name__) + blueprint = Blueprint("test", __name__) api = restx.Api(blueprint) - api.add_resource(HelloWorld, '/hi', endpoint="hello") - api.add_resource(GoodbyeWorld(404), '/bye', endpoint="bye") + api.add_resource(HelloWorld, "/hi", endpoint="hello") + api.add_resource(GoodbyeWorld(404), "/bye", endpoint="bye") app.register_blueprint(blueprint) - with app.test_request_context('/hi', method='POST'): + with app.test_request_context("/hi", method="POST"): assert api._should_use_fr_error_handler() is True assert api._has_fr_route() is True - with app.test_request_context('/bye'): + with app.test_request_context("/bye"): api._should_use_fr_error_handler = mocker.Mock(return_value=False) assert api._has_fr_route() is True def test_non_blueprint_rest_error_routing(self, app, mocker): - blueprint = Blueprint('test', __name__) + blueprint = Blueprint("test", __name__) api = restx.Api(blueprint) - api.add_resource(HelloWorld, '/hi', endpoint="hello") - api.add_resource(GoodbyeWorld(404), '/bye', endpoint="bye") - app.register_blueprint(blueprint, url_prefix='/blueprint') + api.add_resource(HelloWorld, "/hi", endpoint="hello") + api.add_resource(GoodbyeWorld(404), "/bye", endpoint="bye") + app.register_blueprint(blueprint, url_prefix="/blueprint") api2 = restx.Api(app) - api2.add_resource(HelloWorld(api), '/hi', endpoint="hello") - api2.add_resource(GoodbyeWorld(404), '/bye', endpoint="bye") - with app.test_request_context('/hi', method='POST'): + api2.add_resource(HelloWorld(api), "/hi", endpoint="hello") + api2.add_resource(GoodbyeWorld(404), "/bye", endpoint="bye") + with app.test_request_context("/hi", method="POST"): assert api._should_use_fr_error_handler() is False assert api2._should_use_fr_error_handler() is True assert api._has_fr_route() is False assert api2._has_fr_route() is True - with app.test_request_context('/blueprint/hi', method='POST'): + with app.test_request_context("/blueprint/hi", method="POST"): assert api._should_use_fr_error_handler() is True assert api2._should_use_fr_error_handler() is False assert api._has_fr_route() is True assert api2._has_fr_route() is False api._should_use_fr_error_handler = mocker.Mock(return_value=False) api2._should_use_fr_error_handler = mocker.Mock(return_value=False) - with app.test_request_context('/bye'): + with app.test_request_context("/bye"): assert api._has_fr_route() is False assert api2._has_fr_route() is True - with app.test_request_context('/blueprint/bye'): + with app.test_request_context("/blueprint/bye"): assert api._has_fr_route() is True assert api2._has_fr_route() is False def test_non_blueprint_non_rest_error_routing(self, app, mocker): - blueprint = Blueprint('test', __name__) + blueprint = Blueprint("test", __name__) api = restx.Api(blueprint) - api.add_resource(HelloWorld, '/hi', endpoint="hello") - api.add_resource(GoodbyeWorld(404), '/bye', endpoint="bye") - app.register_blueprint(blueprint, url_prefix='/blueprint') + api.add_resource(HelloWorld, "/hi", endpoint="hello") + api.add_resource(GoodbyeWorld(404), "/bye", endpoint="bye") + app.register_blueprint(blueprint, url_prefix="/blueprint") - @app.route('/hi') + @app.route("/hi") def hi(): - return 'hi' + return "hi" - @app.route('/bye') + @app.route("/bye") def bye(): flask.abort(404) - with app.test_request_context('/hi', method='POST'): + + with app.test_request_context("/hi", method="POST"): assert api._should_use_fr_error_handler() is False assert api._has_fr_route() is False - with app.test_request_context('/blueprint/hi', method='POST'): + with app.test_request_context("/blueprint/hi", method="POST"): assert api._should_use_fr_error_handler() is True assert api._has_fr_route() is True api._should_use_fr_error_handler = mocker.Mock(return_value=False) - with app.test_request_context('/bye'): + with app.test_request_context("/bye"): assert api._has_fr_route() is False - with app.test_request_context('/blueprint/bye'): + with app.test_request_context("/blueprint/bye"): assert api._has_fr_route() is True diff --git a/tests/test_accept.py b/tests/test_accept.py index 784677db..69578366 100644 --- a/tests/test_accept.py +++ b/tests/test_accept.py @@ -12,87 +12,87 @@ def get(self): class ErrorsTest(object): def test_accept_default_application_json(self, app, client): api = restx.Api(app) - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/', headers={'Accept': None}) + res = client.get("/test/", headers={"Accept": None}) assert res.status_code == 200 - assert res.content_type == 'application/json' + assert res.content_type == "application/json" def test_accept_application_json_by_default(self, app, client): api = restx.Api(app) - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/', headers=[('Accept', 'application/json')]) + res = client.get("/test/", headers=[("Accept", "application/json")]) assert res.status_code == 200 - assert res.content_type == 'application/json' + assert res.content_type == "application/json" def test_accept_no_default_match_acceptable(self, app, client): api = restx.Api(app, default_mediatype=None) - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/', headers=[('Accept', 'application/json')]) + res = client.get("/test/", headers=[("Accept", "application/json")]) assert res.status_code == 200 - assert res.content_type == 'application/json' + assert res.content_type == "application/json" def test_accept_default_override_accept(self, app, client): api = restx.Api(app) - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/', headers=[('Accept', 'text/plain')]) + res = client.get("/test/", headers=[("Accept", "text/plain")]) assert res.status_code == 200 - assert res.content_type == 'application/json' + assert res.content_type == "application/json" def test_accept_default_any_pick_first(self, app, client): api = restx.Api(app) - @api.representation('text/plain') + @api.representation("text/plain") def text_rep(data, status_code, headers=None): resp = app.make_response((str(data), status_code, headers)) return resp - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/', headers=[('Accept', '*/*')]) + res = client.get("/test/", headers=[("Accept", "*/*")]) assert res.status_code == 200 - assert res.content_type == 'application/json' + assert res.content_type == "application/json" def test_accept_no_default_no_match_not_acceptable(self, app, client): api = restx.Api(app, default_mediatype=None) - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/', headers=[('Accept', 'text/plain')]) + res = client.get("/test/", headers=[("Accept", "text/plain")]) assert res.status_code == 406 - assert res.content_type == 'application/json' + assert res.content_type == "application/json" def test_accept_no_default_custom_repr_match(self, app, client): api = restx.Api(app, default_mediatype=None) api.representations = {} - @api.representation('text/plain') + @api.representation("text/plain") def text_rep(data, status_code, headers=None): resp = app.make_response((str(data), status_code, headers)) return resp - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/', headers=[('Accept', 'text/plain')]) + res = client.get("/test/", headers=[("Accept", "text/plain")]) assert res.status_code == 200 - assert res.content_type == 'text/plain' + assert res.content_type == "text/plain" def test_accept_no_default_custom_repr_not_acceptable(self, app, client): api = restx.Api(app, default_mediatype=None) api.representations = {} - @api.representation('text/plain') + @api.representation("text/plain") def text_rep(data, status_code, headers=None): resp = app.make_response((str(data), status_code, headers)) return resp - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/', headers=[('Accept', 'application/json')]) + res = client.get("/test/", headers=[("Accept", "application/json")]) assert res.status_code == 406 - assert res.content_type == 'text/plain' + assert res.content_type == "text/plain" def test_accept_no_default_match_q0_not_acceptable(self, app, client): """ @@ -101,56 +101,66 @@ def test_accept_no_default_match_q0_not_acceptable(self, app, client): so this test is expected to fail until we depend on werkzeug >= 1.0 """ api = restx.Api(app, default_mediatype=None) - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/', headers=[('Accept', 'application/json; q=0')]) + res = client.get("/test/", headers=[("Accept", "application/json; q=0")]) assert res.status_code == 406 - assert res.content_type == 'application/json' + assert res.content_type == "application/json" def test_accept_no_default_accept_highest_quality_of_two(self, app, client): api = restx.Api(app, default_mediatype=None) - @api.representation('text/plain') + @api.representation("text/plain") def text_rep(data, status_code, headers=None): resp = app.make_response((str(data), status_code, headers)) return resp - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/', headers=[('Accept', 'application/json; q=0.1, text/plain; q=1.0')]) + res = client.get( + "/test/", headers=[("Accept", "application/json; q=0.1, text/plain; q=1.0")] + ) assert res.status_code == 200 - assert res.content_type == 'text/plain' + assert res.content_type == "text/plain" def test_accept_no_default_accept_highest_quality_of_three(self, app, client): api = restx.Api(app, default_mediatype=None) - @api.representation('text/html') - @api.representation('text/plain') + @api.representation("text/html") + @api.representation("text/plain") def text_rep(data, status_code, headers=None): resp = app.make_response((str(data), status_code, headers)) return resp - api.add_resource(Foo, '/test/') - - res = client.get('/test/', headers=[('Accept', 'application/json; q=0.1, text/plain; q=0.3, text/html; q=0.2')]) + api.add_resource(Foo, "/test/") + + res = client.get( + "/test/", + headers=[ + ( + "Accept", + "application/json; q=0.1, text/plain; q=0.3, text/html; q=0.2", + ) + ], + ) assert res.status_code == 200 - assert res.content_type == 'text/plain' + assert res.content_type == "text/plain" def test_accept_no_default_no_representations(self, app, client): api = restx.Api(app, default_mediatype=None) api.representations = {} - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/', headers=[('Accept', 'text/plain')]) + res = client.get("/test/", headers=[("Accept", "text/plain")]) assert res.status_code == 406 - assert res.content_type == 'text/plain' + assert res.content_type == "text/plain" def test_accept_invalid_default_no_representations(self, app, client): - api = restx.Api(app, default_mediatype='nonexistant/mediatype') + api = restx.Api(app, default_mediatype="nonexistant/mediatype") api.representations = {} - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/', headers=[('Accept', 'text/plain')]) + res = client.get("/test/", headers=[("Accept", "text/plain")]) assert res.status_code == 500 diff --git a/tests/test_api.py b/tests/test_api.py index 30be4a54..fb0b9ada 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -10,61 +10,70 @@ class APITest(object): def test_root_endpoint(self, app): - api = restx.Api(app, version='1.0') + api = restx.Api(app, version="1.0") with app.test_request_context(): - url = url_for('root') - assert url == '/' - assert api.base_url == 'http://localhost/' + url = url_for("root") + assert url == "/" + assert api.base_url == "http://localhost/" def test_root_endpoint_lazy(self, app): - api = restx.Api(version='1.0') + api = restx.Api(version="1.0") api.init_app(app) with app.test_request_context(): - url = url_for('root') - assert url == '/' - assert api.base_url == 'http://localhost/' + url = url_for("root") + assert url == "/" + assert api.base_url == "http://localhost/" def test_root_endpoint_with_blueprint(self, app): - blueprint = Blueprint('api', __name__, url_prefix='/api') - api = restx.Api(blueprint, version='1.0') + blueprint = Blueprint("api", __name__, url_prefix="/api") + api = restx.Api(blueprint, version="1.0") app.register_blueprint(blueprint) with app.test_request_context(): - url = url_for('api.root') - assert url == '/api/' - assert api.base_url == 'http://localhost/api/' + url = url_for("api.root") + assert url == "/api/" + assert api.base_url == "http://localhost/api/" def test_root_endpoint_with_blueprint_with_subdomain(self, app): - blueprint = Blueprint('api', __name__, subdomain='api', url_prefix='/api') - api = restx.Api(blueprint, version='1.0') + blueprint = Blueprint("api", __name__, subdomain="api", url_prefix="/api") + api = restx.Api(blueprint, version="1.0") app.register_blueprint(blueprint) with app.test_request_context(): - url = url_for('api.root') - assert url == 'http://api.localhost/api/' - assert api.base_url == 'http://api.localhost/api/' + url = url_for("api.root") + assert url == "http://api.localhost/api/" + assert api.base_url == "http://api.localhost/api/" def test_parser(self): api = restx.Api() assert isinstance(api.parser(), restx.reqparse.RequestParser) def test_doc_decorator(self, app): - api = restx.Api(app, prefix='/api', version='1.0') - params = {'q': {'description': 'some description'}} + api = restx.Api(app, prefix="/api", version="1.0") + params = {"q": {"description": "some description"}} @api.doc(params=params) class TestResource(restx.Resource): pass - assert hasattr(TestResource, '__apidoc__') - assert TestResource.__apidoc__ == {'params': params} + assert hasattr(TestResource, "__apidoc__") + assert TestResource.__apidoc__ == {"params": params} def test_doc_with_inheritance(self, app): - api = restx.Api(app, prefix='/api', version='1.0') - base_params = {'q': {'description': 'some description', 'type': 'string', 'paramType': 'query'}} - child_params = {'q': {'description': 'some new description'}, 'other': {'description': 'another param'}} + api = restx.Api(app, prefix="/api", version="1.0") + base_params = { + "q": { + "description": "some description", + "type": "string", + "paramType": "query", + } + } + child_params = { + "q": {"description": "some new description"}, + "other": {"description": "another param"}, + } @api.doc(params=base_params) class BaseResource(restx.Resource): @@ -74,232 +83,228 @@ class BaseResource(restx.Resource): class TestResource(BaseResource): pass - assert TestResource.__apidoc__ == {'params': { - 'q': { - 'description': 'some new description', - 'type': 'string', - 'paramType': 'query' - }, - 'other': {'description': 'another param'}, - }} + assert TestResource.__apidoc__ == { + "params": { + "q": { + "description": "some new description", + "type": "string", + "paramType": "query", + }, + "other": {"description": "another param"}, + } + } def test_specs_endpoint_not_added(self, app): api = restx.Api() api.init_app(app, add_specs=False) - assert 'specs' not in api.endpoints - assert 'specs' not in app.view_functions + assert "specs" not in api.endpoints + assert "specs" not in app.view_functions def test_specs_endpoint_not_found_if_not_added(self, app, client): api = restx.Api() api.init_app(app, add_specs=False) - resp = client.get('/swagger.json') + resp = client.get("/swagger.json") assert resp.status_code == 404 def test_default_endpoint(self, app): api = restx.Api(app) - @api.route('/test/') + @api.route("/test/") class TestResource(restx.Resource): pass with app.test_request_context(): - assert url_for('test_resource') == '/test/' + assert url_for("test_resource") == "/test/" def test_default_endpoint_lazy(self, app): api = restx.Api() - @api.route('/test/') + @api.route("/test/") class TestResource(restx.Resource): pass api.init_app(app) with app.test_request_context(): - assert url_for('test_resource') == '/test/' + assert url_for("test_resource") == "/test/" def test_default_endpoint_with_blueprint(self, app): - blueprint = Blueprint('api', __name__, url_prefix='/api') + blueprint = Blueprint("api", __name__, url_prefix="/api") api = restx.Api(blueprint) app.register_blueprint(blueprint) - @api.route('/test/') + @api.route("/test/") class TestResource(restx.Resource): pass with app.test_request_context(): - assert url_for('api.test_resource') == '/api/test/' + assert url_for("api.test_resource") == "/api/test/" def test_default_endpoint_with_blueprint_with_subdomain(self, app): - blueprint = Blueprint('api', __name__, subdomain='api', url_prefix='/api') + blueprint = Blueprint("api", __name__, subdomain="api", url_prefix="/api") api = restx.Api(blueprint) app.register_blueprint(blueprint) - @api.route('/test/') + @api.route("/test/") class TestResource(restx.Resource): pass with app.test_request_context(): - assert url_for('api.test_resource') == 'http://api.localhost/api/test/' + assert url_for("api.test_resource") == "http://api.localhost/api/test/" def test_default_endpoint_for_namespace(self, app): api = restx.Api(app) - ns = api.namespace('ns', 'Test namespace') + ns = api.namespace("ns", "Test namespace") - @ns.route('/test/') + @ns.route("/test/") class TestResource(restx.Resource): pass with app.test_request_context(): - assert url_for('ns_test_resource') == '/ns/test/' + assert url_for("ns_test_resource") == "/ns/test/" def test_default_endpoint_lazy_for_namespace(self, app): api = restx.Api() - ns = api.namespace('ns', 'Test namespace') + ns = api.namespace("ns", "Test namespace") - @ns.route('/test/') + @ns.route("/test/") class TestResource(restx.Resource): pass api.init_app(app) with app.test_request_context(): - assert url_for('ns_test_resource') == '/ns/test/' + assert url_for("ns_test_resource") == "/ns/test/" def test_default_endpoint_for_namespace_with_blueprint(self, app): - blueprint = Blueprint('api', __name__, url_prefix='/api') + blueprint = Blueprint("api", __name__, url_prefix="/api") api = restx.Api(blueprint) - ns = api.namespace('ns', 'Test namespace') + ns = api.namespace("ns", "Test namespace") - @ns.route('/test/') + @ns.route("/test/") class TestResource(restx.Resource): pass app.register_blueprint(blueprint) with app.test_request_context(): - assert url_for('api.ns_test_resource') == '/api/ns/test/' + assert url_for("api.ns_test_resource") == "/api/ns/test/" def test_multiple_default_endpoint(self, app): api = restx.Api(app) - @api.route('/test2/') - @api.route('/test/') + @api.route("/test2/") + @api.route("/test/") class TestResource(restx.Resource): pass with app.test_request_context(): - assert url_for('test_resource') == '/test/' - assert url_for('test_resource_2') == '/test2/' + assert url_for("test_resource") == "/test/" + assert url_for("test_resource_2") == "/test2/" def test_multiple_default_endpoint_lazy(self, app): api = restx.Api() - @api.route('/test2/') - @api.route('/test/') + @api.route("/test2/") + @api.route("/test/") class TestResource(restx.Resource): pass api.init_app(app) with app.test_request_context(): - assert url_for('test_resource') == '/test/' - assert url_for('test_resource_2') == '/test2/' + assert url_for("test_resource") == "/test/" + assert url_for("test_resource_2") == "/test2/" def test_multiple_default_endpoint_for_namespace(self, app): api = restx.Api(app) - ns = api.namespace('ns', 'Test namespace') + ns = api.namespace("ns", "Test namespace") - @ns.route('/test2/') - @ns.route('/test/') + @ns.route("/test2/") + @ns.route("/test/") class TestResource(restx.Resource): pass with app.test_request_context(): - assert url_for('ns_test_resource') == '/ns/test/' - assert url_for('ns_test_resource_2') == '/ns/test2/' + assert url_for("ns_test_resource") == "/ns/test/" + assert url_for("ns_test_resource_2") == "/ns/test2/" def test_multiple_default_endpoint_lazy_for_namespace(self, app): api = restx.Api() - ns = api.namespace('ns', 'Test namespace') + ns = api.namespace("ns", "Test namespace") - @ns.route('/test2/') - @ns.route('/test/') + @ns.route("/test2/") + @ns.route("/test/") class TestResource(restx.Resource): pass api.init_app(app) with app.test_request_context(): - assert url_for('ns_test_resource') == '/ns/test/' - assert url_for('ns_test_resource_2') == '/ns/test2/' + assert url_for("ns_test_resource") == "/ns/test/" + assert url_for("ns_test_resource_2") == "/ns/test2/" def test_multiple_default_endpoint_for_namespace_with_blueprint(self, app): - blueprint = Blueprint('api', __name__, url_prefix='/api') + blueprint = Blueprint("api", __name__, url_prefix="/api") api = restx.Api(blueprint) - ns = api.namespace('ns', 'Test namespace') + ns = api.namespace("ns", "Test namespace") - @ns.route('/test2/') - @ns.route('/test/') + @ns.route("/test2/") + @ns.route("/test/") class TestResource(restx.Resource): pass app.register_blueprint(blueprint) with app.test_request_context(): - assert url_for('api.ns_test_resource') == '/api/ns/test/' - assert url_for('api.ns_test_resource_2') == '/api/ns/test2/' + assert url_for("api.ns_test_resource") == "/api/ns/test/" + assert url_for("api.ns_test_resource_2") == "/api/ns/test2/" def test_ns_path_prefixes(self, app): api = restx.Api() - ns = restx.Namespace('test_ns', description='Test namespace') + ns = restx.Namespace("test_ns", description="Test namespace") - @ns.route('/test/', endpoint='test_resource') + @ns.route("/test/", endpoint="test_resource") class TestResource(restx.Resource): pass - api.add_namespace(ns, '/api_test') + api.add_namespace(ns, "/api_test") api.init_app(app) with app.test_request_context(): - assert url_for('test_resource') == '/api_test/test/' + assert url_for("test_resource") == "/api_test/test/" def test_multiple_ns_with_authorizations(self, app): api = restx.Api() - a1 = { - 'apikey': { - 'type': 'apiKey', - 'in': 'header', - 'name': 'X-API' - } - } + a1 = {"apikey": {"type": "apiKey", "in": "header", "name": "X-API"}} a2 = { - 'oauth2': { - 'type': 'oauth2', - 'flow': 'accessCode', - 'tokenUrl': 'https://somewhere.com/token', - 'scopes': { - 'read': 'Grant read-only access', - 'write': 'Grant read-write access', - } + "oauth2": { + "type": "oauth2", + "flow": "accessCode", + "tokenUrl": "https://somewhere.com/token", + "scopes": { + "read": "Grant read-only access", + "write": "Grant read-write access", + }, } } - ns1 = restx.Namespace('ns1', authorizations=a1) - ns2 = restx.Namespace('ns2', authorizations=a2) + ns1 = restx.Namespace("ns1", authorizations=a1) + ns2 = restx.Namespace("ns2", authorizations=a2) - @ns1.route('/') + @ns1.route("/") class Ns1(restx.Resource): - @ns1.doc(security='apikey') + @ns1.doc(security="apikey") def get(self): pass - @ns2.route('/') + @ns2.route("/") class Ns2(restx.Resource): - @ns1.doc(security='oauth2') + @ns1.doc(security="oauth2") def post(self): pass - api.add_namespace(ns1, path='/ns1') - api.add_namespace(ns2, path='/ns2') + api.add_namespace(ns1, path="/ns1") + api.add_namespace(ns2, path="/ns2") api.init_app(app) assert {"apikey": []} in api.__schema__["paths"]["/ns1/"]["get"]["security"] @@ -310,13 +315,13 @@ def post(self): def test_non_ordered_namespace(self, app): api = restx.Api(app) - ns = api.namespace('ns', 'Test namespace') + ns = api.namespace("ns", "Test namespace") assert not ns.ordered def test_ordered_namespace(self, app): api = restx.Api(app, ordered=True) - ns = api.namespace('ns', 'Test namespace') + ns = api.namespace("ns", "Test namespace") assert ns.ordered @@ -329,9 +334,9 @@ class TestResource(restx.Resource): method_decorators = [] api = restx.Api(decorators=[decorator1]) - ns = api.namespace('test_ns', decorators=[decorator2, decorator3]) + ns = api.namespace("test_ns", decorators=[decorator2, decorator3]) - ns.add_resource(TestResource, '/test', endpoint='test') + ns.add_resource(TestResource, "/test", endpoint="test") api.init_app(app) assert decorator1.called is True diff --git a/tests/test_apidoc.py b/tests/test_apidoc.py index 8fb995fb..b1e7bc03 100644 --- a/tests/test_apidoc.py +++ b/tests/test_apidoc.py @@ -11,124 +11,126 @@ class APIDocTest(object): def test_default_apidoc_on_root(self, app, client): - restx.Api(app, version='1.0') + restx.Api(app, version="1.0") - assert url_for('doc') == url_for('root') + assert url_for("doc") == url_for("root") - response = client.get(url_for('doc')) + response = client.get(url_for("doc")) assert response.status_code == 200 - assert response.content_type == 'text/html; charset=utf-8' + assert response.content_type == "text/html; charset=utf-8" def test_default_apidoc_on_root_lazy(self, app, client): - api = restx.Api(version='1.0') + api = restx.Api(version="1.0") api.init_app(app) - assert url_for('doc') == url_for('root') + assert url_for("doc") == url_for("root") - response = client.get(url_for('doc')) + response = client.get(url_for("doc")) assert response.status_code == 200 - assert response.content_type == 'text/html; charset=utf-8' + assert response.content_type == "text/html; charset=utf-8" def test_default_apidoc_on_root_with_blueprint(self, app, client): - blueprint = Blueprint('api', __name__, url_prefix='/api') - restx.Api(blueprint, version='1.0') + blueprint = Blueprint("api", __name__, url_prefix="/api") + restx.Api(blueprint, version="1.0") app.register_blueprint(blueprint) - assert url_for('api.doc') == url_for('api.root') + assert url_for("api.doc") == url_for("api.root") - response = client.get(url_for('api.doc')) + response = client.get(url_for("api.doc")) assert response.status_code == 200 - assert response.content_type == 'text/html; charset=utf-8' + assert response.content_type == "text/html; charset=utf-8" def test_apidoc_with_custom_validator(self, app, client): - app.config['SWAGGER_VALIDATOR_URL'] = 'http://somewhere.com/validator' - restx.Api(app, version='1.0') + app.config["SWAGGER_VALIDATOR_URL"] = "http://somewhere.com/validator" + restx.Api(app, version="1.0") - response = client.get(url_for('doc')) + response = client.get(url_for("doc")) assert response.status_code == 200 - assert response.content_type == 'text/html; charset=utf-8' - assert 'validatorUrl: "http://somewhere.com/validator" || null,' in str(response.data) + assert response.content_type == "text/html; charset=utf-8" + assert 'validatorUrl: "http://somewhere.com/validator" || null,' in str( + response.data + ) def test_apidoc_doc_expansion_parameter(self, app, client): restx.Api(app) - response = client.get(url_for('doc')) + response = client.get(url_for("doc")) assert 'docExpansion: "none"' in str(response.data) - app.config['SWAGGER_UI_DOC_EXPANSION'] = 'list' - response = client.get(url_for('doc')) + app.config["SWAGGER_UI_DOC_EXPANSION"] = "list" + response = client.get(url_for("doc")) assert 'docExpansion: "list"' in str(response.data) - app.config['SWAGGER_UI_DOC_EXPANSION'] = 'full' - response = client.get(url_for('doc')) + app.config["SWAGGER_UI_DOC_EXPANSION"] = "full" + response = client.get(url_for("doc")) assert 'docExpansion: "full"' in str(response.data) def test_apidoc_doc_display_operation_id(self, app, client): restx.Api(app) - response = client.get(url_for('doc')) - assert 'displayOperationId: false' in str(response.data) + response = client.get(url_for("doc")) + assert "displayOperationId: false" in str(response.data) - app.config['SWAGGER_UI_OPERATION_ID'] = False - response = client.get(url_for('doc')) - assert 'displayOperationId: false' in str(response.data) + app.config["SWAGGER_UI_OPERATION_ID"] = False + response = client.get(url_for("doc")) + assert "displayOperationId: false" in str(response.data) - app.config['SWAGGER_UI_OPERATION_ID'] = True - response = client.get(url_for('doc')) - assert 'displayOperationId: true' in str(response.data) + app.config["SWAGGER_UI_OPERATION_ID"] = True + response = client.get(url_for("doc")) + assert "displayOperationId: true" in str(response.data) def test_apidoc_doc_display_request_duration(self, app, client): restx.Api(app) - response = client.get(url_for('doc')) - assert 'displayRequestDuration: false' in str(response.data) + response = client.get(url_for("doc")) + assert "displayRequestDuration: false" in str(response.data) - app.config['SWAGGER_UI_REQUEST_DURATION'] = False - response = client.get(url_for('doc')) - assert 'displayRequestDuration: false' in str(response.data) + app.config["SWAGGER_UI_REQUEST_DURATION"] = False + response = client.get(url_for("doc")) + assert "displayRequestDuration: false" in str(response.data) - app.config['SWAGGER_UI_REQUEST_DURATION'] = True - response = client.get(url_for('doc')) - assert 'displayRequestDuration: true' in str(response.data) + app.config["SWAGGER_UI_REQUEST_DURATION"] = True + response = client.get(url_for("doc")) + assert "displayRequestDuration: true" in str(response.data) def test_custom_apidoc_url(self, app, client): - restx.Api(app, version='1.0', doc='/doc/') + restx.Api(app, version="1.0", doc="/doc/") - doc_url = url_for('doc') - root_url = url_for('root') + doc_url = url_for("doc") + root_url = url_for("root") assert doc_url != root_url response = client.get(root_url) assert response.status_code == 404 - assert doc_url == '/doc/' + assert doc_url == "/doc/" response = client.get(doc_url) assert response.status_code == 200 - assert response.content_type == 'text/html; charset=utf-8' + assert response.content_type == "text/html; charset=utf-8" def test_custom_api_prefix(self, app, client): - prefix = '/api' + prefix = "/api" api = restx.Api(app, prefix=prefix) - api.namespace('resource') - assert url_for('root') == prefix + api.namespace("resource") + assert url_for("root") == prefix def test_custom_apidoc_page(self, app, client): - api = restx.Api(app, version='1.0') - content = 'My Custom API Doc' + api = restx.Api(app, version="1.0") + content = "My Custom API Doc" @api.documentation def api_doc(): return content - response = client.get(url_for('doc')) + response = client.get(url_for("doc")) assert response.status_code == 200 - assert response.data.decode('utf8') == content + assert response.data.decode("utf8") == content def test_custom_apidoc_page_lazy(self, app, client): - blueprint = Blueprint('api', __name__, url_prefix='/api') - api = restx.Api(blueprint, version='1.0') - content = 'My Custom API Doc' + blueprint = Blueprint("api", __name__, url_prefix="/api") + api = restx.Api(blueprint, version="1.0") + content = "My Custom API Doc" @api.documentation def api_doc(): @@ -136,15 +138,15 @@ def api_doc(): app.register_blueprint(blueprint) - response = client.get(url_for('api.doc')) + response = client.get(url_for("api.doc")) assert response.status_code == 200 - assert response.data.decode('utf8') == content + assert response.data.decode("utf8") == content def test_disabled_apidoc(self, app, client): - restx.Api(app, version='1.0', doc=False) + restx.Api(app, version="1.0", doc=False) with pytest.raises(BuildError): - url_for('doc') + url_for("doc") - response = client.get(url_for('root')) + response = client.get(url_for("root")) assert response.status_code == 404 diff --git a/tests/test_cors.py b/tests/test_cors.py index 3d5e49fe..898daedf 100644 --- a/tests/test_cors.py +++ b/tests/test_cors.py @@ -7,35 +7,36 @@ class ErrorsTest(object): def test_crossdomain(self, app, client): class Foo(Resource): - @cors.crossdomain(origin='*') + @cors.crossdomain(origin="*") def get(self): return "data" api = Api(app) - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/') + res = client.get("/test/") assert res.status_code == 200 - assert res.headers['Access-Control-Allow-Origin'] == '*' - assert res.headers['Access-Control-Max-Age'] == '21600' - assert 'HEAD' in res.headers['Access-Control-Allow-Methods'] - assert 'OPTIONS' in res.headers['Access-Control-Allow-Methods'] - assert 'GET' in res.headers['Access-Control-Allow-Methods'] + assert res.headers["Access-Control-Allow-Origin"] == "*" + assert res.headers["Access-Control-Max-Age"] == "21600" + assert "HEAD" in res.headers["Access-Control-Allow-Methods"] + assert "OPTIONS" in res.headers["Access-Control-Allow-Methods"] + assert "GET" in res.headers["Access-Control-Allow-Methods"] def test_access_control_expose_headers(self, app, client): class Foo(Resource): - @cors.crossdomain(origin='*', - expose_headers=['X-My-Header', 'X-Another-Header']) + @cors.crossdomain( + origin="*", expose_headers=["X-My-Header", "X-Another-Header"] + ) def get(self): return "data" api = Api(app) - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/') + res = client.get("/test/") assert res.status_code == 200 - assert 'X-MY-HEADER' in res.headers['Access-Control-Expose-Headers'] - assert 'X-ANOTHER-HEADER' in res.headers['Access-Control-Expose-Headers'] + assert "X-MY-HEADER" in res.headers["Access-Control-Expose-Headers"] + assert "X-ANOTHER-HEADER" in res.headers["Access-Control-Expose-Headers"] def test_no_crossdomain(self, app, client): class Foo(Resource): @@ -43,10 +44,10 @@ def get(self): return "data" api = Api(app) - api.add_resource(Foo, '/test/') + api.add_resource(Foo, "/test/") - res = client.get('/test/') + res = client.get("/test/") assert res.status_code == 200 - assert 'Access-Control-Allow-Origin' not in res.headers - assert 'Access-Control-Allow-Methods' not in res.headers - assert 'Access-Control-Max-Age' not in res.headers + assert "Access-Control-Allow-Origin" not in res.headers + assert "Access-Control-Allow-Methods" not in res.headers + assert "Access-Control-Max-Age" not in res.headers diff --git a/tests/test_errors.py b/tests/test_errors.py index 79cbb2b9..98e24d15 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -22,97 +22,97 @@ def test_abort_type(self): def test_abort_data(self): with pytest.raises(HTTPException) as cm: - restx.abort(404, foo='bar') - assert cm.value.data == {'foo': 'bar'} + restx.abort(404, foo="bar") + assert cm.value.data == {"foo": "bar"} def test_abort_no_data(self): with pytest.raises(HTTPException) as cm: restx.abort(404) - assert not hasattr(cm.value, 'data') + assert not hasattr(cm.value, "data") def test_abort_custom_message(self): with pytest.raises(HTTPException) as cm: - restx.abort(404, 'My message') - assert cm.value.data['message'] == 'My message' + restx.abort(404, "My message") + assert cm.value.data["message"] == "My message" def test_abort_code_only_with_defaults(self, app, client): api = restx.Api(app) - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): api.abort(403) - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 403 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) - assert 'message' in data + data = json.loads(response.data.decode("utf8")) + assert "message" in data def test_abort_with_message(self, app, client): api = restx.Api(app) - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): - api.abort(403, 'A message') + api.abort(403, "A message") - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 403 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) - assert data['message'] == 'A message' + data = json.loads(response.data.decode("utf8")) + assert data["message"] == "A message" def test_abort_with_lazy_init(self, app, client): api = restx.Api() - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): api.abort(403) api.init_app(app) - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 403 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) - assert 'message' in data + data = json.loads(response.data.decode("utf8")) + assert "message" in data def test_abort_on_exception(self, app, client): api = restx.Api(app) - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): raise ValueError() - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 500 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) - assert 'message' in data + data = json.loads(response.data.decode("utf8")) + assert "message" in data def test_abort_on_exception_with_lazy_init(self, app, client): api = restx.Api() - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): raise ValueError() api.init_app(app) - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 500 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) - assert 'message' in data + data = json.loads(response.data.decode("utf8")) + assert "message" in data def test_errorhandler_for_exception_inheritance(self, app, client): api = restx.Api(app) @@ -120,23 +120,23 @@ def test_errorhandler_for_exception_inheritance(self, app, client): class CustomException(RuntimeError): pass - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): - raise CustomException('error') + raise CustomException("error") @api.errorhandler(RuntimeError) def handle_custom_exception(error): - return {'message': str(error), 'test': 'value'}, 400 + return {"message": str(error), "test": "value"}, 400 - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 400 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) + data = json.loads(response.data.decode("utf8")) assert data == { - 'message': 'error', - 'test': 'value', + "message": "error", + "test": "value", } def test_errorhandler_for_custom_exception(self, app, client): @@ -145,26 +145,28 @@ def test_errorhandler_for_custom_exception(self, app, client): class CustomException(RuntimeError): pass - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): - raise CustomException('error') + raise CustomException("error") @api.errorhandler(CustomException) def handle_custom_exception(error): - return {'message': str(error), 'test': 'value'}, 400 + return {"message": str(error), "test": "value"}, 400 - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 400 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) + data = json.loads(response.data.decode("utf8")) assert data == { - 'message': 'error', - 'test': 'value', + "message": "error", + "test": "value", } - def test_blunder_in_errorhandler_is_not_suppressed_in_logs(self, app, client, caplog): + def test_blunder_in_errorhandler_is_not_suppressed_in_logs( + self, app, client, caplog + ): api = restx.Api(app) @@ -174,17 +176,19 @@ class CustomException(RuntimeError): class ProgrammingBlunder(Exception): pass - @api.route('/test/', endpoint="test") + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): - raise CustomException('error') + raise CustomException("error") @api.errorhandler(CustomException) def handle_custom_exception(error): - raise ProgrammingBlunder("This exception needs to be logged, not suppressed, then cause 500") + raise ProgrammingBlunder( + "This exception needs to be logged, not suppressed, then cause 500" + ) with caplog.at_level(logging.ERROR): - response = client.get('/test/') + response = client.get("/test/") exc_type, value, traceback = caplog.records[0].exc_info assert exc_type is ProgrammingBlunder assert response.status_code == 500 @@ -195,43 +199,43 @@ def test_errorhandler_for_custom_exception_with_headers(self, app, client): class CustomException(RuntimeError): pass - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): - raise CustomException('error') + raise CustomException("error") @api.errorhandler(CustomException) def handle_custom_exception(error): - return {'message': 'some maintenance'}, 503, {'Retry-After': 120} + return {"message": "some maintenance"}, 503, {"Retry-After": 120} - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 503 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) - assert data == {'message': 'some maintenance'} - assert response.headers['Retry-After'] == '120' + data = json.loads(response.data.decode("utf8")) + assert data == {"message": "some maintenance"} + assert response.headers["Retry-After"] == "120" def test_errorhandler_for_httpexception(self, app, client): api = restx.Api(app) - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): raise BadRequest() @api.errorhandler(BadRequest) def handle_badrequest_exception(error): - return {'message': str(error), 'test': 'value'}, 400 + return {"message": str(error), "test": "value"}, 400 - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 400 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) + data = json.loads(response.data.decode("utf8")) assert data == { - 'message': str(BadRequest()), - 'test': 'value', + "message": str(BadRequest()), + "test": "value", } def test_errorhandler_with_namespace(self, app, client): @@ -242,25 +246,25 @@ def test_errorhandler_with_namespace(self, app, client): class CustomException(RuntimeError): pass - @ns.route('/test/', endpoint='test') + @ns.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): - raise CustomException('error') + raise CustomException("error") @ns.errorhandler(CustomException) def handle_custom_exception(error): - return {'message': str(error), 'test': 'value'}, 400 + return {"message": str(error), "test": "value"}, 400 api.add_namespace(ns) - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 400 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) + data = json.loads(response.data.decode("utf8")) assert data == { - 'message': 'error', - 'test': 'value', + "message": "error", + "test": "value", } def test_errorhandler_with_namespace_from_api(self, app, client): @@ -271,101 +275,101 @@ def test_errorhandler_with_namespace_from_api(self, app, client): class CustomException(RuntimeError): pass - @ns.route('/test/', endpoint='test') + @ns.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): - raise CustomException('error') + raise CustomException("error") @ns.errorhandler(CustomException) def handle_custom_exception(error): - return {'message': str(error), 'test': 'value'}, 400 + return {"message": str(error), "test": "value"}, 400 - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 400 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) + data = json.loads(response.data.decode("utf8")) assert data == { - 'message': 'error', - 'test': 'value', + "message": "error", + "test": "value", } def test_default_errorhandler(self, app, client): api = restx.Api(app) - @api.route('/test/') + @api.route("/test/") class TestResource(restx.Resource): def get(self): - raise Exception('error') + raise Exception("error") - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 500 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) - assert 'message' in data + data = json.loads(response.data.decode("utf8")) + assert "message" in data def test_default_errorhandler_with_propagate_true(self, app, client): - blueprint = Blueprint('api', __name__, url_prefix='/api') + blueprint = Blueprint("api", __name__, url_prefix="/api") api = restx.Api(blueprint) - @api.route('/test/') + @api.route("/test/") class TestResource(restx.Resource): def get(self): - raise Exception('error') + raise Exception("error") app.register_blueprint(blueprint) - app.config['PROPAGATE_EXCEPTIONS'] = True + app.config["PROPAGATE_EXCEPTIONS"] = True # From the Flask docs: # PROPAGATE_EXCEPTIONS # Exceptions are re-raised rather than being handled by the app’s error handlers. # If not set, this is implicitly true if TESTING or DEBUG is enabled. with pytest.raises(Exception): - client.get('/api/test/') + client.get("/api/test/") def test_custom_default_errorhandler(self, app, client): api = restx.Api(app) - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): - raise Exception('error') + raise Exception("error") @api.errorhandler def default_error_handler(error): - return {'message': str(error), 'test': 'value'}, 500 + return {"message": str(error), "test": "value"}, 500 - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 500 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) + data = json.loads(response.data.decode("utf8")) assert data == { - 'message': 'error', - 'test': 'value', + "message": "error", + "test": "value", } def test_custom_default_errorhandler_with_headers(self, app, client): api = restx.Api(app) - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): - raise Exception('error') + raise Exception("error") @api.errorhandler def default_error_handler(error): - return {'message': 'some maintenance'}, 503, {'Retry-After': 120} + return {"message": "some maintenance"}, 503, {"Retry-After": 120} - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 503 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) - assert data == {'message': 'some maintenance'} - assert response.headers['Retry-After'] == '120' + data = json.loads(response.data.decode("utf8")) + assert data == {"message": "some maintenance"} + assert response.headers["Retry-After"] == "120" def test_errorhandler_lazy(self, app, client): api = restx.Api() @@ -373,53 +377,53 @@ def test_errorhandler_lazy(self, app, client): class CustomException(RuntimeError): pass - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): - raise CustomException('error') + raise CustomException("error") @api.errorhandler(CustomException) def handle_custom_exception(error): - return {'message': str(error), 'test': 'value'}, 400 + return {"message": str(error), "test": "value"}, 400 api.init_app(app) - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 400 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) + data = json.loads(response.data.decode("utf8")) assert data == { - 'message': 'error', - 'test': 'value', + "message": "error", + "test": "value", } def test_handle_api_error(self, app, client): api = restx.Api(app) - @api.route('/api', endpoint='api') + @api.route("/api", endpoint="api") class Test(restx.Resource): def get(self): abort(404) response = client.get("/api") assert response.status_code == 404 - assert response.headers['Content-Type'] == 'application/json' + assert response.headers["Content-Type"] == "application/json" data = json.loads(response.data.decode()) - assert 'message' in data + assert "message" in data def test_handle_non_api_error(self, app, client): restx.Api(app) response = client.get("/foo") assert response.status_code == 404 - assert response.headers['Content-Type'] == 'text/html; charset=utf-8' + assert response.headers["Content-Type"] == "text/html; charset=utf-8" def test_non_api_error_404_catchall(self, app, client): api = restx.Api(app, catch_all_404s=True) response = client.get("/foo") - assert response.headers['Content-Type'] == api.default_mediatype + assert response.headers["Content-Type"] == api.default_mediatype def test_handle_error_signal(self, app): api = restx.Api(app) @@ -446,7 +450,7 @@ def test_handle_error(self, app): response = api.handle_error(BadRequest()) assert response.status_code == 400 assert json.loads(response.data.decode()) == { - 'message': BadRequest.description, + "message": BadRequest.description, } def test_handle_error_does_not_duplicate_content_length(self, app): @@ -454,57 +458,55 @@ def test_handle_error_does_not_duplicate_content_length(self, app): # with self.app.test_request_context("/foo"): response = api.handle_error(BadRequest()) - assert len(response.headers.getlist('Content-Length')) == 1 + assert len(response.headers.getlist("Content-Length")) == 1 def test_handle_smart_errors(self, app): api = restx.Api(app) view = restx.Resource - api.add_resource(view, '/foo', endpoint='bor') - api.add_resource(view, '/fee', endpoint='bir') - api.add_resource(view, '/fii', endpoint='ber') + api.add_resource(view, "/foo", endpoint="bor") + api.add_resource(view, "/fee", endpoint="bir") + api.add_resource(view, "/fii", endpoint="ber") with app.test_request_context("/faaaaa"): response = api.handle_error(NotFound()) assert response.status_code == 404 assert json.loads(response.data.decode()) == { - 'message': NotFound.description, + "message": NotFound.description, } with app.test_request_context("/fOo"): response = api.handle_error(NotFound()) assert response.status_code == 404 - assert 'did you mean /foo ?' in response.data.decode() + assert "did you mean /foo ?" in response.data.decode() - app.config['ERROR_404_HELP'] = False + app.config["ERROR_404_HELP"] = False response = api.handle_error(NotFound()) assert response.status_code == 404 - assert json.loads(response.data.decode()) == { - 'message': NotFound.description - } + assert json.loads(response.data.decode()) == {"message": NotFound.description} def test_handle_include_error_message(self, app): api = restx.Api(app) view = restx.Resource - api.add_resource(view, '/foo', endpoint='bor') + api.add_resource(view, "/foo", endpoint="bor") with app.test_request_context("/faaaaa"): response = api.handle_error(NotFound()) - assert 'message' in json.loads(response.data.decode()) + assert "message" in json.loads(response.data.decode()) def test_handle_not_include_error_message(self, app): - app.config['ERROR_INCLUDE_MESSAGE'] = False + app.config["ERROR_INCLUDE_MESSAGE"] = False api = restx.Api(app) view = restx.Resource - api.add_resource(view, '/foo', endpoint='bor') + api.add_resource(view, "/foo", endpoint="bor") with app.test_request_context("/faaaaa"): response = api.handle_error(NotFound()) - assert 'message' not in json.loads(response.data.decode()) + assert "message" not in json.loads(response.data.decode()) def test_error_router_falls_back_to_original(self, app, mocker): class ProgrammingBlunder(Exception): @@ -528,22 +530,22 @@ def raise_blunder(arg): def test_fr_405(self, app, client): api = restx.Api(app) - @api.route('/ids/', endpoint='hello') + @api.route("/ids/", endpoint="hello") class HelloWorld(restx.Resource): def get(self): return {} - response = client.post('/ids/3') + response = client.post("/ids/3") assert response.status_code == 405 assert response.content_type == api.default_mediatype # Allow can be of the form 'GET, PUT, POST' - allow = ', '.join(set(response.headers.get_all('Allow'))) - allow = set(method.strip() for method in allow.split(',')) - assert allow == set(['HEAD', 'OPTIONS', 'GET']) + allow = ", ".join(set(response.headers.get_all("Allow"))) + allow = set(method.strip() for method in allow.split(",")) + assert allow == set(["HEAD", "OPTIONS", "GET"]) @pytest.mark.options(debug=True) def test_exception_header_forwarded(self, app, client): - '''Ensure that HTTPException's headers are extended properly''' + """Ensure that HTTPException's headers are extended properly""" api = restx.Api(app) class NotModified(HTTPException): @@ -554,33 +556,31 @@ def __init__(self, etag, *args, **kwargs): self.etag = quote_etag(etag) def get_headers(self, *args, **kwargs): - return [('ETag', self.etag)] + return [("ETag", self.etag)] custom_abort = Aborter(mapping={304: NotModified}) - @api.route('/foo') + @api.route("/foo") class Foo1(restx.Resource): def get(self): - custom_abort(304, etag='myETag') + custom_abort(304, etag="myETag") - foo = client.get('/foo') - assert foo.get_etag() == unquote_etag(quote_etag('myETag')) + foo = client.get("/foo") + assert foo.get_etag() == unquote_etag(quote_etag("myETag")) def test_handle_server_error(self, app): api = restx.Api(app) resp = api.handle_error(Exception()) assert resp.status_code == 500 - assert json.loads(resp.data.decode()) == { - 'message': "Internal Server Error" - } + assert json.loads(resp.data.decode()) == {"message": "Internal Server Error"} def test_handle_error_with_code(self, app): api = restx.Api(app, serve_challenge_on_401=True) exception = Exception() exception.code = "Not an integer" - exception.data = {'foo': 'bar'} + exception.data = {"foo": "bar"} response = api.handle_error(exception) assert response.status_code == 500 @@ -592,73 +592,64 @@ def test_errorhandler_swagger_doc(self, app, client): class CustomException(RuntimeError): pass - error = api.model('Error', { - 'message': restx.fields.String() - }) + error = api.model("Error", {"message": restx.fields.String()}) - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): - ''' + """ Do something :raises CustomException: In case of something - ''' + """ pass @api.errorhandler(CustomException) - @api.header('Custom-Header', 'Some custom header') + @api.header("Custom-Header", "Some custom header") @api.marshal_with(error, code=503) def handle_custom_exception(error): - '''Some description''' + """Some description""" pass specs = client.get_specs() - assert 'Error' in specs['definitions'] - assert 'CustomException' in specs['responses'] + assert "Error" in specs["definitions"] + assert "CustomException" in specs["responses"] - response = specs['responses']['CustomException'] - assert response['description'] == 'Some description' - assert response['schema'] == { - '$ref': '#/definitions/Error' - } - assert response['headers'] == { - 'Custom-Header': { - 'description': 'Some custom header', - 'type': 'string' - } + response = specs["responses"]["CustomException"] + assert response["description"] == "Some description" + assert response["schema"] == {"$ref": "#/definitions/Error"} + assert response["headers"] == { + "Custom-Header": {"description": "Some custom header", "type": "string"} } - operation = specs['paths']['/test/']['get'] - assert 'responses' in operation - assert operation['responses'] == { - '503': { - '$ref': '#/responses/CustomException' - } + operation = specs["paths"]["/test/"]["get"] + assert "responses" in operation + assert operation["responses"] == { + "503": {"$ref": "#/responses/CustomException"} } def test_errorhandler_with_propagate_true(self, app, client): - '''Exceptions with errorhandler should not be returned to client, even - if PROPAGATE_EXCEPTIONS is set.''' - app.config['PROPAGATE_EXCEPTIONS'] = True + """Exceptions with errorhandler should not be returned to client, even + if PROPAGATE_EXCEPTIONS is set.""" + app.config["PROPAGATE_EXCEPTIONS"] = True api = restx.Api(app) - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): - raise RuntimeError('error') + raise RuntimeError("error") @api.errorhandler(RuntimeError) def handle_custom_exception(error): - return {'message': str(error), 'test': 'value'}, 400 + return {"message": str(error), "test": "value"}, 400 - response = client.get('/test/') + response = client.get("/test/") assert response.status_code == 400 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" - data = json.loads(response.data.decode('utf8')) + data = json.loads(response.data.decode("utf8")) assert data == { - 'message': 'error', - 'test': 'value', + "message": "error", + "test": "value", } diff --git a/tests/test_fields.py b/tests/test_fields.py index 9b7f9912..c154a972 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -18,29 +18,29 @@ class FieldTestCase(object): @pytest.fixture def api(self, app): - blueprint = Blueprint('api', __name__) + blueprint = Blueprint("api", __name__) api = Api(blueprint) app.register_blueprint(blueprint) yield api def assert_field(self, field, value, expected): - assert field.output('foo', {'foo': value}) == expected + assert field.output("foo", {"foo": value}) == expected def assert_field_raises(self, field, value): with pytest.raises(fields.MarshallingError): - field.output('foo', {'foo': value}) + field.output("foo", {"foo": value}) class BaseFieldTestMixin(object): def test_description(self): - field = self.field_class(description='A description') - assert 'description' in field.__schema__ - assert field.__schema__['description'] == 'A description' + field = self.field_class(description="A description") + assert "description" in field.__schema__ + assert field.__schema__["description"] == "A description" def test_title(self): - field = self.field_class(title='A title') - assert 'title' in field.__schema__ - assert field.__schema__['title'] == 'A title' + field = self.field_class(title="A title") + assert "title" in field.__schema__ + assert field.__schema__["title"] == "A title" def test_required(self): field = self.field_class(required=True) @@ -48,128 +48,129 @@ def test_required(self): def test_readonly(self): field = self.field_class(readonly=True) - assert 'readOnly' in field.__schema__ - assert field.__schema__['readOnly'] + assert "readOnly" in field.__schema__ + assert field.__schema__["readOnly"] class NumberTestMixin(object): def test_min(self): field = self.field_class(min=0) - assert 'minimum' in field.__schema__ - assert field.__schema__['minimum'] == 0 - assert 'exclusiveMinimum' not in field.__schema__ + assert "minimum" in field.__schema__ + assert field.__schema__["minimum"] == 0 + assert "exclusiveMinimum" not in field.__schema__ def test_min_as_callable(self): field = self.field_class(min=lambda: 0) - assert 'minimum' in field.__schema__ - assert field.__schema__['minimum'] == 0 + assert "minimum" in field.__schema__ + assert field.__schema__["minimum"] == 0 def test_min_exlusive(self): field = self.field_class(min=0, exclusiveMin=True) - assert 'minimum' in field.__schema__ - assert field.__schema__['minimum'] == 0 - assert 'exclusiveMinimum' in field.__schema__ - assert field.__schema__['exclusiveMinimum'] is True + assert "minimum" in field.__schema__ + assert field.__schema__["minimum"] == 0 + assert "exclusiveMinimum" in field.__schema__ + assert field.__schema__["exclusiveMinimum"] is True def test_max(self): field = self.field_class(max=42) - assert 'maximum' in field.__schema__ - assert field.__schema__['maximum'] == 42 - assert 'exclusiveMaximum' not in field.__schema__ + assert "maximum" in field.__schema__ + assert field.__schema__["maximum"] == 42 + assert "exclusiveMaximum" not in field.__schema__ def test_max_as_callable(self): field = self.field_class(max=lambda: 42) - assert 'maximum' in field.__schema__ - assert field.__schema__['maximum'] == 42 + assert "maximum" in field.__schema__ + assert field.__schema__["maximum"] == 42 def test_max_exclusive(self): field = self.field_class(max=42, exclusiveMax=True) - assert 'maximum' in field.__schema__ - assert field.__schema__['maximum'] == 42 - assert 'exclusiveMaximum' in field.__schema__ - assert field.__schema__['exclusiveMaximum'] is True + assert "maximum" in field.__schema__ + assert field.__schema__["maximum"] == 42 + assert "exclusiveMaximum" in field.__schema__ + assert field.__schema__["exclusiveMaximum"] is True def test_mulitple_of(self): field = self.field_class(multiple=5) - assert 'multipleOf' in field.__schema__ - assert field.__schema__['multipleOf'] == 5 + assert "multipleOf" in field.__schema__ + assert field.__schema__["multipleOf"] == 5 class StringTestMixin(object): def test_min_length(self): field = self.field_class(min_length=1) - assert 'minLength' in field.__schema__ - assert field.__schema__['minLength'] == 1 + assert "minLength" in field.__schema__ + assert field.__schema__["minLength"] == 1 def test_min_length_as_callable(self): field = self.field_class(min_length=lambda: 1) - assert 'minLength' in field.__schema__ - assert field.__schema__['minLength'] == 1 + assert "minLength" in field.__schema__ + assert field.__schema__["minLength"] == 1 def test_max_length(self): field = self.field_class(max_length=42) - assert 'maxLength' in field.__schema__ - assert field.__schema__['maxLength'] == 42 + assert "maxLength" in field.__schema__ + assert field.__schema__["maxLength"] == 42 def test_max_length_as_callable(self): field = self.field_class(max_length=lambda: 42) - assert 'maxLength' in field.__schema__ - assert field.__schema__['maxLength'] == 42 + assert "maxLength" in field.__schema__ + assert field.__schema__["maxLength"] == 42 def test_pattern(self): - field = self.field_class(pattern='[a-z]') - assert 'pattern' in field.__schema__ - assert field.__schema__['pattern'] == '[a-z]' + field = self.field_class(pattern="[a-z]") + assert "pattern" in field.__schema__ + assert field.__schema__["pattern"] == "[a-z]" class RawFieldTest(BaseFieldTestMixin, FieldTestCase): - ''' Test Raw field AND some common behaviors''' + """ Test Raw field AND some common behaviors""" + field_class = fields.Raw def test_type(self): field = fields.Raw() - assert field.__schema__['type'] == 'object' + assert field.__schema__["type"] == "object" def test_default(self): - field = fields.Raw(default='aaa') - assert field.__schema__['default'] == 'aaa' - self.assert_field(field, None, 'aaa') + field = fields.Raw(default="aaa") + assert field.__schema__["default"] == "aaa" + self.assert_field(field, None, "aaa") def test_default_as_callable(self): - field = fields.Raw(default=lambda: 'aaa') - assert field.__schema__['default'] == 'aaa' - self.assert_field(field, None, 'aaa') + field = fields.Raw(default=lambda: "aaa") + assert field.__schema__["default"] == "aaa" + self.assert_field(field, None, "aaa") def test_with_attribute(self): - field = fields.Raw(attribute='bar') - assert field.output('foo', {'bar': 42}) == 42 + field = fields.Raw(attribute="bar") + assert field.output("foo", {"bar": 42}) == 42 def test_with_lambda_attribute(self, mocker): obj = mocker.Mock() obj.value = 42 field = fields.Raw(attribute=lambda x: x.value) - assert field.output('foo', obj) == 42 + assert field.output("foo", obj) == 42 def test_with_partial_attribute(self, mocker): def f(x, suffix): - return '{0}-{1}'.format(x.value, suffix) + return "{0}-{1}".format(x.value, suffix) obj = mocker.Mock() obj.value = 42 - p = partial(f, suffix='whatever') + p = partial(f, suffix="whatever") field = fields.Raw(attribute=p) - assert field.output('foo', obj) == '42-whatever' + assert field.output("foo", obj) == "42-whatever" def test_attribute_not_found(self): field = fields.Raw() - assert field.output('foo', {'bar': 42}) is None + assert field.output("foo", {"bar": 42}) is None def test_object(self, mocker): obj = mocker.Mock() obj.foo = 42 field = fields.Raw() - assert field.output('foo', obj) == 42 + assert field.output("foo", obj) == 42 def test_nested_object(self, mocker): foo = mocker.Mock() @@ -177,7 +178,7 @@ def test_nested_object(self, mocker): bar.value = 42 foo.bar = bar field = fields.Raw() - assert field.output('bar.value', foo) == 42 + assert field.output("bar.value", foo) == 42 class StringFieldTest(StringTestMixin, BaseFieldTestMixin, FieldTestCase): @@ -187,68 +188,70 @@ def test_defaults(self): field = fields.String() assert not field.required assert not field.discriminator - assert field.__schema__ == {'type': 'string'} + assert field.__schema__ == {"type": "string"} def test_with_enum(self): - enum = ['A', 'B', 'C'] + enum = ["A", "B", "C"] field = fields.String(enum=enum) assert not field.required - assert field.__schema__ == {'type': 'string', 'enum': enum, 'example': enum[0]} + assert field.__schema__ == {"type": "string", "enum": enum, "example": enum[0]} def test_with_empty_enum(self): field = fields.String(enum=[]) assert not field.required - assert field.__schema__ == {'type': 'string'} + assert field.__schema__ == {"type": "string"} def test_with_callable_enum(self): - enum = lambda: ['A', 'B', 'C'] # noqa + enum = lambda: ["A", "B", "C"] # noqa field = fields.String(enum=enum) assert not field.required - assert field.__schema__ == {'type': 'string', 'enum': ['A', 'B', 'C'], 'example': 'A'} + assert field.__schema__ == { + "type": "string", + "enum": ["A", "B", "C"], + "example": "A", + } def test_with_empty_callable_enum(self): enum = lambda: [] # noqa field = fields.String(enum=enum) assert not field.required - assert field.__schema__ == {'type': 'string'} + assert field.__schema__ == {"type": "string"} def test_with_default(self): - field = fields.String(default='aaa') - assert field.__schema__ == {'type': 'string', 'default': 'aaa'} + field = fields.String(default="aaa") + assert field.__schema__ == {"type": "string", "default": "aaa"} def test_string_field_with_discriminator(self): field = fields.String(discriminator=True) assert field.discriminator assert field.required - assert field.__schema__ == {'type': 'string'} + assert field.__schema__ == {"type": "string"} def test_string_field_with_discriminator_override_require(self): field = fields.String(discriminator=True, required=False) assert field.discriminator assert field.required - assert field.__schema__ == {'type': 'string'} + assert field.__schema__ == {"type": "string"} def test_discriminator_output(self, api): - model = api.model('Test', { - 'name': fields.String(discriminator=True), - }) + model = api.model("Test", {"name": fields.String(discriminator=True),}) data = api.marshal({}, model) - assert data == {'name': 'Test'} + assert data == {"name": "Test"} def test_multiple_discriminator_field(self, api): - model = api.model('Test', { - 'name': fields.String(discriminator=True), - 'name2': fields.String(discriminator=True), - }) + model = api.model( + "Test", + { + "name": fields.String(discriminator=True), + "name2": fields.String(discriminator=True), + }, + ) with pytest.raises(ValueError): api.marshal(object(), model) - @pytest.mark.parametrize('value,expected', [ - ('string', 'string'), - (42, '42'), - ]) + @pytest.mark.parametrize("value,expected", [("string", "string"), (42, "42"),]) def test_values(self, value, expected): self.assert_field(fields.String(), value, expected) @@ -259,27 +262,23 @@ class IntegerFieldTest(BaseFieldTestMixin, NumberTestMixin, FieldTestCase): def test_defaults(self): field = fields.Integer() assert not field.required - assert field.__schema__ == {'type': 'integer'} + assert field.__schema__ == {"type": "integer"} def test_with_default(self): field = fields.Integer(default=42) assert not field.required - assert field.__schema__ == {'type': 'integer', 'default': 42} + assert field.__schema__ == {"type": "integer", "default": 42} self.assert_field(field, None, 42) - @pytest.mark.parametrize('value,expected', [ - (0, 0), - (42, 42), - ('42', 42), - (None, None), - (66.6, 66), - ]) + @pytest.mark.parametrize( + "value,expected", [(0, 0), (42, 42), ("42", 42), (None, None), (66.6, 66),] + ) def test_value(self, value, expected): self.assert_field(fields.Integer(), value, expected) def test_decode_error(self): field = fields.Integer() - self.assert_field_raises(field, 'an int') + self.assert_field_raises(field, "an int") class BooleanFieldTest(BaseFieldTestMixin, FieldTestCase): @@ -288,20 +287,23 @@ class BooleanFieldTest(BaseFieldTestMixin, FieldTestCase): def test_defaults(self): field = fields.Boolean() assert not field.required - assert field.__schema__ == {'type': 'boolean'} + assert field.__schema__ == {"type": "boolean"} def test_with_default(self): field = fields.Boolean(default=True) assert not field.required - assert field.__schema__ == {'type': 'boolean', 'default': True} - - @pytest.mark.parametrize('value,expected', [ - (True, True), - (False, False), - ({}, False), - ('false', False), # These consistent with inputs.boolean - ('0', False), - ]) + assert field.__schema__ == {"type": "boolean", "default": True} + + @pytest.mark.parametrize( + "value,expected", + [ + (True, True), + (False, False), + ({}, False), + ("false", False), # These consistent with inputs.boolean + ("0", False), + ], + ) def test_value(self, value, expected): self.assert_field(fields.Boolean(), value, expected) @@ -312,32 +314,32 @@ class FloatFieldTest(BaseFieldTestMixin, NumberTestMixin, FieldTestCase): def test_defaults(self): field = fields.Float() assert not field.required - assert field.__schema__ == {'type': 'number'} + assert field.__schema__ == {"type": "number"} def test_with_default(self): field = fields.Float(default=0.5) assert not field.required - assert field.__schema__ == {'type': 'number', 'default': 0.5} + assert field.__schema__ == {"type": "number", "default": 0.5} - @pytest.mark.parametrize('value,expected', [ - ('-3.13', -3.13), - (str(-3.13), -3.13), - (3, 3.0), - ]) + @pytest.mark.parametrize( + "value,expected", [("-3.13", -3.13), (str(-3.13), -3.13), (3, 3.0),] + ) def test_value(self, value, expected): self.assert_field(fields.Float(), value, expected) def test_raises(self): - self.assert_field_raises(fields.Float(), 'bar') + self.assert_field_raises(fields.Float(), "bar") def test_decode_error(self): field = fields.Float() - self.assert_field_raises(field, 'not a float') + self.assert_field_raises(field, "not a float") -PI_STR = ('3.141592653589793238462643383279502884197169399375105820974944592307816406286208998628034825342117' - '06798214808651328230664709384460955058223172535940812848111745028410270193852110555964462294895493' - '038196442881097566593344612847564823378678316527120190914564856692346034861') +PI_STR = ( + "3.141592653589793238462643383279502884197169399375105820974944592307816406286208998628034825342117" + "06798214808651328230664709384460955058223172535940812848111745028410270193852110555964462294895493" + "038196442881097566593344612847564823378678316527120190914564856692346034861" +) PI = Decimal(PI_STR) @@ -348,34 +350,34 @@ class FixedFieldTest(BaseFieldTestMixin, NumberTestMixin, FieldTestCase): def test_defaults(self): field = fields.Fixed() assert not field.required - assert field.__schema__ == {'type': 'number'} + assert field.__schema__ == {"type": "number"} def test_with_default(self): field = fields.Fixed(default=0.5) assert not field.required - assert field.__schema__ == {'type': 'number', 'default': 0.5} + assert field.__schema__ == {"type": "number", "default": 0.5} def test_fixed(self): field5 = fields.Fixed(5) field4 = fields.Fixed(4) - self.assert_field(field5, PI, '3.14159') - self.assert_field(field4, PI, '3.1416') - self.assert_field(field4, 3, '3.0000') - self.assert_field(field4, '03', '3.0000') - self.assert_field(field4, '03.0', '3.0000') + self.assert_field(field5, PI, "3.14159") + self.assert_field(field4, PI, "3.1416") + self.assert_field(field4, 3, "3.0000") + self.assert_field(field4, "03", "3.0000") + self.assert_field(field4, "03.0", "3.0000") def test_zero(self): - self.assert_field(fields.Fixed(), '0', '0.00000') + self.assert_field(fields.Fixed(), "0", "0.00000") def test_infinite(self): field = fields.Fixed() - self.assert_field_raises(field, '+inf') - self.assert_field_raises(field, '-inf') + self.assert_field_raises(field, "+inf") + self.assert_field_raises(field, "-inf") def test_nan(self): field = fields.Fixed() - self.assert_field_raises(field, 'NaN') + self.assert_field_raises(field, "NaN") class ArbitraryFieldTest(BaseFieldTestMixin, NumberTestMixin, FieldTestCase): @@ -384,16 +386,13 @@ class ArbitraryFieldTest(BaseFieldTestMixin, NumberTestMixin, FieldTestCase): def test_defaults(self): field = fields.Arbitrary() assert not field.required - assert field.__schema__ == {'type': 'number'} + assert field.__schema__ == {"type": "number"} def test_with_default(self): field = fields.Arbitrary(default=0.5) - assert field.__schema__ == {'type': 'number', 'default': 0.5} + assert field.__schema__ == {"type": "number", "default": 0.5} - @pytest.mark.parametrize('value,expected', [ - (PI_STR, PI_STR), - (PI, PI_STR), - ]) + @pytest.mark.parametrize("value,expected", [(PI_STR, PI_STR), (PI, PI_STR),]) def test_value(self, value, expected): self.assert_field(fields.Arbitrary(), value, expected) @@ -404,111 +403,136 @@ class DatetimeFieldTest(BaseFieldTestMixin, FieldTestCase): def test_defaults(self): field = fields.DateTime() assert not field.required - assert field.__schema__ == {'type': 'string', 'format': 'date-time'} + assert field.__schema__ == {"type": "string", "format": "date-time"} self.assert_field(field, None, None) def test_with_default(self): - field = fields.DateTime(default='2014-08-25') - assert field.__schema__ == {'type': 'string', 'format': 'date-time', 'default': '2014-08-25T00:00:00'} - self.assert_field(field, None, '2014-08-25T00:00:00') + field = fields.DateTime(default="2014-08-25") + assert field.__schema__ == { + "type": "string", + "format": "date-time", + "default": "2014-08-25T00:00:00", + } + self.assert_field(field, None, "2014-08-25T00:00:00") def test_with_default_as_datetime(self): field = fields.DateTime(default=datetime(2014, 8, 25)) - assert field.__schema__ == {'type': 'string', 'format': 'date-time', 'default': '2014-08-25T00:00:00'} - self.assert_field(field, None, '2014-08-25T00:00:00') + assert field.__schema__ == { + "type": "string", + "format": "date-time", + "default": "2014-08-25T00:00:00", + } + self.assert_field(field, None, "2014-08-25T00:00:00") def test_with_default_as_date(self): field = fields.DateTime(default=date(2014, 8, 25)) - assert field.__schema__ == {'type': 'string', 'format': 'date-time', 'default': '2014-08-25T00:00:00'} - self.assert_field(field, None, '2014-08-25T00:00:00') + assert field.__schema__ == { + "type": "string", + "format": "date-time", + "default": "2014-08-25T00:00:00", + } + self.assert_field(field, None, "2014-08-25T00:00:00") def test_min(self): - field = fields.DateTime(min='1984-06-07T00:00:00') - assert 'minimum' in field.__schema__ - assert field.__schema__['minimum'] == '1984-06-07T00:00:00' - assert 'exclusiveMinimum' not in field.__schema__ + field = fields.DateTime(min="1984-06-07T00:00:00") + assert "minimum" in field.__schema__ + assert field.__schema__["minimum"] == "1984-06-07T00:00:00" + assert "exclusiveMinimum" not in field.__schema__ def test_min_as_date(self): field = fields.DateTime(min=date(1984, 6, 7)) - assert 'minimum' in field.__schema__ - assert field.__schema__['minimum'] == '1984-06-07T00:00:00' - assert 'exclusiveMinimum' not in field.__schema__ + assert "minimum" in field.__schema__ + assert field.__schema__["minimum"] == "1984-06-07T00:00:00" + assert "exclusiveMinimum" not in field.__schema__ def test_min_as_datetime(self): field = fields.DateTime(min=datetime(1984, 6, 7, 1, 2, 0)) - assert 'minimum' in field.__schema__ - assert field.__schema__['minimum'] == '1984-06-07T01:02:00' - assert 'exclusiveMinimum' not in field.__schema__ + assert "minimum" in field.__schema__ + assert field.__schema__["minimum"] == "1984-06-07T01:02:00" + assert "exclusiveMinimum" not in field.__schema__ def test_min_exlusive(self): - field = fields.DateTime(min='1984-06-07T00:00:00', exclusiveMin=True) - assert 'minimum' in field.__schema__ - assert field.__schema__['minimum'] == '1984-06-07T00:00:00' - assert 'exclusiveMinimum' in field.__schema__ - assert field.__schema__['exclusiveMinimum'] is True + field = fields.DateTime(min="1984-06-07T00:00:00", exclusiveMin=True) + assert "minimum" in field.__schema__ + assert field.__schema__["minimum"] == "1984-06-07T00:00:00" + assert "exclusiveMinimum" in field.__schema__ + assert field.__schema__["exclusiveMinimum"] is True def test_max(self): - field = fields.DateTime(max='1984-06-07T00:00:00') - assert 'maximum' in field.__schema__ - assert field.__schema__['maximum'] == '1984-06-07T00:00:00' - assert 'exclusiveMaximum' not in field.__schema__ + field = fields.DateTime(max="1984-06-07T00:00:00") + assert "maximum" in field.__schema__ + assert field.__schema__["maximum"] == "1984-06-07T00:00:00" + assert "exclusiveMaximum" not in field.__schema__ def test_max_as_date(self): field = fields.DateTime(max=date(1984, 6, 7)) - assert 'maximum' in field.__schema__ - assert field.__schema__['maximum'] == '1984-06-07T00:00:00' - assert 'exclusiveMaximum' not in field.__schema__ + assert "maximum" in field.__schema__ + assert field.__schema__["maximum"] == "1984-06-07T00:00:00" + assert "exclusiveMaximum" not in field.__schema__ def test_max_as_datetime(self): field = fields.DateTime(max=datetime(1984, 6, 7, 1, 2, 0)) - assert 'maximum' in field.__schema__ - assert field.__schema__['maximum'] == '1984-06-07T01:02:00' - assert 'exclusiveMaximum' not in field.__schema__ + assert "maximum" in field.__schema__ + assert field.__schema__["maximum"] == "1984-06-07T01:02:00" + assert "exclusiveMaximum" not in field.__schema__ def test_max_exclusive(self): - field = fields.DateTime(max='1984-06-07T00:00:00', exclusiveMax=True) - assert 'maximum' in field.__schema__ - assert field.__schema__['maximum'] == '1984-06-07T00:00:00' - assert 'exclusiveMaximum' in field.__schema__ - assert field.__schema__['exclusiveMaximum'] is True - - @pytest.mark.parametrize('value,expected', [ - (date(2011, 1, 1), 'Sat, 01 Jan 2011 00:00:00 -0000'), - (datetime(2011, 1, 1), 'Sat, 01 Jan 2011 00:00:00 -0000'), - (datetime(2011, 1, 1, 23, 59, 59), - 'Sat, 01 Jan 2011 23:59:59 -0000'), - (datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.utc), - 'Sat, 01 Jan 2011 23:59:59 -0000'), - (datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.timezone('CET')), - 'Sat, 01 Jan 2011 22:59:59 -0000') - ]) + field = fields.DateTime(max="1984-06-07T00:00:00", exclusiveMax=True) + assert "maximum" in field.__schema__ + assert field.__schema__["maximum"] == "1984-06-07T00:00:00" + assert "exclusiveMaximum" in field.__schema__ + assert field.__schema__["exclusiveMaximum"] is True + + @pytest.mark.parametrize( + "value,expected", + [ + (date(2011, 1, 1), "Sat, 01 Jan 2011 00:00:00 -0000"), + (datetime(2011, 1, 1), "Sat, 01 Jan 2011 00:00:00 -0000"), + (datetime(2011, 1, 1, 23, 59, 59), "Sat, 01 Jan 2011 23:59:59 -0000"), + ( + datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.utc), + "Sat, 01 Jan 2011 23:59:59 -0000", + ), + ( + datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.timezone("CET")), + "Sat, 01 Jan 2011 22:59:59 -0000", + ), + ], + ) def test_rfc822_value(self, value, expected): - self.assert_field(fields.DateTime(dt_format='rfc822'), value, expected) - - @pytest.mark.parametrize('value,expected', [ - (date(2011, 1, 1), '2011-01-01T00:00:00'), - (datetime(2011, 1, 1), '2011-01-01T00:00:00'), - (datetime(2011, 1, 1, 23, 59, 59), - '2011-01-01T23:59:59'), - (datetime(2011, 1, 1, 23, 59, 59, 1000), - '2011-01-01T23:59:59.001000'), - (datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.utc), - '2011-01-01T23:59:59+00:00'), - (datetime(2011, 1, 1, 23, 59, 59, 1000, tzinfo=pytz.utc), - '2011-01-01T23:59:59.001000+00:00'), - (datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.timezone('CET')), - '2011-01-01T23:59:59+01:00') - ]) + self.assert_field(fields.DateTime(dt_format="rfc822"), value, expected) + + @pytest.mark.parametrize( + "value,expected", + [ + (date(2011, 1, 1), "2011-01-01T00:00:00"), + (datetime(2011, 1, 1), "2011-01-01T00:00:00"), + (datetime(2011, 1, 1, 23, 59, 59), "2011-01-01T23:59:59"), + (datetime(2011, 1, 1, 23, 59, 59, 1000), "2011-01-01T23:59:59.001000"), + ( + datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.utc), + "2011-01-01T23:59:59+00:00", + ), + ( + datetime(2011, 1, 1, 23, 59, 59, 1000, tzinfo=pytz.utc), + "2011-01-01T23:59:59.001000+00:00", + ), + ( + datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.timezone("CET")), + "2011-01-01T23:59:59+01:00", + ), + ], + ) def test_iso8601_value(self, value, expected): - self.assert_field(fields.DateTime(dt_format='iso8601'), value, expected) + self.assert_field(fields.DateTime(dt_format="iso8601"), value, expected) def test_unsupported_format(self): - field = fields.DateTime(dt_format='raw') + field = fields.DateTime(dt_format="raw") self.assert_field_raises(field, datetime.now()) def test_unsupported_value_format(self): - field = fields.DateTime(dt_format='raw') - self.assert_field_raises(field, 'xxx') + field = fields.DateTime(dt_format="raw") + self.assert_field_raises(field, "xxx") class DateFieldTest(BaseFieldTestMixin, FieldTestCase): @@ -517,306 +541,330 @@ class DateFieldTest(BaseFieldTestMixin, FieldTestCase): def test_defaults(self): field = fields.Date() assert not field.required - assert field.__schema__ == {'type': 'string', 'format': 'date'} + assert field.__schema__ == {"type": "string", "format": "date"} def test_with_default(self): - field = fields.Date(default='2014-08-25') - assert field.__schema__ == {'type': 'string', 'format': 'date', 'default': '2014-08-25'} - self.assert_field(field, None, '2014-08-25') + field = fields.Date(default="2014-08-25") + assert field.__schema__ == { + "type": "string", + "format": "date", + "default": "2014-08-25", + } + self.assert_field(field, None, "2014-08-25") def test_with_default_as_date(self): field = fields.Date(default=date(2014, 8, 25)) - assert field.__schema__ == {'type': 'string', 'format': 'date', 'default': '2014-08-25'} + assert field.__schema__ == { + "type": "string", + "format": "date", + "default": "2014-08-25", + } def test_with_default_as_datetime(self): field = fields.Date(default=datetime(2014, 8, 25)) - assert field.__schema__ == {'type': 'string', 'format': 'date', 'default': '2014-08-25'} + assert field.__schema__ == { + "type": "string", + "format": "date", + "default": "2014-08-25", + } def test_min(self): - field = fields.Date(min='1984-06-07') - assert 'minimum' in field.__schema__ - assert field.__schema__['minimum'] == '1984-06-07' - assert 'exclusiveMinimum' not in field.__schema__ + field = fields.Date(min="1984-06-07") + assert "minimum" in field.__schema__ + assert field.__schema__["minimum"] == "1984-06-07" + assert "exclusiveMinimum" not in field.__schema__ def test_min_as_date(self): field = fields.Date(min=date(1984, 6, 7)) - assert 'minimum' in field.__schema__ - assert field.__schema__['minimum'] == '1984-06-07' - assert 'exclusiveMinimum' not in field.__schema__ + assert "minimum" in field.__schema__ + assert field.__schema__["minimum"] == "1984-06-07" + assert "exclusiveMinimum" not in field.__schema__ def test_min_as_datetime(self): field = fields.Date(min=datetime(1984, 6, 7, 1, 2, 0)) - assert 'minimum' in field.__schema__ - assert field.__schema__['minimum'] == '1984-06-07' - assert 'exclusiveMinimum' not in field.__schema__ + assert "minimum" in field.__schema__ + assert field.__schema__["minimum"] == "1984-06-07" + assert "exclusiveMinimum" not in field.__schema__ def test_min_exlusive(self): - field = fields.Date(min='1984-06-07', exclusiveMin=True) - assert 'minimum' in field.__schema__ - assert field.__schema__['minimum'] == '1984-06-07' - assert 'exclusiveMinimum' in field.__schema__ - assert field.__schema__['exclusiveMinimum'] is True + field = fields.Date(min="1984-06-07", exclusiveMin=True) + assert "minimum" in field.__schema__ + assert field.__schema__["minimum"] == "1984-06-07" + assert "exclusiveMinimum" in field.__schema__ + assert field.__schema__["exclusiveMinimum"] is True def test_max(self): - field = fields.Date(max='1984-06-07') - assert 'maximum' in field.__schema__ - assert field.__schema__['maximum'] == '1984-06-07' - assert 'exclusiveMaximum' not in field.__schema__ + field = fields.Date(max="1984-06-07") + assert "maximum" in field.__schema__ + assert field.__schema__["maximum"] == "1984-06-07" + assert "exclusiveMaximum" not in field.__schema__ def test_max_as_date(self): field = fields.Date(max=date(1984, 6, 7)) - assert 'maximum' in field.__schema__ - assert field.__schema__['maximum'] == '1984-06-07' - assert 'exclusiveMaximum' not in field.__schema__ + assert "maximum" in field.__schema__ + assert field.__schema__["maximum"] == "1984-06-07" + assert "exclusiveMaximum" not in field.__schema__ def test_max_as_datetime(self): field = fields.Date(max=datetime(1984, 6, 7, 1, 2, 0)) - assert 'maximum' in field.__schema__ - assert field.__schema__['maximum'] == '1984-06-07' - assert 'exclusiveMaximum' not in field.__schema__ + assert "maximum" in field.__schema__ + assert field.__schema__["maximum"] == "1984-06-07" + assert "exclusiveMaximum" not in field.__schema__ def test_max_exclusive(self): - field = fields.Date(max='1984-06-07', exclusiveMax=True) - assert 'maximum' in field.__schema__ - assert field.__schema__['maximum'] == '1984-06-07' - assert 'exclusiveMaximum' in field.__schema__ - assert field.__schema__['exclusiveMaximum'] is True - - @pytest.mark.parametrize('value,expected', [ - (date(2011, 1, 1), '2011-01-01'), - (datetime(2011, 1, 1), '2011-01-01'), - (datetime(2011, 1, 1, 23, 59, 59), '2011-01-01'), - (datetime(2011, 1, 1, 23, 59, 59, 1000), '2011-01-01'), - (datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.utc), '2011-01-01'), - (datetime(2011, 1, 1, 23, 59, 59, 1000, tzinfo=pytz.utc), '2011-01-01'), - (datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.timezone('CET')), '2011-01-01') - ]) + field = fields.Date(max="1984-06-07", exclusiveMax=True) + assert "maximum" in field.__schema__ + assert field.__schema__["maximum"] == "1984-06-07" + assert "exclusiveMaximum" in field.__schema__ + assert field.__schema__["exclusiveMaximum"] is True + + @pytest.mark.parametrize( + "value,expected", + [ + (date(2011, 1, 1), "2011-01-01"), + (datetime(2011, 1, 1), "2011-01-01"), + (datetime(2011, 1, 1, 23, 59, 59), "2011-01-01"), + (datetime(2011, 1, 1, 23, 59, 59, 1000), "2011-01-01"), + (datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.utc), "2011-01-01"), + (datetime(2011, 1, 1, 23, 59, 59, 1000, tzinfo=pytz.utc), "2011-01-01"), + ( + datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.timezone("CET")), + "2011-01-01", + ), + ], + ) def test_value(self, value, expected): self.assert_field(fields.Date(), value, expected) def test_unsupported_value_format(self): - self.assert_field_raises(fields.Date(), 'xxx') + self.assert_field_raises(fields.Date(), "xxx") class FormatedStringFieldTest(StringTestMixin, BaseFieldTestMixin, FieldTestCase): - field_class = partial(fields.FormattedString, 'Hello {name}') + field_class = partial(fields.FormattedString, "Hello {name}") def test_defaults(self): - field = fields.FormattedString('Hello {name}') + field = fields.FormattedString("Hello {name}") assert not field.required - assert field.__schema__ == {'type': 'string'} + assert field.__schema__ == {"type": "string"} def test_dict(self): data = { - 'sid': 3, - 'account_sid': 4, + "sid": 3, + "account_sid": 4, } - field = fields.FormattedString('/foo/{account_sid}/{sid}/') - assert field.output('foo', data) == '/foo/4/3/' + field = fields.FormattedString("/foo/{account_sid}/{sid}/") + assert field.output("foo", data) == "/foo/4/3/" def test_object(self, mocker): obj = mocker.Mock() obj.sid = 3 obj.account_sid = 4 - field = fields.FormattedString('/foo/{account_sid}/{sid}/') - assert field.output('foo', obj) == '/foo/4/3/' + field = fields.FormattedString("/foo/{account_sid}/{sid}/") + assert field.output("foo", obj) == "/foo/4/3/" def test_none(self): - field = fields.FormattedString('{foo}') + field = fields.FormattedString("{foo}") # self.assert_field_raises(field, None) with pytest.raises(fields.MarshallingError): - field.output('foo', None) + field.output("foo", None) def test_invalid_object(self): - field = fields.FormattedString('/foo/{0[account_sid]}/{0[sid]}/') + field = fields.FormattedString("/foo/{0[account_sid]}/{0[sid]}/") self.assert_field_raises(field, {}) def test_tuple(self): - field = fields.FormattedString('/foo/{0[account_sid]}/{0[sid]}/') + field = fields.FormattedString("/foo/{0[account_sid]}/{0[sid]}/") self.assert_field_raises(field, (3, 4)) class UrlFieldTest(StringTestMixin, BaseFieldTestMixin, FieldTestCase): - field_class = partial(fields.Url, 'endpoint') + field_class = partial(fields.Url, "endpoint") def test_defaults(self): - field = fields.Url('endpoint') + field = fields.Url("endpoint") assert not field.required - assert field.__schema__ == {'type': 'string'} + assert field.__schema__ == {"type": "string"} def test_invalid_object(self, app): - app.add_url_rule('/', 'foobar', view_func=lambda x: x) - field = fields.Url('foobar') + app.add_url_rule("/", "foobar", view_func=lambda x: x) + field = fields.Url("foobar") - with app.test_request_context('/'): + with app.test_request_context("/"): with pytest.raises(fields.MarshallingError): - field.output('foo', None) + field.output("foo", None) def test_simple(self, app, mocker): - app.add_url_rule('/', 'foobar', view_func=lambda x: x) - field = fields.Url('foobar') + app.add_url_rule("/", "foobar", view_func=lambda x: x) + field = fields.Url("foobar") obj = mocker.Mock(foo=42) - with app.test_request_context('/'): - assert '/42' == field.output('foo', obj) + with app.test_request_context("/"): + assert "/42" == field.output("foo", obj) def test_absolute(self, app, mocker): - app.add_url_rule('/', 'foobar', view_func=lambda x: x) - field = fields.Url('foobar', absolute=True) + app.add_url_rule("/", "foobar", view_func=lambda x: x) + field = fields.Url("foobar", absolute=True) obj = mocker.Mock(foo=42) - with app.test_request_context('/'): - assert 'http://localhost/42' == field.output('foo', obj) + with app.test_request_context("/"): + assert "http://localhost/42" == field.output("foo", obj) def test_absolute_scheme(self, app, mocker): - '''Url.scheme should override current_request.scheme''' - app.add_url_rule('/', 'foobar', view_func=lambda x: x) - field = fields.Url('foobar', absolute=True, scheme='https') + """Url.scheme should override current_request.scheme""" + app.add_url_rule("/", "foobar", view_func=lambda x: x) + field = fields.Url("foobar", absolute=True, scheme="https") obj = mocker.Mock(foo=42) - with app.test_request_context('/', base_url='http://localhost'): - assert 'https://localhost/42' == field.output('foo', obj) + with app.test_request_context("/", base_url="http://localhost"): + assert "https://localhost/42" == field.output("foo", obj) def test_without_endpoint_invalid_object(self, app): - app.add_url_rule('/', 'foobar', view_func=lambda x: x) + app.add_url_rule("/", "foobar", view_func=lambda x: x) field = fields.Url() - with app.test_request_context('/foo'): + with app.test_request_context("/foo"): with pytest.raises(fields.MarshallingError): - field.output('foo', None) + field.output("foo", None) def test_without_endpoint(self, app, mocker): - app.add_url_rule('/', 'foobar', view_func=lambda x: x) + app.add_url_rule("/", "foobar", view_func=lambda x: x) field = fields.Url() obj = mocker.Mock(foo=42) - with app.test_request_context('/foo'): - assert '/42' == field.output('foo', obj) + with app.test_request_context("/foo"): + assert "/42" == field.output("foo", obj) def test_without_endpoint_absolute(self, app, mocker): - app.add_url_rule('/', 'foobar', view_func=lambda x: x) + app.add_url_rule("/", "foobar", view_func=lambda x: x) field = fields.Url(absolute=True) obj = mocker.Mock(foo=42) - with app.test_request_context('/foo'): - assert 'http://localhost/42' == field.output('foo', obj) + with app.test_request_context("/foo"): + assert "http://localhost/42" == field.output("foo", obj) def test_without_endpoint_absolute_scheme(self, app, mocker): - app.add_url_rule('/', 'foobar', view_func=lambda x: x) - field = fields.Url(absolute=True, scheme='https') + app.add_url_rule("/", "foobar", view_func=lambda x: x) + field = fields.Url(absolute=True, scheme="https") obj = mocker.Mock(foo=42) - with app.test_request_context('/foo', base_url='http://localhost'): - assert 'https://localhost/42' == field.output('foo', obj) + with app.test_request_context("/foo", base_url="http://localhost"): + assert "https://localhost/42" == field.output("foo", obj) def test_with_blueprint_invalid_object(self, app): - bp = Blueprint('foo', __name__, url_prefix='/foo') - bp.add_url_rule('/', 'foobar', view_func=lambda x: x) + bp = Blueprint("foo", __name__, url_prefix="/foo") + bp.add_url_rule("/", "foobar", view_func=lambda x: x) app.register_blueprint(bp) field = fields.Url() - with app.test_request_context('/foo/foo'): + with app.test_request_context("/foo/foo"): with pytest.raises(fields.MarshallingError): - field.output('foo', None) + field.output("foo", None) def test_with_blueprint(self, app, mocker): - bp = Blueprint('foo', __name__, url_prefix='/foo') - bp.add_url_rule('/', 'foobar', view_func=lambda x: x) + bp = Blueprint("foo", __name__, url_prefix="/foo") + bp.add_url_rule("/", "foobar", view_func=lambda x: x) app.register_blueprint(bp) field = fields.Url() obj = mocker.Mock(foo=42) - with app.test_request_context('/foo/foo'): - assert '/foo/42' == field.output('foo', obj) + with app.test_request_context("/foo/foo"): + assert "/foo/42" == field.output("foo", obj) def test_with_blueprint_absolute(self, app, mocker): - bp = Blueprint('foo', __name__, url_prefix='/foo') - bp.add_url_rule('/', 'foobar', view_func=lambda x: x) + bp = Blueprint("foo", __name__, url_prefix="/foo") + bp.add_url_rule("/", "foobar", view_func=lambda x: x) app.register_blueprint(bp) field = fields.Url(absolute=True) obj = mocker.Mock(foo=42) - with app.test_request_context('/foo/foo'): - assert 'http://localhost/foo/42' == field.output('foo', obj) + with app.test_request_context("/foo/foo"): + assert "http://localhost/foo/42" == field.output("foo", obj) def test_with_blueprint_absolute_scheme(self, app, mocker): - bp = Blueprint('foo', __name__, url_prefix='/foo') - bp.add_url_rule('/', 'foobar', view_func=lambda x: x) + bp = Blueprint("foo", __name__, url_prefix="/foo") + bp.add_url_rule("/", "foobar", view_func=lambda x: x) app.register_blueprint(bp) - field = fields.Url(absolute=True, scheme='https') + field = fields.Url(absolute=True, scheme="https") obj = mocker.Mock(foo=42) - with app.test_request_context('/foo/foo', base_url='http://localhost'): - assert 'https://localhost/foo/42' == field.output('foo', obj) + with app.test_request_context("/foo/foo", base_url="http://localhost"): + assert "https://localhost/foo/42" == field.output("foo", obj) class NestedFieldTest(FieldTestCase): def test_defaults(self, api): - nested_fields = api.model('NestedModel', {'name': fields.String}) + nested_fields = api.model("NestedModel", {"name": fields.String}) field = fields.Nested(nested_fields) assert not field.required - assert field.__schema__ == {'$ref': '#/definitions/NestedModel'} + assert field.__schema__ == {"$ref": "#/definitions/NestedModel"} def test_with_required(self, api): - nested_fields = api.model('NestedModel', {'name': fields.String}) + nested_fields = api.model("NestedModel", {"name": fields.String}) field = fields.Nested(nested_fields, required=True) assert field.required assert not field.allow_null - assert field.__schema__ == {'$ref': '#/definitions/NestedModel'} + assert field.__schema__ == {"$ref": "#/definitions/NestedModel"} def test_with_description(self, api): - nested_fields = api.model('NestedModel', {'name': fields.String}) - field = fields.Nested(nested_fields, description='A description') + nested_fields = api.model("NestedModel", {"name": fields.String}) + field = fields.Nested(nested_fields, description="A description") assert field.__schema__ == { - 'description': 'A description', - 'allOf': [{'$ref': '#/definitions/NestedModel'}] + "description": "A description", + "allOf": [{"$ref": "#/definitions/NestedModel"}], } def test_with_title(self, api): - nested_fields = api.model('NestedModel', {'name': fields.String}) - field = fields.Nested(nested_fields, title='A title') + nested_fields = api.model("NestedModel", {"name": fields.String}) + field = fields.Nested(nested_fields, title="A title") assert field.__schema__ == { - 'title': 'A title', - 'allOf': [{'$ref': '#/definitions/NestedModel'}] + "title": "A title", + "allOf": [{"$ref": "#/definitions/NestedModel"}], } def test_with_allow_null(self, api): - nested_fields = api.model('NestedModel', {'name': fields.String}) + nested_fields = api.model("NestedModel", {"name": fields.String}) field = fields.Nested(nested_fields, allow_null=True) assert not field.required assert field.allow_null - assert field.__schema__ == {'$ref': '#/definitions/NestedModel'} + assert field.__schema__ == {"$ref": "#/definitions/NestedModel"} def test_with_skip_none(self, api): - nested_fields = api.model('NestedModel', {'name': fields.String}) + nested_fields = api.model("NestedModel", {"name": fields.String}) field = fields.Nested(nested_fields, skip_none=True) assert not field.required assert field.skip_none - assert field.__schema__ == {'$ref': '#/definitions/NestedModel'} + assert field.__schema__ == {"$ref": "#/definitions/NestedModel"} def test_with_readonly(self, app): api = Api(app) - nested_fields = api.model('NestedModel', {'name': fields.String}) + nested_fields = api.model("NestedModel", {"name": fields.String}) field = fields.Nested(nested_fields, readonly=True) assert field.__schema__ == { - 'readOnly': True, - 'allOf': [{'$ref': '#/definitions/NestedModel'}] + "readOnly": True, + "allOf": [{"$ref": "#/definitions/NestedModel"}], } def test_as_list(self, api): - nested_fields = api.model('NestedModel', {'name': fields.String}) + nested_fields = api.model("NestedModel", {"name": fields.String}) field = fields.Nested(nested_fields, as_list=True) assert field.as_list - assert field.__schema__ == {'type': 'array', 'items': {'$ref': '#/definitions/NestedModel'}} + assert field.__schema__ == { + "type": "array", + "items": {"$ref": "#/definitions/NestedModel"}, + } def test_as_list_is_reusable(self, api): - nested_fields = api.model('NestedModel', {'name': fields.String}) + nested_fields = api.model("NestedModel", {"name": fields.String}) field = fields.Nested(nested_fields, as_list=True) - assert field.__schema__ == {'type': 'array', 'items': {'$ref': '#/definitions/NestedModel'}} + assert field.__schema__ == { + "type": "array", + "items": {"$ref": "#/definitions/NestedModel"}, + } field = fields.Nested(nested_fields) - assert field.__schema__ == {'$ref': '#/definitions/NestedModel'} + assert field.__schema__ == {"$ref": "#/definitions/NestedModel"} class ListFieldTest(BaseFieldTestMixin, FieldTestCase): @@ -825,46 +873,55 @@ class ListFieldTest(BaseFieldTestMixin, FieldTestCase): def test_defaults(self): field = fields.List(fields.String) assert not field.required - assert field.__schema__ == {'type': 'array', 'items': {'type': 'string'}} + assert field.__schema__ == {"type": "array", "items": {"type": "string"}} def test_with_nested_field(self, api): - nested_fields = api.model('NestedModel', {'name': fields.String}) + nested_fields = api.model("NestedModel", {"name": fields.String}) field = fields.List(fields.Nested(nested_fields)) - assert field.__schema__ == {'type': 'array', 'items': {'$ref': '#/definitions/NestedModel'}} + assert field.__schema__ == { + "type": "array", + "items": {"$ref": "#/definitions/NestedModel"}, + } - data = [{'name': 'John Doe', 'age': 42}, {'name': 'Jane Doe', 'age': 66}] - expected = [OrderedDict([('name', 'John Doe')]), OrderedDict([('name', 'Jane Doe')])] + data = [{"name": "John Doe", "age": 42}, {"name": "Jane Doe", "age": 66}] + expected = [ + OrderedDict([("name", "John Doe")]), + OrderedDict([("name", "Jane Doe")]), + ] self.assert_field(field, data, expected) def test_min_items(self): field = fields.List(fields.String, min_items=5) - assert 'minItems' in field.__schema__ - assert field.__schema__['minItems'] == 5 + assert "minItems" in field.__schema__ + assert field.__schema__["minItems"] == 5 def test_max_items(self): field = fields.List(fields.String, max_items=42) - assert 'maxItems' in field.__schema__ - assert field.__schema__['maxItems'] == 42 + assert "maxItems" in field.__schema__ + assert field.__schema__["maxItems"] == 42 def test_unique(self): field = fields.List(fields.String, unique=True) - assert 'uniqueItems' in field.__schema__ - assert field.__schema__['uniqueItems'] is True - - @pytest.mark.parametrize('value,expected', [ - (['a', 'b', 'c'], ['a', 'b', 'c']), - (['c', 'b', 'a'], ['c', 'b', 'a']), - (('a', 'b', 'c'), ['a', 'b', 'c']), - (['a'], ['a']), - (None, None), - ]) + assert "uniqueItems" in field.__schema__ + assert field.__schema__["uniqueItems"] is True + + @pytest.mark.parametrize( + "value,expected", + [ + (["a", "b", "c"], ["a", "b", "c"]), + (["c", "b", "a"], ["c", "b", "a"]), + (("a", "b", "c"), ["a", "b", "c"]), + (["a"], ["a"]), + (None, None), + ], + ) def test_value(self, value, expected): self.assert_field(fields.List(fields.String()), value, expected) def test_with_set(self): field = fields.List(fields.String) - value = set(['a', 'b', 'c']) - output = field.output('foo', {'foo': value}) + value = set(["a", "b", "c"]) + output = field.output("foo", {"foo": value}) assert set(output) == value def test_with_scoped_attribute_on_dict_or_obj(self): @@ -876,29 +933,31 @@ class Nested(object): def __init__(self, value): self.value = value - nesteds = [Nested(i) for i in ['a', 'b', 'c']] + nesteds = [Nested(i) for i in ["a", "b", "c"]] test_obj = Test(nesteds) - test_dict = {'data': [{'value': 'a'}, {'value': 'b'}, {'value': 'c'}]} + test_dict = {"data": [{"value": "a"}, {"value": "b"}, {"value": "c"}]} - field = fields.List(fields.String(attribute='value'), attribute='data') - assert ['a' == 'b', 'c'], field.output('whatever', test_obj) - assert ['a' == 'b', 'c'], field.output('whatever', test_dict) + field = fields.List(fields.String(attribute="value"), attribute="data") + assert ["a" == "b", "c"], field.output("whatever", test_obj) + assert ["a" == "b", "c"], field.output("whatever", test_dict) def test_with_attribute(self): - data = [{'a': 1, 'b': 1}, {'a': 2, 'b': 1}, {'a': 3, 'b': 1}] - field = fields.List(fields.Integer(attribute='a')) + data = [{"a": 1, "b": 1}, {"a": 2, "b": 1}, {"a": 3, "b": 1}] + field = fields.List(fields.Integer(attribute="a")) self.assert_field(field, data, [1, 2, 3]) def test_list_of_raw(self): field = fields.List(fields.Raw) - data = [{'a': 1, 'b': 1}, {'a': 2, 'b': 1}, {'a': 3, 'b': 1}] - expected = [OrderedDict([('a', 1), ('b', 1)]), - OrderedDict([('a', 2), ('b', 1)]), - OrderedDict([('a', 3), ('b', 1)])] + data = [{"a": 1, "b": 1}, {"a": 2, "b": 1}, {"a": 3, "b": 1}] + expected = [ + OrderedDict([("a", 1), ("b", 1)]), + OrderedDict([("a", 2), ("b", 1)]), + OrderedDict([("a", 3), ("b", 1)]), + ] self.assert_field(field, data, expected) - data = [1, 2, 'a'] + data = [1, 2, "a"] self.assert_field(field, data, data) @@ -907,8 +966,10 @@ class WildcardFieldTest(BaseFieldTestMixin, FieldTestCase): def test_types(self): with pytest.raises(fields.MarshallingError): + class WrongType: pass + x = WrongType() field1 = fields.Wildcard(WrongType) # noqa field2 = fields.Wildcard(x) # noqa @@ -916,7 +977,10 @@ class WrongType: def test_defaults(self): field = fields.Wildcard(fields.String) assert not field.required - assert field.__schema__ == {'type': 'object', 'additionalProperties': {'type': 'string'}} + assert field.__schema__ == { + "type": "object", + "additionalProperties": {"type": "string"}, + } def test_with_scoped_attribute_on_dict_or_obj(self): class Test(object): @@ -927,56 +991,64 @@ class Nested(object): def __init__(self, value): self.value = value - nesteds = [Nested(i) for i in ['a', 'b', 'c']] + nesteds = [Nested(i) for i in ["a", "b", "c"]] test_obj = Test(nesteds) - test_dict = {'data': [{'value': 'a'}, {'value': 'b'}, {'value': 'c'}]} + test_dict = {"data": [{"value": "a"}, {"value": "b"}, {"value": "c"}]} - field = fields.Wildcard(fields.String(attribute='value'), attribute='data') - assert ['a' == 'b', 'c'], field.output('whatever', test_obj) - assert ['a' == 'b', 'c'], field.output('whatever', test_dict) + field = fields.Wildcard(fields.String(attribute="value"), attribute="data") + assert ["a" == "b", "c"], field.output("whatever", test_obj) + assert ["a" == "b", "c"], field.output("whatever", test_dict) def test_list_of_raw(self): field = fields.Wildcard(fields.Raw) - data = [{'a': 1, 'b': 1}, {'a': 2, 'b': 1}, {'a': 3, 'b': 1}] - expected = [OrderedDict([('a', 1), ('b', 1)]), - OrderedDict([('a', 2), ('b', 1)]), - OrderedDict([('a', 3), ('b', 1)])] + data = [{"a": 1, "b": 1}, {"a": 2, "b": 1}, {"a": 3, "b": 1}] + expected = [ + OrderedDict([("a", 1), ("b", 1)]), + OrderedDict([("a", 2), ("b", 1)]), + OrderedDict([("a", 3), ("b", 1)]), + ] self.assert_field(field, data, expected) - data = [1, 2, 'a'] + data = [1, 2, "a"] self.assert_field(field, data, data) def test_wildcard(self, api): wild1 = fields.Wildcard(fields.String) wild2 = fields.Wildcard(fields.Integer) wild3 = fields.Wildcard(fields.String) - wild4 = fields.Wildcard(fields.String, default='x') + wild4 = fields.Wildcard(fields.String, default="x") wild5 = fields.Wildcard(fields.String) wild6 = fields.Wildcard(fields.Integer) wild7 = fields.Wildcard(fields.String) wild8 = fields.Wildcard(fields.String) mod5 = OrderedDict() - mod5['toto'] = fields.Integer - mod5['bob'] = fields.Integer - mod5['*'] = wild5 - - wild_fields1 = api.model('WildcardModel1', {'*': wild1}) - wild_fields2 = api.model('WildcardModel2', {'j*': wild2}) - wild_fields3 = api.model('WildcardModel3', {'*': wild3}) - wild_fields4 = api.model('WildcardModel4', {'*': wild4}) - wild_fields5 = api.model('WildcardModel5', mod5) - wild_fields6 = api.model('WildcardModel6', { - 'nested': {'f1': fields.String(default='12'), 'f2': fields.Integer(default=13)}, - 'a*': wild6 - }) - wild_fields7 = api.model('WildcardModel7', {'*': wild7}) - wild_fields8 = api.model('WildcardModel8', {'*': wild8}) + mod5["toto"] = fields.Integer + mod5["bob"] = fields.Integer + mod5["*"] = wild5 + + wild_fields1 = api.model("WildcardModel1", {"*": wild1}) + wild_fields2 = api.model("WildcardModel2", {"j*": wild2}) + wild_fields3 = api.model("WildcardModel3", {"*": wild3}) + wild_fields4 = api.model("WildcardModel4", {"*": wild4}) + wild_fields5 = api.model("WildcardModel5", mod5) + wild_fields6 = api.model( + "WildcardModel6", + { + "nested": { + "f1": fields.String(default="12"), + "f2": fields.Integer(default=13), + }, + "a*": wild6, + }, + ) + wild_fields7 = api.model("WildcardModel7", {"*": wild7}) + wild_fields8 = api.model("WildcardModel8", {"*": wild8}) class Dummy(object): john = 12 - bob = '42' + bob = "42" alice = None class Dummy2(object): @@ -986,19 +1058,19 @@ class Dummy3(object): a = None b = None - data = {'John': 12, 'bob': 42, 'Jane': '68'} + data = {"John": 12, "bob": 42, "Jane": "68"} data3 = Dummy() data4 = Dummy2() - data5 = {'John': 12, 'bob': 42, 'Jane': '68', 'toto': '72'} - data6 = {'nested': {'f1': 12, 'f2': 13}, 'alice': '14'} + data5 = {"John": 12, "bob": 42, "Jane": "68", "toto": "72"} + data6 = {"nested": {"f1": 12, "f2": 13}, "alice": "14"} data7 = Dummy3() data8 = None - expected1 = {'John': '12', 'bob': '42', 'Jane': '68'} - expected2 = {'John': 12, 'Jane': 68} - expected3 = {'john': '12', 'bob': '42'} - expected4 = {'*': 'x'} - expected5 = {'John': '12', 'bob': 42, 'Jane': '68', 'toto': 72} - expected6 = {'nested': {'f1': '12', 'f2': 13}, 'alice': 14} + expected1 = {"John": "12", "bob": "42", "Jane": "68"} + expected2 = {"John": 12, "Jane": 68} + expected3 = {"john": "12", "bob": "42"} + expected4 = {"*": "x"} + expected5 = {"John": "12", "bob": 42, "Jane": "68", "toto": 72} + expected6 = {"nested": {"f1": "12", "f2": 13}, "alice": 14} expected7 = {} expected8 = {} @@ -1024,11 +1096,11 @@ def test_clone(self, api): wild1 = fields.Wildcard(fields.String) wild2 = wild1.clone() - wild_fields1 = api.model('cloneWildcard1', {'*': wild1}) - wild_fields2 = api.model('cloneWildcard2', {'*': wild2}) + wild_fields1 = api.model("cloneWildcard1", {"*": wild1}) + wild_fields2 = api.model("cloneWildcard2", {"*": wild2}) - data = {'John': 12, 'bob': 42, 'Jane': '68'} - expected1 = {'John': '12', 'bob': '42', 'Jane': '68'} + data = {"John": 12, "bob": 42, "Jane": "68"} + expected1 = {"John": "12", "bob": "42", "Jane": "68"} result1 = api.marshal(data, wild_fields1) result2 = api.marshal(data, wild_fields2) @@ -1044,91 +1116,64 @@ def test_simple_string_field(self): field = fields.ClassName() assert not field.required assert not field.discriminator - assert field.__schema__ == {'type': 'string'} + assert field.__schema__ == {"type": "string"} def test_default_output_classname(self, api): - model = api.model('Test', { - 'name': fields.ClassName(), - }) + model = api.model("Test", {"name": fields.ClassName(),}) class FakeClass(object): pass data = api.marshal(FakeClass(), model) - assert data == {'name': 'FakeClass'} + assert data == {"name": "FakeClass"} def test_output_dash(self, api): - model = api.model('Test', { - 'name': fields.ClassName(dash=True), - }) + model = api.model("Test", {"name": fields.ClassName(dash=True),}) class FakeClass(object): pass data = api.marshal(FakeClass(), model) - assert data == {'name': 'fake_class'} + assert data == {"name": "fake_class"} def test_with_dict(self, api): - model = api.model('Test', { - 'name': fields.ClassName(), - }) + model = api.model("Test", {"name": fields.ClassName(),}) data = api.marshal({}, model) - assert data == {'name': 'object'} + assert data == {"name": "object"} class PolymorphTest(FieldTestCase): def test_polymorph_field(self, api): - parent = api.model('Person', { - 'name': fields.String, - }) + parent = api.model("Person", {"name": fields.String,}) - child1 = api.inherit('Child1', parent, { - 'extra1': fields.String, - }) + child1 = api.inherit("Child1", parent, {"extra1": fields.String,}) - child2 = api.inherit('Child2', parent, { - 'extra2': fields.String, - }) + child2 = api.inherit("Child2", parent, {"extra2": fields.String,}) class Child1(object): - name = 'child1' - extra1 = 'extra1' + name = "child1" + extra1 = "extra1" class Child2(object): - name = 'child2' - extra2 = 'extra2' + name = "child2" + extra2 = "extra2" - mapping = { - Child1: child1, - Child2: child2 - } + mapping = {Child1: child1, Child2: child2} - thing = api.model('Thing', { - 'owner': fields.Polymorph(mapping), - }) + thing = api.model("Thing", {"owner": fields.Polymorph(mapping),}) def data(cls): - return api.marshal({'owner': cls()}, thing) + return api.marshal({"owner": cls()}, thing) - assert data(Child1) == {'owner': { - 'name': 'child1', - 'extra1': 'extra1' - }} + assert data(Child1) == {"owner": {"name": "child1", "extra1": "extra1"}} - assert data(Child2) == {'owner': { - 'name': 'child2', - 'extra2': 'extra2' - }} + assert data(Child2) == {"owner": {"name": "child2", "extra2": "extra2"}} def test_polymorph_field_no_common_ancestor(self, api): - child1 = api.model('Child1', { - 'extra1': fields.String, - }) + child1 = api.model("Child1", {"extra1": fields.String,}) - child2 = api.model('Child2', { - 'extra2': fields.String, - }) + child2 = api.model("Child2", {"extra2": fields.String,}) class Child1(object): pass @@ -1136,230 +1181,179 @@ class Child1(object): class Child2(object): pass - mapping = { - Child1: child1, - Child2: child2 - } + mapping = {Child1: child1, Child2: child2} with pytest.raises(ValueError): fields.Polymorph(mapping) def test_polymorph_field_unknown_class(self, api): - parent = api.model('Person', { - 'name': fields.String, - }) + parent = api.model("Person", {"name": fields.String,}) - child1 = api.inherit('Child1', parent, { - 'extra1': fields.String, - }) + child1 = api.inherit("Child1", parent, {"extra1": fields.String,}) - child2 = api.inherit('Child2', parent, { - 'extra2': fields.String, - }) + child2 = api.inherit("Child2", parent, {"extra2": fields.String,}) class Child1(object): - name = 'child1' - extra1 = 'extra1' + name = "child1" + extra1 = "extra1" class Child2(object): - name = 'child2' - extra2 = 'extra2' + name = "child2" + extra2 = "extra2" - mapping = { - Child1: child1, - Child2: child2 - } + mapping = {Child1: child1, Child2: child2} - thing = api.model('Thing', { - 'owner': fields.Polymorph(mapping), - }) + thing = api.model("Thing", {"owner": fields.Polymorph(mapping),}) with pytest.raises(ValueError): - api.marshal({'owner': object()}, thing) + api.marshal({"owner": object()}, thing) def test_polymorph_field_does_not_have_ambiguous_mappings(self, api): """ Regression test for https://github.com/noirbizarre/flask-restx/pull/691 """ - parent = api.model('Parent', { - 'name': fields.String, - }) + parent = api.model("Parent", {"name": fields.String,}) - child = api.inherit('Child', parent, { - 'extra': fields.String, - }) + child = api.inherit("Child", parent, {"extra": fields.String,}) class Parent(object): - name = 'parent' + name = "parent" class Child(Parent): - extra = 'extra' + extra = "extra" - mapping = { - Parent: parent, - Child: child - } + mapping = {Parent: parent, Child: child} - thing = api.model('Thing', { - 'owner': fields.Polymorph(mapping), - }) + thing = api.model("Thing", {"owner": fields.Polymorph(mapping),}) - api.marshal({'owner': Child()}, thing) + api.marshal({"owner": Child()}, thing) def test_polymorph_field_required_default(self, api): - parent = api.model('Person', { - 'name': fields.String, - }) + parent = api.model("Person", {"name": fields.String,}) - child1 = api.inherit('Child1', parent, { - 'extra1': fields.String, - }) + child1 = api.inherit("Child1", parent, {"extra1": fields.String,}) - child2 = api.inherit('Child2', parent, { - 'extra2': fields.String, - }) + child2 = api.inherit("Child2", parent, {"extra2": fields.String,}) class Child1(object): - name = 'child1' - extra1 = 'extra1' + name = "child1" + extra1 = "extra1" class Child2(object): - name = 'child2' - extra2 = 'extra2' + name = "child2" + extra2 = "extra2" - mapping = { - Child1: child1, - Child2: child2 - } + mapping = {Child1: child1, Child2: child2} - thing = api.model('Thing', { - 'owner': fields.Polymorph(mapping, required=True, default={'name': 'default'}), - }) + thing = api.model( + "Thing", + { + "owner": fields.Polymorph( + mapping, required=True, default={"name": "default"} + ), + }, + ) data = api.marshal({}, thing) - assert data == {'owner': { - 'name': 'default' - }} + assert data == {"owner": {"name": "default"}} def test_polymorph_field_not_required(self, api): - parent = api.model('Person', { - 'name': fields.String, - }) + parent = api.model("Person", {"name": fields.String,}) - child1 = api.inherit('Child1', parent, { - 'extra1': fields.String, - }) + child1 = api.inherit("Child1", parent, {"extra1": fields.String,}) - child2 = api.inherit('Child2', parent, { - 'extra2': fields.String, - }) + child2 = api.inherit("Child2", parent, {"extra2": fields.String,}) class Child1(object): - name = 'child1' - extra1 = 'extra1' + name = "child1" + extra1 = "extra1" class Child2(object): - name = 'child2' - extra2 = 'extra2' + name = "child2" + extra2 = "extra2" - mapping = { - Child1: child1, - Child2: child2 - } + mapping = {Child1: child1, Child2: child2} - thing = api.model('Thing', { - 'owner': fields.Polymorph(mapping), - }) + thing = api.model("Thing", {"owner": fields.Polymorph(mapping),}) data = api.marshal({}, thing) - assert data == {'owner': None} + assert data == {"owner": None} def test_polymorph_with_discriminator(self, api): - parent = api.model('Person', { - 'name': fields.String, - 'model': fields.String(discriminator=True), - }) + parent = api.model( + "Person", + {"name": fields.String, "model": fields.String(discriminator=True),}, + ) - child1 = api.inherit('Child1', parent, { - 'extra1': fields.String, - }) + child1 = api.inherit("Child1", parent, {"extra1": fields.String,}) - child2 = api.inherit('Child2', parent, { - 'extra2': fields.String, - }) + child2 = api.inherit("Child2", parent, {"extra2": fields.String,}) class Child1(object): - name = 'child1' - extra1 = 'extra1' + name = "child1" + extra1 = "extra1" class Child2(object): - name = 'child2' - extra2 = 'extra2' + name = "child2" + extra2 = "extra2" - mapping = { - Child1: child1, - Child2: child2 - } + mapping = {Child1: child1, Child2: child2} - thing = api.model('Thing', { - 'owner': fields.Polymorph(mapping), - }) + thing = api.model("Thing", {"owner": fields.Polymorph(mapping),}) def data(cls): - return api.marshal({'owner': cls()}, thing) + return api.marshal({"owner": cls()}, thing) - assert data(Child1) == {'owner': { - 'name': 'child1', - 'model': 'Child1', - 'extra1': 'extra1' - }} + assert data(Child1) == { + "owner": {"name": "child1", "model": "Child1", "extra1": "extra1"} + } - assert data(Child2) == {'owner': { - 'name': 'child2', - 'model': 'Child2', - 'extra2': 'extra2' - }} + assert data(Child2) == { + "owner": {"name": "child2", "model": "Child2", "extra2": "extra2"} + } class CustomFieldTest(FieldTestCase): def test_custom_field(self): class CustomField(fields.Integer): - __schema_format__ = 'int64' + __schema_format__ = "int64" field = CustomField() - assert field.__schema__ == {'type': 'integer', 'format': 'int64'} + assert field.__schema__ == {"type": "integer", "format": "int64"} class FieldsHelpersTest(object): def test_to_dict(self): - expected = data = {'foo': 42} + expected = data = {"foo": 42} assert fields.to_marshallable_type(data) == expected def test_to_dict_obj(self): class Foo(object): def __init__(self): self.foo = 42 - expected = {'foo': 42} + + expected = {"foo": 42} assert fields.to_marshallable_type(Foo()) == expected def test_to_dict_custom_marshal(self): class Foo(object): def __marshallable__(self): - return {'foo': 42} - expected = {'foo': 42} + return {"foo": 42} + + expected = {"foo": 42} assert fields.to_marshallable_type(Foo()) == expected def test_get_value(self): - assert fields.get_value('foo', {'foo': 42}) == 42 + assert fields.get_value("foo", {"foo": 42}) == 42 def test_get_value_no_value(self): - assert fields.get_value("foo", {'foo': 42}) == 42 + assert fields.get_value("foo", {"foo": 42}) == 42 def test_get_value_obj(self, mocker): - assert fields.get_value('foo', mocker.Mock(foo=42)) == 42 + assert fields.get_value("foo", mocker.Mock(foo=42)) == 42 def test_get_value_indexable_object(self): class Test(object): @@ -1373,5 +1367,5 @@ def __getitem__(self, n): raise IndexError raise TypeError - obj = Test('hi') - assert fields.get_value('value', obj) == 'hi' + obj = Test("hi") + assert fields.get_value("value", obj) == "hi" diff --git a/tests/test_fields_mask.py b/tests/test_fields_mask.py index fdd19cd9..0dd79fa7 100644 --- a/tests/test_fields_mask.py +++ b/tests/test_fields_mask.py @@ -10,109 +10,89 @@ def assert_data(tested, expected): - '''Compare data without caring about order and type (dict vs. OrderedDict)''' + """Compare data without caring about order and type (dict vs. OrderedDict)""" tested = json.loads(json.dumps(tested)) expected = json.loads(json.dumps(expected)) assert tested == expected class MaskMixin(object): - def test_empty_mask(self): - assert Mask('') == {} + assert Mask("") == {} def test_one_field(self): - assert Mask('field_name') == {'field_name': True} + assert Mask("field_name") == {"field_name": True} def test_multiple_field(self): - mask = Mask('field1, field2, field3') - assert_data(mask, { - 'field1': True, - 'field2': True, - 'field3': True, - }) + mask = Mask("field1, field2, field3") + assert_data(mask, {"field1": True, "field2": True, "field3": True,}) def test_nested_fields(self): - parsed = Mask('nested{field1,field2}') - expected = { - 'nested': { - 'field1': True, - 'field2': True, - } - } + parsed = Mask("nested{field1,field2}") + expected = {"nested": {"field1": True, "field2": True,}} assert parsed == expected def test_complex(self): - parsed = Mask('field1, nested{field, sub{subfield}}, field2') + parsed = Mask("field1, nested{field, sub{subfield}}, field2") expected = { - 'field1': True, - 'nested': { - 'field': True, - 'sub': { - 'subfield': True, - } - }, - 'field2': True, + "field1": True, + "nested": {"field": True, "sub": {"subfield": True,}}, + "field2": True, } assert_data(parsed, expected) def test_star(self): - parsed = Mask('nested{field1,field2},*') + parsed = Mask("nested{field1,field2},*") expected = { - 'nested': { - 'field1': True, - 'field2': True, - }, - '*': True, + "nested": {"field1": True, "field2": True,}, + "*": True, } assert_data(parsed, expected) def test_order(self): - parsed = Mask('f_3, nested{f_1, f_2, f_3}, f_2, f_1') - expected = OrderedDict([ - ('f_3', True), - ('nested', OrderedDict([ - ('f_1', True), - ('f_2', True), - ('f_3', True), - ])), - ('f_2', True), - ('f_1', True), - ]) + parsed = Mask("f_3, nested{f_1, f_2, f_3}, f_2, f_1") + expected = OrderedDict( + [ + ("f_3", True), + ("nested", OrderedDict([("f_1", True), ("f_2", True), ("f_3", True),])), + ("f_2", True), + ("f_1", True), + ] + ) assert parsed == expected def test_missing_closing_bracket(self): with pytest.raises(mask.ParseError): - Mask('nested{') + Mask("nested{") def test_consecutive_coma(self): with pytest.raises(mask.ParseError): - Mask('field,,') + Mask("field,,") def test_coma_before_bracket(self): with pytest.raises(mask.ParseError): - Mask('field,{}') + Mask("field,{}") def test_coma_after_bracket(self): with pytest.raises(mask.ParseError): - Mask('nested{,}') + Mask("nested{,}") def test_unexpected_opening_bracket(self): with pytest.raises(mask.ParseError): - Mask('{{field}}') + Mask("{{field}}") def test_unexpected_closing_bracket(self): with pytest.raises(mask.ParseError): - Mask('{field}}') + Mask("{field}}") def test_support_colons(self): - assert Mask('field:name') == {'field:name': True} + assert Mask("field:name") == {"field:name": True} def test_support_dash(self): - assert Mask('field-name') == {'field-name': True} + assert Mask("field-name") == {"field-name": True} def test_support_underscore(self): - assert Mask('field_name') == {'field_name': True} + assert Mask("field_name") == {"field_name": True} class MaskUnwrappedTest(MaskMixin): @@ -122,135 +102,109 @@ def parse(self, value): class MaskWrappedTest(MaskMixin): def parse(self, value): - return Mask('{' + value + '}') + return Mask("{" + value + "}") class DObject(object): - '''A dead simple object built from a dictionnary (no recursion)''' + """A dead simple object built from a dictionnary (no recursion)""" + def __init__(self, data): self.__dict__.update(data) -person_fields = { - 'name': fields.String, - 'age': fields.Integer -} +person_fields = {"name": fields.String, "age": fields.Integer} class ApplyMaskTest(object): def test_empty(self): data = { - 'integer': 42, - 'string': 'a string', - 'boolean': True, + "integer": 42, + "string": "a string", + "boolean": True, } - result = mask.apply(data, '{}') + result = mask.apply(data, "{}") assert result == {} def test_single_field(self): data = { - 'integer': 42, - 'string': 'a string', - 'boolean': True, + "integer": 42, + "string": "a string", + "boolean": True, } - result = mask.apply(data, '{integer}') - assert result == {'integer': 42} + result = mask.apply(data, "{integer}") + assert result == {"integer": 42} def test_multiple_fields(self): data = { - 'integer': 42, - 'string': 'a string', - 'boolean': True, + "integer": 42, + "string": "a string", + "boolean": True, } - result = mask.apply(data, '{integer, string}') - assert result == {'integer': 42, 'string': 'a string'} + result = mask.apply(data, "{integer, string}") + assert result == {"integer": 42, "string": "a string"} def test_star_only(self): data = { - 'integer': 42, - 'string': 'a string', - 'boolean': True, + "integer": 42, + "string": "a string", + "boolean": True, } - result = mask.apply(data, '*') + result = mask.apply(data, "*") assert result == data def test_with_objects(self): - data = DObject({ - 'integer': 42, - 'string': 'a string', - 'boolean': True, - }) - result = mask.apply(data, '{integer, string}') - assert result == {'integer': 42, 'string': 'a string'} + data = DObject({"integer": 42, "string": "a string", "boolean": True,}) + result = mask.apply(data, "{integer, string}") + assert result == {"integer": 42, "string": "a string"} def test_with_ordered_dict(self): - data = OrderedDict({ - 'integer': 42, - 'string': 'a string', - 'boolean': True, - }) - result = mask.apply(data, '{integer, string}') - assert result == {'integer': 42, 'string': 'a string'} + data = OrderedDict({"integer": 42, "string": "a string", "boolean": True,}) + result = mask.apply(data, "{integer, string}") + assert result == {"integer": 42, "string": "a string"} def test_nested_field(self): data = { - 'integer': 42, - 'string': 'a string', - 'boolean': True, - 'nested': { - 'integer': 42, - 'string': 'a string', - 'boolean': True, - } + "integer": 42, + "string": "a string", + "boolean": True, + "nested": {"integer": 42, "string": "a string", "boolean": True,}, + } + result = mask.apply(data, "{nested}") + assert result == { + "nested": {"integer": 42, "string": "a string", "boolean": True,} } - result = mask.apply(data, '{nested}') - assert result == {'nested': { - 'integer': 42, - 'string': 'a string', - 'boolean': True, - }} def test_nested_fields(self): - data = { - 'nested': { - 'integer': 42, - 'string': 'a string', - 'boolean': True, - } - } - result = mask.apply(data, '{nested{integer}}') - assert result == {'nested': {'integer': 42}} + data = {"nested": {"integer": 42, "string": "a string", "boolean": True,}} + result = mask.apply(data, "{nested{integer}}") + assert result == {"nested": {"integer": 42}} def test_nested_with_start(self): data = { - 'nested': { - 'integer': 42, - 'string': 'a string', - 'boolean': True, - }, - 'other': 'value', + "nested": {"integer": 42, "string": "a string", "boolean": True,}, + "other": "value", } - result = mask.apply(data, '{nested{integer},*}') - assert result == {'nested': {'integer': 42}, 'other': 'value'} + result = mask.apply(data, "{nested{integer},*}") + assert result == {"nested": {"integer": 42}, "other": "value"} def test_nested_fields_when_none(self): - data = {'nested': None} - result = mask.apply(data, '{nested{integer}}') - assert result == {'nested': None} + data = {"nested": None} + result = mask.apply(data, "{nested{integer}}") + assert result == {"nested": None} def test_raw_api_fields(self): family_fields = { - 'father': fields.Raw, - 'mother': fields.Raw, + "father": fields.Raw, + "mother": fields.Raw, } - result = mask.apply(family_fields, 'father{name},mother{age}') + result = mask.apply(family_fields, "father{name},mother{age}") data = { - 'father': {'name': 'John', 'age': 42}, - 'mother': {'name': 'Jane', 'age': 42}, + "father": {"name": "John", "age": 42}, + "mother": {"name": "Jane", "age": 42}, } - expected = {'father': {'name': 'John'}, 'mother': {'age': 42}} + expected = {"father": {"name": "John"}, "mother": {"age": 42}} assert_data(marshal(data, result), expected) # Should leave th original mask untouched @@ -258,90 +212,70 @@ def test_raw_api_fields(self): def test_nested_api_fields(self): family_fields = { - 'father': fields.Nested(person_fields), - 'mother': fields.Nested(person_fields), + "father": fields.Nested(person_fields), + "mother": fields.Nested(person_fields), } - result = mask.apply(family_fields, 'father{name},mother{age}') - assert set(result.keys()) == set(['father', 'mother']) - assert isinstance(result['father'], fields.Nested) - assert set(result['father'].nested.keys()) == set(['name']) - assert isinstance(result['mother'], fields.Nested) - assert set(result['mother'].nested.keys()) == set(['age']) + result = mask.apply(family_fields, "father{name},mother{age}") + assert set(result.keys()) == set(["father", "mother"]) + assert isinstance(result["father"], fields.Nested) + assert set(result["father"].nested.keys()) == set(["name"]) + assert isinstance(result["mother"], fields.Nested) + assert set(result["mother"].nested.keys()) == set(["age"]) data = { - 'father': {'name': 'John', 'age': 42}, - 'mother': {'name': 'Jane', 'age': 42}, + "father": {"name": "John", "age": 42}, + "mother": {"name": "Jane", "age": 42}, } - expected = {'father': {'name': 'John'}, 'mother': {'age': 42}} + expected = {"father": {"name": "John"}, "mother": {"age": 42}} assert_data(marshal(data, result), expected) # Should leave th original mask untouched assert_data(marshal(data, family_fields), data) def test_multiple_nested_api_fields(self): - level_2 = {'nested_2': fields.Nested(person_fields)} - level_1 = {'nested_1': fields.Nested(level_2)} - root = {'nested': fields.Nested(level_1)} + level_2 = {"nested_2": fields.Nested(person_fields)} + level_1 = {"nested_1": fields.Nested(level_2)} + root = {"nested": fields.Nested(level_1)} - result = mask.apply(root, 'nested{nested_1{nested_2{name}}}') - assert set(result.keys()) == set(['nested']) - assert isinstance(result['nested'], fields.Nested) - assert set(result['nested'].nested.keys()) == set(['nested_1']) + result = mask.apply(root, "nested{nested_1{nested_2{name}}}") + assert set(result.keys()) == set(["nested"]) + assert isinstance(result["nested"], fields.Nested) + assert set(result["nested"].nested.keys()) == set(["nested_1"]) - data = { - 'nested': { - 'nested_1': { - 'nested_2': {'name': 'John', 'age': 42} - } - } - } - expected = { - 'nested': { - 'nested_1': { - 'nested_2': {'name': 'John'} - } - } - } + data = {"nested": {"nested_1": {"nested_2": {"name": "John", "age": 42}}}} + expected = {"nested": {"nested_1": {"nested_2": {"name": "John"}}}} assert_data(marshal(data, result), expected) # Should leave th original mask untouched assert_data(marshal(data, root), data) def test_list_fields_with_simple_field(self): - family_fields = { - 'name': fields.String, - 'members': fields.List(fields.String) - } + family_fields = {"name": fields.String, "members": fields.List(fields.String)} - result = mask.apply(family_fields, 'members') - assert set(result.keys()) == set(['members']) - assert isinstance(result['members'], fields.List) - assert isinstance(result['members'].container, fields.String) + result = mask.apply(family_fields, "members") + assert set(result.keys()) == set(["members"]) + assert isinstance(result["members"], fields.List) + assert isinstance(result["members"].container, fields.String) - data = {'name': 'Doe', 'members': ['John', 'Jane']} - expected = {'members': ['John', 'Jane']} + data = {"name": "Doe", "members": ["John", "Jane"]} + expected = {"members": ["John", "Jane"]} assert_data(marshal(data, result), expected) # Should leave th original mask untouched assert_data(marshal(data, family_fields), data) def test_list_fields_with_nested(self): - family_fields = { - 'members': fields.List(fields.Nested(person_fields)) - } + family_fields = {"members": fields.List(fields.Nested(person_fields))} - result = mask.apply(family_fields, 'members{name}') - assert set(result.keys()) == set(['members']) - assert isinstance(result['members'], fields.List) - assert isinstance(result['members'].container, fields.Nested) - assert set(result['members'].container.nested.keys()) == set(['name']) + result = mask.apply(family_fields, "members{name}") + assert set(result.keys()) == set(["members"]) + assert isinstance(result["members"], fields.List) + assert isinstance(result["members"].container, fields.Nested) + assert set(result["members"].container.nested.keys()) == set(["name"]) - data = {'members': [ - {'name': 'John', 'age': 42}, - {'name': 'Jane', 'age': 42}, - ]} - expected = {'members': [{'name': 'John'}, {'name': 'Jane'}]} + data = {"members": [{"name": "John", "age": 42}, {"name": "Jane", "age": 42},]} + expected = {"members": [{"name": "John"}, {"name": "Jane"}]} assert_data(marshal(data, result), expected) # Should leave th original mask untouched @@ -350,422 +284,359 @@ def test_list_fields_with_nested(self): def test_list_fields_with_nested_inherited(self, app): api = Api(app) - person = api.model('Person', { - 'name': fields.String, - 'age': fields.Integer - }) - child = api.inherit('Child', person, { - 'attr': fields.String - }) - - family = api.model('Family', { - 'children': fields.List(fields.Nested(child)) - }) - - result = mask.apply(family.resolved, 'children{name,attr}') - - data = {'children': [ - {'name': 'John', 'age': 5, 'attr': 'value-john'}, - {'name': 'Jane', 'age': 42, 'attr': 'value-jane'}, - ]} - expected = {'children': [ - {'name': 'John', 'attr': 'value-john'}, - {'name': 'Jane', 'attr': 'value-jane'}, - ]} + person = api.model("Person", {"name": fields.String, "age": fields.Integer}) + child = api.inherit("Child", person, {"attr": fields.String}) + + family = api.model("Family", {"children": fields.List(fields.Nested(child))}) + + result = mask.apply(family.resolved, "children{name,attr}") + + data = { + "children": [ + {"name": "John", "age": 5, "attr": "value-john"}, + {"name": "Jane", "age": 42, "attr": "value-jane"}, + ] + } + expected = { + "children": [ + {"name": "John", "attr": "value-john"}, + {"name": "Jane", "attr": "value-jane"}, + ] + } assert_data(marshal(data, result), expected) # Should leave th original mask untouched assert_data(marshal(data, family), data) def test_list_fields_with_raw(self): - family_fields = { - 'members': fields.List(fields.Raw) - } + family_fields = {"members": fields.List(fields.Raw)} - result = mask.apply(family_fields, 'members{name}') + result = mask.apply(family_fields, "members{name}") - data = {'members': [ - {'name': 'John', 'age': 42}, - {'name': 'Jane', 'age': 42}, - ]} - expected = {'members': [{'name': 'John'}, {'name': 'Jane'}]} + data = {"members": [{"name": "John", "age": 42}, {"name": "Jane", "age": 42},]} + expected = {"members": [{"name": "John"}, {"name": "Jane"}]} assert_data(marshal(data, result), expected) # Should leave th original mask untouched assert_data(marshal(data, family_fields), data) def test_list(self): - data = [{ - 'integer': 42, - 'string': 'a string', - 'boolean': True, - }, { - 'integer': 404, - 'string': 'another string', - 'boolean': False, - }] - result = mask.apply(data, '{integer, string}') + data = [ + {"integer": 42, "string": "a string", "boolean": True,}, + {"integer": 404, "string": "another string", "boolean": False,}, + ] + result = mask.apply(data, "{integer, string}") assert result == [ - {'integer': 42, 'string': 'a string'}, - {'integer': 404, 'string': 'another string'} + {"integer": 42, "string": "a string"}, + {"integer": 404, "string": "another string"}, ] def test_nested_list(self): data = { - 'integer': 42, - 'list': [{ - 'integer': 42, - 'string': 'a string', - }, { - 'integer': 404, - 'string': 'another string', - }] + "integer": 42, + "list": [ + {"integer": 42, "string": "a string",}, + {"integer": 404, "string": "another string",}, + ], + } + result = mask.apply(data, "{list}") + assert result == { + "list": [ + {"integer": 42, "string": "a string",}, + {"integer": 404, "string": "another string",}, + ] } - result = mask.apply(data, '{list}') - assert result == {'list': [{ - 'integer': 42, - 'string': 'a string', - }, { - 'integer': 404, - 'string': 'another string', - }]} def test_nested_list_fields(self): data = { - 'list': [{ - 'integer': 42, - 'string': 'a string', - }, { - 'integer': 404, - 'string': 'another string', - }] + "list": [ + {"integer": 42, "string": "a string",}, + {"integer": 404, "string": "another string",}, + ] } - result = mask.apply(data, '{list{integer}}') - assert result == {'list': [{'integer': 42}, {'integer': 404}]} + result = mask.apply(data, "{list{integer}}") + assert result == {"list": [{"integer": 42}, {"integer": 404}]} def test_missing_field_none_by_default(self): - result = mask.apply({}, '{integer}') - assert result == {'integer': None} + result = mask.apply({}, "{integer}") + assert result == {"integer": None} def test_missing_field_skipped(self): - result = mask.apply({}, '{integer}', skip=True) + result = mask.apply({}, "{integer}", skip=True) assert result == {} def test_missing_nested_field_skipped(self): - result = mask.apply({}, 'nested{integer}', skip=True) + result = mask.apply({}, "nested{integer}", skip=True) assert result == {} def test_mask_error_on_simple_fields(self): model = { - 'name': fields.String, + "name": fields.String, } with pytest.raises(mask.MaskError): - mask.apply(model, 'name{notpossible}') + mask.apply(model, "name{notpossible}") def test_mask_error_on_list_field(self): - model = { - 'nested': fields.List(fields.String) - } + model = {"nested": fields.List(fields.String)} with pytest.raises(mask.MaskError): - mask.apply(model, 'nested{notpossible}') + mask.apply(model, "nested{notpossible}") class MaskAPI(object): def test_marshal_with_honour_field_mask_header(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }) + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): @api.marshal_with(model) def get(self): - return { - 'name': 'John Doe', - 'age': 42, - 'boolean': True - } + return {"name": "John Doe", "age": 42, "boolean": True} - data = client.get_json('/test/', headers={ - 'X-Fields': '{name,age}' - }) - assert data == {'name': 'John Doe', 'age': 42} + data = client.get_json("/test/", headers={"X-Fields": "{name,age}"}) + assert data == {"name": "John Doe", "age": 42} def test_marshal_with_honour_field_mask_list(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }) + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): @api.marshal_with(model) def get(self): - return [{ - 'name': 'John Doe', - 'age': 42, - 'boolean': True - }, { - 'name': 'Jane Doe', - 'age': 33, - 'boolean': False - }] - - data = client.get_json('/test/', headers={'X-Fields': '{name,age}'}) - assert data == [{ - 'name': 'John Doe', - 'age': 42, - }, { - 'name': 'Jane Doe', - 'age': 33, - }] + return [ + {"name": "John Doe", "age": 42, "boolean": True}, + {"name": "Jane Doe", "age": 33, "boolean": False}, + ] + + data = client.get_json("/test/", headers={"X-Fields": "{name,age}"}) + assert data == [ + {"name": "John Doe", "age": 42,}, + {"name": "Jane Doe", "age": 33,}, + ] def test_marshal_with_honour_complex_field_mask_header(self, app, client): api = Api(app) - person = api.model('Person', person_fields) - child = api.inherit('Child', person, { - 'attr': fields.String - }) + person = api.model("Person", person_fields) + child = api.inherit("Child", person, {"attr": fields.String}) - family = api.model('Family', { - 'father': fields.Nested(person), - 'mother': fields.Nested(person), - 'children': fields.List(fields.Nested(child)), - 'free': fields.List(fields.Raw), - }) + family = api.model( + "Family", + { + "father": fields.Nested(person), + "mother": fields.Nested(person), + "children": fields.List(fields.Nested(child)), + "free": fields.List(fields.Raw), + }, + ) - house = api.model('House', { - 'family': fields.Nested(family, attribute='people') - }) + house = api.model( + "House", {"family": fields.Nested(family, attribute="people")} + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): @api.marshal_with(house) def get(self): - return {'people': { - 'father': {'name': 'John', 'age': 42}, - 'mother': {'name': 'Jane', 'age': 42}, - 'children': [ - {'name': 'Jack', 'age': 5, 'attr': 'value-1'}, - {'name': 'Julie', 'age': 7, 'attr': 'value-2'}, - ], - 'free': [ - {'key-1': '1-1', 'key-2': '1-2'}, - {'key-1': '2-1', 'key-2': '2-2'}, - ] - }} - - data = client.get_json('/test/', headers={ - 'X-Fields': 'family{father{name},mother{age},children{name,attr},free{key-2}}' - }) - assert data == {'family': { - 'father': {'name': 'John'}, - 'mother': {'age': 42}, - 'children': [{'name': 'Jack', 'attr': 'value-1'}, {'name': 'Julie', 'attr': 'value-2'}], - 'free': [{'key-2': '1-2'}, {'key-2': '2-2'}] - }} + return { + "people": { + "father": {"name": "John", "age": 42}, + "mother": {"name": "Jane", "age": 42}, + "children": [ + {"name": "Jack", "age": 5, "attr": "value-1"}, + {"name": "Julie", "age": 7, "attr": "value-2"}, + ], + "free": [ + {"key-1": "1-1", "key-2": "1-2"}, + {"key-1": "2-1", "key-2": "2-2"}, + ], + } + } + + data = client.get_json( + "/test/", + headers={ + "X-Fields": "family{father{name},mother{age},children{name,attr},free{key-2}}" + }, + ) + assert data == { + "family": { + "father": {"name": "John"}, + "mother": {"age": 42}, + "children": [ + {"name": "Jack", "attr": "value-1"}, + {"name": "Julie", "attr": "value-2"}, + ], + "free": [{"key-2": "1-2"}, {"key-2": "2-2"}], + } + } def test_marshal_honour_field_mask(self, app): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }) + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + ) - data = { - 'name': 'John Doe', - 'age': 42, - 'boolean': True - } + data = {"name": "John Doe", "age": 42, "boolean": True} - result = api.marshal(data, model, mask='{name,age}') + result = api.marshal(data, model, mask="{name,age}") assert result == { - 'name': 'John Doe', - 'age': 42, + "name": "John Doe", + "age": 42, } def test_marshal_with_honour_default_mask(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }) + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): - @api.marshal_with(model, mask='{name,age}') + @api.marshal_with(model, mask="{name,age}") def get(self): - return { - 'name': 'John Doe', - 'age': 42, - 'boolean': True - } + return {"name": "John Doe", "age": 42, "boolean": True} - data = self.get_json('/test/') - self.assertEqual(data, { - 'name': 'John Doe', - 'age': 42, - }) + data = self.get_json("/test/") + self.assertEqual(data, {"name": "John Doe", "age": 42,}) def test_marshal_with_honour_default_model_mask(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }, mask='{name,age}') + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + mask="{name,age}", + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): @api.marshal_with(model) def get(self): - return { - 'name': 'John Doe', - 'age': 42, - 'boolean': True - } + return {"name": "John Doe", "age": 42, "boolean": True} - data = client.get_json('/test/') - assert data == {'name': 'John Doe', 'age': 42} + data = client.get_json("/test/") + assert data == {"name": "John Doe", "age": 42} - def test_marshal_with_honour_header_field_mask_with_default_model_mask(self, app, client): + def test_marshal_with_honour_header_field_mask_with_default_model_mask( + self, app, client + ): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }, mask='{name,age}') + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + mask="{name,age}", + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): @api.marshal_with(model) def get(self): - return { - 'name': 'John Doe', - 'age': 42, - 'boolean': True - } + return {"name": "John Doe", "age": 42, "boolean": True} - data = client.get_json('/test/', headers={ - 'X-Fields': '{name}' - }) - assert data == {'name': 'John Doe'} + data = client.get_json("/test/", headers={"X-Fields": "{name}"}) + assert data == {"name": "John Doe"} - def test_marshal_with_honour_header_default_mask_with_default_model_mask(self, app, client): + def test_marshal_with_honour_header_default_mask_with_default_model_mask( + self, app, client + ): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }, mask='{name,boolean}') + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + mask="{name,boolean}", + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): - @api.marshal_with(model, mask='{name}') + @api.marshal_with(model, mask="{name}") def get(self): - return { - 'name': 'John Doe', - 'age': 42, - 'boolean': True - } + return {"name": "John Doe", "age": 42, "boolean": True} - data = client.get_json('/test/') - assert data == {'name': 'John Doe'} + data = client.get_json("/test/") + assert data == {"name": "John Doe"} def test_marshal_with_honour_header_field_mask_with_default_mask(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }) + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): - @api.marshal_with(model, mask='{name,age}') + @api.marshal_with(model, mask="{name,age}") def get(self): - return { - 'name': 'John Doe', - 'age': 42, - 'boolean': True - } + return {"name": "John Doe", "age": 42, "boolean": True} - data = client.get_json('/test/', headers={'X-Fields': '{name}'}) - assert data == {'name': 'John Doe'} + data = client.get_json("/test/", headers={"X-Fields": "{name}"}) + assert data == {"name": "John Doe"} - def test_marshal_with_honour_header_field_mask_with_default_mask_and_default_model_mask(self, app, client): + def test_marshal_with_honour_header_field_mask_with_default_mask_and_default_model_mask( + self, app, client + ): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }, mask='{name,boolean}') + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + mask="{name,boolean}", + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): - @api.marshal_with(model, mask='{name,age}') + @api.marshal_with(model, mask="{name,age}") def get(self): - return { - 'name': 'John Doe', - 'age': 42, - 'boolean': True - } + return {"name": "John Doe", "age": 42, "boolean": True} - data = client.get_json('/test/', headers={'X-Fields': '{name}'}) - assert data == {'name': 'John Doe'} + data = client.get_json("/test/", headers={"X-Fields": "{name}"}) + assert data == {"name": "John Doe"} def test_marshal_with_honour_custom_field_mask(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }) + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): @api.marshal_with(model) def get(self): - return { - 'name': 'John Doe', - 'age': 42, - 'boolean': True - } + return {"name": "John Doe", "age": 42, "boolean": True} - app.config['RESTX_MASK_HEADER'] = 'X-Mask' - data = client.get_json('/test/', headers={'X-Mask': '{name,age}'}) + app.config["RESTX_MASK_HEADER"] = "X-Mask" + data = client.get_json("/test/", headers={"X-Mask": "{name,age}"}) - assert data == {'name': 'John Doe', 'age': 42} + assert data == {"name": "John Doe", "age": 42} def test_marshal_does_not_hit_unrequired_attributes(self, app, client): api = Api(app) - model = api.model('Person', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }) + model = api.model( + "Person", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + ) class Person(object): def __init__(self, name, age): @@ -776,57 +647,45 @@ def __init__(self, name, age): def boolean(self): raise Exception() - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): @api.marshal_with(model) def get(self): - return Person('John Doe', 42) + return Person("John Doe", 42) - data = client.get_json('/test/', headers={'X-Fields': '{name,age}'}) - assert data == {'name': 'John Doe', 'age': 42} + data = client.get_json("/test/", headers={"X-Fields": "{name,age}"}) + assert data == {"name": "John Doe", "age": 42} def test_marshal_with_skip_missing_fields(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - }) + model = api.model("Test", {"name": fields.String, "age": fields.Integer,}) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): @api.marshal_with(model) def get(self): return { - 'name': 'John Doe', - 'age': 42, + "name": "John Doe", + "age": 42, } - data = client.get_json('/test/', headers={'X-Fields': '{name,missing}'}) - assert data == {'name': 'John Doe'} + data = client.get_json("/test/", headers={"X-Fields": "{name,missing}"}) + assert data == {"name": "John Doe"} def test_marshal_handle_inheritance(self, app): api = Api(app) - person = api.model('Person', { - 'name': fields.String, - 'age': fields.Integer, - }) + person = api.model("Person", {"name": fields.String, "age": fields.Integer,}) - child = api.inherit('Child', person, { - 'extra': fields.String, - }) + child = api.inherit("Child", person, {"extra": fields.String,}) - data = { - 'name': 'John Doe', - 'age': 42, - 'extra': 'extra' - } + data = {"name": "John Doe", "age": 42, "extra": "extra"} values = ( - ('name', {'name': 'John Doe'}), - ('name,extra', {'name': 'John Doe', 'extra': 'extra'}), - ('extra', {'extra': 'extra'}), + ("name", {"name": "John Doe"}), + ("name,extra", {"name": "John Doe", "extra": "extra"}), + ("extra", {"extra": "extra"}), ) for value, expected in values: @@ -836,234 +695,201 @@ def test_marshal_handle_inheritance(self, app): def test_marshal_with_handle_polymorph(self, app, client): api = Api(app) - parent = api.model('Person', { - 'name': fields.String, - }) + parent = api.model("Person", {"name": fields.String,}) - child1 = api.inherit('Child1', parent, { - 'extra1': fields.String, - }) + child1 = api.inherit("Child1", parent, {"extra1": fields.String,}) - child2 = api.inherit('Child2', parent, { - 'extra2': fields.String, - }) + child2 = api.inherit("Child2", parent, {"extra2": fields.String,}) class Child1(object): - name = 'child1' - extra1 = 'extra1' + name = "child1" + extra1 = "extra1" class Child2(object): - name = 'child2' - extra2 = 'extra2' + name = "child2" + extra2 = "extra2" - mapping = { - Child1: child1, - Child2: child2 - } + mapping = {Child1: child1, Child2: child2} - thing = api.model('Thing', { - 'owner': fields.Polymorph(mapping), - }) + thing = api.model("Thing", {"owner": fields.Polymorph(mapping),}) - @api.route('/thing-1/') + @api.route("/thing-1/") class Thing1Resource(Resource): @api.marshal_with(thing) def get(self): - return {'owner': Child1()} + return {"owner": Child1()} - @api.route('/thing-2/') + @api.route("/thing-2/") class Thing2Resource(Resource): @api.marshal_with(thing) def get(self): - return {'owner': Child2()} + return {"owner": Child2()} - data = client.get_json('/thing-1/', headers={'X-Fields': 'owner{name}'}) - assert data == {'owner': {'name': 'child1'}} + data = client.get_json("/thing-1/", headers={"X-Fields": "owner{name}"}) + assert data == {"owner": {"name": "child1"}} - data = client.get_json('/thing-1/', headers={'X-Fields': 'owner{extra1}'}) - assert data == {'owner': {'extra1': 'extra1'}} + data = client.get_json("/thing-1/", headers={"X-Fields": "owner{extra1}"}) + assert data == {"owner": {"extra1": "extra1"}} - data = client.get_json('/thing-2/', headers={'X-Fields': 'owner{name}'}) - assert data == {'owner': {'name': 'child2'}} + data = client.get_json("/thing-2/", headers={"X-Fields": "owner{name}"}) + assert data == {"owner": {"name": "child2"}} def test_raise_400_on_invalid_mask(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - }) + model = api.model("Test", {"name": fields.String, "age": fields.Integer,}) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): @api.marshal_with(model) def get(self): pass - response = client.get('/test/', headers={'X-Fields': 'name{,missing}'}) + response = client.get("/test/", headers={"X-Fields": "name{,missing}"}) assert response.status_code == 400 - assert response.content_type == 'application/json' + assert response.content_type == "application/json" class SwaggerMaskHeaderTest(object): def test_marshal_with_expose_mask_header(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }) + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): @api.marshal_with(model) def get(self): - return { - 'name': 'John Doe', - 'age': 42, - 'boolean': True - } + return {"name": "John Doe", "age": 42, "boolean": True} specs = client.get_specs() - op = specs['paths']['/test/']['get'] + op = specs["paths"]["/test/"]["get"] - assert 'parameters' in op - assert len(op['parameters']) == 1 + assert "parameters" in op + assert len(op["parameters"]) == 1 - param = op['parameters'][0] + param = op["parameters"][0] - assert param['name'] == 'X-Fields' - assert param['type'] == 'string' - assert param['format'] == 'mask' - assert param['in'] == 'header' - assert 'required' not in param - assert 'default' not in param + assert param["name"] == "X-Fields" + assert param["type"] == "string" + assert param["format"] == "mask" + assert param["in"] == "header" + assert "required" not in param + assert "default" not in param def test_marshal_with_expose_custom_mask_header(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }) + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): @api.marshal_with(model) def get(self): - return { - 'name': 'John Doe', - 'age': 42, - 'boolean': True - } + return {"name": "John Doe", "age": 42, "boolean": True} - app.config['RESTX_MASK_HEADER'] = 'X-Mask' + app.config["RESTX_MASK_HEADER"] = "X-Mask" specs = client.get_specs() - op = specs['paths']['/test/']['get'] - assert 'parameters' in op - assert len(op['parameters']) == 1 + op = specs["paths"]["/test/"]["get"] + assert "parameters" in op + assert len(op["parameters"]) == 1 - param = op['parameters'][0] - assert param['name'] == 'X-Mask' + param = op["parameters"][0] + assert param["name"] == "X-Mask" def test_marshal_with_disabling_mask_header(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }) + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): @api.marshal_with(model) def get(self): - return { - 'name': 'John Doe', - 'age': 42, - 'boolean': True - } + return {"name": "John Doe", "age": 42, "boolean": True} - app.config['RESTX_MASK_SWAGGER'] = False + app.config["RESTX_MASK_SWAGGER"] = False specs = client.get_specs() - op = specs['paths']['/test/']['get'] + op = specs["paths"]["/test/"]["get"] - assert 'parameters' not in op + assert "parameters" not in op def test_is_only_exposed_on_marshal_with(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }) + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): def get(self): - return api.marshal({ - 'name': 'John Doe', - 'age': 42, - 'boolean': True - }, model) + return api.marshal( + {"name": "John Doe", "age": 42, "boolean": True}, model + ) specs = client.get_specs() - op = specs['paths']['/test/']['get'] + op = specs["paths"]["/test/"]["get"] - assert 'parameters' not in op + assert "parameters" not in op def test_marshal_with_expose_default_mask_header(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }) + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): - @api.marshal_with(model, mask='{name,age}') + @api.marshal_with(model, mask="{name,age}") def get(self): pass specs = client.get_specs() - op = specs['paths']['/test/']['get'] + op = specs["paths"]["/test/"]["get"] - assert 'parameters' in op - assert len(op['parameters']) == 1 + assert "parameters" in op + assert len(op["parameters"]) == 1 - param = op['parameters'][0] + param = op["parameters"][0] - assert param['name'] == 'X-Fields' - assert param['type'] == 'string' - assert param['format'] == 'mask' - assert param['default'] == '{name,age}' - assert param['in'] == 'header' - assert 'required' not in param + assert param["name"] == "X-Fields" + assert param["type"] == "string" + assert param["format"] == "mask" + assert param["default"] == "{name,age}" + assert param["in"] == "header" + assert "required" not in param def test_marshal_with_expose_default_model_mask_header(self, app, client): api = Api(app) - model = api.model('Test', { - 'name': fields.String, - 'age': fields.Integer, - 'boolean': fields.Boolean, - }, mask='{name,age}') + model = api.model( + "Test", + {"name": fields.String, "age": fields.Integer, "boolean": fields.Boolean,}, + mask="{name,age}", + ) - @api.route('/test/') + @api.route("/test/") class TestResource(Resource): @api.marshal_with(model) def get(self): pass specs = client.get_specs() - definition = specs['definitions']['Test'] - assert 'x-mask' in definition - assert definition['x-mask'] == '{name,age}' + definition = specs["definitions"]["Test"] + assert "x-mask" in definition + assert definition["x-mask"] == "{name,age}" diff --git a/tests/test_inputs.py b/tests/test_inputs.py index ce3ccd3d..5a0133bc 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -12,132 +12,187 @@ class Iso8601DateTest(object): - @pytest.mark.parametrize('value,expected', [ - ('2011-01-01', date(2011, 1, 1)), - ('2011-01-01T00:00:00+00:00', date(2011, 1, 1)), - ('2011-01-01T23:59:59+00:00', date(2011, 1, 1)), - ('2011-01-01T23:59:59.001000+00:00', date(2011, 1, 1)), - ('2011-01-01T23:59:59+02:00', date(2011, 1, 1)), - ]) + @pytest.mark.parametrize( + "value,expected", + [ + ("2011-01-01", date(2011, 1, 1)), + ("2011-01-01T00:00:00+00:00", date(2011, 1, 1)), + ("2011-01-01T23:59:59+00:00", date(2011, 1, 1)), + ("2011-01-01T23:59:59.001000+00:00", date(2011, 1, 1)), + ("2011-01-01T23:59:59+02:00", date(2011, 1, 1)), + ], + ) def test_valid_values(self, value, expected): assert inputs.date_from_iso8601(value) == expected def test_error(self): with pytest.raises(ValueError): - inputs.date_from_iso8601('2008-13-13') + inputs.date_from_iso8601("2008-13-13") def test_schema(self): - assert inputs.date_from_iso8601.__schema__ == {'type': 'string', 'format': 'date'} + assert inputs.date_from_iso8601.__schema__ == { + "type": "string", + "format": "date", + } class Iso8601DatetimeTest(object): - @pytest.mark.parametrize('value,expected', [ - ('2011-01-01', datetime(2011, 1, 1)), - ('2011-01-01T00:00:00+00:00', datetime(2011, 1, 1, tzinfo=pytz.utc)), - ('2011-01-01T23:59:59+00:00', datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.utc)), - ('2011-01-01T23:59:59.001000+00:00', datetime(2011, 1, 1, 23, 59, 59, 1000, tzinfo=pytz.utc)), - ('2011-01-01T23:59:59+02:00', datetime(2011, 1, 1, 21, 59, 59, tzinfo=pytz.utc)), - ]) + @pytest.mark.parametrize( + "value,expected", + [ + ("2011-01-01", datetime(2011, 1, 1)), + ("2011-01-01T00:00:00+00:00", datetime(2011, 1, 1, tzinfo=pytz.utc)), + ( + "2011-01-01T23:59:59+00:00", + datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.utc), + ), + ( + "2011-01-01T23:59:59.001000+00:00", + datetime(2011, 1, 1, 23, 59, 59, 1000, tzinfo=pytz.utc), + ), + ( + "2011-01-01T23:59:59+02:00", + datetime(2011, 1, 1, 21, 59, 59, tzinfo=pytz.utc), + ), + ], + ) def test_valid_values(self, value, expected): assert inputs.datetime_from_iso8601(value) == expected def test_error(self): with pytest.raises(ValueError): - inputs.datetime_from_iso8601('2008-13-13') + inputs.datetime_from_iso8601("2008-13-13") def test_schema(self): - assert inputs.datetime_from_iso8601.__schema__ == {'type': 'string', 'format': 'date-time'} + assert inputs.datetime_from_iso8601.__schema__ == { + "type": "string", + "format": "date-time", + } class Rfc822DatetimeTest(object): - @pytest.mark.parametrize('value,expected', [ - ('Sat, 01 Jan 2011', datetime(2011, 1, 1, tzinfo=pytz.utc)), - ('Sat, 01 Jan 2011 00:00', datetime(2011, 1, 1, tzinfo=pytz.utc)), - ('Sat, 01 Jan 2011 00:00:00', datetime(2011, 1, 1, tzinfo=pytz.utc)), - ('Sat, 01 Jan 2011 00:00:00 +0000', datetime(2011, 1, 1, tzinfo=pytz.utc)), - ('Sat, 01 Jan 2011 00:00:00 -0000', datetime(2011, 1, 1, tzinfo=pytz.utc)), - ('Sat, 01 Jan 2011 23:59:59 -0000', datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.utc)), - ('Sat, 01 Jan 2011 21:00:00 +0200', datetime(2011, 1, 1, 19, 0, 0, tzinfo=pytz.utc)), - ('Sat, 01 Jan 2011 21:00:00 -0200', datetime(2011, 1, 1, 23, 0, 0, tzinfo=pytz.utc)), - ]) + @pytest.mark.parametrize( + "value,expected", + [ + ("Sat, 01 Jan 2011", datetime(2011, 1, 1, tzinfo=pytz.utc)), + ("Sat, 01 Jan 2011 00:00", datetime(2011, 1, 1, tzinfo=pytz.utc)), + ("Sat, 01 Jan 2011 00:00:00", datetime(2011, 1, 1, tzinfo=pytz.utc)), + ("Sat, 01 Jan 2011 00:00:00 +0000", datetime(2011, 1, 1, tzinfo=pytz.utc)), + ("Sat, 01 Jan 2011 00:00:00 -0000", datetime(2011, 1, 1, tzinfo=pytz.utc)), + ( + "Sat, 01 Jan 2011 23:59:59 -0000", + datetime(2011, 1, 1, 23, 59, 59, tzinfo=pytz.utc), + ), + ( + "Sat, 01 Jan 2011 21:00:00 +0200", + datetime(2011, 1, 1, 19, 0, 0, tzinfo=pytz.utc), + ), + ( + "Sat, 01 Jan 2011 21:00:00 -0200", + datetime(2011, 1, 1, 23, 0, 0, tzinfo=pytz.utc), + ), + ], + ) def test_valid_values(self, value, expected): assert inputs.datetime_from_rfc822(value) == expected def test_error(self): with pytest.raises(ValueError): - inputs.datetime_from_rfc822('Fake, 01 XXX 2011') + inputs.datetime_from_rfc822("Fake, 01 XXX 2011") class NetlocRegexpTest(object): - @pytest.mark.parametrize('netloc,kwargs', [ - ('localhost', {'localhost': 'localhost'}), - ('example.com', {'domain': 'example.com'}), - ('www.example.com', {'domain': 'www.example.com'}), - ('www.example.com:8000', {'domain': 'www.example.com', 'port': '8000'}), - ('valid-with-hyphens.com', {'domain': 'valid-with-hyphens.com'}), - ('subdomain.example.com', {'domain': 'subdomain.example.com'}), - ('200.8.9.10', {'ipv4': '200.8.9.10'}), - ('200.8.9.10:8000', {'ipv4': '200.8.9.10', 'port': '8000'}), - ('valid-----hyphens.com', {'domain': 'valid-----hyphens.com'}), - ('foo:bar@example.com', {'auth': 'foo:bar', 'domain': 'example.com'}), - ('foo:@example.com', {'auth': 'foo:', 'domain': 'example.com'}), - ('foo@example.com', {'auth': 'foo', 'domain': 'example.com'}), - ('foo:@2001:db8:85a3::8a2e:370:7334', {'auth': 'foo:', 'ipv6': '2001:db8:85a3::8a2e:370:7334'}), - ('[1fff:0:a88:85a3::ac1f]:8001', {'ipv6': '1fff:0:a88:85a3::ac1f', 'port': '8001'}), - ('foo2:qd1%r@example.com', {'auth': 'foo2:qd1%r', 'domain': 'example.com'}), - ]) + @pytest.mark.parametrize( + "netloc,kwargs", + [ + ("localhost", {"localhost": "localhost"}), + ("example.com", {"domain": "example.com"}), + ("www.example.com", {"domain": "www.example.com"}), + ("www.example.com:8000", {"domain": "www.example.com", "port": "8000"}), + ("valid-with-hyphens.com", {"domain": "valid-with-hyphens.com"}), + ("subdomain.example.com", {"domain": "subdomain.example.com"}), + ("200.8.9.10", {"ipv4": "200.8.9.10"}), + ("200.8.9.10:8000", {"ipv4": "200.8.9.10", "port": "8000"}), + ("valid-----hyphens.com", {"domain": "valid-----hyphens.com"}), + ("foo:bar@example.com", {"auth": "foo:bar", "domain": "example.com"}), + ("foo:@example.com", {"auth": "foo:", "domain": "example.com"}), + ("foo@example.com", {"auth": "foo", "domain": "example.com"}), + ( + "foo:@2001:db8:85a3::8a2e:370:7334", + {"auth": "foo:", "ipv6": "2001:db8:85a3::8a2e:370:7334"}, + ), + ( + "[1fff:0:a88:85a3::ac1f]:8001", + {"ipv6": "1fff:0:a88:85a3::ac1f", "port": "8001"}, + ), + ("foo2:qd1%r@example.com", {"auth": "foo2:qd1%r", "domain": "example.com"}), + ], + ) def test_match(self, netloc, kwargs): match = inputs.netloc_regex.match(netloc) - assert match, 'Should match {0}'.format(netloc) - expected = {'auth': None, 'domain': None, 'ipv4': None, 'ipv6': None, 'localhost': None, 'port': None} + assert match, "Should match {0}".format(netloc) + expected = { + "auth": None, + "domain": None, + "ipv4": None, + "ipv6": None, + "localhost": None, + "port": None, + } expected.update(kwargs) assert match.groupdict() == expected class URLTest(object): def assert_bad_url(self, validator, value, details=None): - msg = '{0} is not a valid URL' + msg = "{0} is not a valid URL" with pytest.raises(ValueError) as cm: validator(value) if details: - assert text_type(cm.value) == '. '.join((msg, details)).format(value) + assert text_type(cm.value) == ". ".join((msg, details)).format(value) else: assert text_type(cm.value).startswith(msg.format(value)) - @pytest.mark.parametrize('url', [ - 'http://www.djangoproject.com/', - 'http://example.com/', - 'http://www.example.com/', - 'http://www.example.com/test', - 'http://valid-with-hyphens.com/', - 'http://subdomain.example.com/', - 'http://valid-----hyphens.com/', - 'http://example.com?something=value', - 'http://example.com/index.php?something=value&another=value2', - ]) + @pytest.mark.parametrize( + "url", + [ + "http://www.djangoproject.com/", + "http://example.com/", + "http://www.example.com/", + "http://www.example.com/test", + "http://valid-with-hyphens.com/", + "http://subdomain.example.com/", + "http://valid-----hyphens.com/", + "http://example.com?something=value", + "http://example.com/index.php?something=value&another=value2", + ], + ) def test_valid_values_default(self, url): validator = inputs.URL() assert validator(url) == url - @pytest.mark.parametrize('url', [ - 'foo', - 'http://', - 'http://example', - 'http://example.', - 'http://.com', - 'http://invalid-.com', - 'http://-invalid.com', - 'http://inv-.alid-.com', - 'http://inv-.-alid.com', - 'foo bar baz', - 'foo \u2713', - 'http://@foo:bar@example.com', - 'http://:bar@example.com', - 'http://bar:bar:bar@example.com', - 'http://300:300:300:300', - 'http://example.com:70000', - 'http://example.com:0000', - ]) + @pytest.mark.parametrize( + "url", + [ + "foo", + "http://", + "http://example", + "http://example.", + "http://.com", + "http://invalid-.com", + "http://-invalid.com", + "http://inv-.alid-.com", + "http://inv-.-alid.com", + "foo bar baz", + "foo \u2713", + "http://@foo:bar@example.com", + "http://:bar@example.com", + "http://bar:bar:bar@example.com", + "http://300:300:300:300", + "http://example.com:70000", + "http://example.com:0000", + ], + ) def test_bad_urls(self, url): # Test with everything enabled to ensure bad URL are really detected validator = inputs.URL(ip=True, auth=True, port=True) @@ -147,552 +202,599 @@ def test_bad_urls(self, url): # validator(url) # assert text_type(cm.exception).startswith(msg) - @pytest.mark.parametrize('url', [ - 'google.com', - 'domain.google.com', - 'kevin:pass@google.com/path?query', - 'google.com/path?\u2713', - ]) + @pytest.mark.parametrize( + "url", + [ + "google.com", + "domain.google.com", + "kevin:pass@google.com/path?query", + "google.com/path?\u2713", + ], + ) def test_bad_urls_with_suggestion(self, url): validator = inputs.URL() - self.assert_bad_url(validator, url, 'Did you mean: http://{0}') - - @pytest.mark.parametrize('url', [ - 'http://200.8.9.10/', - 'http://foo:bar@200.8.9.10/', - 'http://200.8.9.10:8000/test', - 'http://2001:db8:85a3::8a2e:370:7334', - 'http://[1fff:0:a88:85a3::ac1f]:8001' - ]) + self.assert_bad_url(validator, url, "Did you mean: http://{0}") + + @pytest.mark.parametrize( + "url", + [ + "http://200.8.9.10/", + "http://foo:bar@200.8.9.10/", + "http://200.8.9.10:8000/test", + "http://2001:db8:85a3::8a2e:370:7334", + "http://[1fff:0:a88:85a3::ac1f]:8001", + ], + ) def test_reject_ip(self, url): validator = inputs.URL() - self.assert_bad_url(validator, url, 'IP is not allowed') - - @pytest.mark.parametrize('url', [ - 'http://200.8.9.10/', - 'http://200.8.9.10/test', - 'http://2001:db8:85a3::8a2e:370:7334', - 'http://[1fff:0:a88:85a3::ac1f]', - ]) + self.assert_bad_url(validator, url, "IP is not allowed") + + @pytest.mark.parametrize( + "url", + [ + "http://200.8.9.10/", + "http://200.8.9.10/test", + "http://2001:db8:85a3::8a2e:370:7334", + "http://[1fff:0:a88:85a3::ac1f]", + ], + ) def test_allow_ip(self, url): validator = inputs.URL(ip=True) assert validator(url) == url - @pytest.mark.parametrize('url', [ - 'http://foo:bar@200.8.9.10/', - 'http://foo:@2001:db8:85a3::8a2e:370:7334', - 'http://foo:bar@[1fff:0:a88:85a3::ac1f]:8001', - 'http://foo:@2001:db8:85a3::8a2e:370:7334', - 'http://foo2:qd1%r@example.com', - ]) + @pytest.mark.parametrize( + "url", + [ + "http://foo:bar@200.8.9.10/", + "http://foo:@2001:db8:85a3::8a2e:370:7334", + "http://foo:bar@[1fff:0:a88:85a3::ac1f]:8001", + "http://foo:@2001:db8:85a3::8a2e:370:7334", + "http://foo2:qd1%r@example.com", + ], + ) def test_reject_auth(self, url): # Test with IP and port to ensure only auth is rejected validator = inputs.URL(ip=True, port=True) - self.assert_bad_url(validator, url, 'Authentication is not allowed') - - @pytest.mark.parametrize('url', [ - 'http://foo:bar@example.com', - 'http://foo:@example.com', - 'http://foo@example.com', - ]) + self.assert_bad_url(validator, url, "Authentication is not allowed") + + @pytest.mark.parametrize( + "url", + [ + "http://foo:bar@example.com", + "http://foo:@example.com", + "http://foo@example.com", + ], + ) def test_allow_auth(self, url): validator = inputs.URL(auth=True) assert validator(url) == url - @pytest.mark.parametrize('url', [ - 'http://localhost', - 'http://127.0.0.1', - 'http://127.0.1.1', - 'http://::1', - ]) + @pytest.mark.parametrize( + "url", + ["http://localhost", "http://127.0.0.1", "http://127.0.1.1", "http://::1",], + ) def test_reject_local(self, url): # Test with IP and port to ensure only auth is rejected validator = inputs.URL(ip=True) - self.assert_bad_url(validator, url, 'Localhost is not allowed') - - @pytest.mark.parametrize('url', [ - 'http://localhost', - 'http://127.0.0.1', - 'http://127.0.1.1', - 'http://::1', - ]) + self.assert_bad_url(validator, url, "Localhost is not allowed") + + @pytest.mark.parametrize( + "url", + ["http://localhost", "http://127.0.0.1", "http://127.0.1.1", "http://::1",], + ) def test_allow_local(self, url): validator = inputs.URL(ip=True, local=True) assert validator(url) == url - @pytest.mark.parametrize('url', [ - 'http://200.8.9.10:8080/', - 'http://foo:bar@200.8.9.10:8080/', - 'http://foo:bar@[1fff:0:a88:85a3::ac1f]:8001' - ]) + @pytest.mark.parametrize( + "url", + [ + "http://200.8.9.10:8080/", + "http://foo:bar@200.8.9.10:8080/", + "http://foo:bar@[1fff:0:a88:85a3::ac1f]:8001", + ], + ) def test_reject_port(self, url): # Test with auth and port to ensure only port is rejected validator = inputs.URL(ip=True, auth=True) - self.assert_bad_url(validator, url, 'Custom port is not allowed') - - @pytest.mark.parametrize('url', [ - 'http://example.com:80', - 'http://example.com:8080', - 'http://www.example.com:8000/test', - ]) + self.assert_bad_url(validator, url, "Custom port is not allowed") + + @pytest.mark.parametrize( + "url", + [ + "http://example.com:80", + "http://example.com:8080", + "http://www.example.com:8000/test", + ], + ) def test_allow_port(self, url): validator = inputs.URL(port=True) assert validator(url) == url - @pytest.mark.parametrize('url', [ - 'sip://somewhere.com', - 'irc://somewhere.com', - ]) + @pytest.mark.parametrize("url", ["sip://somewhere.com", "irc://somewhere.com",]) def test_valid_restricted_schemes(self, url): - validator = inputs.URL(schemes=('sip', 'irc')) + validator = inputs.URL(schemes=("sip", "irc")) assert validator(url) == url - @pytest.mark.parametrize('url', [ - 'http://somewhere.com', - 'https://somewhere.com', - ]) + @pytest.mark.parametrize("url", ["http://somewhere.com", "https://somewhere.com",]) def test_invalid_restricted_schemes(self, url): - validator = inputs.URL(schemes=('sip', 'irc')) - self.assert_bad_url(validator, url, 'Protocol is not allowed') - - @pytest.mark.parametrize('url', [ - 'http://example.com', - 'http://example.com/test/', - 'http://www.example.com/', - 'http://www.example.com/test', - ]) + validator = inputs.URL(schemes=("sip", "irc")) + self.assert_bad_url(validator, url, "Protocol is not allowed") + + @pytest.mark.parametrize( + "url", + [ + "http://example.com", + "http://example.com/test/", + "http://www.example.com/", + "http://www.example.com/test", + ], + ) def test_valid_restricted_domains(self, url): - validator = inputs.URL(domains=['example.com', 'www.example.com']) + validator = inputs.URL(domains=["example.com", "www.example.com"]) assert validator(url) == url - @pytest.mark.parametrize('url', [ - 'http://somewhere.com', - 'https://somewhere.com', - ]) + @pytest.mark.parametrize("url", ["http://somewhere.com", "https://somewhere.com",]) def test_invalid_restricted_domains(self, url): - validator = inputs.URL(domains=['example.com', 'www.example.com']) - self.assert_bad_url(validator, url, 'Domain is not allowed') + validator = inputs.URL(domains=["example.com", "www.example.com"]) + self.assert_bad_url(validator, url, "Domain is not allowed") - @pytest.mark.parametrize('url', [ - 'http://somewhere.com', - 'https://somewhere.com', - ]) + @pytest.mark.parametrize("url", ["http://somewhere.com", "https://somewhere.com",]) def test_valid_excluded_domains(self, url): - validator = inputs.URL(exclude=['example.com', 'www.example.com']) + validator = inputs.URL(exclude=["example.com", "www.example.com"]) assert validator(url) == url - @pytest.mark.parametrize('url', [ - 'http://example.com', - 'http://example.com/test/', - 'http://www.example.com/', - 'http://www.example.com/test', - ]) + @pytest.mark.parametrize( + "url", + [ + "http://example.com", + "http://example.com/test/", + "http://www.example.com/", + "http://www.example.com/test", + ], + ) def test_excluded_domains(self, url): - validator = inputs.URL(exclude=['example.com', 'www.example.com']) - self.assert_bad_url(validator, url, 'Domain is not allowed') + validator = inputs.URL(exclude=["example.com", "www.example.com"]) + self.assert_bad_url(validator, url, "Domain is not allowed") def test_check(self): validator = inputs.URL(check=True, ip=True) - assert validator('http://www.google.com') == 'http://www.google.com', 'Should check domain' + assert ( + validator("http://www.google.com") == "http://www.google.com" + ), "Should check domain" # This test will fail on a network where this address is defined - self.assert_bad_url(validator, 'http://this-domain-should-not-exist.com', 'Domain does not exists') + self.assert_bad_url( + validator, + "http://this-domain-should-not-exist.com", + "Domain does not exists", + ) def test_schema(self): - assert inputs.URL().__schema__ == {'type': 'string', 'format': 'url'} + assert inputs.URL().__schema__ == {"type": "string", "format": "url"} class UrlTest(object): - @pytest.mark.parametrize('url', [ - 'http://www.djangoproject.com/', - 'http://localhost/', - 'http://example.com/', - 'http://www.example.com/', - 'http://www.example.com:8000/test', - 'http://valid-with-hyphens.com/', - 'http://subdomain.example.com/', - 'http://200.8.9.10/', - 'http://200.8.9.10:8000/test', - 'http://valid-----hyphens.com/', - 'http://example.com?something=value', - 'http://example.com/index.php?something=value&another=value2', - 'http://foo:bar@example.com', - 'http://foo:@example.com', - 'http://foo@example.com', - 'http://foo:@2001:db8:85a3::8a2e:370:7334', - 'http://foo2:qd1%r@example.com', - ]) + @pytest.mark.parametrize( + "url", + [ + "http://www.djangoproject.com/", + "http://localhost/", + "http://example.com/", + "http://www.example.com/", + "http://www.example.com:8000/test", + "http://valid-with-hyphens.com/", + "http://subdomain.example.com/", + "http://200.8.9.10/", + "http://200.8.9.10:8000/test", + "http://valid-----hyphens.com/", + "http://example.com?something=value", + "http://example.com/index.php?something=value&another=value2", + "http://foo:bar@example.com", + "http://foo:@example.com", + "http://foo@example.com", + "http://foo:@2001:db8:85a3::8a2e:370:7334", + "http://foo2:qd1%r@example.com", + ], + ) def test_valid_url(self, url): assert inputs.url(url) == url - @pytest.mark.parametrize('url', [ - 'foo', - 'http://', - 'http://example', - 'http://example.', - 'http://.com', - 'http://invalid-.com', - 'http://-invalid.com', - 'http://inv-.alid-.com', - 'http://inv-.-alid.com', - 'foo bar baz', - 'foo \u2713', - 'http://@foo:bar@example.com', - 'http://:bar@example.com', - 'http://bar:bar:bar@example.com', - 'http://300:300:300:300', - 'http://example.com:70000', - ]) + @pytest.mark.parametrize( + "url", + [ + "foo", + "http://", + "http://example", + "http://example.", + "http://.com", + "http://invalid-.com", + "http://-invalid.com", + "http://inv-.alid-.com", + "http://inv-.-alid.com", + "foo bar baz", + "foo \u2713", + "http://@foo:bar@example.com", + "http://:bar@example.com", + "http://bar:bar:bar@example.com", + "http://300:300:300:300", + "http://example.com:70000", + ], + ) def test_bad_url(self, url): with pytest.raises(ValueError) as cm: inputs.url(url) - assert text_type(cm.value).startswith('{0} is not a valid URL'.format(url)) - - @pytest.mark.parametrize('url', [ - 'google.com', - 'domain.google.com', - 'kevin:pass@google.com/path?query', - 'google.com/path?\u2713', - ]) + assert text_type(cm.value).startswith("{0} is not a valid URL".format(url)) + + @pytest.mark.parametrize( + "url", + [ + "google.com", + "domain.google.com", + "kevin:pass@google.com/path?query", + "google.com/path?\u2713", + ], + ) def test_bad_url_with_suggestion(self, url): with pytest.raises(ValueError) as cm: inputs.url(url) - assert text_type(cm.value) == '{0} is not a valid URL. Did you mean: http://{0}'.format(url) + assert text_type( + cm.value + ) == "{0} is not a valid URL. Did you mean: http://{0}".format(url) def test_schema(self): - assert inputs.url.__schema__ == {'type': 'string', 'format': 'url'} + assert inputs.url.__schema__ == {"type": "string", "format": "url"} class IPTest(object): - @pytest.mark.parametrize('value', [ - '200.8.9.10', - '127.0.0.1', - '2001:db8:85a3::8a2e:370:7334', - '::1', - ]) + @pytest.mark.parametrize( + "value", ["200.8.9.10", "127.0.0.1", "2001:db8:85a3::8a2e:370:7334", "::1",] + ) def test_valid_value(self, value): assert inputs.ip(value) == value - @pytest.mark.parametrize('value', [ - 'foo', - 'http://', - 'http://example', - 'http://example.', - 'http://.com', - 'http://invalid-.com', - 'http://-invalid.com', - 'http://inv-.alid-.com', - 'http://inv-.-alid.com', - 'foo bar baz', - 'foo \u2713', - 'http://@foo:bar@example.com', - 'http://:bar@example.com', - 'http://bar:bar:bar@example.com', - '127.0' - ]) + @pytest.mark.parametrize( + "value", + [ + "foo", + "http://", + "http://example", + "http://example.", + "http://.com", + "http://invalid-.com", + "http://-invalid.com", + "http://inv-.alid-.com", + "http://inv-.-alid.com", + "foo bar baz", + "foo \u2713", + "http://@foo:bar@example.com", + "http://:bar@example.com", + "http://bar:bar:bar@example.com", + "127.0", + ], + ) def test_bad_value(self, value): with pytest.raises(ValueError): inputs.ip(value) def test_schema(self): - assert inputs.ip.__schema__ == {'type': 'string', 'format': 'ip'} + assert inputs.ip.__schema__ == {"type": "string", "format": "ip"} class IPv4Test(object): - @pytest.mark.parametrize('value', [ - '200.8.9.10', - '127.0.0.1', - ]) + @pytest.mark.parametrize("value", ["200.8.9.10", "127.0.0.1",]) def test_valid_value(self, value): assert inputs.ipv4(value) == value - @pytest.mark.parametrize('value', [ - '2001:db8:85a3::8a2e:370:7334', - '::1', - 'foo', - 'http://', - 'http://example', - 'http://example.', - 'http://.com', - 'http://invalid-.com', - 'http://-invalid.com', - 'http://inv-.alid-.com', - 'http://inv-.-alid.com', - 'foo bar baz', - 'foo \u2713', - 'http://@foo:bar@example.com', - 'http://:bar@example.com', - 'http://bar:bar:bar@example.com', - '127.0' - ]) + @pytest.mark.parametrize( + "value", + [ + "2001:db8:85a3::8a2e:370:7334", + "::1", + "foo", + "http://", + "http://example", + "http://example.", + "http://.com", + "http://invalid-.com", + "http://-invalid.com", + "http://inv-.alid-.com", + "http://inv-.-alid.com", + "foo bar baz", + "foo \u2713", + "http://@foo:bar@example.com", + "http://:bar@example.com", + "http://bar:bar:bar@example.com", + "127.0", + ], + ) def test_bad_value(self, value): with pytest.raises(ValueError): inputs.ipv4(value) def test_schema(self): - assert inputs.ipv4.__schema__ == {'type': 'string', 'format': 'ipv4'} + assert inputs.ipv4.__schema__ == {"type": "string", "format": "ipv4"} class IPv6Test(object): - @pytest.mark.parametrize('value', [ - '2001:db8:85a3::8a2e:370:7334', - '::1', - ]) + @pytest.mark.parametrize("value", ["2001:db8:85a3::8a2e:370:7334", "::1",]) def test_valid_value(self, value): assert inputs.ipv6(value) == value - @pytest.mark.parametrize('value', [ - '200.8.9.10', - '127.0.0.1', - 'foo', - 'http://', - 'http://example', - 'http://example.', - 'http://.com', - 'http://invalid-.com', - 'http://-invalid.com', - 'http://inv-.alid-.com', - 'http://inv-.-alid.com', - 'foo bar baz', - 'foo \u2713', - 'http://@foo:bar@example.com', - 'http://:bar@example.com', - 'http://bar:bar:bar@example.com', - '127.0' - ]) + @pytest.mark.parametrize( + "value", + [ + "200.8.9.10", + "127.0.0.1", + "foo", + "http://", + "http://example", + "http://example.", + "http://.com", + "http://invalid-.com", + "http://-invalid.com", + "http://inv-.alid-.com", + "http://inv-.-alid.com", + "foo bar baz", + "foo \u2713", + "http://@foo:bar@example.com", + "http://:bar@example.com", + "http://bar:bar:bar@example.com", + "127.0", + ], + ) def test_bad_value(self, value): with pytest.raises(ValueError): inputs.ipv6(value) def test_schema(self): - assert inputs.ipv6.__schema__ == {'type': 'string', 'format': 'ipv6'} + assert inputs.ipv6.__schema__ == {"type": "string", "format": "ipv6"} class EmailTest(object): - def assert_bad_email(self, validator, value, msg=None): - msg = msg or '{0} is not a valid email' + msg = msg or "{0} is not a valid email" with pytest.raises(ValueError) as cm: validator(value) assert str(cm.value) == msg.format(value) - @pytest.mark.parametrize('value', [ - 'test@gmail.com', - 'coucou@cmoi.fr', - 'coucou+another@cmoi.fr', - 'Coucou@cmoi.fr', - 'me@valid-with-hyphens.com', - 'me@subdomain.example.com', - 'me@sub.subdomain.example.com', - 'Loïc.Accentué@voilà.fr', - ]) + @pytest.mark.parametrize( + "value", + [ + "test@gmail.com", + "coucou@cmoi.fr", + "coucou+another@cmoi.fr", + "Coucou@cmoi.fr", + "me@valid-with-hyphens.com", + "me@subdomain.example.com", + "me@sub.subdomain.example.com", + "Loïc.Accentué@voilà.fr", + ], + ) def test_valid_value_default(self, value): validator = inputs.email() assert validator(value) == value - @pytest.mark.parametrize('value', [ - 'me@localhost', - 'me@127.0.0.1', - 'me@127.1.2.3', - 'me@::1', - 'me@200.8.9.10', - 'me@2001:db8:85a3::8a2e:370:7334', - ]) + @pytest.mark.parametrize( + "value", + [ + "me@localhost", + "me@127.0.0.1", + "me@127.1.2.3", + "me@::1", + "me@200.8.9.10", + "me@2001:db8:85a3::8a2e:370:7334", + ], + ) def test_invalid_value_default(self, value): self.assert_bad_email(inputs.email(), value) - @pytest.mark.parametrize('value', [ - 'test@gmail.com', - 'test@live.com', - ]) + @pytest.mark.parametrize("value", ["test@gmail.com", "test@live.com",]) def test_valid_value_check(self, value): email = inputs.email(check=True) assert email(value) == value - @pytest.mark.parametrize('value', [ - 'coucou@not-found.fr', - 'me@localhost', - 'me@127.0.0.1', - 'me@127.1.2.3', - 'me@::1', - 'me@200.8.9.10', - 'me@2001:db8:85a3::8a2e:370:7334', - ]) + @pytest.mark.parametrize( + "value", + [ + "coucou@not-found.fr", + "me@localhost", + "me@127.0.0.1", + "me@127.1.2.3", + "me@::1", + "me@200.8.9.10", + "me@2001:db8:85a3::8a2e:370:7334", + ], + ) def test_invalid_values_check(self, value): email = inputs.email(check=True) self.assert_bad_email(email, value) - @pytest.mark.parametrize('value', [ - 'test@gmail.com', - 'coucou@cmoi.fr', - 'coucou+another@cmoi.fr', - 'Coucou@cmoi.fr', - 'me@valid-with-hyphens.com', - 'me@subdomain.example.com', - 'me@200.8.9.10', - 'me@2001:db8:85a3::8a2e:370:7334', - ]) + @pytest.mark.parametrize( + "value", + [ + "test@gmail.com", + "coucou@cmoi.fr", + "coucou+another@cmoi.fr", + "Coucou@cmoi.fr", + "me@valid-with-hyphens.com", + "me@subdomain.example.com", + "me@200.8.9.10", + "me@2001:db8:85a3::8a2e:370:7334", + ], + ) def test_valid_value_ip(self, value): email = inputs.email(ip=True) assert email(value) == value - @pytest.mark.parametrize('value', [ - 'me@localhost', - 'me@127.0.0.1', - 'me@127.1.2.3', - 'me@::1', - ]) + @pytest.mark.parametrize( + "value", ["me@localhost", "me@127.0.0.1", "me@127.1.2.3", "me@::1",] + ) def test_invalid_value_ip(self, value): email = inputs.email(ip=True) self.assert_bad_email(email, value) - @pytest.mark.parametrize('value', [ - 'test@gmail.com', - 'coucou@cmoi.fr', - 'coucou+another@cmoi.fr', - 'Coucou@cmoi.fr', - 'coucou@localhost', - 'me@valid-with-hyphens.com', - 'me@subdomain.example.com', - 'me@localhost', - ]) + @pytest.mark.parametrize( + "value", + [ + "test@gmail.com", + "coucou@cmoi.fr", + "coucou+another@cmoi.fr", + "Coucou@cmoi.fr", + "coucou@localhost", + "me@valid-with-hyphens.com", + "me@subdomain.example.com", + "me@localhost", + ], + ) def test_valid_value_local(self, value): email = inputs.email(local=True) assert email(value) == value - @pytest.mark.parametrize('value', [ - 'me@127.0.0.1', - 'me@127.1.2.3', - 'me@::1', - 'me@200.8.9.10', - 'me@2001:db8:85a3::8a2e:370:7334', - ]) + @pytest.mark.parametrize( + "value", + [ + "me@127.0.0.1", + "me@127.1.2.3", + "me@::1", + "me@200.8.9.10", + "me@2001:db8:85a3::8a2e:370:7334", + ], + ) def test_invalid_value_local(self, value): email = inputs.email(local=True) self.assert_bad_email(email, value) - @pytest.mark.parametrize('value', [ - 'test@gmail.com', - 'coucou@cmoi.fr', - 'coucou+another@cmoi.fr', - 'Coucou@cmoi.fr', - 'coucou@localhost', - 'me@valid-with-hyphens.com', - 'me@subdomain.example.com', - 'me@200.8.9.10', - 'me@2001:db8:85a3::8a2e:370:7334', - 'me@localhost', - 'me@127.0.0.1', - 'me@127.1.2.3', - 'me@::1', - ]) + @pytest.mark.parametrize( + "value", + [ + "test@gmail.com", + "coucou@cmoi.fr", + "coucou+another@cmoi.fr", + "Coucou@cmoi.fr", + "coucou@localhost", + "me@valid-with-hyphens.com", + "me@subdomain.example.com", + "me@200.8.9.10", + "me@2001:db8:85a3::8a2e:370:7334", + "me@localhost", + "me@127.0.0.1", + "me@127.1.2.3", + "me@::1", + ], + ) def test_valid_value_ip_and_local(self, value): email = inputs.email(ip=True, local=True) assert email(value) == value - @pytest.mark.parametrize('value', [ - 'test@gmail.com', - 'coucou@cmoi.fr', - 'coucou+another@cmoi.fr', - 'Coucou@cmoi.fr', - ]) + @pytest.mark.parametrize( + "value", + [ + "test@gmail.com", + "coucou@cmoi.fr", + "coucou+another@cmoi.fr", + "Coucou@cmoi.fr", + ], + ) def test_valid_value_domains(self, value): - email = inputs.email(domains=('gmail.com', 'cmoi.fr')) + email = inputs.email(domains=("gmail.com", "cmoi.fr")) assert email(value) == value - @pytest.mark.parametrize('value', [ - 'me@valid-with-hyphens.com', - 'me@subdomain.example.com', - 'me@localhost', - 'me@127.0.0.1', - 'me@127.1.2.3', - 'me@::1', - 'me@200.8.9.10', - 'me@2001:db8:85a3::8a2e:370:7334', - ]) + @pytest.mark.parametrize( + "value", + [ + "me@valid-with-hyphens.com", + "me@subdomain.example.com", + "me@localhost", + "me@127.0.0.1", + "me@127.1.2.3", + "me@::1", + "me@200.8.9.10", + "me@2001:db8:85a3::8a2e:370:7334", + ], + ) def test_invalid_value_domains(self, value): - email = inputs.email(domains=('gmail.com', 'cmoi.fr')) - self.assert_bad_email(email, value, '{0} does not belong to the authorized domains') - - @pytest.mark.parametrize('value', [ - 'test@gmail.com', - 'coucou@cmoi.fr', - 'coucou+another@cmoi.fr', - 'Coucou@cmoi.fr', - ]) + email = inputs.email(domains=("gmail.com", "cmoi.fr")) + self.assert_bad_email( + email, value, "{0} does not belong to the authorized domains" + ) + + @pytest.mark.parametrize( + "value", + [ + "test@gmail.com", + "coucou@cmoi.fr", + "coucou+another@cmoi.fr", + "Coucou@cmoi.fr", + ], + ) def test_valid_value_exclude(self, value): - email = inputs.email(exclude=('somewhere.com', 'foo.bar')) + email = inputs.email(exclude=("somewhere.com", "foo.bar")) assert email(value) == value - @pytest.mark.parametrize('value', [ - 'me@somewhere.com', - 'me@foo.bar', - ]) + @pytest.mark.parametrize("value", ["me@somewhere.com", "me@foo.bar",]) def test_invalid_value_exclude(self, value): - email = inputs.email(exclude=('somewhere.com', 'foo.bar')) - self.assert_bad_email(email, value, '{0} belongs to a forbidden domain') - - @pytest.mark.parametrize('value', [ - 'someone@', - '@somewhere', - 'email.somewhere.com', - '[invalid!email]', - 'me.@somewhere', - 'me..something@somewhere', - ]) + email = inputs.email(exclude=("somewhere.com", "foo.bar")) + self.assert_bad_email(email, value, "{0} belongs to a forbidden domain") + + @pytest.mark.parametrize( + "value", + [ + "someone@", + "@somewhere", + "email.somewhere.com", + "[invalid!email]", + "me.@somewhere", + "me..something@somewhere", + ], + ) def test_bad_email(self, value): email = inputs.email() self.assert_bad_email(email, value) def test_schema(self): - assert inputs.email().__schema__ == {'type': 'string', 'format': 'email'} + assert inputs.email().__schema__ == {"type": "string", "format": "email"} class RegexTest(object): - @pytest.mark.parametrize('value', [ - '123', - '1234567890', - '00000', - ]) + @pytest.mark.parametrize("value", ["123", "1234567890", "00000",]) def test_valid_input(self, value): - num_only = inputs.regex(r'^[0-9]+$') + num_only = inputs.regex(r"^[0-9]+$") assert num_only(value) == value - @pytest.mark.parametrize('value', [ - 'abc', - '123abc', - 'abc123', - '', - ]) + @pytest.mark.parametrize("value", ["abc", "123abc", "abc123", "",]) def test_bad_input(self, value): - num_only = inputs.regex(r'^[0-9]+$') + num_only = inputs.regex(r"^[0-9]+$") with pytest.raises(ValueError): num_only(value) def test_bad_pattern(self): with pytest.raises(re.error): - inputs.regex('[') + inputs.regex("[") def test_schema(self): - assert inputs.regex(r'^[0-9]+$').__schema__ == {'type': 'string', 'pattern': '^[0-9]+$'} + assert inputs.regex(r"^[0-9]+$").__schema__ == { + "type": "string", + "pattern": "^[0-9]+$", + } class BooleanTest(object): def test_false(self): - assert inputs.boolean('False') is False + assert inputs.boolean("False") is False def test_0(self): - assert inputs.boolean('0') is False + assert inputs.boolean("0") is False def test_true(self): - assert inputs.boolean('true') is True + assert inputs.boolean("true") is True def test_1(self): - assert inputs.boolean('1') is True + assert inputs.boolean("1") is True def test_case(self): - assert inputs.boolean('FaLSE') is False - assert inputs.boolean('FaLSE') is False + assert inputs.boolean("FaLSE") is False + assert inputs.boolean("FaLSE") is False def test_python_bool(self): assert inputs.boolean(True) is True @@ -700,12 +802,12 @@ def test_python_bool(self): def test_bad_boolean(self): with pytest.raises(ValueError): - inputs.boolean('blah') + inputs.boolean("blah") with pytest.raises(ValueError): inputs.boolean(None) def test_checkbox(self): - assert inputs.boolean('on') is True + assert inputs.boolean("on") is True def test_non_strings(self): assert inputs.boolean(0) is False @@ -713,22 +815,22 @@ def test_non_strings(self): assert inputs.boolean([]) is False def test_schema(self): - assert inputs.boolean.__schema__ == {'type': 'boolean'} + assert inputs.boolean.__schema__ == {"type": "boolean"} class DateTest(object): def test_later_than_1900(self): - assert inputs.date('1900-01-01') == datetime(1900, 1, 1) + assert inputs.date("1900-01-01") == datetime(1900, 1, 1) def test_error(self): with pytest.raises(ValueError): - inputs.date('2008-13-13') + inputs.date("2008-13-13") def test_default(self): - assert inputs.date('2008-08-01') == datetime(2008, 8, 1) + assert inputs.date("2008-08-01") == datetime(2008, 8, 1) def test_schema(self): - assert inputs.date.__schema__ == {'type': 'string', 'format': 'date'} + assert inputs.date.__schema__ == {"type": "string", "format": "date"} class NaturalTest(object): @@ -741,10 +843,10 @@ def test_default(self): def test_string(self): with pytest.raises(ValueError): - inputs.natural('foo') + inputs.natural("foo") def test_schema(self): - assert inputs.natural.__schema__ == {'type': 'integer', 'minimum': 0} + assert inputs.natural.__schema__ == {"type": "integer", "minimum": 0} class PositiveTest(object): @@ -761,7 +863,11 @@ def test_negative(self): inputs.positive(-1) def test_schema(self): - assert inputs.positive.__schema__ == {'type': 'integer', 'minimum': 0, 'exclusiveMinimum': True} + assert inputs.positive.__schema__ == { + "type": "integer", + "minimum": 0, + "exclusiveMinimum": True, + } class IntRangeTest(object): @@ -784,159 +890,180 @@ def test_higher(self): int_range(6) def test_schema(self): - assert inputs.int_range(1, 5).__schema__ == {'type': 'integer', 'minimum': 1, 'maximum': 5} + assert inputs.int_range(1, 5).__schema__ == { + "type": "integer", + "minimum": 1, + "maximum": 5, + } -interval_test_values = [( - # Full precision with explicit UTC. - '2013-01-01T12:30:00Z/P1Y2M3DT4H5M6S', +interval_test_values = [ ( - datetime(2013, 1, 1, 12, 30, 0, tzinfo=pytz.utc), - datetime(2014, 3, 5, 16, 35, 6, tzinfo=pytz.utc), + # Full precision with explicit UTC. + "2013-01-01T12:30:00Z/P1Y2M3DT4H5M6S", + ( + datetime(2013, 1, 1, 12, 30, 0, tzinfo=pytz.utc), + datetime(2014, 3, 5, 16, 35, 6, tzinfo=pytz.utc), + ), ), -), ( - # Full precision with alternate UTC indication - '2013-01-01T12:30+00:00/P2D', ( - datetime(2013, 1, 1, 12, 30, 0, tzinfo=pytz.utc), - datetime(2013, 1, 3, 12, 30, 0, tzinfo=pytz.utc), + # Full precision with alternate UTC indication + "2013-01-01T12:30+00:00/P2D", + ( + datetime(2013, 1, 1, 12, 30, 0, tzinfo=pytz.utc), + datetime(2013, 1, 3, 12, 30, 0, tzinfo=pytz.utc), + ), ), -), ( - # Implicit UTC with time - '2013-01-01T15:00/P1M', ( - datetime(2013, 1, 1, 15, 0, 0, tzinfo=pytz.utc), - datetime(2013, 1, 31, 15, 0, 0, tzinfo=pytz.utc), + # Implicit UTC with time + "2013-01-01T15:00/P1M", + ( + datetime(2013, 1, 1, 15, 0, 0, tzinfo=pytz.utc), + datetime(2013, 1, 31, 15, 0, 0, tzinfo=pytz.utc), + ), ), -), ( - # TZ conversion - '2013-01-01T17:00-05:00/P2W', ( - datetime(2013, 1, 1, 22, 0, 0, tzinfo=pytz.utc), - datetime(2013, 1, 15, 22, 0, 0, tzinfo=pytz.utc), + # TZ conversion + "2013-01-01T17:00-05:00/P2W", + ( + datetime(2013, 1, 1, 22, 0, 0, tzinfo=pytz.utc), + datetime(2013, 1, 15, 22, 0, 0, tzinfo=pytz.utc), + ), ), -), ( - # Date upgrade to midnight-midnight period - '2013-01-01/P3D', ( - datetime(2013, 1, 1, 0, 0, 0, tzinfo=pytz.utc), - datetime(2013, 1, 4, 0, 0, 0, 0, tzinfo=pytz.utc), + # Date upgrade to midnight-midnight period + "2013-01-01/P3D", + ( + datetime(2013, 1, 1, 0, 0, 0, tzinfo=pytz.utc), + datetime(2013, 1, 4, 0, 0, 0, 0, tzinfo=pytz.utc), + ), ), -), ( - # Start/end with UTC - '2013-01-01T12:00:00Z/2013-02-01T12:00:00Z', ( - datetime(2013, 1, 1, 12, 0, 0, tzinfo=pytz.utc), - datetime(2013, 2, 1, 12, 0, 0, tzinfo=pytz.utc), + # Start/end with UTC + "2013-01-01T12:00:00Z/2013-02-01T12:00:00Z", + ( + datetime(2013, 1, 1, 12, 0, 0, tzinfo=pytz.utc), + datetime(2013, 2, 1, 12, 0, 0, tzinfo=pytz.utc), + ), ), -), ( - # Start/end with time upgrade - '2013-01-01/2013-06-30', ( - datetime(2013, 1, 1, tzinfo=pytz.utc), - datetime(2013, 6, 30, tzinfo=pytz.utc), + # Start/end with time upgrade + "2013-01-01/2013-06-30", + ( + datetime(2013, 1, 1, tzinfo=pytz.utc), + datetime(2013, 6, 30, tzinfo=pytz.utc), + ), ), -), ( - # Start/end with TZ conversion - '2013-02-17T12:00:00-07:00/2013-02-28T15:00:00-07:00', ( - datetime(2013, 2, 17, 19, 0, 0, tzinfo=pytz.utc), - datetime(2013, 2, 28, 22, 0, 0, tzinfo=pytz.utc), + # Start/end with TZ conversion + "2013-02-17T12:00:00-07:00/2013-02-28T15:00:00-07:00", + ( + datetime(2013, 2, 17, 19, 0, 0, tzinfo=pytz.utc), + datetime(2013, 2, 28, 22, 0, 0, tzinfo=pytz.utc), + ), ), -), ( # Resolution expansion for single date(time) - # Second with UTC - '2013-01-01T12:30:45Z', - ( - datetime(2013, 1, 1, 12, 30, 45, tzinfo=pytz.utc), - datetime(2013, 1, 1, 12, 30, 46, tzinfo=pytz.utc), + ( # Resolution expansion for single date(time) + # Second with UTC + "2013-01-01T12:30:45Z", + ( + datetime(2013, 1, 1, 12, 30, 45, tzinfo=pytz.utc), + datetime(2013, 1, 1, 12, 30, 46, tzinfo=pytz.utc), + ), ), -), ( - # Second with tz conversion - '2013-01-01T12:30:45+02:00', ( - datetime(2013, 1, 1, 10, 30, 45, tzinfo=pytz.utc), - datetime(2013, 1, 1, 10, 30, 46, tzinfo=pytz.utc), + # Second with tz conversion + "2013-01-01T12:30:45+02:00", + ( + datetime(2013, 1, 1, 10, 30, 45, tzinfo=pytz.utc), + datetime(2013, 1, 1, 10, 30, 46, tzinfo=pytz.utc), + ), ), -), ( - # Second with implicit UTC - '2013-01-01T12:30:45', ( - datetime(2013, 1, 1, 12, 30, 45, tzinfo=pytz.utc), - datetime(2013, 1, 1, 12, 30, 46, tzinfo=pytz.utc), + # Second with implicit UTC + "2013-01-01T12:30:45", + ( + datetime(2013, 1, 1, 12, 30, 45, tzinfo=pytz.utc), + datetime(2013, 1, 1, 12, 30, 46, tzinfo=pytz.utc), + ), ), -), ( - # Minute with UTC - '2013-01-01T12:30+00:00', ( - datetime(2013, 1, 1, 12, 30, tzinfo=pytz.utc), - datetime(2013, 1, 1, 12, 31, tzinfo=pytz.utc), + # Minute with UTC + "2013-01-01T12:30+00:00", + ( + datetime(2013, 1, 1, 12, 30, tzinfo=pytz.utc), + datetime(2013, 1, 1, 12, 31, tzinfo=pytz.utc), + ), ), -), ( - # Minute with conversion - '2013-01-01T12:30+04:00', ( - datetime(2013, 1, 1, 8, 30, tzinfo=pytz.utc), - datetime(2013, 1, 1, 8, 31, tzinfo=pytz.utc), + # Minute with conversion + "2013-01-01T12:30+04:00", + ( + datetime(2013, 1, 1, 8, 30, tzinfo=pytz.utc), + datetime(2013, 1, 1, 8, 31, tzinfo=pytz.utc), + ), ), -), ( - # Minute with implicit UTC - '2013-01-01T12:30', ( - datetime(2013, 1, 1, 12, 30, tzinfo=pytz.utc), - datetime(2013, 1, 1, 12, 31, tzinfo=pytz.utc), + # Minute with implicit UTC + "2013-01-01T12:30", + ( + datetime(2013, 1, 1, 12, 30, tzinfo=pytz.utc), + datetime(2013, 1, 1, 12, 31, tzinfo=pytz.utc), + ), ), -), ( - # Hour, explicit UTC - '2013-01-01T12Z', ( - datetime(2013, 1, 1, 12, tzinfo=pytz.utc), - datetime(2013, 1, 1, 13, tzinfo=pytz.utc), + # Hour, explicit UTC + "2013-01-01T12Z", + ( + datetime(2013, 1, 1, 12, tzinfo=pytz.utc), + datetime(2013, 1, 1, 13, tzinfo=pytz.utc), + ), ), -), ( - # Hour with offset - '2013-01-01T12-07:00', ( - datetime(2013, 1, 1, 19, tzinfo=pytz.utc), - datetime(2013, 1, 1, 20, tzinfo=pytz.utc), + # Hour with offset + "2013-01-01T12-07:00", + ( + datetime(2013, 1, 1, 19, tzinfo=pytz.utc), + datetime(2013, 1, 1, 20, tzinfo=pytz.utc), + ), ), -), ( - # Hour with implicit UTC - '2013-01-01T12', ( - datetime(2013, 1, 1, 12, tzinfo=pytz.utc), - datetime(2013, 1, 1, 13, tzinfo=pytz.utc), + # Hour with implicit UTC + "2013-01-01T12", + ( + datetime(2013, 1, 1, 12, tzinfo=pytz.utc), + datetime(2013, 1, 1, 13, tzinfo=pytz.utc), + ), ), -), ( - # Interval with trailing zero fractional seconds should - # be accepted. - '2013-01-01T12:00:00.0/2013-01-01T12:30:00.000000', ( - datetime(2013, 1, 1, 12, tzinfo=pytz.utc), - datetime(2013, 1, 1, 12, 30, tzinfo=pytz.utc), + # Interval with trailing zero fractional seconds should + # be accepted. + "2013-01-01T12:00:00.0/2013-01-01T12:30:00.000000", + ( + datetime(2013, 1, 1, 12, tzinfo=pytz.utc), + datetime(2013, 1, 1, 12, 30, tzinfo=pytz.utc), + ), ), -)] +] class IsoIntervalTest(object): - @pytest.mark.parametrize('value,expected', interval_test_values) + @pytest.mark.parametrize("value,expected", interval_test_values) def test_valid_value(self, value, expected): assert inputs.iso8601interval(value) == expected def test_error_message(self): with pytest.raises(ValueError) as cm: - inputs.iso8601interval('2013-01-01/blah') - expected = 'Invalid argument: 2013-01-01/blah. argument must be a valid ISO8601 date/time interval.' + inputs.iso8601interval("2013-01-01/blah") + expected = "Invalid argument: 2013-01-01/blah. argument must be a valid ISO8601 date/time interval." assert str(cm.value) == expected - @pytest.mark.parametrize('value', [ - '2013-01T14:', - '', - 'asdf', - '01/01/2013', - ]) + @pytest.mark.parametrize("value", ["2013-01T14:", "", "asdf", "01/01/2013",]) def test_bad_values(self, value): with pytest.raises(ValueError): inputs.iso8601interval(value) def test_schema(self): - assert inputs.iso8601interval.__schema__ == {'type': 'string', 'format': 'iso8601-interval'} + assert inputs.iso8601interval.__schema__ == { + "type": "string", + "format": "iso8601-interval", + } diff --git a/tests/test_logging.py b/tests/test_logging.py index 2db298af..4b366d91 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -9,16 +9,16 @@ def test_namespace_loggers_log_to_flask_app_logger(self, app, client, caplog): caplog.set_level(logging.INFO, logger=app.logger.name) api = restx.Api(app) - ns1 = api.namespace('ns1', path='/ns1') - ns2 = api.namespace('ns2', path='/ns2') + ns1 = api.namespace("ns1", path="/ns1") + ns2 = api.namespace("ns2", path="/ns2") - @ns1.route('/') + @ns1.route("/") class Ns1(restx.Resource): def get(self): ns1.logger.info("hello from ns1") pass - @ns2.route('/') + @ns2.route("/") class Ns2(restx.Resource): def get(self): ns2.logger.info("hello from ns2") @@ -38,16 +38,16 @@ def test_defaults_to_app_level(self, app, client, caplog): caplog.set_level(logging.INFO, logger=app.logger.name) api = restx.Api(app) - ns1 = api.namespace('ns1', path='/ns1') - ns2 = api.namespace('ns2', path='/ns2') + ns1 = api.namespace("ns1", path="/ns1") + ns2 = api.namespace("ns2", path="/ns2") - @ns1.route('/') + @ns1.route("/") class Ns1(restx.Resource): def get(self): ns1.logger.debug("hello from ns1") pass - @ns2.route('/') + @ns2.route("/") class Ns2(restx.Resource): def get(self): ns2.logger.info("hello from ns2") @@ -67,17 +67,17 @@ def test_override_app_level(self, app, client, caplog): caplog.set_level(logging.INFO, logger=app.logger.name) api = restx.Api(app) - ns1 = api.namespace('ns1', path='/ns1') + ns1 = api.namespace("ns1", path="/ns1") ns1.logger.setLevel(logging.DEBUG) - ns2 = api.namespace('ns2', path='/ns2') + ns2 = api.namespace("ns2", path="/ns2") - @ns1.route('/') + @ns1.route("/") class Ns1(restx.Resource): def get(self): ns1.logger.debug("hello from ns1") pass - @ns2.route('/') + @ns2.route("/") class Ns2(restx.Resource): def get(self): ns2.logger.debug("hello from ns2") @@ -98,21 +98,21 @@ def test_namespace_additional_handler(self, app, client, caplog, tmp_path): log_file = tmp_path / "v1.log" api = restx.Api(app) - ns1 = api.namespace('ns1', path='/ns1') + ns1 = api.namespace("ns1", path="/ns1") # set up a file handler for ns1 only # FileHandler only supports Path object on Python >= 3.6 -> cast to str fh = logging.FileHandler(str(log_file)) ns1.logger.addHandler(fh) - ns2 = api.namespace('ns2', path='/ns2') + ns2 = api.namespace("ns2", path="/ns2") - @ns1.route('/') + @ns1.route("/") class Ns1(restx.Resource): def get(self): ns1.logger.info("hello from ns1") pass - @ns2.route('/') + @ns2.route("/") class Ns2(restx.Resource): def get(self): ns2.logger.info("hello from ns2") diff --git a/tests/test_marshalling.py b/tests/test_marshalling.py index 3ff8737a..41460345 100644 --- a/tests/test_marshalling.py +++ b/tests/test_marshalling.py @@ -3,9 +3,7 @@ import pytest -from flask_restx import ( - marshal, marshal_with, marshal_with_field, fields, Api, Resource -) +from flask_restx import marshal, marshal_with, marshal_with_field, fields, Api, Resource from collections import OrderedDict @@ -18,179 +16,205 @@ def get(self): class MarshallingTest(object): def test_marshal(self): - model = OrderedDict([('foo', fields.Raw)]) - marshal_dict = OrderedDict([('foo', 'bar'), ('bat', 'baz')]) + model = OrderedDict([("foo", fields.Raw)]) + marshal_dict = OrderedDict([("foo", "bar"), ("bat", "baz")]) output = marshal(marshal_dict, model) assert isinstance(output, dict) assert not isinstance(output, OrderedDict) - assert output == {'foo': 'bar'} + assert output == {"foo": "bar"} def test_marshal_wildcard_nested(self): - nest = fields.Nested(OrderedDict([('thumbnail', fields.String), ('video', fields.String)])) + nest = fields.Nested( + OrderedDict([("thumbnail", fields.String), ("video", fields.String)]) + ) wild = fields.Wildcard(nest) - wildcard_fields = OrderedDict([('*', wild)]) - model = OrderedDict([('preview', fields.Nested(wildcard_fields))]) - sub_dict = OrderedDict([ - ('9:16', {'thumbnail': 24, 'video': 12}), - ('16:9', {'thumbnail': 25, 'video': 11}), - ('1:1', {'thumbnail': 26, 'video': 10}) - ]) - marshal_dict = OrderedDict([('preview', sub_dict)]) + wildcard_fields = OrderedDict([("*", wild)]) + model = OrderedDict([("preview", fields.Nested(wildcard_fields))]) + sub_dict = OrderedDict( + [ + ("9:16", {"thumbnail": 24, "video": 12}), + ("16:9", {"thumbnail": 25, "video": 11}), + ("1:1", {"thumbnail": 26, "video": 10}), + ] + ) + marshal_dict = OrderedDict([("preview", sub_dict)]) output = marshal(marshal_dict, model) - assert output == {'preview': {'1:1': {'thumbnail': '26', 'video': '10'}, - '16:9': {'thumbnail': '25', 'video': '11'}, - '9:16': {'thumbnail': '24', 'video': '12'}}} + assert output == { + "preview": { + "1:1": {"thumbnail": "26", "video": "10"}, + "16:9": {"thumbnail": "25", "video": "11"}, + "9:16": {"thumbnail": "24", "video": "12"}, + } + } def test_marshal_wildcard_list(self): wild = fields.Wildcard(fields.List(fields.String)) - wildcard_fields = OrderedDict([('*', wild)]) - model = OrderedDict([('preview', fields.Nested(wildcard_fields))]) - sub_dict = OrderedDict([ - ('1:1', [1, 2, 3]), - ('16:9', [4, 5, 6]), - ('9:16', [7, 8, 9]) - ]) - marshal_dict = OrderedDict([('preview', sub_dict)]) + wildcard_fields = OrderedDict([("*", wild)]) + model = OrderedDict([("preview", fields.Nested(wildcard_fields))]) + sub_dict = OrderedDict( + [("1:1", [1, 2, 3]), ("16:9", [4, 5, 6]), ("9:16", [7, 8, 9])] + ) + marshal_dict = OrderedDict([("preview", sub_dict)]) output = marshal(marshal_dict, model) - assert output == {'preview': {'9:16': ['7', '8', '9'], - '16:9': ['4', '5', '6'], - '1:1': ['1', '2', '3']}} + assert output == { + "preview": { + "9:16": ["7", "8", "9"], + "16:9": ["4", "5", "6"], + "1:1": ["1", "2", "3"], + } + } def test_marshal_with_envelope(self): - model = OrderedDict([('foo', fields.Raw)]) - marshal_dict = OrderedDict([('foo', 'bar'), ('bat', 'baz')]) - output = marshal(marshal_dict, model, envelope='hey') - assert output == {'hey': {'foo': 'bar'}} + model = OrderedDict([("foo", fields.Raw)]) + marshal_dict = OrderedDict([("foo", "bar"), ("bat", "baz")]) + output = marshal(marshal_dict, model, envelope="hey") + assert output == {"hey": {"foo": "bar"}} def test_marshal_wildcard_with_envelope(self): wild = fields.Wildcard(fields.String) - model = OrderedDict([('foo', fields.Raw), ('*', wild)]) - marshal_dict = OrderedDict([('foo', {'bat': 'baz'}), ('a', 'toto'), ('b', 'tata')]) - output = marshal(marshal_dict, model, envelope='hey') - assert output == {'hey': {'a': 'toto', 'b': 'tata', 'foo': {'bat': 'baz'}}} + model = OrderedDict([("foo", fields.Raw), ("*", wild)]) + marshal_dict = OrderedDict( + [("foo", {"bat": "baz"}), ("a", "toto"), ("b", "tata")] + ) + output = marshal(marshal_dict, model, envelope="hey") + assert output == {"hey": {"a": "toto", "b": "tata", "foo": {"bat": "baz"}}} def test_marshal_with_skip_none(self): - model = OrderedDict([('foo', fields.Raw), ('bat', fields.Raw), ('qux', fields.Raw)]) - marshal_dict = OrderedDict([('foo', 'bar'), ('bat', None)]) + model = OrderedDict( + [("foo", fields.Raw), ("bat", fields.Raw), ("qux", fields.Raw)] + ) + marshal_dict = OrderedDict([("foo", "bar"), ("bat", None)]) output = marshal(marshal_dict, model, skip_none=True) - assert output == {'foo': 'bar'} + assert output == {"foo": "bar"} def test_marshal_wildcard_with_skip_none(self): wild = fields.Wildcard(fields.String) - model = OrderedDict([('foo', fields.Raw), ('*', wild)]) - marshal_dict = OrderedDict([('foo', None), ('bat', None), ('baz', 'biz'), ('bar', None)]) + model = OrderedDict([("foo", fields.Raw), ("*", wild)]) + marshal_dict = OrderedDict( + [("foo", None), ("bat", None), ("baz", "biz"), ("bar", None)] + ) output = marshal(marshal_dict, model, skip_none=True) - assert output == {'baz': 'biz'} + assert output == {"baz": "biz"} def test_marshal_decorator(self): - model = OrderedDict([('foo', fields.Raw)]) + model = OrderedDict([("foo", fields.Raw)]) @marshal_with(model) def try_me(): - return OrderedDict([('foo', 'bar'), ('bat', 'baz')]) - assert try_me() == {'foo': 'bar'} + return OrderedDict([("foo", "bar"), ("bat", "baz")]) + + assert try_me() == {"foo": "bar"} def test_marshal_decorator_with_envelope(self): - model = OrderedDict([('foo', fields.Raw)]) + model = OrderedDict([("foo", fields.Raw)]) - @marshal_with(model, envelope='hey') + @marshal_with(model, envelope="hey") def try_me(): - return OrderedDict([('foo', 'bar'), ('bat', 'baz')]) + return OrderedDict([("foo", "bar"), ("bat", "baz")]) - assert try_me() == {'hey': {'foo': 'bar'}} + assert try_me() == {"hey": {"foo": "bar"}} def test_marshal_decorator_with_skip_none(self): - model = OrderedDict([('foo', fields.Raw), ('bat', fields.Raw), ('qux', fields.Raw)]) + model = OrderedDict( + [("foo", fields.Raw), ("bat", fields.Raw), ("qux", fields.Raw)] + ) @marshal_with(model, skip_none=True) def try_me(): - return OrderedDict([('foo', 'bar'), ('bat', None)]) + return OrderedDict([("foo", "bar"), ("bat", None)]) - assert try_me() == {'foo': 'bar'} + assert try_me() == {"foo": "bar"} def test_marshal_decorator_tuple(self): - model = OrderedDict([('foo', fields.Raw)]) + model = OrderedDict([("foo", fields.Raw)]) @marshal_with(model) def try_me(): - headers = {'X-test': 123} - return OrderedDict([('foo', 'bar'), ('bat', 'baz')]), 200, headers - assert try_me() == ({'foo': 'bar'}, 200, {'X-test': 123}) + headers = {"X-test": 123} + return OrderedDict([("foo", "bar"), ("bat", "baz")]), 200, headers + + assert try_me() == ({"foo": "bar"}, 200, {"X-test": 123}) def test_marshal_decorator_tuple_with_envelope(self): - model = OrderedDict([('foo', fields.Raw)]) + model = OrderedDict([("foo", fields.Raw)]) - @marshal_with(model, envelope='hey') + @marshal_with(model, envelope="hey") def try_me(): - headers = {'X-test': 123} - return OrderedDict([('foo', 'bar'), ('bat', 'baz')]), 200, headers + headers = {"X-test": 123} + return OrderedDict([("foo", "bar"), ("bat", "baz")]), 200, headers - assert try_me() == ({'hey': {'foo': 'bar'}}, 200, {'X-test': 123}) + assert try_me() == ({"hey": {"foo": "bar"}}, 200, {"X-test": 123}) def test_marshal_decorator_tuple_with_skip_none(self): - model = OrderedDict([('foo', fields.Raw), ('bat', fields.Raw), ('qux', fields.Raw)]) + model = OrderedDict( + [("foo", fields.Raw), ("bat", fields.Raw), ("qux", fields.Raw)] + ) @marshal_with(model, skip_none=True) def try_me(): - headers = {'X-test': 123} - return OrderedDict([('foo', 'bar'), ('bat', None)]), 200, headers + headers = {"X-test": 123} + return OrderedDict([("foo", "bar"), ("bat", None)]), 200, headers - assert try_me() == ({'foo': 'bar'}, 200, {'X-test': 123}) + assert try_me() == ({"foo": "bar"}, 200, {"X-test": 123}) def test_marshal_field_decorator(self): model = fields.Raw @marshal_with_field(model) def try_me(): - return 'foo' - assert try_me() == 'foo' + return "foo" + + assert try_me() == "foo" def test_marshal_field_decorator_tuple(self): model = fields.Raw @marshal_with_field(model) def try_me(): - return 'foo', 200, {'X-test': 123} - assert try_me() == ('foo', 200, {'X-test': 123}) + return "foo", 200, {"X-test": 123} + + assert try_me() == ("foo", 200, {"X-test": 123}) def test_marshal_field(self): - model = OrderedDict({'foo': fields.Raw()}) - marshal_fields = OrderedDict([('foo', 'bar'), ('bat', 'baz')]) + model = OrderedDict({"foo": fields.Raw()}) + marshal_fields = OrderedDict([("foo", "bar"), ("bat", "baz")]) output = marshal(marshal_fields, model) - assert output == {'foo': 'bar'} + assert output == {"foo": "bar"} def test_marshal_tuple(self): - model = OrderedDict({'foo': fields.Raw}) - marshal_fields = OrderedDict([('foo', 'bar'), ('bat', 'baz')]) + model = OrderedDict({"foo": fields.Raw}) + marshal_fields = OrderedDict([("foo", "bar"), ("bat", "baz")]) output = marshal((marshal_fields,), model) - assert output == [{'foo': 'bar'}] + assert output == [{"foo": "bar"}] def test_marshal_tuple_with_envelope(self): - model = OrderedDict({'foo': fields.Raw}) - marshal_fields = OrderedDict([('foo', 'bar'), ('bat', 'baz')]) - output = marshal((marshal_fields,), model, envelope='hey') - assert output == {'hey': [{'foo': 'bar'}]} + model = OrderedDict({"foo": fields.Raw}) + marshal_fields = OrderedDict([("foo", "bar"), ("bat", "baz")]) + output = marshal((marshal_fields,), model, envelope="hey") + assert output == {"hey": [{"foo": "bar"}]} def test_marshal_tuple_with_skip_none(self): - model = OrderedDict([('foo', fields.Raw), ('bat', fields.Raw), ('qux', fields.Raw)]) - marshal_fields = OrderedDict([('foo', 'bar'), ('bat', None)]) + model = OrderedDict( + [("foo", fields.Raw), ("bat", fields.Raw), ("qux", fields.Raw)] + ) + marshal_fields = OrderedDict([("foo", "bar"), ("bat", None)]) output = marshal((marshal_fields,), model, skip_none=True) - assert output == [{'foo': 'bar'}] + assert output == [{"foo": "bar"}] def test_marshal_nested(self): model = { - 'foo': fields.Raw, - 'fee': fields.Nested({'fye': fields.String}), + "foo": fields.Raw, + "fee": fields.Nested({"fye": fields.String}), } marshal_fields = { - 'foo': 'bar', - 'bat': 'baz', - 'fee': {'fye': 'fum'}, + "foo": "bar", + "bat": "baz", + "fee": {"fye": "fum"}, } expected = { - 'foo': 'bar', - 'fee': {'fye': 'fum'}, + "foo": "bar", + "fee": {"fye": "fum"}, } output = marshal(marshal_fields, model) @@ -198,225 +222,252 @@ def test_marshal_nested(self): assert output == expected def test_marshal_nested_ordered(self): - model = OrderedDict([ - ('foo', fields.Raw), - ('fee', fields.Nested({ - 'fye': fields.String, - })) - ]) + model = OrderedDict( + [("foo", fields.Raw), ("fee", fields.Nested({"fye": fields.String,}))] + ) marshal_fields = { - 'foo': 'bar', - 'bat': 'baz', - 'fee': {'fye': 'fum'}, + "foo": "bar", + "bat": "baz", + "fee": {"fye": "fum"}, } - expected = OrderedDict([ - ('foo', 'bar'), - ('fee', OrderedDict([('fye', 'fum')])) - ]) + expected = OrderedDict([("foo", "bar"), ("fee", OrderedDict([("fye", "fum")]))]) output = marshal(marshal_fields, model, ordered=True) assert isinstance(output, OrderedDict) assert output == expected - assert isinstance(output['fee'], OrderedDict) + assert isinstance(output["fee"], OrderedDict) def test_marshal_nested_with_non_null(self): - model = OrderedDict([ - ('foo', fields.Raw), - ('fee', fields.Nested( - OrderedDict([ - ('fye', fields.String), - ('blah', fields.String) - ]), allow_null=False)) - ]) - marshal_fields = [OrderedDict([('foo', 'bar'), - ('bat', 'baz'), - ('fee', None)])] + model = OrderedDict( + [ + ("foo", fields.Raw), + ( + "fee", + fields.Nested( + OrderedDict([("fye", fields.String), ("blah", fields.String)]), + allow_null=False, + ), + ), + ] + ) + marshal_fields = [OrderedDict([("foo", "bar"), ("bat", "baz"), ("fee", None)])] output = marshal(marshal_fields, model) - expected = [OrderedDict([ - ('foo', 'bar'), - ('fee', OrderedDict([('fye', None), ('blah', None)])) - ])] + expected = [ + OrderedDict( + [("foo", "bar"), ("fee", OrderedDict([("fye", None), ("blah", None)]))] + ) + ] assert output == expected def test_marshal_nested_with_null(self): - model = OrderedDict([ - ('foo', fields.Raw), - ('fee', fields.Nested( - OrderedDict([ - ('fye', fields.String), - ('blah', fields.String) - ]), allow_null=True)) - ]) - marshal_fields = OrderedDict([('foo', 'bar'), - ('bat', 'baz'), - ('fee', None)]) + model = OrderedDict( + [ + ("foo", fields.Raw), + ( + "fee", + fields.Nested( + OrderedDict([("fye", fields.String), ("blah", fields.String)]), + allow_null=True, + ), + ), + ] + ) + marshal_fields = OrderedDict([("foo", "bar"), ("bat", "baz"), ("fee", None)]) output = marshal(marshal_fields, model) - expected = OrderedDict([('foo', 'bar'), ('fee', None)]) + expected = OrderedDict([("foo", "bar"), ("fee", None)]) assert output == expected def test_marshal_nested_with_skip_none(self): - model = OrderedDict([ - ('foo', fields.Raw), - ('fee', fields.Nested( - OrderedDict([ - ('fye', fields.String) - ]), skip_none=True)) - ]) - marshal_fields = OrderedDict([('foo', 'bar'), - ('bat', 'baz'), - ('fee', None)]) + model = OrderedDict( + [ + ("foo", fields.Raw), + ( + "fee", + fields.Nested( + OrderedDict([("fye", fields.String)]), skip_none=True + ), + ), + ] + ) + marshal_fields = OrderedDict([("foo", "bar"), ("bat", "baz"), ("fee", None)]) output = marshal(marshal_fields, model, skip_none=True) - expected = OrderedDict([('foo', 'bar')]) + expected = OrderedDict([("foo", "bar")]) assert output == expected def test_allow_null_presents_data(self): - model = OrderedDict([ - ('foo', fields.Raw), - ('fee', fields.Nested( - OrderedDict([ - ('fye', fields.String), - ('blah', fields.String) - ]), allow_null=True)) - ]) - marshal_fields = OrderedDict([('foo', 'bar'), - ('bat', 'baz'), - ('fee', {'blah': 'cool'})]) + model = OrderedDict( + [ + ("foo", fields.Raw), + ( + "fee", + fields.Nested( + OrderedDict([("fye", fields.String), ("blah", fields.String)]), + allow_null=True, + ), + ), + ] + ) + marshal_fields = OrderedDict( + [("foo", "bar"), ("bat", "baz"), ("fee", {"blah": "cool"})] + ) output = marshal(marshal_fields, model) - expected = OrderedDict([ - ('foo', 'bar'), - ('fee', OrderedDict([('fye', None), ('blah', 'cool')])) - ]) + expected = OrderedDict( + [("foo", "bar"), ("fee", OrderedDict([("fye", None), ("blah", "cool")]))] + ) assert output == expected def test_skip_none_presents_data(self): - model = OrderedDict([ - ('foo', fields.Raw), - ('fee', fields.Nested( - OrderedDict([ - ('fye', fields.String), - ('blah', fields.String), - ('foe', fields.String) - ]), skip_none=True)) - ]) - marshal_fields = OrderedDict([('foo', 'bar'), - ('bat', 'baz'), - ('fee', {'blah': 'cool', 'foe': None})]) + model = OrderedDict( + [ + ("foo", fields.Raw), + ( + "fee", + fields.Nested( + OrderedDict( + [ + ("fye", fields.String), + ("blah", fields.String), + ("foe", fields.String), + ] + ), + skip_none=True, + ), + ), + ] + ) + marshal_fields = OrderedDict( + [("foo", "bar"), ("bat", "baz"), ("fee", {"blah": "cool", "foe": None})] + ) output = marshal(marshal_fields, model) - expected = OrderedDict([ - ('foo', 'bar'), - ('fee', OrderedDict([('blah', 'cool')])) - ]) + expected = OrderedDict( + [("foo", "bar"), ("fee", OrderedDict([("blah", "cool")]))] + ) assert output == expected def test_marshal_nested_property(self): class TestObject(object): @property def fee(self): - return {'blah': 'cool'} - model = OrderedDict([ - ('foo', fields.Raw), - ('fee', fields.Nested( - OrderedDict([ - ('fye', fields.String), - ('blah', fields.String) - ]), allow_null=True)) - ]) + return {"blah": "cool"} + + model = OrderedDict( + [ + ("foo", fields.Raw), + ( + "fee", + fields.Nested( + OrderedDict([("fye", fields.String), ("blah", fields.String)]), + allow_null=True, + ), + ), + ] + ) obj = TestObject() - obj.foo = 'bar' - obj.bat = 'baz' + obj.foo = "bar" + obj.bat = "baz" output = marshal([obj], model) - expected = [OrderedDict([ - ('foo', 'bar'), - ('fee', OrderedDict([ - ('fye', None), - ('blah', 'cool') - ])) - ])] + expected = [ + OrderedDict( + [ + ("foo", "bar"), + ("fee", OrderedDict([("fye", None), ("blah", "cool")])), + ] + ) + ] assert output == expected def test_marshal_nested_property_with_skip_none(self): class TestObject(object): @property def fee(self): - return {'blah': 'cool', 'foe': None} - model = OrderedDict([ - ('foo', fields.Raw), - ('fee', fields.Nested( - OrderedDict([ - ('fye', fields.String), - ('blah', fields.String), - ('foe', fields.String) - ]), skip_none=True)) - ]) + return {"blah": "cool", "foe": None} + + model = OrderedDict( + [ + ("foo", fields.Raw), + ( + "fee", + fields.Nested( + OrderedDict( + [ + ("fye", fields.String), + ("blah", fields.String), + ("foe", fields.String), + ] + ), + skip_none=True, + ), + ), + ] + ) obj = TestObject() - obj.foo = 'bar' - obj.bat = 'baz' + obj.foo = "bar" + obj.bat = "baz" output = marshal([obj], model) - expected = [OrderedDict([ - ('foo', 'bar'), - ('fee', OrderedDict([ - ('blah', 'cool') - ])) - ])] + expected = [ + OrderedDict([("foo", "bar"), ("fee", OrderedDict([("blah", "cool")]))]) + ] assert output == expected def test_marshal_list(self): - model = OrderedDict([ - ('foo', fields.Raw), - ('fee', fields.List(fields.String)) - ]) - marshal_fields = OrderedDict([('foo', 'bar'), - ('bat', 'baz'), - ('fee', ['fye', 'fum'])]) + model = OrderedDict([("foo", fields.Raw), ("fee", fields.List(fields.String))]) + marshal_fields = OrderedDict( + [("foo", "bar"), ("bat", "baz"), ("fee", ["fye", "fum"])] + ) output = marshal(marshal_fields, model) - expected = OrderedDict([('foo', 'bar'), ('fee', (['fye', 'fum']))]) + expected = OrderedDict([("foo", "bar"), ("fee", (["fye", "fum"]))]) assert output == expected def test_marshal_list_of_nesteds(self): - model = OrderedDict([ - ('foo', fields.Raw), - ('fee', fields.List(fields.Nested({ - 'fye': fields.String - }))) - ]) - marshal_fields = OrderedDict([('foo', 'bar'), - ('bat', 'baz'), - ('fee', {'fye': 'fum'})]) + model = OrderedDict( + [ + ("foo", fields.Raw), + ("fee", fields.List(fields.Nested({"fye": fields.String}))), + ] + ) + marshal_fields = OrderedDict( + [("foo", "bar"), ("bat", "baz"), ("fee", {"fye": "fum"})] + ) output = marshal(marshal_fields, model) - expected = OrderedDict([('foo', 'bar'), - ('fee', [OrderedDict([('fye', 'fum')])])]) + expected = OrderedDict( + [("foo", "bar"), ("fee", [OrderedDict([("fye", "fum")])])] + ) assert output == expected def test_marshal_list_of_lists(self): - model = OrderedDict([ - ('foo', fields.Raw), - ('fee', fields.List(fields.List( - fields.String))) - ]) - marshal_fields = OrderedDict([('foo', 'bar'), - ('bat', 'baz'), - ('fee', [['fye'], ['fum']])]) + model = OrderedDict( + [("foo", fields.Raw), ("fee", fields.List(fields.List(fields.String)))] + ) + marshal_fields = OrderedDict( + [("foo", "bar"), ("bat", "baz"), ("fee", [["fye"], ["fum"]])] + ) output = marshal(marshal_fields, model) - expected = OrderedDict([('foo', 'bar'), ('fee', [['fye'], ['fum']])]) + expected = OrderedDict([("foo", "bar"), ("fee", [["fye"], ["fum"]])]) assert output == expected def test_marshal_nested_dict(self): - model = OrderedDict([ - ('foo', fields.Raw), - ('bar', OrderedDict([ - ('a', fields.Raw), - ('b', fields.Raw), - ])), - ]) - marshal_fields = OrderedDict([('foo', 'foo-val'), - ('bar', 'bar-val'), - ('bat', 'bat-val'), - ('a', 1), ('b', 2), ('c', 3)]) + model = OrderedDict( + [ + ("foo", fields.Raw), + ("bar", OrderedDict([("a", fields.Raw), ("b", fields.Raw),])), + ] + ) + marshal_fields = OrderedDict( + [ + ("foo", "foo-val"), + ("bar", "bar-val"), + ("bat", "bat-val"), + ("a", 1), + ("b", 2), + ("c", 3), + ] + ) output = marshal(marshal_fields, model) - expected = OrderedDict([('foo', 'foo-val'), - ('bar', OrderedDict([('a', 1), ('b', 2)]))]) + expected = OrderedDict( + [("foo", "foo-val"), ("bar", OrderedDict([("a", 1), ("b", 2)]))] + ) assert output == expected @pytest.mark.options(debug=True) @@ -425,11 +476,11 @@ def test_will_prettyprint_json_in_debug_mode(self, app, client): class Foo1(Resource): def get(self): - return {'foo': 'bar', 'baz': 'asdf'} + return {"foo": "bar", "baz": "asdf"} - api.add_resource(Foo1, '/foo', endpoint='bar') + api.add_resource(Foo1, "/foo", endpoint="bar") - foo = client.get('/foo') + foo = client.get("/foo") # Python's dictionaries have random order (as of "new" Pythons, # anyway), so we can't verify the actual output here. We just @@ -437,24 +488,24 @@ def get(self): lines = foo.data.splitlines() lines = [line.decode() for line in lines] assert "{" == lines[0] - assert lines[1].startswith(' ') - assert lines[2].startswith(' ') + assert lines[1].startswith(" ") + assert lines[2].startswith(" ") assert "}" == lines[3] # Assert our trailing newline. - assert foo.data.endswith(b'\n') + assert foo.data.endswith(b"\n") def test_json_float_marshalled(self, app, client): api = Api(app) class FooResource(Resource): - fields = {'foo': fields.Float} + fields = {"foo": fields.Float} def get(self): return marshal({"foo": 3.0}, self.fields) - api.add_resource(FooResource, '/api') + api.add_resource(FooResource, "/api") - resp = client.get('/api') + resp = client.get("/api") assert resp.status_code == 200 - assert resp.data.decode('utf-8') == '{"foo": 3.0}\n' + assert resp.data.decode("utf-8") == '{"foo": 3.0}\n' diff --git a/tests/test_model.py b/tests/test_model.py index 52a69343..e72b8ec1 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -11,272 +11,212 @@ class ModelTest(object): def test_model_as_flat_dict(self): - model = Model('Person', { - 'name': fields.String, - 'age': fields.Integer, - 'birthdate': fields.DateTime, - }) + model = Model( + "Person", + { + "name": fields.String, + "age": fields.Integer, + "birthdate": fields.DateTime, + }, + ) assert isinstance(model, dict) assert not isinstance(model, OrderedDict) assert model.__schema__ == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - } + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, }, - 'type': 'object' + "type": "object", } def test_model_as_ordered_dict(self): - model = OrderedModel('Person', [ - ('name', fields.String), - ('age', fields.Integer), - ('birthdate', fields.DateTime), - ]) + model = OrderedModel( + "Person", + [ + ("name", fields.String), + ("age", fields.Integer), + ("birthdate", fields.DateTime), + ], + ) assert isinstance(model, OrderedDict) assert model.__schema__ == { - 'type': 'object', - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - } - } + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, + }, } def test_model_as_nested_dict(self): - address = Model('Address', { - 'road': fields.String, - }) - - person = Model('Person', { - 'name': fields.String, - 'age': fields.Integer, - 'birthdate': fields.DateTime, - 'address': fields.Nested(address) - }) + address = Model("Address", {"road": fields.String,}) + + person = Model( + "Person", + { + "name": fields.String, + "age": fields.Integer, + "birthdate": fields.DateTime, + "address": fields.Nested(address), + }, + ) assert person.__schema__ == { # 'required': ['address'], - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - }, - 'address': { - '$ref': '#/definitions/Address', - } + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, + "address": {"$ref": "#/definitions/Address",}, }, - 'type': 'object' + "type": "object", } assert address.__schema__ == { - 'properties': { - 'road': { - 'type': 'string' - }, - }, - 'type': 'object' + "properties": {"road": {"type": "string"},}, + "type": "object", } def test_model_as_dict_with_list(self): - model = Model('Person', { - 'name': fields.String, - 'age': fields.Integer, - 'tags': fields.List(fields.String), - }) + model = Model( + "Person", + { + "name": fields.String, + "age": fields.Integer, + "tags": fields.List(fields.String), + }, + ) assert model.__schema__ == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'tags': { - 'type': 'array', - 'items': { - 'type': 'string' - } - } + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "tags": {"type": "array", "items": {"type": "string"}}, }, - 'type': 'object' + "type": "object", } def test_model_as_nested_dict_with_list(self): - address = Model('Address', { - 'road': fields.String, - }) - - person = Model('Person', { - 'name': fields.String, - 'age': fields.Integer, - 'birthdate': fields.DateTime, - 'addresses': fields.List(fields.Nested(address)) - }) + address = Model("Address", {"road": fields.String,}) + + person = Model( + "Person", + { + "name": fields.String, + "age": fields.Integer, + "birthdate": fields.DateTime, + "addresses": fields.List(fields.Nested(address)), + }, + ) assert person.__schema__ == { # 'required': ['address'], - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, + "addresses": { + "type": "array", + "items": {"$ref": "#/definitions/Address",}, }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - }, - 'addresses': { - 'type': 'array', - 'items': { - '$ref': '#/definitions/Address', - } - } }, - 'type': 'object' + "type": "object", } assert address.__schema__ == { - 'properties': { - 'road': { - 'type': 'string' - }, - }, - 'type': 'object' + "properties": {"road": {"type": "string"},}, + "type": "object", } def test_model_with_required(self): - model = Model('Person', { - 'name': fields.String(required=True), - 'age': fields.Integer, - 'birthdate': fields.DateTime(required=True), - }) + model = Model( + "Person", + { + "name": fields.String(required=True), + "age": fields.Integer, + "birthdate": fields.DateTime(required=True), + }, + ) assert model.__schema__ == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - } + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, }, - 'required': ['birthdate', 'name'], - 'type': 'object' + "required": ["birthdate", "name"], + "type": "object", } def test_model_as_nested_dict_and_required(self): - address = Model('Address', { - 'road': fields.String, - }) - - person = Model('Person', { - 'name': fields.String, - 'age': fields.Integer, - 'birthdate': fields.DateTime, - 'address': fields.Nested(address, required=True) - }) + address = Model("Address", {"road": fields.String,}) + + person = Model( + "Person", + { + "name": fields.String, + "age": fields.Integer, + "birthdate": fields.DateTime, + "address": fields.Nested(address, required=True), + }, + ) assert person.__schema__ == { - 'required': ['address'], - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - }, - 'address': { - '$ref': '#/definitions/Address', - } + "required": ["address"], + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, + "address": {"$ref": "#/definitions/Address",}, }, - 'type': 'object' + "type": "object", } assert address.__schema__ == { - 'properties': { - 'road': { - 'type': 'string' - }, - }, - 'type': 'object' + "properties": {"road": {"type": "string"},}, + "type": "object", } def test_model_with_discriminator(self): - model = Model('Person', { - 'name': fields.String(discriminator=True), - 'age': fields.Integer, - }) + model = Model( + "Person", + {"name": fields.String(discriminator=True), "age": fields.Integer,}, + ) assert model.__schema__ == { - 'properties': { - 'name': {'type': 'string'}, - 'age': {'type': 'integer'}, - }, - 'discriminator': 'name', - 'required': ['name'], - 'type': 'object' + "properties": {"name": {"type": "string"}, "age": {"type": "integer"},}, + "discriminator": "name", + "required": ["name"], + "type": "object", } def test_model_with_discriminator_override_require(self): - model = Model('Person', { - 'name': fields.String(discriminator=True, required=False), - 'age': fields.Integer, - }) + model = Model( + "Person", + { + "name": fields.String(discriminator=True, required=False), + "age": fields.Integer, + }, + ) assert model.__schema__ == { - 'properties': { - 'name': {'type': 'string'}, - 'age': {'type': 'integer'}, - }, - 'discriminator': 'name', - 'required': ['name'], - 'type': 'object' + "properties": {"name": {"type": "string"}, "age": {"type": "integer"},}, + "discriminator": "name", + "required": ["name"], + "type": "object", } def test_model_deepcopy(self): - parent = Model('Person', { - 'name': fields.String, - 'age': fields.Integer(description="foo"), - }) + parent = Model( + "Person", {"name": fields.String, "age": fields.Integer(description="foo"),} + ) - child = parent.inherit('Child', { - 'extra': fields.String, - }) + child = parent.inherit("Child", {"extra": fields.String,}) parent_copy = copy.deepcopy(parent) @@ -287,256 +227,164 @@ def test_model_deepcopy(self): assert parent["age"].description == "foo" assert parent_copy["age"].description == "bar" - child = parent.inherit('Child', { - 'extra': fields.String, - }) + child = parent.inherit("Child", {"extra": fields.String,}) child_copy = copy.deepcopy(child) assert child_copy.__parents__[0] == parent def test_clone_from_instance(self): - parent = Model('Parent', { - 'name': fields.String, - 'age': fields.Integer, - 'birthdate': fields.DateTime, - }) + parent = Model( + "Parent", + { + "name": fields.String, + "age": fields.Integer, + "birthdate": fields.DateTime, + }, + ) - child = parent.clone('Child', { - 'extra': fields.String, - }) + child = parent.clone("Child", {"extra": fields.String,}) assert child.__schema__ == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - }, - 'extra': { - 'type': 'string' - } + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, + "extra": {"type": "string"}, }, - 'type': 'object' + "type": "object", } def test_clone_from_class(self): - parent = Model('Parent', { - 'name': fields.String, - 'age': fields.Integer, - 'birthdate': fields.DateTime, - }) + parent = Model( + "Parent", + { + "name": fields.String, + "age": fields.Integer, + "birthdate": fields.DateTime, + }, + ) - child = Model.clone('Child', parent, { - 'extra': fields.String, - }) + child = Model.clone("Child", parent, {"extra": fields.String,}) assert child.__schema__ == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - }, - 'extra': { - 'type': 'string' - } + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, + "extra": {"type": "string"}, }, - 'type': 'object' + "type": "object", } def test_clone_from_instance_with_multiple_parents(self): - grand_parent = Model('GrandParent', { - 'grand_parent': fields.String, - }) - - parent = Model('Parent', { - 'name': fields.String, - 'age': fields.Integer, - 'birthdate': fields.DateTime, - }) + grand_parent = Model("GrandParent", {"grand_parent": fields.String,}) + + parent = Model( + "Parent", + { + "name": fields.String, + "age": fields.Integer, + "birthdate": fields.DateTime, + }, + ) - child = grand_parent.clone('Child', parent, { - 'extra': fields.String, - }) + child = grand_parent.clone("Child", parent, {"extra": fields.String,}) assert child.__schema__ == { - 'properties': { - 'grand_parent': { - 'type': 'string' - }, - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - }, - 'extra': { - 'type': 'string' - } + "properties": { + "grand_parent": {"type": "string"}, + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, + "extra": {"type": "string"}, }, - 'type': 'object' + "type": "object", } def test_clone_from_class_with_multiple_parents(self): - grand_parent = Model('GrandParent', { - 'grand_parent': fields.String, - }) - - parent = Model('Parent', { - 'name': fields.String, - 'age': fields.Integer, - 'birthdate': fields.DateTime, - }) + grand_parent = Model("GrandParent", {"grand_parent": fields.String,}) + + parent = Model( + "Parent", + { + "name": fields.String, + "age": fields.Integer, + "birthdate": fields.DateTime, + }, + ) - child = Model.clone('Child', grand_parent, parent, { - 'extra': fields.String, - }) + child = Model.clone("Child", grand_parent, parent, {"extra": fields.String,}) assert child.__schema__ == { - 'properties': { - 'grand_parent': { - 'type': 'string' - }, - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - }, - 'extra': { - 'type': 'string' - } + "properties": { + "grand_parent": {"type": "string"}, + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, + "extra": {"type": "string"}, }, - 'type': 'object' + "type": "object", } def test_inherit_from_instance(self): - parent = Model('Parent', { - 'name': fields.String, - 'age': fields.Integer, - }) + parent = Model("Parent", {"name": fields.String, "age": fields.Integer,}) - child = parent.inherit('Child', { - 'extra': fields.String, - }) + child = parent.inherit("Child", {"extra": fields.String,}) assert parent.__schema__ == { - 'properties': { - 'name': {'type': 'string'}, - 'age': {'type': 'integer'}, - }, - 'type': 'object' + "properties": {"name": {"type": "string"}, "age": {"type": "integer"},}, + "type": "object", } assert child.__schema__ == { - 'allOf': [ - {'$ref': '#/definitions/Parent'}, - { - 'properties': { - 'extra': {'type': 'string'} - }, - 'type': 'object' - } + "allOf": [ + {"$ref": "#/definitions/Parent"}, + {"properties": {"extra": {"type": "string"}}, "type": "object"}, ] } def test_inherit_from_class(self): - parent = Model('Parent', { - 'name': fields.String, - 'age': fields.Integer, - }) + parent = Model("Parent", {"name": fields.String, "age": fields.Integer,}) - child = Model.inherit('Child', parent, { - 'extra': fields.String, - }) + child = Model.inherit("Child", parent, {"extra": fields.String,}) assert parent.__schema__ == { - 'properties': { - 'name': {'type': 'string'}, - 'age': {'type': 'integer'}, - }, - 'type': 'object' + "properties": {"name": {"type": "string"}, "age": {"type": "integer"},}, + "type": "object", } assert child.__schema__ == { - 'allOf': [ - {'$ref': '#/definitions/Parent'}, - { - 'properties': { - 'extra': {'type': 'string'} - }, - 'type': 'object' - } + "allOf": [ + {"$ref": "#/definitions/Parent"}, + {"properties": {"extra": {"type": "string"}}, "type": "object"}, ] } def test_inherit_from_class_from_multiple_parents(self): - grand_parent = Model('GrandParent', { - 'grand_parent': fields.String, - }) + grand_parent = Model("GrandParent", {"grand_parent": fields.String,}) - parent = Model('Parent', { - 'name': fields.String, - 'age': fields.Integer, - }) + parent = Model("Parent", {"name": fields.String, "age": fields.Integer,}) - child = Model.inherit('Child', grand_parent, parent, { - 'extra': fields.String, - }) + child = Model.inherit("Child", grand_parent, parent, {"extra": fields.String,}) assert child.__schema__ == { - 'allOf': [ - {'$ref': '#/definitions/GrandParent'}, - {'$ref': '#/definitions/Parent'}, - { - 'properties': { - 'extra': {'type': 'string'} - }, - 'type': 'object' - } + "allOf": [ + {"$ref": "#/definitions/GrandParent"}, + {"$ref": "#/definitions/Parent"}, + {"properties": {"extra": {"type": "string"}}, "type": "object"}, ] } def test_inherit_from_instance_from_multiple_parents(self): - grand_parent = Model('GrandParent', { - 'grand_parent': fields.String, - }) + grand_parent = Model("GrandParent", {"grand_parent": fields.String,}) - parent = Model('Parent', { - 'name': fields.String, - 'age': fields.Integer, - }) + parent = Model("Parent", {"name": fields.String, "age": fields.Integer,}) - child = grand_parent.inherit('Child', parent, { - 'extra': fields.String, - }) + child = grand_parent.inherit("Child", parent, {"extra": fields.String,}) assert child.__schema__ == { - 'allOf': [ - {'$ref': '#/definitions/GrandParent'}, - {'$ref': '#/definitions/Parent'}, - { - 'properties': { - 'extra': {'type': 'string'} - }, - 'type': 'object' - } + "allOf": [ + {"$ref": "#/definitions/GrandParent"}, + {"$ref": "#/definitions/Parent"}, + {"properties": {"extra": {"type": "string"}}, "type": "object"}, ] } @@ -565,34 +413,23 @@ class Child1: class Child2: pass - parent = Model('Person', { - 'name': fields.String, - 'age': fields.Integer, - }) + parent = Model("Person", {"name": fields.String, "age": fields.Integer,}) - child1 = parent.inherit('Child1', { - 'extra1': fields.String, - }) + child1 = parent.inherit("Child1", {"extra1": fields.String,}) - child2 = parent.inherit('Child2', { - 'extra2': fields.String, - }) + child2 = parent.inherit("Child2", {"extra2": fields.String,}) mapping = { Child1: child1, Child2: child2, } - output = Model('Output', { - 'child': fields.Polymorph(mapping) - }) + output = Model("Output", {"child": fields.Polymorph(mapping)}) # Should use the common ancestor assert output.__schema__ == { - 'properties': { - 'child': {'$ref': '#/definitions/Person'}, - }, - 'type': 'object' + "properties": {"child": {"$ref": "#/definitions/Person"},}, + "type": "object", } def test_validate(self): @@ -600,11 +437,11 @@ def test_validate(self): from werkzeug.exceptions import BadRequest class IPAddress(fields.Raw): - __schema_type__ = 'string' - __schema_format__ = 'ipv4' + __schema_type__ = "string" + __schema_format__ = "ipv4" - data = {'ip': '192.168.1'} - model = Model('MyModel', {'ip': IPAddress()}) + data = {"ip": "192.168.1"} + model = Model("MyModel", {"ip": IPAddress()}) # Test that validate without a FormatChecker does not check if a # primitive type conforms to the defined format property @@ -618,93 +455,61 @@ class IPAddress(fields.Raw): class ModelSchemaTestCase(object): def test_model_schema(self): - address = SchemaModel('Address', { - 'properties': { - 'road': { - 'type': 'string' - }, - }, - 'type': 'object' - }) - - person = SchemaModel('Person', { - # 'required': ['address'], - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - }, - 'address': { - '$ref': '#/definitions/Address', - } + address = SchemaModel( + "Address", {"properties": {"road": {"type": "string"},}, "type": "object"} + ) + + person = SchemaModel( + "Person", + { + # 'required': ['address'], + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, + "address": {"$ref": "#/definitions/Address",}, + }, + "type": "object", }, - 'type': 'object' - }) + ) assert person.__schema__ == { # 'required': ['address'], - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - }, - 'address': { - '$ref': '#/definitions/Address', - } + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, + "address": {"$ref": "#/definitions/Address",}, }, - 'type': 'object' + "type": "object", } assert address.__schema__ == { - 'properties': { - 'road': { - 'type': 'string' - }, - }, - 'type': 'object' + "properties": {"road": {"type": "string"},}, + "type": "object", } class ModelDeprecattionsTest(object): def test_extend_is_deprecated(self): - parent = Model('Parent', { - 'name': fields.String, - 'age': fields.Integer, - 'birthdate': fields.DateTime, - }) + parent = Model( + "Parent", + { + "name": fields.String, + "age": fields.Integer, + "birthdate": fields.DateTime, + }, + ) with pytest.warns(DeprecationWarning): - child = parent.extend('Child', { - 'extra': fields.String, - }) + child = parent.extend("Child", {"extra": fields.String,}) assert child.__schema__ == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - }, - 'extra': { - 'type': 'string' - } + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, + "extra": {"type": "string"}, }, - 'type': 'object' + "type": "object", } diff --git a/tests/test_namespace.py b/tests/test_namespace.py index f02cfb84..93644a7e 100644 --- a/tests/test_namespace.py +++ b/tests/test_namespace.py @@ -8,24 +8,33 @@ class NamespaceTest(object): def test_parser(self): - api = Namespace('test') + api = Namespace("test") assert isinstance(api.parser(), restx.reqparse.RequestParser) def test_doc_decorator(self): - api = Namespace('test') - params = {'q': {'description': 'some description'}} + api = Namespace("test") + params = {"q": {"description": "some description"}} @api.doc(params=params) class TestResource(restx.Resource): pass - assert hasattr(TestResource, '__apidoc__') - assert TestResource.__apidoc__ == {'params': params} + assert hasattr(TestResource, "__apidoc__") + assert TestResource.__apidoc__ == {"params": params} def test_doc_with_inheritance(self): - api = Namespace('test') - base_params = {'q': {'description': 'some description', 'type': 'string', 'paramType': 'query'}} - child_params = {'q': {'description': 'some new description'}, 'other': {'description': 'another param'}} + api = Namespace("test") + base_params = { + "q": { + "description": "some description", + "type": "string", + "paramType": "query", + } + } + child_params = { + "q": {"description": "some new description"}, + "other": {"description": "another param"}, + } @api.doc(params=base_params) class BaseResource(restx.Resource): @@ -35,88 +44,89 @@ class BaseResource(restx.Resource): class TestResource(BaseResource): pass - assert TestResource.__apidoc__ == {'params': { - 'q': { - 'description': 'some new description', - 'type': 'string', - 'paramType': 'query' - }, - 'other': {'description': 'another param'}, - }} + assert TestResource.__apidoc__ == { + "params": { + "q": { + "description": "some new description", + "type": "string", + "paramType": "query", + }, + "other": {"description": "another param"}, + } + } def test_model(self): - api = Namespace('test') - api.model('Person', {}) - assert 'Person' in api.models - assert isinstance(api.models['Person'], Model) + api = Namespace("test") + api.model("Person", {}) + assert "Person" in api.models + assert isinstance(api.models["Person"], Model) def test_ordered_model(self): - api = Namespace('test', ordered=True) - api.model('Person', {}) - assert 'Person' in api.models - assert isinstance(api.models['Person'], OrderedModel) + api = Namespace("test", ordered=True) + api.model("Person", {}) + assert "Person" in api.models + assert isinstance(api.models["Person"], OrderedModel) def test_schema_model(self): - api = Namespace('test') - api.schema_model('Person', {}) - assert 'Person' in api.models + api = Namespace("test") + api.schema_model("Person", {}) + assert "Person" in api.models def test_clone(self): - api = Namespace('test') - parent = api.model('Parent', {}) - api.clone('Child', parent, {}) + api = Namespace("test") + parent = api.model("Parent", {}) + api.clone("Child", parent, {}) - assert 'Child' in api.models - assert 'Parent' in api.models + assert "Child" in api.models + assert "Parent" in api.models def test_clone_with_multiple_parents(self): - api = Namespace('test') - grand_parent = api.model('GrandParent', {}) - parent = api.model('Parent', {}) - api.clone('Child', grand_parent, parent, {}) + api = Namespace("test") + grand_parent = api.model("GrandParent", {}) + parent = api.model("Parent", {}) + api.clone("Child", grand_parent, parent, {}) - assert 'Child' in api.models - assert 'Parent' in api.models - assert 'GrandParent' in api.models + assert "Child" in api.models + assert "Parent" in api.models + assert "GrandParent" in api.models def test_inherit(self): authorizations = { - 'apikey': { - 'type': 'apiKey', - 'in': 'header', - 'name': 'X-API-KEY' - } + "apikey": {"type": "apiKey", "in": "header", "name": "X-API-KEY"} } - api = Namespace('test', authorizations=authorizations) - parent = api.model('Parent', {}) - api.inherit('Child', parent, {}) + api = Namespace("test", authorizations=authorizations) + parent = api.model("Parent", {}) + api.inherit("Child", parent, {}) - assert 'Parent' in api.models - assert 'Child' in api.models + assert "Parent" in api.models + assert "Child" in api.models assert api.authorizations == authorizations def test_inherit_from_multiple_parents(self): - api = Namespace('test') - grand_parent = api.model('GrandParent', {}) - parent = api.model('Parent', {}) - api.inherit('Child', grand_parent, parent, {}) + api = Namespace("test") + grand_parent = api.model("GrandParent", {}) + parent = api.model("Parent", {}) + api.inherit("Child", grand_parent, parent, {}) - assert 'GrandParent' in api.models - assert 'Parent' in api.models - assert 'Child' in api.models + assert "GrandParent" in api.models + assert "Parent" in api.models + assert "Child" in api.models def test_api_payload(self, app, client): api = restx.Api(app, validate=True) - ns = restx.Namespace('apples') + ns = restx.Namespace("apples") api.add_namespace(ns) - fields = ns.model('Person', { - 'name': restx.fields.String(required=True), - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = ns.model( + "Person", + { + "name": restx.fields.String(required=True), + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @ns.route('/validation/') + @ns.route("/validation/") class Payload(restx.Resource): payload = None @@ -126,10 +136,10 @@ def post(self): return {} data = { - 'name': 'John Doe', - 'age': 15, + "name": "John Doe", + "age": 15, } - client.post_json('/apples/validation/', data) + client.post_json("/apples/validation/", data) assert Payload.payload == data diff --git a/tests/test_payload.py b/tests/test_payload.py index 20789d54..bebae6d0 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -7,90 +7,102 @@ class PayloadTest(object): def assert_errors(self, client, url, data, *errors): out = client.post_json(url, data, status=400) - assert 'message' in out - assert 'errors' in out + assert "message" in out + assert "errors" in out for error in errors: - assert error in out['errors'] + assert error in out["errors"] def test_validation_false_on_constructor(self, app, client): api = restx.Api(app, validate=False) - fields = api.model('Person', { - 'name': restx.fields.String(required=True), - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String(required=True), + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/validation/') + @api.route("/validation/") class ValidationOff(restx.Resource): @api.expect(fields) def post(self): return {} - data = client.post_json('/validation/', {}) + data = client.post_json("/validation/", {}) assert data == {} def test_validation_false_on_constructor_with_override(self, app, client): api = restx.Api(app, validate=False) - fields = api.model('Person', { - 'name': restx.fields.String(required=True), - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String(required=True), + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/validation/') + @api.route("/validation/") class ValidationOn(restx.Resource): @api.expect(fields, validate=True) def post(self): return {} - self.assert_errors(client, '/validation/', {}, 'name') + self.assert_errors(client, "/validation/", {}, "name") def test_validation_true_on_constructor(self, app, client): api = restx.Api(app, validate=True) - fields = api.model('Person', { - 'name': restx.fields.String(required=True), - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String(required=True), + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/validation/') + @api.route("/validation/") class ValidationOff(restx.Resource): @api.expect(fields) def post(self): return {} - self.assert_errors(client, '/validation/', {}, 'name') + self.assert_errors(client, "/validation/", {}, "name") def test_validation_true_on_constructor_with_override(self, app, client): api = restx.Api(app, validate=True) - fields = api.model('Person', { - 'name': restx.fields.String(required=True), - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String(required=True), + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/validation/') + @api.route("/validation/") class ValidationOff(restx.Resource): @api.expect(fields, validate=False) def post(self): return {} - data = client.post_json('/validation/', {}) + data = client.post_json("/validation/", {}) assert data == {} def _setup_api_format_checker_tests(self, app, format_checker=None): class IPAddress(restx.fields.Raw): - __schema_type__ = 'string' - __schema_format__ = 'ipv4' + __schema_type__ = "string" + __schema_format__ = "ipv4" api = restx.Api(app, format_checker=format_checker) - model = api.model('MyModel', {'ip': IPAddress(required=True)}) + model = api.model("MyModel", {"ip": IPAddress(required=True)}) - @api.route('/format_checker/') + @api.route("/format_checker/") class TestResource(restx.Resource): @api.expect(model, validate=True) def post(self): @@ -99,66 +111,76 @@ def post(self): def test_format_checker_none_on_constructor(self, app, client): self._setup_api_format_checker_tests(app) - out = client.post_json('/format_checker/', {'ip': '192.168.1'}) + out = client.post_json("/format_checker/", {"ip": "192.168.1"}) assert out == {} def test_format_checker_object_on_constructor(self, app, client): from jsonschema import FormatChecker + self._setup_api_format_checker_tests(app, format_checker=FormatChecker()) - out = client.post_json('/format_checker/', {'ip': '192.168.1'}, status=400) - assert 'ipv4' in out['errors']['ip'] + out = client.post_json("/format_checker/", {"ip": "192.168.1"}, status=400) + assert "ipv4" in out["errors"]["ip"] def test_validation_false_in_config(self, app, client): - app.config['RESTX_VALIDATE'] = False + app.config["RESTX_VALIDATE"] = False api = restx.Api(app) - fields = api.model('Person', { - 'name': restx.fields.String(required=True), - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String(required=True), + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/validation/') + @api.route("/validation/") class ValidationOff(restx.Resource): @api.expect(fields) def post(self): return {} - out = client.post_json('/validation/', {}) + out = client.post_json("/validation/", {}) # assert response.status_code == 200 # out = json.loads(response.data.decode('utf8')) assert out == {} def test_validation_in_config(self, app, client): - app.config['RESTX_VALIDATE'] = True + app.config["RESTX_VALIDATE"] = True api = restx.Api(app) - fields = api.model('Person', { - 'name': restx.fields.String(required=True), - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String(required=True), + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/validation/') + @api.route("/validation/") class ValidationOn(restx.Resource): @api.expect(fields) def post(self): return {} - self.assert_errors(client, '/validation/', {}, 'name') + self.assert_errors(client, "/validation/", {}, "name") def test_api_payload(self, app, client): api = restx.Api(app, validate=True) - fields = api.model('Person', { - 'name': restx.fields.String(required=True), - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String(required=True), + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/validation/') + @api.route("/validation/") class Payload(restx.Resource): payload = None @@ -168,76 +190,74 @@ def post(self): return {} data = { - 'name': 'John Doe', - 'age': 15, + "name": "John Doe", + "age": 15, } - client.post_json('/validation/', data) + client.post_json("/validation/", data) assert Payload.payload == data def test_validation_with_inheritance(self, app, client): - '''It should perform validation with inheritance (allOf/$ref)''' + """It should perform validation with inheritance (allOf/$ref)""" api = restx.Api(app, validate=True) - fields = api.model('Parent', { - 'name': restx.fields.String(required=True), - }) + fields = api.model("Parent", {"name": restx.fields.String(required=True),}) - child_fields = api.inherit('Child', fields, { - 'age': restx.fields.Integer, - }) + child_fields = api.inherit("Child", fields, {"age": restx.fields.Integer,}) - @api.route('/validation/') + @api.route("/validation/") class Inheritance(restx.Resource): @api.expect(child_fields) def post(self): return {} - client.post_json('/validation/', { - 'name': 'John Doe', - 'age': 15, - }) + client.post_json("/validation/", {"name": "John Doe", "age": 15,}) - self.assert_errors(client, '/validation/', { - 'age': '15', - }, 'name', 'age') + self.assert_errors(client, "/validation/", {"age": "15",}, "name", "age") def test_validation_on_list(self, app, client): - '''It should perform validation on lists''' + """It should perform validation on lists""" api = restx.Api(app, validate=True) - person = api.model('Person', { - 'name': restx.fields.String(required=True), - 'age': restx.fields.Integer(required=True), - }) - - family = api.model('Family', { - 'name': restx.fields.String(required=True), - 'members': restx.fields.List(restx.fields.Nested(person)) - }) - - @api.route('/validation/') + person = api.model( + "Person", + { + "name": restx.fields.String(required=True), + "age": restx.fields.Integer(required=True), + }, + ) + + family = api.model( + "Family", + { + "name": restx.fields.String(required=True), + "members": restx.fields.List(restx.fields.Nested(person)), + }, + ) + + @api.route("/validation/") class List(restx.Resource): @api.expect(family) def post(self): return {} - self.assert_errors(client, '/validation/', { - 'name': 'Doe', - 'members': [{'name': 'Jonn'}, {'age': 42}] - }, 'members.0.age', 'members.1.name') + self.assert_errors( + client, + "/validation/", + {"name": "Doe", "members": [{"name": "Jonn"}, {"age": 42}]}, + "members.0.age", + "members.1.name", + ) def _setup_expect_validation_single_resource_tests(self, app): # Setup a minimal Api with endpoint that expects in input payload # a single object of a resource api = restx.Api(app, validate=True) - user = api.model('User', { - 'username': restx.fields.String() - }) + user = api.model("User", {"username": restx.fields.String()}) - @api.route('/validation/') + @api.route("/validation/") class Users(restx.Resource): @api.expect(user) def post(self): @@ -248,11 +268,9 @@ def _setup_expect_validation_collection_resource_tests(self, app): # one or more objects of a resource api = restx.Api(app, validate=True) - user = api.model('User', { - 'username': restx.fields.String() - }) + user = api.model("User", {"username": restx.fields.String()}) - @api.route('/validation/') + @api.route("/validation/") class Users(restx.Resource): @api.expect([user]) def post(self): @@ -262,82 +280,77 @@ def test_expect_validation_single_resource_success(self, app, client): self._setup_expect_validation_single_resource_tests(app) # Input payload is a valid JSON object - out = client.post_json('/validation/', { - 'username': 'alice' - }) + out = client.post_json("/validation/", {"username": "alice"}) assert {} == out def test_expect_validation_single_resource_error(self, app, client): self._setup_expect_validation_single_resource_tests(app) # Input payload is an invalid JSON object - self.assert_errors(client, '/validation/', { - 'username': 123 - }, 'username') + self.assert_errors(client, "/validation/", {"username": 123}, "username") # Input payload is a JSON array (expected JSON object) - self.assert_errors(client, '/validation/', [{ - 'username': 123 - }], '') + self.assert_errors(client, "/validation/", [{"username": 123}], "") def test_expect_validation_collection_resource_success(self, app, client): self._setup_expect_validation_collection_resource_tests(app) # Input payload is a valid JSON object - out = client.post_json('/validation/', { - 'username': 'alice' - }) + out = client.post_json("/validation/", {"username": "alice"}) assert {} == out # Input payload is a JSON array with valid JSON objects - out = client.post_json('/validation/', [ - {'username': 'alice'}, - {'username': 'bob'} - ]) + out = client.post_json( + "/validation/", [{"username": "alice"}, {"username": "bob"}] + ) assert {} == out def test_expect_validation_collection_resource_error(self, app, client): self._setup_expect_validation_collection_resource_tests(app) # Input payload is an invalid JSON object - self.assert_errors(client, '/validation/', { - 'username': 123 - }, 'username') + self.assert_errors(client, "/validation/", {"username": 123}, "username") # Input payload is a JSON array but with an invalid JSON object - self.assert_errors(client, '/validation/', [ - {'username': 'alice'}, - {'username': 123} - ], 'username') + self.assert_errors( + client, + "/validation/", + [{"username": "alice"}, {"username": 123}], + "username", + ) def test_validation_with_propagate(self, app, client): - app.config['PROPAGATE_EXCEPTIONS'] = True + app.config["PROPAGATE_EXCEPTIONS"] = True api = restx.Api(app, validate=True) - fields = api.model('Person', { - 'name': restx.fields.String(required=True), - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String(required=True), + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/validation/') + @api.route("/validation/") class ValidationOff(restx.Resource): @api.expect(fields) def post(self): return {} - self.assert_errors(client, '/validation/', {}, 'name') + self.assert_errors(client, "/validation/", {}, "name") def test_empty_payload(self, app, client): api = restx.Api(app, validate=True) - @api.route('/empty/') + @api.route("/empty/") class Payload(restx.Resource): def post(self): return {} - response = client.post('/empty/', data='', - headers={'content-type': 'application/json'}) + response = client.post( + "/empty/", data="", headers={"content-type": "application/json"} + ) assert response.status_code == 200 diff --git a/tests/test_postman.py b/tests/test_postman.py index b9e98625..d854e775 100644 --- a/tests/test_postman.py +++ b/tests/test_postman.py @@ -13,7 +13,7 @@ from six.moves.urllib.parse import parse_qs, urlparse -with open(join(dirname(__file__), 'postman-v1.schema.json')) as f: +with open(join(dirname(__file__), "postman-v1.schema.json")) as f: schema = json.load(f) @@ -25,73 +25,72 @@ def test_basic_export(self, app): validate(data, schema) - assert len(data['requests']) == 0 + assert len(data["requests"]) == 0 def test_export_infos(self, app): - api = restx.Api(app, version='1.0', - title='My API', - description='This is a testing API', + api = restx.Api( + app, version="1.0", title="My API", description="This is a testing API", ) data = api.as_postman() validate(data, schema) - assert data['name'] == 'My API 1.0' - assert data['description'] == 'This is a testing API' + assert data["name"] == "My API 1.0" + assert data["description"] == "This is a testing API" def test_export_with_one_entry(self, app): api = restx.Api(app) - @api.route('/test') + @api.route("/test") class Test(restx.Resource): - @api.doc('test_post') + @api.doc("test_post") def post(self): - '''A test post''' + """A test post""" pass data = api.as_postman() validate(data, schema) - assert len(data['requests']) == 1 - request = data['requests'][0] - assert request['name'] == 'test_post' - assert request['description'] == 'A test post' + assert len(data["requests"]) == 1 + request = data["requests"][0] + assert request["name"] == "test_post" + assert request["description"] == "A test post" - assert len(data['folders']) == 1 - folder = data['folders'][0] - assert folder['name'] == 'default' - assert folder['description'] == 'Default namespace' + assert len(data["folders"]) == 1 + folder = data["folders"][0] + assert folder["name"] == "default" + assert folder["description"] == "Default namespace" - assert request['folder'] == folder['id'] + assert request["folder"] == folder["id"] def test_export_with_namespace(self, app): api = restx.Api(app) - ns = api.namespace('test', 'A test namespace') + ns = api.namespace("test", "A test namespace") - @ns.route('/test') + @ns.route("/test") class Test(restx.Resource): - @api.doc('test_post') + @api.doc("test_post") def post(self): - '''A test post''' + """A test post""" pass data = api.as_postman() validate(data, schema) - assert len(data['requests']) == 1 - request = data['requests'][0] - assert request['name'] == 'test_post' - assert request['description'] == 'A test post' + assert len(data["requests"]) == 1 + request = data["requests"][0] + assert request["name"] == "test_post" + assert request["description"] == "A test post" - assert len(data['folders']) == 1 - folder = data['folders'][0] - assert folder['name'] == 'test' - assert folder['description'] == 'A test namespace' + assert len(data["folders"]) == 1 + folder = data["folders"][0] + assert folder["name"] == "test" + assert folder["description"] == "A test namespace" - assert request['folder'] == folder['id'] + assert request["folder"] == folder["id"] def test_id_is_the_same(self, app): api = restx.Api(app) @@ -100,28 +99,28 @@ def test_id_is_the_same(self, app): second = api.as_postman() - assert first['id'] == second['id'] + assert first["id"] == second["id"] def test_resources_order_in_folder(self, app): - '''It should preserve resources order''' + """It should preserve resources order""" api = restx.Api(app) - ns = api.namespace('test', 'A test namespace') + ns = api.namespace("test", "A test namespace") - @ns.route('/test1') + @ns.route("/test1") class Test1(restx.Resource): - @api.doc('test_post_z') + @api.doc("test_post_z") def post(self): pass - @ns.route('/test2') + @ns.route("/test2") class Test2(restx.Resource): - @api.doc('test_post_y') + @api.doc("test_post_y") def post(self): pass - @ns.route('/test3') + @ns.route("/test3") class Test3(restx.Resource): - @api.doc('test_post_x') + @api.doc("test_post_x") def post(self): pass @@ -129,25 +128,25 @@ def post(self): validate(data, schema) - assert len(data['requests']) == 3 + assert len(data["requests"]) == 3 - assert len(data['folders']) == 1 - folder = data['folders'][0] - assert folder['name'] == 'test' + assert len(data["folders"]) == 1 + folder = data["folders"][0] + assert folder["name"] == "test" - expected_order = ('test_post_z', 'test_post_y', 'test_post_x') - assert len(folder['order']) == len(expected_order) + expected_order = ("test_post_z", "test_post_y", "test_post_x") + assert len(folder["order"]) == len(expected_order) - for request_id, expected in zip(folder['order'], expected_order): - request = list(filter(lambda r: r['id'] == request_id, data['requests']))[0] - assert request['name'] == expected + for request_id, expected in zip(folder["order"], expected_order): + request = list(filter(lambda r: r["id"] == request_id, data["requests"]))[0] + assert request["name"] == expected def test_prefix_with_trailing_slash(self, app): - api = restx.Api(app, prefix='/prefix/') + api = restx.Api(app, prefix="/prefix/") - @api.route('/test/') + @api.route("/test/") class Test(restx.Resource): - @api.doc('test_post') + @api.doc("test_post") def post(self): pass @@ -155,16 +154,16 @@ def post(self): validate(data, schema) - assert len(data['requests']) == 1 - request = data['requests'][0] - assert request['url'] == 'http://localhost/prefix/test/' + assert len(data["requests"]) == 1 + request = data["requests"][0] + assert request["url"] == "http://localhost/prefix/test/" def test_prefix_without_trailing_slash(self, app): - api = restx.Api(app, prefix='/prefix') + api = restx.Api(app, prefix="/prefix") - @api.route('/test/') + @api.route("/test/") class Test(restx.Resource): - @api.doc('test_post') + @api.doc("test_post") def post(self): pass @@ -172,16 +171,16 @@ def post(self): validate(data, schema) - assert len(data['requests']) == 1 - request = data['requests'][0] - assert request['url'] == 'http://localhost/prefix/test/' + assert len(data["requests"]) == 1 + request = data["requests"][0] + assert request["url"] == "http://localhost/prefix/test/" def test_path_variables(self, app): api = restx.Api(app) - @api.route('/test////') + @api.route("/test////") class Test(restx.Resource): - @api.doc('test_post') + @api.doc("test_post") def post(self): pass @@ -189,24 +188,24 @@ def post(self): validate(data, schema) - assert len(data['requests']) == 1 - request = data['requests'][0] - assert request['url'] == 'http://localhost/test/:id/:integer/:number/' - assert request['pathVariables'] == { - 'id': '', - 'integer': 0, - 'number': 0, + assert len(data["requests"]) == 1 + request = data["requests"][0] + assert request["url"] == "http://localhost/test/:id/:integer/:number/" + assert request["pathVariables"] == { + "id": "", + "integer": 0, + "number": 0, } def test_url_variables_disabled(self, app): api = restx.Api(app) parser = api.parser() - parser.add_argument('int', type=int) - parser.add_argument('default', type=int, default=5) - parser.add_argument('str', type=str) + parser.add_argument("int", type=int) + parser.add_argument("default", type=int, default=5) + parser.add_argument("str", type=str) - @api.route('/test/') + @api.route("/test/") class Test(restx.Resource): @api.expect(parser) def get(self): @@ -216,19 +215,19 @@ def get(self): validate(data, schema) - assert len(data['requests']) == 1 - request = data['requests'][0] - assert request['url'] == 'http://localhost/test/' + assert len(data["requests"]) == 1 + request = data["requests"][0] + assert request["url"] == "http://localhost/test/" def test_url_variables_enabled(self, app): api = restx.Api(app) parser = api.parser() - parser.add_argument('int', type=int) - parser.add_argument('default', type=int, default=5) - parser.add_argument('str', type=str) + parser.add_argument("int", type=int) + parser.add_argument("default", type=int, default=5) + parser.add_argument("str", type=str) - @api.route('/test/') + @api.route("/test/") class Test(restx.Resource): @api.expect(parser) def get(self): @@ -238,29 +237,29 @@ def get(self): validate(data, schema) - assert len(data['requests']) == 1 - request = data['requests'][0] - qs = parse_qs(urlparse(request['url']).query, keep_blank_values=True) + assert len(data["requests"]) == 1 + request = data["requests"][0] + qs = parse_qs(urlparse(request["url"]).query, keep_blank_values=True) - assert 'int' in qs - assert qs['int'][0] == '0' + assert "int" in qs + assert qs["int"][0] == "0" - assert 'default' in qs - assert qs['default'][0] == '5' + assert "default" in qs + assert qs["default"][0] == "5" - assert 'str' in qs - assert qs['str'][0] == '' + assert "str" in qs + assert qs["str"][0] == "" def test_headers(self, app): api = restx.Api(app) parser = api.parser() - parser.add_argument('X-Header-1', location='headers', default='xxx') - parser.add_argument('X-Header-2', location='headers', required=True) + parser.add_argument("X-Header-1", location="headers", default="xxx") + parser.add_argument("X-Header-2", location="headers", required=True) - @api.route('/headers/') + @api.route("/headers/") class TestHeaders(restx.Resource): - @api.doc('headers') + @api.doc("headers") @api.expect(parser) def get(self): pass @@ -268,97 +267,96 @@ def get(self): data = api.as_postman(urlvars=True) validate(data, schema) - request = data['requests'][0] - headers = dict(r.split(':') for r in request['headers'].splitlines()) + request = data["requests"][0] + headers = dict(r.split(":") for r in request["headers"].splitlines()) - assert headers['X-Header-1'] == 'xxx' - assert headers['X-Header-2'] == '' + assert headers["X-Header-1"] == "xxx" + assert headers["X-Header-2"] == "" def test_content_type_header(self, app): api = restx.Api(app) form_parser = api.parser() - form_parser.add_argument('param', type=int, help='Some param', location='form') + form_parser.add_argument("param", type=int, help="Some param", location="form") file_parser = api.parser() - file_parser.add_argument('in_files', type=FileStorage, location='files') + file_parser.add_argument("in_files", type=FileStorage, location="files") - @api.route('/json/') + @api.route("/json/") class TestJson(restx.Resource): - @api.doc('json') + @api.doc("json") def post(self): pass - @api.route('/form/') + @api.route("/form/") class TestForm(restx.Resource): - @api.doc('form') + @api.doc("form") @api.expect(form_parser) def post(self): pass - @api.route('/file/') + @api.route("/file/") class TestFile(restx.Resource): - @api.doc('file') + @api.doc("file") @api.expect(file_parser) def post(self): pass - @api.route('/get/') + @api.route("/get/") class TestGet(restx.Resource): - @api.doc('get') + @api.doc("get") def get(self): pass data = api.as_postman(urlvars=True) validate(data, schema) - requests = dict((r['name'], r['headers']) for r in data['requests']) + requests = dict((r["name"], r["headers"]) for r in data["requests"]) - assert requests['json'] == 'Content-Type:application/json' - assert requests['form'] == 'Content-Type:multipart/form-data' - assert requests['file'] == 'Content-Type:multipart/form-data' + assert requests["json"] == "Content-Type:application/json" + assert requests["form"] == "Content-Type:multipart/form-data" + assert requests["file"] == "Content-Type:multipart/form-data" # No content-type on get - assert requests['get'] == '' + assert requests["get"] == "" def test_method_security_headers(self, app): - api = restx.Api(app, authorizations={ - 'apikey': { - 'type': 'apiKey', - 'in': 'header', - 'name': 'X-API' - } - }) - - @api.route('/secure/') + api = restx.Api( + app, + authorizations={ + "apikey": {"type": "apiKey", "in": "header", "name": "X-API"} + }, + ) + + @api.route("/secure/") class Secure(restx.Resource): - @api.doc('secure', security='apikey') + @api.doc("secure", security="apikey") def get(self): pass - @api.route('/unsecure/') + @api.route("/unsecure/") class Unsecure(restx.Resource): - @api.doc('unsecure') + @api.doc("unsecure") def get(self): pass data = api.as_postman() validate(data, schema) - requests = dict((r['name'], r['headers']) for r in data['requests']) + requests = dict((r["name"], r["headers"]) for r in data["requests"]) - assert requests['unsecure'] == '' - assert requests['secure'] == 'X-API:' + assert requests["unsecure"] == "" + assert requests["secure"] == "X-API:" def test_global_security_headers(self, app): - api = restx.Api(app, security='apikey', authorizations={ - 'apikey': { - 'type': 'apiKey', - 'in': 'header', - 'name': 'X-API' - } - }) - - @api.route('/test/') + api = restx.Api( + app, + security="apikey", + authorizations={ + "apikey": {"type": "apiKey", "in": "header", "name": "X-API"} + }, + ) + + @api.route("/test/") class Test(restx.Resource): def get(self): pass @@ -366,25 +364,26 @@ def get(self): data = api.as_postman() validate(data, schema) - request = data['requests'][0] - headers = dict(r.split(':') for r in request['headers'].splitlines()) + request = data["requests"][0] + headers = dict(r.split(":") for r in request["headers"].splitlines()) - assert headers['X-API'] == '' + assert headers["X-API"] == "" def test_oauth_security_headers(self, app): - api = restx.Api(app, security='oauth', authorizations={ - 'oauth': { - 'type': 'oauth2', - 'authorizationUrl': 'https://somewhere.com/oauth/authorize', - 'flow': 'implicit', - 'scopes': { - 'read': 'Can read', - 'write': 'Can write' + api = restx.Api( + app, + security="oauth", + authorizations={ + "oauth": { + "type": "oauth2", + "authorizationUrl": "https://somewhere.com/oauth/authorize", + "flow": "implicit", + "scopes": {"read": "Can read", "write": "Can write"}, } - } - }) + }, + ) - @api.route('/test/') + @api.route("/test/") class Test(restx.Resource): def get(self): pass @@ -404,8 +403,8 @@ def test_export_with_swagger(self, app): validate(data, schema) - assert len(data['requests']) == 1 - request = data['requests'][0] - assert request['name'] == 'Swagger specifications' - assert request['description'] == 'The API Swagger specifications as JSON' - assert request['url'] == 'http://localhost/swagger.json' + assert len(data["requests"]) == 1 + request = data["requests"][0] + assert request["name"] == "Swagger specifications" + assert request["description"] == "The API Swagger specifications as JSON" + assert request["url"] == "http://localhost/swagger.json" diff --git a/tests/test_reqparse.py b/tests/test_reqparse.py index 31849bfe..c50a5f1d 100644 --- a/tests/test_reqparse.py +++ b/tests/test_reqparse.py @@ -22,420 +22,435 @@ def test_api_shortcut(self, app): assert isinstance(parser, RequestParser) def test_parse_model(self, app): - model = Model('Todo', { - 'task': fields.String(required=True) - }) + model = Model("Todo", {"task": fields.String(required=True)}) parser = RequestParser() - parser.add_argument('todo', type=model, required=True) + parser.add_argument("todo", type=model, required=True) - data = {'todo': {'task': 'aaa'}} + data = {"todo": {"task": "aaa"}} - with app.test_request_context('/', method='post', - data=json.dumps(data), - content_type='application/json'): + with app.test_request_context( + "/", method="post", data=json.dumps(data), content_type="application/json" + ): args = parser.parse_args() - assert args['todo'] == {'task': 'aaa'} + assert args["todo"] == {"task": "aaa"} def test_help(self, app, mocker): - abort = mocker.patch('flask_restx.reqparse.abort', - side_effect=BadRequest('Bad Request')) + abort = mocker.patch( + "flask_restx.reqparse.abort", side_effect=BadRequest("Bad Request") + ) parser = RequestParser() - parser.add_argument('foo', choices=('one', 'two'), help='Bad choice.') - req = mocker.Mock(['values']) - req.values = MultiDict([('foo', 'three')]) + parser.add_argument("foo", choices=("one", "two"), help="Bad choice.") + req = mocker.Mock(["values"]) + req.values = MultiDict([("foo", "three")]) with pytest.raises(BadRequest): parser.parse_args(req) - expected = {'foo': 'Bad choice. The value \'three\' is not a valid choice for \'foo\'.'} - abort.assert_called_with(400, 'Input payload validation failed', errors=expected) + expected = { + "foo": "Bad choice. The value 'three' is not a valid choice for 'foo'." + } + abort.assert_called_with( + 400, "Input payload validation failed", errors=expected + ) def test_no_help(self, app, mocker): - abort = mocker.patch('flask_restx.reqparse.abort', - side_effect=BadRequest('Bad Request')) + abort = mocker.patch( + "flask_restx.reqparse.abort", side_effect=BadRequest("Bad Request") + ) parser = RequestParser() - parser.add_argument('foo', choices=['one', 'two']) - req = mocker.Mock(['values']) - req.values = MultiDict([('foo', 'three')]) + parser.add_argument("foo", choices=["one", "two"]) + req = mocker.Mock(["values"]) + req.values = MultiDict([("foo", "three")]) with pytest.raises(BadRequest): parser.parse_args(req) - expected = {'foo': 'The value \'three\' is not a valid choice for \'foo\'.'} - abort.assert_called_with(400, 'Input payload validation failed', errors=expected) + expected = {"foo": "The value 'three' is not a valid choice for 'foo'."} + abort.assert_called_with( + 400, "Input payload validation failed", errors=expected + ) @pytest.mark.request_context() def test_viewargs(self, mocker): req = Request.from_values() - req.view_args = {'foo': 'bar'} + req.view_args = {"foo": "bar"} parser = RequestParser() - parser.add_argument('foo', location=['view_args']) + parser.add_argument("foo", location=["view_args"]) args = parser.parse_args(req) - assert args['foo'] == 'bar' + assert args["foo"] == "bar" req = mocker.Mock() req.values = () req.json = None - req.view_args = {'foo': 'bar'} + req.view_args = {"foo": "bar"} parser = RequestParser() - parser.add_argument('foo', store_missing=True) + parser.add_argument("foo", store_missing=True) args = parser.parse_args(req) - assert args['foo'] is None + assert args["foo"] is None def test_parse_unicode(self, app): - req = Request.from_values('/bubble?foo=barß') + req = Request.from_values("/bubble?foo=barß") parser = RequestParser() - parser.add_argument('foo') + parser.add_argument("foo") args = parser.parse_args(req) - assert args['foo'] == 'barß' + assert args["foo"] == "barß" def test_parse_unicode_app(self, app): parser = RequestParser() - parser.add_argument('foo') + parser.add_argument("foo") - with app.test_request_context('/bubble?foo=barß'): + with app.test_request_context("/bubble?foo=barß"): args = parser.parse_args() - assert args['foo'] == 'barß' + assert args["foo"] == "barß" - @pytest.mark.request_context('/bubble', method='post') + @pytest.mark.request_context("/bubble", method="post") def test_json_location(self): parser = RequestParser() - parser.add_argument('foo', location='json', store_missing=True) + parser.add_argument("foo", location="json", store_missing=True) args = parser.parse_args() - assert args['foo'] is None - - @pytest.mark.request_context('/bubble', method='post', - data=json.dumps({'foo': 'bar'}), - content_type='application/json') + assert args["foo"] is None + + @pytest.mark.request_context( + "/bubble", + method="post", + data=json.dumps({"foo": "bar"}), + content_type="application/json", + ) def test_get_json_location(self): parser = RequestParser() - parser.add_argument('foo', location='json') + parser.add_argument("foo", location="json") args = parser.parse_args() - assert args['foo'] == 'bar' + assert args["foo"] == "bar" - @pytest.mark.request_context('/bubble?foo=bar') + @pytest.mark.request_context("/bubble?foo=bar") def test_parse_append_ignore(self, app): parser = RequestParser() - parser.add_argument('foo', ignore=True, type=int, action='append', - store_missing=True), + parser.add_argument( + "foo", ignore=True, type=int, action="append", store_missing=True + ), args = parser.parse_args() - assert args['foo'] is None + assert args["foo"] is None - @pytest.mark.request_context('/bubble?') + @pytest.mark.request_context("/bubble?") def test_parse_append_default(self): parser = RequestParser() - parser.add_argument('foo', action='append', store_missing=True), + parser.add_argument("foo", action="append", store_missing=True), args = parser.parse_args() - assert args['foo'] is None + assert args["foo"] is None - @pytest.mark.request_context('/bubble?foo=bar&foo=bat') + @pytest.mark.request_context("/bubble?foo=bar&foo=bat") def test_parse_append(self): parser = RequestParser() - parser.add_argument('foo', action='append'), + parser.add_argument("foo", action="append"), args = parser.parse_args() - assert args['foo'] == ['bar', 'bat'] + assert args["foo"] == ["bar", "bat"] - @pytest.mark.request_context('/bubble?foo=bar') + @pytest.mark.request_context("/bubble?foo=bar") def test_parse_append_single(self): parser = RequestParser() - parser.add_argument('foo', action='append'), + parser.add_argument("foo", action="append"), args = parser.parse_args() - assert args['foo'] == ['bar'] + assert args["foo"] == ["bar"] - @pytest.mark.request_context('/bubble?foo=bar') + @pytest.mark.request_context("/bubble?foo=bar") def test_split_single(self): parser = RequestParser() - parser.add_argument('foo', action='split'), + parser.add_argument("foo", action="split"), args = parser.parse_args() - assert args['foo'] == ['bar'] + assert args["foo"] == ["bar"] - @pytest.mark.request_context('/bubble?foo=bar,bat') + @pytest.mark.request_context("/bubble?foo=bar,bat") def test_split_multiple(self): parser = RequestParser() - parser.add_argument('foo', action='split'), + parser.add_argument("foo", action="split"), args = parser.parse_args() - assert args['foo'] == ['bar', 'bat'] + assert args["foo"] == ["bar", "bat"] - @pytest.mark.request_context('/bubble?foo=1,2,3') + @pytest.mark.request_context("/bubble?foo=1,2,3") def test_split_multiple_cast(self): parser = RequestParser() - parser.add_argument('foo', type=int, action='split') + parser.add_argument("foo", type=int, action="split") args = parser.parse_args() - assert args['foo'] == [1, 2, 3] + assert args["foo"] == [1, 2, 3] - @pytest.mark.request_context('/bubble?foo=bar') + @pytest.mark.request_context("/bubble?foo=bar") def test_parse_dest(self): parser = RequestParser() - parser.add_argument('foo', dest='bat') + parser.add_argument("foo", dest="bat") args = parser.parse_args() - assert args['bat'] == 'bar' + assert args["bat"] == "bar" - @pytest.mark.request_context('/bubble?foo>=bar&foo<=bat&foo=foo') + @pytest.mark.request_context("/bubble?foo>=bar&foo<=bat&foo=foo") def test_parse_gte_lte_eq(self): parser = RequestParser() - parser.add_argument('foo', operators=['>=', '<=', '='], action='append'), + parser.add_argument("foo", operators=[">=", "<=", "="], action="append"), args = parser.parse_args() - assert args['foo'] == ['bar', 'bat', 'foo'] + assert args["foo"] == ["bar", "bat", "foo"] - @pytest.mark.request_context('/bubble?foo>=bar') + @pytest.mark.request_context("/bubble?foo>=bar") def test_parse_gte(self): parser = RequestParser() - parser.add_argument('foo', operators=['>=']) + parser.add_argument("foo", operators=[">="]) args = parser.parse_args() - assert args['foo'] == 'bar' + assert args["foo"] == "bar" - @pytest.mark.request_context('/bubble?foo=bar') + @pytest.mark.request_context("/bubble?foo=bar") def test_parse_foo_operators_four_hunderd(self): parser = RequestParser() - parser.add_argument('foo', type=int), + parser.add_argument("foo", type=int), with pytest.raises(BadRequest): parser.parse_args() - @pytest.mark.request_context('/bubble') + @pytest.mark.request_context("/bubble") def test_parse_foo_operators_ignore(self): parser = RequestParser() - parser.add_argument('foo', ignore=True, store_missing=True) + parser.add_argument("foo", ignore=True, store_missing=True) args = parser.parse_args() - assert args['foo'] is None + assert args["foo"] is None - @pytest.mark.request_context('/bubble?foo<=bar') + @pytest.mark.request_context("/bubble?foo<=bar") def test_parse_lte_gte_mock(self, mocker): mock_type = mocker.Mock() parser = RequestParser() - parser.add_argument('foo', type=mock_type, operators=['<=']) + parser.add_argument("foo", type=mock_type, operators=["<="]) parser.parse_args() - mock_type.assert_called_with('bar', 'foo', '<=') + mock_type.assert_called_with("bar", "foo", "<=") - @pytest.mark.request_context('/bubble?foo<=bar') + @pytest.mark.request_context("/bubble?foo<=bar") def test_parse_lte_gte_append(self): parser = RequestParser() - parser.add_argument('foo', operators=['<=', '='], action='append') + parser.add_argument("foo", operators=["<=", "="], action="append") args = parser.parse_args() - assert args['foo'] == ['bar'] + assert args["foo"] == ["bar"] - @pytest.mark.request_context('/bubble?foo<=bar') + @pytest.mark.request_context("/bubble?foo<=bar") def test_parse_lte_gte_missing(self): parser = RequestParser() - parser.add_argument('foo', operators=['<=', '=']) + parser.add_argument("foo", operators=["<=", "="]) args = parser.parse_args() - assert args['foo'] == 'bar' + assert args["foo"] == "bar" - @pytest.mark.request_context('/bubble?foo=bar&foo=bat') + @pytest.mark.request_context("/bubble?foo=bar&foo=bat") def test_parse_eq_other(self): parser = RequestParser() - parser.add_argument('foo'), + parser.add_argument("foo"), args = parser.parse_args() - assert args['foo'] == 'bar' + assert args["foo"] == "bar" - @pytest.mark.request_context('/bubble?foo=bar') + @pytest.mark.request_context("/bubble?foo=bar") def test_parse_eq(self): parser = RequestParser() - parser.add_argument('foo'), + parser.add_argument("foo"), args = parser.parse_args() - assert args['foo'] == 'bar' + assert args["foo"] == "bar" - @pytest.mark.request_context('/bubble?foo<=bar') + @pytest.mark.request_context("/bubble?foo<=bar") def test_parse_lte(self): parser = RequestParser() - parser.add_argument('foo', operators=['<=']) + parser.add_argument("foo", operators=["<="]) args = parser.parse_args() - assert args['foo'] == 'bar' + assert args["foo"] == "bar" - @pytest.mark.request_context('/bubble') + @pytest.mark.request_context("/bubble") def test_parse_required(self, app): parser = RequestParser() - parser.add_argument('foo', required=True, location='values') + parser.add_argument("foo", required=True, location="values") expected = { - 'foo': 'Missing required parameter in the post body or the query string' + "foo": "Missing required parameter in the post body or the query string" } with pytest.raises(BadRequest) as cm: parser.parse_args() - assert cm.value.data['message'] == 'Input payload validation failed' - assert cm.value.data['errors'] == expected + assert cm.value.data["message"] == "Input payload validation failed" + assert cm.value.data["errors"] == expected parser = RequestParser() - parser.add_argument('bar', required=True, location=['values', 'cookies']) + parser.add_argument("bar", required=True, location=["values", "cookies"]) expected = { - 'bar': ("Missing required parameter in the post body or the query " - "string or the request's cookies") + "bar": ( + "Missing required parameter in the post body or the query " + "string or the request's cookies" + ) } with pytest.raises(BadRequest) as cm: parser.parse_args() - assert cm.value.data['message'] == 'Input payload validation failed' - assert cm.value.data['errors'] == expected + assert cm.value.data["message"] == "Input payload validation failed" + assert cm.value.data["errors"] == expected - @pytest.mark.request_context('/bubble') + @pytest.mark.request_context("/bubble") @pytest.mark.options(bundle_errors=True) def test_parse_error_bundling(self, app): parser = RequestParser() - parser.add_argument('foo', required=True, location='values') - parser.add_argument('bar', required=True, location=['values', 'cookies']) + parser.add_argument("foo", required=True, location="values") + parser.add_argument("bar", required=True, location=["values", "cookies"]) with pytest.raises(BadRequest) as cm: parser.parse_args() - assert cm.value.data['message'] == 'Input payload validation failed' - assert cm.value.data['errors'] == { - 'foo': 'Missing required parameter in the post body or the query string', - 'bar': ("Missing required parameter in the post body or the query string " - "or the request's cookies") + assert cm.value.data["message"] == "Input payload validation failed" + assert cm.value.data["errors"] == { + "foo": "Missing required parameter in the post body or the query string", + "bar": ( + "Missing required parameter in the post body or the query string " + "or the request's cookies" + ), } - @pytest.mark.request_context('/bubble') + @pytest.mark.request_context("/bubble") @pytest.mark.options(bundle_errors=False) def test_parse_error_bundling_w_parser_arg(self, app): parser = RequestParser(bundle_errors=True) - parser.add_argument('foo', required=True, location='values') - parser.add_argument('bar', required=True, location=['values', 'cookies']) + parser.add_argument("foo", required=True, location="values") + parser.add_argument("bar", required=True, location=["values", "cookies"]) with pytest.raises(BadRequest) as cm: parser.parse_args() - assert cm.value.data['message'] == 'Input payload validation failed' - assert cm.value.data['errors'] == { - 'foo': 'Missing required parameter in the post body or the query string', - 'bar': ("Missing required parameter in the post body or the query string " - "or the request's cookies") + assert cm.value.data["message"] == "Input payload validation failed" + assert cm.value.data["errors"] == { + "foo": "Missing required parameter in the post body or the query string", + "bar": ( + "Missing required parameter in the post body or the query string " + "or the request's cookies" + ), } - @pytest.mark.request_context('/bubble') + @pytest.mark.request_context("/bubble") def test_parse_default_append(self): parser = RequestParser() - parser.add_argument('foo', default='bar', action='append', - store_missing=True) + parser.add_argument("foo", default="bar", action="append", store_missing=True) args = parser.parse_args() - assert args['foo'] == 'bar' + assert args["foo"] == "bar" - @pytest.mark.request_context('/bubble') + @pytest.mark.request_context("/bubble") def test_parse_default(self): parser = RequestParser() - parser.add_argument('foo', default='bar', store_missing=True) + parser.add_argument("foo", default="bar", store_missing=True) args = parser.parse_args() - assert args['foo'] == 'bar' + assert args["foo"] == "bar" - @pytest.mark.request_context('/bubble') + @pytest.mark.request_context("/bubble") def test_parse_callable_default(self): parser = RequestParser() - parser.add_argument('foo', default=lambda: 'bar', store_missing=True) + parser.add_argument("foo", default=lambda: "bar", store_missing=True) args = parser.parse_args() - assert args['foo'] == 'bar' + assert args["foo"] == "bar" - @pytest.mark.request_context('/bubble?foo=bar') + @pytest.mark.request_context("/bubble?foo=bar") def test_parse(self): parser = RequestParser() - parser.add_argument('foo'), + parser.add_argument("foo"), args = parser.parse_args() - assert args['foo'] == 'bar' + assert args["foo"] == "bar" - @pytest.mark.request_context('/bubble') + @pytest.mark.request_context("/bubble") def test_parse_none(self): parser = RequestParser() - parser.add_argument('foo') + parser.add_argument("foo") args = parser.parse_args() - assert args['foo'] is None + assert args["foo"] is None def test_parse_store_missing(self, app): - req = Request.from_values('/bubble') + req = Request.from_values("/bubble") parser = RequestParser() - parser.add_argument('foo', store_missing=False) + parser.add_argument("foo", store_missing=False) args = parser.parse_args(req) - assert 'foo' not in args + assert "foo" not in args def test_parse_choices_correct(self, app): - req = Request.from_values('/bubble?foo=bat') + req = Request.from_values("/bubble?foo=bat") parser = RequestParser() - parser.add_argument('foo', choices=['bat']), + parser.add_argument("foo", choices=["bat"]), args = parser.parse_args(req) - assert args['foo'] == 'bat' + assert args["foo"] == "bat" def test_parse_choices(self, app): - req = Request.from_values('/bubble?foo=bar') + req = Request.from_values("/bubble?foo=bar") parser = RequestParser() - parser.add_argument('foo', choices=['bat']), + parser.add_argument("foo", choices=["bat"]), with pytest.raises(BadRequest): parser.parse_args(req) def test_parse_choices_sensitive(self, app): - req = Request.from_values('/bubble?foo=BAT') + req = Request.from_values("/bubble?foo=BAT") parser = RequestParser() - parser.add_argument('foo', choices=['bat'], case_sensitive=True), + parser.add_argument("foo", choices=["bat"], case_sensitive=True), with pytest.raises(BadRequest): parser.parse_args(req) def test_parse_choices_insensitive(self, app): - req = Request.from_values('/bubble?foo=BAT') + req = Request.from_values("/bubble?foo=BAT") parser = RequestParser() - parser.add_argument('foo', choices=['bat'], case_sensitive=False), + parser.add_argument("foo", choices=["bat"], case_sensitive=False), args = parser.parse_args(req) - assert 'bat' == args.get('foo') + assert "bat" == args.get("foo") # both choices and args are case_insensitive - req = Request.from_values('/bubble?foo=bat') + req = Request.from_values("/bubble?foo=bat") parser = RequestParser() - parser.add_argument('foo', choices=['BAT'], case_sensitive=False), + parser.add_argument("foo", choices=["BAT"], case_sensitive=False), args = parser.parse_args(req) - assert 'bat' == args.get('foo') + assert "bat" == args.get("foo") def test_parse_ignore(self, app): - req = Request.from_values('/bubble?foo=bar') + req = Request.from_values("/bubble?foo=bar") parser = RequestParser() - parser.add_argument('foo', type=int, ignore=True, store_missing=True), + parser.add_argument("foo", type=int, ignore=True, store_missing=True), args = parser.parse_args(req) - assert args['foo'] is None + assert args["foo"] is None def test_chaining(self): parser = RequestParser() - assert parser is parser.add_argument('foo') + assert parser is parser.add_argument("foo") def test_result_existence(self): result = ParseResult() - result.foo = 'bar' - result['bar'] = 'baz' - assert result['foo'] == 'bar' - assert result.bar == 'baz' + result.foo = "bar" + result["bar"] = "baz" + assert result["foo"] == "bar" + assert result.bar == "baz" def test_result_missing(self): result = ParseResult() pytest.raises(AttributeError, lambda: result.spam) - pytest.raises(KeyError, lambda: result['eggs']) + pytest.raises(KeyError, lambda: result["eggs"]) def test_result_configurability(self): req = Request.from_values() @@ -444,105 +459,124 @@ def test_result_configurability(self): def test_none_argument(self, app): parser = RequestParser() - parser.add_argument('foo', location='json') - with app.test_request_context('/bubble', method='post', - data=json.dumps({'foo': None}), - content_type='application/json'): + parser.add_argument("foo", location="json") + with app.test_request_context( + "/bubble", + method="post", + data=json.dumps({"foo": None}), + content_type="application/json", + ): args = parser.parse_args() - assert args['foo'] is None + assert args["foo"] is None def test_type_callable(self, app): - req = Request.from_values('/bubble?foo=1') + req = Request.from_values("/bubble?foo=1") parser = RequestParser() - parser.add_argument('foo', type=lambda x: x, required=False), + parser.add_argument("foo", type=lambda x: x, required=False), args = parser.parse_args(req) - assert args['foo'] == '1' + assert args["foo"] == "1" def test_type_callable_none(self, app): parser = RequestParser() - parser.add_argument('foo', type=lambda x: x, location='json', required=False), + parser.add_argument("foo", type=lambda x: x, location="json", required=False), - with app.test_request_context('/bubble', method='post', - data=json.dumps({'foo': None}), - content_type='application/json'): + with app.test_request_context( + "/bubble", + method="post", + data=json.dumps({"foo": None}), + content_type="application/json", + ): args = parser.parse_args() - assert args['foo'] is None + assert args["foo"] is None def test_type_decimal(self, app): parser = RequestParser() - parser.add_argument('foo', type=decimal.Decimal, location='json') + parser.add_argument("foo", type=decimal.Decimal, location="json") - with app.test_request_context('/bubble', method='post', - data=json.dumps({'foo': '1.0025'}), - content_type='application/json'): + with app.test_request_context( + "/bubble", + method="post", + data=json.dumps({"foo": "1.0025"}), + content_type="application/json", + ): args = parser.parse_args() - assert args['foo'] == decimal.Decimal('1.0025') + assert args["foo"] == decimal.Decimal("1.0025") def test_type_filestorage(self, app): parser = RequestParser() - parser.add_argument('foo', type=FileStorage, location='files') + parser.add_argument("foo", type=FileStorage, location="files") - fdata = 'foo bar baz qux'.encode('utf-8') - with app.test_request_context('/bubble', method='POST', - data={'foo': (six.BytesIO(fdata), 'baz.txt')}): + fdata = "foo bar baz qux".encode("utf-8") + with app.test_request_context( + "/bubble", method="POST", data={"foo": (six.BytesIO(fdata), "baz.txt")} + ): args = parser.parse_args() - assert args['foo'].name == 'foo' - assert args['foo'].filename == 'baz.txt' - assert args['foo'].read() == fdata + assert args["foo"].name == "foo" + assert args["foo"].filename == "baz.txt" + assert args["foo"].read() == fdata def test_filestorage_custom_type(self, app): def _custom_type(f): - return FileStorage(stream=f.stream, - filename='{0}aaaa'.format(f.filename), - name='{0}aaaa'.format(f.name)) + return FileStorage( + stream=f.stream, + filename="{0}aaaa".format(f.filename), + name="{0}aaaa".format(f.name), + ) parser = RequestParser() - parser.add_argument('foo', type=_custom_type, location='files') + parser.add_argument("foo", type=_custom_type, location="files") - fdata = 'foo bar baz qux'.encode('utf-8') - with app.test_request_context('/bubble', method='POST', - data={'foo': (six.BytesIO(fdata), 'baz.txt')}): + fdata = "foo bar baz qux".encode("utf-8") + with app.test_request_context( + "/bubble", method="POST", data={"foo": (six.BytesIO(fdata), "baz.txt")} + ): args = parser.parse_args() - assert args['foo'].name == 'fooaaaa' - assert args['foo'].filename == 'baz.txtaaaa' - assert args['foo'].read() == fdata + assert args["foo"].name == "fooaaaa" + assert args["foo"].filename == "baz.txtaaaa" + assert args["foo"].read() == fdata def test_passing_arguments_object(self, app): - req = Request.from_values('/bubble?foo=bar') + req = Request.from_values("/bubble?foo=bar") parser = RequestParser() - parser.add_argument(Argument('foo')) + parser.add_argument(Argument("foo")) args = parser.parse_args(req) - assert args['foo'] == 'bar' + assert args["foo"] == "bar" def test_int_choice_types(self, app): parser = RequestParser() - parser.add_argument('foo', type=int, choices=[1, 2, 3], location='json') + parser.add_argument("foo", type=int, choices=[1, 2, 3], location="json") - with app.test_request_context('/bubble', method='post', - data=json.dumps({'foo': 5}), - content_type='application/json'): + with app.test_request_context( + "/bubble", + method="post", + data=json.dumps({"foo": 5}), + content_type="application/json", + ): with pytest.raises(BadRequest): parser.parse_args() def test_int_range_choice_types(self, app): parser = RequestParser() - parser.add_argument('foo', type=int, choices=range(100), location='json') + parser.add_argument("foo", type=int, choices=range(100), location="json") - with app.test_request_context('/bubble', method='post', - data=json.dumps({'foo': 101}), - content_type='application/json'): + with app.test_request_context( + "/bubble", + method="post", + data=json.dumps({"foo": 101}), + content_type="application/json", + ): with pytest.raises(BadRequest): parser.parse_args() def test_request_parser_copy(self, app): - req = Request.from_values('/bubble?foo=101&bar=baz') + req = Request.from_values("/bubble?foo=101&bar=baz") parser = RequestParser() - foo_arg = Argument('foo', type=int) + foo_arg = Argument("foo", type=int) parser.args.append(foo_arg) parser_copy = parser.copy() @@ -551,13 +585,13 @@ def test_request_parser_copy(self, app): assert foo_arg not in parser_copy.args # Args added to new parser should not be added to the original - bar_arg = Argument('bar') + bar_arg = Argument("bar") parser_copy.args.append(bar_arg) assert bar_arg not in parser.args args = parser_copy.parse_args(req) - assert args['foo'] == 101 - assert args['bar'] == 'baz' + assert args["foo"] == 101 + assert args["bar"] == "baz" def test_request_parse_copy_including_settings(self): parser = RequestParser(trim=True, bundle_errors=True) @@ -567,269 +601,276 @@ def test_request_parse_copy_including_settings(self): assert parser.bundle_errors == parser_copy.bundle_errors def test_request_parser_replace_argument(self, app): - req = Request.from_values('/bubble?foo=baz') + req = Request.from_values("/bubble?foo=baz") parser = RequestParser() - parser.add_argument('foo', type=int) + parser.add_argument("foo", type=int) parser_copy = parser.copy() - parser_copy.replace_argument('foo') + parser_copy.replace_argument("foo") args = parser_copy.parse_args(req) - assert args['foo'] == 'baz' + assert args["foo"] == "baz" def test_both_json_and_values_location(self, app): parser = RequestParser() - parser.add_argument('foo', type=int) - parser.add_argument('baz', type=int) - with app.test_request_context('/bubble?foo=1', method='post', - data=json.dumps({'baz': 2}), - content_type='application/json'): + parser.add_argument("foo", type=int) + parser.add_argument("baz", type=int) + with app.test_request_context( + "/bubble?foo=1", + method="post", + data=json.dumps({"baz": 2}), + content_type="application/json", + ): args = parser.parse_args() - assert args['foo'] == 1 - assert args['baz'] == 2 + assert args["foo"] == 1 + assert args["baz"] == 2 def test_not_json_location_and_content_type_json(self, app): parser = RequestParser() - parser.add_argument('foo', location='args') + parser.add_argument("foo", location="args") - with app.test_request_context('/bubble', method='get', - content_type='application/json'): + with app.test_request_context( + "/bubble", method="get", content_type="application/json" + ): parser.parse_args() # Should not raise a 400: BadRequest def test_request_parser_remove_argument(self): - req = Request.from_values('/bubble?foo=baz') + req = Request.from_values("/bubble?foo=baz") parser = RequestParser() - parser.add_argument('foo', type=int) + parser.add_argument("foo", type=int) parser_copy = parser.copy() - parser_copy.remove_argument('foo') + parser_copy.remove_argument("foo") args = parser_copy.parse_args(req) assert args == {} def test_strict_parsing_off(self): - req = Request.from_values('/bubble?foo=baz') + req = Request.from_values("/bubble?foo=baz") parser = RequestParser() args = parser.parse_args(req) assert args == {} def test_strict_parsing_on(self): - req = Request.from_values('/bubble?foo=baz') + req = Request.from_values("/bubble?foo=baz") parser = RequestParser() with pytest.raises(BadRequest): parser.parse_args(req, strict=True) def test_strict_parsing_off_partial_hit(self, app): - req = Request.from_values('/bubble?foo=1&bar=bees&n=22') + req = Request.from_values("/bubble?foo=1&bar=bees&n=22") parser = RequestParser() - parser.add_argument('foo', type=int) + parser.add_argument("foo", type=int) args = parser.parse_args(req) - assert args['foo'] == 1 + assert args["foo"] == 1 def test_strict_parsing_on_partial_hit(self, app): - req = Request.from_values('/bubble?foo=1&bar=bees&n=22') + req = Request.from_values("/bubble?foo=1&bar=bees&n=22") parser = RequestParser() - parser.add_argument('foo', type=int) + parser.add_argument("foo", type=int) with pytest.raises(BadRequest): parser.parse_args(req, strict=True) def test_trim_argument(self, app): - req = Request.from_values('/bubble?foo= 1 &bar=bees&n=22') + req = Request.from_values("/bubble?foo= 1 &bar=bees&n=22") parser = RequestParser() - parser.add_argument('foo') + parser.add_argument("foo") args = parser.parse_args(req) - assert args['foo'] == ' 1 ' + assert args["foo"] == " 1 " parser = RequestParser() - parser.add_argument('foo', trim=True) + parser.add_argument("foo", trim=True) args = parser.parse_args(req) - assert args['foo'] == '1' + assert args["foo"] == "1" parser = RequestParser() - parser.add_argument('foo', trim=True, type=int) + parser.add_argument("foo", trim=True, type=int) args = parser.parse_args(req) - assert args['foo'] == 1 + assert args["foo"] == 1 def test_trim_request_parser(self, app): - req = Request.from_values('/bubble?foo= 1 &bar=bees&n=22') + req = Request.from_values("/bubble?foo= 1 &bar=bees&n=22") parser = RequestParser(trim=False) - parser.add_argument('foo') + parser.add_argument("foo") args = parser.parse_args(req) - assert args['foo'] == ' 1 ' + assert args["foo"] == " 1 " parser = RequestParser(trim=True) - parser.add_argument('foo') + parser.add_argument("foo") args = parser.parse_args(req) - assert args['foo'] == '1' + assert args["foo"] == "1" parser = RequestParser(trim=True) - parser.add_argument('foo', type=int) + parser.add_argument("foo", type=int) args = parser.parse_args(req) - assert args['foo'] == 1 + assert args["foo"] == 1 def test_trim_request_parser_override_by_argument(self): parser = RequestParser(trim=True) - parser.add_argument('foo', trim=False) + parser.add_argument("foo", trim=False) assert parser.args[0].trim is False def test_trim_request_parser_json(self, app): parser = RequestParser(trim=True) - parser.add_argument('foo', location='json') - parser.add_argument('int1', location='json', type=int) - parser.add_argument('int2', location='json', type=int) - - with app.test_request_context('/bubble', method='post', - data=json.dumps({'foo': ' bar ', 'int1': 1, 'int2': ' 2 '}), - content_type='application/json'): + parser.add_argument("foo", location="json") + parser.add_argument("int1", location="json", type=int) + parser.add_argument("int2", location="json", type=int) + + with app.test_request_context( + "/bubble", + method="post", + data=json.dumps({"foo": " bar ", "int1": 1, "int2": " 2 "}), + content_type="application/json", + ): args = parser.parse_args() - assert args['foo'] == 'bar' - assert args['int1'] == 1 - assert args['int2'] == 2 + assert args["foo"] == "bar" + assert args["int1"] == 1 + assert args["int2"] == 2 class ArgumentTest(object): def test_name(self): - arg = Argument('foo') - assert arg.name == 'foo' + arg = Argument("foo") + assert arg.name == "foo" def test_dest(self): - arg = Argument('foo', dest='foobar') - assert arg.dest == 'foobar' + arg = Argument("foo", dest="foobar") + assert arg.dest == "foobar" def test_location_url(self): - arg = Argument('foo', location='url') - assert arg.location == 'url' + arg = Argument("foo", location="url") + assert arg.location == "url" def test_location_url_list(self): - arg = Argument('foo', location=['url']) - assert arg.location == ['url'] + arg = Argument("foo", location=["url"]) + assert arg.location == ["url"] def test_location_header(self): - arg = Argument('foo', location='headers') - assert arg.location == 'headers' + arg = Argument("foo", location="headers") + assert arg.location == "headers" def test_location_json(self): - arg = Argument('foo', location='json') - assert arg.location == 'json' + arg = Argument("foo", location="json") + assert arg.location == "json" def test_location_get_json(self): - arg = Argument('foo', location='get_json') - assert arg.location == 'get_json' + arg = Argument("foo", location="get_json") + assert arg.location == "get_json" def test_location_header_list(self): - arg = Argument('foo', location=['headers']) - assert arg.location == ['headers'] + arg = Argument("foo", location=["headers"]) + assert arg.location == ["headers"] def test_type(self): - arg = Argument('foo', type=int) + arg = Argument("foo", type=int) assert arg.type == int def test_default(self): - arg = Argument('foo', default=True) + arg = Argument("foo", default=True) assert arg.default is True def test_default_help(self): - arg = Argument('foo') + arg = Argument("foo") assert arg.help is None def test_required(self): - arg = Argument('foo', required=True) + arg = Argument("foo", required=True) assert arg.required is True def test_ignore(self): - arg = Argument('foo', ignore=True) + arg = Argument("foo", ignore=True) assert arg.ignore is True def test_operator(self): - arg = Argument('foo', operators=['>=', '<=', '=']) - assert arg.operators == ['>=', '<=', '='] + arg = Argument("foo", operators=[">=", "<=", "="]) + assert arg.operators == [">=", "<=", "="] def test_action_filter(self): - arg = Argument('foo', action='filter') - assert arg.action == 'filter' + arg = Argument("foo", action="filter") + assert arg.action == "filter" def test_action(self): - arg = Argument('foo', action='append') - assert arg.action == 'append' + arg = Argument("foo", action="append") + assert arg.action == "append" def test_choices(self): - arg = Argument('foo', choices=[1, 2]) + arg = Argument("foo", choices=[1, 2]) assert arg.choices == [1, 2] def test_default_dest(self): - arg = Argument('foo') + arg = Argument("foo") assert arg.dest is None def test_default_operators(self): - arg = Argument('foo') - assert arg.operators[0] == '=' + arg = Argument("foo") + assert arg.operators[0] == "=" assert len(arg.operators) == 1 def test_default_type(self, mocker): - mock_six = mocker.patch('flask_restx.reqparse.six') - arg = Argument('foo') + mock_six = mocker.patch("flask_restx.reqparse.six") + arg = Argument("foo") sentinel = object() arg.type(sentinel) mock_six.text_type.assert_called_with(sentinel) def test_default_default(self): - arg = Argument('foo') + arg = Argument("foo") assert arg.default is None def test_required_default(self): - arg = Argument('foo') + arg = Argument("foo") assert arg.required is False def test_ignore_default(self): - arg = Argument('foo') + arg = Argument("foo") assert arg.ignore is False def test_action_default(self): - arg = Argument('foo') - assert arg.action == 'store' + arg = Argument("foo") + assert arg.action == "store" def test_choices_default(self): - arg = Argument('foo') + arg = Argument("foo") assert len(arg.choices) == 0 def test_source(self, mocker): - req = mocker.Mock(['args', 'headers', 'values']) - req.args = {'foo': 'bar'} - req.headers = {'baz': 'bat'} - arg = Argument('foo', location=['args']) + req = mocker.Mock(["args", "headers", "values"]) + req.args = {"foo": "bar"} + req.headers = {"baz": "bat"} + arg = Argument("foo", location=["args"]) assert arg.source(req) == MultiDict(req.args) - arg = Argument('foo', location=['headers']) + arg = Argument("foo", location=["headers"]) assert arg.source(req) == MultiDict(req.headers) def test_convert_default_type_with_null_input(self): - arg = Argument('foo') + arg = Argument("foo") assert arg.convert(None, None) is None def test_convert_with_null_input_when_not_nullable(self): - arg = Argument('foo', nullable=False) + arg = Argument("foo", nullable=False) pytest.raises(ValueError, lambda: arg.convert(None, None)) def test_source_bad_location(self, mocker): - req = mocker.Mock(['values']) - arg = Argument('foo', location=['foo']) + req = mocker.Mock(["values"]) + arg = Argument("foo", location=["foo"]) assert len(arg.source(req)) == 0 # yes, basically you don't find it def test_source_default_location(self, mocker): - req = mocker.Mock(['values']) + req = mocker.Mock(["values"]) req._get_child_mock = lambda **kwargs: MultiDict() - arg = Argument('foo') + arg = Argument("foo") assert arg.source(req) == req.values def test_option_case_sensitive(self): - arg = Argument('foo', choices=['bar', 'baz'], case_sensitive=True) + arg = Argument("foo", choices=["bar", "baz"], case_sensitive=True) assert arg.case_sensitive is True # Insensitive - arg = Argument('foo', choices=['bar', 'baz'], case_sensitive=False) + arg = Argument("foo", choices=["bar", "baz"], case_sensitive=False) assert arg.case_sensitive is False # Default - arg = Argument('foo', choices=['bar', 'baz']) + arg = Argument("foo", choices=["bar", "baz"]) assert arg.case_sensitive is True @@ -840,136 +881,112 @@ def test_empty_parser(self): def test_primitive_types(self): parser = RequestParser() - parser.add_argument('int', type=int, help='Some integer') - parser.add_argument('str', type=str, help='Some string') - parser.add_argument('float', type=float, help='Some float') + parser.add_argument("int", type=int, help="Some integer") + parser.add_argument("str", type=str, help="Some string") + parser.add_argument("float", type=float, help="Some float") assert parser.__schema__ == [ { "description": "Some integer", "type": "integer", "name": "int", - "in": "query" - }, { + "in": "query", + }, + { "description": "Some string", "type": "string", "name": "str", - "in": "query" - }, { + "in": "query", + }, + { "description": "Some float", "type": "number", "name": "float", - "in": "query" - } + "in": "query", + }, ] def test_unknown_type(self): parser = RequestParser() - parser.add_argument('unknown', type=lambda v: v) - assert parser.__schema__ == [{ - 'name': 'unknown', - 'type': 'string', - 'in': 'query', - }] + parser.add_argument("unknown", type=lambda v: v) + assert parser.__schema__ == [ + {"name": "unknown", "type": "string", "in": "query",} + ] def test_required(self): parser = RequestParser() - parser.add_argument('int', type=int, required=True) - assert parser.__schema__ == [{ - 'name': 'int', - 'type': 'integer', - 'in': 'query', - 'required': True, - }] + parser.add_argument("int", type=int, required=True) + assert parser.__schema__ == [ + {"name": "int", "type": "integer", "in": "query", "required": True,} + ] def test_default(self): parser = RequestParser() - parser.add_argument('int', type=int, default=5) - assert parser.__schema__ == [{ - 'name': 'int', - 'type': 'integer', - 'in': 'query', - 'default': 5, - }] + parser.add_argument("int", type=int, default=5) + assert parser.__schema__ == [ + {"name": "int", "type": "integer", "in": "query", "default": 5,} + ] def test_default_as_false(self): parser = RequestParser() - parser.add_argument('bool', type=inputs.boolean, default=False) - assert parser.__schema__ == [{ - 'name': 'bool', - 'type': 'boolean', - 'in': 'query', - 'default': False, - }] + parser.add_argument("bool", type=inputs.boolean, default=False) + assert parser.__schema__ == [ + {"name": "bool", "type": "boolean", "in": "query", "default": False,} + ] def test_choices(self): parser = RequestParser() - parser.add_argument('string', type=str, choices=['a', 'b']) - assert parser.__schema__ == [{ - 'name': 'string', - 'type': 'string', - 'in': 'query', - 'enum': ['a', 'b'], - 'collectionFormat': 'multi', - }] + parser.add_argument("string", type=str, choices=["a", "b"]) + assert parser.__schema__ == [ + { + "name": "string", + "type": "string", + "in": "query", + "enum": ["a", "b"], + "collectionFormat": "multi", + } + ] def test_location(self): parser = RequestParser() - parser.add_argument('default', type=int) - parser.add_argument('in_values', type=int, location='values') - parser.add_argument('in_query', type=int, location='args') - parser.add_argument('in_headers', type=int, location='headers') - parser.add_argument('in_cookie', type=int, location='cookie') - assert parser.__schema__ == [{ - 'name': 'default', - 'type': 'integer', - 'in': 'query', - }, { - 'name': 'in_values', - 'type': 'integer', - 'in': 'query', - }, { - 'name': 'in_query', - 'type': 'integer', - 'in': 'query', - }, { - 'name': 'in_headers', - 'type': 'integer', - 'in': 'header', - }] + parser.add_argument("default", type=int) + parser.add_argument("in_values", type=int, location="values") + parser.add_argument("in_query", type=int, location="args") + parser.add_argument("in_headers", type=int, location="headers") + parser.add_argument("in_cookie", type=int, location="cookie") + assert parser.__schema__ == [ + {"name": "default", "type": "integer", "in": "query",}, + {"name": "in_values", "type": "integer", "in": "query",}, + {"name": "in_query", "type": "integer", "in": "query",}, + {"name": "in_headers", "type": "integer", "in": "header",}, + ] def test_location_json(self): parser = RequestParser() - parser.add_argument('in_json', type=str, location='json') - assert parser.__schema__ == [{ - 'name': 'in_json', - 'type': 'string', - 'in': 'body', - }] + parser.add_argument("in_json", type=str, location="json") + assert parser.__schema__ == [ + {"name": "in_json", "type": "string", "in": "body",} + ] def test_location_form(self): parser = RequestParser() - parser.add_argument('in_form', type=int, location='form') - assert parser.__schema__ == [{ - 'name': 'in_form', - 'type': 'integer', - 'in': 'formData', - }] + parser.add_argument("in_form", type=int, location="form") + assert parser.__schema__ == [ + {"name": "in_form", "type": "integer", "in": "formData",} + ] def test_location_files(self): parser = RequestParser() - parser.add_argument('in_files', type=FileStorage, location='files') - assert parser.__schema__ == [{ - 'name': 'in_files', - 'type': 'file', - 'in': 'formData', - }] + parser.add_argument("in_files", type=FileStorage, location="files") + assert parser.__schema__ == [ + {"name": "in_files", "type": "file", "in": "formData",} + ] def test_form_and_body_location(self): parser = RequestParser() - parser.add_argument('default', type=int) - parser.add_argument('in_form', type=int, location='form') - parser.add_argument('in_json', type=str, location='json') + parser.add_argument("default", type=int) + parser.add_argument("in_form", type=int, location="form") + parser.add_argument("in_json", type=str, location="json") with pytest.raises(SpecsError) as cm: parser.__schema__ @@ -977,73 +994,73 @@ def test_form_and_body_location(self): def test_files_and_body_location(self): parser = RequestParser() - parser.add_argument('default', type=int) - parser.add_argument('in_files', type=FileStorage, location='files') - parser.add_argument('in_json', type=str, location='json') + parser.add_argument("default", type=int) + parser.add_argument("in_files", type=FileStorage, location="files") + parser.add_argument("in_json", type=str, location="json") with pytest.raises(SpecsError) as cm: parser.__schema__ assert cm.value.msg == "Can't use formData and body at the same time" def test_models(self): - todo_fields = Model('Todo', { - 'task': fields.String(required=True, description='The task details') - }) + todo_fields = Model( + "Todo", + {"task": fields.String(required=True, description="The task details")}, + ) parser = RequestParser() - parser.add_argument('todo', type=todo_fields) - assert parser.__schema__ == [{ - 'name': 'todo', - 'type': 'Todo', - 'in': 'body', - }] + parser.add_argument("todo", type=todo_fields) + assert parser.__schema__ == [{"name": "todo", "type": "Todo", "in": "body",}] def test_lists(self): parser = RequestParser() - parser.add_argument('int', type=int, action='append') - assert parser.__schema__ == [{ - 'name': 'int', - 'in': 'query', - 'type': 'array', - 'collectionFormat': 'multi', - 'items': {'type': 'integer'} - }] + parser.add_argument("int", type=int, action="append") + assert parser.__schema__ == [ + { + "name": "int", + "in": "query", + "type": "array", + "collectionFormat": "multi", + "items": {"type": "integer"}, + } + ] def test_split_lists(self): parser = RequestParser() - parser.add_argument('int', type=int, action='split') - assert parser.__schema__ == [{ - 'name': 'int', - 'in': 'query', - 'type': 'array', - 'collectionFormat': 'csv', - 'items': {'type': 'integer'} - }] + parser.add_argument("int", type=int, action="split") + assert parser.__schema__ == [ + { + "name": "int", + "in": "query", + "type": "array", + "collectionFormat": "csv", + "items": {"type": "integer"}, + } + ] def test_schema_interface(self): def custom(value): pass custom.__schema__ = { - 'type': 'string', - 'format': 'custom-format', + "type": "string", + "format": "custom-format", } parser = RequestParser() - parser.add_argument('custom', type=custom) + parser.add_argument("custom", type=custom) - assert parser.__schema__ == [{ - 'name': 'custom', - 'in': 'query', - 'type': 'string', - 'format': 'custom-format', - }] + assert parser.__schema__ == [ + { + "name": "custom", + "in": "query", + "type": "string", + "format": "custom-format", + } + ] def test_callable_default(self): parser = RequestParser() - parser.add_argument('int', type=int, default=lambda: 5) - assert parser.__schema__ == [{ - 'name': 'int', - 'type': 'integer', - 'in': 'query', - 'default': 5, - }] + parser.add_argument("int", type=int, default=lambda: 5) + assert parser.__schema__ == [ + {"name": "int", "type": "integer", "in": "query", "default": 5,} + ] diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 586453e0..c12782c6 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -10,42 +10,40 @@ class SchemasTest: def test_lazyness(self): - schema = schemas.LazySchema('oas-2.0.json') + schema = schemas.LazySchema("oas-2.0.json") assert schema._schema is None - '' in schema # Trigger load + "" in schema # Trigger load assert schema._schema is not None assert isinstance(schema._schema, dict) def test_oas2_schema_is_present(self): - assert hasattr(schemas, 'OAS_20') + assert hasattr(schemas, "OAS_20") assert isinstance(schemas.OAS_20, schemas.LazySchema) class ValidationTest: def test_oas_20_valid(self): - assert schemas.validate({ - 'swagger': '2.0', - 'info': { - 'title': 'An empty minimal specification', - 'version': '1.0', - }, - 'paths': {}, - }) + assert schemas.validate( + { + "swagger": "2.0", + "info": {"title": "An empty minimal specification", "version": "1.0",}, + "paths": {}, + } + ) def test_oas_20_invalid(self): with pytest.raises(schemas.SchemaValidationError) as excinfo: - schemas.validate({ - 'swagger': '2.0', - 'should': 'not be here', - }) + schemas.validate( + {"swagger": "2.0", "should": "not be here",} + ) for error in excinfo.value.errors: assert isinstance(error, ValidationError) def test_unknown_schema(self): with pytest.raises(errors.SpecsError): - schemas.validate({'valid': 'no'}) + schemas.validate({"valid": "no"}) def test_unknown_version(self): with pytest.raises(errors.SpecsError): - schemas.validate({'swagger': '42.0'}) + schemas.validate({"swagger": "42.0"}) diff --git a/tests/test_swagger.py b/tests/test_swagger.py index 5c1609de..ed42e638 100644 --- a/tests/test_swagger.py +++ b/tests/test_swagger.py @@ -15,448 +15,420 @@ class SwaggerTest(object): def test_specs_endpoint(self, api, client): - data = client.get_specs('') - assert data['swagger'] == '2.0' - assert data['basePath'] == '/' - assert data['produces'] == ['application/json'] - assert data['consumes'] == ['application/json'] - assert data['paths'] == {} - assert 'info' in data - - @pytest.mark.api(prefix='/api') + data = client.get_specs("") + assert data["swagger"] == "2.0" + assert data["basePath"] == "/" + assert data["produces"] == ["application/json"] + assert data["consumes"] == ["application/json"] + assert data["paths"] == {} + assert "info" in data + + @pytest.mark.api(prefix="/api") def test_specs_endpoint_with_prefix(self, api, client): - data = client.get_specs('/api') - assert data['swagger'] == '2.0' - assert data['basePath'] == '/api' - assert data['produces'] == ['application/json'] - assert data['consumes'] == ['application/json'] - assert data['paths'] == {} - assert 'info' in data + data = client.get_specs("/api") + assert data["swagger"] == "2.0" + assert data["basePath"] == "/api" + assert data["produces"] == ["application/json"] + assert data["consumes"] == ["application/json"] + assert data["paths"] == {} + assert "info" in data def test_specs_endpoint_produces(self, api, client): def output_xml(data, code, headers=None): pass - api.representations['application/xml'] = output_xml + api.representations["application/xml"] = output_xml data = client.get_specs() - assert len(data['produces']) == 2 - assert 'application/json' in data['produces'] - assert 'application/xml' in data['produces'] + assert len(data["produces"]) == 2 + assert "application/json" in data["produces"] + assert "application/xml" in data["produces"] def test_specs_endpoint_info(self, app, client): - api = restx.Api(version='1.0', - title='My API', - description='This is a testing API', - terms_url='http://somewhere.com/terms/', - contact='Support', - contact_url='http://support.somewhere.com', - contact_email='contact@somewhere.com', - license='Apache 2.0', - license_url='http://www.apache.org/licenses/LICENSE-2.0.html' + api = restx.Api( + version="1.0", + title="My API", + description="This is a testing API", + terms_url="http://somewhere.com/terms/", + contact="Support", + contact_url="http://support.somewhere.com", + contact_email="contact@somewhere.com", + license="Apache 2.0", + license_url="http://www.apache.org/licenses/LICENSE-2.0.html", ) api.init_app(app) data = client.get_specs() - assert data['swagger'] == '2.0' - assert data['basePath'] == '/' - assert data['produces'] == ['application/json'] - assert data['paths'] == {} - - assert 'info' in data - assert data['info']['title'] == 'My API' - assert data['info']['version'] == '1.0' - assert data['info']['description'] == 'This is a testing API' - assert data['info']['termsOfService'] == 'http://somewhere.com/terms/' - assert data['info']['contact'] == { - 'name': 'Support', - 'url': 'http://support.somewhere.com', - 'email': 'contact@somewhere.com', + assert data["swagger"] == "2.0" + assert data["basePath"] == "/" + assert data["produces"] == ["application/json"] + assert data["paths"] == {} + + assert "info" in data + assert data["info"]["title"] == "My API" + assert data["info"]["version"] == "1.0" + assert data["info"]["description"] == "This is a testing API" + assert data["info"]["termsOfService"] == "http://somewhere.com/terms/" + assert data["info"]["contact"] == { + "name": "Support", + "url": "http://support.somewhere.com", + "email": "contact@somewhere.com", } - assert data['info']['license'] == { - 'name': 'Apache 2.0', - 'url': 'http://www.apache.org/licenses/LICENSE-2.0.html', + assert data["info"]["license"] == { + "name": "Apache 2.0", + "url": "http://www.apache.org/licenses/LICENSE-2.0.html", } def test_specs_endpoint_info_delayed(self, app, client): - api = restx.Api(version='1.0') - api.init_app(app, - title='My API', - description='This is a testing API', - terms_url='http://somewhere.com/terms/', - contact='Support', - contact_url='http://support.somewhere.com', - contact_email='contact@somewhere.com', - license='Apache 2.0', - license_url='http://www.apache.org/licenses/LICENSE-2.0.html' - ) - - data = client.get_specs() - - assert data['swagger'] == '2.0' - assert data['basePath'] == '/' - assert data['produces'] == ['application/json'] - assert data['paths'] == {} - - assert 'info' in data - assert data['info']['title'] == 'My API' - assert data['info']['version'] == '1.0' - assert data['info']['description'] == 'This is a testing API' - assert data['info']['termsOfService'] == 'http://somewhere.com/terms/' - assert data['info']['contact'] == { - 'name': 'Support', - 'url': 'http://support.somewhere.com', - 'email': 'contact@somewhere.com', + api = restx.Api(version="1.0") + api.init_app( + app, + title="My API", + description="This is a testing API", + terms_url="http://somewhere.com/terms/", + contact="Support", + contact_url="http://support.somewhere.com", + contact_email="contact@somewhere.com", + license="Apache 2.0", + license_url="http://www.apache.org/licenses/LICENSE-2.0.html", + ) + + data = client.get_specs() + + assert data["swagger"] == "2.0" + assert data["basePath"] == "/" + assert data["produces"] == ["application/json"] + assert data["paths"] == {} + + assert "info" in data + assert data["info"]["title"] == "My API" + assert data["info"]["version"] == "1.0" + assert data["info"]["description"] == "This is a testing API" + assert data["info"]["termsOfService"] == "http://somewhere.com/terms/" + assert data["info"]["contact"] == { + "name": "Support", + "url": "http://support.somewhere.com", + "email": "contact@somewhere.com", } - assert data['info']['license'] == { - 'name': 'Apache 2.0', - 'url': 'http://www.apache.org/licenses/LICENSE-2.0.html', + assert data["info"]["license"] == { + "name": "Apache 2.0", + "url": "http://www.apache.org/licenses/LICENSE-2.0.html", } def test_specs_endpoint_info_callable(self, app, client): - api = restx.Api(version=lambda: '1.0', - title=lambda: 'My API', - description=lambda: 'This is a testing API', - terms_url=lambda: 'http://somewhere.com/terms/', - contact=lambda: 'Support', - contact_url=lambda: 'http://support.somewhere.com', - contact_email=lambda: 'contact@somewhere.com', - license=lambda: 'Apache 2.0', - license_url=lambda: 'http://www.apache.org/licenses/LICENSE-2.0.html' + api = restx.Api( + version=lambda: "1.0", + title=lambda: "My API", + description=lambda: "This is a testing API", + terms_url=lambda: "http://somewhere.com/terms/", + contact=lambda: "Support", + contact_url=lambda: "http://support.somewhere.com", + contact_email=lambda: "contact@somewhere.com", + license=lambda: "Apache 2.0", + license_url=lambda: "http://www.apache.org/licenses/LICENSE-2.0.html", ) api.init_app(app) data = client.get_specs() - assert data['swagger'] == '2.0' - assert data['basePath'] == '/' - assert data['produces'] == ['application/json'] - assert data['paths'] == {} - - assert 'info' in data - assert data['info']['title'] == 'My API' - assert data['info']['version'] == '1.0' - assert data['info']['description'] == 'This is a testing API' - assert data['info']['termsOfService'] == 'http://somewhere.com/terms/' - assert data['info']['contact'] == { - 'name': 'Support', - 'url': 'http://support.somewhere.com', - 'email': 'contact@somewhere.com', + assert data["swagger"] == "2.0" + assert data["basePath"] == "/" + assert data["produces"] == ["application/json"] + assert data["paths"] == {} + + assert "info" in data + assert data["info"]["title"] == "My API" + assert data["info"]["version"] == "1.0" + assert data["info"]["description"] == "This is a testing API" + assert data["info"]["termsOfService"] == "http://somewhere.com/terms/" + assert data["info"]["contact"] == { + "name": "Support", + "url": "http://support.somewhere.com", + "email": "contact@somewhere.com", } - assert data['info']['license'] == { - 'name': 'Apache 2.0', - 'url': 'http://www.apache.org/licenses/LICENSE-2.0.html', + assert data["info"]["license"] == { + "name": "Apache 2.0", + "url": "http://www.apache.org/licenses/LICENSE-2.0.html", } def test_specs_endpoint_no_host(self, app, client): restx.Api(app) - data = client.get_specs('') - assert 'host' not in data - assert data['basePath'] == '/' + data = client.get_specs("") + assert "host" not in data + assert data["basePath"] == "/" - @pytest.mark.options(server_name='api.restx.org') + @pytest.mark.options(server_name="api.restx.org") def test_specs_endpoint_host(self, app, client): # app.config['SERVER_NAME'] = 'api.restx.org' restx.Api(app) - data = client.get_specs('') - assert data['host'] == 'api.restx.org' - assert data['basePath'] == '/' + data = client.get_specs("") + assert data["host"] == "api.restx.org" + assert data["basePath"] == "/" - @pytest.mark.options(server_name='api.restx.org') + @pytest.mark.options(server_name="api.restx.org") def test_specs_endpoint_host_with_url_prefix(self, app, client): - blueprint = Blueprint('api', __name__, url_prefix='/api/1') + blueprint = Blueprint("api", __name__, url_prefix="/api/1") restx.Api(blueprint) app.register_blueprint(blueprint) - data = client.get_specs('/api/1') - assert data['host'] == 'api.restx.org' - assert data['basePath'] == '/api/1' + data = client.get_specs("/api/1") + assert data["host"] == "api.restx.org" + assert data["basePath"] == "/api/1" - @pytest.mark.options(server_name='restx.org') + @pytest.mark.options(server_name="restx.org") def test_specs_endpoint_host_and_subdomain(self, app, client): - blueprint = Blueprint('api', __name__, subdomain='api') + blueprint = Blueprint("api", __name__, subdomain="api") restx.Api(blueprint) app.register_blueprint(blueprint) - data = client.get_specs(base_url='http://api.restx.org') - assert data['host'] == 'api.restx.org' - assert data['basePath'] == '/' + data = client.get_specs(base_url="http://api.restx.org") + assert data["host"] == "api.restx.org" + assert data["basePath"] == "/" def test_specs_endpoint_tags_short(self, app, client): - restx.Api(app, tags=['tag-1', 'tag-2', 'tag-3']) + restx.Api(app, tags=["tag-1", "tag-2", "tag-3"]) - data = client.get_specs('') - assert data['tags'] == [ - {'name': 'tag-1'}, - {'name': 'tag-2'}, - {'name': 'tag-3'} - ] + data = client.get_specs("") + assert data["tags"] == [{"name": "tag-1"}, {"name": "tag-2"}, {"name": "tag-3"}] def test_specs_endpoint_tags_tuple(self, app, client): - restx.Api(app, tags=[ - ('tag-1', 'Tag 1'), - ('tag-2', 'Tag 2'), - ('tag-3', 'Tag 3'), - ]) - - data = client.get_specs('') - assert data['tags'] == [ - {'name': 'tag-1', 'description': 'Tag 1'}, - {'name': 'tag-2', 'description': 'Tag 2'}, - {'name': 'tag-3', 'description': 'Tag 3'} + restx.Api( + app, tags=[("tag-1", "Tag 1"), ("tag-2", "Tag 2"), ("tag-3", "Tag 3"),] + ) + + data = client.get_specs("") + assert data["tags"] == [ + {"name": "tag-1", "description": "Tag 1"}, + {"name": "tag-2", "description": "Tag 2"}, + {"name": "tag-3", "description": "Tag 3"}, ] def test_specs_endpoint_tags_dict(self, app, client): - restx.Api(app, tags=[ - {'name': 'tag-1', 'description': 'Tag 1'}, - {'name': 'tag-2', 'description': 'Tag 2'}, - {'name': 'tag-3', 'description': 'Tag 3'}, - ]) - - data = client.get_specs('') - assert data['tags'] == [ - {'name': 'tag-1', 'description': 'Tag 1'}, - {'name': 'tag-2', 'description': 'Tag 2'}, - {'name': 'tag-3', 'description': 'Tag 3'} + restx.Api( + app, + tags=[ + {"name": "tag-1", "description": "Tag 1"}, + {"name": "tag-2", "description": "Tag 2"}, + {"name": "tag-3", "description": "Tag 3"}, + ], + ) + + data = client.get_specs("") + assert data["tags"] == [ + {"name": "tag-1", "description": "Tag 1"}, + {"name": "tag-2", "description": "Tag 2"}, + {"name": "tag-3", "description": "Tag 3"}, ] - @pytest.mark.api(tags=['ns', 'tag']) + @pytest.mark.api(tags=["ns", "tag"]) def test_specs_endpoint_tags_namespaces(self, api, client): - api.namespace('ns', 'Description') + api.namespace("ns", "Description") - data = client.get_specs('') - assert data['tags'] == [{'name': 'ns'}, {'name': 'tag'}] + data = client.get_specs("") + assert data["tags"] == [{"name": "ns"}, {"name": "tag"}] def test_specs_endpoint_invalid_tags(self, app, client): - api = restx.Api(app, tags=[ - {'description': 'Tag 1'} - ]) + api = restx.Api(app, tags=[{"description": "Tag 1"}]) - client.get_specs('', status=500) + client.get_specs("", status=500) - assert list(api.__schema__.keys()) == ['error'] + assert list(api.__schema__.keys()) == ["error"] def test_specs_endpoint_default_ns_with_resources(self, app, client): restx.Api(app) - data = client.get_specs('') - assert data['tags'] == [] + data = client.get_specs("") + assert data["tags"] == [] def test_specs_endpoint_default_ns_without_resources(self, app, client): api = restx.Api(app) - @api.route('/test', endpoint='test') + @api.route("/test", endpoint="test") class TestResource(restx.Resource): def get(self): return {} - data = client.get_specs('') - assert data['tags'] == [ - {'name': 'default', 'description': 'Default namespace'} - ] + data = client.get_specs("") + assert data["tags"] == [{"name": "default", "description": "Default namespace"}] def test_specs_endpoint_default_ns_with_specified_ns(self, app, client): api = restx.Api(app) - ns = api.namespace('ns', 'Test namespace') + ns = api.namespace("ns", "Test namespace") - @ns.route('/test2', endpoint='test2') - @api.route('/test', endpoint='test') + @ns.route("/test2", endpoint="test2") + @api.route("/test", endpoint="test") class TestResource(restx.Resource): def get(self): return {} - data = client.get_specs('') - assert data['tags'] == [ - {'name': 'default', 'description': 'Default namespace'}, - {'name': 'ns', 'description': 'Test namespace'} + data = client.get_specs("") + assert data["tags"] == [ + {"name": "default", "description": "Default namespace"}, + {"name": "ns", "description": "Test namespace"}, ] def test_specs_endpoint_specified_ns_without_default_ns(self, app, client): api = restx.Api(app) - ns = api.namespace('ns', 'Test namespace') + ns = api.namespace("ns", "Test namespace") - @ns.route('/', endpoint='test2') + @ns.route("/", endpoint="test2") class TestResource(restx.Resource): def get(self): return {} - data = client.get_specs('') - assert data['tags'] == [ - {'name': 'ns', 'description': 'Test namespace'} - ] + data = client.get_specs("") + assert data["tags"] == [{"name": "ns", "description": "Test namespace"}] def test_specs_endpoint_namespace_without_description(self, app, client): api = restx.Api(app) - ns = api.namespace('ns') + ns = api.namespace("ns") - @ns.route('/test', endpoint='test') + @ns.route("/test", endpoint="test") class TestResource(restx.Resource): def get(self): return {} - data = client.get_specs('') - assert data['tags'] == [{'name': 'ns'}] + data = client.get_specs("") + assert data["tags"] == [{"name": "ns"}] def test_specs_endpoint_namespace_all_resources_hidden(self, app, client): api = restx.Api(app) - ns = api.namespace('ns') + ns = api.namespace("ns") - @ns.route('/test', endpoint='test', doc=False) + @ns.route("/test", endpoint="test", doc=False) class TestResource(restx.Resource): def get(self): return {} - @ns.route('/test2', endpoint='test2') + @ns.route("/test2", endpoint="test2") @ns.hide class TestResource2(restx.Resource): def get(self): return {} - @ns.route('/test3', endpoint='test3') + @ns.route("/test3", endpoint="test3") @ns.doc(False) class TestResource3(restx.Resource): def get(self): return {} - data = client.get_specs('') - assert data['tags'] == [] + data = client.get_specs("") + assert data["tags"] == [] def test_specs_authorizations(self, app, client): - authorizations = { - 'apikey': { - 'type': 'apiKey', - 'in': 'header', - 'name': 'X-API' - } - } + authorizations = {"apikey": {"type": "apiKey", "in": "header", "name": "X-API"}} restx.Api(app, authorizations=authorizations) data = client.get_specs() - assert 'securityDefinitions' in data - assert data['securityDefinitions'] == authorizations + assert "securityDefinitions" in data + assert data["securityDefinitions"] == authorizations - @pytest.mark.api(prefix='/api') + @pytest.mark.api(prefix="/api") def test_minimal_documentation(self, api, client): - ns = api.namespace('ns', 'Test namespace') + ns = api.namespace("ns", "Test namespace") - @ns.route('/', endpoint='test') + @ns.route("/", endpoint="test") class TestResource(restx.Resource): def get(self): return {} - data = client.get_specs('/api') - paths = data['paths'] + data = client.get_specs("/api") + paths = data["paths"] assert len(paths.keys()) == 1 - assert '/ns/' in paths - assert 'get' in paths['/ns/'] - op = paths['/ns/']['get'] - assert op['tags'] == ['ns'] - assert op['operationId'] == 'get_test_resource' - assert 'parameters' not in op - assert 'summary' not in op - assert 'description' not in op - assert op['responses'] == { - '200': { - 'description': 'Success', - } - } + assert "/ns/" in paths + assert "get" in paths["/ns/"] + op = paths["/ns/"]["get"] + assert op["tags"] == ["ns"] + assert op["operationId"] == "get_test_resource" + assert "parameters" not in op + assert "summary" not in op + assert "description" not in op + assert op["responses"] == {"200": {"description": "Success",}} - assert url_for('api.test') == '/api/ns/' + assert url_for("api.test") == "/api/ns/" - @pytest.mark.api(prefix='/api', version='1.0') + @pytest.mark.api(prefix="/api", version="1.0") def test_default_ns_resource_documentation(self, api, client): - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): return {} - data = client.get_specs('/api') - paths = data['paths'] + data = client.get_specs("/api") + paths = data["paths"] assert len(paths.keys()) == 1 - assert '/test/' in paths - assert 'get' in paths['/test/'] - op = paths['/test/']['get'] - assert op['tags'] == ['default'] - assert op['responses'] == { - '200': { - 'description': 'Success', - } - } + assert "/test/" in paths + assert "get" in paths["/test/"] + op = paths["/test/"]["get"] + assert op["tags"] == ["default"] + assert op["responses"] == {"200": {"description": "Success",}} - assert len(data['tags']) == 1 - tag = data['tags'][0] - assert tag['name'] == 'default' - assert tag['description'] == 'Default namespace' + assert len(data["tags"]) == 1 + tag = data["tags"][0] + assert tag["name"] == "default" + assert tag["description"] == "Default namespace" - assert url_for('api.test') == '/api/test/' + assert url_for("api.test") == "/api/test/" - @pytest.mark.api(default='site', default_label='Site namespace') + @pytest.mark.api(default="site", default_label="Site namespace") def test_default_ns_resource_documentation_with_override(self, api, client): - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): return {} data = client.get_specs() - paths = data['paths'] + paths = data["paths"] assert len(paths.keys()) == 1 - assert '/test/' in paths - assert 'get' in paths['/test/'] - op = paths['/test/']['get'] - assert op['tags'] == ['site'] - assert op['responses'] == { - '200': { - 'description': 'Success', - } - } + assert "/test/" in paths + assert "get" in paths["/test/"] + op = paths["/test/"]["get"] + assert op["tags"] == ["site"] + assert op["responses"] == {"200": {"description": "Success",}} - assert len(data['tags']) == 1 - tag = data['tags'][0] - assert tag['name'] == 'site' - assert tag['description'] == 'Site namespace' + assert len(data["tags"]) == 1 + tag = data["tags"][0] + assert tag["name"] == "site" + assert tag["description"] == "Site namespace" - assert url_for('api.test') == '/test/' + assert url_for("api.test") == "/test/" - @pytest.mark.api(prefix='/api') + @pytest.mark.api(prefix="/api") def test_ns_resource_documentation(self, api, client): - ns = api.namespace('ns', 'Test namespace') + ns = api.namespace("ns", "Test namespace") - @ns.route('/', endpoint='test') + @ns.route("/", endpoint="test") class TestResource(restx.Resource): def get(self): return {} - data = client.get_specs('/api') - paths = data['paths'] + data = client.get_specs("/api") + paths = data["paths"] assert len(paths.keys()) == 1 - assert '/ns/' in paths - assert 'get' in paths['/ns/'] - op = paths['/ns/']['get'] - assert op['tags'] == ['ns'] - assert op['responses'] == { - '200': { - 'description': 'Success', - } - } - assert 'parameters' not in op + assert "/ns/" in paths + assert "get" in paths["/ns/"] + op = paths["/ns/"]["get"] + assert op["tags"] == ["ns"] + assert op["responses"] == {"200": {"description": "Success",}} + assert "parameters" not in op - assert len(data['tags']) == 1 - tag = data['tags'][-1] - assert tag['name'] == 'ns' - assert tag['description'] == 'Test namespace' + assert len(data["tags"]) == 1 + tag = data["tags"][-1] + assert tag["name"] == "ns" + assert tag["description"] == "Test namespace" - assert url_for('api.test') == '/api/ns/' + assert url_for("api.test") == "/api/ns/" def test_ns_resource_documentation_lazy(self, app, client): api = restx.Api() - ns = api.namespace('ns', 'Test namespace') + ns = api.namespace("ns", "Test namespace") - @ns.route('/', endpoint='test') + @ns.route("/", endpoint="test") class TestResource(restx.Resource): def get(self): return {} @@ -464,219 +436,217 @@ def get(self): api.init_app(app) data = client.get_specs() - paths = data['paths'] + paths = data["paths"] assert len(paths.keys()) == 1 - assert '/ns/' in paths - assert 'get' in paths['/ns/'] - op = paths['/ns/']['get'] - assert op['tags'] == ['ns'] - assert op['responses'] == { - '200': { - 'description': 'Success', - } - } + assert "/ns/" in paths + assert "get" in paths["/ns/"] + op = paths["/ns/"]["get"] + assert op["tags"] == ["ns"] + assert op["responses"] == {"200": {"description": "Success",}} - assert len(data['tags']) == 1 - tag = data['tags'][-1] - assert tag['name'] == 'ns' - assert tag['description'] == 'Test namespace' + assert len(data["tags"]) == 1 + tag = data["tags"][-1] + assert tag["name"] == "ns" + assert tag["description"] == "Test namespace" - assert url_for('test') == '/ns/' + assert url_for("test") == "/ns/" def test_methods_docstring_to_summary(self, api, client): - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): - ''' + """ GET operation - ''' + """ return {} def post(self): - '''POST operation. + """POST operation. Should be ignored - ''' + """ return {} def put(self): - '''PUT operation. Should be ignored''' + """PUT operation. Should be ignored""" return {} def delete(self): - ''' + """ DELETE operation. Should be ignored. - ''' + """ return {} data = client.get_specs() - path = data['paths']['/test/'] + path = data["paths"]["/test/"] assert len(path.keys()) == 4 for method in path.keys(): operation = path[method] - assert method in ('get', 'post', 'put', 'delete') - assert operation['summary'] == '{0} operation'.format(method.upper()) - assert operation['operationId'] == '{0}_test_resource'.format(method.lower()) + assert method in ("get", "post", "put", "delete") + assert operation["summary"] == "{0} operation".format(method.upper()) + assert operation["operationId"] == "{0}_test_resource".format( + method.lower() + ) # assert operation['parameters'] == [] def test_path_parameter_no_type(self, api, client): - @api.route('/id//', endpoint='by-id') + @api.route("/id//", endpoint="by-id") class ByIdResource(restx.Resource): def get(self, id): return {} data = client.get_specs() - assert '/id/{id}/' in data['paths'] + assert "/id/{id}/" in data["paths"] - path = data['paths']['/id/{id}/'] - assert len(path['parameters']) == 1 + path = data["paths"]["/id/{id}/"] + assert len(path["parameters"]) == 1 - parameter = path['parameters'][0] - assert parameter['name'] == 'id' - assert parameter['type'] == 'string' - assert parameter['in'] == 'path' - assert parameter['required'] is True + parameter = path["parameters"][0] + assert parameter["name"] == "id" + assert parameter["type"] == "string" + assert parameter["in"] == "path" + assert parameter["required"] is True def test_path_parameter_with_type(self, api, client): - @api.route('/name//', endpoint='by-name') + @api.route("/name//", endpoint="by-name") class ByNameResource(restx.Resource): def get(self, age): return {} data = client.get_specs() - assert '/name/{age}/' in data['paths'] + assert "/name/{age}/" in data["paths"] - path = data['paths']['/name/{age}/'] - assert len(path['parameters']) == 1 + path = data["paths"]["/name/{age}/"] + assert len(path["parameters"]) == 1 - parameter = path['parameters'][0] - assert parameter['name'] == 'age' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'path' - assert parameter['required'] is True + parameter = path["parameters"][0] + assert parameter["name"] == "age" + assert parameter["type"] == "integer" + assert parameter["in"] == "path" + assert parameter["required"] is True def test_path_parameter_with_type_with_argument(self, api, client): - @api.route('/name//', endpoint='by-name') + @api.route("/name//", endpoint="by-name") class ByNameResource(restx.Resource): def get(self, id): return {} data = client.get_specs() - assert '/name/{id}/' in data['paths'] + assert "/name/{id}/" in data["paths"] - path = data['paths']['/name/{id}/'] - assert len(path['parameters']) == 1 + path = data["paths"]["/name/{id}/"] + assert len(path["parameters"]) == 1 - parameter = path['parameters'][0] - assert parameter['name'] == 'id' - assert parameter['type'] == 'string' - assert parameter['in'] == 'path' - assert parameter['required'] is True + parameter = path["parameters"][0] + assert parameter["name"] == "id" + assert parameter["type"] == "string" + assert parameter["in"] == "path" + assert parameter["required"] is True def test_path_parameter_with_explicit_details(self, api, client): - @api.route('/name//', endpoint='by-name', doc={ - 'params': { - 'age': {'description': 'An age'} - } - }) + @api.route( + "/name//", + endpoint="by-name", + doc={"params": {"age": {"description": "An age"}}}, + ) class ByNameResource(restx.Resource): def get(self, age): return {} data = client.get_specs() - assert '/name/{age}/' in data['paths'] + assert "/name/{age}/" in data["paths"] - path = data['paths']['/name/{age}/'] - assert len(path['parameters']) == 1 + path = data["paths"]["/name/{age}/"] + assert len(path["parameters"]) == 1 - parameter = path['parameters'][0] - assert parameter['name'] == 'age' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'path' - assert parameter['required'] is True - assert parameter['description'] == 'An age' + parameter = path["parameters"][0] + assert parameter["name"] == "age" + assert parameter["type"] == "integer" + assert parameter["in"] == "path" + assert parameter["required"] is True + assert parameter["description"] == "An age" def test_path_parameter_with_decorator_details(self, api, client): - @api.route('/name//') - @api.param('age', 'An age') + @api.route("/name//") + @api.param("age", "An age") class ByNameResource(restx.Resource): def get(self, age): return {} data = client.get_specs() - assert '/name/{age}/' in data['paths'] + assert "/name/{age}/" in data["paths"] - path = data['paths']['/name/{age}/'] - assert len(path['parameters']) == 1 + path = data["paths"]["/name/{age}/"] + assert len(path["parameters"]) == 1 - parameter = path['parameters'][0] - assert parameter['name'] == 'age' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'path' - assert parameter['required'] is True - assert parameter['description'] == 'An age' + parameter = path["parameters"][0] + assert parameter["name"] == "age" + assert parameter["type"] == "integer" + assert parameter["in"] == "path" + assert parameter["required"] is True + assert parameter["description"] == "An age" def test_expect_parser(self, api, client): parser = api.parser() - parser.add_argument('param', type=int, help='Some param') - parser.add_argument('jsonparam', type=str, location='json', help='Some param') + parser.add_argument("param", type=int, help="Some param") + parser.add_argument("jsonparam", type=str, location="json", help="Some param") - @api.route('/with-parser/', endpoint='with-parser') + @api.route("/with-parser/", endpoint="with-parser") class WithParserResource(restx.Resource): @api.expect(parser) def get(self): return {} data = client.get_specs() - assert '/with-parser/' in data['paths'] + assert "/with-parser/" in data["paths"] - op = data['paths']['/with-parser/']['get'] - assert len(op['parameters']) == 2 + op = data["paths"]["/with-parser/"]["get"] + assert len(op["parameters"]) == 2 - parameter = [o for o in op['parameters'] if o['in'] == 'query'][0] - assert parameter['name'] == 'param' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'query' - assert parameter['description'] == 'Some param' + parameter = [o for o in op["parameters"] if o["in"] == "query"][0] + assert parameter["name"] == "param" + assert parameter["type"] == "integer" + assert parameter["in"] == "query" + assert parameter["description"] == "Some param" - parameter = [o for o in op['parameters'] if o['in'] == 'body'][0] - assert parameter['name'] == 'payload' - assert parameter['required'] - assert parameter['in'] == 'body' - assert parameter['schema']['properties']['jsonparam']['type'] == 'string' + parameter = [o for o in op["parameters"] if o["in"] == "body"][0] + assert parameter["name"] == "payload" + assert parameter["required"] + assert parameter["in"] == "body" + assert parameter["schema"]["properties"]["jsonparam"]["type"] == "string" def test_expect_parser_on_class(self, api, client): parser = api.parser() - parser.add_argument('param', type=int, help='Some param') + parser.add_argument("param", type=int, help="Some param") - @api.route('/with-parser/', endpoint='with-parser') + @api.route("/with-parser/", endpoint="with-parser") @api.expect(parser) class WithParserResource(restx.Resource): def get(self): return {} data = client.get_specs() - assert '/with-parser/' in data['paths'] + assert "/with-parser/" in data["paths"] - path = data['paths']['/with-parser/'] - assert len(path['parameters']) == 1 + path = data["paths"]["/with-parser/"] + assert len(path["parameters"]) == 1 - parameter = path['parameters'][0] - assert parameter['name'] == 'param' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'query' - assert parameter['description'] == 'Some param' + parameter = path["parameters"][0] + assert parameter["name"] == "param" + assert parameter["type"] == "integer" + assert parameter["in"] == "query" + assert parameter["description"] == "Some param" def test_method_parser_on_class(self, api, client): parser = api.parser() - parser.add_argument('param', type=int, help='Some param') + parser.add_argument("param", type=int, help="Some param") - @api.route('/with-parser/', endpoint='with-parser') - @api.doc(get={'expect': parser}) + @api.route("/with-parser/", endpoint="with-parser") + @api.doc(get={"expect": parser}) class WithParserResource(restx.Resource): def get(self): return {} @@ -685,219 +655,230 @@ def post(self): return {} data = client.get_specs() - assert '/with-parser/' in data['paths'] + assert "/with-parser/" in data["paths"] - op = data['paths']['/with-parser/']['get'] - assert len(op['parameters']) == 1 + op = data["paths"]["/with-parser/"]["get"] + assert len(op["parameters"]) == 1 - parameter = op['parameters'][0] - assert parameter['name'] == 'param' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'query' - assert parameter['description'] == 'Some param' + parameter = op["parameters"][0] + assert parameter["name"] == "param" + assert parameter["type"] == "integer" + assert parameter["in"] == "query" + assert parameter["description"] == "Some param" - op = data['paths']['/with-parser/']['post'] - assert 'parameters' not in op + op = data["paths"]["/with-parser/"]["post"] + assert "parameters" not in op def test_parser_parameters_override(self, api, client): parser = api.parser() - parser.add_argument('param', type=int, help='Some param') + parser.add_argument("param", type=int, help="Some param") - @api.route('/with-parser/', endpoint='with-parser') + @api.route("/with-parser/", endpoint="with-parser") class WithParserResource(restx.Resource): @api.expect(parser) - @api.doc(params={'param': {'description': 'New description'}}) + @api.doc(params={"param": {"description": "New description"}}) def get(self): return {} data = client.get_specs() - assert '/with-parser/' in data['paths'] + assert "/with-parser/" in data["paths"] - op = data['paths']['/with-parser/']['get'] - assert len(op['parameters']) == 1 + op = data["paths"]["/with-parser/"]["get"] + assert len(op["parameters"]) == 1 - parameter = op['parameters'][0] - assert parameter['name'] == 'param' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'query' - assert parameter['description'] == 'New description' + parameter = op["parameters"][0] + assert parameter["name"] == "param" + assert parameter["type"] == "integer" + assert parameter["in"] == "query" + assert parameter["description"] == "New description" def test_parser_parameter_in_form(self, api, client): parser = api.parser() - parser.add_argument('param', type=int, help='Some param', location='form') + parser.add_argument("param", type=int, help="Some param", location="form") - @api.route('/with-parser/', endpoint='with-parser') + @api.route("/with-parser/", endpoint="with-parser") class WithParserResource(restx.Resource): @api.expect(parser) def get(self): return {} data = client.get_specs() - assert '/with-parser/' in data['paths'] + assert "/with-parser/" in data["paths"] - op = data['paths']['/with-parser/']['get'] - assert len(op['parameters']) == 1 + op = data["paths"]["/with-parser/"]["get"] + assert len(op["parameters"]) == 1 - parameter = op['parameters'][0] - assert parameter['name'] == 'param' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'formData' - assert parameter['description'] == 'Some param' + parameter = op["parameters"][0] + assert parameter["name"] == "param" + assert parameter["type"] == "integer" + assert parameter["in"] == "formData" + assert parameter["description"] == "Some param" - assert op['consumes'] == ['application/x-www-form-urlencoded', 'multipart/form-data'] + assert op["consumes"] == [ + "application/x-www-form-urlencoded", + "multipart/form-data", + ] def test_parser_parameter_in_files(self, api, client): parser = api.parser() - parser.add_argument('in_files', type=FileStorage, location='files') + parser.add_argument("in_files", type=FileStorage, location="files") - @api.route('/with-parser/', endpoint='with-parser') + @api.route("/with-parser/", endpoint="with-parser") class WithParserResource(restx.Resource): @api.expect(parser) def get(self): return {} data = client.get_specs() - assert '/with-parser/' in data['paths'] + assert "/with-parser/" in data["paths"] - op = data['paths']['/with-parser/']['get'] - assert len(op['parameters']) == 1 + op = data["paths"]["/with-parser/"]["get"] + assert len(op["parameters"]) == 1 - parameter = op['parameters'][0] - assert parameter['name'] == 'in_files' - assert parameter['type'] == 'file' - assert parameter['in'] == 'formData' + parameter = op["parameters"][0] + assert parameter["name"] == "in_files" + assert parameter["type"] == "file" + assert parameter["in"] == "formData" - assert op['consumes'] == ['multipart/form-data'] + assert op["consumes"] == ["multipart/form-data"] def test_parser_parameter_in_files_on_class(self, api, client): parser = api.parser() - parser.add_argument('in_files', type=FileStorage, location='files') + parser.add_argument("in_files", type=FileStorage, location="files") - @api.route('/with-parser/', endpoint='with-parser') + @api.route("/with-parser/", endpoint="with-parser") @api.expect(parser) class WithParserResource(restx.Resource): def get(self): return {} data = client.get_specs() - assert '/with-parser/' in data['paths'] + assert "/with-parser/" in data["paths"] - path = data['paths']['/with-parser/'] - assert len(path['parameters']) == 1 + path = data["paths"]["/with-parser/"] + assert len(path["parameters"]) == 1 - parameter = path['parameters'][0] - assert parameter['name'] == 'in_files' - assert parameter['type'] == 'file' - assert parameter['in'] == 'formData' + parameter = path["parameters"][0] + assert parameter["name"] == "in_files" + assert parameter["type"] == "file" + assert parameter["in"] == "formData" - assert 'consumes' not in path + assert "consumes" not in path - op = path['get'] - assert 'consumes' in op - assert op['consumes'] == ['multipart/form-data'] + op = path["get"] + assert "consumes" in op + assert op["consumes"] == ["multipart/form-data"] def test_explicit_parameters(self, api, client): - @api.route('/name//', endpoint='by-name') + @api.route("/name//", endpoint="by-name") class ByNameResource(restx.Resource): - @api.doc(params={ - 'q': { - 'type': 'string', - 'in': 'query', - 'description': 'A query string', + @api.doc( + params={ + "q": { + "type": "string", + "in": "query", + "description": "A query string", + } } - }) + ) def get(self, age): return {} data = client.get_specs() - assert '/name/{age}/' in data['paths'] + assert "/name/{age}/" in data["paths"] - path = data['paths']['/name/{age}/'] - assert len(path['parameters']) == 1 + path = data["paths"]["/name/{age}/"] + assert len(path["parameters"]) == 1 - parameter = path['parameters'][0] - assert parameter['name'] == 'age' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'path' - assert parameter['required'] is True + parameter = path["parameters"][0] + assert parameter["name"] == "age" + assert parameter["type"] == "integer" + assert parameter["in"] == "path" + assert parameter["required"] is True - op = path['get'] - assert len(op['parameters']) == 1 + op = path["get"] + assert len(op["parameters"]) == 1 - parameter = op['parameters'][0] - assert parameter['name'] == 'q' - assert parameter['type'] == 'string' - assert parameter['in'] == 'query' - assert parameter['description'] == 'A query string' + parameter = op["parameters"][0] + assert parameter["name"] == "q" + assert parameter["type"] == "string" + assert parameter["in"] == "query" + assert parameter["description"] == "A query string" def test_explicit_parameters_with_decorator(self, api, client): - @api.route('/name/') + @api.route("/name/") class ByNameResource(restx.Resource): - @api.param('q', 'A query string', type='string', _in='formData') + @api.param("q", "A query string", type="string", _in="formData") def get(self, age): return {} data = client.get_specs() - assert '/name/' in data['paths'] + assert "/name/" in data["paths"] - op = data['paths']['/name/']['get'] - assert len(op['parameters']) == 1 + op = data["paths"]["/name/"]["get"] + assert len(op["parameters"]) == 1 - parameter = op['parameters'][0] - assert parameter['name'] == 'q' - assert parameter['type'] == 'string' - assert parameter['in'] == 'formData' - assert parameter['description'] == 'A query string' + parameter = op["parameters"][0] + assert parameter["name"] == "q" + assert parameter["type"] == "string" + assert parameter["in"] == "formData" + assert parameter["description"] == "A query string" def test_class_explicit_parameters(self, api, client): - @api.route('/name//', endpoint='by-name', doc={ - 'params': { - 'q': { - 'type': 'string', - 'in': 'query', - 'description': 'A query string', + @api.route( + "/name//", + endpoint="by-name", + doc={ + "params": { + "q": { + "type": "string", + "in": "query", + "description": "A query string", + } } - } - }) + }, + ) class ByNameResource(restx.Resource): def get(self, age): return {} data = client.get_specs() - assert '/name/{age}/' in data['paths'] + assert "/name/{age}/" in data["paths"] - path = data['paths']['/name/{age}/'] - assert len(path['parameters']) == 2 + path = data["paths"]["/name/{age}/"] + assert len(path["parameters"]) == 2 - by_name = dict((p['name'], p) for p in path['parameters']) + by_name = dict((p["name"], p) for p in path["parameters"]) - parameter = by_name['age'] - assert parameter['name'] == 'age' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'path' - assert parameter['required'] is True + parameter = by_name["age"] + assert parameter["name"] == "age" + assert parameter["type"] == "integer" + assert parameter["in"] == "path" + assert parameter["required"] is True - parameter = by_name['q'] - assert parameter['name'] == 'q' - assert parameter['type'] == 'string' - assert parameter['in'] == 'query' - assert parameter['description'] == 'A query string' + parameter = by_name["q"] + assert parameter["name"] == "q" + assert parameter["type"] == "string" + assert parameter["in"] == "query" + assert parameter["description"] == "A query string" def test_explicit_parameters_override(self, api, client): - @api.route('/name//', endpoint='by-name', doc={ - 'params': { - 'q': { - 'type': 'string', - 'in': 'query', - 'description': 'Overriden description', - }, - 'age': { - 'description': 'An age' + @api.route( + "/name//", + endpoint="by-name", + doc={ + "params": { + "q": { + "type": "string", + "in": "query", + "description": "Overriden description", + }, + "age": {"description": "An age"}, } - } - }) + }, + ) class ByNameResource(restx.Resource): - @api.doc(params={'q': {'description': 'A query string'}}) + @api.doc(params={"q": {"description": "A query string"}}) def get(self, age): return {} @@ -905,60 +886,60 @@ def post(self, age): pass data = client.get_specs() - assert '/name/{age}/' in data['paths'] + assert "/name/{age}/" in data["paths"] - path = data['paths']['/name/{age}/'] - assert len(path['parameters']) == 1 + path = data["paths"]["/name/{age}/"] + assert len(path["parameters"]) == 1 - by_name = dict((p['name'], p) for p in path['parameters']) + by_name = dict((p["name"], p) for p in path["parameters"]) - parameter = by_name['age'] - assert parameter['name'] == 'age' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'path' - assert parameter['required'] is True - assert parameter['description'] == 'An age' + parameter = by_name["age"] + assert parameter["name"] == "age" + assert parameter["type"] == "integer" + assert parameter["in"] == "path" + assert parameter["required"] is True + assert parameter["description"] == "An age" # Don't duplicate parameters - assert 'q' not in by_name + assert "q" not in by_name - get = data['paths']['/name/{age}/']['get'] - assert len(get['parameters']) == 1 + get = data["paths"]["/name/{age}/"]["get"] + assert len(get["parameters"]) == 1 - parameter = get['parameters'][0] - assert parameter['name'] == 'q' - assert parameter['type'] == 'string' - assert parameter['in'] == 'query' - assert parameter['description'] == 'A query string' + parameter = get["parameters"][0] + assert parameter["name"] == "q" + assert parameter["type"] == "string" + assert parameter["in"] == "query" + assert parameter["description"] == "A query string" - post = data['paths']['/name/{age}/']['post'] - assert len(post['parameters']) == 1 + post = data["paths"]["/name/{age}/"]["post"] + assert len(post["parameters"]) == 1 - parameter = post['parameters'][0] - assert parameter['name'] == 'q' - assert parameter['type'] == 'string' - assert parameter['in'] == 'query' - assert parameter['description'] == 'Overriden description' + parameter = post["parameters"][0] + assert parameter["name"] == "q" + assert parameter["type"] == "string" + assert parameter["in"] == "query" + assert parameter["description"] == "Overriden description" def test_explicit_parameters_override_by_method(self, api, client): - @api.route('/name//', endpoint='by-name', doc={ - 'get': { - 'params': { - 'q': { - 'type': 'string', - 'in': 'query', - 'description': 'A query string', + @api.route( + "/name//", + endpoint="by-name", + doc={ + "get": { + "params": { + "q": { + "type": "string", + "in": "query", + "description": "A query string", + } } - } + }, + "params": {"age": {"description": "An age"}}, }, - 'params': { - 'age': { - 'description': 'An age' - } - } - }) + ) class ByNameResource(restx.Resource): - @api.doc(params={'age': {'description': 'Overriden'}}) + @api.doc(params={"age": {"description": "Overriden"}}) def get(self, age): return {} @@ -966,60 +947,60 @@ def post(self, age): return {} data = client.get_specs() - assert '/name/{age}/' in data['paths'] + assert "/name/{age}/" in data["paths"] - path = data['paths']['/name/{age}/'] - assert 'parameters' not in path + path = data["paths"]["/name/{age}/"] + assert "parameters" not in path - get = path['get'] - assert len(get['parameters']) == 2 + get = path["get"] + assert len(get["parameters"]) == 2 - by_name = dict((p['name'], p) for p in get['parameters']) + by_name = dict((p["name"], p) for p in get["parameters"]) - parameter = by_name['age'] - assert parameter['name'] == 'age' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'path' - assert parameter['required'] is True - assert parameter['description'] == 'Overriden' + parameter = by_name["age"] + assert parameter["name"] == "age" + assert parameter["type"] == "integer" + assert parameter["in"] == "path" + assert parameter["required"] is True + assert parameter["description"] == "Overriden" - parameter = by_name['q'] - assert parameter['name'] == 'q' - assert parameter['type'] == 'string' - assert parameter['in'] == 'query' - assert parameter['description'] == 'A query string' + parameter = by_name["q"] + assert parameter["name"] == "q" + assert parameter["type"] == "string" + assert parameter["in"] == "query" + assert parameter["description"] == "A query string" - post = path['post'] - assert len(post['parameters']) == 1 + post = path["post"] + assert len(post["parameters"]) == 1 - by_name = dict((p['name'], p) for p in post['parameters']) + by_name = dict((p["name"], p) for p in post["parameters"]) - parameter = by_name['age'] - assert parameter['name'] == 'age' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'path' - assert parameter['required'] is True - assert parameter['description'] == 'An age' + parameter = by_name["age"] + assert parameter["name"] == "age" + assert parameter["type"] == "integer" + assert parameter["in"] == "path" + assert parameter["required"] is True + assert parameter["description"] == "An age" def test_parameters_cascading_with_apidoc_false(self, api, client): - @api.route('/name//', endpoint='by-name', doc={ - 'get': { - 'params': { - 'q': { - 'type': 'string', - 'in': 'query', - 'description': 'A query string', + @api.route( + "/name//", + endpoint="by-name", + doc={ + "get": { + "params": { + "q": { + "type": "string", + "in": "query", + "description": "A query string", + } } - } + }, + "params": {"age": {"description": "An age"}}, }, - 'params': { - 'age': { - 'description': 'An age' - } - } - }) + ) class ByNameResource(restx.Resource): - @api.doc(params={'age': {'description': 'Overriden'}}) + @api.doc(params={"age": {"description": "Overriden"}}) def get(self, age): return {} @@ -1028,33 +1009,31 @@ def post(self, age): return {} data = client.get_specs() - assert '/name/{age}/' in data['paths'] + assert "/name/{age}/" in data["paths"] - path = data['paths']['/name/{age}/'] - assert 'parameters' not in path + path = data["paths"]["/name/{age}/"] + assert "parameters" not in path - get = path['get'] - assert len(get['parameters']) == 2 + get = path["get"] + assert len(get["parameters"]) == 2 - by_name = dict((p['name'], p) for p in get['parameters']) - assert 'age' in by_name - assert 'q' in by_name + by_name = dict((p["name"], p) for p in get["parameters"]) + assert "age" in by_name + assert "q" in by_name - assert 'post' not in path + assert "post" not in path def test_explicit_parameters_desription_shortcut(self, api, client): - @api.route('/name//', endpoint='by-name', doc={ - 'get': { - 'params': { - 'q': 'A query string', - } + @api.route( + "/name//", + endpoint="by-name", + doc={ + "get": {"params": {"q": "A query string",}}, + "params": {"age": "An age"}, }, - 'params': { - 'age': 'An age' - } - }) + ) class ByNameResource(restx.Resource): - @api.doc(params={'age': 'Overriden'}) + @api.doc(params={"age": "Overriden"}) def get(self, age): return {} @@ -1062,395 +1041,360 @@ def post(self, age): return {} data = client.get_specs() - assert '/name/{age}/' in data['paths'] + assert "/name/{age}/" in data["paths"] - path = data['paths']['/name/{age}/'] - assert 'parameters' not in path + path = data["paths"]["/name/{age}/"] + assert "parameters" not in path - get = path['get'] - assert len(get['parameters']) == 2 + get = path["get"] + assert len(get["parameters"]) == 2 - by_name = dict((p['name'], p) for p in get['parameters']) + by_name = dict((p["name"], p) for p in get["parameters"]) - parameter = by_name['age'] - assert parameter['name'] == 'age' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'path' - assert parameter['required'] is True - assert parameter['description'] == 'Overriden' + parameter = by_name["age"] + assert parameter["name"] == "age" + assert parameter["type"] == "integer" + assert parameter["in"] == "path" + assert parameter["required"] is True + assert parameter["description"] == "Overriden" - parameter = by_name['q'] - assert parameter['name'] == 'q' - assert parameter['type'] == 'string' - assert parameter['in'] == 'query' - assert parameter['description'] == 'A query string' + parameter = by_name["q"] + assert parameter["name"] == "q" + assert parameter["type"] == "string" + assert parameter["in"] == "query" + assert parameter["description"] == "A query string" - post = path['post'] - assert len(post['parameters']) == 1 + post = path["post"] + assert len(post["parameters"]) == 1 - by_name = dict((p['name'], p) for p in post['parameters']) + by_name = dict((p["name"], p) for p in post["parameters"]) - parameter = by_name['age'] - assert parameter['name'] == 'age' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'path' - assert parameter['required'] is True - assert parameter['description'] == 'An age' + parameter = by_name["age"] + assert parameter["name"] == "age" + assert parameter["type"] == "integer" + assert parameter["in"] == "path" + assert parameter["required"] is True + assert parameter["description"] == "An age" - assert 'q' not in by_name + assert "q" not in by_name def test_explicit_parameters_native_types(self, api, client): - @api.route('/types/', endpoint='native') + @api.route("/types/", endpoint="native") class NativeTypesResource(restx.Resource): - @api.doc(params={ - 'int': { - 'type': int, - 'in': 'query', - }, - 'float': { - 'type': float, - 'in': 'query', - }, - 'bool': { - 'type': bool, - 'in': 'query', - }, - 'str': { - 'type': str, - 'in': 'query', - }, - 'int-array': { - 'type': [int], - 'in': 'query', - }, - 'float-array': { - 'type': [float], - 'in': 'query', - }, - 'bool-array': { - 'type': [bool], - 'in': 'query', - }, - 'str-array': { - 'type': [str], - 'in': 'query', + @api.doc( + params={ + "int": {"type": int, "in": "query",}, + "float": {"type": float, "in": "query",}, + "bool": {"type": bool, "in": "query",}, + "str": {"type": str, "in": "query",}, + "int-array": {"type": [int], "in": "query",}, + "float-array": {"type": [float], "in": "query",}, + "bool-array": {"type": [bool], "in": "query",}, + "str-array": {"type": [str], "in": "query",}, } - }) + ) def get(self, age): return {} data = client.get_specs() - op = data['paths']['/types/']['get'] + op = data["paths"]["/types/"]["get"] - parameters = dict((p['name'], p) for p in op['parameters']) + parameters = dict((p["name"], p) for p in op["parameters"]) - assert parameters['int']['type'] == 'integer' - assert parameters['float']['type'] == 'number' - assert parameters['str']['type'] == 'string' - assert parameters['bool']['type'] == 'boolean' + assert parameters["int"]["type"] == "integer" + assert parameters["float"]["type"] == "number" + assert parameters["str"]["type"] == "string" + assert parameters["bool"]["type"] == "boolean" - assert parameters['int-array']['type'] == 'array' - assert parameters['int-array']['items']['type'] == 'integer' - assert parameters['float-array']['type'] == 'array' - assert parameters['float-array']['items']['type'] == 'number' - assert parameters['str-array']['type'] == 'array' - assert parameters['str-array']['items']['type'] == 'string' - assert parameters['bool-array']['type'] == 'array' - assert parameters['bool-array']['items']['type'] == 'boolean' + assert parameters["int-array"]["type"] == "array" + assert parameters["int-array"]["items"]["type"] == "integer" + assert parameters["float-array"]["type"] == "array" + assert parameters["float-array"]["items"]["type"] == "number" + assert parameters["str-array"]["type"] == "array" + assert parameters["str-array"]["items"]["type"] == "string" + assert parameters["bool-array"]["type"] == "array" + assert parameters["bool-array"]["items"]["type"] == "boolean" def test_response_on_method(self, api, client): - api.model('ErrorModel', { - 'message': restx.fields.String, - }) + api.model("ErrorModel", {"message": restx.fields.String,}) - @api.route('/test/') + @api.route("/test/") class ByNameResource(restx.Resource): - @api.doc(responses={ - 404: 'Not found', - 405: ('Some message', 'ErrorModel'), - }) + @api.doc( + responses={404: "Not found", 405: ("Some message", "ErrorModel"),} + ) def get(self): return {} - data = client.get_specs('') - paths = data['paths'] + data = client.get_specs("") + paths = data["paths"] assert len(paths.keys()) == 1 - op = paths['/test/']['get'] - assert op['tags'] == ['default'] - assert op['responses'] == { - '404': { - 'description': 'Not found', + op = paths["/test/"]["get"] + assert op["tags"] == ["default"] + assert op["responses"] == { + "404": {"description": "Not found",}, + "405": { + "description": "Some message", + "schema": {"$ref": "#/definitions/ErrorModel",}, }, - '405': { - 'description': 'Some message', - 'schema': { - '$ref': '#/definitions/ErrorModel', - } - } } - assert 'definitions' in data - assert 'ErrorModel' in data['definitions'] + assert "definitions" in data + assert "ErrorModel" in data["definitions"] def test_api_response(self, api, client): - @api.route('/test/') + @api.route("/test/") class TestResource(restx.Resource): - - @api.response(200, 'Success') + @api.response(200, "Success") def get(self): pass - data = client.get_specs('') - paths = data['paths'] + data = client.get_specs("") + paths = data["paths"] - op = paths['/test/']['get'] - assert op['responses'] == { - '200': { - 'description': 'Success', - } - } + op = paths["/test/"]["get"] + assert op["responses"] == {"200": {"description": "Success",}} def test_api_response_multiple(self, api, client): - @api.route('/test/') + @api.route("/test/") class TestResource(restx.Resource): - - @api.response(200, 'Success') - @api.response(400, 'Validation error') + @api.response(200, "Success") + @api.response(400, "Validation error") def get(self): pass - data = client.get_specs('') - paths = data['paths'] + data = client.get_specs("") + paths = data["paths"] - op = paths['/test/']['get'] - assert op['responses'] == { - '200': { - 'description': 'Success', - }, - '400': { - 'description': 'Validation error', - } + op = paths["/test/"]["get"] + assert op["responses"] == { + "200": {"description": "Success",}, + "400": {"description": "Validation error",}, } def test_api_response_with_model(self, api, client): - model = api.model('SomeModel', { - 'message': restx.fields.String, - }) + model = api.model("SomeModel", {"message": restx.fields.String,}) - @api.route('/test/') + @api.route("/test/") class TestResource(restx.Resource): - - @api.response(200, 'Success', model) + @api.response(200, "Success", model) def get(self): pass - data = client.get_specs('') - paths = data['paths'] + data = client.get_specs("") + paths = data["paths"] - op = paths['/test/']['get'] - assert op['responses'] == { - '200': { - 'description': 'Success', - 'schema': { - '$ref': '#/definitions/SomeModel', - } + op = paths["/test/"]["get"] + assert op["responses"] == { + "200": { + "description": "Success", + "schema": {"$ref": "#/definitions/SomeModel",}, } } - assert 'SomeModel' in data['definitions'] + assert "SomeModel" in data["definitions"] def test_api_response_default(self, api, client): - @api.route('/test/') + @api.route("/test/") class TestResource(restx.Resource): - - @api.response('default', 'Error') + @api.response("default", "Error") def get(self): pass - data = client.get_specs('') - paths = data['paths'] + data = client.get_specs("") + paths = data["paths"] - op = paths['/test/']['get'] - assert op['responses'] == { - 'default': { - 'description': 'Error', - } - } + op = paths["/test/"]["get"] + assert op["responses"] == {"default": {"description": "Error",}} def test_api_header(self, api, client): - @api.route('/test/') - @api.header('X-HEADER', 'A class header') + @api.route("/test/") + @api.header("X-HEADER", "A class header") class TestResource(restx.Resource): - - @api.header('X-HEADER-2', 'Another header', type=[int], collectionFormat='csv') - @api.header('X-HEADER-3', type=int) - @api.header('X-HEADER-4', type='boolean') + @api.header( + "X-HEADER-2", "Another header", type=[int], collectionFormat="csv" + ) + @api.header("X-HEADER-3", type=int) + @api.header("X-HEADER-4", type="boolean") def get(self): pass - data = client.get_specs('') - headers = data['paths']['/test/']['get']['responses']['200']['headers'] + data = client.get_specs("") + headers = data["paths"]["/test/"]["get"]["responses"]["200"]["headers"] - assert 'X-HEADER' in headers - assert headers['X-HEADER'] == { - 'type': 'string', - 'description': 'A class header', + assert "X-HEADER" in headers + assert headers["X-HEADER"] == { + "type": "string", + "description": "A class header", } - assert 'X-HEADER-2' in headers - assert headers['X-HEADER-2'] == { - 'type': 'array', - 'items': {'type': 'integer'}, - 'description': 'Another header', - 'collectionFormat': 'csv', + assert "X-HEADER-2" in headers + assert headers["X-HEADER-2"] == { + "type": "array", + "items": {"type": "integer"}, + "description": "Another header", + "collectionFormat": "csv", } - assert 'X-HEADER-3' in headers - assert headers['X-HEADER-3'] == {'type': 'integer'} + assert "X-HEADER-3" in headers + assert headers["X-HEADER-3"] == {"type": "integer"} - assert 'X-HEADER-4' in headers - assert headers['X-HEADER-4'] == {'type': 'boolean'} + assert "X-HEADER-4" in headers + assert headers["X-HEADER-4"] == {"type": "boolean"} def test_response_header(self, api, client): - @api.route('/test/') + @api.route("/test/") class TestResource(restx.Resource): - @api.response(200, 'Success') - @api.response(400, 'Validation', headers={'X-HEADER': 'An header'}) + @api.response(200, "Success") + @api.response(400, "Validation", headers={"X-HEADER": "An header"}) def get(self): pass - data = client.get_specs('') - headers = data['paths']['/test/']['get']['responses']['400']['headers'] + data = client.get_specs("") + headers = data["paths"]["/test/"]["get"]["responses"]["400"]["headers"] - assert 'X-HEADER' in headers - assert headers['X-HEADER'] == { - 'type': 'string', - 'description': 'An header', + assert "X-HEADER" in headers + assert headers["X-HEADER"] == { + "type": "string", + "description": "An header", } def test_api_and_response_header(self, api, client): - @api.route('/test/') - @api.header('X-HEADER', 'A class header') + @api.route("/test/") + @api.header("X-HEADER", "A class header") class TestResource(restx.Resource): - - @api.header('X-HEADER-2', type=int) - @api.response(200, 'Success') - @api.response(400, 'Validation', headers={'X-ERROR': 'An error header'}) + @api.header("X-HEADER-2", type=int) + @api.response(200, "Success") + @api.response(400, "Validation", headers={"X-ERROR": "An error header"}) def get(self): pass - data = client.get_specs('') - headers200 = data['paths']['/test/']['get']['responses']['200']['headers'] - headers400 = data['paths']['/test/']['get']['responses']['400']['headers'] + data = client.get_specs("") + headers200 = data["paths"]["/test/"]["get"]["responses"]["200"]["headers"] + headers400 = data["paths"]["/test/"]["get"]["responses"]["400"]["headers"] for headers in (headers200, headers400): - assert 'X-HEADER' in headers - assert 'X-HEADER-2' in headers + assert "X-HEADER" in headers + assert "X-HEADER-2" in headers - assert 'X-ERROR' in headers400 - assert 'X-ERROR' not in headers200 + assert "X-ERROR" in headers400 + assert "X-ERROR" not in headers200 def test_expect_header(self, api, client): parser = api.parser() - parser.add_argument('X-Header', location='headers', required=True, help='A required header') - parser.add_argument('X-Header-2', location='headers', type=int, action='split', help='Another header') - parser.add_argument('X-Header-3', location='headers', type=int) - parser.add_argument('X-Header-4', location='headers', type=inputs.boolean) + parser.add_argument( + "X-Header", location="headers", required=True, help="A required header" + ) + parser.add_argument( + "X-Header-2", + location="headers", + type=int, + action="split", + help="Another header", + ) + parser.add_argument("X-Header-3", location="headers", type=int) + parser.add_argument("X-Header-4", location="headers", type=inputs.boolean) - @api.route('/test/') + @api.route("/test/") class TestResource(restx.Resource): - @api.expect(parser) def get(self): pass - data = client.get_specs('') - parameters = data['paths']['/test/']['get']['parameters'] + data = client.get_specs("") + parameters = data["paths"]["/test/"]["get"]["parameters"] def get_param(name): - candidates = [p for p in parameters if p['name'] == name] - assert len(candidates) == 1, 'parameter {0} not found'.format(name) + candidates = [p for p in parameters if p["name"] == name] + assert len(candidates) == 1, "parameter {0} not found".format(name) return candidates[0] - parameter = get_param('X-Header') - assert parameter['type'] == 'string' - assert parameter['in'] == 'header' - assert parameter['required'] is True - assert parameter['description'] == 'A required header' + parameter = get_param("X-Header") + assert parameter["type"] == "string" + assert parameter["in"] == "header" + assert parameter["required"] is True + assert parameter["description"] == "A required header" - parameter = get_param('X-Header-2') - assert parameter['type'] == 'array' - assert parameter['in'] == 'header' - assert parameter['items']['type'] == 'integer' - assert parameter['description'] == 'Another header' - assert parameter['collectionFormat'] == 'csv' + parameter = get_param("X-Header-2") + assert parameter["type"] == "array" + assert parameter["in"] == "header" + assert parameter["items"]["type"] == "integer" + assert parameter["description"] == "Another header" + assert parameter["collectionFormat"] == "csv" - parameter = get_param('X-Header-3') - assert parameter['type'] == 'integer' - assert parameter['in'] == 'header' + parameter = get_param("X-Header-3") + assert parameter["type"] == "integer" + assert parameter["in"] == "header" - parameter = get_param('X-Header-4') - assert parameter['type'] == 'boolean' - assert parameter['in'] == 'header' + parameter = get_param("X-Header-4") + assert parameter["type"] == "boolean" + assert parameter["in"] == "header" def test_description(self, api, client): - @api.route('/description/', endpoint='description', doc={ - 'description': 'Parent description.', - 'delete': {'description': 'A delete operation'}, - }) + @api.route( + "/description/", + endpoint="description", + doc={ + "description": "Parent description.", + "delete": {"description": "A delete operation"}, + }, + ) class ResourceWithDescription(restx.Resource): - @api.doc(description='Some details') + @api.doc(description="Some details") def get(self): return {} def post(self): - ''' + """ Do something. Extra description - ''' + """ return {} def put(self): - '''No description (only summary)''' + """No description (only summary)""" def delete(self): - '''No description (only summary)''' + """No description (only summary)""" - @api.route('/descriptionless/', endpoint='descriptionless') + @api.route("/descriptionless/", endpoint="descriptionless") class ResourceWithoutDescription(restx.Resource): def get(self): - '''No description (only summary)''' + """No description (only summary)""" return {} data = client.get_specs() - description = lambda m: data['paths']['/description/'][m]['description'] # noqa + description = lambda m: data["paths"]["/description/"][m]["description"] # noqa - assert description('get') == dedent('''\ + assert description("get") == dedent( + """\ Parent description. - Some details''' + Some details""" ) - assert description('post') == dedent('''\ + assert description("post") == dedent( + """\ Parent description. - Extra description''' + Extra description""" ) - assert description('delete') == dedent('''\ + assert description("delete") == dedent( + """\ Parent description. - A delete operation''' + A delete operation""" ) - assert description('put') == 'Parent description.' - assert 'description' not in data['paths']['/descriptionless/']['get'] + assert description("put") == "Parent description." + assert "description" not in data["paths"]["/descriptionless/"]["get"] def test_operation_id(self, api, client): - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): - @api.doc(id='get_objects') + @api.doc(id="get_objects") def get(self): return {} @@ -1458,32 +1402,32 @@ def post(self): return {} data = client.get_specs() - path = data['paths']['/test/'] + path = data["paths"]["/test/"] - assert path['get']['operationId'] == 'get_objects' - assert path['post']['operationId'] == 'post_test_resource' + assert path["get"]["operationId"] == "get_objects" + assert path["post"]["operationId"] == "post_test_resource" def test_operation_id_shortcut(self, api, client): - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): - @api.doc('get_objects') + @api.doc("get_objects") def get(self): return {} data = client.get_specs() - path = data['paths']['/test/'] + path = data["paths"]["/test/"] - assert path['get']['operationId'] == 'get_objects' + assert path["get"]["operationId"] == "get_objects" def test_custom_default_operation_id(self, app, client): def default_id(resource, method): - return '{0}{1}'.format(method, resource) + return "{0}{1}".format(method, resource) api = restx.Api(app, default_id=default_id) - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): - @api.doc(id='get_objects') + @api.doc(id="get_objects") def get(self): return {} @@ -1491,16 +1435,16 @@ def post(self): return {} data = client.get_specs() - path = data['paths']['/test/'] + path = data["paths"]["/test/"] - assert path['get']['operationId'] == 'get_objects' - assert path['post']['operationId'] == 'postTestResource' + assert path["get"]["operationId"] == "get_objects" + assert path["post"]["operationId"] == "postTestResource" - @pytest.mark.api(default_id=lambda r, m: '{0}{1}'.format(m, r)) + @pytest.mark.api(default_id=lambda r, m: "{0}{1}".format(m, r)) def test_custom_default_operation_id_blueprint(self, api, client): - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): - @api.doc(id='get_objects') + @api.doc(id="get_objects") def get(self): return {} @@ -1508,13 +1452,13 @@ def post(self): return {} data = client.get_specs() - path = data['paths']['/test/'] + path = data["paths"]["/test/"] - assert path['get']['operationId'] == 'get_objects' - assert path['post']['operationId'] == 'postTestResource' + assert path["get"]["operationId"] == "get_objects" + assert path["post"]["operationId"] == "postTestResource" def test_model_primitive_types(self, api, client): - @api.route('/model-int/') + @api.route("/model-int/") class ModelInt(restx.Resource): @api.doc(model=int) def get(self): @@ -1522,140 +1466,136 @@ def get(self): data = client.get_specs() - assert 'definitions' not in data - assert data['paths']['/model-int/']['get']['responses'] == { - '200': { - 'description': 'Success', - 'schema': { - 'type': 'integer' - } - } + assert "definitions" not in data + assert data["paths"]["/model-int/"]["get"]["responses"] == { + "200": {"description": "Success", "schema": {"type": "integer"}} } def test_model_as_flat_dict(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): @api.doc(model=fields) def get(self): return {} - @api.doc(model='Person') + @api.doc(model="Person") def post(self): return {} data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] - path = data['paths']['/model-as-dict/'] - assert path['get']['responses']['200']['schema']['$ref'] == '#/definitions/Person' - assert path['post']['responses']['200']['schema']['$ref'] == '#/definitions/Person' + path = data["paths"]["/model-as-dict/"] + assert ( + path["get"]["responses"]["200"]["schema"]["$ref"] == "#/definitions/Person" + ) + assert ( + path["post"]["responses"]["200"]["schema"]["$ref"] == "#/definitions/Person" + ) def test_model_as_nested_dict(self, api, client): - address_fields = api.model('Address', { - 'road': restx.fields.String, - }) + address_fields = api.model("Address", {"road": restx.fields.String,}) - fields = api.model('Person', { - 'address': restx.fields.Nested(address_fields) - }) + fields = api.model("Person", {"address": restx.fields.Nested(address_fields)}) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): @api.doc(model=fields) def get(self): return {} - @api.doc(model='Person') + @api.doc(model="Person") def post(self): return {} data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert data['definitions']['Person'] == { - 'properties': { - 'address': { - '$ref': '#/definitions/Address' - }, - }, - 'type': 'object' + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": {"address": {"$ref": "#/definitions/Address"},}, + "type": "object", } - assert 'Address' in data['definitions'] - assert data['definitions']['Address'] == { - 'properties': { - 'road': { - 'type': 'string' - }, - }, - 'type': 'object' + assert "Address" in data["definitions"] + assert data["definitions"]["Address"] == { + "properties": {"road": {"type": "string"},}, + "type": "object", } - path = data['paths']['/model-as-dict/'] - assert path['get']['responses']['200']['schema']['$ref'] == '#/definitions/Person' - assert path['post']['responses']['200']['schema']['$ref'] == '#/definitions/Person' + path = data["paths"]["/model-as-dict/"] + assert ( + path["get"]["responses"]["200"]["schema"]["$ref"] == "#/definitions/Person" + ) + assert ( + path["post"]["responses"]["200"]["schema"]["$ref"] == "#/definitions/Person" + ) def test_model_as_nested_dict_with_details(self, api, client): - address_fields = api.model('Address', { - 'road': restx.fields.String, - }) - - fields = api.model('Person', { - 'address': restx.fields.Nested(address_fields, description='description', readonly=True) - }) + address_fields = api.model("Address", {"road": restx.fields.String,}) + + fields = api.model( + "Person", + { + "address": restx.fields.Nested( + address_fields, description="description", readonly=True + ) + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): @api.doc(model=fields) def get(self): return {} - @api.doc(model='Person') + @api.doc(model="Person") def post(self): return {} data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert data['definitions']['Person'] == { - 'properties': { - 'address': { - 'description': 'description', - 'readOnly': True, - 'allOf': [{'$ref': '#/definitions/Address'}] + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "address": { + "description": "description", + "readOnly": True, + "allOf": [{"$ref": "#/definitions/Address"}], }, }, - 'type': 'object' + "type": "object", } - assert 'Address' in data['definitions'] - assert data['definitions']['Address'] == { - 'properties': { - 'road': { - 'type': 'string' - }, - }, - 'type': 'object' + assert "Address" in data["definitions"] + assert data["definitions"]["Address"] == { + "properties": {"road": {"type": "string"},}, + "type": "object", } def test_model_as_flat_dict_with_marchal_decorator(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): @api.marshal_with(fields) def get(self): @@ -1663,26 +1603,23 @@ def get(self): data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] - responses = data['paths']['/model-as-dict/']['get']['responses'] + responses = data["paths"]["/model-as-dict/"]["get"]["responses"] assert responses == { - '200': { - 'description': 'Success', - 'schema': { - '$ref': '#/definitions/Person' - } + "200": { + "description": "Success", + "schema": {"$ref": "#/definitions/Person"}, } } def test_model_with_non_uri_chars_in_name(self, api, client): # name will be encoded as 'Person%2F%2F%3Flots%7B%7D%20of%20%26illegals%40%60' - name = 'Person//?lots{} of &illegals@`' - fields = api.model(name, { - }) + name = "Person//?lots{} of &illegals@`" + fields = api.model(name, {}) - @api.route('/model-bad-uri/') + @api.route("/model-bad-uri/") class ModelBadUri(restx.Resource): @api.doc(model=fields) def get(self): @@ -1694,23 +1631,30 @@ def post(self): data = client.get_specs() - assert 'definitions' in data - assert name in data['definitions'] + assert "definitions" in data + assert name in data["definitions"] - path = data['paths']['/model-bad-uri/'] - assert path['get']['responses']['200']['schema']['$ref'] == \ - '#/definitions/Person%2F%2F%3Flots%7B%7D%20of%20%26illegals%40%60' - assert path['post']['responses']['201']['schema']['$ref'] == \ - '#/definitions/Person%2F%2F%3Flots%7B%7D%20of%20%26illegals%40%60' + path = data["paths"]["/model-bad-uri/"] + assert ( + path["get"]["responses"]["200"]["schema"]["$ref"] + == "#/definitions/Person%2F%2F%3Flots%7B%7D%20of%20%26illegals%40%60" + ) + assert ( + path["post"]["responses"]["201"]["schema"]["$ref"] + == "#/definitions/Person%2F%2F%3Flots%7B%7D%20of%20%26illegals%40%60" + ) def test_marchal_decorator_with_code(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): @api.marshal_with(fields, code=204) def delete(self): @@ -1718,87 +1662,86 @@ def delete(self): data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] - responses = data['paths']['/model-as-dict/']['delete']['responses'] + responses = data["paths"]["/model-as-dict/"]["delete"]["responses"] assert responses == { - '204': { - 'description': 'Success', - 'schema': { - '$ref': '#/definitions/Person' - } + "204": { + "description": "Success", + "schema": {"$ref": "#/definitions/Person"}, } } def test_marchal_decorator_with_description(self, api, client): - person = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + person = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): - @api.marshal_with(person, description='Some details') + @api.marshal_with(person, description="Some details") def get(self): return {} data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] - responses = data['paths']['/model-as-dict/']['get']['responses'] + responses = data["paths"]["/model-as-dict/"]["get"]["responses"] assert responses == { - '200': { - 'description': 'Some details', - 'schema': { - '$ref': '#/definitions/Person' - } + "200": { + "description": "Some details", + "schema": {"$ref": "#/definitions/Person"}, } } def test_marhsal_decorator_with_envelope(self, api, client): - person = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + person = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): - @api.marshal_with(person, envelope='person') + @api.marshal_with(person, envelope="person") def get(self): return {} data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] - responses = data['paths']['/model-as-dict/']['get']['responses'] + responses = data["paths"]["/model-as-dict/"]["get"]["responses"] assert responses == { - '200': { - 'description': 'Success', - 'schema': { - 'properties': { - 'person': { - '$ref': '#/definitions/Person' - } - } - } + "200": { + "description": "Success", + "schema": {"properties": {"person": {"$ref": "#/definitions/Person"}}}, } } def test_model_as_flat_dict_with_marchal_decorator_list(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): @api.marshal_with(fields, as_list=True) def get(self): @@ -1806,38 +1749,34 @@ def get(self): data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert data['definitions']['Person'] == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - } + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, }, - 'type': 'object' + "type": "object", } - path = data['paths']['/model-as-dict/'] - assert path['get']['responses']['200']['schema'] == { - 'type': 'array', - 'items': {'$ref': '#/definitions/Person'}, + path = data["paths"]["/model-as-dict/"] + assert path["get"]["responses"]["200"]["schema"] == { + "type": "array", + "items": {"$ref": "#/definitions/Person"}, } def test_model_as_flat_dict_with_marchal_decorator_list_alt(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): @api.marshal_list_with(fields) def get(self): @@ -1845,52 +1784,55 @@ def get(self): data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] - path = data['paths']['/model-as-dict/'] - assert path['get']['responses']['200']['schema'] == { - 'type': 'array', - 'items': {'$ref': '#/definitions/Person'}, + path = data["paths"]["/model-as-dict/"] + assert path["get"]["responses"]["200"]["schema"] == { + "type": "array", + "items": {"$ref": "#/definitions/Person"}, } def test_model_as_flat_dict_with_marchal_decorator_list_kwargs(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): - @api.marshal_list_with(fields, code=201, description='Some details') + @api.marshal_list_with(fields, code=201, description="Some details") def get(self): return {} data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] - path = data['paths']['/model-as-dict/'] - assert path['get']['responses'] == { - '201': { - 'description': 'Some details', - 'schema': { - 'type': 'array', - 'items': {'$ref': '#/definitions/Person'}, - } + path = data["paths"]["/model-as-dict/"] + assert path["get"]["responses"] == { + "201": { + "description": "Some details", + "schema": {"type": "array", "items": {"$ref": "#/definitions/Person"},}, } } def test_model_as_dict_with_list(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'tags': restx.fields.List(restx.fields.String), - }) + fields = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "tags": restx.fields.List(restx.fields.String), + }, + ) - @api.route('/model-with-list/') + @api.route("/model-with-list/") class ModelAsDict(restx.Resource): @api.doc(model=fields) def get(self): @@ -1898,42 +1840,36 @@ def get(self): data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert data['definitions']['Person'] == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'tags': { - 'type': 'array', - 'items': { - 'type': 'string' - } - } + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "tags": {"type": "array", "items": {"type": "string"}}, }, - 'type': 'object' + "type": "object", } - path = data['paths']['/model-with-list/'] - assert path['get']['responses']['200']['schema'] == {'$ref': '#/definitions/Person'} + path = data["paths"]["/model-with-list/"] + assert path["get"]["responses"]["200"]["schema"] == { + "$ref": "#/definitions/Person" + } def test_model_as_nested_dict_with_list(self, api, client): - address = api.model('Address', { - 'road': restx.fields.String, - }) - - person = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - 'addresses': restx.fields.List(restx.fields.Nested(address)) - }) - - @api.route('/model-with-list/') + address = api.model("Address", {"road": restx.fields.String,}) + + person = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + "addresses": restx.fields.List(restx.fields.Nested(address)), + }, + ) + + @api.route("/model-with-list/") class ModelAsDict(restx.Resource): @api.doc(model=person) def get(self): @@ -1941,12 +1877,12 @@ def get(self): data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert 'Address' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] + assert "Address" in data["definitions"] def test_model_list_of_primitive_types(self, api, client): - @api.route('/model-list/') + @api.route("/model-list/") class ModelAsDict(restx.Resource): @api.doc(model=[int]) def get(self): @@ -1958,55 +1894,61 @@ def post(self): data = client.get_specs() - assert 'definitions' not in data + assert "definitions" not in data - path = data['paths']['/model-list/'] - assert path['get']['responses']['200']['schema'] == { - 'type': 'array', - 'items': {'type': 'integer'}, + path = data["paths"]["/model-list/"] + assert path["get"]["responses"]["200"]["schema"] == { + "type": "array", + "items": {"type": "integer"}, } - assert path['post']['responses']['200']['schema'] == { - 'type': 'array', - 'items': {'type': 'string'}, + assert path["post"]["responses"]["200"]["schema"] == { + "type": "array", + "items": {"type": "string"}, } def test_model_list_as_flat_dict(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): @api.doc(model=[fields]) def get(self): return {} - @api.doc(model=['Person']) + @api.doc(model=["Person"]) def post(self): return {} data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] - path = data['paths']['/model-as-dict/'] - for method in 'get', 'post': - assert path[method]['responses']['200']['schema'] == { - 'type': 'array', - 'items': {'$ref': '#/definitions/Person'}, + path = data["paths"]["/model-as-dict/"] + for method in "get", "post": + assert path[method]["responses"]["200"]["schema"] == { + "type": "array", + "items": {"$ref": "#/definitions/Person"}, } def test_model_doc_on_class(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") @api.doc(model=fields) class ModelAsDict(restx.Resource): def get(self): @@ -2016,22 +1958,27 @@ def post(self): return {} data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] - path = data['paths']['/model-as-dict/'] - for method in 'get', 'post': - assert path[method]['responses']['200']['schema'] == {'$ref': '#/definitions/Person'} + path = data["paths"]["/model-as-dict/"] + for method in "get", "post": + assert path[method]["responses"]["200"]["schema"] == { + "$ref": "#/definitions/Person" + } def test_model_doc_for_method_on_class(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) - - @api.route('/model-as-dict/') - @api.doc(get={'model': fields}) + fields = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) + + @api.route("/model-as-dict/") + @api.doc(get={"model": fields}) class ModelAsDict(restx.Resource): def get(self): return {} @@ -2040,20 +1987,25 @@ def post(self): return {} data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] - path = data['paths']['/model-as-dict/'] - assert path['get']['responses']['200']['schema'] == {'$ref': '#/definitions/Person'} - assert 'schema' not in path['post']['responses']['200'] + path = data["paths"]["/model-as-dict/"] + assert path["get"]["responses"]["200"]["schema"] == { + "$ref": "#/definitions/Person" + } + assert "schema" not in path["post"]["responses"]["200"] def test_model_with_discriminator(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String(discriminator=True), - 'age': restx.fields.Integer, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String(discriminator=True), + "age": restx.fields.Integer, + }, + ) - @api.route('/model-with-discriminator/') + @api.route("/model-with-discriminator/") class ModelAsDict(restx.Resource): @api.marshal_with(fields) def get(self): @@ -2061,25 +2013,25 @@ def get(self): data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert data['definitions']['Person'] == { - 'properties': { - 'name': {'type': 'string'}, - 'age': {'type': 'integer'}, - }, - 'discriminator': 'name', - 'required': ['name'], - 'type': 'object' + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": {"name": {"type": "string"}, "age": {"type": "integer"},}, + "discriminator": "name", + "required": ["name"], + "type": "object", } def test_model_with_discriminator_override_require(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String(discriminator=True, required=False), - 'age': restx.fields.Integer, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String(discriminator=True, required=False), + "age": restx.fields.Integer, + }, + ) - @api.route('/model-with-discriminator/') + @api.route("/model-with-discriminator/") class ModelAsDict(restx.Resource): @api.marshal_with(fields) def get(self): @@ -2087,215 +2039,187 @@ def get(self): data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert data['definitions']['Person'] == { - 'properties': { - 'name': {'type': 'string'}, - 'age': {'type': 'integer'}, - }, - 'discriminator': 'name', - 'required': ['name'], - 'type': 'object' + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": {"name": {"type": "string"}, "age": {"type": "integer"},}, + "discriminator": "name", + "required": ["name"], + "type": "object", } def test_model_not_found(self, api, client): - @api.route('/model-not-found/') + @api.route("/model-not-found/") class ModelAsDict(restx.Resource): - @api.doc(model='NotFound') + @api.doc(model="NotFound") def get(self): return {} client.get_specs(status=500) def test_specs_no_duplicate_response_keys(self, api, client): - ''' + """ This tests that the swagger.json document will not be written with duplicate object keys due to the coercion of dict keys to string. The last @api.response should win. - ''' + """ # Note the use of a strings '404' and '200' in class decorators as opposed to ints in method decorators. - @api.response('404', 'Not Found') + @api.response("404", "Not Found") class BaseResource(restx.Resource): def get(self): pass - model = api.model('SomeModel', { - 'message': restx.fields.String, - }) + model = api.model("SomeModel", {"message": restx.fields.String,}) - @api.route('/test/') - @api.response('200', 'Success') + @api.route("/test/") + @api.response("200", "Success") class TestResource(BaseResource): # @api.marshal_with also yields a response - @api.marshal_with(model, code=200, description='Success on method') - @api.response(404, 'Not Found on method') + @api.marshal_with(model, code=200, description="Success on method") + @api.response(404, "Not Found on method") def get(self): {} - data = client.get_specs('') - paths = data['paths'] + data = client.get_specs("") + paths = data["paths"] - op = paths['/test/']['get'] - print(op['responses']) - assert op['responses'] == { - '200': { - 'description': 'Success on method', - 'schema': { - '$ref': '#/definitions/SomeModel' - } + op = paths["/test/"]["get"] + print(op["responses"]) + assert op["responses"] == { + "200": { + "description": "Success on method", + "schema": {"$ref": "#/definitions/SomeModel"}, }, - '404': { - 'description': 'Not Found on method', - } + "404": {"description": "Not Found on method",}, } def test_clone(self, api, client): - parent = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + parent = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - child = api.clone('Child', parent, { - 'extra': restx.fields.String, - }) + child = api.clone("Child", parent, {"extra": restx.fields.String,}) - @api.route('/extend/') + @api.route("/extend/") class ModelAsDict(restx.Resource): @api.doc(model=child) def get(self): return {} - @api.doc(model='Child') + @api.doc(model="Child") def post(self): return {} data = client.get_specs() - assert 'definitions' in data - assert 'Person' not in data['definitions'] - assert 'Child' in data['definitions'] + assert "definitions" in data + assert "Person" not in data["definitions"] + assert "Child" in data["definitions"] - path = data['paths']['/extend/'] - assert path['get']['responses']['200']['schema']['$ref'] == '#/definitions/Child' - assert path['post']['responses']['200']['schema']['$ref'] == '#/definitions/Child' + path = data["paths"]["/extend/"] + assert ( + path["get"]["responses"]["200"]["schema"]["$ref"] == "#/definitions/Child" + ) + assert ( + path["post"]["responses"]["200"]["schema"]["$ref"] == "#/definitions/Child" + ) def test_inherit(self, api, client): - parent = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - }) + parent = api.model( + "Person", {"name": restx.fields.String, "age": restx.fields.Integer,} + ) - child = api.inherit('Child', parent, { - 'extra': restx.fields.String, - }) + child = api.inherit("Child", parent, {"extra": restx.fields.String,}) - @api.route('/inherit/') + @api.route("/inherit/") class ModelAsDict(restx.Resource): @api.marshal_with(child) def get(self): return { - 'name': 'John', - 'age': 42, - 'extra': 'test', + "name": "John", + "age": 42, + "extra": "test", } - @api.doc(model='Child') + @api.doc(model="Child") def post(self): return {} data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert 'Child' in data['definitions'] - assert data['definitions']['Person'] == { - 'properties': { - 'name': {'type': 'string'}, - 'age': {'type': 'integer'}, - }, - 'type': 'object' + assert "definitions" in data + assert "Person" in data["definitions"] + assert "Child" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": {"name": {"type": "string"}, "age": {"type": "integer"},}, + "type": "object", } - assert data['definitions']['Child'] == { - 'allOf': [{ - '$ref': '#/definitions/Person' - }, { - 'properties': { - 'extra': {'type': 'string'} - }, - 'type': 'object' - }] + assert data["definitions"]["Child"] == { + "allOf": [ + {"$ref": "#/definitions/Person"}, + {"properties": {"extra": {"type": "string"}}, "type": "object"}, + ] } - path = data['paths']['/inherit/'] - assert path['get']['responses']['200']['schema']['$ref'] == '#/definitions/Child' - assert path['post']['responses']['200']['schema']['$ref'] == '#/definitions/Child' + path = data["paths"]["/inherit/"] + assert ( + path["get"]["responses"]["200"]["schema"]["$ref"] == "#/definitions/Child" + ) + assert ( + path["post"]["responses"]["200"]["schema"]["$ref"] == "#/definitions/Child" + ) - data = client.get_json('/inherit/') + data = client.get_json("/inherit/") assert data == { - 'name': 'John', - 'age': 42, - 'extra': 'test', + "name": "John", + "age": 42, + "extra": "test", } def test_inherit_inline(self, api, client): - parent = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - }) + parent = api.model( + "Person", {"name": restx.fields.String, "age": restx.fields.Integer,} + ) - child = api.inherit('Child', parent, { - 'extra': restx.fields.String, - }) + child = api.inherit("Child", parent, {"extra": restx.fields.String,}) - output = api.model('Output', { - 'child': restx.fields.Nested(child), - 'children': restx.fields.List(restx.fields.Nested(child)) - }) + output = api.model( + "Output", + { + "child": restx.fields.Nested(child), + "children": restx.fields.List(restx.fields.Nested(child)), + }, + ) - @api.route('/inherit/') + @api.route("/inherit/") class ModelAsDict(restx.Resource): @api.marshal_with(output) def get(self): return { - 'child': { - 'name': 'John', - 'age': 42, - 'extra': 'test', - }, - 'children': [{ - 'name': 'John', - 'age': 42, - 'extra': 'test', - }, { - 'name': 'Doe', - 'age': 33, - 'extra': 'test2', - }] + "child": {"name": "John", "age": 42, "extra": "test",}, + "children": [ + {"name": "John", "age": 42, "extra": "test",}, + {"name": "Doe", "age": 33, "extra": "test2",}, + ], } data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert 'Child' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] + assert "Child" in data["definitions"] - data = client.get_json('/inherit/') + data = client.get_json("/inherit/") assert data == { - 'child': { - 'name': 'John', - 'age': 42, - 'extra': 'test', - }, - 'children': [{ - 'name': 'John', - 'age': 42, - 'extra': 'test', - }, { - 'name': 'Doe', - 'age': 33, - 'extra': 'test2', - }] + "child": {"name": "John", "age": 42, "extra": "test",}, + "children": [ + {"name": "John", "age": 42, "extra": "test",}, + {"name": "Doe", "age": 33, "extra": "test2",}, + ], } def test_polymorph_inherit(self, api, client): @@ -2305,29 +2229,22 @@ class Child1: class Child2: pass - parent = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - }) + parent = api.model( + "Person", {"name": restx.fields.String, "age": restx.fields.Integer,} + ) - child1 = api.inherit('Child1', parent, { - 'extra1': restx.fields.String, - }) + child1 = api.inherit("Child1", parent, {"extra1": restx.fields.String,}) - child2 = api.inherit('Child2', parent, { - 'extra2': restx.fields.String, - }) + child2 = api.inherit("Child2", parent, {"extra2": restx.fields.String,}) mapping = { Child1: child1, Child2: child2, } - output = api.model('Output', { - 'child': restx.fields.Polymorph(mapping) - }) + output = api.model("Output", {"child": restx.fields.Polymorph(mapping)}) - @api.route('/polymorph/') + @api.route("/polymorph/") class ModelAsDict(restx.Resource): @api.marshal_with(output) def get(self): @@ -2335,83 +2252,79 @@ def get(self): data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert 'Child1' in data['definitions'] - assert 'Child2' in data['definitions'] - assert 'Output' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] + assert "Child1" in data["definitions"] + assert "Child2" in data["definitions"] + assert "Output" in data["definitions"] - path = data['paths']['/polymorph/'] - assert path['get']['responses']['200']['schema']['$ref'] == '#/definitions/Output' + path = data["paths"]["/polymorph/"] + assert ( + path["get"]["responses"]["200"]["schema"]["$ref"] == "#/definitions/Output" + ) def test_polymorph_inherit_list(self, api, client): class Child1(object): - name = 'Child1' - extra1 = 'extra1' + name = "Child1" + extra1 = "extra1" class Child2(object): - name = 'Child2' - extra2 = 'extra2' + name = "Child2" + extra2 = "extra2" - parent = api.model('Person', { - 'name': restx.fields.String, - }) + parent = api.model("Person", {"name": restx.fields.String,}) - child1 = api.inherit('Child1', parent, { - 'extra1': restx.fields.String, - }) + child1 = api.inherit("Child1", parent, {"extra1": restx.fields.String,}) - child2 = api.inherit('Child2', parent, { - 'extra2': restx.fields.String, - }) + child2 = api.inherit("Child2", parent, {"extra2": restx.fields.String,}) mapping = { Child1: child1, Child2: child2, } - output = api.model('Output', { - 'children': restx.fields.List(restx.fields.Polymorph(mapping)) - }) + output = api.model( + "Output", {"children": restx.fields.List(restx.fields.Polymorph(mapping))} + ) - @api.route('/polymorph/') + @api.route("/polymorph/") class ModelAsDict(restx.Resource): @api.marshal_with(output) def get(self): - return { - 'children': [Child1(), Child2()] - } + return {"children": [Child1(), Child2()]} data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert 'Child1' in data['definitions'] - assert 'Child2' in data['definitions'] - assert 'Output' in data['definitions'] + assert "definitions" in data + assert "Person" in data["definitions"] + assert "Child1" in data["definitions"] + assert "Child2" in data["definitions"] + assert "Output" in data["definitions"] - path = data['paths']['/polymorph/'] - assert path['get']['responses']['200']['schema']['$ref'] == '#/definitions/Output' + path = data["paths"]["/polymorph/"] + assert ( + path["get"]["responses"]["200"]["schema"]["$ref"] == "#/definitions/Output" + ) - data = client.get_json('/polymorph/') + data = client.get_json("/polymorph/") assert data == { - 'children': [{ - 'name': 'Child1', - 'extra1': 'extra1', - }, { - 'name': 'Child2', - 'extra2': 'extra2', - }] + "children": [ + {"name": "Child1", "extra1": "extra1",}, + {"name": "Child2", "extra2": "extra2",}, + ] } def test_expect_model(self, api, client): - person = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + person = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): @api.expect(person) def post(self): @@ -2419,96 +2332,84 @@ def post(self): data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert data['definitions']['Person'] == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - } + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, }, - 'type': 'object' + "type": "object", } - op = data['paths']['/model-as-dict/']['post'] - assert len(op['parameters']) == 1 + op = data["paths"]["/model-as-dict/"]["post"] + assert len(op["parameters"]) == 1 - parameter = op['parameters'][0] + parameter = op["parameters"][0] assert parameter == { - 'name': 'payload', - 'in': 'body', - 'required': True, - 'schema': { - '$ref': '#/definitions/Person' - } + "name": "payload", + "in": "body", + "required": True, + "schema": {"$ref": "#/definitions/Person"}, } - assert 'description' not in parameter + assert "description" not in parameter def test_body_model_shortcut(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): - @api.doc(model='Person') + @api.doc(model="Person") @api.expect(fields) def post(self): return {} data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert data['definitions']['Person'] == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - } + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, }, - 'type': 'object' + "type": "object", } - op = data['paths']['/model-as-dict/']['post'] - assert op['responses']['200']['schema']['$ref'] == '#/definitions/Person' + op = data["paths"]["/model-as-dict/"]["post"] + assert op["responses"]["200"]["schema"]["$ref"] == "#/definitions/Person" - assert len(op['parameters']) == 1 + assert len(op["parameters"]) == 1 - parameter = op['parameters'][0] + parameter = op["parameters"][0] assert parameter == { - 'name': 'payload', - 'in': 'body', - 'required': True, - 'schema': { - '$ref': '#/definitions/Person' - } + "name": "payload", + "in": "body", + "required": True, + "schema": {"$ref": "#/definitions/Person"}, } - assert 'description' not in parameter + assert "description" not in parameter def test_expect_model_list(self, api, client): - model = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + model = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-list/') + @api.route("/model-list/") class ModelAsDict(restx.Resource): @api.expect([model]) def post(self): @@ -2516,48 +2417,41 @@ def post(self): data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert data['definitions']['Person'] == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - } + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, }, - 'type': 'object' + "type": "object", } - op = data['paths']['/model-list/']['post'] - parameter = op['parameters'][0] + op = data["paths"]["/model-list/"]["post"] + parameter = op["parameters"][0] assert parameter == { - 'name': 'payload', - 'in': 'body', - 'required': True, - 'schema': { - 'type': 'array', - 'items': {'$ref': '#/definitions/Person'}, - } + "name": "payload", + "in": "body", + "required": True, + "schema": {"type": "array", "items": {"$ref": "#/definitions/Person"},}, } def test_both_model_and_parser_from_expect(self, api, client): parser = api.parser() - parser.add_argument('param', type=int, help='Some param') - - person = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + parser.add_argument("param", type=int, help="Some param") + + person = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/with-parser/', endpoint='with-parser') + @api.route("/with-parser/", endpoint="with-parser") class WithParserResource(restx.Resource): @api.expect(parser, person) def get(self): @@ -2565,49 +2459,40 @@ def get(self): data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert data['definitions']['Person'] == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - } + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, }, - 'type': 'object' + "type": "object", } - assert '/with-parser/' in data['paths'] + assert "/with-parser/" in data["paths"] - op = data['paths']['/with-parser/']['get'] - assert len(op['parameters']) == 2 + op = data["paths"]["/with-parser/"]["get"] + assert len(op["parameters"]) == 2 - parameters = dict((p['in'], p) for p in op['parameters']) + parameters = dict((p["in"], p) for p in op["parameters"]) - parameter = parameters['query'] - assert parameter['name'] == 'param' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'query' - assert parameter['description'] == 'Some param' + parameter = parameters["query"] + assert parameter["name"] == "param" + assert parameter["type"] == "integer" + assert parameter["in"] == "query" + assert parameter["description"] == "Some param" - parameter = parameters['body'] + parameter = parameters["body"] assert parameter == { - 'name': 'payload', - 'in': 'body', - 'required': True, - 'schema': { - '$ref': '#/definitions/Person' - } + "name": "payload", + "in": "body", + "required": True, + "schema": {"$ref": "#/definitions/Person"}, } def test_expect_primitive_list(self, api, client): - @api.route('/model-list/') + @api.route("/model-list/") class ModelAsDict(restx.Resource): @api.expect([restx.fields.String]) def post(self): @@ -2615,26 +2500,26 @@ def post(self): data = client.get_specs() - op = data['paths']['/model-list/']['post'] - parameter = op['parameters'][0] + op = data["paths"]["/model-list/"]["post"] + parameter = op["parameters"][0] assert parameter == { - 'name': 'payload', - 'in': 'body', - 'required': True, - 'schema': { - 'type': 'array', - 'items': {'type': 'string'}, - } + "name": "payload", + "in": "body", + "required": True, + "schema": {"type": "array", "items": {"type": "string"},}, } def test_body_model_list(self, api, client): - fields = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-list/') + @api.route("/model-list/") class ModelAsDict(restx.Resource): @api.expect([fields]) def post(self): @@ -2642,93 +2527,76 @@ def post(self): data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert data['definitions']['Person'] == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - } + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, }, - 'type': 'object' + "type": "object", } - op = data['paths']['/model-list/']['post'] - parameter = op['parameters'][0] + op = data["paths"]["/model-list/"]["post"] + parameter = op["parameters"][0] assert parameter == { - 'name': 'payload', - 'in': 'body', - 'required': True, - 'schema': { - 'type': 'array', - 'items': {'$ref': '#/definitions/Person'}, - } + "name": "payload", + "in": "body", + "required": True, + "schema": {"type": "array", "items": {"$ref": "#/definitions/Person"},}, } def test_expect_model_with_description(self, api, client): - person = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + person = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) - @api.route('/model-as-dict/') + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): - @api.expect((person, 'Body description')) + @api.expect((person, "Body description")) def post(self): return {} data = client.get_specs() - assert 'definitions' in data - assert 'Person' in data['definitions'] - assert data['definitions']['Person'] == { - 'properties': { - 'name': { - 'type': 'string' - }, - 'age': { - 'type': 'integer' - }, - 'birthdate': { - 'type': 'string', - 'format': 'date-time' - } + assert "definitions" in data + assert "Person" in data["definitions"] + assert data["definitions"]["Person"] == { + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "birthdate": {"type": "string", "format": "date-time"}, }, - 'type': 'object' + "type": "object", } - op = data['paths']['/model-as-dict/']['post'] - assert len(op['parameters']) == 1 + op = data["paths"]["/model-as-dict/"]["post"] + assert len(op["parameters"]) == 1 - parameter = op['parameters'][0] + parameter = op["parameters"][0] assert parameter == { - 'name': 'payload', - 'in': 'body', - 'required': True, - 'description': 'Body description', - 'schema': { - '$ref': '#/definitions/Person' - } + "name": "payload", + "in": "body", + "required": True, + "description": "Body description", + "schema": {"$ref": "#/definitions/Person"}, } def test_authorizations(self, app, client): - restx.Api(app, authorizations={ - 'apikey': { - 'type': 'apiKey', - 'in': 'header', - 'name': 'X-API' - } - }) + restx.Api( + app, + authorizations={ + "apikey": {"type": "apiKey", "in": "header", "name": "X-API"} + }, + ) # @api.route('/authorizations/') # class ModelAsDict(restx.Resource): @@ -2739,200 +2607,186 @@ def test_authorizations(self, app, client): # return {} data = client.get_specs() - assert 'securityDefinitions' in data - assert 'security' not in data + assert "securityDefinitions" in data + assert "security" not in data # path = data['paths']['/authorizations/'] # assert 'security' not in path['get'] # assert path['post']['security'] == {'apikey': []} def test_single_root_security_string(self, app, client): - api = restx.Api(app, security='apikey', authorizations={ - 'apikey': { - 'type': 'apiKey', - 'in': 'header', - 'name': 'X-API' - } - }) + api = restx.Api( + app, + security="apikey", + authorizations={ + "apikey": {"type": "apiKey", "in": "header", "name": "X-API"} + }, + ) - @api.route('/authorizations/') + @api.route("/authorizations/") class ModelAsDict(restx.Resource): def post(self): return {} data = client.get_specs() - assert data['securityDefinitions'] == { - 'apikey': { - 'type': 'apiKey', - 'in': 'header', - 'name': 'X-API' - } + assert data["securityDefinitions"] == { + "apikey": {"type": "apiKey", "in": "header", "name": "X-API"} } - assert data['security'] == [{'apikey': []}] + assert data["security"] == [{"apikey": []}] - op = data['paths']['/authorizations/']['post'] - assert 'security' not in op + op = data["paths"]["/authorizations/"]["post"] + assert "security" not in op def test_single_root_security_object(self, app, client): security_definitions = { - 'oauth2': { - 'type': 'oauth2', - 'flow': 'accessCode', - 'tokenUrl': 'https://somewhere.com/token', - 'scopes': { - 'read': 'Grant read-only access', - 'write': 'Grant read-write access', - } + "oauth2": { + "type": "oauth2", + "flow": "accessCode", + "tokenUrl": "https://somewhere.com/token", + "scopes": { + "read": "Grant read-only access", + "write": "Grant read-write access", + }, + }, + "implicit": { + "type": "oauth2", + "flow": "implicit", + "tokenUrl": "https://somewhere.com/token", + "scopes": { + "read": "Grant read-only access", + "write": "Grant read-write access", + }, }, - 'implicit': { - 'type': 'oauth2', - 'flow': 'implicit', - 'tokenUrl': 'https://somewhere.com/token', - 'scopes': { - 'read': 'Grant read-only access', - 'write': 'Grant read-write access', - } - } } - api = restx.Api(app, - security={ - 'oauth2': 'read', - 'implicit': ['read', 'write'] - }, - authorizations=security_definitions + api = restx.Api( + app, + security={"oauth2": "read", "implicit": ["read", "write"]}, + authorizations=security_definitions, ) - @api.route('/authorizations/') + @api.route("/authorizations/") class ModelAsDict(restx.Resource): def post(self): return {} data = client.get_specs() - assert data['securityDefinitions'] == security_definitions - assert data['security'] == [{ - 'oauth2': ['read'], - 'implicit': ['read', 'write'] - }] + assert data["securityDefinitions"] == security_definitions + assert data["security"] == [{"oauth2": ["read"], "implicit": ["read", "write"]}] - op = data['paths']['/authorizations/']['post'] - assert 'security' not in op + op = data["paths"]["/authorizations/"]["post"] + assert "security" not in op def test_root_security_as_list(self, app, client): security_definitions = { - 'apikey': { - 'type': 'apiKey', - 'in': 'header', - 'name': 'X-API' + "apikey": {"type": "apiKey", "in": "header", "name": "X-API"}, + "oauth2": { + "type": "oauth2", + "flow": "accessCode", + "tokenUrl": "https://somewhere.com/token", + "scopes": { + "read": "Grant read-only access", + "write": "Grant read-write access", + }, }, - 'oauth2': { - 'type': 'oauth2', - 'flow': 'accessCode', - 'tokenUrl': 'https://somewhere.com/token', - 'scopes': { - 'read': 'Grant read-only access', - 'write': 'Grant read-write access', - } - } } - api = restx.Api(app, security=['apikey', {'oauth2': 'read'}], authorizations=security_definitions) + api = restx.Api( + app, + security=["apikey", {"oauth2": "read"}], + authorizations=security_definitions, + ) - @api.route('/authorizations/') + @api.route("/authorizations/") class ModelAsDict(restx.Resource): def post(self): return {} data = client.get_specs() - assert data['securityDefinitions'] == security_definitions - assert data['security'] == [{'apikey': []}, {'oauth2': ['read']}] + assert data["securityDefinitions"] == security_definitions + assert data["security"] == [{"apikey": []}, {"oauth2": ["read"]}] - op = data['paths']['/authorizations/']['post'] - assert 'security' not in op + op = data["paths"]["/authorizations/"]["post"] + assert "security" not in op def test_method_security(self, app, client): - api = restx.Api(app, authorizations={ - 'apikey': { - 'type': 'apiKey', - 'in': 'header', - 'name': 'X-API' - } - }) + api = restx.Api( + app, + authorizations={ + "apikey": {"type": "apiKey", "in": "header", "name": "X-API"} + }, + ) - @api.route('/authorizations/') + @api.route("/authorizations/") class ModelAsDict(restx.Resource): - @api.doc(security=['apikey']) + @api.doc(security=["apikey"]) def get(self): return {} - @api.doc(security='apikey') + @api.doc(security="apikey") def post(self): return {} data = client.get_specs() - assert data['securityDefinitions'] == { - 'apikey': { - 'type': 'apiKey', - 'in': 'header', - 'name': 'X-API' - } + assert data["securityDefinitions"] == { + "apikey": {"type": "apiKey", "in": "header", "name": "X-API"} } - assert 'security' not in data + assert "security" not in data - path = data['paths']['/authorizations/'] - for method in 'get', 'post': - assert path[method]['security'] == [{'apikey': []}] + path = data["paths"]["/authorizations/"] + for method in "get", "post": + assert path[method]["security"] == [{"apikey": []}] def test_security_override(self, app, client): security_definitions = { - 'apikey': { - 'type': 'apiKey', - 'in': 'header', - 'name': 'X-API' + "apikey": {"type": "apiKey", "in": "header", "name": "X-API"}, + "oauth2": { + "type": "oauth2", + "flow": "accessCode", + "tokenUrl": "https://somewhere.com/token", + "scopes": { + "read": "Grant read-only access", + "write": "Grant read-write access", + }, }, - 'oauth2': { - 'type': 'oauth2', - 'flow': 'accessCode', - 'tokenUrl': 'https://somewhere.com/token', - 'scopes': { - 'read': 'Grant read-only access', - 'write': 'Grant read-write access', - } - } } - api = restx.Api(app, security=['apikey', {'oauth2': 'read'}], authorizations=security_definitions) + api = restx.Api( + app, + security=["apikey", {"oauth2": "read"}], + authorizations=security_definitions, + ) - @api.route('/authorizations/') + @api.route("/authorizations/") class ModelAsDict(restx.Resource): - @api.doc(security=[{'oauth2': ['read', 'write']}]) + @api.doc(security=[{"oauth2": ["read", "write"]}]) def get(self): return {} data = client.get_specs() - assert data['securityDefinitions'] == security_definitions + assert data["securityDefinitions"] == security_definitions - op = data['paths']['/authorizations/']['get'] - assert op['security'] == [{'oauth2': ['read', 'write']}] + op = data["paths"]["/authorizations/"]["get"] + assert op["security"] == [{"oauth2": ["read", "write"]}] def test_security_nullify(self, app, client): security_definitions = { - 'apikey': { - 'type': 'apiKey', - 'in': 'header', - 'name': 'X-API' + "apikey": {"type": "apiKey", "in": "header", "name": "X-API"}, + "oauth2": { + "type": "oauth2", + "flow": "accessCode", + "tokenUrl": "https://somewhere.com/token", + "scopes": { + "read": "Grant read-only access", + "write": "Grant read-write access", + }, }, - 'oauth2': { - 'type': 'oauth2', - 'flow': 'accessCode', - 'tokenUrl': 'https://somewhere.com/token', - 'scopes': { - 'read': 'Grant read-only access', - 'write': 'Grant read-write access', - } - } } - api = restx.Api(app, security=['apikey', {'oauth2': 'read'}], authorizations=security_definitions) + api = restx.Api( + app, + security=["apikey", {"oauth2": "read"}], + authorizations=security_definitions, + ) - @api.route('/authorizations/') + @api.route("/authorizations/") class ModelAsDict(restx.Resource): @api.doc(security=[]) def get(self): @@ -2943,122 +2797,122 @@ def post(self): return {} data = client.get_specs() - assert data['securityDefinitions'] == security_definitions + assert data["securityDefinitions"] == security_definitions - path = data['paths']['/authorizations/'] - for method in 'get', 'post': - assert path[method]['security'] == [] + path = data["paths"]["/authorizations/"] + for method in "get", "post": + assert path[method]["security"] == [] def test_hidden_resource(self, api, client): - @api.route('/test/', endpoint='test', doc=False) + @api.route("/test/", endpoint="test", doc=False) class TestResource(restx.Resource): def get(self): - ''' + """ GET operation - ''' + """ return {} @api.hide - @api.route('/test2/', endpoint='test2') + @api.route("/test2/", endpoint="test2") class TestResource2(restx.Resource): def get(self): - ''' + """ GET operation - ''' + """ return {} @api.doc(False) - @api.route('/test3/', endpoint='test3') + @api.route("/test3/", endpoint="test3") class TestResource3(restx.Resource): def get(self): - ''' + """ GET operation - ''' + """ return {} data = client.get_specs() - for path in '/test/', '/test2/', '/test3/': - assert path not in data['paths'] + for path in "/test/", "/test2/", "/test3/": + assert path not in data["paths"] resp = client.get(path) assert resp.status_code == 200 def test_hidden_resource_from_namespace(self, api, client): - ns = api.namespace('ns') + ns = api.namespace("ns") - @ns.route('/test/', endpoint='test', doc=False) + @ns.route("/test/", endpoint="test", doc=False) class TestResource(restx.Resource): def get(self): - ''' + """ GET operation - ''' + """ return {} data = client.get_specs() - assert '/ns/test/' not in data['paths'] + assert "/ns/test/" not in data["paths"] - resp = client.get('/ns/test/') + resp = client.get("/ns/test/") assert resp.status_code == 200 def test_hidden_methods(self, api, client): - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") @api.doc(delete=False) class TestResource(restx.Resource): def get(self): - ''' + """ GET operation - ''' + """ return {} @api.doc(False) def post(self): - '''POST operation. + """POST operation. Should be ignored - ''' + """ return {} @api.hide def put(self): - '''PUT operation. Should be ignored''' + """PUT operation. Should be ignored""" return {} def delete(self): return {} data = client.get_specs() - path = data['paths']['/test/'] + path = data["paths"]["/test/"] - assert 'get' in path - assert 'post' not in path - assert 'put' not in path + assert "get" in path + assert "post" not in path + assert "put" not in path - for method in 'GET', 'POST', 'PUT': - resp = client.open('/test/', method=method) + for method in "GET", "POST", "PUT": + resp = client.open("/test/", method=method) assert resp.status_code == 200 def test_produces_method(self, api, client): - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): pass - @api.produces(['application/octet-stream']) + @api.produces(["application/octet-stream"]) def post(self): pass data = client.get_specs() - get_operation = data['paths']['/test/']['get'] - assert 'produces' not in get_operation + get_operation = data["paths"]["/test/"]["get"] + assert "produces" not in get_operation - post_operation = data['paths']['/test/']['post'] - assert 'produces' in post_operation - assert post_operation['produces'] == ['application/octet-stream'] + post_operation = data["paths"]["/test/"]["post"] + assert "produces" in post_operation + assert post_operation["produces"] == ["application/octet-stream"] def test_deprecated_resource(self, api, client): @api.deprecated - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): pass @@ -3067,13 +2921,13 @@ def post(self): pass data = client.get_specs() - resource = data['paths']['/test/'] + resource = data["paths"]["/test/"] for operation in resource.values(): - assert 'deprecated' in operation - assert operation['deprecated'] is True + assert "deprecated" in operation + assert operation["deprecated"] is True def test_deprecated_method(self, api, client): - @api.route('/test/', endpoint='test') + @api.route("/test/", endpoint="test") class TestResource(restx.Resource): def get(self): pass @@ -3084,61 +2938,60 @@ def post(self): data = client.get_specs() - get_operation = data['paths']['/test/']['get'] - assert 'deprecated' not in get_operation + get_operation = data["paths"]["/test/"]["get"] + assert "deprecated" not in get_operation - post_operation = data['paths']['/test/']['post'] - assert 'deprecated' in post_operation - assert post_operation['deprecated'] is True + post_operation = data["paths"]["/test/"]["post"] + assert "deprecated" in post_operation + assert post_operation["deprecated"] is True def test_vendor_as_kwargs(self, api, client): - @api.route('/vendor_fields', endpoint='vendor_fields') + @api.route("/vendor_fields", endpoint="vendor_fields") class TestResource(restx.Resource): - @api.vendor(integration={'integration1': '1'}) + @api.vendor(integration={"integration1": "1"}) def get(self): return {} data = client.get_specs() - assert '/vendor_fields' in data['paths'] + assert "/vendor_fields" in data["paths"] - path = data['paths']['/vendor_fields']['get'] + path = data["paths"]["/vendor_fields"]["get"] - assert 'x-integration' in path + assert "x-integration" in path - assert path['x-integration'] == {'integration1': '1'} + assert path["x-integration"] == {"integration1": "1"} def test_vendor_as_dict(self, api, client): - @api.route('/vendor_fields', endpoint='vendor_fields') + @api.route("/vendor_fields", endpoint="vendor_fields") class TestResource(restx.Resource): - @api.vendor({ - 'x-some-integration': { - 'integration1': '1' + @api.vendor( + { + "x-some-integration": {"integration1": "1"}, + "another-integration": True, }, - 'another-integration': True - }, { - 'third-integration': True - }) + {"third-integration": True}, + ) def get(self, age): return {} data = client.get_specs() - assert '/vendor_fields' in data['paths'] + assert "/vendor_fields" in data["paths"] - path = data['paths']['/vendor_fields']['get'] - assert 'x-some-integration' in path - assert path['x-some-integration'] == {'integration1': '1'} + path = data["paths"]["/vendor_fields"]["get"] + assert "x-some-integration" in path + assert path["x-some-integration"] == {"integration1": "1"} - assert 'x-another-integration' in path - assert path['x-another-integration'] is True + assert "x-another-integration" in path + assert path["x-another-integration"] is True - assert 'x-third-integration' in path - assert path['x-third-integration'] is True + assert "x-third-integration" in path + assert path["x-third-integration"] is True def test_method_restrictions(self, api, client): - @api.route('/foo/bar', endpoint='foo') - @api.route('/bar', methods=['GET'], endpoint='bar') + @api.route("/foo/bar", endpoint="foo") + @api.route("/bar", methods=["GET"], endpoint="bar") class TestResource(restx.Resource): def get(self): pass @@ -3148,82 +3001,81 @@ def post(self): data = client.get_specs() - path = data['paths']['/foo/bar'] - assert 'get' in path - assert 'post' in path + path = data["paths"]["/foo/bar"] + assert "get" in path + assert "post" in path - path = data['paths']['/bar'] - assert 'get' in path - assert 'post' not in path + path = data["paths"]["/bar"] + assert "get" in path + assert "post" not in path def test_multiple_routes_inherit_doc(self, api, client): - @api.route('/foo/bar') - @api.route('/bar') - @api.doc(description='an endpoint') + @api.route("/foo/bar") + @api.route("/bar") + @api.doc(description="an endpoint") class TestResource(restx.Resource): def get(self): pass data = client.get_specs() - path = data['paths']['/foo/bar'] - assert path['get']['description'] == 'an endpoint' + path = data["paths"]["/foo/bar"] + assert path["get"]["description"] == "an endpoint" - path = data['paths']['/bar'] - assert path['get']['description'] == 'an endpoint' + path = data["paths"]["/bar"] + assert path["get"]["description"] == "an endpoint" def test_multiple_routes_individual_doc(self, api, client): - @api.route('/foo/bar', doc={'description': 'the same endpoint'}) - @api.route('/bar', doc={'description': 'an endpoint'}) + @api.route("/foo/bar", doc={"description": "the same endpoint"}) + @api.route("/bar", doc={"description": "an endpoint"}) class TestResource(restx.Resource): def get(self): pass data = client.get_specs() - path = data['paths']['/foo/bar'] - assert path['get']['description'] == 'the same endpoint' + path = data["paths"]["/foo/bar"] + assert path["get"]["description"] == "the same endpoint" - path = data['paths']['/bar'] - assert path['get']['description'] == 'an endpoint' + path = data["paths"]["/bar"] + assert path["get"]["description"] == "an endpoint" def test_multiple_routes_override_doc(self, api, client): - @api.route('/foo/bar', doc={'description': 'the same endpoint'}) - @api.route('/bar') - @api.doc(description='an endpoint') + @api.route("/foo/bar", doc={"description": "the same endpoint"}) + @api.route("/bar") + @api.doc(description="an endpoint") class TestResource(restx.Resource): def get(self): pass data = client.get_specs() - path = data['paths']['/foo/bar'] - assert path['get']['description'] == 'the same endpoint' + path = data["paths"]["/foo/bar"] + assert path["get"]["description"] == "the same endpoint" - path = data['paths']['/bar'] - assert path['get']['description'] == 'an endpoint' + path = data["paths"]["/bar"] + assert path["get"]["description"] == "an endpoint" def test_multiple_routes_no_doc_same_operationIds(self, api, client): - @api.route('/foo/bar') - @api.route('/bar') + @api.route("/foo/bar") + @api.route("/bar") class TestResource(restx.Resource): def get(self): pass data = client.get_specs() - expected_operation_id = 'get_test_resource' + expected_operation_id = "get_test_resource" - path = data['paths']['/foo/bar'] - assert path['get']['operationId'] == expected_operation_id + path = data["paths"]["/foo/bar"] + assert path["get"]["operationId"] == expected_operation_id - path = data['paths']['/bar'] - assert path['get']['operationId'] == expected_operation_id + path = data["paths"]["/bar"] + assert path["get"]["operationId"] == expected_operation_id def test_multiple_routes_with_doc_unique_operationIds(self, api, client): @api.route( - "/foo/bar", - doc={"description": "I should be treated separately"}, + "/foo/bar", doc={"description": "I should be treated separately"}, ) @api.route("/bar") class TestResource(restx.Resource): @@ -3232,46 +3084,46 @@ def get(self): data = client.get_specs() - path = data['paths']['/foo/bar'] - assert path['get']['operationId'] == 'get_test_resource_/foo/bar' + path = data["paths"]["/foo/bar"] + assert path["get"]["operationId"] == "get_test_resource_/foo/bar" - path = data['paths']['/bar'] - assert path['get']['operationId'] == 'get_test_resource' + path = data["paths"]["/bar"] + assert path["get"]["operationId"] == "get_test_resource" def test_mutltiple_routes_merge_doc(self, api, client): - @api.route('/foo/bar', doc={'description': 'the same endpoint'}) - @api.route('/bar', doc={'description': False}) - @api.doc(security=[{'oauth2': ['read', 'write']}]) + @api.route("/foo/bar", doc={"description": "the same endpoint"}) + @api.route("/bar", doc={"description": False}) + @api.doc(security=[{"oauth2": ["read", "write"]}]) class TestResource(restx.Resource): def get(self): pass data = client.get_specs() - path = data['paths']['/foo/bar'] - assert path['get']['description'] == 'the same endpoint' - assert path['get']['security'] == [{'oauth2': ['read', 'write']}] + path = data["paths"]["/foo/bar"] + assert path["get"]["description"] == "the same endpoint" + assert path["get"]["security"] == [{"oauth2": ["read", "write"]}] - path = data['paths']['/bar'] - assert 'description' not in path['get'] - assert path['get']['security'] == [{'oauth2': ['read', 'write']}] + path = data["paths"]["/bar"] + assert "description" not in path["get"] + assert path["get"]["security"] == [{"oauth2": ["read", "write"]}] def test_multiple_routes_deprecation(self, api, client): - @api.route('/foo/bar', doc={'deprecated': True}) - @api.route('/bar') + @api.route("/foo/bar", doc={"deprecated": True}) + @api.route("/bar") class TestResource(restx.Resource): def get(self): pass data = client.get_specs() - path = data['paths']['/foo/bar'] - assert path['get']['deprecated'] is True + path = data["paths"]["/foo/bar"] + assert path["get"]["deprecated"] is True - path = data['paths']['/bar'] - assert 'deprecated' not in path['get'] + path = data["paths"]["/bar"] + assert "deprecated" not in path["get"] - @pytest.mark.parametrize('path_name', ['/name/{age}/', '/first-name/{age}/']) + @pytest.mark.parametrize("path_name", ["/name/{age}/", "/first-name/{age}/"]) def test_multiple_routes_explicit_parameters_override(self, path_name, api, client): @api.route("/name//", endpoint="by-name") @api.route("/first-name//") @@ -3286,9 +3138,7 @@ def test_multiple_routes_explicit_parameters_override(self, path_name, api, clie } ) class ByNameResource(restx.Resource): - @api.doc( - params={"q": {"description": "A query string"}} - ) + @api.doc(params={"q": {"description": "A query string"}}) def get(self, age): return {} @@ -3296,66 +3146,68 @@ def post(self, age): pass data = client.get_specs() - assert path_name in data['paths'] + assert path_name in data["paths"] - path = data['paths'][path_name] - assert len(path['parameters']) == 1 + path = data["paths"][path_name] + assert len(path["parameters"]) == 1 - by_name = dict((p['name'], p) for p in path['parameters']) + by_name = dict((p["name"], p) for p in path["parameters"]) - parameter = by_name['age'] - assert parameter['name'] == 'age' - assert parameter['type'] == 'integer' - assert parameter['in'] == 'path' - assert parameter['required'] is True - assert parameter['description'] == 'An age' + parameter = by_name["age"] + assert parameter["name"] == "age" + assert parameter["type"] == "integer" + assert parameter["in"] == "path" + assert parameter["required"] is True + assert parameter["description"] == "An age" # Don't duplicate parameters - assert 'q' not in by_name + assert "q" not in by_name - get = path['get'] - assert len(get['parameters']) == 1 + get = path["get"] + assert len(get["parameters"]) == 1 - parameter = get['parameters'][0] - assert parameter['name'] == 'q' - assert parameter['type'] == 'string' - assert parameter['in'] == 'query' - assert parameter['description'] == 'A query string' + parameter = get["parameters"][0] + assert parameter["name"] == "q" + assert parameter["type"] == "string" + assert parameter["in"] == "query" + assert parameter["description"] == "A query string" - post = path['post'] - assert len(post['parameters']) == 1 + post = path["post"] + assert len(post["parameters"]) == 1 - parameter = post['parameters'][0] - assert parameter['name'] == 'q' - assert parameter['type'] == 'string' - assert parameter['in'] == 'query' - assert parameter['description'] == 'Overriden description' + parameter = post["parameters"][0] + assert parameter["name"] == "q" + assert parameter["type"] == "string" + assert parameter["in"] == "query" + assert parameter["description"] == "Overriden description" class SwaggerDeprecatedTest(object): def test_doc_parser_parameters(self, api): parser = api.parser() - parser.add_argument('param', type=int, help='Some param') + parser.add_argument("param", type=int, help="Some param") with pytest.warns(DeprecationWarning): - @api.route('/with-parser/') + + @api.route("/with-parser/") class WithParserResource(restx.Resource): @api.doc(parser=parser) def get(self): return {} - assert 'parser' not in WithParserResource.get.__apidoc__ - assert 'expect' in WithParserResource.get.__apidoc__ - doc_parser = WithParserResource.get.__apidoc__['expect'][0] + assert "parser" not in WithParserResource.get.__apidoc__ + assert "expect" in WithParserResource.get.__apidoc__ + doc_parser = WithParserResource.get.__apidoc__["expect"][0] assert doc_parser.__schema__ == parser.__schema__ def test_doc_method_parser_on_class(self, api): parser = api.parser() - parser.add_argument('param', type=int, help='Some param') + parser.add_argument("param", type=int, help="Some param") with pytest.warns(DeprecationWarning): - @api.route('/with-parser/') - @api.doc(get={'parser': parser}) + + @api.route("/with-parser/") + @api.doc(get={"parser": parser}) class WithParserResource(restx.Resource): def get(self): return {} @@ -3363,40 +3215,44 @@ def get(self): def post(self): return {} - assert 'parser' not in WithParserResource.__apidoc__['get'] - assert 'expect' in WithParserResource.__apidoc__['get'] - doc_parser = WithParserResource.__apidoc__['get']['expect'][0] + assert "parser" not in WithParserResource.__apidoc__["get"] + assert "expect" in WithParserResource.__apidoc__["get"] + doc_parser = WithParserResource.__apidoc__["get"]["expect"][0] assert doc_parser.__schema__ == parser.__schema__ def test_doc_body_as_tuple(self, api): - fields = api.model('Person', { - 'name': restx.fields.String, - 'age': restx.fields.Integer, - 'birthdate': restx.fields.DateTime, - }) + fields = api.model( + "Person", + { + "name": restx.fields.String, + "age": restx.fields.Integer, + "birthdate": restx.fields.DateTime, + }, + ) with pytest.warns(DeprecationWarning): - @api.route('/model-as-dict/') + + @api.route("/model-as-dict/") class ModelAsDict(restx.Resource): - @api.doc(body=(fields, 'Body description')) + @api.doc(body=(fields, "Body description")) def post(self): return {} - assert 'body' not in ModelAsDict.post.__apidoc__ - assert ModelAsDict.post.__apidoc__['expect'] == [(fields, 'Body description')] + assert "body" not in ModelAsDict.post.__apidoc__ + assert ModelAsDict.post.__apidoc__["expect"] == [(fields, "Body description")] def test_build_request_body_parameters_schema(self): parser = restx.reqparse.RequestParser() - parser.add_argument('test', type=int, location='headers') - parser.add_argument('test1', type=int, location='json') - parser.add_argument('test2', location='json') + parser.add_argument("test", type=int, location="headers") + parser.add_argument("test1", type=int, location="json") + parser.add_argument("test2", location="json") - body_params = [p for p in parser.__schema__ if p['in'] == 'body'] + body_params = [p for p in parser.__schema__ if p["in"] == "body"] result = restx.swagger.build_request_body_parameters_schema(body_params) - assert result['name'] == 'payload' - assert result['required'] - assert result['in'] == 'body' - assert result['schema']['type'] == 'object' - assert result['schema']['properties']['test1']['type'] == 'integer' - assert result['schema']['properties']['test2']['type'] == 'string' + assert result["name"] == "payload" + assert result["required"] + assert result["in"] == "body" + assert result["schema"]["type"] == "object" + assert result["schema"]["properties"]["test1"]["type"] == "integer" + assert result["schema"]["properties"]["test2"]["type"] == "string" diff --git a/tests/test_swagger_utils.py b/tests/test_swagger_utils.py index bec6eeaa..eda8cfbb 100644 --- a/tests/test_swagger_utils.py +++ b/tests/test_swagger_utils.py @@ -6,89 +6,89 @@ class ExtractPathTest(object): def test_extract_static_path(self): - path = '/test' - assert extract_path(path) == '/test' + path = "/test" + assert extract_path(path) == "/test" def test_extract_path_with_a_single_simple_parameter(self): - path = '/test/' - assert extract_path(path) == '/test/{parameter}' + path = "/test/" + assert extract_path(path) == "/test/{parameter}" def test_extract_path_with_a_single_typed_parameter(self): - path = '/test/' - assert extract_path(path) == '/test/{parameter}' + path = "/test/" + assert extract_path(path) == "/test/{parameter}" def test_extract_path_with_a_single_typed_parameter_with_arguments(self): - path = '/test/' - assert extract_path(path) == '/test/{parameter}' + path = "/test/" + assert extract_path(path) == "/test/{parameter}" def test_extract_path_with_multiple_parameters(self): - path = '/test///' - assert extract_path(path) == '/test/{parameter}/{other}/' + path = "/test///" + assert extract_path(path) == "/test/{parameter}/{other}/" class ExtractPathParamsTestCase(object): def test_extract_static_path(self): - path = '/test' + path = "/test" assert extract_path_params(path) == {} def test_extract_single_simple_parameter(self): - path = '/test/' + path = "/test/" assert extract_path_params(path) == { - 'parameter': { - 'name': 'parameter', - 'type': 'string', - 'in': 'path', - 'required': True + "parameter": { + "name": "parameter", + "type": "string", + "in": "path", + "required": True, } } def test_single_int_parameter(self): - path = '/test/' + path = "/test/" assert extract_path_params(path) == { - 'parameter': { - 'name': 'parameter', - 'type': 'integer', - 'in': 'path', - 'required': True + "parameter": { + "name": "parameter", + "type": "integer", + "in": "path", + "required": True, } } def test_single_float_parameter(self): - path = '/test/' + path = "/test/" assert extract_path_params(path) == { - 'parameter': { - 'name': 'parameter', - 'type': 'number', - 'in': 'path', - 'required': True + "parameter": { + "name": "parameter", + "type": "number", + "in": "path", + "required": True, } } def test_extract_path_with_multiple_parameters(self): - path = '/test///' + path = "/test///" assert extract_path_params(path) == { - 'parameter': { - 'name': 'parameter', - 'type': 'string', - 'in': 'path', - 'required': True + "parameter": { + "name": "parameter", + "type": "string", + "in": "path", + "required": True, + }, + "other": { + "name": "other", + "type": "integer", + "in": "path", + "required": True, }, - 'other': { - 'name': 'other', - 'type': 'integer', - 'in': 'path', - 'required': True - } } def test_extract_parameter_with_arguments(self): - path = '/test/' + path = "/test/" assert extract_path_params(path) == { - 'parameter': { - 'name': 'parameter', - 'type': 'string', - 'in': 'path', - 'required': True + "parameter": { + "name": "parameter", + "type": "string", + "in": "path", + "required": True, } } @@ -119,76 +119,77 @@ def without_doc(): parsed = parse_docstring(without_doc) - assert parsed['raw'] is None - assert parsed['summary'] is None - assert parsed['details'] is None - assert parsed['returns'] is None - assert parsed['raises'] == {} - assert parsed['params'] == [] + assert parsed["raw"] is None + assert parsed["summary"] is None + assert parsed["details"] is None + assert parsed["returns"] is None + assert parsed["raises"] == {} + assert parsed["params"] == [] def test_single_line(self): def func(): - '''Some summary''' + """Some summary""" pass parsed = parse_docstring(func) - assert parsed['raw'] == 'Some summary' - assert parsed['summary'] == 'Some summary' - assert parsed['details'] is None - assert parsed['returns'] is None - assert parsed['raises'] == {} - assert parsed['params'] == [] + assert parsed["raw"] == "Some summary" + assert parsed["summary"] == "Some summary" + assert parsed["details"] is None + assert parsed["returns"] is None + assert parsed["raises"] == {} + assert parsed["params"] == [] def test_multi_line(self): def func(): - ''' + """ Some summary Some details - ''' + """ pass parsed = parse_docstring(func) - assert parsed['raw'] == 'Some summary\nSome details' - assert parsed['summary'] == 'Some summary' - assert parsed['details'] == 'Some details' - assert parsed['returns'] is None - assert parsed['raises'] == {} - assert parsed['params'] == [] + assert parsed["raw"] == "Some summary\nSome details" + assert parsed["summary"] == "Some summary" + assert parsed["details"] == "Some details" + assert parsed["returns"] is None + assert parsed["raises"] == {} + assert parsed["params"] == [] def test_multi_line_and_dot(self): def func(): - ''' + """ Some summary. bla bla Some details - ''' + """ pass parsed = parse_docstring(func) - assert parsed['raw'] == 'Some summary. bla bla\nSome details' - assert parsed['summary'] == 'Some summary' - assert parsed['details'] == 'bla bla\nSome details' - assert parsed['returns'] is None - assert parsed['raises'] == {} - assert parsed['params'] == [] + assert parsed["raw"] == "Some summary. bla bla\nSome details" + assert parsed["summary"] == "Some summary" + assert parsed["details"] == "bla bla\nSome details" + assert parsed["returns"] is None + assert parsed["raises"] == {} + assert parsed["params"] == [] def test_raises(self): def func(): - ''' + """ Some summary. :raises SomeException: in case of something - ''' + """ pass parsed = parse_docstring(func) - assert parsed['raw'] == 'Some summary.\n:raises SomeException: in case of something' - assert parsed['summary'] == 'Some summary' - assert parsed['details'] is None - assert parsed['returns'] is None - assert parsed['params'] == [] - assert parsed['raises'] == { - 'SomeException': 'in case of something' - } + assert ( + parsed["raw"] + == "Some summary.\n:raises SomeException: in case of something" + ) + assert parsed["summary"] == "Some summary" + assert parsed["details"] is None + assert parsed["returns"] is None + assert parsed["params"] == [] + assert parsed["raises"] == {"SomeException": "in case of something"} diff --git a/tests/test_utils.py b/tests/test_utils.py index 9c6d3e5e..20ec93a3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,115 +8,95 @@ class MergeTestCase(object): def test_merge_simple_dicts_without_precedence(self): - a = {'a': 'value'} - b = {'b': 'other value'} - assert utils.merge(a, b) == {'a': 'value', 'b': 'other value'} + a = {"a": "value"} + b = {"b": "other value"} + assert utils.merge(a, b) == {"a": "value", "b": "other value"} def test_merge_simple_dicts_with_precedence(self): - a = {'a': 'value', 'ab': 'overwritten'} - b = {'b': 'other value', 'ab': 'keep'} - assert utils.merge(a, b) == {'a': 'value', 'b': 'other value', 'ab': 'keep'} + a = {"a": "value", "ab": "overwritten"} + b = {"b": "other value", "ab": "keep"} + assert utils.merge(a, b) == {"a": "value", "b": "other value", "ab": "keep"} def test_recursions(self): a = { - 'a': 'value', - 'ab': 'overwritten', - 'nested_a': { - 'a': 'nested' - }, - 'nested_a_b': { - 'a': 'a only', - 'ab': 'overwritten' - } + "a": "value", + "ab": "overwritten", + "nested_a": {"a": "nested"}, + "nested_a_b": {"a": "a only", "ab": "overwritten"}, } b = { - 'b': 'other value', - 'ab': 'keep', - 'nested_b': { - 'b': 'nested' - }, - 'nested_a_b': { - 'b': 'b only', - 'ab': 'keep' - } + "b": "other value", + "ab": "keep", + "nested_b": {"b": "nested"}, + "nested_a_b": {"b": "b only", "ab": "keep"}, } assert utils.merge(a, b) == { - 'a': 'value', - 'b': 'other value', - 'ab': 'keep', - 'nested_a': { - 'a': 'nested' - }, - 'nested_b': { - 'b': 'nested' - }, - 'nested_a_b': { - 'a': 'a only', - 'b': 'b only', - 'ab': 'keep' - } + "a": "value", + "b": "other value", + "ab": "keep", + "nested_a": {"a": "nested"}, + "nested_b": {"b": "nested"}, + "nested_a_b": {"a": "a only", "b": "b only", "ab": "keep"}, } def test_recursions_with_empty(self): a = {} b = { - 'b': 'other value', - 'ab': 'keep', - 'nested_b': { - 'b': 'nested' - }, - 'nested_a_b': { - 'b': 'b only', - 'ab': 'keep' - } + "b": "other value", + "ab": "keep", + "nested_b": {"b": "nested"}, + "nested_a_b": {"b": "b only", "ab": "keep"}, } assert utils.merge(a, b) == b class CamelToDashTestCase(object): def test_no_transform(self): - assert utils.camel_to_dash('test') == 'test' - - @pytest.mark.parametrize('value,expected', [ - ('aValue', 'a_value'), - ('aLongValue', 'a_long_value'), - ('Upper', 'upper'), - ('UpperCase', 'upper_case'), - ]) + assert utils.camel_to_dash("test") == "test" + + @pytest.mark.parametrize( + "value,expected", + [ + ("aValue", "a_value"), + ("aLongValue", "a_long_value"), + ("Upper", "upper"), + ("UpperCase", "upper_case"), + ], + ) def test_transform(self, value, expected): assert utils.camel_to_dash(value) == expected class UnpackTest(object): def test_single_value(self): - data, code, headers = utils.unpack('test') - assert data == 'test' + data, code, headers = utils.unpack("test") + assert data == "test" assert code == 200 assert headers == {} def test_single_value_with_default_code(self): - data, code, headers = utils.unpack('test', 500) - assert data == 'test' + data, code, headers = utils.unpack("test", 500) + assert data == "test" assert code == 500 assert headers == {} def test_value_code(self): - data, code, headers = utils.unpack(('test', 201)) - assert data == 'test' + data, code, headers = utils.unpack(("test", 201)) + assert data == "test" assert code == 201 assert headers == {} def test_value_code_headers(self): - data, code, headers = utils.unpack(('test', 201, {'Header': 'value'})) - assert data == 'test' + data, code, headers = utils.unpack(("test", 201, {"Header": "value"})) + assert data == "test" assert code == 201 - assert headers == {'Header': 'value'} + assert headers == {"Header": "value"} def test_value_headers_default_code(self): - data, code, headers = utils.unpack(('test', None, {'Header': 'value'})) - assert data == 'test' + data, code, headers = utils.unpack(("test", None, {"Header": "value"})) + assert data == "test" assert code == 200 - assert headers == {'Header': 'value'} + assert headers == {"Header": "value"} def test_too_many_values(self): with pytest.raises(ValueError):