diff --git a/docs/src/whatsnew/latest.rst b/docs/src/whatsnew/latest.rst index 7ecce0c5fb..86623ed204 100644 --- a/docs/src/whatsnew/latest.rst +++ b/docs/src/whatsnew/latest.rst @@ -111,6 +111,11 @@ This document explains the changes made to Iris for this release #. `@wjbenfold`_ and `@rcomer`_ (reviewer) corrected the axis on which masking is applied when an aggregator adds a trailing dimension. (:pull:`4755`) +* `@rcomer`_ and `@pp-mo`_ ensured that all methods to create or modify a + :class:`iris.cube.CubeList` check that it only contains cubes. According to + code comments, this was supposedly already the case, but there were several bugs + and loopholes. + 💣 Incompatible Changes ======================= diff --git a/lib/iris/cube.py b/lib/iris/cube.py index f0a79d0965..e3bacf08fc 100644 --- a/lib/iris/cube.py +++ b/lib/iris/cube.py @@ -152,19 +152,13 @@ class CubeList(list): """ - def __new__(cls, list_of_cubes=None): - """Given a :class:`list` of cubes, return a CubeList instance.""" - cube_list = list.__new__(cls, list_of_cubes) - - # Check that all items in the incoming list are cubes. Note that this - # checking does not guarantee that a CubeList instance *always* has - # just cubes in its list as the append & __getitem__ methods have not - # been overridden. - if not all([isinstance(cube, Cube) for cube in cube_list]): - raise ValueError( - "All items in list_of_cubes must be Cube " "instances." - ) - return cube_list + def __init__(self, *args, **kwargs): + """Given an iterable of cubes, return a CubeList instance.""" + # Do whatever a list does, to initialise ourself "as a list" + super().__init__(*args, **kwargs) + # Check that all items in the list are cubes. + for cube in self: + self._assert_is_cube(cube) def __str__(self): """Runs short :meth:`Cube.summary` on every cube.""" @@ -182,13 +176,17 @@ def __repr__(self): """Runs repr on every cube.""" return "[%s]" % ",\n".join([repr(cube) for cube in self]) - def _repr_html_(self): - from iris.experimental.representation import CubeListRepresentation - - representer = CubeListRepresentation(self) - return representer.repr_html() + @staticmethod + def _assert_is_cube(obj): + if not hasattr(obj, "add_aux_coord"): + msg = ( + r"Object {obj} cannot be put in a cubelist, " + "as it is not a Cube." + ) + raise ValueError(msg) # TODO #370 Which operators need overloads? + def __add__(self, other): return CubeList(list.__add__(self, other)) @@ -210,6 +208,48 @@ def __getslice__(self, start, stop): result = CubeList(result) return result + def __iadd__(self, other_cubes): + """ + Add a sequence of cubes to the cubelist in place. + """ + return super(CubeList, self).__iadd__(CubeList(other_cubes)) + + def __setitem__(self, key, cube_or_sequence): + """Set self[key] to cube or sequence of cubes""" + if isinstance(key, int): + # should have single cube. + self._assert_is_cube(cube_or_sequence) + else: + # key is a slice (or exception will come from list method). + cube_or_sequence = CubeList(cube_or_sequence) + + super(CubeList, self).__setitem__(key, cube_or_sequence) + + def append(self, cube): + """ + Append a cube. + """ + self._assert_is_cube(cube) + super(CubeList, self).append(cube) + + def extend(self, other_cubes): + """ + Extend cubelist by appending the cubes contained in other_cubes. + + Args: + + * other_cubes: + A cubelist or other sequence of cubes. + """ + super(CubeList, self).extend(CubeList(other_cubes)) + + def insert(self, index, cube): + """ + Insert a cube before index. + """ + self._assert_is_cube(cube) + super(CubeList, self).insert(index, cube) + def xml(self, checksum=False, order=True, byteorder=True): """Return a string of the XML that this list of cubes represents.""" diff --git a/lib/iris/tests/unit/cube/test_CubeList.py b/lib/iris/tests/unit/cube/test_CubeList.py index 50c1e553bc..771b214fe4 100644 --- a/lib/iris/tests/unit/cube/test_CubeList.py +++ b/lib/iris/tests/unit/cube/test_CubeList.py @@ -10,6 +10,7 @@ import iris.tests as tests # isort:skip import collections +import copy from unittest import mock from cf_units import Unit @@ -23,6 +24,26 @@ from iris.fileformats.pp import STASH import iris.tests.stock +NOT_CUBE_MSG = "cannot be put in a cubelist, as it is not a Cube." +NON_ITERABLE_MSG = "object is not iterable" + + +class Test_append(tests.IrisTest): + def setUp(self): + self.cubelist = iris.cube.CubeList() + self.cube1 = iris.cube.Cube(1, long_name="foo") + self.cube2 = iris.cube.Cube(1, long_name="bar") + + def test_pass(self): + self.cubelist.append(self.cube1) + self.assertEqual(self.cubelist[-1], self.cube1) + self.cubelist.append(self.cube2) + self.assertEqual(self.cubelist[-1], self.cube2) + + def test_fail(self): + with self.assertRaisesRegex(ValueError, NOT_CUBE_MSG): + self.cubelist.append(None) + class Test_concatenate_cube(tests.IrisTest): def setUp(self): @@ -70,6 +91,29 @@ def test_empty(self): CubeList([]).concatenate_cube() +class Test_extend(tests.IrisTest): + def setUp(self): + self.cube1 = iris.cube.Cube(1, long_name="foo") + self.cube2 = iris.cube.Cube(1, long_name="bar") + self.cubelist1 = iris.cube.CubeList([self.cube1]) + self.cubelist2 = iris.cube.CubeList([self.cube2]) + + def test_pass(self): + cubelist = copy.copy(self.cubelist1) + cubelist.extend(self.cubelist2) + self.assertEqual(cubelist, self.cubelist1 + self.cubelist2) + cubelist.extend([self.cube2]) + self.assertEqual(cubelist[-1], self.cube2) + + def test_fail(self): + with self.assertRaisesRegex(TypeError, NON_ITERABLE_MSG): + self.cubelist1.extend(self.cube1) + with self.assertRaisesRegex(TypeError, NON_ITERABLE_MSG): + self.cubelist1.extend(None) + with self.assertRaisesRegex(ValueError, NOT_CUBE_MSG): + self.cubelist1.extend(range(3)) + + class Test_extract_overlapping(tests.IrisTest): def setUp(self): shape = (6, 14, 19) @@ -130,6 +174,44 @@ def test_different_orders(self): self.assertEqual(b.coord("time"), self.cube.coord("time")[2:4]) +class Test_iadd(tests.IrisTest): + def setUp(self): + self.cube1 = iris.cube.Cube(1, long_name="foo") + self.cube2 = iris.cube.Cube(1, long_name="bar") + self.cubelist1 = iris.cube.CubeList([self.cube1]) + self.cubelist2 = iris.cube.CubeList([self.cube2]) + + def test_pass(self): + cubelist = copy.copy(self.cubelist1) + cubelist += self.cubelist2 + self.assertEqual(cubelist, self.cubelist1 + self.cubelist2) + cubelist += [self.cube2] + self.assertEqual(cubelist[-1], self.cube2) + + def test_fail(self): + with self.assertRaisesRegex(TypeError, NON_ITERABLE_MSG): + self.cubelist1 += self.cube1 + with self.assertRaisesRegex(TypeError, NON_ITERABLE_MSG): + self.cubelist1 += 1.0 + with self.assertRaisesRegex(ValueError, NOT_CUBE_MSG): + self.cubelist1 += range(3) + + +class Test_insert(tests.IrisTest): + def setUp(self): + self.cube1 = iris.cube.Cube(1, long_name="foo") + self.cube2 = iris.cube.Cube(1, long_name="bar") + self.cubelist = iris.cube.CubeList([self.cube1] * 3) + + def test_pass(self): + self.cubelist.insert(1, self.cube2) + self.assertEqual(self.cubelist[1], self.cube2) + + def test_fail(self): + with self.assertRaisesRegex(ValueError, NOT_CUBE_MSG): + self.cubelist.insert(0, None) + + class Test_merge_cube(tests.IrisTest): def setUp(self): self.cube1 = Cube([1, 2, 3], "air_temperature", units="K") @@ -274,6 +356,34 @@ def test_combination_with_extra_triple(self): self.assertCML(cube, checksum=False) +class Test_setitem(tests.IrisTest): + def setUp(self): + self.cube1 = iris.cube.Cube(1, long_name="foo") + self.cube2 = iris.cube.Cube(1, long_name="bar") + self.cube3 = iris.cube.Cube(1, long_name="boo") + self.cubelist = iris.cube.CubeList([self.cube1] * 3) + + def test_pass(self): + self.cubelist[1] = self.cube2 + self.assertEqual(self.cubelist[1], self.cube2) + self.cubelist[:2] = (self.cube2, self.cube3) + self.assertEqual( + self.cubelist, + iris.cube.CubeList([self.cube2, self.cube3, self.cube1]), + ) + + def test_fail(self): + with self.assertRaisesRegex(ValueError, NOT_CUBE_MSG): + self.cubelist[0] = None + with self.assertRaisesRegex(ValueError, NOT_CUBE_MSG): + self.cubelist[0:2] = [self.cube3, None] + + with self.assertRaisesRegex(TypeError, NON_ITERABLE_MSG): + self.cubelist[:1] = 2.5 + with self.assertRaisesRegex(TypeError, NON_ITERABLE_MSG): + self.cubelist[:1] = self.cube1 + + class Test_xml(tests.IrisTest): def setUp(self): self.cubes = CubeList([Cube(np.arange(3)), Cube(np.arange(3))])