From 26e2d0d8dad3e6b5c15524aa9410b50d22a8d03a Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 12 Sep 2021 22:09:40 -0400 Subject: [PATCH 01/68] chore: registered new pytest mark --- tox.ini | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 19231ed7..47dcd9e5 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,9 @@ envlist = py{36,37,38,39,py3} [pytest] -markers = asyncio +markers = + asyncio: Tests belonging to the asyncio transport + appsyncwebsocket: Tests belonging to the AwsAppSyncWebsocket transport [gh-actions] python = From d4c8a7b06d39ae44645994de3e66db5ad3c50b07 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 12 Sep 2021 22:09:51 -0400 Subject: [PATCH 02/68] chore: added botocore dependency --- setup.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 248099ab..248dbcf4 100644 --- a/setup.py +++ b/setup.py @@ -44,8 +44,12 @@ "websockets>=9,<10", ] +install_aws_requires = [ + "botocore>=1.21,<1.22", +] + install_websockets_requires + install_all_requires = ( - install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_aws_requires ) # Get version from __version__.py file @@ -87,6 +91,7 @@ "aiohttp": install_aiohttp_requires, "requests": install_requests_requires, "websockets": install_websockets_requires, + "aws": install_aws_requires, }, include_package_data=True, zip_safe=False, From 8535f5e311524ae6f393d3274a9b389363ddd66f Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 12 Sep 2021 22:10:22 -0400 Subject: [PATCH 03/68] test: initial test for appsyncwebsocket before I started coyp/pasting lots of stuff and not writing tests in the process :shamecube: --- tests/test_appsyncwebsocket.py | 70 ++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/test_appsyncwebsocket.py diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py new file mode 100644 index 00000000..66dd5ec8 --- /dev/null +++ b/tests/test_appsyncwebsocket.py @@ -0,0 +1,70 @@ +import pytest + +from gql.transport.awsappsyncwebsocket import AppSyncWebsocketsTransport, AppSyncIAMAuthorization, AppSyncOIDCAuthorization, AppSyncApiKeyAuthorization + +# TODO +# from gql.transport.exceptions import ( +# TransportAlreadyConnected, +# TransportClosed, +# TransportProtocolError, +# TransportQueryError, +# TransportServerError, +# ) + +# TODO: +# from .conftest import TemporaryFile + +# TODO: +# query1_str = """ +# query getContinents { +# continents { +# code +# name +# } +# } +# """ + +# TODO: +# query1_server_answer_data = ( +# '{"continents":[' +# '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' +# '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' +# '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' +# '{"code":"SA","name":"South America"}]}' +# ) + + +# TODO: +# query1_server_answer = f'{{"data":{query1_server_answer_data}}}' + +# Marking all tests in this file with the appsyncwebsocket marker +pytestmark = pytest.mark.appsyncwebsocket +mock_transport_url = "https://appsyncapp.awsgateway.com.example.org" + +@pytest.mark.appsyncwebsocket +def test_appsyncwebsocket_init_with_minimal_args(): + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url) + assert isinstance(sample_transport.authorization, AppSyncIAMAuthorization) + assert sample_transport.connect_timeout == 10 + assert sample_transport.close_timeout == 10 + assert sample_transport.ack_timeout == 10 + assert sample_transport.ssl == False + assert sample_transport.connect_args == {} + +def test_appsyncwebsocket_init_with_oidc_auth(): + authorization = AppSyncOIDCAuthorization() + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + assert sample_transport.authorization is authorization + + +def test_appsyncwebsocket_init_with_apikey_auth(): + authorization = AppSyncApiKeyAuthorization() + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + assert sample_transport.authorization is authorization + + +def test_appsyncwebsocket_init_with_iam_auth(): + authorization = AppSyncIAMAuthorization() + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + assert sample_transport.authorization is authorization + From 149d78808dd747f779535244dbd4fe3815bb70e4 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 12 Sep 2021 22:11:28 -0400 Subject: [PATCH 04/68] feat: wip aws appsync websocket code inspired / copied extensively from work by @joseph-wortmann here: https://github.com/graphql-python/gql/issues/125#issuecomment-907827947 --- gql/transport/awsappsyncwebsocket.py | 145 +++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 gql/transport/awsappsyncwebsocket.py diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py new file mode 100644 index 00000000..81616b35 --- /dev/null +++ b/gql/transport/awsappsyncwebsocket.py @@ -0,0 +1,145 @@ +from asyncio import wait_for, ensure_future + +from graphql import DocumentNode, print_ast + +from transport.exceptions import TransportProtocolError +from transport.websockets import WebsocketsTransport +from ssl import SSLContext +from typing import Any, Dict, Union, Optional +from abc import ABC, abstractmethod +from base64 import b64encode +from botocore.awsrequest import AWSRequest, create_request_object +from botocore.session import Session +from botocore.auth import SigV4Auth +import json + + +class AppSyncAuthorization(ABC): + def on_connect(self) -> str: + return b64encode( + json.dumps(self.get_headers(), separators=(",", ":")).encode() + ).decode() + + @abstractmethod + def get_headers(self, data: Optional[str] = None) -> Dict: + raise NotImplementedError() + + +class AppSyncApiKeyAuthorization(AppSyncAuthorization): + def __init__(self, host: str, api_key: str) -> None: + self.host = host + self.api_key = api_key + + def get_headers(self, data: Optional[str] = None) -> Dict: + return {"host": self.host, "x-api-key": self.api_key} + + +class AppSyncOIDCAuthorization(AppSyncAuthorization): + def __init__(self, host: str, jwt: str) -> None: + self.host = host + self.jwt = jwt + + def get_headers(self, data: Optional[str] = None) -> Dict: + return {"host": self.host, "Authorization": self.jwt} + + +class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): + """Alias for AppSyncOIDCAuthorization""" + pass + + +class AppSyncIAMAuthorization(AppSyncAuthorization): + def __init__(self, host: str, region_name=None, profile=None) -> None: + self._host = host + self._session = Session(profile=profile) + self._credentials = self._session.get_credentials() + self._region_name = self._session._resolve_region_name(region_name, self._session.get_default_client_config()) + self._service_name = "appsync" + self._signer = SigV4Auth(self._credentials, self._service_name, self._region_name) + + def get_headers(self, data: Optional[str] = None) -> Dict: + request = create_request_object({ + 'method': 'GET', + 'url': self._host, + 'body': data, + }) + self._signer.add_auth(request) + return request.headers + + +class AppSyncWebsocketsTransport(WebsocketsTransport): + def __init__( + self, + url: str, + authorization: AppSyncAuthorization = None, + ssl: Union[SSLContext, bool] = False, + connect_timeout: int = 10, + close_timeout: int = 10, + ack_timeout: int = 10, + connect_args: Dict[str, Any] = {}, + ) -> None: + if authorization: + self.authorization = authorization + else: + self.authorization = AppSyncIAMAuthorization() + super().__init__( + url, + ssl=ssl, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + ack_timeout=ack_timeout, + connect_args=connect_args, + ) + + async def _wait_start_ack(self) -> None: + """Wait for the start_ack message. Keep alive messages are ignored""" + + while True: + answer_type = str(json.loads(await self._receive()).get("type")) + + if answer_type == "start_ack": + return + + if answer_type != "ka": + raise TransportProtocolError( + "AppSync server did not return a start ack" + ) + + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + query_id = self.next_query_id + self.next_query_id += 1 + + data = {"query": print_ast(document)} + if variable_values: + data["variables"] = variable_values + if operation_name: + data["operationName"] = operation_name + + data["extensions"] = { + "authorization": self.authorization.get_headers(data) + }, + data = json.dumps(data, separators=(",", ":")) + + await self._send( + json.dumps( + { + "id": str(query_id), + "type": "start", + "payload": data, + }, + separators=(",", ":"), + ) + ) + + # Wait for the connection_ack message or raise a TimeoutError + await wait_for(self._wait_start_ack(), self.ack_timeout) + + # Create a task to listen to the incoming websocket messages + self.receive_data_task = ensure_future(self._receive_data_loop()) + + return query_id From 4d099e725d882deee0d94067ce910bd4556b86ad Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 19 Sep 2021 19:08:29 -0400 Subject: [PATCH 05/68] fix: added url munging and cleaned up _send_query to avoid having to overload connect() and subscribe() --- gql/transport/awsappsyncwebsocket.py | 32 +++++++++++++++++----------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index 81616b35..e5942472 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -82,6 +82,7 @@ def __init__( self.authorization = authorization else: self.authorization = AppSyncIAMAuthorization() + url = self._munge_url_for_appsync_auth(url) super().__init__( url, ssl=ssl, @@ -91,6 +92,19 @@ def __init__( connect_args=connect_args, ) + def _munge_url_for_appsync_auth(self, url: str) -> str: + """Munge URL For Appsync Auth + + :param url: The original URL where we replace 'https' and 'appsync-api' and append auth headers + :return: a new url used to establish websocket connections to the appsync-realtime-api + """ + url_after_replacements=url.replace("https", "wss").replace("appsync-api", "appsync-realtime-api") + headers_from_auth=self.authorization.get_headers() + return '{url}?header={headers}&payload=e30='.format( + url=url_after_replacements, + headers=headers_from_auth + ) + async def _wait_start_ack(self) -> None: """Wait for the start_ack message. Keep alive messages are ignored""" @@ -120,26 +134,20 @@ async def _send_query( if operation_name: data["operationName"] = operation_name - data["extensions"] = { - "authorization": self.authorization.get_headers(data) - }, - data = json.dumps(data, separators=(",", ":")) - await self._send( json.dumps( { "id": str(query_id), "type": "start", - "payload": data, + "payload": { + "data": data, + "extensions": { + "authorization": self.authorization.get_headers(data) + } + } }, separators=(",", ":"), ) ) - # Wait for the connection_ack message or raise a TimeoutError - await wait_for(self._wait_start_ack(), self.ack_timeout) - - # Create a task to listen to the incoming websocket messages - self.receive_data_task = ensure_future(self._receive_data_loop()) - return query_id From 5cb3c33b52964cca52a18b881c275cf4b90a7754 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 19 Sep 2021 19:09:13 -0400 Subject: [PATCH 06/68] test: wip test coverage. Added coverage for munge_url during init and started adding some dependency injection as well as some fakes. --- gql/transport/awsappsyncwebsocket.py | 9 +++++---- tests/fixtures/aws/fake_signer.py | 20 ++++++++++++++++++++ tests/test_appsyncwebsocket.py | 11 +++++++++++ 3 files changed, 36 insertions(+), 4 deletions(-) create mode 100644 tests/fixtures/aws/fake_signer.py diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index e5942472..2ffb2488 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -49,16 +49,17 @@ class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): class AppSyncIAMAuthorization(AppSyncAuthorization): - def __init__(self, host: str, region_name=None, profile=None) -> None: + def __init__(self, host: str, region_name=None, profile=None, signer=None, request_creator=None) -> None: self._host = host self._session = Session(profile=profile) self._credentials = self._session.get_credentials() self._region_name = self._session._resolve_region_name(region_name, self._session.get_default_client_config()) self._service_name = "appsync" - self._signer = SigV4Auth(self._credentials, self._service_name, self._region_name) + self._signer = signer if signer else SigV4Auth(self._credentials, self._service_name, self._region_name) + self._request_creator = request_creator if request_creator else create_request_object - def get_headers(self, data: Optional[str] = None) -> Dict: - request = create_request_object({ + def get_headers(self, data: Optional[str] = None, request_creator: callable = None) -> Dict: + request = self._request_creator({ 'method': 'GET', 'url': self._host, 'body': data, diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py new file mode 100644 index 00000000..d5ffd5d8 --- /dev/null +++ b/tests/fixtures/aws/fake_signer.py @@ -0,0 +1,20 @@ +def fake_request_creator(): + return FakeRequest() + +class FakeRequest(object): + headers = None + +class FakeSigner(object): + def __init__(self, request=None) -> None: + self.request = request if request else FakeRequest() + + def add_auth(self, request) -> None: + """ + A fake for getting a request object that + :return: + """ + request.headers = {"FakeAuthorization": "a", "FakeTime": "today"} + + def get_headers(self): + self.add_auth(self.request) + return self.request.headers \ No newline at end of file diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 66dd5ec8..4813f509 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -68,3 +68,14 @@ def test_appsyncwebsocket_init_with_iam_auth(): sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) assert sample_transport.authorization is authorization + +@pytest.fixture('aws.fake_request_creator') +@pytest.fixture('aws.FakeSigner') +def test_munge_url(fake_signer,fake_request_creator): + authorization = AppSyncIAMAuthorization(signer=fake_signer, request_creator=fake_request_creator) + test_url = 'https://appsync-api.aws.example.org/some-other-params' + expected_url = 'wss://appsync-realtime-api.aws.example.org/some-other-params?header={headers}&payload=e30='.format(headers=authorization.on_connect()) + + sample_transport = AppSyncWebsocketsTransport(url=test_url, authorization=authorization) + + assert sample_transport.url == expected_url From 0fddb4fb5106c538a8d2e72b22a4266fa8064723 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 26 Sep 2021 20:50:00 -0400 Subject: [PATCH 07/68] test: started fixing unit tests. Looks like credentials are not being passed to add_auth. Gotta find where they should be coming from... --- gql/transport/awsappsyncwebsocket.py | 56 ++++++++++++++-------------- tests/fixtures/aws/__init__.py | 0 tests/fixtures/aws/fake_signer.py | 19 +++++++++- tests/test_appsyncwebsocket.py | 6 +-- 4 files changed, 47 insertions(+), 34 deletions(-) create mode 100644 tests/fixtures/aws/__init__.py diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index 2ffb2488..3ef14dff 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -2,8 +2,8 @@ from graphql import DocumentNode, print_ast -from transport.exceptions import TransportProtocolError -from transport.websockets import WebsocketsTransport +from .exceptions import TransportProtocolError +from .websockets import WebsocketsTransport from ssl import SSLContext from typing import Any, Dict, Union, Optional from abc import ABC, abstractmethod @@ -15,10 +15,23 @@ class AppSyncAuthorization(ABC): - def on_connect(self) -> str: - return b64encode( - json.dumps(self.get_headers(), separators=(",", ":")).encode() + def __init__(self, host: str): + self._host = host + + def host_to_auth_url(self) -> str: + """Munge Host For Appsync Auth + + :return: a url used to establish websocket connections to the appsync-realtime-api + """ + url_after_replacements=self._host.replace("https", "wss").replace("appsync-api", "appsync-realtime-api") + headers_from_auth=self.get_headers() + encoded_headers = b64encode( + json.dumps(headers_from_auth, separators=(",", ":")).encode() ).decode() + return '{url}?header={headers}&payload=e30='.format( + url=url_after_replacements, + headers=encoded_headers + ) @abstractmethod def get_headers(self, data: Optional[str] = None) -> Dict: @@ -27,20 +40,21 @@ def get_headers(self, data: Optional[str] = None) -> Dict: class AppSyncApiKeyAuthorization(AppSyncAuthorization): def __init__(self, host: str, api_key: str) -> None: - self.host = host + super().__init__(host) self.api_key = api_key def get_headers(self, data: Optional[str] = None) -> Dict: - return {"host": self.host, "x-api-key": self.api_key} + return {"host": self._host, "x-api-key": self.api_key} class AppSyncOIDCAuthorization(AppSyncAuthorization): def __init__(self, host: str, jwt: str) -> None: - self.host = host + super().__init__(host) self.jwt = jwt def get_headers(self, data: Optional[str] = None) -> Dict: - return {"host": self.host, "Authorization": self.jwt} + return {"host": self._host, "Authorization": self.jwt} + class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): @@ -50,7 +64,7 @@ class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): class AppSyncIAMAuthorization(AppSyncAuthorization): def __init__(self, host: str, region_name=None, profile=None, signer=None, request_creator=None) -> None: - self._host = host + super().__init__(host) self._session = Session(profile=profile) self._credentials = self._session.get_credentials() self._region_name = self._session._resolve_region_name(region_name, self._session.get_default_client_config()) @@ -62,6 +76,8 @@ def get_headers(self, data: Optional[str] = None, request_creator: callable = No request = self._request_creator({ 'method': 'GET', 'url': self._host, + 'headers': {}, + 'context': {}, 'body': data, }) self._signer.add_auth(request) @@ -79,11 +95,8 @@ def __init__( ack_timeout: int = 10, connect_args: Dict[str, Any] = {}, ) -> None: - if authorization: - self.authorization = authorization - else: - self.authorization = AppSyncIAMAuthorization() - url = self._munge_url_for_appsync_auth(url) + self.authorization = authorization if authorization else AppSyncIAMAuthorization(host=url) + url = self.authorization.host_to_auth_url() super().__init__( url, ssl=ssl, @@ -93,19 +106,6 @@ def __init__( connect_args=connect_args, ) - def _munge_url_for_appsync_auth(self, url: str) -> str: - """Munge URL For Appsync Auth - - :param url: The original URL where we replace 'https' and 'appsync-api' and append auth headers - :return: a new url used to establish websocket connections to the appsync-realtime-api - """ - url_after_replacements=url.replace("https", "wss").replace("appsync-api", "appsync-realtime-api") - headers_from_auth=self.authorization.get_headers() - return '{url}?header={headers}&payload=e30='.format( - url=url_after_replacements, - headers=headers_from_auth - ) - async def _wait_start_ack(self) -> None: """Wait for the start_ack message. Keep alive messages are ignored""" diff --git a/tests/fixtures/aws/__init__.py b/tests/fixtures/aws/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py index d5ffd5d8..e1fe61dc 100644 --- a/tests/fixtures/aws/fake_signer.py +++ b/tests/fixtures/aws/fake_signer.py @@ -1,9 +1,24 @@ -def fake_request_creator(): - return FakeRequest() +import pytest + + +@pytest.fixture +def fake_request_factory(): + def _fake_request_factory(): + return FakeRequest() + yield _fake_request_factory + + +@pytest.fixture +def fake_signer_factory(fake_request_factory): + def _fake_signer_factory(request=None): + return FakeSigner(request=request) + yield _fake_signer_factory + class FakeRequest(object): headers = None + class FakeSigner(object): def __init__(self, request=None) -> None: self.request = request if request else FakeRequest() diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 4813f509..a38f6e8e 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -69,10 +69,8 @@ def test_appsyncwebsocket_init_with_iam_auth(): assert sample_transport.authorization is authorization -@pytest.fixture('aws.fake_request_creator') -@pytest.fixture('aws.FakeSigner') -def test_munge_url(fake_signer,fake_request_creator): - authorization = AppSyncIAMAuthorization(signer=fake_signer, request_creator=fake_request_creator) +def test_munge_url(fake_signer_factory, fake_request_factory): + authorization = AppSyncIAMAuthorization(signer=fake_signer_factory(), request_creator=fake_request_factory()) test_url = 'https://appsync-api.aws.example.org/some-other-params' expected_url = 'wss://appsync-realtime-api.aws.example.org/some-other-params?header={headers}&payload=e30='.format(headers=authorization.on_connect()) From 246d2cc56d810816ebb4fa35d0f7cc0271edd3aa Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Mon, 27 Sep 2021 19:45:44 -0400 Subject: [PATCH 08/68] test: added more dependency injection and fakes for aws. New error, now, about request needing some additional structure. That's next. --- gql/transport/awsappsyncwebsocket.py | 12 +++++++----- tests/fixtures/aws/fake_credentials.py | 12 ++++++++++++ tests/fixtures/aws/fake_request.py | 12 ++++++++++++ tests/fixtures/aws/fake_session.py | 21 +++++++++++++++++++++ tests/fixtures/aws/fake_signer.py | 15 +++------------ tests/test_appsyncwebsocket.py | 4 ++-- 6 files changed, 57 insertions(+), 19 deletions(-) create mode 100644 tests/fixtures/aws/fake_credentials.py create mode 100644 tests/fixtures/aws/fake_request.py create mode 100644 tests/fixtures/aws/fake_session.py diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index 3ef14dff..cd080e69 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -1,5 +1,6 @@ from asyncio import wait_for, ensure_future +import botocore.session from graphql import DocumentNode, print_ast from .exceptions import TransportProtocolError @@ -9,7 +10,7 @@ from abc import ABC, abstractmethod from base64 import b64encode from botocore.awsrequest import AWSRequest, create_request_object -from botocore.session import Session +from botocore.session import get_session from botocore.auth import SigV4Auth import json @@ -63,10 +64,10 @@ class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): class AppSyncIAMAuthorization(AppSyncAuthorization): - def __init__(self, host: str, region_name=None, profile=None, signer=None, request_creator=None) -> None: + def __init__(self, host: str, region_name=None, signer=None, request_creator=None, credentials=None, session=None) -> None: super().__init__(host) - self._session = Session(profile=profile) - self._credentials = self._session.get_credentials() + self._session = session if session else get_session() + self._credentials = credentials if credentials else self._session.get_credentials() self._region_name = self._session._resolve_region_name(region_name, self._session.get_default_client_config()) self._service_name = "appsync" self._signer = signer if signer else SigV4Auth(self._credentials, self._service_name, self._region_name) @@ -89,13 +90,14 @@ def __init__( self, url: str, authorization: AppSyncAuthorization = None, + session: botocore.session.Session = None, ssl: Union[SSLContext, bool] = False, connect_timeout: int = 10, close_timeout: int = 10, ack_timeout: int = 10, connect_args: Dict[str, Any] = {}, ) -> None: - self.authorization = authorization if authorization else AppSyncIAMAuthorization(host=url) + self.authorization = authorization if authorization else AppSyncIAMAuthorization(host=url, session=session) url = self.authorization.host_to_auth_url() super().__init__( url, diff --git a/tests/fixtures/aws/fake_credentials.py b/tests/fixtures/aws/fake_credentials.py new file mode 100644 index 00000000..102f5f49 --- /dev/null +++ b/tests/fixtures/aws/fake_credentials.py @@ -0,0 +1,12 @@ +import pytest + +@pytest.fixture +def fake_credentials_factory(): + def _fake_credentials_factory(access_key=None, secret_key=None, method=None, token=None): + return { + "access_key": access_key if access_key else "fake-access-key", + "secret_key": secret_key if secret_key else "fake-secret-key", + "method": method if method else "shared-credentials-file", + "token": token if token else "fake-token", + } + yield _fake_credentials_factory \ No newline at end of file diff --git a/tests/fixtures/aws/fake_request.py b/tests/fixtures/aws/fake_request.py new file mode 100644 index 00000000..bb1de4ed --- /dev/null +++ b/tests/fixtures/aws/fake_request.py @@ -0,0 +1,12 @@ +import pytest + + +class FakeRequest(object): + headers = None + + +@pytest.fixture +def fake_request_factory(): + def _fake_request_factory(): + return FakeRequest() + yield _fake_request_factory diff --git a/tests/fixtures/aws/fake_session.py b/tests/fixtures/aws/fake_session.py new file mode 100644 index 00000000..11f4afa8 --- /dev/null +++ b/tests/fixtures/aws/fake_session.py @@ -0,0 +1,21 @@ +import pytest + + +class FakeSession(object): + def __init__(self, credentials, region_name): + self._credentials = credentials + self._region_name = region_name + + def get_credentials(self): + return self._credentials + + def _resolve_region_name(self): + return self._region_name + + +@pytest.fixture +def fake_session_factory(fake_credentials_factory): + def _fake_session_factory(): + return FakeSession(credentials=fake_credentials_factory, region='fake-region') + + yield _fake_session_factory diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py index e1fe61dc..e50824e1 100644 --- a/tests/fixtures/aws/fake_signer.py +++ b/tests/fixtures/aws/fake_signer.py @@ -1,27 +1,18 @@ import pytest -@pytest.fixture -def fake_request_factory(): - def _fake_request_factory(): - return FakeRequest() - yield _fake_request_factory - - @pytest.fixture def fake_signer_factory(fake_request_factory): def _fake_signer_factory(request=None): + if not request: + request=fake_request_factory() return FakeSigner(request=request) yield _fake_signer_factory -class FakeRequest(object): - headers = None - - class FakeSigner(object): def __init__(self, request=None) -> None: - self.request = request if request else FakeRequest() + self.request = request def add_auth(self, request) -> None: """ diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index a38f6e8e..1793156e 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -42,8 +42,8 @@ mock_transport_url = "https://appsyncapp.awsgateway.com.example.org" @pytest.mark.appsyncwebsocket -def test_appsyncwebsocket_init_with_minimal_args(): - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url) +def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=fake_session_factory()) assert isinstance(sample_transport.authorization, AppSyncIAMAuthorization) assert sample_transport.connect_timeout == 10 assert sample_transport.close_timeout == 10 From 07c9ae4c6933758aad46b02ab98c18b36ae79332 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Tue, 28 Sep 2021 20:19:29 -0400 Subject: [PATCH 09/68] tests: cleaned up awsappsyncwebsocket and related tests so that they all pass. Added additional validation and error handling for various states of credentials. Added additional fakes as needed. --- gql/transport/awsappsyncwebsocket.py | 24 ++++++- tests/conftest.py | 8 +++ tests/fixtures/__init__.py | 0 tests/fixtures/aws/fake_credentials.py | 22 ++++--- tests/fixtures/aws/fake_request.py | 13 +++- tests/fixtures/aws/fake_session.py | 11 ++-- tests/fixtures/fake_logger.py | 18 ++++++ tests/test_appsyncwebsocket.py | 87 +++++++++++++------------- 8 files changed, 123 insertions(+), 60 deletions(-) create mode 100644 tests/fixtures/__init__.py create mode 100644 tests/fixtures/fake_logger.py diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index cd080e69..c95ed16e 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -1,4 +1,5 @@ from asyncio import wait_for, ensure_future +from logging import Logger import botocore.session from graphql import DocumentNode, print_ast @@ -12,6 +13,7 @@ from botocore.awsrequest import AWSRequest, create_request_object from botocore.session import get_session from botocore.auth import SigV4Auth +from botocore.exceptions import NoCredentialsError import json @@ -82,7 +84,7 @@ def get_headers(self, data: Optional[str] = None, request_creator: callable = No 'body': data, }) self._signer.add_auth(request) - return request.headers + return dict(request.headers) class AppSyncWebsocketsTransport(WebsocketsTransport): @@ -96,9 +98,21 @@ def __init__( close_timeout: int = 10, ack_timeout: int = 10, connect_args: Dict[str, Any] = {}, + logger: Logger = None, ) -> None: - self.authorization = authorization if authorization else AppSyncIAMAuthorization(host=url, session=session) - url = self.authorization.host_to_auth_url() + self.logger = logger if logger else Logger('debug') + try: + self.authorization = authorization if authorization else AppSyncIAMAuthorization(host=url, session=session) + url = self.authorization.host_to_auth_url() + except botocore.exceptions.NoCredentialsError as e: + self.authorization = None + self.logger.log(0, 'Credentials not found. Do you have default AWS credentials configured?') + raise e + except TypeError as e: + self.authorization = None + self.logger.log(0, 'A TypeError was raised. The most likely reason for this is that the AWS region is missing from the credentials.') + raise MissingRegionError + super().__init__( url, ssl=ssl, @@ -154,3 +168,7 @@ async def _send_query( ) return query_id + + +class MissingRegionError(Exception): + pass diff --git a/tests/conftest.py b/tests/conftest.py index df69c121..945e70ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -385,3 +385,11 @@ async def run_sync_test_inner(event_loop, server, test_function): await server.close() return run_sync_test_inner + +pytest_plugins = [ + "tests.fixtures.fake_logger", + "tests.fixtures.aws.fake_credentials", + "tests.fixtures.aws.fake_request", + "tests.fixtures.aws.fake_session", + "tests.fixtures.aws.fake_signer", +] diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/aws/fake_credentials.py b/tests/fixtures/aws/fake_credentials.py index 102f5f49..738eeabd 100644 --- a/tests/fixtures/aws/fake_credentials.py +++ b/tests/fixtures/aws/fake_credentials.py @@ -1,12 +1,18 @@ import pytest + +class FakeCredentials(object): + def __init__(self, access_key=None, secret_key=None, method=None, token=None, region=None): + self.region = region if region else "us-east-1a" + self.access_key = access_key if access_key else "fake-access-key" + self.secret_key = secret_key if secret_key else "fake-secret-key" + self.method = method if method else "shared-credentials-file" + self.token = token if token else "fake-token" + + @pytest.fixture def fake_credentials_factory(): - def _fake_credentials_factory(access_key=None, secret_key=None, method=None, token=None): - return { - "access_key": access_key if access_key else "fake-access-key", - "secret_key": secret_key if secret_key else "fake-secret-key", - "method": method if method else "shared-credentials-file", - "token": token if token else "fake-token", - } - yield _fake_credentials_factory \ No newline at end of file + def _fake_credentials_factory(access_key=None, secret_key=None, method=None, token=None, region=None): + return FakeCredentials(access_key=access_key, secret_key=secret_key, method=method, token=token, region=region) + + yield _fake_credentials_factory diff --git a/tests/fixtures/aws/fake_request.py b/tests/fixtures/aws/fake_request.py index bb1de4ed..ffca22db 100644 --- a/tests/fixtures/aws/fake_request.py +++ b/tests/fixtures/aws/fake_request.py @@ -4,9 +4,18 @@ class FakeRequest(object): headers = None + def __init__(self, request_props=None): + if not isinstance(request_props, dict): + return + self.method = request_props.get('method') + self.url = request_props.get('url') + self.headers = request_props.get('headers') + self.context = request_props.get('context') + self.body = request_props.get('body') + @pytest.fixture def fake_request_factory(): - def _fake_request_factory(): - return FakeRequest() + def _fake_request_factory(request_props=None): + return FakeRequest(request_props=request_props) yield _fake_request_factory diff --git a/tests/fixtures/aws/fake_session.py b/tests/fixtures/aws/fake_session.py index 11f4afa8..6c217ea4 100644 --- a/tests/fixtures/aws/fake_session.py +++ b/tests/fixtures/aws/fake_session.py @@ -6,16 +6,19 @@ def __init__(self, credentials, region_name): self._credentials = credentials self._region_name = region_name + def get_default_client_config(self): + return + def get_credentials(self): return self._credentials - def _resolve_region_name(self): - return self._region_name + def _resolve_region_name(self, region_name, client_config): + return region_name if region_name else self._region_name @pytest.fixture def fake_session_factory(fake_credentials_factory): - def _fake_session_factory(): - return FakeSession(credentials=fake_credentials_factory, region='fake-region') + def _fake_session_factory(credentials=fake_credentials_factory()): + return FakeSession(credentials=credentials, region_name='fake-region') yield _fake_session_factory diff --git a/tests/fixtures/fake_logger.py b/tests/fixtures/fake_logger.py new file mode 100644 index 00000000..23efe333 --- /dev/null +++ b/tests/fixtures/fake_logger.py @@ -0,0 +1,18 @@ +import pytest + + +class FakeLogger(object): + def __init__(self, messages=None): + self._messages = messages if messages else [] + + def log(self, level, message): + self._messages.append("LEVEL {}: {}".format(level, message)) + + +@pytest.fixture +def fake_logger_factory(): + + def _fake_logger_factory(messages=None): + return FakeLogger(messages=messages) + + yield _fake_logger_factory diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 1793156e..f0381949 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -1,47 +1,12 @@ +import botocore.exceptions import pytest +from gql.transport.awsappsyncwebsocket import MissingRegionError from gql.transport.awsappsyncwebsocket import AppSyncWebsocketsTransport, AppSyncIAMAuthorization, AppSyncOIDCAuthorization, AppSyncApiKeyAuthorization -# TODO -# from gql.transport.exceptions import ( -# TransportAlreadyConnected, -# TransportClosed, -# TransportProtocolError, -# TransportQueryError, -# TransportServerError, -# ) - -# TODO: -# from .conftest import TemporaryFile - -# TODO: -# query1_str = """ -# query getContinents { -# continents { -# code -# name -# } -# } -# """ - -# TODO: -# query1_server_answer_data = ( -# '{"continents":[' -# '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' -# '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' -# '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' -# '{"code":"SA","name":"South America"}]}' -# ) - - -# TODO: -# query1_server_answer = f'{{"data":{query1_server_answer_data}}}' - -# Marking all tests in this file with the appsyncwebsocket marker -pytestmark = pytest.mark.appsyncwebsocket mock_transport_url = "https://appsyncapp.awsgateway.com.example.org" -@pytest.mark.appsyncwebsocket + def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=fake_session_factory()) assert isinstance(sample_transport.authorization, AppSyncIAMAuthorization) @@ -51,29 +16,65 @@ def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): assert sample_transport.ssl == False assert sample_transport.connect_args == {} + +def test_appsyncwebsocket_init_with_no_credentials(fake_session_factory, fake_logger_factory): + fake_logger = fake_logger_factory() + with pytest.raises(botocore.exceptions.NoCredentialsError): + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=fake_session_factory(credentials=None), logger=fake_logger) + assert sample_transport.authorization is None + assert fake_logger._messages.length > 0 + assert "credentials" in fake_logger._messages[0].lower() + def test_appsyncwebsocket_init_with_oidc_auth(): - authorization = AppSyncOIDCAuthorization() + authorization = AppSyncOIDCAuthorization(host=mock_transport_url, jwt="some-jwt") sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) assert sample_transport.authorization is authorization def test_appsyncwebsocket_init_with_apikey_auth(): - authorization = AppSyncApiKeyAuthorization() + authorization = AppSyncApiKeyAuthorization(host=mock_transport_url, api_key="some-api-key") sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) assert sample_transport.authorization is authorization def test_appsyncwebsocket_init_with_iam_auth(): - authorization = AppSyncIAMAuthorization() + authorization = AppSyncIAMAuthorization(host=mock_transport_url) + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + assert sample_transport.authorization is authorization + + +def test_appsyncwebsocket_init_with_iam_auth(fake_credentials_factory): + authorization = AppSyncIAMAuthorization(host=mock_transport_url, credentials=fake_credentials_factory(), region_name="us-east-1") sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) assert sample_transport.authorization is authorization + +def test_appsyncwebsocket_init_with_iam_auth_and_no_region(fake_credentials_factory, fake_logger_factory): + fake_logger = fake_logger_factory() + with pytest.raises(MissingRegionError): + authorization = AppSyncIAMAuthorization(host=mock_transport_url, credentials=fake_credentials_factory()) + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization, logger=fake_logger) + assert sample_transport.authorization is None + assert fake_logger._messages.length > 0 + assert "credentials" in fake_logger._messages[0].lower() + + def test_munge_url(fake_signer_factory, fake_request_factory): - authorization = AppSyncIAMAuthorization(signer=fake_signer_factory(), request_creator=fake_request_factory()) test_url = 'https://appsync-api.aws.example.org/some-other-params' - expected_url = 'wss://appsync-realtime-api.aws.example.org/some-other-params?header={headers}&payload=e30='.format(headers=authorization.on_connect()) + authorization = AppSyncIAMAuthorization(host=test_url, signer=fake_signer_factory(), request_creator=fake_request_factory) sample_transport = AppSyncWebsocketsTransport(url=test_url, authorization=authorization) + expected_url = authorization.host_to_auth_url() assert sample_transport.url == expected_url + + +def test_munge_url_format(fake_signer_factory, fake_request_factory, fake_credentials_factory, fake_session_factory): + test_url = 'https://appsync-api.aws.example.org/some-other-params' + + authorization = AppSyncIAMAuthorization(host=test_url, signer=fake_signer_factory(), session=fake_session_factory(), request_creator=fake_request_factory, credentials=fake_credentials_factory()) + + expected_url = 'wss://appsync-realtime-api.aws.example.org/some-other-params?header=eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5In0=&payload=e30=' + assert authorization.host_to_auth_url() == expected_url + From 25cde9706275e10a877cfc34da77c198bf357622 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 3 Oct 2021 20:59:09 -0400 Subject: [PATCH 10/68] fix: make check applied, flake errors fixed --- gql/transport/awsappsyncwebsocket.py | 114 ++++++++++++++++--------- tests/fixtures/aws/fake_credentials.py | 16 +++- tests/fixtures/aws/fake_request.py | 11 +-- tests/fixtures/aws/fake_session.py | 2 +- tests/fixtures/aws/fake_signer.py | 5 +- tests/fixtures/fake_logger.py | 1 - tests/test_appsyncwebsocket.py | 107 +++++++++++++++++------ 7 files changed, 178 insertions(+), 78 deletions(-) diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index c95ed16e..cb0e9186 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -1,20 +1,19 @@ -from asyncio import wait_for, ensure_future +import json +from abc import ABC, abstractmethod +from base64 import b64encode from logging import Logger +from ssl import SSLContext +from typing import Any, Dict, Optional, Union import botocore.session +from botocore.auth import SigV4Auth +from botocore.awsrequest import create_request_object +from botocore.exceptions import NoCredentialsError +from botocore.session import get_session from graphql import DocumentNode, print_ast from .exceptions import TransportProtocolError from .websockets import WebsocketsTransport -from ssl import SSLContext -from typing import Any, Dict, Union, Optional -from abc import ABC, abstractmethod -from base64 import b64encode -from botocore.awsrequest import AWSRequest, create_request_object -from botocore.session import get_session -from botocore.auth import SigV4Auth -from botocore.exceptions import NoCredentialsError -import json class AppSyncAuthorization(ABC): @@ -24,16 +23,18 @@ def __init__(self, host: str): def host_to_auth_url(self) -> str: """Munge Host For Appsync Auth - :return: a url used to establish websocket connections to the appsync-realtime-api + :return: a url used to establish websocket connections + to the appsync-realtime-api """ - url_after_replacements=self._host.replace("https", "wss").replace("appsync-api", "appsync-realtime-api") - headers_from_auth=self.get_headers() + url_after_replacements = self._host.replace("https", "wss").replace( + "appsync-api", "appsync-realtime-api" + ) + headers_from_auth = self.get_headers() encoded_headers = b64encode( json.dumps(headers_from_auth, separators=(",", ":")).encode() ).decode() - return '{url}?header={headers}&payload=e30='.format( - url=url_after_replacements, - headers=encoded_headers + return "{url}?header={headers}&payload=e30=".format( + url=url_after_replacements, headers=encoded_headers ) @abstractmethod @@ -59,30 +60,52 @@ def get_headers(self, data: Optional[str] = None) -> Dict: return {"host": self._host, "Authorization": self.jwt} - class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): """Alias for AppSyncOIDCAuthorization""" + pass class AppSyncIAMAuthorization(AppSyncAuthorization): - def __init__(self, host: str, region_name=None, signer=None, request_creator=None, credentials=None, session=None) -> None: + def __init__( + self, + host: str, + region_name=None, + signer=None, + request_creator=None, + credentials=None, + session=None, + ) -> None: super().__init__(host) self._session = session if session else get_session() - self._credentials = credentials if credentials else self._session.get_credentials() - self._region_name = self._session._resolve_region_name(region_name, self._session.get_default_client_config()) + self._credentials = ( + credentials if credentials else self._session.get_credentials() + ) + self._region_name = self._session._resolve_region_name( + region_name, self._session.get_default_client_config() + ) self._service_name = "appsync" - self._signer = signer if signer else SigV4Auth(self._credentials, self._service_name, self._region_name) - self._request_creator = request_creator if request_creator else create_request_object - - def get_headers(self, data: Optional[str] = None, request_creator: callable = None) -> Dict: - request = self._request_creator({ - 'method': 'GET', - 'url': self._host, - 'headers': {}, - 'context': {}, - 'body': data, - }) + self._signer = ( + signer + if signer + else SigV4Auth(self._credentials, self._service_name, self._region_name) + ) + self._request_creator = ( + request_creator if request_creator else create_request_object + ) + + def get_headers( + self, data: Optional[str] = None, request_creator: callable = None + ) -> Dict: + request = self._request_creator( + { + "method": "GET", + "url": self._host, + "headers": {}, + "context": {}, + "body": data, + } + ) self._signer.add_auth(request) return dict(request.headers) @@ -100,17 +123,30 @@ def __init__( connect_args: Dict[str, Any] = {}, logger: Logger = None, ) -> None: - self.logger = logger if logger else Logger('debug') + self.logger = logger if logger else Logger("debug") try: - self.authorization = authorization if authorization else AppSyncIAMAuthorization(host=url, session=session) + self.authorization = ( + authorization + if authorization + else AppSyncIAMAuthorization(host=url, session=session) + ) url = self.authorization.host_to_auth_url() - except botocore.exceptions.NoCredentialsError as e: + except NoCredentialsError as e: self.authorization = None - self.logger.log(0, 'Credentials not found. Do you have default AWS credentials configured?') + self.logger.log( + 0, + "Credentials not found. " + "Do you have default AWS credentials configured?", + ) raise e - except TypeError as e: + except TypeError: self.authorization = None - self.logger.log(0, 'A TypeError was raised. The most likely reason for this is that the AWS region is missing from the credentials.') + self.logger.log( + 0, + "A TypeError was raised. " + "The most likely reason for this is that the AWS " + "region is missing from the credentials.", + ) raise MissingRegionError super().__init__( @@ -160,8 +196,8 @@ async def _send_query( "data": data, "extensions": { "authorization": self.authorization.get_headers(data) - } - } + }, + }, }, separators=(",", ":"), ) diff --git a/tests/fixtures/aws/fake_credentials.py b/tests/fixtures/aws/fake_credentials.py index 738eeabd..d8eac834 100644 --- a/tests/fixtures/aws/fake_credentials.py +++ b/tests/fixtures/aws/fake_credentials.py @@ -2,7 +2,9 @@ class FakeCredentials(object): - def __init__(self, access_key=None, secret_key=None, method=None, token=None, region=None): + def __init__( + self, access_key=None, secret_key=None, method=None, token=None, region=None + ): self.region = region if region else "us-east-1a" self.access_key = access_key if access_key else "fake-access-key" self.secret_key = secret_key if secret_key else "fake-secret-key" @@ -12,7 +14,15 @@ def __init__(self, access_key=None, secret_key=None, method=None, token=None, re @pytest.fixture def fake_credentials_factory(): - def _fake_credentials_factory(access_key=None, secret_key=None, method=None, token=None, region=None): - return FakeCredentials(access_key=access_key, secret_key=secret_key, method=method, token=token, region=region) + def _fake_credentials_factory( + access_key=None, secret_key=None, method=None, token=None, region=None + ): + return FakeCredentials( + access_key=access_key, + secret_key=secret_key, + method=method, + token=token, + region=region, + ) yield _fake_credentials_factory diff --git a/tests/fixtures/aws/fake_request.py b/tests/fixtures/aws/fake_request.py index ffca22db..615bc095 100644 --- a/tests/fixtures/aws/fake_request.py +++ b/tests/fixtures/aws/fake_request.py @@ -7,15 +7,16 @@ class FakeRequest(object): def __init__(self, request_props=None): if not isinstance(request_props, dict): return - self.method = request_props.get('method') - self.url = request_props.get('url') - self.headers = request_props.get('headers') - self.context = request_props.get('context') - self.body = request_props.get('body') + self.method = request_props.get("method") + self.url = request_props.get("url") + self.headers = request_props.get("headers") + self.context = request_props.get("context") + self.body = request_props.get("body") @pytest.fixture def fake_request_factory(): def _fake_request_factory(request_props=None): return FakeRequest(request_props=request_props) + yield _fake_request_factory diff --git a/tests/fixtures/aws/fake_session.py b/tests/fixtures/aws/fake_session.py index 6c217ea4..78e1511a 100644 --- a/tests/fixtures/aws/fake_session.py +++ b/tests/fixtures/aws/fake_session.py @@ -19,6 +19,6 @@ def _resolve_region_name(self, region_name, client_config): @pytest.fixture def fake_session_factory(fake_credentials_factory): def _fake_session_factory(credentials=fake_credentials_factory()): - return FakeSession(credentials=credentials, region_name='fake-region') + return FakeSession(credentials=credentials, region_name="fake-region") yield _fake_session_factory diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py index e50824e1..ff096745 100644 --- a/tests/fixtures/aws/fake_signer.py +++ b/tests/fixtures/aws/fake_signer.py @@ -5,8 +5,9 @@ def fake_signer_factory(fake_request_factory): def _fake_signer_factory(request=None): if not request: - request=fake_request_factory() + request = fake_request_factory() return FakeSigner(request=request) + yield _fake_signer_factory @@ -23,4 +24,4 @@ def add_auth(self, request) -> None: def get_headers(self): self.add_auth(self.request) - return self.request.headers \ No newline at end of file + return self.request.headers diff --git a/tests/fixtures/fake_logger.py b/tests/fixtures/fake_logger.py index 23efe333..5e03863b 100644 --- a/tests/fixtures/fake_logger.py +++ b/tests/fixtures/fake_logger.py @@ -11,7 +11,6 @@ def log(self, level, message): @pytest.fixture def fake_logger_factory(): - def _fake_logger_factory(messages=None): return FakeLogger(messages=messages) diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index f0381949..2b84e9a4 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -1,80 +1,133 @@ import botocore.exceptions import pytest -from gql.transport.awsappsyncwebsocket import MissingRegionError -from gql.transport.awsappsyncwebsocket import AppSyncWebsocketsTransport, AppSyncIAMAuthorization, AppSyncOIDCAuthorization, AppSyncApiKeyAuthorization +from gql.transport.awsappsyncwebsocket import ( + AppSyncApiKeyAuthorization, + AppSyncIAMAuthorization, + AppSyncOIDCAuthorization, + AppSyncWebsocketsTransport, + MissingRegionError, +) mock_transport_url = "https://appsyncapp.awsgateway.com.example.org" def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=fake_session_factory()) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, session=fake_session_factory() + ) assert isinstance(sample_transport.authorization, AppSyncIAMAuthorization) assert sample_transport.connect_timeout == 10 assert sample_transport.close_timeout == 10 assert sample_transport.ack_timeout == 10 - assert sample_transport.ssl == False + assert sample_transport.ssl is False assert sample_transport.connect_args == {} -def test_appsyncwebsocket_init_with_no_credentials(fake_session_factory, fake_logger_factory): +def test_appsyncwebsocket_init_with_no_credentials( + fake_session_factory, fake_logger_factory +): fake_logger = fake_logger_factory() with pytest.raises(botocore.exceptions.NoCredentialsError): - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=fake_session_factory(credentials=None), logger=fake_logger) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, + session=fake_session_factory(credentials=None), + logger=fake_logger, + ) assert sample_transport.authorization is None assert fake_logger._messages.length > 0 assert "credentials" in fake_logger._messages[0].lower() + def test_appsyncwebsocket_init_with_oidc_auth(): authorization = AppSyncOIDCAuthorization(host=mock_transport_url, jwt="some-jwt") - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, authorization=authorization + ) assert sample_transport.authorization is authorization def test_appsyncwebsocket_init_with_apikey_auth(): - authorization = AppSyncApiKeyAuthorization(host=mock_transport_url, api_key="some-api-key") - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + authorization = AppSyncApiKeyAuthorization( + host=mock_transport_url, api_key="some-api-key" + ) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, authorization=authorization + ) assert sample_transport.authorization is authorization def test_appsyncwebsocket_init_with_iam_auth(): authorization = AppSyncIAMAuthorization(host=mock_transport_url) - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, authorization=authorization + ) assert sample_transport.authorization is authorization -def test_appsyncwebsocket_init_with_iam_auth(fake_credentials_factory): - authorization = AppSyncIAMAuthorization(host=mock_transport_url, credentials=fake_credentials_factory(), region_name="us-east-1") - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) +def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory): + authorization = AppSyncIAMAuthorization( + host=mock_transport_url, + credentials=fake_credentials_factory(), + region_name="us-east-1", + ) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, authorization=authorization + ) assert sample_transport.authorization is authorization - -def test_appsyncwebsocket_init_with_iam_auth_and_no_region(fake_credentials_factory, fake_logger_factory): +def test_appsyncwebsocket_init_with_iam_auth_and_no_region( + fake_credentials_factory, fake_logger_factory +): fake_logger = fake_logger_factory() with pytest.raises(MissingRegionError): - authorization = AppSyncIAMAuthorization(host=mock_transport_url, credentials=fake_credentials_factory()) - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization, logger=fake_logger) + authorization = AppSyncIAMAuthorization( + host=mock_transport_url, credentials=fake_credentials_factory() + ) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, authorization=authorization, logger=fake_logger + ) assert sample_transport.authorization is None assert fake_logger._messages.length > 0 assert "credentials" in fake_logger._messages[0].lower() def test_munge_url(fake_signer_factory, fake_request_factory): - test_url = 'https://appsync-api.aws.example.org/some-other-params' + test_url = "https://appsync-api.aws.example.org/some-other-params" - authorization = AppSyncIAMAuthorization(host=test_url, signer=fake_signer_factory(), request_creator=fake_request_factory) - sample_transport = AppSyncWebsocketsTransport(url=test_url, authorization=authorization) + authorization = AppSyncIAMAuthorization( + host=test_url, + signer=fake_signer_factory(), + request_creator=fake_request_factory, + ) + sample_transport = AppSyncWebsocketsTransport( + url=test_url, authorization=authorization + ) expected_url = authorization.host_to_auth_url() assert sample_transport.url == expected_url -def test_munge_url_format(fake_signer_factory, fake_request_factory, fake_credentials_factory, fake_session_factory): - test_url = 'https://appsync-api.aws.example.org/some-other-params' - - authorization = AppSyncIAMAuthorization(host=test_url, signer=fake_signer_factory(), session=fake_session_factory(), request_creator=fake_request_factory, credentials=fake_credentials_factory()) - - expected_url = 'wss://appsync-realtime-api.aws.example.org/some-other-params?header=eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5In0=&payload=e30=' +def test_munge_url_format( + fake_signer_factory, + fake_request_factory, + fake_credentials_factory, + fake_session_factory, +): + test_url = "https://appsync-api.aws.example.org/some-other-params" + + authorization = AppSyncIAMAuthorization( + host=test_url, + signer=fake_signer_factory(), + session=fake_session_factory(), + request_creator=fake_request_factory, + credentials=fake_credentials_factory(), + ) + + header_string = "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5In0=" + expected_url = ( + f"wss://appsync-realtime-api.aws.example.org/" + f"some-other-params?header={header_string}&payload=e30=" + ) assert authorization.host_to_auth_url() == expected_url - From 6cb8c83d749b0b19a6beb355f589e266dde58cb0 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 3 Oct 2021 20:59:31 -0400 Subject: [PATCH 11/68] fix: added 'aws' as an install target for tests --- tests/conftest.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 945e70ba..d742d50b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,11 +14,7 @@ from gql import Client -all_transport_dependencies = [ - "aiohttp", - "requests", - "websockets", -] +all_transport_dependencies = ["aiohttp", "requests", "websockets", "aws"] def pytest_addoption(parser): @@ -386,6 +382,7 @@ async def run_sync_test_inner(event_loop, server, test_function): return run_sync_test_inner + pytest_plugins = [ "tests.fixtures.fake_logger", "tests.fixtures.aws.fake_credentials", From c8bafa431e59a53d5ec60e7fea4eea6e90474393 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 10 Oct 2021 19:23:26 -0400 Subject: [PATCH 12/68] fix: typehint issues addressed --- gql/transport/awsappsyncwebsocket.py | 53 ++++++++++++++-------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index cb0e9186..6bb4903b 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -3,11 +3,11 @@ from base64 import b64encode from logging import Logger from ssl import SSLContext -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import botocore.session from botocore.auth import SigV4Auth -from botocore.awsrequest import create_request_object +from botocore.awsrequest import AWSRequest, create_request_object from botocore.exceptions import NoCredentialsError from botocore.session import get_session from graphql import DocumentNode, print_ast @@ -38,7 +38,7 @@ def host_to_auth_url(self) -> str: ) @abstractmethod - def get_headers(self, data: Optional[str] = None) -> Dict: + def get_headers(self, data: Optional[dict] = None) -> Dict: raise NotImplementedError() @@ -47,7 +47,7 @@ def __init__(self, host: str, api_key: str) -> None: super().__init__(host) self.api_key = api_key - def get_headers(self, data: Optional[str] = None) -> Dict: + def get_headers(self, data: Optional[dict] = None) -> Dict: return {"host": self._host, "x-api-key": self.api_key} @@ -56,7 +56,7 @@ def __init__(self, host: str, jwt: str) -> None: super().__init__(host) self.jwt = jwt - def get_headers(self, data: Optional[str] = None) -> Dict: + def get_headers(self, data: Optional[dict] = None) -> Dict: return {"host": self._host, "Authorization": self.jwt} @@ -95,7 +95,9 @@ def __init__( ) def get_headers( - self, data: Optional[str] = None, request_creator: callable = None + self, + data: Optional[dict] = None, + request_creator: Callable[[dict], AWSRequest] = None, ) -> Dict: request = self._request_creator( { @@ -111,11 +113,13 @@ def get_headers( class AppSyncWebsocketsTransport(WebsocketsTransport): + authorization: Optional[AppSyncAuthorization] + def __init__( self, url: str, - authorization: AppSyncAuthorization = None, - session: botocore.session.Session = None, + authorization: Optional[AppSyncAuthorization] = None, + session: Optional[botocore.session.Session] = None, ssl: Union[SSLContext, bool] = False, connect_timeout: int = 10, close_timeout: int = 10, @@ -132,7 +136,7 @@ def __init__( ) url = self.authorization.host_to_auth_url() except NoCredentialsError as e: - self.authorization = None + del self.authorization self.logger.log( 0, "Credentials not found. " @@ -140,7 +144,7 @@ def __init__( ) raise e except TypeError: - self.authorization = None + del self.authorization self.logger.log( 0, "A TypeError was raised. " @@ -181,27 +185,24 @@ async def _send_query( query_id = self.next_query_id self.next_query_id += 1 - data = {"query": print_ast(document)} + data: Dict = {"query": print_ast(document)} if variable_values: data["variables"] = variable_values if operation_name: data["operationName"] = operation_name - await self._send( - json.dumps( - { - "id": str(query_id), - "type": "start", - "payload": { - "data": data, - "extensions": { - "authorization": self.authorization.get_headers(data) - }, - }, - }, - separators=(",", ":"), - ) - ) + message: Dict = { + "id": str(query_id), + "type": "start", + "payload": {"data": data}, + } + + if self.authorization: + message["payload"]["extensions"] = { + "authorization": self.authorization.get_headers(data) + } + + await self._send(json.dumps(message, separators=(",", ":"),)) return query_id From ee0496dafe45fdcb4aa8a906834d7bdc22f4378f Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 10 Oct 2021 19:29:03 -0400 Subject: [PATCH 13/68] test: updated iam test without credentials to explicitly trigger credential error --- tests/test_appsyncwebsocket.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 2b84e9a4..67423b10 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -57,12 +57,12 @@ def test_appsyncwebsocket_init_with_apikey_auth(): assert sample_transport.authorization is authorization -def test_appsyncwebsocket_init_with_iam_auth(): - authorization = AppSyncIAMAuthorization(host=mock_transport_url) - sample_transport = AppSyncWebsocketsTransport( - url=mock_transport_url, authorization=authorization - ) - assert sample_transport.authorization is authorization +def test_appsyncwebsocket_init_with_iam_auth_without_creds(): + authorization = AppSyncIAMAuthorization(host=mock_transport_url, credentials=None) + with pytest.raises(botocore.exceptions.NoCredentialsError): + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, authorization=authorization + ) def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory): From b9a61b79d1bafd11d080685ab5db83162c1ad3bb Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 12 Sep 2021 22:09:40 -0400 Subject: [PATCH 14/68] chore: registered new pytest mark --- tox.ini | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 19231ed7..47dcd9e5 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,9 @@ envlist = py{36,37,38,39,py3} [pytest] -markers = asyncio +markers = + asyncio: Tests belonging to the asyncio transport + appsyncwebsocket: Tests belonging to the AwsAppSyncWebsocket transport [gh-actions] python = From 45aa7691117e5275236a489950e67fd970c6459e Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 12 Sep 2021 22:09:51 -0400 Subject: [PATCH 15/68] chore: added botocore dependency --- setup.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ead75821..64018932 100644 --- a/setup.py +++ b/setup.py @@ -45,8 +45,12 @@ "websockets>=9,<10", ] +install_aws_requires = [ + "botocore>=1.21,<1.22", +] + install_websockets_requires + install_all_requires = ( - install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_aws_requires ) # Get version from __version__.py file @@ -88,6 +92,7 @@ "aiohttp": install_aiohttp_requires, "requests": install_requests_requires, "websockets": install_websockets_requires, + "aws": install_aws_requires, }, include_package_data=True, zip_safe=False, From b1fe6b0b054bfb124d8b20aff3e5531397aac906 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 12 Sep 2021 22:10:22 -0400 Subject: [PATCH 16/68] test: initial test for appsyncwebsocket before I started coyp/pasting lots of stuff and not writing tests in the process :shamecube: --- tests/test_appsyncwebsocket.py | 70 ++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/test_appsyncwebsocket.py diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py new file mode 100644 index 00000000..66dd5ec8 --- /dev/null +++ b/tests/test_appsyncwebsocket.py @@ -0,0 +1,70 @@ +import pytest + +from gql.transport.awsappsyncwebsocket import AppSyncWebsocketsTransport, AppSyncIAMAuthorization, AppSyncOIDCAuthorization, AppSyncApiKeyAuthorization + +# TODO +# from gql.transport.exceptions import ( +# TransportAlreadyConnected, +# TransportClosed, +# TransportProtocolError, +# TransportQueryError, +# TransportServerError, +# ) + +# TODO: +# from .conftest import TemporaryFile + +# TODO: +# query1_str = """ +# query getContinents { +# continents { +# code +# name +# } +# } +# """ + +# TODO: +# query1_server_answer_data = ( +# '{"continents":[' +# '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' +# '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' +# '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' +# '{"code":"SA","name":"South America"}]}' +# ) + + +# TODO: +# query1_server_answer = f'{{"data":{query1_server_answer_data}}}' + +# Marking all tests in this file with the appsyncwebsocket marker +pytestmark = pytest.mark.appsyncwebsocket +mock_transport_url = "https://appsyncapp.awsgateway.com.example.org" + +@pytest.mark.appsyncwebsocket +def test_appsyncwebsocket_init_with_minimal_args(): + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url) + assert isinstance(sample_transport.authorization, AppSyncIAMAuthorization) + assert sample_transport.connect_timeout == 10 + assert sample_transport.close_timeout == 10 + assert sample_transport.ack_timeout == 10 + assert sample_transport.ssl == False + assert sample_transport.connect_args == {} + +def test_appsyncwebsocket_init_with_oidc_auth(): + authorization = AppSyncOIDCAuthorization() + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + assert sample_transport.authorization is authorization + + +def test_appsyncwebsocket_init_with_apikey_auth(): + authorization = AppSyncApiKeyAuthorization() + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + assert sample_transport.authorization is authorization + + +def test_appsyncwebsocket_init_with_iam_auth(): + authorization = AppSyncIAMAuthorization() + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + assert sample_transport.authorization is authorization + From 3f28592d25166aaa83a2d49199fe7c0ea53360c7 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 12 Sep 2021 22:11:28 -0400 Subject: [PATCH 17/68] feat: wip aws appsync websocket code inspired / copied extensively from work by @joseph-wortmann here: https://github.com/graphql-python/gql/issues/125#issuecomment-907827947 --- gql/transport/awsappsyncwebsocket.py | 145 +++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 gql/transport/awsappsyncwebsocket.py diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py new file mode 100644 index 00000000..81616b35 --- /dev/null +++ b/gql/transport/awsappsyncwebsocket.py @@ -0,0 +1,145 @@ +from asyncio import wait_for, ensure_future + +from graphql import DocumentNode, print_ast + +from transport.exceptions import TransportProtocolError +from transport.websockets import WebsocketsTransport +from ssl import SSLContext +from typing import Any, Dict, Union, Optional +from abc import ABC, abstractmethod +from base64 import b64encode +from botocore.awsrequest import AWSRequest, create_request_object +from botocore.session import Session +from botocore.auth import SigV4Auth +import json + + +class AppSyncAuthorization(ABC): + def on_connect(self) -> str: + return b64encode( + json.dumps(self.get_headers(), separators=(",", ":")).encode() + ).decode() + + @abstractmethod + def get_headers(self, data: Optional[str] = None) -> Dict: + raise NotImplementedError() + + +class AppSyncApiKeyAuthorization(AppSyncAuthorization): + def __init__(self, host: str, api_key: str) -> None: + self.host = host + self.api_key = api_key + + def get_headers(self, data: Optional[str] = None) -> Dict: + return {"host": self.host, "x-api-key": self.api_key} + + +class AppSyncOIDCAuthorization(AppSyncAuthorization): + def __init__(self, host: str, jwt: str) -> None: + self.host = host + self.jwt = jwt + + def get_headers(self, data: Optional[str] = None) -> Dict: + return {"host": self.host, "Authorization": self.jwt} + + +class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): + """Alias for AppSyncOIDCAuthorization""" + pass + + +class AppSyncIAMAuthorization(AppSyncAuthorization): + def __init__(self, host: str, region_name=None, profile=None) -> None: + self._host = host + self._session = Session(profile=profile) + self._credentials = self._session.get_credentials() + self._region_name = self._session._resolve_region_name(region_name, self._session.get_default_client_config()) + self._service_name = "appsync" + self._signer = SigV4Auth(self._credentials, self._service_name, self._region_name) + + def get_headers(self, data: Optional[str] = None) -> Dict: + request = create_request_object({ + 'method': 'GET', + 'url': self._host, + 'body': data, + }) + self._signer.add_auth(request) + return request.headers + + +class AppSyncWebsocketsTransport(WebsocketsTransport): + def __init__( + self, + url: str, + authorization: AppSyncAuthorization = None, + ssl: Union[SSLContext, bool] = False, + connect_timeout: int = 10, + close_timeout: int = 10, + ack_timeout: int = 10, + connect_args: Dict[str, Any] = {}, + ) -> None: + if authorization: + self.authorization = authorization + else: + self.authorization = AppSyncIAMAuthorization() + super().__init__( + url, + ssl=ssl, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + ack_timeout=ack_timeout, + connect_args=connect_args, + ) + + async def _wait_start_ack(self) -> None: + """Wait for the start_ack message. Keep alive messages are ignored""" + + while True: + answer_type = str(json.loads(await self._receive()).get("type")) + + if answer_type == "start_ack": + return + + if answer_type != "ka": + raise TransportProtocolError( + "AppSync server did not return a start ack" + ) + + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + query_id = self.next_query_id + self.next_query_id += 1 + + data = {"query": print_ast(document)} + if variable_values: + data["variables"] = variable_values + if operation_name: + data["operationName"] = operation_name + + data["extensions"] = { + "authorization": self.authorization.get_headers(data) + }, + data = json.dumps(data, separators=(",", ":")) + + await self._send( + json.dumps( + { + "id": str(query_id), + "type": "start", + "payload": data, + }, + separators=(",", ":"), + ) + ) + + # Wait for the connection_ack message or raise a TimeoutError + await wait_for(self._wait_start_ack(), self.ack_timeout) + + # Create a task to listen to the incoming websocket messages + self.receive_data_task = ensure_future(self._receive_data_loop()) + + return query_id From d37c769abdfb6956e83033fa89dc6cdf6490b720 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 19 Sep 2021 19:08:29 -0400 Subject: [PATCH 18/68] fix: added url munging and cleaned up _send_query to avoid having to overload connect() and subscribe() --- gql/transport/awsappsyncwebsocket.py | 32 +++++++++++++++++----------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index 81616b35..e5942472 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -82,6 +82,7 @@ def __init__( self.authorization = authorization else: self.authorization = AppSyncIAMAuthorization() + url = self._munge_url_for_appsync_auth(url) super().__init__( url, ssl=ssl, @@ -91,6 +92,19 @@ def __init__( connect_args=connect_args, ) + def _munge_url_for_appsync_auth(self, url: str) -> str: + """Munge URL For Appsync Auth + + :param url: The original URL where we replace 'https' and 'appsync-api' and append auth headers + :return: a new url used to establish websocket connections to the appsync-realtime-api + """ + url_after_replacements=url.replace("https", "wss").replace("appsync-api", "appsync-realtime-api") + headers_from_auth=self.authorization.get_headers() + return '{url}?header={headers}&payload=e30='.format( + url=url_after_replacements, + headers=headers_from_auth + ) + async def _wait_start_ack(self) -> None: """Wait for the start_ack message. Keep alive messages are ignored""" @@ -120,26 +134,20 @@ async def _send_query( if operation_name: data["operationName"] = operation_name - data["extensions"] = { - "authorization": self.authorization.get_headers(data) - }, - data = json.dumps(data, separators=(",", ":")) - await self._send( json.dumps( { "id": str(query_id), "type": "start", - "payload": data, + "payload": { + "data": data, + "extensions": { + "authorization": self.authorization.get_headers(data) + } + } }, separators=(",", ":"), ) ) - # Wait for the connection_ack message or raise a TimeoutError - await wait_for(self._wait_start_ack(), self.ack_timeout) - - # Create a task to listen to the incoming websocket messages - self.receive_data_task = ensure_future(self._receive_data_loop()) - return query_id From 9bd8d93cfb5b91a450c0997f8deedaa6944fe9c6 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 19 Sep 2021 19:09:13 -0400 Subject: [PATCH 19/68] test: wip test coverage. Added coverage for munge_url during init and started adding some dependency injection as well as some fakes. --- gql/transport/awsappsyncwebsocket.py | 9 +++++---- tests/fixtures/aws/fake_signer.py | 20 ++++++++++++++++++++ tests/test_appsyncwebsocket.py | 11 +++++++++++ 3 files changed, 36 insertions(+), 4 deletions(-) create mode 100644 tests/fixtures/aws/fake_signer.py diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index e5942472..2ffb2488 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -49,16 +49,17 @@ class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): class AppSyncIAMAuthorization(AppSyncAuthorization): - def __init__(self, host: str, region_name=None, profile=None) -> None: + def __init__(self, host: str, region_name=None, profile=None, signer=None, request_creator=None) -> None: self._host = host self._session = Session(profile=profile) self._credentials = self._session.get_credentials() self._region_name = self._session._resolve_region_name(region_name, self._session.get_default_client_config()) self._service_name = "appsync" - self._signer = SigV4Auth(self._credentials, self._service_name, self._region_name) + self._signer = signer if signer else SigV4Auth(self._credentials, self._service_name, self._region_name) + self._request_creator = request_creator if request_creator else create_request_object - def get_headers(self, data: Optional[str] = None) -> Dict: - request = create_request_object({ + def get_headers(self, data: Optional[str] = None, request_creator: callable = None) -> Dict: + request = self._request_creator({ 'method': 'GET', 'url': self._host, 'body': data, diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py new file mode 100644 index 00000000..d5ffd5d8 --- /dev/null +++ b/tests/fixtures/aws/fake_signer.py @@ -0,0 +1,20 @@ +def fake_request_creator(): + return FakeRequest() + +class FakeRequest(object): + headers = None + +class FakeSigner(object): + def __init__(self, request=None) -> None: + self.request = request if request else FakeRequest() + + def add_auth(self, request) -> None: + """ + A fake for getting a request object that + :return: + """ + request.headers = {"FakeAuthorization": "a", "FakeTime": "today"} + + def get_headers(self): + self.add_auth(self.request) + return self.request.headers \ No newline at end of file diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 66dd5ec8..4813f509 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -68,3 +68,14 @@ def test_appsyncwebsocket_init_with_iam_auth(): sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) assert sample_transport.authorization is authorization + +@pytest.fixture('aws.fake_request_creator') +@pytest.fixture('aws.FakeSigner') +def test_munge_url(fake_signer,fake_request_creator): + authorization = AppSyncIAMAuthorization(signer=fake_signer, request_creator=fake_request_creator) + test_url = 'https://appsync-api.aws.example.org/some-other-params' + expected_url = 'wss://appsync-realtime-api.aws.example.org/some-other-params?header={headers}&payload=e30='.format(headers=authorization.on_connect()) + + sample_transport = AppSyncWebsocketsTransport(url=test_url, authorization=authorization) + + assert sample_transport.url == expected_url From 5ea673b012341b0a1a818e6a61403e8313aeff29 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 26 Sep 2021 20:50:00 -0400 Subject: [PATCH 20/68] test: started fixing unit tests. Looks like credentials are not being passed to add_auth. Gotta find where they should be coming from... --- gql/transport/awsappsyncwebsocket.py | 56 ++++++++++++++-------------- tests/fixtures/aws/__init__.py | 0 tests/fixtures/aws/fake_signer.py | 19 +++++++++- tests/test_appsyncwebsocket.py | 6 +-- 4 files changed, 47 insertions(+), 34 deletions(-) create mode 100644 tests/fixtures/aws/__init__.py diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index 2ffb2488..3ef14dff 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -2,8 +2,8 @@ from graphql import DocumentNode, print_ast -from transport.exceptions import TransportProtocolError -from transport.websockets import WebsocketsTransport +from .exceptions import TransportProtocolError +from .websockets import WebsocketsTransport from ssl import SSLContext from typing import Any, Dict, Union, Optional from abc import ABC, abstractmethod @@ -15,10 +15,23 @@ class AppSyncAuthorization(ABC): - def on_connect(self) -> str: - return b64encode( - json.dumps(self.get_headers(), separators=(",", ":")).encode() + def __init__(self, host: str): + self._host = host + + def host_to_auth_url(self) -> str: + """Munge Host For Appsync Auth + + :return: a url used to establish websocket connections to the appsync-realtime-api + """ + url_after_replacements=self._host.replace("https", "wss").replace("appsync-api", "appsync-realtime-api") + headers_from_auth=self.get_headers() + encoded_headers = b64encode( + json.dumps(headers_from_auth, separators=(",", ":")).encode() ).decode() + return '{url}?header={headers}&payload=e30='.format( + url=url_after_replacements, + headers=encoded_headers + ) @abstractmethod def get_headers(self, data: Optional[str] = None) -> Dict: @@ -27,20 +40,21 @@ def get_headers(self, data: Optional[str] = None) -> Dict: class AppSyncApiKeyAuthorization(AppSyncAuthorization): def __init__(self, host: str, api_key: str) -> None: - self.host = host + super().__init__(host) self.api_key = api_key def get_headers(self, data: Optional[str] = None) -> Dict: - return {"host": self.host, "x-api-key": self.api_key} + return {"host": self._host, "x-api-key": self.api_key} class AppSyncOIDCAuthorization(AppSyncAuthorization): def __init__(self, host: str, jwt: str) -> None: - self.host = host + super().__init__(host) self.jwt = jwt def get_headers(self, data: Optional[str] = None) -> Dict: - return {"host": self.host, "Authorization": self.jwt} + return {"host": self._host, "Authorization": self.jwt} + class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): @@ -50,7 +64,7 @@ class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): class AppSyncIAMAuthorization(AppSyncAuthorization): def __init__(self, host: str, region_name=None, profile=None, signer=None, request_creator=None) -> None: - self._host = host + super().__init__(host) self._session = Session(profile=profile) self._credentials = self._session.get_credentials() self._region_name = self._session._resolve_region_name(region_name, self._session.get_default_client_config()) @@ -62,6 +76,8 @@ def get_headers(self, data: Optional[str] = None, request_creator: callable = No request = self._request_creator({ 'method': 'GET', 'url': self._host, + 'headers': {}, + 'context': {}, 'body': data, }) self._signer.add_auth(request) @@ -79,11 +95,8 @@ def __init__( ack_timeout: int = 10, connect_args: Dict[str, Any] = {}, ) -> None: - if authorization: - self.authorization = authorization - else: - self.authorization = AppSyncIAMAuthorization() - url = self._munge_url_for_appsync_auth(url) + self.authorization = authorization if authorization else AppSyncIAMAuthorization(host=url) + url = self.authorization.host_to_auth_url() super().__init__( url, ssl=ssl, @@ -93,19 +106,6 @@ def __init__( connect_args=connect_args, ) - def _munge_url_for_appsync_auth(self, url: str) -> str: - """Munge URL For Appsync Auth - - :param url: The original URL where we replace 'https' and 'appsync-api' and append auth headers - :return: a new url used to establish websocket connections to the appsync-realtime-api - """ - url_after_replacements=url.replace("https", "wss").replace("appsync-api", "appsync-realtime-api") - headers_from_auth=self.authorization.get_headers() - return '{url}?header={headers}&payload=e30='.format( - url=url_after_replacements, - headers=headers_from_auth - ) - async def _wait_start_ack(self) -> None: """Wait for the start_ack message. Keep alive messages are ignored""" diff --git a/tests/fixtures/aws/__init__.py b/tests/fixtures/aws/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py index d5ffd5d8..e1fe61dc 100644 --- a/tests/fixtures/aws/fake_signer.py +++ b/tests/fixtures/aws/fake_signer.py @@ -1,9 +1,24 @@ -def fake_request_creator(): - return FakeRequest() +import pytest + + +@pytest.fixture +def fake_request_factory(): + def _fake_request_factory(): + return FakeRequest() + yield _fake_request_factory + + +@pytest.fixture +def fake_signer_factory(fake_request_factory): + def _fake_signer_factory(request=None): + return FakeSigner(request=request) + yield _fake_signer_factory + class FakeRequest(object): headers = None + class FakeSigner(object): def __init__(self, request=None) -> None: self.request = request if request else FakeRequest() diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 4813f509..a38f6e8e 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -69,10 +69,8 @@ def test_appsyncwebsocket_init_with_iam_auth(): assert sample_transport.authorization is authorization -@pytest.fixture('aws.fake_request_creator') -@pytest.fixture('aws.FakeSigner') -def test_munge_url(fake_signer,fake_request_creator): - authorization = AppSyncIAMAuthorization(signer=fake_signer, request_creator=fake_request_creator) +def test_munge_url(fake_signer_factory, fake_request_factory): + authorization = AppSyncIAMAuthorization(signer=fake_signer_factory(), request_creator=fake_request_factory()) test_url = 'https://appsync-api.aws.example.org/some-other-params' expected_url = 'wss://appsync-realtime-api.aws.example.org/some-other-params?header={headers}&payload=e30='.format(headers=authorization.on_connect()) From 7ae6786d20af2838f55e2a1c1f943b4557ec0d9e Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Mon, 27 Sep 2021 19:45:44 -0400 Subject: [PATCH 21/68] test: added more dependency injection and fakes for aws. New error, now, about request needing some additional structure. That's next. --- gql/transport/awsappsyncwebsocket.py | 12 +++++++----- tests/fixtures/aws/fake_credentials.py | 12 ++++++++++++ tests/fixtures/aws/fake_request.py | 12 ++++++++++++ tests/fixtures/aws/fake_session.py | 21 +++++++++++++++++++++ tests/fixtures/aws/fake_signer.py | 15 +++------------ tests/test_appsyncwebsocket.py | 4 ++-- 6 files changed, 57 insertions(+), 19 deletions(-) create mode 100644 tests/fixtures/aws/fake_credentials.py create mode 100644 tests/fixtures/aws/fake_request.py create mode 100644 tests/fixtures/aws/fake_session.py diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index 3ef14dff..cd080e69 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -1,5 +1,6 @@ from asyncio import wait_for, ensure_future +import botocore.session from graphql import DocumentNode, print_ast from .exceptions import TransportProtocolError @@ -9,7 +10,7 @@ from abc import ABC, abstractmethod from base64 import b64encode from botocore.awsrequest import AWSRequest, create_request_object -from botocore.session import Session +from botocore.session import get_session from botocore.auth import SigV4Auth import json @@ -63,10 +64,10 @@ class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): class AppSyncIAMAuthorization(AppSyncAuthorization): - def __init__(self, host: str, region_name=None, profile=None, signer=None, request_creator=None) -> None: + def __init__(self, host: str, region_name=None, signer=None, request_creator=None, credentials=None, session=None) -> None: super().__init__(host) - self._session = Session(profile=profile) - self._credentials = self._session.get_credentials() + self._session = session if session else get_session() + self._credentials = credentials if credentials else self._session.get_credentials() self._region_name = self._session._resolve_region_name(region_name, self._session.get_default_client_config()) self._service_name = "appsync" self._signer = signer if signer else SigV4Auth(self._credentials, self._service_name, self._region_name) @@ -89,13 +90,14 @@ def __init__( self, url: str, authorization: AppSyncAuthorization = None, + session: botocore.session.Session = None, ssl: Union[SSLContext, bool] = False, connect_timeout: int = 10, close_timeout: int = 10, ack_timeout: int = 10, connect_args: Dict[str, Any] = {}, ) -> None: - self.authorization = authorization if authorization else AppSyncIAMAuthorization(host=url) + self.authorization = authorization if authorization else AppSyncIAMAuthorization(host=url, session=session) url = self.authorization.host_to_auth_url() super().__init__( url, diff --git a/tests/fixtures/aws/fake_credentials.py b/tests/fixtures/aws/fake_credentials.py new file mode 100644 index 00000000..102f5f49 --- /dev/null +++ b/tests/fixtures/aws/fake_credentials.py @@ -0,0 +1,12 @@ +import pytest + +@pytest.fixture +def fake_credentials_factory(): + def _fake_credentials_factory(access_key=None, secret_key=None, method=None, token=None): + return { + "access_key": access_key if access_key else "fake-access-key", + "secret_key": secret_key if secret_key else "fake-secret-key", + "method": method if method else "shared-credentials-file", + "token": token if token else "fake-token", + } + yield _fake_credentials_factory \ No newline at end of file diff --git a/tests/fixtures/aws/fake_request.py b/tests/fixtures/aws/fake_request.py new file mode 100644 index 00000000..bb1de4ed --- /dev/null +++ b/tests/fixtures/aws/fake_request.py @@ -0,0 +1,12 @@ +import pytest + + +class FakeRequest(object): + headers = None + + +@pytest.fixture +def fake_request_factory(): + def _fake_request_factory(): + return FakeRequest() + yield _fake_request_factory diff --git a/tests/fixtures/aws/fake_session.py b/tests/fixtures/aws/fake_session.py new file mode 100644 index 00000000..11f4afa8 --- /dev/null +++ b/tests/fixtures/aws/fake_session.py @@ -0,0 +1,21 @@ +import pytest + + +class FakeSession(object): + def __init__(self, credentials, region_name): + self._credentials = credentials + self._region_name = region_name + + def get_credentials(self): + return self._credentials + + def _resolve_region_name(self): + return self._region_name + + +@pytest.fixture +def fake_session_factory(fake_credentials_factory): + def _fake_session_factory(): + return FakeSession(credentials=fake_credentials_factory, region='fake-region') + + yield _fake_session_factory diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py index e1fe61dc..e50824e1 100644 --- a/tests/fixtures/aws/fake_signer.py +++ b/tests/fixtures/aws/fake_signer.py @@ -1,27 +1,18 @@ import pytest -@pytest.fixture -def fake_request_factory(): - def _fake_request_factory(): - return FakeRequest() - yield _fake_request_factory - - @pytest.fixture def fake_signer_factory(fake_request_factory): def _fake_signer_factory(request=None): + if not request: + request=fake_request_factory() return FakeSigner(request=request) yield _fake_signer_factory -class FakeRequest(object): - headers = None - - class FakeSigner(object): def __init__(self, request=None) -> None: - self.request = request if request else FakeRequest() + self.request = request def add_auth(self, request) -> None: """ diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index a38f6e8e..1793156e 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -42,8 +42,8 @@ mock_transport_url = "https://appsyncapp.awsgateway.com.example.org" @pytest.mark.appsyncwebsocket -def test_appsyncwebsocket_init_with_minimal_args(): - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url) +def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=fake_session_factory()) assert isinstance(sample_transport.authorization, AppSyncIAMAuthorization) assert sample_transport.connect_timeout == 10 assert sample_transport.close_timeout == 10 From 3a025d8fd82de3816e1dba4bd559939e87dcc984 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Tue, 28 Sep 2021 20:19:29 -0400 Subject: [PATCH 22/68] tests: cleaned up awsappsyncwebsocket and related tests so that they all pass. Added additional validation and error handling for various states of credentials. Added additional fakes as needed. --- gql/transport/awsappsyncwebsocket.py | 24 ++++++- tests/conftest.py | 8 +++ tests/fixtures/__init__.py | 0 tests/fixtures/aws/fake_credentials.py | 22 ++++--- tests/fixtures/aws/fake_request.py | 13 +++- tests/fixtures/aws/fake_session.py | 11 ++-- tests/fixtures/fake_logger.py | 18 ++++++ tests/test_appsyncwebsocket.py | 87 +++++++++++++------------- 8 files changed, 123 insertions(+), 60 deletions(-) create mode 100644 tests/fixtures/__init__.py create mode 100644 tests/fixtures/fake_logger.py diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index cd080e69..c95ed16e 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -1,4 +1,5 @@ from asyncio import wait_for, ensure_future +from logging import Logger import botocore.session from graphql import DocumentNode, print_ast @@ -12,6 +13,7 @@ from botocore.awsrequest import AWSRequest, create_request_object from botocore.session import get_session from botocore.auth import SigV4Auth +from botocore.exceptions import NoCredentialsError import json @@ -82,7 +84,7 @@ def get_headers(self, data: Optional[str] = None, request_creator: callable = No 'body': data, }) self._signer.add_auth(request) - return request.headers + return dict(request.headers) class AppSyncWebsocketsTransport(WebsocketsTransport): @@ -96,9 +98,21 @@ def __init__( close_timeout: int = 10, ack_timeout: int = 10, connect_args: Dict[str, Any] = {}, + logger: Logger = None, ) -> None: - self.authorization = authorization if authorization else AppSyncIAMAuthorization(host=url, session=session) - url = self.authorization.host_to_auth_url() + self.logger = logger if logger else Logger('debug') + try: + self.authorization = authorization if authorization else AppSyncIAMAuthorization(host=url, session=session) + url = self.authorization.host_to_auth_url() + except botocore.exceptions.NoCredentialsError as e: + self.authorization = None + self.logger.log(0, 'Credentials not found. Do you have default AWS credentials configured?') + raise e + except TypeError as e: + self.authorization = None + self.logger.log(0, 'A TypeError was raised. The most likely reason for this is that the AWS region is missing from the credentials.') + raise MissingRegionError + super().__init__( url, ssl=ssl, @@ -154,3 +168,7 @@ async def _send_query( ) return query_id + + +class MissingRegionError(Exception): + pass diff --git a/tests/conftest.py b/tests/conftest.py index df69c121..945e70ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -385,3 +385,11 @@ async def run_sync_test_inner(event_loop, server, test_function): await server.close() return run_sync_test_inner + +pytest_plugins = [ + "tests.fixtures.fake_logger", + "tests.fixtures.aws.fake_credentials", + "tests.fixtures.aws.fake_request", + "tests.fixtures.aws.fake_session", + "tests.fixtures.aws.fake_signer", +] diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/aws/fake_credentials.py b/tests/fixtures/aws/fake_credentials.py index 102f5f49..738eeabd 100644 --- a/tests/fixtures/aws/fake_credentials.py +++ b/tests/fixtures/aws/fake_credentials.py @@ -1,12 +1,18 @@ import pytest + +class FakeCredentials(object): + def __init__(self, access_key=None, secret_key=None, method=None, token=None, region=None): + self.region = region if region else "us-east-1a" + self.access_key = access_key if access_key else "fake-access-key" + self.secret_key = secret_key if secret_key else "fake-secret-key" + self.method = method if method else "shared-credentials-file" + self.token = token if token else "fake-token" + + @pytest.fixture def fake_credentials_factory(): - def _fake_credentials_factory(access_key=None, secret_key=None, method=None, token=None): - return { - "access_key": access_key if access_key else "fake-access-key", - "secret_key": secret_key if secret_key else "fake-secret-key", - "method": method if method else "shared-credentials-file", - "token": token if token else "fake-token", - } - yield _fake_credentials_factory \ No newline at end of file + def _fake_credentials_factory(access_key=None, secret_key=None, method=None, token=None, region=None): + return FakeCredentials(access_key=access_key, secret_key=secret_key, method=method, token=token, region=region) + + yield _fake_credentials_factory diff --git a/tests/fixtures/aws/fake_request.py b/tests/fixtures/aws/fake_request.py index bb1de4ed..ffca22db 100644 --- a/tests/fixtures/aws/fake_request.py +++ b/tests/fixtures/aws/fake_request.py @@ -4,9 +4,18 @@ class FakeRequest(object): headers = None + def __init__(self, request_props=None): + if not isinstance(request_props, dict): + return + self.method = request_props.get('method') + self.url = request_props.get('url') + self.headers = request_props.get('headers') + self.context = request_props.get('context') + self.body = request_props.get('body') + @pytest.fixture def fake_request_factory(): - def _fake_request_factory(): - return FakeRequest() + def _fake_request_factory(request_props=None): + return FakeRequest(request_props=request_props) yield _fake_request_factory diff --git a/tests/fixtures/aws/fake_session.py b/tests/fixtures/aws/fake_session.py index 11f4afa8..6c217ea4 100644 --- a/tests/fixtures/aws/fake_session.py +++ b/tests/fixtures/aws/fake_session.py @@ -6,16 +6,19 @@ def __init__(self, credentials, region_name): self._credentials = credentials self._region_name = region_name + def get_default_client_config(self): + return + def get_credentials(self): return self._credentials - def _resolve_region_name(self): - return self._region_name + def _resolve_region_name(self, region_name, client_config): + return region_name if region_name else self._region_name @pytest.fixture def fake_session_factory(fake_credentials_factory): - def _fake_session_factory(): - return FakeSession(credentials=fake_credentials_factory, region='fake-region') + def _fake_session_factory(credentials=fake_credentials_factory()): + return FakeSession(credentials=credentials, region_name='fake-region') yield _fake_session_factory diff --git a/tests/fixtures/fake_logger.py b/tests/fixtures/fake_logger.py new file mode 100644 index 00000000..23efe333 --- /dev/null +++ b/tests/fixtures/fake_logger.py @@ -0,0 +1,18 @@ +import pytest + + +class FakeLogger(object): + def __init__(self, messages=None): + self._messages = messages if messages else [] + + def log(self, level, message): + self._messages.append("LEVEL {}: {}".format(level, message)) + + +@pytest.fixture +def fake_logger_factory(): + + def _fake_logger_factory(messages=None): + return FakeLogger(messages=messages) + + yield _fake_logger_factory diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 1793156e..f0381949 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -1,47 +1,12 @@ +import botocore.exceptions import pytest +from gql.transport.awsappsyncwebsocket import MissingRegionError from gql.transport.awsappsyncwebsocket import AppSyncWebsocketsTransport, AppSyncIAMAuthorization, AppSyncOIDCAuthorization, AppSyncApiKeyAuthorization -# TODO -# from gql.transport.exceptions import ( -# TransportAlreadyConnected, -# TransportClosed, -# TransportProtocolError, -# TransportQueryError, -# TransportServerError, -# ) - -# TODO: -# from .conftest import TemporaryFile - -# TODO: -# query1_str = """ -# query getContinents { -# continents { -# code -# name -# } -# } -# """ - -# TODO: -# query1_server_answer_data = ( -# '{"continents":[' -# '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' -# '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' -# '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' -# '{"code":"SA","name":"South America"}]}' -# ) - - -# TODO: -# query1_server_answer = f'{{"data":{query1_server_answer_data}}}' - -# Marking all tests in this file with the appsyncwebsocket marker -pytestmark = pytest.mark.appsyncwebsocket mock_transport_url = "https://appsyncapp.awsgateway.com.example.org" -@pytest.mark.appsyncwebsocket + def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=fake_session_factory()) assert isinstance(sample_transport.authorization, AppSyncIAMAuthorization) @@ -51,29 +16,65 @@ def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): assert sample_transport.ssl == False assert sample_transport.connect_args == {} + +def test_appsyncwebsocket_init_with_no_credentials(fake_session_factory, fake_logger_factory): + fake_logger = fake_logger_factory() + with pytest.raises(botocore.exceptions.NoCredentialsError): + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=fake_session_factory(credentials=None), logger=fake_logger) + assert sample_transport.authorization is None + assert fake_logger._messages.length > 0 + assert "credentials" in fake_logger._messages[0].lower() + def test_appsyncwebsocket_init_with_oidc_auth(): - authorization = AppSyncOIDCAuthorization() + authorization = AppSyncOIDCAuthorization(host=mock_transport_url, jwt="some-jwt") sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) assert sample_transport.authorization is authorization def test_appsyncwebsocket_init_with_apikey_auth(): - authorization = AppSyncApiKeyAuthorization() + authorization = AppSyncApiKeyAuthorization(host=mock_transport_url, api_key="some-api-key") sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) assert sample_transport.authorization is authorization def test_appsyncwebsocket_init_with_iam_auth(): - authorization = AppSyncIAMAuthorization() + authorization = AppSyncIAMAuthorization(host=mock_transport_url) + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + assert sample_transport.authorization is authorization + + +def test_appsyncwebsocket_init_with_iam_auth(fake_credentials_factory): + authorization = AppSyncIAMAuthorization(host=mock_transport_url, credentials=fake_credentials_factory(), region_name="us-east-1") sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) assert sample_transport.authorization is authorization + +def test_appsyncwebsocket_init_with_iam_auth_and_no_region(fake_credentials_factory, fake_logger_factory): + fake_logger = fake_logger_factory() + with pytest.raises(MissingRegionError): + authorization = AppSyncIAMAuthorization(host=mock_transport_url, credentials=fake_credentials_factory()) + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization, logger=fake_logger) + assert sample_transport.authorization is None + assert fake_logger._messages.length > 0 + assert "credentials" in fake_logger._messages[0].lower() + + def test_munge_url(fake_signer_factory, fake_request_factory): - authorization = AppSyncIAMAuthorization(signer=fake_signer_factory(), request_creator=fake_request_factory()) test_url = 'https://appsync-api.aws.example.org/some-other-params' - expected_url = 'wss://appsync-realtime-api.aws.example.org/some-other-params?header={headers}&payload=e30='.format(headers=authorization.on_connect()) + authorization = AppSyncIAMAuthorization(host=test_url, signer=fake_signer_factory(), request_creator=fake_request_factory) sample_transport = AppSyncWebsocketsTransport(url=test_url, authorization=authorization) + expected_url = authorization.host_to_auth_url() assert sample_transport.url == expected_url + + +def test_munge_url_format(fake_signer_factory, fake_request_factory, fake_credentials_factory, fake_session_factory): + test_url = 'https://appsync-api.aws.example.org/some-other-params' + + authorization = AppSyncIAMAuthorization(host=test_url, signer=fake_signer_factory(), session=fake_session_factory(), request_creator=fake_request_factory, credentials=fake_credentials_factory()) + + expected_url = 'wss://appsync-realtime-api.aws.example.org/some-other-params?header=eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5In0=&payload=e30=' + assert authorization.host_to_auth_url() == expected_url + From 6dd9de90b9841fc11700b955f10ebb1ff5ca2619 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 3 Oct 2021 20:59:09 -0400 Subject: [PATCH 23/68] fix: make check applied, flake errors fixed --- gql/transport/awsappsyncwebsocket.py | 114 ++++++++++++++++--------- tests/fixtures/aws/fake_credentials.py | 16 +++- tests/fixtures/aws/fake_request.py | 11 +-- tests/fixtures/aws/fake_session.py | 2 +- tests/fixtures/aws/fake_signer.py | 5 +- tests/fixtures/fake_logger.py | 1 - tests/test_appsyncwebsocket.py | 107 +++++++++++++++++------ 7 files changed, 178 insertions(+), 78 deletions(-) diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index c95ed16e..cb0e9186 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -1,20 +1,19 @@ -from asyncio import wait_for, ensure_future +import json +from abc import ABC, abstractmethod +from base64 import b64encode from logging import Logger +from ssl import SSLContext +from typing import Any, Dict, Optional, Union import botocore.session +from botocore.auth import SigV4Auth +from botocore.awsrequest import create_request_object +from botocore.exceptions import NoCredentialsError +from botocore.session import get_session from graphql import DocumentNode, print_ast from .exceptions import TransportProtocolError from .websockets import WebsocketsTransport -from ssl import SSLContext -from typing import Any, Dict, Union, Optional -from abc import ABC, abstractmethod -from base64 import b64encode -from botocore.awsrequest import AWSRequest, create_request_object -from botocore.session import get_session -from botocore.auth import SigV4Auth -from botocore.exceptions import NoCredentialsError -import json class AppSyncAuthorization(ABC): @@ -24,16 +23,18 @@ def __init__(self, host: str): def host_to_auth_url(self) -> str: """Munge Host For Appsync Auth - :return: a url used to establish websocket connections to the appsync-realtime-api + :return: a url used to establish websocket connections + to the appsync-realtime-api """ - url_after_replacements=self._host.replace("https", "wss").replace("appsync-api", "appsync-realtime-api") - headers_from_auth=self.get_headers() + url_after_replacements = self._host.replace("https", "wss").replace( + "appsync-api", "appsync-realtime-api" + ) + headers_from_auth = self.get_headers() encoded_headers = b64encode( json.dumps(headers_from_auth, separators=(",", ":")).encode() ).decode() - return '{url}?header={headers}&payload=e30='.format( - url=url_after_replacements, - headers=encoded_headers + return "{url}?header={headers}&payload=e30=".format( + url=url_after_replacements, headers=encoded_headers ) @abstractmethod @@ -59,30 +60,52 @@ def get_headers(self, data: Optional[str] = None) -> Dict: return {"host": self._host, "Authorization": self.jwt} - class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): """Alias for AppSyncOIDCAuthorization""" + pass class AppSyncIAMAuthorization(AppSyncAuthorization): - def __init__(self, host: str, region_name=None, signer=None, request_creator=None, credentials=None, session=None) -> None: + def __init__( + self, + host: str, + region_name=None, + signer=None, + request_creator=None, + credentials=None, + session=None, + ) -> None: super().__init__(host) self._session = session if session else get_session() - self._credentials = credentials if credentials else self._session.get_credentials() - self._region_name = self._session._resolve_region_name(region_name, self._session.get_default_client_config()) + self._credentials = ( + credentials if credentials else self._session.get_credentials() + ) + self._region_name = self._session._resolve_region_name( + region_name, self._session.get_default_client_config() + ) self._service_name = "appsync" - self._signer = signer if signer else SigV4Auth(self._credentials, self._service_name, self._region_name) - self._request_creator = request_creator if request_creator else create_request_object - - def get_headers(self, data: Optional[str] = None, request_creator: callable = None) -> Dict: - request = self._request_creator({ - 'method': 'GET', - 'url': self._host, - 'headers': {}, - 'context': {}, - 'body': data, - }) + self._signer = ( + signer + if signer + else SigV4Auth(self._credentials, self._service_name, self._region_name) + ) + self._request_creator = ( + request_creator if request_creator else create_request_object + ) + + def get_headers( + self, data: Optional[str] = None, request_creator: callable = None + ) -> Dict: + request = self._request_creator( + { + "method": "GET", + "url": self._host, + "headers": {}, + "context": {}, + "body": data, + } + ) self._signer.add_auth(request) return dict(request.headers) @@ -100,17 +123,30 @@ def __init__( connect_args: Dict[str, Any] = {}, logger: Logger = None, ) -> None: - self.logger = logger if logger else Logger('debug') + self.logger = logger if logger else Logger("debug") try: - self.authorization = authorization if authorization else AppSyncIAMAuthorization(host=url, session=session) + self.authorization = ( + authorization + if authorization + else AppSyncIAMAuthorization(host=url, session=session) + ) url = self.authorization.host_to_auth_url() - except botocore.exceptions.NoCredentialsError as e: + except NoCredentialsError as e: self.authorization = None - self.logger.log(0, 'Credentials not found. Do you have default AWS credentials configured?') + self.logger.log( + 0, + "Credentials not found. " + "Do you have default AWS credentials configured?", + ) raise e - except TypeError as e: + except TypeError: self.authorization = None - self.logger.log(0, 'A TypeError was raised. The most likely reason for this is that the AWS region is missing from the credentials.') + self.logger.log( + 0, + "A TypeError was raised. " + "The most likely reason for this is that the AWS " + "region is missing from the credentials.", + ) raise MissingRegionError super().__init__( @@ -160,8 +196,8 @@ async def _send_query( "data": data, "extensions": { "authorization": self.authorization.get_headers(data) - } - } + }, + }, }, separators=(",", ":"), ) diff --git a/tests/fixtures/aws/fake_credentials.py b/tests/fixtures/aws/fake_credentials.py index 738eeabd..d8eac834 100644 --- a/tests/fixtures/aws/fake_credentials.py +++ b/tests/fixtures/aws/fake_credentials.py @@ -2,7 +2,9 @@ class FakeCredentials(object): - def __init__(self, access_key=None, secret_key=None, method=None, token=None, region=None): + def __init__( + self, access_key=None, secret_key=None, method=None, token=None, region=None + ): self.region = region if region else "us-east-1a" self.access_key = access_key if access_key else "fake-access-key" self.secret_key = secret_key if secret_key else "fake-secret-key" @@ -12,7 +14,15 @@ def __init__(self, access_key=None, secret_key=None, method=None, token=None, re @pytest.fixture def fake_credentials_factory(): - def _fake_credentials_factory(access_key=None, secret_key=None, method=None, token=None, region=None): - return FakeCredentials(access_key=access_key, secret_key=secret_key, method=method, token=token, region=region) + def _fake_credentials_factory( + access_key=None, secret_key=None, method=None, token=None, region=None + ): + return FakeCredentials( + access_key=access_key, + secret_key=secret_key, + method=method, + token=token, + region=region, + ) yield _fake_credentials_factory diff --git a/tests/fixtures/aws/fake_request.py b/tests/fixtures/aws/fake_request.py index ffca22db..615bc095 100644 --- a/tests/fixtures/aws/fake_request.py +++ b/tests/fixtures/aws/fake_request.py @@ -7,15 +7,16 @@ class FakeRequest(object): def __init__(self, request_props=None): if not isinstance(request_props, dict): return - self.method = request_props.get('method') - self.url = request_props.get('url') - self.headers = request_props.get('headers') - self.context = request_props.get('context') - self.body = request_props.get('body') + self.method = request_props.get("method") + self.url = request_props.get("url") + self.headers = request_props.get("headers") + self.context = request_props.get("context") + self.body = request_props.get("body") @pytest.fixture def fake_request_factory(): def _fake_request_factory(request_props=None): return FakeRequest(request_props=request_props) + yield _fake_request_factory diff --git a/tests/fixtures/aws/fake_session.py b/tests/fixtures/aws/fake_session.py index 6c217ea4..78e1511a 100644 --- a/tests/fixtures/aws/fake_session.py +++ b/tests/fixtures/aws/fake_session.py @@ -19,6 +19,6 @@ def _resolve_region_name(self, region_name, client_config): @pytest.fixture def fake_session_factory(fake_credentials_factory): def _fake_session_factory(credentials=fake_credentials_factory()): - return FakeSession(credentials=credentials, region_name='fake-region') + return FakeSession(credentials=credentials, region_name="fake-region") yield _fake_session_factory diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py index e50824e1..ff096745 100644 --- a/tests/fixtures/aws/fake_signer.py +++ b/tests/fixtures/aws/fake_signer.py @@ -5,8 +5,9 @@ def fake_signer_factory(fake_request_factory): def _fake_signer_factory(request=None): if not request: - request=fake_request_factory() + request = fake_request_factory() return FakeSigner(request=request) + yield _fake_signer_factory @@ -23,4 +24,4 @@ def add_auth(self, request) -> None: def get_headers(self): self.add_auth(self.request) - return self.request.headers \ No newline at end of file + return self.request.headers diff --git a/tests/fixtures/fake_logger.py b/tests/fixtures/fake_logger.py index 23efe333..5e03863b 100644 --- a/tests/fixtures/fake_logger.py +++ b/tests/fixtures/fake_logger.py @@ -11,7 +11,6 @@ def log(self, level, message): @pytest.fixture def fake_logger_factory(): - def _fake_logger_factory(messages=None): return FakeLogger(messages=messages) diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index f0381949..2b84e9a4 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -1,80 +1,133 @@ import botocore.exceptions import pytest -from gql.transport.awsappsyncwebsocket import MissingRegionError -from gql.transport.awsappsyncwebsocket import AppSyncWebsocketsTransport, AppSyncIAMAuthorization, AppSyncOIDCAuthorization, AppSyncApiKeyAuthorization +from gql.transport.awsappsyncwebsocket import ( + AppSyncApiKeyAuthorization, + AppSyncIAMAuthorization, + AppSyncOIDCAuthorization, + AppSyncWebsocketsTransport, + MissingRegionError, +) mock_transport_url = "https://appsyncapp.awsgateway.com.example.org" def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=fake_session_factory()) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, session=fake_session_factory() + ) assert isinstance(sample_transport.authorization, AppSyncIAMAuthorization) assert sample_transport.connect_timeout == 10 assert sample_transport.close_timeout == 10 assert sample_transport.ack_timeout == 10 - assert sample_transport.ssl == False + assert sample_transport.ssl is False assert sample_transport.connect_args == {} -def test_appsyncwebsocket_init_with_no_credentials(fake_session_factory, fake_logger_factory): +def test_appsyncwebsocket_init_with_no_credentials( + fake_session_factory, fake_logger_factory +): fake_logger = fake_logger_factory() with pytest.raises(botocore.exceptions.NoCredentialsError): - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=fake_session_factory(credentials=None), logger=fake_logger) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, + session=fake_session_factory(credentials=None), + logger=fake_logger, + ) assert sample_transport.authorization is None assert fake_logger._messages.length > 0 assert "credentials" in fake_logger._messages[0].lower() + def test_appsyncwebsocket_init_with_oidc_auth(): authorization = AppSyncOIDCAuthorization(host=mock_transport_url, jwt="some-jwt") - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, authorization=authorization + ) assert sample_transport.authorization is authorization def test_appsyncwebsocket_init_with_apikey_auth(): - authorization = AppSyncApiKeyAuthorization(host=mock_transport_url, api_key="some-api-key") - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + authorization = AppSyncApiKeyAuthorization( + host=mock_transport_url, api_key="some-api-key" + ) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, authorization=authorization + ) assert sample_transport.authorization is authorization def test_appsyncwebsocket_init_with_iam_auth(): authorization = AppSyncIAMAuthorization(host=mock_transport_url) - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, authorization=authorization + ) assert sample_transport.authorization is authorization -def test_appsyncwebsocket_init_with_iam_auth(fake_credentials_factory): - authorization = AppSyncIAMAuthorization(host=mock_transport_url, credentials=fake_credentials_factory(), region_name="us-east-1") - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) +def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory): + authorization = AppSyncIAMAuthorization( + host=mock_transport_url, + credentials=fake_credentials_factory(), + region_name="us-east-1", + ) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, authorization=authorization + ) assert sample_transport.authorization is authorization - -def test_appsyncwebsocket_init_with_iam_auth_and_no_region(fake_credentials_factory, fake_logger_factory): +def test_appsyncwebsocket_init_with_iam_auth_and_no_region( + fake_credentials_factory, fake_logger_factory +): fake_logger = fake_logger_factory() with pytest.raises(MissingRegionError): - authorization = AppSyncIAMAuthorization(host=mock_transport_url, credentials=fake_credentials_factory()) - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization, logger=fake_logger) + authorization = AppSyncIAMAuthorization( + host=mock_transport_url, credentials=fake_credentials_factory() + ) + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, authorization=authorization, logger=fake_logger + ) assert sample_transport.authorization is None assert fake_logger._messages.length > 0 assert "credentials" in fake_logger._messages[0].lower() def test_munge_url(fake_signer_factory, fake_request_factory): - test_url = 'https://appsync-api.aws.example.org/some-other-params' + test_url = "https://appsync-api.aws.example.org/some-other-params" - authorization = AppSyncIAMAuthorization(host=test_url, signer=fake_signer_factory(), request_creator=fake_request_factory) - sample_transport = AppSyncWebsocketsTransport(url=test_url, authorization=authorization) + authorization = AppSyncIAMAuthorization( + host=test_url, + signer=fake_signer_factory(), + request_creator=fake_request_factory, + ) + sample_transport = AppSyncWebsocketsTransport( + url=test_url, authorization=authorization + ) expected_url = authorization.host_to_auth_url() assert sample_transport.url == expected_url -def test_munge_url_format(fake_signer_factory, fake_request_factory, fake_credentials_factory, fake_session_factory): - test_url = 'https://appsync-api.aws.example.org/some-other-params' - - authorization = AppSyncIAMAuthorization(host=test_url, signer=fake_signer_factory(), session=fake_session_factory(), request_creator=fake_request_factory, credentials=fake_credentials_factory()) - - expected_url = 'wss://appsync-realtime-api.aws.example.org/some-other-params?header=eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5In0=&payload=e30=' +def test_munge_url_format( + fake_signer_factory, + fake_request_factory, + fake_credentials_factory, + fake_session_factory, +): + test_url = "https://appsync-api.aws.example.org/some-other-params" + + authorization = AppSyncIAMAuthorization( + host=test_url, + signer=fake_signer_factory(), + session=fake_session_factory(), + request_creator=fake_request_factory, + credentials=fake_credentials_factory(), + ) + + header_string = "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5In0=" + expected_url = ( + f"wss://appsync-realtime-api.aws.example.org/" + f"some-other-params?header={header_string}&payload=e30=" + ) assert authorization.host_to_auth_url() == expected_url - From 887ed59d3384a7d03f7a3d850869598a5985465a Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 3 Oct 2021 20:59:31 -0400 Subject: [PATCH 24/68] fix: added 'aws' as an install target for tests --- tests/conftest.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 945e70ba..d742d50b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,11 +14,7 @@ from gql import Client -all_transport_dependencies = [ - "aiohttp", - "requests", - "websockets", -] +all_transport_dependencies = ["aiohttp", "requests", "websockets", "aws"] def pytest_addoption(parser): @@ -386,6 +382,7 @@ async def run_sync_test_inner(event_loop, server, test_function): return run_sync_test_inner + pytest_plugins = [ "tests.fixtures.fake_logger", "tests.fixtures.aws.fake_credentials", From 90c121aa0abe634a7cca7af28f3bc68df6d8bc76 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 10 Oct 2021 19:23:26 -0400 Subject: [PATCH 25/68] fix: typehint issues addressed --- gql/transport/awsappsyncwebsocket.py | 53 ++++++++++++++-------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index cb0e9186..6bb4903b 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -3,11 +3,11 @@ from base64 import b64encode from logging import Logger from ssl import SSLContext -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import botocore.session from botocore.auth import SigV4Auth -from botocore.awsrequest import create_request_object +from botocore.awsrequest import AWSRequest, create_request_object from botocore.exceptions import NoCredentialsError from botocore.session import get_session from graphql import DocumentNode, print_ast @@ -38,7 +38,7 @@ def host_to_auth_url(self) -> str: ) @abstractmethod - def get_headers(self, data: Optional[str] = None) -> Dict: + def get_headers(self, data: Optional[dict] = None) -> Dict: raise NotImplementedError() @@ -47,7 +47,7 @@ def __init__(self, host: str, api_key: str) -> None: super().__init__(host) self.api_key = api_key - def get_headers(self, data: Optional[str] = None) -> Dict: + def get_headers(self, data: Optional[dict] = None) -> Dict: return {"host": self._host, "x-api-key": self.api_key} @@ -56,7 +56,7 @@ def __init__(self, host: str, jwt: str) -> None: super().__init__(host) self.jwt = jwt - def get_headers(self, data: Optional[str] = None) -> Dict: + def get_headers(self, data: Optional[dict] = None) -> Dict: return {"host": self._host, "Authorization": self.jwt} @@ -95,7 +95,9 @@ def __init__( ) def get_headers( - self, data: Optional[str] = None, request_creator: callable = None + self, + data: Optional[dict] = None, + request_creator: Callable[[dict], AWSRequest] = None, ) -> Dict: request = self._request_creator( { @@ -111,11 +113,13 @@ def get_headers( class AppSyncWebsocketsTransport(WebsocketsTransport): + authorization: Optional[AppSyncAuthorization] + def __init__( self, url: str, - authorization: AppSyncAuthorization = None, - session: botocore.session.Session = None, + authorization: Optional[AppSyncAuthorization] = None, + session: Optional[botocore.session.Session] = None, ssl: Union[SSLContext, bool] = False, connect_timeout: int = 10, close_timeout: int = 10, @@ -132,7 +136,7 @@ def __init__( ) url = self.authorization.host_to_auth_url() except NoCredentialsError as e: - self.authorization = None + del self.authorization self.logger.log( 0, "Credentials not found. " @@ -140,7 +144,7 @@ def __init__( ) raise e except TypeError: - self.authorization = None + del self.authorization self.logger.log( 0, "A TypeError was raised. " @@ -181,27 +185,24 @@ async def _send_query( query_id = self.next_query_id self.next_query_id += 1 - data = {"query": print_ast(document)} + data: Dict = {"query": print_ast(document)} if variable_values: data["variables"] = variable_values if operation_name: data["operationName"] = operation_name - await self._send( - json.dumps( - { - "id": str(query_id), - "type": "start", - "payload": { - "data": data, - "extensions": { - "authorization": self.authorization.get_headers(data) - }, - }, - }, - separators=(",", ":"), - ) - ) + message: Dict = { + "id": str(query_id), + "type": "start", + "payload": {"data": data}, + } + + if self.authorization: + message["payload"]["extensions"] = { + "authorization": self.authorization.get_headers(data) + } + + await self._send(json.dumps(message, separators=(",", ":"),)) return query_id From 1ac1350c5042acdf06fd86275bbe2ebf12be65cd Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Sun, 10 Oct 2021 19:29:03 -0400 Subject: [PATCH 26/68] test: updated iam test without credentials to explicitly trigger credential error --- tests/test_appsyncwebsocket.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 2b84e9a4..67423b10 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -57,12 +57,12 @@ def test_appsyncwebsocket_init_with_apikey_auth(): assert sample_transport.authorization is authorization -def test_appsyncwebsocket_init_with_iam_auth(): - authorization = AppSyncIAMAuthorization(host=mock_transport_url) - sample_transport = AppSyncWebsocketsTransport( - url=mock_transport_url, authorization=authorization - ) - assert sample_transport.authorization is authorization +def test_appsyncwebsocket_init_with_iam_auth_without_creds(): + authorization = AppSyncIAMAuthorization(host=mock_transport_url, credentials=None) + with pytest.raises(botocore.exceptions.NoCredentialsError): + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, authorization=authorization + ) def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory): From 7b75fd5bc3268cc315ea864becbdec19321fbf2e Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Mon, 11 Oct 2021 18:09:36 -0400 Subject: [PATCH 27/68] fix: linting errors --- gql/transport/requests.py | 2 +- tests/test_appsyncwebsocket.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 68b4144b..b2f4e8d0 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -8,9 +8,9 @@ from requests.adapters import HTTPAdapter, Retry from requests.auth import AuthBase from requests.cookies import RequestsCookieJar -from requests_toolbelt.multipart.encoder import MultipartEncoder from gql.transport import Transport +from requests_toolbelt.multipart.encoder import MultipartEncoder from ..utils import extract_files from .exceptions import ( diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 67423b10..17d00709 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -60,9 +60,7 @@ def test_appsyncwebsocket_init_with_apikey_auth(): def test_appsyncwebsocket_init_with_iam_auth_without_creds(): authorization = AppSyncIAMAuthorization(host=mock_transport_url, credentials=None) with pytest.raises(botocore.exceptions.NoCredentialsError): - sample_transport = AppSyncWebsocketsTransport( - url=mock_transport_url, authorization=authorization - ) + AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory): From 86ddc13a065aad824c463783907a225a70329cf5 Mon Sep 17 00:00:00 2001 From: Chad Furman Date: Mon, 11 Oct 2021 19:09:24 -0400 Subject: [PATCH 28/68] fix: test marker adjusted, renamed --- setup.py | 6 +++--- tests/conftest.py | 2 +- tests/test_appsyncwebsocket.py | 3 +++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 64018932..da279d63 100644 --- a/setup.py +++ b/setup.py @@ -45,12 +45,12 @@ "websockets>=9,<10", ] -install_aws_requires = [ +install_appsyncwebsockets_requires = [ "botocore>=1.21,<1.22", ] + install_websockets_requires install_all_requires = ( - install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_aws_requires + install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_appsyncwebsockets_requires ) # Get version from __version__.py file @@ -92,7 +92,7 @@ "aiohttp": install_aiohttp_requires, "requests": install_requests_requires, "websockets": install_websockets_requires, - "aws": install_aws_requires, + "appsyncwebsockets": install_appsyncwebsockets_requires, }, include_package_data=True, zip_safe=False, diff --git a/tests/conftest.py b/tests/conftest.py index d742d50b..17bf7cc4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ from gql import Client -all_transport_dependencies = ["aiohttp", "requests", "websockets", "aws"] +all_transport_dependencies = ["aiohttp", "requests", "websockets", "appsyncwebsockets"] def pytest_addoption(parser): diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 17d00709..a1d80f09 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -9,6 +9,9 @@ MissingRegionError, ) +# Marking all tests in this file with the appsyncwebsockets marker +pytestmark = pytest.mark.appsyncwebsockets + mock_transport_url = "https://appsyncapp.awsgateway.com.example.org" From d79800130608b71b71e45def341bf03328f0386e Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 26 Oct 2021 17:03:19 +0200 Subject: [PATCH 29/68] fix: running 'make check' to fix module sort order --- gql/transport/requests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index b2f4e8d0..68b4144b 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -8,9 +8,9 @@ from requests.adapters import HTTPAdapter, Retry from requests.auth import AuthBase from requests.cookies import RequestsCookieJar +from requests_toolbelt.multipart.encoder import MultipartEncoder from gql.transport import Transport -from requests_toolbelt.multipart.encoder import MultipartEncoder from ..utils import extract_files from .exceptions import ( From 2ef43eb27ed573d142f233b68626a7d14741a562 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 26 Oct 2021 17:21:20 +0200 Subject: [PATCH 30/68] fix: tests put imports inside methods to fix tests without botocore Now pytest tests --aiohttp-only will work without the botocore dependency --- tests/test_appsyncwebsocket.py | 51 ++++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index a1d80f09..13c4dc7c 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -1,14 +1,5 @@ -import botocore.exceptions import pytest -from gql.transport.awsappsyncwebsocket import ( - AppSyncApiKeyAuthorization, - AppSyncIAMAuthorization, - AppSyncOIDCAuthorization, - AppSyncWebsocketsTransport, - MissingRegionError, -) - # Marking all tests in this file with the appsyncwebsockets marker pytestmark = pytest.mark.appsyncwebsockets @@ -16,6 +7,11 @@ def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): + from gql.transport.awsappsyncwebsocket import ( + AppSyncIAMAuthorization, + AppSyncWebsocketsTransport, + ) + sample_transport = AppSyncWebsocketsTransport( url=mock_transport_url, session=fake_session_factory() ) @@ -30,6 +26,9 @@ def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): def test_appsyncwebsocket_init_with_no_credentials( fake_session_factory, fake_logger_factory ): + import botocore.exceptions + from gql.transport.awsappsyncwebsocket import AppSyncWebsocketsTransport + fake_logger = fake_logger_factory() with pytest.raises(botocore.exceptions.NoCredentialsError): sample_transport = AppSyncWebsocketsTransport( @@ -43,6 +42,11 @@ def test_appsyncwebsocket_init_with_no_credentials( def test_appsyncwebsocket_init_with_oidc_auth(): + from gql.transport.awsappsyncwebsocket import ( + AppSyncOIDCAuthorization, + AppSyncWebsocketsTransport, + ) + authorization = AppSyncOIDCAuthorization(host=mock_transport_url, jwt="some-jwt") sample_transport = AppSyncWebsocketsTransport( url=mock_transport_url, authorization=authorization @@ -51,6 +55,11 @@ def test_appsyncwebsocket_init_with_oidc_auth(): def test_appsyncwebsocket_init_with_apikey_auth(): + from gql.transport.awsappsyncwebsocket import ( + AppSyncApiKeyAuthorization, + AppSyncWebsocketsTransport, + ) + authorization = AppSyncApiKeyAuthorization( host=mock_transport_url, api_key="some-api-key" ) @@ -61,12 +70,23 @@ def test_appsyncwebsocket_init_with_apikey_auth(): def test_appsyncwebsocket_init_with_iam_auth_without_creds(): + import botocore.exceptions + from gql.transport.awsappsyncwebsocket import ( + AppSyncIAMAuthorization, + AppSyncWebsocketsTransport, + ) + authorization = AppSyncIAMAuthorization(host=mock_transport_url, credentials=None) with pytest.raises(botocore.exceptions.NoCredentialsError): AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory): + from gql.transport.awsappsyncwebsocket import ( + AppSyncIAMAuthorization, + AppSyncWebsocketsTransport, + ) + authorization = AppSyncIAMAuthorization( host=mock_transport_url, credentials=fake_credentials_factory(), @@ -81,6 +101,12 @@ def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory def test_appsyncwebsocket_init_with_iam_auth_and_no_region( fake_credentials_factory, fake_logger_factory ): + from gql.transport.awsappsyncwebsocket import ( + AppSyncIAMAuthorization, + AppSyncWebsocketsTransport, + MissingRegionError, + ) + fake_logger = fake_logger_factory() with pytest.raises(MissingRegionError): authorization = AppSyncIAMAuthorization( @@ -95,6 +121,11 @@ def test_appsyncwebsocket_init_with_iam_auth_and_no_region( def test_munge_url(fake_signer_factory, fake_request_factory): + from gql.transport.awsappsyncwebsocket import ( + AppSyncIAMAuthorization, + AppSyncWebsocketsTransport, + ) + test_url = "https://appsync-api.aws.example.org/some-other-params" authorization = AppSyncIAMAuthorization( @@ -116,6 +147,8 @@ def test_munge_url_format( fake_credentials_factory, fake_session_factory, ): + from gql.transport.awsappsyncwebsocket import AppSyncIAMAuthorization + test_url = "https://appsync-api.aws.example.org/some-other-params" authorization = AppSyncIAMAuthorization( From aa9cb069d5c2e709a25c155a1ed8325649f62d4a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 29 Nov 2021 21:41:43 +0100 Subject: [PATCH 31/68] Multiple fixes: serializing the data in the payload adding parse_answer method replace dict typing by Dict change log method to correspond to other transports rename host_to_auth method to get_auth_url and add an url argument this is needed as the connection url could be different than the host which is used in the signed requests Trying to make it work with IAM auth. Not working for now. Removing MissingRegionError Exception --- gql/transport/awsappsyncwebsocket.py | 168 +++++++++++++++++---------- gql/transport/websockets.py | 1 + tests/conftest.py | 4 +- tests/fixtures/fake_logger.py | 17 --- tests/test_appsyncwebsocket.py | 49 ++++---- 5 files changed, 139 insertions(+), 100 deletions(-) delete mode 100644 tests/fixtures/fake_logger.py diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsyncwebsocket.py index 6bb4903b..096da2c6 100644 --- a/gql/transport/awsappsyncwebsocket.py +++ b/gql/transport/awsappsyncwebsocket.py @@ -1,62 +1,61 @@ import json +import logging from abc import ABC, abstractmethod from base64 import b64encode -from logging import Logger from ssl import SSLContext -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import botocore.session from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest, create_request_object from botocore.exceptions import NoCredentialsError from botocore.session import get_session -from graphql import DocumentNode, print_ast +from graphql import DocumentNode, ExecutionResult, print_ast -from .exceptions import TransportProtocolError +from .exceptions import TransportProtocolError, TransportServerError from .websockets import WebsocketsTransport +log = logging.getLogger(__name__) -class AppSyncAuthorization(ABC): - def __init__(self, host: str): - self._host = host - - def host_to_auth_url(self) -> str: - """Munge Host For Appsync Auth +class AppSyncAuthorization(ABC): + def get_auth_url(self, url: str) -> str: + """ :return: a url used to establish websocket connections to the appsync-realtime-api """ - url_after_replacements = self._host.replace("https", "wss").replace( - "appsync-api", "appsync-realtime-api" - ) - headers_from_auth = self.get_headers() + headers = self.get_headers() + encoded_headers = b64encode( - json.dumps(headers_from_auth, separators=(",", ":")).encode() + json.dumps(headers, separators=(",", ":")).encode() ).decode() - return "{url}?header={headers}&payload=e30=".format( - url=url_after_replacements, headers=encoded_headers + + url_base = url.replace("https://", "wss://").replace( + "appsync-api", "appsync-realtime-api" ) + return f"{url_base}?header={encoded_headers}&payload=e30=" + @abstractmethod - def get_headers(self, data: Optional[dict] = None) -> Dict: + def get_headers(self, data: Optional[Dict] = None) -> Dict: raise NotImplementedError() class AppSyncApiKeyAuthorization(AppSyncAuthorization): def __init__(self, host: str, api_key: str) -> None: - super().__init__(host) + self._host = host self.api_key = api_key - def get_headers(self, data: Optional[dict] = None) -> Dict: + def get_headers(self, data: Optional[Dict] = None) -> Dict: return {"host": self._host, "x-api-key": self.api_key} class AppSyncOIDCAuthorization(AppSyncAuthorization): def __init__(self, host: str, jwt: str) -> None: - super().__init__(host) + self._host = host self.jwt = jwt - def get_headers(self, data: Optional[dict] = None) -> Dict: + def get_headers(self, data: Optional[Dict] = None) -> Dict: return {"host": self._host, "Authorization": self.jwt} @@ -76,7 +75,7 @@ def __init__( credentials=None, session=None, ) -> None: - super().__init__(host) + self._host = host self._session = session if session else get_session() self._credentials = ( credentials if credentials else self._session.get_credentials() @@ -96,20 +95,42 @@ def __init__( def get_headers( self, - data: Optional[dict] = None, - request_creator: Callable[[dict], AWSRequest] = None, + data: Optional[Dict] = None, + request_creator: Callable[[Dict], AWSRequest] = None, ) -> Dict: - request = self._request_creator( + + """ + Should we add other data in the headers field ? + utc_now = datetime.utcnow() + amz_date = utc_now.strftime("%Y%m%dT%H%M%SZ") + + headers = { + "accept": "application/json, text/javascript", + "content-encoding": "amz-1.0", + "content-type": "application/json; charset=UTF-8", + "host": self._host, + "x-amz-date": amz_date, + } + """ + headers: Dict[str, Any] = {} + + request: AWSRequest = self._request_creator( { - "method": "GET", + "method": "POST", "url": self._host, - "headers": {}, + "headers": headers, "context": {}, "body": data, } ) + self._signer.add_auth(request) - return dict(request.headers) + + headers = dict(request.headers) + + log.debug(f"\n\nSigned headers: {headers}\n\n") + + return headers class AppSyncWebsocketsTransport(WebsocketsTransport): @@ -125,33 +146,28 @@ def __init__( close_timeout: int = 10, ack_timeout: int = 10, connect_args: Dict[str, Any] = {}, - logger: Logger = None, ) -> None: - self.logger = logger if logger else Logger("debug") try: - self.authorization = ( - authorization - if authorization - else AppSyncIAMAuthorization(host=url, session=session) - ) - url = self.authorization.host_to_auth_url() - except NoCredentialsError as e: - del self.authorization - self.logger.log( - 0, + if not authorization: + authorization = AppSyncIAMAuthorization(host=url, session=session) + + self.authorization = authorization + + url = self.authorization.get_auth_url(url) + + except NoCredentialsError: + log.warning( "Credentials not found. " "Do you have default AWS credentials configured?", ) - raise e + raise except TypeError: - del self.authorization - self.logger.log( - 0, + log.warning( "A TypeError was raised. " "The most likely reason for this is that the AWS " "region is missing from the credentials.", ) - raise MissingRegionError + raise super().__init__( url, @@ -162,19 +178,55 @@ def __init__( connect_args=connect_args, ) - async def _wait_start_ack(self) -> None: - """Wait for the start_ack message. Keep alive messages are ignored""" + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server. + + Difference between apollo protocol and aws protocol: + + - aws protocol can return an error without an id + - aws protocol will send start_ack messages + + Returns a list consisting of: + - the answer_type: + - 'connection_ack', + - 'connection_error', + - 'start_ack', + - 'ka', + - 'data', + - 'error', + - 'complete' + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + """ - while True: - answer_type = str(json.loads(await self._receive()).get("type")) + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + json_answer = json.loads(answer) + + answer_type = str(json_answer.get("type")) if answer_type == "start_ack": - return + return ("start_ack", None, None) + + elif answer_type == "error" and id not in json_answer: + error_payload = json_answer.get("payload") + raise TransportServerError(f"Server error: '{error_payload!r}'") + + else: + + return self._parse_answer_apollo(answer) - if answer_type != "ka": - raise TransportProtocolError( - "AppSync server did not return a start ack" - ) + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) from e + + return answer_type, answer_id, execution_result async def _send_query( self, @@ -194,7 +246,7 @@ async def _send_query( message: Dict = { "id": str(query_id), "type": "start", - "payload": {"data": data}, + "payload": {"data": json.dumps(data, separators=(",", ":"))}, } if self.authorization: @@ -205,7 +257,3 @@ async def _send_query( await self._send(json.dumps(message, separators=(",", ":"),)) return query_id - - -class MissingRegionError(Exception): - pass diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 779a3608..29c381d5 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -133,6 +133,7 @@ def __init__( :param connect_args: Other parameters forwarded to websockets.connect """ + log.debug(f"WebsocketTransport url = {url}") self.url: str = url self.ssl: Union[SSLContext, bool] = ssl self.headers: Optional[HeadersLike] = headers diff --git a/tests/conftest.py b/tests/conftest.py index c9e23299..f8865972 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -112,10 +112,11 @@ async def ssl_aiohttp_server(): yield server -# Adding debug logs to websocket tests +# Adding debug logs for name in [ "websockets.legacy.server", "gql.transport.aiohttp", + # "gql.transport.awsappsyncwebsocket", "gql.transport.phoenix_channel_websockets", "gql.transport.requests", "gql.transport.websockets", @@ -491,7 +492,6 @@ async def run_sync_test_inner(event_loop, server, test_function): pytest_plugins = [ - "tests.fixtures.fake_logger", "tests.fixtures.aws.fake_credentials", "tests.fixtures.aws.fake_request", "tests.fixtures.aws.fake_session", diff --git a/tests/fixtures/fake_logger.py b/tests/fixtures/fake_logger.py deleted file mode 100644 index 5e03863b..00000000 --- a/tests/fixtures/fake_logger.py +++ /dev/null @@ -1,17 +0,0 @@ -import pytest - - -class FakeLogger(object): - def __init__(self, messages=None): - self._messages = messages if messages else [] - - def log(self, level, message): - self._messages.append("LEVEL {}: {}".format(level, message)) - - -@pytest.fixture -def fake_logger_factory(): - def _fake_logger_factory(messages=None): - return FakeLogger(messages=messages) - - yield _fake_logger_factory diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 13c4dc7c..86844850 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -23,22 +23,21 @@ def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): assert sample_transport.connect_args == {} -def test_appsyncwebsocket_init_with_no_credentials( - fake_session_factory, fake_logger_factory -): +def test_appsyncwebsocket_init_with_no_credentials(caplog, fake_session_factory): import botocore.exceptions from gql.transport.awsappsyncwebsocket import AppSyncWebsocketsTransport - fake_logger = fake_logger_factory() with pytest.raises(botocore.exceptions.NoCredentialsError): sample_transport = AppSyncWebsocketsTransport( - url=mock_transport_url, - session=fake_session_factory(credentials=None), - logger=fake_logger, + url=mock_transport_url, session=fake_session_factory(credentials=None), ) assert sample_transport.authorization is None - assert fake_logger._messages.length > 0 - assert "credentials" in fake_logger._messages[0].lower() + + expected_error = "Credentials not found" + + print(f"Captured log: {caplog.text}") + + assert expected_error in caplog.text def test_appsyncwebsocket_init_with_oidc_auth(): @@ -69,14 +68,16 @@ def test_appsyncwebsocket_init_with_apikey_auth(): assert sample_transport.authorization is authorization -def test_appsyncwebsocket_init_with_iam_auth_without_creds(): +def test_appsyncwebsocket_init_with_iam_auth_without_creds(fake_session_factory): import botocore.exceptions from gql.transport.awsappsyncwebsocket import ( AppSyncIAMAuthorization, AppSyncWebsocketsTransport, ) - authorization = AppSyncIAMAuthorization(host=mock_transport_url, credentials=None) + authorization = AppSyncIAMAuthorization( + host=mock_transport_url, session=fake_session_factory(credentials=None), + ) with pytest.raises(botocore.exceptions.NoCredentialsError): AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) @@ -99,25 +100,27 @@ def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory def test_appsyncwebsocket_init_with_iam_auth_and_no_region( - fake_credentials_factory, fake_logger_factory + caplog, fake_credentials_factory ): from gql.transport.awsappsyncwebsocket import ( AppSyncIAMAuthorization, AppSyncWebsocketsTransport, - MissingRegionError, ) - fake_logger = fake_logger_factory() - with pytest.raises(MissingRegionError): + with pytest.raises(TypeError): authorization = AppSyncIAMAuthorization( host=mock_transport_url, credentials=fake_credentials_factory() ) sample_transport = AppSyncWebsocketsTransport( - url=mock_transport_url, authorization=authorization, logger=fake_logger + url=mock_transport_url, authorization=authorization ) assert sample_transport.authorization is None - assert fake_logger._messages.length > 0 - assert "credentials" in fake_logger._messages[0].lower() + + print(f"Captured: {caplog.text}") + + expected_error = "the AWS region is missing from the credentials" + + assert expected_error in caplog.text def test_munge_url(fake_signer_factory, fake_request_factory): @@ -137,7 +140,11 @@ def test_munge_url(fake_signer_factory, fake_request_factory): url=test_url, authorization=authorization ) - expected_url = authorization.host_to_auth_url() + header_string = "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5In0=" + expected_url = ( + "wss://appsync-realtime-api.aws.example.org/" + f"some-other-params?header={header_string}&payload=e30=" + ) assert sample_transport.url == expected_url @@ -161,7 +168,7 @@ def test_munge_url_format( header_string = "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5In0=" expected_url = ( - f"wss://appsync-realtime-api.aws.example.org/" + "wss://appsync-realtime-api.aws.example.org/" f"some-other-params?header={header_string}&payload=e30=" ) - assert authorization.host_to_auth_url() == expected_url + assert authorization.get_auth_url(test_url) == expected_url From adf5a385966abdaa1ec83405a9784c160e8f1a19 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 30 Nov 2021 22:43:30 +0100 Subject: [PATCH 32/68] Add AWS code examples --- docs/code_examples/aws_api_key_mutation.py | 53 +++++++++++++++++++ .../code_examples/aws_api_key_subscription.py | 52 ++++++++++++++++++ docs/code_examples/aws_iam_subscription.py | 42 +++++++++++++++ 3 files changed, 147 insertions(+) create mode 100644 docs/code_examples/aws_api_key_mutation.py create mode 100644 docs/code_examples/aws_api_key_subscription.py create mode 100644 docs/code_examples/aws_iam_subscription.py diff --git a/docs/code_examples/aws_api_key_mutation.py b/docs/code_examples/aws_api_key_mutation.py new file mode 100644 index 00000000..44fde329 --- /dev/null +++ b/docs/code_examples/aws_api_key_mutation.py @@ -0,0 +1,53 @@ +import asyncio +import logging +import os +import sys +from urllib.parse import urlparse + +from gql import Client, gql +from gql.transport.aiohttp import AIOHTTPTransport +from gql.transport.awsappsyncwebsocket import AppSyncApiKeyAuthorization + +logging.basicConfig(level=logging.DEBUG) + + +async def main(): + + # Should look like: + # https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + url = os.environ.get("AWS_GRAPHQL_API_ENDPOINT") + api_key = os.environ.get("AWS_GRAPHQL_API_KEY") + + if url is None or api_key is None: + print("Missing environment variables") + sys.exit() + + # Extract host from url + host = str(urlparse(url).netloc) + + auth = AppSyncApiKeyAuthorization(host=host, api_key=api_key) + + transport = AIOHTTPTransport(url=url, headers=auth.get_headers()) + + async with Client( + transport=transport, fetch_schema_from_transport=False, + ) as session: + + query = gql( + """ +mutation createMessage($message: String!) { + createMessage(input: {message: $message}) { + id + message + createdAt + } +}""" + ) + + variable_values = {"message": "Hello world!"} + + result = await session.execute(query, variable_values=variable_values) + print(result) + + +asyncio.run(main()) diff --git a/docs/code_examples/aws_api_key_subscription.py b/docs/code_examples/aws_api_key_subscription.py new file mode 100644 index 00000000..6268e0a4 --- /dev/null +++ b/docs/code_examples/aws_api_key_subscription.py @@ -0,0 +1,52 @@ +import asyncio +import logging +import os +import sys +from urllib.parse import urlparse + +from gql import Client, gql +from gql.transport.awsappsyncwebsocket import ( + AppSyncApiKeyAuthorization, + AppSyncWebsocketsTransport, +) + +logging.basicConfig(level=logging.DEBUG) + + +async def main(): + + # Should look like: + # https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + url = os.environ.get("AWS_GRAPHQL_API_ENDPOINT") + api_key = os.environ.get("AWS_GRAPHQL_API_KEY") + + if url is None or api_key is None: + print("Missing environment variables") + sys.exit() + + # Extract host from url + host = str(urlparse(url).netloc) + + print(f"Host: {host}") + + auth = AppSyncApiKeyAuthorization(host=host, api_key=api_key) + + transport = AppSyncWebsocketsTransport(url=url, authorization=auth) + + async with Client(transport=transport) as session: + + subscription = gql( + """ +subscription onCreateMessage { + onCreateMessage { + message + } +} +""" + ) + + async for result in session.subscribe(subscription): + print(result) + + +asyncio.run(main()) diff --git a/docs/code_examples/aws_iam_subscription.py b/docs/code_examples/aws_iam_subscription.py new file mode 100644 index 00000000..82394628 --- /dev/null +++ b/docs/code_examples/aws_iam_subscription.py @@ -0,0 +1,42 @@ +import asyncio +import logging +import os +import sys + +from gql import Client, gql +from gql.transport.awsappsyncwebsocket import AppSyncWebsocketsTransport + +logging.basicConfig(level=logging.DEBUG) + + +async def main(): + + # Should look like: + # https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + url = os.environ.get("AWS_GRAPHQL_API_ENDPOINT") + api_key = os.environ.get("AWS_GRAPHQL_API_KEY") + + if url is None or api_key is None: + print("Missing environment variables") + sys.exit() + + # Using implicit auth (IAM) + transport = AppSyncWebsocketsTransport(url=url) + + async with Client(transport=transport) as session: + + subscription = gql( + """ +subscription onCreateMessage { + onCreateMessage { + message + } +} +""" + ) + + async for result in session.subscribe(subscription): + print(result) + + +asyncio.run(main()) From 8d07646ad9010980c30851d6217b4a1e756bf0b3 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 30 Nov 2021 23:13:50 +0100 Subject: [PATCH 33/68] Rename file awsappsyncwebsockets.py to awsappsync.py --- docs/code_examples/aws_api_key_mutation.py | 2 +- docs/code_examples/aws_api_key_subscription.py | 2 +- docs/code_examples/aws_iam_subscription.py | 2 +- .../{awsappsyncwebsocket.py => awsappsync.py} | 0 tests/conftest.py | 2 +- tests/test_appsyncwebsocket.py | 18 +++++++++--------- 6 files changed, 13 insertions(+), 13 deletions(-) rename gql/transport/{awsappsyncwebsocket.py => awsappsync.py} (100%) diff --git a/docs/code_examples/aws_api_key_mutation.py b/docs/code_examples/aws_api_key_mutation.py index 44fde329..7623eecf 100644 --- a/docs/code_examples/aws_api_key_mutation.py +++ b/docs/code_examples/aws_api_key_mutation.py @@ -6,7 +6,7 @@ from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport -from gql.transport.awsappsyncwebsocket import AppSyncApiKeyAuthorization +from gql.transport.awsappsync import AppSyncApiKeyAuthorization logging.basicConfig(level=logging.DEBUG) diff --git a/docs/code_examples/aws_api_key_subscription.py b/docs/code_examples/aws_api_key_subscription.py index 6268e0a4..a9aa7b12 100644 --- a/docs/code_examples/aws_api_key_subscription.py +++ b/docs/code_examples/aws_api_key_subscription.py @@ -5,7 +5,7 @@ from urllib.parse import urlparse from gql import Client, gql -from gql.transport.awsappsyncwebsocket import ( +from gql.transport.awsappsync import ( AppSyncApiKeyAuthorization, AppSyncWebsocketsTransport, ) diff --git a/docs/code_examples/aws_iam_subscription.py b/docs/code_examples/aws_iam_subscription.py index 82394628..f0423d1c 100644 --- a/docs/code_examples/aws_iam_subscription.py +++ b/docs/code_examples/aws_iam_subscription.py @@ -4,7 +4,7 @@ import sys from gql import Client, gql -from gql.transport.awsappsyncwebsocket import AppSyncWebsocketsTransport +from gql.transport.awsappsync import AppSyncWebsocketsTransport logging.basicConfig(level=logging.DEBUG) diff --git a/gql/transport/awsappsyncwebsocket.py b/gql/transport/awsappsync.py similarity index 100% rename from gql/transport/awsappsyncwebsocket.py rename to gql/transport/awsappsync.py diff --git a/tests/conftest.py b/tests/conftest.py index f8865972..cef4eae6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -116,7 +116,7 @@ async def ssl_aiohttp_server(): for name in [ "websockets.legacy.server", "gql.transport.aiohttp", - # "gql.transport.awsappsyncwebsocket", + "gql.transport.awsappsync", "gql.transport.phoenix_channel_websockets", "gql.transport.requests", "gql.transport.websockets", diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 86844850..eee5caf5 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -7,7 +7,7 @@ def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): - from gql.transport.awsappsyncwebsocket import ( + from gql.transport.awsappsync import ( AppSyncIAMAuthorization, AppSyncWebsocketsTransport, ) @@ -25,7 +25,7 @@ def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): def test_appsyncwebsocket_init_with_no_credentials(caplog, fake_session_factory): import botocore.exceptions - from gql.transport.awsappsyncwebsocket import AppSyncWebsocketsTransport + from gql.transport.awsappsync import AppSyncWebsocketsTransport with pytest.raises(botocore.exceptions.NoCredentialsError): sample_transport = AppSyncWebsocketsTransport( @@ -41,7 +41,7 @@ def test_appsyncwebsocket_init_with_no_credentials(caplog, fake_session_factory) def test_appsyncwebsocket_init_with_oidc_auth(): - from gql.transport.awsappsyncwebsocket import ( + from gql.transport.awsappsync import ( AppSyncOIDCAuthorization, AppSyncWebsocketsTransport, ) @@ -54,7 +54,7 @@ def test_appsyncwebsocket_init_with_oidc_auth(): def test_appsyncwebsocket_init_with_apikey_auth(): - from gql.transport.awsappsyncwebsocket import ( + from gql.transport.awsappsync import ( AppSyncApiKeyAuthorization, AppSyncWebsocketsTransport, ) @@ -70,7 +70,7 @@ def test_appsyncwebsocket_init_with_apikey_auth(): def test_appsyncwebsocket_init_with_iam_auth_without_creds(fake_session_factory): import botocore.exceptions - from gql.transport.awsappsyncwebsocket import ( + from gql.transport.awsappsync import ( AppSyncIAMAuthorization, AppSyncWebsocketsTransport, ) @@ -83,7 +83,7 @@ def test_appsyncwebsocket_init_with_iam_auth_without_creds(fake_session_factory) def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory): - from gql.transport.awsappsyncwebsocket import ( + from gql.transport.awsappsync import ( AppSyncIAMAuthorization, AppSyncWebsocketsTransport, ) @@ -102,7 +102,7 @@ def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory def test_appsyncwebsocket_init_with_iam_auth_and_no_region( caplog, fake_credentials_factory ): - from gql.transport.awsappsyncwebsocket import ( + from gql.transport.awsappsync import ( AppSyncIAMAuthorization, AppSyncWebsocketsTransport, ) @@ -124,7 +124,7 @@ def test_appsyncwebsocket_init_with_iam_auth_and_no_region( def test_munge_url(fake_signer_factory, fake_request_factory): - from gql.transport.awsappsyncwebsocket import ( + from gql.transport.awsappsync import ( AppSyncIAMAuthorization, AppSyncWebsocketsTransport, ) @@ -154,7 +154,7 @@ def test_munge_url_format( fake_credentials_factory, fake_session_factory, ): - from gql.transport.awsappsyncwebsocket import AppSyncIAMAuthorization + from gql.transport.awsappsync import AppSyncIAMAuthorization test_url = "https://appsync-api.aws.example.org/some-other-params" From 7552baceacbecbc948bcdffee705b48dff1ab103 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Tue, 30 Nov 2021 23:16:09 +0100 Subject: [PATCH 34/68] setup.py allow more recent versions of botocore --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1d61f771..f30e230f 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ ] install_appsyncwebsockets_requires = [ - "botocore>=1.21,<1.22", + "botocore>=1.21,<2", ] + install_websockets_requires install_all_requires = ( From 590239d8a3b99ec9cf0d4d1a26b26ab4402d7f0a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 1 Dec 2021 13:34:41 +0100 Subject: [PATCH 35/68] IAM auth: fix host instead of url and headers --- gql/transport/awsappsync.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/gql/transport/awsappsync.py b/gql/transport/awsappsync.py index 096da2c6..f324b8b1 100644 --- a/gql/transport/awsappsync.py +++ b/gql/transport/awsappsync.py @@ -2,8 +2,10 @@ import logging from abc import ABC, abstractmethod from base64 import b64encode +from datetime import datetime from ssl import SSLContext -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union +from urllib.parse import urlparse import botocore.session from botocore.auth import SigV4Auth @@ -93,14 +95,8 @@ def __init__( request_creator if request_creator else create_request_object ) - def get_headers( - self, - data: Optional[Dict] = None, - request_creator: Callable[[Dict], AWSRequest] = None, - ) -> Dict: + def get_headers(self, data: Optional[Dict] = None,) -> Dict: - """ - Should we add other data in the headers field ? utc_now = datetime.utcnow() amz_date = utc_now.strftime("%Y%m%dT%H%M%SZ") @@ -111,8 +107,6 @@ def get_headers( "host": self._host, "x-amz-date": amz_date, } - """ - headers: Dict[str, Any] = {} request: AWSRequest = self._request_creator( { @@ -149,7 +143,11 @@ def __init__( ) -> None: try: if not authorization: - authorization = AppSyncIAMAuthorization(host=url, session=session) + + # Extract host from url + host = str(urlparse(url).netloc) + + authorization = AppSyncIAMAuthorization(host=host, session=session) self.authorization = authorization From 9ce4fa6deb65241dd4790f4603fa10202b7d3a44 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 1 Dec 2021 15:18:59 +0100 Subject: [PATCH 36/68] Fix IAM auth --- gql/transport/awsappsync.py | 49 +++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/gql/transport/awsappsync.py b/gql/transport/awsappsync.py index f324b8b1..f68f7610 100644 --- a/gql/transport/awsappsync.py +++ b/gql/transport/awsappsync.py @@ -2,7 +2,6 @@ import logging from abc import ABC, abstractmethod from base64 import b64encode -from datetime import datetime from ssl import SSLContext from typing import Any, Dict, Optional, Tuple, Union from urllib.parse import urlparse @@ -39,7 +38,7 @@ def get_auth_url(self, url: str) -> str: return f"{url_base}?header={encoded_headers}&payload=e30=" @abstractmethod - def get_headers(self, data: Optional[Dict] = None) -> Dict: + def get_headers(self, data: Optional[str] = None) -> Dict: raise NotImplementedError() @@ -48,7 +47,7 @@ def __init__(self, host: str, api_key: str) -> None: self._host = host self.api_key = api_key - def get_headers(self, data: Optional[Dict] = None) -> Dict: + def get_headers(self, data: Optional[str] = None) -> Dict: return {"host": self._host, "x-api-key": self.api_key} @@ -57,7 +56,7 @@ def __init__(self, host: str, jwt: str) -> None: self._host = host self.jwt = jwt - def get_headers(self, data: Optional[Dict] = None) -> Dict: + def get_headers(self, data: Optional[str] = None) -> Dict: return {"host": self._host, "Authorization": self.jwt} @@ -95,26 +94,21 @@ def __init__( request_creator if request_creator else create_request_object ) - def get_headers(self, data: Optional[Dict] = None,) -> Dict: - - utc_now = datetime.utcnow() - amz_date = utc_now.strftime("%Y%m%dT%H%M%SZ") + def get_headers(self, data: Optional[str] = None,) -> Dict: headers = { "accept": "application/json, text/javascript", "content-encoding": "amz-1.0", "content-type": "application/json; charset=UTF-8", - "host": self._host, - "x-amz-date": amz_date, } request: AWSRequest = self._request_creator( { "method": "POST", - "url": self._host, + "url": f"https://{self._host}/graphql{'' if data else '/connect'}", "headers": headers, "context": {}, - "body": data, + "body": data or "{}", } ) @@ -122,7 +116,15 @@ def get_headers(self, data: Optional[Dict] = None,) -> Dict: headers = dict(request.headers) - log.debug(f"\n\nSigned headers: {headers}\n\n") + headers["host"] = self._host + + if log.isEnabledFor(logging.DEBUG): + headers_log = [] + headers_log.append("\n\nSigned headers:") + for key, value in headers.items(): + headers_log.append(f" {key}: {value}") + headers_log.append("\n") + log.debug("\n".join(headers_log)) return headers @@ -232,26 +234,35 @@ async def _send_query( variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, ) -> int: + query_id = self.next_query_id + self.next_query_id += 1 data: Dict = {"query": print_ast(document)} + if variable_values: data["variables"] = variable_values + if operation_name: data["operationName"] = operation_name + serialized_data = json.dumps(data, separators=(",", ":")) + + payload = {"data": serialized_data} + message: Dict = { "id": str(query_id), "type": "start", - "payload": {"data": json.dumps(data, separators=(",", ":"))}, + "payload": payload, } - if self.authorization: - message["payload"]["extensions"] = { - "authorization": self.authorization.get_headers(data) - } + assert self.authorization is not None + + message["payload"]["extensions"] = { + "authorization": self.authorization.get_headers(serialized_data) + } - await self._send(json.dumps(message, separators=(",", ":"),)) + await self._send(json.dumps(message, separators=(",", ":"),)) return query_id From ad80541b23b2bb00be7f00f9a7a8c17b6a2946eb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 1 Dec 2021 20:46:27 +0100 Subject: [PATCH 37/68] fix IAM tests --- tests/test_appsyncwebsocket.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index eee5caf5..d437baf0 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -140,7 +140,11 @@ def test_munge_url(fake_signer_factory, fake_request_factory): url=test_url, authorization=authorization ) - header_string = "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5In0=" + header_string = ( + "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5" + "IiwiaG9zdCI6Imh0dHBzOi8vYXBwc3luYy1hcGkuYXdzLmV4YW1wbGUu" + "b3JnL3NvbWUtb3RoZXItcGFyYW1zIn0=" + ) expected_url = ( "wss://appsync-realtime-api.aws.example.org/" f"some-other-params?header={header_string}&payload=e30=" @@ -166,7 +170,11 @@ def test_munge_url_format( credentials=fake_credentials_factory(), ) - header_string = "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5In0=" + header_string = ( + "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5" + "IiwiaG9zdCI6Imh0dHBzOi8vYXBwc3luYy1hcGkuYXdzLmV4YW1wbGUu" + "b3JnL3NvbWUtb3RoZXItcGFyYW1zIn0=" + ) expected_url = ( "wss://appsync-realtime-api.aws.example.org/" f"some-other-params?header={header_string}&payload=e30=" From dfee6db2ca7f115a41859f32156579084ca58843 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 2 Dec 2021 10:18:05 +0100 Subject: [PATCH 38/68] add typing to AppSyncIAMAuthorization --- gql/transport/awsappsync.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/gql/transport/awsappsync.py b/gql/transport/awsappsync.py index f68f7610..d3115345 100644 --- a/gql/transport/awsappsync.py +++ b/gql/transport/awsappsync.py @@ -3,12 +3,13 @@ from abc import ABC, abstractmethod from base64 import b64encode from ssl import SSLContext -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union from urllib.parse import urlparse import botocore.session -from botocore.auth import SigV4Auth +from botocore.auth import BaseSigner, SigV4Auth from botocore.awsrequest import AWSRequest, create_request_object +from botocore.credentials import Credentials from botocore.exceptions import NoCredentialsError from botocore.session import get_session from graphql import DocumentNode, ExecutionResult, print_ast @@ -70,11 +71,11 @@ class AppSyncIAMAuthorization(AppSyncAuthorization): def __init__( self, host: str, - region_name=None, - signer=None, - request_creator=None, - credentials=None, - session=None, + region_name: Optional[str] = None, + signer: Optional[BaseSigner] = None, + request_creator: Optional[Callable[[Dict[str, Any]], AWSRequest]] = None, + credentials: Optional[Credentials] = None, + session: Optional[botocore.session.Session] = None, ) -> None: self._host = host self._session = session if session else get_session() From 850ad91f9da89bb59ffdb3d9b50678023ca7473f Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 2 Dec 2021 10:54:03 +0100 Subject: [PATCH 39/68] Refactor parse_answer to have json.loads happening only once --- gql/transport/awsappsync.py | 8 ++++---- gql/transport/websockets.py | 23 +++++++++++++---------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/gql/transport/awsappsync.py b/gql/transport/awsappsync.py index d3115345..07f5dfb2 100644 --- a/gql/transport/awsappsync.py +++ b/gql/transport/awsappsync.py @@ -214,18 +214,18 @@ def _parse_answer( if answer_type == "start_ack": return ("start_ack", None, None) - elif answer_type == "error" and id not in json_answer: + elif answer_type == "error" and "id" not in json_answer: error_payload = json_answer.get("payload") raise TransportServerError(f"Server error: '{error_payload!r}'") else: - return self._parse_answer_apollo(answer) + return self._parse_answer_apollo(json_answer) - except ValueError as e: + except ValueError: raise TransportProtocolError( f"Server did not return a GraphQL result: {answer}" - ) from e + ) return answer_type, answer_id, execution_result diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 29c381d5..bdd485e2 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -379,7 +379,7 @@ async def _send_query( return query_id def _parse_answer_graphqlws( - self, answer: str + self, json_answer: Dict[str, Any] ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server if the server supports the graphql-ws protocol. @@ -404,8 +404,6 @@ def _parse_answer_graphqlws( execution_result: Optional[ExecutionResult] = None try: - json_answer = json.loads(answer) - answer_type = str(json_answer.get("type")) if answer_type in ["next", "error", "complete"]: @@ -451,13 +449,13 @@ def _parse_answer_graphqlws( except ValueError as e: raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" + f"Server did not return a GraphQL result: {json_answer}" ) from e return answer_type, answer_id, execution_result def _parse_answer_apollo( - self, answer: str + self, json_answer: Dict[str, Any] ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server if the server supports the apollo websockets protocol. @@ -474,8 +472,6 @@ def _parse_answer_apollo( execution_result: Optional[ExecutionResult] = None try: - json_answer = json.loads(answer) - answer_type = str(json_answer.get("type")) if answer_type in ["data", "error", "complete"]: @@ -521,7 +517,7 @@ def _parse_answer_apollo( except ValueError as e: raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" + f"Server did not return a GraphQL result: {json_answer}" ) from e return answer_type, answer_id, execution_result @@ -532,10 +528,17 @@ def _parse_answer( """Parse the answer received from the server depending on the detected subprotocol. """ + try: + json_answer = json.loads(answer) + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(answer) + return self._parse_answer_graphqlws(json_answer) - return self._parse_answer_apollo(answer) + return self._parse_answer_apollo(json_answer) async def _check_ws_liveness(self) -> None: """Coroutine which will periodically check the liveness of the connection From fe22631c0a883e4ac69a400b29453eb2e627e758 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 2 Dec 2021 14:34:56 +0100 Subject: [PATCH 40/68] Adding some documentation --- docs/modules/gql.rst | 5 ++ docs/modules/transport.rst | 10 +-- docs/modules/transport_aiohttp.rst | 7 ++ docs/modules/transport_appsync.rst | 7 ++ .../transport_phoenix_channel_websockets.rst | 7 ++ docs/modules/transport_requests.rst | 7 ++ docs/modules/transport_websockets.rst | 7 ++ docs/transports/appsync.rst | 58 ++++++++++++++++ docs/transports/async_transports.rst | 1 + gql/transport/awsappsync.py | 67 ++++++++++++++++++- 10 files changed, 164 insertions(+), 12 deletions(-) create mode 100644 docs/modules/transport_aiohttp.rst create mode 100644 docs/modules/transport_appsync.rst create mode 100644 docs/modules/transport_phoenix_channel_websockets.rst create mode 100644 docs/modules/transport_requests.rst create mode 100644 docs/modules/transport_websockets.rst create mode 100644 docs/transports/appsync.rst diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index 6730e07b..01cda657 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -20,6 +20,11 @@ Sub-Packages client transport + transport_aiohttp + transport_appsync transport_exceptions + transport_phoenix_channel_websockets + transport_requests + transport_websockets dsl utilities diff --git a/docs/modules/transport.rst b/docs/modules/transport.rst index 1b250d7a..d03dbf1f 100644 --- a/docs/modules/transport.rst +++ b/docs/modules/transport.rst @@ -5,14 +5,6 @@ gql.transport .. autoclass:: gql.transport.transport.Transport -.. autoclass:: gql.transport.local_schema.LocalSchemaTransport - -.. autoclass:: gql.transport.requests.RequestsHTTPTransport - .. autoclass:: gql.transport.async_transport.AsyncTransport -.. autoclass:: gql.transport.aiohttp.AIOHTTPTransport - -.. autoclass:: gql.transport.websockets.WebsocketsTransport - -.. autoclass:: gql.transport.phoenix_channel_websockets.PhoenixChannelWebsocketsTransport +.. autoclass:: gql.transport.local_schema.LocalSchemaTransport diff --git a/docs/modules/transport_aiohttp.rst b/docs/modules/transport_aiohttp.rst new file mode 100644 index 00000000..41cebd99 --- /dev/null +++ b/docs/modules/transport_aiohttp.rst @@ -0,0 +1,7 @@ +gql.transport.aiohttp +===================== + +.. currentmodule:: gql.transport.aiohttp + +.. automodule:: gql.transport.aiohttp + :member-order: bysource diff --git a/docs/modules/transport_appsync.rst b/docs/modules/transport_appsync.rst new file mode 100644 index 00000000..1369ef2c --- /dev/null +++ b/docs/modules/transport_appsync.rst @@ -0,0 +1,7 @@ +gql.transport.awsappsync +======================== + +.. currentmodule:: gql.transport.awsappsync + +.. automodule:: gql.transport.awsappsync + :member-order: bysource diff --git a/docs/modules/transport_phoenix_channel_websockets.rst b/docs/modules/transport_phoenix_channel_websockets.rst new file mode 100644 index 00000000..5f412a33 --- /dev/null +++ b/docs/modules/transport_phoenix_channel_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.phoenix_channel_websockets +======================================== + +.. currentmodule:: gql.transport.phoenix_channel_websockets + +.. automodule:: gql.transport.phoenix_channel_websockets + :member-order: bysource diff --git a/docs/modules/transport_requests.rst b/docs/modules/transport_requests.rst new file mode 100644 index 00000000..78a07a02 --- /dev/null +++ b/docs/modules/transport_requests.rst @@ -0,0 +1,7 @@ +gql.transport.requests +====================== + +.. currentmodule:: gql.transport.requests + +.. automodule:: gql.transport.requests + :member-order: bysource diff --git a/docs/modules/transport_websockets.rst b/docs/modules/transport_websockets.rst new file mode 100644 index 00000000..9a924afd --- /dev/null +++ b/docs/modules/transport_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.websockets +======================== + +.. currentmodule:: gql.transport.websockets + +.. automodule:: gql.transport.websockets + :member-order: bysource diff --git a/docs/transports/appsync.rst b/docs/transports/appsync.rst new file mode 100644 index 00000000..b41c1049 --- /dev/null +++ b/docs/transports/appsync.rst @@ -0,0 +1,58 @@ +.. _appsync_transport: + +AppSyncWebsocketsTransport +========================== + +AWS AppSync allows you to execute GraphQL subscriptions on its realtime GraphQL endpoint. + +See `Building a real-time websocket client`_ for an explanation. + +GQL provides the :code:`AppSyncWebsocketsTransport` transport which implements this +for you to allow you to execute subscriptions. + +.. note:: + It is only possible to execute subscriptions with this transport + +How to use it: + + * choose one :ref:`authentication method ` (API key, IAM, Cognito user pools or OIDC) + * instantiate a :code:`AppSyncWebsocketsTransport` with your GraphQL endpoint as url and your auth method + +.. note:: + It is also possible to instantiate the transport without an auth argument. In that case, + gql will use by default the :class:`IAM auth ` + which will try to authenticate with environment variables or from your aws credentials file. + +Full example with API key authentication from environment variables: + +.. literalinclude:: ../code_examples/aws_api_key_subscription.py + +Reference: :class:`gql.transport.awsappsync.AppSyncWebsocketsTransport` + + +.. _appsync_authentication_methods: + +Authentication methods +---------------------- + +API key +^^^^^^^ + +Reference: :class:`gql.transport.awsappsync.AppSyncApiKeyAuthorization` + +IAM +^^^ + +Reference: :class:`gql.transport.awsappsync.AppSyncIAMAuthorization` + +Amazon Cognito user pools +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Reference: :class:`gql.transport.awsappsync.AppSyncOIDCAuthorization` + +OpenID Connect (OIDC) +^^^^^^^^^^^^^^^^^^^^^ + +Reference: :class:`gql.transport.awsappsync.AppSyncCognitoUserPoolAuthorization` + +.. _Building a real-time websocket client: https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html diff --git a/docs/transports/async_transports.rst b/docs/transports/async_transports.rst index 9fb1b017..df8c23cf 100644 --- a/docs/transports/async_transports.rst +++ b/docs/transports/async_transports.rst @@ -12,3 +12,4 @@ Async transports are transports which are using an underlying async library. The aiohttp websockets phoenix + appsync diff --git a/gql/transport/awsappsync.py b/gql/transport/awsappsync.py index 07f5dfb2..055cf446 100644 --- a/gql/transport/awsappsync.py +++ b/gql/transport/awsappsync.py @@ -21,10 +21,16 @@ class AppSyncAuthorization(ABC): + """AWS authorization abstract base class + + All AWS authorization class should have a + :meth:`get_headers ` + method which defines the headers used in the authentication process.""" + def get_auth_url(self, url: str) -> str: """ - :return: a url used to establish websocket connections - to the appsync-realtime-api + :return: a url with base64 encoded headers used to establish + a websocket connection to the appsync-realtime-api. """ headers = self.get_headers() @@ -44,7 +50,14 @@ def get_headers(self, data: Optional[str] = None) -> Dict: class AppSyncApiKeyAuthorization(AppSyncAuthorization): + """AWS authorization class using an API key""" + def __init__(self, host: str, api_key: str) -> None: + """ + :param host: the host, something like: + XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com + :param api_key: the API key + """ self._host = host self.api_key = api_key @@ -53,7 +66,14 @@ def get_headers(self, data: Optional[str] = None) -> Dict: class AppSyncOIDCAuthorization(AppSyncAuthorization): + """AWS authorization class using an OpenID JWT access token""" + def __init__(self, host: str, jwt: str) -> None: + """ + :param host: the host, something like: + XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com + :param jwt: the JWT Access Token + """ self._host = host self.jwt = jwt @@ -62,12 +82,24 @@ def get_headers(self, data: Optional[str] = None) -> Dict: class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): - """Alias for AppSyncOIDCAuthorization""" + """AWS authorization class using a Cognito user pools JWT access token""" pass class AppSyncIAMAuthorization(AppSyncAuthorization): + """AWS authorization class using IAM. + + .. note:: + There is no need for you to use this class directly, you could instead + intantiate the :class:`gql.transport.awsappsync.AppSyncWebsocketsTransport` + without an auth argument. + + During initialization, this class will use botocore to attempt to + find your IAM credentials, either from environment variables or + from your AWS credentials file. + """ + def __init__( self, host: str, @@ -77,6 +109,11 @@ def __init__( credentials: Optional[Credentials] = None, session: Optional[botocore.session.Session] = None, ) -> None: + """Initialize itself, saving the found credentials used + to sign the headers later. + + if no credentials are found, then a NoCredentialsError is raised. + """ self._host = host self._session = session if session else get_session() self._credentials = ( @@ -131,6 +168,13 @@ def get_headers(self, data: Optional[str] = None,) -> Dict: class AppSyncWebsocketsTransport(WebsocketsTransport): + """:ref:`Async Transport ` used to execute GraphQL subscription on + AWS appsync realtime endpoint. + + This transport uses asyncio and the websockets library in order to send requests + on a websocket connection. + """ + authorization: Optional[AppSyncAuthorization] def __init__( @@ -144,6 +188,23 @@ def __init__( ack_timeout: int = 10, connect_args: Dict[str, Any] = {}, ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL endpoint URL. Example: + https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + :param authorization: Optional AWS authorization class which will provide the + necessary headers to be correctly authenticated. If this + argument is not provided, then we will try to authenticate + using IAM. + :param ssl: ssl_context of the connection. + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param connect_args: Other parameters forwarded to websockets.connect + """ try: if not authorization: From 413c6512d4c6348f3befb1c660db7ef06dd535bb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 2 Dec 2021 14:39:31 +0100 Subject: [PATCH 41/68] Rename awsappsync.py to appsync.py --- docs/code_examples/aws_api_key_mutation.py | 2 +- docs/code_examples/aws_api_key_subscription.py | 5 +---- docs/code_examples/aws_iam_subscription.py | 2 +- docs/modules/transport_appsync.rst | 8 ++++---- docs/transports/appsync.rst | 12 ++++++------ gql/transport/{awsappsync.py => appsync.py} | 4 ++-- tests/test_appsyncwebsocket.py | 18 +++++++++--------- 7 files changed, 24 insertions(+), 27 deletions(-) rename gql/transport/{awsappsync.py => appsync.py} (98%) diff --git a/docs/code_examples/aws_api_key_mutation.py b/docs/code_examples/aws_api_key_mutation.py index 7623eecf..9fa5f3cc 100644 --- a/docs/code_examples/aws_api_key_mutation.py +++ b/docs/code_examples/aws_api_key_mutation.py @@ -6,7 +6,7 @@ from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport -from gql.transport.awsappsync import AppSyncApiKeyAuthorization +from gql.transport.appsync import AppSyncApiKeyAuthorization logging.basicConfig(level=logging.DEBUG) diff --git a/docs/code_examples/aws_api_key_subscription.py b/docs/code_examples/aws_api_key_subscription.py index a9aa7b12..3e994911 100644 --- a/docs/code_examples/aws_api_key_subscription.py +++ b/docs/code_examples/aws_api_key_subscription.py @@ -5,10 +5,7 @@ from urllib.parse import urlparse from gql import Client, gql -from gql.transport.awsappsync import ( - AppSyncApiKeyAuthorization, - AppSyncWebsocketsTransport, -) +from gql.transport.appsync import AppSyncApiKeyAuthorization, AppSyncWebsocketsTransport logging.basicConfig(level=logging.DEBUG) diff --git a/docs/code_examples/aws_iam_subscription.py b/docs/code_examples/aws_iam_subscription.py index f0423d1c..e891fc44 100644 --- a/docs/code_examples/aws_iam_subscription.py +++ b/docs/code_examples/aws_iam_subscription.py @@ -4,7 +4,7 @@ import sys from gql import Client, gql -from gql.transport.awsappsync import AppSyncWebsocketsTransport +from gql.transport.appsync import AppSyncWebsocketsTransport logging.basicConfig(level=logging.DEBUG) diff --git a/docs/modules/transport_appsync.rst b/docs/modules/transport_appsync.rst index 1369ef2c..f7360088 100644 --- a/docs/modules/transport_appsync.rst +++ b/docs/modules/transport_appsync.rst @@ -1,7 +1,7 @@ -gql.transport.awsappsync -======================== +gql.transport.appsync +===================== -.. currentmodule:: gql.transport.awsappsync +.. currentmodule:: gql.transport.appsync -.. automodule:: gql.transport.awsappsync +.. automodule:: gql.transport.appsync :member-order: bysource diff --git a/docs/transports/appsync.rst b/docs/transports/appsync.rst index b41c1049..41e15a35 100644 --- a/docs/transports/appsync.rst +++ b/docs/transports/appsync.rst @@ -20,14 +20,14 @@ How to use it: .. note:: It is also possible to instantiate the transport without an auth argument. In that case, - gql will use by default the :class:`IAM auth ` + gql will use by default the :class:`IAM auth ` which will try to authenticate with environment variables or from your aws credentials file. Full example with API key authentication from environment variables: .. literalinclude:: ../code_examples/aws_api_key_subscription.py -Reference: :class:`gql.transport.awsappsync.AppSyncWebsocketsTransport` +Reference: :class:`gql.transport.appsync.AppSyncWebsocketsTransport` .. _appsync_authentication_methods: @@ -38,21 +38,21 @@ Authentication methods API key ^^^^^^^ -Reference: :class:`gql.transport.awsappsync.AppSyncApiKeyAuthorization` +Reference: :class:`gql.transport.appsync.AppSyncApiKeyAuthorization` IAM ^^^ -Reference: :class:`gql.transport.awsappsync.AppSyncIAMAuthorization` +Reference: :class:`gql.transport.appsync.AppSyncIAMAuthorization` Amazon Cognito user pools ^^^^^^^^^^^^^^^^^^^^^^^^^ -Reference: :class:`gql.transport.awsappsync.AppSyncOIDCAuthorization` +Reference: :class:`gql.transport.appsync.AppSyncOIDCAuthorization` OpenID Connect (OIDC) ^^^^^^^^^^^^^^^^^^^^^ -Reference: :class:`gql.transport.awsappsync.AppSyncCognitoUserPoolAuthorization` +Reference: :class:`gql.transport.appsync.AppSyncCognitoUserPoolAuthorization` .. _Building a real-time websocket client: https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html diff --git a/gql/transport/awsappsync.py b/gql/transport/appsync.py similarity index 98% rename from gql/transport/awsappsync.py rename to gql/transport/appsync.py index 055cf446..1befc928 100644 --- a/gql/transport/awsappsync.py +++ b/gql/transport/appsync.py @@ -24,7 +24,7 @@ class AppSyncAuthorization(ABC): """AWS authorization abstract base class All AWS authorization class should have a - :meth:`get_headers ` + :meth:`get_headers ` method which defines the headers used in the authentication process.""" def get_auth_url(self, url: str) -> str: @@ -92,7 +92,7 @@ class AppSyncIAMAuthorization(AppSyncAuthorization): .. note:: There is no need for you to use this class directly, you could instead - intantiate the :class:`gql.transport.awsappsync.AppSyncWebsocketsTransport` + intantiate the :class:`gql.transport.appsync.AppSyncWebsocketsTransport` without an auth argument. During initialization, this class will use botocore to attempt to diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index d437baf0..f5d03079 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -7,7 +7,7 @@ def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): - from gql.transport.awsappsync import ( + from gql.transport.appsync import ( AppSyncIAMAuthorization, AppSyncWebsocketsTransport, ) @@ -25,7 +25,7 @@ def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): def test_appsyncwebsocket_init_with_no_credentials(caplog, fake_session_factory): import botocore.exceptions - from gql.transport.awsappsync import AppSyncWebsocketsTransport + from gql.transport.appsync import AppSyncWebsocketsTransport with pytest.raises(botocore.exceptions.NoCredentialsError): sample_transport = AppSyncWebsocketsTransport( @@ -41,7 +41,7 @@ def test_appsyncwebsocket_init_with_no_credentials(caplog, fake_session_factory) def test_appsyncwebsocket_init_with_oidc_auth(): - from gql.transport.awsappsync import ( + from gql.transport.appsync import ( AppSyncOIDCAuthorization, AppSyncWebsocketsTransport, ) @@ -54,7 +54,7 @@ def test_appsyncwebsocket_init_with_oidc_auth(): def test_appsyncwebsocket_init_with_apikey_auth(): - from gql.transport.awsappsync import ( + from gql.transport.appsync import ( AppSyncApiKeyAuthorization, AppSyncWebsocketsTransport, ) @@ -70,7 +70,7 @@ def test_appsyncwebsocket_init_with_apikey_auth(): def test_appsyncwebsocket_init_with_iam_auth_without_creds(fake_session_factory): import botocore.exceptions - from gql.transport.awsappsync import ( + from gql.transport.appsync import ( AppSyncIAMAuthorization, AppSyncWebsocketsTransport, ) @@ -83,7 +83,7 @@ def test_appsyncwebsocket_init_with_iam_auth_without_creds(fake_session_factory) def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory): - from gql.transport.awsappsync import ( + from gql.transport.appsync import ( AppSyncIAMAuthorization, AppSyncWebsocketsTransport, ) @@ -102,7 +102,7 @@ def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory def test_appsyncwebsocket_init_with_iam_auth_and_no_region( caplog, fake_credentials_factory ): - from gql.transport.awsappsync import ( + from gql.transport.appsync import ( AppSyncIAMAuthorization, AppSyncWebsocketsTransport, ) @@ -124,7 +124,7 @@ def test_appsyncwebsocket_init_with_iam_auth_and_no_region( def test_munge_url(fake_signer_factory, fake_request_factory): - from gql.transport.awsappsync import ( + from gql.transport.appsync import ( AppSyncIAMAuthorization, AppSyncWebsocketsTransport, ) @@ -158,7 +158,7 @@ def test_munge_url_format( fake_credentials_factory, fake_session_factory, ): - from gql.transport.awsappsync import AppSyncIAMAuthorization + from gql.transport.appsync import AppSyncIAMAuthorization test_url = "https://appsync-api.aws.example.org/some-other-params" From 9aad6ff5b9bc314fe3b14d0d96a49712827612c4 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 2 Dec 2021 16:45:45 +0100 Subject: [PATCH 42/68] Split WebsocketsTransport into WebsocketsTransportBase and WebsocketsTransport --- gql/transport/websockets.py | 1360 +++++++++++++++++++---------------- 1 file changed, 749 insertions(+), 611 deletions(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index bdd485e2..ebb777a4 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -2,9 +2,10 @@ import json import logging import warnings +from abc import abstractmethod from contextlib import suppress from ssl import SSLContext -from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union, cast +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast import websockets from graphql import DocumentNode, ExecutionResult, print_ast @@ -79,19 +80,14 @@ async def set_exception(self, exception: Exception) -> None: self._closed = True -class WebsocketsTransport(AsyncTransport): - """:ref:`Async Transport ` used to execute GraphQL queries on - remote servers with websocket connection. +class WebsocketsTransportBase(AsyncTransport): + """abstract :ref:`Async Transport ` used to implement + different websockets protocols. This transport uses asyncio and the websockets library in order to send requests on a websocket connection. """ - # This transport supports two subprotocols and will autodetect the - # subprotocol supported on the server - APOLLO_SUBPROTOCOL = cast(Subprotocol, "graphql-ws") - GRAPHQLWS_SUBPROTOCOL = cast(Subprotocol, "graphql-transport-ws") - def __init__( self, url: str, @@ -102,9 +98,6 @@ def __init__( close_timeout: Optional[Union[int, float]] = 10, ack_timeout: Optional[Union[int, float]] = 10, keep_alive_timeout: Optional[Union[int, float]] = None, - ping_interval: Optional[Union[int, float]] = None, - pong_timeout: Optional[Union[int, float]] = None, - answer_pings: bool = True, connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given parameters. @@ -121,37 +114,18 @@ def __init__( from the server. If None is provided this will wait forever. :param keep_alive_timeout: Optional Timeout in seconds to receive a sign of liveness from the server. - :param ping_interval: Delay in seconds between pings sent by the client to - the backend for the graphql-ws protocol. None (by default) means that - we don't send pings. - :param pong_timeout: Delay in seconds to receive a pong from the backend - after we sent a ping (only for the graphql-ws protocol). - By default equal to half of the ping_interval. - :param answer_pings: Whether the client answers the pings from the backend - (for the graphql-ws protocol). - By default: True :param connect_args: Other parameters forwarded to websockets.connect """ - log.debug(f"WebsocketTransport url = {url}") self.url: str = url - self.ssl: Union[SSLContext, bool] = ssl self.headers: Optional[HeadersLike] = headers + self.ssl: Union[SSLContext, bool] = ssl self.init_payload: Dict[str, Any] = init_payload self.connect_timeout: Optional[Union[int, float]] = connect_timeout self.close_timeout: Optional[Union[int, float]] = close_timeout self.ack_timeout: Optional[Union[int, float]] = ack_timeout self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout - self.ping_interval: Optional[Union[int, float]] = ping_interval - self.pong_timeout: Optional[Union[int, float]] - self.answer_pings: bool = answer_pings - - if ping_interval is not None: - if pong_timeout is None: - self.pong_timeout = ping_interval / 2 - else: - self.pong_timeout = pong_timeout self.connect_args = connect_args @@ -161,7 +135,6 @@ def __init__( self.receive_data_task: Optional[asyncio.Future] = None self.check_keep_alive_task: Optional[asyncio.Future] = None - self.send_ping_task: Optional[asyncio.Future] = None self.close_task: Optional[asyncio.Future] = None # We need to set an event loop here if there is none @@ -186,27 +159,51 @@ def __init__( self._next_keep_alive_message: asyncio.Event = asyncio.Event() self._next_keep_alive_message.set() - self.ping_received: asyncio.Event = asyncio.Event() - """ping_received is an asyncio Event which will fire each time - a ping is received with the graphql-ws protocol""" - - self.pong_received: asyncio.Event = asyncio.Event() - """pong_received is an asyncio Event which will fire each time - a pong is received with the graphql-ws protocol""" - self.payloads: Dict[str, Any] = {} """payloads is a dict which will contain the payloads received - with the graphql-ws protocol. - Possible keys are: 'ping', 'pong', 'connection_ack'""" + for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" self._connecting: bool = False self.close_exception: Optional[Exception] = None - self.supported_subprotocols = [ - self.APOLLO_SUBPROTOCOL, - self.GRAPHQLWS_SUBPROTOCOL, - ] + # The list of supported subprotocols should be defined in the subclass + self.supported_subprotocols: List[Subprotocol] = [] + + async def _initialize(self) -> None: + """Hook to send the initialization messages after the connection + and potentially wait for the backend ack. + """ + pass # pragma: no cover + + async def _stop_listener(self, query_id: int) -> None: + """Hook to stop to listen to a specific query. + Will send a stop message in some subclasses. + """ + pass # pragma: no cover + + async def _after_connect(self): + """Hook to add custom code for subclasses after the connection + has been established. + """ + pass # pragma: no cover + + async def _after_initialize(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + pass # pragma: no cover + + async def _close_hook(self): + """Hook to add custom code for subclasses for the connection close + """ + pass # pragma: no cover + + async def _connection_terminate(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + pass # pragma: no cover async def _send(self, message: str) -> None: """Send the provided message to the websocket connection and log the message""" @@ -244,767 +241,908 @@ async def _receive(self) -> str: return answer - async def _wait_ack(self) -> None: - """Wait for the connection_ack message. Keep alive messages are ignored""" - - while True: - init_answer = await self._receive() - - answer_type, answer_id, execution_result = self._parse_answer(init_answer) - - if answer_type == "connection_ack": - return - - if answer_type != "ka": - raise TransportProtocolError( - "Websocket server did not return a connection ack" - ) + @abstractmethod + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + raise NotImplementedError # pragma: no cover - async def _send_init_message_and_wait_ack(self) -> None: - """Send init message to the provided websocket and wait for the connection ACK. + @abstractmethod + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + raise NotImplementedError # pragma: no cover - If the answer is not a connection_ack message, we will return an Exception. + async def _check_ws_liveness(self) -> None: + """Coroutine which will periodically check the liveness of the connection + through keep-alive messages """ - init_message = json.dumps( - {"type": "connection_init", "payload": self.init_payload} - ) + try: + while True: + await asyncio.wait_for( + self._next_keep_alive_message.wait(), self.keep_alive_timeout + ) - await self._send(init_message) + # Reset for the next iteration + self._next_keep_alive_message.clear() - # Wait for the connection_ack message or raise a TimeoutError - await asyncio.wait_for(self._wait_ack(), self.ack_timeout) + except asyncio.TimeoutError: + # No keep-alive message in the appriopriate interval, close with error + # while trying to notify the server of a proper close (in case + # the keep-alive interval of the client or server was not aligned + # the connection still remains) - async def send_ping(self, payload: Optional[Any] = None) -> None: - """Send a ping message for the graphql-ws protocol - """ + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + "No keep-alive message has been received within " + "the expected interval ('keep_alive_timeout' parameter)" + ), + clean_close=False, + ) - ping_message = {"type": "ping"} + except asyncio.CancelledError: + # The client is probably closing, handle it properly + pass - if payload is not None: - ping_message["payload"] = payload + async def _receive_data_loop(self) -> None: + """Main asyncio task which will listen to the incoming messages and will + call the parse_answer and handle_answer methods of the subclass.""" + try: + while True: - await self._send(json.dumps(ping_message)) + # Wait the next answer from the websocket server + try: + answer = await self._receive() + except (ConnectionClosed, TransportProtocolError) as e: + await self._fail(e, clean_close=False) + break + except TransportClosed: + break - async def send_pong(self, payload: Optional[Any] = None) -> None: - """Send a pong message for the graphql-ws protocol - """ + # Parse the answer + try: + answer_type, answer_id, execution_result = self._parse_answer( + answer + ) + except TransportQueryError as e: + # Received an exception for a specific query + # ==> Add an exception to this query queue + # The exception is raised for this specific query, + # but the transport is not closed. + assert isinstance( + e.query_id, int + ), "TransportQueryError should have a query_id defined here" + try: + await self.listeners[e.query_id].set_exception(e) + except KeyError: + # Do nothing if no one is listening to this query_id + pass - pong_message = {"type": "pong"} + continue - if payload is not None: - pong_message["payload"] = payload + except (TransportServerError, TransportProtocolError) as e: + # Received a global exception for this transport + # ==> close the transport + # The exception will be raised for all current queries. + await self._fail(e, clean_close=False) + break - await self._send(json.dumps(pong_message)) + await self._handle_answer(answer_type, answer_id, execution_result) - async def _send_stop_message(self, query_id: int) -> None: - """Send stop message to the provided websocket connection and query_id. + finally: + log.debug("Exiting _receive_data_loop()") - The server should afterwards return a 'complete' message. - """ + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: - stop_message = json.dumps({"id": str(query_id), "type": "stop"}) + try: + # Put the answer in the queue + if answer_id is not None: + await self.listeners[answer_id].put((answer_type, execution_result)) + except KeyError: + # Do nothing if no one is listening to this query_id. + pass - await self._send(stop_message) + async def subscribe( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + send_stop: Optional[bool] = True, + ) -> AsyncGenerator[ExecutionResult, None]: + """Send a query and receive the results using a python async generator. - async def _send_complete_message(self, query_id: int) -> None: - """Send a complete message for the provided query_id. + The query can be a graphql query, mutation or subscription. - This is only for the graphql-ws protocol. + The results are sent as an ExecutionResult object. """ - complete_message = json.dumps({"id": str(query_id), "type": "complete"}) + # Send the query and receive the id + query_id: int = await self._send_query( + document, variable_values, operation_name + ) - await self._send(complete_message) + # Create a queue to receive the answers for this query_id + listener = ListenerQueue(query_id, send_stop=(send_stop is True)) + self.listeners[query_id] = listener - async def _stop_listener(self, query_id: int) -> None: - """Stop the listener corresponding to the query_id depending on the - detected backend protocol. + # We will need to wait at close for this query to clean properly + self._no_more_listeners.clear() - For apollo: send a "stop" message - (a "complete" message will be sent from the backend) + try: + # Loop over the received answers + while True: - For graphql-ws: send a "complete" message and simulate the reception - of a "complete" message from the backend - """ - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - await self._send_complete_message(query_id) - await self.listeners[query_id].put(("complete", None)) - else: - await self._send_stop_message(query_id) + # Wait for the answer from the queue of this query_id + # This can raise a TransportError or ConnectionClosed exception. + answer_type, execution_result = await listener.get() - async def _send_connection_terminate_message(self) -> None: - """Send a connection_terminate message to the provided websocket connection. + # If the received answer contains data, + # Then we will yield the results back as an ExecutionResult object + if execution_result is not None: + yield execution_result - This message indicates that the connection will disconnect. - """ + # If we receive a 'complete' answer from the server, + # Then we will end this async generator output without errors + elif answer_type == "complete": + log.debug( + f"Complete received for query {query_id} --> exit without error" + ) + break - connection_terminate_message = json.dumps({"type": "connection_terminate"}) + except (asyncio.CancelledError, GeneratorExit) as e: + log.debug(f"Exception in subscribe: {e!r}") + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False - await self._send(connection_terminate_message) + finally: + log.debug(f"In subscribe finally for query_id {query_id}") + self._remove_listener(query_id) - async def _send_query( + async def execute( self, document: DocumentNode, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - ) -> int: - """Send a query to the provided websocket connection. + ) -> ExecutionResult: + """Execute the provided document AST against the configured remote server + using the current session. - We use an incremented id to reference the query. + Send a query but close the async generator as soon as we have the first answer. - Returns the used id for this query. + The result is sent as an ExecutionResult object. """ + first_result = None - query_id = self.next_query_id - self.next_query_id += 1 - - payload: Dict[str, Any] = {"query": print_ast(document)} - if variable_values: - payload["variables"] = variable_values - if operation_name: - payload["operationName"] = operation_name + generator = self.subscribe( + document, variable_values, operation_name, send_stop=False + ) - query_type = "start" + async for result in generator: + first_result = result - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - query_type = "subscribe" + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + await generator.aclose() - query_str = json.dumps( - {"id": str(query_id), "type": query_type, "payload": payload} - ) + break - await self._send(query_str) + if first_result is None: + raise TransportQueryError( + "Query completed without any answer received from the server" + ) - return query_id + return first_result - def _parse_answer_graphqlws( - self, json_answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - graphql-ws protocol. + async def connect(self) -> None: + """Coroutine which will: - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None + - connect to the websocket address + - send the init message + - wait for the connection acknowledge from the server + - create an asyncio task which will be used to receive + and parse the websocket answers - Differences with the apollo websockets protocol (superclass): - - the "data" message is now called "next" - - the "stop" message is now called "complete" - - there is no connection_terminate or connection_error messages - - instead of a unidirectional keep-alive (ka) message from server to client, - there is now the possibility to send bidirectional ping/pong messages - - connection_ack has an optional payload + Should be cleaned with a call to the close coroutine """ - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None + log.debug("connect: starting") - try: - answer_type = str(json_answer.get("type")) + if self.websocket is None and not self._connecting: - if answer_type in ["next", "error", "complete"]: - answer_id = int(str(json_answer.get("id"))) + # Set connecting to True to avoid a race condition if user is trying + # to connect twice using the same client at the same time + self._connecting = True - if answer_type == "next" or answer_type == "error": + # If the ssl parameter is not provided, + # generate the ssl value depending on the url + ssl: Optional[Union[SSLContext, bool]] + if self.ssl: + ssl = self.ssl + else: + ssl = True if self.url.startswith("wss") else None - payload = json_answer.get("payload") + # Set default arguments used in the websockets.connect call + connect_args: Dict[str, Any] = { + "ssl": ssl, + "extra_headers": self.headers, + "subprotocols": self.supported_subprotocols, + } - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") + # Adding custom parameters passed from init + connect_args.update(self.connect_args) - if answer_type == "next": + # Connection to the specified url + # Generate a TimeoutError if taking more than connect_timeout seconds + # Set the _connecting flag to False after in all cases + try: + self.websocket = await asyncio.wait_for( + websockets.client.connect(self.url, **connect_args), + self.connect_timeout, + ) + finally: + self._connecting = False - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) + self.websocket = cast(WebSocketClientProtocol, self.websocket) - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) + # Run the after_connect hook of the subclass + await self._after_connect() - # Saving answer_type as 'data' to be understood with superclass - answer_type = "data" + self.next_query_id = 1 + self.close_exception = None + self._wait_closed.clear() - elif answer_type == "error": + # Send the init message and wait for the ack from the server + # Note: This should generate a TimeoutError + # if no ACKs are received within the ack_timeout + try: + await self._initialize() + except ConnectionClosed as e: + raise e + except (TransportProtocolError, asyncio.TimeoutError) as e: + await self._fail(e, clean_close=False) + raise e - raise TransportQueryError( - str(payload), query_id=answer_id, errors=[payload] - ) + # Run the after_init hook of the subclass + await self._after_initialize() - elif answer_type in ["ping", "pong", "connection_ack"]: - self.payloads[answer_type] = json_answer.get("payload", None) + # If specified, create a task to check liveness of the connection + # through keep-alive messages + if self.keep_alive_timeout is not None: + self.check_keep_alive_task = asyncio.ensure_future( + self._check_ws_liveness() + ) - else: - raise ValueError + # Create a task to listen to the incoming websocket messages + self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() + else: + raise TransportAlreadyConnected("Transport is already connected") - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {json_answer}" - ) from e + log.debug("connect: done") - return answer_type, answer_id, execution_result + def _remove_listener(self, query_id) -> None: + """After exiting from a subscription, remove the listener and + signal an event if this was the last listener for the client. + """ + if query_id in self.listeners: + del self.listeners[query_id] - def _parse_answer_apollo( - self, json_answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - apollo websockets protocol. + remaining = len(self.listeners) + log.debug(f"listener {query_id} deleted, {remaining} remaining") - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - """ + if remaining == 0: + self._no_more_listeners.set() - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None + async def _clean_close(self, e: Exception) -> None: + """Coroutine which will: - try: - answer_type = str(json_answer.get("type")) + - send stop messages for each active subscription to the server + - send the connection terminate message + """ - if answer_type in ["data", "error", "complete"]: - answer_id = int(str(json_answer.get("id"))) + # Send 'stop' message for all current queries + for query_id, listener in self.listeners.items(): - if answer_type == "data" or answer_type == "error": + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False - payload = json_answer.get("payload") + # Wait that there is no more listeners (we received 'complete' for all queries) + try: + await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) + except asyncio.TimeoutError: # pragma: no cover + log.debug("Timer close_timeout fired") - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") + # Calling the subclass hook + await self._connection_terminate() - if answer_type == "data": + async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: + """Coroutine which will: - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) + - do a clean_close if possible: + - send stop messages for each active query to the server + - send the connection terminate message + - close the websocket connection + - send the exception to all the remaining listeners + """ - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) + log.debug("_close_coro: starting") - elif answer_type == "error": + try: - raise TransportQueryError( - str(payload), query_id=answer_id, errors=[payload] - ) + # We should always have an active websocket connection here + assert self.websocket is not None - elif answer_type == "ka": - # Keep-alive message - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - elif answer_type == "connection_ack": - pass - elif answer_type == "connection_error": - error_payload = json_answer.get("payload") - raise TransportServerError(f"Server error: '{repr(error_payload)}'") - else: - raise ValueError + # Properly shut down liveness checker if enabled + if self.check_keep_alive_task is not None: + # More info: https://stackoverflow.com/a/43810272/1113207 + self.check_keep_alive_task.cancel() + with suppress(asyncio.CancelledError): + await self.check_keep_alive_task - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {json_answer}" - ) from e + # Calling the subclass close hook + await self._close_hook() - return answer_type, answer_id, execution_result + # Saving exception to raise it later if trying to use the transport + # after it has already closed. + self.close_exception = e - def _parse_answer( - self, answer: str - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server depending on - the detected subprotocol. - """ - try: - json_answer = json.loads(answer) - except ValueError: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) + if clean_close: + log.debug("_close_coro: starting clean_close") + try: + await self._clean_close(e) + except Exception as exc: # pragma: no cover + log.warning("Ignoring exception in _clean_close: " + repr(exc)) - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(json_answer) + log.debug("_close_coro: sending exception to listeners") - return self._parse_answer_apollo(json_answer) + # Send an exception to all remaining listeners + for query_id, listener in self.listeners.items(): + await listener.set_exception(e) - async def _check_ws_liveness(self) -> None: - """Coroutine which will periodically check the liveness of the connection - through keep-alive messages - """ + log.debug("_close_coro: close websocket connection") - try: - while True: - await asyncio.wait_for( - self._next_keep_alive_message.wait(), self.keep_alive_timeout - ) + await self.websocket.close() - # Reset for the next iteration - self._next_keep_alive_message.clear() + log.debug("_close_coro: websocket connection closed") - except asyncio.TimeoutError: - # No keep-alive message in the appriopriate interval, close with error - # while trying to notify the server of a proper close (in case - # the keep-alive interval of the client or server was not aligned - # the connection still remains) + except Exception as exc: # pragma: no cover + log.warning("Exception catched in _close_coro: " + repr(exc)) - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - "No keep-alive message has been received within " - "the expected interval ('keep_alive_timeout' parameter)" - ), - clean_close=False, + finally: + + log.debug("_close_coro: start cleanup") + + self.websocket = None + self.close_task = None + self.check_keep_alive_task = None + self._wait_closed.set() + + log.debug("_close_coro: exiting") + + async def _fail(self, e: Exception, clean_close: bool = True) -> None: + log.debug("_fail: starting with exception: " + repr(e)) + + if self.close_task is None: + + if self.websocket is None: + log.debug("_fail started with self.websocket == None -> already closed") + else: + self.close_task = asyncio.shield( + asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) ) + else: + log.debug( + "close_task is not None in _fail. Previous exception is: " + + repr(self.close_exception) + + " New exception is: " + + repr(e) + ) - except asyncio.CancelledError: - # The client is probably closing, handle it properly - pass + async def close(self) -> None: + log.debug("close: starting") - async def _send_ping_coro(self) -> None: - """Coroutine to periodically send a ping from the client to the backend. + await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) + await self.wait_closed() - Only used for the graphql-ws protocol. + log.debug("close: done") - Send a ping every ping_interval seconds. - Close the connection if a pong is not received within pong_timeout seconds. + async def wait_closed(self) -> None: + log.debug("wait_close: starting") + + await self._wait_closed.wait() + + log.debug("wait_close: done") + + +class WebsocketsTransport(WebsocketsTransportBase): + """:ref:`Async Transport ` used to execute GraphQL queries on + remote servers with websocket connection. + + This transport uses asyncio and the websockets library in order to send requests + on a websocket connection. + """ + + # This transport supports two subprotocols and will autodetect the + # subprotocol supported on the server + APOLLO_SUBPROTOCOL = cast(Subprotocol, "graphql-ws") + GRAPHQLWS_SUBPROTOCOL = cast(Subprotocol, "graphql-transport-ws") + + def __init__( + self, + url: str, + headers: Optional[HeadersLike] = None, + ssl: Union[SSLContext, bool] = False, + init_payload: Dict[str, Any] = {}, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + ping_interval: Optional[Union[int, float]] = None, + pong_timeout: Optional[Union[int, float]] = None, + answer_pings: bool = True, + connect_args: Dict[str, Any] = {}, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param init_payload: Dict of the payload sent in the connection_init message. + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param ping_interval: Delay in seconds between pings sent by the client to + the backend for the graphql-ws protocol. None (by default) means that + we don't send pings. + :param pong_timeout: Delay in seconds to receive a pong from the backend + after we sent a ping (only for the graphql-ws protocol). + By default equal to half of the ping_interval. + :param answer_pings: Whether the client answers the pings from the backend + (for the graphql-ws protocol). + By default: True + :param connect_args: Other parameters forwarded to websockets.connect """ - assert self.ping_interval is not None + super().__init__( + url, + headers, + ssl, + init_payload, + connect_timeout, + close_timeout, + ack_timeout, + keep_alive_timeout, + connect_args, + ) - try: - while True: - await asyncio.sleep(self.ping_interval) + self.ping_interval: Optional[Union[int, float]] = ping_interval + self.pong_timeout: Optional[Union[int, float]] + self.answer_pings: bool = answer_pings - await self.send_ping() + if ping_interval is not None: + if pong_timeout is None: + self.pong_timeout = ping_interval / 2 + else: + self.pong_timeout = pong_timeout - await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) + self.send_ping_task: Optional[asyncio.Future] = None - # Reset for the next iteration - self.pong_received.clear() + self.ping_received: asyncio.Event = asyncio.Event() + """ping_received is an asyncio Event which will fire each time + a ping is received with the graphql-ws protocol""" - except asyncio.TimeoutError: - # No pong received in the appriopriate time, close with error - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - f"No pong received after {self.pong_timeout!r} seconds" - ), - clean_close=False, + self.pong_received: asyncio.Event = asyncio.Event() + """pong_received is an asyncio Event which will fire each time + a pong is received with the graphql-ws protocol""" + + self.supported_subprotocols = [ + self.APOLLO_SUBPROTOCOL, + self.GRAPHQLWS_SUBPROTOCOL, + ] + + async def _wait_ack(self) -> None: + """Wait for the connection_ack message. Keep alive messages are ignored""" + + while True: + init_answer = await self._receive() + + answer_type, answer_id, execution_result = self._parse_answer(init_answer) + + if answer_type == "connection_ack": + return + + if answer_type != "ka": + raise TransportProtocolError( + "Websocket server did not return a connection ack" ) - async def _receive_data_loop(self) -> None: - try: - while True: + async def _send_init_message_and_wait_ack(self) -> None: + """Send init message to the provided websocket and wait for the connection ACK. - # Wait the next answer from the websocket server - try: - answer = await self._receive() - except (ConnectionClosed, TransportProtocolError) as e: - await self._fail(e, clean_close=False) - break - except TransportClosed: - break + If the answer is not a connection_ack message, we will return an Exception. + """ - # Parse the answer - try: - answer_type, answer_id, execution_result = self._parse_answer( - answer - ) - except TransportQueryError as e: - # Received an exception for a specific query - # ==> Add an exception to this query queue - # The exception is raised for this specific query, - # but the transport is not closed. - assert isinstance( - e.query_id, int - ), "TransportQueryError should have a query_id defined here" - try: - await self.listeners[e.query_id].set_exception(e) - except KeyError: - # Do nothing if no one is listening to this query_id - pass + init_message = json.dumps( + {"type": "connection_init", "payload": self.init_payload} + ) - continue + await self._send(init_message) - except (TransportServerError, TransportProtocolError) as e: - # Received a global exception for this transport - # ==> close the transport - # The exception will be raised for all current queries. - await self._fail(e, clean_close=False) - break + # Wait for the connection_ack message or raise a TimeoutError + await asyncio.wait_for(self._wait_ack(), self.ack_timeout) - await self._handle_answer(answer_type, answer_id, execution_result) + async def _initialize(self): + await self._send_init_message_and_wait_ack() - finally: - log.debug("Exiting _receive_data_loop()") + async def send_ping(self, payload: Optional[Any] = None) -> None: + """Send a ping message for the graphql-ws protocol + """ - async def _handle_answer( - self, - answer_type: str, - answer_id: Optional[int], - execution_result: Optional[ExecutionResult], - ) -> None: + ping_message = {"type": "ping"} - try: - # Put the answer in the queue - if answer_id is not None: - await self.listeners[answer_id].put((answer_type, execution_result)) - except KeyError: - # Do nothing if no one is listening to this query_id. - pass + if payload is not None: + ping_message["payload"] = payload - # Answer pong to ping for graphql-ws protocol - if answer_type == "ping": - self.ping_received.set() - if self.answer_pings: - await self.send_pong() + await self._send(json.dumps(ping_message)) - elif answer_type == "pong": - self.pong_received.set() + async def send_pong(self, payload: Optional[Any] = None) -> None: + """Send a pong message for the graphql-ws protocol + """ - async def subscribe( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - send_stop: Optional[bool] = True, - ) -> AsyncGenerator[ExecutionResult, None]: - """Send a query and receive the results using a python async generator. + pong_message = {"type": "pong"} - The query can be a graphql query, mutation or subscription. + if payload is not None: + pong_message["payload"] = payload - The results are sent as an ExecutionResult object. + await self._send(json.dumps(pong_message)) + + async def _send_stop_message(self, query_id: int) -> None: + """Send stop message to the provided websocket connection and query_id. + + The server should afterwards return a 'complete' message. """ - # Send the query and receive the id - query_id: int = await self._send_query( - document, variable_values, operation_name - ) + stop_message = json.dumps({"id": str(query_id), "type": "stop"}) - # Create a queue to receive the answers for this query_id - listener = ListenerQueue(query_id, send_stop=(send_stop is True)) - self.listeners[query_id] = listener + await self._send(stop_message) - # We will need to wait at close for this query to clean properly - self._no_more_listeners.clear() + async def _send_complete_message(self, query_id: int) -> None: + """Send a complete message for the provided query_id. - try: - # Loop over the received answers - while True: + This is only for the graphql-ws protocol. + """ - # Wait for the answer from the queue of this query_id - # This can raise a TransportError or ConnectionClosed exception. - answer_type, execution_result = await listener.get() + complete_message = json.dumps({"id": str(query_id), "type": "complete"}) - # If the received answer contains data, - # Then we will yield the results back as an ExecutionResult object - if execution_result is not None: - yield execution_result + await self._send(complete_message) - # If we receive a 'complete' answer from the server, - # Then we will end this async generator output without errors - elif answer_type == "complete": - log.debug( - f"Complete received for query {query_id} --> exit without error" - ) - break + async def _stop_listener(self, query_id: int) -> None: + """Stop the listener corresponding to the query_id depending on the + detected backend protocol. - except (asyncio.CancelledError, GeneratorExit) as e: - log.debug(f"Exception in subscribe: {e!r}") - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False + For apollo: send a "stop" message + (a "complete" message will be sent from the backend) - finally: - log.debug(f"In subscribe finally for query_id {query_id}") - self._remove_listener(query_id) + For graphql-ws: send a "complete" message and simulate the reception + of a "complete" message from the backend + """ + log.debug(f"stop listener {query_id}") - async def execute( + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + await self._send_complete_message(query_id) + await self.listeners[query_id].put(("complete", None)) + else: + await self._send_stop_message(query_id) + + async def _send_connection_terminate_message(self) -> None: + """Send a connection_terminate message to the provided websocket connection. + + This message indicates that the connection will disconnect. + """ + + connection_terminate_message = json.dumps({"type": "connection_terminate"}) + + await self._send(connection_terminate_message) + + async def _send_query( self, document: DocumentNode, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, - ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server - using the current session. + ) -> int: + """Send a query to the provided websocket connection. - Send a query but close the async generator as soon as we have the first answer. + We use an incremented id to reference the query. - The result is sent as an ExecutionResult object. + Returns the used id for this query. """ - first_result = None - generator = self.subscribe( - document, variable_values, operation_name, send_stop=False - ) + query_id = self.next_query_id + self.next_query_id += 1 - async for result in generator: - first_result = result + payload: Dict[str, Any] = {"query": print_ast(document)} + if variable_values: + payload["variables"] = variable_values + if operation_name: + payload["operationName"] = operation_name - # Note: we need to run generator.aclose() here or the finally block in - # the subscribe will not be reached in pypy3 (python version 3.6.1) - await generator.aclose() + query_type = "start" - break + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + query_type = "subscribe" - if first_result is None: - raise TransportQueryError( - "Query completed without any answer received from the server" - ) + query_str = json.dumps( + {"id": str(query_id), "type": query_type, "payload": payload} + ) - return first_result + await self._send(query_str) - async def connect(self) -> None: - """Coroutine which will: + return query_id - - connect to the websocket address - - send the init message - - wait for the connection acknowledge from the server - - create an asyncio task which will be used to receive - and parse the websocket answers + async def _connection_terminate(self): + if self.subprotocol == self.APOLLO_SUBPROTOCOL: + await self._send_connection_terminate_message() - Should be cleaned with a call to the close coroutine - """ + def _parse_answer_graphqlws( + self, json_answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + graphql-ws protocol. - log.debug("connect: starting") + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None - if self.websocket is None and not self._connecting: + Differences with the apollo websockets protocol (superclass): + - the "data" message is now called "next" + - the "stop" message is now called "complete" + - there is no connection_terminate or connection_error messages + - instead of a unidirectional keep-alive (ka) message from server to client, + there is now the possibility to send bidirectional ping/pong messages + - connection_ack has an optional payload + """ - # Set connecting to True to avoid a race condition if user is trying - # to connect twice using the same client at the same time - self._connecting = True + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None - # If the ssl parameter is not provided, - # generate the ssl value depending on the url - ssl: Optional[Union[SSLContext, bool]] - if self.ssl: - ssl = self.ssl - else: - ssl = True if self.url.startswith("wss") else None + try: + answer_type = str(json_answer.get("type")) - # Set default arguments used in the websockets.connect call - connect_args: Dict[str, Any] = { - "ssl": ssl, - "extra_headers": self.headers, - "subprotocols": self.supported_subprotocols, - } + if answer_type in ["next", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) - # Adding custom parameters passed from init - connect_args.update(self.connect_args) + if answer_type == "next" or answer_type == "error": - # Connection to the specified url - # Generate a TimeoutError if taking more than connect_timeout seconds - # Set the _connecting flag to False after in all cases - try: - self.websocket = await asyncio.wait_for( - websockets.client.connect(self.url, **connect_args), - self.connect_timeout, - ) - finally: - self._connecting = False + payload = json_answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") - self.websocket = cast(WebSocketClientProtocol, self.websocket) + if answer_type == "next": - # Find the backend subprotocol returned in the response headers - response_headers = self.websocket.response_headers - try: - self.subprotocol = response_headers["Sec-WebSocket-Protocol"] - except KeyError: - # If the server does not send the subprotocol header, using - # the apollo subprotocol by default - self.subprotocol = self.APOLLO_SUBPROTOCOL + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) - log.debug(f"backend subprotocol returned: {self.subprotocol!r}") + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) - self.next_query_id = 1 - self.close_exception = None - self._wait_closed.clear() + # Saving answer_type as 'data' to be understood with superclass + answer_type = "data" - # Send the init message and wait for the ack from the server - # Note: This will generate a TimeoutError - # if no ACKs are received within the ack_timeout - try: - await self._send_init_message_and_wait_ack() - except ConnectionClosed as e: - raise e - except (TransportProtocolError, asyncio.TimeoutError) as e: - await self._fail(e, clean_close=False) - raise e + elif answer_type == "error": - # If specified, create a task to check liveness of the connection - # through keep-alive messages - if self.keep_alive_timeout is not None: - self.check_keep_alive_task = asyncio.ensure_future( - self._check_ws_liveness() - ) + raise TransportQueryError( + str(payload), query_id=answer_id, errors=[payload] + ) - # If requested, create a task to send periodic pings to the backend - if ( - self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL - and self.ping_interval is not None - ): + elif answer_type in ["ping", "pong", "connection_ack"]: + self.payloads[answer_type] = json_answer.get("payload", None) - self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) + else: + raise ValueError - # Create a task to listen to the incoming websocket messages - self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() - else: - raise TransportAlreadyConnected("Transport is already connected") + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {json_answer}" + ) from e - log.debug("connect: done") + return answer_type, answer_id, execution_result - def _remove_listener(self, query_id) -> None: - """After exiting from a subscription, remove the listener and - signal an event if this was the last listener for the client. + def _parse_answer_apollo( + self, json_answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + apollo websockets protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None """ - if query_id in self.listeners: - del self.listeners[query_id] - remaining = len(self.listeners) - log.debug(f"listener {query_id} deleted, {remaining} remaining") + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None - if remaining == 0: - self._no_more_listeners.set() + try: + answer_type = str(json_answer.get("type")) - async def _clean_close(self, e: Exception) -> None: - """Coroutine which will: + if answer_type in ["data", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) - - send stop messages for each active subscription to the server - - send the connection terminate message - """ + if answer_type == "data" or answer_type == "error": - # Send 'stop' message for all current queries - for query_id, listener in self.listeners.items(): + payload = json_answer.get("payload") - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") - # Wait that there is no more listeners (we received 'complete' for all queries) - try: - await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) - except asyncio.TimeoutError: # pragma: no cover - log.debug("Timer close_timeout fired") + if answer_type == "data": - if self.subprotocol == self.APOLLO_SUBPROTOCOL: - # Finally send the 'connection_terminate' message - await self._send_connection_terminate_message() + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) - async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: - """Coroutine which will: + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) - - do a clean_close if possible: - - send stop messages for each active query to the server - - send the connection terminate message - - close the websocket connection - - send the exception to all the remaining listeners - """ + elif answer_type == "error": - log.debug("_close_coro: starting") + raise TransportQueryError( + str(payload), query_id=answer_id, errors=[payload] + ) - try: + elif answer_type == "ka": + # Keep-alive message + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + elif answer_type == "connection_ack": + pass + elif answer_type == "connection_error": + error_payload = json_answer.get("payload") + raise TransportServerError(f"Server error: '{repr(error_payload)}'") + else: + raise ValueError - # We should always have an active websocket connection here - assert self.websocket is not None + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {json_answer}" + ) from e - # Properly shut down liveness checker if enabled - if self.check_keep_alive_task is not None: - # More info: https://stackoverflow.com/a/43810272/1113207 - self.check_keep_alive_task.cancel() - with suppress(asyncio.CancelledError): - await self.check_keep_alive_task + return answer_type, answer_id, execution_result - # Properly shut down the send ping task if enabled - if self.send_ping_task is not None: - self.send_ping_task.cancel() - with suppress(asyncio.CancelledError): - await self.send_ping_task + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server depending on + the detected subprotocol. + """ + try: + json_answer = json.loads(answer) + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) - # Saving exception to raise it later if trying to use the transport - # after it has already closed. - self.close_exception = e + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + return self._parse_answer_graphqlws(json_answer) - if clean_close: - log.debug("_close_coro: starting clean_close") - try: - await self._clean_close(e) - except Exception as exc: # pragma: no cover - log.warning("Ignoring exception in _clean_close: " + repr(exc)) + return self._parse_answer_apollo(json_answer) - log.debug("_close_coro: sending exception to listeners") + async def _send_ping_coro(self) -> None: + """Coroutine to periodically send a ping from the client to the backend. - # Send an exception to all remaining listeners - for query_id, listener in self.listeners.items(): - await listener.set_exception(e) + Only used for the graphql-ws protocol. - log.debug("_close_coro: close websocket connection") + Send a ping every ping_interval seconds. + Close the connection if a pong is not received within pong_timeout seconds. + """ - await self.websocket.close() + assert self.ping_interval is not None - log.debug("_close_coro: websocket connection closed") + try: + while True: + await asyncio.sleep(self.ping_interval) - except Exception as exc: # pragma: no cover - log.warning("Exception catched in _close_coro: " + repr(exc)) + await self.send_ping() - finally: + await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) - log.debug("_close_coro: start cleanup") + # Reset for the next iteration + self.pong_received.clear() - self.websocket = None - self.close_task = None - self.check_keep_alive_task = None - self.send_ping_task = None + except asyncio.TimeoutError: + # No pong received in the appriopriate time, close with error + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + f"No pong received after {self.pong_timeout!r} seconds" + ), + clean_close=False, + ) - self._wait_closed.set() + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: - log.debug("_close_coro: exiting") + # Put the answer in the queue + await super()._handle_answer(answer_type, answer_id, execution_result) - async def _fail(self, e: Exception, clean_close: bool = True) -> None: - log.debug("_fail: starting with exception: " + repr(e)) + # Answer pong to ping for graphql-ws protocol + if answer_type == "ping": + self.ping_received.set() + if self.answer_pings: + await self.send_pong() - if self.close_task is None: + elif answer_type == "pong": + self.pong_received.set() - if self.websocket is None: - log.debug("_fail started with self.websocket == None -> already closed") - else: - self.close_task = asyncio.shield( - asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) - ) - else: - log.debug( - "close_task is not None in _fail. Previous exception is: " - + repr(self.close_exception) - + " New exception is: " - + repr(e) - ) + async def _after_connect(self): - async def close(self) -> None: - log.debug("close: starting") + # Find the backend subprotocol returned in the response headers + response_headers = self.websocket.response_headers + try: + self.subprotocol = response_headers["Sec-WebSocket-Protocol"] + except KeyError: + # If the server does not send the subprotocol header, using + # the apollo subprotocol by default + self.subprotocol = self.APOLLO_SUBPROTOCOL - await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) - await self.wait_closed() + log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - log.debug("close: done") + async def _after_initialize(self): - async def wait_closed(self) -> None: - log.debug("wait_close: starting") + # If requested, create a task to send periodic pings to the backend + if ( + self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL + and self.ping_interval is not None + ): - await self._wait_closed.wait() + self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) - log.debug("wait_close: done") + async def _close_hook(self): + + # Properly shut down the send ping task if enabled + if self.send_ping_task is not None: + self.send_ping_task.cancel() + with suppress(asyncio.CancelledError): + await self.send_ping_task + self.send_ping_task = None From 6b79fe4c0840e3c2cc462bff4d98b7f022411c41 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 2 Dec 2021 17:19:44 +0100 Subject: [PATCH 43/68] AppSyncWebsocketsTransport now inherits WebsocketsTransportBase --- gql/transport/appsync.py | 24 ++++++++++++++++++++---- gql/transport/websockets.py | 6 +++--- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/gql/transport/appsync.py b/gql/transport/appsync.py index 1befc928..8fac48c4 100644 --- a/gql/transport/appsync.py +++ b/gql/transport/appsync.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from base64 import b64encode from ssl import SSLContext -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Coroutine, Dict, Optional, Tuple, Union, cast from urllib.parse import urlparse import botocore.session @@ -15,7 +15,7 @@ from graphql import DocumentNode, ExecutionResult, print_ast from .exceptions import TransportProtocolError, TransportServerError -from .websockets import WebsocketsTransport +from .websockets import WebsocketsTransport, WebsocketsTransportBase log = logging.getLogger(__name__) @@ -167,7 +167,7 @@ def get_headers(self, data: Optional[str] = None,) -> Dict: return headers -class AppSyncWebsocketsTransport(WebsocketsTransport): +class AppSyncWebsocketsTransport(WebsocketsTransportBase): """:ref:`Async Transport ` used to execute GraphQL subscription on AWS appsync realtime endpoint. @@ -240,6 +240,12 @@ def __init__( connect_args=connect_args, ) + # Using the same 'graphql-ws' protocol as the apollo protocol + self.supported_subprotocols = [ + WebsocketsTransport.APOLLO_SUBPROTOCOL, + ] + self.subprotocol = WebsocketsTransport.APOLLO_SUBPROTOCOL + def _parse_answer( self, answer: str ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: @@ -281,7 +287,9 @@ def _parse_answer( else: - return self._parse_answer_apollo(json_answer) + return WebsocketsTransport._parse_answer_apollo( + cast(WebsocketsTransport, self), json_answer + ) except ValueError: raise TransportProtocolError( @@ -327,4 +335,12 @@ async def _send_query( await self._send(json.dumps(message, separators=(",", ":"),)) + print(Coroutine) return query_id + + _initialize = WebsocketsTransport._initialize + _stop_listener = WebsocketsTransport._stop_listener + _send_init_message_and_wait_ack = ( + WebsocketsTransport._send_init_message_and_wait_ack + ) + _wait_ack = WebsocketsTransport._wait_ack diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index ebb777a4..80adb563 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -170,13 +170,13 @@ def __init__( # The list of supported subprotocols should be defined in the subclass self.supported_subprotocols: List[Subprotocol] = [] - async def _initialize(self) -> None: + async def _initialize(self): """Hook to send the initialization messages after the connection and potentially wait for the backend ack. """ pass # pragma: no cover - async def _stop_listener(self, query_id: int) -> None: + async def _stop_listener(self, query_id: int): """Hook to stop to listen to a specific query. Will send a stop message in some subclasses. """ @@ -834,7 +834,7 @@ async def _send_complete_message(self, query_id: int) -> None: await self._send(complete_message) - async def _stop_listener(self, query_id: int) -> None: + async def _stop_listener(self, query_id: int): """Stop the listener corresponding to the query_id depending on the detected backend protocol. From 23f72472dba7aac2eb375292fba25716dfd602ee Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 2 Dec 2021 17:30:54 +0100 Subject: [PATCH 44/68] Refactor: split websockets.py into websockets.py and websockets_base.py --- gql/transport/websockets.py | 654 +----------------------------- gql/transport/websockets_base.py | 666 +++++++++++++++++++++++++++++++ 2 files changed, 669 insertions(+), 651 deletions(-) create mode 100644 gql/transport/websockets_base.py diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 80adb563..41478daf 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,671 +1,23 @@ import asyncio import json import logging -import warnings -from abc import abstractmethod from contextlib import suppress from ssl import SSLContext -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, Optional, Tuple, Union, cast -import websockets from graphql import DocumentNode, ExecutionResult, print_ast -from websockets.client import WebSocketClientProtocol from websockets.datastructures import HeadersLike -from websockets.exceptions import ConnectionClosed -from websockets.typing import Data, Subprotocol +from websockets.typing import Subprotocol -from .async_transport import AsyncTransport from .exceptions import ( - TransportAlreadyConnected, - TransportClosed, TransportProtocolError, TransportQueryError, TransportServerError, ) +from .websockets_base import WebsocketsTransportBase log = logging.getLogger(__name__) -ParsedAnswer = Tuple[str, Optional[ExecutionResult]] - - -class ListenerQueue: - """Special queue used for each query waiting for server answers - - If the server is stopped while the listener is still waiting, - Then we send an exception to the queue and this exception will be raised - to the consumer once all the previous messages have been consumed from the queue - """ - - def __init__(self, query_id: int, send_stop: bool) -> None: - self.query_id: int = query_id - self.send_stop: bool = send_stop - self._queue: asyncio.Queue = asyncio.Queue() - self._closed: bool = False - - async def get(self) -> ParsedAnswer: - - try: - item = self._queue.get_nowait() - except asyncio.QueueEmpty: - item = await self._queue.get() - - self._queue.task_done() - - # If we receive an exception when reading the queue, we raise it - if isinstance(item, Exception): - self._closed = True - raise item - - # Don't need to save new answers or - # send the stop message if we already received the complete message - answer_type, execution_result = item - if answer_type == "complete": - self.send_stop = False - self._closed = True - - return item - - async def put(self, item: ParsedAnswer) -> None: - - if not self._closed: - await self._queue.put(item) - - async def set_exception(self, exception: Exception) -> None: - - # Put the exception in the queue - await self._queue.put(exception) - - # Don't need to send stop messages in case of error - self.send_stop = False - self._closed = True - - -class WebsocketsTransportBase(AsyncTransport): - """abstract :ref:`Async Transport ` used to implement - different websockets protocols. - - This transport uses asyncio and the websockets library in order to send requests - on a websocket connection. - """ - - def __init__( - self, - url: str, - headers: Optional[HeadersLike] = None, - ssl: Union[SSLContext, bool] = False, - init_payload: Dict[str, Any] = {}, - connect_timeout: Optional[Union[int, float]] = 10, - close_timeout: Optional[Union[int, float]] = 10, - ack_timeout: Optional[Union[int, float]] = 10, - keep_alive_timeout: Optional[Union[int, float]] = None, - connect_args: Dict[str, Any] = {}, - ) -> None: - """Initialize the transport with the given parameters. - - :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. - :param headers: Dict of HTTP Headers. - :param ssl: ssl_context of the connection. Use ssl=False to disable encryption - :param init_payload: Dict of the payload sent in the connection_init message. - :param connect_timeout: Timeout in seconds for the establishment - of the websocket connection. If None is provided this will wait forever. - :param close_timeout: Timeout in seconds for the close. If None is provided - this will wait forever. - :param ack_timeout: Timeout in seconds to wait for the connection_ack message - from the server. If None is provided this will wait forever. - :param keep_alive_timeout: Optional Timeout in seconds to receive - a sign of liveness from the server. - :param connect_args: Other parameters forwarded to websockets.connect - """ - - self.url: str = url - self.headers: Optional[HeadersLike] = headers - self.ssl: Union[SSLContext, bool] = ssl - self.init_payload: Dict[str, Any] = init_payload - - self.connect_timeout: Optional[Union[int, float]] = connect_timeout - self.close_timeout: Optional[Union[int, float]] = close_timeout - self.ack_timeout: Optional[Union[int, float]] = ack_timeout - self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout - - self.connect_args = connect_args - - self.websocket: Optional[WebSocketClientProtocol] = None - self.next_query_id: int = 1 - self.listeners: Dict[int, ListenerQueue] = {} - - self.receive_data_task: Optional[asyncio.Future] = None - self.check_keep_alive_task: Optional[asyncio.Future] = None - self.close_task: Optional[asyncio.Future] = None - - # We need to set an event loop here if there is none - # Or else we will not be able to create an asyncio.Event() - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - self._loop = asyncio.get_event_loop() - except RuntimeError: - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - self._wait_closed: asyncio.Event = asyncio.Event() - self._wait_closed.set() - - self._no_more_listeners: asyncio.Event = asyncio.Event() - self._no_more_listeners.set() - - if self.keep_alive_timeout is not None: - self._next_keep_alive_message: asyncio.Event = asyncio.Event() - self._next_keep_alive_message.set() - - self.payloads: Dict[str, Any] = {} - """payloads is a dict which will contain the payloads received - for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" - - self._connecting: bool = False - - self.close_exception: Optional[Exception] = None - - # The list of supported subprotocols should be defined in the subclass - self.supported_subprotocols: List[Subprotocol] = [] - - async def _initialize(self): - """Hook to send the initialization messages after the connection - and potentially wait for the backend ack. - """ - pass # pragma: no cover - - async def _stop_listener(self, query_id: int): - """Hook to stop to listen to a specific query. - Will send a stop message in some subclasses. - """ - pass # pragma: no cover - - async def _after_connect(self): - """Hook to add custom code for subclasses after the connection - has been established. - """ - pass # pragma: no cover - - async def _after_initialize(self): - """Hook to add custom code for subclasses after the initialization - has been done. - """ - pass # pragma: no cover - - async def _close_hook(self): - """Hook to add custom code for subclasses for the connection close - """ - pass # pragma: no cover - - async def _connection_terminate(self): - """Hook to add custom code for subclasses after the initialization - has been done. - """ - pass # pragma: no cover - - async def _send(self, message: str) -> None: - """Send the provided message to the websocket connection and log the message""" - - if not self.websocket: - raise TransportClosed( - "Transport is not connected" - ) from self.close_exception - - try: - await self.websocket.send(message) - log.info(">>> %s", message) - except ConnectionClosed as e: - await self._fail(e, clean_close=False) - raise e - - async def _receive(self) -> str: - """Wait the next message from the websocket connection and log the answer""" - - # It is possible that the websocket has been already closed in another task - if self.websocket is None: - raise TransportClosed("Transport is already closed") - - # Wait for the next websocket frame. Can raise ConnectionClosed - data: Data = await self.websocket.recv() - - # websocket.recv() can return either str or bytes - # In our case, we should receive only str here - if not isinstance(data, str): - raise TransportProtocolError("Binary data received in the websocket") - - answer: str = data - - log.info("<<< %s", answer) - - return answer - - @abstractmethod - async def _send_query( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> int: - raise NotImplementedError # pragma: no cover - - @abstractmethod - def _parse_answer( - self, answer: str - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - raise NotImplementedError # pragma: no cover - - async def _check_ws_liveness(self) -> None: - """Coroutine which will periodically check the liveness of the connection - through keep-alive messages - """ - - try: - while True: - await asyncio.wait_for( - self._next_keep_alive_message.wait(), self.keep_alive_timeout - ) - - # Reset for the next iteration - self._next_keep_alive_message.clear() - - except asyncio.TimeoutError: - # No keep-alive message in the appriopriate interval, close with error - # while trying to notify the server of a proper close (in case - # the keep-alive interval of the client or server was not aligned - # the connection still remains) - - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - "No keep-alive message has been received within " - "the expected interval ('keep_alive_timeout' parameter)" - ), - clean_close=False, - ) - - except asyncio.CancelledError: - # The client is probably closing, handle it properly - pass - - async def _receive_data_loop(self) -> None: - """Main asyncio task which will listen to the incoming messages and will - call the parse_answer and handle_answer methods of the subclass.""" - try: - while True: - - # Wait the next answer from the websocket server - try: - answer = await self._receive() - except (ConnectionClosed, TransportProtocolError) as e: - await self._fail(e, clean_close=False) - break - except TransportClosed: - break - - # Parse the answer - try: - answer_type, answer_id, execution_result = self._parse_answer( - answer - ) - except TransportQueryError as e: - # Received an exception for a specific query - # ==> Add an exception to this query queue - # The exception is raised for this specific query, - # but the transport is not closed. - assert isinstance( - e.query_id, int - ), "TransportQueryError should have a query_id defined here" - try: - await self.listeners[e.query_id].set_exception(e) - except KeyError: - # Do nothing if no one is listening to this query_id - pass - - continue - - except (TransportServerError, TransportProtocolError) as e: - # Received a global exception for this transport - # ==> close the transport - # The exception will be raised for all current queries. - await self._fail(e, clean_close=False) - break - - await self._handle_answer(answer_type, answer_id, execution_result) - - finally: - log.debug("Exiting _receive_data_loop()") - - async def _handle_answer( - self, - answer_type: str, - answer_id: Optional[int], - execution_result: Optional[ExecutionResult], - ) -> None: - - try: - # Put the answer in the queue - if answer_id is not None: - await self.listeners[answer_id].put((answer_type, execution_result)) - except KeyError: - # Do nothing if no one is listening to this query_id. - pass - - async def subscribe( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - send_stop: Optional[bool] = True, - ) -> AsyncGenerator[ExecutionResult, None]: - """Send a query and receive the results using a python async generator. - - The query can be a graphql query, mutation or subscription. - - The results are sent as an ExecutionResult object. - """ - - # Send the query and receive the id - query_id: int = await self._send_query( - document, variable_values, operation_name - ) - - # Create a queue to receive the answers for this query_id - listener = ListenerQueue(query_id, send_stop=(send_stop is True)) - self.listeners[query_id] = listener - - # We will need to wait at close for this query to clean properly - self._no_more_listeners.clear() - - try: - # Loop over the received answers - while True: - - # Wait for the answer from the queue of this query_id - # This can raise a TransportError or ConnectionClosed exception. - answer_type, execution_result = await listener.get() - - # If the received answer contains data, - # Then we will yield the results back as an ExecutionResult object - if execution_result is not None: - yield execution_result - - # If we receive a 'complete' answer from the server, - # Then we will end this async generator output without errors - elif answer_type == "complete": - log.debug( - f"Complete received for query {query_id} --> exit without error" - ) - break - - except (asyncio.CancelledError, GeneratorExit) as e: - log.debug(f"Exception in subscribe: {e!r}") - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False - - finally: - log.debug(f"In subscribe finally for query_id {query_id}") - self._remove_listener(query_id) - - async def execute( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server - using the current session. - - Send a query but close the async generator as soon as we have the first answer. - - The result is sent as an ExecutionResult object. - """ - first_result = None - - generator = self.subscribe( - document, variable_values, operation_name, send_stop=False - ) - - async for result in generator: - first_result = result - - # Note: we need to run generator.aclose() here or the finally block in - # the subscribe will not be reached in pypy3 (python version 3.6.1) - await generator.aclose() - - break - - if first_result is None: - raise TransportQueryError( - "Query completed without any answer received from the server" - ) - - return first_result - - async def connect(self) -> None: - """Coroutine which will: - - - connect to the websocket address - - send the init message - - wait for the connection acknowledge from the server - - create an asyncio task which will be used to receive - and parse the websocket answers - - Should be cleaned with a call to the close coroutine - """ - - log.debug("connect: starting") - - if self.websocket is None and not self._connecting: - - # Set connecting to True to avoid a race condition if user is trying - # to connect twice using the same client at the same time - self._connecting = True - - # If the ssl parameter is not provided, - # generate the ssl value depending on the url - ssl: Optional[Union[SSLContext, bool]] - if self.ssl: - ssl = self.ssl - else: - ssl = True if self.url.startswith("wss") else None - - # Set default arguments used in the websockets.connect call - connect_args: Dict[str, Any] = { - "ssl": ssl, - "extra_headers": self.headers, - "subprotocols": self.supported_subprotocols, - } - - # Adding custom parameters passed from init - connect_args.update(self.connect_args) - - # Connection to the specified url - # Generate a TimeoutError if taking more than connect_timeout seconds - # Set the _connecting flag to False after in all cases - try: - self.websocket = await asyncio.wait_for( - websockets.client.connect(self.url, **connect_args), - self.connect_timeout, - ) - finally: - self._connecting = False - - self.websocket = cast(WebSocketClientProtocol, self.websocket) - - # Run the after_connect hook of the subclass - await self._after_connect() - - self.next_query_id = 1 - self.close_exception = None - self._wait_closed.clear() - - # Send the init message and wait for the ack from the server - # Note: This should generate a TimeoutError - # if no ACKs are received within the ack_timeout - try: - await self._initialize() - except ConnectionClosed as e: - raise e - except (TransportProtocolError, asyncio.TimeoutError) as e: - await self._fail(e, clean_close=False) - raise e - - # Run the after_init hook of the subclass - await self._after_initialize() - - # If specified, create a task to check liveness of the connection - # through keep-alive messages - if self.keep_alive_timeout is not None: - self.check_keep_alive_task = asyncio.ensure_future( - self._check_ws_liveness() - ) - - # Create a task to listen to the incoming websocket messages - self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) - - else: - raise TransportAlreadyConnected("Transport is already connected") - - log.debug("connect: done") - - def _remove_listener(self, query_id) -> None: - """After exiting from a subscription, remove the listener and - signal an event if this was the last listener for the client. - """ - if query_id in self.listeners: - del self.listeners[query_id] - - remaining = len(self.listeners) - log.debug(f"listener {query_id} deleted, {remaining} remaining") - - if remaining == 0: - self._no_more_listeners.set() - - async def _clean_close(self, e: Exception) -> None: - """Coroutine which will: - - - send stop messages for each active subscription to the server - - send the connection terminate message - """ - - # Send 'stop' message for all current queries - for query_id, listener in self.listeners.items(): - - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False - - # Wait that there is no more listeners (we received 'complete' for all queries) - try: - await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) - except asyncio.TimeoutError: # pragma: no cover - log.debug("Timer close_timeout fired") - - # Calling the subclass hook - await self._connection_terminate() - - async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: - """Coroutine which will: - - - do a clean_close if possible: - - send stop messages for each active query to the server - - send the connection terminate message - - close the websocket connection - - send the exception to all the remaining listeners - """ - - log.debug("_close_coro: starting") - - try: - - # We should always have an active websocket connection here - assert self.websocket is not None - - # Properly shut down liveness checker if enabled - if self.check_keep_alive_task is not None: - # More info: https://stackoverflow.com/a/43810272/1113207 - self.check_keep_alive_task.cancel() - with suppress(asyncio.CancelledError): - await self.check_keep_alive_task - - # Calling the subclass close hook - await self._close_hook() - - # Saving exception to raise it later if trying to use the transport - # after it has already closed. - self.close_exception = e - - if clean_close: - log.debug("_close_coro: starting clean_close") - try: - await self._clean_close(e) - except Exception as exc: # pragma: no cover - log.warning("Ignoring exception in _clean_close: " + repr(exc)) - - log.debug("_close_coro: sending exception to listeners") - - # Send an exception to all remaining listeners - for query_id, listener in self.listeners.items(): - await listener.set_exception(e) - - log.debug("_close_coro: close websocket connection") - - await self.websocket.close() - - log.debug("_close_coro: websocket connection closed") - - except Exception as exc: # pragma: no cover - log.warning("Exception catched in _close_coro: " + repr(exc)) - - finally: - - log.debug("_close_coro: start cleanup") - - self.websocket = None - self.close_task = None - self.check_keep_alive_task = None - self._wait_closed.set() - - log.debug("_close_coro: exiting") - - async def _fail(self, e: Exception, clean_close: bool = True) -> None: - log.debug("_fail: starting with exception: " + repr(e)) - - if self.close_task is None: - - if self.websocket is None: - log.debug("_fail started with self.websocket == None -> already closed") - else: - self.close_task = asyncio.shield( - asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) - ) - else: - log.debug( - "close_task is not None in _fail. Previous exception is: " - + repr(self.close_exception) - + " New exception is: " - + repr(e) - ) - - async def close(self) -> None: - log.debug("close: starting") - - await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) - await self.wait_closed() - - log.debug("close: done") - - async def wait_closed(self) -> None: - log.debug("wait_close: starting") - - await self._wait_closed.wait() - - log.debug("wait_close: done") - class WebsocketsTransport(WebsocketsTransportBase): """:ref:`Async Transport ` used to execute GraphQL queries on diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py new file mode 100644 index 00000000..151e444e --- /dev/null +++ b/gql/transport/websockets_base.py @@ -0,0 +1,666 @@ +import asyncio +import logging +import warnings +from abc import abstractmethod +from contextlib import suppress +from ssl import SSLContext +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast + +import websockets +from graphql import DocumentNode, ExecutionResult +from websockets.client import WebSocketClientProtocol +from websockets.datastructures import HeadersLike +from websockets.exceptions import ConnectionClosed +from websockets.typing import Data, Subprotocol + +from .async_transport import AsyncTransport +from .exceptions import ( + TransportAlreadyConnected, + TransportClosed, + TransportProtocolError, + TransportQueryError, + TransportServerError, +) + +log = logging.getLogger("gql.transport.websockets") + +ParsedAnswer = Tuple[str, Optional[ExecutionResult]] + + +class ListenerQueue: + """Special queue used for each query waiting for server answers + + If the server is stopped while the listener is still waiting, + Then we send an exception to the queue and this exception will be raised + to the consumer once all the previous messages have been consumed from the queue + """ + + def __init__(self, query_id: int, send_stop: bool) -> None: + self.query_id: int = query_id + self.send_stop: bool = send_stop + self._queue: asyncio.Queue = asyncio.Queue() + self._closed: bool = False + + async def get(self) -> ParsedAnswer: + + try: + item = self._queue.get_nowait() + except asyncio.QueueEmpty: + item = await self._queue.get() + + self._queue.task_done() + + # If we receive an exception when reading the queue, we raise it + if isinstance(item, Exception): + self._closed = True + raise item + + # Don't need to save new answers or + # send the stop message if we already received the complete message + answer_type, execution_result = item + if answer_type == "complete": + self.send_stop = False + self._closed = True + + return item + + async def put(self, item: ParsedAnswer) -> None: + + if not self._closed: + await self._queue.put(item) + + async def set_exception(self, exception: Exception) -> None: + + # Put the exception in the queue + await self._queue.put(exception) + + # Don't need to send stop messages in case of error + self.send_stop = False + self._closed = True + + +class WebsocketsTransportBase(AsyncTransport): + """abstract :ref:`Async Transport ` used to implement + different websockets protocols. + + This transport uses asyncio and the websockets library in order to send requests + on a websocket connection. + """ + + def __init__( + self, + url: str, + headers: Optional[HeadersLike] = None, + ssl: Union[SSLContext, bool] = False, + init_payload: Dict[str, Any] = {}, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + connect_args: Dict[str, Any] = {}, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param init_payload: Dict of the payload sent in the connection_init message. + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param connect_args: Other parameters forwarded to websockets.connect + """ + + self.url: str = url + self.headers: Optional[HeadersLike] = headers + self.ssl: Union[SSLContext, bool] = ssl + self.init_payload: Dict[str, Any] = init_payload + + self.connect_timeout: Optional[Union[int, float]] = connect_timeout + self.close_timeout: Optional[Union[int, float]] = close_timeout + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout + + self.connect_args = connect_args + + self.websocket: Optional[WebSocketClientProtocol] = None + self.next_query_id: int = 1 + self.listeners: Dict[int, ListenerQueue] = {} + + self.receive_data_task: Optional[asyncio.Future] = None + self.check_keep_alive_task: Optional[asyncio.Future] = None + self.close_task: Optional[asyncio.Future] = None + + # We need to set an event loop here if there is none + # Or else we will not be able to create an asyncio.Event() + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + self._wait_closed: asyncio.Event = asyncio.Event() + self._wait_closed.set() + + self._no_more_listeners: asyncio.Event = asyncio.Event() + self._no_more_listeners.set() + + if self.keep_alive_timeout is not None: + self._next_keep_alive_message: asyncio.Event = asyncio.Event() + self._next_keep_alive_message.set() + + self.payloads: Dict[str, Any] = {} + """payloads is a dict which will contain the payloads received + for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" + + self._connecting: bool = False + + self.close_exception: Optional[Exception] = None + + # The list of supported subprotocols should be defined in the subclass + self.supported_subprotocols: List[Subprotocol] = [] + + async def _initialize(self): + """Hook to send the initialization messages after the connection + and potentially wait for the backend ack. + """ + pass # pragma: no cover + + async def _stop_listener(self, query_id: int): + """Hook to stop to listen to a specific query. + Will send a stop message in some subclasses. + """ + pass # pragma: no cover + + async def _after_connect(self): + """Hook to add custom code for subclasses after the connection + has been established. + """ + pass # pragma: no cover + + async def _after_initialize(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + pass # pragma: no cover + + async def _close_hook(self): + """Hook to add custom code for subclasses for the connection close + """ + pass # pragma: no cover + + async def _connection_terminate(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + pass # pragma: no cover + + async def _send(self, message: str) -> None: + """Send the provided message to the websocket connection and log the message""" + + if not self.websocket: + raise TransportClosed( + "Transport is not connected" + ) from self.close_exception + + try: + await self.websocket.send(message) + log.info(">>> %s", message) + except ConnectionClosed as e: + await self._fail(e, clean_close=False) + raise e + + async def _receive(self) -> str: + """Wait the next message from the websocket connection and log the answer""" + + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportClosed("Transport is already closed") + + # Wait for the next websocket frame. Can raise ConnectionClosed + data: Data = await self.websocket.recv() + + # websocket.recv() can return either str or bytes + # In our case, we should receive only str here + if not isinstance(data, str): + raise TransportProtocolError("Binary data received in the websocket") + + answer: str = data + + log.info("<<< %s", answer) + + return answer + + @abstractmethod + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + raise NotImplementedError # pragma: no cover + + @abstractmethod + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + raise NotImplementedError # pragma: no cover + + async def _check_ws_liveness(self) -> None: + """Coroutine which will periodically check the liveness of the connection + through keep-alive messages + """ + + try: + while True: + await asyncio.wait_for( + self._next_keep_alive_message.wait(), self.keep_alive_timeout + ) + + # Reset for the next iteration + self._next_keep_alive_message.clear() + + except asyncio.TimeoutError: + # No keep-alive message in the appriopriate interval, close with error + # while trying to notify the server of a proper close (in case + # the keep-alive interval of the client or server was not aligned + # the connection still remains) + + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + "No keep-alive message has been received within " + "the expected interval ('keep_alive_timeout' parameter)" + ), + clean_close=False, + ) + + except asyncio.CancelledError: + # The client is probably closing, handle it properly + pass + + async def _receive_data_loop(self) -> None: + """Main asyncio task which will listen to the incoming messages and will + call the parse_answer and handle_answer methods of the subclass.""" + try: + while True: + + # Wait the next answer from the websocket server + try: + answer = await self._receive() + except (ConnectionClosed, TransportProtocolError) as e: + await self._fail(e, clean_close=False) + break + except TransportClosed: + break + + # Parse the answer + try: + answer_type, answer_id, execution_result = self._parse_answer( + answer + ) + except TransportQueryError as e: + # Received an exception for a specific query + # ==> Add an exception to this query queue + # The exception is raised for this specific query, + # but the transport is not closed. + assert isinstance( + e.query_id, int + ), "TransportQueryError should have a query_id defined here" + try: + await self.listeners[e.query_id].set_exception(e) + except KeyError: + # Do nothing if no one is listening to this query_id + pass + + continue + + except (TransportServerError, TransportProtocolError) as e: + # Received a global exception for this transport + # ==> close the transport + # The exception will be raised for all current queries. + await self._fail(e, clean_close=False) + break + + await self._handle_answer(answer_type, answer_id, execution_result) + + finally: + log.debug("Exiting _receive_data_loop()") + + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: + + try: + # Put the answer in the queue + if answer_id is not None: + await self.listeners[answer_id].put((answer_type, execution_result)) + except KeyError: + # Do nothing if no one is listening to this query_id. + pass + + async def subscribe( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + send_stop: Optional[bool] = True, + ) -> AsyncGenerator[ExecutionResult, None]: + """Send a query and receive the results using a python async generator. + + The query can be a graphql query, mutation or subscription. + + The results are sent as an ExecutionResult object. + """ + + # Send the query and receive the id + query_id: int = await self._send_query( + document, variable_values, operation_name + ) + + # Create a queue to receive the answers for this query_id + listener = ListenerQueue(query_id, send_stop=(send_stop is True)) + self.listeners[query_id] = listener + + # We will need to wait at close for this query to clean properly + self._no_more_listeners.clear() + + try: + # Loop over the received answers + while True: + + # Wait for the answer from the queue of this query_id + # This can raise a TransportError or ConnectionClosed exception. + answer_type, execution_result = await listener.get() + + # If the received answer contains data, + # Then we will yield the results back as an ExecutionResult object + if execution_result is not None: + yield execution_result + + # If we receive a 'complete' answer from the server, + # Then we will end this async generator output without errors + elif answer_type == "complete": + log.debug( + f"Complete received for query {query_id} --> exit without error" + ) + break + + except (asyncio.CancelledError, GeneratorExit) as e: + log.debug(f"Exception in subscribe: {e!r}") + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False + + finally: + log.debug(f"In subscribe finally for query_id {query_id}") + self._remove_listener(query_id) + + async def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> ExecutionResult: + """Execute the provided document AST against the configured remote server + using the current session. + + Send a query but close the async generator as soon as we have the first answer. + + The result is sent as an ExecutionResult object. + """ + first_result = None + + generator = self.subscribe( + document, variable_values, operation_name, send_stop=False + ) + + async for result in generator: + first_result = result + + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + await generator.aclose() + + break + + if first_result is None: + raise TransportQueryError( + "Query completed without any answer received from the server" + ) + + return first_result + + async def connect(self) -> None: + """Coroutine which will: + + - connect to the websocket address + - send the init message + - wait for the connection acknowledge from the server + - create an asyncio task which will be used to receive + and parse the websocket answers + + Should be cleaned with a call to the close coroutine + """ + + log.debug("connect: starting") + + if self.websocket is None and not self._connecting: + + # Set connecting to True to avoid a race condition if user is trying + # to connect twice using the same client at the same time + self._connecting = True + + # If the ssl parameter is not provided, + # generate the ssl value depending on the url + ssl: Optional[Union[SSLContext, bool]] + if self.ssl: + ssl = self.ssl + else: + ssl = True if self.url.startswith("wss") else None + + # Set default arguments used in the websockets.connect call + connect_args: Dict[str, Any] = { + "ssl": ssl, + "extra_headers": self.headers, + "subprotocols": self.supported_subprotocols, + } + + # Adding custom parameters passed from init + connect_args.update(self.connect_args) + + # Connection to the specified url + # Generate a TimeoutError if taking more than connect_timeout seconds + # Set the _connecting flag to False after in all cases + try: + self.websocket = await asyncio.wait_for( + websockets.client.connect(self.url, **connect_args), + self.connect_timeout, + ) + finally: + self._connecting = False + + self.websocket = cast(WebSocketClientProtocol, self.websocket) + + # Run the after_connect hook of the subclass + await self._after_connect() + + self.next_query_id = 1 + self.close_exception = None + self._wait_closed.clear() + + # Send the init message and wait for the ack from the server + # Note: This should generate a TimeoutError + # if no ACKs are received within the ack_timeout + try: + await self._initialize() + except ConnectionClosed as e: + raise e + except (TransportProtocolError, asyncio.TimeoutError) as e: + await self._fail(e, clean_close=False) + raise e + + # Run the after_init hook of the subclass + await self._after_initialize() + + # If specified, create a task to check liveness of the connection + # through keep-alive messages + if self.keep_alive_timeout is not None: + self.check_keep_alive_task = asyncio.ensure_future( + self._check_ws_liveness() + ) + + # Create a task to listen to the incoming websocket messages + self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) + + else: + raise TransportAlreadyConnected("Transport is already connected") + + log.debug("connect: done") + + def _remove_listener(self, query_id) -> None: + """After exiting from a subscription, remove the listener and + signal an event if this was the last listener for the client. + """ + if query_id in self.listeners: + del self.listeners[query_id] + + remaining = len(self.listeners) + log.debug(f"listener {query_id} deleted, {remaining} remaining") + + if remaining == 0: + self._no_more_listeners.set() + + async def _clean_close(self, e: Exception) -> None: + """Coroutine which will: + + - send stop messages for each active subscription to the server + - send the connection terminate message + """ + + # Send 'stop' message for all current queries + for query_id, listener in self.listeners.items(): + + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False + + # Wait that there is no more listeners (we received 'complete' for all queries) + try: + await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) + except asyncio.TimeoutError: # pragma: no cover + log.debug("Timer close_timeout fired") + + # Calling the subclass hook + await self._connection_terminate() + + async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: + """Coroutine which will: + + - do a clean_close if possible: + - send stop messages for each active query to the server + - send the connection terminate message + - close the websocket connection + - send the exception to all the remaining listeners + """ + + log.debug("_close_coro: starting") + + try: + + # We should always have an active websocket connection here + assert self.websocket is not None + + # Properly shut down liveness checker if enabled + if self.check_keep_alive_task is not None: + # More info: https://stackoverflow.com/a/43810272/1113207 + self.check_keep_alive_task.cancel() + with suppress(asyncio.CancelledError): + await self.check_keep_alive_task + + # Calling the subclass close hook + await self._close_hook() + + # Saving exception to raise it later if trying to use the transport + # after it has already closed. + self.close_exception = e + + if clean_close: + log.debug("_close_coro: starting clean_close") + try: + await self._clean_close(e) + except Exception as exc: # pragma: no cover + log.warning("Ignoring exception in _clean_close: " + repr(exc)) + + log.debug("_close_coro: sending exception to listeners") + + # Send an exception to all remaining listeners + for query_id, listener in self.listeners.items(): + await listener.set_exception(e) + + log.debug("_close_coro: close websocket connection") + + await self.websocket.close() + + log.debug("_close_coro: websocket connection closed") + + except Exception as exc: # pragma: no cover + log.warning("Exception catched in _close_coro: " + repr(exc)) + + finally: + + log.debug("_close_coro: start cleanup") + + self.websocket = None + self.close_task = None + self.check_keep_alive_task = None + self._wait_closed.set() + + log.debug("_close_coro: exiting") + + async def _fail(self, e: Exception, clean_close: bool = True) -> None: + log.debug("_fail: starting with exception: " + repr(e)) + + if self.close_task is None: + + if self.websocket is None: + log.debug("_fail started with self.websocket == None -> already closed") + else: + self.close_task = asyncio.shield( + asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) + ) + else: + log.debug( + "close_task is not None in _fail. Previous exception is: " + + repr(self.close_exception) + + " New exception is: " + + repr(e) + ) + + async def close(self) -> None: + log.debug("close: starting") + + await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) + await self.wait_closed() + + log.debug("close: done") + + async def wait_closed(self) -> None: + log.debug("wait_close: starting") + + await self._wait_closed.wait() + + log.debug("wait_close: done") From 2f199f077679b9e8d5cfe39409c7e686b35e10f4 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 2 Dec 2021 18:16:20 +0100 Subject: [PATCH 45/68] Disallow execute on appsync transport + change subscribe doc message --- gql/transport/appsync.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/gql/transport/appsync.py b/gql/transport/appsync.py index 8fac48c4..8850f586 100644 --- a/gql/transport/appsync.py +++ b/gql/transport/appsync.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from base64 import b64encode from ssl import SSLContext -from typing import Any, Callable, Coroutine, Dict, Optional, Tuple, Union, cast +from typing import Any, Callable, Dict, Optional, Tuple, Union, cast from urllib.parse import urlparse import botocore.session @@ -335,11 +335,36 @@ async def _send_query( await self._send(json.dumps(message, separators=(",", ":"),)) - print(Coroutine) return query_id + subscribe = WebsocketsTransportBase.subscribe + """Send a subscription query and receive the results using + a python async generator. + + Only subscriptions are supported, queries and mutations are forbidden. + + The results are sent as an ExecutionResult object. + """ + + async def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> ExecutionResult: + """This method is not available. + + Only subscriptions are supported on the AWS realtime endpoint. + + :raise: AssertionError""" + raise AssertionError( + "execute method is not allowed for AppSyncWebsocketsTransport " + "because only subscriptions are allowed on the realtime endpoint. " + "fetch_schema_from_transport should be set to False in the client!" + ) + _initialize = WebsocketsTransport._initialize - _stop_listener = WebsocketsTransport._stop_listener + _stop_listener = WebsocketsTransport._stop_listener # type: ignore _send_init_message_and_wait_ack = ( WebsocketsTransport._send_init_message_and_wait_ack ) From 9fd9fafe930d54bc6cd8f460d18f9a0c350f6320 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 2 Dec 2021 18:35:22 +0100 Subject: [PATCH 46/68] Rename extra dependency from appsyncwebsockets to appsync + add to installation docs --- docs/intro.rst | 24 +++++++++++++----------- setup.py | 6 +++--- tests/conftest.py | 2 +- tests/test_appsyncwebsocket.py | 4 ++-- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/docs/intro.rst b/docs/intro.rst index e377c56e..00c6f87d 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -37,17 +37,19 @@ which needs the :code:`aiohttp` dependency, then you can install GQL with:: The corresponding between extra dependencies required and the GQL transports is: -+-------------------+----------------------------------------------------------------+ -| Extra dependency | Transports | -+===================+================================================================+ -| aiohttp | :ref:`AIOHTTPTransport ` | -+-------------------+----------------------------------------------------------------+ -| websockets | :ref:`WebsocketsTransport ` | -| | | -| | :ref:`PhoenixChannelWebsocketsTransport ` | -+-------------------+----------------------------------------------------------------+ -| requests | :ref:`RequestsHTTPTransport ` | -+-------------------+----------------------------------------------------------------+ ++---------------------+----------------------------------------------------------------+ +| Extra dependencies | Transports | ++=====================+================================================================+ +| aiohttp | :ref:`AIOHTTPTransport ` | ++---------------------+----------------------------------------------------------------+ +| websockets | :ref:`WebsocketsTransport ` | +| | | +| | :ref:`PhoenixChannelWebsocketsTransport ` | ++---------------------+----------------------------------------------------------------+ +| requests | :ref:`RequestsHTTPTransport ` | ++---------------------+----------------------------------------------------------------+ +| appsync, websockets | :ref:`AppSyncWebsocketsTransport ` | ++---------------------+----------------------------------------------------------------+ .. note:: diff --git a/setup.py b/setup.py index f30e230f..041881d3 100644 --- a/setup.py +++ b/setup.py @@ -50,12 +50,12 @@ "websockets>=10,<11;python_version>'3.6'", ] -install_appsyncwebsockets_requires = [ +install_appsync_requires = [ "botocore>=1.21,<2", ] + install_websockets_requires install_all_requires = ( - install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_appsyncwebsockets_requires + install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_appsync_requires ) # Get version from __version__.py file @@ -99,7 +99,7 @@ "aiohttp": install_aiohttp_requires, "requests": install_requests_requires, "websockets": install_websockets_requires, - "appsyncwebsockets": install_appsyncwebsockets_requires, + "appsync": install_appsync_requires, }, include_package_data=True, zip_safe=False, diff --git a/tests/conftest.py b/tests/conftest.py index cef4eae6..ccedd7aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ from gql import Client -all_transport_dependencies = ["aiohttp", "requests", "websockets", "appsyncwebsockets"] +all_transport_dependencies = ["aiohttp", "requests", "websockets", "appsync"] def pytest_addoption(parser): diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index f5d03079..9b2deb0d 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -1,7 +1,7 @@ import pytest -# Marking all tests in this file with the appsyncwebsockets marker -pytestmark = pytest.mark.appsyncwebsockets +# Marking all tests in this file with the appsync marker +pytestmark = pytest.mark.appsync mock_transport_url = "https://appsyncapp.awsgateway.com.example.org" From c5408bb1fa5cf0e2e322faf891a1ac0fc0538889 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 2 Dec 2021 19:11:18 +0100 Subject: [PATCH 47/68] Rename Authorization to Authentication --- docs/code_examples/aws_api_key_mutation.py | 4 +- .../code_examples/aws_api_key_subscription.py | 9 ++- docs/transports/appsync.rst | 10 +-- gql/transport/appsync.py | 48 ++++++------- tests/test_appsyncwebsocket.py | 68 ++++++++----------- 5 files changed, 65 insertions(+), 74 deletions(-) diff --git a/docs/code_examples/aws_api_key_mutation.py b/docs/code_examples/aws_api_key_mutation.py index 9fa5f3cc..a3fcc2f0 100644 --- a/docs/code_examples/aws_api_key_mutation.py +++ b/docs/code_examples/aws_api_key_mutation.py @@ -6,7 +6,7 @@ from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport -from gql.transport.appsync import AppSyncApiKeyAuthorization +from gql.transport.appsync import AppSyncApiKeyAuthentication logging.basicConfig(level=logging.DEBUG) @@ -25,7 +25,7 @@ async def main(): # Extract host from url host = str(urlparse(url).netloc) - auth = AppSyncApiKeyAuthorization(host=host, api_key=api_key) + auth = AppSyncApiKeyAuthentication(host=host, api_key=api_key) transport = AIOHTTPTransport(url=url, headers=auth.get_headers()) diff --git a/docs/code_examples/aws_api_key_subscription.py b/docs/code_examples/aws_api_key_subscription.py index 3e994911..06d14b67 100644 --- a/docs/code_examples/aws_api_key_subscription.py +++ b/docs/code_examples/aws_api_key_subscription.py @@ -5,7 +5,10 @@ from urllib.parse import urlparse from gql import Client, gql -from gql.transport.appsync import AppSyncApiKeyAuthorization, AppSyncWebsocketsTransport +from gql.transport.appsync import ( + AppSyncApiKeyAuthentication, + AppSyncWebsocketsTransport, +) logging.basicConfig(level=logging.DEBUG) @@ -26,9 +29,9 @@ async def main(): print(f"Host: {host}") - auth = AppSyncApiKeyAuthorization(host=host, api_key=api_key) + auth = AppSyncApiKeyAuthentication(host=host, api_key=api_key) - transport = AppSyncWebsocketsTransport(url=url, authorization=auth) + transport = AppSyncWebsocketsTransport(url=url, auth=auth) async with Client(transport=transport) as session: diff --git a/docs/transports/appsync.rst b/docs/transports/appsync.rst index 41e15a35..8f48a232 100644 --- a/docs/transports/appsync.rst +++ b/docs/transports/appsync.rst @@ -20,7 +20,7 @@ How to use it: .. note:: It is also possible to instantiate the transport without an auth argument. In that case, - gql will use by default the :class:`IAM auth ` + gql will use by default the :class:`IAM auth ` which will try to authenticate with environment variables or from your aws credentials file. Full example with API key authentication from environment variables: @@ -38,21 +38,21 @@ Authentication methods API key ^^^^^^^ -Reference: :class:`gql.transport.appsync.AppSyncApiKeyAuthorization` +Reference: :class:`gql.transport.appsync.AppSyncApiKeyAuthentication` IAM ^^^ -Reference: :class:`gql.transport.appsync.AppSyncIAMAuthorization` +Reference: :class:`gql.transport.appsync.AppSyncIAMAuthentication` Amazon Cognito user pools ^^^^^^^^^^^^^^^^^^^^^^^^^ -Reference: :class:`gql.transport.appsync.AppSyncOIDCAuthorization` +Reference: :class:`gql.transport.appsync.AppSyncOIDCAuthentication` OpenID Connect (OIDC) ^^^^^^^^^^^^^^^^^^^^^ -Reference: :class:`gql.transport.appsync.AppSyncCognitoUserPoolAuthorization` +Reference: :class:`gql.transport.appsync.AppSyncCognitoUserPoolAuthentication` .. _Building a real-time websocket client: https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html diff --git a/gql/transport/appsync.py b/gql/transport/appsync.py index 8850f586..ef73684e 100644 --- a/gql/transport/appsync.py +++ b/gql/transport/appsync.py @@ -20,11 +20,11 @@ log = logging.getLogger(__name__) -class AppSyncAuthorization(ABC): - """AWS authorization abstract base class +class AppSyncAuthentication(ABC): + """AWS authentication abstract base class - All AWS authorization class should have a - :meth:`get_headers ` + All AWS authentication class should have a + :meth:`get_headers ` method which defines the headers used in the authentication process.""" def get_auth_url(self, url: str) -> str: @@ -49,8 +49,8 @@ def get_headers(self, data: Optional[str] = None) -> Dict: raise NotImplementedError() -class AppSyncApiKeyAuthorization(AppSyncAuthorization): - """AWS authorization class using an API key""" +class AppSyncApiKeyAuthentication(AppSyncAuthentication): + """AWS authentication class using an API key""" def __init__(self, host: str, api_key: str) -> None: """ @@ -65,8 +65,8 @@ def get_headers(self, data: Optional[str] = None) -> Dict: return {"host": self._host, "x-api-key": self.api_key} -class AppSyncOIDCAuthorization(AppSyncAuthorization): - """AWS authorization class using an OpenID JWT access token""" +class AppSyncOIDCAuthentication(AppSyncAuthentication): + """AWS authentication class using an OpenID JWT access token""" def __init__(self, host: str, jwt: str) -> None: """ @@ -81,14 +81,14 @@ def get_headers(self, data: Optional[str] = None) -> Dict: return {"host": self._host, "Authorization": self.jwt} -class AppSyncCognitoUserPoolAuthorization(AppSyncOIDCAuthorization): - """AWS authorization class using a Cognito user pools JWT access token""" +class AppSyncCognitoUserPoolAuthentication(AppSyncOIDCAuthentication): + """AWS authentication class using a Cognito user pools JWT access token""" pass -class AppSyncIAMAuthorization(AppSyncAuthorization): - """AWS authorization class using IAM. +class AppSyncIAMAuthentication(AppSyncAuthentication): + """AWS authentication class using IAM. .. note:: There is no need for you to use this class directly, you could instead @@ -175,12 +175,12 @@ class AppSyncWebsocketsTransport(WebsocketsTransportBase): on a websocket connection. """ - authorization: Optional[AppSyncAuthorization] + auth: Optional[AppSyncAuthentication] def __init__( self, url: str, - authorization: Optional[AppSyncAuthorization] = None, + auth: Optional[AppSyncAuthentication] = None, session: Optional[botocore.session.Session] = None, ssl: Union[SSLContext, bool] = False, connect_timeout: int = 10, @@ -192,10 +192,10 @@ def __init__( :param url: The GraphQL endpoint URL. Example: https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql - :param authorization: Optional AWS authorization class which will provide the - necessary headers to be correctly authenticated. If this - argument is not provided, then we will try to authenticate - using IAM. + :param auth: Optional AWS authentication class which will provide the + necessary headers to be correctly authenticated. If this + argument is not provided, then we will try to authenticate + using IAM. :param ssl: ssl_context of the connection. :param connect_timeout: Timeout in seconds for the establishment of the websocket connection. If None is provided this will wait forever. @@ -206,16 +206,16 @@ def __init__( :param connect_args: Other parameters forwarded to websockets.connect """ try: - if not authorization: + if not auth: # Extract host from url host = str(urlparse(url).netloc) - authorization = AppSyncIAMAuthorization(host=host, session=session) + auth = AppSyncIAMAuthentication(host=host, session=session) - self.authorization = authorization + self.auth = auth - url = self.authorization.get_auth_url(url) + url = self.auth.get_auth_url(url) except NoCredentialsError: log.warning( @@ -327,10 +327,10 @@ async def _send_query( "payload": payload, } - assert self.authorization is not None + assert self.auth is not None message["payload"]["extensions"] = { - "authorization": self.authorization.get_headers(serialized_data) + "authorization": self.auth.get_headers(serialized_data) } await self._send(json.dumps(message, separators=(",", ":"),)) diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 9b2deb0d..20a6f348 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -8,14 +8,14 @@ def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): from gql.transport.appsync import ( - AppSyncIAMAuthorization, + AppSyncIAMAuthentication, AppSyncWebsocketsTransport, ) sample_transport = AppSyncWebsocketsTransport( url=mock_transport_url, session=fake_session_factory() ) - assert isinstance(sample_transport.authorization, AppSyncIAMAuthorization) + assert isinstance(sample_transport.auth, AppSyncIAMAuthentication) assert sample_transport.connect_timeout == 10 assert sample_transport.close_timeout == 10 assert sample_transport.ack_timeout == 10 @@ -31,7 +31,7 @@ def test_appsyncwebsocket_init_with_no_credentials(caplog, fake_session_factory) sample_transport = AppSyncWebsocketsTransport( url=mock_transport_url, session=fake_session_factory(credentials=None), ) - assert sample_transport.authorization is None + assert sample_transport.auth is None expected_error = "Credentials not found" @@ -42,79 +42,69 @@ def test_appsyncwebsocket_init_with_no_credentials(caplog, fake_session_factory) def test_appsyncwebsocket_init_with_oidc_auth(): from gql.transport.appsync import ( - AppSyncOIDCAuthorization, + AppSyncOIDCAuthentication, AppSyncWebsocketsTransport, ) - authorization = AppSyncOIDCAuthorization(host=mock_transport_url, jwt="some-jwt") - sample_transport = AppSyncWebsocketsTransport( - url=mock_transport_url, authorization=authorization - ) - assert sample_transport.authorization is authorization + auth = AppSyncOIDCAuthentication(host=mock_transport_url, jwt="some-jwt") + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert sample_transport.auth is auth def test_appsyncwebsocket_init_with_apikey_auth(): from gql.transport.appsync import ( - AppSyncApiKeyAuthorization, + AppSyncApiKeyAuthentication, AppSyncWebsocketsTransport, ) - authorization = AppSyncApiKeyAuthorization( - host=mock_transport_url, api_key="some-api-key" - ) - sample_transport = AppSyncWebsocketsTransport( - url=mock_transport_url, authorization=authorization - ) - assert sample_transport.authorization is authorization + auth = AppSyncApiKeyAuthentication(host=mock_transport_url, api_key="some-api-key") + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert sample_transport.auth is auth def test_appsyncwebsocket_init_with_iam_auth_without_creds(fake_session_factory): import botocore.exceptions from gql.transport.appsync import ( - AppSyncIAMAuthorization, + AppSyncIAMAuthentication, AppSyncWebsocketsTransport, ) - authorization = AppSyncIAMAuthorization( + auth = AppSyncIAMAuthentication( host=mock_transport_url, session=fake_session_factory(credentials=None), ) with pytest.raises(botocore.exceptions.NoCredentialsError): - AppSyncWebsocketsTransport(url=mock_transport_url, authorization=authorization) + AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory): from gql.transport.appsync import ( - AppSyncIAMAuthorization, + AppSyncIAMAuthentication, AppSyncWebsocketsTransport, ) - authorization = AppSyncIAMAuthorization( + auth = AppSyncIAMAuthentication( host=mock_transport_url, credentials=fake_credentials_factory(), region_name="us-east-1", ) - sample_transport = AppSyncWebsocketsTransport( - url=mock_transport_url, authorization=authorization - ) - assert sample_transport.authorization is authorization + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert sample_transport.auth is auth def test_appsyncwebsocket_init_with_iam_auth_and_no_region( caplog, fake_credentials_factory ): from gql.transport.appsync import ( - AppSyncIAMAuthorization, + AppSyncIAMAuthentication, AppSyncWebsocketsTransport, ) with pytest.raises(TypeError): - authorization = AppSyncIAMAuthorization( + auth = AppSyncIAMAuthentication( host=mock_transport_url, credentials=fake_credentials_factory() ) - sample_transport = AppSyncWebsocketsTransport( - url=mock_transport_url, authorization=authorization - ) - assert sample_transport.authorization is None + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert sample_transport.auth is None print(f"Captured: {caplog.text}") @@ -125,20 +115,18 @@ def test_appsyncwebsocket_init_with_iam_auth_and_no_region( def test_munge_url(fake_signer_factory, fake_request_factory): from gql.transport.appsync import ( - AppSyncIAMAuthorization, + AppSyncIAMAuthentication, AppSyncWebsocketsTransport, ) test_url = "https://appsync-api.aws.example.org/some-other-params" - authorization = AppSyncIAMAuthorization( + auth = AppSyncIAMAuthentication( host=test_url, signer=fake_signer_factory(), request_creator=fake_request_factory, ) - sample_transport = AppSyncWebsocketsTransport( - url=test_url, authorization=authorization - ) + sample_transport = AppSyncWebsocketsTransport(url=test_url, auth=auth) header_string = ( "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5" @@ -158,11 +146,11 @@ def test_munge_url_format( fake_credentials_factory, fake_session_factory, ): - from gql.transport.appsync import AppSyncIAMAuthorization + from gql.transport.appsync import AppSyncIAMAuthentication test_url = "https://appsync-api.aws.example.org/some-other-params" - authorization = AppSyncIAMAuthorization( + auth = AppSyncIAMAuthentication( host=test_url, signer=fake_signer_factory(), session=fake_session_factory(), @@ -179,4 +167,4 @@ def test_munge_url_format( "wss://appsync-realtime-api.aws.example.org/" f"some-other-params?header={header_string}&payload=e30=" ) - assert authorization.get_auth_url(test_url) == expected_url + assert auth.get_auth_url(test_url) == expected_url From 3b181587ca8f56062fe4eb7b88ef8488593839aa Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 3 Dec 2021 17:51:55 +0100 Subject: [PATCH 48/68] Add simulated AppSync backend tests Add keep_alive_timeout parameter Add error message if user is trying to use the transport with fetch_schema_from_transport=True --- gql/client.py | 6 + gql/transport/appsync.py | 13 +- tests/conftest.py | 2 +- tests/test_appsync_subscription.py | 706 +++++++++++++++++++++++++++++ 4 files changed, 719 insertions(+), 8 deletions(-) create mode 100644 tests/test_appsync_subscription.py diff --git a/gql/client.py b/gql/client.py index 111a3dd7..1f57e18a 100644 --- a/gql/client.py +++ b/gql/client.py @@ -82,6 +82,12 @@ def __init__( not schema ), "Cannot fetch the schema from transport if is already provided." + assert not type(transport).__name__ == "AppSyncWebsocketsTransport", ( + "fetch_schema_from_transport=True is not allowed " + "for AppSyncWebsocketsTransport " + "because only subscriptions are allowed on the realtime endpoint." + ) + if schema and not transport: transport = LocalSchemaTransport(schema) diff --git a/gql/transport/appsync.py b/gql/transport/appsync.py index ef73684e..ea9da9d1 100644 --- a/gql/transport/appsync.py +++ b/gql/transport/appsync.py @@ -46,7 +46,7 @@ def get_auth_url(self, url: str) -> str: @abstractmethod def get_headers(self, data: Optional[str] = None) -> Dict: - raise NotImplementedError() + raise NotImplementedError() # pragma: no cover class AppSyncApiKeyAuthentication(AppSyncAuthentication): @@ -186,6 +186,7 @@ def __init__( connect_timeout: int = 10, close_timeout: int = 10, ack_timeout: int = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given parameters. @@ -203,6 +204,8 @@ def __init__( this will wait forever. :param ack_timeout: Timeout in seconds to wait for the connection_ack message from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. :param connect_args: Other parameters forwarded to websockets.connect """ try: @@ -237,6 +240,7 @@ def __init__( connect_timeout=connect_timeout, close_timeout=close_timeout, ack_timeout=ack_timeout, + keep_alive_timeout=keep_alive_timeout, connect_args=connect_args, ) @@ -270,8 +274,6 @@ def _parse_answer( """ answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None try: json_answer = json.loads(answer) @@ -296,8 +298,6 @@ def _parse_answer( f"Server did not return a GraphQL result: {answer}" ) - return answer_type, answer_id, execution_result - async def _send_query( self, document: DocumentNode, @@ -359,8 +359,7 @@ async def execute( :raise: AssertionError""" raise AssertionError( "execute method is not allowed for AppSyncWebsocketsTransport " - "because only subscriptions are allowed on the realtime endpoint. " - "fetch_schema_from_transport should be set to False in the client!" + "because only subscriptions are allowed on the realtime endpoint." ) _initialize = WebsocketsTransport._initialize diff --git a/tests/conftest.py b/tests/conftest.py index ccedd7aa..a2653027 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -116,7 +116,7 @@ async def ssl_aiohttp_server(): for name in [ "websockets.legacy.server", "gql.transport.aiohttp", - "gql.transport.awsappsync", + "gql.transport.appsync", "gql.transport.phoenix_channel_websockets", "gql.transport.requests", "gql.transport.websockets", diff --git a/tests/test_appsync_subscription.py b/tests/test_appsync_subscription.py new file mode 100644 index 00000000..9d897c12 --- /dev/null +++ b/tests/test_appsync_subscription.py @@ -0,0 +1,706 @@ +import asyncio +import json +from base64 import b64decode +from typing import List +from urllib import parse + +import pytest + +from gql import Client, gql + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the appsync marker +pytestmark = pytest.mark.appsync + +SEND_MESSAGE_DELAY = 20 * MS +NB_MESSAGES = 10 + +DUMMY_API_KEY = "da2-thisisadummyapikey01234567" +DUMMY_ACCESS_KEY_ID = "DUMMYACCESSKEYID0123" +DUMMY_ACCESS_KEY_ID_NOT_ALLOWED = "DUMMYACCESSKEYID!ALL" +DUMMY_ACCESS_KEY_IDS = [DUMMY_ACCESS_KEY_ID, DUMMY_ACCESS_KEY_ID_NOT_ALLOWED] +DUMMY_SECRET_ACCESS_KEY = "ThisIsADummySecret0123401234012340123401" +DUMMY_SECRET_SESSION_TOKEN = ( + "FwoREDACTEDzEREDACTED+YREDACTEDJLREDACTEDz2REDACTEDH5RE" + "DACTEDbVREDACTEDqwREDACTEDHJREDACTEDxFREDACTEDtMREDACTED5kREDACTEDSwREDACTED0BRED" + "ACTEDuDREDACTEDm4REDACTEDSBREDACTEDaoREDACTEDP2REDACTEDCBREDACTED0wREDACTEDmdREDA" + "CTEDyhREDACTEDSKREDACTEDYbREDACTEDfeREDACTED3UREDACTEDaKREDACTEDi1REDACTEDGEREDAC" + "TED4VREDACTEDjmREDACTEDYcREDACTEDkQREDACTEDyI=" +) + +# List which can used to store received messages by the server +logged_messages: List[str] = [] + + +def realtime_appsync_server_factory( + keepalive=False, not_json_answer=False, error_without_id=False +): + def verify_headers(headers, in_query=False): + """Returns an error or None if all is ok""" + + if "x-api-key" in headers: + print("API KEY Authentication detected!") + + if headers["x-api-key"] == DUMMY_API_KEY: + return None + + elif "Authorization" in headers: + if "X-Amz-Security-Token" in headers: + with_token = True + print("IAM Authentication with token detected!") + else: + with_token = False + print("IAM Authentication with token detected!") + print("IAM Authentication without token detected!") + + assert headers["accept"] == "application/json, text/javascript" + assert headers["content-encoding"] == "amz-1.0" + assert headers["content-type"] == "application/json; charset=UTF-8" + assert "X-Amz-Date" in headers + + authorization_fields = headers["Authorization"].split(" ") + + assert authorization_fields[0] == "AWS4-HMAC-SHA256" + + credential_field = authorization_fields[1][:-1].split("=") + assert credential_field[0] == "Credential" + credential_content = credential_field[1].split("/") + assert credential_content[0] in DUMMY_ACCESS_KEY_IDS + + if in_query: + if credential_content[0] == DUMMY_ACCESS_KEY_ID_NOT_ALLOWED: + return { + "errorType": "UnauthorizedException", + "message": "Permission denied", + } + + # assert credential_content[1]== date + # assert credential_content[2]== region + assert credential_content[3] == "appsync" + assert credential_content[4] == "aws4_request" + + signed_headers_field = authorization_fields[2][:-1].split("=") + + assert signed_headers_field[0] == "SignedHeaders" + signed_headers = signed_headers_field[1].split(";") + + assert "accept" in signed_headers + assert "content-encoding" in signed_headers + assert "content-type" in signed_headers + assert "host" in signed_headers + assert "x-amz-date" in signed_headers + + if with_token: + assert "x-amz-security-token" in signed_headers + + signature_field = authorization_fields[3].split("=") + + assert signature_field[0] == "Signature" + + return None + + return { + "errorType": "com.amazonaws.deepdish.graphql.auth#UnauthorizedException", + "message": "You are not authorized to make this call.", + "errorCode": 400, + } + + async def realtime_appsync_server_template(ws, path): + import websockets + + logged_messages.clear() + + try: + if not_json_answer: + await ws.send("Something not json") + return + + if error_without_id: + await ws.send( + json.dumps( + { + "type": "error", + "payload": { + "errors": [ + { + "errorType": "Error without id", + "message": ( + "Sometimes AppSync will send you " + "an error without an id" + ), + } + ] + }, + }, + separators=(",", ":"), + ) + ) + return + + print(f"path = {path}") + + path_base, parameters_str = path.split("?") + + assert path_base == "/graphql" + + parameters = parse.parse_qs(parameters_str) + + header_param = parameters["header"][0] + payload_param = parameters["payload"][0] + + assert payload_param == "e30=" + + headers = json.loads(b64decode(header_param).decode()) + + print("\nHeaders received in URL:") + for key, value in headers.items(): + print(f" {key}: {value}") + print("\n") + + error = verify_headers(headers) + + if error is not None: + await ws.send( + json.dumps( + {"payload": {"errors": [error]}, "type": "connection_error"}, + separators=(",", ":"), + ) + ) + return + + await WebSocketServerHelper.send_connection_ack( + ws, payload='{"connectionTimeoutMs":300000}' + ) + + result = await ws.recv() + logged_messages.append(result) + + json_result = json.loads(result) + + query_id = json_result["id"] + assert json_result["type"] == "start" + + payload = json_result["payload"] + + # With appsync, the data field is serialized to string + data_str = payload["data"] + extensions = payload["extensions"] + + data = json.loads(data_str) + + query = data["query"] + variables = data.get("variables", None) + operation_name = data.get("operationName", None) + print(f"Received query: {query}") + print(f"Received variables: {variables}") + print(f"Received operation_name: {operation_name}") + + authorization = extensions["authorization"] + print("\nHeaders received in the extensions of the query:") + for key, value in authorization.items(): + print(f" {key}: {value}") + print("\n") + + error = verify_headers(headers, in_query=True) + + if error is not None: + await ws.send( + json.dumps( + { + "id": str(query_id), + "type": "error", + "payload": {"errors": [error]}, + }, + separators=(",", ":"), + ) + ) + return + + await ws.send( + json.dumps( + {"id": str(query_id), "type": "start_ack"}, separators=(",", ":") + ) + ) + + async def send_message_coro(): + print(" Server: send message task started") + try: + for number in range(NB_MESSAGES): + payload = { + "data": { + "onCreateMessage": {"message": f"Hello world {number}!"} + } + } + + if operation_name or variables: + + payload["extensions"] = {} + + if operation_name: + payload["extensions"]["operation_name"] = operation_name + if variables: + payload["extensions"]["variables"] = variables + + await ws.send( + json.dumps( + { + "id": str(query_id), + "type": "data", + "payload": payload, + }, + separators=(",", ":"), + ) + ) + await asyncio.sleep(SEND_MESSAGE_DELAY) + finally: + print(" Server: send message task ended") + + print(" Server: starting send message task") + send_message_task = asyncio.ensure_future(send_message_coro()) + + async def keepalive_coro(): + while True: + await asyncio.sleep(5 * MS) + try: + await WebSocketServerHelper.send_keepalive(ws) + except websockets.exceptions.ConnectionClosed: + break + + if keepalive: + print(" Server: starting keepalive task") + keepalive_task = asyncio.ensure_future(keepalive_coro()) + + async def receiving_coro(): + print(" Server: receiving task started") + try: + nonlocal send_message_task + while True: + + try: + result = await ws.recv() + logged_messages.append(result) + except websockets.exceptions.ConnectionClosed: + break + + finally: + print(" Server: receiving task ended") + if keepalive: + keepalive_task.cancel() + + print(" Server: starting receiving task") + receiving_task = asyncio.ensure_future(receiving_coro()) + + try: + print( + " Server: waiting for sending message task to complete" + ) + await send_message_task + except asyncio.CancelledError: + print(" Server: Now sending message task is cancelled") + + print(" Server: sending complete message") + await WebSocketServerHelper.send_complete(ws, query_id) + + if keepalive: + print(" Server: cancelling keepalive task") + keepalive_task.cancel() + try: + await keepalive_task + except asyncio.CancelledError: + print(" Server: Now keepalive task is cancelled") + + print(" Server: waiting for client to close the connection") + try: + await asyncio.wait_for(receiving_task, 1000 * MS) + except asyncio.TimeoutError: + pass + + print(" Server: cancelling receiving task") + receiving_task.cancel() + + try: + await receiving_task + except asyncio.CancelledError: + print(" Server: Now receiving task is cancelled") + + except websockets.exceptions.ConnectionClosedOK: + pass + except AssertionError as e: + print(f"\n Server: Assertion failed: {e!s}\n") + except Exception as e: + print(f"\n Server: Exception received: {e!s}\n") + finally: + print(" Server: waiting for websocket connection to close") + await ws.wait_closed() + print(" Server: connection closed") + + return realtime_appsync_server_template + + +async def realtime_appsync_server(ws, path): + + server = realtime_appsync_server_factory() + await server(ws, path) + + +async def realtime_appsync_server_keepalive(ws, path): + + server = realtime_appsync_server_factory(keepalive=True) + await server(ws, path) + + +async def realtime_appsync_server_not_json_answer(ws, path): + + server = realtime_appsync_server_factory(not_json_answer=True) + await server(ws, path) + + +async def realtime_appsync_server_error_without_id(ws, path): + + server = realtime_appsync_server_factory(error_without_id=True) + await server(ws, path) + + +on_create_message_subscription_str = """ +subscription onCreateMessage { + onCreateMessage { + message + } +} +""" + + +async def default_transport_test(transport): + client = Client(transport=transport) + + expected_messages = [f"Hello world {number}!" for number in range(NB_MESSAGES)] + received_messages = [] + + async with client as session: + subscription = gql(on_create_message_subscription_str) + + async for result in session.subscribe(subscription): + + message = result["onCreateMessage"]["message"] + print(f"Message received: '{message}'") + + received_messages.append(message) + + assert expected_messages == received_messages + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [realtime_appsync_server_keepalive], indirect=True) +async def test_appsync_subscription_api_key(event_loop, server): + + from gql.transport.appsync import ( + AppSyncWebsocketsTransport, + AppSyncApiKeyAuthentication, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key=DUMMY_API_KEY) + + transport = AppSyncWebsocketsTransport( + url=url, auth=auth, keep_alive_timeout=(5 * SEND_MESSAGE_DELAY) + ) + + await default_transport_test(transport) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_subscription_iam_with_token(event_loop, server): + + from gql.transport.appsync import ( + AppSyncWebsocketsTransport, + AppSyncIAMAuthentication, + ) + from botocore.credentials import Credentials + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID, + secret_key=DUMMY_SECRET_ACCESS_KEY, + token=DUMMY_SECRET_SESSION_TOKEN, + ) + + auth = AppSyncIAMAuthentication(host=server.hostname, credentials=dummy_credentials) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + await default_transport_test(transport) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_subscription_iam_without_token(event_loop, server): + + from gql.transport.appsync import ( + AppSyncWebsocketsTransport, + AppSyncIAMAuthentication, + ) + from botocore.credentials import Credentials + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, + ) + + auth = AppSyncIAMAuthentication(host=server.hostname, credentials=dummy_credentials) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + await default_transport_test(transport) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_execute_method_not_allowed(event_loop, server): + + from gql.transport.appsync import ( + AppSyncWebsocketsTransport, + AppSyncIAMAuthentication, + ) + from botocore.credentials import Credentials + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, + ) + + auth = AppSyncIAMAuthentication(host=server.hostname, credentials=dummy_credentials) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + async with client as session: + query = gql( + """ +mutation createMessage($message: String!) { + createMessage(input: {message: $message}) { + id + message + createdAt + } +}""" + ) + + variable_values = {"message": "Hello world!"} + + with pytest.raises(AssertionError) as exc_info: + await session.execute(query, variable_values=variable_values) + + assert ( + "execute method is not allowed for AppSyncWebsocketsTransport " + "because only subscriptions are allowed on the realtime endpoint." + ) in str(exc_info) + + +@pytest.mark.asyncio +async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): + + from gql.transport.appsync import ( + AppSyncWebsocketsTransport, + AppSyncIAMAuthentication, + ) + from botocore.credentials import Credentials + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, + ) + + auth = AppSyncIAMAuthentication(host="something", credentials=dummy_credentials) + + transport = AppSyncWebsocketsTransport(url="https://something", auth=auth) + + with pytest.raises(AssertionError) as exc_info: + Client(transport=transport, fetch_schema_from_transport=True) + + assert ( + "fetch_schema_from_transport=True is not allowed for AppSyncWebsocketsTransport" + " because only subscriptions are allowed on the realtime endpoint." + ) in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_subscription_api_key_unauthorized(event_loop, server): + + from gql.transport.appsync import ( + AppSyncWebsocketsTransport, + AppSyncApiKeyAuthentication, + ) + from gql.transport.exceptions import TransportServerError + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key="invalid") + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + with pytest.raises(TransportServerError) as exc_info: + async with client as _: + pass + + assert "You are not authorized to make this call." in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_subscription_api_key_not_allowed(event_loop, server): + + from gql.transport.appsync import ( + AppSyncWebsocketsTransport, + AppSyncIAMAuthentication, + ) + from gql.transport.exceptions import TransportQueryError + from botocore.credentials import Credentials + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID_NOT_ALLOWED, + secret_key=DUMMY_SECRET_ACCESS_KEY, + token=DUMMY_SECRET_SESSION_TOKEN, + ) + + auth = AppSyncIAMAuthentication(host=server.hostname, credentials=dummy_credentials) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + async with client as session: + subscription = gql(on_create_message_subscription_str) + + with pytest.raises(TransportQueryError) as exc_info: + + async for result in session.subscribe(subscription): + pass + + assert "Permission denied" in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [realtime_appsync_server_not_json_answer], indirect=True +) +async def test_appsync_subscription_server_sending_a_not_json_answer( + event_loop, server +): + + from gql.transport.appsync import ( + AppSyncWebsocketsTransport, + AppSyncApiKeyAuthentication, + ) + from gql.transport.exceptions import TransportProtocolError + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key=DUMMY_API_KEY) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + with pytest.raises(TransportProtocolError) as exc_info: + async with client as _: + pass + + assert "Server did not return a GraphQL result: Something not json" in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [realtime_appsync_server_error_without_id], indirect=True +) +async def test_appsync_subscription_server_sending_an_error_without_an_id( + event_loop, server +): + + from gql.transport.appsync import ( + AppSyncWebsocketsTransport, + AppSyncApiKeyAuthentication, + ) + from gql.transport.exceptions import TransportServerError + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key=DUMMY_API_KEY) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + with pytest.raises(TransportServerError) as exc_info: + async with client as _: + pass + + assert "Sometimes AppSync will send you an error without an id" in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [realtime_appsync_server_keepalive], indirect=True) +async def test_appsync_subscription_variable_values_and_operation_name( + event_loop, server +): + + from gql.transport.appsync import ( + AppSyncWebsocketsTransport, + AppSyncApiKeyAuthentication, + ) + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key=DUMMY_API_KEY) + + transport = AppSyncWebsocketsTransport( + url=url, auth=auth, keep_alive_timeout=(5 * SEND_MESSAGE_DELAY) + ) + + client = Client(transport=transport) + + expected_messages = [f"Hello world {number}!" for number in range(NB_MESSAGES)] + received_messages = [] + + async with client as session: + subscription = gql(on_create_message_subscription_str) + + async for execution_result in session.subscribe( + subscription, + operation_name="onCreateMessage", + variable_values={"key1": "val1"}, + get_execution_result=True, + ): + + result = execution_result.data + message = result["onCreateMessage"]["message"] + print(f"Message received: '{message}'") + + received_messages.append(message) + + print(f"extensions received: {execution_result.extensions}") + + assert execution_result.extensions["operation_name"] == "onCreateMessage" + variables = execution_result.extensions["variables"] + assert variables["key1"] == "val1" + + assert expected_messages == received_messages From 4cec1ddab4b5d4877340316dddb3cca869b3486d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 3 Dec 2021 18:32:41 +0100 Subject: [PATCH 49/68] Only one auth class for JWT --- docs/transports/appsync.rst | 12 ++++++------ gql/transport/appsync.py | 15 +++++++-------- tests/test_appsyncwebsocket.py | 30 +++++++++++++++++++----------- 3 files changed, 32 insertions(+), 25 deletions(-) diff --git a/docs/transports/appsync.rst b/docs/transports/appsync.rst index 8f48a232..cae50e7a 100644 --- a/docs/transports/appsync.rst +++ b/docs/transports/appsync.rst @@ -45,14 +45,14 @@ IAM Reference: :class:`gql.transport.appsync.AppSyncIAMAuthentication` -Amazon Cognito user pools -^^^^^^^^^^^^^^^^^^^^^^^^^ +Json Web Tokens (jwt) +^^^^^^^^^^^^^^^^^^^^^ -Reference: :class:`gql.transport.appsync.AppSyncOIDCAuthentication` +AWS provides json web tokens (jwt) for the authentication methods: -OpenID Connect (OIDC) -^^^^^^^^^^^^^^^^^^^^^ +- Amazon Cognito user pools +- OpenID Connect (OIDC) -Reference: :class:`gql.transport.appsync.AppSyncCognitoUserPoolAuthentication` +Reference: :class:`gql.transport.appsync.AppSyncJWTAuthentication` .. _Building a real-time websocket client: https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html diff --git a/gql/transport/appsync.py b/gql/transport/appsync.py index ea9da9d1..f5e673b9 100644 --- a/gql/transport/appsync.py +++ b/gql/transport/appsync.py @@ -65,8 +65,13 @@ def get_headers(self, data: Optional[str] = None) -> Dict: return {"host": self._host, "x-api-key": self.api_key} -class AppSyncOIDCAuthentication(AppSyncAuthentication): - """AWS authentication class using an OpenID JWT access token""" +class AppSyncJWTAuthentication(AppSyncAuthentication): + """AWS authentication class using a JWT access token. + + It can be used either for: + - Amazon Cognito user pools + - OpenID Connect (OIDC) + """ def __init__(self, host: str, jwt: str) -> None: """ @@ -81,12 +86,6 @@ def get_headers(self, data: Optional[str] = None) -> Dict: return {"host": self._host, "Authorization": self.jwt} -class AppSyncCognitoUserPoolAuthentication(AppSyncOIDCAuthentication): - """AWS authentication class using a Cognito user pools JWT access token""" - - pass - - class AppSyncIAMAuthentication(AppSyncAuthentication): """AWS authentication class using IAM. diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index 20a6f348..c08fab30 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -6,7 +6,7 @@ mock_transport_url = "https://appsyncapp.awsgateway.com.example.org" -def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): +def test_appsync_init_with_minimal_args(fake_session_factory): from gql.transport.appsync import ( AppSyncIAMAuthentication, AppSyncWebsocketsTransport, @@ -23,7 +23,7 @@ def test_appsyncwebsocket_init_with_minimal_args(fake_session_factory): assert sample_transport.connect_args == {} -def test_appsyncwebsocket_init_with_no_credentials(caplog, fake_session_factory): +def test_appsync_init_with_no_credentials(caplog, fake_session_factory): import botocore.exceptions from gql.transport.appsync import AppSyncWebsocketsTransport @@ -40,18 +40,23 @@ def test_appsyncwebsocket_init_with_no_credentials(caplog, fake_session_factory) assert expected_error in caplog.text -def test_appsyncwebsocket_init_with_oidc_auth(): +def test_appsync_init_with_jwt_auth(): from gql.transport.appsync import ( - AppSyncOIDCAuthentication, + AppSyncJWTAuthentication, AppSyncWebsocketsTransport, ) - auth = AppSyncOIDCAuthentication(host=mock_transport_url, jwt="some-jwt") + auth = AppSyncJWTAuthentication(host=mock_transport_url, jwt="some-jwt") sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) assert sample_transport.auth is auth + assert auth.get_headers() == { + "host": mock_transport_url, + "Authorization": "some-jwt", + } -def test_appsyncwebsocket_init_with_apikey_auth(): + +def test_appsync_init_with_apikey_auth(): from gql.transport.appsync import ( AppSyncApiKeyAuthentication, AppSyncWebsocketsTransport, @@ -61,8 +66,13 @@ def test_appsyncwebsocket_init_with_apikey_auth(): sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) assert sample_transport.auth is auth + assert auth.get_headers() == { + "host": mock_transport_url, + "x-api-key": "some-api-key", + } + -def test_appsyncwebsocket_init_with_iam_auth_without_creds(fake_session_factory): +def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): import botocore.exceptions from gql.transport.appsync import ( AppSyncIAMAuthentication, @@ -76,7 +86,7 @@ def test_appsyncwebsocket_init_with_iam_auth_without_creds(fake_session_factory) AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) -def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory): +def test_appsync_init_with_iam_auth_with_creds(fake_credentials_factory): from gql.transport.appsync import ( AppSyncIAMAuthentication, AppSyncWebsocketsTransport, @@ -91,9 +101,7 @@ def test_appsyncwebsocket_init_with_iam_auth_with_creds(fake_credentials_factory assert sample_transport.auth is auth -def test_appsyncwebsocket_init_with_iam_auth_and_no_region( - caplog, fake_credentials_factory -): +def test_appsync_init_with_iam_auth_and_no_region(caplog, fake_credentials_factory): from gql.transport.appsync import ( AppSyncIAMAuthentication, AppSyncWebsocketsTransport, From 35af8ce548d06433e5e1ff430b69f504ab8e7829 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 3 Dec 2021 18:39:42 +0100 Subject: [PATCH 50/68] Fix tests host/url confusion --- tests/test_appsyncwebsocket.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsyncwebsocket.py index c08fab30..40772d83 100644 --- a/tests/test_appsyncwebsocket.py +++ b/tests/test_appsyncwebsocket.py @@ -3,7 +3,8 @@ # Marking all tests in this file with the appsync marker pytestmark = pytest.mark.appsync -mock_transport_url = "https://appsyncapp.awsgateway.com.example.org" +mock_transport_host = "appsyncapp.awsgateway.com.example.org" +mock_transport_url = f"https://{mock_transport_host}/graphql" def test_appsync_init_with_minimal_args(fake_session_factory): @@ -46,12 +47,12 @@ def test_appsync_init_with_jwt_auth(): AppSyncWebsocketsTransport, ) - auth = AppSyncJWTAuthentication(host=mock_transport_url, jwt="some-jwt") + auth = AppSyncJWTAuthentication(host=mock_transport_host, jwt="some-jwt") sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) assert sample_transport.auth is auth assert auth.get_headers() == { - "host": mock_transport_url, + "host": mock_transport_host, "Authorization": "some-jwt", } @@ -62,12 +63,12 @@ def test_appsync_init_with_apikey_auth(): AppSyncWebsocketsTransport, ) - auth = AppSyncApiKeyAuthentication(host=mock_transport_url, api_key="some-api-key") + auth = AppSyncApiKeyAuthentication(host=mock_transport_host, api_key="some-api-key") sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) assert sample_transport.auth is auth assert auth.get_headers() == { - "host": mock_transport_url, + "host": mock_transport_host, "x-api-key": "some-api-key", } @@ -80,7 +81,7 @@ def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): ) auth = AppSyncIAMAuthentication( - host=mock_transport_url, session=fake_session_factory(credentials=None), + host=mock_transport_host, session=fake_session_factory(credentials=None), ) with pytest.raises(botocore.exceptions.NoCredentialsError): AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) @@ -93,7 +94,7 @@ def test_appsync_init_with_iam_auth_with_creds(fake_credentials_factory): ) auth = AppSyncIAMAuthentication( - host=mock_transport_url, + host=mock_transport_host, credentials=fake_credentials_factory(), region_name="us-east-1", ) @@ -109,7 +110,7 @@ def test_appsync_init_with_iam_auth_and_no_region(caplog, fake_credentials_facto with pytest.raises(TypeError): auth = AppSyncIAMAuthentication( - host=mock_transport_url, credentials=fake_credentials_factory() + host=mock_transport_host, credentials=fake_credentials_factory() ) sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) assert sample_transport.auth is None From bfe99a56259ca5883e68993f633b1001d72f8c21 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 3 Dec 2021 18:41:30 +0100 Subject: [PATCH 51/68] rename test_appsyncwebsocket.py to test_appsync_auth.py --- tests/{test_appsyncwebsocket.py => test_appsync_auth.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_appsyncwebsocket.py => test_appsync_auth.py} (100%) diff --git a/tests/test_appsyncwebsocket.py b/tests/test_appsync_auth.py similarity index 100% rename from tests/test_appsyncwebsocket.py rename to tests/test_appsync_auth.py From 5efc6609537ad59accab6d02f20869386644d6e9 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 3 Dec 2021 19:19:27 +0100 Subject: [PATCH 52/68] small modif: print Region found if region test fails --- tests/test_appsync_auth.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py index 40772d83..046e3867 100644 --- a/tests/test_appsync_auth.py +++ b/tests/test_appsync_auth.py @@ -112,8 +112,9 @@ def test_appsync_init_with_iam_auth_and_no_region(caplog, fake_credentials_facto auth = AppSyncIAMAuthentication( host=mock_transport_host, credentials=fake_credentials_factory() ) - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) - assert sample_transport.auth is None + AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + + print(f"Region found: {auth._region_name}") print(f"Captured: {caplog.text}") From f3446e60c01f32f86aa367341abea05aa10f54fe Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 3 Dec 2021 20:24:21 +0100 Subject: [PATCH 53/68] Add region_name in tests --- tests/test_appsync_subscription.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/test_appsync_subscription.py b/tests/test_appsync_subscription.py index 9d897c12..99a0ed17 100644 --- a/tests/test_appsync_subscription.py +++ b/tests/test_appsync_subscription.py @@ -28,6 +28,7 @@ "CTEDyhREDACTEDSKREDACTEDYbREDACTEDfeREDACTED3UREDACTEDaKREDACTEDi1REDACTEDGEREDAC" "TED4VREDACTEDjmREDACTEDYcREDACTEDkQREDACTEDyI=" ) +REGION_NAME = "eu-west-3" # List which can used to store received messages by the server logged_messages: List[str] = [] @@ -430,7 +431,9 @@ async def test_appsync_subscription_iam_with_token(event_loop, server): token=DUMMY_SECRET_SESSION_TOKEN, ) - auth = AppSyncIAMAuthentication(host=server.hostname, credentials=dummy_credentials) + auth = AppSyncIAMAuthentication( + host=server.hostname, credentials=dummy_credentials, region_name=REGION_NAME + ) transport = AppSyncWebsocketsTransport(url=url, auth=auth) @@ -454,7 +457,9 @@ async def test_appsync_subscription_iam_without_token(event_loop, server): access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, ) - auth = AppSyncIAMAuthentication(host=server.hostname, credentials=dummy_credentials) + auth = AppSyncIAMAuthentication( + host=server.hostname, credentials=dummy_credentials, region_name=REGION_NAME + ) transport = AppSyncWebsocketsTransport(url=url, auth=auth) @@ -478,7 +483,9 @@ async def test_appsync_execute_method_not_allowed(event_loop, server): access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, ) - auth = AppSyncIAMAuthentication(host=server.hostname, credentials=dummy_credentials) + auth = AppSyncIAMAuthentication( + host=server.hostname, credentials=dummy_credentials, region_name=REGION_NAME + ) transport = AppSyncWebsocketsTransport(url=url, auth=auth) @@ -520,7 +527,9 @@ async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, ) - auth = AppSyncIAMAuthentication(host="something", credentials=dummy_credentials) + auth = AppSyncIAMAuthentication( + host="something", credentials=dummy_credentials, region_name=REGION_NAME + ) transport = AppSyncWebsocketsTransport(url="https://something", auth=auth) @@ -579,7 +588,9 @@ async def test_appsync_subscription_api_key_not_allowed(event_loop, server): token=DUMMY_SECRET_SESSION_TOKEN, ) - auth = AppSyncIAMAuthentication(host=server.hostname, credentials=dummy_credentials) + auth = AppSyncIAMAuthentication( + host=server.hostname, credentials=dummy_credentials, region_name=REGION_NAME + ) transport = AppSyncWebsocketsTransport(url=url, auth=auth) From 69147e66e61820c478a085ac35da2ec5acc25bb7 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 4 Dec 2021 00:24:21 +0100 Subject: [PATCH 54/68] By default try to get the region from the host --- gql/transport/appsync.py | 48 +++++++++++++++++++++++++++++++------- tests/test_appsync_auth.py | 39 +++++++++++++++++++++---------- 2 files changed, 66 insertions(+), 21 deletions(-) diff --git a/gql/transport/appsync.py b/gql/transport/appsync.py index f5e673b9..e29dfcf7 100644 --- a/gql/transport/appsync.py +++ b/gql/transport/appsync.py @@ -1,5 +1,6 @@ import json import logging +import re from abc import ABC, abstractmethod from base64 import b64encode from ssl import SSLContext @@ -10,7 +11,7 @@ from botocore.auth import BaseSigner, SigV4Auth from botocore.awsrequest import AWSRequest, create_request_object from botocore.credentials import Credentials -from botocore.exceptions import NoCredentialsError +from botocore.exceptions import NoCredentialsError, NoRegionError from botocore.session import get_session from graphql import DocumentNode, ExecutionResult, print_ast @@ -118,10 +119,8 @@ def __init__( self._credentials = ( credentials if credentials else self._session.get_credentials() ) - self._region_name = self._session._resolve_region_name( - region_name, self._session.get_default_client_config() - ) self._service_name = "appsync" + self._region_name = region_name or self._detect_region_name() self._signer = ( signer if signer @@ -131,6 +130,37 @@ def __init__( request_creator if request_creator else create_request_object ) + def _detect_region_name(self): + """Try to detect the correct region_name. + + First try to extract the region_name from the host. + + If that does not work, then try to get the region_name from + the aws configuration (~/.aws/config file) or the AWS_DEFAULT_REGION + environment variable. + + If no region_name was found, then raise a NoRegionError exception.""" + + # Regular expression from botocore.utils.validate_region + m = re.search( + r"appsync-api\.((?![0-9]+$)(?!-)[a-zA-Z0-9-]{,63}(? Dict: headers = { @@ -221,15 +251,15 @@ def __init__( except NoCredentialsError: log.warning( - "Credentials not found. " + "Credentials not found. " "Do you have default AWS credentials configured?", ) raise - except TypeError: + except NoRegionError: log.warning( - "A TypeError was raised. " - "The most likely reason for this is that the AWS " - "region is missing from the credentials.", + "Region name not found. " + "It was not possible to detect your region either from the host " + "or from your default AWS configuration." ) raise diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py index 046e3867..981d8f13 100644 --- a/tests/test_appsync_auth.py +++ b/tests/test_appsync_auth.py @@ -102,23 +102,38 @@ def test_appsync_init_with_iam_auth_with_creds(fake_credentials_factory): assert sample_transport.auth is auth -def test_appsync_init_with_iam_auth_and_no_region(caplog, fake_credentials_factory): - from gql.transport.appsync import ( - AppSyncIAMAuthentication, - AppSyncWebsocketsTransport, - ) +def test_appsync_init_with_iam_auth_and_no_region( + caplog, fake_credentials_factory, fake_session_factory +): + """ - with pytest.raises(TypeError): - auth = AppSyncIAMAuthentication( - host=mock_transport_host, credentials=fake_credentials_factory() - ) - AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + WARNING: this test will fail if: + - you have a default region set in ~/.aws/config + - you have the AWS_DEFAULT_REGION environment variable set + + """ + from gql.transport.appsync import AppSyncWebsocketsTransport + from botocore.exceptions import NoRegionError + import logging - print(f"Region found: {auth._region_name}") + caplog.set_level(logging.WARNING) + + with pytest.raises(NoRegionError): + session = fake_session_factory(credentials=fake_credentials_factory()) + session._region_name = None + session._credentials.region = None + transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=session) + + # prints the region name in case the test fails + print(f"Region found: {transport.auth._region_name}") print(f"Captured: {caplog.text}") - expected_error = "the AWS region is missing from the credentials" + expected_error = ( + "Region name not found. " + "It was not possible to detect your region either from the host " + "or from your default AWS configuration." + ) assert expected_error in caplog.text From b271a1cd109e47ef4711dca42f8186e6b0c22a10 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 4 Dec 2021 01:09:59 +0100 Subject: [PATCH 55/68] Improve documentation --- docs/transports/appsync.rst | 85 ++++++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/docs/transports/appsync.rst b/docs/transports/appsync.rst index cae50e7a..8da6ab93 100644 --- a/docs/transports/appsync.rst +++ b/docs/transports/appsync.rst @@ -29,6 +29,7 @@ Full example with API key authentication from environment variables: Reference: :class:`gql.transport.appsync.AppSyncWebsocketsTransport` +.. _Building a real-time websocket client: https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html .. _appsync_authentication_methods: @@ -38,11 +39,70 @@ Authentication methods API key ^^^^^^^ +Use the :code:`AppSyncApiKeyAuthentication` class to provide your API key: + +.. code-block:: python + + auth = AppSyncApiKeyAuthentication( + host="XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com", + api_key="YOUR_API_KEY", + ) + + transport = AppSyncWebsocketsTransport( + url="https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql", + auth=auth, + ) + Reference: :class:`gql.transport.appsync.AppSyncApiKeyAuthentication` IAM ^^^ +For the IAM authentication, you can simply create your transport without +an auth argument. + +The region name will be autodetected from the url or from your AWS configuration +(:code:`.aws/config`) or the environment variable: + +- AWS_DEFAULT_REGION + +The credentials will be detected from your AWS configuration file +(:code:`.aws/credentials`) or from the environment variables: + +- AWS_ACCESS_KEY_ID +- AWS_SECRET_ACCESS_KEY +- AWS_SESSION_TOKEN (optional) + +.. code-block:: python + + transport = AppSyncWebsocketsTransport( + url="https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql", + ) + +OR You can also provide the credentials manually by creating the +:code:`AppSyncIAMAuthentication` class yourself: + +.. code-block:: python + + from botocore.credentials import Credentials + + credentials = Credentials( + access_key = os.environ.get("AWS_ACCESS_KEY_ID"), + secret_key= os.environ.get("AWS_SECRET_ACCESS_KEY"), + token=os.environ.get("AWS_SESSION_TOKEN", None), # Optional + ) + + auth = AppSyncIAMAuthentication( + host="XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com", + credentials=credentials, + region_name="your region" + ) + + transport = AppSyncWebsocketsTransport( + url="https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql", + auth=auth, + ) + Reference: :class:`gql.transport.appsync.AppSyncIAMAuthentication` Json Web Tokens (jwt) @@ -53,6 +113,29 @@ AWS provides json web tokens (jwt) for the authentication methods: - Amazon Cognito user pools - OpenID Connect (OIDC) +For these authentication methods, you can use the :code:`AppSyncJWTAuthentication` class: + +.. code-block:: python + + auth = AppSyncJWTAuthentication( + host="XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com", + jwt="YOUR_JWT_STRING", + ) + + transport = AppSyncWebsocketsTransport( + url="https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql", + auth=auth, + ) + Reference: :class:`gql.transport.appsync.AppSyncJWTAuthentication` -.. _Building a real-time websocket client: https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html +AppSync GraphQL Queries and mutations +------------------------------------- + +Queries and mutations are not allowed on the realtime websockets endpoint. +But you can use the :ref:`AIOHTTPTransport ` to create +a normal http session and reuse the authentication classes to create the headers for you. + +Full example with API key authentication from environment variables: + +.. literalinclude:: ../code_examples/aws_api_key_mutation.py From 1f6f2c4bcfda940defeeb39a0015140af2d803c4 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 4 Dec 2021 01:14:32 +0100 Subject: [PATCH 56/68] api_key is not needed in iam example --- docs/code_examples/aws_iam_subscription.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/code_examples/aws_iam_subscription.py b/docs/code_examples/aws_iam_subscription.py index e891fc44..a7abba19 100644 --- a/docs/code_examples/aws_iam_subscription.py +++ b/docs/code_examples/aws_iam_subscription.py @@ -14,9 +14,8 @@ async def main(): # Should look like: # https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql url = os.environ.get("AWS_GRAPHQL_API_ENDPOINT") - api_key = os.environ.get("AWS_GRAPHQL_API_KEY") - if url is None or api_key is None: + if url is None: print("Missing environment variables") sys.exit() From db4a710128ac7643f8ac76d5bc5801c5c9684a77 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 5 Dec 2021 17:16:46 +0100 Subject: [PATCH 57/68] Fix AttributeError('AppSyncWebsocketsTransport' object has no attribute 'GRAPHQLWS_SUBPROTOCOL' --- gql/transport/appsync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/transport/appsync.py b/gql/transport/appsync.py index e29dfcf7..c7796bbc 100644 --- a/gql/transport/appsync.py +++ b/gql/transport/appsync.py @@ -392,7 +392,7 @@ async def execute( ) _initialize = WebsocketsTransport._initialize - _stop_listener = WebsocketsTransport._stop_listener # type: ignore + _stop_listener = WebsocketsTransport._send_stop_message# type: ignore _send_init_message_and_wait_ack = ( WebsocketsTransport._send_init_message_and_wait_ack ) From a5874d436e2cf4bdd0e62d350bee962816d234c3 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 5 Dec 2021 23:36:14 +0100 Subject: [PATCH 58/68] Split appsync.py into 2 files Modify AIOHTTPTransport to accept an AppSync auth class --- docs/code_examples/aws_api_key_mutation.py | 4 +- .../code_examples/aws_api_key_subscription.py | 6 +- docs/code_examples/aws_iam_mutation.py | 52 +++++ docs/code_examples/aws_iam_subscription.py | 2 +- docs/modules/gql.rst | 3 +- docs/modules/transport_appsync.rst | 7 - docs/modules/transport_appsync_auth.rst | 7 + docs/modules/transport_appsync_websockets.rst | 7 + docs/transports/appsync.rst | 10 +- gql/transport/aiohttp.py | 23 +- gql/transport/appsync_auth.py | 198 ++++++++++++++++++ .../{appsync.py => appsync_websockets.py} | 193 +---------------- tests/test_appsync_auth.py | 42 ++-- tests/test_appsync_subscription.py | 60 ++---- 14 files changed, 337 insertions(+), 277 deletions(-) create mode 100644 docs/code_examples/aws_iam_mutation.py delete mode 100644 docs/modules/transport_appsync.rst create mode 100644 docs/modules/transport_appsync_auth.rst create mode 100644 docs/modules/transport_appsync_websockets.rst create mode 100644 gql/transport/appsync_auth.py rename gql/transport/{appsync.py => appsync_websockets.py} (52%) diff --git a/docs/code_examples/aws_api_key_mutation.py b/docs/code_examples/aws_api_key_mutation.py index a3fcc2f0..563fd240 100644 --- a/docs/code_examples/aws_api_key_mutation.py +++ b/docs/code_examples/aws_api_key_mutation.py @@ -6,7 +6,7 @@ from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport -from gql.transport.appsync import AppSyncApiKeyAuthentication +from gql.transport.appsync_auth import AppSyncApiKeyAuthentication logging.basicConfig(level=logging.DEBUG) @@ -27,7 +27,7 @@ async def main(): auth = AppSyncApiKeyAuthentication(host=host, api_key=api_key) - transport = AIOHTTPTransport(url=url, headers=auth.get_headers()) + transport = AIOHTTPTransport(url=url, auth=auth) async with Client( transport=transport, fetch_schema_from_transport=False, diff --git a/docs/code_examples/aws_api_key_subscription.py b/docs/code_examples/aws_api_key_subscription.py index 06d14b67..10ce272d 100644 --- a/docs/code_examples/aws_api_key_subscription.py +++ b/docs/code_examples/aws_api_key_subscription.py @@ -5,10 +5,8 @@ from urllib.parse import urlparse from gql import Client, gql -from gql.transport.appsync import ( - AppSyncApiKeyAuthentication, - AppSyncWebsocketsTransport, -) +from gql.transport.appsync_auth import AppSyncApiKeyAuthentication +from gql.transport.appsync_websockets import AppSyncWebsocketsTransport logging.basicConfig(level=logging.DEBUG) diff --git a/docs/code_examples/aws_iam_mutation.py b/docs/code_examples/aws_iam_mutation.py new file mode 100644 index 00000000..8a90598e --- /dev/null +++ b/docs/code_examples/aws_iam_mutation.py @@ -0,0 +1,52 @@ +import asyncio +import logging +import os +import sys +from urllib.parse import urlparse + +from gql import Client, gql +from gql.transport.aiohttp import AIOHTTPTransport +from gql.transport.appsync_auth import AppSyncIAMAuthentication + +logging.basicConfig(level=logging.DEBUG) + + +async def main(): + + # Should look like: + # https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + url = os.environ.get("AWS_GRAPHQL_API_ENDPOINT") + + if url is None: + print("Missing environment variables") + sys.exit() + + # Extract host from url + host = str(urlparse(url).netloc) + + auth = AppSyncIAMAuthentication(host=host) + + transport = AIOHTTPTransport(url=url, auth=auth) + + async with Client( + transport=transport, fetch_schema_from_transport=False, + ) as session: + + query = gql( + """ +mutation createMessage($message: String!) { + createMessage(input: {message: $message}) { + id + message + createdAt + } +}""" + ) + + variable_values = {"message": "Hello world!"} + + result = await session.execute(query, variable_values=variable_values) + print(result) + + +asyncio.run(main()) diff --git a/docs/code_examples/aws_iam_subscription.py b/docs/code_examples/aws_iam_subscription.py index a7abba19..7e57318f 100644 --- a/docs/code_examples/aws_iam_subscription.py +++ b/docs/code_examples/aws_iam_subscription.py @@ -4,7 +4,7 @@ import sys from gql import Client, gql -from gql.transport.appsync import AppSyncWebsocketsTransport +from gql.transport.appsync_websockets import AppSyncWebsocketsTransport logging.basicConfig(level=logging.DEBUG) diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index 01cda657..a00c324f 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -21,7 +21,8 @@ Sub-Packages client transport transport_aiohttp - transport_appsync + transport_appsync_auth + transport_appsync_websockets transport_exceptions transport_phoenix_channel_websockets transport_requests diff --git a/docs/modules/transport_appsync.rst b/docs/modules/transport_appsync.rst deleted file mode 100644 index f7360088..00000000 --- a/docs/modules/transport_appsync.rst +++ /dev/null @@ -1,7 +0,0 @@ -gql.transport.appsync -===================== - -.. currentmodule:: gql.transport.appsync - -.. automodule:: gql.transport.appsync - :member-order: bysource diff --git a/docs/modules/transport_appsync_auth.rst b/docs/modules/transport_appsync_auth.rst new file mode 100644 index 00000000..b8ac42c0 --- /dev/null +++ b/docs/modules/transport_appsync_auth.rst @@ -0,0 +1,7 @@ +gql.transport.appsync_auth +========================== + +.. currentmodule:: gql.transport.appsync_auth + +.. automodule:: gql.transport.appsync_auth + :member-order: bysource diff --git a/docs/modules/transport_appsync_websockets.rst b/docs/modules/transport_appsync_websockets.rst new file mode 100644 index 00000000..f0d9523d --- /dev/null +++ b/docs/modules/transport_appsync_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.appsync_websockets +================================ + +.. currentmodule:: gql.transport.appsync_websockets + +.. automodule:: gql.transport.appsync_websockets + :member-order: bysource diff --git a/docs/transports/appsync.rst b/docs/transports/appsync.rst index 8da6ab93..4f18f802 100644 --- a/docs/transports/appsync.rst +++ b/docs/transports/appsync.rst @@ -20,14 +20,14 @@ How to use it: .. note:: It is also possible to instantiate the transport without an auth argument. In that case, - gql will use by default the :class:`IAM auth ` + gql will use by default the :class:`IAM auth ` which will try to authenticate with environment variables or from your aws credentials file. Full example with API key authentication from environment variables: .. literalinclude:: ../code_examples/aws_api_key_subscription.py -Reference: :class:`gql.transport.appsync.AppSyncWebsocketsTransport` +Reference: :class:`gql.transport.appsync_websockets.AppSyncWebsocketsTransport` .. _Building a real-time websocket client: https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html @@ -53,7 +53,7 @@ Use the :code:`AppSyncApiKeyAuthentication` class to provide your API key: auth=auth, ) -Reference: :class:`gql.transport.appsync.AppSyncApiKeyAuthentication` +Reference: :class:`gql.transport.appsync_auth.AppSyncApiKeyAuthentication` IAM ^^^ @@ -103,7 +103,7 @@ OR You can also provide the credentials manually by creating the auth=auth, ) -Reference: :class:`gql.transport.appsync.AppSyncIAMAuthentication` +Reference: :class:`gql.transport.appsync_auth.AppSyncIAMAuthentication` Json Web Tokens (jwt) ^^^^^^^^^^^^^^^^^^^^^ @@ -127,7 +127,7 @@ For these authentication methods, you can use the :code:`AppSyncJWTAuthenticatio auth=auth, ) -Reference: :class:`gql.transport.appsync.AppSyncJWTAuthentication` +Reference: :class:`gql.transport.appsync_auth.AppSyncJWTAuthentication` AppSync GraphQL Queries and mutations ------------------------------------- diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index f34a0066..6000612f 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -22,6 +22,11 @@ TransportServerError, ) +try: + from .appsync_auth import AppSyncAuthentication +except ImportError: + pass + log = logging.getLogger(__name__) @@ -43,7 +48,7 @@ def __init__( url: str, headers: Optional[LooseHeaders] = None, cookies: Optional[LooseCookies] = None, - auth: Optional[BasicAuth] = None, + auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None, ssl: Union[SSLContext, bool, Fingerprint] = False, timeout: Optional[int] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, @@ -55,6 +60,7 @@ def __init__( :param headers: Dict of HTTP Headers. :param cookies: Dict of HTTP cookies. :param auth: BasicAuth object to enable Basic HTTP auth if needed + Or Appsync Authentication class :param ssl: ssl_context of the connection. Use ssl=False to disable encryption :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection to close properly @@ -67,7 +73,7 @@ def __init__( self.url: str = url self.headers: Optional[LooseHeaders] = headers self.cookies: Optional[LooseCookies] = cookies - self.auth: Optional[BasicAuth] = auth + self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth self.ssl: Union[SSLContext, bool, Fingerprint] = ssl self.timeout: Optional[int] = timeout self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout @@ -89,9 +95,11 @@ async def connect(self) -> None: client_session_args: Dict[str, Any] = { "cookies": self.cookies, "headers": self.headers, - "auth": self.auth, } + if isinstance(self.auth, BasicAuth): + client_session_args["auth"] = self.auth + if self.timeout is not None: client_session_args["timeout"] = aiohttp.ClientTimeout( total=self.timeout @@ -266,6 +274,15 @@ async def execute( if extra_args: post_args.update(extra_args) + # Add headers for AppSync if requested + try: + if isinstance(self.auth, AppSyncAuthentication): + post_args["headers"] = self.auth.get_headers( + json.dumps(payload), {"content-type": "application/json"}, + ) + except NameError: + pass + if self.session is None: raise TransportClosed("Transport is not connected") diff --git a/gql/transport/appsync_auth.py b/gql/transport/appsync_auth.py new file mode 100644 index 00000000..af39341f --- /dev/null +++ b/gql/transport/appsync_auth.py @@ -0,0 +1,198 @@ +import json +import logging +import re +from abc import ABC, abstractmethod +from base64 import b64encode +from typing import Any, Callable, Dict, Optional + +from botocore.auth import BaseSigner, SigV4Auth +from botocore.awsrequest import AWSRequest, create_request_object +from botocore.credentials import Credentials +from botocore.exceptions import NoRegionError +from botocore.session import Session, get_session + +log = logging.getLogger("gql.transport.appsync") + + +class AppSyncAuthentication(ABC): + """AWS authentication abstract base class + + All AWS authentication class should have a + :meth:`get_headers ` + method which defines the headers used in the authentication process.""" + + def get_auth_url(self, url: str) -> str: + """ + :return: a url with base64 encoded headers used to establish + a websocket connection to the appsync-realtime-api. + """ + headers = self.get_headers() + + encoded_headers = b64encode( + json.dumps(headers, separators=(",", ":")).encode() + ).decode() + + url_base = url.replace("https://", "wss://").replace( + "appsync-api", "appsync-realtime-api" + ) + + return f"{url_base}?header={encoded_headers}&payload=e30=" + + @abstractmethod + def get_headers( + self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + raise NotImplementedError() # pragma: no cover + + +class AppSyncApiKeyAuthentication(AppSyncAuthentication): + """AWS authentication class using an API key""" + + def __init__(self, host: str, api_key: str) -> None: + """ + :param host: the host, something like: + XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com + :param api_key: the API key + """ + self._host = host + self.api_key = api_key + + def get_headers( + self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + return {"host": self._host, "x-api-key": self.api_key} + + +class AppSyncJWTAuthentication(AppSyncAuthentication): + """AWS authentication class using a JWT access token. + + It can be used either for: + - Amazon Cognito user pools + - OpenID Connect (OIDC) + """ + + def __init__(self, host: str, jwt: str) -> None: + """ + :param host: the host, something like: + XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com + :param jwt: the JWT Access Token + """ + self._host = host + self.jwt = jwt + + def get_headers( + self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + return {"host": self._host, "Authorization": self.jwt} + + +class AppSyncIAMAuthentication(AppSyncAuthentication): + """AWS authentication class using IAM. + + .. note:: + There is no need for you to use this class directly, you could instead + intantiate the :class:`gql.transport.appsync.AppSyncWebsocketsTransport` + without an auth argument. + + During initialization, this class will use botocore to attempt to + find your IAM credentials, either from environment variables or + from your AWS credentials file. + """ + + def __init__( + self, + host: str, + region_name: Optional[str] = None, + signer: Optional[BaseSigner] = None, + request_creator: Optional[Callable[[Dict[str, Any]], AWSRequest]] = None, + credentials: Optional[Credentials] = None, + session: Optional[Session] = None, + ) -> None: + """Initialize itself, saving the found credentials used + to sign the headers later. + + if no credentials are found, then a NoCredentialsError is raised. + """ + self._host = host + self._session = session if session else get_session() + self._credentials = ( + credentials if credentials else self._session.get_credentials() + ) + self._service_name = "appsync" + self._region_name = region_name or self._detect_region_name() + self._signer = ( + signer + if signer + else SigV4Auth(self._credentials, self._service_name, self._region_name) + ) + self._request_creator = ( + request_creator if request_creator else create_request_object + ) + + def _detect_region_name(self): + """Try to detect the correct region_name. + + First try to extract the region_name from the host. + + If that does not work, then try to get the region_name from + the aws configuration (~/.aws/config file) or the AWS_DEFAULT_REGION + environment variable. + + If no region_name was found, then raise a NoRegionError exception.""" + + # Regular expression from botocore.utils.validate_region + m = re.search( + r"appsync-api\.((?![0-9]+$)(?!-)[a-zA-Z0-9-]{,63}(? Dict[str, Any]: + + # Default headers for a websocket connection + headers = headers or { + "accept": "application/json, text/javascript", + "content-encoding": "amz-1.0", + "content-type": "application/json; charset=UTF-8", + } + + request: AWSRequest = self._request_creator( + { + "method": "POST", + "url": f"https://{self._host}/graphql{'' if data else '/connect'}", + "headers": headers, + "context": {}, + "body": data or "{}", + } + ) + + self._signer.add_auth(request) + + headers = dict(request.headers) + + headers["host"] = self._host + + if log.isEnabledFor(logging.DEBUG): + headers_log = [] + headers_log.append("\n\nSigned headers:") + for key, value in headers.items(): + headers_log.append(f" {key}: {value}") + headers_log.append("\n") + log.debug("\n".join(headers_log)) + + return headers diff --git a/gql/transport/appsync.py b/gql/transport/appsync_websockets.py similarity index 52% rename from gql/transport/appsync.py rename to gql/transport/appsync_websockets.py index c7796bbc..7c460d19 100644 --- a/gql/transport/appsync.py +++ b/gql/transport/appsync_websockets.py @@ -1,199 +1,18 @@ import json import logging -import re -from abc import ABC, abstractmethod -from base64 import b64encode from ssl import SSLContext -from typing import Any, Callable, Dict, Optional, Tuple, Union, cast +from typing import Any, Dict, Optional, Tuple, Union, cast from urllib.parse import urlparse -import botocore.session -from botocore.auth import BaseSigner, SigV4Auth -from botocore.awsrequest import AWSRequest, create_request_object -from botocore.credentials import Credentials from botocore.exceptions import NoCredentialsError, NoRegionError -from botocore.session import get_session +from botocore.session import Session from graphql import DocumentNode, ExecutionResult, print_ast +from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication from .exceptions import TransportProtocolError, TransportServerError from .websockets import WebsocketsTransport, WebsocketsTransportBase -log = logging.getLogger(__name__) - - -class AppSyncAuthentication(ABC): - """AWS authentication abstract base class - - All AWS authentication class should have a - :meth:`get_headers ` - method which defines the headers used in the authentication process.""" - - def get_auth_url(self, url: str) -> str: - """ - :return: a url with base64 encoded headers used to establish - a websocket connection to the appsync-realtime-api. - """ - headers = self.get_headers() - - encoded_headers = b64encode( - json.dumps(headers, separators=(",", ":")).encode() - ).decode() - - url_base = url.replace("https://", "wss://").replace( - "appsync-api", "appsync-realtime-api" - ) - - return f"{url_base}?header={encoded_headers}&payload=e30=" - - @abstractmethod - def get_headers(self, data: Optional[str] = None) -> Dict: - raise NotImplementedError() # pragma: no cover - - -class AppSyncApiKeyAuthentication(AppSyncAuthentication): - """AWS authentication class using an API key""" - - def __init__(self, host: str, api_key: str) -> None: - """ - :param host: the host, something like: - XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com - :param api_key: the API key - """ - self._host = host - self.api_key = api_key - - def get_headers(self, data: Optional[str] = None) -> Dict: - return {"host": self._host, "x-api-key": self.api_key} - - -class AppSyncJWTAuthentication(AppSyncAuthentication): - """AWS authentication class using a JWT access token. - - It can be used either for: - - Amazon Cognito user pools - - OpenID Connect (OIDC) - """ - - def __init__(self, host: str, jwt: str) -> None: - """ - :param host: the host, something like: - XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com - :param jwt: the JWT Access Token - """ - self._host = host - self.jwt = jwt - - def get_headers(self, data: Optional[str] = None) -> Dict: - return {"host": self._host, "Authorization": self.jwt} - - -class AppSyncIAMAuthentication(AppSyncAuthentication): - """AWS authentication class using IAM. - - .. note:: - There is no need for you to use this class directly, you could instead - intantiate the :class:`gql.transport.appsync.AppSyncWebsocketsTransport` - without an auth argument. - - During initialization, this class will use botocore to attempt to - find your IAM credentials, either from environment variables or - from your AWS credentials file. - """ - - def __init__( - self, - host: str, - region_name: Optional[str] = None, - signer: Optional[BaseSigner] = None, - request_creator: Optional[Callable[[Dict[str, Any]], AWSRequest]] = None, - credentials: Optional[Credentials] = None, - session: Optional[botocore.session.Session] = None, - ) -> None: - """Initialize itself, saving the found credentials used - to sign the headers later. - - if no credentials are found, then a NoCredentialsError is raised. - """ - self._host = host - self._session = session if session else get_session() - self._credentials = ( - credentials if credentials else self._session.get_credentials() - ) - self._service_name = "appsync" - self._region_name = region_name or self._detect_region_name() - self._signer = ( - signer - if signer - else SigV4Auth(self._credentials, self._service_name, self._region_name) - ) - self._request_creator = ( - request_creator if request_creator else create_request_object - ) - - def _detect_region_name(self): - """Try to detect the correct region_name. - - First try to extract the region_name from the host. - - If that does not work, then try to get the region_name from - the aws configuration (~/.aws/config file) or the AWS_DEFAULT_REGION - environment variable. - - If no region_name was found, then raise a NoRegionError exception.""" - - # Regular expression from botocore.utils.validate_region - m = re.search( - r"appsync-api\.((?![0-9]+$)(?!-)[a-zA-Z0-9-]{,63}(? Dict: - - headers = { - "accept": "application/json, text/javascript", - "content-encoding": "amz-1.0", - "content-type": "application/json; charset=UTF-8", - } - - request: AWSRequest = self._request_creator( - { - "method": "POST", - "url": f"https://{self._host}/graphql{'' if data else '/connect'}", - "headers": headers, - "context": {}, - "body": data or "{}", - } - ) - - self._signer.add_auth(request) - - headers = dict(request.headers) - - headers["host"] = self._host - - if log.isEnabledFor(logging.DEBUG): - headers_log = [] - headers_log.append("\n\nSigned headers:") - for key, value in headers.items(): - headers_log.append(f" {key}: {value}") - headers_log.append("\n") - log.debug("\n".join(headers_log)) - - return headers +log = logging.getLogger("gql.transport.appsync") class AppSyncWebsocketsTransport(WebsocketsTransportBase): @@ -210,7 +29,7 @@ def __init__( self, url: str, auth: Optional[AppSyncAuthentication] = None, - session: Optional[botocore.session.Session] = None, + session: Optional[Session] = None, ssl: Union[SSLContext, bool] = False, connect_timeout: int = 10, close_timeout: int = 10, @@ -392,7 +211,7 @@ async def execute( ) _initialize = WebsocketsTransport._initialize - _stop_listener = WebsocketsTransport._send_stop_message# type: ignore + _stop_listener = WebsocketsTransport._send_stop_message # type: ignore _send_init_message_and_wait_ack = ( WebsocketsTransport._send_init_message_and_wait_ack ) diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py index 981d8f13..27750e29 100644 --- a/tests/test_appsync_auth.py +++ b/tests/test_appsync_auth.py @@ -8,10 +8,8 @@ def test_appsync_init_with_minimal_args(fake_session_factory): - from gql.transport.appsync import ( - AppSyncIAMAuthentication, - AppSyncWebsocketsTransport, - ) + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport sample_transport = AppSyncWebsocketsTransport( url=mock_transport_url, session=fake_session_factory() @@ -26,7 +24,7 @@ def test_appsync_init_with_minimal_args(fake_session_factory): def test_appsync_init_with_no_credentials(caplog, fake_session_factory): import botocore.exceptions - from gql.transport.appsync import AppSyncWebsocketsTransport + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport with pytest.raises(botocore.exceptions.NoCredentialsError): sample_transport = AppSyncWebsocketsTransport( @@ -42,10 +40,8 @@ def test_appsync_init_with_no_credentials(caplog, fake_session_factory): def test_appsync_init_with_jwt_auth(): - from gql.transport.appsync import ( - AppSyncJWTAuthentication, - AppSyncWebsocketsTransport, - ) + from gql.transport.appsync_auth import AppSyncJWTAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport auth = AppSyncJWTAuthentication(host=mock_transport_host, jwt="some-jwt") sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) @@ -58,10 +54,8 @@ def test_appsync_init_with_jwt_auth(): def test_appsync_init_with_apikey_auth(): - from gql.transport.appsync import ( - AppSyncApiKeyAuthentication, - AppSyncWebsocketsTransport, - ) + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport auth = AppSyncApiKeyAuthentication(host=mock_transport_host, api_key="some-api-key") sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) @@ -75,10 +69,8 @@ def test_appsync_init_with_apikey_auth(): def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): import botocore.exceptions - from gql.transport.appsync import ( - AppSyncIAMAuthentication, - AppSyncWebsocketsTransport, - ) + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport auth = AppSyncIAMAuthentication( host=mock_transport_host, session=fake_session_factory(credentials=None), @@ -88,10 +80,8 @@ def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): def test_appsync_init_with_iam_auth_with_creds(fake_credentials_factory): - from gql.transport.appsync import ( - AppSyncIAMAuthentication, - AppSyncWebsocketsTransport, - ) + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport auth = AppSyncIAMAuthentication( host=mock_transport_host, @@ -112,7 +102,7 @@ def test_appsync_init_with_iam_auth_and_no_region( - you have the AWS_DEFAULT_REGION environment variable set """ - from gql.transport.appsync import AppSyncWebsocketsTransport + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from botocore.exceptions import NoRegionError import logging @@ -139,10 +129,8 @@ def test_appsync_init_with_iam_auth_and_no_region( def test_munge_url(fake_signer_factory, fake_request_factory): - from gql.transport.appsync import ( - AppSyncIAMAuthentication, - AppSyncWebsocketsTransport, - ) + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport test_url = "https://appsync-api.aws.example.org/some-other-params" @@ -171,7 +159,7 @@ def test_munge_url_format( fake_credentials_factory, fake_session_factory, ): - from gql.transport.appsync import AppSyncIAMAuthentication + from gql.transport.appsync_auth import AppSyncIAMAuthentication test_url = "https://appsync-api.aws.example.org/some-other-params" diff --git a/tests/test_appsync_subscription.py b/tests/test_appsync_subscription.py index 99a0ed17..1a13907a 100644 --- a/tests/test_appsync_subscription.py +++ b/tests/test_appsync_subscription.py @@ -395,10 +395,8 @@ async def default_transport_test(transport): @pytest.mark.parametrize("server", [realtime_appsync_server_keepalive], indirect=True) async def test_appsync_subscription_api_key(event_loop, server): - from gql.transport.appsync import ( - AppSyncWebsocketsTransport, - AppSyncApiKeyAuthentication, - ) + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -416,10 +414,8 @@ async def test_appsync_subscription_api_key(event_loop, server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_with_token(event_loop, server): - from gql.transport.appsync import ( - AppSyncWebsocketsTransport, - AppSyncIAMAuthentication, - ) + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from botocore.credentials import Credentials path = "/graphql" @@ -444,10 +440,8 @@ async def test_appsync_subscription_iam_with_token(event_loop, server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_without_token(event_loop, server): - from gql.transport.appsync import ( - AppSyncWebsocketsTransport, - AppSyncIAMAuthentication, - ) + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from botocore.credentials import Credentials path = "/graphql" @@ -470,10 +464,8 @@ async def test_appsync_subscription_iam_without_token(event_loop, server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_execute_method_not_allowed(event_loop, server): - from gql.transport.appsync import ( - AppSyncWebsocketsTransport, - AppSyncIAMAuthentication, - ) + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from botocore.credentials import Credentials path = "/graphql" @@ -517,10 +509,8 @@ async def test_appsync_execute_method_not_allowed(event_loop, server): @pytest.mark.asyncio async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): - from gql.transport.appsync import ( - AppSyncWebsocketsTransport, - AppSyncIAMAuthentication, - ) + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from botocore.credentials import Credentials dummy_credentials = Credentials( @@ -546,10 +536,8 @@ async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_api_key_unauthorized(event_loop, server): - from gql.transport.appsync import ( - AppSyncWebsocketsTransport, - AppSyncApiKeyAuthentication, - ) + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.exceptions import TransportServerError path = "/graphql" @@ -572,10 +560,8 @@ async def test_appsync_subscription_api_key_unauthorized(event_loop, server): @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_api_key_not_allowed(event_loop, server): - from gql.transport.appsync import ( - AppSyncWebsocketsTransport, - AppSyncIAMAuthentication, - ) + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.exceptions import TransportQueryError from botocore.credentials import Credentials @@ -615,10 +601,8 @@ async def test_appsync_subscription_server_sending_a_not_json_answer( event_loop, server ): - from gql.transport.appsync import ( - AppSyncWebsocketsTransport, - AppSyncApiKeyAuthentication, - ) + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.exceptions import TransportProtocolError path = "/graphql" @@ -645,10 +629,8 @@ async def test_appsync_subscription_server_sending_an_error_without_an_id( event_loop, server ): - from gql.transport.appsync import ( - AppSyncWebsocketsTransport, - AppSyncApiKeyAuthentication, - ) + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.exceptions import TransportServerError path = "/graphql" @@ -673,10 +655,8 @@ async def test_appsync_subscription_variable_values_and_operation_name( event_loop, server ): - from gql.transport.appsync import ( - AppSyncWebsocketsTransport, - AppSyncApiKeyAuthentication, - ) + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" From 500c635e2a3a4ad110fb7933cd592789a4903af9 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 6 Dec 2021 14:19:03 +0100 Subject: [PATCH 59/68] Put appsync code examples in separate folder --- .../{aws_api_key_mutation.py => appsync/mutation_api_key.py} | 0 .../{aws_iam_mutation.py => appsync/mutation_iam.py} | 0 .../subscription_api_key.py} | 0 .../{aws_iam_subscription.py => appsync/subscription_iam.py} | 0 docs/transports/appsync.rst | 4 ++-- 5 files changed, 2 insertions(+), 2 deletions(-) rename docs/code_examples/{aws_api_key_mutation.py => appsync/mutation_api_key.py} (100%) rename docs/code_examples/{aws_iam_mutation.py => appsync/mutation_iam.py} (100%) rename docs/code_examples/{aws_api_key_subscription.py => appsync/subscription_api_key.py} (100%) rename docs/code_examples/{aws_iam_subscription.py => appsync/subscription_iam.py} (100%) diff --git a/docs/code_examples/aws_api_key_mutation.py b/docs/code_examples/appsync/mutation_api_key.py similarity index 100% rename from docs/code_examples/aws_api_key_mutation.py rename to docs/code_examples/appsync/mutation_api_key.py diff --git a/docs/code_examples/aws_iam_mutation.py b/docs/code_examples/appsync/mutation_iam.py similarity index 100% rename from docs/code_examples/aws_iam_mutation.py rename to docs/code_examples/appsync/mutation_iam.py diff --git a/docs/code_examples/aws_api_key_subscription.py b/docs/code_examples/appsync/subscription_api_key.py similarity index 100% rename from docs/code_examples/aws_api_key_subscription.py rename to docs/code_examples/appsync/subscription_api_key.py diff --git a/docs/code_examples/aws_iam_subscription.py b/docs/code_examples/appsync/subscription_iam.py similarity index 100% rename from docs/code_examples/aws_iam_subscription.py rename to docs/code_examples/appsync/subscription_iam.py diff --git a/docs/transports/appsync.rst b/docs/transports/appsync.rst index 4f18f802..496556f1 100644 --- a/docs/transports/appsync.rst +++ b/docs/transports/appsync.rst @@ -25,7 +25,7 @@ How to use it: Full example with API key authentication from environment variables: -.. literalinclude:: ../code_examples/aws_api_key_subscription.py +.. literalinclude:: ../code_examples/appsync/subscription_api_key.py Reference: :class:`gql.transport.appsync_websockets.AppSyncWebsocketsTransport` @@ -138,4 +138,4 @@ a normal http session and reuse the authentication classes to create the headers Full example with API key authentication from environment variables: -.. literalinclude:: ../code_examples/aws_api_key_mutation.py +.. literalinclude:: ../code_examples/appsync/mutation_api_key.py From ffd4bf172c01a71ddec9fe15bd8138a28c5647db Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 6 Dec 2021 14:24:14 +0100 Subject: [PATCH 60/68] Doc code examples comment debug logging --- docs/code_examples/appsync/mutation_api_key.py | 5 +++-- docs/code_examples/appsync/mutation_iam.py | 5 +++-- docs/code_examples/appsync/subscription_api_key.py | 7 +++++-- docs/code_examples/appsync/subscription_iam.py | 7 +++++-- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/docs/code_examples/appsync/mutation_api_key.py b/docs/code_examples/appsync/mutation_api_key.py index 563fd240..052da850 100644 --- a/docs/code_examples/appsync/mutation_api_key.py +++ b/docs/code_examples/appsync/mutation_api_key.py @@ -1,5 +1,4 @@ import asyncio -import logging import os import sys from urllib.parse import urlparse @@ -8,7 +7,9 @@ from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.appsync_auth import AppSyncApiKeyAuthentication -logging.basicConfig(level=logging.DEBUG) +# Uncomment the following lines to enable debug output +# import logging +# logging.basicConfig(level=logging.DEBUG) async def main(): diff --git a/docs/code_examples/appsync/mutation_iam.py b/docs/code_examples/appsync/mutation_iam.py index 8a90598e..327e0d91 100644 --- a/docs/code_examples/appsync/mutation_iam.py +++ b/docs/code_examples/appsync/mutation_iam.py @@ -1,5 +1,4 @@ import asyncio -import logging import os import sys from urllib.parse import urlparse @@ -8,7 +7,9 @@ from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.appsync_auth import AppSyncIAMAuthentication -logging.basicConfig(level=logging.DEBUG) +# Uncomment the following lines to enable debug output +# import logging +# logging.basicConfig(level=logging.DEBUG) async def main(): diff --git a/docs/code_examples/appsync/subscription_api_key.py b/docs/code_examples/appsync/subscription_api_key.py index 10ce272d..87bb3611 100644 --- a/docs/code_examples/appsync/subscription_api_key.py +++ b/docs/code_examples/appsync/subscription_api_key.py @@ -1,5 +1,4 @@ import asyncio -import logging import os import sys from urllib.parse import urlparse @@ -8,7 +7,9 @@ from gql.transport.appsync_auth import AppSyncApiKeyAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport -logging.basicConfig(level=logging.DEBUG) +# Uncomment the following lines to enable debug output +# import logging +# logging.basicConfig(level=logging.DEBUG) async def main(): @@ -43,6 +44,8 @@ async def main(): """ ) + print("Waiting for messages...") + async for result in session.subscribe(subscription): print(result) diff --git a/docs/code_examples/appsync/subscription_iam.py b/docs/code_examples/appsync/subscription_iam.py index 7e57318f..1bb540d0 100644 --- a/docs/code_examples/appsync/subscription_iam.py +++ b/docs/code_examples/appsync/subscription_iam.py @@ -1,12 +1,13 @@ import asyncio -import logging import os import sys from gql import Client, gql from gql.transport.appsync_websockets import AppSyncWebsocketsTransport -logging.basicConfig(level=logging.DEBUG) +# Uncomment the following lines to enable debug output +# import logging +# logging.basicConfig(level=logging.DEBUG) async def main(): @@ -34,6 +35,8 @@ async def main(): """ ) + print("Waiting for messages...") + async for result in session.subscribe(subscription): print(result) From 715bc623c10e6598690ed79e6ea8980a0d807041 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 6 Dec 2021 15:02:14 +0100 Subject: [PATCH 61/68] botocore dependency needed only for IAM auth --- docs/intro.rst | 8 +++-- docs/transports/appsync.rst | 6 ++++ gql/transport/aiohttp.py | 23 +++++---------- gql/transport/appsync_auth.py | 45 ++++++++++++++++++++++------- gql/transport/appsync_websockets.py | 37 +++++++++--------------- setup.py | 8 ++--- tests/conftest.py | 2 +- tests/test_appsync_auth.py | 10 +++++-- tests/test_appsync_subscription.py | 11 +++++-- 9 files changed, 87 insertions(+), 63 deletions(-) diff --git a/docs/intro.rst b/docs/intro.rst index 00c6f87d..1cd3f5c8 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -35,20 +35,22 @@ which needs the :code:`aiohttp` dependency, then you can install GQL with:: pip install --pre gql[aiohttp] -The corresponding between extra dependencies required and the GQL transports is: +The corresponding between extra dependencies required and the GQL classes is: +---------------------+----------------------------------------------------------------+ -| Extra dependencies | Transports | +| Extra dependencies | Classes | +=====================+================================================================+ | aiohttp | :ref:`AIOHTTPTransport ` | +---------------------+----------------------------------------------------------------+ | websockets | :ref:`WebsocketsTransport ` | | | | | | :ref:`PhoenixChannelWebsocketsTransport ` | +| | | +| | :ref:`AppSyncWebsocketsTransport ` | +---------------------+----------------------------------------------------------------+ | requests | :ref:`RequestsHTTPTransport ` | +---------------------+----------------------------------------------------------------+ -| appsync, websockets | :ref:`AppSyncWebsocketsTransport ` | +| botocore | :ref:`AppSyncIAMAuthentication ` | +---------------------+----------------------------------------------------------------+ .. note:: diff --git a/docs/transports/appsync.rst b/docs/transports/appsync.rst index 496556f1..4846e102 100644 --- a/docs/transports/appsync.rst +++ b/docs/transports/appsync.rst @@ -36,6 +36,8 @@ Reference: :class:`gql.transport.appsync_websockets.AppSyncWebsocketsTransport` Authentication methods ---------------------- +.. _appsync_api_key_auth: + API key ^^^^^^^ @@ -55,6 +57,8 @@ Use the :code:`AppSyncApiKeyAuthentication` class to provide your API key: Reference: :class:`gql.transport.appsync_auth.AppSyncApiKeyAuthentication` +.. _appsync_iam_auth: + IAM ^^^ @@ -105,6 +109,8 @@ OR You can also provide the credentials manually by creating the Reference: :class:`gql.transport.appsync_auth.AppSyncIAMAuthentication` +.. _appsync_jwt_auth: + Json Web Tokens (jwt) ^^^^^^^^^^^^^^^^^^^^^ diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 6000612f..12c57068 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -14,6 +14,7 @@ from graphql import DocumentNode, ExecutionResult, print_ast from ..utils import extract_files +from .appsync_auth import AppSyncAuthentication from .async_transport import AsyncTransport from .exceptions import ( TransportAlreadyConnected, @@ -22,11 +23,6 @@ TransportServerError, ) -try: - from .appsync_auth import AppSyncAuthentication -except ImportError: - pass - log = logging.getLogger(__name__) @@ -95,11 +91,11 @@ async def connect(self) -> None: client_session_args: Dict[str, Any] = { "cookies": self.cookies, "headers": self.headers, + "auth": None + if isinstance(self.auth, AppSyncAuthentication) + else self.auth, } - if isinstance(self.auth, BasicAuth): - client_session_args["auth"] = self.auth - if self.timeout is not None: client_session_args["timeout"] = aiohttp.ClientTimeout( total=self.timeout @@ -275,13 +271,10 @@ async def execute( post_args.update(extra_args) # Add headers for AppSync if requested - try: - if isinstance(self.auth, AppSyncAuthentication): - post_args["headers"] = self.auth.get_headers( - json.dumps(payload), {"content-type": "application/json"}, - ) - except NameError: - pass + if isinstance(self.auth, AppSyncAuthentication): + post_args["headers"] = self.auth.get_headers( + json.dumps(payload), {"content-type": "application/json"}, + ) if self.session is None: raise TransportClosed("Transport is not connected") diff --git a/gql/transport/appsync_auth.py b/gql/transport/appsync_auth.py index af39341f..04c07c10 100644 --- a/gql/transport/appsync_auth.py +++ b/gql/transport/appsync_auth.py @@ -5,11 +5,11 @@ from base64 import b64encode from typing import Any, Callable, Dict, Optional -from botocore.auth import BaseSigner, SigV4Auth -from botocore.awsrequest import AWSRequest, create_request_object -from botocore.credentials import Credentials -from botocore.exceptions import NoRegionError -from botocore.session import Session, get_session +try: + import botocore +except ImportError: # pragma: no cover + # botocore is only needed for the IAM AppSync authentication method + pass log = logging.getLogger("gql.transport.appsync") @@ -103,16 +103,23 @@ def __init__( self, host: str, region_name: Optional[str] = None, - signer: Optional[BaseSigner] = None, - request_creator: Optional[Callable[[Dict[str, Any]], AWSRequest]] = None, - credentials: Optional[Credentials] = None, - session: Optional[Session] = None, + signer: Optional["botocore.auth.BaseSigner"] = None, + request_creator: Optional[ + Callable[[Dict[str, Any]], "botocore.awsrequest.AWSRequest"] + ] = None, + credentials: Optional["botocore.credentials.Credentials"] = None, + session: Optional["botocore.session.Session"] = None, ) -> None: """Initialize itself, saving the found credentials used to sign the headers later. if no credentials are found, then a NoCredentialsError is raised. """ + + from botocore.auth import SigV4Auth + from botocore.awsrequest import create_request_object + from botocore.session import get_session + self._host = host self._session = session if session else get_session() self._credentials = ( @@ -140,6 +147,8 @@ def _detect_region_name(self): If no region_name was found, then raise a NoRegionError exception.""" + from botocore.exceptions import NoRegionError + # Regular expression from botocore.utils.validate_region m = re.search( r"appsync-api\.((?![0-9]+$)(?!-)[a-zA-Z0-9-]{,63}(? Dict[str, Any]: + from botocore.exceptions import NoCredentialsError + # Default headers for a websocket connection headers = headers or { "accept": "application/json, text/javascript", @@ -171,7 +187,7 @@ def get_headers( "content-type": "application/json; charset=UTF-8", } - request: AWSRequest = self._request_creator( + request: "botocore.awsrequest.AWSRequest" = self._request_creator( { "method": "POST", "url": f"https://{self._host}/graphql{'' if data else '/connect'}", @@ -181,7 +197,14 @@ def get_headers( } ) - self._signer.add_auth(request) + try: + self._signer.add_auth(request) + except NoCredentialsError: + log.warning( + "Credentials not found. " + "Do you have default AWS credentials configured?", + ) + raise headers = dict(request.headers) diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index 7c460d19..c7e05a09 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -4,8 +4,6 @@ from typing import Any, Dict, Optional, Tuple, Union, cast from urllib.parse import urlparse -from botocore.exceptions import NoCredentialsError, NoRegionError -from botocore.session import Session from graphql import DocumentNode, ExecutionResult, print_ast from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication @@ -14,6 +12,12 @@ log = logging.getLogger("gql.transport.appsync") +try: + import botocore +except ImportError: # pragma: no cover + # botocore is only needed for the IAM AppSync authentication method + pass + class AppSyncWebsocketsTransport(WebsocketsTransportBase): """:ref:`Async Transport ` used to execute GraphQL subscription on @@ -29,7 +33,7 @@ def __init__( self, url: str, auth: Optional[AppSyncAuthentication] = None, - session: Optional[Session] = None, + session: Optional["botocore.session.Session"] = None, ssl: Union[SSLContext, bool] = False, connect_timeout: int = 10, close_timeout: int = 10, @@ -56,31 +60,18 @@ def __init__( a sign of liveness from the server. :param connect_args: Other parameters forwarded to websockets.connect """ - try: - if not auth: - # Extract host from url - host = str(urlparse(url).netloc) + if not auth: - auth = AppSyncIAMAuthentication(host=host, session=session) + # Extract host from url + host = str(urlparse(url).netloc) - self.auth = auth + # May raise NoRegionError or NoCredentialsError or ImportError + auth = AppSyncIAMAuthentication(host=host, session=session) - url = self.auth.get_auth_url(url) + self.auth = auth - except NoCredentialsError: - log.warning( - "Credentials not found. " - "Do you have default AWS credentials configured?", - ) - raise - except NoRegionError: - log.warning( - "Region name not found. " - "It was not possible to detect your region either from the host " - "or from your default AWS configuration." - ) - raise + url = self.auth.get_auth_url(url) super().__init__( url, diff --git a/setup.py b/setup.py index 041881d3..a29cdb0a 100644 --- a/setup.py +++ b/setup.py @@ -50,12 +50,12 @@ "websockets>=10,<11;python_version>'3.6'", ] -install_appsync_requires = [ +install_botocore_requires = [ "botocore>=1.21,<2", -] + install_websockets_requires +] install_all_requires = ( - install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_appsync_requires + install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_botocore_requires ) # Get version from __version__.py file @@ -99,7 +99,7 @@ "aiohttp": install_aiohttp_requires, "requests": install_requests_requires, "websockets": install_websockets_requires, - "appsync": install_appsync_requires, + "botocore": install_botocore_requires, }, include_package_data=True, zip_safe=False, diff --git a/tests/conftest.py b/tests/conftest.py index a2653027..d433c1ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ from gql import Client -all_transport_dependencies = ["aiohttp", "requests", "websockets", "appsync"] +all_transport_dependencies = ["aiohttp", "requests", "websockets", "botocore"] def pytest_addoption(parser): diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py index 27750e29..17558e35 100644 --- a/tests/test_appsync_auth.py +++ b/tests/test_appsync_auth.py @@ -1,12 +1,10 @@ import pytest -# Marking all tests in this file with the appsync marker -pytestmark = pytest.mark.appsync - mock_transport_host = "appsyncapp.awsgateway.com.example.org" mock_transport_url = f"https://{mock_transport_host}/graphql" +@pytest.mark.botocore def test_appsync_init_with_minimal_args(fake_session_factory): from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -22,6 +20,7 @@ def test_appsync_init_with_minimal_args(fake_session_factory): assert sample_transport.connect_args == {} +@pytest.mark.botocore def test_appsync_init_with_no_credentials(caplog, fake_session_factory): import botocore.exceptions from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -67,6 +66,7 @@ def test_appsync_init_with_apikey_auth(): } +@pytest.mark.botocore def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): import botocore.exceptions from gql.transport.appsync_auth import AppSyncIAMAuthentication @@ -79,6 +79,7 @@ def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) +@pytest.mark.botocore def test_appsync_init_with_iam_auth_with_creds(fake_credentials_factory): from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -92,6 +93,7 @@ def test_appsync_init_with_iam_auth_with_creds(fake_credentials_factory): assert sample_transport.auth is auth +@pytest.mark.botocore def test_appsync_init_with_iam_auth_and_no_region( caplog, fake_credentials_factory, fake_session_factory ): @@ -128,6 +130,7 @@ def test_appsync_init_with_iam_auth_and_no_region( assert expected_error in caplog.text +@pytest.mark.botocore def test_munge_url(fake_signer_factory, fake_request_factory): from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -153,6 +156,7 @@ def test_munge_url(fake_signer_factory, fake_request_factory): assert sample_transport.url == expected_url +@pytest.mark.botocore def test_munge_url_format( fake_signer_factory, fake_request_factory, diff --git a/tests/test_appsync_subscription.py b/tests/test_appsync_subscription.py index 1a13907a..f510d4a7 100644 --- a/tests/test_appsync_subscription.py +++ b/tests/test_appsync_subscription.py @@ -10,8 +10,8 @@ from .conftest import MS, WebSocketServerHelper -# Marking all tests in this file with the appsync marker -pytestmark = pytest.mark.appsync +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.websockets SEND_MESSAGE_DELAY = 20 * MS NB_MESSAGES = 10 @@ -411,6 +411,7 @@ async def test_appsync_subscription_api_key(event_loop, server): @pytest.mark.asyncio +@pytest.mark.botocore @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_with_token(event_loop, server): @@ -437,6 +438,7 @@ async def test_appsync_subscription_iam_with_token(event_loop, server): @pytest.mark.asyncio +@pytest.mark.botocore @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_subscription_iam_without_token(event_loop, server): @@ -461,6 +463,7 @@ async def test_appsync_subscription_iam_without_token(event_loop, server): @pytest.mark.asyncio +@pytest.mark.botocore @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) async def test_appsync_execute_method_not_allowed(event_loop, server): @@ -507,6 +510,7 @@ async def test_appsync_execute_method_not_allowed(event_loop, server): @pytest.mark.asyncio +@pytest.mark.botocore async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): from gql.transport.appsync_auth import AppSyncIAMAuthentication @@ -557,8 +561,9 @@ async def test_appsync_subscription_api_key_unauthorized(event_loop, server): @pytest.mark.asyncio +@pytest.mark.botocore @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) -async def test_appsync_subscription_api_key_not_allowed(event_loop, server): +async def test_appsync_subscription_iam_not_allowed(event_loop, server): from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport From 8873afebf65941fc4819c9625df466451b73c29b Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 6 Dec 2021 15:28:45 +0100 Subject: [PATCH 62/68] add websockets_base module to docs --- docs/modules/gql.rst | 1 + docs/modules/transport_websockets_base.rst | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 docs/modules/transport_websockets_base.rst diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index a00c324f..be6f904b 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -27,5 +27,6 @@ Sub-Packages transport_phoenix_channel_websockets transport_requests transport_websockets + transport_websockets_base dsl utilities diff --git a/docs/modules/transport_websockets_base.rst b/docs/modules/transport_websockets_base.rst new file mode 100644 index 00000000..548351eb --- /dev/null +++ b/docs/modules/transport_websockets_base.rst @@ -0,0 +1,7 @@ +gql.transport.websockets_base +============================= + +.. currentmodule:: gql.transport.websockets_base + +.. automodule:: gql.transport.websockets_base + :member-order: bysource From bb319225d8917a782d1cb6a33ec2a5d605411a3e Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 6 Dec 2021 16:12:18 +0100 Subject: [PATCH 63/68] add appsync http iam test Rename test_appsync_subscription.py to tests/test_appsync_websockets.py --- tests/test_appsync_http.py | 80 +++++++++++++++++++ ...cription.py => test_appsync_websockets.py} | 0 2 files changed, 80 insertions(+) create mode 100644 tests/test_appsync_http.py rename tests/{test_appsync_subscription.py => test_appsync_websockets.py} (100%) diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py new file mode 100644 index 00000000..a1ba2486 --- /dev/null +++ b/tests/test_appsync_http.py @@ -0,0 +1,80 @@ +import json + +import pytest + +from gql import Client, gql + + +@pytest.mark.asyncio +@pytest.mark.aiohttp +@pytest.mark.botocore +async def test_appsync_iam_mutation( + event_loop, aiohttp_server, fake_credentials_factory +): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from urllib.parse import urlparse + + async def handler(request): + data = { + "createMessage": { + "id": "4b436192-aab2-460c-8bdf-4f2605eb63da", + "message": "Hello world!", + "createdAt": "2021-12-06T14:49:55.087Z", + } + } + payload = { + "data": data, + "extensions": {"received_headers": dict(request.headers)}, + } + + return web.Response( + text=json.dumps(payload, separators=(",", ":")), + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + # Extract host from url + host = str(urlparse(url).netloc) + + auth = AppSyncIAMAuthentication( + host=host, credentials=fake_credentials_factory(), region_name="us-east-1", + ) + + sample_transport = AIOHTTPTransport(url=url, auth=auth) + + auth = AppSyncIAMAuthentication(host=host) + + async with Client(transport=sample_transport) as session: + + query = gql( + """ +mutation createMessage($message: String!) { + createMessage(input: {message: $message}) { + id + message + createdAt + } +}""" + ) + + # Execute query asynchronously + execution_result = await session.execute(query, get_execution_result=True) + + result = execution_result.data + message = result["createMessage"]["message"] + + assert message == "Hello world!" + + sent_headers = execution_result.extensions["received_headers"] + + assert sent_headers["X-Amz-Security-Token"] == "fake-token" + assert sent_headers["Authorization"].startswith( + "AWS4-HMAC-SHA256 Credential=fake-access-key/" + ) diff --git a/tests/test_appsync_subscription.py b/tests/test_appsync_websockets.py similarity index 100% rename from tests/test_appsync_subscription.py rename to tests/test_appsync_websockets.py From 25262749ada09e5b82db433389e9c75ca5719768 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 6 Dec 2021 16:21:43 +0100 Subject: [PATCH 64/68] fix mark websockets missing on some appsync auth tests --- tests/test_appsync_auth.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py index 17558e35..546e0e6f 100644 --- a/tests/test_appsync_auth.py +++ b/tests/test_appsync_auth.py @@ -38,6 +38,7 @@ def test_appsync_init_with_no_credentials(caplog, fake_session_factory): assert expected_error in caplog.text +@pytest.mark.websockets def test_appsync_init_with_jwt_auth(): from gql.transport.appsync_auth import AppSyncJWTAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -52,6 +53,7 @@ def test_appsync_init_with_jwt_auth(): } +@pytest.mark.websockets def test_appsync_init_with_apikey_auth(): from gql.transport.appsync_auth import AppSyncApiKeyAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport From 9224701da2c1f8a59ff5c6d25f46471d8948b8b9 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 6 Dec 2021 16:30:24 +0100 Subject: [PATCH 65/68] remove redundant line in test_appsync_http.py --- tests/test_appsync_http.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py index a1ba2486..1f787a68 100644 --- a/tests/test_appsync_http.py +++ b/tests/test_appsync_http.py @@ -49,8 +49,6 @@ async def handler(request): sample_transport = AIOHTTPTransport(url=url, auth=auth) - auth = AppSyncIAMAuthentication(host=host) - async with Client(transport=sample_transport) as session: query = gql( From 1952302b2ff45a9d5021fd9cf7fc3cb27d11133a Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 6 Dec 2021 16:48:32 +0100 Subject: [PATCH 66/68] Docs: Add reference blog post to create the sample app --- docs/transports/appsync.rst | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/transports/appsync.rst b/docs/transports/appsync.rst index 4846e102..7ceb7480 100644 --- a/docs/transports/appsync.rst +++ b/docs/transports/appsync.rst @@ -11,7 +11,8 @@ GQL provides the :code:`AppSyncWebsocketsTransport` transport which implements t for you to allow you to execute subscriptions. .. note:: - It is only possible to execute subscriptions with this transport + It is only possible to execute subscriptions with this transport. + For queries or mutations, See :ref:`AppSync GraphQL Queries and mutations ` How to use it: @@ -23,6 +24,10 @@ How to use it: gql will use by default the :class:`IAM auth ` which will try to authenticate with environment variables or from your aws credentials file. +.. note:: + All the examples in this documentation are based on the sample app created + by following `this AWS blog post`_ + Full example with API key authentication from environment variables: .. literalinclude:: ../code_examples/appsync/subscription_api_key.py @@ -30,6 +35,8 @@ Full example with API key authentication from environment variables: Reference: :class:`gql.transport.appsync_websockets.AppSyncWebsocketsTransport` .. _Building a real-time websocket client: https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html +.. _this AWS blog post: https://aws.amazon.com/fr/blogs/mobile/appsync-realtime/ + .. _appsync_authentication_methods: @@ -135,6 +142,8 @@ For these authentication methods, you can use the :code:`AppSyncJWTAuthenticatio Reference: :class:`gql.transport.appsync_auth.AppSyncJWTAuthentication` +.. _appsync_http: + AppSync GraphQL Queries and mutations ------------------------------------- From 34dbd3e85839e33ca745aff5ae3a071938094fd0 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 6 Dec 2021 16:53:56 +0100 Subject: [PATCH 67/68] README.md add AWS AppSync to features --- README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2fa37978..0962c80e 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,12 @@ The complete documentation for GQL can be found at The main features of GQL are: -* Execute GraphQL queries using [different protocols](https://gql.readthedocs.io/en/latest/transports/index.html) (http, websockets, ...) +* Execute GraphQL queries using [different protocols](https://gql.readthedocs.io/en/latest/transports/index.html): + * http + * websockets: + * apollo or graphql-ws protocol + * Phoenix channels + * AWS AppSync realtime protocol (experimental) * Possibility to [validate the queries locally](https://gql.readthedocs.io/en/latest/usage/validation.html) using a GraphQL schema provided locally or fetched from the backend using an instrospection query * Supports GraphQL queries, mutations and [subscriptions](https://gql.readthedocs.io/en/latest/usage/subscriptions.html) * Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) From e35537e79b97e930e86282e8a08cb5f125b95ccd Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 6 Dec 2021 17:05:23 +0100 Subject: [PATCH 68/68] Remove markers from tox.ini The markers are described in conftest.py and can be checked by running: pytest --markers --- tox.ini | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tox.ini b/tox.ini index 77c2c24a..e75b8fac 100644 --- a/tox.ini +++ b/tox.ini @@ -3,11 +3,6 @@ envlist = black,flake8,import-order,mypy,manifest, py{36,37,38,39,310,py3} -[pytest] -markers = - asyncio: Tests belonging to the asyncio transport - appsyncwebsocket: Tests belonging to the AwsAppSyncWebsocket transport - [gh-actions] python = 3.6: py36