Skip to content

Commit 488b341

Browse files
author
sean
committed
Add support for aiohttp.web
1 parent 203b9af commit 488b341

File tree

8 files changed

+798
-3
lines changed

8 files changed

+798
-3
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from openapi_core.contrib.aiohttp.requests import AIOHTTPOpenAPIWebRequest
2+
from openapi_core.contrib.aiohttp.responses import AIOHTTPOpenAPIWebResponse
3+
4+
__all__ = [
5+
"AIOHTTPOpenAPIWebRequest",
6+
"AIOHTTPOpenAPIWebResponse",
7+
]
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""OpenAPI core contrib aiohttp requests module"""
2+
from __future__ import annotations
3+
4+
from typing import cast
5+
6+
from aiohttp import web
7+
from asgiref.sync import AsyncToSync
8+
9+
from openapi_core.datatypes import RequestParameters
10+
11+
12+
class Empty:
13+
...
14+
15+
16+
_empty = Empty()
17+
18+
19+
class AIOHTTPOpenAPIWebRequest:
20+
__slots__ = ("request", "parameters", "_get_body", "_body")
21+
22+
def __init__(self, request: web.Request, *, body: str | None):
23+
if not isinstance(request, web.Request):
24+
raise TypeError(
25+
f"'request' argument is not type of {web.Request.__qualname__!r}"
26+
)
27+
self.request = request
28+
self.parameters = RequestParameters(
29+
query=self.request.query,
30+
header=self.request.headers,
31+
cookie=self.request.cookies,
32+
)
33+
self._body = body
34+
35+
@property
36+
def host_url(self) -> str:
37+
return self.request.url.host or ""
38+
39+
@property
40+
def path(self) -> str:
41+
return self.request.url.path
42+
43+
@property
44+
def method(self) -> str:
45+
return self.request.method.lower()
46+
47+
@property
48+
def body(self) -> str | None:
49+
return self._body
50+
51+
@property
52+
def mimetype(self) -> str:
53+
return self.request.content_type
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""OpenAPI core contrib aiohttp responses module"""
2+
3+
import multidict
4+
from aiohttp import web
5+
6+
7+
class AIOHTTPOpenAPIWebResponse:
8+
def __init__(self, response: web.Response):
9+
if not isinstance(response, web.Response):
10+
raise TypeError(
11+
f"'response' argument is not type of {web.Response.__qualname__!r}"
12+
)
13+
self.response = response
14+
15+
@property
16+
def data(self) -> str:
17+
if isinstance(self.response.body, bytes):
18+
return self.response.body.decode("utf-8")
19+
assert isinstance(self.response.body, str)
20+
return self.response.body
21+
22+
@property
23+
def status_code(self) -> int:
24+
return self.response.status
25+
26+
@property
27+
def mimetype(self) -> str:
28+
return self.response.content_type or ""
29+
30+
@property
31+
def headers(self) -> multidict.CIMultiDict[str]:
32+
return self.response.headers

poetry.lock

Lines changed: 453 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ pathable = "^0.4.0"
6060
django = {version = ">=3.0", optional = true}
6161
falcon = {version = ">=3.0", optional = true}
6262
flask = {version = "*", optional = true}
63+
aiohttp = {version = ">=3.0", optional = true}
6364
isodate = "*"
6465
more-itertools = "*"
6566
parse = "*"
@@ -80,8 +81,9 @@ falcon = ["falcon"]
8081
flask = ["flask"]
8182
requests = ["requests"]
8283
starlette = ["starlette", "httpx"]
84+
aiohttp = ["aiohttp"]
8385

84-
[tool.poetry.dev-dependencies]
86+
[tool.poetry.group.dev.dependencies]
8587
black = "^23.3.0"
8688
django = ">=3.0"
8789
djangorestframework = "^3.11.2"
@@ -98,6 +100,8 @@ webob = "*"
98100
mypy = "^1.2"
99101
starlette = "^0.26.1"
100102
httpx = "^0.24.0"
103+
aiohttp = "^3.8.4"
104+
pytest-aiohttp = "^1.0.4"
101105

102106
[tool.poetry.group.docs.dependencies]
103107
sphinx = "^5.3.0"
@@ -113,6 +117,7 @@ addopts = """
113117
--cov-report=term-missing
114118
--cov-report=xml
115119
"""
120+
asyncio_mode = "auto"
116121

117122
[tool.black]
118123
line-length = 79
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import asyncio
2+
import pathlib
3+
from typing import Any
4+
from unittest import mock
5+
6+
import pytest
7+
from aiohttp import web
8+
from aiohttp.test_utils import TestClient
9+
10+
from openapi_core import openapi_request_validator
11+
from openapi_core import openapi_response_validator
12+
from openapi_core.contrib.aiohttp import AIOHTTPOpenAPIWebRequest
13+
from openapi_core.contrib.aiohttp import AIOHTTPOpenAPIWebResponse
14+
15+
16+
@pytest.fixture
17+
def spec(factory):
18+
directory = pathlib.Path(__file__).parent
19+
specfile = directory / "data" / "v3.0" / "aiohttp_factory.yaml"
20+
return factory.spec_from_file(str(specfile))
21+
22+
23+
@pytest.fixture
24+
def response_getter() -> mock.MagicMock:
25+
return mock.MagicMock(return_value={"data": "data"})
26+
27+
28+
@pytest.fixture
29+
def no_validation(response_getter):
30+
async def test_route(request: web.Request) -> web.Response:
31+
await asyncio.sleep(0)
32+
response = web.json_response(
33+
response_getter(),
34+
headers={"X-Rate-Limit": "12"},
35+
status=200,
36+
)
37+
return response
38+
39+
return test_route
40+
41+
42+
@pytest.fixture
43+
def request_validation(spec, response_getter):
44+
async def test_route(request: web.Request) -> web.Response:
45+
request_body = await request.text()
46+
openapi_request = AIOHTTPOpenAPIWebRequest(request, body=request_body)
47+
result = openapi_request_validator.validate(spec, openapi_request)
48+
response: dict[str, Any] = response_getter()
49+
status = 200
50+
if result.errors:
51+
status = 400
52+
response = {"errors": [{"message": str(e) for e in result.errors}]}
53+
return web.json_response(
54+
response,
55+
headers={"X-Rate-Limit": "12"},
56+
status=status,
57+
)
58+
59+
return test_route
60+
61+
62+
@pytest.fixture
63+
def response_validation(spec, response_getter):
64+
async def test_route(request: web.Request) -> web.Response:
65+
request_body = await request.text()
66+
openapi_request = AIOHTTPOpenAPIWebRequest(request, body=request_body)
67+
response_body = response_getter()
68+
response = web.json_response(
69+
response_body,
70+
headers={"X-Rate-Limit": "12"},
71+
status=200,
72+
)
73+
openapi_response = AIOHTTPOpenAPIWebResponse(response)
74+
result = openapi_response_validator.validate(
75+
spec, openapi_request, openapi_response
76+
)
77+
if result.errors:
78+
response = web.json_response(
79+
{"errors": [{"message": str(e) for e in result.errors}]},
80+
headers={"X-Rate-Limit": "12"},
81+
status=400,
82+
)
83+
return response
84+
85+
return test_route
86+
87+
88+
@pytest.fixture(
89+
params=["no_validation", "request_validation", "response_validation"]
90+
)
91+
def router(
92+
request,
93+
no_validation,
94+
request_validation,
95+
response_validation,
96+
) -> web.RouteTableDef:
97+
test_routes = dict(
98+
no_validation=no_validation,
99+
request_validation=request_validation,
100+
response_validation=response_validation,
101+
)
102+
router_ = web.RouteTableDef()
103+
handler = test_routes[request.param]
104+
route = router_.post("/browse/{id}/")(handler)
105+
return router_
106+
107+
108+
@pytest.fixture
109+
def app(router):
110+
app = web.Application()
111+
app.add_routes(router)
112+
113+
return app
114+
115+
116+
@pytest.fixture
117+
async def client(app, aiohttp_client) -> TestClient:
118+
return await aiohttp_client(app)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
openapi: "3.0.0"
2+
info:
3+
title: Basic OpenAPI specification used with starlette integration tests
4+
version: "0.1"
5+
servers:
6+
- url: '/'
7+
description: 'testing'
8+
paths:
9+
'/browse/{id}/':
10+
parameters:
11+
- name: id
12+
in: path
13+
required: true
14+
description: the ID of the resource to retrieve
15+
schema:
16+
type: integer
17+
- name: q
18+
in: query
19+
required: true
20+
description: query key
21+
schema:
22+
type: string
23+
post:
24+
requestBody:
25+
description: request data
26+
required: True
27+
content:
28+
application/json:
29+
schema:
30+
type: object
31+
required:
32+
- param1
33+
properties:
34+
param1:
35+
type: integer
36+
responses:
37+
200:
38+
description: Return the resource.
39+
content:
40+
application/json:
41+
schema:
42+
type: object
43+
required:
44+
- data
45+
properties:
46+
data:
47+
type: string
48+
headers:
49+
X-Rate-Limit:
50+
description: Rate limit
51+
schema:
52+
type: integer
53+
required: true
54+
default:
55+
description: Return errors.
56+
content:
57+
application/json:
58+
schema:
59+
type: object
60+
required:
61+
- errors
62+
properties:
63+
errors:
64+
type: array
65+
items:
66+
type: object
67+
properties:
68+
title:
69+
type: string
70+
code:
71+
type: string
72+
message:
73+
type: string
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
from unittest import mock
5+
6+
import pytest
7+
8+
if TYPE_CHECKING:
9+
from aiohttp.test_utils import TestClient
10+
11+
12+
async def test_aiohttp_integration_valid_input(client: TestClient):
13+
# Given
14+
given_query_string = {
15+
"q": "string",
16+
}
17+
given_headers = {"content-type": "application/json"}
18+
given_data = {"param1": 1}
19+
expected_status_code = 200
20+
expected_response_data = {"data": "data"}
21+
# When
22+
response = await client.post(
23+
"/browse/12/",
24+
params=given_query_string,
25+
json=given_data,
26+
headers=given_headers,
27+
)
28+
response_data = await response.json()
29+
assert response.status == expected_status_code
30+
assert response_data == expected_response_data
31+
32+
33+
async def test_aiohttp_integration_invalid_input(
34+
client: TestClient, response_getter, request
35+
):
36+
if "no_validation" in request.node.name:
37+
pytest.skip("No validation for given handler.")
38+
# Given
39+
given_query_string = {
40+
"q": "string",
41+
}
42+
given_headers = {"content-type": "application/json"}
43+
given_data = {"param1": "string"}
44+
response_getter.return_value = {"data": 1}
45+
expected_status_code = 400
46+
expected_response_data = {"errors": [{"message": mock.ANY}]}
47+
# When
48+
response = await client.post(
49+
"/browse/12/",
50+
params=given_query_string,
51+
json=given_data,
52+
headers=given_headers,
53+
)
54+
response_data = await response.json()
55+
assert response.status == expected_status_code
56+
assert response_data == expected_response_data

0 commit comments

Comments
 (0)