Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
360 changes: 242 additions & 118 deletions lib/iris/experimental/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,140 +473,264 @@ def _regrid_area_weighted_array(
grid.

"""
# Determine which grid bounds are within src extent.
y_within_bounds = _within_bounds(
src_y_bounds, grid_y_bounds, grid_y_decreasing
)
x_within_bounds = _within_bounds(
src_x_bounds, grid_x_bounds, grid_x_decreasing
)

# Cache which src_bounds are within grid bounds
cached_x_bounds = []
cached_x_indices = []
for (x_0, x_1) in grid_x_bounds:
if grid_x_decreasing:
x_0, x_1 = x_1, x_0
x_bounds, x_indices = _cropped_bounds(src_x_bounds, x_0, x_1)
cached_x_bounds.append(x_bounds)
cached_x_indices.append(x_indices)

# Create empty data array to match the new grid.
# Note that dtype is not preserved and that the array is
# masked to allow for regions that do not overlap.
new_shape = list(src_data.shape)
if x_dim is not None:
new_shape[x_dim] = grid_x_bounds.shape[0]
if y_dim is not None:
new_shape[y_dim] = grid_y_bounds.shape[0]
def _calculate_regrid_area_weighted_weights(
src_x_bounds,
src_y_bounds,
grid_x_bounds,
grid_y_bounds,
grid_x_decreasing,
grid_y_decreasing,
area_func,
circular=False,
):
"""
Compute the area weights used for area-weighted regridding.

"""
# Determine which grid bounds are within src extent.
y_within_bounds = _within_bounds(
src_y_bounds, grid_y_bounds, grid_y_decreasing
)
x_within_bounds = _within_bounds(
src_x_bounds, grid_x_bounds, grid_x_decreasing
)

# Cache which src_bounds are within grid bounds
cached_x_bounds = []
cached_x_indices = []
max_x_indices = 0
for (x_0, x_1) in grid_x_bounds:
if grid_x_decreasing:
x_0, x_1 = x_1, x_0
x_bounds, x_indices = _cropped_bounds(src_x_bounds, x_0, x_1)
cached_x_bounds.append(x_bounds)
cached_x_indices.append(x_indices)
# Keep record of the largest slice
if isinstance(x_indices, slice):
x_indices_size = np.sum(x_indices.stop - x_indices.start)
else: # is tuple of indices
x_indices_size = len(x_indices)
if x_indices_size > max_x_indices:
max_x_indices = x_indices_size

# Cache which y src_bounds areas and weights are within grid bounds
cached_y_indices = []
cached_weights = []
max_y_indices = 0
for j, (y_0, y_1) in enumerate(grid_y_bounds):
# Reverse lower and upper if dest grid is decreasing.
if grid_y_decreasing:
y_0, y_1 = y_1, y_0
y_bounds, y_indices = _cropped_bounds(src_y_bounds, y_0, y_1)
cached_y_indices.append(y_indices)
# Keep record of the largest slice
if isinstance(y_indices, slice):
y_indices_size = np.sum(y_indices.stop - y_indices.start)
else: # is tuple of indices
y_indices_size = len(y_indices)
if y_indices_size > max_y_indices:
max_y_indices = y_indices_size

weights_i = []
for i, (x_0, x_1) in enumerate(grid_x_bounds):
# Reverse lower and upper if dest grid is decreasing.
if grid_x_decreasing:
x_0, x_1 = x_1, x_0
x_bounds = cached_x_bounds[i]
x_indices = cached_x_indices[i]

# Determine whether element i, j overlaps with src and hence
# an area weight should be computed.
# If x_0 > x_1 then we want [0]->x_1 and x_0->[0] + mod in the case
# of wrapped longitudes. However if the src grid is not global
# (i.e. circular) this new cell would include a region outside of
# the extent of the src grid and thus the weight is therefore
# invalid.
outside_extent = x_0 > x_1 and not circular
if (
outside_extent
or not y_within_bounds[j]
or not x_within_bounds[i]
):
weights = False
else:
# Calculate weights based on areas of cropped bounds.
if isinstance(x_indices, tuple) and isinstance(
y_indices, tuple
):
raise RuntimeError(
"Cannot handle split bounds " "in both x and y."
)
weights = area_func(y_bounds, x_bounds)
weights_i.append(weights)
cached_weights.append(weights_i)
return (
tuple(cached_x_indices),
tuple(cached_y_indices),
max_x_indices,
max_y_indices,
tuple(cached_weights),
)

weights_info = _calculate_regrid_area_weighted_weights(
src_x_bounds,
src_y_bounds,
grid_x_bounds,
grid_y_bounds,
grid_x_decreasing,
grid_y_decreasing,
area_func,
circular,
)
(
cached_x_indices,
cached_y_indices,
max_x_indices,
max_y_indices,
cached_weights,
) = weights_info
# Delete variables that are not needed and would not be available
# if _calculate_regrid_area_weighted_weights was refactored further
del src_x_bounds, src_y_bounds, grid_x_bounds, grid_y_bounds
del grid_x_decreasing, grid_y_decreasing
del area_func, circular

# Ensure we have x_dim and y_dim.
x_dim_orig = x_dim
y_dim_orig = y_dim
if y_dim is None:
src_data = np.expand_dims(src_data, axis=src_data.ndim)
y_dim = src_data.ndim - 1
if x_dim is None:
src_data = np.expand_dims(src_data, axis=src_data.ndim)
x_dim = src_data.ndim - 1
# Move y_dim and x_dim to last dimensions
if not x_dim == src_data.ndim - 1:
src_data = np.moveaxis(src_data, x_dim, -1)
if not y_dim == src_data.ndim - 2:
if x_dim < y_dim:
# note: y_dim was shifted along by one position when
# x_dim was moved to the last dimension
src_data = np.moveaxis(src_data, y_dim - 1, -2)
elif x_dim > y_dim:
src_data = np.moveaxis(src_data, y_dim, -2)
x_dim = src_data.ndim - 1
y_dim = src_data.ndim - 2

# Create empty "pre-averaging" data array that will enable the
# src_data data coresponding to a given target grid point,
# to be stacked per point.
# Note that dtype is not preserved and that the array mask
# allows for regions that do not overlap.
new_shape = list(src_data.shape)
new_shape[x_dim] = len(cached_x_indices)
new_shape[y_dim] = len(cached_y_indices)
num_target_pts = len(cached_y_indices) * len(cached_x_indices)
src_areas_shape = list(src_data.shape)
src_areas_shape[y_dim] = max_y_indices
src_areas_shape[x_dim] = max_x_indices
src_areas_shape += [num_target_pts]
# Use input cube dtype or convert values to the smallest possible float
# dtype when necessary.
dtype = np.promote_types(src_data.dtype, np.float16)
# Create empty arrays to hold src_data per target point, and weights
src_area_datas = np.zeros(src_areas_shape, dtype=np.float64)
src_area_weights = np.zeros(
list((max_y_indices, max_x_indices, num_target_pts))
)

# Flag to indicate whether the original data was a masked array.
src_masked = ma.isMaskedArray(src_data)
src_masked = src_data.mask.any() if ma.isMaskedArray(src_data) else False
if src_masked:
new_data = ma.zeros(
new_shape, fill_value=src_data.fill_value, dtype=dtype
)
src_area_masks = np.full(src_areas_shape, True, dtype=np.bool)
else:
new_data = ma.zeros(new_shape, dtype=dtype)
# Assign to mask to explode it, allowing indexed assignment.
new_data.mask = False
new_data_mask = np.full(new_shape, False, dtype=np.bool)

# Axes of data over which the weighted mean is calculated.
axes = []
if y_dim is not None:
axes.append(y_dim)
if x_dim is not None:
axes.append(x_dim)
axis = tuple(axes)

# Simple for loop approach.
indices = [slice(None)] * new_data.ndim
for j, (y_0, y_1) in enumerate(grid_y_bounds):
# Reverse lower and upper if dest grid is decreasing.
if grid_y_decreasing:
y_0, y_1 = y_1, y_0
y_bounds, y_indices = _cropped_bounds(src_y_bounds, y_0, y_1)
for i, (x_0, x_1) in enumerate(grid_x_bounds):
# Reverse lower and upper if dest grid is decreasing.
if grid_x_decreasing:
x_0, x_1 = x_1, x_0
x_bounds = cached_x_bounds[i]
x_indices = cached_x_indices[i]

# Determine whether to mask element i, j based on overlap with
# src.
# If x_0 > x_1 then we want [0]->x_1 and x_0->[0] + mod in the case
# of wrapped longitudes. However if the src grid is not global
# (i.e. circular) this new cell would include a region outside of
# the extent of the src grid and should therefore be masked.
outside_extent = x_0 > x_1 and not circular
if (
outside_extent
or not y_within_bounds[j]
or not x_within_bounds[i]
):
# Mask out element(s) in new_data
if x_dim is not None:
indices[x_dim] = i
if y_dim is not None:
indices[y_dim] = j
new_data[tuple(indices)] = ma.masked
axis = (y_dim, x_dim)

# Stack the src_area data and weights for each target point
target_pt_ji = -1
for j, y_indices in enumerate(cached_y_indices):
for i, x_indices in enumerate(cached_x_indices):
target_pt_ji += 1
# Determine whether to mask element i, j based on whether
# there are valid weights.
weights = cached_weights[j][i]
if isinstance(weights, bool) and not weights:
if not src_masked:
# Cheat! Fill the data with zeros and weights as one.
# The weighted average result will be the same, but
# we avoid dividing by zero.
src_area_weights[..., target_pt_ji] = 1
new_data_mask[..., j, i] = True
else:
# Calculate weighted mean of data points.
# Slice out relevant data (this may or may not be a view()
# depending on x_indices being a slice or not).
if isinstance(x_indices, tuple) and isinstance(
y_indices, tuple
):
raise RuntimeError(
"Cannot handle split bounds " "in both x and y."
)
# Calculate weights based on areas of cropped bounds.
weights = area_func(y_bounds, x_bounds)

if x_dim is not None:
indices[x_dim] = x_indices
if y_dim is not None:
indices[y_dim] = y_indices
data = src_data[tuple(indices)]

# Transpose weights to match dim ordering in data.
weights_shape_y = weights.shape[0]
weights_shape_x = weights.shape[1]
if x_dim is not None and y_dim is not None and x_dim < y_dim:
weights = weights.T
# Broadcast the weights array to allow numpy's ma.average
# to be called.
weights_padded_shape = [1] * data.ndim
if y_dim is not None:
weights_padded_shape[y_dim] = weights_shape_y
if x_dim is not None:
weights_padded_shape[x_dim] = weights_shape_x
# Assign new shape to raise error on copy.
weights.shape = weights_padded_shape
# Broadcast weights to match shape of data.
_, weights = np.broadcast_arrays(data, weights)

# Calculate weighted mean taking into account missing data.
new_data_pt = _weighted_mean_with_mdtol(
data, weights=weights, axis=axis, mdtol=mdtol
)
data = src_data[..., y_indices, x_indices]
len_x = data.shape[-1]
len_y = data.shape[-2]
src_area_datas[..., 0:len_y, 0:len_x, target_pt_ji] = data
src_area_weights[0:len_y, 0:len_x, target_pt_ji] = weights
if src_masked:
src_area_masks[
..., 0:len_y, 0:len_x, target_pt_ji
] = data.mask

# Broadcast the weights array to allow numpy's ma.average
# to be called.
# Assign new shape to raise error on copy.
src_area_weights.shape = src_area_datas.shape[-3:]
# Broadcast weights to match shape of data.
_, src_area_weights = np.broadcast_arrays(src_area_datas, src_area_weights)

# Mask the data points
if src_masked:
src_area_datas = np.ma.array(src_area_datas, mask=src_area_masks)

# Insert data (and mask) values into new array.
if x_dim is not None:
indices[x_dim] = i
if y_dim is not None:
indices[y_dim] = j
new_data[tuple(indices)] = new_data_pt

# Remove new mask if original data was not masked
# and no values in the new array are masked.
if not src_masked and not new_data.mask.any():
new_data = new_data.data
# Calculate weighted mean taking into account missing data.
new_data = _weighted_mean_with_mdtol(
src_area_datas, weights=src_area_weights, axis=axis, mdtol=mdtol
)
new_data = new_data.reshape(new_shape)
if src_masked:
new_data_mask = new_data.mask

# Mask the data if originally masked or if the result has masked points
if ma.isMaskedArray(src_data):
new_data = ma.array(
new_data,
mask=new_data_mask,
fill_value=src_data.fill_value,
dtype=dtype,
)
elif new_data_mask.any():
new_data = ma.array(new_data, mask=new_data_mask, dtype=dtype)
else:
new_data = new_data.astype(dtype)

# Restore data to original form
if x_dim_orig is None and y_dim_orig is None:
new_data = np.squeeze(new_data, axis=x_dim)
new_data = np.squeeze(new_data, axis=y_dim)
elif y_dim_orig is None:
new_data = np.squeeze(new_data, axis=y_dim)
new_data = np.moveaxis(new_data, -1, x_dim_orig)
elif x_dim_orig is None:
new_data = np.squeeze(new_data, axis=x_dim)
new_data = np.moveaxis(new_data, -1, y_dim_orig)
elif x_dim_orig < y_dim_orig:
# move the x_dim back first, so that the y_dim will
# then be moved to its original position
new_data = np.moveaxis(new_data, -1, x_dim_orig)
new_data = np.moveaxis(new_data, -1, y_dim_orig)
else:
# move the y_dim back first, so that the x_dim will
# then be moved to its original position
new_data = np.moveaxis(new_data, -2, y_dim_orig)
new_data = np.moveaxis(new_data, -1, x_dim_orig)

return new_data

Expand Down
Loading