diff --git a/lib/iris/cube.py b/lib/iris/cube.py index 85fcd22b34..4df66690e6 100644 --- a/lib/iris/cube.py +++ b/lib/iris/cube.py @@ -972,7 +972,11 @@ def __eq__(self, other): return result # - # Provide a copy method, as for 'dict', but *not* provided by MutableMapping + # Provide methods duplicating those for a 'dict', but which are *not* provided by + # MutableMapping, for compatibility with code which expected a cube.attributes to be + # a :class:`~iris.common.mixin.LimitedAttributeDict`. + # The extra required methods are : + # 'copy', 'update', '__ior__', '__or__', '__ror__' and 'fromkeys'. # def copy(self): """ @@ -983,6 +987,66 @@ def copy(self): """ return CubeAttrsDict(self) + def update(self, *args, **kwargs): + """ + Update by adding items from a mapping arg, or keyword-values. + + If the argument is a split dictionary, preserve the local/global nature of its + keys. + """ + if args and hasattr(args[0], "globals") and hasattr(args[0], "locals"): + dic = args[0] + self.globals.update(dic.globals) + self.locals.update(dic.locals) + else: + super().update(*args) + super().update(**kwargs) + + def __or__(self, arg): + """Implement 'or' via 'update'.""" + if not isinstance(arg, Mapping): + return NotImplemented + new_dict = self.copy() + new_dict.update(arg) + return new_dict + + def __ior__(self, arg): + """Implement 'ior' via 'update'.""" + self.update(arg) + return self + + def __ror__(self, arg): + """ + Implement 'ror' via 'update'. + + This needs to promote, such that the result is a CubeAttrsDict. + """ + if not isinstance(arg, Mapping): + return NotImplemented + result = CubeAttrsDict(arg) + result.update(self) + return result + + @classmethod + def fromkeys(cls, iterable, value=None): + """ + Create a new object with keys taken from an argument, all set to one value. + + If the argument is a split dictionary, preserve the local/global nature of its + keys. + """ + if hasattr(iterable, "globals") and hasattr(iterable, "locals"): + # When main input is a split-attrs dict, create global/local parts from its + # global/local keys + result = cls( + globals=dict.fromkeys(iterable.globals, value), + locals=dict.fromkeys(iterable.locals, value), + ) + else: + # Create from a dict.fromkeys, using default classification of the keys. + result = cls(dict.fromkeys(iterable, value)) + return result + # # The remaining methods are sufficient to generate a complete standard Mapping # API. See - diff --git a/lib/iris/tests/unit/cube/test_CubeAttrsDict.py b/lib/iris/tests/unit/cube/test_CubeAttrsDict.py index 709dc3ccba..540917949b 100644 --- a/lib/iris/tests/unit/cube/test_CubeAttrsDict.py +++ b/lib/iris/tests/unit/cube/test_CubeAttrsDict.py @@ -23,8 +23,17 @@ def sample_attrs() -> CubeAttrsDict: def check_content(attrs, locals=None, globals=None, matches=None): - # Check a CubeAttrsDict for expected properties. - # If locals/globals are set, test for equality and non-identity. + """ + Check a CubeAttrsDict for expected properties. + + Its ".globals" and ".locals" must match 'locals' and 'globals' args + -- except that, if 'matches' is provided, it is a CubeAttrsDict, whose + locals/globals *replace* the 'locals'/'globals' arguments. + + Check that the result is a CubeAttrsDict and, for both local + global parts, + * parts match for *equality* (==) but are *non-identical* (is not) + * order of keys matches expected (N.B. which is *not* required for equality) + """ assert isinstance(attrs, CubeAttrsDict) attr_locals, attr_globals = attrs.locals, attrs.globals assert type(attr_locals) == LimitedAttributeDict @@ -41,6 +50,7 @@ def check(arg, content): # .. we proceed to ensure that the stored content is equal but NOT the same assert content == arg assert content is not arg + assert list(content.keys()) == list(arg.keys()) check(locals, attr_locals) check(globals, attr_globals) @@ -98,14 +108,89 @@ def test_copy(self, sample_attrs): assert copy is not sample_attrs check_content(copy, matches=sample_attrs) - def test_update(self, sample_attrs): - updated = sample_attrs.copy() - updated.update({"q": 77}) - expected_locals = sample_attrs.locals.copy() - expected_locals["q"] = 77 - check_content( - updated, globals=sample_attrs.globals, locals=expected_locals + @pytest.fixture(params=["regular_arg", "split_arg"]) + def update_testcase(self, request): + lhs = CubeAttrsDict(globals={"a": 1, "b": 2}, locals={"b": 3, "c": 4}) + if request.param == "split_arg": + # A set of "update settings", with global/local-specific keys. + rhs = CubeAttrsDict( + globals={"a": 1001, "x": 1007}, + # NOTE: use a global-default key here, to check that type is preserved + locals={"b": 1003, "history": 1099}, + ) + expected_result = CubeAttrsDict( + globals={"a": 1001, "b": 2, "x": 1007}, + locals={"b": 1003, "c": 4, "history": 1099}, + ) + else: + assert request.param == "regular_arg" + # A similar set of update values in a regular dict (so not local/global) + rhs = {"a": 1001, "x": 1007, "b": 1003, "history": 1099} + expected_result = CubeAttrsDict( + globals={"a": 1001, "b": 2, "history": 1099}, + locals={"b": 1003, "c": 4, "x": 1007}, + ) + return lhs, rhs, expected_result + + def test_update(self, update_testcase): + testval, updater, expected = update_testcase + testval.update(updater) + check_content(testval, matches=expected) + + def test___or__(self, update_testcase): + testval, updater, expected = update_testcase + original = testval.copy() + result = testval | updater + assert result is not testval + assert testval == original + check_content(result, matches=expected) + + def test___ior__(self, update_testcase): + testval, updater, expected = update_testcase + testval |= updater + check_content(testval, matches=expected) + + def test___ror__(self): + # Check the "or" operation, when lhs is a regular dictionary + lhs = {"a": 1, "b": 2, "history": 3} + rhs = CubeAttrsDict( + globals={"a": 1001, "x": 1007}, + # NOTE: use a global-default key here, to check that type is preserved + locals={"b": 1003, "history": 1099}, ) + # The lhs should be promoted to a CubeAttrsDict, and then combined. + expected = CubeAttrsDict( + globals={"history": 3, "a": 1001, "x": 1007}, + locals={"a": 1, "b": 1003, "history": 1099}, + ) + result = lhs | rhs + check_content(result, matches=expected) + + @pytest.mark.parametrize("value", [1, None]) + @pytest.mark.parametrize("inputtype", ["regular_arg", "split_arg"]) + def test__fromkeys(self, value, inputtype): + if inputtype == "regular_arg": + # Check when input is a plain iterable of key-names + keys = ["a", "b", "history"] + # Result has keys assigned local/global via default mechanism. + expected = CubeAttrsDict( + globals={"history": value}, + locals={"a": value, "b": value}, + ) + else: + assert inputtype == "split_arg" + # Check when input is a CubeAttrsDict + keys = CubeAttrsDict( + globals={"a": 1}, locals={"b": 2, "history": 3} + ) + # The result preserves the input keys' local/global identity + # N.B. "history" would be global by default (cf. "regular_arg" case) + expected = CubeAttrsDict( + globals={"a": value}, + locals={"b": value, "history": value}, + ) + result = CubeAttrsDict.fromkeys(keys, value) + check_content(result, matches=expected) def test_to_dict(self, sample_attrs): result = dict(sample_attrs)