Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions lib/iris/common/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,19 @@


__all__ = [
"SERVICES_COMBINE",
"SERVICES_DIFFERENCE",
"SERVICES_EQUAL",
"SERVICES",
"AncillaryVariableMetadata",
"BaseMetadata",
"CellMeasureMetadata",
"CoordMetadata",
"CubeMetadata",
"DimCoordMetadata",
"hexdigest",
"metadata_filter",
"metadata_manager_factory",
"SERVICES",
"SERVICES_COMBINE",
"SERVICES_DIFFERENCE",
"SERVICES_EQUAL",
]


Expand Down Expand Up @@ -1353,13 +1354,13 @@ def metadata_filter(
):
"""
Filter a collection of objects by their metadata to fit the given metadata
criteria. Criteria be one or both of: specific properties / other objects
criteria. Criteria can be one or both of: specific properties / other objects
carrying metadata to be matched.

Args:

* instances
An iterable of objects to be filtered.
One or more objects to be filtered.

Kwargs:

Expand Down Expand Up @@ -1408,6 +1409,10 @@ def metadata_filter(
else:
obj = item

# apply de morgan's law for one less logical operation
if not (isinstance(instances, str) or isinstance(instances, Iterable)):
instances = [instances]

result = instances

if name is not None:
Expand Down Expand Up @@ -1449,10 +1454,16 @@ def attr_filter(instance):

if axis is not None:
axis = axis.upper()

def get_axis(instance):
if hasattr(instance, "axis"):
axis = instance.axis.upper()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤❤❤

else:
axis = guess_coord_axis(instance)
return axis

result = [
instance
for instance in result
if guess_coord_axis(instance) == axis
instance for instance in result if get_axis(instance) == axis
]

if obj is not None:
Expand Down
14 changes: 10 additions & 4 deletions lib/iris/experimental/ugrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"""

from abc import ABC, abstractmethod
from collections import namedtuple
from collections import Iterable, namedtuple
from functools import wraps

import dask.array as da
Expand All @@ -20,18 +20,18 @@
from .. import _lazy_data as _lazy
from ..common.metadata import (
BaseMetadata,
metadata_filter,
metadata_manager_factory,
SERVICES,
SERVICES_COMBINE,
SERVICES_EQUAL,
SERVICES_DIFFERENCE,
metadata_filter,
)
from ..common.lenient import _lenient_service as lenient_service
from ..common.mixin import CFVariableMixin
from ..config import get_logger
from ..coords import _DimensionalMetadata, AuxCoord
from ..exceptions import CoordinateNotFoundError, ConnectivityNotFoundError
from ..exceptions import ConnectivityNotFoundError, CoordinateNotFoundError
from ..util import guess_coord_axis


Expand Down Expand Up @@ -831,14 +831,14 @@ def __init__(
self,
topology_dimension,
node_coords_and_axes,
connectivities,
standard_name=None,
long_name=None,
var_name=None,
units=None,
attributes=None,
edge_coords_and_axes=None,
face_coords_and_axes=None,
connectivities=None,
node_dimension=None,
edge_dimension=None,
face_dimension=None,
Expand Down Expand Up @@ -874,6 +874,12 @@ def normalise(location, axis):
raise ValueError(emsg)
return f"{location}_{axis}"

if not isinstance(node_coords_and_axes, Iterable):
node_coords_and_axes = [node_coords_and_axes]

if not isinstance(connectivities, Iterable):
connectivities = [connectivities]

kwargs = {}
for coord, axis in node_coords_and_axes:
kwargs[normalise("node", axis)] = coord
Expand Down
20 changes: 19 additions & 1 deletion lib/iris/tests/unit/common/metadata/test_metadata_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@


class Test_standard(tests.IrisTest):
def test_instances_non_iterable(self):
item = Mock()
item.name.return_value = "one"
result = metadata_filter(item, item="one")
self.assertEqual(1, len(result))
self.assertIn(item, result)

def test_name(self):
name_one = Mock()
name_one.name.return_value = "one"
Expand Down Expand Up @@ -101,14 +108,25 @@ def test_invalid_attributes(self):
attributes="one",
)

def test_axis(self):
def test_axis__by_guess(self):
# see https://docs.python.org/3/library/unittest.mock.html#deleting-attributes
axis_lon = Mock(standard_name="longitude")
del axis_lon.axis
axis_lat = Mock(standard_name="latitude")
del axis_lat.axis
input_list = [axis_lon, axis_lat]
result = metadata_filter(input_list, axis="x")
self.assertIn(axis_lon, result)
self.assertNotIn(axis_lat, result)

def test_axis__by_member(self):
axis_x = Mock(axis="x")
axis_y = Mock(axis="y")
input_list = [axis_x, axis_y]
result = metadata_filter(input_list, axis="x")
self.assertEqual(1, len(result))
self.assertIn(axis_x, result)

def test_multiple_args(self):
coord_one = Mock(__class__=AuxCoord, long_name="one")
coord_two = Mock(__class__=AuxCoord, long_name="two")
Expand Down