-
Notifications
You must be signed in to change notification settings - Fork 297
PI-2472: Tweak area weighting regrid move averaging out of loop #3596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
264d33b
0792ad1
7d99492
b64d6f7
852c34d
2d13339
cd31d53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -473,136 +473,245 @@ def _regrid_area_weighted_array( | |
| grid. | ||
|
|
||
| """ | ||
| # Determine which grid bounds are within src extent. | ||
| y_within_bounds = _within_bounds( | ||
| src_y_bounds, grid_y_bounds, grid_y_decreasing | ||
| ) | ||
| x_within_bounds = _within_bounds( | ||
| src_x_bounds, grid_x_bounds, grid_x_decreasing | ||
| ) | ||
|
|
||
| # Cache which src_bounds are within grid bounds | ||
| cached_x_bounds = [] | ||
| cached_x_indices = [] | ||
| for (x_0, x_1) in grid_x_bounds: | ||
| if grid_x_decreasing: | ||
| x_0, x_1 = x_1, x_0 | ||
| x_bounds, x_indices = _cropped_bounds(src_x_bounds, x_0, x_1) | ||
| cached_x_bounds.append(x_bounds) | ||
| cached_x_indices.append(x_indices) | ||
| def _calculate_regrid_area_weighted_weights( | ||
| src_x_bounds, | ||
| src_y_bounds, | ||
| grid_x_bounds, | ||
| grid_y_bounds, | ||
| grid_x_decreasing, | ||
| grid_y_decreasing, | ||
| area_func, | ||
| circular=False, | ||
| ): | ||
| """ | ||
| Compute the area weights used for area-weighted regridding. | ||
|
|
||
| """ | ||
| # Determine which grid bounds are within src extent. | ||
| y_within_bounds = _within_bounds( | ||
| src_y_bounds, grid_y_bounds, grid_y_decreasing | ||
| ) | ||
| x_within_bounds = _within_bounds( | ||
| src_x_bounds, grid_x_bounds, grid_x_decreasing | ||
| ) | ||
|
|
||
| # Cache which src_bounds are within grid bounds | ||
| cached_x_bounds = [] | ||
| cached_x_indices = [] | ||
| max_x_indices = 0 | ||
| for (x_0, x_1) in grid_x_bounds: | ||
| if grid_x_decreasing: | ||
| x_0, x_1 = x_1, x_0 | ||
| x_bounds, x_indices = _cropped_bounds(src_x_bounds, x_0, x_1) | ||
| cached_x_bounds.append(x_bounds) | ||
| cached_x_indices.append(x_indices) | ||
| # Keep record of the largest slice | ||
| if isinstance(x_indices, slice): | ||
| x_indices_size = np.sum(x_indices.stop - x_indices.start) | ||
| else: # is tuple of indices | ||
| x_indices_size = len(x_indices) | ||
| if x_indices_size > max_x_indices: | ||
| max_x_indices = x_indices_size | ||
|
|
||
| # Cache which y src_bounds areas and weights are within grid bounds | ||
| cached_y_indices = [] | ||
| cached_weights = [] | ||
| max_y_indices = 0 | ||
| for j, (y_0, y_1) in enumerate(grid_y_bounds): | ||
| # Reverse lower and upper if dest grid is decreasing. | ||
| if grid_y_decreasing: | ||
| y_0, y_1 = y_1, y_0 | ||
| y_bounds, y_indices = _cropped_bounds(src_y_bounds, y_0, y_1) | ||
| cached_y_indices.append(y_indices) | ||
| # Keep record of the largest slice | ||
| if isinstance(y_indices, slice): | ||
| y_indices_size = np.sum(y_indices.stop - y_indices.start) | ||
| else: # is tuple of indices | ||
| y_indices_size = len(y_indices) | ||
| if y_indices_size > max_y_indices: | ||
| max_y_indices = y_indices_size | ||
|
|
||
| weights_i = [] | ||
| for i, (x_0, x_1) in enumerate(grid_x_bounds): | ||
| # Reverse lower and upper if dest grid is decreasing. | ||
| if grid_x_decreasing: | ||
| x_0, x_1 = x_1, x_0 | ||
| x_bounds = cached_x_bounds[i] | ||
| x_indices = cached_x_indices[i] | ||
|
|
||
| # Determine whether element i, j overlaps with src and hence | ||
| # an area weight should be computed. | ||
| # If x_0 > x_1 then we want [0]->x_1 and x_0->[0] + mod in the case | ||
| # of wrapped longitudes. However if the src grid is not global | ||
| # (i.e. circular) this new cell would include a region outside of | ||
| # the extent of the src grid and thus the weight is therefore | ||
| # invalid. | ||
| outside_extent = x_0 > x_1 and not circular | ||
| if ( | ||
| outside_extent | ||
| or not y_within_bounds[j] | ||
| or not x_within_bounds[i] | ||
| ): | ||
| weights = False | ||
| else: | ||
| # Calculate weights based on areas of cropped bounds. | ||
| if isinstance(x_indices, tuple) and isinstance( | ||
| y_indices, tuple | ||
| ): | ||
| raise RuntimeError( | ||
| "Cannot handle split bounds " "in both x and y." | ||
| ) | ||
| weights = area_func(y_bounds, x_bounds) | ||
| weights_i.append(weights) | ||
| cached_weights.append(weights_i) | ||
| return ( | ||
| tuple(cached_x_indices), | ||
| tuple(cached_y_indices), | ||
| max_x_indices, | ||
| max_y_indices, | ||
| tuple(cached_weights), | ||
| ) | ||
|
|
||
| weights_info = _calculate_regrid_area_weighted_weights( | ||
| src_x_bounds, | ||
| src_y_bounds, | ||
| grid_x_bounds, | ||
| grid_y_bounds, | ||
| grid_x_decreasing, | ||
| grid_y_decreasing, | ||
| area_func, | ||
| circular, | ||
| ) | ||
| ( | ||
| cached_x_indices, | ||
| cached_y_indices, | ||
| max_x_indices, | ||
| max_y_indices, | ||
| cached_weights, | ||
| ) = weights_info | ||
| # Delete variables that are not needed and would not be available | ||
| # if _calculate_regrid_area_weighted_weights was refactored further | ||
| del src_x_bounds, src_y_bounds, grid_x_bounds, grid_y_bounds | ||
| del grid_x_decreasing, grid_y_decreasing | ||
| del area_func, circular | ||
|
|
||
| # Ensure we have x_dim and y_dim. | ||
| x_dim_orig = copy.copy(x_dim) | ||
| y_dim_orig = copy.copy(y_dim) | ||
| x_dim_orig = x_dim | ||
| y_dim_orig = y_dim | ||
| if y_dim is None: | ||
| src_data = np.expand_dims(src_data, axis=src_data.ndim) | ||
| y_dim = src_data.ndim - 1 | ||
| if x_dim is None: | ||
| src_data = np.expand_dims(src_data, axis=src_data.ndim) | ||
| x_dim = src_data.ndim - 1 | ||
| # Move y_dim and x_dim to last dimensions | ||
| src_data = np.moveaxis(src_data, x_dim, -1) | ||
| if x_dim < y_dim: | ||
| src_data = np.moveaxis(src_data, y_dim - 1, -2) | ||
| elif x_dim > y_dim: | ||
| src_data = np.moveaxis(src_data, y_dim, -2) | ||
| if not x_dim == src_data.ndim - 1: | ||
| src_data = np.moveaxis(src_data, x_dim, -1) | ||
| if not y_dim == src_data.ndim - 2: | ||
| if x_dim < y_dim: | ||
| # note: y_dim was shifted along by one position when | ||
| # x_dim was moved to the last dimension | ||
| src_data = np.moveaxis(src_data, y_dim - 1, -2) | ||
| elif x_dim > y_dim: | ||
| src_data = np.moveaxis(src_data, y_dim, -2) | ||
| x_dim = src_data.ndim - 1 | ||
| y_dim = src_data.ndim - 2 | ||
|
|
||
| # Create empty data array to match the new grid. | ||
| # Note that dtype is not preserved and that the array is | ||
| # masked to allow for regions that do not overlap. | ||
| # Create empty "pre-averaging" data array that will enable the | ||
| # src_data data coresponding to a given target grid point, | ||
| # to be stacked per point. | ||
| # Note that dtype is not preserved and that the array mask | ||
| # allows for regions that do not overlap. | ||
| new_shape = list(src_data.shape) | ||
| new_shape[x_dim] = grid_x_bounds.shape[0] | ||
| new_shape[y_dim] = grid_y_bounds.shape[0] | ||
|
|
||
| new_shape[x_dim] = len(cached_x_indices) | ||
| new_shape[y_dim] = len(cached_y_indices) | ||
| num_target_pts = len(cached_y_indices) * len(cached_x_indices) | ||
| src_areas_shape = list(src_data.shape) | ||
| src_areas_shape[y_dim] = max_y_indices | ||
| src_areas_shape[x_dim] = max_x_indices | ||
| src_areas_shape += [num_target_pts] | ||
| # Use input cube dtype or convert values to the smallest possible float | ||
| # dtype when necessary. | ||
| dtype = np.promote_types(src_data.dtype, np.float16) | ||
| # Create empty arrays to hold src_data per target point, and weights | ||
| src_area_datas = np.zeros(src_areas_shape, dtype=np.float64) | ||
| src_area_weights = np.zeros( | ||
| list((max_y_indices, max_x_indices, num_target_pts)) | ||
| ) | ||
|
|
||
| # Flag to indicate whether the original data was a masked array. | ||
| src_masked = ma.isMaskedArray(src_data) | ||
| src_masked = src_data.mask.any() if ma.isMaskedArray(src_data) else False | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment: there's one test that fulfils this condition. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. However, there isn't a test that uses a masked array but that doesn't have any masked points. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added an extra test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks good, thanks! |
||
| if src_masked: | ||
| new_data = ma.zeros( | ||
| new_shape, fill_value=src_data.fill_value, dtype=dtype | ||
| ) | ||
| src_area_masks = np.full(src_areas_shape, True, dtype=np.bool) | ||
| else: | ||
| new_data = ma.zeros(new_shape, dtype=dtype) | ||
| # Assign to mask to explode it, allowing indexed assignment. | ||
| new_data.mask = False | ||
| new_data_mask = np.full(new_shape, False, dtype=np.bool) | ||
|
|
||
| # Axes of data over which the weighted mean is calculated. | ||
| axis = (y_dim, x_dim) | ||
|
|
||
| # Simple for loop approach. | ||
| for j, (y_0, y_1) in enumerate(grid_y_bounds): | ||
| # Reverse lower and upper if dest grid is decreasing. | ||
| if grid_y_decreasing: | ||
| y_0, y_1 = y_1, y_0 | ||
| y_bounds, y_indices = _cropped_bounds(src_y_bounds, y_0, y_1) | ||
| for i, (x_0, x_1) in enumerate(grid_x_bounds): | ||
| # Reverse lower and upper if dest grid is decreasing. | ||
| if grid_x_decreasing: | ||
| x_0, x_1 = x_1, x_0 | ||
| x_bounds = cached_x_bounds[i] | ||
| x_indices = cached_x_indices[i] | ||
|
|
||
| # Determine whether to mask element i, j based on overlap with | ||
| # src. | ||
| # If x_0 > x_1 then we want [0]->x_1 and x_0->[0] + mod in the case | ||
| # of wrapped longitudes. However if the src grid is not global | ||
| # (i.e. circular) this new cell would include a region outside of | ||
| # the extent of the src grid and should therefore be masked. | ||
| outside_extent = x_0 > x_1 and not circular | ||
| if ( | ||
| outside_extent | ||
| or not y_within_bounds[j] | ||
| or not x_within_bounds[i] | ||
| ): | ||
| # Mask out element(s) in new_data | ||
| new_data[..., j, i] = ma.masked | ||
| # Stack the src_area data and weights for each target point | ||
| target_pt_ji = -1 | ||
| for j, y_indices in enumerate(cached_y_indices): | ||
| for i, x_indices in enumerate(cached_x_indices): | ||
| target_pt_ji += 1 | ||
| # Determine whether to mask element i, j based on whether | ||
| # there are valid weights. | ||
| weights = cached_weights[j][i] | ||
| if isinstance(weights, bool) and not weights: | ||
| if not src_masked: | ||
| # Cheat! Fill the data with zeros and weights as one. | ||
| # The weighted average result will be the same, but | ||
| # we avoid dividing by zero. | ||
| src_area_weights[..., target_pt_ji] = 1 | ||
| new_data_mask[..., j, i] = True | ||
| else: | ||
| # Calculate weighted mean of data points. | ||
| # Slice out relevant data (this may or may not be a view() | ||
| # depending on x_indices being a slice or not). | ||
| if isinstance(x_indices, tuple) and isinstance( | ||
| y_indices, tuple | ||
| ): | ||
| raise RuntimeError( | ||
| "Cannot handle split bounds " "in both x and y." | ||
| ) | ||
| # Calculate weights based on areas of cropped bounds. | ||
| weights = area_func(y_bounds, x_bounds) | ||
|
|
||
| data = src_data[..., y_indices, x_indices] | ||
| len_x = data.shape[-1] | ||
| len_y = data.shape[-2] | ||
| src_area_datas[..., 0:len_y, 0:len_x, target_pt_ji] = data | ||
| src_area_weights[0:len_y, 0:len_x, target_pt_ji] = weights | ||
| if src_masked: | ||
| src_area_masks[ | ||
| ..., 0:len_y, 0:len_x, target_pt_ji | ||
| ] = data.mask | ||
|
|
||
| # Broadcast the weights array to allow numpy's ma.average | ||
| # to be called. | ||
| # Assign new shape to raise error on copy. | ||
| src_area_weights.shape = src_area_datas.shape[-3:] | ||
| # Broadcast weights to match shape of data. | ||
| _, src_area_weights = np.broadcast_arrays(src_area_datas, src_area_weights) | ||
|
|
||
| # Mask the data points | ||
| if src_masked: | ||
| src_area_datas = np.ma.array(src_area_datas, mask=src_area_masks) | ||
|
|
||
| # Transpose weights to match dim ordering in data. | ||
| weights_shape_y = weights.shape[0] | ||
| weights_shape_x = weights.shape[1] | ||
| # Broadcast the weights array to allow numpy's ma.average | ||
| # to be called. | ||
| weights_padded_shape = [1] * data.ndim | ||
| weights_padded_shape[y_dim] = weights_shape_y | ||
| weights_padded_shape[x_dim] = weights_shape_x | ||
| # Assign new shape to raise error on copy. | ||
| weights.shape = weights_padded_shape | ||
| # Broadcast weights to match shape of data. | ||
| _, weights = np.broadcast_arrays(data, weights) | ||
|
|
||
| # Calculate weighted mean taking into account missing data. | ||
| new_data_pt = _weighted_mean_with_mdtol( | ||
| data, weights=weights, axis=axis, mdtol=mdtol | ||
| ) | ||
|
|
||
| # Insert data (and mask) values into new array. | ||
| new_data[..., j, i] = new_data_pt | ||
|
|
||
| # Remove new mask if original data was not masked | ||
| # and no values in the new array are masked. | ||
| if not src_masked and not new_data.mask.any(): | ||
| new_data = new_data.data | ||
| # Calculate weighted mean taking into account missing data. | ||
| new_data = _weighted_mean_with_mdtol( | ||
| src_area_datas, weights=src_area_weights, axis=axis, mdtol=mdtol | ||
| ) | ||
| new_data = new_data.reshape(new_shape) | ||
| if src_masked: | ||
| new_data_mask = new_data.mask | ||
|
|
||
| # Mask the data if originally masked or if the result has masked points | ||
| if ma.isMaskedArray(src_data): | ||
| new_data = ma.array( | ||
| new_data, | ||
| mask=new_data_mask, | ||
| fill_value=src_data.fill_value, | ||
| dtype=dtype, | ||
| ) | ||
| elif new_data_mask.any(): | ||
| new_data = ma.array(new_data, mask=new_data_mask, dtype=dtype) | ||
| else: | ||
| new_data = new_data.astype(dtype) | ||
|
|
||
| # Restore axis to original order | ||
| # Restore data to original form | ||
| if x_dim_orig is None and y_dim_orig is None: | ||
| new_data = np.squeeze(new_data, axis=x_dim) | ||
| new_data = np.squeeze(new_data, axis=y_dim) | ||
|
|
@@ -613,9 +722,13 @@ def _regrid_area_weighted_array( | |
| new_data = np.squeeze(new_data, axis=x_dim) | ||
| new_data = np.moveaxis(new_data, -1, y_dim_orig) | ||
| elif x_dim_orig < y_dim_orig: | ||
| # move the x_dim back first, so that the y_dim will | ||
| # then be moved to its original position | ||
| new_data = np.moveaxis(new_data, -1, x_dim_orig) | ||
| new_data = np.moveaxis(new_data, -1, y_dim_orig) | ||
| else: | ||
| # move the y_dim back first, so that the x_dim will | ||
| # then be moved to its original position | ||
| new_data = np.moveaxis(new_data, -2, y_dim_orig) | ||
| new_data = np.moveaxis(new_data, -1, x_dim_orig) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The last two loops can't be consolidated; add a comment explaining why :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should
dtypeabove be used rather thannp.float64?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think as it standard the code is using the numpy default value of np.float64, so I've used this here. I don't think we should revisit the dtype handling in a follow up PR as I think it would be beneficial to improve it.