@@ -461,135 +461,264 @@ def _regrid_area_weighted_array(src_data, x_dim, y_dim,
461461 grid.
462462
463463 """
464- # Determine which grid bounds are within src extent.
465- y_within_bounds = _within_bounds (
466- src_y_bounds , grid_y_bounds , grid_y_decreasing
467- )
468- x_within_bounds = _within_bounds (
469- src_x_bounds , grid_x_bounds , grid_x_decreasing
470- )
471464
472- # Cache which src_bounds are within grid bounds
473- cached_x_bounds = []
474- cached_x_indices = []
475- for (x_0 , x_1 ) in grid_x_bounds :
476- if grid_x_decreasing :
477- x_0 , x_1 = x_1 , x_0
478- x_bounds , x_indices = _cropped_bounds (src_x_bounds , x_0 , x_1 )
479- cached_x_bounds .append (x_bounds )
480- cached_x_indices .append (x_indices )
481-
482- # Create empty data array to match the new grid.
483- # Note that dtype is not preserved and that the array is
484- # masked to allow for regions that do not overlap.
485- new_shape = list (src_data .shape )
486- if x_dim is not None :
487- new_shape [x_dim ] = grid_x_bounds .shape [0 ]
488- if y_dim is not None :
489- new_shape [y_dim ] = grid_y_bounds .shape [0 ]
465+ def _calculate_regrid_area_weighted_weights (
466+ src_x_bounds ,
467+ src_y_bounds ,
468+ grid_x_bounds ,
469+ grid_y_bounds ,
470+ grid_x_decreasing ,
471+ grid_y_decreasing ,
472+ area_func ,
473+ circular = False ,
474+ ):
475+ """
476+ Compute the area weights used for area-weighted regridding.
490477
478+ """
479+ # Determine which grid bounds are within src extent.
480+ y_within_bounds = _within_bounds (
481+ src_y_bounds , grid_y_bounds , grid_y_decreasing
482+ )
483+ x_within_bounds = _within_bounds (
484+ src_x_bounds , grid_x_bounds , grid_x_decreasing
485+ )
486+
487+ # Cache which src_bounds are within grid bounds
488+ cached_x_bounds = []
489+ cached_x_indices = []
490+ max_x_indices = 0
491+ for (x_0 , x_1 ) in grid_x_bounds :
492+ if grid_x_decreasing :
493+ x_0 , x_1 = x_1 , x_0
494+ x_bounds , x_indices = _cropped_bounds (src_x_bounds , x_0 , x_1 )
495+ cached_x_bounds .append (x_bounds )
496+ cached_x_indices .append (x_indices )
497+ # Keep record of the largest slice
498+ if isinstance (x_indices , slice ):
499+ x_indices_size = np .sum (x_indices .stop - x_indices .start )
500+ else : # is tuple of indices
501+ x_indices_size = len (x_indices )
502+ if x_indices_size > max_x_indices :
503+ max_x_indices = x_indices_size
504+
505+ # Cache which y src_bounds areas and weights are within grid bounds
506+ cached_y_indices = []
507+ cached_weights = []
508+ max_y_indices = 0
509+ for j , (y_0 , y_1 ) in enumerate (grid_y_bounds ):
510+ # Reverse lower and upper if dest grid is decreasing.
511+ if grid_y_decreasing :
512+ y_0 , y_1 = y_1 , y_0
513+ y_bounds , y_indices = _cropped_bounds (src_y_bounds , y_0 , y_1 )
514+ cached_y_indices .append (y_indices )
515+ # Keep record of the largest slice
516+ if isinstance (y_indices , slice ):
517+ y_indices_size = np .sum (y_indices .stop - y_indices .start )
518+ else : # is tuple of indices
519+ y_indices_size = len (y_indices )
520+ if y_indices_size > max_y_indices :
521+ max_y_indices = y_indices_size
522+
523+ weights_i = []
524+ for i , (x_0 , x_1 ) in enumerate (grid_x_bounds ):
525+ # Reverse lower and upper if dest grid is decreasing.
526+ if grid_x_decreasing :
527+ x_0 , x_1 = x_1 , x_0
528+ x_bounds = cached_x_bounds [i ]
529+ x_indices = cached_x_indices [i ]
530+
531+ # Determine whether element i, j overlaps with src and hence
532+ # an area weight should be computed.
533+ # If x_0 > x_1 then we want [0]->x_1 and x_0->[0] + mod in the case
534+ # of wrapped longitudes. However if the src grid is not global
535+ # (i.e. circular) this new cell would include a region outside of
536+ # the extent of the src grid and thus the weight is therefore
537+ # invalid.
538+ outside_extent = x_0 > x_1 and not circular
539+ if (
540+ outside_extent
541+ or not y_within_bounds [j ]
542+ or not x_within_bounds [i ]
543+ ):
544+ weights = False
545+ else :
546+ # Calculate weights based on areas of cropped bounds.
547+ if isinstance (x_indices , tuple ) and isinstance (
548+ y_indices , tuple
549+ ):
550+ raise RuntimeError (
551+ "Cannot handle split bounds " "in both x and y."
552+ )
553+ weights = area_func (y_bounds , x_bounds )
554+ weights_i .append (weights )
555+ cached_weights .append (weights_i )
556+ return (
557+ tuple (cached_x_indices ),
558+ tuple (cached_y_indices ),
559+ max_x_indices ,
560+ max_y_indices ,
561+ tuple (cached_weights ),
562+ )
563+
564+ weights_info = _calculate_regrid_area_weighted_weights (
565+ src_x_bounds ,
566+ src_y_bounds ,
567+ grid_x_bounds ,
568+ grid_y_bounds ,
569+ grid_x_decreasing ,
570+ grid_y_decreasing ,
571+ area_func ,
572+ circular ,
573+ )
574+ (
575+ cached_x_indices ,
576+ cached_y_indices ,
577+ max_x_indices ,
578+ max_y_indices ,
579+ cached_weights ,
580+ ) = weights_info
581+ # Delete variables that are not needed and would not be available
582+ # if _calculate_regrid_area_weighted_weights was refactored further
583+ del src_x_bounds , src_y_bounds , grid_x_bounds , grid_y_bounds
584+ del grid_x_decreasing , grid_y_decreasing
585+ del area_func , circular
586+
587+ # Ensure we have x_dim and y_dim.
588+ x_dim_orig = x_dim
589+ y_dim_orig = y_dim
590+ if y_dim is None :
591+ src_data = np .expand_dims (src_data , axis = src_data .ndim )
592+ y_dim = src_data .ndim - 1
593+ if x_dim is None :
594+ src_data = np .expand_dims (src_data , axis = src_data .ndim )
595+ x_dim = src_data .ndim - 1
596+ # Move y_dim and x_dim to last dimensions
597+ if not x_dim == src_data .ndim - 1 :
598+ src_data = np .moveaxis (src_data , x_dim , - 1 )
599+ if not y_dim == src_data .ndim - 2 :
600+ if x_dim < y_dim :
601+ # note: y_dim was shifted along by one position when
602+ # x_dim was moved to the last dimension
603+ src_data = np .moveaxis (src_data , y_dim - 1 , - 2 )
604+ elif x_dim > y_dim :
605+ src_data = np .moveaxis (src_data , y_dim , - 2 )
606+ x_dim = src_data .ndim - 1
607+ y_dim = src_data .ndim - 2
608+
609+ # Create empty "pre-averaging" data array that will enable the
610+ # src_data data coresponding to a given target grid point,
611+ # to be stacked per point.
612+ # Note that dtype is not preserved and that the array mask
613+ # allows for regions that do not overlap.
614+ new_shape = list (src_data .shape )
615+ new_shape [x_dim ] = len (cached_x_indices )
616+ new_shape [y_dim ] = len (cached_y_indices )
617+ num_target_pts = len (cached_y_indices ) * len (cached_x_indices )
618+ src_areas_shape = list (src_data .shape )
619+ src_areas_shape [y_dim ] = max_y_indices
620+ src_areas_shape [x_dim ] = max_x_indices
621+ src_areas_shape += [num_target_pts ]
491622 # Use input cube dtype or convert values to the smallest possible float
492623 # dtype when necessary.
493624 dtype = np .promote_types (src_data .dtype , np .float16 )
625+ # Create empty arrays to hold src_data per target point, and weights
626+ src_area_datas = np .zeros (src_areas_shape , dtype = np .float64 )
627+ src_area_weights = np .zeros (
628+ list ((max_y_indices , max_x_indices , num_target_pts ))
629+ )
494630
495631 # Flag to indicate whether the original data was a masked array.
496- src_masked = ma .isMaskedArray (src_data )
632+ src_masked = src_data . mask . any () if ma .isMaskedArray (src_data ) else False
497633 if src_masked :
498- new_data = ma .zeros (new_shape , fill_value = src_data .fill_value ,
499- dtype = dtype )
634+ src_area_masks = np .full (src_areas_shape , True , dtype = np .bool )
500635 else :
501- new_data = ma .zeros (new_shape , dtype = dtype )
502- # Assign to mask to explode it, allowing indexed assignment.
503- new_data .mask = False
636+ new_data_mask = np .full (new_shape , False , dtype = np .bool )
504637
505638 # Axes of data over which the weighted mean is calculated.
506- axes = []
507- if y_dim is not None :
508- axes .append (y_dim )
509- if x_dim is not None :
510- axes .append (x_dim )
511- axis = tuple (axes )
512-
513- # Simple for loop approach.
514- indices = [slice (None )] * new_data .ndim
515- for j , (y_0 , y_1 ) in enumerate (grid_y_bounds ):
516- # Reverse lower and upper if dest grid is decreasing.
517- if grid_y_decreasing :
518- y_0 , y_1 = y_1 , y_0
519- y_bounds , y_indices = _cropped_bounds (src_y_bounds , y_0 , y_1 )
520- for i , (x_0 , x_1 ) in enumerate (grid_x_bounds ):
521- # Reverse lower and upper if dest grid is decreasing.
522- if grid_x_decreasing :
523- x_0 , x_1 = x_1 , x_0
524- x_bounds = cached_x_bounds [i ]
525- x_indices = cached_x_indices [i ]
526-
527- # Determine whether to mask element i, j based on overlap with
528- # src.
529- # If x_0 > x_1 then we want [0]->x_1 and x_0->[0] + mod in the case
530- # of wrapped longitudes. However if the src grid is not global
531- # (i.e. circular) this new cell would include a region outside of
532- # the extent of the src grid and should therefore be masked.
533- outside_extent = x_0 > x_1 and not circular
534- if (outside_extent or not y_within_bounds [j ] or not
535- x_within_bounds [i ]):
536- # Mask out element(s) in new_data
537- if x_dim is not None :
538- indices [x_dim ] = i
539- if y_dim is not None :
540- indices [y_dim ] = j
541- new_data [tuple (indices )] = ma .masked
639+ axis = (y_dim , x_dim )
640+
641+ # Stack the src_area data and weights for each target point
642+ target_pt_ji = - 1
643+ for j , y_indices in enumerate (cached_y_indices ):
644+ for i , x_indices in enumerate (cached_x_indices ):
645+ target_pt_ji += 1
646+ # Determine whether to mask element i, j based on whether
647+ # there are valid weights.
648+ weights = cached_weights [j ][i ]
649+ if isinstance (weights , bool ) and not weights :
650+ if not src_masked :
651+ # Cheat! Fill the data with zeros and weights as one.
652+ # The weighted average result will be the same, but
653+ # we avoid dividing by zero.
654+ src_area_weights [..., target_pt_ji ] = 1
655+ new_data_mask [..., j , i ] = True
542656 else :
543657 # Calculate weighted mean of data points.
544658 # Slice out relevant data (this may or may not be a view()
545659 # depending on x_indices being a slice or not).
546- if isinstance (x_indices , tuple ) and isinstance (
547- y_indices , tuple
548- ):
549- raise RuntimeError (
550- "Cannot handle split bounds " "in both x and y."
551- )
552- # Calculate weights based on areas of cropped bounds.
553- weights = area_func (y_bounds , x_bounds )
554-
555- if x_dim is not None :
556- indices [x_dim ] = x_indices
557- if y_dim is not None :
558- indices [y_dim ] = y_indices
559- data = src_data [tuple (indices )]
560-
561- # Transpose weights to match dim ordering in data.
562- weights_shape_y = weights .shape [0 ]
563- weights_shape_x = weights .shape [1 ]
564- if x_dim is not None and y_dim is not None and x_dim < y_dim :
565- weights = weights .T
566- # Broadcast the weights array to allow numpy's ma.average
567- # to be called.
568- weights_padded_shape = [1 ] * data .ndim
569- if y_dim is not None :
570- weights_padded_shape [y_dim ] = weights_shape_y
571- if x_dim is not None :
572- weights_padded_shape [x_dim ] = weights_shape_x
573- # Assign new shape to raise error on copy.
574- weights .shape = weights_padded_shape
575- # Broadcast weights to match shape of data.
576- _ , weights = np .broadcast_arrays (data , weights )
577-
578- # Calculate weighted mean taking into account missing data.
579- new_data_pt = _weighted_mean_with_mdtol (
580- data , weights = weights , axis = axis , mdtol = mdtol )
581-
582- # Insert data (and mask) values into new array.
583- if x_dim is not None :
584- indices [x_dim ] = i
585- if y_dim is not None :
586- indices [y_dim ] = j
587- new_data [tuple (indices )] = new_data_pt
588-
589- # Remove new mask if original data was not masked
590- # and no values in the new array are masked.
591- if not src_masked and not new_data .mask .any ():
592- new_data = new_data .data
660+ data = src_data [..., y_indices , x_indices ]
661+ len_x = data .shape [- 1 ]
662+ len_y = data .shape [- 2 ]
663+ src_area_datas [..., 0 :len_y , 0 :len_x , target_pt_ji ] = data
664+ src_area_weights [0 :len_y , 0 :len_x , target_pt_ji ] = weights
665+ if src_masked :
666+ src_area_masks [
667+ ..., 0 :len_y , 0 :len_x , target_pt_ji
668+ ] = data .mask
669+
670+ # Broadcast the weights array to allow numpy's ma.average
671+ # to be called.
672+ # Assign new shape to raise error on copy.
673+ src_area_weights .shape = src_area_datas .shape [- 3 :]
674+ # Broadcast weights to match shape of data.
675+ _ , src_area_weights = np .broadcast_arrays (src_area_datas , src_area_weights )
676+
677+ # Mask the data points
678+ if src_masked :
679+ src_area_datas = np .ma .array (src_area_datas , mask = src_area_masks )
680+
681+ # Calculate weighted mean taking into account missing data.
682+ new_data = _weighted_mean_with_mdtol (
683+ src_area_datas , weights = src_area_weights , axis = axis , mdtol = mdtol
684+ )
685+ new_data = new_data .reshape (new_shape )
686+ if src_masked :
687+ new_data_mask = new_data .mask
688+
689+ # Mask the data if originally masked or if the result has masked points
690+ if ma .isMaskedArray (src_data ):
691+ new_data = ma .array (
692+ new_data ,
693+ mask = new_data_mask ,
694+ fill_value = src_data .fill_value ,
695+ dtype = dtype ,
696+ )
697+ elif new_data_mask .any ():
698+ new_data = ma .array (new_data , mask = new_data_mask , dtype = dtype )
699+ else :
700+ new_data = new_data .astype (dtype )
701+
702+ # Restore data to original form
703+ if x_dim_orig is None and y_dim_orig is None :
704+ new_data = np .squeeze (new_data , axis = x_dim )
705+ new_data = np .squeeze (new_data , axis = y_dim )
706+ elif y_dim_orig is None :
707+ new_data = np .squeeze (new_data , axis = y_dim )
708+ new_data = np .moveaxis (new_data , - 1 , x_dim_orig )
709+ elif x_dim_orig is None :
710+ new_data = np .squeeze (new_data , axis = x_dim )
711+ new_data = np .moveaxis (new_data , - 1 , y_dim_orig )
712+ elif x_dim_orig < y_dim_orig :
713+ # move the x_dim back first, so that the y_dim will
714+ # then be moved to its original position
715+ new_data = np .moveaxis (new_data , - 1 , x_dim_orig )
716+ new_data = np .moveaxis (new_data , - 1 , y_dim_orig )
717+ else :
718+ # move the y_dim back first, so that the x_dim will
719+ # then be moved to its original position
720+ new_data = np .moveaxis (new_data , - 2 , y_dim_orig )
721+ new_data = np .moveaxis (new_data , - 1 , x_dim_orig )
593722
594723 return new_data
595724
0 commit comments