Skip to content

Commit 57472bb

Browse files
committed
Improve packstream Structure class
These changes have no affect on the driver's public API. They're targeted at improving the development and debugging experience. * Adjust `repr` to follow Python's recommendations * Fix `__eq__` returning `NotImplementedError` instead of `NotImplemented` * Add type hints * Add tests
1 parent 9816ca0 commit 57472bb

File tree

2 files changed

+152
-8
lines changed

2 files changed

+152
-8
lines changed

src/neo4j/_codec/packstream/_python/_common.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,28 @@
1414
# limitations under the License.
1515

1616

17+
from .... import _typing as t
18+
19+
1720
class Structure:
18-
def __init__(self, tag, *fields):
21+
tag: bytes
22+
fields: list[t.Any]
23+
24+
def __init__(self, tag: bytes, *fields: t.Any):
1925
self.tag = tag
2026
self.fields = list(fields)
2127

22-
def __repr__(self):
23-
fields = ", ".join(map(repr, self.fields))
24-
tag_int = ord(self.tag)
25-
return f"Structure[0x{tag_int:02X}]({fields})"
28+
def __repr__(self) -> str:
29+
args = ", ".join(map(repr, (self.tag, *self.fields)))
30+
return f"Structure({args})"
2631

27-
def __eq__(self, other):
32+
def __eq__(self, other) -> bool:
2833
try:
2934
return self.tag == other.tag and self.fields == other.fields
3035
except AttributeError:
31-
return NotImplementedError
36+
return NotImplemented
3237

33-
def __len__(self):
38+
def __len__(self) -> int:
3439
return len(self.fields)
3540

3641
def __getitem__(self, key):
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import pytest
18+
19+
from neo4j._codec.packstream import Structure
20+
21+
22+
@pytest.mark.parametrize(
23+
"args",
24+
(
25+
(b"T", 1, 2, 3, "abc", 1.2, None, False),
26+
(b"F",),
27+
),
28+
)
29+
def test_structure_accessors(args):
30+
tag = args[0]
31+
fields = list(args[1:])
32+
s1 = Structure(*args)
33+
assert s1.tag == tag
34+
assert s1.fields == fields
35+
36+
37+
@pytest.mark.parametrize(
38+
("other", "expected"),
39+
(
40+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None]), True),
41+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, 0]), False),
42+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "B"}, None]), False),
43+
(Structure(b"T", 1, 2, 3, "abc", 1.2, [{"A": "b"}, None]), False),
44+
(Structure(b"T", 1, 2, 3, "abc", 1.3, [{"a": "b"}, None]), False),
45+
(
46+
Structure(b"T", 1, 2, 3, "aBc", float("Nan"), [{"a": "b"}, None]),
47+
False,
48+
),
49+
(Structure(b"T", 2, 2, 3, "abc", 1.2, [{"a": "b"}, None]), False),
50+
(Structure(b"T", 2, 3, "abc", 1.2, [{"a": "b"}, None]), False),
51+
(Structure(b"T", [1, 2, 3, "abc", 1.2, [{"a": "b"}, None]]), False),
52+
(object(), NotImplemented),
53+
),
54+
)
55+
def test_structure_equality(other, expected):
56+
s1 = Structure(b"T", 1, 2, 3, "abc", 1.2, [{"a": "b"}, None])
57+
assert s1.__eq__(other) is expected # noqa: PLC2801
58+
if expected is NotImplemented:
59+
assert s1.__ne__(other) is NotImplemented # noqa: PLC2801
60+
else:
61+
assert s1.__ne__(other) is not expected # noqa: PLC2801
62+
63+
64+
@pytest.mark.parametrize(
65+
("args", "expected"),
66+
(
67+
((b"F", 1, 2), "Structure(b'F', 1, 2)"),
68+
((b"f", [1, 2]), "Structure(b'f', [1, 2])"),
69+
(
70+
(b"T", 1.3, None, {"a": "b"}),
71+
"Structure(b'T', 1.3, None, {'a': 'b'})",
72+
),
73+
),
74+
)
75+
def test_structure_repr(args, expected):
76+
s1 = Structure(*args)
77+
assert repr(s1) == expected
78+
assert str(s1) == expected
79+
80+
# Ensure that the repr is consistent with the constructor
81+
assert eval(repr(s1)) == s1
82+
assert eval(str(s1)) == s1
83+
84+
85+
@pytest.mark.parametrize(
86+
("fields", "expected"),
87+
(
88+
((), 0),
89+
(([],), 1),
90+
((1, 2), 2),
91+
((1, 2, []), 3),
92+
(([1, 2], {"a": "foo", "b": "bar"}), 2),
93+
),
94+
)
95+
def test_structure_len(fields, expected):
96+
structure = Structure(b"F", *fields)
97+
assert len(structure) == expected
98+
99+
100+
def test_structure_getitem():
101+
fields = [1, 2, 3, "abc", 1.2, None, False, {"a": "b"}]
102+
structure = Structure(b"F", *fields)
103+
for i, field in enumerate(fields):
104+
assert structure[i] == field
105+
assert structure[-len(fields) + i] == field
106+
with pytest.raises(IndexError):
107+
_ = structure[len(fields)]
108+
with pytest.raises(IndexError):
109+
_ = structure[-len(fields) - 1]
110+
111+
112+
def test_structure_setitem():
113+
test_value = object()
114+
fields = [1, 2, 3, "abc", 1.2, None, False, {"a": "b"}]
115+
structure = Structure(b"F", *fields)
116+
for i, original_value in enumerate(fields):
117+
structure[i] = test_value
118+
assert structure[i] == test_value
119+
assert structure[-len(fields) + i] == test_value
120+
assert structure[i] != original_value
121+
assert structure[-len(fields) + i] != original_value
122+
123+
structure[i] = original_value
124+
assert structure[i] == original_value
125+
assert structure[-len(fields) + i] == original_value
126+
127+
structure[-len(fields) + i] = test_value
128+
assert structure[i] == test_value
129+
assert structure[-len(fields) + i] == test_value
130+
assert structure[i] != original_value
131+
assert structure[-len(fields) + i] != original_value
132+
133+
structure[-len(fields) + i] = original_value
134+
assert structure[i] == original_value
135+
assert structure[-len(fields) + i] == original_value
136+
with pytest.raises(IndexError):
137+
structure[len(fields)] = test_value
138+
with pytest.raises(IndexError):
139+
structure[-len(fields) - 1] = test_value

0 commit comments

Comments
 (0)