Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/Reference/Models.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,4 @@
::: flask_openapi3.models.path.Operation
::: flask_openapi3.models.path.PathItem


::: flask_openapi3.models.validation_error.UnprocessableEntity
::: flask_openapi3.models.validation_error.ValidationErrorModel
43 changes: 42 additions & 1 deletion docs/Usage/Specification.md
Original file line number Diff line number Diff line change
Expand Up @@ -375,4 +375,45 @@ def hello():

if __name__ == "__main__":
app.run(debug=True)
```
```


## validation error

*new in v2.5.0*

You can override validation error response use `validation_error_status`, `validation_error_model`
and `validation_error_callback`.


- validation_error_status: HTTP Status of the response given when a validation error is detected by pydantic.
Defaults to 422.
- validation_error_model: Validation error response model for OpenAPI Specification.
- validation_error_callback: Validation error response callback, the return format corresponds to
the validation_error_model. Receive `ValidationError` and return `Flask Response`.


```python
from flask.wrappers import Response as FlaskResponse
from pydantic import BaseModel, ValidationError

class ValidationErrorModel(BaseModel):
code: str
message: str


def validation_error_callback(e: ValidationError) -> FlaskResponse:
validation_error_object = ValidationErrorModel(code="400", message=e.json())
response = make_response(validation_error_object.json())
response.headers["Content-Type"] = "application/json"
response.status_code = getattr(current_app, "validation_error_status", 422)
return response


app = OpenAPI(
__name__,
validation_error_status=400,
validation_error_model=ValidationErrorModel,
validation_error_callback=validation_error_callback
)
```
1 change: 1 addition & 0 deletions flask_openapi3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,6 @@
from .models.server import ServerVariable
from .models.tag import Tag
from .models.validation_error import UnprocessableEntity
from .models.validation_error import ValidationErrorModel
from .openapi import OpenAPI
from .view import APIView
8 changes: 6 additions & 2 deletions flask_openapi3/models/validation_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from pydantic import BaseModel, Field


class UnprocessableEntity(BaseModel):
# More information: https://pydantic-docs.helpmanual.io/usage/models/#error-handling
class ValidationErrorModel(BaseModel):
# More information: https://docs.pydantic.dev/1.10/usage/models/#error-handling
loc: Optional[List[str]] = Field(None, title="Location", description="the error's location as a list. ")
msg: Optional[str] = Field(None, title="Message", description="a computer-readable identifier of the error type.")
type_: Optional[str] = Field(None, title="Error Type", description="a human readable explanation of the error.")
Expand All @@ -16,3 +16,7 @@ class UnprocessableEntity(BaseModel):
title="Error context",
description="an optional object which contains values required to render the error message."
)


# backward compatibility
UnprocessableEntity = ValidationErrorModel
58 changes: 50 additions & 8 deletions flask_openapi3/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@

from .blueprint import APIBlueprint
from .commands import openapi_command
from .http import HTTPMethod
from .models import Info, APISpec, Tag, Components, Server
from .models.common import ExternalDocumentation, ExtraRequestBody
from .http import HTTPMethod, HTTP_STATUS
from .models import Info, APISpec, Tag, Components, Server, OPENAPI3_REF_PREFIX
from .models.common import ExternalDocumentation, ExtraRequestBody, Schema
from .models.oauth import OAuthConfig
from .models.security import SecurityScheme
from .models.validation_error import ValidationErrorModel
from .scaffold import APIScaffold
from .templates import openapi_html_string, redoc_html_string, rapidoc_html_string, swagger_html_string
from .utils import get_operation, get_responses, parse_and_store_tags, parse_parameters, parse_method, \
get_operation_id_for_path
get_operation_id_for_path, make_validation_error_response
from .view import APIView


Expand All @@ -45,6 +46,9 @@ def __init__(
external_docs: Optional[ExternalDocumentation] = None,
operation_id_callback: Callable = get_operation_id_for_path,
openapi_extensions: Optional[Dict[str, Any]] = None,
validation_error_status: Union[str, int] = 422,
validation_error_model: Type[BaseModel] = ValidationErrorModel,
validation_error_callback: Callable = make_validation_error_response,
**kwargs: Any
) -> None:
"""
Expand Down Expand Up @@ -77,6 +81,11 @@ def __init__(
Default to `get_operation_id_for_path` from utils.
openapi_extensions: Extensions to the OpenAPI Schema.
See https://spec.openapis.org/oas/v3.0.3#specification-extensions.
validation_error_status: HTTP Status of the response given when a validation error is detected by pydantic.
Defaults to 422.
validation_error_model: Validation error response model for OpenAPI Specification.
validation_error_callback: Validation error response callback, the return format corresponds to
the validation_error_model.
**kwargs: Additional kwargs to be passed to Flask.
"""
super(OpenAPI, self).__init__(import_name, **kwargs)
Expand Down Expand Up @@ -120,13 +129,21 @@ def __init__(
# Set OpenAPI extensions
self.openapi_extensions = openapi_extensions or dict()

# Set HTTP Response of validation errors within OpenAPI
self.validation_error_status = str(validation_error_status)
self.validation_error_model = validation_error_model
self.validation_error_callback = validation_error_callback

# Initialize the OpenAPI documentation UI
if doc_ui:
self._init_doc()

# Add the OpenAPI command
self.cli.add_command(openapi_command)

# Initialize specification JSON
self.spec_json: Dict = dict()

def _init_doc(self) -> None:
"""
Provide Swagger UI, Redoc, and Rapidoc
Expand Down Expand Up @@ -198,6 +215,9 @@ def api_doc(self) -> Dict:
The OpenAPI specification JSON as a dictionary.

"""
if self.spec_json:
return self.spec_json

spec = APISpec(
openapi=self.openapi_version,
info=self.info,
Expand All @@ -210,18 +230,40 @@ def api_doc(self) -> Dict:
# Set paths
spec.paths = self.paths

# Add ValidationErrorModel to components schemas
self.components_schemas[self.validation_error_model.__name__] = Schema(**self.validation_error_model.schema())

# Set components
self.components.schemas = self.components_schemas
self.components.securitySchemes = self.security_schemes
spec.components = self.components

# Convert spec to JSON
spec_json = json.loads(spec.json(by_alias=True, exclude_none=True))
self.spec_json = json.loads(spec.json(by_alias=True, exclude_none=True))

# Update with OpenAPI extensions
spec_json.update(**self.openapi_extensions)

return spec_json
self.spec_json.update(**self.openapi_extensions)

# Handle validation error response
for rule, path_item in self.spec_json["paths"].items():
for http_method, operation in path_item.items():
if not operation.get("responses"):
operation["responses"] = {}
if operation["responses"].get(self.validation_error_status):
continue
operation["responses"][self.validation_error_status] = {
"description": HTTP_STATUS[self.validation_error_status],
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {"$ref": f"{OPENAPI3_REF_PREFIX}/{self.validation_error_model.__name__}"}
}
}
}
}

return self.spec_json

def register_api(self, api: APIBlueprint) -> None:
"""
Expand Down
14 changes: 6 additions & 8 deletions flask_openapi3/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from json import JSONDecodeError
from typing import Any, Type, Optional, Dict, Union

from flask import request, make_response
from flask.wrappers import Response
from flask import request, current_app
from flask.wrappers import Response as FlaskResponse
from pydantic import ValidationError, BaseModel
from pydantic.error_wrappers import ErrorWrapper

Expand Down Expand Up @@ -94,7 +94,7 @@ def _do_request(
form: Optional[Type[BaseModel]] = None,
body: Optional[Type[BaseModel]] = None,
path_kwargs: Optional[Dict[Any, Any]] = None
) -> Union[Response, Dict]:
) -> Union[FlaskResponse, Dict]:
"""
Validate requests and responses.

Expand Down Expand Up @@ -132,10 +132,8 @@ def _do_request(
if body:
_do_body(body, func_kwargs)
except ValidationError as e:
# Create a JSON response with validation error details
response = make_response(e.json())
response.headers["Content-Type"] = "application/json"
response.status_code = 422
return response
# Create a response with validation error details
validation_error_callback = getattr(current_app, "validation_error_callback")
return validation_error_callback(e)

return func_kwargs
22 changes: 11 additions & 11 deletions flask_openapi3/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Callable, List, Optional, Dict, Type, Any, Tuple, Union

from flask.scaffold import Scaffold
from flask.wrappers import Response
from flask.wrappers import Response as FlaskResponse
from pydantic import BaseModel

from .http import HTTPMethod
Expand Down Expand Up @@ -77,7 +77,7 @@ def create_view_func(
is_coroutine_function = iscoroutinefunction(func)
if is_coroutine_function:
@wraps(func)
async def view_func(**kwargs) -> Response:
async def view_func(**kwargs) -> FlaskResponse:
func_kwargs = _do_request(
header=header,
cookie=cookie,
Expand All @@ -87,8 +87,8 @@ async def view_func(**kwargs) -> Response:
body=body,
path_kwargs=kwargs
)
if isinstance(func_kwargs, Response):
# 422
if isinstance(func_kwargs, FlaskResponse):
# Validation error response
return func_kwargs
# handle async request
if view_class:
Expand All @@ -104,8 +104,8 @@ async def view_func(**kwargs) -> Response:
return response
else:
@wraps(func)
def view_func(**kwargs) -> Response:
result = _do_request(
def view_func(**kwargs) -> FlaskResponse:
func_kwargs = _do_request(
header=header,
cookie=cookie,
path=path,
Expand All @@ -114,9 +114,9 @@ def view_func(**kwargs) -> Response:
body=body,
path_kwargs=kwargs
)
if isinstance(result, Response):
# 422
return result
if isinstance(func_kwargs, FlaskResponse):
# Validation error response
return func_kwargs
# handle request
if view_class:
signature = inspect.signature(view_class.__init__)
Expand All @@ -125,9 +125,9 @@ def view_func(**kwargs) -> Response:
view_object = view_class(view_kwargs=view_kwargs)
else:
view_object = view_class()
response = func(view_object, **result)
response = func(view_object, **func_kwargs)
else:
response = func(**result)
response = func(**func_kwargs)
return response

if not hasattr(func, "view"):
Expand Down
47 changes: 22 additions & 25 deletions flask_openapi3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import re
from typing import get_type_hints, Dict, Type, Callable, List, Tuple, Optional, Any, Union

from pydantic import BaseModel
from flask import make_response, current_app
from flask.wrappers import Response as FlaskResponse
from pydantic import BaseModel, ValidationError

from .http import HTTP_STATUS, HTTPMethod
from .models import OPENAPI3_REF_TEMPLATE, OPENAPI3_REF_PREFIX, Tag
from .models.common import Schema, MediaType, Encoding, ExtraRequestBody
from .models.path import Operation, RequestBody, PathItem, Response
from .models.path import ParameterInType, Parameter
from .models.validation_error import UnprocessableEntity


def get_operation(
Expand Down Expand Up @@ -283,39 +284,19 @@ def parse_body(


def get_responses(
responses: Optional[Dict[str, Union[Type[BaseModel], Dict[Any, Any], None]]],
responses: Dict[str, Union[Type[BaseModel], Dict[Any, Any], None]],
extra_responses: Dict[str, dict],
components_schemas: dict,
operation: Operation
) -> None:
if responses is None:
responses = {}
_responses = {}
_schemas = {}
if not responses.get("422"):
# Handle response 422 for Unprocessable Entity
_responses["422"] = Response(
description=HTTP_STATUS["422"],
content={
"application/json": MediaType(
**{
"schema": Schema(
**{
"type": "array",
"items": {"$ref": f"{OPENAPI3_REF_PREFIX}/{UnprocessableEntity.__name__}"}
}
)
}
)
}
)
_schemas[UnprocessableEntity.__name__] = Schema(**UnprocessableEntity.schema())

for key, response in responses.items():
if response is None:
# If the response is None, it means HTTP status code "204" (No Content)
_responses[key] = Response(description=HTTP_STATUS.get(key, ""))
continue
if isinstance(response, dict):
elif isinstance(response, dict):
_responses[key] = response # type: ignore
else:
schema = get_model_schema(response)
Expand Down Expand Up @@ -560,6 +541,22 @@ def parse_method(uri: str, method: str, paths: dict, operation: Operation) -> No
paths[uri].delete = operation


def make_validation_error_response(e: ValidationError) -> FlaskResponse:
"""
Create a Flask response for a validation error.

Args:
e: The ValidationError object containing the details of the error.

Returns:
FlaskResponse: A Flask Response object with the JSON representation of the error.
"""
response = make_response(e.json())
response.headers["Content-Type"] = "application/json"
response.status_code = getattr(current_app, "validation_error_status", 422)
return response


def parse_rule(rule: str, url_prefix=None) -> str:
trail_slash = rule.endswith("/")

Expand Down
Loading