diff --git a/nutkit/frontend/result.py b/nutkit/frontend/result.py index 3c26ac19c..293556ea6 100644 --- a/nutkit/frontend/result.py +++ b/nutkit/frontend/result.py @@ -19,6 +19,16 @@ def single(self): req = protocol.ResultSingle(self._result.id) return self._driver.send_and_receive(req, allow_resolution=True) + def single_optional(self): + """Try leniently to fetch exactly one record. + + Return None, if there is no record. + Return the record, if there is at least one. + Warn, if there are more than one. + """ + req = protocol.ResultSingleOptional(self._result.id) + return self._driver.send_and_receive(req, allow_resolution=True) + def peek(self): """Return the next Record or NullRecord without consuming it.""" req = protocol.ResultPeek(self._result.id) diff --git a/nutkit/protocol/cypher.py b/nutkit/protocol/cypher.py index 24fd2b25a..af76a610d 100644 --- a/nutkit/protocol/cypher.py +++ b/nutkit/protocol/cypher.py @@ -263,3 +263,198 @@ def __eq__(self, other): # More in line with other naming CypherPath = Path + + +class CypherPoint: + def __init__(self, system, x, y, z=None): + self.system = system + self.x = x + self.y = y + self.z = z + if system not in ("cartesian", "wgs84"): + raise ValueError("Invalid system: {}".format(system)) + + def __str__(self): + if self.z is None: + return "CypherPoint(system={}, x={}, y={})".format( + self.system, self.x, self.y + ) + return "CypherPoint(system={}, x={}, y={}, z={})".format( + self.system, self.x, self.y, self.z + ) + + def __repr__(self): + return "<{}(system={}, x={}, y={}, z={})>".format( + self.__class__.__name__, self.system, self.x, self.y, self.z + ) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + + return all(getattr(self, attr) == getattr(other, attr) + for attr in ("system", "x", "y", "z")) + + +class CypherDate: + def __init__(self, year, month, day): + self.year = int(year) + self.month = int(month) + self.day = int(day) + for v in ("year", "month", "day"): + if getattr(self, v) != locals()[v]: + raise ValueError("{} must be integer".format(v)) + + def __str__(self): + return "CypherDate(year={}, month={}, day={})".format( + self.year, self.month, self.day + ) + + def __repr__(self): + return "<{}(year={}, month={}, day={})>".format( + self.__class__.__name__, self.year, self.month, self.day + ) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + + return all(getattr(self, attr) == getattr(other, attr) + for attr in ("year", "month", "day")) + + +class CypherTime: + def __init__(self, hour, minute, second, nanosecond, utc_offset_s=None): + self.hour = int(hour) + self.minute = int(minute) + self.second = int(second) + self.nanosecond = int(nanosecond) + # seconds east of UTC or None for local time + self.utc_offset_s = utc_offset_s + if self.utc_offset_s is not None: + self.utc_offset_s = int(utc_offset_s) + for v in ("hour", "minute", "second", "nanosecond", "utc_offset_s"): + if getattr(self, v) != locals()[v]: + raise ValueError("{} must be integer".format(v)) + + def __str__(self): + return ( + "CypherTime(hour={}, minute={}, second={}, nanosecond={}, " + "utc_offset_s={})".format( + self.hour, self.minute, self.second, self.nanosecond, + self.utc_offset_s + ) + ) + + def __repr__(self): + return ( + "<{}(hour={}, minute={}, second={}, nanosecond={}, " + "utc_offset_s={})>".format( + self.__class__.__name__, self.hour, self.minute, self.second, + self.nanosecond, self.utc_offset_s + ) + ) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + + return all(getattr(self, attr) == getattr(other, attr) + for attr in ("hour", "minute", "second", "nanosecond", + "utc_offset_s")) + + +class CypherDateTime: + def __init__(self, year, month, day, hour, minute, second, nanosecond, + utc_offset_s=None, timezone_id=None): + # The date time is always wall clock time (with or without timezone) + # If timezone_id is given (e.g., "Europe/Stockholm"), utc_offset_s + # must also be provided to avoid ambiguity. + self.year = int(year) + self.month = int(month) + self.day = int(day) + self.hour = int(hour) + self.minute = int(minute) + self.second = int(second) + self.nanosecond = int(nanosecond) + + self.utc_offset_s = utc_offset_s + if self.utc_offset_s is not None: + self.utc_offset_s = int(utc_offset_s) + self.timezone_id = timezone_id + if self.timezone_id is not None: + self.timezone_id = str(timezone_id) + + for v in ("year", "month", "day", "hour", "minute", "second", + "nanosecond", "utc_offset_s"): + if getattr(self, v) != locals()[v]: + raise ValueError("{} must be integer".format(v)) + if timezone_id is not None and utc_offset_s is None: + raise ValueError("utc_offset_s must be provided if timezone_id " + "is given") + + def __str__(self): + return ( + "CypherDateTime(year={}, month={}, day={}, hour={}, minute={}, " + "second={}, nanosecond={}, utc_offset_s={}, timezone_id={})" + .format( + self.year, self.month, self.day, self.hour, self.minute, + self.second, self.nanosecond, self.utc_offset_s, + self.timezone_id + ) + ) + + def __repr__(self): + return ( + "<{}(year={}, month={}, day={}, hour={}, minute={}, second={}, " + "nanosecond={}, utc_offset_s={}, timezone_id={})>" + .format( + self.__class__.__name__, self.year, self.month, self.day, + self.hour, self.minute, self.second, self.nanosecond, + self.utc_offset_s, self.timezone_id + ) + ) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + + return all(getattr(self, attr) == getattr(other, attr) + for attr in ("year", "month", "day", "hour", "minute", + "second", "nanosecond", "utc_offset_s", + "timezone_id")) + + +class CypherDuration: + def __init__(self, months, days, seconds, nanoseconds): + self.months = int(months) + self.days = int(days) + seconds, nanoseconds = divmod( + seconds * 1000000000 + nanoseconds, 1000000000 + ) + self.seconds = int(seconds) + self.nanoseconds = int(nanoseconds) + + for v in ("months", "days", "seconds", "nanoseconds"): + if getattr(self, v) != locals()[v]: + raise ValueError("{} must be integer".format(v)) + + def __str__(self): + return ( + "CypherDuration(months={}, days={}, seconds={}, nanoseconds={})" + .format(self.months, self.days, self.seconds, self.nanoseconds) + ) + + def __repr__(self): + return ( + "<{}(months={}, days={}, seconds={}, nanoseconds={})>" + .format(self.__class__.__name__, self.months, self.days, + self.seconds, self.nanoseconds) + ) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + + return all(getattr(self, attr) == getattr(other, attr) + for attr in ("months", "days", "seconds", "nanoseconds")) diff --git a/nutkit/protocol/feature.py b/nutkit/protocol/feature.py index b62c791e8..bb654415c 100644 --- a/nutkit/protocol/feature.py +++ b/nutkit/protocol/feature.py @@ -21,6 +21,8 @@ class Feature(Enum): # The driver offers a method for checking if a connection to the remote # server of cluster can be established. API_DRIVER_VERIFY_CONNECTIVITY = "Feature:API:Driver.VerifyConnectivity" + # The driver supports connection liveness check. + API_LIVENESS_CHECK = "Feature:API:Liveness.Check" # The driver offers a method for the result to return all records as a list # or array. This method should exhaust the result. API_RESULT_LIST = "Feature:API:Result.List" @@ -32,8 +34,12 @@ class Feature(Enum): # This method asserts that exactly one record in left in the result # stream, else it will raise an exception. API_RESULT_SINGLE = "Feature:API:Result.Single" - # The driver supports connection liveness check. - API_LIVENESS_CHECK = "Feature:API:Liveness.Check" + # The driver offers a method for the result to retrieve the next record in + # the result stream. If there are no more records left in the result, the + # driver will indicate so by returning None/null/nil/any other empty value. + # If there are more than records, the driver emits a warning. + # This method is supposed to always exhaust the result stream. + API_RESULT_SINGLE_OPTIONAL = "Feature:API:Result.SingleOptional" # The driver implements explicit configuration options for SSL. # - enable / disable SSL # - verify signature against system store / custom cert / not at all @@ -43,6 +49,10 @@ class Feature(Enum): # ...+s: enforce SSL + verify server's signature with system's trust store # ...+ssc: enforce SSL but do not verify the server's signature at all API_SSL_SCHEMES = "Feature:API:SSLSchemes" + # The driver supports sending and receiving geospatial data types. + API_TYPE_SPATIAL = "Feature:API:Type.Spatial" + # The driver supports sending and receiving temporal data types. + API_TYPE_TEMPORAL = "Feature:API:Type.Temporal" # The driver supports single-sign-on (SSO) by providing a bearer auth token # API. AUTH_BEARER = "Feature:Auth:Bearer" diff --git a/nutkit/protocol/requests.py b/nutkit/protocol/requests.py index 32aa30e8a..74fd42700 100644 --- a/nutkit/protocol/requests.py +++ b/nutkit/protocol/requests.py @@ -22,14 +22,34 @@ class StartTest: """ Request the backend to confirm to run a specific test. - The backend should respond with RunTest if the backend wants the test to be - skipped it must respond with SkipTest. + The backend should respond with RunTest unless it wants the test to be + skipped, in that case it must respond with SkipTest. + + The backend might also respond with RunSubTests. In this case, TestKit will + - run the test, if it does not have subtests + - ask for each subtest whether it should be run, if it has subtests + StartSubTest will be sent to the backend, for each subtest """ def __init__(self, test_name): self.testName = test_name +class StartSubTest: + """ + Request the backend to confirm to run a specific subtest. + + See StartTest for when TestKit might emmit this message. + + The backend should respond with RunTest unless it wants the subtest to be + skipped, in that case it must respond with SkipTest. + """ + + def __init__(self, test_name, subtest_arguments: dict): + self.testName = test_name + self.subtestArguments = subtest_arguments + + class GetFeatures: """ Request the backend to the list of features supported by the driver. @@ -394,6 +414,20 @@ def __init__(self, resultId): self.resultId = resultId +class ResultSingleOptional: + """ + Request to expect and return exactly one record in the result stream. + + Furthermore, the method is supposed to fully exhaust the result stream. + + The backend should respond with a RecordOptional or, if any error occurs + while retrieving the records, an Error response should be returned. + """ + + def __init__(self, resultId): + self.resultId = resultId + + class ResultPeek: """ Request to return the next result in the Stream without consuming it. diff --git a/nutkit/protocol/responses.py b/nutkit/protocol/responses.py index a6531d057..4d668b6f1 100644 --- a/nutkit/protocol/responses.py +++ b/nutkit/protocol/responses.py @@ -20,10 +20,6 @@ """ -class RunTest: - """Response to StartTest indicating that the test can be started.""" - - class FeatureList: """ Response to GetFeatures. @@ -37,6 +33,14 @@ def __init__(self, features=None): self.features = features +class RunTest: + """Response to StartTest indicating that the test can be started.""" + + +class RunSubTests: + """Response to StartTest requesting to decide for each subtest.""" + + class SkipTest: """Response to StartTest indicating that the test should be skipped.""" @@ -215,6 +219,24 @@ def __init__(self, records=None): self.records = record_list +class RecordOptional: + """ + Represents an optional record. + + Possible response to the ResultOptionalSingle request. + + Fields: + record: Record values or None (see Record response) + warnings: List of warnings (str) (potentially empty) + """ + + def __init__(self, record, warnings): + self.record = None + if record is not None: + self.record = Record(values=record["values"]) + self.warnings = warnings + + class Summary: """Represents summary returned from a ResultConsume request.""" diff --git a/requirements.txt b/requirements.txt index ed1be8d6b..1331d017e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ ifaddr~=0.1.7 lark~=1.0.0 nose~=1.3.7 pre-commit~=2.15.0 +pytz diff --git a/tests/neo4j/datatypes/__init__.py b/tests/neo4j/datatypes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/neo4j/datatypes/_base.py b/tests/neo4j/datatypes/_base.py new file mode 100644 index 000000000..871458e90 --- /dev/null +++ b/tests/neo4j/datatypes/_base.py @@ -0,0 +1,64 @@ +from nutkit.frontend import ApplicationCodeError +import nutkit.protocol as types +from tests.neo4j.shared import get_driver +from tests.shared import TestkitTestCase + +MIN_INT64 = -(2 ** 63) +MAX_INT64 = (2 ** 63) - 1 + + +class _TestTypesBase(TestkitTestCase): + def setUp(self): + super().setUp() + self._session = None + self._driver = None + + def tearDown(self): + if self._session: + self._session.close() + if self._driver: + self._driver.close() + super().tearDown() + + def _create_driver_and_session(self): + if self._session is not None: + self._session.close() + if self._driver is not None: + self._driver.close() + self._driver = get_driver(self._backend) + self._session = self._driver.session("w") + + def _verify_can_echo(self, val): + def work(tx): + result = tx.run("RETURN $x as y", params={"x": val}) + record_ = result.next() + assert isinstance(result.next(), types.NullRecord) + return record_ + + record = self._session.read_transaction(work) + self.assertEqual(record, types.Record(values=[val])) + + def _read_query_values(self, query, params=None): + def work(tx): + result = tx.run(query, params=params) + record_ = result.next() + assert isinstance(result.next(), types.NullRecord) + return record_.values + + return self._session.read_transaction(work) + + def _write_query_values(self, query, params=None): + values = [] + + def work(tx): + nonlocal values + result = tx.run(query, params=params) + record_ = result.next() + assert isinstance(result.next(), types.NullRecord) + values = record_.values + # rollback + raise ApplicationCodeError + + with self.assertRaises(types.FrontendError): + self._session.write_transaction(work) + return values diff --git a/tests/neo4j/datatypes/_util.py b/tests/neo4j/datatypes/_util.py new file mode 100644 index 000000000..f59f444eb --- /dev/null +++ b/tests/neo4j/datatypes/_util.py @@ -0,0 +1,607 @@ +TZ_IDS = ( + "Africa/Abidjan", + "Africa/Accra", + "Africa/Addis_Ababa", + "Africa/Algiers", + "Africa/Asmara", + "Africa/Asmera", + "Africa/Bamako", + "Africa/Bangui", + "Africa/Banjul", + "Africa/Bissau", + "Africa/Blantyre", + "Africa/Brazzaville", + "Africa/Bujumbura", + "Africa/Cairo", + "Africa/Casablanca", + "Africa/Ceuta", + "Africa/Conakry", + "Africa/Dakar", + "Africa/Dar_es_Salaam", + "Africa/Djibouti", + "Africa/Douala", + "Africa/El_Aaiun", + "Africa/Freetown", + "Africa/Gaborone", + "Africa/Harare", + "Africa/Johannesburg", + "Africa/Juba", + "Africa/Kampala", + "Africa/Khartoum", + "Africa/Kigali", + "Africa/Kinshasa", + "Africa/Lagos", + "Africa/Libreville", + "Africa/Lome", + "Africa/Luanda", + "Africa/Lubumbashi", + "Africa/Lusaka", + "Africa/Malabo", + "Africa/Maputo", + "Africa/Maseru", + "Africa/Mbabane", + "Africa/Mogadishu", + "Africa/Monrovia", + "Africa/Nairobi", + "Africa/Ndjamena", + "Africa/Niamey", + "Africa/Nouakchott", + "Africa/Ouagadougou", + "Africa/Porto-Novo", + "Africa/Sao_Tome", + "Africa/Timbuktu", + "Africa/Tripoli", + "Africa/Tunis", + "Africa/Windhoek", + "America/Adak", + "America/Anchorage", + "America/Anguilla", + "America/Antigua", + "America/Araguaina", + "America/Argentina/Buenos_Aires", + "America/Argentina/Catamarca", + "America/Argentina/ComodRivadavia", + "America/Argentina/Cordoba", + "America/Argentina/Jujuy", + "America/Argentina/La_Rioja", + "America/Argentina/Mendoza", + "America/Argentina/Rio_Gallegos", + "America/Argentina/Salta", + "America/Argentina/San_Juan", + "America/Argentina/San_Luis", + "America/Argentina/Tucuman", + "America/Argentina/Ushuaia", + "America/Aruba", + "America/Asuncion", + "America/Atikokan", + "America/Atka", + "America/Bahia", + "America/Bahia_Banderas", + "America/Barbados", + "America/Belem", + "America/Belize", + "America/Blanc-Sablon", + "America/Boa_Vista", + "America/Bogota", + "America/Boise", + "America/Buenos_Aires", + "America/Cambridge_Bay", + "America/Campo_Grande", + "America/Cancun", + "America/Caracas", + "America/Catamarca", + "America/Cayenne", + "America/Cayman", + "America/Chicago", + "America/Chihuahua", + "America/Coral_Harbour", + "America/Cordoba", + "America/Costa_Rica", + "America/Creston", + "America/Cuiaba", + "America/Curacao", + "America/Danmarkshavn", + "America/Dawson", + "America/Dawson_Creek", + "America/Denver", + "America/Detroit", + "America/Dominica", + "America/Edmonton", + "America/Eirunepe", + "America/El_Salvador", + "America/Ensenada", + "America/Fort_Nelson", + "America/Fort_Wayne", + "America/Fortaleza", + "America/Glace_Bay", + "America/Godthab", + "America/Goose_Bay", + "America/Grand_Turk", + "America/Grenada", + "America/Guadeloupe", + "America/Guatemala", + "America/Guayaquil", + "America/Guyana", + "America/Halifax", + "America/Havana", + "America/Hermosillo", + "America/Indiana/Indianapolis", + "America/Indiana/Knox", + "America/Indiana/Marengo", + "America/Indiana/Petersburg", + "America/Indiana/Tell_City", + "America/Indiana/Vevay", + "America/Indiana/Vincennes", + "America/Indiana/Winamac", + "America/Indianapolis", + "America/Inuvik", + "America/Iqaluit", + "America/Jamaica", + "America/Jujuy", + "America/Juneau", + "America/Kentucky/Louisville", + "America/Kentucky/Monticello", + "America/Knox_IN", + "America/Kralendijk", + "America/La_Paz", + "America/Lima", + "America/Los_Angeles", + "America/Louisville", + "America/Lower_Princes", + "America/Maceio", + "America/Managua", + "America/Manaus", + "America/Marigot", + "America/Martinique", + "America/Matamoros", + "America/Mazatlan", + "America/Mendoza", + "America/Menominee", + "America/Merida", + "America/Metlakatla", + "America/Mexico_City", + "America/Miquelon", + "America/Moncton", + "America/Monterrey", + "America/Montevideo", + "America/Montreal", + "America/Montserrat", + "America/Nassau", + "America/New_York", + "America/Nipigon", + "America/Nome", + "America/Noronha", + "America/North_Dakota/Beulah", + "America/North_Dakota/Center", + "America/North_Dakota/New_Salem", + "America/Nuuk", + "America/Ojinaga", + "America/Panama", + "America/Pangnirtung", + "America/Paramaribo", + "America/Phoenix", + "America/Port-au-Prince", + "America/Port_of_Spain", + "America/Porto_Acre", + "America/Porto_Velho", + "America/Puerto_Rico", + "America/Punta_Arenas", + "America/Rainy_River", + "America/Rankin_Inlet", + "America/Recife", + "America/Regina", + "America/Resolute", + "America/Rio_Branco", + "America/Rosario", + "America/Santa_Isabel", + "America/Santarem", + "America/Santiago", + "America/Santo_Domingo", + "America/Sao_Paulo", + "America/Scoresbysund", + "America/Shiprock", + "America/Sitka", + "America/St_Barthelemy", + "America/St_Johns", + "America/St_Kitts", + "America/St_Lucia", + "America/St_Thomas", + "America/St_Vincent", + "America/Swift_Current", + "America/Tegucigalpa", + "America/Thule", + "America/Thunder_Bay", + "America/Tijuana", + "America/Toronto", + "America/Tortola", + "America/Vancouver", + "America/Virgin", + "America/Whitehorse", + "America/Winnipeg", + "America/Yakutat", + "America/Yellowknife", + "Antarctica/Casey", + "Antarctica/Davis", + "Antarctica/DumontDUrville", + "Antarctica/Macquarie", + "Antarctica/Mawson", + "Antarctica/McMurdo", + "Antarctica/Palmer", + "Antarctica/Rothera", + "Antarctica/South_Pole", + "Antarctica/Syowa", + "Antarctica/Troll", + "Antarctica/Vostok", + "Arctic/Longyearbyen", + "Asia/Aden", + "Asia/Almaty", + "Asia/Amman", + "Asia/Anadyr", + "Asia/Aqtau", + "Asia/Aqtobe", + "Asia/Ashgabat", + "Asia/Ashkhabad", + "Asia/Atyrau", + "Asia/Baghdad", + "Asia/Bahrain", + "Asia/Baku", + "Asia/Bangkok", + "Asia/Barnaul", + "Asia/Beirut", + "Asia/Bishkek", + "Asia/Brunei", + "Asia/Calcutta", + "Asia/Chita", + "Asia/Choibalsan", + "Asia/Chongqing", + "Asia/Chungking", + "Asia/Colombo", + "Asia/Dacca", + "Asia/Damascus", + "Asia/Dhaka", + "Asia/Dili", + "Asia/Dubai", + "Asia/Dushanbe", + "Asia/Famagusta", + "Asia/Gaza", + "Asia/Harbin", + "Asia/Hebron", + "Asia/Ho_Chi_Minh", + "Asia/Hong_Kong", + "Asia/Hovd", + "Asia/Irkutsk", + "Asia/Istanbul", + "Asia/Jakarta", + "Asia/Jayapura", + "Asia/Jerusalem", + "Asia/Kabul", + "Asia/Kamchatka", + "Asia/Karachi", + "Asia/Kashgar", + "Asia/Kathmandu", + "Asia/Katmandu", + "Asia/Khandyga", + "Asia/Kolkata", + "Asia/Krasnoyarsk", + "Asia/Kuala_Lumpur", + "Asia/Kuching", + "Asia/Kuwait", + "Asia/Macao", + "Asia/Macau", + "Asia/Magadan", + "Asia/Makassar", + "Asia/Manila", + "Asia/Muscat", + "Asia/Nicosia", + "Asia/Novokuznetsk", + "Asia/Novosibirsk", + "Asia/Omsk", + "Asia/Oral", + "Asia/Phnom_Penh", + "Asia/Pontianak", + "Asia/Pyongyang", + "Asia/Qatar", + "Asia/Qostanay", + "Asia/Qyzylorda", + "Asia/Rangoon", + "Asia/Riyadh", + "Asia/Saigon", + "Asia/Sakhalin", + "Asia/Samarkand", + "Asia/Seoul", + "Asia/Shanghai", + "Asia/Singapore", + "Asia/Srednekolymsk", + "Asia/Taipei", + "Asia/Tashkent", + "Asia/Tbilisi", + "Asia/Tehran", + "Asia/Tel_Aviv", + "Asia/Thimbu", + "Asia/Thimphu", + "Asia/Tokyo", + "Asia/Tomsk", + "Asia/Ujung_Pandang", + "Asia/Ulaanbaatar", + "Asia/Ulan_Bator", + "Asia/Urumqi", + "Asia/Ust-Nera", + "Asia/Vientiane", + "Asia/Vladivostok", + "Asia/Yakutsk", + "Asia/Yangon", + "Asia/Yekaterinburg", + "Asia/Yerevan", + "Atlantic/Azores", + "Atlantic/Bermuda", + "Atlantic/Canary", + "Atlantic/Cape_Verde", + "Atlantic/Faeroe", + "Atlantic/Faroe", + "Atlantic/Jan_Mayen", + "Atlantic/Madeira", + "Atlantic/Reykjavik", + "Atlantic/South_Georgia", + "Atlantic/St_Helena", + "Atlantic/Stanley", + "Australia/ACT", + "Australia/Adelaide", + "Australia/Brisbane", + "Australia/Broken_Hill", + "Australia/Canberra", + "Australia/Currie", + "Australia/Darwin", + "Australia/Eucla", + "Australia/Hobart", + "Australia/LHI", + "Australia/Lindeman", + "Australia/Lord_Howe", + "Australia/Melbourne", + "Australia/NSW", + "Australia/North", + "Australia/Perth", + "Australia/Queensland", + "Australia/South", + "Australia/Sydney", + "Australia/Tasmania", + "Australia/Victoria", + "Australia/West", + "Australia/Yancowinna", + "Brazil/Acre", + "Brazil/DeNoronha", + "Brazil/East", + "Brazil/West", + "CET", + "CST6CDT", + "Canada/Atlantic", + "Canada/Central", + "Canada/Eastern", + "Canada/Mountain", + "Canada/Newfoundland", + "Canada/Pacific", + "Canada/Saskatchewan", + "Canada/Yukon", + "Chile/Continental", + "Chile/EasterIsland", + "Cuba", + "EET", + "EST", + "EST5EDT", + "Egypt", + "Eire", + "Etc/GMT", + "Etc/GMT+0", + "Etc/GMT+1", + "Etc/GMT+10", + "Etc/GMT+11", + "Etc/GMT+12", + "Etc/GMT+2", + "Etc/GMT+3", + "Etc/GMT+4", + "Etc/GMT+5", + "Etc/GMT+6", + "Etc/GMT+7", + "Etc/GMT+8", + "Etc/GMT+9", + "Etc/GMT-0", + "Etc/GMT-1", + "Etc/GMT-10", + "Etc/GMT-11", + "Etc/GMT-12", + "Etc/GMT-13", + "Etc/GMT-14", + "Etc/GMT-2", + "Etc/GMT-3", + "Etc/GMT-4", + "Etc/GMT-5", + "Etc/GMT-6", + "Etc/GMT-7", + "Etc/GMT-8", + "Etc/GMT-9", + "Etc/GMT0", + "Etc/Greenwich", + "Etc/UCT", + "Etc/UTC", + "Etc/Universal", + "Etc/Zulu", + "Europe/Amsterdam", + "Europe/Andorra", + "Europe/Astrakhan", + "Europe/Athens", + "Europe/Belfast", + "Europe/Belgrade", + "Europe/Berlin", + "Europe/Bratislava", + "Europe/Brussels", + "Europe/Bucharest", + "Europe/Budapest", + "Europe/Busingen", + "Europe/Chisinau", + "Europe/Copenhagen", + "Europe/Dublin", + "Europe/Gibraltar", + "Europe/Guernsey", + "Europe/Helsinki", + "Europe/Isle_of_Man", + "Europe/Istanbul", + "Europe/Jersey", + "Europe/Kaliningrad", + "Europe/Kiev", + "Europe/Kirov", + "Europe/Lisbon", + "Europe/Ljubljana", + "Europe/London", + "Europe/Luxembourg", + "Europe/Madrid", + "Europe/Malta", + "Europe/Mariehamn", + "Europe/Minsk", + "Europe/Monaco", + "Europe/Moscow", + "Europe/Nicosia", + "Europe/Oslo", + "Europe/Paris", + "Europe/Podgorica", + "Europe/Prague", + "Europe/Riga", + "Europe/Rome", + "Europe/Samara", + "Europe/San_Marino", + "Europe/Sarajevo", + "Europe/Saratov", + "Europe/Simferopol", + "Europe/Skopje", + "Europe/Sofia", + "Europe/Stockholm", + "Europe/Tallinn", + "Europe/Tirane", + "Europe/Tiraspol", + "Europe/Ulyanovsk", + "Europe/Uzhgorod", + "Europe/Vaduz", + "Europe/Vatican", + "Europe/Vienna", + "Europe/Vilnius", + "Europe/Volgograd", + "Europe/Warsaw", + "Europe/Zagreb", + "Europe/Zaporozhye", + "Europe/Zurich", + "GB", + "GB-Eire", + "GMT", + "GMT0", + "Greenwich", + "HST", + "Hongkong", + "Iceland", + "Indian/Antananarivo", + "Indian/Chagos", + "Indian/Christmas", + "Indian/Cocos", + "Indian/Comoro", + "Indian/Kerguelen", + "Indian/Mahe", + "Indian/Maldives", + "Indian/Mauritius", + "Indian/Mayotte", + "Indian/Reunion", + "Iran", + "Israel", + "Jamaica", + "Japan", + "Kwajalein", + "Libya", + "MET", + "MST", + "MST7MDT", + "Mexico/BajaNorte", + "Mexico/BajaSur", + "Mexico/General", + "NZ", + "NZ-CHAT", + "Navajo", + "PRC", + "PST8PDT", + "Pacific/Apia", + "Pacific/Auckland", + "Pacific/Bougainville", + "Pacific/Chatham", + "Pacific/Chuuk", + "Pacific/Easter", + "Pacific/Efate", + "Pacific/Enderbury", + "Pacific/Fakaofo", + "Pacific/Fiji", + "Pacific/Funafuti", + "Pacific/Galapagos", + "Pacific/Gambier", + "Pacific/Guadalcanal", + "Pacific/Guam", + "Pacific/Honolulu", + "Pacific/Johnston", + "Pacific/Kanton", + "Pacific/Kiritimati", + "Pacific/Kosrae", + "Pacific/Kwajalein", + "Pacific/Majuro", + "Pacific/Marquesas", + "Pacific/Midway", + "Pacific/Nauru", + "Pacific/Niue", + "Pacific/Norfolk", + "Pacific/Noumea", + "Pacific/Pago_Pago", + "Pacific/Palau", + "Pacific/Pitcairn", + "Pacific/Pohnpei", + "Pacific/Ponape", + "Pacific/Port_Moresby", + "Pacific/Rarotonga", + "Pacific/Saipan", + "Pacific/Samoa", + "Pacific/Tahiti", + "Pacific/Tarawa", + "Pacific/Tongatapu", + "Pacific/Truk", + "Pacific/Wake", + "Pacific/Wallis", + "Pacific/Yap", + "Poland", + "Portugal", + "ROC", + "ROK", + "Singapore", + "SystemV/AST4", + "SystemV/AST4ADT", + "SystemV/CST6", + "SystemV/CST6CDT", + "SystemV/EST5", + "SystemV/EST5EDT", + "SystemV/HST10", + "SystemV/MST7", + "SystemV/MST7MDT", + "SystemV/PST8", + "SystemV/PST8PDT", + "SystemV/YST9", + "SystemV/YST9YDT", + "Turkey", + "UCT", + "US/Alaska", + "US/Aleutian", + "US/Arizona", + "US/Central", + "US/East-Indiana", + "US/Eastern", + "US/Hawaii", + "US/Indiana-Starke", + "US/Michigan", + "US/Mountain", + "US/Pacific", + "US/Samoa", + "UTC", + "Universal", + "W-SU", + "WET", + "Zulu", +) diff --git a/tests/neo4j/test_datatypes.py b/tests/neo4j/datatypes/test_datatypes.py similarity index 84% rename from tests/neo4j/test_datatypes.py rename to tests/neo4j/datatypes/test_datatypes.py index a6b4cf6d5..b47e5f9ff 100644 --- a/tests/neo4j/test_datatypes.py +++ b/tests/neo4j/datatypes/test_datatypes.py @@ -1,38 +1,9 @@ import nutkit.protocol as types -from tests.neo4j.shared import get_driver -from tests.shared import ( - get_driver_name, - TestkitTestCase, -) - - -class TestDataTypes(TestkitTestCase): - def setUp(self): - super().setUp() - self._session = None - self._driver = None - - def tearDown(self): - if self._session: - self._session.close() - if self._driver: - self._driver.close() - super().tearDown() - - def create_driver_and_session(self): - self._driver = get_driver(self._backend) - self._session = self._driver.session("w") - - def verify_can_echo(self, val): - def work(tx): - result = tx.run("RETURN $x as y", params={"x": val}) - record_ = result.next() - assert isinstance(result.next(), types.NullRecord) - return record_ +from tests.neo4j.datatypes._base import _TestTypesBase +from tests.shared import get_driver_name - record = self._session.read_transaction(work) - self.assertEqual(record, types.Record(values=[val])) +class TestDataTypes(_TestTypesBase): def test_should_echo_back(self): vals = [ types.CypherBool(True), @@ -80,7 +51,7 @@ def test_should_echo_back(self): types.CypherBytes(bytearray([0x00, 0x33, 0x66, 0x99, 0xCC, 0xFF])), ] - self.create_driver_and_session() + self._create_driver_and_session() for val in vals: # TODO: remove this block once all languages work if get_driver_name() in ["javascript", "dotnet"]: @@ -100,7 +71,7 @@ def test_should_echo_back(self): if isinstance(val, types.CypherBytes): continue - self.verify_can_echo(val) + self._verify_can_echo(val) def test_should_echo_very_long_list(self): vals = [ @@ -111,18 +82,18 @@ def test_should_echo_very_long_list(self): types.CypherBool(True), ] - self.create_driver_and_session() + self._create_driver_and_session() for val in vals: long_list = [] for _ in range(1000): long_list.append(val) - self.verify_can_echo(types.CypherList(long_list)) + self._verify_can_echo(types.CypherList(long_list)) def test_should_echo_very_long_string(self): - self.create_driver_and_session() + self._create_driver_and_session() long_string = "*" * 10000 - self.verify_can_echo(types.CypherString(long_string)) + self._verify_can_echo(types.CypherString(long_string)) def test_should_echo_nested_lists(self): test_lists = [ @@ -161,8 +132,8 @@ def test_should_echo_nested_lists(self): ]) ] - self.create_driver_and_session() - self.verify_can_echo(types.CypherList(test_lists)) + self._create_driver_and_session() + self._verify_can_echo(types.CypherList(test_lists)) def test_should_echo_node(self): def work(tx): @@ -173,7 +144,7 @@ def work(tx): assert isinstance(result.next(), types.NullRecord) return record_ - self.create_driver_and_session() + self._create_driver_and_session() record = self._session.write_transaction(work) self.assertIsInstance(record, types.Record) @@ -200,7 +171,7 @@ def work(tx): assert isinstance(result.next(), types.NullRecord) return record_ - self.create_driver_and_session() + self._create_driver_and_session() record = self._session.write_transaction(work) self.assertIsInstance(record, types.Record) @@ -232,7 +203,7 @@ def work(tx): assert isinstance(result.next(), types.NullRecord) return record_ - self.create_driver_and_session() + self._create_driver_and_session() record = self._session.write_transaction(work) self.assertIsInstance(record, types.Record) @@ -273,14 +244,14 @@ def test_should_echo_very_long_map(self): types.CypherString("Hello World"), types.CypherBool(True)] - self.create_driver_and_session() + self._create_driver_and_session() long_map = {} for cypher_type in test_list: long_map.clear() for i in range(1000): long_map[str(i)] = cypher_type - self.verify_can_echo(types.CypherMap(long_map)) + self._verify_can_echo(types.CypherMap(long_map)) def test_should_echo_nested_map(self): test_maps = { @@ -306,8 +277,8 @@ def test_should_echo_nested_map(self): } - self.create_driver_and_session() - self.verify_can_echo(types.CypherMap(test_maps)) + self._create_driver_and_session() + self._verify_can_echo(types.CypherMap(test_maps)) def test_should_echo_list_of_maps(self): test_list = [ @@ -320,13 +291,13 @@ def test_should_echo_list_of_maps(self): "d": types.CypherInt(4) }) ] - self.create_driver_and_session() - self.verify_can_echo(types.CypherList(test_list)) + self._create_driver_and_session() + self._verify_can_echo(types.CypherList(test_list)) def test_should_echo_map_of_lists(self): test_map = { "a": types.CypherList([types.CypherInt(1)]), "b": types.CypherList([types.CypherInt(2)]) } - self.create_driver_and_session() - self.verify_can_echo(types.CypherMap(test_map)) + self._create_driver_and_session() + self._verify_can_echo(types.CypherMap(test_map)) diff --git a/tests/neo4j/datatypes/test_spatial_types.py b/tests/neo4j/datatypes/test_spatial_types.py new file mode 100644 index 000000000..6cf62e81f --- /dev/null +++ b/tests/neo4j/datatypes/test_spatial_types.py @@ -0,0 +1,88 @@ +import nutkit.protocol as types +from tests.neo4j.datatypes._base import _TestTypesBase + + +class TestDataTypes(_TestTypesBase): + + required_features = (types.Feature.API_TYPE_SPATIAL,) + + def test_should_echo_spatial_point(self): + coords = [ + (1.1, -2), + (0, 2, 3.5), + (0, 0, 0), + (1.1, -2), + (0, 2, 3.5), + (0, 0, 0), + ] + systems = ["cartesian", "wgs84"] + + self._create_driver_and_session() + for coord in coords: + for system in systems: + self._verify_can_echo(types.CypherPoint(system, *coord)) + + def test_point_components(self): + for system, coords, names in ( + ("cartesian", (1.1, -2), ("x", "y")), + ("cartesian", (1.1, -2.0, 123456.789), ("x", "y", "z")), + ("wgs84", (1.1, -2.0), ("x", "y")), + ("wgs84", (1.1, -2.0), ("longitude", "latitude")), + ("wgs84", (1.1, -2.0, 123456.789), ("x", "y", "z")), + ( + "wgs84", (1.1, -2.0, 123456.789), + ("longitude", "latitude", "height") + ), + ): + with self.subTest(system=system, coords=coords, names=names): + self._create_driver_and_session() + point = types.CypherPoint(system, *coords) + values = self._read_query_values( + "CYPHER runtime=interpreted " + "WITH $point AS point " + f"RETURN [{', '.join(f'point.{n}' for n in names)}]", + params={"point": point}, + ) + self.assertEqual( + values, + [types.CypherList(list(map(types.CypherFloat, coords)))], + ) + + def test_nested_point(self): + for points in ( + [ + types.CypherPoint("cartesian", 1.1, -2.456), + types.CypherPoint("cartesian", -123456789.1, 2.45655), + ], + [ + types.CypherPoint("cartesian", 1.1, -2.456, 123456789.999), + types.CypherPoint("cartesian", -1.1, 2.456, -123456789.999), + ], + [ + types.CypherPoint("wgs84", 1.1, -2.0), + types.CypherPoint("wgs84", 78.23456, -89.45), + ], + [ + types.CypherPoint("wgs84", 1.1, -2.0, 123456.789), + types.CypherPoint("wgs84", 78.23456, -89.45, -123456.789), + ], + ): + with self.subTest(points=points): + self._create_driver_and_session() + data = types.CypherList(points) + values = self._write_query_values( + "CREATE (a {x:$x}) RETURN a.x", params={"x": data}, + ) + self.assertEqual(values, [data]) + + def test_cypher_created_point(self): + for s, p in ( + ("point({x:3, y:4})", ("cartesian", 3, 4)), + ("point({x:3, y:4, z:5})", ("cartesian", 3, 4, 5)), + ("point({longitude:3, latitude:4})", ("wgs84", 3, 4)), + ("point({longitude:3, latitude:4, height:5})", ("wgs84", 3, 4, 5)), + ): + with self.subTest(s=s): + self._create_driver_and_session() + values = self._read_query_values(f"RETURN {s}") + self.assertEqual(values, [types.CypherPoint(*p)]) diff --git a/tests/neo4j/datatypes/test_temporal_types.py b/tests/neo4j/datatypes/test_temporal_types.py new file mode 100644 index 000000000..650618f95 --- /dev/null +++ b/tests/neo4j/datatypes/test_temporal_types.py @@ -0,0 +1,499 @@ +import datetime + +import pytz + +import nutkit.protocol as types +from tests.neo4j.datatypes._base import ( + _TestTypesBase, + MAX_INT64, + MIN_INT64, +) +from tests.neo4j.datatypes._util import TZ_IDS +from tests.shared import get_driver_name + + +class TestDataTypes(_TestTypesBase): + + required_features = (types.Feature.API_TYPE_TEMPORAL,) + + def test_should_echo_temporal_type(self): + vals = [ + types.CypherDate(1, 1, 1), + types.CypherDate(1970, 1, 1), + types.CypherDate(9999, 12, 31), + + types.CypherTime(0, 0, 0, 0), + types.CypherTime(23, 59, 59, 999999999), + types.CypherTime(23, 59, 59, 0, utc_offset_s=60 * 60 * 1), + types.CypherTime(0, 0, 0, 0, utc_offset_s=60 * 60 * -1), + # we are only testing utc offset down to 1 minute precision + # while bolt supports up to 1 second precision, we don't expect + # all driver languages to support that. + types.CypherTime(23, 59, 59, 1, utc_offset_s=60), + types.CypherTime(0, 0, 0, 0, utc_offset_s=-60), + # current real world boundaries + types.CypherTime(23, 59, 59, 0, utc_offset_s=60 * 60 * 14), + types.CypherTime(0, 0, 0, 0, utc_offset_s=60 * 60 * -12), + # server enforced boundaries + types.CypherTime(23, 59, 59, 0, utc_offset_s=60 * 60 * 18), + types.CypherTime(0, 0, 0, 0, utc_offset_s=60 * 60 * -18), + + types.CypherDateTime(1, 1, 1, 0, 0, 0, 0), + types.CypherDateTime(1970, 1, 1, 0, 0, 0, 0), + types.CypherDateTime(9999, 12, 31, 23, 59, 59, 999999999), + types.CypherDateTime(1970, 1, 1, 0, 0, 0, 0, + utc_offset_s=0), + types.CypherDateTime(1970, 1, 1, 0, 0, 0, 0, + utc_offset_s=60 * 60 * 18), + types.CypherDateTime(1970, 1, 1, 0, 0, 0, 0, + utc_offset_s=60 * 60 * -18), + + # # FIXME: this breaks because the bolt protocol needs fixing + # # in the 80's the Swedes thought it was a good idea to introduce + # # daylight saving time. LOL, fools! + # # pre-shift 2:30 am UCT+2 + # types.CypherDateTime(1980, 9, 28, 2, 30, 0, 0, + # utc_offset_s=60 * 60 * 2, + # timezone_id="Europe/Stockholm"), + # # one (non-political) hour later + # # post-shift 2:30 am again but UCT+1 + # types.CypherDateTime(1980, 9, 28, 2, 30, 0, 0, + # utc_offset_s=60 * 60 * 1, + # timezone_id="Europe/Stockholm"), + + types.CypherDuration(0, 0, 0, 0), + types.CypherDuration(0, 0, 0, -999999999), + types.CypherDuration(1, 2, 3, 4), + types.CypherDuration(-4, -3, -2, -1), + types.CypherDuration(0, 0, MAX_INT64, 999999999), + types.CypherDuration(0, MAX_INT64 // 86400, 0, 999999999), + types.CypherDuration(MAX_INT64 // 2629746, 0, 0, 999999999), + types.CypherDuration(0, 0, MIN_INT64, 0), + # Note: `int(a / b) != a // b` + # `int(a / b)` rounds towards zero + # `a // b` rounds down (towards negative infinity) + types.CypherDuration(0, int(MIN_INT64 / 86400), 0, 0), + types.CypherDuration(int(MIN_INT64 / 2629746), 0, 0, 0), + ] + + self._create_driver_and_session() + for val in vals: + with self.subTest(val=val): + if get_driver_name() in ["python"]: + if (isinstance(val, types.CypherDateTime) + and val.utc_offset_s == 0): + self.skipTest( + "timezone library cannot tell the difference " + "between named UTC and 0s offset timezone" + ) + self._verify_can_echo(val) + + def _timezone_server_support(self, tz_id): + def work(tx): + res = tx.run( + f"RETURN datetime('1970-01-01T00:00:00.000[{tz_id}]').timezone" + ) + rec = res.next() + assert isinstance(rec, types.Record) + assert isinstance(rec.values[0], types.CypherString) + return rec.values[0].value + + assert self._driver and self._session + try: + echoed_tz_id = self._session.read_transaction(work) + return echoed_tz_id == tz_id + except types.DriverError as e: + print("timezone %s not supported by server" % tz_id) + print(e) + return False + + def test_should_echo_all_timezone_ids(self): + times = ( + # 1970-01-01 06:00:00 + (1970, 1, 1, 0, 0, 0, 0), + # 2022-06-17 13:24:34.699546224 + (2022, 6, 17, 13, 24, 34, 699546224), + # 2022-01-17 13:24:34.699546224 + (2022, 1, 17, 13, 24, 34, 699546224), + # 0001-01-02 00:00:00 + (1, 1, 2, 0, 0, 0, 0), + ) + + self._create_driver_and_session() + for tz_id in TZ_IDS: + if not self._timezone_server_support(tz_id): + continue + for time in times: + with self.subTest(tz_id=tz_id, time=time): + # TODO: fire cypher query to check if timezone is supported + # TODO: add timezones to Java blacklist + # +33078908-07-16T03:42:09.322689764+13:00[Pacific/Kanton] + # +278812530-07-26T22:36:01.559972143-02:00[America/Nuuk] + if not self.check_timezone_supported(tz_id): + self.skipTest("timezone %s not supported by driver" + % tz_id) + # FIXME: while there is a bug in the bolt protocol that + # makes it incapable of representing datetimes with + # timezone ids when there is ambiguity, we will + # avoid those. + # --------------------------------------------------------- + try: + tz = pytz.timezone(tz_id) + except pytz.UnknownTimeZoneError: + # We will be able to remove this check and test those + # timezones, once we don't need the workaround for the + # ambiguity in the bolt protocol anymore. + self.skipTest("timezone %s not supported by TestKit" + % tz_id) + naive_dt = datetime.datetime(*time[:-1]) + dst_local_dt = tz.localize(naive_dt, is_dst=True) + no_dst_local_dt = tz.localize(naive_dt, is_dst=False) + while dst_local_dt != no_dst_local_dt: + naive_dt += datetime.timedelta(hours=1) + dst_local_dt = tz.localize(naive_dt, is_dst=True) + no_dst_local_dt = tz.localize(naive_dt, is_dst=False) + # --------------------------------------------------------- + + dt = types.CypherDateTime( + naive_dt.year, + naive_dt.month, + naive_dt.day, + naive_dt.hour, + naive_dt.minute, + naive_dt.second, + time[-1], + utc_offset_s=dst_local_dt.utcoffset().total_seconds(), + timezone_id=tz_id + ) + self._verify_can_echo(dt) + + def test_date_time_cypher_created_tz_id(self): + def work(tx): + res = tx.run( + f"WITH datetime('1970-01-01T10:08:09.000000001[{tz_id}]') " + f"AS dt " + f"RETURN dt, dt.year, dt.month, dt.day, dt.hour, dt.minute, " + f"dt.second, dt.nanosecond, dt.offsetSeconds, dt.timezone" + ) + rec = res.next() + assert isinstance(rec, types.Record) + dt_, y_, mo_, d_, h_, m_, s_, ns_, offset_, tz_ = rec.values + assert isinstance(dt_, types.CypherDateTime) + assert isinstance(y_, types.CypherInt) + assert isinstance(mo_, types.CypherInt) + assert isinstance(d_, types.CypherInt) + assert isinstance(h_, types.CypherInt) + assert isinstance(m_, types.CypherInt) + assert isinstance(s_, types.CypherInt) + assert isinstance(ns_, types.CypherInt) + assert isinstance(offset_, types.CypherInt) + assert isinstance(tz_, types.CypherString) + + return map(lambda x: getattr(x, "value", x), rec.values) + + self._create_driver_and_session() + for tz_id in TZ_IDS: + if not self._timezone_server_support(tz_id): + continue + with self.subTest(tz_id=tz_id): + dt, y, mo, d, h, m, s, ns, offset, tz = \ + self._session.read_transaction(work) + self.assertEqual(dt.year, y) + self.assertEqual(dt.month, mo) + self.assertEqual(dt.day, d) + self.assertEqual(dt.hour, h) + self.assertEqual(dt.minute, m) + self.assertEqual(dt.second, s) + self.assertEqual(dt.nanosecond, ns) + self.assertEqual(dt.timezone_id, tz) + + def test_date_components(self): + self._create_driver_and_session() + values = self._read_query_values( + "CYPHER runtime=interpreted WITH $x AS x " + "RETURN [x.year, x.month, x.day]", + params={"x": types.CypherDate(2022, 3, 30)} + ) + self.assertEqual( + values, + [types.CypherList(list(map(types.CypherInt, [2022, 3, 30])))] + ) + + def test_nested_date(self): + data = types.CypherList( + [types.CypherDate(2022, 3, 30), types.CypherDate(1976, 6, 13)] + ) + self._create_driver_and_session() + values = self._write_query_values( + "CREATE (a {x:$x}) RETURN a.x", + params={"x": data} + ) + self.assertEqual(values, [data]) + + def test_cypher_created_date(self): + self._create_driver_and_session() + values = self._read_query_values("RETURN date('1976-06-13')") + self.assertEqual( + values, + [types.CypherDate(1976, 6, 13)] + ) + + def test_time_components(self): + self._create_driver_and_session() + values = self._read_query_values( + "CYPHER runtime=interpreted WITH $x AS x " + "RETURN [x.hour, x.minute, x.second, x.nanosecond]", + params={"x": types.CypherTime(13, 24, 34, 699546224)} + ) + self.assertEqual( + values, + [types.CypherList(list(map(types.CypherInt, + [13, 24, 34, 699546224])))] + ) + + def test_time_with_offset_components(self): + self._create_driver_and_session() + values = self._read_query_values( + "CYPHER runtime=interpreted WITH $x AS x " + "RETURN [x.hour, x.minute, x.second, x.nanosecond, x.offset]", + params={"x": types.CypherTime(13, 24, 34, 699546224, + utc_offset_s=-5520)} + ) + self.assertEqual( + values, + [ + types.CypherList([ + types.CypherInt(13), + types.CypherInt(24), + types.CypherInt(34), + types.CypherInt(699546224), + types.CypherString("-01:32") + ]) + ] + ) + + def test_nested_time(self): + data = types.CypherList([ + types.CypherTime(13, 24, 34, 699546224), + types.CypherTime(23, 25, 34, 699547625) + ]) + t = types.CypherTime(23, 25, 34, 699547625, utc_offset_s=-5520) + self._create_driver_and_session() + values = self._write_query_values( + "CREATE (a {x:$x, y:$y}) RETURN a.x, a.y", + params={"x": data, "y": t} + ) + self.assertEqual(values, [data, t]) + + def test_cypher_created_time(self): + for (s, t) in ( + ( + "time('13:24:34')", + types.CypherTime(13, 24, 34, 0, utc_offset_s=0) + ), + ( + "time('13:24:34.699546224')", + types.CypherTime(13, 24, 34, 699546224, utc_offset_s=0) + ), + ( + "time('12:34:56.789012345+0130')", + types.CypherTime(12, 34, 56, 789012345, utc_offset_s=5400) + ), + ( + "time('12:34:56.789012345-01:30')", + types.CypherTime(12, 34, 56, 789012345, utc_offset_s=-5400) + ), + ( + "time('12:34:56.789012345Z')", + types.CypherTime(12, 34, 56, 789012345, utc_offset_s=0) + ), + ( + "localtime('12:34:56.789012345')", + types.CypherTime(12, 34, 56, 789012345) + ), + ): + with self.subTest(s=s, t=t): + self._create_driver_and_session() + values = self._read_query_values(f"RETURN {s}") + self.assertEqual(values, [t]) + + def test_datetime_components(self): + self._create_driver_and_session() + values = self._read_query_values( + "CYPHER runtime=interpreted WITH $x AS x " + "RETURN [x.year, x.month, x.day, x.hour, x.minute, x.second, " + "x.nanosecond]", + params={"x": types.CypherDateTime(2022, 3, 30, 13, 24, 34, + 699546224)} + ) + self.assertEqual( + values, + [ + types.CypherList(list(map( + types.CypherInt, + [2022, 3, 30, 13, 24, 34, 699546224] + ))) + ] + ) + + def test_datetime_with_offset_components(self): + self._create_driver_and_session() + values = self._read_query_values( + "CYPHER runtime=interpreted WITH $x AS x " + "RETURN [x.year, x.month, x.day, x.hour, x.minute, x.second, " + "x.nanosecond, x.offset]", + params={"x": types.CypherDateTime(2022, 3, 30, 13, 24, 34, + 699546224, utc_offset_s=-5520)} + ) + self.assertEqual( + values, + [ + types.CypherList([ + types.CypherInt(2022), + types.CypherInt(3), + types.CypherInt(30), + types.CypherInt(13), + types.CypherInt(24), + types.CypherInt(34), + types.CypherInt(699546224), + types.CypherString("-01:32") + ]) + ] + ) + + def test_datetime_with_timezone_components(self): + self._create_driver_and_session() + values = self._read_query_values( + "CYPHER runtime=interpreted WITH $x AS x " + "RETURN [x.year, x.month, x.day, x.hour, x.minute, x.second, " + "x.nanosecond, x.offset, x.timezone]", + params={"x": types.CypherDateTime( + 2022, 3, 30, 13, 24, 34, 699546224, + utc_offset_s=-14400, timezone_id="America/Toronto" + )} + ) + self.assertEqual( + values, + [ + types.CypherList([ + types.CypherInt(2022), + types.CypherInt(3), + types.CypherInt(30), + types.CypherInt(13), + types.CypherInt(24), + types.CypherInt(34), + types.CypherInt(699546224), + types.CypherString("-04:00"), + types.CypherString("America/Toronto") + ]) + ] + ) + + def test_nested_datetime(self): + data = types.CypherList([ + types.CypherDateTime(2018, 4, 6, 13, 4, 42, 516120123), + types.CypherDateTime(2022, 3, 30, 0, 0, 0, 0) + ]) + dt1 = types.CypherDateTime(2022, 3, 30, 0, 0, 0, 0, utc_offset_s=-5520) + dt2 = types.CypherDateTime( + 2022, 3, 30, 13, 24, 34, 699546224, + utc_offset_s=-14400, timezone_id="America/Toronto" + ) + self._create_driver_and_session() + values = self._write_query_values( + "CREATE (a {x:$x, y:$y, z:$z}) RETURN a.x, a.y, a.z", + params={"x": data, "y": dt1, "z": dt2} + ) + self.assertEqual(values, [data, dt1, dt2]) + + def test_cypher_created_datetime(self): + for (s, dt) in ( + ( + "datetime('1976-06-13T12:34:56')", + types.CypherDateTime(1976, 6, 13, 12, 34, 56, 0, + utc_offset_s=0, timezone_id="UTC") + ), + ( + "datetime('1976-06-13T12:34:56.999888777')", + types.CypherDateTime(1976, 6, 13, 12, 34, 56, 999888777, + utc_offset_s=0, timezone_id="UTC") + ), + ( + "datetime('1976-06-13T12:34:56.999888777-05:00')", + types.CypherDateTime(1976, 6, 13, 12, 34, 56, 999888777, + utc_offset_s=-18000) + ), + ( + "datetime('1976-06-13T12:34:56.789012345[Europe/London]')", + types.CypherDateTime( + 1976, 6, 13, 12, 34, 56, 789012345, + utc_offset_s=3600, timezone_id="Europe/London" + ) + ), + ( + "localdatetime('1976-06-13T12:34:56')", + types.CypherDateTime(1976, 6, 13, 12, 34, 56, 0) + ), + ( + "localdatetime('1976-06-13T12:34:56.123')", + types.CypherDateTime(1976, 6, 13, 12, 34, 56, 123000000) + ), + ): + with self.subTest(s=s, dt=dt): + self._create_driver_and_session() + values = self._read_query_values(f"RETURN {s}") + self.assertEqual(values, [dt]) + + def test_duration_components(self): + for (mo, d, s, ns_os, ns) in ( + (3, 4, 999, 123456789, 999_123456789), + (0, 0, MAX_INT64, 999999999, -1), # LUL, Cypher overflows + ): + with self.subTest(mo=mo, d=d, s=s, ns=ns): + self._create_driver_and_session() + values = self._read_query_values( + "CYPHER runtime=interpreted WITH $x AS x " + "RETURN [x.months, x.days, x.seconds, " + "x.nanosecondsOfSecond, x.nanoseconds]", + params={"x": types.CypherDuration( + months=mo, days=d, seconds=s, nanoseconds=ns_os + )} + ) + self.assertEqual( + values, + [types.CypherList(list(map(types.CypherInt, + [mo, d, s, ns_os, ns])))] + ) + + def test_nested_duration(self): + data = types.CypherList([ + types.CypherDuration(months=3, days=4, seconds=999, + nanoseconds=123456789), + types.CypherDuration(months=0, days=0, seconds=MAX_INT64, + nanoseconds=999999999) + ]) + self._create_driver_and_session() + values = self._write_query_values("CREATE (a {x:$x}) RETURN a.x", + params={"x": data}) + self.assertEqual(values, [data]) + + def test_cypher_created_duration(self): + for (s, d) in ( + ( + "duration('P1234M567890DT123456789123.999S')", + types.CypherDuration(1234, 567890, 123456789123, 999000000) + ), + ( + "duration('P1Y2M3W4DT5H6M7.999888777S')", + types.CypherDuration( + 12 * 1 + 2, + 7 * 3 + 4, + 5 * 60 * 60 + 6 * 60 + 7, + 999888777 + ) + ), + ): + with self.subTest(s=s, d=d): + self._create_driver_and_session() + values = self._read_query_values(f"RETURN {s}") + self.assertEqual(values, [d]) diff --git a/tests/neo4j/suites.py b/tests/neo4j/suites.py index 7e30f2110..6db8ff71a 100644 --- a/tests/neo4j/suites.py +++ b/tests/neo4j/suites.py @@ -1,18 +1,9 @@ """Defines suites of test to run in different setups.""" +import os import sys import unittest -from tests.neo4j import ( - test_authentication, - test_bookmarks, - test_datatypes, - test_direct_driver, - test_session_run, - test_summary, - test_tx_func_run, - test_tx_run, -) from tests.neo4j.shared import env_neo4j_version from tests.testenv import ( begin_test_suite, @@ -20,20 +11,19 @@ get_test_result_class, ) -loader = unittest.TestLoader() - ####################### # Suite for Neo4j 4.2 # ####################### +loader = unittest.TestLoader() + suite_4x2 = unittest.TestSuite() -suite_4x2.addTests(loader.loadTestsFromModule(test_authentication)) -suite_4x2.addTests(loader.loadTestsFromModule(test_bookmarks)) -suite_4x2.addTests(loader.loadTestsFromModule(test_datatypes)) -suite_4x2.addTests(loader.loadTestsFromModule(test_direct_driver)) -suite_4x2.addTests(loader.loadTestsFromModule(test_session_run)) -suite_4x2.addTests(loader.loadTestsFromModule(test_summary)) -suite_4x2.addTests(loader.loadTestsFromModule(test_tx_func_run)) -suite_4x2.addTests(loader.loadTestsFromModule(test_tx_run)) + +suite_4x2.addTest(loader.discover( + "tests.neo4j", + top_level_dir=os.path.abspath(os.path.join( + os.path.dirname(__file__), "..", ".." + )) +)) ####################### # Suite for Neo4j 4.3 # diff --git a/tests/neo4j/test_session_run.py b/tests/neo4j/test_session_run.py index 64da3215f..ccf7603d8 100644 --- a/tests/neo4j/test_session_run.py +++ b/tests/neo4j/test_session_run.py @@ -260,7 +260,7 @@ def _test(): self._session1 = None for consume in (True, False): - with self.subTest("consume" if consume else "iterate"): + with self.subTest(consume=consume): _test() @cluster_unsafe_test @@ -280,7 +280,7 @@ def _test(): self._session1 = None for consume in (True, False): - with self.subTest("consume" if consume else "no consume"): + with self.subTest(consume=consume): _test() @cluster_unsafe_test diff --git a/tests/neo4j/test_tx_func_run.py b/tests/neo4j/test_tx_func_run.py index a1951f355..f18bd0573 100644 --- a/tests/neo4j/test_tx_func_run.py +++ b/tests/neo4j/test_tx_func_run.py @@ -44,7 +44,7 @@ def _test(): self._session1 = None for consume in (True, False): - with self.subTest("consume" if consume else "iterate"): + with self.subTest(consume=consume): _test() def test_iteration_nested(self): diff --git a/tests/neo4j/test_tx_run.py b/tests/neo4j/test_tx_run.py index caab7e5ac..146ca9fa9 100644 --- a/tests/neo4j/test_tx_run.py +++ b/tests/neo4j/test_tx_run.py @@ -50,8 +50,7 @@ def _test(): for consume in (True, False): for rollback in (True, False): - with self.subTest("consume" if consume else "iterate" - + "_rollback" if rollback else "_commit"): + with self.subTest(consume=consume, rollback=rollback): _test() @cluster_unsafe_test @@ -371,7 +370,7 @@ def _test(): self._session1 = None for invert_fetching in (True, False): - with self.subTest("inverted" if invert_fetching else "in_order"): + with self.subTest(invert_fetching=invert_fetching): _test() @cluster_unsafe_test @@ -418,8 +417,7 @@ def _test(): self._session1 = None for run_q2_before_q1_fetch in (True, False): - with self.subTest("run_q2_before_q1_fetch-%s" - % run_q2_before_q1_fetch): + with self.subTest(run_q2_before_q1_fetch=run_q2_before_q1_fetch): _test() @cluster_unsafe_test @@ -448,5 +446,5 @@ def _test(): self.assertEqual(len(list(res)), commit) for commit in (True, False): - with self.subTest("commit" if commit else "rollback"): + with self.subTest(commit=commit): _test() diff --git a/tests/shared.py b/tests/shared.py index e546045c9..3dd39fd94 100644 --- a/tests/shared.py +++ b/tests/shared.py @@ -9,6 +9,8 @@ TEST_BACKEND_PORT Port on backend host, default is 9876 """ + +from contextlib import contextmanager import inspect import os import re @@ -145,7 +147,10 @@ class TestkitTestCase(unittest.TestCase): def setUp(self): super().setUp() - id_ = re.sub(r"^([^\.]+\.)*?tests\.", "", self.id()) + self._testkit_test_name = id_ = re.sub( + r"^([^\.]+\.)*?tests\.", "", self.id() + ) + self._check_subtests = False self._backend = new_backend() self.addCleanup(self._backend.close) self._driver_features = get_driver_features(self._backend) @@ -198,9 +203,10 @@ def setUp(self): response = self._backend.send_and_receive(protocol.StartTest(id_)) if isinstance(response, protocol.SkipTest): self.skipTest(response.reason) - + elif isinstance(response, protocol.RunSubTests): + self._check_subtests = True elif not isinstance(response, protocol.RunTest): - raise Exception("Should be SkipTest or RunTest, " + raise Exception("Should be SkipTest, RunSubTests, or RunTest, " "received {}: {}".format(type(response), response)) @@ -251,3 +257,27 @@ def skip_if_missing_bolt_support(self, version): def script_path(self, *path): base_path = os.path.dirname(inspect.getfile(self.__class__)) return os.path.join(base_path, "scripts", *path) + + @contextmanager + def subTest(self, **params): # noqa: N802 + assert "msg" not in params + subtest_context = super().subTest(**params) + with subtest_context: + if not self._check_subtests: + yield + return + response = self._backend.send_and_receive( + protocol.StartSubTest(self._testkit_test_name, params) + ) + # we have to run the subtest, but we don't care for the result + # if we want to throw or skip (in fact also a throw) + try: + yield + finally: + if isinstance(response, protocol.SkipTest): + # skipping after the fact :/ + self.skipTest(response.reason) + elif not isinstance(response, protocol.RunTest): + raise Exception("Should be SkipTest, or RunTest, " + "received {}: {}".format(type(response), + response)) diff --git a/tests/stub/authorization/test_authorization.py b/tests/stub/authorization/test_authorization.py index 7560b19d4..625a82ca6 100644 --- a/tests/stub/authorization/test_authorization.py +++ b/tests/stub/authorization/test_authorization.py @@ -952,7 +952,7 @@ def test(): self._server.done() for realm in (None, "", "foobar"): - with self.subTest("realm-%s" % realm): + with self.subTest(realm=realm): test() self._server.reset() diff --git a/tests/stub/bookmarks/test_bookmarks_v4.py b/tests/stub/bookmarks/test_bookmarks_v4.py index b6f878118..326725ea1 100644 --- a/tests/stub/bookmarks/test_bookmarks_v4.py +++ b/tests/stub/bookmarks/test_bookmarks_v4.py @@ -70,7 +70,7 @@ def test(mode_, bm_count_): for mode in ("read", "write"): for bm_count in (0, 1, 2): - with self.subTest(mode + "_%i_bookmarks" % bm_count): + with self.subTest(mode=mode, bm_count=bm_count): test(mode, bm_count) def test_bookmarks_session_run(self): @@ -96,12 +96,11 @@ def test(mode_, bm_count_, check_bms_pre_query_, consume_): # TODO: make a decision if consume should be triggered # implicitly or not. for consume in (False, True)[1:]: - with self.subTest(mode + "_%i_bookmarks%s%s" % ( - bm_count, - "_check_bms_pre_query" if check_bms_pre_query - else "", - "_consume" if consume else "_no_consume" - )): + with self.subTest( + mode=mode, bm_count=bm_count, + check_bms_pre_query=check_bms_pre_query, + consume=consume + ): test(mode, bm_count, check_bms_pre_query, consume) def test_bookmarks_tx_run(self): @@ -130,12 +129,11 @@ def test(mode_, bm_count_, check_bms_pre_query_, consume_): for bm_count in (0, 1, 2): for check_bms_pre_query in (False, True): for consume in (False, True): - with self.subTest(mode + "_%i_bookmarks%s%s" % ( - bm_count, - "_check_bms_pre_query" if check_bms_pre_query - else "", - "_consume" if consume else "_no_consume" - )): + with self.subTest( + mode=mode, bm_count=bm_count, + check_bms_pre_query=check_bms_pre_query, + consume=consume + ): test(mode, bm_count, check_bms_pre_query, consume) def test_bookmarks_tx_func(self): @@ -168,12 +166,11 @@ def test(mode_, bm_count_, check_bms_pre_query_, consume_): for bm_count in (0, 1, 2): for check_bms_pre_query in (False, True): for consume in (False, True): - with self.subTest(mode + "_%i_bookmarks%s%s" % ( - bm_count, - "_check_bms_pre_query" if check_bms_pre_query - else "", - "_consume" if consume else "_no_consume" - )): + with self.subTest( + mode=mode, bm_count=bm_count, + check_bms_pre_query=check_bms_pre_query, + consume=consume + ): test(mode, bm_count, check_bms_pre_query, consume) def test_sequence_of_writing_and_reading_tx(self): diff --git a/tests/stub/bookmarks/test_bookmarks_v5.py b/tests/stub/bookmarks/test_bookmarks_v5.py index 2ed4ec954..f5464f4c1 100644 --- a/tests/stub/bookmarks/test_bookmarks_v5.py +++ b/tests/stub/bookmarks/test_bookmarks_v5.py @@ -36,7 +36,7 @@ def test(): # TODO: decide what we expect to happen when multiple bookmarks are # passed in: return all or only the last one? for bm_count in (0, 1): - with self.subTest(mode + "_%i_bookmarks" % bm_count): + with self.subTest(mode=mode, bm_count=bm_count): test() # Tests that a committed transaction can return the last bookmark diff --git a/tests/stub/homedb/test_homedb.py b/tests/stub/homedb/test_homedb.py index abdfe1dcb..e8f8352b7 100644 --- a/tests/stub/homedb/test_homedb.py +++ b/tests/stub/homedb/test_homedb.py @@ -60,8 +60,7 @@ def _test(): self._reader1.done() for parallel_sessions in (True, False): - with self.subTest("parallel-sessions" if parallel_sessions - else "sequential-sessions"): + with self.subTest(parallel_sessions=parallel_sessions): _test() self._router.reset() self._reader1.reset() @@ -104,8 +103,7 @@ def _test(): self._reader1.done() for parallel_sessions in (True, False): - with self.subTest("parallel-sessions" if parallel_sessions - else "sequential-sessions"): + with self.subTest(parallel_sessions=parallel_sessions): _test() self._router.reset() self._reader1.reset() @@ -148,8 +146,7 @@ def work(tx): self._reader1.done() for parallel_sessions in (True, False): - with self.subTest("parallel-sessions" if parallel_sessions - else "sequential-sessions"): + with self.subTest(parallel_sessions=parallel_sessions): _test() self._router.reset() self._reader1.reset() diff --git a/tests/stub/iteration/test_iteration_session_run.py b/tests/stub/iteration/test_iteration_session_run.py index f3cbfb6c9..f9fd965df 100644 --- a/tests/stub/iteration/test_iteration_session_run.py +++ b/tests/stub/iteration/test_iteration_session_run.py @@ -115,5 +115,5 @@ def test(version_, script_): # background continue for mode in ("write", "read"): - with self.subTest(version + "-" + mode): + with self.subTest(version=version, mode=mode): test(version, script) diff --git a/tests/stub/iteration/test_result_list.py b/tests/stub/iteration/test_result_list.py index 533e76d5f..351e39b35 100644 --- a/tests/stub/iteration/test_result_list.py +++ b/tests/stub/iteration/test_result_list.py @@ -63,7 +63,7 @@ def _test(): ]) for fetch_size in (1, 2): - with self.subTest("fetch_size-%i" % fetch_size): + with self.subTest(fetch_size=fetch_size): _test() def test_result_list_with_disconnect(self): diff --git a/tests/stub/iteration/test_result_optional_single.py b/tests/stub/iteration/test_result_optional_single.py new file mode 100644 index 000000000..7ee00c46b --- /dev/null +++ b/tests/stub/iteration/test_result_optional_single.py @@ -0,0 +1,120 @@ +import nutkit.protocol as types +from tests.shared import get_driver_name + +from ._common import IterationTestBase + + +class TestResultSingleOptional(IterationTestBase): + + required_features = (types.Feature.BOLT_4_4, + types.Feature.API_RESULT_SINGLE_OPTIONAL) + + def _assert_not_exactly_one_record_warning(self, warnings): + self.assertEqual(1, len(warnings)) + warning = warnings[0] + driver = get_driver_name() + if driver in ["python"]: + self.assertIn("multiple", warning) + else: + self.fail("no error mapping is defined for %s driver" % driver) + + def _assert_connection_error(self, error): + self.assertIsInstance(error, types.DriverError) + driver = get_driver_name() + if driver in ["python"]: + self.assertEqual("", + error.errorType) + else: + self.fail("no error mapping is defined for %s driver" % driver) + + def test_result_single_optional_with_0_records(self): + with self._session("yield_0_records.script") as session: + result = session.run("RETURN 1 AS n") + optional_record = result.single_optional() + self.assertIsNone(optional_record.record) + self.assertEqual(optional_record.warnings, []) + + def test_result_single_optional_with_1_records(self): + with self._session("yield_1_record.script") as session: + result = session.run("RETURN 1 AS n") + optional_record = result.single_optional() + record, warnings = optional_record.record, optional_record.warnings + self.assertIsInstance(record, types.Record) + self.assertEqual(record.values, [types.CypherInt(1)]) + self.assertEqual(warnings, []) + + def test_result_single_optional_with_2_records(self): + def _test(): + with self._session("yield_2_records.script", + fetch_size=fetch_size) as session: + result = session.run("RETURN 1 AS n") + + optional_record = result.single_optional() + record = optional_record.record + warnings = optional_record.warnings + self.assertIsInstance(record, types.Record) + self.assertEqual(record.values, [types.CypherInt(1)]) + self._assert_not_exactly_one_record_warning(warnings) + + # single_optional(), should always exhaust the full result + # stream to prevent abusing it as another `next` method. + for _ in range(2): + optional_record = result.single_optional() + self.assertIsNone(optional_record.record) + self.assertEqual(optional_record.warnings, []) + + for fetch_size in (1, 2): + with self.subTest(fetch_size=fetch_size): + _test() + + def test_result_single_optional_with_disconnect(self): + with self._session("disconnect_on_pull.script") as session: + result = session.run("RETURN 1 AS n") + with self.assertRaises(types.DriverError) as exc: + result.single_optional() + self._assert_connection_error(exc.exception) + + def test_result_single_optional_with_failure(self): + err = "Neo.TransientError.Completely.MadeUp" + + with self._session("error_on_pull.script", + vars_={"#ERROR#": err}) as session: + result = session.run("RETURN 1 AS n") + with self.assertRaises(types.DriverError) as exc: + result.single_optional() + self.assertEqual(err, exc.exception.code) + + def test_result_single_optional_with_failure_tx_run(self): + err = "Neo.TransientError.Completely.MadeUp" + + with self._session("tx_error_on_pull.script", + vars_={"#ERROR#": err}) as session: + tx = session.begin_transaction() + result = tx.run("RETURN 1 AS n") + with self.assertRaises(types.DriverError) as exc: + result.single_optional() + self.assertEqual(err, exc.exception.code) + + def test_result_single_optional_with_failure_tx_func_run(self): + err = "Neo.TransientError.Completely.MadeUp" + work_call_count = 0 + + def work(tx): + nonlocal work_call_count + work_call_count += 1 + result = tx.run("RETURN 1 AS n") + if work_call_count == 1: + with self.assertRaises(types.DriverError) as exc: + result.single_optional() + self.assertEqual(err, exc.exception.code) + raise exc.exception + else: + return result.single_optional() + + with self._session("tx_error_on_pull.script", + vars_={"#ERROR#": err}) as session: + optional_record = session.read_transaction(work) + record, warnings = optional_record.record, optional_record.warnings + self.assertIsInstance(record, types.Record) + self.assertEqual(record.values, [types.CypherInt(1)]) + self.assertEqual(warnings, []) diff --git a/tests/stub/iteration/test_result_peek.py b/tests/stub/iteration/test_result_peek.py index c0f59f1bf..e083b19a0 100644 --- a/tests/stub/iteration/test_result_peek.py +++ b/tests/stub/iteration/test_result_peek.py @@ -77,7 +77,7 @@ def _test(): self.assertIsInstance(record, types.NullRecord) for fetch_size in (1, 2): - with self.subTest("fetch_size-%i" % fetch_size): + with self.subTest(fetch_size=fetch_size): _test() @driver_feature(types.Feature.API_RESULT_PEEK) diff --git a/tests/stub/iteration/test_result_single.py b/tests/stub/iteration/test_result_single.py index 84a62cd28..9d4df605f 100644 --- a/tests/stub/iteration/test_result_single.py +++ b/tests/stub/iteration/test_result_single.py @@ -1,15 +1,12 @@ import nutkit.protocol as types -from tests.shared import ( - driver_feature, - get_driver_name, -) +from tests.shared import get_driver_name from ._common import IterationTestBase class TestResultSingle(IterationTestBase): - required_features = types.Feature.BOLT_4_4, + required_features = types.Feature.BOLT_4_4, types.Feature.API_RESULT_SINGLE def _assert_not_exactly_one_record_error(self, error): self.assertIsInstance(error, types.DriverError) @@ -43,7 +40,6 @@ def _assert_connection_error(self, error): else: self.fail("no error mapping is defined for %s driver" % driver) - @driver_feature(types.Feature.API_RESULT_SINGLE) def test_result_single_with_0_records(self): with self._session("yield_0_records.script") as session: result = session.run("RETURN 1 AS n") @@ -51,7 +47,6 @@ def test_result_single_with_0_records(self): result.single() self._assert_not_exactly_one_record_error(exc.exception) - @driver_feature(types.Feature.API_RESULT_SINGLE) def test_result_single_with_1_records(self): with self._session("yield_1_record.script") as session: result = session.run("RETURN 1 AS n") @@ -59,7 +54,6 @@ def test_result_single_with_1_records(self): self.assertIsInstance(record, types.Record) self.assertEqual(record.values, [types.CypherInt(1)]) - @driver_feature(types.Feature.API_RESULT_SINGLE) def test_result_single_with_2_records(self): def _test(): with self._session("yield_2_records.script", @@ -70,10 +64,9 @@ def _test(): self._assert_not_exactly_one_record_error(exc.exception) for fetch_size in (1, 2): - with self.subTest("fetch_size-%i" % fetch_size): + with self.subTest(fetch_size=fetch_size): _test() - @driver_feature(types.Feature.API_RESULT_SINGLE) def test_result_single_with_disconnect(self): with self._session("disconnect_on_pull.script") as session: result = session.run("RETURN 1 AS n") @@ -81,7 +74,6 @@ def test_result_single_with_disconnect(self): result.single() self._assert_connection_error(exc.exception) - @driver_feature(types.Feature.API_RESULT_SINGLE) def test_result_single_with_failure(self): err = "Neo.TransientError.Completely.MadeUp" @@ -92,7 +84,6 @@ def test_result_single_with_failure(self): result.single() self.assertEqual(err, exc.exception.code) - @driver_feature(types.Feature.API_RESULT_SINGLE) def test_result_single_with_failure_tx_run(self): err = "Neo.TransientError.Completely.MadeUp" @@ -104,7 +95,6 @@ def test_result_single_with_failure_tx_run(self): result.single() self.assertEqual(err, exc.exception.code) - @driver_feature(types.Feature.API_RESULT_SINGLE) def test_result_single_with_failure_tx_func_run(self): err = "Neo.TransientError.Completely.MadeUp" work_call_count = 0 diff --git a/tests/stub/optimizations/test_optimizations.py b/tests/stub/optimizations/test_optimizations.py index 108f7c5b5..6b11ead93 100644 --- a/tests/stub/optimizations/test_optimizations.py +++ b/tests/stub/optimizations/test_optimizations.py @@ -51,7 +51,7 @@ def test(): for mode in ("read", "write"): for use_tx in (True, False): - with self.subTest(mode + ("_tx" if use_tx else "")): + with self.subTest(mode=mode, use_tx=use_tx): test() self._server.reset() @@ -138,11 +138,8 @@ def test_reuses_connection(self): for use_tx in (None, "commit", "rollback"): for new_session in (True, False): with self.subTest( - mode - + (("_tx_" + use_tx) if use_tx else "") - + ("_one_shot_session" - if new_session else "_reuse_session") - + ("_routing" if routing else "_direct") + routing=routing, mode=mode, use_tx=use_tx, + new_session=new_session ): self.double_read( mode, new_session, use_tx, routing, @@ -165,14 +162,8 @@ def test_no_reset_on_clean_connection(self): continue for use_tx in (None, "commit", "rollback"): for new_session in (True, False): - with self.subTest(mode - + "_" + version - + (("_tx_" + use_tx) if use_tx - else "") - + ("_one_shot_session" if new_session - else"_reuse_session") - + ("_discard" if consume - else "_pull")): + with self.subTest(version=version, consume=consume, + use_tx=use_tx): self.double_read(mode, new_session, use_tx, False, version=version, consume=consume, check_no_reset=True) @@ -226,12 +217,8 @@ def test(): for fail_on in ("pull", "run", "begin"): if fail_on == "begin" and not use_tx: continue - with self.subTest( - version - + ("_tx" if use_tx else "_autocommit") - + "_{}".format(fail_on) - + ("_routing" if routing else "_no_routing") - ): + with self.subTest(version=version, use_tx=use_tx, + routing=routing, fail_on=fail_on): test() self._server.reset() self._router.reset() @@ -289,11 +276,8 @@ def test(): for use_tx in (True, False): for consume in (True, False): for routing in (True, False): - with self.subTest( - ("tx" if use_tx else "auto_commit") - + ("_discard" if consume else "_pull") - + ("_routing" if routing else "_no_routing") - ): + with self.subTest(version=version, use_tx=use_tx, + consume=consume, routing=routing): test() self._server.reset() self._router.reset() @@ -331,11 +315,8 @@ def test(): for version in ("v4x3", "v4x4"): for consume1 in (True, False): for consume2 in (True, False): - with self.subTest( - version - + ("_discard1" if consume1 else "_pull1") - + ("_discard2" if consume2 else "_pull2") - ): + with self.subTest(version=version, consume1=consume1, + consume2=consume2): test() self._server.reset() self._router.reset() @@ -375,8 +356,7 @@ def test(): continue for consume1 in (True, False): for consume2 in (True, False): - with self.subTest(("discard1" if consume1 else "pull1") - + ("_discard2" - if consume2 else "_pull2")): + with self.subTest(version=version, consume1=consume1, + consume2=consume2): test() self._server.reset() diff --git a/tests/stub/retry/test_retry.py b/tests/stub/retry/test_retry.py index 4e27b5af5..719ab86a4 100644 --- a/tests/stub/retry/test_retry.py +++ b/tests/stub/retry/test_retry.py @@ -171,6 +171,6 @@ def once(tx): self._server.done() for mode in ("read", "write"): - with self.subTest(mode): + with self.subTest(mode=mode): _test() self._server.reset() diff --git a/tests/stub/session_run_parameters/test_session_run_parameters.py b/tests/stub/session_run_parameters/test_session_run_parameters.py index 3809474e0..938d76968 100644 --- a/tests/stub/session_run_parameters/test_session_run_parameters.py +++ b/tests/stub/session_run_parameters/test_session_run_parameters.py @@ -83,21 +83,21 @@ def _run(self, script, routing, session_args=None, session_kwargs=None, @driver_feature(types.Feature.BOLT_4_4) def test_access_mode_read(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("access_mode_read", routing, session_args=("r",)) @driver_feature(types.Feature.BOLT_4_4) def test_access_mode_write(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("access_mode_write", routing, session_args=("w",)) @driver_feature(types.Feature.BOLT_4_4) def test_parameters(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("parameters", routing, session_args=("w",), run_kwargs={"params": {"p": types.CypherInt(1)}}) @@ -105,7 +105,7 @@ def test_parameters(self): @driver_feature(types.Feature.BOLT_4_4) def test_bookmarks(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("bookmarks", routing, session_args=("w",), session_kwargs={"bookmarks": ["b1", "b2"]}) @@ -113,7 +113,7 @@ def test_bookmarks(self): @driver_feature(types.Feature.BOLT_4_4) def test_tx_meta(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("tx_meta", routing, session_args=("w",), run_kwargs={"tx_meta": {"akey": "aval"}}) @@ -121,7 +121,7 @@ def test_tx_meta(self): @driver_feature(types.Feature.BOLT_4_4) def test_timeout(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("timeout_17", routing, session_args=("w",), run_kwargs={"timeout": 17}) @@ -149,7 +149,7 @@ def test_negative_timeout(self): types.Feature.BOLT_4_4) def test_database(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("adb", routing, session_args=("w",), session_kwargs={"database": "adb"}) @@ -158,7 +158,7 @@ def test_database(self): types.Feature.BOLT_4_4) def test_impersonation(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("imp_user", routing, session_args=("w",), session_kwargs={ @@ -169,7 +169,7 @@ def test_impersonation(self): types.Feature.BOLT_4_3) def test_impersonation_fails_on_v4x3(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): with self.assertRaises(types.DriverError) as exc: self._run("imp_user_v4x3", routing, session_args=("w",), @@ -199,7 +199,7 @@ def test_impersonation_fails_on_v4x3(self): types.Feature.BOLT_4_4) def test_combined(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("combined", routing, session_args=("r",), session_kwargs={ @@ -219,7 +219,7 @@ def test_empty_query(self): if get_driver_name() in ["javascript", "java"]: self.skipTest("rejects empty string") for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._start_servers_and_driver("empty_query", routing, None, None) session = self._driver.session("w") diff --git a/tests/stub/summary/test_summary.py b/tests/stub/summary/test_summary.py index 1f1ceda6f..dd9c87f47 100644 --- a/tests/stub/summary/test_summary.py +++ b/tests/stub/summary/test_summary.py @@ -104,7 +104,7 @@ def _test(): self.assertEqual(summary.query_type, query_type) for query_type in ("r", "w", "rw", "s", None): - with self.subTest(query_type): + with self.subTest(query_type=query_type): _test() @driver_feature(types.Feature.TMP_FULL_SUMMARY) @@ -128,7 +128,7 @@ def _test(): ) for query_type in ("wr",): - with self.subTest(query_type): + with self.subTest(query_type=query_type): _test() @driver_feature(types.Feature.TMP_FULL_SUMMARY) diff --git a/tests/stub/tx_begin_parameters/test_tx_begin_parameters.py b/tests/stub/tx_begin_parameters/test_tx_begin_parameters.py index 94922cb31..98b140c53 100644 --- a/tests/stub/tx_begin_parameters/test_tx_begin_parameters.py +++ b/tests/stub/tx_begin_parameters/test_tx_begin_parameters.py @@ -101,7 +101,7 @@ def work(tx_): @driver_feature(types.Feature.BOLT_4_4) def test_access_mode_read(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("access_mode_read", routing, session_args=("r",)) self._server.done() @@ -109,8 +109,8 @@ def test_access_mode_read(self): def test_tx_func_access_mode_read(self): for routing in (True, False): for session_access_mode in ("r", "w"): - with self.subTest("session_mode_" + session_access_mode - + "_routing" if routing else "_direct"): + with self.subTest(routing=routing, + session_access_mode=session_access_mode): self._run("access_mode_read", routing, session_args=(session_access_mode[0],), tx_func_access_mode="r") @@ -118,7 +118,7 @@ def test_tx_func_access_mode_read(self): @driver_feature(types.Feature.BOLT_4_4) def test_access_mode_write(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("access_mode_write", routing, session_args=("w",)) self._server.done() @@ -127,7 +127,8 @@ def test_access_mode_write(self): def test_tx_func_access_mode(self): for routing in (True, False): for session_access_mode in ("r", "w"): - with self.subTest("session_mode_" + session_access_mode): + with self.subTest(routing=routing, + session_access_mode=session_access_mode): self._run("access_mode_write", routing, session_args=(session_access_mode[0],), tx_func_access_mode="w") @@ -135,7 +136,7 @@ def test_tx_func_access_mode(self): @driver_feature(types.Feature.BOLT_4_4) def test_parameters(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("parameters", routing, session_args=("w",), run_kwargs={"params": {"p": types.CypherInt(1)}}) @@ -143,7 +144,7 @@ def test_parameters(self): @driver_feature(types.Feature.BOLT_4_4) def test_bookmarks(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("bookmarks", routing, session_args=("w",), session_kwargs={"bookmarks": ["b1", "b2"]}) @@ -151,7 +152,7 @@ def test_bookmarks(self): @driver_feature(types.Feature.BOLT_4_4) def test_tx_meta(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("tx_meta", routing, session_args=("w",), tx_kwargs={"tx_meta": {"akey": "aval"}}) @@ -159,7 +160,7 @@ def test_tx_meta(self): @driver_feature(types.Feature.BOLT_4_4) def test_timeout(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("timeout_17", routing, session_args=("w",), tx_kwargs={"timeout": 17}) @@ -187,7 +188,7 @@ def test_default_timeout(self): @driver_feature(types.Feature.BOLT_4_4) def test_database(self): for routing in (True, False)[1:]: - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("adb", routing, session_args=("w",), session_kwargs={"database": "adb"}) @@ -196,7 +197,7 @@ def test_database(self): types.Feature.BOLT_4_4) def test_impersonation(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("imp_user", routing, session_args=("w",), session_kwargs={ @@ -207,7 +208,7 @@ def test_impersonation(self): types.Feature.BOLT_4_3) def test_impersonation_fails_on_v4x3(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): with self.assertRaises(types.DriverError) as exc: self._run("imp_user_v4x3", routing, session_args=("w",), @@ -241,7 +242,7 @@ def test_impersonation_fails_on_v4x3(self): types.Feature.BOLT_4_4) def test_combined(self): for routing in (True, False): - with self.subTest("routing" if routing else "direct"): + with self.subTest(routing=routing): self._run("combined", routing, session_args=("r",), run_kwargs={"params": {"p": types.CypherInt(1)}}, diff --git a/tests/stub/versions/test_versions.py b/tests/stub/versions/test_versions.py index 7bc5a0157..b0cac8c16 100644 --- a/tests/stub/versions/test_versions.py +++ b/tests/stub/versions/test_versions.py @@ -141,7 +141,7 @@ def test_server_version(self): for version in ("5x0", "4x4", "4x3", "4x2", "4x1", "3"): if not self.driver_supports_bolt(version): continue - with self.subTest(version): + with self.subTest(version=version): self._run(version, check_version=True) def test_server_agent(self): @@ -169,7 +169,7 @@ def test_server_agent(self): continue if not self.driver_supports_bolt(version): continue - with self.subTest(version + "-" + agent.replace(".", "x")): + with self.subTest(version=version, agent=agent): self._run(version, server_agent=agent, rejected_agent=reject) @@ -180,7 +180,7 @@ def test_server_address_in_summary(self): for version in ("5x0", "4x4", "4x3", "4x2", "4x1", "3"): if not self.driver_supports_bolt(version): continue - with self.subTest(version): + with self.subTest(version=version): self._run(version, check_server_address=True) def test_obtain_summary_twice(self): diff --git a/tests/tls/test_explicit_options.py b/tests/tls/test_explicit_options.py index 8ca375b40..e06778106 100644 --- a/tests/tls/test_explicit_options.py +++ b/tests/tls/test_explicit_options.py @@ -51,5 +51,6 @@ def _test(): if not supports_value_equality or encrypted: cert_options.append("None") for certs in cert_options: - with self.subTest("%s-%s-%s" % (scheme, encrypted, certs)): + with self.subTest(scheme=scheme, encrypted=encrypted, + certs=certs): _test() diff --git a/tests/tls/test_secure_scheme.py b/tests/tls/test_secure_scheme.py index 711d430c5..05187464c 100644 --- a/tests/tls/test_secure_scheme.py +++ b/tests/tls/test_secure_scheme.py @@ -39,7 +39,7 @@ def _start_server(self, cert_suffix, **kwargs): def test_driver_is_encrypted(self): for driver_config in self.extra_driver_configs: for scheme in self.schemes: - with self.subTest(scheme + "-" + str(driver_config)): + with self.subTest(scheme=scheme, driver_config=driver_config): self._test_reports_encrypted(True, scheme, **driver_config) def test_trusted_ca_correct_hostname(self): @@ -47,7 +47,7 @@ def test_trusted_ca_correct_hostname(self): # trusted certificate authority. for driver_config in self.extra_driver_configs: for scheme in self.schemes: - with self.subTest(scheme + "-" + str(driver_config)): + with self.subTest(scheme=scheme, driver_config=driver_config): self._start_server("thehost") self.assertTrue(self._try_connect( self._server, scheme, "thehost", **driver_config @@ -59,7 +59,7 @@ def test_trusted_ca_expired_server_correct_hostname(self): # certificate has expired. Should not connect on expired certificate. for driver_config in self.extra_driver_configs: for scheme in self.schemes: - with self.subTest(scheme + "-" + str(driver_config)): + with self.subTest(scheme=scheme, driver_config=driver_config): self._start_server("thehost_expired") self.assertFalse(self._try_connect( self._server, scheme, "thehost", **driver_config @@ -76,7 +76,7 @@ def test_trusted_ca_wrong_hostname(self): # all. for driver_config in self.extra_driver_configs: for scheme in self.schemes: - with self.subTest(scheme + "-" + str(driver_config)): + with self.subTest(scheme=scheme, driver_config=driver_config): self._start_server("thehost") self.assertFalse(self._try_connect( self._server, scheme, "thehostbutwrong", @@ -89,7 +89,7 @@ def test_untrusted_ca_correct_hostname(self): # trusted for driver_config in self.extra_driver_configs: for scheme in self.schemes: - with self.subTest(scheme + "-" + str(driver_config)): + with self.subTest(scheme=scheme, driver_config=driver_config): self._server = TlsServer("untrustedRoot_thehost") self.assertFalse(self._try_connect( self._server, scheme, "thehost", **driver_config @@ -101,7 +101,7 @@ def test_unencrypted(self): # TLS connections but the server doesn't speak TLS for driver_config in self.extra_driver_configs: for scheme in self.schemes: - with self.subTest(scheme + "-" + str(driver_config)): + with self.subTest(scheme=scheme, driver_config=driver_config): # The server cert doesn't really matter but set it to the # one that would work if TLS happens to be on. self._start_server("thehost", disable_tls=True) diff --git a/tests/tls/test_self_signed_scheme.py b/tests/tls/test_self_signed_scheme.py index 91f90a673..b52322adb 100644 --- a/tests/tls/test_self_signed_scheme.py +++ b/tests/tls/test_self_signed_scheme.py @@ -30,7 +30,7 @@ def tearDown(self): def test_driver_is_encrypted(self): for driver_config in self.extra_driver_configs: for scheme in self.schemes: - with self.subTest(scheme + "-" + str(driver_config)): + with self.subTest(scheme=scheme, driver_config=driver_config): self._test_reports_encrypted(True, scheme, **driver_config) def test_trusted_ca_correct_hostname(self): @@ -38,7 +38,7 @@ def test_trusted_ca_correct_hostname(self): # when configured for self signed. for driver_config in self.extra_driver_configs: for scheme in self.schemes: - with self.subTest(scheme + "-" + str(driver_config)): + with self.subTest(scheme=scheme, driver_config=driver_config): self._server = TlsServer("trustedRoot_thehost") self.assertTrue(self._try_connect( self._server, scheme, "thehost", **driver_config @@ -52,7 +52,7 @@ def test_trusted_ca_expired_server_correct_hostname(self): # enabled, same for all drivers ? for driver_config in self.extra_driver_configs: for scheme in self.schemes: - with self.subTest(scheme + "-" + str(driver_config)): + with self.subTest(scheme=scheme, driver_config=driver_config): self._server = TlsServer("trustedRoot_thehost_expired") self.assertTrue(self._try_connect( self._server, scheme, "thehost", **driver_config @@ -71,7 +71,7 @@ def test_trusted_ca_wrong_hostname(self): # all. for driver_config in self.extra_driver_configs: for scheme in self.schemes: - with self.subTest(scheme + "-" + str(driver_config)): + with self.subTest(scheme=scheme, driver_config=driver_config): self._server = TlsServer("trustedRoot_thehost") self.assertTrue(self._try_connect( self._server, scheme, "thehostbutwrong", @@ -84,7 +84,7 @@ def test_untrusted_ca_correct_hostname(self): # Should connect for driver_config in self.extra_driver_configs: for scheme in self.schemes: - with self.subTest(scheme + "-" + str(driver_config)): + with self.subTest(scheme=scheme, driver_config=driver_config): self._server = TlsServer("untrustedRoot_thehost") self.assertTrue(self._try_connect( self._server, scheme, "thehost", **driver_config @@ -96,7 +96,7 @@ def test_untrusted_ca_wrong_hostname(self): # Should connect for driver_config in self.extra_driver_configs: for scheme in self.schemes: - with self.subTest(scheme + "-" + str(driver_config)): + with self.subTest(scheme=scheme, driver_config=driver_config): self._server = TlsServer("untrustedRoot_thehost") self.assertTrue(self._try_connect( self._server, scheme, "thehostbutwrong", @@ -110,7 +110,7 @@ def test_unencrypted(self): # TLS connections but the server doesn't speak TLS for driver_config in self.extra_driver_configs: for scheme in self.schemes: - with self.subTest(scheme + "-" + str(driver_config)): + with self.subTest(scheme=scheme, driver_config=driver_config): # The server cert doesn't really matter but set it to the # one that would work if TLS happens to be on. self._server = TlsServer("untrustedRoot_thehost", diff --git a/tests/tls/test_unsecure_scheme.py b/tests/tls/test_unsecure_scheme.py index 8c4059512..e4e286096 100644 --- a/tests/tls/test_unsecure_scheme.py +++ b/tests/tls/test_unsecure_scheme.py @@ -32,12 +32,12 @@ def tearDown(self): def test_driver_is_not_encrypted(self): for scheme in schemes: - with self.subTest(scheme): + with self.subTest(scheme=scheme): self._test_reports_encrypted(False, scheme) def test_secure_server(self): for scheme in schemes: - with self.subTest(scheme): + with self.subTest(scheme=scheme): self._server = TlsServer("trustedRoot_thehost") self.assertFalse(self._try_connect( self._server, scheme, "thehost" @@ -47,7 +47,7 @@ def test_secure_server(self): @driver_feature(types.Feature.API_SSL_CONFIG) def test_secure_server_explicitly_disabled_encryption(self): for scheme in schemes: - with self.subTest(scheme): + with self.subTest(scheme=scheme): self._server = TlsServer("trustedRoot_thehost") self.assertFalse(self._try_connect( self._server, scheme, "thehost", encrypted=False