Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 211 additions & 98 deletions lib/iris/experimental/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,136 +473,245 @@ def _regrid_area_weighted_array(
grid.

"""
# Determine which grid bounds are within src extent.
y_within_bounds = _within_bounds(
src_y_bounds, grid_y_bounds, grid_y_decreasing
)
x_within_bounds = _within_bounds(
src_x_bounds, grid_x_bounds, grid_x_decreasing
)

# Cache which src_bounds are within grid bounds
cached_x_bounds = []
cached_x_indices = []
for (x_0, x_1) in grid_x_bounds:
if grid_x_decreasing:
x_0, x_1 = x_1, x_0
x_bounds, x_indices = _cropped_bounds(src_x_bounds, x_0, x_1)
cached_x_bounds.append(x_bounds)
cached_x_indices.append(x_indices)
def _calculate_regrid_area_weighted_weights(
src_x_bounds,
src_y_bounds,
grid_x_bounds,
grid_y_bounds,
grid_x_decreasing,
grid_y_decreasing,
area_func,
circular=False,
):
"""
Compute the area weights used for area-weighted regridding.

"""
# Determine which grid bounds are within src extent.
y_within_bounds = _within_bounds(
src_y_bounds, grid_y_bounds, grid_y_decreasing
)
x_within_bounds = _within_bounds(
src_x_bounds, grid_x_bounds, grid_x_decreasing
)

# Cache which src_bounds are within grid bounds
cached_x_bounds = []
cached_x_indices = []
max_x_indices = 0
for (x_0, x_1) in grid_x_bounds:
if grid_x_decreasing:
x_0, x_1 = x_1, x_0
x_bounds, x_indices = _cropped_bounds(src_x_bounds, x_0, x_1)
cached_x_bounds.append(x_bounds)
cached_x_indices.append(x_indices)
# Keep record of the largest slice
if isinstance(x_indices, slice):
x_indices_size = np.sum(x_indices.stop - x_indices.start)
else: # is tuple of indices
x_indices_size = len(x_indices)
if x_indices_size > max_x_indices:
max_x_indices = x_indices_size

# Cache which y src_bounds areas and weights are within grid bounds
cached_y_indices = []
cached_weights = []
max_y_indices = 0
for j, (y_0, y_1) in enumerate(grid_y_bounds):
# Reverse lower and upper if dest grid is decreasing.
if grid_y_decreasing:
y_0, y_1 = y_1, y_0
y_bounds, y_indices = _cropped_bounds(src_y_bounds, y_0, y_1)
cached_y_indices.append(y_indices)
# Keep record of the largest slice
if isinstance(y_indices, slice):
y_indices_size = np.sum(y_indices.stop - y_indices.start)
else: # is tuple of indices
y_indices_size = len(y_indices)
if y_indices_size > max_y_indices:
max_y_indices = y_indices_size

weights_i = []
for i, (x_0, x_1) in enumerate(grid_x_bounds):
# Reverse lower and upper if dest grid is decreasing.
if grid_x_decreasing:
x_0, x_1 = x_1, x_0
x_bounds = cached_x_bounds[i]
x_indices = cached_x_indices[i]

# Determine whether element i, j overlaps with src and hence
# an area weight should be computed.
# If x_0 > x_1 then we want [0]->x_1 and x_0->[0] + mod in the case
# of wrapped longitudes. However if the src grid is not global
# (i.e. circular) this new cell would include a region outside of
# the extent of the src grid and thus the weight is therefore
# invalid.
outside_extent = x_0 > x_1 and not circular
if (
outside_extent
or not y_within_bounds[j]
or not x_within_bounds[i]
):
weights = False
else:
# Calculate weights based on areas of cropped bounds.
if isinstance(x_indices, tuple) and isinstance(
y_indices, tuple
):
raise RuntimeError(
"Cannot handle split bounds " "in both x and y."
)
weights = area_func(y_bounds, x_bounds)
weights_i.append(weights)
cached_weights.append(weights_i)
return (
tuple(cached_x_indices),
tuple(cached_y_indices),
max_x_indices,
max_y_indices,
tuple(cached_weights),
)

weights_info = _calculate_regrid_area_weighted_weights(
src_x_bounds,
src_y_bounds,
grid_x_bounds,
grid_y_bounds,
grid_x_decreasing,
grid_y_decreasing,
area_func,
circular,
)
(
cached_x_indices,
cached_y_indices,
max_x_indices,
max_y_indices,
cached_weights,
) = weights_info
# Delete variables that are not needed and would not be available
# if _calculate_regrid_area_weighted_weights was refactored further
del src_x_bounds, src_y_bounds, grid_x_bounds, grid_y_bounds
del grid_x_decreasing, grid_y_decreasing
del area_func, circular

# Ensure we have x_dim and y_dim.
x_dim_orig = copy.copy(x_dim)
y_dim_orig = copy.copy(y_dim)
x_dim_orig = x_dim
y_dim_orig = y_dim
if y_dim is None:
src_data = np.expand_dims(src_data, axis=src_data.ndim)
y_dim = src_data.ndim - 1
if x_dim is None:
src_data = np.expand_dims(src_data, axis=src_data.ndim)
x_dim = src_data.ndim - 1
# Move y_dim and x_dim to last dimensions
src_data = np.moveaxis(src_data, x_dim, -1)
if x_dim < y_dim:
src_data = np.moveaxis(src_data, y_dim - 1, -2)
elif x_dim > y_dim:
src_data = np.moveaxis(src_data, y_dim, -2)
if not x_dim == src_data.ndim - 1:
src_data = np.moveaxis(src_data, x_dim, -1)
if not y_dim == src_data.ndim - 2:
if x_dim < y_dim:
# note: y_dim was shifted along by one position when
# x_dim was moved to the last dimension
src_data = np.moveaxis(src_data, y_dim - 1, -2)
elif x_dim > y_dim:
src_data = np.moveaxis(src_data, y_dim, -2)
x_dim = src_data.ndim - 1
y_dim = src_data.ndim - 2

# Create empty data array to match the new grid.
# Note that dtype is not preserved and that the array is
# masked to allow for regions that do not overlap.
# Create empty "pre-averaging" data array that will enable the
# src_data data coresponding to a given target grid point,
# to be stacked per point.
# Note that dtype is not preserved and that the array mask
# allows for regions that do not overlap.
new_shape = list(src_data.shape)
new_shape[x_dim] = grid_x_bounds.shape[0]
new_shape[y_dim] = grid_y_bounds.shape[0]

new_shape[x_dim] = len(cached_x_indices)
new_shape[y_dim] = len(cached_y_indices)
num_target_pts = len(cached_y_indices) * len(cached_x_indices)
src_areas_shape = list(src_data.shape)
src_areas_shape[y_dim] = max_y_indices
src_areas_shape[x_dim] = max_x_indices
src_areas_shape += [num_target_pts]
# Use input cube dtype or convert values to the smallest possible float
# dtype when necessary.
dtype = np.promote_types(src_data.dtype, np.float16)
# Create empty arrays to hold src_data per target point, and weights
src_area_datas = np.zeros(src_areas_shape, dtype=np.float64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should dtype above be used rather than np.float64?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think as it standard the code is using the numpy default value of np.float64, so I've used this here. I don't think we should revisit the dtype handling in a follow up PR as I think it would be beneficial to improve it.

src_area_weights = np.zeros(
list((max_y_indices, max_x_indices, num_target_pts))
)

# Flag to indicate whether the original data was a masked array.
src_masked = ma.isMaskedArray(src_data)
src_masked = src_data.mask.any() if ma.isMaskedArray(src_data) else False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment: there's one test that fulfils this condition.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, there isn't a test that uses a masked array but that doesn't have any masked points.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added an extra test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

if src_masked:
new_data = ma.zeros(
new_shape, fill_value=src_data.fill_value, dtype=dtype
)
src_area_masks = np.full(src_areas_shape, True, dtype=np.bool)
else:
new_data = ma.zeros(new_shape, dtype=dtype)
# Assign to mask to explode it, allowing indexed assignment.
new_data.mask = False
new_data_mask = np.full(new_shape, False, dtype=np.bool)

# Axes of data over which the weighted mean is calculated.
axis = (y_dim, x_dim)

# Simple for loop approach.
for j, (y_0, y_1) in enumerate(grid_y_bounds):
# Reverse lower and upper if dest grid is decreasing.
if grid_y_decreasing:
y_0, y_1 = y_1, y_0
y_bounds, y_indices = _cropped_bounds(src_y_bounds, y_0, y_1)
for i, (x_0, x_1) in enumerate(grid_x_bounds):
# Reverse lower and upper if dest grid is decreasing.
if grid_x_decreasing:
x_0, x_1 = x_1, x_0
x_bounds = cached_x_bounds[i]
x_indices = cached_x_indices[i]

# Determine whether to mask element i, j based on overlap with
# src.
# If x_0 > x_1 then we want [0]->x_1 and x_0->[0] + mod in the case
# of wrapped longitudes. However if the src grid is not global
# (i.e. circular) this new cell would include a region outside of
# the extent of the src grid and should therefore be masked.
outside_extent = x_0 > x_1 and not circular
if (
outside_extent
or not y_within_bounds[j]
or not x_within_bounds[i]
):
# Mask out element(s) in new_data
new_data[..., j, i] = ma.masked
# Stack the src_area data and weights for each target point
target_pt_ji = -1
for j, y_indices in enumerate(cached_y_indices):
for i, x_indices in enumerate(cached_x_indices):
target_pt_ji += 1
# Determine whether to mask element i, j based on whether
# there are valid weights.
weights = cached_weights[j][i]
if isinstance(weights, bool) and not weights:
if not src_masked:
# Cheat! Fill the data with zeros and weights as one.
# The weighted average result will be the same, but
# we avoid dividing by zero.
src_area_weights[..., target_pt_ji] = 1
new_data_mask[..., j, i] = True
else:
# Calculate weighted mean of data points.
# Slice out relevant data (this may or may not be a view()
# depending on x_indices being a slice or not).
if isinstance(x_indices, tuple) and isinstance(
y_indices, tuple
):
raise RuntimeError(
"Cannot handle split bounds " "in both x and y."
)
# Calculate weights based on areas of cropped bounds.
weights = area_func(y_bounds, x_bounds)

data = src_data[..., y_indices, x_indices]
len_x = data.shape[-1]
len_y = data.shape[-2]
src_area_datas[..., 0:len_y, 0:len_x, target_pt_ji] = data
src_area_weights[0:len_y, 0:len_x, target_pt_ji] = weights
if src_masked:
src_area_masks[
..., 0:len_y, 0:len_x, target_pt_ji
] = data.mask

# Broadcast the weights array to allow numpy's ma.average
# to be called.
# Assign new shape to raise error on copy.
src_area_weights.shape = src_area_datas.shape[-3:]
# Broadcast weights to match shape of data.
_, src_area_weights = np.broadcast_arrays(src_area_datas, src_area_weights)

# Mask the data points
if src_masked:
src_area_datas = np.ma.array(src_area_datas, mask=src_area_masks)

# Transpose weights to match dim ordering in data.
weights_shape_y = weights.shape[0]
weights_shape_x = weights.shape[1]
# Broadcast the weights array to allow numpy's ma.average
# to be called.
weights_padded_shape = [1] * data.ndim
weights_padded_shape[y_dim] = weights_shape_y
weights_padded_shape[x_dim] = weights_shape_x
# Assign new shape to raise error on copy.
weights.shape = weights_padded_shape
# Broadcast weights to match shape of data.
_, weights = np.broadcast_arrays(data, weights)

# Calculate weighted mean taking into account missing data.
new_data_pt = _weighted_mean_with_mdtol(
data, weights=weights, axis=axis, mdtol=mdtol
)

# Insert data (and mask) values into new array.
new_data[..., j, i] = new_data_pt

# Remove new mask if original data was not masked
# and no values in the new array are masked.
if not src_masked and not new_data.mask.any():
new_data = new_data.data
# Calculate weighted mean taking into account missing data.
new_data = _weighted_mean_with_mdtol(
src_area_datas, weights=src_area_weights, axis=axis, mdtol=mdtol
)
new_data = new_data.reshape(new_shape)
if src_masked:
new_data_mask = new_data.mask

# Mask the data if originally masked or if the result has masked points
if ma.isMaskedArray(src_data):
new_data = ma.array(
new_data,
mask=new_data_mask,
fill_value=src_data.fill_value,
dtype=dtype,
)
elif new_data_mask.any():
new_data = ma.array(new_data, mask=new_data_mask, dtype=dtype)
else:
new_data = new_data.astype(dtype)

# Restore axis to original order
# Restore data to original form
if x_dim_orig is None and y_dim_orig is None:
new_data = np.squeeze(new_data, axis=x_dim)
new_data = np.squeeze(new_data, axis=y_dim)
Expand All @@ -613,9 +722,13 @@ def _regrid_area_weighted_array(
new_data = np.squeeze(new_data, axis=x_dim)
new_data = np.moveaxis(new_data, -1, y_dim_orig)
elif x_dim_orig < y_dim_orig:
# move the x_dim back first, so that the y_dim will
# then be moved to its original position
new_data = np.moveaxis(new_data, -1, x_dim_orig)
new_data = np.moveaxis(new_data, -1, y_dim_orig)
else:
# move the y_dim back first, so that the x_dim will
# then be moved to its original position
new_data = np.moveaxis(new_data, -2, y_dim_orig)
new_data = np.moveaxis(new_data, -1, x_dim_orig)
Copy link
Contributor

@ehogan ehogan Dec 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Consider using x_dim and y_dim rather than -1 and -2 here?
  • These last two loops can be consolidated, since new_data will end in [..., y, x] in either case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last two loops can't be consolidated; add a comment explaining why :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,22 @@ def test_hybrid_height(self):

def test_missing_data(self):
src = self.simple_cube.copy()
src.data = ma.masked_array(src.data)
src.data = ma.masked_array(src.data, fill_value=999)
src.data[1, 2] = ma.masked
dest = _resampled_grid(self.simple_cube, 2.3, 2.4)
res = regrid_area_weighted(src, dest)
mask = np.zeros((7, 9), bool)
mask[slice(2, 5), slice(4, 7)] = True
self.assertArrayEqual(res.data.mask, mask)
self.assertArrayEqual(res.data.fill_value, 999)

def test_masked_data_all_false(self):
src = self.simple_cube.copy()
src.data = ma.masked_array(src.data, mask=False, fill_value=999)
dest = _resampled_grid(self.simple_cube, 2.3, 2.4)
res = regrid_area_weighted(src, dest)
self.assertArrayEqual(res.data.mask, False)
self.assertArrayEqual(res.data.fill_value, 999)

def test_no_x_overlap(self):
src = self.simple_cube
Expand Down