Skip to content

Commit da96028

Browse files
authored
Extra CubeAttrsDict methods to emulate dictionary behaviours. (#5592)
* Extra CubeAttrsDict methods to emulate dictionary behaviours. * Don't use staticmethod on fixture.
1 parent 40f82b5 commit da96028

File tree

2 files changed

+159
-10
lines changed

2 files changed

+159
-10
lines changed

lib/iris/cube.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,11 @@ def __eq__(self, other):
972972
return result
973973

974974
#
975-
# Provide a copy method, as for 'dict', but *not* provided by MutableMapping
975+
# Provide methods duplicating those for a 'dict', but which are *not* provided by
976+
# MutableMapping, for compatibility with code which expected a cube.attributes to be
977+
# a :class:`~iris.common.mixin.LimitedAttributeDict`.
978+
# The extra required methods are :
979+
# 'copy', 'update', '__ior__', '__or__', '__ror__' and 'fromkeys'.
976980
#
977981
def copy(self):
978982
"""
@@ -983,6 +987,66 @@ def copy(self):
983987
"""
984988
return CubeAttrsDict(self)
985989

990+
def update(self, *args, **kwargs):
991+
"""
992+
Update by adding items from a mapping arg, or keyword-values.
993+
994+
If the argument is a split dictionary, preserve the local/global nature of its
995+
keys.
996+
"""
997+
if args and hasattr(args[0], "globals") and hasattr(args[0], "locals"):
998+
dic = args[0]
999+
self.globals.update(dic.globals)
1000+
self.locals.update(dic.locals)
1001+
else:
1002+
super().update(*args)
1003+
super().update(**kwargs)
1004+
1005+
def __or__(self, arg):
1006+
"""Implement 'or' via 'update'."""
1007+
if not isinstance(arg, Mapping):
1008+
return NotImplemented
1009+
new_dict = self.copy()
1010+
new_dict.update(arg)
1011+
return new_dict
1012+
1013+
def __ior__(self, arg):
1014+
"""Implement 'ior' via 'update'."""
1015+
self.update(arg)
1016+
return self
1017+
1018+
def __ror__(self, arg):
1019+
"""
1020+
Implement 'ror' via 'update'.
1021+
1022+
This needs to promote, such that the result is a CubeAttrsDict.
1023+
"""
1024+
if not isinstance(arg, Mapping):
1025+
return NotImplemented
1026+
result = CubeAttrsDict(arg)
1027+
result.update(self)
1028+
return result
1029+
1030+
@classmethod
1031+
def fromkeys(cls, iterable, value=None):
1032+
"""
1033+
Create a new object with keys taken from an argument, all set to one value.
1034+
1035+
If the argument is a split dictionary, preserve the local/global nature of its
1036+
keys.
1037+
"""
1038+
if hasattr(iterable, "globals") and hasattr(iterable, "locals"):
1039+
# When main input is a split-attrs dict, create global/local parts from its
1040+
# global/local keys
1041+
result = cls(
1042+
globals=dict.fromkeys(iterable.globals, value),
1043+
locals=dict.fromkeys(iterable.locals, value),
1044+
)
1045+
else:
1046+
# Create from a dict.fromkeys, using default classification of the keys.
1047+
result = cls(dict.fromkeys(iterable, value))
1048+
return result
1049+
9861050
#
9871051
# The remaining methods are sufficient to generate a complete standard Mapping
9881052
# API. See -

lib/iris/tests/unit/cube/test_CubeAttrsDict.py

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,17 @@ def sample_attrs() -> CubeAttrsDict:
2323

2424

2525
def check_content(attrs, locals=None, globals=None, matches=None):
26-
# Check a CubeAttrsDict for expected properties.
27-
# If locals/globals are set, test for equality and non-identity.
26+
"""
27+
Check a CubeAttrsDict for expected properties.
28+
29+
Its ".globals" and ".locals" must match 'locals' and 'globals' args
30+
-- except that, if 'matches' is provided, it is a CubeAttrsDict, whose
31+
locals/globals *replace* the 'locals'/'globals' arguments.
32+
33+
Check that the result is a CubeAttrsDict and, for both local + global parts,
34+
* parts match for *equality* (==) but are *non-identical* (is not)
35+
* order of keys matches expected (N.B. which is *not* required for equality)
36+
"""
2837
assert isinstance(attrs, CubeAttrsDict)
2938
attr_locals, attr_globals = attrs.locals, attrs.globals
3039
assert type(attr_locals) == LimitedAttributeDict
@@ -41,6 +50,7 @@ def check(arg, content):
4150
# .. we proceed to ensure that the stored content is equal but NOT the same
4251
assert content == arg
4352
assert content is not arg
53+
assert list(content.keys()) == list(arg.keys())
4454

4555
check(locals, attr_locals)
4656
check(globals, attr_globals)
@@ -98,14 +108,89 @@ def test_copy(self, sample_attrs):
98108
assert copy is not sample_attrs
99109
check_content(copy, matches=sample_attrs)
100110

101-
def test_update(self, sample_attrs):
102-
updated = sample_attrs.copy()
103-
updated.update({"q": 77})
104-
expected_locals = sample_attrs.locals.copy()
105-
expected_locals["q"] = 77
106-
check_content(
107-
updated, globals=sample_attrs.globals, locals=expected_locals
111+
@pytest.fixture(params=["regular_arg", "split_arg"])
112+
def update_testcase(self, request):
113+
lhs = CubeAttrsDict(globals={"a": 1, "b": 2}, locals={"b": 3, "c": 4})
114+
if request.param == "split_arg":
115+
# A set of "update settings", with global/local-specific keys.
116+
rhs = CubeAttrsDict(
117+
globals={"a": 1001, "x": 1007},
118+
# NOTE: use a global-default key here, to check that type is preserved
119+
locals={"b": 1003, "history": 1099},
120+
)
121+
expected_result = CubeAttrsDict(
122+
globals={"a": 1001, "b": 2, "x": 1007},
123+
locals={"b": 1003, "c": 4, "history": 1099},
124+
)
125+
else:
126+
assert request.param == "regular_arg"
127+
# A similar set of update values in a regular dict (so not local/global)
128+
rhs = {"a": 1001, "x": 1007, "b": 1003, "history": 1099}
129+
expected_result = CubeAttrsDict(
130+
globals={"a": 1001, "b": 2, "history": 1099},
131+
locals={"b": 1003, "c": 4, "x": 1007},
132+
)
133+
return lhs, rhs, expected_result
134+
135+
def test_update(self, update_testcase):
136+
testval, updater, expected = update_testcase
137+
testval.update(updater)
138+
check_content(testval, matches=expected)
139+
140+
def test___or__(self, update_testcase):
141+
testval, updater, expected = update_testcase
142+
original = testval.copy()
143+
result = testval | updater
144+
assert result is not testval
145+
assert testval == original
146+
check_content(result, matches=expected)
147+
148+
def test___ior__(self, update_testcase):
149+
testval, updater, expected = update_testcase
150+
testval |= updater
151+
check_content(testval, matches=expected)
152+
153+
def test___ror__(self):
154+
# Check the "or" operation, when lhs is a regular dictionary
155+
lhs = {"a": 1, "b": 2, "history": 3}
156+
rhs = CubeAttrsDict(
157+
globals={"a": 1001, "x": 1007},
158+
# NOTE: use a global-default key here, to check that type is preserved
159+
locals={"b": 1003, "history": 1099},
108160
)
161+
# The lhs should be promoted to a CubeAttrsDict, and then combined.
162+
expected = CubeAttrsDict(
163+
globals={"history": 3, "a": 1001, "x": 1007},
164+
locals={"a": 1, "b": 1003, "history": 1099},
165+
)
166+
result = lhs | rhs
167+
check_content(result, matches=expected)
168+
169+
@pytest.mark.parametrize("value", [1, None])
170+
@pytest.mark.parametrize("inputtype", ["regular_arg", "split_arg"])
171+
def test__fromkeys(self, value, inputtype):
172+
if inputtype == "regular_arg":
173+
# Check when input is a plain iterable of key-names
174+
keys = ["a", "b", "history"]
175+
# Result has keys assigned local/global via default mechanism.
176+
expected = CubeAttrsDict(
177+
globals={"history": value},
178+
locals={"a": value, "b": value},
179+
)
180+
else:
181+
assert inputtype == "split_arg"
182+
# Check when input is a CubeAttrsDict
183+
keys = CubeAttrsDict(
184+
globals={"a": 1}, locals={"b": 2, "history": 3}
185+
)
186+
# The result preserves the input keys' local/global identity
187+
# N.B. "history" would be global by default (cf. "regular_arg" case)
188+
expected = CubeAttrsDict(
189+
globals={"a": value},
190+
locals={"b": value, "history": value},
191+
)
192+
result = CubeAttrsDict.fromkeys(keys, value)
193+
check_content(result, matches=expected)
109194

110195
def test_to_dict(self, sample_attrs):
111196
result = dict(sample_attrs)

0 commit comments

Comments
 (0)