diff --git a/docs/iris/src/whatsnew/contributions_2.3.0/bugfix_2018-Dec-07_only_cubes_in_cubelists.txt b/docs/iris/src/whatsnew/contributions_2.3.0/bugfix_2018-Dec-07_only_cubes_in_cubelists.txt new file mode 100644 index 0000000000..2af5b2f537 --- /dev/null +++ b/docs/iris/src/whatsnew/contributions_2.3.0/bugfix_2018-Dec-07_only_cubes_in_cubelists.txt @@ -0,0 +1,3 @@ +* The `append`, `extend` and `insert` methods of :class:`iris.cube.CubeList` +now perform a check to ensure that only :class:`iris.cube.Cube` instances are +added. diff --git a/lib/iris/cube.py b/lib/iris/cube.py index 81152a1293..2f38623c37 100644 --- a/lib/iris/cube.py +++ b/lib/iris/cube.py @@ -1,4 +1,4 @@ -# (C) British Crown Copyright 2010 - 2018, Met Office +# (C) British Crown Copyright 2010 - 2019, Met Office # # This file is part of Iris. # @@ -196,10 +196,7 @@ def __new__(cls, list_of_cubes=None): """Given a :class:`list` of cubes, return a CubeList instance.""" cube_list = list.__new__(cls, list_of_cubes) - # Check that all items in the incoming list are cubes. Note that this - # checking does not guarantee that a CubeList instance *always* has - # just cubes in its list as the append & __getitem__ methods have not - # been overridden. + # Check that all items in the incoming list are cubes. if not all([isinstance(cube, Cube) for cube in cube_list]): raise ValueError('All items in list_of_cubes must be Cube ' 'instances.') @@ -219,7 +216,14 @@ def __repr__(self): """Runs repr on every cube.""" return '[%s]' % ',\n'.join([repr(cube) for cube in self]) + @staticmethod + def _assert_is_cube(obj): + if not isinstance(obj, Cube): + msg = ("Object of type '{}' does not belong in a cubelist.") + raise ValueError(msg.format(type(obj).__name__)) + # TODO #370 Which operators need overloads? + def __add__(self, other): return CubeList(list.__add__(self, other)) @@ -241,6 +245,54 @@ def __getslice__(self, start, stop): result = CubeList(result) return result + def __iadd__(self, other_cubes): + """ + Add a sequence of cubes to the cubelist in place. + """ + super(CubeList, self).__iadd__(CubeList(other_cubes)) + + def __setitem__(self, key, cube_or_sequence): + """Set self[key] to cube or sequence of cubes""" + if isinstance(key, int): + # should have single cube. + self._assert_is_cube(cube_or_sequence) + else: + # key is a slice (or exception will come from list method). + cube_or_sequence = CubeList(cube_or_sequence) + + super(CubeList, self).__setitem__(key, cube_or_sequence) + + # __setslice__ is only required for python2.7 compatibility. + def __setslice__(self, *args): + args_list = list(args) + args_list[-1] = CubeList(args[-1]) + super(CubeList, self).__setslice__(*args_list) + + def append(self, cube): + """ + Append a cube. + """ + self._assert_is_cube(cube) + super(CubeList, self).append(cube) + + def extend(self, other_cubes): + """ + Extend cubelist by appending the cubes contained in other_cubes. + + Args: + + * other_cubes: + A cubelist or other sequence of cubes. + """ + super(CubeList, self).extend(CubeList(other_cubes)) + + def insert(self, index, cube): + """ + Insert a cube before index. + """ + self._assert_is_cube(cube) + super(CubeList, self).insert(index, cube) + def xml(self, checksum=False, order=True, byteorder=True): """Return a string of the XML that this list of cubes represents.""" doc = Document() diff --git a/lib/iris/tests/unit/cube/test_CubeList.py b/lib/iris/tests/unit/cube/test_CubeList.py index 258408c6f8..e0a800f99a 100644 --- a/lib/iris/tests/unit/cube/test_CubeList.py +++ b/lib/iris/tests/unit/cube/test_CubeList.py @@ -1,4 +1,4 @@ -# (C) British Crown Copyright 2014 - 2018, Met Office +# (C) British Crown Copyright 2014 - 2019, Met Office # # This file is part of Iris. # @@ -24,6 +24,8 @@ import iris.tests as tests import iris.tests.stock +import copy + from cf_units import Unit import numpy as np @@ -34,6 +36,26 @@ from iris.fileformats.pp import STASH from iris.tests import mock +NOT_CUBE_MSG = "Object of type '{}' does not belong in a cubelist." + + +class Test_append(tests.IrisTest): + def setUp(self): + self.cubelist = iris.cube.CubeList() + self.cube1 = iris.cube.Cube(1, long_name='foo') + self.cube2 = iris.cube.Cube(1, long_name='bar') + + def test_pass(self): + self.cubelist.append(self.cube1) + self.assertEqual(self.cubelist[-1], self.cube1) + self.cubelist.append(self.cube2) + self.assertEqual(self.cubelist[-1], self.cube2) + + def test_fail(self): + with self.assertRaisesRegexp(ValueError, + NOT_CUBE_MSG.format('NoneType')): + self.cubelist.append(None) + class Test_concatenate_cube(tests.IrisTest): def setUp(self): @@ -64,6 +86,30 @@ def test_empty(self): CubeList([]).concatenate_cube() +class Test_extend(tests.IrisTest): + def setUp(self): + self.cube1 = iris.cube.Cube(1, long_name='foo') + self.cube2 = iris.cube.Cube(1, long_name='bar') + self.cubelist1 = iris.cube.CubeList([self.cube1]) + self.cubelist2 = iris.cube.CubeList([self.cube2]) + + def test_pass(self): + cubelist = copy.copy(self.cubelist1) + cubelist.extend(self.cubelist2) + self.assertEqual(cubelist, self.cubelist1 + self.cubelist2) + cubelist.extend([self.cube2]) + self.assertEqual(cubelist[-1], self.cube2) + + def test_fail(self): + with self.assertRaisesRegexp(TypeError, 'Cube is not iterable'): + self.cubelist1.extend(self.cube1) + msg = "'NoneType' object is not iterable" + with self.assertRaisesRegexp(TypeError, msg): + self.cubelist1.extend(None) + with self.assertRaisesRegexp(ValueError, NOT_CUBE_MSG.format('int')): + self.cubelist1.extend(range(3)) + + class Test_extract_overlapping(tests.IrisTest): def setUp(self): shape = (6, 14, 19) @@ -118,6 +164,47 @@ def test_different_orders(self): self.assertEqual(b.coord('time'), self.cube.coord('time')[2:4]) +class Test_iadd(tests.IrisTest): + def setUp(self): + self.cube1 = iris.cube.Cube(1, long_name='foo') + self.cube2 = iris.cube.Cube(1, long_name='bar') + self.cubelist1 = iris.cube.CubeList([self.cube1]) + self.cubelist2 = iris.cube.CubeList([self.cube2]) + + def test_pass(self): + cubelist = copy.copy(self.cubelist1) + cubelist += self.cubelist2 + self.assertEqual(cubelist, self.cubelist1 + self.cubelist2) + cubelist += [self.cube2] + self.assertEqual(cubelist[-1], self.cube2) + + def test_fail(self): + msg = 'Cube is not iterable' + with self.assertRaisesRegexp(TypeError, msg): + self.cubelist1 += self.cube1 + msg = "'float' object is not iterable" + with self.assertRaisesRegexp(TypeError, msg): + self.cubelist1 += 1. + with self.assertRaisesRegexp(ValueError, NOT_CUBE_MSG.format('int')): + self.cubelist1 += range(3) + + +class Test_insert(tests.IrisTest): + def setUp(self): + self.cube1 = iris.cube.Cube(1, long_name='foo') + self.cube2 = iris.cube.Cube(1, long_name='bar') + self.cubelist = iris.cube.CubeList([self.cube1] * 3) + + def test_pass(self): + self.cubelist.insert(1, self.cube2) + self.assertEqual(self.cubelist[1], self.cube2) + + def test_fail(self): + with self.assertRaisesRegexp(ValueError, + NOT_CUBE_MSG.format('NoneType')): + self.cubelist.insert(0, None) + + class Test_merge_cube(tests.IrisTest): def setUp(self): self.cube1 = Cube([1, 2, 3], "air_temperature", units="K") @@ -241,6 +328,36 @@ def test_combination_with_extra_triple(self): self.assertCML(cube, checksum=False) +class Test_setitem(tests.IrisTest): + def setUp(self): + self.cube1 = iris.cube.Cube(1, long_name='foo') + self.cube2 = iris.cube.Cube(1, long_name='bar') + self.cube3 = iris.cube.Cube(1, long_name='boo') + self.cubelist = iris.cube.CubeList([self.cube1] * 3) + + def test_pass(self): + self.cubelist[1] = self.cube2 + self.assertEqual(self.cubelist[1], self.cube2) + self.cubelist[:2] = (self.cube2, self.cube3) + self.assertEqual( + self.cubelist, + iris.cube.CubeList([self.cube2, self.cube3, self.cube1])) + + def test_fail(self): + with self.assertRaisesRegexp(ValueError, + NOT_CUBE_MSG.format('NoneType')): + self.cubelist[0] = None + with self.assertRaisesRegexp(ValueError, + NOT_CUBE_MSG.format('NoneType')): + self.cubelist[0:2] = [self.cube3, None] + + msg = "can only assign an iterable" + with self.assertRaisesRegexp(TypeError, msg): + self.cubelist[:1] = 2.5 + with self.assertRaisesRegexp(TypeError, msg): + self.cubelist[:1] = self.cube1 + + class Test_xml(tests.IrisTest): def setUp(self): self.cubes = CubeList([Cube(np.arange(3)),