Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion lib/iris/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,11 @@ def __eq__(self, other):
return result

#
# Provide a copy method, as for 'dict', but *not* provided by MutableMapping
# Provide methods duplicating those for a 'dict', but which are *not* provided by
# MutableMapping, for compatibility with code which expected a cube.attributes to be
# a :class:`~iris.common.mixin.LimitedAttributeDict`.
# The extra required methods are :
# 'copy', 'update', '__ior__', '__or__', '__ror__' and 'fromkeys'.
#
def copy(self):
"""
Expand All @@ -983,6 +987,66 @@ def copy(self):
"""
return CubeAttrsDict(self)

def update(self, *args, **kwargs):
"""
Update by adding items from a mapping arg, or keyword-values.

If the argument is a split dictionary, preserve the local/global nature of its
keys.
"""
if args and hasattr(args[0], "globals") and hasattr(args[0], "locals"):
dic = args[0]
self.globals.update(dic.globals)
self.locals.update(dic.locals)
else:
super().update(*args)
super().update(**kwargs)

def __or__(self, arg):
"""Implement 'or' via 'update'."""
if not isinstance(arg, Mapping):
return NotImplemented
new_dict = self.copy()
new_dict.update(arg)
return new_dict

def __ior__(self, arg):
"""Implement 'ior' via 'update'."""
self.update(arg)
return self

def __ror__(self, arg):
"""
Implement 'ror' via 'update'.

This needs to promote, such that the result is a CubeAttrsDict.
"""
if not isinstance(arg, Mapping):
return NotImplemented
result = CubeAttrsDict(arg)
result.update(self)
return result

@classmethod
def fromkeys(cls, iterable, value=None):
"""
Create a new object with keys taken from an argument, all set to one value.

If the argument is a split dictionary, preserve the local/global nature of its
keys.
"""
if hasattr(iterable, "globals") and hasattr(iterable, "locals"):
# When main input is a split-attrs dict, create global/local parts from its
# global/local keys
result = cls(
globals=dict.fromkeys(iterable.globals, value),
locals=dict.fromkeys(iterable.locals, value),
)
else:
# Create from a dict.fromkeys, using default classification of the keys.
result = cls(dict.fromkeys(iterable, value))
return result

#
# The remaining methods are sufficient to generate a complete standard Mapping
# API. See -
Expand Down
103 changes: 94 additions & 9 deletions lib/iris/tests/unit/cube/test_CubeAttrsDict.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,17 @@ def sample_attrs() -> CubeAttrsDict:


def check_content(attrs, locals=None, globals=None, matches=None):
# Check a CubeAttrsDict for expected properties.
# If locals/globals are set, test for equality and non-identity.
"""
Check a CubeAttrsDict for expected properties.

Its ".globals" and ".locals" must match 'locals' and 'globals' args
-- except that, if 'matches' is provided, it is a CubeAttrsDict, whose
locals/globals *replace* the 'locals'/'globals' arguments.

Check that the result is a CubeAttrsDict and, for both local + global parts,
* parts match for *equality* (==) but are *non-identical* (is not)
* order of keys matches expected (N.B. which is *not* required for equality)
"""
assert isinstance(attrs, CubeAttrsDict)
attr_locals, attr_globals = attrs.locals, attrs.globals
assert type(attr_locals) == LimitedAttributeDict
Expand All @@ -41,6 +50,7 @@ def check(arg, content):
# .. we proceed to ensure that the stored content is equal but NOT the same
assert content == arg
assert content is not arg
assert list(content.keys()) == list(arg.keys())

check(locals, attr_locals)
check(globals, attr_globals)
Expand Down Expand Up @@ -98,14 +108,89 @@ def test_copy(self, sample_attrs):
assert copy is not sample_attrs
check_content(copy, matches=sample_attrs)

def test_update(self, sample_attrs):
updated = sample_attrs.copy()
updated.update({"q": 77})
expected_locals = sample_attrs.locals.copy()
expected_locals["q"] = 77
check_content(
updated, globals=sample_attrs.globals, locals=expected_locals
@pytest.fixture(params=["regular_arg", "split_arg"])
def update_testcase(self, request):
lhs = CubeAttrsDict(globals={"a": 1, "b": 2}, locals={"b": 3, "c": 4})
if request.param == "split_arg":
# A set of "update settings", with global/local-specific keys.
rhs = CubeAttrsDict(
globals={"a": 1001, "x": 1007},
# NOTE: use a global-default key here, to check that type is preserved
locals={"b": 1003, "history": 1099},
)
expected_result = CubeAttrsDict(
globals={"a": 1001, "b": 2, "x": 1007},
locals={"b": 1003, "c": 4, "history": 1099},
)
else:
assert request.param == "regular_arg"
# A similar set of update values in a regular dict (so not local/global)
rhs = {"a": 1001, "x": 1007, "b": 1003, "history": 1099}
expected_result = CubeAttrsDict(
globals={"a": 1001, "b": 2, "history": 1099},
locals={"b": 1003, "c": 4, "x": 1007},
)
return lhs, rhs, expected_result

def test_update(self, update_testcase):
testval, updater, expected = update_testcase
testval.update(updater)
check_content(testval, matches=expected)

def test___or__(self, update_testcase):
testval, updater, expected = update_testcase
original = testval.copy()
result = testval | updater
assert result is not testval
assert testval == original
check_content(result, matches=expected)

def test___ior__(self, update_testcase):
testval, updater, expected = update_testcase
testval |= updater
check_content(testval, matches=expected)

def test___ror__(self):
# Check the "or" operation, when lhs is a regular dictionary
lhs = {"a": 1, "b": 2, "history": 3}
rhs = CubeAttrsDict(
globals={"a": 1001, "x": 1007},
# NOTE: use a global-default key here, to check that type is preserved
locals={"b": 1003, "history": 1099},
)
# The lhs should be promoted to a CubeAttrsDict, and then combined.
expected = CubeAttrsDict(
globals={"history": 3, "a": 1001, "x": 1007},
locals={"a": 1, "b": 1003, "history": 1099},
)
result = lhs | rhs
check_content(result, matches=expected)

@pytest.mark.parametrize("value", [1, None])
@pytest.mark.parametrize("inputtype", ["regular_arg", "split_arg"])
def test__fromkeys(self, value, inputtype):
if inputtype == "regular_arg":
# Check when input is a plain iterable of key-names
keys = ["a", "b", "history"]
# Result has keys assigned local/global via default mechanism.
expected = CubeAttrsDict(
globals={"history": value},
locals={"a": value, "b": value},
)
else:
assert inputtype == "split_arg"
# Check when input is a CubeAttrsDict
keys = CubeAttrsDict(
globals={"a": 1}, locals={"b": 2, "history": 3}
)
# The result preserves the input keys' local/global identity
# N.B. "history" would be global by default (cf. "regular_arg" case)
expected = CubeAttrsDict(
globals={"a": value},
locals={"b": value, "history": value},
)
result = CubeAttrsDict.fromkeys(keys, value)
check_content(result, matches=expected)

def test_to_dict(self, sample_attrs):
result = dict(sample_attrs)
Expand Down