diff --git a/test/test_bulk_loader.py b/test/test_bulk_loader.py index 8bb92e7..eacd928 100644 --- a/test/test_bulk_loader.py +++ b/test/test_bulk_loader.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import csv import os import sys @@ -29,32 +27,32 @@ def row_count(in_csv): return idx -class TestBulkLoader(unittest.TestCase): +class TestBulkLoader: + + redis_con = redis.Redis(decode_responses=True) + @classmethod - def setUpClass(cls): - """ - Instantiate a new Redis connection - """ - cls.redis_con = redis.Redis(host="localhost", port=6379, decode_responses=True) + def setup_class(cls): cls.redis_con.flushall() @classmethod - def tearDownClass(cls): + def teardown_class(cls): """Delete temporary files""" - try: - os.remove("/tmp/nodes.tmp") - os.remove("/tmp/relations.tmp") - os.remove("/tmp/nodes_index.tmp") - os.remove("/tmp/nodes_full_text_index.tmp") - except OSError: - pass + for i in [ + "nodes.tmp", + "relations.tmp", + "nodex_index.tmp", + "nodes_full_text_index.tml", + ]: + if os.path.isfile(f"/tmp/{i}"): + os.unlink(f"/tmp/{i}") cls.redis_con.flushall() def validate_exception(self, res, expected_msg): - self.assertNotEqual(res.exit_code, 0) - self.assertIn(expected_msg, str(res.exception)) + assert res.exit_code != 0 + assert expected_msg in str(res.exception) - def test01_social_graph(self): + def test_social_graph(self): """Build the graph in 'example' and validate the created graph.""" global person_count global country_count @@ -92,17 +90,15 @@ def test01_social_graph(self): ) # The script should report 27 overall node creations and 48 edge creations. - self.assertEqual(res.exit_code, 0) - self.assertIn("27 nodes created", res.output) - self.assertIn("48 relations created", res.output) + assert res.exit_code == 0 + assert "27 nodes created" in res.output + assert "48 relations created" in res.output # Validate creation count by label/type - self.assertIn(person_count + " nodes created with label 'Person'", res.output) - self.assertIn(country_count + " nodes created with label 'Country'", res.output) - self.assertIn(knows_count + " relations created for type 'KNOWS'", res.output) - self.assertIn( - visited_count + " relations created for type 'VISITED'", res.output - ) + assert person_count + " nodes created with label 'Person'" in res.output + assert country_count + " nodes created with label 'Country'" in res.output + assert knows_count + " relations created for type 'KNOWS'" in res.output + assert visited_count + " relations created for type 'VISITED'" in res.output # Open the constructed graph. graph = self.redis_con.graph("social") @@ -126,7 +122,7 @@ def test01_social_graph(self): ["Tal Doron", 32, "male", "single"], ["Valerie Abigail Arad", 31, "female", "married"], ] - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result # Verify that the Country label exists, has the correct attributes, and is properly populated. query_result = graph.query("MATCH (c:Country) RETURN c.name ORDER BY c.name") @@ -145,7 +141,7 @@ def test01_social_graph(self): ["Thailand"], ["USA"], ] - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result # Validate that the expected relations and properties have been constructed query_result = graph.query( @@ -166,8 +162,7 @@ def test01_social_graph(self): ["Alon Fital", "married", "Lucy Yanfital"], ["Ori Laslo", "married", "Shelly Laslo Rooz"], ] - self.assertEqual(query_result.result_set, expected_result) - + assert query_result.result_set == expected_result query_result = graph.query( "MATCH (a)-[e:VISITED]->(b) RETURN a.name, e.purpose, b.name ORDER BY e.purpose, a.name, b.name" ) @@ -209,9 +204,9 @@ def test01_social_graph(self): ["Valerie Abigail Arad", "pleasure", "Amsterdam"], ["Valerie Abigail Arad", "pleasure", "Russia"], ] - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result - def test02_private_identifiers(self): + def test_private_identifiers(self): """Validate that private identifiers are not added to the graph.""" graphname = "tmpgraph1" # Write temporary files @@ -241,18 +236,18 @@ def test02_private_identifiers(self): ) # The script should report 3 node creations and 2 edge creations - self.assertEqual(res.exit_code, 0) - self.assertIn("3 nodes created", res.output) - self.assertIn("2 relations created", res.output) + assert res.exit_code == 0 + assert "3 nodes created" in res.output + assert "2 relations created" in res.output tmp_graph = self.redis_con.graph(graphname) # The field "_identifier" should not be a property in the graph query_result = tmp_graph.query("MATCH (a) RETURN a") for propname in query_result.header: - self.assertNotIn("_identifier", propname) + assert "_identifier" not in propname - def test03_reused_identifier(self): + def test_reused_identifier(self): """Expect failure on reused identifiers.""" graphname = "tmpgraph2" # Write temporary files @@ -281,8 +276,8 @@ def test03_reused_identifier(self): ) # The script should fail because a node identifier is reused - self.assertNotEqual(res.exit_code, 0) - self.assertIn("used multiple times", res.output) + assert res.exit_code != 0 + assert "used multiple times" in res.output # Run the script again without creating relations runner = CliRunner() @@ -293,10 +288,10 @@ def test03_reused_identifier(self): ) # The script should succeed and create 3 nodes - self.assertEqual(res.exit_code, 0) - self.assertIn("3 nodes created", res.output) + assert res.exit_code == 0 + assert "3 nodes created" in res.output - def test04_batched_build(self): + def test_batched_build(self): """ Create a graph using many batches. Reuses the inputs of test01_social_graph @@ -333,17 +328,15 @@ def test04_batched_build(self): ) # The script should report 27 overall node creations and 48 edge creations. - self.assertEqual(res.exit_code, 0) - self.assertIn("27 nodes created", res.output) - self.assertIn("48 relations created", res.output) + assert res.exit_code == 0 + assert "27 nodes created" in res.output + assert "48 relations created" in res.output # Validate creation count by label/type - self.assertIn(person_count + " nodes created with label 'Person'", res.output) - self.assertIn(country_count + " nodes created with label 'Country'", res.output) - self.assertIn(knows_count + " relations created for type 'KNOWS'", res.output) - self.assertIn( - visited_count + " relations created for type 'VISITED'", res.output - ) + assert person_count + " nodes created with label 'Person'" in res.output + assert country_count + " nodes created with label 'Country'" in res.output + assert knows_count + " relations created for type 'KNOWS'" in res.output + assert visited_count + " relations created for type 'VISITED'" in res.output original_graph = self.redis_con.graph("social") new_graph = self.redis_con.graph(graphname) @@ -353,7 +346,7 @@ def test04_batched_build(self): "MATCH (p:Person) RETURN p, ID(p) ORDER BY p.name" ) new_result = new_graph.query("MATCH (p:Person) RETURN p, ID(p) ORDER BY p.name") - self.assertEqual(original_result.result_set, new_result.result_set) + assert original_result.result_set == new_result.result_set original_result = original_graph.query( "MATCH (a)-[e:KNOWS]->(b) RETURN a.name, e, b.name ORDER BY e.relation, a.name" @@ -361,9 +354,9 @@ def test04_batched_build(self): new_result = new_graph.query( "MATCH (a)-[e:KNOWS]->(b) RETURN a.name, e, b.name ORDER BY e.relation, a.name" ) - self.assertEqual(original_result.result_set, new_result.result_set) + assert original_result.result_set == new_result.result_set - def test05_script_failures(self): + def test_script_failures(self): """Validate that the bulk loader fails gracefully on invalid inputs and arguments""" graphname = "tmpgraph3" @@ -425,7 +418,7 @@ def test05_script_failures(self): # The script should fail because an invalid node identifier was used self.validate_exception(res, "fakeidentifier") - def test06_property_types(self): + def test_property_types(self): """Verify that numeric, boolean, and string types are properly handled""" graphname = "tmpgraph4" @@ -456,9 +449,9 @@ def test06_property_types(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("3 nodes created", res.output) - self.assertIn("3 relations created", res.output) + assert res.exit_code == 0 + assert "3 nodes created" in res.output + assert "3 relations created" in res.output graph = self.redis_con.graph(graphname) query_result = graph.query( @@ -471,9 +464,9 @@ def test06_property_types(self): ] # The graph should have the correct types for all properties - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result - def test07_utf8(self): + def test_utf8(self): """Verify that numeric, boolean, and null types are properly handled""" graphname = "tmpgraph5" # Write temporary files @@ -516,9 +509,9 @@ def test07_utf8(self): ] for i, j in zip(query_result.result_set, expected_strs): - self.assertEqual(repr(i), repr(j)) + repr(i) == repr(j) - def test08_nonstandard_separators(self): + def test_nonstandard_separators(self): """Validate use of non-comma delimiters in input files.""" graphname = "tmpgraph6" @@ -540,8 +533,8 @@ def test08_nonstandard_separators(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("2 nodes created", res.output) + assert res.exit_code == 0 + assert "2 nodes created" in res.output graph = self.redis_con.graph(graphname) query_result = graph.query( @@ -550,9 +543,9 @@ def test08_nonstandard_separators(self): expected_result = [["val1", 5, True], [10.5, "a", False]] # The graph should have the correct types for all properties - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result - def test09_schema(self): + def test_schema(self): """Validate that the enforce-schema argument is respected""" graphname = "tmpgraph7" @@ -569,8 +562,8 @@ def test09_schema(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("2 nodes created", res.output) + assert res.exit_code == 0 + assert "2 nodes created" in res.output graph = self.redis_con.graph(graphname) query_result = graph.query( @@ -579,9 +572,9 @@ def test09_schema(self): expected_result = [["0", 0, True], ["1", 1, False]] # The graph should have the correct types for all properties - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result - def test10_invalid_schema(self): + def test_invalid_schema(self): """Validate that errors are emitted properly with an invalid CSV schema.""" graphname = "expect_fail" @@ -602,7 +595,7 @@ def test10_invalid_schema(self): # Expect an error. self.validate_exception(res, "Could not parse") - def test11_schema_ignore_columns(self): + def test_schema_ignore_columns(self): """Validate that columns with the type IGNORE are not inserted.""" graphname = "ignore_graph" @@ -619,8 +612,8 @@ def test11_schema_ignore_columns(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("2 nodes created", res.output) + assert res.exit_code == 0 + assert "2 nodes created" in res.output graph = self.redis_con.graph(graphname) query_result = graph.query("MATCH (a) RETURN a ORDER BY a.str_col") @@ -628,10 +621,10 @@ def test11_schema_ignore_columns(self): # The nodes should only have the 'str_col' property node_1 = {"str_col": "str1"} node_2 = {"str_col": "str2"} - self.assertEqual(query_result.result_set[0][0].properties, node_1) - self.assertEqual(query_result.result_set[1][0].properties, node_2) + assert query_result.result_set[0][0].properties == node_1 + assert query_result.result_set[1][0].properties == node_2 - def test12_no_null_values(self): + def test_no_null_values(self): """Validate that NULL inputs are not inserted.""" graphname = "null_graph" @@ -648,8 +641,8 @@ def test12_no_null_values(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("2 nodes created", res.output) + assert res.exit_code == 0 + assert "2 nodes created" in res.output graph = self.redis_con.graph(graphname) query_result = graph.query("MATCH (a) RETURN a ORDER BY a.str_col") @@ -657,10 +650,10 @@ def test12_no_null_values(self): # Only the first node should only have the 'mixed_col' property node_1 = {"str_col": "str1", "mixed_col": True} node_2 = {"str_col": "str2"} - self.assertEqual(query_result.result_set[0][0].properties, node_1) - self.assertEqual(query_result.result_set[1][0].properties, node_2) + assert query_result.result_set[0][0].properties == node_1 + assert query_result.result_set[1][0].properties == node_2 - def test13_id_namespaces(self): + def test_id_namespaces(self): """Validate that ID namespaces allow for scoped identifiers.""" graphname = "namespace_graph" @@ -701,9 +694,9 @@ def test13_id_namespaces(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("4 nodes created", res.output) - self.assertIn("2 relations created", res.output) + assert res.exit_code == 0 + assert "4 nodes created" in res.output + assert "2 relations created" in res.output graph = self.redis_con.graph(graphname) query_result = graph.query( @@ -714,9 +707,9 @@ def test13_id_namespaces(self): ["0", "Jeffrey", ["User"], "0", 20, ["Post"]], ["1", "Filipe", ["User"], "1", 40, ["Post"]], ] - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result - def test14_array_properties_inferred(self): + def test_array_properties_inferred(self): """Validate that array properties are correctly inserted.""" graphname = "arr_graph" @@ -733,18 +726,18 @@ def test14_array_properties_inferred(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("2 nodes created", res.output) + assert res.exit_code == 0 + assert "2 nodes created" in res.output graph = self.redis_con.graph(graphname) query_result = graph.query("MATCH (a) RETURN a ORDER BY a.str_col") node_1 = {"str_col": "str1", "arr_col": [1, 0.2, "nested_str", False]} node_2 = {"str_col": "str2", "arr_col": ["prop1", ["nested_1", "nested_2"], 5]} - self.assertEqual(query_result.result_set[0][0].properties, node_1) - self.assertEqual(query_result.result_set[1][0].properties, node_2) + assert query_result.result_set[0][0].properties == node_1 + assert query_result.result_set[1][0].properties == node_2 - def test15_array_properties_schema_enforced(self): + def test_array_properties_schema_enforced(self): """Validate that array properties are correctly inserted with an enforced schema.""" graphname = "arr_graph_with_schema" @@ -768,18 +761,18 @@ def test15_array_properties_schema_enforced(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("2 nodes created", res.output) + assert res.exit_code == 0 + assert "2 nodes created" in res.output graph = self.redis_con.graph(graphname) query_result = graph.query("MATCH (a) RETURN a ORDER BY a.str_col") node_1 = {"str_col": "str1", "arr_col": [1, 0.2, "nested_str", False]} node_2 = {"str_col": "str2", "arr_col": ["prop1", ["nested_1", "nested_2"], 5]} - self.assertEqual(query_result.result_set[0][0].properties, node_1) - self.assertEqual(query_result.result_set[1][0].properties, node_2) + assert query_result.result_set[0][0].properties == node_1 + assert query_result.result_set[1][0].properties == node_2 - def test16_error_on_schema_failure(self): + def test_error_on_schema_failure(self): """Validate that the loader errors on processing non-conformant CSVs with an enforced schema.""" graphname = "schema_error" @@ -806,10 +799,10 @@ def test16_error_on_schema_failure(self): self.fail() # Should be unreachable except Exception as e: # Verify that the correct exception is raised. - self.assertEqual(sys.exc_info()[0].__name__, "SchemaError") - self.assertIn("Could not parse 'strval' as an array", str(e)) + assert sys.exc_info()[0].__name__ == "SchemaError" + assert "Could not parse 'strval' as an array" in str(e) - def test17_ensure_index_is_created(self): + def test_ensure_index_is_created(self): graphname = "index_test" with open("/tmp/nodes_index.tmp", mode="w") as csv_file: out = csv.writer(csv_file, delimiter="|") @@ -835,17 +828,17 @@ def test17_ensure_index_is_created(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("2 nodes created", res.output) - self.assertIn("Indices created: 1", res.output) + assert res.exit_code == 0 + assert "2 nodes created" in res.output + assert "Indices created: 1" in res.output r = redis.Redis(host="localhost", port=6379, decode_responses=True) res = r.execute_command( "GRAPH.EXPLAIN", graphname, "MATCH (p:Person) WHERE p.age > 16 RETURN p" ) - self.assertIn(" Node By Index Scan | (p:Person)", res) + assert " Node By Index Scan | (p:Person)" in res - def test18_ensure_full_text_index_is_created(self): + def test_ensure_full_text_index_is_created(self): graphname = "index_full_text_test" with open("/tmp/nodes_full_text_index.tmp", mode="w") as csv_file: out = csv.writer(csv_file, delimiter="|") @@ -871,9 +864,9 @@ def test18_ensure_full_text_index_is_created(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("4 nodes created", res.output) - self.assertIn("Indices created: 1", res.output) + assert res.exit_code == 0 + assert "4 nodes created" in res.output + assert "Indices created: 1" in res.output graph = self.redis_con.graph(graphname) query_result = graph.query( @@ -886,9 +879,9 @@ def test18_ensure_full_text_index_is_created(self): ] # We should find only the tamarins - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result - def test19_integer_ids(self): + def test_integer_ids(self): """Validate that IDs can be persisted as integers.""" graphname = "id_integer_graph" @@ -931,9 +924,9 @@ def test19_integer_ids(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("4 nodes created", res.output) - self.assertIn("2 relations created", res.output) + assert res.exit_code == 0 + assert "4 nodes created" in res.output + assert "2 relations created" in res.output graph = self.redis_con.graph(graphname) query_result = graph.query( @@ -945,8 +938,4 @@ def test19_integer_ids(self): [0, "Jeffrey", ["User"], 0, 20, ["Post"]], [1, "Filipe", ["User"], 1, 40, ["Post"]], ] - self.assertEqual(query_result.result_set, expected_result) - - -if __name__ == "__main__": - unittest.main() + assert query_result.result_set == expected_result diff --git a/test/test_bulk_update.py b/test/test_bulk_update.py index 117f7da..15ecc9d 100644 --- a/test/test_bulk_update.py +++ b/test/test_bulk_update.py @@ -11,22 +11,21 @@ from redisgraph_bulk_loader.bulk_update import bulk_update -class TestBulkUpdate(unittest.TestCase): +class TestBulkUpdate: + + redis_con = redis.Redis(decode_responses=True) + @classmethod - def setUpClass(cls): - """ - Instantiate a new Redis connection - """ - cls.redis_con = redis.Redis(host="localhost", port=6379, decode_responses=True) + def setup_class(cls): cls.redis_con.flushall() @classmethod - def tearDownClass(cls): + def teardown_class(cls): """Delete temporary files""" - os.remove("/tmp/csv.tmp") + os.unlink("/tmp/csv.tmp") cls.redis_con.flushall() - def test01_simple_updates(self): + def test_simple_updates(self): """Validate that bulk updates work on an empty graph.""" graphname = "tmpgraph1" # Write temporary files @@ -50,17 +49,17 @@ def test01_simple_updates(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("Labels added: 1", res.output) - self.assertIn("Nodes created: 3", res.output) - self.assertIn("Properties set: 6", res.output) + assert res.exit_code == 0 + assert "Labels added: 1" in res.output + assert "Nodes created: 3" in res.output + assert "Properties set: 6" in res.output tmp_graph = self.redis_con.graph(graphname) query_result = tmp_graph.query("MATCH (a) RETURN a.id, a.name ORDER BY a.id") # Validate that the expected results are all present in the graph expected_result = [[0, "a"], [3, "c"], [5, "b"]] - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result # Attempt to re-insert the entities using MERGE. res = runner.invoke( @@ -76,12 +75,12 @@ def test01_simple_updates(self): ) # No new entities should be created. - self.assertEqual(res.exit_code, 0) - self.assertNotIn("Labels added", res.output) - self.assertNotIn("Nodes created", res.output) - self.assertNotIn("Properties set", res.output) + assert res.exit_code == 0 + assert "Labels added" not in res.output + assert "Nodes created" not in res.output + assert "Properties set" not in res.output - def test02_traversal_updates(self): + def test_traversal_updates(self): """Validate that bulk updates can create edges and perform traversals.""" graphname = "tmpgraph1" # Write temporary files @@ -107,10 +106,10 @@ def test02_traversal_updates(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("Nodes created: 3", res.output) - self.assertIn("Relationships created: 3", res.output) - self.assertIn("Properties set: 6", res.output) + assert res.exit_code == 0 + assert "Nodes created: 3" in res.output + assert "Relationships created: 3" in res.output + assert "Properties set: 6" in res.output tmp_graph = self.redis_con.graph(graphname) query_result = tmp_graph.query( @@ -119,9 +118,9 @@ def test02_traversal_updates(self): # Validate that the expected results are all present in the graph expected_result = [["a", "a2"], ["b", "b2"], ["c", "c2"]] - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result - def test03_datatypes(self): + def test_datatypes(self): """Validate that all RedisGraph datatypes are supported by the bulk updater.""" graphname = "tmpgraph2" # Write temporary files @@ -143,9 +142,9 @@ def test03_datatypes(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("Nodes created: 1", res.output) - self.assertIn("Properties set: 5", res.output) + assert res.exit_code == 0 + assert "Nodes created: 1" in res.output + assert "Properties set: 5" in res.output tmp_graph = self.redis_con.graph(graphname) query_result = tmp_graph.query( @@ -154,9 +153,9 @@ def test03_datatypes(self): # Validate that the expected results are all present in the graph expected_result = [[0, 1.5, True, "string", "[1,'nested_str']"]] - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result - def test04_custom_delimiter(self): + def test_custom_delimiter(self): """Validate that non-comma delimiters produce the correct results.""" graphname = "tmpgraph3" # Write temporary files @@ -182,17 +181,17 @@ def test04_custom_delimiter(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("Labels added: 1", res.output) - self.assertIn("Nodes created: 3", res.output) - self.assertIn("Properties set: 6", res.output) + assert res.exit_code == 0 + assert "Labels added: 1" in res.output + assert "Nodes created: 3" in res.output + assert "Properties set: 6" in res.output tmp_graph = self.redis_con.graph(graphname) query_result = tmp_graph.query("MATCH (a) RETURN a.id, a.name ORDER BY a.id") # Validate that the expected results are all present in the graph expected_result = [[0, "a"], [3, "c"], [5, "b"]] - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result # Attempt to re-insert the entities using MERGE. res = runner.invoke( @@ -210,12 +209,12 @@ def test04_custom_delimiter(self): ) # No new entities should be created. - self.assertEqual(res.exit_code, 0) - self.assertNotIn("Labels added", res.output) - self.assertNotIn("Nodes created", res.output) - self.assertNotIn("Properties set", res.output) + assert res.exit_code == 0 + assert "Labels added" not in res.output + assert "Nodes created" not in res.output + assert "Properties set" not in res.output - def test05_custom_variable_name(self): + def test_custom_variable_name(self): """Validate that the user can specify the name of the 'row' query variable.""" graphname = "variable_name" runner = CliRunner() @@ -239,10 +238,10 @@ def test05_custom_variable_name(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("Labels added: 1", res.output) - self.assertIn("Nodes created: 14", res.output) - self.assertIn("Properties set: 56", res.output) + assert res.exit_code == 0 + assert "Labels added: 1" in res.output + assert "Nodes created: 14" in res.output + assert "Properties set: 56" in res.output tmp_graph = self.redis_con.graph(graphname) @@ -266,9 +265,9 @@ def test05_custom_variable_name(self): ["Tal Doron", 32, "male", "single"], ["Valerie Abigail Arad", 31, "female", "married"], ] - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result - def test06_no_header(self): + def test_no_header(self): """Validate that the '--no-header' option works properly.""" graphname = "tmpgraph4" # Write temporary files @@ -292,19 +291,19 @@ def test06_no_header(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("Labels added: 1", res.output) - self.assertIn("Nodes created: 3", res.output) - self.assertIn("Properties set: 6", res.output) + assert res.exit_code == 0 + assert "Labels added: 1" in res.output + assert "Nodes created: 3" in res.output + assert "Properties set: 6" in res.output tmp_graph = self.redis_con.graph(graphname) query_result = tmp_graph.query("MATCH (a) RETURN a.id, a.name ORDER BY a.id") # Validate that the expected results are all present in the graph expected_result = [[0, "a"], [3, "c"], [5, "b"]] - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result - def test07_batched_update(self): + def test_batched_update(self): """Validate that updates performed over multiple batches produce the correct results.""" graphname = "batched_update" @@ -331,19 +330,19 @@ def test07_batched_update(self): catch_exceptions=False, ) - self.assertEqual(res.exit_code, 0) - self.assertIn("Labels added: 1", res.output) - self.assertIn("Nodes created: 100000", res.output) - self.assertIn("Properties set: 100000", res.output) + assert res.exit_code == 0 + assert "Labels added: 1" in res.output + assert "Nodes created: 100000" in res.output + assert "Properties set: 100000" in res.output tmp_graph = self.redis_con.graph(graphname) query_result = tmp_graph.query("MATCH (a) RETURN DISTINCT a.prop") # Validate that the expected results are all present in the graph expected_result = [[prop_str]] - self.assertEqual(query_result.result_set, expected_result) + assert query_result.result_set == expected_result - def test08_runtime_error(self): + def test_runtime_error(self): """Validate that run-time errors are captured by the bulk updater.""" graphname = "tmpgraph5" @@ -364,10 +363,10 @@ def test08_runtime_error(self): ], ) - self.assertNotEqual(res.exit_code, 0) - self.assertIn("Cannot merge node", str(res.exception)) + assert res.exit_code != 0 + assert "Cannot merge node" in str(res.exception) - def test09_compile_time_error(self): + def test_compile_time_error(self): """Validate that malformed queries trigger an early exit from the bulk updater.""" graphname = "tmpgraph5" runner = CliRunner() @@ -383,10 +382,10 @@ def test09_compile_time_error(self): ], ) - self.assertNotEqual(res.exit_code, 0) - self.assertIn("undefined_identifier not defined", str(res.exception)) + assert res.exit_code != 0 + assert "undefined_identifier not defined" in str(res.exception) - def test10_invalid_inputs(self): + def test_invalid_inputs(self): """Validate that the bulk updater handles invalid inputs incorrectly.""" graphname = "tmpgraph6" @@ -403,5 +402,5 @@ def test10_invalid_inputs(self): ], ) - self.assertNotEqual(res.exit_code, 0) - self.assertIn("No such file", str(res.exception)) + assert res.exit_code != 0 + assert "No such file" in str(res.exception) diff --git a/test/test_config.py b/test/test_config.py index 5ded53f..5868af1 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -3,22 +3,22 @@ from redisgraph_bulk_loader.config import Config -class TestBulkLoader(unittest.TestCase): - def test01_default_values(self): +class TestBulkLoader: + def test_default_values(self): """Verify the default values in the Config class.""" config = Config() - self.assertEqual(config.max_token_count, 1024 * 1023) - self.assertEqual(config.max_buffer_size, 64_000_000) - self.assertEqual(config.max_token_size, 64_000_000) - self.assertEqual(config.enforce_schema, False) - self.assertEqual(config.id_type, "STRING") - self.assertEqual(config.skip_invalid_nodes, False) - self.assertEqual(config.skip_invalid_edges, False) - self.assertEqual(config.store_node_identifiers, False) - self.assertEqual(config.separator, ",") - self.assertEqual(config.quoting, 3) + assert config.max_token_count == 1024 * 1023 + assert config.max_buffer_size == 64_000_000 + assert config.max_token_size == 64_000_000 + assert config.enforce_schema == False + assert config.id_type == "STRING" + assert not config.skip_invalid_nodes + assert not config.skip_invalid_edges + assert not config.store_node_identifiers + assert config.separator == "," + assert config.quoting == 3 - def test02_modified_values(self): + def test_modified_values(self): """Verify that Config_set updates Config class values accordingly.""" config = Config( max_token_count=10, @@ -31,17 +31,15 @@ def test02_modified_values(self): separator="|", quoting=0, ) - self.assertEqual(config.max_token_count, 10) - self.assertEqual( - config.max_token_size, 200_000_000 - ) # Max token size argument is converted to megabytes - self.assertEqual( - config.max_buffer_size, 500_000_000 - ) # Buffer size argument is converted to megabytes - self.assertEqual(config.enforce_schema, True) - self.assertEqual(config.id_type, "INTEGER") - self.assertEqual(config.skip_invalid_nodes, True) - self.assertEqual(config.skip_invalid_edges, True) - self.assertEqual(config.store_node_identifiers, False) - self.assertEqual(config.separator, "|") - self.assertEqual(config.quoting, 0) + assert config.max_token_count == 10 + assert config.max_token_size == 200_000_000 + # Max token size argument is converted to megabytes + assert config.max_buffer_size == 500_000_000 + # Buffer size argument is converted to megabytes + assert config.enforce_schema + assert config.id_type == "INTEGER" + assert config.skip_invalid_nodes + assert config.skip_invalid_edges + assert not config.store_node_identifiers + assert config.separator == "|" + assert config.quoting == 0 diff --git a/test/test_label.py b/test/test_label.py index 70592bc..cb15fce 100644 --- a/test/test_label.py +++ b/test/test_label.py @@ -6,13 +6,13 @@ from redisgraph_bulk_loader.label import Label -class TestBulkLoader(unittest.TestCase): +class TestBulkLoader: @classmethod - def tearDownClass(cls): + def teardown_class(cls): """Delete temporary files""" os.remove("/tmp/labels.tmp") - def test01_process_schemaless_header(self): + def test_process_schemaless_header(self): """Verify that a schema-less header is parsed properly.""" with open("/tmp/labels.tmp", mode="w") as csv_file: out = csv.writer(csv_file) @@ -24,14 +24,14 @@ def test01_process_schemaless_header(self): label = Label(None, "/tmp/labels.tmp", "LabelTest", config) # The '_ID' column will not be stored, as the underscore indicates a private identifier. - self.assertEqual(label.column_names, [None, "prop"]) - self.assertEqual(label.column_count, 2) - self.assertEqual(label.id, 0) - self.assertEqual(label.entity_str, "LabelTest") - self.assertEqual(label.prop_count, 1) - self.assertEqual(label.entities_count, 2) - - def test02_process_header_with_schema(self): + assert label.column_names == [None, "prop"] + assert label.column_count == 2 + assert label.id == 0 + assert label.entity_str == "LabelTest" + assert label.prop_count == 1 + assert label.entities_count == 2 + + def test_process_header_with_schema(self): """Verify that a header with a schema is parsed properly.""" with open("/tmp/labels.tmp", mode="w") as csv_file: out = csv.writer(csv_file) @@ -41,11 +41,11 @@ def test02_process_header_with_schema(self): config = Config(enforce_schema=True, store_node_identifiers=True) label = Label(None, "/tmp/labels.tmp", "LabelTest", config) - self.assertEqual(label.column_names, ["id", "property"]) - self.assertEqual(label.column_count, 2) - self.assertEqual(label.id_namespace, "IDNamespace") - self.assertEqual(label.entity_str, "LabelTest") - self.assertEqual(label.prop_count, 2) - self.assertEqual(label.entities_count, 2) - self.assertEqual(label.types[0].name, "ID_STRING") - self.assertEqual(label.types[1].name, "STRING") + assert label.column_names == ["id", "property"] + assert label.column_count == 2 + assert label.id_namespace == "IDNamespace" + assert label.entity_str == "LabelTest" + assert label.prop_count == 2 + assert label.entities_count == 2 + assert label.types[0].name == "ID_STRING" + assert label.types[1].name == "STRING" diff --git a/test/test_relation_type.py b/test/test_relation_type.py index bbc44ae..95098a3 100644 --- a/test/test_relation_type.py +++ b/test/test_relation_type.py @@ -6,13 +6,13 @@ from redisgraph_bulk_loader.relation_type import RelationType -class TestBulkLoader(unittest.TestCase): +class TestBulkLoader: @classmethod - def tearDownClass(cls): + def teardown_class(cls): """Delete temporary files""" os.remove("/tmp/relations.tmp") - def test01_process_schemaless_header(self): + def test_process_schemaless_header(self): """Verify that a schema-less header is parsed properly.""" with open("/tmp/relations.tmp", mode="w") as csv_file: out = csv.writer(csv_file) @@ -22,13 +22,13 @@ def test01_process_schemaless_header(self): config = Config() reltype = RelationType(None, "/tmp/relations.tmp", "RelationTest", config) - self.assertEqual(reltype.start_id, 0) - self.assertEqual(reltype.end_id, 1) - self.assertEqual(reltype.entity_str, "RelationTest") - self.assertEqual(reltype.prop_count, 1) - self.assertEqual(reltype.entities_count, 2) + assert reltype.start_id == 0 + assert reltype.end_id == 1 + assert reltype.entity_str == "RelationTest" + assert reltype.prop_count == 1 + assert reltype.entities_count == 2 - def test02_process_header_with_schema(self): + def test_process_header_with_schema(self): """Verify that a header with a schema is parsed properly.""" with open("/tmp/relations.tmp", mode="w") as csv_file: out = csv.writer(csv_file) @@ -44,13 +44,13 @@ def test02_process_header_with_schema(self): config = Config(enforce_schema=True) reltype = RelationType(None, "/tmp/relations.tmp", "RelationTest", config) - self.assertEqual(reltype.start_id, 1) - self.assertEqual(reltype.start_namespace, "StartNamespace") - self.assertEqual(reltype.end_id, 0) - self.assertEqual(reltype.end_namespace, "EndNamespace") - self.assertEqual(reltype.entity_str, "RelationTest") - self.assertEqual(reltype.prop_count, 1) - self.assertEqual(reltype.entities_count, 2) - self.assertEqual(reltype.types[0].name, "END_ID") - self.assertEqual(reltype.types[1].name, "START_ID") - self.assertEqual(reltype.types[2].name, "STRING") + assert reltype.start_id == 1 + assert reltype.start_namespace == "StartNamespace" + assert reltype.end_id == 0 + assert reltype.end_namespace == "EndNamespace" + assert reltype.entity_str == "RelationTest" + assert reltype.prop_count == 1 + assert reltype.entities_count == 2 + assert reltype.types[0].name == "END_ID" + assert reltype.types[1].name == "START_ID" + assert reltype.types[2].name == "STRING"