diff --git a/README.md b/README.md index 2fa37978..0962c80e 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,12 @@ The complete documentation for GQL can be found at The main features of GQL are: -* Execute GraphQL queries using [different protocols](https://gql.readthedocs.io/en/latest/transports/index.html) (http, websockets, ...) +* Execute GraphQL queries using [different protocols](https://gql.readthedocs.io/en/latest/transports/index.html): + * http + * websockets: + * apollo or graphql-ws protocol + * Phoenix channels + * AWS AppSync realtime protocol (experimental) * Possibility to [validate the queries locally](https://gql.readthedocs.io/en/latest/usage/validation.html) using a GraphQL schema provided locally or fetched from the backend using an instrospection query * Supports GraphQL queries, mutations and [subscriptions](https://gql.readthedocs.io/en/latest/usage/subscriptions.html) * Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) diff --git a/docs/code_examples/appsync/mutation_api_key.py b/docs/code_examples/appsync/mutation_api_key.py new file mode 100644 index 00000000..052da850 --- /dev/null +++ b/docs/code_examples/appsync/mutation_api_key.py @@ -0,0 +1,54 @@ +import asyncio +import os +import sys +from urllib.parse import urlparse + +from gql import Client, gql +from gql.transport.aiohttp import AIOHTTPTransport +from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + +# Uncomment the following lines to enable debug output +# import logging +# logging.basicConfig(level=logging.DEBUG) + + +async def main(): + + # Should look like: + # https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + url = os.environ.get("AWS_GRAPHQL_API_ENDPOINT") + api_key = os.environ.get("AWS_GRAPHQL_API_KEY") + + if url is None or api_key is None: + print("Missing environment variables") + sys.exit() + + # Extract host from url + host = str(urlparse(url).netloc) + + auth = AppSyncApiKeyAuthentication(host=host, api_key=api_key) + + transport = AIOHTTPTransport(url=url, auth=auth) + + async with Client( + transport=transport, fetch_schema_from_transport=False, + ) as session: + + query = gql( + """ +mutation createMessage($message: String!) { + createMessage(input: {message: $message}) { + id + message + createdAt + } +}""" + ) + + variable_values = {"message": "Hello world!"} + + result = await session.execute(query, variable_values=variable_values) + print(result) + + +asyncio.run(main()) diff --git a/docs/code_examples/appsync/mutation_iam.py b/docs/code_examples/appsync/mutation_iam.py new file mode 100644 index 00000000..327e0d91 --- /dev/null +++ b/docs/code_examples/appsync/mutation_iam.py @@ -0,0 +1,53 @@ +import asyncio +import os +import sys +from urllib.parse import urlparse + +from gql import Client, gql +from gql.transport.aiohttp import AIOHTTPTransport +from gql.transport.appsync_auth import AppSyncIAMAuthentication + +# Uncomment the following lines to enable debug output +# import logging +# logging.basicConfig(level=logging.DEBUG) + + +async def main(): + + # Should look like: + # https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + url = os.environ.get("AWS_GRAPHQL_API_ENDPOINT") + + if url is None: + print("Missing environment variables") + sys.exit() + + # Extract host from url + host = str(urlparse(url).netloc) + + auth = AppSyncIAMAuthentication(host=host) + + transport = AIOHTTPTransport(url=url, auth=auth) + + async with Client( + transport=transport, fetch_schema_from_transport=False, + ) as session: + + query = gql( + """ +mutation createMessage($message: String!) { + createMessage(input: {message: $message}) { + id + message + createdAt + } +}""" + ) + + variable_values = {"message": "Hello world!"} + + result = await session.execute(query, variable_values=variable_values) + print(result) + + +asyncio.run(main()) diff --git a/docs/code_examples/appsync/subscription_api_key.py b/docs/code_examples/appsync/subscription_api_key.py new file mode 100644 index 00000000..87bb3611 --- /dev/null +++ b/docs/code_examples/appsync/subscription_api_key.py @@ -0,0 +1,53 @@ +import asyncio +import os +import sys +from urllib.parse import urlparse + +from gql import Client, gql +from gql.transport.appsync_auth import AppSyncApiKeyAuthentication +from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + +# Uncomment the following lines to enable debug output +# import logging +# logging.basicConfig(level=logging.DEBUG) + + +async def main(): + + # Should look like: + # https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + url = os.environ.get("AWS_GRAPHQL_API_ENDPOINT") + api_key = os.environ.get("AWS_GRAPHQL_API_KEY") + + if url is None or api_key is None: + print("Missing environment variables") + sys.exit() + + # Extract host from url + host = str(urlparse(url).netloc) + + print(f"Host: {host}") + + auth = AppSyncApiKeyAuthentication(host=host, api_key=api_key) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + async with Client(transport=transport) as session: + + subscription = gql( + """ +subscription onCreateMessage { + onCreateMessage { + message + } +} +""" + ) + + print("Waiting for messages...") + + async for result in session.subscribe(subscription): + print(result) + + +asyncio.run(main()) diff --git a/docs/code_examples/appsync/subscription_iam.py b/docs/code_examples/appsync/subscription_iam.py new file mode 100644 index 00000000..1bb540d0 --- /dev/null +++ b/docs/code_examples/appsync/subscription_iam.py @@ -0,0 +1,44 @@ +import asyncio +import os +import sys + +from gql import Client, gql +from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + +# Uncomment the following lines to enable debug output +# import logging +# logging.basicConfig(level=logging.DEBUG) + + +async def main(): + + # Should look like: + # https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + url = os.environ.get("AWS_GRAPHQL_API_ENDPOINT") + + if url is None: + print("Missing environment variables") + sys.exit() + + # Using implicit auth (IAM) + transport = AppSyncWebsocketsTransport(url=url) + + async with Client(transport=transport) as session: + + subscription = gql( + """ +subscription onCreateMessage { + onCreateMessage { + message + } +} +""" + ) + + print("Waiting for messages...") + + async for result in session.subscribe(subscription): + print(result) + + +asyncio.run(main()) diff --git a/docs/intro.rst b/docs/intro.rst index e377c56e..1cd3f5c8 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -35,19 +35,23 @@ which needs the :code:`aiohttp` dependency, then you can install GQL with:: pip install --pre gql[aiohttp] -The corresponding between extra dependencies required and the GQL transports is: - -+-------------------+----------------------------------------------------------------+ -| Extra dependency | Transports | -+===================+================================================================+ -| aiohttp | :ref:`AIOHTTPTransport ` | -+-------------------+----------------------------------------------------------------+ -| websockets | :ref:`WebsocketsTransport ` | -| | | -| | :ref:`PhoenixChannelWebsocketsTransport ` | -+-------------------+----------------------------------------------------------------+ -| requests | :ref:`RequestsHTTPTransport ` | -+-------------------+----------------------------------------------------------------+ +The corresponding between extra dependencies required and the GQL classes is: + ++---------------------+----------------------------------------------------------------+ +| Extra dependencies | Classes | ++=====================+================================================================+ +| aiohttp | :ref:`AIOHTTPTransport ` | ++---------------------+----------------------------------------------------------------+ +| websockets | :ref:`WebsocketsTransport ` | +| | | +| | :ref:`PhoenixChannelWebsocketsTransport ` | +| | | +| | :ref:`AppSyncWebsocketsTransport ` | ++---------------------+----------------------------------------------------------------+ +| requests | :ref:`RequestsHTTPTransport ` | ++---------------------+----------------------------------------------------------------+ +| botocore | :ref:`AppSyncIAMAuthentication ` | ++---------------------+----------------------------------------------------------------+ .. note:: diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index 6730e07b..be6f904b 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -20,6 +20,13 @@ Sub-Packages client transport + transport_aiohttp + transport_appsync_auth + transport_appsync_websockets transport_exceptions + transport_phoenix_channel_websockets + transport_requests + transport_websockets + transport_websockets_base dsl utilities diff --git a/docs/modules/transport.rst b/docs/modules/transport.rst index 1b250d7a..d03dbf1f 100644 --- a/docs/modules/transport.rst +++ b/docs/modules/transport.rst @@ -5,14 +5,6 @@ gql.transport .. autoclass:: gql.transport.transport.Transport -.. autoclass:: gql.transport.local_schema.LocalSchemaTransport - -.. autoclass:: gql.transport.requests.RequestsHTTPTransport - .. autoclass:: gql.transport.async_transport.AsyncTransport -.. autoclass:: gql.transport.aiohttp.AIOHTTPTransport - -.. autoclass:: gql.transport.websockets.WebsocketsTransport - -.. autoclass:: gql.transport.phoenix_channel_websockets.PhoenixChannelWebsocketsTransport +.. autoclass:: gql.transport.local_schema.LocalSchemaTransport diff --git a/docs/modules/transport_aiohttp.rst b/docs/modules/transport_aiohttp.rst new file mode 100644 index 00000000..41cebd99 --- /dev/null +++ b/docs/modules/transport_aiohttp.rst @@ -0,0 +1,7 @@ +gql.transport.aiohttp +===================== + +.. currentmodule:: gql.transport.aiohttp + +.. automodule:: gql.transport.aiohttp + :member-order: bysource diff --git a/docs/modules/transport_appsync_auth.rst b/docs/modules/transport_appsync_auth.rst new file mode 100644 index 00000000..b8ac42c0 --- /dev/null +++ b/docs/modules/transport_appsync_auth.rst @@ -0,0 +1,7 @@ +gql.transport.appsync_auth +========================== + +.. currentmodule:: gql.transport.appsync_auth + +.. automodule:: gql.transport.appsync_auth + :member-order: bysource diff --git a/docs/modules/transport_appsync_websockets.rst b/docs/modules/transport_appsync_websockets.rst new file mode 100644 index 00000000..f0d9523d --- /dev/null +++ b/docs/modules/transport_appsync_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.appsync_websockets +================================ + +.. currentmodule:: gql.transport.appsync_websockets + +.. automodule:: gql.transport.appsync_websockets + :member-order: bysource diff --git a/docs/modules/transport_phoenix_channel_websockets.rst b/docs/modules/transport_phoenix_channel_websockets.rst new file mode 100644 index 00000000..5f412a33 --- /dev/null +++ b/docs/modules/transport_phoenix_channel_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.phoenix_channel_websockets +======================================== + +.. currentmodule:: gql.transport.phoenix_channel_websockets + +.. automodule:: gql.transport.phoenix_channel_websockets + :member-order: bysource diff --git a/docs/modules/transport_requests.rst b/docs/modules/transport_requests.rst new file mode 100644 index 00000000..78a07a02 --- /dev/null +++ b/docs/modules/transport_requests.rst @@ -0,0 +1,7 @@ +gql.transport.requests +====================== + +.. currentmodule:: gql.transport.requests + +.. automodule:: gql.transport.requests + :member-order: bysource diff --git a/docs/modules/transport_websockets.rst b/docs/modules/transport_websockets.rst new file mode 100644 index 00000000..9a924afd --- /dev/null +++ b/docs/modules/transport_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.websockets +======================== + +.. currentmodule:: gql.transport.websockets + +.. automodule:: gql.transport.websockets + :member-order: bysource diff --git a/docs/modules/transport_websockets_base.rst b/docs/modules/transport_websockets_base.rst new file mode 100644 index 00000000..548351eb --- /dev/null +++ b/docs/modules/transport_websockets_base.rst @@ -0,0 +1,7 @@ +gql.transport.websockets_base +============================= + +.. currentmodule:: gql.transport.websockets_base + +.. automodule:: gql.transport.websockets_base + :member-order: bysource diff --git a/docs/transports/appsync.rst b/docs/transports/appsync.rst new file mode 100644 index 00000000..7ceb7480 --- /dev/null +++ b/docs/transports/appsync.rst @@ -0,0 +1,156 @@ +.. _appsync_transport: + +AppSyncWebsocketsTransport +========================== + +AWS AppSync allows you to execute GraphQL subscriptions on its realtime GraphQL endpoint. + +See `Building a real-time websocket client`_ for an explanation. + +GQL provides the :code:`AppSyncWebsocketsTransport` transport which implements this +for you to allow you to execute subscriptions. + +.. note:: + It is only possible to execute subscriptions with this transport. + For queries or mutations, See :ref:`AppSync GraphQL Queries and mutations ` + +How to use it: + + * choose one :ref:`authentication method ` (API key, IAM, Cognito user pools or OIDC) + * instantiate a :code:`AppSyncWebsocketsTransport` with your GraphQL endpoint as url and your auth method + +.. note:: + It is also possible to instantiate the transport without an auth argument. In that case, + gql will use by default the :class:`IAM auth ` + which will try to authenticate with environment variables or from your aws credentials file. + +.. note:: + All the examples in this documentation are based on the sample app created + by following `this AWS blog post`_ + +Full example with API key authentication from environment variables: + +.. literalinclude:: ../code_examples/appsync/subscription_api_key.py + +Reference: :class:`gql.transport.appsync_websockets.AppSyncWebsocketsTransport` + +.. _Building a real-time websocket client: https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html +.. _this AWS blog post: https://aws.amazon.com/fr/blogs/mobile/appsync-realtime/ + + +.. _appsync_authentication_methods: + +Authentication methods +---------------------- + +.. _appsync_api_key_auth: + +API key +^^^^^^^ + +Use the :code:`AppSyncApiKeyAuthentication` class to provide your API key: + +.. code-block:: python + + auth = AppSyncApiKeyAuthentication( + host="XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com", + api_key="YOUR_API_KEY", + ) + + transport = AppSyncWebsocketsTransport( + url="https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql", + auth=auth, + ) + +Reference: :class:`gql.transport.appsync_auth.AppSyncApiKeyAuthentication` + +.. _appsync_iam_auth: + +IAM +^^^ + +For the IAM authentication, you can simply create your transport without +an auth argument. + +The region name will be autodetected from the url or from your AWS configuration +(:code:`.aws/config`) or the environment variable: + +- AWS_DEFAULT_REGION + +The credentials will be detected from your AWS configuration file +(:code:`.aws/credentials`) or from the environment variables: + +- AWS_ACCESS_KEY_ID +- AWS_SECRET_ACCESS_KEY +- AWS_SESSION_TOKEN (optional) + +.. code-block:: python + + transport = AppSyncWebsocketsTransport( + url="https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql", + ) + +OR You can also provide the credentials manually by creating the +:code:`AppSyncIAMAuthentication` class yourself: + +.. code-block:: python + + from botocore.credentials import Credentials + + credentials = Credentials( + access_key = os.environ.get("AWS_ACCESS_KEY_ID"), + secret_key= os.environ.get("AWS_SECRET_ACCESS_KEY"), + token=os.environ.get("AWS_SESSION_TOKEN", None), # Optional + ) + + auth = AppSyncIAMAuthentication( + host="XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com", + credentials=credentials, + region_name="your region" + ) + + transport = AppSyncWebsocketsTransport( + url="https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql", + auth=auth, + ) + +Reference: :class:`gql.transport.appsync_auth.AppSyncIAMAuthentication` + +.. _appsync_jwt_auth: + +Json Web Tokens (jwt) +^^^^^^^^^^^^^^^^^^^^^ + +AWS provides json web tokens (jwt) for the authentication methods: + +- Amazon Cognito user pools +- OpenID Connect (OIDC) + +For these authentication methods, you can use the :code:`AppSyncJWTAuthentication` class: + +.. code-block:: python + + auth = AppSyncJWTAuthentication( + host="XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com", + jwt="YOUR_JWT_STRING", + ) + + transport = AppSyncWebsocketsTransport( + url="https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql", + auth=auth, + ) + +Reference: :class:`gql.transport.appsync_auth.AppSyncJWTAuthentication` + +.. _appsync_http: + +AppSync GraphQL Queries and mutations +------------------------------------- + +Queries and mutations are not allowed on the realtime websockets endpoint. +But you can use the :ref:`AIOHTTPTransport ` to create +a normal http session and reuse the authentication classes to create the headers for you. + +Full example with API key authentication from environment variables: + +.. literalinclude:: ../code_examples/appsync/mutation_api_key.py diff --git a/docs/transports/async_transports.rst b/docs/transports/async_transports.rst index 9fb1b017..df8c23cf 100644 --- a/docs/transports/async_transports.rst +++ b/docs/transports/async_transports.rst @@ -12,3 +12,4 @@ Async transports are transports which are using an underlying async library. The aiohttp websockets phoenix + appsync diff --git a/gql/client.py b/gql/client.py index 2236189d..e10f7509 100644 --- a/gql/client.py +++ b/gql/client.py @@ -82,6 +82,12 @@ def __init__( not schema ), "Cannot fetch the schema from transport if is already provided." + assert not type(transport).__name__ == "AppSyncWebsocketsTransport", ( + "fetch_schema_from_transport=True is not allowed " + "for AppSyncWebsocketsTransport " + "because only subscriptions are allowed on the realtime endpoint." + ) + if schema and not transport: transport = LocalSchemaTransport(schema) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index f34a0066..12c57068 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -14,6 +14,7 @@ from graphql import DocumentNode, ExecutionResult, print_ast from ..utils import extract_files +from .appsync_auth import AppSyncAuthentication from .async_transport import AsyncTransport from .exceptions import ( TransportAlreadyConnected, @@ -43,7 +44,7 @@ def __init__( url: str, headers: Optional[LooseHeaders] = None, cookies: Optional[LooseCookies] = None, - auth: Optional[BasicAuth] = None, + auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None, ssl: Union[SSLContext, bool, Fingerprint] = False, timeout: Optional[int] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, @@ -55,6 +56,7 @@ def __init__( :param headers: Dict of HTTP Headers. :param cookies: Dict of HTTP cookies. :param auth: BasicAuth object to enable Basic HTTP auth if needed + Or Appsync Authentication class :param ssl: ssl_context of the connection. Use ssl=False to disable encryption :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection to close properly @@ -67,7 +69,7 @@ def __init__( self.url: str = url self.headers: Optional[LooseHeaders] = headers self.cookies: Optional[LooseCookies] = cookies - self.auth: Optional[BasicAuth] = auth + self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth self.ssl: Union[SSLContext, bool, Fingerprint] = ssl self.timeout: Optional[int] = timeout self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout @@ -89,7 +91,9 @@ async def connect(self) -> None: client_session_args: Dict[str, Any] = { "cookies": self.cookies, "headers": self.headers, - "auth": self.auth, + "auth": None + if isinstance(self.auth, AppSyncAuthentication) + else self.auth, } if self.timeout is not None: @@ -266,6 +270,12 @@ async def execute( if extra_args: post_args.update(extra_args) + # Add headers for AppSync if requested + if isinstance(self.auth, AppSyncAuthentication): + post_args["headers"] = self.auth.get_headers( + json.dumps(payload), {"content-type": "application/json"}, + ) + if self.session is None: raise TransportClosed("Transport is not connected") diff --git a/gql/transport/appsync_auth.py b/gql/transport/appsync_auth.py new file mode 100644 index 00000000..04c07c10 --- /dev/null +++ b/gql/transport/appsync_auth.py @@ -0,0 +1,221 @@ +import json +import logging +import re +from abc import ABC, abstractmethod +from base64 import b64encode +from typing import Any, Callable, Dict, Optional + +try: + import botocore +except ImportError: # pragma: no cover + # botocore is only needed for the IAM AppSync authentication method + pass + +log = logging.getLogger("gql.transport.appsync") + + +class AppSyncAuthentication(ABC): + """AWS authentication abstract base class + + All AWS authentication class should have a + :meth:`get_headers ` + method which defines the headers used in the authentication process.""" + + def get_auth_url(self, url: str) -> str: + """ + :return: a url with base64 encoded headers used to establish + a websocket connection to the appsync-realtime-api. + """ + headers = self.get_headers() + + encoded_headers = b64encode( + json.dumps(headers, separators=(",", ":")).encode() + ).decode() + + url_base = url.replace("https://", "wss://").replace( + "appsync-api", "appsync-realtime-api" + ) + + return f"{url_base}?header={encoded_headers}&payload=e30=" + + @abstractmethod + def get_headers( + self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + raise NotImplementedError() # pragma: no cover + + +class AppSyncApiKeyAuthentication(AppSyncAuthentication): + """AWS authentication class using an API key""" + + def __init__(self, host: str, api_key: str) -> None: + """ + :param host: the host, something like: + XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com + :param api_key: the API key + """ + self._host = host + self.api_key = api_key + + def get_headers( + self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + return {"host": self._host, "x-api-key": self.api_key} + + +class AppSyncJWTAuthentication(AppSyncAuthentication): + """AWS authentication class using a JWT access token. + + It can be used either for: + - Amazon Cognito user pools + - OpenID Connect (OIDC) + """ + + def __init__(self, host: str, jwt: str) -> None: + """ + :param host: the host, something like: + XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com + :param jwt: the JWT Access Token + """ + self._host = host + self.jwt = jwt + + def get_headers( + self, data: Optional[str] = None, headers: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + return {"host": self._host, "Authorization": self.jwt} + + +class AppSyncIAMAuthentication(AppSyncAuthentication): + """AWS authentication class using IAM. + + .. note:: + There is no need for you to use this class directly, you could instead + intantiate the :class:`gql.transport.appsync.AppSyncWebsocketsTransport` + without an auth argument. + + During initialization, this class will use botocore to attempt to + find your IAM credentials, either from environment variables or + from your AWS credentials file. + """ + + def __init__( + self, + host: str, + region_name: Optional[str] = None, + signer: Optional["botocore.auth.BaseSigner"] = None, + request_creator: Optional[ + Callable[[Dict[str, Any]], "botocore.awsrequest.AWSRequest"] + ] = None, + credentials: Optional["botocore.credentials.Credentials"] = None, + session: Optional["botocore.session.Session"] = None, + ) -> None: + """Initialize itself, saving the found credentials used + to sign the headers later. + + if no credentials are found, then a NoCredentialsError is raised. + """ + + from botocore.auth import SigV4Auth + from botocore.awsrequest import create_request_object + from botocore.session import get_session + + self._host = host + self._session = session if session else get_session() + self._credentials = ( + credentials if credentials else self._session.get_credentials() + ) + self._service_name = "appsync" + self._region_name = region_name or self._detect_region_name() + self._signer = ( + signer + if signer + else SigV4Auth(self._credentials, self._service_name, self._region_name) + ) + self._request_creator = ( + request_creator if request_creator else create_request_object + ) + + def _detect_region_name(self): + """Try to detect the correct region_name. + + First try to extract the region_name from the host. + + If that does not work, then try to get the region_name from + the aws configuration (~/.aws/config file) or the AWS_DEFAULT_REGION + environment variable. + + If no region_name was found, then raise a NoRegionError exception.""" + + from botocore.exceptions import NoRegionError + + # Regular expression from botocore.utils.validate_region + m = re.search( + r"appsync-api\.((?![0-9]+$)(?!-)[a-zA-Z0-9-]{,63}(? Dict[str, Any]: + + from botocore.exceptions import NoCredentialsError + + # Default headers for a websocket connection + headers = headers or { + "accept": "application/json, text/javascript", + "content-encoding": "amz-1.0", + "content-type": "application/json; charset=UTF-8", + } + + request: "botocore.awsrequest.AWSRequest" = self._request_creator( + { + "method": "POST", + "url": f"https://{self._host}/graphql{'' if data else '/connect'}", + "headers": headers, + "context": {}, + "body": data or "{}", + } + ) + + try: + self._signer.add_auth(request) + except NoCredentialsError: + log.warning( + "Credentials not found. " + "Do you have default AWS credentials configured?", + ) + raise + + headers = dict(request.headers) + + headers["host"] = self._host + + if log.isEnabledFor(logging.DEBUG): + headers_log = [] + headers_log.append("\n\nSigned headers:") + for key, value in headers.items(): + headers_log.append(f" {key}: {value}") + headers_log.append("\n") + log.debug("\n".join(headers_log)) + + return headers diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py new file mode 100644 index 00000000..c7e05a09 --- /dev/null +++ b/gql/transport/appsync_websockets.py @@ -0,0 +1,209 @@ +import json +import logging +from ssl import SSLContext +from typing import Any, Dict, Optional, Tuple, Union, cast +from urllib.parse import urlparse + +from graphql import DocumentNode, ExecutionResult, print_ast + +from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication +from .exceptions import TransportProtocolError, TransportServerError +from .websockets import WebsocketsTransport, WebsocketsTransportBase + +log = logging.getLogger("gql.transport.appsync") + +try: + import botocore +except ImportError: # pragma: no cover + # botocore is only needed for the IAM AppSync authentication method + pass + + +class AppSyncWebsocketsTransport(WebsocketsTransportBase): + """:ref:`Async Transport ` used to execute GraphQL subscription on + AWS appsync realtime endpoint. + + This transport uses asyncio and the websockets library in order to send requests + on a websocket connection. + """ + + auth: Optional[AppSyncAuthentication] + + def __init__( + self, + url: str, + auth: Optional[AppSyncAuthentication] = None, + session: Optional["botocore.session.Session"] = None, + ssl: Union[SSLContext, bool] = False, + connect_timeout: int = 10, + close_timeout: int = 10, + ack_timeout: int = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + connect_args: Dict[str, Any] = {}, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL endpoint URL. Example: + https://XXXXXXXXXXXXXXXXXXXXXXXXXX.appsync-api.REGION.amazonaws.com/graphql + :param auth: Optional AWS authentication class which will provide the + necessary headers to be correctly authenticated. If this + argument is not provided, then we will try to authenticate + using IAM. + :param ssl: ssl_context of the connection. + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param connect_args: Other parameters forwarded to websockets.connect + """ + + if not auth: + + # Extract host from url + host = str(urlparse(url).netloc) + + # May raise NoRegionError or NoCredentialsError or ImportError + auth = AppSyncIAMAuthentication(host=host, session=session) + + self.auth = auth + + url = self.auth.get_auth_url(url) + + super().__init__( + url, + ssl=ssl, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + ack_timeout=ack_timeout, + keep_alive_timeout=keep_alive_timeout, + connect_args=connect_args, + ) + + # Using the same 'graphql-ws' protocol as the apollo protocol + self.supported_subprotocols = [ + WebsocketsTransport.APOLLO_SUBPROTOCOL, + ] + self.subprotocol = WebsocketsTransport.APOLLO_SUBPROTOCOL + + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server. + + Difference between apollo protocol and aws protocol: + + - aws protocol can return an error without an id + - aws protocol will send start_ack messages + + Returns a list consisting of: + - the answer_type: + - 'connection_ack', + - 'connection_error', + - 'start_ack', + - 'ka', + - 'data', + - 'error', + - 'complete' + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + """ + + answer_type: str = "" + + try: + json_answer = json.loads(answer) + + answer_type = str(json_answer.get("type")) + + if answer_type == "start_ack": + return ("start_ack", None, None) + + elif answer_type == "error" and "id" not in json_answer: + error_payload = json_answer.get("payload") + raise TransportServerError(f"Server error: '{error_payload!r}'") + + else: + + return WebsocketsTransport._parse_answer_apollo( + cast(WebsocketsTransport, self), json_answer + ) + + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) + + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + + query_id = self.next_query_id + + self.next_query_id += 1 + + data: Dict = {"query": print_ast(document)} + + if variable_values: + data["variables"] = variable_values + + if operation_name: + data["operationName"] = operation_name + + serialized_data = json.dumps(data, separators=(",", ":")) + + payload = {"data": serialized_data} + + message: Dict = { + "id": str(query_id), + "type": "start", + "payload": payload, + } + + assert self.auth is not None + + message["payload"]["extensions"] = { + "authorization": self.auth.get_headers(serialized_data) + } + + await self._send(json.dumps(message, separators=(",", ":"),)) + + return query_id + + subscribe = WebsocketsTransportBase.subscribe + """Send a subscription query and receive the results using + a python async generator. + + Only subscriptions are supported, queries and mutations are forbidden. + + The results are sent as an ExecutionResult object. + """ + + async def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> ExecutionResult: + """This method is not available. + + Only subscriptions are supported on the AWS realtime endpoint. + + :raise: AssertionError""" + raise AssertionError( + "execute method is not allowed for AppSyncWebsocketsTransport " + "because only subscriptions are allowed on the realtime endpoint." + ) + + _initialize = WebsocketsTransport._initialize + _stop_listener = WebsocketsTransport._send_stop_message # type: ignore + _send_init_message_and_wait_ack = ( + WebsocketsTransport._send_init_message_and_wait_ack + ) + _wait_ack = WebsocketsTransport._wait_ack diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 779a3608..41478daf 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,85 +1,25 @@ import asyncio import json import logging -import warnings from contextlib import suppress from ssl import SSLContext -from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union, cast +from typing import Any, Dict, Optional, Tuple, Union, cast -import websockets from graphql import DocumentNode, ExecutionResult, print_ast -from websockets.client import WebSocketClientProtocol from websockets.datastructures import HeadersLike -from websockets.exceptions import ConnectionClosed -from websockets.typing import Data, Subprotocol +from websockets.typing import Subprotocol -from .async_transport import AsyncTransport from .exceptions import ( - TransportAlreadyConnected, - TransportClosed, TransportProtocolError, TransportQueryError, TransportServerError, ) +from .websockets_base import WebsocketsTransportBase log = logging.getLogger(__name__) -ParsedAnswer = Tuple[str, Optional[ExecutionResult]] - -class ListenerQueue: - """Special queue used for each query waiting for server answers - - If the server is stopped while the listener is still waiting, - Then we send an exception to the queue and this exception will be raised - to the consumer once all the previous messages have been consumed from the queue - """ - - def __init__(self, query_id: int, send_stop: bool) -> None: - self.query_id: int = query_id - self.send_stop: bool = send_stop - self._queue: asyncio.Queue = asyncio.Queue() - self._closed: bool = False - - async def get(self) -> ParsedAnswer: - - try: - item = self._queue.get_nowait() - except asyncio.QueueEmpty: - item = await self._queue.get() - - self._queue.task_done() - - # If we receive an exception when reading the queue, we raise it - if isinstance(item, Exception): - self._closed = True - raise item - - # Don't need to save new answers or - # send the stop message if we already received the complete message - answer_type, execution_result = item - if answer_type == "complete": - self.send_stop = False - self._closed = True - - return item - - async def put(self, item: ParsedAnswer) -> None: - - if not self._closed: - await self._queue.put(item) - - async def set_exception(self, exception: Exception) -> None: - - # Put the exception in the queue - await self._queue.put(exception) - - # Don't need to send stop messages in case of error - self.send_stop = False - self._closed = True - - -class WebsocketsTransport(AsyncTransport): +class WebsocketsTransport(WebsocketsTransportBase): """:ref:`Async Transport ` used to execute GraphQL queries on remote servers with websocket connection. @@ -133,15 +73,18 @@ def __init__( :param connect_args: Other parameters forwarded to websockets.connect """ - self.url: str = url - self.ssl: Union[SSLContext, bool] = ssl - self.headers: Optional[HeadersLike] = headers - self.init_payload: Dict[str, Any] = init_payload + super().__init__( + url, + headers, + ssl, + init_payload, + connect_timeout, + close_timeout, + ack_timeout, + keep_alive_timeout, + connect_args, + ) - self.connect_timeout: Optional[Union[int, float]] = connect_timeout - self.close_timeout: Optional[Union[int, float]] = close_timeout - self.ack_timeout: Optional[Union[int, float]] = ack_timeout - self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout self.ping_interval: Optional[Union[int, float]] = ping_interval self.pong_timeout: Optional[Union[int, float]] self.answer_pings: bool = answer_pings @@ -152,38 +95,7 @@ def __init__( else: self.pong_timeout = pong_timeout - self.connect_args = connect_args - - self.websocket: Optional[WebSocketClientProtocol] = None - self.next_query_id: int = 1 - self.listeners: Dict[int, ListenerQueue] = {} - - self.receive_data_task: Optional[asyncio.Future] = None - self.check_keep_alive_task: Optional[asyncio.Future] = None self.send_ping_task: Optional[asyncio.Future] = None - self.close_task: Optional[asyncio.Future] = None - - # We need to set an event loop here if there is none - # Or else we will not be able to create an asyncio.Event() - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - self._loop = asyncio.get_event_loop() - except RuntimeError: - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - self._wait_closed: asyncio.Event = asyncio.Event() - self._wait_closed.set() - - self._no_more_listeners: asyncio.Event = asyncio.Event() - self._no_more_listeners.set() - - if self.keep_alive_timeout is not None: - self._next_keep_alive_message: asyncio.Event = asyncio.Event() - self._next_keep_alive_message.set() self.ping_received: asyncio.Event = asyncio.Event() """ping_received is an asyncio Event which will fire each time @@ -193,56 +105,11 @@ def __init__( """pong_received is an asyncio Event which will fire each time a pong is received with the graphql-ws protocol""" - self.payloads: Dict[str, Any] = {} - """payloads is a dict which will contain the payloads received - with the graphql-ws protocol. - Possible keys are: 'ping', 'pong', 'connection_ack'""" - - self._connecting: bool = False - - self.close_exception: Optional[Exception] = None - self.supported_subprotocols = [ self.APOLLO_SUBPROTOCOL, self.GRAPHQLWS_SUBPROTOCOL, ] - async def _send(self, message: str) -> None: - """Send the provided message to the websocket connection and log the message""" - - if not self.websocket: - raise TransportClosed( - "Transport is not connected" - ) from self.close_exception - - try: - await self.websocket.send(message) - log.info(">>> %s", message) - except ConnectionClosed as e: - await self._fail(e, clean_close=False) - raise e - - async def _receive(self) -> str: - """Wait the next message from the websocket connection and log the answer""" - - # It is possible that the websocket has been already closed in another task - if self.websocket is None: - raise TransportClosed("Transport is already closed") - - # Wait for the next websocket frame. Can raise ConnectionClosed - data: Data = await self.websocket.recv() - - # websocket.recv() can return either str or bytes - # In our case, we should receive only str here - if not isinstance(data, str): - raise TransportProtocolError("Binary data received in the websocket") - - answer: str = data - - log.info("<<< %s", answer) - - return answer - async def _wait_ack(self) -> None: """Wait for the connection_ack message. Keep alive messages are ignored""" @@ -274,6 +141,9 @@ async def _send_init_message_and_wait_ack(self) -> None: # Wait for the connection_ack message or raise a TimeoutError await asyncio.wait_for(self._wait_ack(), self.ack_timeout) + async def _initialize(self): + await self._send_init_message_and_wait_ack() + async def send_ping(self, payload: Optional[Any] = None) -> None: """Send a ping message for the graphql-ws protocol """ @@ -316,7 +186,7 @@ async def _send_complete_message(self, query_id: int) -> None: await self._send(complete_message) - async def _stop_listener(self, query_id: int) -> None: + async def _stop_listener(self, query_id: int): """Stop the listener corresponding to the query_id depending on the detected backend protocol. @@ -326,6 +196,8 @@ async def _stop_listener(self, query_id: int) -> None: For graphql-ws: send a "complete" message and simulate the reception of a "complete" message from the backend """ + log.debug(f"stop listener {query_id}") + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: await self._send_complete_message(query_id) await self.listeners[query_id].put(("complete", None)) @@ -377,8 +249,12 @@ async def _send_query( return query_id + async def _connection_terminate(self): + if self.subprotocol == self.APOLLO_SUBPROTOCOL: + await self._send_connection_terminate_message() + def _parse_answer_graphqlws( - self, answer: str + self, json_answer: Dict[str, Any] ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server if the server supports the graphql-ws protocol. @@ -403,8 +279,6 @@ def _parse_answer_graphqlws( execution_result: Optional[ExecutionResult] = None try: - json_answer = json.loads(answer) - answer_type = str(json_answer.get("type")) if answer_type in ["next", "error", "complete"]: @@ -450,13 +324,13 @@ def _parse_answer_graphqlws( except ValueError as e: raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" + f"Server did not return a GraphQL result: {json_answer}" ) from e return answer_type, answer_id, execution_result def _parse_answer_apollo( - self, answer: str + self, json_answer: Dict[str, Any] ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server if the server supports the apollo websockets protocol. @@ -473,8 +347,6 @@ def _parse_answer_apollo( execution_result: Optional[ExecutionResult] = None try: - json_answer = json.loads(answer) - answer_type = str(json_answer.get("type")) if answer_type in ["data", "error", "complete"]: @@ -520,7 +392,7 @@ def _parse_answer_apollo( except ValueError as e: raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" + f"Server did not return a GraphQL result: {json_answer}" ) from e return answer_type, answer_id, execution_result @@ -531,44 +403,17 @@ def _parse_answer( """Parse the answer received from the server depending on the detected subprotocol. """ - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(answer) - - return self._parse_answer_apollo(answer) - - async def _check_ws_liveness(self) -> None: - """Coroutine which will periodically check the liveness of the connection - through keep-alive messages - """ - try: - while True: - await asyncio.wait_for( - self._next_keep_alive_message.wait(), self.keep_alive_timeout - ) - - # Reset for the next iteration - self._next_keep_alive_message.clear() - - except asyncio.TimeoutError: - # No keep-alive message in the appriopriate interval, close with error - # while trying to notify the server of a proper close (in case - # the keep-alive interval of the client or server was not aligned - # the connection still remains) + json_answer = json.loads(answer) + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - "No keep-alive message has been received within " - "the expected interval ('keep_alive_timeout' parameter)" - ), - clean_close=False, - ) + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + return self._parse_answer_graphqlws(json_answer) - except asyncio.CancelledError: - # The client is probably closing, handle it properly - pass + return self._parse_answer_apollo(json_answer) async def _send_ping_coro(self) -> None: """Coroutine to periodically send a ping from the client to the backend. @@ -603,52 +448,6 @@ async def _send_ping_coro(self) -> None: clean_close=False, ) - async def _receive_data_loop(self) -> None: - try: - while True: - - # Wait the next answer from the websocket server - try: - answer = await self._receive() - except (ConnectionClosed, TransportProtocolError) as e: - await self._fail(e, clean_close=False) - break - except TransportClosed: - break - - # Parse the answer - try: - answer_type, answer_id, execution_result = self._parse_answer( - answer - ) - except TransportQueryError as e: - # Received an exception for a specific query - # ==> Add an exception to this query queue - # The exception is raised for this specific query, - # but the transport is not closed. - assert isinstance( - e.query_id, int - ), "TransportQueryError should have a query_id defined here" - try: - await self.listeners[e.query_id].set_exception(e) - except KeyError: - # Do nothing if no one is listening to this query_id - pass - - continue - - except (TransportServerError, TransportProtocolError) as e: - # Received a global exception for this transport - # ==> close the transport - # The exception will be raised for all current queries. - await self._fail(e, clean_close=False) - break - - await self._handle_answer(answer_type, answer_id, execution_result) - - finally: - log.debug("Exiting _receive_data_loop()") - async def _handle_answer( self, answer_type: str, @@ -656,13 +455,8 @@ async def _handle_answer( execution_result: Optional[ExecutionResult], ) -> None: - try: - # Put the answer in the queue - if answer_id is not None: - await self.listeners[answer_id].put((answer_type, execution_result)) - except KeyError: - # Do nothing if no one is listening to this query_id. - pass + # Put the answer in the queue + await super()._handle_answer(answer_type, answer_id, execution_result) # Answer pong to ping for graphql-ws protocol if answer_type == "ping": @@ -673,334 +467,34 @@ async def _handle_answer( elif answer_type == "pong": self.pong_received.set() - async def subscribe( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - send_stop: Optional[bool] = True, - ) -> AsyncGenerator[ExecutionResult, None]: - """Send a query and receive the results using a python async generator. - - The query can be a graphql query, mutation or subscription. - - The results are sent as an ExecutionResult object. - """ - - # Send the query and receive the id - query_id: int = await self._send_query( - document, variable_values, operation_name - ) - - # Create a queue to receive the answers for this query_id - listener = ListenerQueue(query_id, send_stop=(send_stop is True)) - self.listeners[query_id] = listener - - # We will need to wait at close for this query to clean properly - self._no_more_listeners.clear() - - try: - # Loop over the received answers - while True: - - # Wait for the answer from the queue of this query_id - # This can raise a TransportError or ConnectionClosed exception. - answer_type, execution_result = await listener.get() - - # If the received answer contains data, - # Then we will yield the results back as an ExecutionResult object - if execution_result is not None: - yield execution_result - - # If we receive a 'complete' answer from the server, - # Then we will end this async generator output without errors - elif answer_type == "complete": - log.debug( - f"Complete received for query {query_id} --> exit without error" - ) - break - - except (asyncio.CancelledError, GeneratorExit) as e: - log.debug(f"Exception in subscribe: {e!r}") - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False - - finally: - log.debug(f"In subscribe finally for query_id {query_id}") - self._remove_listener(query_id) - - async def execute( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server - using the current session. - - Send a query but close the async generator as soon as we have the first answer. - - The result is sent as an ExecutionResult object. - """ - first_result = None - - generator = self.subscribe( - document, variable_values, operation_name, send_stop=False - ) - - async for result in generator: - first_result = result - - # Note: we need to run generator.aclose() here or the finally block in - # the subscribe will not be reached in pypy3 (python version 3.6.1) - await generator.aclose() - - break - - if first_result is None: - raise TransportQueryError( - "Query completed without any answer received from the server" - ) - - return first_result - - async def connect(self) -> None: - """Coroutine which will: - - - connect to the websocket address - - send the init message - - wait for the connection acknowledge from the server - - create an asyncio task which will be used to receive - and parse the websocket answers - - Should be cleaned with a call to the close coroutine - """ - - log.debug("connect: starting") - - if self.websocket is None and not self._connecting: - - # Set connecting to True to avoid a race condition if user is trying - # to connect twice using the same client at the same time - self._connecting = True - - # If the ssl parameter is not provided, - # generate the ssl value depending on the url - ssl: Optional[Union[SSLContext, bool]] - if self.ssl: - ssl = self.ssl - else: - ssl = True if self.url.startswith("wss") else None - - # Set default arguments used in the websockets.connect call - connect_args: Dict[str, Any] = { - "ssl": ssl, - "extra_headers": self.headers, - "subprotocols": self.supported_subprotocols, - } - - # Adding custom parameters passed from init - connect_args.update(self.connect_args) - - # Connection to the specified url - # Generate a TimeoutError if taking more than connect_timeout seconds - # Set the _connecting flag to False after in all cases - try: - self.websocket = await asyncio.wait_for( - websockets.client.connect(self.url, **connect_args), - self.connect_timeout, - ) - finally: - self._connecting = False - - self.websocket = cast(WebSocketClientProtocol, self.websocket) - - # Find the backend subprotocol returned in the response headers - response_headers = self.websocket.response_headers - try: - self.subprotocol = response_headers["Sec-WebSocket-Protocol"] - except KeyError: - # If the server does not send the subprotocol header, using - # the apollo subprotocol by default - self.subprotocol = self.APOLLO_SUBPROTOCOL - - log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - - self.next_query_id = 1 - self.close_exception = None - self._wait_closed.clear() - - # Send the init message and wait for the ack from the server - # Note: This will generate a TimeoutError - # if no ACKs are received within the ack_timeout - try: - await self._send_init_message_and_wait_ack() - except ConnectionClosed as e: - raise e - except (TransportProtocolError, asyncio.TimeoutError) as e: - await self._fail(e, clean_close=False) - raise e - - # If specified, create a task to check liveness of the connection - # through keep-alive messages - if self.keep_alive_timeout is not None: - self.check_keep_alive_task = asyncio.ensure_future( - self._check_ws_liveness() - ) - - # If requested, create a task to send periodic pings to the backend - if ( - self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL - and self.ping_interval is not None - ): - - self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) - - # Create a task to listen to the incoming websocket messages - self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) - - else: - raise TransportAlreadyConnected("Transport is already connected") - - log.debug("connect: done") - - def _remove_listener(self, query_id) -> None: - """After exiting from a subscription, remove the listener and - signal an event if this was the last listener for the client. - """ - if query_id in self.listeners: - del self.listeners[query_id] - - remaining = len(self.listeners) - log.debug(f"listener {query_id} deleted, {remaining} remaining") - - if remaining == 0: - self._no_more_listeners.set() - - async def _clean_close(self, e: Exception) -> None: - """Coroutine which will: - - - send stop messages for each active subscription to the server - - send the connection terminate message - """ - - # Send 'stop' message for all current queries - for query_id, listener in self.listeners.items(): - - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False - - # Wait that there is no more listeners (we received 'complete' for all queries) - try: - await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) - except asyncio.TimeoutError: # pragma: no cover - log.debug("Timer close_timeout fired") - - if self.subprotocol == self.APOLLO_SUBPROTOCOL: - # Finally send the 'connection_terminate' message - await self._send_connection_terminate_message() - - async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: - """Coroutine which will: - - - do a clean_close if possible: - - send stop messages for each active query to the server - - send the connection terminate message - - close the websocket connection - - send the exception to all the remaining listeners - """ - - log.debug("_close_coro: starting") + async def _after_connect(self): + # Find the backend subprotocol returned in the response headers + response_headers = self.websocket.response_headers try: + self.subprotocol = response_headers["Sec-WebSocket-Protocol"] + except KeyError: + # If the server does not send the subprotocol header, using + # the apollo subprotocol by default + self.subprotocol = self.APOLLO_SUBPROTOCOL - # We should always have an active websocket connection here - assert self.websocket is not None - - # Properly shut down liveness checker if enabled - if self.check_keep_alive_task is not None: - # More info: https://stackoverflow.com/a/43810272/1113207 - self.check_keep_alive_task.cancel() - with suppress(asyncio.CancelledError): - await self.check_keep_alive_task - - # Properly shut down the send ping task if enabled - if self.send_ping_task is not None: - self.send_ping_task.cancel() - with suppress(asyncio.CancelledError): - await self.send_ping_task - - # Saving exception to raise it later if trying to use the transport - # after it has already closed. - self.close_exception = e - - if clean_close: - log.debug("_close_coro: starting clean_close") - try: - await self._clean_close(e) - except Exception as exc: # pragma: no cover - log.warning("Ignoring exception in _clean_close: " + repr(exc)) - - log.debug("_close_coro: sending exception to listeners") - - # Send an exception to all remaining listeners - for query_id, listener in self.listeners.items(): - await listener.set_exception(e) - - log.debug("_close_coro: close websocket connection") - - await self.websocket.close() + log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - log.debug("_close_coro: websocket connection closed") + async def _after_initialize(self): - except Exception as exc: # pragma: no cover - log.warning("Exception catched in _close_coro: " + repr(exc)) + # If requested, create a task to send periodic pings to the backend + if ( + self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL + and self.ping_interval is not None + ): - finally: + self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) - log.debug("_close_coro: start cleanup") + async def _close_hook(self): - self.websocket = None - self.close_task = None - self.check_keep_alive_task = None + # Properly shut down the send ping task if enabled + if self.send_ping_task is not None: + self.send_ping_task.cancel() + with suppress(asyncio.CancelledError): + await self.send_ping_task self.send_ping_task = None - - self._wait_closed.set() - - log.debug("_close_coro: exiting") - - async def _fail(self, e: Exception, clean_close: bool = True) -> None: - log.debug("_fail: starting with exception: " + repr(e)) - - if self.close_task is None: - - if self.websocket is None: - log.debug("_fail started with self.websocket == None -> already closed") - else: - self.close_task = asyncio.shield( - asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) - ) - else: - log.debug( - "close_task is not None in _fail. Previous exception is: " - + repr(self.close_exception) - + " New exception is: " - + repr(e) - ) - - async def close(self) -> None: - log.debug("close: starting") - - await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) - await self.wait_closed() - - log.debug("close: done") - - async def wait_closed(self) -> None: - log.debug("wait_close: starting") - - await self._wait_closed.wait() - - log.debug("wait_close: done") diff --git a/gql/transport/websockets_base.py b/gql/transport/websockets_base.py new file mode 100644 index 00000000..151e444e --- /dev/null +++ b/gql/transport/websockets_base.py @@ -0,0 +1,666 @@ +import asyncio +import logging +import warnings +from abc import abstractmethod +from contextlib import suppress +from ssl import SSLContext +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast + +import websockets +from graphql import DocumentNode, ExecutionResult +from websockets.client import WebSocketClientProtocol +from websockets.datastructures import HeadersLike +from websockets.exceptions import ConnectionClosed +from websockets.typing import Data, Subprotocol + +from .async_transport import AsyncTransport +from .exceptions import ( + TransportAlreadyConnected, + TransportClosed, + TransportProtocolError, + TransportQueryError, + TransportServerError, +) + +log = logging.getLogger("gql.transport.websockets") + +ParsedAnswer = Tuple[str, Optional[ExecutionResult]] + + +class ListenerQueue: + """Special queue used for each query waiting for server answers + + If the server is stopped while the listener is still waiting, + Then we send an exception to the queue and this exception will be raised + to the consumer once all the previous messages have been consumed from the queue + """ + + def __init__(self, query_id: int, send_stop: bool) -> None: + self.query_id: int = query_id + self.send_stop: bool = send_stop + self._queue: asyncio.Queue = asyncio.Queue() + self._closed: bool = False + + async def get(self) -> ParsedAnswer: + + try: + item = self._queue.get_nowait() + except asyncio.QueueEmpty: + item = await self._queue.get() + + self._queue.task_done() + + # If we receive an exception when reading the queue, we raise it + if isinstance(item, Exception): + self._closed = True + raise item + + # Don't need to save new answers or + # send the stop message if we already received the complete message + answer_type, execution_result = item + if answer_type == "complete": + self.send_stop = False + self._closed = True + + return item + + async def put(self, item: ParsedAnswer) -> None: + + if not self._closed: + await self._queue.put(item) + + async def set_exception(self, exception: Exception) -> None: + + # Put the exception in the queue + await self._queue.put(exception) + + # Don't need to send stop messages in case of error + self.send_stop = False + self._closed = True + + +class WebsocketsTransportBase(AsyncTransport): + """abstract :ref:`Async Transport ` used to implement + different websockets protocols. + + This transport uses asyncio and the websockets library in order to send requests + on a websocket connection. + """ + + def __init__( + self, + url: str, + headers: Optional[HeadersLike] = None, + ssl: Union[SSLContext, bool] = False, + init_payload: Dict[str, Any] = {}, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + connect_args: Dict[str, Any] = {}, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param init_payload: Dict of the payload sent in the connection_init message. + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param connect_args: Other parameters forwarded to websockets.connect + """ + + self.url: str = url + self.headers: Optional[HeadersLike] = headers + self.ssl: Union[SSLContext, bool] = ssl + self.init_payload: Dict[str, Any] = init_payload + + self.connect_timeout: Optional[Union[int, float]] = connect_timeout + self.close_timeout: Optional[Union[int, float]] = close_timeout + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout + + self.connect_args = connect_args + + self.websocket: Optional[WebSocketClientProtocol] = None + self.next_query_id: int = 1 + self.listeners: Dict[int, ListenerQueue] = {} + + self.receive_data_task: Optional[asyncio.Future] = None + self.check_keep_alive_task: Optional[asyncio.Future] = None + self.close_task: Optional[asyncio.Future] = None + + # We need to set an event loop here if there is none + # Or else we will not be able to create an asyncio.Event() + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + self._wait_closed: asyncio.Event = asyncio.Event() + self._wait_closed.set() + + self._no_more_listeners: asyncio.Event = asyncio.Event() + self._no_more_listeners.set() + + if self.keep_alive_timeout is not None: + self._next_keep_alive_message: asyncio.Event = asyncio.Event() + self._next_keep_alive_message.set() + + self.payloads: Dict[str, Any] = {} + """payloads is a dict which will contain the payloads received + for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" + + self._connecting: bool = False + + self.close_exception: Optional[Exception] = None + + # The list of supported subprotocols should be defined in the subclass + self.supported_subprotocols: List[Subprotocol] = [] + + async def _initialize(self): + """Hook to send the initialization messages after the connection + and potentially wait for the backend ack. + """ + pass # pragma: no cover + + async def _stop_listener(self, query_id: int): + """Hook to stop to listen to a specific query. + Will send a stop message in some subclasses. + """ + pass # pragma: no cover + + async def _after_connect(self): + """Hook to add custom code for subclasses after the connection + has been established. + """ + pass # pragma: no cover + + async def _after_initialize(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + pass # pragma: no cover + + async def _close_hook(self): + """Hook to add custom code for subclasses for the connection close + """ + pass # pragma: no cover + + async def _connection_terminate(self): + """Hook to add custom code for subclasses after the initialization + has been done. + """ + pass # pragma: no cover + + async def _send(self, message: str) -> None: + """Send the provided message to the websocket connection and log the message""" + + if not self.websocket: + raise TransportClosed( + "Transport is not connected" + ) from self.close_exception + + try: + await self.websocket.send(message) + log.info(">>> %s", message) + except ConnectionClosed as e: + await self._fail(e, clean_close=False) + raise e + + async def _receive(self) -> str: + """Wait the next message from the websocket connection and log the answer""" + + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportClosed("Transport is already closed") + + # Wait for the next websocket frame. Can raise ConnectionClosed + data: Data = await self.websocket.recv() + + # websocket.recv() can return either str or bytes + # In our case, we should receive only str here + if not isinstance(data, str): + raise TransportProtocolError("Binary data received in the websocket") + + answer: str = data + + log.info("<<< %s", answer) + + return answer + + @abstractmethod + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + raise NotImplementedError # pragma: no cover + + @abstractmethod + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + raise NotImplementedError # pragma: no cover + + async def _check_ws_liveness(self) -> None: + """Coroutine which will periodically check the liveness of the connection + through keep-alive messages + """ + + try: + while True: + await asyncio.wait_for( + self._next_keep_alive_message.wait(), self.keep_alive_timeout + ) + + # Reset for the next iteration + self._next_keep_alive_message.clear() + + except asyncio.TimeoutError: + # No keep-alive message in the appriopriate interval, close with error + # while trying to notify the server of a proper close (in case + # the keep-alive interval of the client or server was not aligned + # the connection still remains) + + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + "No keep-alive message has been received within " + "the expected interval ('keep_alive_timeout' parameter)" + ), + clean_close=False, + ) + + except asyncio.CancelledError: + # The client is probably closing, handle it properly + pass + + async def _receive_data_loop(self) -> None: + """Main asyncio task which will listen to the incoming messages and will + call the parse_answer and handle_answer methods of the subclass.""" + try: + while True: + + # Wait the next answer from the websocket server + try: + answer = await self._receive() + except (ConnectionClosed, TransportProtocolError) as e: + await self._fail(e, clean_close=False) + break + except TransportClosed: + break + + # Parse the answer + try: + answer_type, answer_id, execution_result = self._parse_answer( + answer + ) + except TransportQueryError as e: + # Received an exception for a specific query + # ==> Add an exception to this query queue + # The exception is raised for this specific query, + # but the transport is not closed. + assert isinstance( + e.query_id, int + ), "TransportQueryError should have a query_id defined here" + try: + await self.listeners[e.query_id].set_exception(e) + except KeyError: + # Do nothing if no one is listening to this query_id + pass + + continue + + except (TransportServerError, TransportProtocolError) as e: + # Received a global exception for this transport + # ==> close the transport + # The exception will be raised for all current queries. + await self._fail(e, clean_close=False) + break + + await self._handle_answer(answer_type, answer_id, execution_result) + + finally: + log.debug("Exiting _receive_data_loop()") + + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: + + try: + # Put the answer in the queue + if answer_id is not None: + await self.listeners[answer_id].put((answer_type, execution_result)) + except KeyError: + # Do nothing if no one is listening to this query_id. + pass + + async def subscribe( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + send_stop: Optional[bool] = True, + ) -> AsyncGenerator[ExecutionResult, None]: + """Send a query and receive the results using a python async generator. + + The query can be a graphql query, mutation or subscription. + + The results are sent as an ExecutionResult object. + """ + + # Send the query and receive the id + query_id: int = await self._send_query( + document, variable_values, operation_name + ) + + # Create a queue to receive the answers for this query_id + listener = ListenerQueue(query_id, send_stop=(send_stop is True)) + self.listeners[query_id] = listener + + # We will need to wait at close for this query to clean properly + self._no_more_listeners.clear() + + try: + # Loop over the received answers + while True: + + # Wait for the answer from the queue of this query_id + # This can raise a TransportError or ConnectionClosed exception. + answer_type, execution_result = await listener.get() + + # If the received answer contains data, + # Then we will yield the results back as an ExecutionResult object + if execution_result is not None: + yield execution_result + + # If we receive a 'complete' answer from the server, + # Then we will end this async generator output without errors + elif answer_type == "complete": + log.debug( + f"Complete received for query {query_id} --> exit without error" + ) + break + + except (asyncio.CancelledError, GeneratorExit) as e: + log.debug(f"Exception in subscribe: {e!r}") + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False + + finally: + log.debug(f"In subscribe finally for query_id {query_id}") + self._remove_listener(query_id) + + async def execute( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> ExecutionResult: + """Execute the provided document AST against the configured remote server + using the current session. + + Send a query but close the async generator as soon as we have the first answer. + + The result is sent as an ExecutionResult object. + """ + first_result = None + + generator = self.subscribe( + document, variable_values, operation_name, send_stop=False + ) + + async for result in generator: + first_result = result + + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + await generator.aclose() + + break + + if first_result is None: + raise TransportQueryError( + "Query completed without any answer received from the server" + ) + + return first_result + + async def connect(self) -> None: + """Coroutine which will: + + - connect to the websocket address + - send the init message + - wait for the connection acknowledge from the server + - create an asyncio task which will be used to receive + and parse the websocket answers + + Should be cleaned with a call to the close coroutine + """ + + log.debug("connect: starting") + + if self.websocket is None and not self._connecting: + + # Set connecting to True to avoid a race condition if user is trying + # to connect twice using the same client at the same time + self._connecting = True + + # If the ssl parameter is not provided, + # generate the ssl value depending on the url + ssl: Optional[Union[SSLContext, bool]] + if self.ssl: + ssl = self.ssl + else: + ssl = True if self.url.startswith("wss") else None + + # Set default arguments used in the websockets.connect call + connect_args: Dict[str, Any] = { + "ssl": ssl, + "extra_headers": self.headers, + "subprotocols": self.supported_subprotocols, + } + + # Adding custom parameters passed from init + connect_args.update(self.connect_args) + + # Connection to the specified url + # Generate a TimeoutError if taking more than connect_timeout seconds + # Set the _connecting flag to False after in all cases + try: + self.websocket = await asyncio.wait_for( + websockets.client.connect(self.url, **connect_args), + self.connect_timeout, + ) + finally: + self._connecting = False + + self.websocket = cast(WebSocketClientProtocol, self.websocket) + + # Run the after_connect hook of the subclass + await self._after_connect() + + self.next_query_id = 1 + self.close_exception = None + self._wait_closed.clear() + + # Send the init message and wait for the ack from the server + # Note: This should generate a TimeoutError + # if no ACKs are received within the ack_timeout + try: + await self._initialize() + except ConnectionClosed as e: + raise e + except (TransportProtocolError, asyncio.TimeoutError) as e: + await self._fail(e, clean_close=False) + raise e + + # Run the after_init hook of the subclass + await self._after_initialize() + + # If specified, create a task to check liveness of the connection + # through keep-alive messages + if self.keep_alive_timeout is not None: + self.check_keep_alive_task = asyncio.ensure_future( + self._check_ws_liveness() + ) + + # Create a task to listen to the incoming websocket messages + self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) + + else: + raise TransportAlreadyConnected("Transport is already connected") + + log.debug("connect: done") + + def _remove_listener(self, query_id) -> None: + """After exiting from a subscription, remove the listener and + signal an event if this was the last listener for the client. + """ + if query_id in self.listeners: + del self.listeners[query_id] + + remaining = len(self.listeners) + log.debug(f"listener {query_id} deleted, {remaining} remaining") + + if remaining == 0: + self._no_more_listeners.set() + + async def _clean_close(self, e: Exception) -> None: + """Coroutine which will: + + - send stop messages for each active subscription to the server + - send the connection terminate message + """ + + # Send 'stop' message for all current queries + for query_id, listener in self.listeners.items(): + + if listener.send_stop: + await self._stop_listener(query_id) + listener.send_stop = False + + # Wait that there is no more listeners (we received 'complete' for all queries) + try: + await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) + except asyncio.TimeoutError: # pragma: no cover + log.debug("Timer close_timeout fired") + + # Calling the subclass hook + await self._connection_terminate() + + async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: + """Coroutine which will: + + - do a clean_close if possible: + - send stop messages for each active query to the server + - send the connection terminate message + - close the websocket connection + - send the exception to all the remaining listeners + """ + + log.debug("_close_coro: starting") + + try: + + # We should always have an active websocket connection here + assert self.websocket is not None + + # Properly shut down liveness checker if enabled + if self.check_keep_alive_task is not None: + # More info: https://stackoverflow.com/a/43810272/1113207 + self.check_keep_alive_task.cancel() + with suppress(asyncio.CancelledError): + await self.check_keep_alive_task + + # Calling the subclass close hook + await self._close_hook() + + # Saving exception to raise it later if trying to use the transport + # after it has already closed. + self.close_exception = e + + if clean_close: + log.debug("_close_coro: starting clean_close") + try: + await self._clean_close(e) + except Exception as exc: # pragma: no cover + log.warning("Ignoring exception in _clean_close: " + repr(exc)) + + log.debug("_close_coro: sending exception to listeners") + + # Send an exception to all remaining listeners + for query_id, listener in self.listeners.items(): + await listener.set_exception(e) + + log.debug("_close_coro: close websocket connection") + + await self.websocket.close() + + log.debug("_close_coro: websocket connection closed") + + except Exception as exc: # pragma: no cover + log.warning("Exception catched in _close_coro: " + repr(exc)) + + finally: + + log.debug("_close_coro: start cleanup") + + self.websocket = None + self.close_task = None + self.check_keep_alive_task = None + self._wait_closed.set() + + log.debug("_close_coro: exiting") + + async def _fail(self, e: Exception, clean_close: bool = True) -> None: + log.debug("_fail: starting with exception: " + repr(e)) + + if self.close_task is None: + + if self.websocket is None: + log.debug("_fail started with self.websocket == None -> already closed") + else: + self.close_task = asyncio.shield( + asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) + ) + else: + log.debug( + "close_task is not None in _fail. Previous exception is: " + + repr(self.close_exception) + + " New exception is: " + + repr(e) + ) + + async def close(self) -> None: + log.debug("close: starting") + + await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) + await self.wait_closed() + + log.debug("close: done") + + async def wait_closed(self) -> None: + log.debug("wait_close: starting") + + await self._wait_closed.wait() + + log.debug("wait_close: done") diff --git a/setup.py b/setup.py index 266fbb0c..7e97f8bc 100644 --- a/setup.py +++ b/setup.py @@ -50,8 +50,12 @@ "websockets>=10,<11;python_version>'3.6'", ] +install_botocore_requires = [ + "botocore>=1.21,<2", +] + install_all_requires = ( - install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_aiohttp_requires + install_requests_requires + install_websockets_requires + install_botocore_requires ) # Get version from __version__.py file @@ -97,6 +101,7 @@ "aiohttp": install_aiohttp_requires, "requests": install_requests_requires, "websockets": install_websockets_requires, + "botocore": install_botocore_requires, }, include_package_data=True, zip_safe=False, diff --git a/tests/conftest.py b/tests/conftest.py index c0101241..d433c1ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,11 +14,7 @@ from gql import Client -all_transport_dependencies = [ - "aiohttp", - "requests", - "websockets", -] +all_transport_dependencies = ["aiohttp", "requests", "websockets", "botocore"] def pytest_addoption(parser): @@ -116,10 +112,11 @@ async def ssl_aiohttp_server(): yield server -# Adding debug logs to websocket tests +# Adding debug logs for name in [ "websockets.legacy.server", "gql.transport.aiohttp", + "gql.transport.appsync", "gql.transport.phoenix_channel_websockets", "gql.transport.requests", "gql.transport.websockets", @@ -492,3 +489,11 @@ async def run_sync_test_inner(event_loop, server, test_function): await server.close() return run_sync_test_inner + + +pytest_plugins = [ + "tests.fixtures.aws.fake_credentials", + "tests.fixtures.aws.fake_request", + "tests.fixtures.aws.fake_session", + "tests.fixtures.aws.fake_signer", +] diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/aws/__init__.py b/tests/fixtures/aws/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/aws/fake_credentials.py b/tests/fixtures/aws/fake_credentials.py new file mode 100644 index 00000000..d8eac834 --- /dev/null +++ b/tests/fixtures/aws/fake_credentials.py @@ -0,0 +1,28 @@ +import pytest + + +class FakeCredentials(object): + def __init__( + self, access_key=None, secret_key=None, method=None, token=None, region=None + ): + self.region = region if region else "us-east-1a" + self.access_key = access_key if access_key else "fake-access-key" + self.secret_key = secret_key if secret_key else "fake-secret-key" + self.method = method if method else "shared-credentials-file" + self.token = token if token else "fake-token" + + +@pytest.fixture +def fake_credentials_factory(): + def _fake_credentials_factory( + access_key=None, secret_key=None, method=None, token=None, region=None + ): + return FakeCredentials( + access_key=access_key, + secret_key=secret_key, + method=method, + token=token, + region=region, + ) + + yield _fake_credentials_factory diff --git a/tests/fixtures/aws/fake_request.py b/tests/fixtures/aws/fake_request.py new file mode 100644 index 00000000..615bc095 --- /dev/null +++ b/tests/fixtures/aws/fake_request.py @@ -0,0 +1,22 @@ +import pytest + + +class FakeRequest(object): + headers = None + + def __init__(self, request_props=None): + if not isinstance(request_props, dict): + return + self.method = request_props.get("method") + self.url = request_props.get("url") + self.headers = request_props.get("headers") + self.context = request_props.get("context") + self.body = request_props.get("body") + + +@pytest.fixture +def fake_request_factory(): + def _fake_request_factory(request_props=None): + return FakeRequest(request_props=request_props) + + yield _fake_request_factory diff --git a/tests/fixtures/aws/fake_session.py b/tests/fixtures/aws/fake_session.py new file mode 100644 index 00000000..78e1511a --- /dev/null +++ b/tests/fixtures/aws/fake_session.py @@ -0,0 +1,24 @@ +import pytest + + +class FakeSession(object): + def __init__(self, credentials, region_name): + self._credentials = credentials + self._region_name = region_name + + def get_default_client_config(self): + return + + def get_credentials(self): + return self._credentials + + def _resolve_region_name(self, region_name, client_config): + return region_name if region_name else self._region_name + + +@pytest.fixture +def fake_session_factory(fake_credentials_factory): + def _fake_session_factory(credentials=fake_credentials_factory()): + return FakeSession(credentials=credentials, region_name="fake-region") + + yield _fake_session_factory diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py new file mode 100644 index 00000000..ff096745 --- /dev/null +++ b/tests/fixtures/aws/fake_signer.py @@ -0,0 +1,27 @@ +import pytest + + +@pytest.fixture +def fake_signer_factory(fake_request_factory): + def _fake_signer_factory(request=None): + if not request: + request = fake_request_factory() + return FakeSigner(request=request) + + yield _fake_signer_factory + + +class FakeSigner(object): + def __init__(self, request=None) -> None: + self.request = request + + def add_auth(self, request) -> None: + """ + A fake for getting a request object that + :return: + """ + request.headers = {"FakeAuthorization": "a", "FakeTime": "today"} + + def get_headers(self): + self.add_auth(self.request) + return self.request.headers diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py new file mode 100644 index 00000000..546e0e6f --- /dev/null +++ b/tests/test_appsync_auth.py @@ -0,0 +1,189 @@ +import pytest + +mock_transport_host = "appsyncapp.awsgateway.com.example.org" +mock_transport_url = f"https://{mock_transport_host}/graphql" + + +@pytest.mark.botocore +def test_appsync_init_with_minimal_args(fake_session_factory): + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, session=fake_session_factory() + ) + assert isinstance(sample_transport.auth, AppSyncIAMAuthentication) + assert sample_transport.connect_timeout == 10 + assert sample_transport.close_timeout == 10 + assert sample_transport.ack_timeout == 10 + assert sample_transport.ssl is False + assert sample_transport.connect_args == {} + + +@pytest.mark.botocore +def test_appsync_init_with_no_credentials(caplog, fake_session_factory): + import botocore.exceptions + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + with pytest.raises(botocore.exceptions.NoCredentialsError): + sample_transport = AppSyncWebsocketsTransport( + url=mock_transport_url, session=fake_session_factory(credentials=None), + ) + assert sample_transport.auth is None + + expected_error = "Credentials not found" + + print(f"Captured log: {caplog.text}") + + assert expected_error in caplog.text + + +@pytest.mark.websockets +def test_appsync_init_with_jwt_auth(): + from gql.transport.appsync_auth import AppSyncJWTAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + auth = AppSyncJWTAuthentication(host=mock_transport_host, jwt="some-jwt") + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert sample_transport.auth is auth + + assert auth.get_headers() == { + "host": mock_transport_host, + "Authorization": "some-jwt", + } + + +@pytest.mark.websockets +def test_appsync_init_with_apikey_auth(): + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + auth = AppSyncApiKeyAuthentication(host=mock_transport_host, api_key="some-api-key") + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert sample_transport.auth is auth + + assert auth.get_headers() == { + "host": mock_transport_host, + "x-api-key": "some-api-key", + } + + +@pytest.mark.botocore +def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): + import botocore.exceptions + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + auth = AppSyncIAMAuthentication( + host=mock_transport_host, session=fake_session_factory(credentials=None), + ) + with pytest.raises(botocore.exceptions.NoCredentialsError): + AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + + +@pytest.mark.botocore +def test_appsync_init_with_iam_auth_with_creds(fake_credentials_factory): + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + auth = AppSyncIAMAuthentication( + host=mock_transport_host, + credentials=fake_credentials_factory(), + region_name="us-east-1", + ) + sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert sample_transport.auth is auth + + +@pytest.mark.botocore +def test_appsync_init_with_iam_auth_and_no_region( + caplog, fake_credentials_factory, fake_session_factory +): + """ + + WARNING: this test will fail if: + - you have a default region set in ~/.aws/config + - you have the AWS_DEFAULT_REGION environment variable set + + """ + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.exceptions import NoRegionError + import logging + + caplog.set_level(logging.WARNING) + + with pytest.raises(NoRegionError): + session = fake_session_factory(credentials=fake_credentials_factory()) + session._region_name = None + session._credentials.region = None + transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=session) + + # prints the region name in case the test fails + print(f"Region found: {transport.auth._region_name}") + + print(f"Captured: {caplog.text}") + + expected_error = ( + "Region name not found. " + "It was not possible to detect your region either from the host " + "or from your default AWS configuration." + ) + + assert expected_error in caplog.text + + +@pytest.mark.botocore +def test_munge_url(fake_signer_factory, fake_request_factory): + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + test_url = "https://appsync-api.aws.example.org/some-other-params" + + auth = AppSyncIAMAuthentication( + host=test_url, + signer=fake_signer_factory(), + request_creator=fake_request_factory, + ) + sample_transport = AppSyncWebsocketsTransport(url=test_url, auth=auth) + + header_string = ( + "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5" + "IiwiaG9zdCI6Imh0dHBzOi8vYXBwc3luYy1hcGkuYXdzLmV4YW1wbGUu" + "b3JnL3NvbWUtb3RoZXItcGFyYW1zIn0=" + ) + expected_url = ( + "wss://appsync-realtime-api.aws.example.org/" + f"some-other-params?header={header_string}&payload=e30=" + ) + assert sample_transport.url == expected_url + + +@pytest.mark.botocore +def test_munge_url_format( + fake_signer_factory, + fake_request_factory, + fake_credentials_factory, + fake_session_factory, +): + from gql.transport.appsync_auth import AppSyncIAMAuthentication + + test_url = "https://appsync-api.aws.example.org/some-other-params" + + auth = AppSyncIAMAuthentication( + host=test_url, + signer=fake_signer_factory(), + session=fake_session_factory(), + request_creator=fake_request_factory, + credentials=fake_credentials_factory(), + ) + + header_string = ( + "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5" + "IiwiaG9zdCI6Imh0dHBzOi8vYXBwc3luYy1hcGkuYXdzLmV4YW1wbGUu" + "b3JnL3NvbWUtb3RoZXItcGFyYW1zIn0=" + ) + expected_url = ( + "wss://appsync-realtime-api.aws.example.org/" + f"some-other-params?header={header_string}&payload=e30=" + ) + assert auth.get_auth_url(test_url) == expected_url diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py new file mode 100644 index 00000000..1f787a68 --- /dev/null +++ b/tests/test_appsync_http.py @@ -0,0 +1,78 @@ +import json + +import pytest + +from gql import Client, gql + + +@pytest.mark.asyncio +@pytest.mark.aiohttp +@pytest.mark.botocore +async def test_appsync_iam_mutation( + event_loop, aiohttp_server, fake_credentials_factory +): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from urllib.parse import urlparse + + async def handler(request): + data = { + "createMessage": { + "id": "4b436192-aab2-460c-8bdf-4f2605eb63da", + "message": "Hello world!", + "createdAt": "2021-12-06T14:49:55.087Z", + } + } + payload = { + "data": data, + "extensions": {"received_headers": dict(request.headers)}, + } + + return web.Response( + text=json.dumps(payload, separators=(",", ":")), + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + # Extract host from url + host = str(urlparse(url).netloc) + + auth = AppSyncIAMAuthentication( + host=host, credentials=fake_credentials_factory(), region_name="us-east-1", + ) + + sample_transport = AIOHTTPTransport(url=url, auth=auth) + + async with Client(transport=sample_transport) as session: + + query = gql( + """ +mutation createMessage($message: String!) { + createMessage(input: {message: $message}) { + id + message + createdAt + } +}""" + ) + + # Execute query asynchronously + execution_result = await session.execute(query, get_execution_result=True) + + result = execution_result.data + message = result["createMessage"]["message"] + + assert message == "Hello world!" + + sent_headers = execution_result.extensions["received_headers"] + + assert sent_headers["X-Amz-Security-Token"] == "fake-token" + assert sent_headers["Authorization"].startswith( + "AWS4-HMAC-SHA256 Credential=fake-access-key/" + ) diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py new file mode 100644 index 00000000..f510d4a7 --- /dev/null +++ b/tests/test_appsync_websockets.py @@ -0,0 +1,702 @@ +import asyncio +import json +from base64 import b64decode +from typing import List +from urllib import parse + +import pytest + +from gql import Client, gql + +from .conftest import MS, WebSocketServerHelper + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.websockets + +SEND_MESSAGE_DELAY = 20 * MS +NB_MESSAGES = 10 + +DUMMY_API_KEY = "da2-thisisadummyapikey01234567" +DUMMY_ACCESS_KEY_ID = "DUMMYACCESSKEYID0123" +DUMMY_ACCESS_KEY_ID_NOT_ALLOWED = "DUMMYACCESSKEYID!ALL" +DUMMY_ACCESS_KEY_IDS = [DUMMY_ACCESS_KEY_ID, DUMMY_ACCESS_KEY_ID_NOT_ALLOWED] +DUMMY_SECRET_ACCESS_KEY = "ThisIsADummySecret0123401234012340123401" +DUMMY_SECRET_SESSION_TOKEN = ( + "FwoREDACTEDzEREDACTED+YREDACTEDJLREDACTEDz2REDACTEDH5RE" + "DACTEDbVREDACTEDqwREDACTEDHJREDACTEDxFREDACTEDtMREDACTED5kREDACTEDSwREDACTED0BRED" + "ACTEDuDREDACTEDm4REDACTEDSBREDACTEDaoREDACTEDP2REDACTEDCBREDACTED0wREDACTEDmdREDA" + "CTEDyhREDACTEDSKREDACTEDYbREDACTEDfeREDACTED3UREDACTEDaKREDACTEDi1REDACTEDGEREDAC" + "TED4VREDACTEDjmREDACTEDYcREDACTEDkQREDACTEDyI=" +) +REGION_NAME = "eu-west-3" + +# List which can used to store received messages by the server +logged_messages: List[str] = [] + + +def realtime_appsync_server_factory( + keepalive=False, not_json_answer=False, error_without_id=False +): + def verify_headers(headers, in_query=False): + """Returns an error or None if all is ok""" + + if "x-api-key" in headers: + print("API KEY Authentication detected!") + + if headers["x-api-key"] == DUMMY_API_KEY: + return None + + elif "Authorization" in headers: + if "X-Amz-Security-Token" in headers: + with_token = True + print("IAM Authentication with token detected!") + else: + with_token = False + print("IAM Authentication with token detected!") + print("IAM Authentication without token detected!") + + assert headers["accept"] == "application/json, text/javascript" + assert headers["content-encoding"] == "amz-1.0" + assert headers["content-type"] == "application/json; charset=UTF-8" + assert "X-Amz-Date" in headers + + authorization_fields = headers["Authorization"].split(" ") + + assert authorization_fields[0] == "AWS4-HMAC-SHA256" + + credential_field = authorization_fields[1][:-1].split("=") + assert credential_field[0] == "Credential" + credential_content = credential_field[1].split("/") + assert credential_content[0] in DUMMY_ACCESS_KEY_IDS + + if in_query: + if credential_content[0] == DUMMY_ACCESS_KEY_ID_NOT_ALLOWED: + return { + "errorType": "UnauthorizedException", + "message": "Permission denied", + } + + # assert credential_content[1]== date + # assert credential_content[2]== region + assert credential_content[3] == "appsync" + assert credential_content[4] == "aws4_request" + + signed_headers_field = authorization_fields[2][:-1].split("=") + + assert signed_headers_field[0] == "SignedHeaders" + signed_headers = signed_headers_field[1].split(";") + + assert "accept" in signed_headers + assert "content-encoding" in signed_headers + assert "content-type" in signed_headers + assert "host" in signed_headers + assert "x-amz-date" in signed_headers + + if with_token: + assert "x-amz-security-token" in signed_headers + + signature_field = authorization_fields[3].split("=") + + assert signature_field[0] == "Signature" + + return None + + return { + "errorType": "com.amazonaws.deepdish.graphql.auth#UnauthorizedException", + "message": "You are not authorized to make this call.", + "errorCode": 400, + } + + async def realtime_appsync_server_template(ws, path): + import websockets + + logged_messages.clear() + + try: + if not_json_answer: + await ws.send("Something not json") + return + + if error_without_id: + await ws.send( + json.dumps( + { + "type": "error", + "payload": { + "errors": [ + { + "errorType": "Error without id", + "message": ( + "Sometimes AppSync will send you " + "an error without an id" + ), + } + ] + }, + }, + separators=(",", ":"), + ) + ) + return + + print(f"path = {path}") + + path_base, parameters_str = path.split("?") + + assert path_base == "/graphql" + + parameters = parse.parse_qs(parameters_str) + + header_param = parameters["header"][0] + payload_param = parameters["payload"][0] + + assert payload_param == "e30=" + + headers = json.loads(b64decode(header_param).decode()) + + print("\nHeaders received in URL:") + for key, value in headers.items(): + print(f" {key}: {value}") + print("\n") + + error = verify_headers(headers) + + if error is not None: + await ws.send( + json.dumps( + {"payload": {"errors": [error]}, "type": "connection_error"}, + separators=(",", ":"), + ) + ) + return + + await WebSocketServerHelper.send_connection_ack( + ws, payload='{"connectionTimeoutMs":300000}' + ) + + result = await ws.recv() + logged_messages.append(result) + + json_result = json.loads(result) + + query_id = json_result["id"] + assert json_result["type"] == "start" + + payload = json_result["payload"] + + # With appsync, the data field is serialized to string + data_str = payload["data"] + extensions = payload["extensions"] + + data = json.loads(data_str) + + query = data["query"] + variables = data.get("variables", None) + operation_name = data.get("operationName", None) + print(f"Received query: {query}") + print(f"Received variables: {variables}") + print(f"Received operation_name: {operation_name}") + + authorization = extensions["authorization"] + print("\nHeaders received in the extensions of the query:") + for key, value in authorization.items(): + print(f" {key}: {value}") + print("\n") + + error = verify_headers(headers, in_query=True) + + if error is not None: + await ws.send( + json.dumps( + { + "id": str(query_id), + "type": "error", + "payload": {"errors": [error]}, + }, + separators=(",", ":"), + ) + ) + return + + await ws.send( + json.dumps( + {"id": str(query_id), "type": "start_ack"}, separators=(",", ":") + ) + ) + + async def send_message_coro(): + print(" Server: send message task started") + try: + for number in range(NB_MESSAGES): + payload = { + "data": { + "onCreateMessage": {"message": f"Hello world {number}!"} + } + } + + if operation_name or variables: + + payload["extensions"] = {} + + if operation_name: + payload["extensions"]["operation_name"] = operation_name + if variables: + payload["extensions"]["variables"] = variables + + await ws.send( + json.dumps( + { + "id": str(query_id), + "type": "data", + "payload": payload, + }, + separators=(",", ":"), + ) + ) + await asyncio.sleep(SEND_MESSAGE_DELAY) + finally: + print(" Server: send message task ended") + + print(" Server: starting send message task") + send_message_task = asyncio.ensure_future(send_message_coro()) + + async def keepalive_coro(): + while True: + await asyncio.sleep(5 * MS) + try: + await WebSocketServerHelper.send_keepalive(ws) + except websockets.exceptions.ConnectionClosed: + break + + if keepalive: + print(" Server: starting keepalive task") + keepalive_task = asyncio.ensure_future(keepalive_coro()) + + async def receiving_coro(): + print(" Server: receiving task started") + try: + nonlocal send_message_task + while True: + + try: + result = await ws.recv() + logged_messages.append(result) + except websockets.exceptions.ConnectionClosed: + break + + finally: + print(" Server: receiving task ended") + if keepalive: + keepalive_task.cancel() + + print(" Server: starting receiving task") + receiving_task = asyncio.ensure_future(receiving_coro()) + + try: + print( + " Server: waiting for sending message task to complete" + ) + await send_message_task + except asyncio.CancelledError: + print(" Server: Now sending message task is cancelled") + + print(" Server: sending complete message") + await WebSocketServerHelper.send_complete(ws, query_id) + + if keepalive: + print(" Server: cancelling keepalive task") + keepalive_task.cancel() + try: + await keepalive_task + except asyncio.CancelledError: + print(" Server: Now keepalive task is cancelled") + + print(" Server: waiting for client to close the connection") + try: + await asyncio.wait_for(receiving_task, 1000 * MS) + except asyncio.TimeoutError: + pass + + print(" Server: cancelling receiving task") + receiving_task.cancel() + + try: + await receiving_task + except asyncio.CancelledError: + print(" Server: Now receiving task is cancelled") + + except websockets.exceptions.ConnectionClosedOK: + pass + except AssertionError as e: + print(f"\n Server: Assertion failed: {e!s}\n") + except Exception as e: + print(f"\n Server: Exception received: {e!s}\n") + finally: + print(" Server: waiting for websocket connection to close") + await ws.wait_closed() + print(" Server: connection closed") + + return realtime_appsync_server_template + + +async def realtime_appsync_server(ws, path): + + server = realtime_appsync_server_factory() + await server(ws, path) + + +async def realtime_appsync_server_keepalive(ws, path): + + server = realtime_appsync_server_factory(keepalive=True) + await server(ws, path) + + +async def realtime_appsync_server_not_json_answer(ws, path): + + server = realtime_appsync_server_factory(not_json_answer=True) + await server(ws, path) + + +async def realtime_appsync_server_error_without_id(ws, path): + + server = realtime_appsync_server_factory(error_without_id=True) + await server(ws, path) + + +on_create_message_subscription_str = """ +subscription onCreateMessage { + onCreateMessage { + message + } +} +""" + + +async def default_transport_test(transport): + client = Client(transport=transport) + + expected_messages = [f"Hello world {number}!" for number in range(NB_MESSAGES)] + received_messages = [] + + async with client as session: + subscription = gql(on_create_message_subscription_str) + + async for result in session.subscribe(subscription): + + message = result["onCreateMessage"]["message"] + print(f"Message received: '{message}'") + + received_messages.append(message) + + assert expected_messages == received_messages + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [realtime_appsync_server_keepalive], indirect=True) +async def test_appsync_subscription_api_key(event_loop, server): + + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key=DUMMY_API_KEY) + + transport = AppSyncWebsocketsTransport( + url=url, auth=auth, keep_alive_timeout=(5 * SEND_MESSAGE_DELAY) + ) + + await default_transport_test(transport) + + +@pytest.mark.asyncio +@pytest.mark.botocore +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_subscription_iam_with_token(event_loop, server): + + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.credentials import Credentials + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID, + secret_key=DUMMY_SECRET_ACCESS_KEY, + token=DUMMY_SECRET_SESSION_TOKEN, + ) + + auth = AppSyncIAMAuthentication( + host=server.hostname, credentials=dummy_credentials, region_name=REGION_NAME + ) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + await default_transport_test(transport) + + +@pytest.mark.asyncio +@pytest.mark.botocore +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_subscription_iam_without_token(event_loop, server): + + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.credentials import Credentials + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, + ) + + auth = AppSyncIAMAuthentication( + host=server.hostname, credentials=dummy_credentials, region_name=REGION_NAME + ) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + await default_transport_test(transport) + + +@pytest.mark.asyncio +@pytest.mark.botocore +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_execute_method_not_allowed(event_loop, server): + + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.credentials import Credentials + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, + ) + + auth = AppSyncIAMAuthentication( + host=server.hostname, credentials=dummy_credentials, region_name=REGION_NAME + ) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + async with client as session: + query = gql( + """ +mutation createMessage($message: String!) { + createMessage(input: {message: $message}) { + id + message + createdAt + } +}""" + ) + + variable_values = {"message": "Hello world!"} + + with pytest.raises(AssertionError) as exc_info: + await session.execute(query, variable_values=variable_values) + + assert ( + "execute method is not allowed for AppSyncWebsocketsTransport " + "because only subscriptions are allowed on the realtime endpoint." + ) in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.botocore +async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): + + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from botocore.credentials import Credentials + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID, secret_key=DUMMY_SECRET_ACCESS_KEY, + ) + + auth = AppSyncIAMAuthentication( + host="something", credentials=dummy_credentials, region_name=REGION_NAME + ) + + transport = AppSyncWebsocketsTransport(url="https://something", auth=auth) + + with pytest.raises(AssertionError) as exc_info: + Client(transport=transport, fetch_schema_from_transport=True) + + assert ( + "fetch_schema_from_transport=True is not allowed for AppSyncWebsocketsTransport" + " because only subscriptions are allowed on the realtime endpoint." + ) in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_subscription_api_key_unauthorized(event_loop, server): + + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from gql.transport.exceptions import TransportServerError + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key="invalid") + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + with pytest.raises(TransportServerError) as exc_info: + async with client as _: + pass + + assert "You are not authorized to make this call." in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.botocore +@pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) +async def test_appsync_subscription_iam_not_allowed(event_loop, server): + + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from gql.transport.exceptions import TransportQueryError + from botocore.credentials import Credentials + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + dummy_credentials = Credentials( + access_key=DUMMY_ACCESS_KEY_ID_NOT_ALLOWED, + secret_key=DUMMY_SECRET_ACCESS_KEY, + token=DUMMY_SECRET_SESSION_TOKEN, + ) + + auth = AppSyncIAMAuthentication( + host=server.hostname, credentials=dummy_credentials, region_name=REGION_NAME + ) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + async with client as session: + subscription = gql(on_create_message_subscription_str) + + with pytest.raises(TransportQueryError) as exc_info: + + async for result in session.subscribe(subscription): + pass + + assert "Permission denied" in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [realtime_appsync_server_not_json_answer], indirect=True +) +async def test_appsync_subscription_server_sending_a_not_json_answer( + event_loop, server +): + + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from gql.transport.exceptions import TransportProtocolError + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key=DUMMY_API_KEY) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + with pytest.raises(TransportProtocolError) as exc_info: + async with client as _: + pass + + assert "Server did not return a GraphQL result: Something not json" in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [realtime_appsync_server_error_without_id], indirect=True +) +async def test_appsync_subscription_server_sending_an_error_without_an_id( + event_loop, server +): + + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + from gql.transport.exceptions import TransportServerError + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key=DUMMY_API_KEY) + + transport = AppSyncWebsocketsTransport(url=url, auth=auth) + + client = Client(transport=transport) + + with pytest.raises(TransportServerError) as exc_info: + async with client as _: + pass + + assert "Sometimes AppSync will send you an error without an id" in str(exc_info) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [realtime_appsync_server_keepalive], indirect=True) +async def test_appsync_subscription_variable_values_and_operation_name( + event_loop, server +): + + from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + + auth = AppSyncApiKeyAuthentication(host=server.hostname, api_key=DUMMY_API_KEY) + + transport = AppSyncWebsocketsTransport( + url=url, auth=auth, keep_alive_timeout=(5 * SEND_MESSAGE_DELAY) + ) + + client = Client(transport=transport) + + expected_messages = [f"Hello world {number}!" for number in range(NB_MESSAGES)] + received_messages = [] + + async with client as session: + subscription = gql(on_create_message_subscription_str) + + async for execution_result in session.subscribe( + subscription, + operation_name="onCreateMessage", + variable_values={"key1": "val1"}, + get_execution_result=True, + ): + + result = execution_result.data + message = result["onCreateMessage"]["message"] + print(f"Message received: '{message}'") + + received_messages.append(message) + + print(f"extensions received: {execution_result.extensions}") + + assert execution_result.extensions["operation_name"] == "onCreateMessage" + variables = execution_result.extensions["variables"] + assert variables["key1"] == "val1" + + assert expected_messages == received_messages diff --git a/tox.ini b/tox.ini index 2699744c..e75b8fac 100644 --- a/tox.ini +++ b/tox.ini @@ -3,9 +3,6 @@ envlist = black,flake8,import-order,mypy,manifest, py{36,37,38,39,310,py3} -[pytest] -markers = asyncio - [gh-actions] python = 3.6: py36