From 7d1fc5b5e997e8defdeb85c298dcc2af5d4dff1f Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 25 Sep 2022 21:48:56 +0200 Subject: [PATCH 1/5] fix utils.get_axis with kwargs --- xarray/plot/utils.py | 30 ++++++++++++++------- xarray/tests/test_plot.py | 56 ++++++++++++++++++++++++++++++++++----- 2 files changed, 70 insertions(+), 16 deletions(-) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index f106d56689c..11bd66a6945 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -30,6 +30,8 @@ if TYPE_CHECKING: + from matplotlib.axes import Axes + from ..core.dataarray import DataArray @@ -423,7 +425,13 @@ def _assert_valid_xy(darray: DataArray, xy: None | Hashable, name: str) -> None: ) -def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): +def get_axis( + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + **subplot_kws: Any, +) -> Axes: try: import matplotlib as mpl import matplotlib.pyplot as plt @@ -435,28 +443,32 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): raise ValueError("cannot provide both `figsize` and `ax` arguments") if size is not None: raise ValueError("cannot provide both `figsize` and `size` arguments") - _, ax = plt.subplots(figsize=figsize) - elif size is not None: + _, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kws) + return ax + + if size is not None: if ax is not None: raise ValueError("cannot provide both `size` and `ax` arguments") if aspect is None: width, height = mpl.rcParams["figure.figsize"] aspect = width / height figsize = (size * aspect, size) - _, ax = plt.subplots(figsize=figsize) - elif aspect is not None: + _, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kws) + return ax + + if aspect is not None: raise ValueError("cannot provide `aspect` argument without `size`") - if kwargs and ax is not None: + if subplot_kws and ax is not None: raise ValueError("cannot use subplot_kws with existing ax") if ax is None: - ax = _maybe_gca(**kwargs) + ax = _maybe_gca(**subplot_kws) return ax -def _maybe_gca(**kwargs): +def _maybe_gca(**subplot_kws: Any) -> Axes: import matplotlib.pyplot as plt @@ -468,7 +480,7 @@ def _maybe_gca(**kwargs): # can not pass kwargs to active axes return plt.gca() - return plt.axes(**kwargs) + return plt.axes(**subplot_kws) def _get_units_from_attrs(da) -> str: diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index f37c2fd7508..c3065251e34 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2955,7 +2955,7 @@ def test_facetgrid_single_contour(): @requires_matplotlib -def test_get_axis(): +def test_get_axis_raises(): # test get_axis works with different args combinations # and return the right type @@ -2975,18 +2975,60 @@ def test_get_axis(): with pytest.raises(ValueError, match="`aspect` argument without `size`"): get_axis(figsize=None, size=None, aspect=4 / 3, ax=None) + # cannot provide axis and subplot_kws + with pytest.raises(ValueError, match="cannot use subplot_kws with existing ax"): + get_axis(figsize=None, size=None, aspect=None, ax=1, something_else=5) + + +@requires_matplotlib +@pytest.mark.parametrize( + ["figsize", "size", "aspect", "ax", "kwargs"], + [ + pytest.param((3, 2), None, None, None, {}, id="figsize"), + pytest.param( + (3.5, 2.5), None, None, None, {"label": "test"}, id="figsize_kwargs" + ), + pytest.param(None, 5, None, None, {}, id="size"), + pytest.param(None, 5.5, None, None, {"label": "test"}, id="size_kwargs"), + pytest.param(None, 5, 1, None, {}, id="size+aspect"), + pytest.param(None, None, None, True, {}, id="ax"), + pytest.param(None, None, None, None, {}, id="default"), + pytest.param(None, None, None, None, {"label": "test"}, id="default_kwargs"), + ], +) +def test_get_axis( + figsize: tuple[float, float] | None, + size: float | None, + aspect: float | None, + ax: bool | None, + kwargs: dict[str, Any], +) -> None: with figure_context(): - ax = get_axis() - assert isinstance(ax, mpl.axes.Axes) + inp_ax = None if ax is None else plt.axes() + out_ax = get_axis( + figsize=figsize, size=size, aspect=aspect, ax=inp_ax, **kwargs + ) + assert isinstance(out_ax, mpl.axes.Axes) +@requires_matplotlib @requires_cartopy -def test_get_axis_cartopy(): - +@pytest.mark.parametrize( + ["figsize", "size", "aspect"], + [ + pytest.param((3, 2), None, None, id="figsize"), + pytest.param(None, 5, None, id="size"), + pytest.param(None, 5, 1, id="size+aspect"), + pytest.param(None, None, None, id="default"), + ], +) +def test_get_axis_cartopy( + figsize: tuple[float, float] | None, size: float | None, aspect: float | None +) -> None: kwargs = {"projection": cartopy.crs.PlateCarree()} with figure_context(): - ax = get_axis(**kwargs) - assert isinstance(ax, cartopy.mpl.geoaxes.GeoAxesSubplot) + out_ax = get_axis(figsize=figsize, size=size, aspect=aspect, **kwargs) + assert isinstance(out_ax, cartopy.mpl.geoaxes.GeoAxesSubplot) @requires_matplotlib From 7ca089cbe23e557be1da05453e55478e0f5c9a1a Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 25 Sep 2022 21:52:02 +0200 Subject: [PATCH 2/5] add bugfix to whats-new --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 10b59ba1c0c..41e05bd2752 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -75,6 +75,8 @@ Bug fixes - Allow writing NetCDF files including only dimensionless variables using the distributed or multiprocessing scheduler (:issue:`7013`, :pull:`7040`). By `Francesco Nattino `_. +- Fix bug where subplot_kwargs were not working when plotting with figsize, size or aspect (:issue:`7078`, :pull:`7080`) + By `Michael Niklas `_. Documentation ~~~~~~~~~~~~~ From a13fefdf349d83495101b4e82185ceb6c8178abc Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 27 Sep 2022 21:49:59 +0200 Subject: [PATCH 3/5] better describe current axis use of plot --- xarray/plot/plot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index ae0adfff00b..ee616b9040e 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -280,7 +280,8 @@ def plot( col_wrap : int, optional Use together with ``col`` to wrap faceted plots. ax : matplotlib axes object, optional - If ``None``, use the current axes. Not applicable when using facets. + Axes on which to plot. By default, use the current axes. + Mutually exclusive with ``size``, ``figsize`` and facets. rtol : float, optional Relative tolerance used to determine if the indexes are uniformly spaced. Usually a small positive number. From 99a1f5cd1ea75777d599ff2815384251b6f0bf3c Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 27 Sep 2022 21:57:16 +0200 Subject: [PATCH 4/5] add test for get_axis current axis --- xarray/tests/test_plot.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index c3065251e34..eb45fd7c54f 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2956,8 +2956,7 @@ def test_facetgrid_single_contour(): @requires_matplotlib def test_get_axis_raises(): - # test get_axis works with different args combinations - # and return the right type + # test get_axis raises an error if trying to do invalid things # cannot provide both ax and figsize with pytest.raises(ValueError, match="both `figsize` and `ax`"): @@ -3031,6 +3030,14 @@ def test_get_axis_cartopy( assert isinstance(out_ax, cartopy.mpl.geoaxes.GeoAxesSubplot) +@requires_matplotlib +def test_get_axis_current() -> None: + with figure_context(): + _, ax = plt.subplots() + out_ax = get_axis() + assert ax is out_ax + + @requires_matplotlib def test_maybe_gca(): From 2a2968e4cb9e076ac514c9b3d420709fa7571d3b Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 27 Sep 2022 21:59:52 +0200 Subject: [PATCH 5/5] simplify test a bit --- xarray/tests/test_plot.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index eb45fd7c54f..ca530bc9cce 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2983,27 +2983,27 @@ def test_get_axis_raises(): @pytest.mark.parametrize( ["figsize", "size", "aspect", "ax", "kwargs"], [ - pytest.param((3, 2), None, None, None, {}, id="figsize"), + pytest.param((3, 2), None, None, False, {}, id="figsize"), pytest.param( - (3.5, 2.5), None, None, None, {"label": "test"}, id="figsize_kwargs" + (3.5, 2.5), None, None, False, {"label": "test"}, id="figsize_kwargs" ), - pytest.param(None, 5, None, None, {}, id="size"), - pytest.param(None, 5.5, None, None, {"label": "test"}, id="size_kwargs"), - pytest.param(None, 5, 1, None, {}, id="size+aspect"), + pytest.param(None, 5, None, False, {}, id="size"), + pytest.param(None, 5.5, None, False, {"label": "test"}, id="size_kwargs"), + pytest.param(None, 5, 1, False, {}, id="size+aspect"), pytest.param(None, None, None, True, {}, id="ax"), - pytest.param(None, None, None, None, {}, id="default"), - pytest.param(None, None, None, None, {"label": "test"}, id="default_kwargs"), + pytest.param(None, None, None, False, {}, id="default"), + pytest.param(None, None, None, False, {"label": "test"}, id="default_kwargs"), ], ) def test_get_axis( figsize: tuple[float, float] | None, size: float | None, aspect: float | None, - ax: bool | None, + ax: bool, kwargs: dict[str, Any], ) -> None: with figure_context(): - inp_ax = None if ax is None else plt.axes() + inp_ax = plt.axes() if ax else None out_ax = get_axis( figsize=figsize, size=size, aspect=aspect, ax=inp_ax, **kwargs )