diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c70dfd4f3f6..8a771635399 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -41,7 +41,8 @@ New Features - Implement :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`) By `Todd Jennings `_ - +- Allow plotting of boolean arrays. (:pull:`3766`) + By `Marek Jacob `_ Bug fixes ~~~~~~~~~ diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 302cac05b05..71a05627692 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -689,7 +689,7 @@ def newplotfunc( xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__) yplt, ylab_extra = _resolve_intervals_2dplot(yval, plotfunc.__name__) - _ensure_plottable(xplt, yplt) + _ensure_plottable(xplt, yplt, zval) cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( plotfunc, zval.data, **locals() diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index e6c15037cb8..5bf1382994b 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -534,7 +534,7 @@ def _ensure_plottable(*args): Raise exception if there is anything in args that can't be plotted on an axis by matplotlib. """ - numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64] + numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64, np.bool_] other_types = [datetime] try: import cftime @@ -549,10 +549,10 @@ def _ensure_plottable(*args): or _valid_other_type(np.array(x), other_types) ): raise TypeError( - "Plotting requires coordinates to be numeric " - "or dates of type np.datetime64, " + "Plotting requires coordinates to be numeric, boolean, " + "or dates of type numpy.datetime64, " "datetime.datetime, cftime.datetime or " - "pd.Interval." + f"pandas.Interval. Received data of type {np.array(x).dtype} instead." ) if ( _valid_other_type(np.array(x), cftime_datetime) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 7f3f1620133..21fc3a892d8 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -139,6 +139,12 @@ def test1d(self): with raises_regex(ValueError, "None"): self.darray[:, 0, 0].plot(x="dim_1") + with raises_regex(TypeError, "complex128"): + (self.darray[:, 0, 0] + 1j).plot() + + def test_1d_bool(self): + xr.ones_like(self.darray[:, 0, 0], dtype=np.bool).plot() + def test_1d_x_y_kw(self): z = np.arange(10) da = DataArray(np.cos(z), dims=["z"], coords=[z], name="f") @@ -989,6 +995,13 @@ def test_1d_raises_valueerror(self): with raises_regex(ValueError, r"DataArray must be 2d"): self.plotfunc(self.darray[0, :]) + def test_bool(self): + xr.ones_like(self.darray, dtype=np.bool).plot() + + def test_complex_raises_typeerror(self): + with raises_regex(TypeError, "complex128"): + (self.darray + 1j).plot() + def test_3d_raises_valueerror(self): a = DataArray(easy_array((2, 3, 4))) if self.plotfunc.__name__ == "imshow":