Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion docs/Usage/Model_Config.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,18 @@ class Message(BaseModel):

Effect in swagger:

![](../assets/Snipaste_2023-06-02_11-08-40.png)
![](../assets/Snipaste_2023-06-02_11-08-40.png)


## by_alias

Sometimes you may not want to use aliases (such as in the responses model). In that case, `by_alias` will be convenient:

```python
class MessageResponse(BaseModel):
message: str = Field(..., description="The message")
metadata: Dict[str, str] = Field(alias="metadata_")

class Config:
by_alias = False
```
25 changes: 14 additions & 11 deletions flask_openapi3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,21 @@ def get_operation_id_for_path(*, name: str, path: str, method: str) -> str:
return operation_id


def get_schema(obj: Type[BaseModel]) -> dict:
def get_model_schema(model: Type[BaseModel]) -> dict:
"""Converts a Pydantic model to an OpenAPI schema."""

assert inspect.isclass(obj) and issubclass(obj, BaseModel), \
f"{obj} is invalid `pydantic.BaseModel`"
assert inspect.isclass(model) and issubclass(model, BaseModel), \
f"{model} is invalid `pydantic.BaseModel`"

return obj.schema(ref_template=OPENAPI3_REF_TEMPLATE)
model_config = model.Config
by_alias = getattr(model_config, "by_alias", True)

return model.schema(by_alias=by_alias, ref_template=OPENAPI3_REF_TEMPLATE)


def parse_header(header: Type[BaseModel]) -> Tuple[List[Parameter], dict]:
"""Parses a header model and returns a list of parameters and component schemas."""
schema = get_schema(header)
schema = get_model_schema(header)
parameters = []
components_schemas: Dict = dict()
properties = schema.get("properties", {})
Expand All @@ -121,7 +124,7 @@ def parse_header(header: Type[BaseModel]) -> Tuple[List[Parameter], dict]:

def parse_cookie(cookie: Type[BaseModel]) -> Tuple[List[Parameter], dict]:
"""Parses a cookie model and returns a list of parameters and component schemas."""
schema = get_schema(cookie)
schema = get_model_schema(cookie)
parameters = []
components_schemas: Dict = dict()
properties = schema.get("properties", {})
Expand All @@ -148,7 +151,7 @@ def parse_cookie(cookie: Type[BaseModel]) -> Tuple[List[Parameter], dict]:

def parse_path(path: Type[BaseModel]) -> Tuple[List[Parameter], dict]:
"""Parses a path model and returns a list of parameters and component schemas."""
schema = get_schema(path)
schema = get_model_schema(path)
parameters = []
components_schemas: Dict = dict()
properties = schema.get("properties", {})
Expand All @@ -175,7 +178,7 @@ def parse_path(path: Type[BaseModel]) -> Tuple[List[Parameter], dict]:

def parse_query(query: Type[BaseModel]) -> Tuple[List[Parameter], dict]:
"""Parses a query model and returns a list of parameters and component schemas."""
schema = get_schema(query)
schema = get_model_schema(query)
parameters = []
components_schemas: Dict = dict()
properties = schema.get("properties", {})
Expand Down Expand Up @@ -205,7 +208,7 @@ def parse_form(
extra_form: Optional[ExtraRequestBody] = None,
) -> Tuple[Dict[str, MediaType], dict]:
"""Parses a form model and returns a list of parameters and component schemas."""
schema = get_schema(form)
schema = get_model_schema(form)
components_schemas = dict()
properties = schema.get("properties", {})

Expand Down Expand Up @@ -250,7 +253,7 @@ def parse_body(
extra_body: Optional[ExtraRequestBody] = None,
) -> Tuple[Dict[str, MediaType], dict]:
"""Parses a body model and returns a list of parameters and component schemas."""
schema = get_schema(body)
schema = get_model_schema(body)
components_schemas = dict()

title = schema.get("title")
Expand Down Expand Up @@ -315,7 +318,7 @@ def get_responses(
if isinstance(response, dict):
_responses[key] = response # type: ignore
else:
schema = response.schema(ref_template=OPENAPI3_REF_TEMPLATE)
schema = get_model_schema(response)
_responses[key] = Response(
description=HTTP_STATUS.get(key, ""),
content={
Expand Down
151 changes: 151 additions & 0 deletions tests/test_model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# -*- coding: utf-8 -*-
# @Author : llc
# @Time : 2023/6/30 10:12
from typing import List, Dict

import pytest
from pydantic import BaseModel, Field

from flask_openapi3 import OpenAPI, FileStorage

app = OpenAPI(__name__)
app.config["TESTING"] = True


class UploadFilesForm(BaseModel):
file: FileStorage
str_list: List[str]

class Config:
openapi_extra = {
# "example": {"a": 123},
"examples": {
"Example 01": {
"summary": "An example",
"value": {
"file": "Example-01.jpg",
"str_list": ["a", "b", "c"]
}
},
"Example 02": {
"summary": "Another example",
"value": {
"str_list": ["1", "2", "3"]
}
}
}
}


class BookBody(BaseModel):
age: int
author: str

class Config:
openapi_extra = {
"description": "This is post RequestBody",
"example": {"age": 12, "author": "author1"},
"examples": {
"example1": {
"summary": "example summary1",
"description": "example description1",
"value": {
"age": 24,
"author": "author2"
}
},
"example2": {
"summary": "example summary2",
"description": "example description2",
"value": {
"age": 48,
"author": "author3"
}
}

}}


class MessageResponse(BaseModel):
message: str = Field(..., description="The message")
metadata: Dict[str, str] = Field(alias="metadata_")

class Config:
by_alias = False
openapi_extra = {
# "example": {"message": "aaa"},
"examples": {
"example1": {
"summary": "example1 summary",
"value": {
"message": "bbb"
}
},
"example2": {
"summary": "example2 summary",
"value": {
"message": "ccc"
}
}
}
}


@app.post("/form")
def api_form(form: UploadFilesForm):
print(form)
return {"code": 0, "message": "ok"}


@app.post("/body", responses={"200": MessageResponse})
def api_error_json(body: BookBody):
print(body)
return {"code": 0, "message": "ok"}


@pytest.fixture
def client():
client = app.test_client()

return client


def test_openapi(client):
resp = client.get("/openapi/openapi.json")
_json = resp.json
assert resp.status_code == 200
assert _json["paths"]["/form"]["post"]["requestBody"]["content"]["multipart/form-data"]["examples"] == \
{
"Example 01": {
"summary": "An example",
"value": {
"file": "Example-01.jpg",
"str_list": ["a", "b", "c"]
}
},
"Example 02": {
"summary": "Another example",
"value": {
"str_list": ["1", "2", "3"]
}
}
}
assert _json["paths"]["/body"]["post"]["requestBody"]["description"] == "This is post RequestBody"
assert _json["paths"]["/body"]["post"]["requestBody"]["content"]["application/json"]["example"] == \
{"age": 12, "author": "author1"}
assert _json["paths"]["/body"]["post"]["responses"]["200"]["content"]["application/json"]["examples"] == \
{
"example1": {
"summary": "example1 summary",
"value": {
"message": "bbb"
}
},
"example2": {
"summary": "example2 summary",
"value": {
"message": "ccc"
}
}
}
assert _json["components"]["schemas"]["MessageResponse"]["properties"].get("metadata") is not None