Skip to content

Commit 89d2330

Browse files
abootonpp-mo
authored andcommitted
PI-2472: Optimise the area weighted regridding routine (SciTools#3598)
* PI-2472: Tweak area weighting regrid move xdim and ydim axes (SciTools#3594) * _regrid_area_weighted_array: Set axis order to y_dim, x_dim last dimensions * _regrid_area_weighted_array: Extra tests for axes ordering * PI-2472: Tweak area weighting regrid enforce xdim ydim (SciTools#3595) * _regrid_area_weighted_array: Set axis order to y_dim, x_dim last dimensions * _regrid_area_weighted_array: Extra tests for axes ordering * _regrid_area_weighted_array: Ensure x_dim and y_dim * PI-2472: Tweak area weighting regrid move averaging out of loop (SciTools#3596) * _regrid_area_weighted_array: Refactor weights and move averaging outside loop
1 parent cebfb52 commit 89d2330

File tree

4 files changed

+307
-151
lines changed

4 files changed

+307
-151
lines changed

lib/iris/experimental/regrid.py

Lines changed: 243 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -461,135 +461,264 @@ def _regrid_area_weighted_array(src_data, x_dim, y_dim,
461461
grid.
462462
463463
"""
464-
# Determine which grid bounds are within src extent.
465-
y_within_bounds = _within_bounds(
466-
src_y_bounds, grid_y_bounds, grid_y_decreasing
467-
)
468-
x_within_bounds = _within_bounds(
469-
src_x_bounds, grid_x_bounds, grid_x_decreasing
470-
)
471464

472-
# Cache which src_bounds are within grid bounds
473-
cached_x_bounds = []
474-
cached_x_indices = []
475-
for (x_0, x_1) in grid_x_bounds:
476-
if grid_x_decreasing:
477-
x_0, x_1 = x_1, x_0
478-
x_bounds, x_indices = _cropped_bounds(src_x_bounds, x_0, x_1)
479-
cached_x_bounds.append(x_bounds)
480-
cached_x_indices.append(x_indices)
481-
482-
# Create empty data array to match the new grid.
483-
# Note that dtype is not preserved and that the array is
484-
# masked to allow for regions that do not overlap.
485-
new_shape = list(src_data.shape)
486-
if x_dim is not None:
487-
new_shape[x_dim] = grid_x_bounds.shape[0]
488-
if y_dim is not None:
489-
new_shape[y_dim] = grid_y_bounds.shape[0]
465+
def _calculate_regrid_area_weighted_weights(
466+
src_x_bounds,
467+
src_y_bounds,
468+
grid_x_bounds,
469+
grid_y_bounds,
470+
grid_x_decreasing,
471+
grid_y_decreasing,
472+
area_func,
473+
circular=False,
474+
):
475+
"""
476+
Compute the area weights used for area-weighted regridding.
490477
478+
"""
479+
# Determine which grid bounds are within src extent.
480+
y_within_bounds = _within_bounds(
481+
src_y_bounds, grid_y_bounds, grid_y_decreasing
482+
)
483+
x_within_bounds = _within_bounds(
484+
src_x_bounds, grid_x_bounds, grid_x_decreasing
485+
)
486+
487+
# Cache which src_bounds are within grid bounds
488+
cached_x_bounds = []
489+
cached_x_indices = []
490+
max_x_indices = 0
491+
for (x_0, x_1) in grid_x_bounds:
492+
if grid_x_decreasing:
493+
x_0, x_1 = x_1, x_0
494+
x_bounds, x_indices = _cropped_bounds(src_x_bounds, x_0, x_1)
495+
cached_x_bounds.append(x_bounds)
496+
cached_x_indices.append(x_indices)
497+
# Keep record of the largest slice
498+
if isinstance(x_indices, slice):
499+
x_indices_size = np.sum(x_indices.stop - x_indices.start)
500+
else: # is tuple of indices
501+
x_indices_size = len(x_indices)
502+
if x_indices_size > max_x_indices:
503+
max_x_indices = x_indices_size
504+
505+
# Cache which y src_bounds areas and weights are within grid bounds
506+
cached_y_indices = []
507+
cached_weights = []
508+
max_y_indices = 0
509+
for j, (y_0, y_1) in enumerate(grid_y_bounds):
510+
# Reverse lower and upper if dest grid is decreasing.
511+
if grid_y_decreasing:
512+
y_0, y_1 = y_1, y_0
513+
y_bounds, y_indices = _cropped_bounds(src_y_bounds, y_0, y_1)
514+
cached_y_indices.append(y_indices)
515+
# Keep record of the largest slice
516+
if isinstance(y_indices, slice):
517+
y_indices_size = np.sum(y_indices.stop - y_indices.start)
518+
else: # is tuple of indices
519+
y_indices_size = len(y_indices)
520+
if y_indices_size > max_y_indices:
521+
max_y_indices = y_indices_size
522+
523+
weights_i = []
524+
for i, (x_0, x_1) in enumerate(grid_x_bounds):
525+
# Reverse lower and upper if dest grid is decreasing.
526+
if grid_x_decreasing:
527+
x_0, x_1 = x_1, x_0
528+
x_bounds = cached_x_bounds[i]
529+
x_indices = cached_x_indices[i]
530+
531+
# Determine whether element i, j overlaps with src and hence
532+
# an area weight should be computed.
533+
# If x_0 > x_1 then we want [0]->x_1 and x_0->[0] + mod in the case
534+
# of wrapped longitudes. However if the src grid is not global
535+
# (i.e. circular) this new cell would include a region outside of
536+
# the extent of the src grid and thus the weight is therefore
537+
# invalid.
538+
outside_extent = x_0 > x_1 and not circular
539+
if (
540+
outside_extent
541+
or not y_within_bounds[j]
542+
or not x_within_bounds[i]
543+
):
544+
weights = False
545+
else:
546+
# Calculate weights based on areas of cropped bounds.
547+
if isinstance(x_indices, tuple) and isinstance(
548+
y_indices, tuple
549+
):
550+
raise RuntimeError(
551+
"Cannot handle split bounds " "in both x and y."
552+
)
553+
weights = area_func(y_bounds, x_bounds)
554+
weights_i.append(weights)
555+
cached_weights.append(weights_i)
556+
return (
557+
tuple(cached_x_indices),
558+
tuple(cached_y_indices),
559+
max_x_indices,
560+
max_y_indices,
561+
tuple(cached_weights),
562+
)
563+
564+
weights_info = _calculate_regrid_area_weighted_weights(
565+
src_x_bounds,
566+
src_y_bounds,
567+
grid_x_bounds,
568+
grid_y_bounds,
569+
grid_x_decreasing,
570+
grid_y_decreasing,
571+
area_func,
572+
circular,
573+
)
574+
(
575+
cached_x_indices,
576+
cached_y_indices,
577+
max_x_indices,
578+
max_y_indices,
579+
cached_weights,
580+
) = weights_info
581+
# Delete variables that are not needed and would not be available
582+
# if _calculate_regrid_area_weighted_weights was refactored further
583+
del src_x_bounds, src_y_bounds, grid_x_bounds, grid_y_bounds
584+
del grid_x_decreasing, grid_y_decreasing
585+
del area_func, circular
586+
587+
# Ensure we have x_dim and y_dim.
588+
x_dim_orig = x_dim
589+
y_dim_orig = y_dim
590+
if y_dim is None:
591+
src_data = np.expand_dims(src_data, axis=src_data.ndim)
592+
y_dim = src_data.ndim - 1
593+
if x_dim is None:
594+
src_data = np.expand_dims(src_data, axis=src_data.ndim)
595+
x_dim = src_data.ndim - 1
596+
# Move y_dim and x_dim to last dimensions
597+
if not x_dim == src_data.ndim - 1:
598+
src_data = np.moveaxis(src_data, x_dim, -1)
599+
if not y_dim == src_data.ndim - 2:
600+
if x_dim < y_dim:
601+
# note: y_dim was shifted along by one position when
602+
# x_dim was moved to the last dimension
603+
src_data = np.moveaxis(src_data, y_dim - 1, -2)
604+
elif x_dim > y_dim:
605+
src_data = np.moveaxis(src_data, y_dim, -2)
606+
x_dim = src_data.ndim - 1
607+
y_dim = src_data.ndim - 2
608+
609+
# Create empty "pre-averaging" data array that will enable the
610+
# src_data data coresponding to a given target grid point,
611+
# to be stacked per point.
612+
# Note that dtype is not preserved and that the array mask
613+
# allows for regions that do not overlap.
614+
new_shape = list(src_data.shape)
615+
new_shape[x_dim] = len(cached_x_indices)
616+
new_shape[y_dim] = len(cached_y_indices)
617+
num_target_pts = len(cached_y_indices) * len(cached_x_indices)
618+
src_areas_shape = list(src_data.shape)
619+
src_areas_shape[y_dim] = max_y_indices
620+
src_areas_shape[x_dim] = max_x_indices
621+
src_areas_shape += [num_target_pts]
491622
# Use input cube dtype or convert values to the smallest possible float
492623
# dtype when necessary.
493624
dtype = np.promote_types(src_data.dtype, np.float16)
625+
# Create empty arrays to hold src_data per target point, and weights
626+
src_area_datas = np.zeros(src_areas_shape, dtype=np.float64)
627+
src_area_weights = np.zeros(
628+
list((max_y_indices, max_x_indices, num_target_pts))
629+
)
494630

495631
# Flag to indicate whether the original data was a masked array.
496-
src_masked = ma.isMaskedArray(src_data)
632+
src_masked = src_data.mask.any() if ma.isMaskedArray(src_data) else False
497633
if src_masked:
498-
new_data = ma.zeros(new_shape, fill_value=src_data.fill_value,
499-
dtype=dtype)
634+
src_area_masks = np.full(src_areas_shape, True, dtype=np.bool)
500635
else:
501-
new_data = ma.zeros(new_shape, dtype=dtype)
502-
# Assign to mask to explode it, allowing indexed assignment.
503-
new_data.mask = False
636+
new_data_mask = np.full(new_shape, False, dtype=np.bool)
504637

505638
# Axes of data over which the weighted mean is calculated.
506-
axes = []
507-
if y_dim is not None:
508-
axes.append(y_dim)
509-
if x_dim is not None:
510-
axes.append(x_dim)
511-
axis = tuple(axes)
512-
513-
# Simple for loop approach.
514-
indices = [slice(None)] * new_data.ndim
515-
for j, (y_0, y_1) in enumerate(grid_y_bounds):
516-
# Reverse lower and upper if dest grid is decreasing.
517-
if grid_y_decreasing:
518-
y_0, y_1 = y_1, y_0
519-
y_bounds, y_indices = _cropped_bounds(src_y_bounds, y_0, y_1)
520-
for i, (x_0, x_1) in enumerate(grid_x_bounds):
521-
# Reverse lower and upper if dest grid is decreasing.
522-
if grid_x_decreasing:
523-
x_0, x_1 = x_1, x_0
524-
x_bounds = cached_x_bounds[i]
525-
x_indices = cached_x_indices[i]
526-
527-
# Determine whether to mask element i, j based on overlap with
528-
# src.
529-
# If x_0 > x_1 then we want [0]->x_1 and x_0->[0] + mod in the case
530-
# of wrapped longitudes. However if the src grid is not global
531-
# (i.e. circular) this new cell would include a region outside of
532-
# the extent of the src grid and should therefore be masked.
533-
outside_extent = x_0 > x_1 and not circular
534-
if (outside_extent or not y_within_bounds[j] or not
535-
x_within_bounds[i]):
536-
# Mask out element(s) in new_data
537-
if x_dim is not None:
538-
indices[x_dim] = i
539-
if y_dim is not None:
540-
indices[y_dim] = j
541-
new_data[tuple(indices)] = ma.masked
639+
axis = (y_dim, x_dim)
640+
641+
# Stack the src_area data and weights for each target point
642+
target_pt_ji = -1
643+
for j, y_indices in enumerate(cached_y_indices):
644+
for i, x_indices in enumerate(cached_x_indices):
645+
target_pt_ji += 1
646+
# Determine whether to mask element i, j based on whether
647+
# there are valid weights.
648+
weights = cached_weights[j][i]
649+
if isinstance(weights, bool) and not weights:
650+
if not src_masked:
651+
# Cheat! Fill the data with zeros and weights as one.
652+
# The weighted average result will be the same, but
653+
# we avoid dividing by zero.
654+
src_area_weights[..., target_pt_ji] = 1
655+
new_data_mask[..., j, i] = True
542656
else:
543657
# Calculate weighted mean of data points.
544658
# Slice out relevant data (this may or may not be a view()
545659
# depending on x_indices being a slice or not).
546-
if isinstance(x_indices, tuple) and isinstance(
547-
y_indices, tuple
548-
):
549-
raise RuntimeError(
550-
"Cannot handle split bounds " "in both x and y."
551-
)
552-
# Calculate weights based on areas of cropped bounds.
553-
weights = area_func(y_bounds, x_bounds)
554-
555-
if x_dim is not None:
556-
indices[x_dim] = x_indices
557-
if y_dim is not None:
558-
indices[y_dim] = y_indices
559-
data = src_data[tuple(indices)]
560-
561-
# Transpose weights to match dim ordering in data.
562-
weights_shape_y = weights.shape[0]
563-
weights_shape_x = weights.shape[1]
564-
if x_dim is not None and y_dim is not None and x_dim < y_dim:
565-
weights = weights.T
566-
# Broadcast the weights array to allow numpy's ma.average
567-
# to be called.
568-
weights_padded_shape = [1] * data.ndim
569-
if y_dim is not None:
570-
weights_padded_shape[y_dim] = weights_shape_y
571-
if x_dim is not None:
572-
weights_padded_shape[x_dim] = weights_shape_x
573-
# Assign new shape to raise error on copy.
574-
weights.shape = weights_padded_shape
575-
# Broadcast weights to match shape of data.
576-
_, weights = np.broadcast_arrays(data, weights)
577-
578-
# Calculate weighted mean taking into account missing data.
579-
new_data_pt = _weighted_mean_with_mdtol(
580-
data, weights=weights, axis=axis, mdtol=mdtol)
581-
582-
# Insert data (and mask) values into new array.
583-
if x_dim is not None:
584-
indices[x_dim] = i
585-
if y_dim is not None:
586-
indices[y_dim] = j
587-
new_data[tuple(indices)] = new_data_pt
588-
589-
# Remove new mask if original data was not masked
590-
# and no values in the new array are masked.
591-
if not src_masked and not new_data.mask.any():
592-
new_data = new_data.data
660+
data = src_data[..., y_indices, x_indices]
661+
len_x = data.shape[-1]
662+
len_y = data.shape[-2]
663+
src_area_datas[..., 0:len_y, 0:len_x, target_pt_ji] = data
664+
src_area_weights[0:len_y, 0:len_x, target_pt_ji] = weights
665+
if src_masked:
666+
src_area_masks[
667+
..., 0:len_y, 0:len_x, target_pt_ji
668+
] = data.mask
669+
670+
# Broadcast the weights array to allow numpy's ma.average
671+
# to be called.
672+
# Assign new shape to raise error on copy.
673+
src_area_weights.shape = src_area_datas.shape[-3:]
674+
# Broadcast weights to match shape of data.
675+
_, src_area_weights = np.broadcast_arrays(src_area_datas, src_area_weights)
676+
677+
# Mask the data points
678+
if src_masked:
679+
src_area_datas = np.ma.array(src_area_datas, mask=src_area_masks)
680+
681+
# Calculate weighted mean taking into account missing data.
682+
new_data = _weighted_mean_with_mdtol(
683+
src_area_datas, weights=src_area_weights, axis=axis, mdtol=mdtol
684+
)
685+
new_data = new_data.reshape(new_shape)
686+
if src_masked:
687+
new_data_mask = new_data.mask
688+
689+
# Mask the data if originally masked or if the result has masked points
690+
if ma.isMaskedArray(src_data):
691+
new_data = ma.array(
692+
new_data,
693+
mask=new_data_mask,
694+
fill_value=src_data.fill_value,
695+
dtype=dtype,
696+
)
697+
elif new_data_mask.any():
698+
new_data = ma.array(new_data, mask=new_data_mask, dtype=dtype)
699+
else:
700+
new_data = new_data.astype(dtype)
701+
702+
# Restore data to original form
703+
if x_dim_orig is None and y_dim_orig is None:
704+
new_data = np.squeeze(new_data, axis=x_dim)
705+
new_data = np.squeeze(new_data, axis=y_dim)
706+
elif y_dim_orig is None:
707+
new_data = np.squeeze(new_data, axis=y_dim)
708+
new_data = np.moveaxis(new_data, -1, x_dim_orig)
709+
elif x_dim_orig is None:
710+
new_data = np.squeeze(new_data, axis=x_dim)
711+
new_data = np.moveaxis(new_data, -1, y_dim_orig)
712+
elif x_dim_orig < y_dim_orig:
713+
# move the x_dim back first, so that the y_dim will
714+
# then be moved to its original position
715+
new_data = np.moveaxis(new_data, -1, x_dim_orig)
716+
new_data = np.moveaxis(new_data, -1, y_dim_orig)
717+
else:
718+
# move the y_dim back first, so that the x_dim will
719+
# then be moved to its original position
720+
new_data = np.moveaxis(new_data, -2, y_dim_orig)
721+
new_data = np.moveaxis(new_data, -1, x_dim_orig)
593722

594723
return new_data
595724

0 commit comments

Comments
 (0)