From cc8ef9dd14e546502ea516a9a3d4bd6f231d29b3 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Wed, 29 Jun 2022 09:20:14 +0200 Subject: [PATCH 1/5] Add support for new temporal packstream encoding (UTC) --- CHANGELOG.md | 13 + bin/dist-functions | 20 +- bin/make-unasync | 2 +- docs/source/conf.py | 2 +- neo4j/__init__.py | 13 +- neo4j/_async/driver.py | 10 +- neo4j/_async/io/_bolt.py | 179 +++++++--- neo4j/_async/io/_bolt3.py | 83 +++-- neo4j/_async/io/_bolt4.py | 138 ++++++-- neo4j/_async/io/_bolt5.py | 94 +++-- neo4j/_async/io/_common.py | 170 +++++---- neo4j/_async/io/_pool.py | 2 +- neo4j/_async/work/result.py | 22 +- neo4j/_async/work/session.py | 13 +- neo4j/_async/work/transaction.py | 4 +- neo4j/_async/work/workspace.py | 8 +- neo4j/_async_compat/network/_bolt_socket.py | 2 +- neo4j/_async_compat/util.py | 2 +- .../common/data => neo4j/_codec}/__init__.py | 0 neo4j/_codec/hydration/__init__.py | 25 ++ neo4j/_codec/hydration/_common.py | 50 +++ neo4j/_codec/hydration/_interface/__init__.py | 29 ++ neo4j/_codec/hydration/v1/__init__.py | 24 ++ .../_codec/hydration/v1/hydration_handler.py | 201 +++++++++++ neo4j/_codec/hydration/v1/spatial.py | 63 ++++ .../hydration/v1/temporal.py} | 4 +- neo4j/_codec/hydration/v2/__init__.py | 23 ++ .../_codec/hydration/v2/hydration_handler.py | 57 +++ neo4j/_codec/hydration/v2/temporal.py | 92 +++++ neo4j/_codec/packstream/__init__.py | 24 ++ neo4j/_codec/packstream/_common.py | 44 +++ .../packstream/v1/__init__.py} | 137 ++++---- neo4j/{data.py => _data.py} | 160 --------- neo4j/{meta.py => _meta.py} | 8 +- neo4j/_spatial/__init__.py | 106 ++++++ neo4j/_sync/driver.py | 10 +- neo4j/_sync/io/_bolt.py | 177 +++++++--- neo4j/_sync/io/_bolt3.py | 83 +++-- neo4j/_sync/io/_bolt4.py | 138 ++++++-- neo4j/_sync/io/_bolt5.py | 94 +++-- neo4j/_sync/io/_common.py | 168 +++++---- neo4j/_sync/io/_pool.py | 2 +- neo4j/_sync/work/result.py | 22 +- neo4j/_sync/work/session.py | 13 +- neo4j/_sync/work/transaction.py | 4 +- neo4j/_sync/work/workspace.py | 8 +- neo4j/api.py | 2 +- neo4j/conf.py | 8 +- neo4j/graph/__init__.py | 84 ----- neo4j/spatial/__init__.py | 156 +++------ neo4j/time/__init__.py | 23 +- neo4j/time/{arithmetic.py => _arithmetic.py} | 9 + ...entations.py => _clock_implementations.py} | 7 + .../time/{metaclasses.py => _metaclasses.py} | 7 + setup.cfg | 8 +- setup.py | 2 +- testkitbackend/_async/backend.py | 4 +- testkitbackend/_async/requests.py | 2 +- testkitbackend/_sync/backend.py | 4 +- testkitbackend/_sync/requests.py | 2 +- testkitbackend/test_config.json | 1 + tests/conftest.py | 2 +- tests/env.py | 2 +- tests/unit/async_/io/conftest.py | 82 ++--- tests/unit/async_/io/test__common.py | 40 ++- tests/unit/async_/io/test_class_bolt3.py | 10 +- tests/unit/async_/io/test_class_bolt4x0.py | 22 +- tests/unit/async_/io/test_class_bolt4x1.py | 30 +- tests/unit/async_/io/test_class_bolt4x2.py | 30 +- tests/unit/async_/io/test_class_bolt4x3.py | 30 +- tests/unit/async_/io/test_class_bolt4x4.py | 30 +- tests/unit/async_/work/test_result.py | 160 ++++----- tests/unit/async_/work/test_session.py | 81 +++-- tests/unit/async_/work/test_transaction.py | 17 - tests/unit/common/codec/__init__.py | 16 + tests/unit/common/codec/hydration/__init__.py | 16 + tests/unit/common/codec/hydration/_common.py | 71 ++++ .../common/codec/hydration/v1/__init__.py | 16 + tests/unit/common/codec/hydration/v1/_base.py | 29 ++ .../hydration/v1/test_graph_hydration.py | 67 ++++ .../hydration/v1/test_hydration_handler.py | 78 +++++ .../hydration/v1/test_spacial_dehydration.py | 73 ++++ .../hydration/v1/test_spacial_hydration.py | 77 +++++ .../hydration/v1/test_time_dehydration.py | 193 +++++++++++ .../codec/hydration/v1/test_time_hydration.py | 167 +++++++++ .../common/codec/hydration/v2/__init__.py | 16 + .../hydration/v2/test_graph_hydration.py | 67 ++++ .../hydration/v2/test_hydration_handler.py | 31 ++ .../hydration/v2/test_spacial_dehydration.py | 31 ++ .../hydration/v2/test_spacial_hydration.py | 31 ++ .../hydration/v2/test_time_dehydration.py | 74 ++++ .../codec/hydration/v2/test_time_hydration.py | 74 ++++ .../unit/common/codec/packstream/__init__.py | 16 + .../common/codec/packstream/v1/__init__.py | 16 + .../codec/packstream/v1/test_packstream.py | 326 ++++++++++++++++++ tests/unit/common/data/test_packing.py | 284 --------------- .../common/spatial/test_cartesian_point.py | 36 -- tests/unit/common/spatial/test_point.py | 20 -- tests/unit/common/spatial/test_wgs84_point.py | 34 -- tests/unit/common/test_addressing.py | 2 +- tests/unit/common/test_api.py | 122 ------- tests/unit/common/test_data.py | 147 -------- tests/unit/common/test_import_neo4j.py | 8 +- tests/unit/common/test_record.py | 17 +- tests/unit/common/test_types.py | 69 ++-- tests/unit/common/time/__init__.py | 2 +- tests/unit/common/time/test_clock.py | 2 +- tests/unit/common/time/test_datetime.py | 8 +- tests/unit/common/time/test_dehydration.py | 135 -------- tests/unit/common/time/test_duration.py | 1 - tests/unit/common/time/test_hydration.py | 114 ------ tests/unit/common/time/test_time.py | 8 +- tests/unit/mixed/io/test_direct.py | 2 +- tests/unit/sync/io/conftest.py | 82 ++--- tests/unit/sync/io/test__common.py | 38 +- tests/unit/sync/io/test_class_bolt3.py | 10 +- tests/unit/sync/io/test_class_bolt4x0.py | 22 +- tests/unit/sync/io/test_class_bolt4x1.py | 30 +- tests/unit/sync/io/test_class_bolt4x2.py | 30 +- tests/unit/sync/io/test_class_bolt4x3.py | 30 +- tests/unit/sync/io/test_class_bolt4x4.py | 30 +- tests/unit/sync/work/test_result.py | 160 ++++----- tests/unit/sync/work/test_session.py | 81 +++-- tests/unit/sync/work/test_transaction.py | 17 - 124 files changed, 3982 insertions(+), 2442 deletions(-) rename {tests/unit/common/data => neo4j/_codec}/__init__.py (100%) create mode 100644 neo4j/_codec/hydration/__init__.py create mode 100644 neo4j/_codec/hydration/_common.py create mode 100644 neo4j/_codec/hydration/_interface/__init__.py create mode 100644 neo4j/_codec/hydration/v1/__init__.py create mode 100644 neo4j/_codec/hydration/v1/hydration_handler.py create mode 100644 neo4j/_codec/hydration/v1/spatial.py rename neo4j/{time/hydration.py => _codec/hydration/v1/temporal.py} (98%) create mode 100644 neo4j/_codec/hydration/v2/__init__.py create mode 100644 neo4j/_codec/hydration/v2/hydration_handler.py create mode 100644 neo4j/_codec/hydration/v2/temporal.py create mode 100644 neo4j/_codec/packstream/__init__.py create mode 100644 neo4j/_codec/packstream/_common.py rename neo4j/{packstream.py => _codec/packstream/v1/__init__.py} (78%) rename neo4j/{data.py => _data.py} (66%) rename neo4j/{meta.py => _meta.py} (94%) create mode 100644 neo4j/_spatial/__init__.py rename neo4j/time/{arithmetic.py => _arithmetic.py} (95%) rename neo4j/time/{clock_implementations.py => _clock_implementations.py} (97%) rename neo4j/time/{metaclasses.py => _metaclasses.py} (96%) create mode 100644 tests/unit/common/codec/__init__.py create mode 100644 tests/unit/common/codec/hydration/__init__.py create mode 100644 tests/unit/common/codec/hydration/_common.py create mode 100644 tests/unit/common/codec/hydration/v1/__init__.py create mode 100644 tests/unit/common/codec/hydration/v1/_base.py create mode 100644 tests/unit/common/codec/hydration/v1/test_graph_hydration.py create mode 100644 tests/unit/common/codec/hydration/v1/test_hydration_handler.py create mode 100644 tests/unit/common/codec/hydration/v1/test_spacial_dehydration.py create mode 100644 tests/unit/common/codec/hydration/v1/test_spacial_hydration.py create mode 100644 tests/unit/common/codec/hydration/v1/test_time_dehydration.py create mode 100644 tests/unit/common/codec/hydration/v1/test_time_hydration.py create mode 100644 tests/unit/common/codec/hydration/v2/__init__.py create mode 100644 tests/unit/common/codec/hydration/v2/test_graph_hydration.py create mode 100644 tests/unit/common/codec/hydration/v2/test_hydration_handler.py create mode 100644 tests/unit/common/codec/hydration/v2/test_spacial_dehydration.py create mode 100644 tests/unit/common/codec/hydration/v2/test_spacial_hydration.py create mode 100644 tests/unit/common/codec/hydration/v2/test_time_dehydration.py create mode 100644 tests/unit/common/codec/hydration/v2/test_time_hydration.py create mode 100644 tests/unit/common/codec/packstream/__init__.py create mode 100644 tests/unit/common/codec/packstream/v1/__init__.py create mode 100644 tests/unit/common/codec/packstream/v1/test_packstream.py delete mode 100644 tests/unit/common/data/test_packing.py delete mode 100644 tests/unit/common/test_data.py delete mode 100644 tests/unit/common/time/test_dehydration.py delete mode 100644 tests/unit/common/time/test_hydration.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ffe36b39f..0835a6e9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -98,6 +98,19 @@ - ANSI colour codes for log output are now opt-in - Prepend log format with log-level (if colours are disabled) - Prepend log format with thread name and id +- Importing submodules from `neo4j.time` (`neo4j.time.xyz`) has been deprecated. + Everything needed should be imported from `neo4j.time` directly. +- `neo4j.spatial.hydrate_point` and `neo4j.spatial.dehydrate_point` have been + deprecated without replacement. They are internal functions. +- Importing `neo4j.packstream` has been deprecated. It's internal and should not + be used by client code. +- Importing `neo4j.meta` has been deprecated. It's internal and should not + be used by client code. `ExperimantalWarning` should be imported directly from + `neo4j`. `neo4j.meta.version` is exposed through `neo4j.__vesrion__` +- Importing `neo4j.data` has been deprecated. It's internal and should not + be used by client code. `Record` should be imported directly from `neo4j` + instead. `neo4j.data.DataHydrator` and `neo4j.data.DataDeydrator` have been + removed without replacement. ## Version 4.4 diff --git a/bin/dist-functions b/bin/dist-functions index f4d6780c2..04688be10 100644 --- a/bin/dist-functions +++ b/bin/dist-functions @@ -6,22 +6,22 @@ DIST="${ROOT}/dist" function get_package { - python -c "from neo4j.meta import package; print(package)" + python -c "from neo4j._meta import package; print(package)" } function set_package { - sed -i 's/^package = .*/package = "'$1'"/g' neo4j/meta.py + sed -i 's/^package = .*/package = "'$1'"/g' neo4j/_meta.py } function get_version { - python -c "from neo4j.meta import version; print(version)" + python -c "from neo4j._meta import version; print(version)" } function set_version { - sed -i 's/^version = .*/version = "'$1'"/g' neo4j/meta.py + sed -i 's/^version = .*/version = "'$1'"/g' neo4j/_meta.py } function check_file @@ -49,8 +49,8 @@ function set_metadata_and_setup ORIGINAL_VERSION=$(get_version) echo "Source code originally configured for package ${ORIGINAL_PACKAGE}/${ORIGINAL_VERSION}" echo "----------------------------------------" - grep "package\s\+=" neo4j/meta.py - grep "version\s\+=" neo4j/meta.py + grep "package\s\+=" neo4j/_meta.py + grep "version\s\+=" neo4j/_meta.py echo "----------------------------------------" function cleanup() { @@ -59,8 +59,8 @@ function set_metadata_and_setup set_version "${ORIGINAL_VERSION}" echo "Source code reconfigured back to original package ${ORIGINAL_PACKAGE}/${ORIGINAL_VERSION}" echo "----------------------------------------" - grep "package\s\+=" neo4j/meta.py - grep "version\s\+=" neo4j/meta.py + grep "package\s\+=" neo4j/_meta.py + grep "version\s\+=" neo4j/_meta.py echo "----------------------------------------" } trap cleanup EXIT @@ -70,8 +70,8 @@ function set_metadata_and_setup set_version "${VERSION}" echo "Source code reconfigured for package ${PACKAGE}/${VERSION}" echo "----------------------------------------" - grep "package\s\+=" neo4j/meta.py - grep "version\s\+=" neo4j/meta.py + grep "package\s\+=" neo4j/_meta.py + grep "version\s\+=" neo4j/_meta.py echo "----------------------------------------" # Create source distribution diff --git a/bin/make-unasync b/bin/make-unasync index 6fe47b8c5..d99a51a9b 100755 --- a/bin/make-unasync +++ b/bin/make-unasync @@ -21,10 +21,10 @@ import collections import errno import os -from pathlib import Path import re import sys import tokenize as std_tokenize +from pathlib import Path import isort import isort.files diff --git a/docs/source/conf.py b/docs/source/conf.py index 9b2c2ff4c..d6cfe9fe4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,7 +21,7 @@ sys.path.insert(0, os.path.abspath(os.path.join("..", ".."))) -from neo4j.meta import version as project_version +from neo4j import __version__ as project_version # -- General configuration ------------------------------------------------ diff --git a/neo4j/__init__.py b/neo4j/__init__.py index 435bf3649..1251f0f03 100644 --- a/neo4j/__init__.py +++ b/neo4j/__init__.py @@ -88,6 +88,12 @@ TrustCustomCAs, TrustSystemCAs, ) +from ._data import Record +from ._meta import ( + ExperimentalWarning, + get_user_agent, + version as __version__, +) from ._sync.driver import ( BoltDriver, Driver, @@ -131,13 +137,6 @@ SessionConfig, WorkspaceConfig, ) -from .data import Record -from .meta import ( - experimental, - ExperimentalWarning, - get_user_agent, - version as __version__, -) from .work import ( Query, ResultSummary, diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py index 254742ca0..72aee4045 100644 --- a/neo4j/_async/driver.py +++ b/neo4j/_async/driver.py @@ -21,6 +21,11 @@ TrustAll, TrustStore, ) +from .._meta import ( + deprecation_warn, + experimental, + unclosed_resource_warn, +) from ..addressing import Address from ..api import ( READ_ACCESS, @@ -33,11 +38,6 @@ SessionConfig, WorkspaceConfig, ) -from ..meta import ( - deprecation_warn, - experimental, - unclosed_resource_warn, -) class AsyncGraphDatabase: diff --git a/neo4j/_async/io/_bolt.py b/neo4j/_async/io/_bolt.py index 8c7e82dea..66f14aa10 100644 --- a/neo4j/_async/io/_bolt.py +++ b/neo4j/_async/io/_bolt.py @@ -24,11 +24,14 @@ from ..._async_compat.network import AsyncBoltSocket from ..._async_compat.util import AsyncUtil +from ..._codec.hydration import v1 as hydration_v1 +from ..._codec.packstream import v1 as packstream_v1 from ..._exceptions import ( BoltError, BoltHandshakeError, SocketDeadlineExceeded, ) +from ..._meta import get_user_agent from ...addressing import Address from ...api import ( ServerInfo, @@ -42,15 +45,10 @@ ServiceUnavailable, SessionExpired, ) -from ...meta import get_user_agent -from ...packstream import ( - Packer, - Unpacker, -) from ._common import ( AsyncInbox, + AsyncOutbox, CommitResponse, - Outbox, ) @@ -68,6 +66,13 @@ class AsyncBolt: the handshake was carried out. """ + # TODO: let packer/unpacker know of hydration (give them hooks?) + # TODO: make sure query parameter dehydration gets clear error message. + + PACKER_CLS = packstream_v1.Packer + UNPACKER_CLS = packstream_v1.Unpacker + HYDRATION_HANDLER_CLS = hydration_v1.HydrationHandler + MAGIC_PREAMBLE = b"\x60\x60\xB0\x17" PROTOCOL_VERSION = None @@ -107,10 +112,16 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, # configuration hint that exists. Therefore, all hints can be stored at # connection level. This might change in the future. self.configuration_hints = {} - self.outbox = Outbox() - self.inbox = AsyncInbox(self.socket, on_error=self._set_defunct_read) - self.packer = Packer(self.outbox) - self.unpacker = Unpacker(self.inbox) + self.patch = {} + self.outbox = AsyncOutbox( + self.socket, on_error=self._set_defunct_write, + packer_cls=self.PACKER_CLS + ) + self.inbox = AsyncInbox( + self.socket, on_error=self._set_defunct_read, + unpacker_cls=self.UNPACKER_CLS + ) + self.hydration_handler = self.HYDRATION_HANDLER_CLS() self.responses = deque() self._max_connection_lifetime = max_connection_lifetime self._creation_timestamp = perf_counter() @@ -376,14 +387,17 @@ def der_encoded_server_certificate(self): pass @abc.abstractmethod - async def hello(self): + async def hello(self, dehydration_hooks=None, hydration_hooks=None): """ Appends a HELLO message to the outgoing queue, sends it and consumes all remaining messages. """ pass @abc.abstractmethod - async def route(self, database=None, imp_user=None, bookmarks=None): + async def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): """ Fetch a routing table from the server for the given `database`. For Bolt 4.3 and above, this appends a ROUTE message; for earlier versions, a procedure call is made via @@ -396,13 +410,22 @@ async def route(self, database=None, imp_user=None, bookmarks=None): Requires Bolt 4.4+. :param bookmarks: iterable of bookmark values after which this transaction should begin - :return: dictionary of raw routing data + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. """ pass @abc.abstractmethod def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, + **handlers): """ Appends a RUN message to the output queue. :param query: Cypher query string @@ -415,36 +438,60 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, Requires Bolt 4.0+. :param imp_user: the user to impersonate Requires Bolt 4.4+. + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object - :return: Response object """ pass @abc.abstractmethod - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): """ Appends a DISCARD message to the output queue. :param n: number of records to discard, default = -1 (ALL) :param qid: query ID to discard for, default = -1 (last query) + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object - :return: Response object """ pass @abc.abstractmethod - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): """ Appends a PULL message to the output queue. :param n: number of records to pull, default = -1 (ALL) :param qid: query ID to pull for, default = -1 (last query) + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object - :return: Response object """ pass @abc.abstractmethod def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): """ Appends a BEGIN message to the output queue. :param mode: access mode for routing - "READ" or "WRITE" (default) @@ -455,53 +502,99 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, Requires Bolt 4.0+. :param imp_user: the user to impersonate Requires Bolt 4.4+ + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object :return: Response object """ pass @abc.abstractmethod - def commit(self, **handlers): - """ Appends a COMMIT message to the output queue.""" + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + """ Appends a COMMIT message to the output queue. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. + """ pass @abc.abstractmethod - def rollback(self, **handlers): - """ Appends a ROLLBACK message to the output queue.""" + def rollback(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + """ Appends a ROLLBACK message to the output queue. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything.""" pass @abc.abstractmethod - async def reset(self): + async def reset(self, dehydration_hooks=None, hydration_hooks=None): """ Appends a RESET message to the outgoing queue, sends it and consumes all remaining messages. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. """ pass @abc.abstractmethod - def goodbye(self): - """Append a GOODBYE message to the outgoing queue.""" + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + """Append a GOODBYE message to the outgoing queue. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. + """ pass - def _append(self, signature, fields=(), response=None): + def new_hydration_scope(self): + return self.hydration_handler.new_hydration_scope() + + def _append(self, signature, fields=(), response=None, + dehydration_hooks=None): """ Appends a message to the outgoing queue. :param signature: the signature of the message :param fields: the fields of the message as a tuple :param response: a response object to handle callbacks + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. """ - with self.outbox.tmp_buffer(): - self.packer.pack_struct(signature, fields) - self.outbox.wrap_message() + self.outbox.append_message(signature, fields, dehydration_hooks) self.responses.append(response) async def _send_all(self): - data = self.outbox.view() - if data: - try: - await self.socket.sendall(data) - except OSError as error: - await self._set_defunct_write(error) - self.outbox.clear() + if await self.outbox.flush(): self.idle_since = perf_counter() async def send_all(self): @@ -523,8 +616,7 @@ async def send_all(self): await self._send_all() @abc.abstractmethod - async def _process_message(self, details, summary_signature, - summary_metadata): + async def _process_message(self, tag, fields): """ Receive at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary @@ -549,11 +641,10 @@ async def fetch_message(self): return 0, 0 # Receive exactly one message - details, summary_signature, summary_metadata = \ - await AsyncUtil.next(self.inbox) - res = await self._process_message( - details, summary_signature, summary_metadata + tag, fields = await self.inbox.pop( + hydration_hooks=self.responses[0].hydration_hooks ) + res = await self._process_message(tag, fields) self.idle_since = perf_counter() return res diff --git a/neo4j/_async/io/_bolt3.py b/neo4j/_async/io/_bolt3.py index b361e1f52..653f2a10a 100644 --- a/neo4j/_async/io/_bolt3.py +++ b/neo4j/_async/io/_bolt3.py @@ -142,7 +142,7 @@ def get_base_headers(self): "user_agent": self.user_agent, } - async def hello(self): + async def hello(self, dehydration_hooks=None, hydration_hooks=None): headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -150,13 +150,17 @@ async def hello(self): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=self.server_info.update)) + response=InitResponse(self, "hello", hydration_hooks, + on_success=self.server_info.update), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() check_supported_server_product(self.server_info.agent) - async def route(self, database=None, imp_user=None, bookmarks=None): + async def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): if database is not None: raise ConfigurationError( "Database name parameter for selecting database is not " @@ -183,16 +187,20 @@ async def route(self, database=None, imp_user=None, bookmarks=None): "CALL dbms.cluster.routing.getRoutingTable($context)", # This is an internal procedure call. Only available if the Neo4j 3.5 is setup with clustering. {"context": self.routing_context}, mode="r", # Bolt Protocol Version(3, 0) supports mode="r" + dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks, on_success=metadata.update ) - self.pull(on_success=metadata.update, on_records=records.extend) + self.pull(dehydration_hooks = None, hydration_hooks = None, + on_success=metadata.update, on_records=records.extend) await self.send_all() await self.fetch_all() routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] return routing_info def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if db is not None: raise ConfigurationError( "Database name parameter for selecting database is not " @@ -231,20 +239,29 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, raise ValueError("Timeout must be a positive number or 0.") fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: DISCARD_ALL", self.local_port) - self._append(b"\x2F", (), Response(self, "discard", **handlers)) + self._append(b"\x2F", (), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: PULL_ALL", self.local_port) - self._append(b"\x3F", (), Response(self, "pull", **handlers)) + self._append(b"\x3F", (), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): if db is not None: raise ConfigurationError( "Database name parameter for selecting database is not " @@ -280,17 +297,25 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a positive number or 0.") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def commit(self, **handlers): + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + self._append(b"\x12", (), + CommitResponse(self, "commit", hydration_hooks, + **handlers), + dehydration_hooks=dehydration_hooks) - def rollback(self, **handlers): + def rollback(self, dehydration_hooks=None, hydration_hooks=None, + **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, "rollback", **handlers)) + self._append(b"\x13", (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - async def reset(self): + async def reset(self, dehydration_hooks=None, hydration_hooks=None): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ @@ -299,21 +324,33 @@ def fail(metadata): raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self._append(b"\x0F", + response=Response(self, "reset", hydration_hooks, + on_failure=fail), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() - def goodbye(self): + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): log.debug("[#%04X] C: GOODBYE", self.local_port) - self._append(b"\x02", ()) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) - async def _process_message(self, details, summary_signature, - summary_metadata): + async def _process_message(self, tag, fields): """ Process at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data await self.responses[0].on_records(details) diff --git a/neo4j/_async/io/_bolt4.py b/neo4j/_async/io/_bolt4.py index d1fbf035b..f82cf7069 100644 --- a/neo4j/_async/io/_bolt4.py +++ b/neo4j/_async/io/_bolt4.py @@ -34,11 +34,11 @@ NotALeader, ServiceUnavailable, ) +from ._bolt import AsyncBolt from ._bolt3 import ( ServerStateManager, ServerStates, ) -from ._bolt import AsyncBolt from ._common import ( check_supported_server_product, CommitResponse, @@ -95,7 +95,7 @@ def get_base_headers(self): "user_agent": self.user_agent, } - async def hello(self): + async def hello(self, dehydration_hooks=None, hydration_hooks=None): headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -103,13 +103,19 @@ async def hello(self): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=self.server_info.update)) + response=InitResponse( + self, "hello", hydration_hooks, + on_success=self.server_info.update + ), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() check_supported_server_product(self.server_info.agent) - async def route(self, database=None, imp_user=None, bookmarks=None): + async def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -138,14 +144,20 @@ async def route(self, database=None, imp_user=None, bookmarks=None): db=SYSTEM_DATABASE, on_success=metadata.update ) - self.pull(on_success=metadata.update, on_records=records.extend) + self.pull( + dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks, + on_success=metadata.update, + on_records=records.extend + ) await self.send_all() await self.fetch_all() routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] return routing_info def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -179,24 +191,33 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, raise ValueError("Timeout must be a positive number or 0.") fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) - self._append(b"\x2F", (extra,), Response(self, "discard", **handlers)) + self._append(b"\x2F", (extra,), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: PULL %r", self.local_port, extra) - self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) + self._append(b"\x3F", (extra,), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -227,17 +248,25 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a positive number or 0.") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def commit(self, **handlers): + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + self._append(b"\x12", (), + CommitResponse(self, "commit", hydration_hooks, + **handlers), + dehydration_hooks=dehydration_hooks) - def rollback(self, **handlers): + def rollback(self, dehydration_hooks=None, hydration_hooks=None, + **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, "rollback", **handlers)) + self._append(b"\x13", (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - async def reset(self): + async def reset(self, dehydration_hooks=None, hydration_hooks=None): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ @@ -246,21 +275,33 @@ def fail(metadata): raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self._append(b"\x0F", + response=Response(self, "reset", hydration_hooks, + on_failure=fail), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() - def goodbye(self): + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): log.debug("[#%04X] C: GOODBYE", self.local_port) - self._append(b"\x02", ()) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) - async def _process_message(self, details, summary_signature, - summary_metadata): + async def _process_message(self, tag, fields): """ Process at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data await self.responses[0].on_records(details) @@ -341,7 +382,15 @@ class AsyncBolt4x3(AsyncBolt4x2): PROTOCOL_VERSION = Version(4, 3) - async def route(self, database=None, imp_user=None, bookmarks=None): + def get_base_headers(self): + headers = super().get_base_headers() + headers["patch_bolt"] = ["utc"] + return headers + + async def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -359,13 +408,14 @@ async def route(self, database=None, imp_user=None, bookmarks=None): else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, database), - response=Response(self, "route", - on_success=metadata.update)) + response=Response(self, "route", hydration_hooks, + on_success=metadata.update), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() return [metadata.get("rt")] - async def hello(self): + async def hello(self, dehydration_hooks=None, hydration_hooks=None): def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) @@ -380,6 +430,9 @@ def on_success(metadata): "connection.recv_timeout_seconds (%r). Make sure " "the server and network is set up correctly.", self.local_port, recv_timeout) + self.patch = set(metadata.pop("patch_bolt", [])) + if "utc" in self.patch: + self.hydration_handler.patch_utc() headers = self.get_base_headers() headers.update(self.auth_dict) @@ -388,8 +441,9 @@ def on_success(metadata): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=on_success)) + response=InitResponse(self, "hello", hydration_hooks, + on_success=on_success), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() check_supported_server_product(self.server_info.agent) @@ -403,7 +457,10 @@ class AsyncBolt4x4(AsyncBolt4x3): PROTOCOL_VERSION = Version(4, 4) - async def route(self, database=None, imp_user=None, bookmarks=None): + async def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -418,14 +475,16 @@ async def route(self, database=None, imp_user=None, bookmarks=None): else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, db_context), - response=Response(self, "route", - on_success=metadata.update)) + response=Response(self, "route", hydration_hooks, + on_success=metadata.update), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() return [metadata.get("rt")] def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if not parameters: parameters = {} extra = {} @@ -456,10 +515,13 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {} if mode in (READ_ACCESS, "r"): # It will default to mode "w" if nothing is specified @@ -486,4 +548,6 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a positive number or 0.") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) diff --git a/neo4j/_async/io/_bolt5.py b/neo4j/_async/io/_bolt5.py index 99901ff31..aa399e8ac 100644 --- a/neo4j/_async/io/_bolt5.py +++ b/neo4j/_async/io/_bolt5.py @@ -19,28 +19,24 @@ from logging import getLogger from ssl import SSLSocket -from ..._async_compat.util import AsyncUtil -from ..._exceptions import ( - BoltError, - BoltProtocolError, -) +from ..._codec.hydration import v2 as hydration_v2 +from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, Version, ) from ...exceptions import ( DatabaseUnavailable, - DriverError, ForbiddenOnReadOnlyDatabase, Neo4jError, NotALeader, ServiceUnavailable, ) +from ._bolt import AsyncBolt from ._bolt3 import ( ServerStateManager, ServerStates, ) -from ._bolt import AsyncBolt from ._common import ( check_supported_server_product, CommitResponse, @@ -57,6 +53,8 @@ class AsyncBolt5x0(AsyncBolt): PROTOCOL_VERSION = Version(5, 0) + HYDRATION_HANDLER_CLS = hydration_v2.HydrationHandler + supports_multiple_results = True supports_multiple_databases = True @@ -95,7 +93,7 @@ def get_base_headers(self): headers["routing"] = self.routing_context return headers - async def hello(self): + async def hello(self, dehydration_hooks=None, hydration_hooks=None): def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) @@ -118,13 +116,15 @@ def on_success(metadata): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=on_success)) + response=InitResponse(self, "hello", hydration_hooks, + on_success=on_success), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() check_supported_server_product(self.server_info.agent) - async def route(self, database=None, imp_user=None, bookmarks=None): + async def route(self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None): routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -139,14 +139,16 @@ async def route(self, database=None, imp_user=None, bookmarks=None): else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, db_context), - response=Response(self, "route", - on_success=metadata.update)) + response=Response(self, "route", hydration_hooks, + on_success=metadata.update), + dehydration_hooks=hydration_hooks) await self.send_all() await self.fetch_all() return [metadata.get("rt")] def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if not parameters: parameters = {} extra = {} @@ -177,24 +179,33 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) - self._append(b"\x2F", (extra,), Response(self, "discard", **handlers)) + self._append(b"\x2F", (extra,), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: PULL %r", self.local_port, extra) - self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) + self._append(b"\x3F", (extra,), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {} if mode in (READ_ACCESS, "r"): # It will default to mode "w" if nothing is specified @@ -221,17 +232,25 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a number <= 0") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def commit(self, **handlers): + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + self._append(b"\x12", (), + CommitResponse(self, "commit", hydration_hooks, + **handlers), + dehydration_hooks=dehydration_hooks) - def rollback(self, **handlers): + def rollback(self, dehydration_hooks=None, hydration_hooks=None, + **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, "rollback", **handlers)) + self._append(b"\x13", (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - async def reset(self): + async def reset(self, dehydration_hooks=None, hydration_hooks=None): """Reset the connection. Add a RESET message to the outgoing queue, send it and consume all @@ -243,22 +262,33 @@ def fail(metadata): self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", - on_failure=fail)) + self._append(b"\x0F", + response=Response(self, "reset", hydration_hooks, + on_failure=fail), + dehydration_hooks=dehydration_hooks) await self.send_all() await self.fetch_all() - def goodbye(self): + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): log.debug("[#%04X] C: GOODBYE", self.local_port) - self._append(b"\x02", ()) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) - async def _process_message(self, details, summary_signature, - summary_metadata): + async def _process_message(self, tag, fields): """Process at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + if details: # Do not log any data log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) diff --git a/neo4j/_async/io/_common.py b/neo4j/_async/io/_common.py index 7ab3cebe9..98744453a 100644 --- a/neo4j/_async/io/_common.py +++ b/neo4j/_async/io/_common.py @@ -17,7 +17,6 @@ import asyncio -from contextlib import contextmanager import logging import socket from struct import pack as struct_pack @@ -30,132 +29,122 @@ SessionExpired, UnsupportedServerProduct, ) -from ...packstream import ( - UnpackableBuffer, - Unpacker, -) log = logging.getLogger("neo4j") -class AsyncMessageInbox: +class AsyncInbox: - def __init__(self, s, on_error): + def __init__(self, sock, on_error, unpacker_cls): self.on_error = on_error - self._local_port = s.getsockname()[1] - self._messages = self._yield_messages(s) - - async def _yield_messages(self, sock): + self._local_port = sock.getsockname()[1] + self._socket = sock + self._buffer = unpacker_cls.new_unpackable_buffer() + self._unpacker = unpacker_cls(self._buffer) + self._broken = False + + async def _buffer_one_chunk(self): + assert not self._broken try: - buffer = UnpackableBuffer() - unpacker = Unpacker(buffer) chunk_size = 0 while True: - while chunk_size == 0: # Determine the chunk size and skip noop - await receive_into_buffer(sock, buffer, 2) - chunk_size = buffer.pop_u16() + await receive_into_buffer(self._socket, self._buffer, 2) + chunk_size = self._buffer.pop_u16() if chunk_size == 0: log.debug("[#%04X] S: ", self._local_port) - await receive_into_buffer(sock, buffer, chunk_size + 2) - chunk_size = buffer.pop_u16() + await receive_into_buffer( + self._socket, self._buffer, chunk_size + 2 + ) + chunk_size = self._buffer.pop_u16() if chunk_size == 0: # chunk_size was the end marker for the message - size, tag = unpacker.unpack_structure_header() - fields = [unpacker.unpack() for _ in range(size)] - yield tag, fields - # Reset for new message - unpacker.reset() + return except (OSError, socket.timeout, SocketDeadlineExceeded) as error: + self._broken = True await AsyncUtil.callback(self.on_error, error) + raise - async def pop(self): - return await AsyncUtil.next(self._messages) - - -class AsyncInbox(AsyncMessageInbox): - - async def __anext__(self): - tag, fields = await self.pop() - if tag == b"\x71": - return fields, None, None - elif fields: - return [], tag, fields[0] - else: - return [], tag, None + async def pop(self, hydration_hooks): + await self._buffer_one_chunk() + try: + size, tag = self._unpacker.unpack_structure_header() + fields = [self._unpacker.unpack(hydration_hooks) + for _ in range(size)] + return tag, fields + finally: + # Reset for new message + self._unpacker.reset() -class Outbox: +class AsyncOutbox: - def __init__(self, max_chunk_size=16384): + def __init__(self, sock, on_error, packer_cls, max_chunk_size=16384): self._max_chunk_size = max_chunk_size self._chunked_data = bytearray() - self._raw_data = bytearray() - self.write = self._raw_data.extend - self._tmp_buffering = 0 + self._buffer = packer_cls.new_packable_buffer() + self._packer = packer_cls(self._buffer) + self.socket = sock + self.on_error = on_error def max_chunk_size(self): return self._max_chunk_size - def clear(self): - if self._tmp_buffering: - raise RuntimeError("Cannot clear while buffering") + def _clear(self): + assert not self._buffer.is_tmp_buffering() self._chunked_data = bytearray() - self._raw_data.clear() + self._buffer.clear() def _chunk_data(self): - data_len = len(self._raw_data) + data_len = len(self._buffer.data) num_full_chunks, chunk_rest = divmod( data_len, self._max_chunk_size ) num_chunks = num_full_chunks + bool(chunk_rest) - data_view = memoryview(self._raw_data) - header_start = len(self._chunked_data) - data_start = header_start + 2 - raw_data_start = 0 - for i in range(num_chunks): - chunk_size = min(data_len - raw_data_start, - self._max_chunk_size) - self._chunked_data[header_start:data_start] = struct_pack( - ">H", chunk_size - ) - self._chunked_data[data_start:(data_start + chunk_size)] = \ - data_view[raw_data_start:(raw_data_start + chunk_size)] - header_start += chunk_size + 2 + with memoryview(self._buffer.data) as data_view: + header_start = len(self._chunked_data) data_start = header_start + 2 - raw_data_start += chunk_size - del data_view - self._raw_data.clear() - - def wrap_message(self): - if self._tmp_buffering: - raise RuntimeError("Cannot wrap message while buffering") + raw_data_start = 0 + for i in range(num_chunks): + chunk_size = min(data_len - raw_data_start, + self._max_chunk_size) + self._chunked_data[header_start:data_start] = struct_pack( + ">H", chunk_size + ) + self._chunked_data[data_start:(data_start + chunk_size)] = \ + data_view[raw_data_start:(raw_data_start + chunk_size)] + header_start += chunk_size + 2 + data_start = header_start + 2 + raw_data_start += chunk_size + self._buffer.clear() + + def _wrap_message(self): + assert not self._buffer.is_tmp_buffering() self._chunk_data() self._chunked_data += b"\x00\x00" - def view(self): - if self._tmp_buffering: - raise RuntimeError("Cannot view while buffering") - self._chunk_data() - return memoryview(self._chunked_data) + def append_message(self, tag, fields, dehydration_hooks): + with self._buffer.tmp_buffer(): + self._packer.pack_struct(tag, fields, dehydration_hooks) + self._wrap_message() - @contextmanager - def tmp_buffer(self): - self._tmp_buffering += 1 - old_len = len(self._raw_data) - try: - yield - except Exception: - del self._raw_data[old_len:] - raise - finally: - self._tmp_buffering -= 1 + async def flush(self): + data = self._chunked_data + if data: + try: + await self.socket.sendall(data) + except OSError as error: + await self.on_error(error) + return False + self._clear() + return True + return False class ConnectionErrorHandler: @@ -218,8 +207,9 @@ class Response: more detail messages followed by one summary message). """ - def __init__(self, connection, message, **handlers): + def __init__(self, connection, message, hydration_hooks, **handlers): self.connection = connection + self.hydration_hooks = hydration_hooks self.handlers = handlers self.message = message self.complete = False @@ -294,9 +284,9 @@ async def receive_into_buffer(sock, buffer, n_bytes): end = buffer.used + n_bytes if end > len(buffer.data): buffer.data += bytearray(end - len(buffer.data)) - view = memoryview(buffer.data) - while buffer.used < end: - n = await sock.recv_into(view[buffer.used:end], end - buffer.used) - if n == 0: - raise OSError("No data") - buffer.used += n + with memoryview(buffer.data) as view: + while buffer.used < end: + n = await sock.recv_into(view[buffer.used:end], end - buffer.used) + if n == 0: + raise OSError("No data") + buffer.used += n diff --git a/neo4j/_async/io/_pool.py b/neo4j/_async/io/_pool.py index 198e0f8ac..40837d3f1 100644 --- a/neo4j/_async/io/_pool.py +++ b/neo4j/_async/io/_pool.py @@ -17,12 +17,12 @@ import abc +import logging from collections import ( defaultdict, deque, ) from contextlib import asynccontextmanager -import logging from logging import getLogger from random import choice diff --git a/neo4j/_async/work/result.py b/neo4j/_async/work/result.py index 9a2b6af3d..7f913cb72 100644 --- a/neo4j/_async/work/result.py +++ b/neo4j/_async/work/result.py @@ -20,15 +20,15 @@ from warnings import warn from ..._async_compat.util import AsyncUtil -from ...data import ( - DataDehydrator, +from ..._data import ( + Record, RecordTableRowExporter, ) +from ..._meta import experimental from ...exceptions import ( ResultConsumedError, ResultNotSingleError, ) -from ...meta import experimental from ...time import ( Date, DateTime, @@ -54,10 +54,9 @@ class AsyncResult: :meth:`.AyncSession.run` and :meth:`.AsyncTransaction.run`. """ - def __init__(self, connection, hydrant, fetch_size, on_closed, - on_error): + def __init__(self, connection, fetch_size, on_closed, on_error): self._connection = ConnectionErrorHandler(connection, on_error) - self._hydrant = hydrant + self._hydration_scope = connection.new_hydration_scope() self._on_closed = on_closed self._metadata = None self._keys = None @@ -104,7 +103,7 @@ async def _run( query_metadata = getattr(query, "metadata", None) query_timeout = getattr(query, "timeout", None) - parameters = DataDehydrator.fix_parameters(dict(parameters or {}, **kwargs)) + parameters = dict(parameters or {}, **kwargs) self._metadata = { "query": query_text, @@ -135,6 +134,7 @@ async def on_failed_attach(metadata): timeout=query_timeout, db=db, imp_user=imp_user, + dehydration_hooks=self._hydration_scope.dehydration_hooks, on_success=on_attached, on_failure=on_failed_attach, ) @@ -145,7 +145,10 @@ async def on_failed_attach(metadata): def _pull(self): def on_records(records): if not self._discarding: - self._record_buffer.extend(self._hydrant.hydrate_records(self._keys, records)) + self._record_buffer.extend(( + Record(zip(self._keys, record)) + for record in records + )) async def on_summary(): self._attached = False @@ -167,6 +170,7 @@ def on_success(summary_metadata): self._connection.pull( n=self._fetch_size, qid=self._qid, + hydration_hooks=self._hydration_scope.hydration_hooks, on_records=on_records, on_success=on_success, on_failure=on_failure, @@ -479,7 +483,7 @@ async def graph(self): Can raise :exc:`ResultConsumedError`. """ await self._buffer_all() - return self._hydrant.graph + return self._hydration_scope.get_graph() async def value(self, key=0, default=None): """Helper function that return the remainder of the result as a list of values. diff --git a/neo4j/_async/work/session.py b/neo4j/_async/work/session.py index fc0e045b5..eed2d9d3a 100644 --- a/neo4j/_async/work/session.py +++ b/neo4j/_async/work/session.py @@ -21,13 +21,16 @@ from time import perf_counter from ..._async_compat import async_sleep +from ..._meta import ( + deprecated, + deprecation_warn, +) from ...api import ( Bookmarks, READ_ACCESS, WRITE_ACCESS, ) from ...conf import SessionConfig -from ...data import DataHydrator from ...exceptions import ( ClientError, DriverError, @@ -36,10 +39,6 @@ SessionExpired, TransactionError, ) -from ...meta import ( - deprecated, - deprecation_warn, -) from ...work import Query from .result import AsyncResult from .transaction import ( @@ -228,10 +227,8 @@ async def run(self, query, parameters=None, **kwargs): protocol_version = cx.PROTOCOL_VERSION server_info = cx.server_info - hydrant = DataHydrator() - self._auto_result = AsyncResult( - cx, hydrant, self._config.fetch_size, self._result_closed, + cx, self._config.fetch_size, self._result_closed, self._result_error ) await self._auto_result._run( diff --git a/neo4j/_async/work/transaction.py b/neo4j/_async/work/transaction.py index e293a8ecd..03b76476e 100644 --- a/neo4j/_async/work/transaction.py +++ b/neo4j/_async/work/transaction.py @@ -19,7 +19,6 @@ from functools import wraps from ..._async_compat.util import AsyncUtil -from ...data import DataHydrator from ...exceptions import TransactionError from ...work import Query from ..io import ConnectionErrorHandler @@ -123,8 +122,7 @@ async def run(self, query, parameters=None, **kwparameters): await self._results[-1]._buffer_all() result = AsyncResult( - self._connection, DataHydrator(), self._fetch_size, - self._result_on_closed_handler, + self._connection, self._fetch_size, self._result_on_closed_handler, self._error_handler ) self._results.append(result) diff --git a/neo4j/_async/work/workspace.py b/neo4j/_async/work/workspace.py index e4d1d3118..3a374f7a1 100644 --- a/neo4j/_async/work/workspace.py +++ b/neo4j/_async/work/workspace.py @@ -19,15 +19,15 @@ import asyncio from ..._deadline import Deadline +from ..._meta import ( + deprecation_warn, + unclosed_resource_warn, +) from ...conf import WorkspaceConfig from ...exceptions import ( ServiceUnavailable, SessionExpired, ) -from ...meta import ( - deprecation_warn, - unclosed_resource_warn, -) from ..io import AsyncNeo4jPool diff --git a/neo4j/_async_compat/network/_bolt_socket.py b/neo4j/_async_compat/network/_bolt_socket.py index e8405bc18..475536bc6 100644 --- a/neo4j/_async_compat/network/_bolt_socket.py +++ b/neo4j/_async_compat/network/_bolt_socket.py @@ -20,6 +20,7 @@ import logging import selectors import socket +import struct from socket import ( AF_INET, AF_INET6, @@ -34,7 +35,6 @@ HAS_SNI, SSLError, ) -import struct from time import perf_counter from ... import addressing diff --git a/neo4j/_async_compat/util.py b/neo4j/_async_compat/util.py index 98e29369a..fd510d02b 100644 --- a/neo4j/_async_compat/util.py +++ b/neo4j/_async_compat/util.py @@ -18,7 +18,7 @@ import inspect -from ..meta import experimental +from .._meta import experimental __all__ = [ diff --git a/tests/unit/common/data/__init__.py b/neo4j/_codec/__init__.py similarity index 100% rename from tests/unit/common/data/__init__.py rename to neo4j/_codec/__init__.py diff --git a/neo4j/_codec/hydration/__init__.py b/neo4j/_codec/hydration/__init__.py new file mode 100644 index 000000000..bd4fdb81f --- /dev/null +++ b/neo4j/_codec/hydration/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._common import HydrationScope +from ._interface import HydrationHandlerABC + + +__all__ = [ + "HydrationHandlerABC", + "HydrationScope", +] diff --git a/neo4j/_codec/hydration/_common.py b/neo4j/_codec/hydration/_common.py new file mode 100644 index 000000000..1fc634fa3 --- /dev/null +++ b/neo4j/_codec/hydration/_common.py @@ -0,0 +1,50 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...graph import Graph +from ...packstream import Structure + + +class GraphHydrator: + def __init__(self): + self.graph = Graph() + self.struct_hydration_functions = {} + + +class HydrationScope: + + def __init__(self, hydration_handler, graph_hydrator): + self._hydration_handler = hydration_handler + self._graph_hydrator = graph_hydrator + self._struct_hydration_functions = { + **hydration_handler.struct_hydration_functions, + **graph_hydrator.struct_hydration_functions, + } + self.hydration_hooks = { + Structure: self._hydrate_structure, + } + self.dehydration_hooks = hydration_handler.dehydration_functions + + def _hydrate_structure(self, value): + f = self._struct_hydration_functions.get(value.tag) + if not f: + return value + return f(*value.fields) + + def get_graph(self): + return self._graph_hydrator.graph diff --git a/neo4j/_codec/hydration/_interface/__init__.py b/neo4j/_codec/hydration/_interface/__init__.py new file mode 100644 index 000000000..5092d5e0d --- /dev/null +++ b/neo4j/_codec/hydration/_interface/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc + + +class HydrationHandlerABC(abc.ABC): + def __init__(self): + self.struct_hydration_functions = {} + self.dehydration_functions = {} + + @abc.abstractmethod + def new_hydration_scope(self): + ... diff --git a/neo4j/_codec/hydration/v1/__init__.py b/neo4j/_codec/hydration/v1/__init__.py new file mode 100644 index 000000000..985f6f033 --- /dev/null +++ b/neo4j/_codec/hydration/v1/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .hydration_handler import HydrationHandler + + +__all__ = [ + "HydrationHandler", +] diff --git a/neo4j/_codec/hydration/v1/hydration_handler.py b/neo4j/_codec/hydration/v1/hydration_handler.py new file mode 100644 index 000000000..fd62ad84f --- /dev/null +++ b/neo4j/_codec/hydration/v1/hydration_handler.py @@ -0,0 +1,201 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import ( + date, + datetime, + time, + timedelta, +) + +from ....graph import ( + Graph, + Node, + Path, +) +from ....spatial import ( + CartesianPoint, + Point, + WGS84Point, +) +from ....time import ( + Date, + DateTime, + Duration, + Time, +) +from .._common import ( + GraphHydrator, + HydrationScope, +) +from .._interface import HydrationHandlerABC +from . import ( + spatial, + temporal, +) + + +class _GraphHydrator(GraphHydrator): + def __init__(self): + super().__init__() + self.struct_hydration_functions = { + **self.struct_hydration_functions, + b"N": self.hydrate_node, + b"R": self.hydrate_relationship, + b"r": self.hydrate_unbound_relationship, + b"P": self.hydrate_path, + } + + def hydrate_node(self, id_, labels=None, + properties=None, element_id=None): + # backwards compatibility with Neo4j < 5.0 + if element_id is None: + element_id = str(id_) + + try: + inst = self.graph._nodes[element_id] + except KeyError: + inst = self.graph._nodes[element_id] = Node( + self.graph, element_id, id_, labels, properties + ) + else: + # If we have already hydrated this node as the endpoint of + # a relationship, it won't have any labels or properties. + # Therefore, we need to add the ones we have here. + if labels: + inst._labels = inst._labels.union(labels) # frozen_set + if properties: + inst._properties.update(properties) + return inst + + def hydrate_relationship(self, id_, n0_id, n1_id, type_, + properties=None, element_id=None, + n0_element_id=None, n1_element_id=None): + # backwards compatibility with Neo4j < 5.0 + if element_id is None: + element_id = str(id_) + if n0_element_id is None: + n0_element_id = str(n0_id) + if n1_element_id is None: + n1_element_id = str(n1_id) + + inst = self.hydrate_unbound_relationship(id_, type_, properties, + element_id) + inst._start_node = self.hydrate_node(n0_id, + element_id=n0_element_id) + inst._end_node = self.hydrate_node(n1_id, element_id=n1_element_id) + return inst + + def hydrate_unbound_relationship(self, id_, type_, properties=None, + element_id=None): + assert isinstance(self.graph, Graph) + # backwards compatibility with Neo4j < 5.0 + if element_id is None: + element_id = str(id_) + + try: + inst = self.graph._relationships[element_id] + except KeyError: + r = self.graph.relationship_type(type_) + inst = self.graph._relationships[element_id] = r( + self.graph, element_id, id_, properties + ) + return inst + + def hydrate_path(self, nodes, relationships, sequence): + assert isinstance(self.graph, Graph) + assert len(nodes) >= 1 + assert len(sequence) % 2 == 0 + last_node = nodes[0] + entities = [last_node] + for i, rel_index in enumerate(sequence[::2]): + assert rel_index != 0 + next_node = nodes[sequence[2 * i + 1]] + if rel_index > 0: + r = relationships[rel_index - 1] + r._start_node = last_node + r._end_node = next_node + entities.append(r) + else: + r = relationships[-rel_index - 1] + r._start_node = next_node + r._end_node = last_node + entities.append(r) + last_node = next_node + return Path(*entities) + + def _hydrate(self, value): + pass + + def _dehydrate(self, value): + raise NotImplementedError( + "GraphHydrationHandler cannot be used for dehydration." + ) + + +class HydrationHandler(HydrationHandlerABC): + def __init__(self): + super().__init__() + self._created_scope = False + self.struct_hydration_functions = { + **self.struct_hydration_functions, + b"X": spatial.hydrate_point, + b"Y": spatial.hydrate_point, + b"D": temporal.hydrate_date, + b"T": temporal.hydrate_time, # time zone offset + b"t": temporal.hydrate_time, # no time zone + b"F": temporal.hydrate_datetime, # time zone offset + b"f": temporal.hydrate_datetime, # time zone name + b"d": temporal.hydrate_datetime, # no time zone + b"E": temporal.hydrate_duration, + } + self.dehydration_functions = { + **self.dehydration_functions, + Point: spatial.dehydrate_point, + CartesianPoint: spatial.dehydrate_point, + WGS84Point: spatial.dehydrate_point, + Date: temporal.dehydrate_date, + date: temporal.dehydrate_date, + Time: temporal.dehydrate_time, + time: temporal.dehydrate_time, + DateTime: temporal.dehydrate_datetime, + datetime: temporal.dehydrate_datetime, + Duration: temporal.dehydrate_duration, + timedelta: temporal.dehydrate_timedelta, + } + + def patch_utc(self): + from ..v2 import temporal as temporal_v2 + + assert not self._created_scope + + del self.struct_hydration_functions[b"F"] + del self.struct_hydration_functions[b"f"] + self.struct_hydration_functions.update({ + b"I": temporal_v2.hydrate_datetime, + b"i": temporal_v2.hydrate_datetime, + }) + + self.dehydration_functions.update({ + DateTime: temporal_v2.dehydrate_datetime, + datetime: temporal_v2.dehydrate_datetime, + }) + + def new_hydration_scope(self): + self._created_scope = True + return HydrationScope(self, _GraphHydrator()) diff --git a/neo4j/_codec/hydration/v1/spatial.py b/neo4j/_codec/hydration/v1/spatial.py new file mode 100644 index 000000000..6e2f7a6f5 --- /dev/null +++ b/neo4j/_codec/hydration/v1/spatial.py @@ -0,0 +1,63 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...._spatial import ( + Point, + srid_table, +) +from ...packstream import Structure + + +def hydrate_point(srid, *coordinates): + """ Create a new instance of a Point subclass from a raw + set of fields. The subclass chosen is determined by the + given SRID; a ValueError will be raised if no such + subclass can be found. + """ + try: + point_class, dim = srid_table[srid] + except KeyError: + point = Point(coordinates) + point.srid = srid + return point + else: + if len(coordinates) != dim: + raise ValueError("SRID %d requires %d coordinates (%d provided)" % (srid, dim, len(coordinates))) + return point_class(coordinates) + + +def dehydrate_point(value): + """ Dehydrator for Point data. + + :param value: + :type value: Point + :return: + """ + dim = len(value) + if dim == 2: + return Structure(b"X", value.srid, *value) + elif dim == 3: + return Structure(b"Y", value.srid, *value) + else: + raise ValueError("Cannot dehydrate Point with %d dimensions" % dim) + + +__all__ = [ + "hydrate_point", + "dehydrate_point", +] diff --git a/neo4j/time/hydration.py b/neo4j/_codec/hydration/v1/temporal.py similarity index 98% rename from neo4j/time/hydration.py rename to neo4j/_codec/hydration/v1/temporal.py index 056cafda0..a9c511a5a 100644 --- a/neo4j/time/hydration.py +++ b/neo4j/_codec/hydration/v1/temporal.py @@ -22,13 +22,13 @@ timedelta, ) -from neo4j.packstream import Structure -from neo4j.time import ( +from ....time import ( Date, DateTime, Duration, Time, ) +from ...packstream import Structure def get_date_unix_epoch(): diff --git a/neo4j/_codec/hydration/v2/__init__.py b/neo4j/_codec/hydration/v2/__init__.py new file mode 100644 index 000000000..c3cd9e2e8 --- /dev/null +++ b/neo4j/_codec/hydration/v2/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .hydration_handler import HydrationHandler + + +__all__ = [ + "HydrationHandler", +] diff --git a/neo4j/_codec/hydration/v2/hydration_handler.py b/neo4j/_codec/hydration/v2/hydration_handler.py new file mode 100644 index 000000000..092201a07 --- /dev/null +++ b/neo4j/_codec/hydration/v2/hydration_handler.py @@ -0,0 +1,57 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..v1.hydration_handler import * +from ..v1.hydration_handler import _GraphHydrator +from . import temporal + + +class HydrationHandler(HydrationHandlerABC): + def __init__(self): + super().__init__() + self._created_scope = False + self.struct_hydration_functions = { + **self.struct_hydration_functions, + b"X": spatial.hydrate_point, + b"Y": spatial.hydrate_point, + b"D": temporal.hydrate_date, + b"T": temporal.hydrate_time, # time zone offset + b"t": temporal.hydrate_time, # no time zone + b"I": temporal.hydrate_datetime, # time zone offset + b"i": temporal.hydrate_datetime, # time zone name + b"d": temporal.hydrate_datetime, # no time zone + b"E": temporal.hydrate_duration, + } + self.dehydration_functions = { + **self.dehydration_functions, + Point: spatial.dehydrate_point, + CartesianPoint: spatial.dehydrate_point, + WGS84Point: spatial.dehydrate_point, + Date: temporal.dehydrate_date, + date: temporal.dehydrate_date, + Time: temporal.dehydrate_time, + time: temporal.dehydrate_time, + DateTime: temporal.dehydrate_datetime, + datetime: temporal.dehydrate_datetime, + Duration: temporal.dehydrate_duration, + timedelta: temporal.dehydrate_timedelta, + } + + def new_hydration_scope(self): + self._created_scope = True + return HydrationScope(self, _GraphHydrator()) diff --git a/neo4j/_codec/hydration/v2/temporal.py b/neo4j/_codec/hydration/v2/temporal.py new file mode 100644 index 000000000..4741ce9aa --- /dev/null +++ b/neo4j/_codec/hydration/v2/temporal.py @@ -0,0 +1,92 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ..v1.temporal import * + + +def hydrate_datetime(seconds, nanoseconds, tz=None): + """ Hydrator for `DateTime` and `LocalDateTime` values. + + :param seconds: + :param nanoseconds: + :param tz: + :return: datetime + """ + import pytz + + minutes, seconds = map(int, divmod(seconds, 60)) + hours, minutes = map(int, divmod(minutes, 60)) + days, hours = map(int, divmod(hours, 24)) + t = DateTime.combine( + Date.from_ordinal(get_date_unix_epoch_ordinal() + days), + Time(hours, minutes, seconds, nanoseconds) + ) + if tz is None: + return t + if isinstance(tz, int): + tz_offset_minutes, tz_offset_seconds = divmod(tz, 60) + zone = pytz.FixedOffset(tz_offset_minutes) + else: + zone = pytz.timezone(tz) + t = t.replace(tzinfo=pytz.UTC) + return t.as_timezone(zone) + + +def dehydrate_datetime(value): + """ Dehydrator for `datetime` values. + + :param value: + :type value: datetime + :return: + """ + + import pytz + + def seconds_and_nanoseconds(dt): + if isinstance(dt, datetime): + dt = DateTime.from_native(dt) + dt = dt.astimezone(pytz.UTC) + utc_epoch = DateTime(1970, 1, 1, tzinfo=pytz.UTC) + dt_clock_time = dt.to_clock_time() + utc_epoch_clock_time = utc_epoch.to_clock_time() + t = dt_clock_time - utc_epoch_clock_time + return t.seconds, t.nanoseconds + + tz = value.tzinfo + if tz is None: + # without time zone + value = pytz.UTC.localize(value) + seconds, nanoseconds = seconds_and_nanoseconds(value) + return Structure(b"d", seconds, nanoseconds) + elif hasattr(tz, "zone") and tz.zone and isinstance(tz.zone, str): + # with named pytz time zone + seconds, nanoseconds = seconds_and_nanoseconds(value) + return Structure(b"i", seconds, nanoseconds, tz.zone) + elif hasattr(tz, "key") and tz.key and isinstance(tz.key, str): + # with named zoneinfo (Python 3.9+) time zone + seconds, nanoseconds = seconds_and_nanoseconds(value) + return Structure(b"i", seconds, nanoseconds, tz.key) + else: + # with time offset + seconds, nanoseconds = seconds_and_nanoseconds(value) + offset = tz.utcoffset(value) + if offset.microseconds: + raise ValueError("Bolt protocol does not support sub-second " + "UTC offsets.") + offset_seconds = offset.days * 86400 + offset.seconds + return Structure(b"I", seconds, nanoseconds, offset_seconds) diff --git a/neo4j/_codec/packstream/__init__.py b/neo4j/_codec/packstream/__init__.py new file mode 100644 index 000000000..ba0188b0f --- /dev/null +++ b/neo4j/_codec/packstream/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._common import Structure + + +__all__ = [ + "Structure", +] diff --git a/neo4j/_codec/packstream/_common.py b/neo4j/_codec/packstream/_common.py new file mode 100644 index 000000000..84403de7c --- /dev/null +++ b/neo4j/_codec/packstream/_common.py @@ -0,0 +1,44 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Structure: + + def __init__(self, tag, *fields): + self.tag = tag + self.fields = list(fields) + + def __repr__(self): + return "Structure[0x%02X](%s)" % (ord(self.tag), ", ".join(map(repr, self.fields))) + + def __eq__(self, other): + try: + return self.tag == other.tag and self.fields == other.fields + except AttributeError: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __len__(self): + return len(self.fields) + + def __getitem__(self, key): + return self.fields[key] + + def __setitem__(self, key, value): + self.fields[key] = value diff --git a/neo4j/packstream.py b/neo4j/_codec/packstream/v1/__init__.py similarity index 78% rename from neo4j/packstream.py rename to neo4j/_codec/packstream/v1/__init__.py index 92bf6b96b..0c74b6879 100644 --- a/neo4j/packstream.py +++ b/neo4j/_codec/packstream/v1/__init__.py @@ -17,11 +17,14 @@ from codecs import decode +from contextlib import contextmanager from struct import ( pack as struct_pack, unpack as struct_unpack, ) +from .._common import Structure + PACKED_UINT_8 = [struct_pack(">B", value) for value in range(0x100)] PACKED_UINT_16 = [struct_pack(">H", value) for value in range(0x10000)] @@ -38,34 +41,6 @@ INT64_MAX = 2 ** 63 -class Structure: - - def __init__(self, tag, *fields): - self.tag = tag - self.fields = list(fields) - - def __repr__(self): - return "Structure[0x%02X](%s)" % (ord(self.tag), ", ".join(map(repr, self.fields))) - - def __eq__(self, other): - try: - return self.tag == other.tag and self.fields == other.fields - except AttributeError: - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def __len__(self): - return len(self.fields) - - def __getitem__(self, key): - return self.fields[key] - - def __setitem__(self, key, value): - self.fields[key] = value - - class Packer: def __init__(self, stream): @@ -75,10 +50,7 @@ def __init__(self, stream): def pack_raw(self, data): self._write(data) - def pack(self, value): - return self._pack(value) - - def _pack(self, value): + def pack(self, value, dehydration_hooks=None): write = self._write # None @@ -130,20 +102,26 @@ def _pack(self, value): elif isinstance(value, list): self.pack_list_header(len(value)) for item in value: - self._pack(item) + self.pack(item, dehydration_hooks=dehydration_hooks) # Map elif isinstance(value, dict): self.pack_map_header(len(value)) for key, item in value.items(): - self._pack(key) - self._pack(item) + if not isinstance(key, str): + raise TypeError( + "Map keys must be strings, not {}".format(type(key)) + ) + self.pack(key, dehydration_hooks=dehydration_hooks) + self.pack(item, dehydration_hooks=dehydration_hooks) # Structure elif isinstance(value, Structure): self.pack_struct(value.tag, value.fields) # Other + elif dehydration_hooks and type(value) in dehydration_hooks: + self.pack(dehydration_hooks[type(value)](value)) else: raise ValueError("Values of type %s are not supported" % type(value)) @@ -209,7 +187,7 @@ def pack_map_header(self, size): else: raise OverflowError("Map header size out of range") - def pack_struct(self, signature, fields): + def pack_struct(self, signature, fields, dehydration_hooks=None): if len(signature) != 1 or not isinstance(signature, bytes): raise ValueError("Structure signature must be a single byte value") write = self._write @@ -220,7 +198,35 @@ def pack_struct(self, signature, fields): raise OverflowError("Structure size out of range") write(signature) for field in fields: - self._pack(field) + self.pack(field, dehydration_hooks=dehydration_hooks) + + @staticmethod + def new_packable_buffer(): + return PackableBuffer() + + +class PackableBuffer: + def __init__(self): + self.data = bytearray() + # export write method for packer; "inline" for performance + self.write = self.data.extend + self.clear = self.data.clear + self._tmp_buffering = 0 + + @contextmanager + def tmp_buffer(self): + self._tmp_buffering += 1 + old_len = len(self.data) + try: + yield + except Exception: + del self.data[old_len:] + raise + finally: + self._tmp_buffering -= 1 + + def is_tmp_buffering(self): + return bool(self._tmp_buffering) class Unpacker: @@ -237,10 +243,13 @@ def read(self, n=1): def read_u8(self): return self.unpackable.read_u8() - def unpack(self): - return self._unpack() + def unpack(self, hydration_hooks=None): + value = self._unpack(hydration_hooks=hydration_hooks) + if hydration_hooks and type(value) in hydration_hooks: + return hydration_hooks[type(value)](value) + return value - def _unpack(self): + def _unpack(self, hydration_hooks=None): marker = self.read_u8() if marker == -1: @@ -305,82 +314,86 @@ def _unpack(self): # List elif 0x90 <= marker <= 0x9F or 0xD4 <= marker <= 0xD6: - return list(self._unpack_list_items(marker)) + return list(self._unpack_list_items( + marker, hydration_hooks=hydration_hooks) + ) # Map elif 0xA0 <= marker <= 0xAF or 0xD8 <= marker <= 0xDA: - return self._unpack_map(marker) + return self._unpack_map( + marker, hydration_hooks=hydration_hooks + ) # Structure elif 0xB0 <= marker <= 0xBF: size, tag = self._unpack_structure_header(marker) value = Structure(tag, *([None] * size)) for i in range(len(value)): - value[i] = self._unpack() + value[i] = self.unpack(hydration_hooks=hydration_hooks) return value else: raise ValueError("Unknown PackStream marker %02X" % marker) - def _unpack_list_items(self, marker): + def _unpack_list_items(self, marker, hydration_hooks=None): marker_high = marker & 0xF0 if marker_high == 0x90: size = marker & 0x0F if size == 0: return elif size == 1: - yield self._unpack() + yield self.unpack(hydration_hooks=hydration_hooks) else: for _ in range(size): - yield self._unpack() + yield self.unpack(hydration_hooks=hydration_hooks) elif marker == 0xD4: # LIST_8: size, = struct_unpack(">B", self.read(1)) for _ in range(size): - yield self._unpack() + yield self.unpack(hydration_hooks=hydration_hooks) elif marker == 0xD5: # LIST_16: size, = struct_unpack(">H", self.read(2)) for _ in range(size): - yield self._unpack() + yield self.unpack(hydration_hooks=hydration_hooks) elif marker == 0xD6: # LIST_32: size, = struct_unpack(">I", self.read(4)) for _ in range(size): - yield self._unpack() + yield self.unpack(hydration_hooks=hydration_hooks) else: return - def unpack_map(self): + def unpack_map(self, hydration_hooks=None): marker = self.read_u8() - return self._unpack_map(marker) + return self._unpack_map(marker, hydration_hooks=hydration_hooks) - def _unpack_map(self, marker): + def _unpack_map(self, marker, hydration_hooks=None): marker_high = marker & 0xF0 if marker_high == 0xA0: size = marker & 0x0F value = {} for _ in range(size): - key = self._unpack() - value[key] = self._unpack() + key = self.unpack(hydration_hooks=hydration_hooks) + value[key] = self.unpack(hydration_hooks=hydration_hooks) return value elif marker == 0xD8: # MAP_8: size, = struct_unpack(">B", self.read(1)) value = {} for _ in range(size): - key = self._unpack() - value[key] = self._unpack() + key = self.unpack(hydration_hooks=hydration_hooks) + value[key] = self.unpack(hydration_hooks=hydration_hooks) return value elif marker == 0xD9: # MAP_16: size, = struct_unpack(">H", self.read(2)) value = {} for _ in range(size): - key = self._unpack() - value[key] = self._unpack() + key = self.unpack(hydration_hooks=hydration_hooks) + value[key] = self.unpack(hydration_hooks=hydration_hooks) return value elif marker == 0xDA: # MAP_32: size, = struct_unpack(">I", self.read(4)) value = {} for _ in range(size): - key = self._unpack() - value[key] = self._unpack() + key = self.unpack(hydration_hooks=hydration_hooks) + value[key] = self.unpack(hydration_hooks=hydration_hooks) return value else: return None @@ -400,6 +413,10 @@ def _unpack_structure_header(self, marker): else: raise ValueError("Expected structure, found marker %02X" % marker) + @staticmethod + def new_unpackable_buffer(): + return UnpackableBuffer() + class UnpackableBuffer: diff --git a/neo4j/data.py b/neo4j/_data.py similarity index 66% rename from neo4j/data.py rename to neo4j/_data.py index 7ce3e712e..2daf7be03 100644 --- a/neo4j/data.py +++ b/neo4j/_data.py @@ -25,52 +25,15 @@ Sequence, Set, ) -from datetime import ( - date, - datetime, - time, - timedelta, -) from functools import reduce from operator import xor as xor_operator from .conf import iter_items from .graph import ( - Graph, Node, Path, Relationship, ) -from .packstream import ( - INT64_MAX, - INT64_MIN, - Structure, -) -from .spatial import ( - dehydrate_point, - hydrate_point, - Point, -) -from .time import ( - Date, - DateTime, - Duration, - Time, -) -from .time.hydration import ( - dehydrate_date, - dehydrate_datetime, - dehydrate_duration, - dehydrate_time, - dehydrate_timedelta, - hydrate_date, - hydrate_datetime, - hydrate_duration, - hydrate_time, -) - - -map_type = type(map(str, range(0))) class Record(tuple, Mapping): @@ -348,126 +311,3 @@ def _transform(self, x, prefix): ) else: return {prefix: x} - - -class DataHydrator: - # TODO: extend DataTransformer - - def __init__(self): - super(DataHydrator, self).__init__() - self.graph = Graph() - self.graph_hydrator = Graph.Hydrator(self.graph) - self.hydration_functions = { - b"N": self.graph_hydrator.hydrate_node, - b"R": self.graph_hydrator.hydrate_relationship, - b"r": self.graph_hydrator.hydrate_unbound_relationship, - b"P": self.graph_hydrator.hydrate_path, - b"X": hydrate_point, - b"Y": hydrate_point, - b"D": hydrate_date, - b"T": hydrate_time, # time zone offset - b"t": hydrate_time, # no time zone - b"F": hydrate_datetime, # time zone offset - b"f": hydrate_datetime, # time zone name - b"d": hydrate_datetime, # no time zone - b"E": hydrate_duration, - } - - def hydrate(self, values): - """ Convert PackStream values into native values. - """ - - def hydrate_(obj): - if isinstance(obj, Structure): - try: - f = self.hydration_functions[obj.tag] - except KeyError: - # If we don't recognise the structure - # type, just return it as-is - return obj - else: - return f(*map(hydrate_, obj.fields)) - elif isinstance(obj, list): - return list(map(hydrate_, obj)) - elif isinstance(obj, dict): - return {key: hydrate_(value) for key, value in obj.items()} - else: - return obj - - return tuple(map(hydrate_, values)) - - def hydrate_records(self, keys, record_values): - for values in record_values: - yield Record(zip(keys, self.hydrate(values))) - - -class DataDehydrator: - # TODO: extend DataTransformer - - @classmethod - def fix_parameters(cls, parameters): - if not parameters: - return {} - dehydrator = cls() - try: - dehydrated, = dehydrator.dehydrate([parameters]) - except TypeError as error: - value = error.args[0] - raise TypeError("Parameters of type {} are not supported".format(type(value).__name__)) - else: - return dehydrated - - def __init__(self): - self.dehydration_functions = {} - self.dehydration_functions.update({ - Point: dehydrate_point, - Date: dehydrate_date, - date: dehydrate_date, - Time: dehydrate_time, - time: dehydrate_time, - DateTime: dehydrate_datetime, - datetime: dehydrate_datetime, - Duration: dehydrate_duration, - timedelta: dehydrate_timedelta, - }) - # Allow dehydration from any direct Point subclass - self.dehydration_functions.update({cls: dehydrate_point for cls in Point.__subclasses__()}) - - def dehydrate(self, values): - """ Convert native values into PackStream values. - """ - - def dehydrate_(obj): - try: - f = self.dehydration_functions[type(obj)] - except KeyError: - pass - else: - return f(obj) - if obj is None: - return None - elif isinstance(obj, bool): - return obj - elif isinstance(obj, int): - if INT64_MIN <= obj <= INT64_MAX: - return obj - raise ValueError("Integer out of bounds (64-bit signed " - "integer values only)") - elif isinstance(obj, float): - return obj - elif isinstance(obj, str): - return obj - elif isinstance(obj, (bytes, bytearray)): - # order is important here - bytes must be checked after str - return obj - elif isinstance(obj, (list, map_type)): - return list(map(dehydrate_, obj)) - elif isinstance(obj, dict): - if any(not isinstance(key, str) for key in obj.keys()): - raise TypeError("Non-string dictionary keys are " - "not supported") - return {key: dehydrate_(value) for key, value in obj.items()} - else: - raise TypeError(obj) - - return tuple(map(dehydrate_, values)) diff --git a/neo4j/meta.py b/neo4j/_meta.py similarity index 94% rename from neo4j/meta.py rename to neo4j/_meta.py index 0e487c2e2..8aea8a95a 100644 --- a/neo4j/meta.py +++ b/neo4j/_meta.py @@ -39,8 +39,8 @@ def get_user_agent(): return template.format(*fields) -def deprecation_warn(message, stack_level=2): - warn(message, category=DeprecationWarning, stacklevel=stack_level) +def deprecation_warn(message, stack_level=1): + warn(message, category=DeprecationWarning, stacklevel=stack_level + 1) def deprecated(message): @@ -57,14 +57,14 @@ def decorator(f): if asyncio.iscoroutinefunction(f): @wraps(f) async def inner(*args, **kwargs): - deprecation_warn(message, stack_level=3) + deprecation_warn(message, stack_level=2) return await f(*args, **kwargs) return inner else: @wraps(f) def inner(*args, **kwargs): - deprecation_warn(message, stack_level=3) + deprecation_warn(message, stack_level=2) return f(*args, **kwargs) return inner diff --git a/neo4j/_spatial/__init__.py b/neo4j/_spatial/__init__.py new file mode 100644 index 000000000..3c84a0b0f --- /dev/null +++ b/neo4j/_spatial/__init__.py @@ -0,0 +1,106 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This module defines _spatial data types. +""" + + +from threading import Lock + + +# SRID to subclass mappings +srid_table = {} +srid_table_lock = Lock() + + +class Point(tuple): + """Base-class for _spatial data. + + A point within a geometric space. This type is generally used via its + subclasses and should not be instantiated directly unless there is no + subclass defined for the required SRID. + + :param iterable: + An iterable of coordinates. + All items will be converted to :class:`float`. + """ + + #: The SRID (_spatial reference identifier) of the _spatial data. + #: A number that identifies the coordinate system the _spatial type is to be + #: interpreted in. + #: + #: :type: int + srid = None + + def __new__(cls, iterable): + return tuple.__new__(cls, map(float, iterable)) + + def __repr__(self): + return "POINT(%s)" % " ".join(map(str, self)) + + def __eq__(self, other): + try: + return type(self) is type(other) and tuple(self) == tuple(other) + except (AttributeError, TypeError): + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(type(self)) ^ hash(tuple(self)) + + +def point_type(name, fields, srid_map): + """ Dynamically create a Point subclass. + """ + + def srid(self): + try: + return srid_map[len(self)] + except KeyError: + return None + + attributes = {"srid": property(srid)} + + for index, subclass_field in enumerate(fields): + + def accessor(self, i=index, f=subclass_field): + try: + return self[i] + except IndexError: + raise AttributeError(f) + + for field_alias in {subclass_field, "xyz"[index]}: + attributes[field_alias] = property(accessor) + + cls = type(name, (Point,), attributes) + + with srid_table_lock: + for dim, srid in srid_map.items(): + srid_table[srid] = (cls, dim) + + return cls + + +# Point subclass definitions +CartesianPoint = point_type("CartesianPoint", ["x", "y", "z"], + {2: 7203, 3: 9157}) +WGS84Point = point_type("WGS84Point", ["longitude", "latitude", "height"], + {2: 4326, 3: 4979}) diff --git a/neo4j/_sync/driver.py b/neo4j/_sync/driver.py index 975c7048f..aa03264f1 100644 --- a/neo4j/_sync/driver.py +++ b/neo4j/_sync/driver.py @@ -21,6 +21,11 @@ TrustAll, TrustStore, ) +from .._meta import ( + deprecation_warn, + experimental, + unclosed_resource_warn, +) from ..addressing import Address from ..api import ( READ_ACCESS, @@ -33,11 +38,6 @@ SessionConfig, WorkspaceConfig, ) -from ..meta import ( - deprecation_warn, - experimental, - unclosed_resource_warn, -) class GraphDatabase: diff --git a/neo4j/_sync/io/_bolt.py b/neo4j/_sync/io/_bolt.py index 5093de404..2b6fb8e65 100644 --- a/neo4j/_sync/io/_bolt.py +++ b/neo4j/_sync/io/_bolt.py @@ -24,11 +24,14 @@ from ..._async_compat.network import BoltSocket from ..._async_compat.util import Util +from ..._codec.hydration import v1 as hydration_v1 +from ..._codec.packstream import v1 as packstream_v1 from ..._exceptions import ( BoltError, BoltHandshakeError, SocketDeadlineExceeded, ) +from ..._meta import get_user_agent from ...addressing import Address from ...api import ( ServerInfo, @@ -42,11 +45,6 @@ ServiceUnavailable, SessionExpired, ) -from ...meta import get_user_agent -from ...packstream import ( - Packer, - Unpacker, -) from ._common import ( CommitResponse, Inbox, @@ -68,6 +66,13 @@ class Bolt: the handshake was carried out. """ + # TODO: let packer/unpacker know of hydration (give them hooks?) + # TODO: make sure query parameter dehydration gets clear error message. + + PACKER_CLS = packstream_v1.Packer + UNPACKER_CLS = packstream_v1.Unpacker + HYDRATION_HANDLER_CLS = hydration_v1.HydrationHandler + MAGIC_PREAMBLE = b"\x60\x60\xB0\x17" PROTOCOL_VERSION = None @@ -107,10 +112,16 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, # configuration hint that exists. Therefore, all hints can be stored at # connection level. This might change in the future. self.configuration_hints = {} - self.outbox = Outbox() - self.inbox = Inbox(self.socket, on_error=self._set_defunct_read) - self.packer = Packer(self.outbox) - self.unpacker = Unpacker(self.inbox) + self.patch = {} + self.outbox = Outbox( + self.socket, on_error=self._set_defunct_write, + packer_cls=self.PACKER_CLS + ) + self.inbox = Inbox( + self.socket, on_error=self._set_defunct_read, + unpacker_cls=self.UNPACKER_CLS + ) + self.hydration_handler = self.HYDRATION_HANDLER_CLS() self.responses = deque() self._max_connection_lifetime = max_connection_lifetime self._creation_timestamp = perf_counter() @@ -376,14 +387,17 @@ def der_encoded_server_certificate(self): pass @abc.abstractmethod - def hello(self): + def hello(self, dehydration_hooks=None, hydration_hooks=None): """ Appends a HELLO message to the outgoing queue, sends it and consumes all remaining messages. """ pass @abc.abstractmethod - def route(self, database=None, imp_user=None, bookmarks=None): + def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): """ Fetch a routing table from the server for the given `database`. For Bolt 4.3 and above, this appends a ROUTE message; for earlier versions, a procedure call is made via @@ -396,13 +410,22 @@ def route(self, database=None, imp_user=None, bookmarks=None): Requires Bolt 4.4+. :param bookmarks: iterable of bookmark values after which this transaction should begin - :return: dictionary of raw routing data + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. """ pass @abc.abstractmethod def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, + **handlers): """ Appends a RUN message to the output queue. :param query: Cypher query string @@ -415,36 +438,60 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, Requires Bolt 4.0+. :param imp_user: the user to impersonate Requires Bolt 4.4+. + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object - :return: Response object """ pass @abc.abstractmethod - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): """ Appends a DISCARD message to the output queue. :param n: number of records to discard, default = -1 (ALL) :param qid: query ID to discard for, default = -1 (last query) + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object - :return: Response object """ pass @abc.abstractmethod - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): """ Appends a PULL message to the output queue. :param n: number of records to pull, default = -1 (ALL) :param qid: query ID to pull for, default = -1 (last query) + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object - :return: Response object """ pass @abc.abstractmethod def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): """ Appends a BEGIN message to the output queue. :param mode: access mode for routing - "READ" or "WRITE" (default) @@ -455,53 +502,99 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, Requires Bolt 4.0+. :param imp_user: the user to impersonate Requires Bolt 4.4+ + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. :param handlers: handler functions passed into the returned Response object :return: Response object """ pass @abc.abstractmethod - def commit(self, **handlers): - """ Appends a COMMIT message to the output queue.""" + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + """ Appends a COMMIT message to the output queue. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. + """ pass @abc.abstractmethod - def rollback(self, **handlers): - """ Appends a ROLLBACK message to the output queue.""" + def rollback(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + """ Appends a ROLLBACK message to the output queue. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything.""" pass @abc.abstractmethod - def reset(self): + def reset(self, dehydration_hooks=None, hydration_hooks=None): """ Appends a RESET message to the outgoing queue, sends it and consumes all remaining messages. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. """ pass @abc.abstractmethod - def goodbye(self): - """Append a GOODBYE message to the outgoing queue.""" + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + """Append a GOODBYE message to the outgoing queue. + + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. + :param hydration_hooks: + Hooks to hydrate types (mapping from type (class) to + dehydration function). Dehydration functions receive the value of + type understood by packstream and are free to return anything. + """ pass - def _append(self, signature, fields=(), response=None): + def new_hydration_scope(self): + return self.hydration_handler.new_hydration_scope() + + def _append(self, signature, fields=(), response=None, + dehydration_hooks=None): """ Appends a message to the outgoing queue. :param signature: the signature of the message :param fields: the fields of the message as a tuple :param response: a response object to handle callbacks + :param dehydration_hooks: + Hooks to dehydrate types (dict from type (class) to dehydration + function). Dehydration functions receive the value and returns an + object of type understood by packstream. """ - with self.outbox.tmp_buffer(): - self.packer.pack_struct(signature, fields) - self.outbox.wrap_message() + self.outbox.append_message(signature, fields, dehydration_hooks) self.responses.append(response) def _send_all(self): - data = self.outbox.view() - if data: - try: - self.socket.sendall(data) - except OSError as error: - self._set_defunct_write(error) - self.outbox.clear() + if self.outbox.flush(): self.idle_since = perf_counter() def send_all(self): @@ -523,8 +616,7 @@ def send_all(self): self._send_all() @abc.abstractmethod - def _process_message(self, details, summary_signature, - summary_metadata): + def _process_message(self, tag, fields): """ Receive at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary @@ -549,11 +641,10 @@ def fetch_message(self): return 0, 0 # Receive exactly one message - details, summary_signature, summary_metadata = \ - Util.next(self.inbox) - res = self._process_message( - details, summary_signature, summary_metadata + tag, fields = self.inbox.pop( + hydration_hooks=self.responses[0].hydration_hooks ) + res = self._process_message(tag, fields) self.idle_since = perf_counter() return res diff --git a/neo4j/_sync/io/_bolt3.py b/neo4j/_sync/io/_bolt3.py index ac6e61fb4..1a169f71c 100644 --- a/neo4j/_sync/io/_bolt3.py +++ b/neo4j/_sync/io/_bolt3.py @@ -142,7 +142,7 @@ def get_base_headers(self): "user_agent": self.user_agent, } - def hello(self): + def hello(self, dehydration_hooks=None, hydration_hooks=None): headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -150,13 +150,17 @@ def hello(self): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=self.server_info.update)) + response=InitResponse(self, "hello", hydration_hooks, + on_success=self.server_info.update), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent) - def route(self, database=None, imp_user=None, bookmarks=None): + def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): if database is not None: raise ConfigurationError( "Database name parameter for selecting database is not " @@ -183,16 +187,20 @@ def route(self, database=None, imp_user=None, bookmarks=None): "CALL dbms.cluster.routing.getRoutingTable($context)", # This is an internal procedure call. Only available if the Neo4j 3.5 is setup with clustering. {"context": self.routing_context}, mode="r", # Bolt Protocol Version(3, 0) supports mode="r" + dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks, on_success=metadata.update ) - self.pull(on_success=metadata.update, on_records=records.extend) + self.pull(dehydration_hooks = None, hydration_hooks = None, + on_success=metadata.update, on_records=records.extend) self.send_all() self.fetch_all() routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] return routing_info def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if db is not None: raise ConfigurationError( "Database name parameter for selecting database is not " @@ -231,20 +239,29 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, raise ValueError("Timeout must be a positive number or 0.") fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: DISCARD_ALL", self.local_port) - self._append(b"\x2F", (), Response(self, "discard", **handlers)) + self._append(b"\x2F", (), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: PULL_ALL", self.local_port) - self._append(b"\x3F", (), Response(self, "pull", **handlers)) + self._append(b"\x3F", (), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): if db is not None: raise ConfigurationError( "Database name parameter for selecting database is not " @@ -280,17 +297,25 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a positive number or 0.") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def commit(self, **handlers): + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + self._append(b"\x12", (), + CommitResponse(self, "commit", hydration_hooks, + **handlers), + dehydration_hooks=dehydration_hooks) - def rollback(self, **handlers): + def rollback(self, dehydration_hooks=None, hydration_hooks=None, + **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, "rollback", **handlers)) + self._append(b"\x13", (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def reset(self): + def reset(self, dehydration_hooks=None, hydration_hooks=None): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ @@ -299,21 +324,33 @@ def fail(metadata): raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self._append(b"\x0F", + response=Response(self, "reset", hydration_hooks, + on_failure=fail), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() - def goodbye(self): + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): log.debug("[#%04X] C: GOODBYE", self.local_port) - self._append(b"\x02", ()) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) - def _process_message(self, details, summary_signature, - summary_metadata): + def _process_message(self, tag, fields): """ Process at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data self.responses[0].on_records(details) diff --git a/neo4j/_sync/io/_bolt4.py b/neo4j/_sync/io/_bolt4.py index 2c26af8cd..609115264 100644 --- a/neo4j/_sync/io/_bolt4.py +++ b/neo4j/_sync/io/_bolt4.py @@ -34,11 +34,11 @@ NotALeader, ServiceUnavailable, ) +from ._bolt import Bolt from ._bolt3 import ( ServerStateManager, ServerStates, ) -from ._bolt import Bolt from ._common import ( check_supported_server_product, CommitResponse, @@ -95,7 +95,7 @@ def get_base_headers(self): "user_agent": self.user_agent, } - def hello(self): + def hello(self, dehydration_hooks=None, hydration_hooks=None): headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -103,13 +103,19 @@ def hello(self): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=self.server_info.update)) + response=InitResponse( + self, "hello", hydration_hooks, + on_success=self.server_info.update + ), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent) - def route(self, database=None, imp_user=None, bookmarks=None): + def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -138,14 +144,20 @@ def route(self, database=None, imp_user=None, bookmarks=None): db=SYSTEM_DATABASE, on_success=metadata.update ) - self.pull(on_success=metadata.update, on_records=records.extend) + self.pull( + dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks, + on_success=metadata.update, + on_records=records.extend + ) self.send_all() self.fetch_all() routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] return routing_info def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -179,24 +191,33 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, raise ValueError("Timeout must be a positive number or 0.") fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) - self._append(b"\x2F", (extra,), Response(self, "discard", **handlers)) + self._append(b"\x2F", (extra,), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: PULL %r", self.local_port, extra) - self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) + self._append(b"\x3F", (extra,), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -227,17 +248,25 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a positive number or 0.") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def commit(self, **handlers): + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + self._append(b"\x12", (), + CommitResponse(self, "commit", hydration_hooks, + **handlers), + dehydration_hooks=dehydration_hooks) - def rollback(self, **handlers): + def rollback(self, dehydration_hooks=None, hydration_hooks=None, + **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, "rollback", **handlers)) + self._append(b"\x13", (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def reset(self): + def reset(self, dehydration_hooks=None, hydration_hooks=None): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ @@ -246,21 +275,33 @@ def fail(metadata): raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self._append(b"\x0F", + response=Response(self, "reset", hydration_hooks, + on_failure=fail), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() - def goodbye(self): + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): log.debug("[#%04X] C: GOODBYE", self.local_port) - self._append(b"\x02", ()) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) - def _process_message(self, details, summary_signature, - summary_metadata): + def _process_message(self, tag, fields): """ Process at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data self.responses[0].on_records(details) @@ -341,7 +382,15 @@ class Bolt4x3(Bolt4x2): PROTOCOL_VERSION = Version(4, 3) - def route(self, database=None, imp_user=None, bookmarks=None): + def get_base_headers(self): + headers = super().get_base_headers() + headers["patch_bolt"] = ["utc"] + return headers + + def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): if imp_user is not None: raise ConfigurationError( "Impersonation is not supported in Bolt Protocol {!r}. " @@ -359,13 +408,14 @@ def route(self, database=None, imp_user=None, bookmarks=None): else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, database), - response=Response(self, "route", - on_success=metadata.update)) + response=Response(self, "route", hydration_hooks, + on_success=metadata.update), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() return [metadata.get("rt")] - def hello(self): + def hello(self, dehydration_hooks=None, hydration_hooks=None): def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) @@ -380,6 +430,9 @@ def on_success(metadata): "connection.recv_timeout_seconds (%r). Make sure " "the server and network is set up correctly.", self.local_port, recv_timeout) + self.patch = set(metadata.pop("patch_bolt", [])) + if "utc" in self.patch: + self.hydration_handler.patch_utc() headers = self.get_base_headers() headers.update(self.auth_dict) @@ -388,8 +441,9 @@ def on_success(metadata): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=on_success)) + response=InitResponse(self, "hello", hydration_hooks, + on_success=on_success), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent) @@ -403,7 +457,10 @@ class Bolt4x4(Bolt4x3): PROTOCOL_VERSION = Version(4, 4) - def route(self, database=None, imp_user=None, bookmarks=None): + def route( + self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None + ): routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -418,14 +475,16 @@ def route(self, database=None, imp_user=None, bookmarks=None): else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, db_context), - response=Response(self, "route", - on_success=metadata.update)) + response=Response(self, "route", hydration_hooks, + on_success=metadata.update), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() return [metadata.get("rt")] def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if not parameters: parameters = {} extra = {} @@ -456,10 +515,13 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {} if mode in (READ_ACCESS, "r"): # It will default to mode "w" if nothing is specified @@ -486,4 +548,6 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a positive number or 0.") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) diff --git a/neo4j/_sync/io/_bolt5.py b/neo4j/_sync/io/_bolt5.py index 74fe2d183..a7180e5a5 100644 --- a/neo4j/_sync/io/_bolt5.py +++ b/neo4j/_sync/io/_bolt5.py @@ -19,28 +19,24 @@ from logging import getLogger from ssl import SSLSocket -from ..._async_compat.util import Util -from ..._exceptions import ( - BoltError, - BoltProtocolError, -) +from ..._codec.hydration import v2 as hydration_v2 +from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, Version, ) from ...exceptions import ( DatabaseUnavailable, - DriverError, ForbiddenOnReadOnlyDatabase, Neo4jError, NotALeader, ServiceUnavailable, ) +from ._bolt import Bolt from ._bolt3 import ( ServerStateManager, ServerStates, ) -from ._bolt import Bolt from ._common import ( check_supported_server_product, CommitResponse, @@ -57,6 +53,8 @@ class Bolt5x0(Bolt): PROTOCOL_VERSION = Version(5, 0) + HYDRATION_HANDLER_CLS = hydration_v2.HydrationHandler + supports_multiple_results = True supports_multiple_databases = True @@ -95,7 +93,7 @@ def get_base_headers(self): headers["routing"] = self.routing_context return headers - def hello(self): + def hello(self, dehydration_hooks=None, hydration_hooks=None): def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) @@ -118,13 +116,15 @@ def on_success(metadata): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, "hello", - on_success=on_success)) + response=InitResponse(self, "hello", hydration_hooks, + on_success=on_success), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent) - def route(self, database=None, imp_user=None, bookmarks=None): + def route(self, database=None, imp_user=None, bookmarks=None, + dehydration_hooks=None, hydration_hooks=None): routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -139,14 +139,16 @@ def route(self, database=None, imp_user=None, bookmarks=None): else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, db_context), - response=Response(self, "route", - on_success=metadata.update)) + response=Response(self, "route", hydration_hooks, + on_success=metadata.update), + dehydration_hooks=hydration_hooks) self.send_all() self.fetch_all() return [metadata.get("rt")] def run(self, query, parameters=None, mode=None, bookmarks=None, - metadata=None, timeout=None, db=None, imp_user=None, **handlers): + metadata=None, timeout=None, db=None, imp_user=None, + dehydration_hooks=None, hydration_hooks=None, **handlers): if not parameters: parameters = {} extra = {} @@ -177,24 +179,33 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) - self._append(b"\x10", fields, Response(self, "run", **handlers)) + self._append(b"\x10", fields, + Response(self, "run", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def discard(self, n=-1, qid=-1, **handlers): + def discard(self, n=-1, qid=-1, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) - self._append(b"\x2F", (extra,), Response(self, "discard", **handlers)) + self._append(b"\x2F", (extra,), + Response(self, "discard", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def pull(self, n=-1, qid=-1, **handlers): + def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, + **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: PULL %r", self.local_port, extra) - self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) + self._append(b"\x3F", (extra,), + Response(self, "pull", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, imp_user=None, **handlers): + db=None, imp_user=None, dehydration_hooks=None, + hydration_hooks=None, **handlers): extra = {} if mode in (READ_ACCESS, "r"): # It will default to mode "w" if nothing is specified @@ -221,17 +232,25 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, if extra["tx_timeout"] < 0: raise ValueError("Timeout must be a number <= 0") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) + self._append(b"\x11", (extra,), + Response(self, "begin", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def commit(self, **handlers): + def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) + self._append(b"\x12", (), + CommitResponse(self, "commit", hydration_hooks, + **handlers), + dehydration_hooks=dehydration_hooks) - def rollback(self, **handlers): + def rollback(self, dehydration_hooks=None, hydration_hooks=None, + **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, "rollback", **handlers)) + self._append(b"\x13", (), + Response(self, "rollback", hydration_hooks, **handlers), + dehydration_hooks=dehydration_hooks) - def reset(self): + def reset(self, dehydration_hooks=None, hydration_hooks=None): """Reset the connection. Add a RESET message to the outgoing queue, send it and consume all @@ -243,22 +262,33 @@ def fail(metadata): self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", - on_failure=fail)) + self._append(b"\x0F", + response=Response(self, "reset", hydration_hooks, + on_failure=fail), + dehydration_hooks=dehydration_hooks) self.send_all() self.fetch_all() - def goodbye(self): + def goodbye(self, dehydration_hooks=None, hydration_hooks=None): log.debug("[#%04X] C: GOODBYE", self.local_port) - self._append(b"\x02", ()) + self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) - def _process_message(self, details, summary_signature, - summary_metadata): + def _process_message(self, tag, fields): """Process at most one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + if details: # Do not log any data log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) diff --git a/neo4j/_sync/io/_common.py b/neo4j/_sync/io/_common.py index 26a2b5547..1eea4ba34 100644 --- a/neo4j/_sync/io/_common.py +++ b/neo4j/_sync/io/_common.py @@ -17,7 +17,6 @@ import asyncio -from contextlib import contextmanager import logging import socket from struct import pack as struct_pack @@ -30,132 +29,122 @@ SessionExpired, UnsupportedServerProduct, ) -from ...packstream import ( - UnpackableBuffer, - Unpacker, -) log = logging.getLogger("neo4j") -class MessageInbox: +class Inbox: - def __init__(self, s, on_error): + def __init__(self, sock, on_error, unpacker_cls): self.on_error = on_error - self._local_port = s.getsockname()[1] - self._messages = self._yield_messages(s) - - def _yield_messages(self, sock): + self._local_port = sock.getsockname()[1] + self._socket = sock + self._buffer = unpacker_cls.new_unpackable_buffer() + self._unpacker = unpacker_cls(self._buffer) + self._broken = False + + def _buffer_one_chunk(self): + assert not self._broken try: - buffer = UnpackableBuffer() - unpacker = Unpacker(buffer) chunk_size = 0 while True: - while chunk_size == 0: # Determine the chunk size and skip noop - receive_into_buffer(sock, buffer, 2) - chunk_size = buffer.pop_u16() + receive_into_buffer(self._socket, self._buffer, 2) + chunk_size = self._buffer.pop_u16() if chunk_size == 0: log.debug("[#%04X] S: ", self._local_port) - receive_into_buffer(sock, buffer, chunk_size + 2) - chunk_size = buffer.pop_u16() + receive_into_buffer( + self._socket, self._buffer, chunk_size + 2 + ) + chunk_size = self._buffer.pop_u16() if chunk_size == 0: # chunk_size was the end marker for the message - size, tag = unpacker.unpack_structure_header() - fields = [unpacker.unpack() for _ in range(size)] - yield tag, fields - # Reset for new message - unpacker.reset() + return except (OSError, socket.timeout, SocketDeadlineExceeded) as error: + self._broken = True Util.callback(self.on_error, error) + raise - def pop(self): - return Util.next(self._messages) - - -class Inbox(MessageInbox): - - def __next__(self): - tag, fields = self.pop() - if tag == b"\x71": - return fields, None, None - elif fields: - return [], tag, fields[0] - else: - return [], tag, None + def pop(self, hydration_hooks): + self._buffer_one_chunk() + try: + size, tag = self._unpacker.unpack_structure_header() + fields = [self._unpacker.unpack(hydration_hooks) + for _ in range(size)] + return tag, fields + finally: + # Reset for new message + self._unpacker.reset() class Outbox: - def __init__(self, max_chunk_size=16384): + def __init__(self, sock, on_error, packer_cls, max_chunk_size=16384): self._max_chunk_size = max_chunk_size self._chunked_data = bytearray() - self._raw_data = bytearray() - self.write = self._raw_data.extend - self._tmp_buffering = 0 + self._buffer = packer_cls.new_packable_buffer() + self._packer = packer_cls(self._buffer) + self.socket = sock + self.on_error = on_error def max_chunk_size(self): return self._max_chunk_size - def clear(self): - if self._tmp_buffering: - raise RuntimeError("Cannot clear while buffering") + def _clear(self): + assert not self._buffer.is_tmp_buffering() self._chunked_data = bytearray() - self._raw_data.clear() + self._buffer.clear() def _chunk_data(self): - data_len = len(self._raw_data) + data_len = len(self._buffer.data) num_full_chunks, chunk_rest = divmod( data_len, self._max_chunk_size ) num_chunks = num_full_chunks + bool(chunk_rest) - data_view = memoryview(self._raw_data) - header_start = len(self._chunked_data) - data_start = header_start + 2 - raw_data_start = 0 - for i in range(num_chunks): - chunk_size = min(data_len - raw_data_start, - self._max_chunk_size) - self._chunked_data[header_start:data_start] = struct_pack( - ">H", chunk_size - ) - self._chunked_data[data_start:(data_start + chunk_size)] = \ - data_view[raw_data_start:(raw_data_start + chunk_size)] - header_start += chunk_size + 2 + with memoryview(self._buffer.data) as data_view: + header_start = len(self._chunked_data) data_start = header_start + 2 - raw_data_start += chunk_size - del data_view - self._raw_data.clear() - - def wrap_message(self): - if self._tmp_buffering: - raise RuntimeError("Cannot wrap message while buffering") + raw_data_start = 0 + for i in range(num_chunks): + chunk_size = min(data_len - raw_data_start, + self._max_chunk_size) + self._chunked_data[header_start:data_start] = struct_pack( + ">H", chunk_size + ) + self._chunked_data[data_start:(data_start + chunk_size)] = \ + data_view[raw_data_start:(raw_data_start + chunk_size)] + header_start += chunk_size + 2 + data_start = header_start + 2 + raw_data_start += chunk_size + self._buffer.clear() + + def _wrap_message(self): + assert not self._buffer.is_tmp_buffering() self._chunk_data() self._chunked_data += b"\x00\x00" - def view(self): - if self._tmp_buffering: - raise RuntimeError("Cannot view while buffering") - self._chunk_data() - return memoryview(self._chunked_data) + def append_message(self, tag, fields, dehydration_hooks): + with self._buffer.tmp_buffer(): + self._packer.pack_struct(tag, fields, dehydration_hooks) + self._wrap_message() - @contextmanager - def tmp_buffer(self): - self._tmp_buffering += 1 - old_len = len(self._raw_data) - try: - yield - except Exception: - del self._raw_data[old_len:] - raise - finally: - self._tmp_buffering -= 1 + def flush(self): + data = self._chunked_data + if data: + try: + self.socket.sendall(data) + except OSError as error: + self.on_error(error) + return False + self._clear() + return True + return False class ConnectionErrorHandler: @@ -218,8 +207,9 @@ class Response: more detail messages followed by one summary message). """ - def __init__(self, connection, message, **handlers): + def __init__(self, connection, message, hydration_hooks, **handlers): self.connection = connection + self.hydration_hooks = hydration_hooks self.handlers = handlers self.message = message self.complete = False @@ -294,9 +284,9 @@ def receive_into_buffer(sock, buffer, n_bytes): end = buffer.used + n_bytes if end > len(buffer.data): buffer.data += bytearray(end - len(buffer.data)) - view = memoryview(buffer.data) - while buffer.used < end: - n = sock.recv_into(view[buffer.used:end], end - buffer.used) - if n == 0: - raise OSError("No data") - buffer.used += n + with memoryview(buffer.data) as view: + while buffer.used < end: + n = sock.recv_into(view[buffer.used:end], end - buffer.used) + if n == 0: + raise OSError("No data") + buffer.used += n diff --git a/neo4j/_sync/io/_pool.py b/neo4j/_sync/io/_pool.py index ca25a49b2..2de6d090d 100644 --- a/neo4j/_sync/io/_pool.py +++ b/neo4j/_sync/io/_pool.py @@ -17,12 +17,12 @@ import abc +import logging from collections import ( defaultdict, deque, ) from contextlib import contextmanager -import logging from logging import getLogger from random import choice diff --git a/neo4j/_sync/work/result.py b/neo4j/_sync/work/result.py index 807096556..888fd2701 100644 --- a/neo4j/_sync/work/result.py +++ b/neo4j/_sync/work/result.py @@ -20,15 +20,15 @@ from warnings import warn from ..._async_compat.util import Util -from ...data import ( - DataDehydrator, +from ..._data import ( + Record, RecordTableRowExporter, ) +from ..._meta import experimental from ...exceptions import ( ResultConsumedError, ResultNotSingleError, ) -from ...meta import experimental from ...time import ( Date, DateTime, @@ -54,10 +54,9 @@ class Result: :meth:`.AyncSession.run` and :meth:`.Transaction.run`. """ - def __init__(self, connection, hydrant, fetch_size, on_closed, - on_error): + def __init__(self, connection, fetch_size, on_closed, on_error): self._connection = ConnectionErrorHandler(connection, on_error) - self._hydrant = hydrant + self._hydration_scope = connection.new_hydration_scope() self._on_closed = on_closed self._metadata = None self._keys = None @@ -104,7 +103,7 @@ def _run( query_metadata = getattr(query, "metadata", None) query_timeout = getattr(query, "timeout", None) - parameters = DataDehydrator.fix_parameters(dict(parameters or {}, **kwargs)) + parameters = dict(parameters or {}, **kwargs) self._metadata = { "query": query_text, @@ -135,6 +134,7 @@ def on_failed_attach(metadata): timeout=query_timeout, db=db, imp_user=imp_user, + dehydration_hooks=self._hydration_scope.dehydration_hooks, on_success=on_attached, on_failure=on_failed_attach, ) @@ -145,7 +145,10 @@ def on_failed_attach(metadata): def _pull(self): def on_records(records): if not self._discarding: - self._record_buffer.extend(self._hydrant.hydrate_records(self._keys, records)) + self._record_buffer.extend(( + Record(zip(self._keys, record)) + for record in records + )) def on_summary(): self._attached = False @@ -167,6 +170,7 @@ def on_success(summary_metadata): self._connection.pull( n=self._fetch_size, qid=self._qid, + hydration_hooks=self._hydration_scope.hydration_hooks, on_records=on_records, on_success=on_success, on_failure=on_failure, @@ -479,7 +483,7 @@ def graph(self): Can raise :exc:`ResultConsumedError`. """ self._buffer_all() - return self._hydrant.graph + return self._hydration_scope.get_graph() def value(self, key=0, default=None): """Helper function that return the remainder of the result as a list of values. diff --git a/neo4j/_sync/work/session.py b/neo4j/_sync/work/session.py index 1e8175792..861bfaa52 100644 --- a/neo4j/_sync/work/session.py +++ b/neo4j/_sync/work/session.py @@ -21,13 +21,16 @@ from time import perf_counter from ..._async_compat import sleep +from ..._meta import ( + deprecated, + deprecation_warn, +) from ...api import ( Bookmarks, READ_ACCESS, WRITE_ACCESS, ) from ...conf import SessionConfig -from ...data import DataHydrator from ...exceptions import ( ClientError, DriverError, @@ -36,10 +39,6 @@ SessionExpired, TransactionError, ) -from ...meta import ( - deprecated, - deprecation_warn, -) from ...work import Query from .result import Result from .transaction import ( @@ -228,10 +227,8 @@ def run(self, query, parameters=None, **kwargs): protocol_version = cx.PROTOCOL_VERSION server_info = cx.server_info - hydrant = DataHydrator() - self._auto_result = Result( - cx, hydrant, self._config.fetch_size, self._result_closed, + cx, self._config.fetch_size, self._result_closed, self._result_error ) self._auto_result._run( diff --git a/neo4j/_sync/work/transaction.py b/neo4j/_sync/work/transaction.py index 9b6b29c4b..95dd80332 100644 --- a/neo4j/_sync/work/transaction.py +++ b/neo4j/_sync/work/transaction.py @@ -19,7 +19,6 @@ from functools import wraps from ..._async_compat.util import Util -from ...data import DataHydrator from ...exceptions import TransactionError from ...work import Query from ..io import ConnectionErrorHandler @@ -123,8 +122,7 @@ def run(self, query, parameters=None, **kwparameters): self._results[-1]._buffer_all() result = Result( - self._connection, DataHydrator(), self._fetch_size, - self._result_on_closed_handler, + self._connection, self._fetch_size, self._result_on_closed_handler, self._error_handler ) self._results.append(result) diff --git a/neo4j/_sync/work/workspace.py b/neo4j/_sync/work/workspace.py index a177b097c..f8c305930 100644 --- a/neo4j/_sync/work/workspace.py +++ b/neo4j/_sync/work/workspace.py @@ -19,15 +19,15 @@ import asyncio from ..._deadline import Deadline +from ..._meta import ( + deprecation_warn, + unclosed_resource_warn, +) from ...conf import WorkspaceConfig from ...exceptions import ( ServiceUnavailable, SessionExpired, ) -from ...meta import ( - deprecation_warn, - unclosed_resource_warn, -) from ..io import Neo4jPool diff --git a/neo4j/api.py b/neo4j/api.py index 58292ad83..d35f0809d 100644 --- a/neo4j/api.py +++ b/neo4j/api.py @@ -24,8 +24,8 @@ urlparse, ) +from ._meta import deprecated from .exceptions import ConfigurationError -from .meta import deprecated READ_ACCESS = "READ" diff --git a/neo4j/conf.py b/neo4j/conf.py index f93ba30b8..c578cba7c 100644 --- a/neo4j/conf.py +++ b/neo4j/conf.py @@ -24,6 +24,10 @@ TrustCustomCAs, TrustSystemCAs, ) +from ._meta import ( + deprecation_warn, + get_user_agent, +) from .api import ( DEFAULT_DATABASE, TRUST_ALL_CERTIFICATES, @@ -31,10 +35,6 @@ WRITE_ACCESS, ) from .exceptions import ConfigurationError -from .meta import ( - deprecation_warn, - get_user_agent, -) def iter_items(iterable): diff --git a/neo4j/graph/__init__.py b/neo4j/graph/__init__.py index c1721512a..17e1e7b0b 100644 --- a/neo4j/graph/__init__.py +++ b/neo4j/graph/__init__.py @@ -71,90 +71,6 @@ def relationship_type(self, name): cls = self._relationship_types[name] = type(str(name), (Relationship,), {}) return cls - class Hydrator: - - def __init__(self, graph): - self.graph = graph - - def hydrate_node(self, id_, labels=None, - properties=None, element_id=None): - assert isinstance(self.graph, Graph) - # backwards compatibility with Neo4j < 5.0 - if element_id is None: - element_id = str(id_) - - try: - inst = self.graph._nodes[element_id] - except KeyError: - inst = self.graph._nodes[element_id] = Node( - self.graph, element_id, id_, labels, properties - ) - else: - # If we have already hydrated this node as the endpoint of - # a relationship, it won't have any labels or properties. - # Therefore, we need to add the ones we have here. - if labels: - inst._labels = inst._labels.union(labels) # frozen_set - if properties: - inst._properties.update(properties) - return inst - - def hydrate_relationship(self, id_, n0_id, n1_id, type_, - properties=None, element_id=None, - n0_element_id=None, n1_element_id=None): - # backwards compatibility with Neo4j < 5.0 - if element_id is None: - element_id = str(id_) - if n0_element_id is None: - n0_element_id = str(n0_id) - if n1_element_id is None: - n1_element_id = str(n1_id) - - inst = self.hydrate_unbound_relationship(id_, type_, properties, - element_id) - inst._start_node = self.hydrate_node(n0_id, - element_id=n0_element_id) - inst._end_node = self.hydrate_node(n1_id, element_id=n1_element_id) - return inst - - def hydrate_unbound_relationship(self, id_, type_, properties=None, - element_id=None): - assert isinstance(self.graph, Graph) - # backwards compatibility with Neo4j < 5.0 - if element_id is None: - element_id = str(id_) - - try: - inst = self.graph._relationships[element_id] - except KeyError: - r = self.graph.relationship_type(type_) - inst = self.graph._relationships[element_id] = r( - self.graph, element_id, id_, properties - ) - return inst - - def hydrate_path(self, nodes, relationships, sequence): - assert isinstance(self.graph, Graph) - assert len(nodes) >= 1 - assert len(sequence) % 2 == 0 - last_node = nodes[0] - entities = [last_node] - for i, rel_index in enumerate(sequence[::2]): - assert rel_index != 0 - next_node = nodes[sequence[2 * i + 1]] - if rel_index > 0: - r = relationships[rel_index - 1] - r._start_node = last_node - r._end_node = next_node - entities.append(r) - else: - r = relationships[-rel_index - 1] - r._start_node = next_node - r._end_node = last_node - entities.append(r) - last_node = next_node - return Path(*entities) - class Entity(Mapping): """ Base class for :class:`.Node` and :class:`.Relationship` that diff --git a/neo4j/spatial/__init__.py b/neo4j/spatial/__init__.py index 243f72c5f..f530d3da0 100644 --- a/neo4j/spatial/__init__.py +++ b/neo4j/spatial/__init__.py @@ -30,112 +30,38 @@ "WGS84Point", ] - -from threading import Lock - -from neo4j.packstream import Structure - - -# SRID to subclass mappings -__srid_table = {} -__srid_table_lock = Lock() - - -class Point(tuple): - """Base-class for spatial data. - - A point within a geometric space. This type is generally used via its - subclasses and should not be instantiated directly unless there is no - subclass defined for the required SRID. - - :param iterable: - An iterable of coordinates. - All items will be converted to :class:`float`. - """ - - #: The SRID (spatial reference identifier) of the spatial data. - #: A number that identifies the coordinate system the spatial type is to be - #: interpreted in. - #: - #: :type: int - srid = None - - def __new__(cls, iterable): - return tuple.__new__(cls, map(float, iterable)) - - def __repr__(self): - return "POINT(%s)" % " ".join(map(str, self)) - - def __eq__(self, other): - try: - return type(self) is type(other) and tuple(self) == tuple(other) - except (AttributeError, TypeError): - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(type(self)) ^ hash(tuple(self)) - - -def point_type(name, fields, srid_map): - """ Dynamically create a Point subclass. - """ - - def srid(self): - try: - return srid_map[len(self)] - except KeyError: - return None - - attributes = {"srid": property(srid)} - - for index, subclass_field in enumerate(fields): - - def accessor(self, i=index, f=subclass_field): - try: - return self[i] - except IndexError: - raise AttributeError(f) - - for field_alias in {subclass_field, "xyz"[index]}: - attributes[field_alias] = property(accessor) - - cls = type(name, (Point,), attributes) - - with __srid_table_lock: - for dim, srid in srid_map.items(): - __srid_table[srid] = (cls, dim) - - return cls - - -# Point subclass definitions -CartesianPoint = point_type("CartesianPoint", ["x", "y", "z"], - {2: 7203, 3: 9157}) -WGS84Point = point_type("WGS84Point", ["longitude", "latitude", "height"], - {2: 4326, 3: 4979}) - - +from functools import wraps + +from .._codec.hydration.v1 import spatial as _hydration +from .._meta import deprecated +from .._spatial import ( + CartesianPoint, + Point, + point_type as _point_type, + WGS84Point, +) + + +# TODO: 6.0 remove +@deprecated( + "hydrate_point is considered an internal function and will be removed in " + "a future version" +) def hydrate_point(srid, *coordinates): """ Create a new instance of a Point subclass from a raw set of fields. The subclass chosen is determined by the given SRID; a ValueError will be raised if no such subclass can be found. """ - try: - point_class, dim = __srid_table[srid] - except KeyError: - point = Point(coordinates) - point.srid = srid - return point - else: - if len(coordinates) != dim: - raise ValueError("SRID %d requires %d coordinates (%d provided)" % (srid, dim, len(coordinates))) - return point_class(coordinates) + return _hydration.hydrate_point(srid, *coordinates) +# TODO: 6.0 remove +@deprecated( + "hydrate_point is considered an internal function and will be removed in " + "a future version" +) +@wraps(_hydration.dehydrate_point) def dehydrate_point(value): """ Dehydrator for Point data. @@ -143,10 +69,30 @@ def dehydrate_point(value): :type value: Point :return: """ - dim = len(value) - if dim == 2: - return Structure(b"X", value.srid, *value) - elif dim == 3: - return Structure(b"Y", value.srid, *value) - else: - raise ValueError("Cannot dehydrate Point with %d dimensions" % dim) + return _hydration.dehydrate_point(value) + + +# TODO: 6.0 remove +@deprecated( + "hydrate_point is considered an internal function and will be removed in " + "a future version" +) +@wraps(_hydration.dehydrate_point) +def dehydrate_point(value): + """ Dehydrator for Point data. + + :param value: + :type value: Point + :return: + """ + return _hydration.dehydrate_point(value) + + +# TODO: 6.0 remove +@deprecated( + "point_type is considered an internal function and will be removed in " + "a future version" +) +@wraps(_point_type) +def point_type(name, fields, srid_map): + return _point_type(name, fields, srid_map) diff --git a/neo4j/time/__init__.py b/neo4j/time/__init__.py index b0302e316..d30d705e3 100644 --- a/neo4j/time/__init__.py +++ b/neo4j/time/__init__.py @@ -37,19 +37,36 @@ struct_time, ) -from neo4j.time.arithmetic import ( +from ._arithmetic import ( nano_add, nano_div, round_half_to_even, symmetric_divmod, ) -from neo4j.time.metaclasses import ( +from ._metaclasses import ( DateTimeType, DateType, TimeType, ) +__all__ = [ + "MIN_INT64", + "MAX_INT64", + "MIN_YEAR", + "MAX_YEAR", + "Duration", + "Date", + "ZeroDate", + "Time", + "Midnight", + "Midday", + "DateTime", + "Never", + "UnixEpoch", +] + + MIN_INT64 = -(2 ** 63) MAX_INT64 = (2 ** 63) - 1 @@ -241,7 +258,7 @@ class Clock: def __new__(cls): if cls.__implementations is None: # Find an available clock with the best precision - import neo4j.time.clock_implementations + import neo4j.time._clock_implementations cls.__implementations = sorted((clock for clock in Clock.__subclasses__() if clock.available()), key=lambda clock: clock.precision(), reverse=True) if not cls.__implementations: diff --git a/neo4j/time/arithmetic.py b/neo4j/time/_arithmetic.py similarity index 95% rename from neo4j/time/arithmetic.py rename to neo4j/time/_arithmetic.py index 6ab7b6581..93bfe8eda 100644 --- a/neo4j/time/arithmetic.py +++ b/neo4j/time/_arithmetic.py @@ -16,6 +16,15 @@ # limitations under the License. +__all__ = [ + "nano_add", + "nano_div", + "nano_divmod", + "symmetric_divmod", + "round_half_to_even", +] + + def nano_add(x, y): """ diff --git a/neo4j/time/clock_implementations.py b/neo4j/time/_clock_implementations.py similarity index 97% rename from neo4j/time/clock_implementations.py rename to neo4j/time/_clock_implementations.py index 38e93ab5b..cbfaf0f27 100644 --- a/neo4j/time/clock_implementations.py +++ b/neo4j/time/_clock_implementations.py @@ -32,6 +32,13 @@ from neo4j.time.arithmetic import nano_divmod +__all__ = [ + "SafeClock", + "PEP564Clock", + "LibCClock", +] + + class SafeClock(Clock): """ Clock implementation that should work for any variant of Python. This clock is guaranteed microsecond precision. diff --git a/neo4j/time/metaclasses.py b/neo4j/time/_metaclasses.py similarity index 96% rename from neo4j/time/metaclasses.py rename to neo4j/time/_metaclasses.py index 95be7e96c..cf9022fbb 100644 --- a/neo4j/time/metaclasses.py +++ b/neo4j/time/_metaclasses.py @@ -16,6 +16,13 @@ # limitations under the License. +__all__ = [ + "DateType", + "TimeType", + "DateTimeType", +] + + class DateType(type): def __getattr__(cls, name): diff --git a/setup.cfg b/setup.cfg index 421144bec..0f6574a96 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,9 +2,13 @@ combine_as_imports=true ensure_newline_before_comments=true force_grid_wrap=2 -force_sort_within_sections=true +# breaks order of relative imports +# https://github.com/PyCQA/isort/issues/1944 +#force_sort_within_sections=true include_trailing_comma=true -#lines_before_imports=2 # currently broken +# currently broken +# https://github.com/PyCQA/isort/issues/1855 +#lines_before_imports=2 lines_after_imports=2 lines_between_sections=1 multi_line_output=3 diff --git a/setup.py b/setup.py index 31ce78611..7f7f38b7f 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ setup, ) -from neo4j.meta import ( +from neo4j._meta import ( package, version, ) diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index fce7386af..b939e9803 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -17,6 +17,7 @@ import asyncio +import traceback from inspect import ( getmembers, isfunction, @@ -26,7 +27,6 @@ loads, ) from pathlib import Path -import traceback from neo4j._exceptions import BoltError from neo4j.exceptions import ( @@ -35,13 +35,13 @@ UnsupportedServerProduct, ) -from . import requests from .._driver_logger import ( buffer_handler, log, ) from ..backend import Request from ..exceptions import MarkdAsDriverException +from . import requests TESTKIT_BACKEND_PATH = Path(__file__).absolute().resolve().parents[1] diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 6dbd38561..0ea955cb6 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -17,9 +17,9 @@ import json -from os import path import re import warnings +from os import path import neo4j from neo4j._async_compat.util import AsyncUtil diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index 625e2ee9e..75a23d66d 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -17,6 +17,7 @@ import asyncio +import traceback from inspect import ( getmembers, isfunction, @@ -26,7 +27,6 @@ loads, ) from pathlib import Path -import traceback from neo4j._exceptions import BoltError from neo4j.exceptions import ( @@ -35,13 +35,13 @@ UnsupportedServerProduct, ) -from . import requests from .._driver_logger import ( buffer_handler, log, ) from ..backend import Request from ..exceptions import MarkdAsDriverException +from . import requests TESTKIT_BACKEND_PATH = Path(__file__).absolute().resolve().parents[1] diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 798038fd0..776cfa95a 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -17,9 +17,9 @@ import json -from os import path import re import warnings +from os import path import neo4j from neo4j._async_compat.util import Util diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index bdf420f5f..51dd43a96 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -40,6 +40,7 @@ "Feature:Bolt:4.3": true, "Feature:Bolt:4.4": true, "Feature:Bolt:5.0": true, + "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", "Feature:TLS:1.2": true, diff --git a/tests/conftest.py b/tests/conftest.py index 6a62e1129..4bcc74353 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,9 +17,9 @@ import asyncio +import warnings from functools import wraps from os import environ -import warnings import pytest import pytest_asyncio diff --git a/tests/env.py b/tests/env.py index ea9fb46ee..8e4b077ef 100644 --- a/tests/env.py +++ b/tests/env.py @@ -17,8 +17,8 @@ import abc -from os import environ import sys +from os import environ class _LazyEval(abc.ABC): diff --git a/tests/unit/async_/io/conftest.py b/tests/unit/async_/io/conftest.py index 05bbe8b94..08b6e9c41 100644 --- a/tests/unit/async_/io/conftest.py +++ b/tests/unit/async_/io/conftest.py @@ -24,20 +24,22 @@ import pytest -from neo4j._async.io._common import AsyncMessageInbox -from neo4j.packstream import ( - Packer, - UnpackableBuffer, - Unpacker, +from neo4j._async.io._common import ( + AsyncInbox, + AsyncOutbox, ) class AsyncFakeSocket: - def __init__(self, address): + def __init__(self, address, unpacker_cls=None): self.address = address self.captured = b"" - self.messages = AsyncMessageInbox(self, on_error=print) + self.messages = None + if unpacker_cls is not None: + self.messages = AsyncInbox( + self, on_error=print, unpacker_cls=unpacker_cls + ) def getsockname(self): return "127.0.0.1", 0xFFFF @@ -59,16 +61,27 @@ def close(self): return async def pop_message(self): - return await self.messages.pop() + assert self.messages + return await self.messages.pop(None) class AsyncFakeSocket2: - def __init__(self, address=None, on_send=None): + def __init__(self, address=None, on_send=None, + packer_cls=None, unpacker_cls=None): self.address = address self.recv_buffer = bytearray() - self._messages = AsyncMessageInbox(self, on_error=print) + # self.messages = AsyncMessageInbox(self, on_error=print) self.on_send = on_send + self._outbox = self._messages = None + if packer_cls: + self._outbox = AsyncOutbox( + self, on_error=print, packer_cls=packer_cls + ) + if unpacker_cls: + self._messages = AsyncInbox( + self, on_error=print, unpacker_cls=unpacker_cls + ) def getsockname(self): return "127.0.0.1", 0xFFFF @@ -93,50 +106,25 @@ def close(self): def inject(self, data): self.recv_buffer += data - def _pop_chunk(self): - chunk_size, = struct_unpack(">H", self.recv_buffer[:2]) - print("CHUNK SIZE %r" % chunk_size) - end = 2 + chunk_size - chunk_data, self.recv_buffer = self.recv_buffer[2:end], self.recv_buffer[end:] - return chunk_data - async def pop_message(self): - data = bytearray() - while True: - chunk = self._pop_chunk() - print("CHUNK %r" % chunk) - if chunk: - data.extend(chunk) - elif data: - break # end of message - else: - continue # NOOP - header = data[0] - n_fields = header % 0x10 - tag = data[1] - buffer = UnpackableBuffer(data[2:]) - unpacker = Unpacker(buffer) - fields = [unpacker.unpack() for _ in range(n_fields)] - return tag, fields + assert self._messages + return await self._messages.pop(None) async def send_message(self, tag, *fields): - data = self.encode_message(tag, *fields) - await self.sendall(struct_pack(">H", len(data)) + data + b"\x00\x00") - - @classmethod - def encode_message(cls, tag, *fields): - b = BytesIO() - packer = Packer(b) - for field in fields: - packer.pack(field) - return bytearray([0xB0 + len(fields), tag]) + b.getvalue() + assert self._outbox + self._outbox.append_message(tag, fields, None) + await self._outbox.flush() class AsyncFakeSocketPair: - def __init__(self, address): - self.client = AsyncFakeSocket2(address) - self.server = AsyncFakeSocket2() + def __init__(self, address, packer_cls=None, unpacker_cls=None): + self.client = AsyncFakeSocket2( + address, packer_cls=packer_cls, unpacker_cls=unpacker_cls + ) + self.server = AsyncFakeSocket2( + packer_cls=packer_cls, unpacker_cls=unpacker_cls + ) self.client.on_send = self.server.inject self.server.on_send = self.client.inject diff --git a/tests/unit/async_/io/test__common.py b/tests/unit/async_/io/test__common.py index bc95738a3..1c14ea202 100644 --- a/tests/unit/async_/io/test__common.py +++ b/tests/unit/async_/io/test__common.py @@ -18,33 +18,43 @@ import pytest -from neo4j._async.io._common import Outbox +from neo4j._async.io._common import AsyncOutbox +from neo4j._codec.packstream.v1 import PackableBuffer + +from ...._async_compat import mark_async_test @pytest.mark.parametrize(("chunk_size", "data", "result"), ( ( 2, - (bytes(range(10, 15)),), + bytes(range(10, 15)), bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)) ), ( 2, - (bytes(range(10, 14)),), + bytes(range(10, 14)), bytes((0, 2, 10, 11, 0, 2, 12, 13)) ), ( 2, - (bytes((5, 6, 7)), bytes((8, 9))), - bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9)) + bytes((5,)), + bytes((0, 1, 5)) ), )) -def test_async_outbox_chunking(chunk_size, data, result): - outbox = Outbox(max_chunk_size=chunk_size) - assert bytes(outbox.view()) == b"" - for d in data: - outbox.write(d) - assert bytes(outbox.view()) == result - # make sure this works multiple times - assert bytes(outbox.view()) == result - outbox.clear() - assert bytes(outbox.view()) == b"" +@mark_async_test +async def test_async_outbox_chunking(chunk_size, data, result, mocker): + buffer = PackableBuffer() + socket_mock = mocker.AsyncMock() + packer_mock = mocker.Mock() + packer_mock.return_value = packer_mock + packer_mock.new_packable_buffer.return_value = buffer + packer_mock.pack_struct.side_effect = \ + lambda *args, **kwargs: buffer.write(data) + outbox = AsyncOutbox(socket_mock, pytest.fail, packer_mock, chunk_size) + outbox.append_message(None, None, None) + socket_mock.sendall.assert_not_called() + assert await outbox.flush() + socket_mock.sendall.assert_awaited_once_with(result + b"\x00\x00") + + assert not await outbox.flush() + socket_mock.sendall.assert_awaited_once() diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index 8644aaded..22524e02b 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -72,7 +72,7 @@ def test_db_extra_not_supported_in_run(fake_socket): @mark_async_test async def test_simple_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime) connection.discard() await connection.send_all() @@ -84,7 +84,7 @@ async def test_simple_discard(fake_socket): @mark_async_test async def test_simple_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime) connection.pull() await connection.send_all() @@ -99,9 +99,11 @@ async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair( + address, AsyncBolt3.PACKER_CLS, AsyncBolt3.UNPACKER_CLS + ) sockets.client.settimeout = mocker.AsyncMock() - await sockets.server.send_message(0x70, { + await sockets.server.send_message(b"\x70", { "server": "Neo4j/3.5.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index 7ba714e8a..64e1a8ddb 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -57,7 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") await connection.send_all() @@ -70,7 +70,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") await connection.send_all() @@ -85,7 +85,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() @@ -105,7 +105,7 @@ async def test_n_extra_in_discard(fake_socket): @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() @@ -125,7 +125,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() @@ -145,7 +145,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() @@ -165,7 +165,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() @@ -178,7 +178,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() @@ -194,9 +194,11 @@ async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x0.PACKER_CLS, + unpacker_cls=AsyncBolt4x0.UNPACKER_CLS) sockets.client.settimeout = mocker.MagicMock() - await sockets.server.send_message(0x70, { + await sockets.server.send_message(b"\x70", { "server": "Neo4j/4.0.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index 47cca348e..a647fd5d7 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -57,7 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") await connection.send_all() @@ -70,7 +70,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") await connection.send_all() @@ -85,7 +85,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() @@ -105,7 +105,7 @@ async def test_n_extra_in_discard(fake_socket): @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() @@ -126,7 +126,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() @@ -146,7 +146,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() @@ -167,7 +167,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() @@ -180,7 +180,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() @@ -193,15 +193,17 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - await sockets.server.send_message(0x70, {"server": "Neo4j/4.1.0"}) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x1.PACKER_CLS, + unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.1.0"}) connection = AsyncBolt4x1( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() tag, fields = await sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -212,9 +214,11 @@ async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x1.PACKER_CLS, + unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) sockets.client.settimeout = mocker.AsyncMock() - await sockets.server.send_message(0x70, { + await sockets.server.send_message(b"\x70", { "server": "Neo4j/4.1.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index bb3921c8b..fb77b3733 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -57,7 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") await connection.send_all() @@ -70,7 +70,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") await connection.send_all() @@ -85,7 +85,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() @@ -105,7 +105,7 @@ async def test_n_extra_in_discard(fake_socket): @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() @@ -126,7 +126,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() @@ -146,7 +146,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() @@ -167,7 +167,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() @@ -180,7 +180,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() @@ -193,15 +193,17 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - await sockets.server.send_message(0x70, {"server": "Neo4j/4.2.0"}) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x2.PACKER_CLS, + unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.2.0"}) connection = AsyncBolt4x2( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() tag, fields = await sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -212,9 +214,11 @@ async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x2.PACKER_CLS, + unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) sockets.client.settimeout = mocker.AsyncMock() - await sockets.server.send_message(0x70, { + await sockets.server.send_message(b"\x70", { "server": "Neo4j/4.2.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index fff16687e..36698de1c 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -59,7 +59,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") await connection.send_all() @@ -72,7 +72,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") await connection.send_all() @@ -87,7 +87,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() @@ -107,7 +107,7 @@ async def test_n_extra_in_discard(fake_socket): @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() @@ -128,7 +128,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() @@ -148,7 +148,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() @@ -169,7 +169,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() @@ -182,7 +182,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() @@ -195,15 +195,17 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - await sockets.server.send_message(0x70, {"server": "Neo4j/4.3.0"}) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x3.PACKER_CLS, + unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.0"}) connection = AsyncBolt4x3( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() tag, fields = await sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -225,10 +227,12 @@ async def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x3.PACKER_CLS, + unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) sockets.client.settimeout = mocker.AsyncMock() await sockets.server.send_message( - 0x70, {"server": "Neo4j/4.3.0", "hints": hints} + b"\x70", {"server": "Neo4j/4.3.0", "hints": hints} ) connection = AsyncBolt4x3( address, sockets.client, PoolConfig.max_connection_lifetime diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index 5507fbc7f..62a0dcec3 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -68,7 +68,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) await connection.send_all() @@ -89,7 +89,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): @mark_async_test async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) await connection.send_all() @@ -101,7 +101,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_async_test async def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) await connection.send_all() @@ -121,7 +121,7 @@ async def test_n_extra_in_discard(fake_socket): @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) await connection.send_all() @@ -142,7 +142,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) await connection.send_all() @@ -162,7 +162,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) await connection.send_all() @@ -183,7 +183,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): async def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) await connection.send_all() @@ -196,7 +196,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) await connection.send_all() @@ -209,15 +209,17 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - await sockets.server.send_message(0x70, {"server": "Neo4j/4.4.0"}) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x4.PACKER_CLS, + unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) connection = AsyncBolt4x4( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) await connection.hello() tag, fields = await sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -239,10 +241,12 @@ async def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x4.PACKER_CLS, + unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) sockets.client.settimeout = mocker.MagicMock() await sockets.server.send_message( - 0x70, {"server": "Neo4j/4.3.4", "hints": hints} + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) connection = AsyncBolt4x4( address, sockets.client, PoolConfig.max_connection_lifetime diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index 30ffe5832..50fb3c932 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -14,9 +14,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from re import match -from unittest import mock + + import warnings +from unittest import mock import pandas as pd import pytest @@ -33,9 +34,9 @@ Version, ) from neo4j._async_compat.util import AsyncUtil -from neo4j.data import ( - DataDehydrator, - DataHydrator, +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j._data import ( Node, Relationship, ) @@ -44,7 +45,6 @@ EntitySetView, Graph, ) -from neo4j.packstream import Structure from ...._async_compat import mark_async_test @@ -52,9 +52,24 @@ class Records: def __init__(self, fields, records): self.fields = tuple(fields) + self.hydration_scope = HydrationHandler().new_hydration_scope() self.records = tuple(records) + self._hydrate_records() + assert all(len(self.fields) == len(r) for r in self.records) + def _hydrate_records(self): + def _hydrate(value): + if type(value) in self.hydration_scope.hydration_hooks: + return self.hydration_scope.hydration_hooks[type(value)](value) + if isinstance(value, (list, tuple)): + return type(value)(_hydrate(v) for v in value) + if isinstance(value, dict): + return {k: _hydrate(v) for k, v in value.items()} + return value + + self.records = tuple(_hydrate(r) for r in self.records) + def __len__(self): return self.records.__len__() @@ -113,6 +128,7 @@ def __init__(self, records=None, run_meta=None, summary_meta=None, self.summary_meta = summary_meta AsyncConnectionStub.server_info.update({"server": "Neo4j/4.3.0"}) self.unresolved_address = None + self._new_hydration_scope_called = False async def send_all(self): self.sent += self.queued @@ -187,10 +203,20 @@ def pull(self, *args, **kwargs): def defunct(self): return False + def new_hydration_scope(self): + class FakeHydrationScope: + hydration_hooks = None + dehydration_hooks = None -class HydratorStub(DataHydrator): - def hydrate(self, values): - return values + def get_graph(self): + return Graph() + + if len(self._records) > 1: + return FakeHydrationScope() + assert not self._new_hydration_scope_called + assert self._records + self._new_hydration_scope_called = True + return self._records[0].hydration_scope def noop(*_, **__): @@ -254,7 +280,7 @@ async def fetch_and_compare_all_records( @mark_async_test async def test_result_iteration(method, records): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), 2, noop, noop) + result = AsyncResult(connection, 2, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) await fetch_and_compare_all_records(result, "x", records, method) @@ -263,7 +289,7 @@ async def test_result_iteration(method, records): async def test_result_iteration_mixed_methods(): records = [[i] for i in range(10)] connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), 4, noop, noop) + result = AsyncResult(connection, 4, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) iter1 = AsyncUtil.iter(result) iter2 = AsyncUtil.iter(result) @@ -299,9 +325,9 @@ async def test_parallel_result_iteration(method, invert_fetch): connection = AsyncConnectionStub( records=(Records(["x"], records1), Records(["x"], records2)) ) - result1 = AsyncResult(connection, HydratorStub(), 2, noop, noop) + result1 = AsyncResult(connection, 2, noop, noop) await result1._run("CYPHER1", {}, None, None, "r", None) - result2 = AsyncResult(connection, HydratorStub(), 2, noop, noop) + result2 = AsyncResult(connection, 2, noop, noop) await result2._run("CYPHER2", {}, None, None, "r", None) if invert_fetch: await fetch_and_compare_all_records( @@ -329,9 +355,9 @@ async def test_interwoven_result_iteration(method, invert_fetch): connection = AsyncConnectionStub( records=(Records(["x"], records1), Records(["y"], records2)) ) - result1 = AsyncResult(connection, HydratorStub(), 2, noop, noop) + result1 = AsyncResult(connection, 2, noop, noop) await result1._run("CYPHER1", {}, None, None, "r", None) - result2 = AsyncResult(connection, HydratorStub(), 2, noop, noop) + result2 = AsyncResult(connection, 2, noop, noop) await result2._run("CYPHER2", {}, None, None, "r", None) start = 0 for n in (1, 2, 3, 1, None): @@ -358,7 +384,7 @@ async def test_interwoven_result_iteration(method, invert_fetch): @mark_async_test async def test_result_peek(records, fetch_size): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + result = AsyncResult(connection, fetch_size, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) for i in range(len(records) + 1): record = await result.peek() @@ -381,7 +407,7 @@ async def test_result_single_non_strict(records, fetch_size, default): kwargs["strict"] = False connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + result = AsyncResult(connection, fetch_size, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) if len(records) == 0: assert await result.single(**kwargs) is None @@ -400,7 +426,7 @@ async def test_result_single_non_strict(records, fetch_size, default): @mark_async_test async def test_result_single_strict(records, fetch_size): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + result = AsyncResult(connection, fetch_size, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) try: record = await result.single(strict=True) @@ -427,7 +453,7 @@ async def test_result_single_strict(records, fetch_size): @mark_async_test async def test_result_single_exhausts_records(records, fetch_size, strict): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + result = AsyncResult(connection, fetch_size, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) try: with warnings.catch_warnings(): @@ -449,7 +475,7 @@ async def test_result_single_exhausts_records(records, fetch_size, strict): @mark_async_test async def test_result_fetch(records, fetch_size, strict): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, HydratorStub(), fetch_size, noop, noop) + result = AsyncResult(connection, fetch_size, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) assert await result.fetch(0) == [] assert await result.fetch(-1) == [] @@ -461,7 +487,7 @@ async def test_result_fetch(records, fetch_size, strict): @mark_async_test async def test_keys_are_available_before_and_after_stream(): connection = AsyncConnectionStub(records=Records(["x"], [[1], [2]])) - result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) assert list(result.keys()) == ["x"] await AsyncUtil.list(result) @@ -477,7 +503,7 @@ async def test_consume(records, consume_one, summary_meta, consume_times): connection = AsyncConnectionStub( records=Records(["x"], records), summary_meta=summary_meta ) - result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) if consume_one: try: @@ -512,7 +538,7 @@ async def test_time_in_summary(t_first, t_last): summary_meta=summary_meta ) - result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) summary = await result.consume() @@ -534,7 +560,7 @@ async def test_time_in_summary(t_first, t_last): async def test_counts_in_summary(): connection = AsyncConnectionStub(records=Records(["n"], [[1], [2]])) - result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) summary = await result.consume() @@ -548,7 +574,7 @@ async def test_query_type(query_type): records=Records(["n"], [[1], [2]]), summary_meta={"type": query_type} ) - result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) summary = await result.consume() @@ -563,7 +589,7 @@ async def test_data(num_records): records=Records(["n"], [[i + 1] for i in range(num_records)]) ) - result = AsyncResult(connection, HydratorStub(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) await result._buffer_all() records = result._record_buffer.copy() @@ -578,6 +604,7 @@ async def test_data(num_records): assert record.data.called_once_with("hello", "world") +# TODO: dehydration now happens on a much lower level @pytest.mark.parametrize("records", ( Records(["n"], []), Records(["n"], [[42], [69], [420], [1337]]), @@ -603,8 +630,9 @@ async def test_result_graph(records, async_scripted_connection): "on_summary": None }), )) - result = AsyncResult(async_scripted_connection, DataHydrator(), 1, noop, - noop) + async_scripted_connection.new_hydration_scope.return_value = \ + records.hydration_scope + result = AsyncResult(async_scripted_connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) graph = await result.graph() assert isinstance(graph, Graph) @@ -702,7 +730,7 @@ async def test_result_graph(records, async_scripted_connection): @mark_async_test async def test_to_df(keys, values, types, instances, test_default_expand): connection = AsyncConnectionStub(records=Records(keys, values)) - result = AsyncResult(connection, DataHydrator(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) if test_default_expand: df = await result.to_df() @@ -807,12 +835,12 @@ async def test_to_df(keys, values, types, instances, test_default_expand): ( ["n"], list(zip(( - Structure(b"N", 0, ["LABEL_A"], - {"a": 1, "b": 2, "d": 1}, "00"), - Structure(b"N", 2, ["LABEL_B"], - {"a": 1, "c": 1.2, "d": 2}, "02"), - Structure(b"N", 1, ["LABEL_A", "LABEL_B"], - {"a": [1, "a"], "d": 3}, "01"), + Node(None, "00", 0, ["LABEL_A"], + {"a": 1, "b": 2, "d": 1}), + Node(None, "02", 2, ["LABEL_B"], + {"a": 1, "c": 1.2, "d": 2}), + Node(None, "01", 1, ["LABEL_A", "LABEL_B"], + {"a": [1, "a"], "d": 3}), ))), [ "n().element_id", "n().labels", "n().prop.a", "n().prop.b", @@ -848,11 +876,7 @@ async def test_to_df(keys, values, types, instances, test_default_expand): ), ( ["dt"], - [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - ], + [[neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)]], ["dt"], [[neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)]], ["object"], @@ -863,7 +887,7 @@ async def test_to_df(keys, values, types, instances, test_default_expand): async def test_to_df_expand(keys, values, expected_columns, expected_rows, expected_types): connection = AsyncConnectionStub(records=Records(keys, values)) - result = AsyncResult(connection, DataHydrator(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) df = await result.to_df(expand=True) @@ -895,9 +919,7 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["dt"], [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], ], pd.DataFrame( [[pd.Timestamp("2022-01-02 03:04:05.000000006")]], @@ -908,9 +930,7 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["d"], [ - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), + [neo4j_time.Date(2222, 2, 22)], ], pd.DataFrame( [[pd.Timestamp("2222-02-22")]], @@ -921,11 +941,11 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["dt_tz"], [ - DataDehydrator().dehydrate(( + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], ], pd.DataFrame( [[ @@ -941,17 +961,13 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ["mixed"], [ [None], - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), - DataDehydrator().dehydrate(( + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], + [neo4j_time.Date(2222, 2, 22)], + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], ], pd.DataFrame( [ @@ -971,18 +987,14 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["mixed"], [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], + [neo4j_time.Date(2222, 2, 22)], [None], - DataDehydrator().dehydrate(( + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], ], pd.DataFrame( [ @@ -1002,17 +1014,13 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["mixed"], [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), - DataDehydrator().dehydrate(( + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6),], + [neo4j_time.Date(2222, 2, 22),], + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], [None], ], pd.DataFrame( @@ -1052,9 +1060,7 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, ], [ None, - *DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), + neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), 1.234, ], ], @@ -1080,7 +1086,7 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, @mark_async_test async def test_to_df_parse_dates(keys, values, expected_df, expand): connection = AsyncConnectionStub(records=Records(keys, values)) - result = AsyncResult(connection, DataHydrator(), 1, noop, noop) + result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) df = await result.to_df(expand=expand, parse_dates=True) diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index 3dcb03828..d44ceb083 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -37,7 +37,15 @@ @pytest.fixture() def pool(async_fake_connection_generator, mocker): pool = mocker.AsyncMock(spec=AsyncIOPool) - pool.acquire.side_effect = iter(async_fake_connection_generator, 0) + assert not hasattr(pool, "acquired_connection_mocks") + pool.acquired_connection_mocks = [] + + def acquire_side_effect(*_, **__): + connection = async_fake_connection_generator() + pool.acquired_connection_mocks.append(connection) + return connection + + pool.acquire.side_effect = acquire_side_effect return pool @@ -267,57 +275,46 @@ async def test_session_tx_type(pool): assert isinstance(tx, AsyncTransaction) -@pytest.mark.parametrize(("parameters", "error_type"), ( - ({"x": None}, None), - ({"x": True}, None), - ({"x": False}, None), - ({"x": 123456789}, None), - ({"x": 3.1415926}, None), - ({"x": float("nan")}, None), - ({"x": float("inf")}, None), - ({"x": float("-inf")}, None), - ({"x": "foo"}, None), - ({"x": bytearray([0x00, 0x33, 0x66, 0x99, 0xCC, 0xFF])}, None), - ({"x": b"\x00\x33\x66\x99\xcc\xff"}, None), - ({"x": [1, 2, 3]}, None), - ({"x": ["a", "b", "c"]}, None), - ({"x": ["a", 2, 1.234]}, None), - ({"x": ["a", 2, ["c"]]}, None), - ({"x": {"one": "eins", "two": "zwei", "three": "drei"}}, None), - ({"x": {"one": ["eins", "uno", 1], "two": ["zwei", "dos", 2]}}, None), - - # maps must have string keys - ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), - ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), +@pytest.mark.parametrize("parameters", ( + {"x": None}, + {"x": True}, + {"x": False}, + {"x": 123456789}, + {"x": 3.1415926}, + {"x": float("nan")}, + {"x": float("inf")}, + {"x": float("-inf")}, + {"x": "foo"}, + {"x": bytearray([0x00, 0x33, 0x66, 0x99, 0xCC, 0xFF])}, + {"x": b"\x00\x33\x66\x99\xcc\xff"}, + {"x": [1, 2, 3]}, + {"x": ["a", "b", "c"]}, + {"x": ["a", 2, 1.234]}, + {"x": ["a", 2, ["c"]]}, + {"x": {"one": "eins", "two": "zwei", "three": "drei"}}, + {"x": {"one": ["eins", "uno", 1], "two": ["zwei", "dos", 2]}}, )) @pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed")) @mark_async_test async def test_session_run_with_parameters( - pool, parameters, error_type, run_type + pool, parameters, run_type, mocker ): - @contextmanager - def raises(): - if error_type is not None: - with pytest.raises(error_type) as exc: - yield exc - else: - yield None - async with AsyncSession(pool, SessionConfig()) as session: if run_type == "auto": - with raises(): - await session.run("RETURN $x", **parameters) + await session.run("RETURN $x", **parameters) elif run_type == "unmanaged": tx = await session.begin_transaction() - with raises(): - await tx.run("RETURN $x", **parameters) + await tx.run("RETURN $x", **parameters) elif run_type == "managed": async def work(tx): - with raises() as exc: - await tx.run("RETURN $x", **parameters) - if exc is not None: - raise exc - with raises(): - await session.write_transaction(work) + await tx.run("RETURN $x", **parameters) + await session.write_transaction(work) else: raise ValueError(run_type) + + assert len(pool.acquired_connection_mocks) == 1 + connection_mock = pool.acquired_connection_mocks[0] + assert connection_mock.run.called_once() + call = connection_mock.run.call_args + assert call.args[0] == "RETURN $x" + assert call.kwargs["parameters"] == parameters diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index 86e968cf1..7fa36ab76 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -113,23 +113,6 @@ class OopsError(RuntimeError): assert tx_.closed() -@pytest.mark.parametrize(("parameters", "error_type"), ( - # maps must have string keys - ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), - ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), - ({"x": uuid4()}, TypeError), -)) -@mark_async_test -async def test_transaction_run_with_invalid_parameters( - async_fake_connection, parameters, error_type -): - on_closed = MagicMock() - on_error = MagicMock() - tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error) - with pytest.raises(error_type): - await tx.run("RETURN $x", **parameters) - - @mark_async_test async def test_transaction_run_takes_no_query_object(async_fake_connection): on_closed = MagicMock() diff --git a/tests/unit/common/codec/__init__.py b/tests/unit/common/codec/__init__.py new file mode 100644 index 000000000..c42cc6fb6 --- /dev/null +++ b/tests/unit/common/codec/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/codec/hydration/__init__.py b/tests/unit/common/codec/hydration/__init__.py new file mode 100644 index 000000000..c42cc6fb6 --- /dev/null +++ b/tests/unit/common/codec/hydration/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/codec/hydration/_common.py b/tests/unit/common/codec/hydration/_common.py new file mode 100644 index 000000000..a0c924c62 --- /dev/null +++ b/tests/unit/common/codec/hydration/_common.py @@ -0,0 +1,71 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import ( + date, + datetime, + time, + timedelta, +) + +import pytest + +from neo4j._codec.hydration import HydrationScope +from neo4j._codec.hydration.v1 import HydrationHandler as HydrationHandlerV1 +from neo4j._codec.hydration.v2 import HydrationHandler as HydrationHandlerV2 +from neo4j._codec.packstream import Structure +from neo4j.spatial import ( + CartesianPoint, + Point, + WGS84Point, +) +from neo4j.time import ( + Date, + DateTime, + Duration, + Time, +) + + +class HydrationHandlerTestBase: + + @pytest.fixture(params=[HydrationHandlerV1, HydrationHandlerV2]) + def hydration_handler(self, request): + return request.param() + + def test_handler_hydration_scope(self, hydration_handler): + scope = hydration_handler.new_hydration_scope() + assert isinstance(scope, HydrationScope) + + @pytest.fixture + def hydration_scope(self, hydration_handler): + return hydration_handler.new_hydration_scope() + + def test_scope_hydration_keys(self, hydration_scope): + hooks = hydration_scope.hydration_hooks + assert isinstance(hooks, dict) + assert set(hooks.keys()) == {Structure} + + def test_scope_dehydration_keys(self, hydration_scope): + hooks = hydration_scope.dehydration_hooks + assert isinstance(hooks, dict) + assert set(hooks.keys()) == { + date, datetime, time, timedelta, + Date, DateTime, Duration, Time, + CartesianPoint, Point, WGS84Point + } diff --git a/tests/unit/common/codec/hydration/v1/__init__.py b/tests/unit/common/codec/hydration/v1/__init__.py new file mode 100644 index 000000000..c42cc6fb6 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/codec/hydration/v1/_base.py b/tests/unit/common/codec/hydration/v1/_base.py new file mode 100644 index 000000000..4400d3480 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/_base.py @@ -0,0 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + + +class HydrationHandlerTestBase: + @pytest.fixture() + def hydration_handler(self): + raise NotImplementedError() + + @pytest.fixture + def hydration_scope(self, hydration_handler): + return hydration_handler.new_hydration_scope() diff --git a/tests/unit/common/codec/hydration/v1/test_graph_hydration.py b/tests/unit/common/codec/hydration/v1/test_graph_hydration.py new file mode 100644 index 000000000..92f0d5afb --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_graph_hydration.py @@ -0,0 +1,67 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.graph import ( + Graph, + Node, + Relationship, +) + +from ._base import HydrationHandlerTestBase + + +class TestGraphHydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_can_hydrate_node_structure(self, hydration_scope): + struct = Structure(b'N', 123, ["Person"], {"name": "Alice"}) + alice = hydration_scope.hydration_hooks[Structure](struct) + + assert isinstance(alice, Node) + with pytest.warns(DeprecationWarning, match="element_id"): + assert alice.id == 123 + # for backwards compatibility, the driver should compy the element_id + assert alice.element_id == "123" + assert alice.labels == {"Person"} + assert set(alice.keys()) == {"name"} + assert alice.get("name") == "Alice" + + def test_can_hydrate_relationship_structure(self, hydration_scope): + struct = Structure(b'R', 123, 456, 789, "KNOWS", {"since": 1999}) + rel = hydration_scope.hydration_hooks[Structure](struct) + + assert isinstance(rel, Relationship) + with pytest.warns(DeprecationWarning, match="element_id"): + assert rel.id == 123 + with pytest.warns(DeprecationWarning, match="element_id"): + assert rel.start_node.id == 456 + with pytest.warns(DeprecationWarning, match="element_id"): + assert rel.end_node.id == 789 + # for backwards compatibility, the driver should compy the element_id + assert rel.element_id == "123" + assert rel.start_node.element_id == "456" + assert rel.end_node.element_id == "789" + assert rel.type == "KNOWS" + assert set(rel.keys()) == {"since"} + assert rel.get("since") == 1999 diff --git a/tests/unit/common/codec/hydration/v1/test_hydration_handler.py b/tests/unit/common/codec/hydration/v1/test_hydration_handler.py new file mode 100644 index 000000000..eccc10e18 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_hydration_handler.py @@ -0,0 +1,78 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import ( + date, + datetime, + time, + timedelta, +) + +import pytest + +from neo4j._codec.hydration import HydrationScope +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.graph import Graph +from neo4j.spatial import ( + CartesianPoint, + Point, + WGS84Point, +) +from neo4j.time import ( + Date, + DateTime, + Duration, + Time, +) + +from ._base import HydrationHandlerTestBase + + +class TestHydrationHandler(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_handler_hydration_scope(self, hydration_handler): + scope = hydration_handler.new_hydration_scope() + assert isinstance(scope, HydrationScope) + + @pytest.fixture + def hydration_scope(self, hydration_handler): + return hydration_handler.new_hydration_scope() + + def test_scope_hydration_keys(self, hydration_scope): + hooks = hydration_scope.hydration_hooks + assert isinstance(hooks, dict) + assert set(hooks.keys()) == {Structure} + + def test_scope_dehydration_keys(self, hydration_scope): + hooks = hydration_scope.dehydration_hooks + assert isinstance(hooks, dict) + assert set(hooks.keys()) == { + date, datetime, time, timedelta, + Date, DateTime, Duration, Time, + CartesianPoint, Point, WGS84Point + } + + def test_scope_get_graph(self, hydration_scope): + graph = hydration_scope.get_graph() + assert isinstance(graph, Graph) + assert not graph.nodes + assert not graph.relationships diff --git a/tests/unit/common/codec/hydration/v1/test_spacial_dehydration.py b/tests/unit/common/codec/hydration/v1/test_spacial_dehydration.py new file mode 100644 index 000000000..6486cea52 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_spacial_dehydration.py @@ -0,0 +1,73 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.spatial import ( + CartesianPoint, + Point, + WGS84Point, +) + +from ._base import HydrationHandlerTestBase + + +class TestSpatialDehydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_cartesian_2d(self, hydration_scope): + point = CartesianPoint((1, 3.1)) + struct = hydration_scope.dehydration_hooks[type(point)](point) + assert struct == Structure(b"X", 7203, 1.0, 3.1) + assert all(isinstance(f, float) for f in struct.fields[1:]) + + def test_cartesian_3d(self, hydration_scope): + point = CartesianPoint((1, -2, 3.1)) + struct = hydration_scope.dehydration_hooks[type(point)](point) + assert struct == Structure(b"Y", 9157, 1.0, -2.0, 3.1) + assert all(isinstance(f, float) for f in struct.fields[1:]) + + def test_wgs84_2d(self, hydration_scope): + point = WGS84Point((1, 3.1)) + struct = hydration_scope.dehydration_hooks[type(point)](point) + assert struct == Structure(b"X", 4326, 1.0, 3.1) + assert all(isinstance(f, float) for f in struct.fields[1:]) + + def test_wgs84_3d(self, hydration_scope): + point = WGS84Point((1, -2, 3.1)) + struct = hydration_scope.dehydration_hooks[type(point)](point) + assert struct == Structure(b"Y", 4979, 1.0, -2.0, 3.1) + assert all(isinstance(f, float) for f in struct.fields[1:]) + + def test_custom_point_2d(self, hydration_scope): + point = Point((1, 3.1)) + point.srid = 12345 + struct = hydration_scope.dehydration_hooks[type(point)](point) + assert struct == Structure(b"X", 12345, 1.0, 3.1) + assert all(isinstance(f, float) for f in struct.fields[1:]) + + def test_custom_point_3d(self, hydration_scope): + point = Point((1, -2, 3.1)) + point.srid = 12345 + struct = hydration_scope.dehydration_hooks[type(point)](point) + assert struct == Structure(b"Y", 12345, 1.0, -2.0, 3.1) + assert all(isinstance(f, float) for f in struct.fields[1:]) diff --git a/tests/unit/common/codec/hydration/v1/test_spacial_hydration.py b/tests/unit/common/codec/hydration/v1/test_spacial_hydration.py new file mode 100644 index 000000000..ef4fad6b8 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_spacial_hydration.py @@ -0,0 +1,77 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.spatial import ( + CartesianPoint, + Point, + WGS84Point, +) + +from ._base import HydrationHandlerTestBase + + +class TestSpatialHydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_cartesian_2d(self, hydration_scope): + struct = Structure(b"X", 7203, 1.0, 3.1) + point = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(point, CartesianPoint) + assert point.srid == 7203 + assert tuple(point) == (1.0, 3.1) + + def test_cartesian_3d(self, hydration_scope): + struct = Structure(b"Y", 9157, 1.0, -2.0, 3.1) + point = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(point, CartesianPoint) + assert point.srid == 9157 + assert tuple(point) == (1.0, -2.0, 3.1) + + def test_wgs84_2d(self, hydration_scope): + struct = Structure(b"X", 4326, 1.0, 3.1) + point = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(point, WGS84Point) + assert point.srid == 4326 + assert tuple(point) == (1.0, 3.1) + + def test_wgs84_3d(self, hydration_scope): + struct = Structure(b"Y", 4979, 1.0, -2.0, 3.1) + point = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(point, WGS84Point) + assert point.srid == 4979 + assert tuple(point) == (1.0, -2.0, 3.1) + + def test_custom_point_2d(self, hydration_scope): + struct = Structure(b"X", 12345, 1.0, 3.1) + point = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(point, Point) + assert point.srid == 12345 + assert tuple(point) == (1.0, 3.1) + + def test_custom_point_3d(self, hydration_scope): + struct = Structure(b"Y", 12345, 1.0, -2.0, 3.1) + point = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(point, Point) + assert point.srid == 12345 + assert tuple(point) == (1.0, -2.0, 3.1) diff --git a/tests/unit/common/codec/hydration/v1/test_time_dehydration.py b/tests/unit/common/codec/hydration/v1/test_time_dehydration.py new file mode 100644 index 000000000..8315d7081 --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_time_dehydration.py @@ -0,0 +1,193 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime + +import pytest +import pytz + +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.time import ( + Date, + DateTime, + Duration, + Time, +) + +from ._base import HydrationHandlerTestBase + + +class TestTimeDehydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_date(self, hydration_scope): + date = Date(1991, 8, 24) + struct = hydration_scope.dehydration_hooks[type(date)](date) + assert struct == Structure(b"D", 7905) + + def test_native_date(self, hydration_scope): + date = datetime.date(1991, 8, 24) + struct = hydration_scope.dehydration_hooks[type(date)](date) + assert struct == Structure(b"D", 7905) + + def test_time(self, hydration_scope): + time = Time(1, 2, 3, 4, pytz.FixedOffset(60)) + struct = hydration_scope.dehydration_hooks[type(time)](time) + assert struct == Structure(b"T", 3723000000004, 3600) + + def test_native_time(self, hydration_scope): + time = datetime.time(1, 2, 3, 4, pytz.FixedOffset(60)) + struct = hydration_scope.dehydration_hooks[type(time)](time) + assert struct == Structure(b"T", 3723000004000, 3600) + + def test_local_time(self, hydration_scope): + time = Time(1, 2, 3, 4) + struct = hydration_scope.dehydration_hooks[type(time)](time) + assert struct == Structure(b"t", 3723000000004) + + def test_local_native_time(self, hydration_scope): + time = datetime.time(1, 2, 3, 4) + struct = hydration_scope.dehydration_hooks[type(time)](time) + assert struct == Structure(b"t", 3723000004000) + + def test_date_time(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.FixedOffset(60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"F", 1539344261, 474716862, 3600) + + def test_native_date_time(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.FixedOffset(60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"F", 1539344261, 474716000, 3600) + + def test_date_time_negative_offset(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.FixedOffset(-60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"F", 1539344261, 474716862, -3600) + + def test_native_date_time_negative_offset(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.FixedOffset(-60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"F", 1539344261, 474716000, -3600) + + def test_date_time_zone_id(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.timezone("Europe/Stockholm")) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"f", 1539344261, 474716862, + "Europe/Stockholm") + + def test_native_date_time_zone_id(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.timezone("Europe/Stockholm")) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"f", 1539344261, 474716000, + "Europe/Stockholm") + + def test_local_date_time(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"d", 1539344261, 474716862) + + def test_native_local_date_time(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"d", 1539344261, 474716000) + + def test_duration(self, hydration_scope): + duration = Duration(months=1, days=2, seconds=3, nanoseconds=4) + struct = hydration_scope.dehydration_hooks[type(duration)](duration) + assert struct == Structure(b"E", 1, 2, 3, 4) + + def test_native_duration(self, hydration_scope): + duration = datetime.timedelta(days=1, seconds=2, microseconds=3) + struct = hydration_scope.dehydration_hooks[type(duration)](duration) + assert struct == Structure(b"E", 0, 1, 2, 3000) + + def test_duration_mixed_sign(self, hydration_scope): + duration = Duration(months=1, days=-2, seconds=3, nanoseconds=4) + struct = hydration_scope.dehydration_hooks[type(duration)](duration) + assert struct == Structure(b"E", 1, -2, 3, 4) + + def test_native_duration_mixed_sign(self, hydration_scope): + duration = datetime.timedelta(days=-1, seconds=2, microseconds=3) + struct = hydration_scope.dehydration_hooks[type(duration)](duration) + assert struct == Structure(b"E", 0, -1, 2, 3000) + + +class TestUTCPatchedTimeDehydration(TestTimeDehydration): + @pytest.fixture + def hydration_handler(self): + handler = HydrationHandler() + handler.patch_utc() + return handler + + def test_date_time(self, hydration_scope): + from ..v2.test_time_dehydration import ( + TestTimeDehydration as TestTimeDehydrationV2, + ) + TestTimeDehydrationV2().test_date_time( + hydration_scope + ) + + def test_native_date_time(self, hydration_scope): + from ..v2.test_time_dehydration import ( + TestTimeDehydration as TestTimeDehydrationV2, + ) + TestTimeDehydrationV2().test_native_date_time( + hydration_scope + ) + + def test_date_time_negative_offset(self, hydration_scope): + from ..v2.test_time_dehydration import ( + TestTimeDehydration as TestTimeDehydrationV2, + ) + TestTimeDehydrationV2().test_date_time_negative_offset( + hydration_scope + ) + + def test_native_date_time_negative_offset(self, hydration_scope): + from ..v2.test_time_dehydration import ( + TestTimeDehydration as TestTimeDehydrationV2, + ) + TestTimeDehydrationV2().test_native_date_time_negative_offset( + hydration_scope + ) + + def test_date_time_zone_id(self, hydration_scope): + from ..v2.test_time_dehydration import ( + TestTimeDehydration as TestTimeDehydrationV2, + ) + TestTimeDehydrationV2().test_date_time_zone_id( + hydration_scope + ) + + def test_native_date_time_zone_id(self, hydration_scope): + from ..v2.test_time_dehydration import ( + TestTimeDehydration as TestTimeDehydrationV2, + ) + TestTimeDehydrationV2().test_native_date_time_zone_id( + hydration_scope + ) diff --git a/tests/unit/common/codec/hydration/v1/test_time_hydration.py b/tests/unit/common/codec/hydration/v1/test_time_hydration.py new file mode 100644 index 000000000..3c04c253f --- /dev/null +++ b/tests/unit/common/codec/hydration/v1/test_time_hydration.py @@ -0,0 +1,167 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import pytz + +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.time import ( + Date, + DateTime, + Duration, + Time, +) + +from ._base import HydrationHandlerTestBase + + +class TestTimeHydration(HydrationHandlerTestBase): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_hydrate_date_structure(self, hydration_scope): + struct = Structure(b"D", 7905) + d = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(d, Date) + assert d.year == 1991 + assert d.month == 8 + assert d.day == 24 + + def test_hydrate_time_structure(self, hydration_scope): + struct = Structure(b"T", 3723000000004, 3600) + t = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(t, Time) + assert t.hour == 1 + assert t.minute == 2 + assert t.second == 3 + assert t.nanosecond == 4 + assert t.tzinfo == pytz.FixedOffset(60) + + def test_hydrate_local_time_structure(self, hydration_scope): + struct = Structure(b"t", 3723000000004) + t = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(t, Time) + assert t.hour == 1 + assert t.minute == 2 + assert t.second == 3 + assert t.nanosecond == 4 + assert t.tzinfo is None + + def test_hydrate_date_time_structure_v1(self, hydration_scope): + struct = Structure(b"F", 1539344261, 474716862, 3600) + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, DateTime) + assert dt.year == 2018 + assert dt.month == 10 + assert dt.day == 12 + assert dt.hour == 11 + assert dt.minute == 37 + assert dt.second == 41 + assert dt.nanosecond == 474716862 + assert dt.tzinfo == pytz.FixedOffset(60) + + def test_hydrate_date_time_structure_v2(self, hydration_scope): + struct = Structure(b"I", 1539344261, 474716862, 3600) + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, Structure) + assert dt == struct + + def test_hydrate_date_time_zone_id_structure_v1(self, hydration_scope): + struct = Structure(b"f", 1539344261, 474716862, "Europe/Stockholm") + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, DateTime) + assert dt.year == 2018 + assert dt.month == 10 + assert dt.day == 12 + assert dt.hour == 11 + assert dt.minute == 37 + assert dt.second == 41 + assert dt.nanosecond == 474716862 + tz = pytz.timezone("Europe/Stockholm") \ + .localize(dt.replace(tzinfo=None)).tzinfo + assert dt.tzinfo == tz + + def test_hydrate_date_time_zone_id_structure_v2(self, hydration_scope): + struct = Structure(b"i", 1539344261, 474716862, "Europe/Stockholm") + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, Structure) + assert dt == struct + + def test_hydrate_local_date_time_structure(self, hydration_scope): + struct = Structure(b"d", 1539344261, 474716862) + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, DateTime) + assert dt.year == 2018 + assert dt.month == 10 + assert dt.day == 12 + assert dt.hour == 11 + assert dt.minute == 37 + assert dt.second == 41 + assert dt.nanosecond == 474716862 + assert dt.tzinfo is None + + def test_hydrate_duration_structure(self, hydration_scope): + struct = Structure(b"E", 1, 2, 3, 4) + d = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(d, Duration) + assert d.months == 1 + assert d.days == 2 + assert d.seconds == 3 + assert d.nanoseconds == 4 + + +class TestUTCPatchedTimeHydration(TestTimeHydration): + @pytest.fixture + def hydration_handler(self): + handler = HydrationHandler() + handler.patch_utc() + return handler + + def test_hydrate_date_time_structure_v1(self, hydration_scope): + from ..v2.test_time_hydration import ( + TestTimeHydration as TestTimeHydrationV2, + ) + TestTimeHydrationV2().test_hydrate_date_time_structure_v1( + hydration_scope + ) + + def test_hydrate_date_time_structure_v2(self, hydration_scope): + from ..v2.test_time_hydration import ( + TestTimeHydration as TestTimeHydrationV2, + ) + TestTimeHydrationV2().test_hydrate_date_time_structure_v2( + hydration_scope + ) + + def test_hydrate_date_time_zone_id_structure_v1(self, hydration_scope): + from ..v2.test_time_hydration import ( + TestTimeHydration as TestTimeHydrationV2, + ) + TestTimeHydrationV2().test_hydrate_date_time_zone_id_structure_v1( + hydration_scope + ) + + def test_hydrate_date_time_zone_id_structure_v2(self, hydration_scope): + from ..v2.test_time_hydration import ( + TestTimeHydration as TestTimeHydrationV2, + ) + TestTimeHydrationV2().test_hydrate_date_time_zone_id_structure_v2( + hydration_scope + ) diff --git a/tests/unit/common/codec/hydration/v2/__init__.py b/tests/unit/common/codec/hydration/v2/__init__.py new file mode 100644 index 000000000..c42cc6fb6 --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/codec/hydration/v2/test_graph_hydration.py b/tests/unit/common/codec/hydration/v2/test_graph_hydration.py new file mode 100644 index 000000000..588f84c7e --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_graph_hydration.py @@ -0,0 +1,67 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.graph import ( + Graph, + Node, + Relationship, +) + +from ..v1.test_graph_hydration import TestGraphHydration as _TestGraphHydration + + +class TestGraphHydration(_TestGraphHydration): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_can_hydrate_node_structure(self, hydration_scope): + struct = Structure(b'N', 123, ["Person"], {"name": "Alice"}, "abc") + alice = hydration_scope.hydration_hooks[Structure](struct) + + assert isinstance(alice, Node) + with pytest.warns(DeprecationWarning, match="element_id"): + assert alice.id == 123 + assert alice.element_id == "abc" + assert alice.labels == {"Person"} + assert set(alice.keys()) == {"name"} + assert alice.get("name") == "Alice" + + def test_can_hydrate_relationship_structure(self, hydration_scope): + struct = Structure(b'R', 123, 456, 789, "KNOWS", {"since": 1999}, + "abc", "def", "ghi") + rel = hydration_scope.hydration_hooks[Structure](struct) + + assert isinstance(rel, Relationship) + with pytest.warns(DeprecationWarning, match="element_id"): + assert rel.id == 123 + with pytest.warns(DeprecationWarning, match="element_id"): + assert rel.start_node.id == 456 + with pytest.warns(DeprecationWarning, match="element_id"): + assert rel.end_node.id == 789 + # for backwards compatibility, the driver should compy the element_id + assert rel.element_id == "abc" + assert rel.start_node.element_id == "def" + assert rel.end_node.element_id == "ghi" + assert rel.type == "KNOWS" + assert set(rel.keys()) == {"since"} + assert rel.get("since") == 1999 diff --git a/tests/unit/common/codec/hydration/v2/test_hydration_handler.py b/tests/unit/common/codec/hydration/v2/test_hydration_handler.py new file mode 100644 index 000000000..c28379ea6 --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_hydration_handler.py @@ -0,0 +1,31 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler + +from ..v1.test_hydration_handler import ( + TestHydrationHandler as TestHydrationHandlerV1, +) + + +class TestHydrationHandler(TestHydrationHandlerV1): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v2/test_spacial_dehydration.py b/tests/unit/common/codec/hydration/v2/test_spacial_dehydration.py new file mode 100644 index 000000000..85349dc50 --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_spacial_dehydration.py @@ -0,0 +1,31 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler + +from ..v1.test_spacial_dehydration import ( + TestSpatialDehydration as _TestSpatialDehydrationV1, +) + + +class TestSpatialDehydration(_TestSpatialDehydrationV1): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v2/test_spacial_hydration.py b/tests/unit/common/codec/hydration/v2/test_spacial_hydration.py new file mode 100644 index 000000000..d905965ca --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_spacial_hydration.py @@ -0,0 +1,31 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j._codec.hydration.v1 import HydrationHandler + +from ..v1.test_spacial_hydration import ( + TestSpatialHydration as _TestSpatialHydrationV1, +) + + +class TestSpatialHydration(_TestSpatialHydrationV1): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() diff --git a/tests/unit/common/codec/hydration/v2/test_time_dehydration.py b/tests/unit/common/codec/hydration/v2/test_time_dehydration.py new file mode 100644 index 000000000..021db2eb4 --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_time_dehydration.py @@ -0,0 +1,74 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime + +import pytest +import pytz + +from neo4j._codec.hydration.v2 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.time import DateTime + +from ..v1.test_time_dehydration import ( + TestTimeDehydration as _TestTimeDehydrationV1, +) + + +class TestTimeDehydration(_TestTimeDehydrationV1): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_date_time(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.FixedOffset(60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"I", 1539340661, 474716862, 3600) + + def test_native_date_time(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.FixedOffset(60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"I", 1539340661, 474716000, 3600) + + def test_date_time_negative_offset(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.FixedOffset(-60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"I", 1539347861, 474716862, -3600) + + def test_native_date_time_negative_offset(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.FixedOffset(-60)) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"I", 1539347861, 474716000, -3600) + + def test_date_time_zone_id(self, hydration_scope): + dt = DateTime(2018, 10, 12, 11, 37, 41, 474716862, + pytz.timezone("Europe/Stockholm")) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"i", 1539339941, 474716862, + "Europe/Stockholm") + + def test_native_date_time_zone_id(self, hydration_scope): + dt = datetime.datetime(2018, 10, 12, 11, 37, 41, 474716, + pytz.timezone("Europe/Stockholm")) + struct = hydration_scope.dehydration_hooks[type(dt)](dt) + assert struct == Structure(b"i", 1539339941, 474716000, + "Europe/Stockholm") diff --git a/tests/unit/common/codec/hydration/v2/test_time_hydration.py b/tests/unit/common/codec/hydration/v2/test_time_hydration.py new file mode 100644 index 000000000..7fe308ec0 --- /dev/null +++ b/tests/unit/common/codec/hydration/v2/test_time_hydration.py @@ -0,0 +1,74 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime + +import pytest +import pytz + +from neo4j._codec.hydration.v2 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j.time import DateTime + +from ..v1.test_time_hydration import TestTimeHydration as _TestTimeHydrationV1 + + +class TestTimeHydration(_TestTimeHydrationV1): + @pytest.fixture + def hydration_handler(self): + return HydrationHandler() + + def test_hydrate_date_time_structure_v1(self, hydration_scope): + struct = Structure(b"F", 1539344261, 474716862, 3600) + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, Structure) + assert dt == struct + + def test_hydrate_date_time_structure_v2(self, hydration_scope): + struct = Structure(b"I", 1539344261, 474716862, 3600) + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, DateTime) + assert dt.year == 2018 + assert dt.month == 10 + assert dt.day == 12 + assert dt.hour == 12 + assert dt.minute == 37 + assert dt.second == 41 + assert dt.nanosecond == 474716862 + assert dt.tzinfo == pytz.FixedOffset(60) + + def test_hydrate_date_time_zone_id_structure_v1(self, hydration_scope): + struct = Structure(b"f", 1539344261, 474716862, "Europe/Stockholm") + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, Structure) + assert dt == struct + + def test_hydrate_date_time_zone_id_structure_v2(self, hydration_scope): + struct = Structure(b"i", 1539344261, 474716862, "Europe/Stockholm") + dt = hydration_scope.hydration_hooks[Structure](struct) + assert isinstance(dt, DateTime) + assert dt.year == 2018 + assert dt.month == 10 + assert dt.day == 12 + assert dt.hour == 13 + assert dt.minute == 37 + assert dt.second == 41 + assert dt.nanosecond == 474716862 + tz = pytz.timezone("Europe/Stockholm") \ + .localize(dt.replace(tzinfo=None)).tzinfo + assert dt.tzinfo == tz diff --git a/tests/unit/common/codec/packstream/__init__.py b/tests/unit/common/codec/packstream/__init__.py new file mode 100644 index 000000000..c42cc6fb6 --- /dev/null +++ b/tests/unit/common/codec/packstream/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/codec/packstream/v1/__init__.py b/tests/unit/common/codec/packstream/v1/__init__.py new file mode 100644 index 000000000..c42cc6fb6 --- /dev/null +++ b/tests/unit/common/codec/packstream/v1/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/common/codec/packstream/v1/test_packstream.py b/tests/unit/common/codec/packstream/v1/test_packstream.py new file mode 100644 index 000000000..14f8fcfb5 --- /dev/null +++ b/tests/unit/common/codec/packstream/v1/test_packstream.py @@ -0,0 +1,326 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import struct +from io import BytesIO +from math import pi +from uuid import uuid4 + +import pytest + +from neo4j._codec.packstream import Structure +from neo4j._codec.packstream.v1 import ( + PackableBuffer, + Packer, + UnpackableBuffer, + Unpacker, +) + + +standard_ascii = [chr(i) for i in range(128)] +not_ascii = "♥O◘♦♥O◘♦" + + +class TestPackStream: + @pytest.fixture + def packer_with_buffer(self): + packable_buffer = Packer.new_packable_buffer() + return Packer(packable_buffer), packable_buffer + + @pytest.fixture + def unpacker_with_buffer(self): + unpackable_buffer = Unpacker.new_unpackable_buffer() + return Unpacker(unpackable_buffer), unpackable_buffer + + def test_packable_buffer(self, packer_with_buffer): + packer, packable_buffer = packer_with_buffer + assert isinstance(packable_buffer, PackableBuffer) + assert packable_buffer is packer.stream + + def test_unpackable_buffer(self, unpacker_with_buffer): + unpacker, unpackable_buffer = unpacker_with_buffer + assert isinstance(unpackable_buffer, UnpackableBuffer) + assert unpackable_buffer is unpacker.unpackable + + @pytest.fixture + def pack(self, packer_with_buffer): + packer, packable_buffer = packer_with_buffer + + def _pack(*values, dehydration_hooks=None): + for value in values: + packer.pack(value, dehydration_hooks=dehydration_hooks) + data = bytearray(packable_buffer.data) + packable_buffer.clear() + return data + + return _pack + + @pytest.fixture + def assert_packable(self, packer_with_buffer, unpacker_with_buffer): + def _assert(value, packed_value): + nonlocal packer_with_buffer, unpacker_with_buffer + packer, packable_buffer = packer_with_buffer + unpacker, unpackable_buffer = unpacker_with_buffer + packable_buffer.clear() + unpackable_buffer.reset() + + packer.pack(value) + packed_data = packable_buffer.data + assert packed_data == packed_value + + unpackable_buffer.data = bytearray(packed_data) + unpackable_buffer.used = len(packed_data) + unpacked_data = unpacker.unpack() + assert unpacked_data == value + + return _assert + + def test_none(self, assert_packable): + assert_packable(None, b"\xC0") + + def test_boolean(self, assert_packable): + assert_packable(True, b"\xC3") + assert_packable(False, b"\xC2") + + def test_negative_tiny_int(self, assert_packable): + for z in range(-16, 0): + assert_packable(z, bytes(bytearray([z + 0x100]))) + + def test_positive_tiny_int(self, assert_packable): + for z in range(0, 128): + assert_packable(z, bytes(bytearray([z]))) + + def test_negative_int8(self, assert_packable): + for z in range(-128, -16): + assert_packable(z, bytes(bytearray([0xC8, z + 0x100]))) + + def test_positive_int16(self, assert_packable): + for z in range(128, 32768): + expected = b"\xC9" + struct.pack(">h", z) + assert_packable(z, expected) + + def test_negative_int16(self, assert_packable): + for z in range(-32768, -128): + expected = b"\xC9" + struct.pack(">h", z) + assert_packable(z, expected) + + def test_positive_int32(self, assert_packable): + for e in range(15, 31): + z = 2 ** e + expected = b"\xCA" + struct.pack(">i", z) + assert_packable(z, expected) + + def test_negative_int32(self, assert_packable): + for e in range(15, 31): + z = -(2 ** e + 1) + expected = b"\xCA" + struct.pack(">i", z) + assert_packable(z, expected) + + def test_positive_int64(self, assert_packable): + for e in range(31, 63): + z = 2 ** e + expected = b"\xCB" + struct.pack(">q", z) + assert_packable(z, expected) + + def test_negative_int64(self, assert_packable): + for e in range(31, 63): + z = -(2 ** e + 1) + expected = b"\xCB" + struct.pack(">q", z) + assert_packable(z, expected) + + def test_integer_positive_overflow(self, pack, assert_packable): + with pytest.raises(OverflowError): + pack(2 ** 63 + 1) + + def test_integer_negative_overflow(self, pack, assert_packable): + with pytest.raises(OverflowError): + pack(-(2 ** 63) - 1) + + def test_zero_float64(self, assert_packable): + zero = 0.0 + expected = b"\xC1" + struct.pack(">d", zero) + assert_packable(zero, expected) + + def test_tau_float64(self, assert_packable): + tau = 2 * pi + expected = b"\xC1" + struct.pack(">d", tau) + assert_packable(tau, expected) + + def test_positive_float64(self, assert_packable): + for e in range(0, 100): + r = float(2 ** e) + 0.5 + expected = b"\xC1" + struct.pack(">d", r) + assert_packable(r, expected) + + def test_negative_float64(self, assert_packable): + for e in range(0, 100): + r = -(float(2 ** e) + 0.5) + expected = b"\xC1" + struct.pack(">d", r) + assert_packable(r, expected) + + def test_empty_bytes(self, assert_packable): + assert_packable(b"", b"\xCC\x00") + + def test_empty_bytearray(self, assert_packable): + assert_packable(bytearray(), b"\xCC\x00") + + def test_bytes_8(self, assert_packable): + assert_packable(bytearray(b"hello"), b"\xCC\x05hello") + + def test_bytes_16(self, assert_packable): + b = bytearray(40000) + assert_packable(b, b"\xCD\x9C\x40" + b) + + def test_bytes_32(self, assert_packable): + b = bytearray(80000) + assert_packable(b, b"\xCE\x00\x01\x38\x80" + b) + + def test_bytearray_size_overflow(self, assert_packable): + stream_out = BytesIO() + packer = Packer(stream_out) + with pytest.raises(OverflowError): + packer.pack_bytes_header(2 ** 32) + + def test_empty_string(self, assert_packable): + assert_packable(u"", b"\x80") + + def test_tiny_strings(self, assert_packable): + for size in range(0x10): + assert_packable(u"A" * size, bytes(bytearray([0x80 + size]) + (b"A" * size))) + + def test_string_8(self, assert_packable): + t = u"A" * 40 + b = t.encode("utf-8") + assert_packable(t, b"\xD0\x28" + b) + + def test_string_16(self, assert_packable): + t = u"A" * 40000 + b = t.encode("utf-8") + assert_packable(t, b"\xD1\x9C\x40" + b) + + def test_string_32(self, assert_packable): + t = u"A" * 80000 + b = t.encode("utf-8") + assert_packable(t, b"\xD2\x00\x01\x38\x80" + b) + + def test_unicode_string(self, assert_packable): + t = u"héllö" + b = t.encode("utf-8") + assert_packable(t, bytes(bytearray([0x80 + len(b)])) + b) + + def test_string_size_overflow(self): + stream_out = BytesIO() + packer = Packer(stream_out) + with pytest.raises(OverflowError): + packer.pack_string_header(2 ** 32) + + def test_empty_list(self, assert_packable): + assert_packable([], b"\x90") + + def test_tiny_lists(self, assert_packable): + for size in range(0x10): + data_out = bytearray([0x90 + size]) + bytearray([1] * size) + assert_packable([1] * size, bytes(data_out)) + + def test_list_8(self, assert_packable): + l = [1] * 40 + assert_packable(l, b"\xD4\x28" + (b"\x01" * 40)) + + def test_list_16(self, assert_packable): + l = [1] * 40000 + assert_packable(l, b"\xD5\x9C\x40" + (b"\x01" * 40000)) + + def test_list_32(self, assert_packable): + l = [1] * 80000 + assert_packable(l, b"\xD6\x00\x01\x38\x80" + (b"\x01" * 80000)) + + def test_nested_lists(self, assert_packable): + assert_packable([[[]]], b"\x91\x91\x90") + + def test_list_size_overflow(self): + stream_out = BytesIO() + packer = Packer(stream_out) + with pytest.raises(OverflowError): + packer.pack_list_header(2 ** 32) + + def test_empty_map(self, assert_packable): + assert_packable({}, b"\xA0") + + @pytest.mark.parametrize("size", range(0x10)) + def test_tiny_maps(self, assert_packable, size): + data_in = dict() + data_out = bytearray([0xA0 + size]) + for el in range(1, size + 1): + data_in[chr(64 + el)] = el + data_out += bytearray([0x81, 64 + el, el]) + assert_packable(data_in, bytes(data_out)) + + def test_map_8(self, pack, assert_packable): + d = dict([(u"A%s" % i, 1) for i in range(40)]) + b = b"".join(pack(u"A%s" % i, 1) for i in range(40)) + assert_packable(d, b"\xD8\x28" + b) + + def test_map_16(self, pack, assert_packable): + d = dict([(u"A%s" % i, 1) for i in range(40000)]) + b = b"".join(pack(u"A%s" % i, 1) for i in range(40000)) + assert_packable(d, b"\xD9\x9C\x40" + b) + + def test_map_32(self, pack, assert_packable): + d = dict([(u"A%s" % i, 1) for i in range(80000)]) + b = b"".join(pack(u"A%s" % i, 1) for i in range(80000)) + assert_packable(d, b"\xDA\x00\x01\x38\x80" + b) + + def test_map_size_overflow(self): + stream_out = BytesIO() + packer = Packer(stream_out) + with pytest.raises(OverflowError): + packer.pack_map_header(2 ** 32) + + @pytest.mark.parametrize(("map_", "exc_type"), ( + ({1: "1"}, TypeError), + ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), + ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), + )) + def test_map_key_type(self, packer_with_buffer, map_, exc_type): + # maps must have string keys + packer, packable_buffer = packer_with_buffer + with pytest.raises(exc_type, match="strings"): + packer.pack(map_) + + def test_illegal_signature(self, assert_packable): + with pytest.raises(ValueError): + assert_packable(Structure(b"XXX"), b"\xB0XXX") + + def test_empty_struct(self, assert_packable): + assert_packable(Structure(b"X"), b"\xB0X") + + def test_tiny_structs(self, assert_packable): + for size in range(0x10): + fields = [1] * size + data_in = Structure(b"A", *fields) + data_out = bytearray([0xB0 + size, 0x41] + fields) + assert_packable(data_in, bytes(data_out)) + + def test_struct_size_overflow(self, pack): + with pytest.raises(OverflowError): + fields = [1] * 16 + pack(Structure(b"X", *fields)) + + def test_illegal_uuid(self, assert_packable): + with pytest.raises(ValueError): + assert_packable(uuid4(), b"\xB0XXX") diff --git a/tests/unit/common/data/test_packing.py b/tests/unit/common/data/test_packing.py deleted file mode 100644 index 8b274b587..000000000 --- a/tests/unit/common/data/test_packing.py +++ /dev/null @@ -1,284 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from io import BytesIO -from math import pi -import struct -from unittest import TestCase -from uuid import uuid4 - -from pytest import raises - -from neo4j.packstream import ( - Packer, - Structure, - UnpackableBuffer, - Unpacker, -) - - -class PackStreamTestCase(TestCase): - - @classmethod - def packb(cls, *values): - stream = BytesIO() - packer = Packer(stream) - for value in values: - packer.pack(value) - return stream.getvalue() - - @classmethod - def assert_packable(cls, value, packed_value): - stream_out = BytesIO() - packer = Packer(stream_out) - packer.pack(value) - packed = stream_out.getvalue() - try: - assert packed == packed_value - except AssertionError: - raise AssertionError("Packed value %r is %r instead of expected %r" % - (value, packed, packed_value)) - unpacked = Unpacker(UnpackableBuffer(packed)).unpack() - try: - assert unpacked == value - except AssertionError: - raise AssertionError("Unpacked value %r is not equal to original %r" % (unpacked, value)) - - def test_none(self): - self.assert_packable(None, b"\xC0") - - def test_boolean(self): - self.assert_packable(True, b"\xC3") - self.assert_packable(False, b"\xC2") - - def test_negative_tiny_int(self): - for z in range(-16, 0): - self.assert_packable(z, bytes(bytearray([z + 0x100]))) - - def test_positive_tiny_int(self): - for z in range(0, 128): - self.assert_packable(z, bytes(bytearray([z]))) - - def test_negative_int8(self): - for z in range(-128, -16): - self.assert_packable(z, bytes(bytearray([0xC8, z + 0x100]))) - - def test_positive_int16(self): - for z in range(128, 32768): - expected = b"\xC9" + struct.pack(">h", z) - self.assert_packable(z, expected) - - def test_negative_int16(self): - for z in range(-32768, -128): - expected = b"\xC9" + struct.pack(">h", z) - self.assert_packable(z, expected) - - def test_positive_int32(self): - for e in range(15, 31): - z = 2 ** e - expected = b"\xCA" + struct.pack(">i", z) - self.assert_packable(z, expected) - - def test_negative_int32(self): - for e in range(15, 31): - z = -(2 ** e + 1) - expected = b"\xCA" + struct.pack(">i", z) - self.assert_packable(z, expected) - - def test_positive_int64(self): - for e in range(31, 63): - z = 2 ** e - expected = b"\xCB" + struct.pack(">q", z) - self.assert_packable(z, expected) - - def test_negative_int64(self): - for e in range(31, 63): - z = -(2 ** e + 1) - expected = b"\xCB" + struct.pack(">q", z) - self.assert_packable(z, expected) - - def test_integer_positive_overflow(self): - with raises(OverflowError): - self.packb(2 ** 63 + 1) - - def test_integer_negative_overflow(self): - with raises(OverflowError): - self.packb(-(2 ** 63) - 1) - - def test_zero_float64(self): - zero = 0.0 - expected = b"\xC1" + struct.pack(">d", zero) - self.assert_packable(zero, expected) - - def test_tau_float64(self): - tau = 2 * pi - expected = b"\xC1" + struct.pack(">d", tau) - self.assert_packable(tau, expected) - - def test_positive_float64(self): - for e in range(0, 100): - r = float(2 ** e) + 0.5 - expected = b"\xC1" + struct.pack(">d", r) - self.assert_packable(r, expected) - - def test_negative_float64(self): - for e in range(0, 100): - r = -(float(2 ** e) + 0.5) - expected = b"\xC1" + struct.pack(">d", r) - self.assert_packable(r, expected) - - def test_empty_bytes(self): - self.assert_packable(b"", b"\xCC\x00") - - def test_empty_bytearray(self): - self.assert_packable(bytearray(), b"\xCC\x00") - - def test_bytes_8(self): - self.assert_packable(bytearray(b"hello"), b"\xCC\x05hello") - - def test_bytes_16(self): - b = bytearray(40000) - self.assert_packable(b, b"\xCD\x9C\x40" + b) - - def test_bytes_32(self): - b = bytearray(80000) - self.assert_packable(b, b"\xCE\x00\x01\x38\x80" + b) - - def test_bytearray_size_overflow(self): - stream_out = BytesIO() - packer = Packer(stream_out) - with raises(OverflowError): - packer.pack_bytes_header(2 ** 32) - - def test_empty_string(self): - self.assert_packable(u"", b"\x80") - - def test_tiny_strings(self): - for size in range(0x10): - self.assert_packable(u"A" * size, bytes(bytearray([0x80 + size]) + (b"A" * size))) - - def test_string_8(self): - t = u"A" * 40 - b = t.encode("utf-8") - self.assert_packable(t, b"\xD0\x28" + b) - - def test_string_16(self): - t = u"A" * 40000 - b = t.encode("utf-8") - self.assert_packable(t, b"\xD1\x9C\x40" + b) - - def test_string_32(self): - t = u"A" * 80000 - b = t.encode("utf-8") - self.assert_packable(t, b"\xD2\x00\x01\x38\x80" + b) - - def test_unicode_string(self): - t = u"héllö" - b = t.encode("utf-8") - self.assert_packable(t, bytes(bytearray([0x80 + len(b)])) + b) - - def test_string_size_overflow(self): - stream_out = BytesIO() - packer = Packer(stream_out) - with raises(OverflowError): - packer.pack_string_header(2 ** 32) - - def test_empty_list(self): - self.assert_packable([], b"\x90") - - def test_tiny_lists(self): - for size in range(0x10): - data_out = bytearray([0x90 + size]) + bytearray([1] * size) - self.assert_packable([1] * size, bytes(data_out)) - - def test_list_8(self): - l = [1] * 40 - self.assert_packable(l, b"\xD4\x28" + (b"\x01" * 40)) - - def test_list_16(self): - l = [1] * 40000 - self.assert_packable(l, b"\xD5\x9C\x40" + (b"\x01" * 40000)) - - def test_list_32(self): - l = [1] * 80000 - self.assert_packable(l, b"\xD6\x00\x01\x38\x80" + (b"\x01" * 80000)) - - def test_nested_lists(self): - self.assert_packable([[[]]], b"\x91\x91\x90") - - def test_list_size_overflow(self): - stream_out = BytesIO() - packer = Packer(stream_out) - with raises(OverflowError): - packer.pack_list_header(2 ** 32) - - def test_empty_map(self): - self.assert_packable({}, b"\xA0") - - def test_tiny_maps(self): - for size in range(0x10): - data_in = dict() - data_out = bytearray([0xA0 + size]) - for el in range(1, size + 1): - data_in[chr(64 + el)] = el - data_out += bytearray([0x81, 64 + el, el]) - self.assert_packable(data_in, bytes(data_out)) - - def test_map_8(self): - d = dict([(u"A%s" % i, 1) for i in range(40)]) - b = b"".join(self.packb(u"A%s" % i, 1) for i in range(40)) - self.assert_packable(d, b"\xD8\x28" + b) - - def test_map_16(self): - d = dict([(u"A%s" % i, 1) for i in range(40000)]) - b = b"".join(self.packb(u"A%s" % i, 1) for i in range(40000)) - self.assert_packable(d, b"\xD9\x9C\x40" + b) - - def test_map_32(self): - d = dict([(u"A%s" % i, 1) for i in range(80000)]) - b = b"".join(self.packb(u"A%s" % i, 1) for i in range(80000)) - self.assert_packable(d, b"\xDA\x00\x01\x38\x80" + b) - - def test_map_size_overflow(self): - stream_out = BytesIO() - packer = Packer(stream_out) - with raises(OverflowError): - packer.pack_map_header(2 ** 32) - - def test_illegal_signature(self): - with self.assertRaises(ValueError): - self.assert_packable(Structure(b"XXX"), b"\xB0XXX") - - def test_empty_struct(self): - self.assert_packable(Structure(b"X"), b"\xB0X") - - def test_tiny_structs(self): - for size in range(0x10): - fields = [1] * size - data_in = Structure(b"A", *fields) - data_out = bytearray([0xB0 + size, 0x41] + fields) - self.assert_packable(data_in, bytes(data_out)) - - def test_struct_size_overflow(self): - with raises(OverflowError): - fields = [1] * 16 - self.packb(Structure(b"X", *fields)) - - def test_illegal_uuid(self): - with self.assertRaises(ValueError): - self.assert_packable(uuid4(), b"\xB0XXX") diff --git a/tests/unit/common/spatial/test_cartesian_point.py b/tests/unit/common/spatial/test_cartesian_point.py index 742aa7b61..5ec40ac5e 100644 --- a/tests/unit/common/spatial/test_cartesian_point.py +++ b/tests/unit/common/spatial/test_cartesian_point.py @@ -16,14 +16,8 @@ # limitations under the License. -import io -import struct from unittest import TestCase -import pytest - -from neo4j.data import DataDehydrator -from neo4j.packstream import Packer from neo4j.spatial import CartesianPoint @@ -48,33 +42,3 @@ def test_alias_2d(self): self.assertEqual(p.y, y) with self.assertRaises(AttributeError): p.z - - def test_dehydration_3d(self): - coordinates = (1, -2, 3.1) - p = CartesianPoint(coordinates) - - dehydrator = DataDehydrator() - buffer = io.BytesIO() - packer = Packer(buffer) - packer.pack(dehydrator.dehydrate((p,))[0]) - self.assertEqual( - buffer.getvalue(), - b"\xB4Y" + - b"\xC9" + struct.pack(">h", 9157) + - b"".join(map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)) - ) - - def test_dehydration_2d(self): - coordinates = (.1, 0) - p = CartesianPoint(coordinates) - - dehydrator = DataDehydrator() - buffer = io.BytesIO() - packer = Packer(buffer) - packer.pack(dehydrator.dehydrate((p,))[0]) - self.assertEqual( - buffer.getvalue(), - b"\xB3X" + - b"\xC9" + struct.pack(">h", 7203) + - b"".join(map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)) - ) diff --git a/tests/unit/common/spatial/test_point.py b/tests/unit/common/spatial/test_point.py index fd7f35e98..816ee01c5 100644 --- a/tests/unit/common/spatial/test_point.py +++ b/tests/unit/common/spatial/test_point.py @@ -16,12 +16,8 @@ # limitations under the License. -import io -import struct from unittest import TestCase -from neo4j.data import DataDehydrator -from neo4j.packstream import Packer from neo4j.spatial import ( Point, point_type, @@ -42,22 +38,6 @@ def test_number_arguments(self): p = Point(argument) assert tuple(p) == argument - def test_dehydration(self): - MyPoint = point_type("MyPoint", ["x", "y"], {2: 1234}) - coordinates = (.1, 0) - p = MyPoint(coordinates) - - dehydrator = DataDehydrator() - buffer = io.BytesIO() - packer = Packer(buffer) - packer.pack(dehydrator.dehydrate((p,))[0]) - self.assertEqual( - buffer.getvalue(), - b"\xB3X" + - b"\xC9" + struct.pack(">h", 1234) + - b"".join(map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)) - ) - def test_immutable_coordinates(self): MyPoint = point_type("MyPoint", ["x", "y"], {2: 1234}) coordinates = (.1, 0) diff --git a/tests/unit/common/spatial/test_wgs84_point.py b/tests/unit/common/spatial/test_wgs84_point.py index 43f4f251f..540cfd2c4 100644 --- a/tests/unit/common/spatial/test_wgs84_point.py +++ b/tests/unit/common/spatial/test_wgs84_point.py @@ -16,12 +16,8 @@ # limitations under the License. -import io -import struct from unittest import TestCase -from neo4j.data import DataDehydrator -from neo4j.packstream import Packer from neo4j.spatial import WGS84Point @@ -64,33 +60,3 @@ def test_alias_2d(self): p.height with self.assertRaises(AttributeError): p.z - - def test_dehydration_3d(self): - coordinates = (1, -2, 3.1) - p = WGS84Point(coordinates) - - dehydrator = DataDehydrator() - buffer = io.BytesIO() - packer = Packer(buffer) - packer.pack(dehydrator.dehydrate((p,))[0]) - self.assertEqual( - buffer.getvalue(), - b"\xB4Y" + - b"\xC9" + struct.pack(">h", 4979) + - b"".join(map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)) - ) - - def test_dehydration_2d(self): - coordinates = (.1, 0) - p = WGS84Point(coordinates) - - dehydrator = DataDehydrator() - buffer = io.BytesIO() - packer = Packer(buffer) - packer.pack(dehydrator.dehydrate((p,))[0]) - self.assertEqual( - buffer.getvalue(), - b"\xB3X" + - b"\xC9" + struct.pack(">h", 4326) + - b"".join(map(lambda c: b"\xC1" + struct.pack(">d", c), coordinates)) - ) diff --git a/tests/unit/common/test_addressing.py b/tests/unit/common/test_addressing.py index eafe7f17f..99b730f38 100644 --- a/tests/unit/common/test_addressing.py +++ b/tests/unit/common/test_addressing.py @@ -20,7 +20,7 @@ AF_INET, AF_INET6, ) -import unittest.mock as mock +from unittest import mock import pytest diff --git a/tests/unit/common/test_api.py b/tests/unit/common/test_api.py index f0920d953..5f8576aab 100644 --- a/tests/unit/common/test_api.py +++ b/tests/unit/common/test_api.py @@ -16,14 +16,11 @@ # limitations under the License. -from copy import deepcopy import itertools -from uuid import uuid4 import pytest import neo4j.api -from neo4j.data import DataDehydrator from neo4j.exceptions import ConfigurationError @@ -31,125 +28,6 @@ not_ascii = "♥O◘♦♥O◘♦" -def dehydrated_value(value): - return DataDehydrator.fix_parameters({"_": value})["_"] - - -def test_value_dehydration_should_allow_none(): - assert dehydrated_value(None) is None - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (True, True), - (False, False), - ] -) -def test_value_dehydration_should_allow_boolean(test_input, expected): - assert dehydrated_value(test_input) is expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (0, 0), - (1, 1), - (0x7F, 0x7F), - (0x7FFF, 0x7FFF), - (0x7FFFFFFF, 0x7FFFFFFF), - (0x7FFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFFF), - ] -) -def test_value_dehydration_should_allow_integer(test_input, expected): - assert dehydrated_value(test_input) == expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (0x10000000000000000, ValueError), - (-0x10000000000000000, ValueError), - ] -) -def test_value_dehydration_should_disallow_oversized_integer(test_input, expected): - with pytest.raises(expected): - dehydrated_value(test_input) - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (0.0, 0.0), - (-0.1, -0.1), - (3.1415926, 3.1415926), - (-3.1415926, -3.1415926), - ] -) -def test_value_dehydration_should_allow_float(test_input, expected): - assert dehydrated_value(test_input) == expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (u"", u""), - (u"hello, world", u"hello, world"), - ("".join(standard_ascii), "".join(standard_ascii)), - ] -) -def test_value_dehydration_should_allow_string(test_input, expected): - assert dehydrated_value(test_input) == expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (bytearray(), bytearray()), - (bytearray([1, 2, 3]), bytearray([1, 2, 3])), - ] -) -def test_value_dehydration_should_allow_bytes(test_input, expected): - assert dehydrated_value(test_input) == expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - ([], []), - ([1, 2, 3], [1, 2, 3]), - ([1, 3.1415926, "string", None], [1, 3.1415926, "string", None]) - ] -) -def test_value_dehydration_should_allow_list(test_input, expected): - assert dehydrated_value(test_input) == expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - ({}, {}), - ({u"one": 1, u"two": 1, u"three": 1}, {u"one": 1, u"two": 1, u"three": 1}), - ({u"list": [1, 2, 3, [4, 5, 6]], u"dict": {u"a": 1, u"b": 2}}, {u"list": [1, 2, 3, [4, 5, 6]], u"dict": {u"a": 1, u"b": 2}}), - ({"alpha": [1, 3.1415926, "string", None]}, {"alpha": [1, 3.1415926, "string", None]}), - ] -) -def test_value_dehydration_should_allow_dict(test_input, expected): - assert dehydrated_value(test_input) == expected - - -@pytest.mark.parametrize( - "test_input, expected", - [ - (object(), TypeError), - (uuid4(), TypeError), - ] -) -def test_value_dehydration_should_disallow_object(test_input, expected): - with pytest.raises(expected): - dehydrated_value(test_input) - - def test_bookmark_is_deprecated(): with pytest.deprecated_call(): neo4j.Bookmark() diff --git a/tests/unit/common/test_data.py b/tests/unit/common/test_data.py deleted file mode 100644 index 59641aef6..000000000 --- a/tests/unit/common/test_data.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest - -from neo4j.data import DataHydrator -from neo4j.packstream import Structure - - -# python -m pytest -s -v tests/unit/test_data.py - - -def test_can_hydrate_v1_node_structure(): - hydrant = DataHydrator() - - struct = Structure(b'N', 123, ["Person"], {"name": "Alice"}) - alice, = hydrant.hydrate([struct]) - - with pytest.warns(DeprecationWarning, match="element_id"): - assert alice.id == 123 - # for backwards compatibility, the driver should compy the element_id - assert alice.element_id == "123" - assert alice.labels == {"Person"} - assert set(alice.keys()) == {"name"} - assert alice.get("name") == "Alice" - - -@pytest.mark.parametrize("with_id", (True, False)) -def test_can_hydrate_v2_node_structure(with_id): - hydrant = DataHydrator() - - id_ = 123 if with_id else None - - struct = Structure(b'N', id_, ["Person"], {"name": "Alice"}, "abc") - alice, = hydrant.hydrate([struct]) - - with pytest.warns(DeprecationWarning, match="element_id"): - assert alice.id == id_ - assert alice.element_id == "abc" - assert alice.labels == {"Person"} - assert set(alice.keys()) == {"name"} - assert alice.get("name") == "Alice" - - -def test_can_hydrate_v1_relationship_structure(): - hydrant = DataHydrator() - - struct = Structure(b'R', 123, 456, 789, "KNOWS", {"since": 1999}) - rel, = hydrant.hydrate([struct]) - - with pytest.warns(DeprecationWarning, match="element_id"): - assert rel.id == 123 - with pytest.warns(DeprecationWarning, match="element_id"): - assert rel.start_node.id == 456 - with pytest.warns(DeprecationWarning, match="element_id"): - assert rel.end_node.id == 789 - # for backwards compatibility, the driver should compy the element_id - assert rel.element_id == "123" - assert rel.start_node.element_id == "456" - assert rel.end_node.element_id == "789" - assert rel.type == "KNOWS" - assert set(rel.keys()) == {"since"} - assert rel.get("since") == 1999 - - -@pytest.mark.parametrize("with_ids", (True, False)) -def test_can_hydrate_v2_relationship_structure(with_ids): - hydrant = DataHydrator() - - id_ = 123 if with_ids else None - start_id = 456 if with_ids else None - end_id = 789 if with_ids else None - - struct = Structure(b'R', id_, start_id, end_id, "KNOWS", {"since": 1999}, - "abc", "def", "ghi") - - rel, = hydrant.hydrate([struct]) - - with pytest.warns(DeprecationWarning, match="element_id"): - assert rel.id == id_ - with pytest.warns(DeprecationWarning, match="element_id"): - assert rel.start_node.id == start_id - with pytest.warns(DeprecationWarning, match="element_id"): - assert rel.end_node.id == end_id - # for backwards compatibility, the driver should compy the element_id - assert rel.element_id == "abc" - assert rel.start_node.element_id == "def" - assert rel.end_node.element_id == "ghi" - assert rel.type == "KNOWS" - assert set(rel.keys()) == {"since"} - assert rel.get("since") == 1999 - - -def test_hydrating_unknown_structure_returns_same(): - hydrant = DataHydrator() - - struct = Structure(b'?', "foo") - mystery, = hydrant.hydrate([struct]) - - assert mystery == struct - - -def test_can_hydrate_in_list(): - hydrant = DataHydrator() - - struct = Structure(b'N', 123, ["Person"], {"name": "Alice"}) - alice_in_list, = hydrant.hydrate([[struct]]) - - assert isinstance(alice_in_list, list) - - alice, = alice_in_list - - assert alice.id == 123 - assert alice.labels == {"Person"} - assert set(alice.keys()) == {"name"} - assert alice.get("name") == "Alice" - - -def test_can_hydrate_in_dict(): - hydrant = DataHydrator() - - struct = Structure(b'N', 123, ["Person"], {"name": "Alice"}) - alice_in_dict, = hydrant.hydrate([{"foo": struct}]) - - assert isinstance(alice_in_dict, dict) - - alice = alice_in_dict["foo"] - - assert alice.id == 123 - assert alice.labels == {"Person"} - assert set(alice.keys()) == {"name"} - assert alice.get("name") == "Alice" diff --git a/tests/unit/common/test_import_neo4j.py b/tests/unit/common/test_import_neo4j.py index 01bfd5905..a42e6207e 100644 --- a/tests/unit/common/test_import_neo4j.py +++ b/tests/unit/common/test_import_neo4j.py @@ -137,7 +137,7 @@ def test_import_poolconfig(): def test_import_graph(): - import neo4j.graph as graph + from neo4j import graph def test_import_graph_node(): @@ -153,12 +153,12 @@ def test_import_graph_graph(): def test_import_spatial(): - import neo4j.spatial as spatial + from neo4j import spatial def test_import_time(): - import neo4j.time as time + from neo4j import time def test_import_exceptions(): - import neo4j.exceptions as exceptions + from neo4j import exceptions diff --git a/tests/unit/common/test_record.py b/tests/unit/common/test_record.py index 8999af832..be62f8b6c 100644 --- a/tests/unit/common/test_record.py +++ b/tests/unit/common/test_record.py @@ -18,10 +18,11 @@ import pytest -from neo4j.data import ( +from neo4j import Record +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j.graph import ( Graph, Node, - Record, ) @@ -283,8 +284,8 @@ def test_data(raw, keys, serialized): def test_data_relationship(): - g = Graph() - gh = Graph.Hydrator(g) + hydration_scope = HydrationHandler().new_hydration_scope() + gh = hydration_scope._graph_hydrator alice = gh.hydrate_node(1, {"Person"}, {"name": "Alice", "age": 33}) bob = gh.hydrate_node(2, {"Person"}, {"name": "Bob", "age": 44}) alice_knows_bob = gh.hydrate_relationship(1, alice.id, bob.id, "KNOWS", @@ -302,8 +303,8 @@ def test_data_relationship(): def test_data_unbound_relationship(): - g = Graph() - gh = Graph.Hydrator(g) + hydration_scope = HydrationHandler().new_hydration_scope() + gh = hydration_scope._graph_hydrator some_one_knows_some_one = gh.hydrate_relationship( 1, 42, 43, "KNOWS", {"since": 1999} ) @@ -313,8 +314,8 @@ def test_data_unbound_relationship(): @pytest.mark.parametrize("cyclic", (True, False)) def test_data_path(cyclic): - g = Graph() - gh = Graph.Hydrator(g) + hydration_scope = HydrationHandler().new_hydration_scope() + gh = hydration_scope._graph_hydrator alice = gh.hydrate_node(1, {"Person"}, {"name": "Alice", "age": 33}) bob = gh.hydrate_node(2, {"Person"}, {"name": "Bob", "age": 44}) if cyclic: diff --git a/tests/unit/common/test_types.py b/tests/unit/common/test_types.py index 8657288c9..e9dfec721 100644 --- a/tests/unit/common/test_types.py +++ b/tests/unit/common/test_types.py @@ -20,6 +20,7 @@ import pytest +from neo4j._codec.hydration.v1 import HydrationHandler from neo4j.graph import ( Graph, Node, @@ -40,8 +41,8 @@ (None, "foobar"), )) def test_can_create_node(id_, element_id): - g = Graph() - gh = Graph.Hydrator(g) + hydration_scope = HydrationHandler().new_hydration_scope() + gh = hydration_scope._graph_hydrator fields = [id_, {"Person"}, {"name": "Alice", "age": 33}] if element_id is not None: @@ -74,8 +75,8 @@ def test_can_create_node(id_, element_id): def test_node_with_null_properties(): - g = Graph() - gh = Graph.Hydrator(g) + hydration_scope = HydrationHandler().new_hydration_scope() + gh = hydration_scope._graph_hydrator stuff = gh.hydrate_node(1, (), {"good": ["puppies", "kittens"], "bad": None}) assert isinstance(stuff, Node) @@ -132,8 +133,8 @@ def test_node_hashing(legacy_id): def test_node_v1_repr(): - g = Graph() - gh = Graph.Hydrator(g) + hydration_scope = HydrationHandler().new_hydration_scope() + gh = hydration_scope._graph_hydrator alice = gh.hydrate_node(1, {"Person"}, {"name": "Alice"}) assert repr(alice) == ( "H", self.recv_buffer[:2]) - print("CHUNK SIZE %r" % chunk_size) - end = 2 + chunk_size - chunk_data, self.recv_buffer = self.recv_buffer[2:end], self.recv_buffer[end:] - return chunk_data - def pop_message(self): - data = bytearray() - while True: - chunk = self._pop_chunk() - print("CHUNK %r" % chunk) - if chunk: - data.extend(chunk) - elif data: - break # end of message - else: - continue # NOOP - header = data[0] - n_fields = header % 0x10 - tag = data[1] - buffer = UnpackableBuffer(data[2:]) - unpacker = Unpacker(buffer) - fields = [unpacker.unpack() for _ in range(n_fields)] - return tag, fields + assert self._messages + return self._messages.pop(None) def send_message(self, tag, *fields): - data = self.encode_message(tag, *fields) - self.sendall(struct_pack(">H", len(data)) + data + b"\x00\x00") - - @classmethod - def encode_message(cls, tag, *fields): - b = BytesIO() - packer = Packer(b) - for field in fields: - packer.pack(field) - return bytearray([0xB0 + len(fields), tag]) + b.getvalue() + assert self._outbox + self._outbox.append_message(tag, fields, None) + self._outbox.flush() class FakeSocketPair: - def __init__(self, address): - self.client = FakeSocket2(address) - self.server = FakeSocket2() + def __init__(self, address, packer_cls=None, unpacker_cls=None): + self.client = FakeSocket2( + address, packer_cls=packer_cls, unpacker_cls=unpacker_cls + ) + self.server = FakeSocket2( + packer_cls=packer_cls, unpacker_cls=unpacker_cls + ) self.client.on_send = self.server.inject self.server.on_send = self.client.inject diff --git a/tests/unit/sync/io/test__common.py b/tests/unit/sync/io/test__common.py index 27dad7cb9..0298573b5 100644 --- a/tests/unit/sync/io/test__common.py +++ b/tests/unit/sync/io/test__common.py @@ -18,33 +18,43 @@ import pytest +from neo4j._codec.packstream.v1 import PackableBuffer from neo4j._sync.io._common import Outbox +from ...._async_compat import mark_sync_test + @pytest.mark.parametrize(("chunk_size", "data", "result"), ( ( 2, - (bytes(range(10, 15)),), + bytes(range(10, 15)), bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)) ), ( 2, - (bytes(range(10, 14)),), + bytes(range(10, 14)), bytes((0, 2, 10, 11, 0, 2, 12, 13)) ), ( 2, - (bytes((5, 6, 7)), bytes((8, 9))), - bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9)) + bytes((5,)), + bytes((0, 1, 5)) ), )) -def test_async_outbox_chunking(chunk_size, data, result): - outbox = Outbox(max_chunk_size=chunk_size) - assert bytes(outbox.view()) == b"" - for d in data: - outbox.write(d) - assert bytes(outbox.view()) == result - # make sure this works multiple times - assert bytes(outbox.view()) == result - outbox.clear() - assert bytes(outbox.view()) == b"" +@mark_sync_test +def test_async_outbox_chunking(chunk_size, data, result, mocker): + buffer = PackableBuffer() + socket_mock = mocker.Mock() + packer_mock = mocker.Mock() + packer_mock.return_value = packer_mock + packer_mock.new_packable_buffer.return_value = buffer + packer_mock.pack_struct.side_effect = \ + lambda *args, **kwargs: buffer.write(data) + outbox = Outbox(socket_mock, pytest.fail, packer_mock, chunk_size) + outbox.append_message(None, None, None) + socket_mock.sendall.assert_not_called() + assert outbox.flush() + socket_mock.sendall.assert_called_once_with(result + b"\x00\x00") + + assert not outbox.flush() + socket_mock.sendall.assert_called_once() diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index bfa63f4fd..e4b878039 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -72,7 +72,7 @@ def test_db_extra_not_supported_in_run(fake_socket): @mark_sync_test def test_simple_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt3.UNPACKER_CLS) connection = Bolt3(address, socket, PoolConfig.max_connection_lifetime) connection.discard() connection.send_all() @@ -84,7 +84,7 @@ def test_simple_discard(fake_socket): @mark_sync_test def test_simple_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt3.UNPACKER_CLS) connection = Bolt3(address, socket, PoolConfig.max_connection_lifetime) connection.pull() connection.send_all() @@ -99,9 +99,11 @@ def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair( + address, Bolt3.PACKER_CLS, Bolt3.UNPACKER_CLS + ) sockets.client.settimeout = mocker.Mock() - sockets.server.send_message(0x70, { + sockets.server.send_message(b"\x70", { "server": "Neo4j/3.5.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index e1c0a5ccd..c6228e564 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -57,7 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") connection.send_all() @@ -70,7 +70,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") connection.send_all() @@ -85,7 +85,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) connection.send_all() @@ -105,7 +105,7 @@ def test_n_extra_in_discard(fake_socket): @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) connection.send_all() @@ -125,7 +125,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) connection.send_all() @@ -145,7 +145,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) connection.send_all() @@ -165,7 +165,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) connection.send_all() @@ -178,7 +178,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) connection.send_all() @@ -194,9 +194,11 @@ def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x0.PACKER_CLS, + unpacker_cls=Bolt4x0.UNPACKER_CLS) sockets.client.settimeout = mocker.MagicMock() - sockets.server.send_message(0x70, { + sockets.server.send_message(b"\x70", { "server": "Neo4j/4.0.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index 9a32fa8e3..283c0b475 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -57,7 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") connection.send_all() @@ -70,7 +70,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") connection.send_all() @@ -85,7 +85,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) connection.send_all() @@ -105,7 +105,7 @@ def test_n_extra_in_discard(fake_socket): @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) connection.send_all() @@ -126,7 +126,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) connection.send_all() @@ -146,7 +146,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) connection.send_all() @@ -167,7 +167,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) connection.send_all() @@ -180,7 +180,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) connection.send_all() @@ -193,15 +193,17 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - sockets.server.send_message(0x70, {"server": "Neo4j/4.1.0"}) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x1.PACKER_CLS, + unpacker_cls=Bolt4x1.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.1.0"}) connection = Bolt4x1( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) connection.hello() tag, fields = sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -212,9 +214,11 @@ def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x1.PACKER_CLS, + unpacker_cls=Bolt4x1.UNPACKER_CLS) sockets.client.settimeout = mocker.Mock() - sockets.server.send_message(0x70, { + sockets.server.send_message(b"\x70", { "server": "Neo4j/4.1.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index 145bc0850..04885598f 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -57,7 +57,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") connection.send_all() @@ -70,7 +70,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") connection.send_all() @@ -85,7 +85,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) connection.send_all() @@ -105,7 +105,7 @@ def test_n_extra_in_discard(fake_socket): @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) connection.send_all() @@ -126,7 +126,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) connection.send_all() @@ -146,7 +146,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) connection.send_all() @@ -167,7 +167,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) connection.send_all() @@ -180,7 +180,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) connection.send_all() @@ -193,15 +193,17 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - sockets.server.send_message(0x70, {"server": "Neo4j/4.2.0"}) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x2.PACKER_CLS, + unpacker_cls=Bolt4x2.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.2.0"}) connection = Bolt4x2( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) connection.hello() tag, fields = sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -212,9 +214,11 @@ def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x2.PACKER_CLS, + unpacker_cls=Bolt4x2.UNPACKER_CLS) sockets.client.settimeout = mocker.Mock() - sockets.server.send_message(0x70, { + sockets.server.send_message(b"\x70", { "server": "Neo4j/4.2.0", "hints": {"connection.recv_timeout_seconds": recv_timeout}, }) diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index fbde3872e..321494286 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -59,7 +59,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") connection.send_all() @@ -72,7 +72,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") connection.send_all() @@ -87,7 +87,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) connection.send_all() @@ -107,7 +107,7 @@ def test_n_extra_in_discard(fake_socket): @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) connection.send_all() @@ -128,7 +128,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) connection.send_all() @@ -148,7 +148,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) connection.send_all() @@ -169,7 +169,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) connection.send_all() @@ -182,7 +182,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) connection.send_all() @@ -195,15 +195,17 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - sockets.server.send_message(0x70, {"server": "Neo4j/4.3.0"}) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x3.PACKER_CLS, + unpacker_cls=Bolt4x3.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.0"}) connection = Bolt4x3( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) connection.hello() tag, fields = sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -225,10 +227,12 @@ def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x3.PACKER_CLS, + unpacker_cls=Bolt4x3.UNPACKER_CLS) sockets.client.settimeout = mocker.Mock() sockets.server.send_message( - 0x70, {"server": "Neo4j/4.3.0", "hints": hints} + b"\x70", {"server": "Neo4j/4.3.0", "hints": hints} ) connection = Bolt4x3( address, sockets.client, PoolConfig.max_connection_lifetime diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 665731727..9f2515b4e 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -68,7 +68,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) connection.send_all() @@ -89,7 +89,7 @@ def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): @mark_sync_test def test_extra_in_run(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) connection.send_all() @@ -101,7 +101,7 @@ def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_sync_test def test_n_extra_in_discard(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) connection.send_all() @@ -121,7 +121,7 @@ def test_n_extra_in_discard(fake_socket): @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) connection.send_all() @@ -142,7 +142,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) connection.send_all() @@ -162,7 +162,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) connection.send_all() @@ -183,7 +183,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) connection.send_all() @@ -196,7 +196,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): address = ("127.0.0.1", 7687) - socket = fake_socket(address) + socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) connection.send_all() @@ -209,15 +209,17 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) - sockets.server.send_message(0x70, {"server": "Neo4j/4.4.0"}) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x4.PACKER_CLS, + unpacker_cls=Bolt4x4.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) connection = Bolt4x4( address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"} ) connection.hello() tag, fields = sockets.server.pop_message() - assert tag == 0x01 + assert tag == b"\x01" assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} @@ -239,10 +241,12 @@ def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): address = ("127.0.0.1", 7687) - sockets = fake_socket_pair(address) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x4.PACKER_CLS, + unpacker_cls=Bolt4x4.UNPACKER_CLS) sockets.client.settimeout = mocker.MagicMock() sockets.server.send_message( - 0x70, {"server": "Neo4j/4.3.4", "hints": hints} + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} ) connection = Bolt4x4( address, sockets.client, PoolConfig.max_connection_lifetime diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index 4edeec99e..e8d05b5aa 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -14,9 +14,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from re import match -from unittest import mock + + import warnings +from unittest import mock import pandas as pd import pytest @@ -33,9 +34,9 @@ Version, ) from neo4j._async_compat.util import Util -from neo4j.data import ( - DataDehydrator, - DataHydrator, +from neo4j._codec.hydration.v1 import HydrationHandler +from neo4j._codec.packstream import Structure +from neo4j._data import ( Node, Relationship, ) @@ -44,7 +45,6 @@ EntitySetView, Graph, ) -from neo4j.packstream import Structure from ...._async_compat import mark_sync_test @@ -52,9 +52,24 @@ class Records: def __init__(self, fields, records): self.fields = tuple(fields) + self.hydration_scope = HydrationHandler().new_hydration_scope() self.records = tuple(records) + self._hydrate_records() + assert all(len(self.fields) == len(r) for r in self.records) + def _hydrate_records(self): + def _hydrate(value): + if type(value) in self.hydration_scope.hydration_hooks: + return self.hydration_scope.hydration_hooks[type(value)](value) + if isinstance(value, (list, tuple)): + return type(value)(_hydrate(v) for v in value) + if isinstance(value, dict): + return {k: _hydrate(v) for k, v in value.items()} + return value + + self.records = tuple(_hydrate(r) for r in self.records) + def __len__(self): return self.records.__len__() @@ -113,6 +128,7 @@ def __init__(self, records=None, run_meta=None, summary_meta=None, self.summary_meta = summary_meta ConnectionStub.server_info.update({"server": "Neo4j/4.3.0"}) self.unresolved_address = None + self._new_hydration_scope_called = False def send_all(self): self.sent += self.queued @@ -187,10 +203,20 @@ def pull(self, *args, **kwargs): def defunct(self): return False + def new_hydration_scope(self): + class FakeHydrationScope: + hydration_hooks = None + dehydration_hooks = None -class HydratorStub(DataHydrator): - def hydrate(self, values): - return values + def get_graph(self): + return Graph() + + if len(self._records) > 1: + return FakeHydrationScope() + assert not self._new_hydration_scope_called + assert self._records + self._new_hydration_scope_called = True + return self._records[0].hydration_scope def noop(*_, **__): @@ -254,7 +280,7 @@ def fetch_and_compare_all_records( @mark_sync_test def test_result_iteration(method, records): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), 2, noop, noop) + result = Result(connection, 2, noop, noop) result._run("CYPHER", {}, None, None, "r", None) fetch_and_compare_all_records(result, "x", records, method) @@ -263,7 +289,7 @@ def test_result_iteration(method, records): def test_result_iteration_mixed_methods(): records = [[i] for i in range(10)] connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), 4, noop, noop) + result = Result(connection, 4, noop, noop) result._run("CYPHER", {}, None, None, "r", None) iter1 = Util.iter(result) iter2 = Util.iter(result) @@ -299,9 +325,9 @@ def test_parallel_result_iteration(method, invert_fetch): connection = ConnectionStub( records=(Records(["x"], records1), Records(["x"], records2)) ) - result1 = Result(connection, HydratorStub(), 2, noop, noop) + result1 = Result(connection, 2, noop, noop) result1._run("CYPHER1", {}, None, None, "r", None) - result2 = Result(connection, HydratorStub(), 2, noop, noop) + result2 = Result(connection, 2, noop, noop) result2._run("CYPHER2", {}, None, None, "r", None) if invert_fetch: fetch_and_compare_all_records( @@ -329,9 +355,9 @@ def test_interwoven_result_iteration(method, invert_fetch): connection = ConnectionStub( records=(Records(["x"], records1), Records(["y"], records2)) ) - result1 = Result(connection, HydratorStub(), 2, noop, noop) + result1 = Result(connection, 2, noop, noop) result1._run("CYPHER1", {}, None, None, "r", None) - result2 = Result(connection, HydratorStub(), 2, noop, noop) + result2 = Result(connection, 2, noop, noop) result2._run("CYPHER2", {}, None, None, "r", None) start = 0 for n in (1, 2, 3, 1, None): @@ -358,7 +384,7 @@ def test_interwoven_result_iteration(method, invert_fetch): @mark_sync_test def test_result_peek(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result = Result(connection, fetch_size, noop, noop) result._run("CYPHER", {}, None, None, "r", None) for i in range(len(records) + 1): record = result.peek() @@ -381,7 +407,7 @@ def test_result_single_non_strict(records, fetch_size, default): kwargs["strict"] = False connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result = Result(connection, fetch_size, noop, noop) result._run("CYPHER", {}, None, None, "r", None) if len(records) == 0: assert result.single(**kwargs) is None @@ -400,7 +426,7 @@ def test_result_single_non_strict(records, fetch_size, default): @mark_sync_test def test_result_single_strict(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result = Result(connection, fetch_size, noop, noop) result._run("CYPHER", {}, None, None, "r", None) try: record = result.single(strict=True) @@ -427,7 +453,7 @@ def test_result_single_strict(records, fetch_size): @mark_sync_test def test_result_single_exhausts_records(records, fetch_size, strict): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result = Result(connection, fetch_size, noop, noop) result._run("CYPHER", {}, None, None, "r", None) try: with warnings.catch_warnings(): @@ -449,7 +475,7 @@ def test_result_single_exhausts_records(records, fetch_size, strict): @mark_sync_test def test_result_fetch(records, fetch_size, strict): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, HydratorStub(), fetch_size, noop, noop) + result = Result(connection, fetch_size, noop, noop) result._run("CYPHER", {}, None, None, "r", None) assert result.fetch(0) == [] assert result.fetch(-1) == [] @@ -461,7 +487,7 @@ def test_result_fetch(records, fetch_size, strict): @mark_sync_test def test_keys_are_available_before_and_after_stream(): connection = ConnectionStub(records=Records(["x"], [[1], [2]])) - result = Result(connection, HydratorStub(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) assert list(result.keys()) == ["x"] Util.list(result) @@ -477,7 +503,7 @@ def test_consume(records, consume_one, summary_meta, consume_times): connection = ConnectionStub( records=Records(["x"], records), summary_meta=summary_meta ) - result = Result(connection, HydratorStub(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) if consume_one: try: @@ -512,7 +538,7 @@ def test_time_in_summary(t_first, t_last): summary_meta=summary_meta ) - result = Result(connection, HydratorStub(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) summary = result.consume() @@ -534,7 +560,7 @@ def test_time_in_summary(t_first, t_last): def test_counts_in_summary(): connection = ConnectionStub(records=Records(["n"], [[1], [2]])) - result = Result(connection, HydratorStub(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) summary = result.consume() @@ -548,7 +574,7 @@ def test_query_type(query_type): records=Records(["n"], [[1], [2]]), summary_meta={"type": query_type} ) - result = Result(connection, HydratorStub(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) summary = result.consume() @@ -563,7 +589,7 @@ def test_data(num_records): records=Records(["n"], [[i + 1] for i in range(num_records)]) ) - result = Result(connection, HydratorStub(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) result._buffer_all() records = result._record_buffer.copy() @@ -578,6 +604,7 @@ def test_data(num_records): assert record.data.called_once_with("hello", "world") +# TODO: dehydration now happens on a much lower level @pytest.mark.parametrize("records", ( Records(["n"], []), Records(["n"], [[42], [69], [420], [1337]]), @@ -603,8 +630,9 @@ def test_result_graph(records, scripted_connection): "on_summary": None }), )) - result = Result(scripted_connection, DataHydrator(), 1, noop, - noop) + scripted_connection.new_hydration_scope.return_value = \ + records.hydration_scope + result = Result(scripted_connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) graph = result.graph() assert isinstance(graph, Graph) @@ -702,7 +730,7 @@ def test_result_graph(records, scripted_connection): @mark_sync_test def test_to_df(keys, values, types, instances, test_default_expand): connection = ConnectionStub(records=Records(keys, values)) - result = Result(connection, DataHydrator(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) if test_default_expand: df = result.to_df() @@ -807,12 +835,12 @@ def test_to_df(keys, values, types, instances, test_default_expand): ( ["n"], list(zip(( - Structure(b"N", 0, ["LABEL_A"], - {"a": 1, "b": 2, "d": 1}, "00"), - Structure(b"N", 2, ["LABEL_B"], - {"a": 1, "c": 1.2, "d": 2}, "02"), - Structure(b"N", 1, ["LABEL_A", "LABEL_B"], - {"a": [1, "a"], "d": 3}, "01"), + Node(None, "00", 0, ["LABEL_A"], + {"a": 1, "b": 2, "d": 1}), + Node(None, "02", 2, ["LABEL_B"], + {"a": 1, "c": 1.2, "d": 2}), + Node(None, "01", 1, ["LABEL_A", "LABEL_B"], + {"a": [1, "a"], "d": 3}), ))), [ "n().element_id", "n().labels", "n().prop.a", "n().prop.b", @@ -848,11 +876,7 @@ def test_to_df(keys, values, types, instances, test_default_expand): ), ( ["dt"], - [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - ], + [[neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)]], ["dt"], [[neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)]], ["object"], @@ -863,7 +887,7 @@ def test_to_df(keys, values, types, instances, test_default_expand): def test_to_df_expand(keys, values, expected_columns, expected_rows, expected_types): connection = ConnectionStub(records=Records(keys, values)) - result = Result(connection, DataHydrator(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) df = result.to_df(expand=True) @@ -895,9 +919,7 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["dt"], [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], ], pd.DataFrame( [[pd.Timestamp("2022-01-02 03:04:05.000000006")]], @@ -908,9 +930,7 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["d"], [ - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), + [neo4j_time.Date(2222, 2, 22)], ], pd.DataFrame( [[pd.Timestamp("2222-02-22")]], @@ -921,11 +941,11 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["dt_tz"], [ - DataDehydrator().dehydrate(( + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], ], pd.DataFrame( [[ @@ -941,17 +961,13 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ["mixed"], [ [None], - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), - DataDehydrator().dehydrate(( + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], + [neo4j_time.Date(2222, 2, 22)], + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], ], pd.DataFrame( [ @@ -971,18 +987,14 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["mixed"], [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6)], + [neo4j_time.Date(2222, 2, 22)], [None], - DataDehydrator().dehydrate(( + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], ], pd.DataFrame( [ @@ -1002,17 +1014,13 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ( ["mixed"], [ - DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), - DataDehydrator().dehydrate(( - neo4j_time.Date(2222, 2, 22), - )), - DataDehydrator().dehydrate(( + [neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6),], + [neo4j_time.Date(2222, 2, 22),], + [ pytz.timezone("Europe/Stockholm").localize( neo4j_time.DateTime(1970, 1, 1, 0, 0, 0, 0) ), - )), + ], [None], ], pd.DataFrame( @@ -1052,9 +1060,7 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, ], [ None, - *DataDehydrator().dehydrate(( - neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), - )), + neo4j_time.DateTime(2022, 1, 2, 3, 4, 5, 6), 1.234, ], ], @@ -1080,7 +1086,7 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, @mark_sync_test def test_to_df_parse_dates(keys, values, expected_df, expand): connection = ConnectionStub(records=Records(keys, values)) - result = Result(connection, DataHydrator(), 1, noop, noop) + result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) df = result.to_df(expand=expand, parse_dates=True) diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index 92edd9aab..839d24a37 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -37,7 +37,15 @@ @pytest.fixture() def pool(fake_connection_generator, mocker): pool = mocker.Mock(spec=IOPool) - pool.acquire.side_effect = iter(fake_connection_generator, 0) + assert not hasattr(pool, "acquired_connection_mocks") + pool.acquired_connection_mocks = [] + + def acquire_side_effect(*_, **__): + connection = fake_connection_generator() + pool.acquired_connection_mocks.append(connection) + return connection + + pool.acquire.side_effect = acquire_side_effect return pool @@ -267,57 +275,46 @@ def test_session_tx_type(pool): assert isinstance(tx, Transaction) -@pytest.mark.parametrize(("parameters", "error_type"), ( - ({"x": None}, None), - ({"x": True}, None), - ({"x": False}, None), - ({"x": 123456789}, None), - ({"x": 3.1415926}, None), - ({"x": float("nan")}, None), - ({"x": float("inf")}, None), - ({"x": float("-inf")}, None), - ({"x": "foo"}, None), - ({"x": bytearray([0x00, 0x33, 0x66, 0x99, 0xCC, 0xFF])}, None), - ({"x": b"\x00\x33\x66\x99\xcc\xff"}, None), - ({"x": [1, 2, 3]}, None), - ({"x": ["a", "b", "c"]}, None), - ({"x": ["a", 2, 1.234]}, None), - ({"x": ["a", 2, ["c"]]}, None), - ({"x": {"one": "eins", "two": "zwei", "three": "drei"}}, None), - ({"x": {"one": ["eins", "uno", 1], "two": ["zwei", "dos", 2]}}, None), - - # maps must have string keys - ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), - ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), +@pytest.mark.parametrize("parameters", ( + {"x": None}, + {"x": True}, + {"x": False}, + {"x": 123456789}, + {"x": 3.1415926}, + {"x": float("nan")}, + {"x": float("inf")}, + {"x": float("-inf")}, + {"x": "foo"}, + {"x": bytearray([0x00, 0x33, 0x66, 0x99, 0xCC, 0xFF])}, + {"x": b"\x00\x33\x66\x99\xcc\xff"}, + {"x": [1, 2, 3]}, + {"x": ["a", "b", "c"]}, + {"x": ["a", 2, 1.234]}, + {"x": ["a", 2, ["c"]]}, + {"x": {"one": "eins", "two": "zwei", "three": "drei"}}, + {"x": {"one": ["eins", "uno", 1], "two": ["zwei", "dos", 2]}}, )) @pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed")) @mark_sync_test def test_session_run_with_parameters( - pool, parameters, error_type, run_type + pool, parameters, run_type, mocker ): - @contextmanager - def raises(): - if error_type is not None: - with pytest.raises(error_type) as exc: - yield exc - else: - yield None - with Session(pool, SessionConfig()) as session: if run_type == "auto": - with raises(): - session.run("RETURN $x", **parameters) + session.run("RETURN $x", **parameters) elif run_type == "unmanaged": tx = session.begin_transaction() - with raises(): - tx.run("RETURN $x", **parameters) + tx.run("RETURN $x", **parameters) elif run_type == "managed": def work(tx): - with raises() as exc: - tx.run("RETURN $x", **parameters) - if exc is not None: - raise exc - with raises(): - session.write_transaction(work) + tx.run("RETURN $x", **parameters) + session.write_transaction(work) else: raise ValueError(run_type) + + assert len(pool.acquired_connection_mocks) == 1 + connection_mock = pool.acquired_connection_mocks[0] + assert connection_mock.run.called_once() + call = connection_mock.run.call_args + assert call.args[0] == "RETURN $x" + assert call.kwargs["parameters"] == parameters diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index 3c5dfcbee..9a3440faf 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -113,23 +113,6 @@ class OopsError(RuntimeError): assert tx_.closed() -@pytest.mark.parametrize(("parameters", "error_type"), ( - # maps must have string keys - ({"x": {1: 'eins', 2: 'zwei', 3: 'drei'}}, TypeError), - ({"x": {(1, 2): '1+2i', (2, 0): '2'}}, TypeError), - ({"x": uuid4()}, TypeError), -)) -@mark_sync_test -def test_transaction_run_with_invalid_parameters( - fake_connection, parameters, error_type -): - on_closed = MagicMock() - on_error = MagicMock() - tx = Transaction(fake_connection, 2, on_closed, on_error) - with pytest.raises(error_type): - tx.run("RETURN $x", **parameters) - - @mark_sync_test def test_transaction_run_takes_no_query_object(fake_connection): on_closed = MagicMock() From e06e920bfbb5395fb30fa41ac11579e05cda623d Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Wed, 29 Jun 2022 10:42:34 +0200 Subject: [PATCH 2/5] Temporary rename --- neo4j/{_conf.py => __conf.py} | 0 neo4j/__init__.py | 10 +++++----- neo4j/_async/driver.py | 4 ++-- neo4j/_sync/driver.py | 4 ++-- neo4j/conf.py | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) rename neo4j/{_conf.py => __conf.py} (100%) diff --git a/neo4j/_conf.py b/neo4j/__conf.py similarity index 100% rename from neo4j/_conf.py rename to neo4j/__conf.py diff --git a/neo4j/__init__.py b/neo4j/__init__.py index 1251f0f03..9f4d2bbfb 100644 --- a/neo4j/__init__.py +++ b/neo4j/__init__.py @@ -71,6 +71,11 @@ from logging import getLogger +from .__conf import ( + TrustAll, + TrustCustomCAs, + TrustSystemCAs, +) from ._async.driver import ( AsyncBoltDriver, AsyncDriver, @@ -83,11 +88,6 @@ AsyncSession, AsyncTransaction, ) -from ._conf import ( - TrustAll, - TrustCustomCAs, - TrustSystemCAs, -) from ._data import Record from ._meta import ( ExperimentalWarning, diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py index 72aee4045..a272b1089 100644 --- a/neo4j/_async/driver.py +++ b/neo4j/_async/driver.py @@ -16,11 +16,11 @@ # limitations under the License. -from .._async_compat.util import AsyncUtil -from .._conf import ( +from ..__conf import ( TrustAll, TrustStore, ) +from .._async_compat.util import AsyncUtil from .._meta import ( deprecation_warn, experimental, diff --git a/neo4j/_sync/driver.py b/neo4j/_sync/driver.py index aa03264f1..5c208f5ee 100644 --- a/neo4j/_sync/driver.py +++ b/neo4j/_sync/driver.py @@ -16,11 +16,11 @@ # limitations under the License. -from .._async_compat.util import Util -from .._conf import ( +from ..__conf import ( TrustAll, TrustStore, ) +from .._async_compat.util import Util from .._meta import ( deprecation_warn, experimental, diff --git a/neo4j/conf.py b/neo4j/conf.py index c578cba7c..b2e567831 100644 --- a/neo4j/conf.py +++ b/neo4j/conf.py @@ -19,7 +19,7 @@ from abc import ABCMeta from collections.abc import Mapping -from ._conf import ( +from .__conf import ( TrustAll, TrustCustomCAs, TrustSystemCAs, From 4075ed51531b27905ee4c4a85e33fbe0a231c01a Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Wed, 29 Jun 2022 10:47:13 +0200 Subject: [PATCH 3/5] Make `neo4j.conf` and `neo4j.routing` internal --- neo4j/__conf.py | 86 ---------------------- neo4j/__init__.py | 20 +++-- neo4j/_async/driver.py | 14 ++-- neo4j/_async/io/_bolt.py | 2 +- neo4j/_async/io/_pool.py | 10 +-- neo4j/_async/work/session.py | 2 +- neo4j/_async/work/workspace.py | 2 +- neo4j/{conf.py => _conf.py} | 75 +++++++++++++++++-- neo4j/_data.py | 2 +- neo4j/{routing.py => _routing.py} | 2 +- neo4j/_sync/driver.py | 14 ++-- neo4j/_sync/io/_bolt.py | 2 +- neo4j/_sync/io/_pool.py | 10 +-- neo4j/_sync/work/session.py | 2 +- neo4j/_sync/work/workspace.py | 2 +- tests/unit/async_/io/test_class_bolt3.py | 2 +- tests/unit/async_/io/test_class_bolt4x0.py | 2 +- tests/unit/async_/io/test_class_bolt4x1.py | 2 +- tests/unit/async_/io/test_class_bolt4x2.py | 2 +- tests/unit/async_/io/test_class_bolt4x3.py | 2 +- tests/unit/async_/io/test_class_bolt4x4.py | 2 +- tests/unit/async_/io/test_neo4j_pool.py | 6 +- tests/unit/common/io/test_routing.py | 4 +- tests/unit/common/test_conf.py | 12 +-- tests/unit/sync/io/test_class_bolt3.py | 2 +- tests/unit/sync/io/test_class_bolt4x0.py | 2 +- tests/unit/sync/io/test_class_bolt4x1.py | 2 +- tests/unit/sync/io/test_class_bolt4x2.py | 2 +- tests/unit/sync/io/test_class_bolt4x3.py | 2 +- tests/unit/sync/io/test_class_bolt4x4.py | 2 +- tests/unit/sync/io/test_neo4j_pool.py | 8 +- 31 files changed, 136 insertions(+), 163 deletions(-) delete mode 100644 neo4j/__conf.py rename neo4j/{conf.py => _conf.py} (87%) rename neo4j/{routing.py => _routing.py} (99%) diff --git a/neo4j/__conf.py b/neo4j/__conf.py deleted file mode 100644 index 6237743e8..000000000 --- a/neo4j/__conf.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class TrustStore: - # Base class for trust stores. For internal type-checking only. - pass - - -class TrustSystemCAs(TrustStore): - """Used to configure the driver to trust system CAs (default). - - Trust server certificates that can be verified against the system - certificate authority. This option is primarily intended for use with - full certificates. - - For example:: - - import neo4j - - driver = neo4j.GraphDatabase.driver( - url, auth=auth, trusted_certificates=neo4j.TrustSystemCAs() - ) - """ - pass - - -class TrustAll(TrustStore): - """Used to configure the driver to trust all certificates. - - Trust any server certificate. This ensures that communication - is encrypted but does not verify the server certificate against a - certificate authority. This option is primarily intended for use with - the default auto-generated server certificate. - - - For example:: - - import neo4j - - driver = neo4j.GraphDatabase.driver( - url, auth=auth, trusted_certificates=neo4j.TrustAll() - ) - """ - pass - - -class TrustCustomCAs(TrustStore): - """Used to configure the driver to trust custom CAs. - - Trust server certificates that can be verified against the certificate - authority at the specified paths. This option is primarily intended for - self-signed and custom certificates. - - :param certificates (str): paths to the certificates to trust. - Those are not the certificates you expect to see from the server but - the CA certificates you expect to be used to sign the server's - certificate. - - For example:: - - import neo4j - - driver = neo4j.GraphDatabase.driver( - url, auth=auth, - trusted_certificates=neo4j.TrustCustomCAs( - "/path/to/ca1.crt", "/path/to/ca2.crt", - ) - ) - """ - def __init__(self, *certificates): - self.certs = certificates diff --git a/neo4j/__init__.py b/neo4j/__init__.py index 9f4d2bbfb..76f842457 100644 --- a/neo4j/__init__.py +++ b/neo4j/__init__.py @@ -71,11 +71,6 @@ from logging import getLogger -from .__conf import ( - TrustAll, - TrustCustomCAs, - TrustSystemCAs, -) from ._async.driver import ( AsyncBoltDriver, AsyncDriver, @@ -88,6 +83,15 @@ AsyncSession, AsyncTransaction, ) +from ._conf import ( + Config, + PoolConfig, + SessionConfig, + TrustAll, + TrustCustomCAs, + TrustSystemCAs, + WorkspaceConfig, +) from ._data import Record from ._meta import ( ExperimentalWarning, @@ -131,12 +135,6 @@ Version, WRITE_ACCESS, ) -from .conf import ( - Config, - PoolConfig, - SessionConfig, - WorkspaceConfig, -) from .work import ( Query, ResultSummary, diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py index a272b1089..e2857b83b 100644 --- a/neo4j/_async/driver.py +++ b/neo4j/_async/driver.py @@ -16,11 +16,15 @@ # limitations under the License. -from ..__conf import ( +from .._async_compat.util import AsyncUtil +from .._conf import ( + Config, + PoolConfig, + SessionConfig, TrustAll, TrustStore, + WorkspaceConfig, ) -from .._async_compat.util import AsyncUtil from .._meta import ( deprecation_warn, experimental, @@ -32,12 +36,6 @@ TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, ) -from ..conf import ( - Config, - PoolConfig, - SessionConfig, - WorkspaceConfig, -) class AsyncGraphDatabase: diff --git a/neo4j/_async/io/_bolt.py b/neo4j/_async/io/_bolt.py index 66f14aa10..7f31f32dd 100644 --- a/neo4j/_async/io/_bolt.py +++ b/neo4j/_async/io/_bolt.py @@ -26,6 +26,7 @@ from ..._async_compat.util import AsyncUtil from ..._codec.hydration import v1 as hydration_v1 from ..._codec.packstream import v1 as packstream_v1 +from ..._conf import PoolConfig from ..._exceptions import ( BoltError, BoltHandshakeError, @@ -37,7 +38,6 @@ ServerInfo, Version, ) -from ...conf import PoolConfig from ...exceptions import ( AuthError, DriverError, diff --git a/neo4j/_async/io/_pool.py b/neo4j/_async/io/_pool.py index 40837d3f1..0a3bfac58 100644 --- a/neo4j/_async/io/_pool.py +++ b/neo4j/_async/io/_pool.py @@ -31,6 +31,10 @@ AsyncRLock, ) from ..._async_compat.network import AsyncNetworkUtil +from ..._conf import ( + PoolConfig, + WorkspaceConfig, +) from ..._deadline import ( connection_deadline, Deadline, @@ -38,14 +42,11 @@ merge_deadlines_and_timeouts, ) from ..._exceptions import BoltError +from ..._routing import RoutingTable from ...api import ( READ_ACCESS, WRITE_ACCESS, ) -from ...conf import ( - PoolConfig, - WorkspaceConfig, -) from ...exceptions import ( ClientError, ConfigurationError, @@ -56,7 +57,6 @@ SessionExpired, WriteServiceUnavailable, ) -from ...routing import RoutingTable from ._bolt import AsyncBolt diff --git a/neo4j/_async/work/session.py b/neo4j/_async/work/session.py index eed2d9d3a..c69d84e95 100644 --- a/neo4j/_async/work/session.py +++ b/neo4j/_async/work/session.py @@ -21,6 +21,7 @@ from time import perf_counter from ..._async_compat import async_sleep +from ..._conf import SessionConfig from ..._meta import ( deprecated, deprecation_warn, @@ -30,7 +31,6 @@ READ_ACCESS, WRITE_ACCESS, ) -from ...conf import SessionConfig from ...exceptions import ( ClientError, DriverError, diff --git a/neo4j/_async/work/workspace.py b/neo4j/_async/work/workspace.py index 3a374f7a1..9c589db57 100644 --- a/neo4j/_async/work/workspace.py +++ b/neo4j/_async/work/workspace.py @@ -18,12 +18,12 @@ import asyncio +from ..._conf import WorkspaceConfig from ..._deadline import Deadline from ..._meta import ( deprecation_warn, unclosed_resource_warn, ) -from ...conf import WorkspaceConfig from ...exceptions import ( ServiceUnavailable, SessionExpired, diff --git a/neo4j/conf.py b/neo4j/_conf.py similarity index 87% rename from neo4j/conf.py rename to neo4j/_conf.py index b2e567831..a3e290cc0 100644 --- a/neo4j/conf.py +++ b/neo4j/_conf.py @@ -19,11 +19,6 @@ from abc import ABCMeta from collections.abc import Mapping -from .__conf import ( - TrustAll, - TrustCustomCAs, - TrustSystemCAs, -) from ._meta import ( deprecation_warn, get_user_agent, @@ -52,6 +47,76 @@ def iter_items(iterable): yield key, value +class TrustStore: + # Base class for trust stores. For internal type-checking only. + pass + + +class TrustSystemCAs(TrustStore): + """Used to configure the driver to trust system CAs (default). + + Trust server certificates that can be verified against the system + certificate authority. This option is primarily intended for use with + full certificates. + + For example:: + + import neo4j + + driver = neo4j.GraphDatabase.driver( + url, auth=auth, trusted_certificates=neo4j.TrustSystemCAs() + ) + """ + pass + + +class TrustAll(TrustStore): + """Used to configure the driver to trust all certificates. + + Trust any server certificate. This ensures that communication + is encrypted but does not verify the server certificate against a + certificate authority. This option is primarily intended for use with + the default auto-generated server certificate. + + + For example:: + + import neo4j + + driver = neo4j.GraphDatabase.driver( + url, auth=auth, trusted_certificates=neo4j.TrustAll() + ) + """ + pass + + +class TrustCustomCAs(TrustStore): + """Used to configure the driver to trust custom CAs. + + Trust server certificates that can be verified against the certificate + authority at the specified paths. This option is primarily intended for + self-signed and custom certificates. + + :param certificates (str): paths to the certificates to trust. + Those are not the certificates you expect to see from the server but + the CA certificates you expect to be used to sign the server's + certificate. + + For example:: + + import neo4j + + driver = neo4j.GraphDatabase.driver( + url, auth=auth, + trusted_certificates=neo4j.TrustCustomCAs( + "/path/to/ca1.crt", "/path/to/ca2.crt", + ) + ) + """ + def __init__(self, *certificates): + self.certs = certificates + + class DeprecatedAlias: """Used when a config option has been renamed.""" diff --git a/neo4j/_data.py b/neo4j/_data.py index 2daf7be03..9207b50f5 100644 --- a/neo4j/_data.py +++ b/neo4j/_data.py @@ -28,7 +28,7 @@ from functools import reduce from operator import xor as xor_operator -from .conf import iter_items +from ._conf import iter_items from .graph import ( Node, Path, diff --git a/neo4j/routing.py b/neo4j/_routing.py similarity index 99% rename from neo4j/routing.py rename to neo4j/_routing.py index 99364fccc..a073dda3a 100644 --- a/neo4j/routing.py +++ b/neo4j/_routing.py @@ -146,7 +146,7 @@ def should_be_purged_from_memory(self): :return: Returns true if it is old and not used for a while. :rtype: bool """ - from neo4j.conf import RoutingConfig + from neo4j._conf import RoutingConfig perf_time = perf_counter() log.debug("[#0000] C: last_updated_time=%r perf_time=%r", self.last_updated_time, perf_time) return self.last_updated_time + self.ttl + RoutingConfig.routing_table_purge_delay <= perf_time diff --git a/neo4j/_sync/driver.py b/neo4j/_sync/driver.py index 5c208f5ee..3c4fbbaf9 100644 --- a/neo4j/_sync/driver.py +++ b/neo4j/_sync/driver.py @@ -16,11 +16,15 @@ # limitations under the License. -from ..__conf import ( +from .._async_compat.util import Util +from .._conf import ( + Config, + PoolConfig, + SessionConfig, TrustAll, TrustStore, + WorkspaceConfig, ) -from .._async_compat.util import Util from .._meta import ( deprecation_warn, experimental, @@ -32,12 +36,6 @@ TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, ) -from ..conf import ( - Config, - PoolConfig, - SessionConfig, - WorkspaceConfig, -) class GraphDatabase: diff --git a/neo4j/_sync/io/_bolt.py b/neo4j/_sync/io/_bolt.py index 2b6fb8e65..b7f9ecd88 100644 --- a/neo4j/_sync/io/_bolt.py +++ b/neo4j/_sync/io/_bolt.py @@ -26,6 +26,7 @@ from ..._async_compat.util import Util from ..._codec.hydration import v1 as hydration_v1 from ..._codec.packstream import v1 as packstream_v1 +from ..._conf import PoolConfig from ..._exceptions import ( BoltError, BoltHandshakeError, @@ -37,7 +38,6 @@ ServerInfo, Version, ) -from ...conf import PoolConfig from ...exceptions import ( AuthError, DriverError, diff --git a/neo4j/_sync/io/_pool.py b/neo4j/_sync/io/_pool.py index 2de6d090d..3cd66a6d3 100644 --- a/neo4j/_sync/io/_pool.py +++ b/neo4j/_sync/io/_pool.py @@ -31,6 +31,10 @@ RLock, ) from ..._async_compat.network import NetworkUtil +from ..._conf import ( + PoolConfig, + WorkspaceConfig, +) from ..._deadline import ( connection_deadline, Deadline, @@ -38,14 +42,11 @@ merge_deadlines_and_timeouts, ) from ..._exceptions import BoltError +from ..._routing import RoutingTable from ...api import ( READ_ACCESS, WRITE_ACCESS, ) -from ...conf import ( - PoolConfig, - WorkspaceConfig, -) from ...exceptions import ( ClientError, ConfigurationError, @@ -56,7 +57,6 @@ SessionExpired, WriteServiceUnavailable, ) -from ...routing import RoutingTable from ._bolt import Bolt diff --git a/neo4j/_sync/work/session.py b/neo4j/_sync/work/session.py index 861bfaa52..ec82c9a06 100644 --- a/neo4j/_sync/work/session.py +++ b/neo4j/_sync/work/session.py @@ -21,6 +21,7 @@ from time import perf_counter from ..._async_compat import sleep +from ..._conf import SessionConfig from ..._meta import ( deprecated, deprecation_warn, @@ -30,7 +31,6 @@ READ_ACCESS, WRITE_ACCESS, ) -from ...conf import SessionConfig from ...exceptions import ( ClientError, DriverError, diff --git a/neo4j/_sync/work/workspace.py b/neo4j/_sync/work/workspace.py index f8c305930..c10fc912e 100644 --- a/neo4j/_sync/work/workspace.py +++ b/neo4j/_sync/work/workspace.py @@ -18,12 +18,12 @@ import asyncio +from ..._conf import WorkspaceConfig from ..._deadline import Deadline from ..._meta import ( deprecation_warn, unclosed_resource_warn, ) -from ...conf import WorkspaceConfig from ...exceptions import ( ServiceUnavailable, SessionExpired, diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index 22524e02b..aa6aac101 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -19,7 +19,7 @@ import pytest from neo4j._async.io._bolt3 import AsyncBolt3 -from neo4j.conf import PoolConfig +from neo4j._conf import PoolConfig from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index 64e1a8ddb..56b15f4d0 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -19,7 +19,7 @@ import pytest from neo4j._async.io._bolt4 import AsyncBolt4x0 -from neo4j.conf import PoolConfig +from neo4j._conf import PoolConfig from ...._async_compat import mark_async_test diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index a647fd5d7..4371f005e 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -19,7 +19,7 @@ import pytest from neo4j._async.io._bolt4 import AsyncBolt4x1 -from neo4j.conf import PoolConfig +from neo4j._conf import PoolConfig from ...._async_compat import mark_async_test diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index fb77b3733..804038cb1 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -19,7 +19,7 @@ import pytest from neo4j._async.io._bolt4 import AsyncBolt4x2 -from neo4j.conf import PoolConfig +from neo4j._conf import PoolConfig from ...._async_compat import mark_async_test diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index 36698de1c..7538e127b 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -21,7 +21,7 @@ import pytest from neo4j._async.io._bolt4 import AsyncBolt4x3 -from neo4j.conf import PoolConfig +from neo4j._conf import PoolConfig from ...._async_compat import mark_async_test diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index 62a0dcec3..285aa9744 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -21,7 +21,7 @@ import pytest from neo4j._async.io._bolt4 import AsyncBolt4x4 -from neo4j.conf import PoolConfig +from neo4j._conf import PoolConfig from ...._async_compat import mark_async_test diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index b98fcd73a..a2c7156cd 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -23,13 +23,13 @@ WRITE_ACCESS, ) from neo4j._async.io import AsyncNeo4jPool -from neo4j._deadline import Deadline -from neo4j.addressing import ResolvedAddress -from neo4j.conf import ( +from neo4j._conf import ( PoolConfig, RoutingConfig, WorkspaceConfig, ) +from neo4j._deadline import Deadline +from neo4j.addressing import ResolvedAddress from neo4j.exceptions import ( ServiceUnavailable, SessionExpired, diff --git a/tests/unit/common/io/test_routing.py b/tests/unit/common/io/test_routing.py index 030768c58..4dc8ecd82 100644 --- a/tests/unit/common/io/test_routing.py +++ b/tests/unit/common/io/test_routing.py @@ -18,11 +18,11 @@ import pytest -from neo4j.api import DEFAULT_DATABASE -from neo4j.routing import ( +from neo4j._routing import ( OrderedSet, RoutingTable, ) +from neo4j.api import DEFAULT_DATABASE VALID_ROUTING_RECORD = { diff --git a/tests/unit/common/test_conf.py b/tests/unit/common/test_conf.py index db3497263..390f2cab7 100644 --- a/tests/unit/common/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -23,18 +23,18 @@ TrustCustomCAs, TrustSystemCAs, ) +from neo4j._conf import ( + Config, + PoolConfig, + SessionConfig, + WorkspaceConfig, +) from neo4j.api import ( READ_ACCESS, TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, WRITE_ACCESS, ) -from neo4j.conf import ( - Config, - PoolConfig, - SessionConfig, - WorkspaceConfig, -) from neo4j.debug import watch from neo4j.exceptions import ConfigurationError diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index e4b878039..87f477d8d 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -18,8 +18,8 @@ import pytest +from neo4j._conf import PoolConfig from neo4j._sync.io._bolt3 import Bolt3 -from neo4j.conf import PoolConfig from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index c6228e564..88f549936 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -18,8 +18,8 @@ import pytest +from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x0 -from neo4j.conf import PoolConfig from ...._async_compat import mark_sync_test diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index 283c0b475..e656cc349 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -18,8 +18,8 @@ import pytest +from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x1 -from neo4j.conf import PoolConfig from ...._async_compat import mark_sync_test diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index 04885598f..d6bff9c23 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -18,8 +18,8 @@ import pytest +from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x2 -from neo4j.conf import PoolConfig from ...._async_compat import mark_sync_test diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index 321494286..474b15857 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -20,8 +20,8 @@ import pytest +from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x3 -from neo4j.conf import PoolConfig from ...._async_compat import mark_sync_test diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 9f2515b4e..564660966 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -20,8 +20,8 @@ import pytest +from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x4 -from neo4j.conf import PoolConfig from ...._async_compat import mark_sync_test diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 6c9b5db62..3b2cba764 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -22,14 +22,14 @@ READ_ACCESS, WRITE_ACCESS, ) -from neo4j._deadline import Deadline -from neo4j._sync.io import Neo4jPool -from neo4j.addressing import ResolvedAddress -from neo4j.conf import ( +from neo4j._conf import ( PoolConfig, RoutingConfig, WorkspaceConfig, ) +from neo4j._deadline import Deadline +from neo4j._sync.io import Neo4jPool +from neo4j.addressing import ResolvedAddress from neo4j.exceptions import ( ServiceUnavailable, SessionExpired, From c8c7c6172b150e14203c3aa36a5ef8ce7c00d853 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Wed, 29 Jun 2022 10:20:15 +0200 Subject: [PATCH 4/5] Add proxies with deprecation warning for internal modules --- CHANGELOG.md | 7 ++ neo4j/__init__.py | 138 ++++++++++++++---------- neo4j/_async/driver.py | 3 +- neo4j/_codec/hydration/_common.py | 2 +- neo4j/_sync/driver.py | 3 +- neo4j/api.py | 5 + neo4j/conf.py | 53 +++++++++ neo4j/data.py | 45 ++++++++ neo4j/graph/__init__.py | 2 +- neo4j/meta.py | 48 +++++++++ neo4j/packstream.py | 58 ++++++++++ neo4j/routing.py | 37 +++++++ neo4j/time/_clock_implementations.py | 4 +- neo4j/time/arithmetic.py | 43 ++++++++ neo4j/time/clock_implementations.py | 39 +++++++ neo4j/time/hydration.py | 57 ++++++++++ neo4j/time/metaclasses.py | 37 +++++++ setup.cfg | 1 + tests/unit/async_/io/test_direct.py | 6 +- tests/unit/async_/io/test_neo4j_pool.py | 24 +++-- tests/unit/async_/test_driver.py | 43 ++++++-- tests/unit/async_/work/test_result.py | 16 +-- tests/unit/async_/work/test_session.py | 15 +-- tests/unit/common/spatial/test_point.py | 2 +- tests/unit/common/test_api.py | 28 ++++- tests/unit/common/test_import_neo4j.py | 10 +- tests/unit/common/test_record.py | 2 +- tests/unit/common/test_types.py | 4 +- tests/unit/common/time/test_clock.py | 4 +- tests/unit/sync/io/test_direct.py | 2 +- tests/unit/sync/io/test_neo4j_pool.py | 20 ++-- tests/unit/sync/test_driver.py | 43 ++++++-- tests/unit/sync/work/test_result.py | 16 +-- tests/unit/sync/work/test_session.py | 15 +-- 34 files changed, 690 insertions(+), 142 deletions(-) create mode 100644 neo4j/conf.py create mode 100644 neo4j/data.py create mode 100644 neo4j/meta.py create mode 100644 neo4j/packstream.py create mode 100644 neo4j/routing.py create mode 100644 neo4j/time/arithmetic.py create mode 100644 neo4j/time/clock_implementations.py create mode 100644 neo4j/time/hydration.py create mode 100644 neo4j/time/metaclasses.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0835a6e9c..f9f7cf9df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -104,6 +104,13 @@ deprecated without replacement. They are internal functions. - Importing `neo4j.packstream` has been deprecated. It's internal and should not be used by client code. +- Importing `neo4j.routing` has been deprecated. It's internal and should not + be used by client code. +- Importing `neo4j.config` has been deprecated. It's internal and should not + be used by client code. +- `neoj4.Config`, `neoj4.PoolConfig`, `neoj4.SessionConfig`, and + `neoj4.WorkspaceConfig` have been deprecated without replacement. They are + internal classes. - Importing `neo4j.meta` has been deprecated. It's internal and should not be used by client code. `ExperimantalWarning` should be imported directly from `neo4j`. `neo4j.meta.version` is exposed through `neo4j.__vesrion__` diff --git a/neo4j/__init__.py b/neo4j/__init__.py index 76f842457..89acc690d 100644 --- a/neo4j/__init__.py +++ b/neo4j/__init__.py @@ -16,60 +16,7 @@ # limitations under the License. -__all__ = [ - "__version__", - "Address", - "AsyncBoltDriver", - "AsyncDriver", - "AsyncGraphDatabase", - "AsyncManagedTransaction", - "AsyncNeo4jDriver", - "AsyncResult", - "AsyncSession", - "AsyncTransaction", - "Auth", - "AuthToken", - "basic_auth", - "bearer_auth", - "BoltDriver", - "Bookmark", - "Bookmarks", - "Config", - "custom_auth", - "DEFAULT_DATABASE", - "Driver", - "ExperimentalWarning", - "get_user_agent", - "GraphDatabase", - "IPv4Address", - "IPv6Address", - "kerberos_auth", - "ManagedTransaction", - "Neo4jDriver", - "PoolConfig", - "Query", - "READ_ACCESS", - "Record", - "Result", - "ResultSummary", - "ServerInfo", - "Session", - "SessionConfig", - "SummaryCounters", - "Transaction", - "TRUST_ALL_CERTIFICATES", - "TRUST_SYSTEM_CA_SIGNED_CERTIFICATES", - "TrustAll", - "TrustCustomCAs", - "TrustSystemCAs", - "unit_of_work", - "Version", - "WorkspaceConfig", - "WRITE_ACCESS", -] - - -from logging import getLogger +from logging import getLogger as _getLogger from ._async.driver import ( AsyncBoltDriver, @@ -84,13 +31,13 @@ AsyncTransaction, ) from ._conf import ( - Config, - PoolConfig, - SessionConfig, + Config as _Config, + PoolConfig as _PoolConfig, + SessionConfig as _SessionConfig, TrustAll, TrustCustomCAs, TrustSystemCAs, - WorkspaceConfig, + WorkspaceConfig as _WorkspaceConfig, ) from ._data import Record from ._meta import ( @@ -143,4 +90,77 @@ ) -log = getLogger("neo4j") +__all__ = [ + "__version__", + "Address", + "AsyncBoltDriver", + "AsyncDriver", + "AsyncGraphDatabase", + "AsyncManagedTransaction", + "AsyncNeo4jDriver", + "AsyncResult", + "AsyncSession", + "AsyncTransaction", + "Auth", + "AuthToken", + "basic_auth", + "bearer_auth", + "BoltDriver", + "Bookmark", + "Bookmarks", + "Config", + "custom_auth", + "DEFAULT_DATABASE", + "Driver", + "ExperimentalWarning", + "get_user_agent", + "GraphDatabase", + "IPv4Address", + "IPv6Address", + "kerberos_auth", + "ManagedTransaction", + "Neo4jDriver", + "PoolConfig", + "Query", + "READ_ACCESS", + "Record", + "Result", + "ResultSummary", + "ServerInfo", + "Session", + "SessionConfig", + "SummaryCounters", + "Transaction", + "TRUST_ALL_CERTIFICATES", + "TRUST_SYSTEM_CA_SIGNED_CERTIFICATES", + "TrustAll", + "TrustCustomCAs", + "TrustSystemCAs", + "unit_of_work", + "Version", + "WorkspaceConfig", + "WRITE_ACCESS", +] + + +_log = _getLogger("neo4j") + + +def __getattr__(name): + # TODO 6.0 - remove this + if name in ( + "log", "Config", "PoolConfig", "SessionConfig", "WorkspaceConfig" + ): + from ._meta import deprecation_warn + deprecation_warn( + "Importing {} from neo4j is deprecated without replacement. It's " + "internal and will be removed in a future version." + .format(name), + stack_level=2 + ) + return globals()[f"_{name}"] + raise AttributeError(f"module {__name__} has no attribute {name}") + + +def __dir__(): + return __all__ diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py index e2857b83b..1b3504e57 100644 --- a/neo4j/_async/driver.py +++ b/neo4j/_async/driver.py @@ -139,7 +139,8 @@ def driver(cls, uri, *, auth=None, **config): "Creating a direct driver (`bolt://` scheme) with routing " "context (URI parameters) is deprecated. They will be " "ignored. This will raise an error in a future release. " - 'Given URI "{}"'.format(uri) + 'Given URI "{}"'.format(uri), + stack_level=2 ) # TODO: 6.0 - raise instead of warning # raise ValueError( diff --git a/neo4j/_codec/hydration/_common.py b/neo4j/_codec/hydration/_common.py index 1fc634fa3..3a51b030d 100644 --- a/neo4j/_codec/hydration/_common.py +++ b/neo4j/_codec/hydration/_common.py @@ -17,7 +17,7 @@ from ...graph import Graph -from ...packstream import Structure +from ..packstream import Structure class GraphHydrator: diff --git a/neo4j/_sync/driver.py b/neo4j/_sync/driver.py index 3c4fbbaf9..0c8f5e097 100644 --- a/neo4j/_sync/driver.py +++ b/neo4j/_sync/driver.py @@ -139,7 +139,8 @@ def driver(cls, uri, *, auth=None, **config): "Creating a direct driver (`bolt://` scheme) with routing " "context (URI parameters) is deprecated. They will be " "ignored. This will raise an error in a future release. " - 'Given URI "{}"'.format(uri) + 'Given URI "{}"'.format(uri), + stack_level=2 ) # TODO: 6.0 - raise instead of warning # raise ValueError( diff --git a/neo4j/api.py b/neo4j/api.py index d35f0809d..7930d1d4e 100644 --- a/neo4j/api.py +++ b/neo4j/api.py @@ -165,6 +165,7 @@ def custom_auth(principal, credentials, realm, scheme, **parameters): return Auth(scheme, principal, credentials, realm, **parameters) +# TODO 6.0 - remove this class class Bookmark: """A Bookmark object contains an immutable list of bookmark string values. @@ -271,6 +272,10 @@ def from_raw_values(cls, values): if not isinstance(value, str): raise TypeError("Raw bookmark values must be str. " "Found {}".format(type(value))) + try: + value.encode("ascii") + except UnicodeEncodeError as e: + raise ValueError(f"The value {value} is not ASCII") from e bookmarks.append(value) obj._raw_values = frozenset(bookmarks) return obj diff --git a/neo4j/conf.py b/neo4j/conf.py new file mode 100644 index 000000000..150a58850 --- /dev/null +++ b/neo4j/conf.py @@ -0,0 +1,53 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: 6.0 - remove this file + + +from ._conf import ( + Config, + ConfigType, + DeprecatedAlias, + DeprecatedAlternative, + iter_items, + PoolConfig, + RoutingConfig, + SessionConfig, + TransactionConfig, + WorkspaceConfig, +) +from ._meta import deprecation_warn as _deprecation_warn + + +__all__ = [ + "Config", + "ConfigType", + "DeprecatedAlias", + "DeprecatedAlternative", + "iter_items", + "PoolConfig", + "RoutingConfig", + "SessionConfig", + "TransactionConfig", + "WorkspaceConfig", +] + +_deprecation_warn( + "The module 'neo4j.conf' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 +) diff --git a/neo4j/data.py b/neo4j/data.py new file mode 100644 index 000000000..0713ed460 --- /dev/null +++ b/neo4j/data.py @@ -0,0 +1,45 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: 6.0 - remove this file + + +from ._data import ( + DataTransformer, + Record, + RecordExporter, + RecordTableRowExporter, +) +from ._meta import deprecation_warn + + +map_type = type(map(str, range(0))) + +__all__ = [ + "map_type", + "Record", + "DataTransformer", + "RecordExporter", + "RecordTableRowExporter", +] + +deprecation_warn( + "The module 'neo4j.data' was made internal and will " + "no longer be available for import in future versions. " + "`neo4j.data.Record` should be imported directly from `neo4j`.", + stack_level=2 +) diff --git a/neo4j/graph/__init__.py b/neo4j/graph/__init__.py index 17e1e7b0b..ee7d1beee 100644 --- a/neo4j/graph/__init__.py +++ b/neo4j/graph/__init__.py @@ -31,7 +31,7 @@ from collections.abc import Mapping -from ..meta import ( +from .._meta import ( deprecated, deprecation_warn, ) diff --git a/neo4j/meta.py b/neo4j/meta.py new file mode 100644 index 000000000..05b1be0cc --- /dev/null +++ b/neo4j/meta.py @@ -0,0 +1,48 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: 6.0 - remove this file + + +from ._meta import ( + deprecated, + deprecation_warn, + experimental, + ExperimentalWarning, + get_user_agent, + package, + version, +) + + +__all__ = [ + "package", + "version", + "get_user_agent", + "deprecation_warn", + "deprecated", + "ExperimentalWarning", + "experimental", +] + +deprecation_warn( + "The module 'neo4j.meta' was made internal and will " + "no longer be available for import in future versions." + "`ExperimentalWarning` can be imported from `neo4j` directly and " + "`neo4j.meta.version` is exposed as `neo4j.__version__`.", + stack_level=2 +) diff --git a/neo4j/packstream.py b/neo4j/packstream.py new file mode 100644 index 000000000..041b644f7 --- /dev/null +++ b/neo4j/packstream.py @@ -0,0 +1,58 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# TODO: 6.0 - remove this file + + +from ._codec.packstream.v1 import ( + INT64_MAX, + INT64_MIN, + PACKED_UINT_8, + PACKED_UINT_16, + Packer, + Structure, + UnpackableBuffer, + UNPACKED_MARKERS, + UNPACKED_UINT_8, + UNPACKED_UINT_16, + Unpacker, +) +from ._meta import deprecation_warn + + +__all__ = [ + "PACKED_UINT_8", + "PACKED_UINT_16", + "UNPACKED_UINT_8", + "UNPACKED_UINT_16", + "UNPACKED_MARKERS", + "UNPACKED_MARKERS", + "UNPACKED_MARKERS", + "INT64_MIN", + "INT64_MAX", + "Structure", + "Packer", + "Unpacker", + "UnpackableBuffer", +] + +deprecation_warn( + "The module 'neo4j.packstream' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 +) diff --git a/neo4j/routing.py b/neo4j/routing.py new file mode 100644 index 000000000..1036d92fc --- /dev/null +++ b/neo4j/routing.py @@ -0,0 +1,37 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: 6.0 - remove this file + + +from ._meta import deprecation_warn as _deprecation_warn +from ._routing import ( + OrderedSet, + RoutingTable, +) + + +__all__ = [ + "OrderedSet", + "RoutingTable", +] + +_deprecation_warn( + "The module 'neo4j.routing' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 +) diff --git a/neo4j/time/_clock_implementations.py b/neo4j/time/_clock_implementations.py index cbfaf0f27..60f82cc8b 100644 --- a/neo4j/time/_clock_implementations.py +++ b/neo4j/time/_clock_implementations.py @@ -25,11 +25,11 @@ ) from platform import uname -from neo4j.time import ( +from . import ( Clock, ClockTime, ) -from neo4j.time.arithmetic import nano_divmod +from ._arithmetic import nano_divmod __all__ = [ diff --git a/neo4j/time/arithmetic.py b/neo4j/time/arithmetic.py new file mode 100644 index 000000000..7f0961f61 --- /dev/null +++ b/neo4j/time/arithmetic.py @@ -0,0 +1,43 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: 6.0 - remove this file + + +from .._meta import deprecation_warn +from ._arithmetic import ( + nano_add, + nano_div, + nano_divmod, + round_half_to_even, + symmetric_divmod, +) + + +__all__ = [ + "nano_add", + "nano_div", + "nano_divmod", + "symmetric_divmod", + "round_half_to_even", +] + +deprecation_warn( + "The module 'neo4j.time.arithmetic' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 +) diff --git a/neo4j/time/clock_implementations.py b/neo4j/time/clock_implementations.py new file mode 100644 index 000000000..facfa5f61 --- /dev/null +++ b/neo4j/time/clock_implementations.py @@ -0,0 +1,39 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: 6.0 - remove this file + + +from .._meta import deprecation_warn +from ._clock_implementations import ( + LibCClock, + PEP564Clock, + SafeClock, +) + + +__all__ = [ + "SafeClock", + "PEP564Clock", + "LibCClock", +] + +deprecation_warn( + "The module 'neo4j.time.clock_implementations' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 +) diff --git a/neo4j/time/hydration.py b/neo4j/time/hydration.py new file mode 100644 index 000000000..4681c7c66 --- /dev/null +++ b/neo4j/time/hydration.py @@ -0,0 +1,57 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: 6.0 - remove this file + + +from .._codec.hydration.v1.temporal import ( + dehydrate_date, + dehydrate_datetime, + dehydrate_duration, + dehydrate_time, + dehydrate_timedelta, + get_date_unix_epoch, + get_date_unix_epoch_ordinal, + get_datetime_unix_epoch_utc, + hydrate_date, + hydrate_datetime, + hydrate_duration, + hydrate_time, +) +from .._meta import deprecation_warn + + +__all__ = [ + "get_date_unix_epoch", + "get_date_unix_epoch_ordinal", + "get_datetime_unix_epoch_utc", + "hydrate_date", + "dehydrate_date", + "hydrate_time", + "dehydrate_time", + "hydrate_datetime", + "dehydrate_datetime", + "hydrate_duration", + "dehydrate_duration", + "dehydrate_timedelta", +] + +deprecation_warn( + "The module 'neo4j.time.hydration' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 +) diff --git a/neo4j/time/metaclasses.py b/neo4j/time/metaclasses.py new file mode 100644 index 000000000..c23101f1a --- /dev/null +++ b/neo4j/time/metaclasses.py @@ -0,0 +1,37 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .._meta import deprecation_warn +from ._metaclasses import ( + DateTimeType, + DateType, + TimeType, +) + + +__all__ = [ + "DateType", + "TimeType", + "DateTimeType", +] + +deprecation_warn( + "The module 'neo4j.time.metaclasses' was made internal and will " + "no longer be available for import in future versions.", + stack_level=2 +) diff --git a/setup.cfg b/setup.cfg index 0f6574a96..cf2d22207 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,3 +18,4 @@ use_parentheses=true [tool:pytest] mock_use_standalone_module = true +asyncio_mode = auto diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index d68082f49..01d37e463 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -18,13 +18,13 @@ import pytest -from neo4j import ( +from neo4j._async.io import AsyncBolt +from neo4j._async.io._pool import AsyncIOPool +from neo4j._conf import ( Config, PoolConfig, WorkspaceConfig, ) -from neo4j._async.io import AsyncBolt -from neo4j._async.io._pool import AsyncIOPool from neo4j._deadline import Deadline from neo4j.exceptions import ( ClientError, diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index a2c7156cd..44b1931f4 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -16,6 +16,8 @@ # limitations under the License. +import inspect + import pytest from neo4j import ( @@ -36,7 +38,7 @@ ) from ...._async_compat import mark_async_test -from ..work import async_fake_connection_generator +from ..work import async_fake_connection_generator # needed as fixture ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") @@ -44,7 +46,7 @@ WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") -@pytest.fixture() +@pytest.fixture def opener(async_fake_connection_generator, mocker): async def open_(addr, timeout): connection = async_fake_connection_generator() @@ -156,11 +158,13 @@ async def test_reuses_connection(opener): @pytest.mark.parametrize("break_on_close", (True, False)) @mark_async_test async def test_closes_stale_connections(opener, break_on_close): - def break_connection(): - pool.deactivate(cx1.addr) + async def break_connection(): + await pool.deactivate(cx1.addr) if cx_close_mock_side_effect: - cx_close_mock_side_effect() + res = cx_close_mock_side_effect() + if inspect.isawaitable(res): + return await res pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS @@ -242,8 +246,8 @@ async def test_release_does_not_resets_closed_connections(opener): cx1.is_reset_mock.reset_mock() await pool.release(cx1) cx1.closed.assert_called_once() - cx1.is_reset_mock.asset_not_called() - cx1.reset.asset_not_called() + cx1.is_reset_mock.assert_not_called() + cx1.reset.assert_not_called() @mark_async_test @@ -257,8 +261,8 @@ async def test_release_does_not_resets_defunct_connections(opener): cx1.is_reset_mock.reset_mock() await pool.release(cx1) cx1.defunct.assert_called_once() - cx1.is_reset_mock.asset_not_called() - cx1.reset.asset_not_called() + cx1.is_reset_mock.assert_not_called() + cx1.reset.assert_not_called() @pytest.mark.parametrize("liveness_timeout", (0, 1, 2)) @@ -271,7 +275,7 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection( ) cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) assert cx1.addr == READER_ADDRESS - cx1.reset.asset_not_called() + cx1.reset.assert_not_called() @pytest.mark.parametrize("liveness_timeout", (0, 1, 2)) diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index 13242006f..17fbeb1a0 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -17,6 +17,7 @@ import ssl +from functools import wraps import pytest @@ -24,12 +25,14 @@ AsyncBoltDriver, AsyncGraphDatabase, AsyncNeo4jDriver, + ExperimentalWarning, TRUST_ALL_CERTIFICATES, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, TrustAll, TrustCustomCAs, TrustSystemCAs, ) +from neo4j._async_compat.util import AsyncUtil from neo4j.api import ( READ_ACCESS, WRITE_ACCESS, @@ -39,6 +42,21 @@ from ..._async_compat import mark_async_test +@wraps(AsyncGraphDatabase.driver) +def create_driver(*args, **kwargs): + if AsyncUtil.is_async_code: + with pytest.warns(ExperimentalWarning, match="async") as warnings: + driver = AsyncGraphDatabase.driver(*args, **kwargs) + print(warnings) + return driver + else: + return AsyncGraphDatabase.driver(*args, **kwargs) + + +def driver(*args, **kwargs): + return AsyncNeo4jDriver(*args, **kwargs) + + @pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) @pytest.mark.parametrize("host", ("localhost", "127.0.0.1", "[::1]", "[0:0:0:0:0:0:0:1]")) @@ -52,7 +70,7 @@ async def test_direct_driver_constructor(protocol, host, port, params, auth_toke with pytest.warns(DeprecationWarning, match="routing context"): driver = AsyncGraphDatabase.driver(uri, auth=auth_token) else: - driver = AsyncGraphDatabase.driver(uri, auth=auth_token) + driver = create_driver(uri, auth=auth_token) assert isinstance(driver, AsyncBoltDriver) await driver.close() @@ -67,7 +85,7 @@ async def test_direct_driver_constructor(protocol, host, port, params, auth_toke @mark_async_test async def test_routing_driver_constructor(protocol, host, port, params, auth_token): uri = protocol + host + port + params - driver = AsyncGraphDatabase.driver(uri, auth=auth_token) + driver = create_driver(uri, auth=auth_token) assert isinstance(driver, AsyncNeo4jDriver) await driver.close() @@ -127,13 +145,20 @@ async def test_routing_driver_constructor(protocol, host, port, params, auth_tok async def test_driver_config_error( test_uri, test_config, expected_failure, expected_failure_message ): + def driver_builder(): + if "trust" in test_config: + with pytest.warns(DeprecationWarning, match="trust"): + return AsyncGraphDatabase.driver(test_uri, **test_config) + else: + return create_driver(test_uri, **test_config) + if "+" in test_uri: # `+s` and `+ssc` are short hand syntax for not having to configure the # encryption behavior of the driver. Specifying both is invalid. with pytest.raises(expected_failure, match=expected_failure_message): - AsyncGraphDatabase.driver(test_uri, **test_config) + driver_builder() else: - driver = AsyncGraphDatabase.driver(test_uri, **test_config) + driver = driver_builder() await driver.close() @@ -144,7 +169,7 @@ async def test_driver_config_error( )) def test_invalid_protocol(test_uri): with pytest.raises(ConfigurationError, match="scheme"): - AsyncGraphDatabase.driver(test_uri) + create_driver(test_uri) @pytest.mark.parametrize( @@ -159,7 +184,7 @@ def test_driver_trust_config_error( test_config, expected_failure, expected_failure_message ): with pytest.raises(expected_failure, match=expected_failure_message): - AsyncGraphDatabase.driver("bolt://127.0.0.1:9001", **test_config) + create_driver("bolt://127.0.0.1:9001", **test_config) @pytest.mark.parametrize("uri", ( @@ -168,7 +193,7 @@ def test_driver_trust_config_error( )) @mark_async_test async def test_driver_opens_write_session_by_default(uri, mocker): - driver = AsyncGraphDatabase.driver(uri) + driver = create_driver(uri) from neo4j import AsyncTransaction # we set a specific db, because else the driver would try to fetch a RT @@ -207,7 +232,7 @@ async def test_driver_opens_write_session_by_default(uri, mocker): )) @mark_async_test async def test_verify_connectivity(uri, mocker): - driver = AsyncGraphDatabase.driver(uri) + driver = create_driver(uri) pool_mock = mocker.patch.object(driver, "_pool", autospec=True) try: @@ -233,7 +258,7 @@ async def test_verify_connectivity(uri, mocker): @mark_async_test async def test_verify_connectivity_parameters_are_deprecated(uri, kwargs, mocker): - driver = AsyncGraphDatabase.driver(uri) + driver = create_driver(uri) mocker.patch.object(driver, "_pool", autospec=True) try: diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index 50fb3c932..8981ef327 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -26,6 +26,7 @@ from neo4j import ( Address, AsyncResult, + ExperimentalWarning, Record, ResultSummary, ServerInfo, @@ -732,10 +733,11 @@ async def test_to_df(keys, values, types, instances, test_default_expand): connection = AsyncConnectionStub(records=Records(keys, values)) result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) - if test_default_expand: - df = await result.to_df() - else: - df = await result.to_df(expand=False) + with pytest.warns(ExperimentalWarning, match="pandas"): + if test_default_expand: + df = await result.to_df() + else: + df = await result.to_df(expand=False) assert isinstance(df, pd.DataFrame) assert df.keys().to_list() == keys @@ -889,7 +891,8 @@ async def test_to_df_expand(keys, values, expected_columns, expected_rows, connection = AsyncConnectionStub(records=Records(keys, values)) result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) - df = await result.to_df(expand=True) + with pytest.warns(ExperimentalWarning, match="pandas"): + df = await result.to_df(expand=True) assert isinstance(df, pd.DataFrame) assert len(set(expected_columns)) == len(expected_columns) @@ -1088,6 +1091,7 @@ async def test_to_df_parse_dates(keys, values, expected_df, expand): connection = AsyncConnectionStub(records=Records(keys, values)) result = AsyncResult(connection, 1, noop, noop) await result._run("CYPHER", {}, None, None, "r", None) - df = await result.to_df(expand=expand, parse_dates=True) + with pytest.warns(ExperimentalWarning, match="pandas"): + df = await result.to_df(expand=expand, parse_dates=True) pd.testing.assert_frame_equal(df, expected_df) diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index d44ceb083..117cb5cbf 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -25,13 +25,12 @@ AsyncSession, AsyncTransaction, Bookmarks, - SessionConfig, unit_of_work, ) from neo4j._async.io._pool import AsyncIOPool +from neo4j._conf import SessionConfig from ...._async_compat import mark_async_test -from ._fake_connection import async_fake_connection_generator @pytest.fixture() @@ -52,7 +51,8 @@ def acquire_side_effect(*_, **__): @mark_async_test async def test_session_context_calls_close(mocker): s = AsyncSession(None, SessionConfig()) - mock_close = mocker.patch.object(s, 'close', autospec=True) + mock_close = mocker.patch.object(s, 'close', autospec=True, + side_effect=s.close) async with s: pass mock_close.assert_called_once_with() @@ -203,9 +203,12 @@ async def test_session_returns_bookmarks_directly(pool, bookmark_values): ) @mark_async_test async def test_session_last_bookmark_is_deprecated(pool, bookmarks): - async with AsyncSession(pool, SessionConfig( - bookmarks=bookmarks - )) as session: + if bookmarks is not None: + with pytest.warns(DeprecationWarning): + session = AsyncSession(pool, SessionConfig(bookmarks=bookmarks)) + else: + session = AsyncSession(pool, SessionConfig(bookmarks=bookmarks)) + async with session: with pytest.warns(DeprecationWarning): if bookmarks: assert (await session.last_bookmark()) == bookmarks[-1] diff --git a/tests/unit/common/spatial/test_point.py b/tests/unit/common/spatial/test_point.py index 816ee01c5..3122e2de3 100644 --- a/tests/unit/common/spatial/test_point.py +++ b/tests/unit/common/spatial/test_point.py @@ -18,7 +18,7 @@ from unittest import TestCase -from neo4j.spatial import ( +from neo4j._spatial import ( Point, point_type, ) diff --git a/tests/unit/common/test_api.py b/tests/unit/common/test_api.py index 5f8576aab..a0f836796 100644 --- a/tests/unit/common/test_api.py +++ b/tests/unit/common/test_api.py @@ -17,6 +17,7 @@ import itertools +from contextlib import contextmanager import pytest @@ -101,9 +102,27 @@ def test_bookmark_initialization_with_valid_strings(test_input, expected_values, (("bookmark1", chr(129),), ValueError), ] ) -def test_bookmark_initialization_with_invalid_strings(test_input, expected): +@pytest.mark.parametrize(("method", "deprecated", "splat_args"), ( + (neo4j.Bookmark, True, True), + (neo4j.Bookmarks.from_raw_values, False, False), +)) +def test_bookmark_initialization_with_invalid_strings( + test_input, expected, method, deprecated, splat_args +): + @contextmanager + def deprecation_assertion(): + if deprecated: + with pytest.warns(DeprecationWarning): + yield + else: + yield + with pytest.raises(expected): - neo4j.Bookmark(*test_input) + with deprecation_assertion(): + if splat_args: + method(*test_input) + else: + method(test_input) @pytest.mark.parametrize("test_as_generator", [True, False]) @@ -116,7 +135,6 @@ def test_bookmark_initialization_with_invalid_strings(test_input, expected): ("bookmark1", ""), ("bookmark1",), (), - (not_ascii,), )) def test_bookmarks_raw_values(test_as_generator, values): expected = frozenset(values) @@ -140,6 +158,7 @@ def test_bookmarks_raw_values(test_as_generator, values): ((set(),), TypeError), ((frozenset(),), TypeError), ((["bookmark1", "bookmark2"],), TypeError), + ((not_ascii,), ValueError), )) def test_bookmarks_invalid_raw_values(values, exc_type): with pytest.raises(exc_type): @@ -255,7 +274,8 @@ def test_serverinfo_initialization(): assert server_info.address is address assert server_info.protocol_version is version assert server_info.agent is None - assert server_info.connection_id is None + with pytest.warns(DeprecationWarning): + assert server_info.connection_id is None @pytest.mark.parametrize( diff --git a/tests/unit/common/test_import_neo4j.py b/tests/unit/common/test_import_neo4j.py index a42e6207e..aa97bea28 100644 --- a/tests/unit/common/test_import_neo4j.py +++ b/tests/unit/common/test_import_neo4j.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest def test_import_dunder_version(): @@ -105,7 +106,8 @@ def test_import_async_session(): def test_import_sessionconfig(): - from neo4j import SessionConfig + with pytest.warns(DeprecationWarning): + from neo4j import SessionConfig def test_import_query(): @@ -129,11 +131,13 @@ def test_import_unit_of_work(): def test_import_config(): - from neo4j import Config + with pytest.warns(DeprecationWarning): + from neo4j import Config def test_import_poolconfig(): - from neo4j import PoolConfig + with pytest.warns(DeprecationWarning): + from neo4j import PoolConfig def test_import_graph(): diff --git a/tests/unit/common/test_record.py b/tests/unit/common/test_record.py index be62f8b6c..9aaf67c8a 100644 --- a/tests/unit/common/test_record.py +++ b/tests/unit/common/test_record.py @@ -288,7 +288,7 @@ def test_data_relationship(): gh = hydration_scope._graph_hydrator alice = gh.hydrate_node(1, {"Person"}, {"name": "Alice", "age": 33}) bob = gh.hydrate_node(2, {"Person"}, {"name": "Bob", "age": 44}) - alice_knows_bob = gh.hydrate_relationship(1, alice.id, bob.id, "KNOWS", + alice_knows_bob = gh.hydrate_relationship(1, alice._id, bob._id, "KNOWS", {"since": 1999}) record = Record(zip(["a", "b", "r"], [alice, bob, alice_knows_bob])) assert record.data() == { diff --git a/tests/unit/common/test_types.py b/tests/unit/common/test_types.py index e9dfec721..4564cef8b 100644 --- a/tests/unit/common/test_types.py +++ b/tests/unit/common/test_types.py @@ -478,12 +478,12 @@ def test_path_v2_repr(legacy_id): ) alice_knows_bob = gh.hydrate_relationship( - 1 if legacy_id else None, alice.id, bob.id, "KNOWS", {"since": 1999}, + 1 if legacy_id else None, alice._id, bob._id, "KNOWS", {"since": 1999}, "1" if legacy_id else "alice_knows_bob", alice.element_id, bob.element_id ) carol_dislikes_bob = gh.hydrate_relationship( - 2 if legacy_id else None, carol.id, bob.id, "DISLIKES", {}, + 2 if legacy_id else None, carol._id, bob._id, "DISLIKES", {}, "2" if legacy_id else "carol_dislikes_bob", carol.element_id, bob.element_id ) diff --git a/tests/unit/common/time/test_clock.py b/tests/unit/common/time/test_clock.py index 038d82737..f5fc41b8a 100644 --- a/tests/unit/common/time/test_clock.py +++ b/tests/unit/common/time/test_clock.py @@ -57,7 +57,7 @@ def test_local_offset(self): def test_local_time(self): _ = Clock() for impl in Clock._Clock__implementations: - self.assert_(issubclass(impl, Clock)) + self.assertTrue(issubclass(impl, Clock)) clock = object.__new__(impl) time = clock.local_time() self.assertIsInstance(time, ClockTime) @@ -65,7 +65,7 @@ def test_local_time(self): def test_utc_time(self): _ = Clock() for impl in Clock._Clock__implementations: - self.assert_(issubclass(impl, Clock)) + self.assertTrue(issubclass(impl, Clock)) clock = object.__new__(impl) time = clock.utc_time() self.assertIsInstance(time, ClockTime) diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index 98a1c5b0e..cddbecef8 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -18,7 +18,7 @@ import pytest -from neo4j import ( +from neo4j._conf import ( Config, PoolConfig, WorkspaceConfig, diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 3b2cba764..af10dc7da 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -16,6 +16,8 @@ # limitations under the License. +import inspect + import pytest from neo4j import ( @@ -36,7 +38,7 @@ ) from ...._async_compat import mark_sync_test -from ..work import fake_connection_generator +from ..work import fake_connection_generator # needed as fixture ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") @@ -44,7 +46,7 @@ WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") -@pytest.fixture() +@pytest.fixture def opener(fake_connection_generator, mocker): def open_(addr, timeout): connection = fake_connection_generator() @@ -160,7 +162,9 @@ def break_connection(): pool.deactivate(cx1.addr) if cx_close_mock_side_effect: - cx_close_mock_side_effect() + res = cx_close_mock_side_effect() + if inspect.isawaitable(res): + return res pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS @@ -242,8 +246,8 @@ def test_release_does_not_resets_closed_connections(opener): cx1.is_reset_mock.reset_mock() pool.release(cx1) cx1.closed.assert_called_once() - cx1.is_reset_mock.asset_not_called() - cx1.reset.asset_not_called() + cx1.is_reset_mock.assert_not_called() + cx1.reset.assert_not_called() @mark_sync_test @@ -257,8 +261,8 @@ def test_release_does_not_resets_defunct_connections(opener): cx1.is_reset_mock.reset_mock() pool.release(cx1) cx1.defunct.assert_called_once() - cx1.is_reset_mock.asset_not_called() - cx1.reset.asset_not_called() + cx1.is_reset_mock.assert_not_called() + cx1.reset.assert_not_called() @pytest.mark.parametrize("liveness_timeout", (0, 1, 2)) @@ -271,7 +275,7 @@ def test_acquire_performs_no_liveness_check_on_fresh_connection( ) cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) assert cx1.addr == READER_ADDRESS - cx1.reset.asset_not_called() + cx1.reset.assert_not_called() @pytest.mark.parametrize("liveness_timeout", (0, 1, 2)) diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index 0e4670adf..57c9bf9b5 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -17,11 +17,13 @@ import ssl +from functools import wraps import pytest from neo4j import ( BoltDriver, + ExperimentalWarning, GraphDatabase, Neo4jDriver, TRUST_ALL_CERTIFICATES, @@ -30,6 +32,7 @@ TrustCustomCAs, TrustSystemCAs, ) +from neo4j._async_compat.util import Util from neo4j.api import ( READ_ACCESS, WRITE_ACCESS, @@ -39,6 +42,21 @@ from ..._async_compat import mark_sync_test +@wraps(GraphDatabase.driver) +def create_driver(*args, **kwargs): + if Util.is_async_code: + with pytest.warns(ExperimentalWarning, match="async") as warnings: + driver = GraphDatabase.driver(*args, **kwargs) + print(warnings) + return driver + else: + return GraphDatabase.driver(*args, **kwargs) + + +def driver(*args, **kwargs): + return Neo4jDriver(*args, **kwargs) + + @pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) @pytest.mark.parametrize("host", ("localhost", "127.0.0.1", "[::1]", "[0:0:0:0:0:0:0:1]")) @@ -52,7 +70,7 @@ def test_direct_driver_constructor(protocol, host, port, params, auth_token): with pytest.warns(DeprecationWarning, match="routing context"): driver = GraphDatabase.driver(uri, auth=auth_token) else: - driver = GraphDatabase.driver(uri, auth=auth_token) + driver = create_driver(uri, auth=auth_token) assert isinstance(driver, BoltDriver) driver.close() @@ -67,7 +85,7 @@ def test_direct_driver_constructor(protocol, host, port, params, auth_token): @mark_sync_test def test_routing_driver_constructor(protocol, host, port, params, auth_token): uri = protocol + host + port + params - driver = GraphDatabase.driver(uri, auth=auth_token) + driver = create_driver(uri, auth=auth_token) assert isinstance(driver, Neo4jDriver) driver.close() @@ -127,13 +145,20 @@ def test_routing_driver_constructor(protocol, host, port, params, auth_token): def test_driver_config_error( test_uri, test_config, expected_failure, expected_failure_message ): + def driver_builder(): + if "trust" in test_config: + with pytest.warns(DeprecationWarning, match="trust"): + return GraphDatabase.driver(test_uri, **test_config) + else: + return create_driver(test_uri, **test_config) + if "+" in test_uri: # `+s` and `+ssc` are short hand syntax for not having to configure the # encryption behavior of the driver. Specifying both is invalid. with pytest.raises(expected_failure, match=expected_failure_message): - GraphDatabase.driver(test_uri, **test_config) + driver_builder() else: - driver = GraphDatabase.driver(test_uri, **test_config) + driver = driver_builder() driver.close() @@ -144,7 +169,7 @@ def test_driver_config_error( )) def test_invalid_protocol(test_uri): with pytest.raises(ConfigurationError, match="scheme"): - GraphDatabase.driver(test_uri) + create_driver(test_uri) @pytest.mark.parametrize( @@ -159,7 +184,7 @@ def test_driver_trust_config_error( test_config, expected_failure, expected_failure_message ): with pytest.raises(expected_failure, match=expected_failure_message): - GraphDatabase.driver("bolt://127.0.0.1:9001", **test_config) + create_driver("bolt://127.0.0.1:9001", **test_config) @pytest.mark.parametrize("uri", ( @@ -168,7 +193,7 @@ def test_driver_trust_config_error( )) @mark_sync_test def test_driver_opens_write_session_by_default(uri, mocker): - driver = GraphDatabase.driver(uri) + driver = create_driver(uri) from neo4j import Transaction # we set a specific db, because else the driver would try to fetch a RT @@ -207,7 +232,7 @@ def test_driver_opens_write_session_by_default(uri, mocker): )) @mark_sync_test def test_verify_connectivity(uri, mocker): - driver = GraphDatabase.driver(uri) + driver = create_driver(uri) pool_mock = mocker.patch.object(driver, "_pool", autospec=True) try: @@ -233,7 +258,7 @@ def test_verify_connectivity(uri, mocker): @mark_sync_test def test_verify_connectivity_parameters_are_deprecated(uri, kwargs, mocker): - driver = GraphDatabase.driver(uri) + driver = create_driver(uri) mocker.patch.object(driver, "_pool", autospec=True) try: diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index e8d05b5aa..a801ca7e7 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -25,6 +25,7 @@ from neo4j import ( Address, + ExperimentalWarning, Record, Result, ResultSummary, @@ -732,10 +733,11 @@ def test_to_df(keys, values, types, instances, test_default_expand): connection = ConnectionStub(records=Records(keys, values)) result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) - if test_default_expand: - df = result.to_df() - else: - df = result.to_df(expand=False) + with pytest.warns(ExperimentalWarning, match="pandas"): + if test_default_expand: + df = result.to_df() + else: + df = result.to_df(expand=False) assert isinstance(df, pd.DataFrame) assert df.keys().to_list() == keys @@ -889,7 +891,8 @@ def test_to_df_expand(keys, values, expected_columns, expected_rows, connection = ConnectionStub(records=Records(keys, values)) result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) - df = result.to_df(expand=True) + with pytest.warns(ExperimentalWarning, match="pandas"): + df = result.to_df(expand=True) assert isinstance(df, pd.DataFrame) assert len(set(expected_columns)) == len(expected_columns) @@ -1088,6 +1091,7 @@ def test_to_df_parse_dates(keys, values, expected_df, expand): connection = ConnectionStub(records=Records(keys, values)) result = Result(connection, 1, noop, noop) result._run("CYPHER", {}, None, None, "r", None) - df = result.to_df(expand=expand, parse_dates=True) + with pytest.warns(ExperimentalWarning, match="pandas"): + df = result.to_df(expand=expand, parse_dates=True) pd.testing.assert_frame_equal(df, expected_df) diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index 839d24a37..c93646306 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -24,14 +24,13 @@ Bookmarks, ManagedTransaction, Session, - SessionConfig, Transaction, unit_of_work, ) +from neo4j._conf import SessionConfig from neo4j._sync.io._pool import IOPool from ...._async_compat import mark_sync_test -from ._fake_connection import fake_connection_generator @pytest.fixture() @@ -52,7 +51,8 @@ def acquire_side_effect(*_, **__): @mark_sync_test def test_session_context_calls_close(mocker): s = Session(None, SessionConfig()) - mock_close = mocker.patch.object(s, 'close', autospec=True) + mock_close = mocker.patch.object(s, 'close', autospec=True, + side_effect=s.close) with s: pass mock_close.assert_called_once_with() @@ -203,9 +203,12 @@ def test_session_returns_bookmarks_directly(pool, bookmark_values): ) @mark_sync_test def test_session_last_bookmark_is_deprecated(pool, bookmarks): - with Session(pool, SessionConfig( - bookmarks=bookmarks - )) as session: + if bookmarks is not None: + with pytest.warns(DeprecationWarning): + session = Session(pool, SessionConfig(bookmarks=bookmarks)) + else: + session = Session(pool, SessionConfig(bookmarks=bookmarks)) + with session: with pytest.warns(DeprecationWarning): if bookmarks: assert (session.last_bookmark()) == bookmarks[-1] From 9c4a8b85a04fba4bf24e9431ef4d96637e861cac Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Mon, 4 Jul 2022 11:14:38 +0200 Subject: [PATCH 5/5] Adjust TestKit SubTest parameters --- testkitbackend/test_config.json | 2 +- testkitbackend/test_subtest_skips.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 51dd43a96..6495025e1 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -11,7 +11,7 @@ "'stub.server_side_routing.test_server_side_routing.TestServerSideRouting.test_direct_connection_with_url_params'": "Driver emits deprecation warning. Behavior will be unified in 6.0.", "neo4j.datatypes.test_temporal_types.TestDataTypes.test_should_echo_all_timezone_ids": - "test_subtest_skips.tz_id", + "test_subtest_skips.dt_conversion", "neo4j.datatypes.test_temporal_types.TestDataTypes.test_date_time_cypher_created_tz_id": "test_subtest_skips.tz_id" }, diff --git a/testkitbackend/test_subtest_skips.py b/testkitbackend/test_subtest_skips.py index 68b02b471..6dfb6434e 100644 --- a/testkitbackend/test_subtest_skips.py +++ b/testkitbackend/test_subtest_skips.py @@ -25,6 +25,11 @@ """ +import pytz + +from . import fromtestkit + + def tz_id(**params): # We could do this automatically, but with an explicit black list we # make sure we know what we test and what we don't. @@ -51,3 +56,11 @@ def tz_id(**params): return ( "timezone id %s is not supported by the system" % params["tz_id"] ) + + +def dt_conversion(**params): + dt = params["dt"] + try: + fromtestkit.to_param(dt) + except (pytz.UnknownTimeZoneError, ValueError) as e: + return "cannot create desired dt %s: %r" % (dt, e)