| 
39 | 39 | from collections.abc import Iterable  | 
40 | 40 | import functools  | 
41 | 41 | from functools import wraps  | 
 | 42 | +import warnings  | 
42 | 43 | 
 
  | 
43 | 44 | import dask.array as da  | 
44 | 45 | import numpy as np  | 
@@ -591,7 +592,13 @@ def aggregate(self, data, axis, **kwargs):  | 
591 | 592 |             and result is not ma.masked  | 
592 | 593 |         ):  | 
593 | 594 |             fraction_not_missing = data.count(axis=axis) / data.shape[axis]  | 
594 |  | -            mask_update = 1 - mdtol > fraction_not_missing  | 
 | 595 | +            mask_update = np.array(1 - mdtol > fraction_not_missing)  | 
 | 596 | +            if np.array(result).ndim > mask_update.ndim:  | 
 | 597 | +                # call_func created trailing dimension.  | 
 | 598 | +                mask_update = np.broadcast_to(  | 
 | 599 | +                    mask_update.reshape(mask_update.shape + (1,)),  | 
 | 600 | +                    np.array(result).shape,  | 
 | 601 | +                )  | 
595 | 602 |             if ma.isMaskedArray(result):  | 
596 | 603 |                 result.mask = result.mask | mask_update  | 
597 | 604 |             else:  | 
@@ -720,6 +727,25 @@ def __init__(self, units_func=None, **kwargs):  | 
720 | 727 |             **kwargs,  | 
721 | 728 |         )  | 
722 | 729 | 
 
  | 
 | 730 | +    def _base_aggregate(self, data, axis, lazy, **kwargs):  | 
 | 731 | +        """  | 
 | 732 | +        Method to avoid duplication of checks in aggregate and lazy_aggregate.  | 
 | 733 | +        """  | 
 | 734 | +        msg = "{} aggregator requires the mandatory keyword argument {!r}."  | 
 | 735 | +        for arg in self._args:  | 
 | 736 | +            if arg not in kwargs:  | 
 | 737 | +                raise ValueError(msg.format(self.name(), arg))  | 
 | 738 | + | 
 | 739 | +        if kwargs.get("fast_percentile_method", False) and (  | 
 | 740 | +            kwargs.get("mdtol", 1) != 0  | 
 | 741 | +        ):  | 
 | 742 | +            kwargs["error_on_masked"] = True  | 
 | 743 | + | 
 | 744 | +        if lazy:  | 
 | 745 | +            return _Aggregator.lazy_aggregate(self, data, axis, **kwargs)  | 
 | 746 | +        else:  | 
 | 747 | +            return _Aggregator.aggregate(self, data, axis, **kwargs)  | 
 | 748 | + | 
723 | 749 |     def aggregate(self, data, axis, **kwargs):  | 
724 | 750 |         """  | 
725 | 751 |         Perform the percentile aggregation over the given data.  | 
@@ -755,12 +781,7 @@ def aggregate(self, data, axis, **kwargs):  | 
755 | 781 | 
  | 
756 | 782 |         """  | 
757 | 783 | 
 
  | 
758 |  | -        msg = "{} aggregator requires the mandatory keyword argument {!r}."  | 
759 |  | -        for arg in self._args:  | 
760 |  | -            if arg not in kwargs:  | 
761 |  | -                raise ValueError(msg.format(self.name(), arg))  | 
762 |  | - | 
763 |  | -        return _Aggregator.aggregate(self, data, axis, **kwargs)  | 
 | 784 | +        return self._base_aggregate(data, axis, lazy=False, **kwargs)  | 
764 | 785 | 
 
  | 
765 | 786 |     def lazy_aggregate(self, data, axis, **kwargs):  | 
766 | 787 |         """  | 
@@ -794,12 +815,7 @@ def lazy_aggregate(self, data, axis, **kwargs):  | 
794 | 815 | 
  | 
795 | 816 |         """  | 
796 | 817 | 
 
  | 
797 |  | -        msg = "{} aggregator requires the mandatory keyword argument {!r}."  | 
798 |  | -        for arg in self._args:  | 
799 |  | -            if arg not in kwargs:  | 
800 |  | -                raise ValueError(msg.format(self.name(), arg))  | 
801 |  | - | 
802 |  | -        return _Aggregator.lazy_aggregate(self, data, axis, **kwargs)  | 
 | 818 | +        return self._base_aggregate(data, axis, lazy=True, **kwargs)  | 
803 | 819 | 
 
  | 
804 | 820 |     def post_process(self, collapsed_cube, data_result, coords, **kwargs):  | 
805 | 821 |         """  | 
@@ -1281,10 +1297,19 @@ def _calc_percentile(data, percent, fast_percentile_method=False, **kwargs):  | 
1281 | 1297 | 
  | 
1282 | 1298 |     """  | 
1283 | 1299 |     if fast_percentile_method:  | 
1284 |  | -        msg = "Cannot use fast np.percentile method with masked array."  | 
1285 |  | -        if ma.is_masked(data):  | 
1286 |  | -            raise TypeError(msg)  | 
1287 |  | -        result = np.percentile(data, percent, axis=-1)  | 
 | 1300 | +        if kwargs.pop("error_on_masked", False):  | 
 | 1301 | +            msg = (  | 
 | 1302 | +                "Cannot use fast np.percentile method with masked array unless"  | 
 | 1303 | +                " mdtol is 0."  | 
 | 1304 | +            )  | 
 | 1305 | +            if ma.is_masked(data):  | 
 | 1306 | +                raise TypeError(msg)  | 
 | 1307 | +        with warnings.catch_warnings():  | 
 | 1308 | +            warnings.filterwarnings(  | 
 | 1309 | +                "ignore",  | 
 | 1310 | +                "Warning: 'partition' will ignore the 'mask' of the MaskedArray.",  | 
 | 1311 | +            )  | 
 | 1312 | +            result = np.percentile(data, percent, axis=-1)  | 
1288 | 1313 |         result = result.T  | 
1289 | 1314 |     else:  | 
1290 | 1315 |         quantiles = percent / 100.0  | 
@@ -1965,7 +1990,8 @@ def interp_order(length):  | 
1965 | 1990 | * fast_percentile_method (boolean):  | 
1966 | 1991 |     When set to True, uses :func:`numpy.percentile` method as a faster  | 
1967 | 1992 |     alternative to the :func:`scipy.stats.mstats.mquantiles` method.  alphap and  | 
1968 |  | -    betap are ignored. An exception is raised if the data are masked.  | 
 | 1993 | +    betap are ignored. An exception is raised if the data are masked and the  | 
 | 1994 | +    missing data tolerance is not 0.  | 
1969 | 1995 |     Defaults to False.  | 
1970 | 1996 | 
  | 
1971 | 1997 | **For example**:  | 
 | 
0 commit comments