Skip to content

Commit 9f11862

Browse files
Surface plots (#5101)
* Use broadcast_like for 2d plot coordinates Use broadcast_like if either `x` or `y` inputs are 2d to ensure that both have dimensions in the same order as the DataArray being plotted. Convert to numpy arrays after possibly using broadcast_like. Simplifies code, and fixes #5097 (bug when dimensions have the same size). * Update whats-new * Implement 'surface()' plot function Wraps mpl_toolkits.mplot3d.axes3d.plot_surface * Make surface plots work with facet grids * Unit tests for surface plot * Minor fixes for surface plots * Add surface plots to api.rst and api-hidden.rst * Update whats-new * Fix tests * mypy fix * seaborn doesn't work with matplotlib 3d toolkit * Remove cfdatetime surface plot test Does not work because the datetime.timedelta does not work with surface's 'shading'. * Ignore type checks for mpl_toolkits module * Check matplotlib version is new enough for surface plots * version check requires matplotlib * Handle matplotlib not installed for TestSurface version check * fix flake8 error * Don't run test_plot_transposed_nondim_coord for surface plots Too complicated to check matplotlib version is high enough just for surface plots. * Apply suggestions from code review Co-authored-by: Mathias Hauser <[email protected]> * More suggestions from code review * black * isort and flake8 * Make surface plots more backward compatible Following suggestion from Illviljan * Clean up matplotlib requirement * Update xarray/plot/plot.py Co-authored-by: Mathias Hauser <[email protected]> * Apply suggestions from code review Co-authored-by: Mathias Hauser <[email protected]> * Use None as default value * black * More 2D plotting method examples in docs * Fix docs * [skip-ci] Make example surface plot look a bit nicer Co-authored-by: Mathias Hauser <[email protected]> Co-authored-by: Mathias Hauser <[email protected]>
1 parent 3391fec commit 9f11862

File tree

10 files changed

+202
-10
lines changed

10 files changed

+202
-10
lines changed

doc/api-hidden.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@
597597
plot.imshow
598598
plot.pcolormesh
599599
plot.scatter
600+
plot.surface
600601

601602
plot.FacetGrid.map_dataarray
602603
plot.FacetGrid.set_titles

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ Plotting
588588
DataArray.plot.line
589589
DataArray.plot.pcolormesh
590590
DataArray.plot.step
591+
DataArray.plot.surface
591592

592593
.. _api.ufuncs:
593594

doc/user-guide/plotting.rst

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,37 @@ produce plots with nonuniform coordinates.
411411
@savefig plotting_nonuniform_coords.png width=4in
412412
b.plot()
413413
414+
====================
415+
Other types of plot
416+
====================
417+
418+
There are several other options for plotting 2D data.
419+
420+
Contour plot using :py:meth:`DataArray.plot.contour()`
421+
422+
.. ipython:: python
423+
:okwarning:
424+
425+
@savefig plotting_contour.png width=4in
426+
air2d.plot.contour()
427+
428+
Filled contour plot using :py:meth:`DataArray.plot.contourf()`
429+
430+
.. ipython:: python
431+
:okwarning:
432+
433+
@savefig plotting_contourf.png width=4in
434+
air2d.plot.contourf()
435+
436+
Surface plot using :py:meth:`DataArray.plot.surface()`
437+
438+
.. ipython:: python
439+
:okwarning:
440+
441+
@savefig plotting_surface.png width=4in
442+
# transpose just to make the example look a bit nicer
443+
air2d.T.plot.surface()
444+
414445
====================
415446
Calling Matplotlib
416447
====================

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ v0.17.1 (unreleased)
2323
New Features
2424
~~~~~~~~~~~~
2525

26+
- Add :py:meth:`DataArray.plot.surface` which wraps matplotlib's `plot_surface` to make
27+
surface plots (:issue:`#2235` :issue:`#5084` :pull:`5101`).
2628
- Allow passing multiple arrays to :py:meth:`Dataset.__setitem__` (:pull:`5216`).
2729
By `Giacomo Caria <https://github.com/gcaria>`_.
2830
- Add 'cumulative' option to :py:meth:`Dataset.integrate` and

xarray/plot/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .dataset_plot import scatter
22
from .facetgrid import FacetGrid
3-
from .plot import contour, contourf, hist, imshow, line, pcolormesh, plot, step
3+
from .plot import contour, contourf, hist, imshow, line, pcolormesh, plot, step, surface
44

55
__all__ = [
66
"plot",
@@ -13,4 +13,5 @@
1313
"pcolormesh",
1414
"FacetGrid",
1515
"scatter",
16+
"surface",
1617
]

xarray/plot/facetgrid.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,9 @@ def map_dataarray(self, func, x, y, **kwargs):
263263
if k not in {"cmap", "colors", "cbar_kwargs", "levels"}
264264
}
265265
func_kwargs.update(cmap_params)
266-
func_kwargs.update({"add_colorbar": False, "add_labels": False})
266+
func_kwargs["add_colorbar"] = False
267+
if func.__name__ != "surface":
268+
func_kwargs["add_labels"] = False
267269

268270
# Get x, y labels for the first subplot
269271
x, y = _infer_xy_labels(

xarray/plot/plot.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,11 @@ def newplotfunc(
633633

634634
# Decide on a default for the colorbar before facetgrids
635635
if add_colorbar is None:
636-
add_colorbar = plotfunc.__name__ != "contour"
636+
add_colorbar = True
637+
if plotfunc.__name__ == "contour" or (
638+
plotfunc.__name__ == "surface" and cmap is None
639+
):
640+
add_colorbar = False
637641
imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == (
638642
3 + (row is not None) + (col is not None)
639643
)
@@ -646,6 +650,25 @@ def newplotfunc(
646650
darray = _rescale_imshow_rgb(darray, vmin, vmax, robust)
647651
vmin, vmax, robust = None, None, False
648652

653+
if subplot_kws is None:
654+
subplot_kws = dict()
655+
656+
if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False):
657+
if ax is None:
658+
# TODO: Importing Axes3D is no longer necessary in matplotlib >= 3.2.
659+
# Remove when minimum requirement of matplotlib is 3.2:
660+
from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa: F401
661+
662+
# delete so it does not end up in locals()
663+
del Axes3D
664+
665+
# Need to create a "3d" Axes instance for surface plots
666+
subplot_kws["projection"] = "3d"
667+
668+
# In facet grids, shared axis labels don't make sense for surface plots
669+
sharex = False
670+
sharey = False
671+
649672
# Handle facetgrids first
650673
if row or col:
651674
allargs = locals().copy()
@@ -658,6 +681,19 @@ def newplotfunc(
658681

659682
plt = import_matplotlib_pyplot()
660683

684+
if (
685+
plotfunc.__name__ == "surface"
686+
and not kwargs.get("_is_facetgrid", False)
687+
and ax is not None
688+
):
689+
import mpl_toolkits # type: ignore
690+
691+
if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D):
692+
raise ValueError(
693+
"If ax is passed to surface(), it must be created with "
694+
'projection="3d"'
695+
)
696+
661697
rgb = kwargs.pop("rgb", None)
662698
if rgb is not None and plotfunc.__name__ != "imshow":
663699
raise ValueError('The "rgb" keyword is only valid for imshow()')
@@ -674,9 +710,10 @@ def newplotfunc(
674710
xval = darray[xlab]
675711
yval = darray[ylab]
676712

677-
if xval.ndim > 1 or yval.ndim > 1:
713+
if xval.ndim > 1 or yval.ndim > 1 or plotfunc.__name__ == "surface":
678714
# Passing 2d coordinate values, need to ensure they are transposed the same
679-
# way as darray
715+
# way as darray.
716+
# Also surface plots always need 2d coordinates
680717
xval = xval.broadcast_like(darray)
681718
yval = yval.broadcast_like(darray)
682719
dims = darray.dims
@@ -734,8 +771,6 @@ def newplotfunc(
734771
# forbid usage of mpl strings
735772
raise ValueError("plt.imshow's `aspect` kwarg is not available in xarray")
736773

737-
if subplot_kws is None:
738-
subplot_kws = dict()
739774
ax = get_axis(figsize, size, aspect, ax, **subplot_kws)
740775

741776
primitive = plotfunc(
@@ -755,6 +790,8 @@ def newplotfunc(
755790
ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra))
756791
ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra))
757792
ax.set_title(darray._title_for_slice())
793+
if plotfunc.__name__ == "surface":
794+
ax.set_zlabel(label_from_attrs(darray))
758795

759796
if add_colorbar:
760797
if add_labels and "label" not in cbar_kwargs:
@@ -987,3 +1024,14 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs):
9871024
ax.set_ylim(y[0], y[-1])
9881025

9891026
return primitive
1027+
1028+
1029+
@_plot2d
1030+
def surface(x, y, z, ax, **kwargs):
1031+
"""
1032+
Surface plot of 2d DataArray
1033+
1034+
Wraps :func:`matplotlib:mpl_toolkits.mplot3d.axes3d.plot_surface`
1035+
"""
1036+
primitive = ax.plot_surface(x, y, z, **kwargs)
1037+
return primitive

xarray/plot/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,14 @@ def _process_cmap_cbar_kwargs(
804804
cmap_params
805805
cbar_kwargs
806806
"""
807+
if func.__name__ == "surface":
808+
# Leave user to specify cmap settings for surface plots
809+
kwargs["cmap"] = cmap
810+
return {
811+
k: kwargs.get(k, None)
812+
for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"]
813+
}, {}
814+
807815
cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs)
808816

809817
if "contour" in func.__name__ and levels is None:

xarray/tests/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def LooseVersion(vstring):
5959

6060

6161
has_matplotlib, requires_matplotlib = _importorskip("matplotlib")
62+
has_matplotlib_3_3_0, requires_matplotlib_3_3_0 = _importorskip(
63+
"matplotlib", minversion="3.3.0"
64+
)
6265
has_scipy, requires_scipy = _importorskip("scipy")
6366
has_pydap, requires_pydap = _importorskip("pydap.client")
6467
has_netCDF4, requires_netCDF4 = _importorskip("netCDF4")

xarray/tests/test_plot.py

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
from copy import copy
44
from datetime import datetime
5+
from typing import Any, Dict, Union
56

67
import numpy as np
78
import pandas as pd
@@ -27,6 +28,7 @@
2728
requires_cartopy,
2829
requires_cftime,
2930
requires_matplotlib,
31+
requires_matplotlib_3_3_0,
3032
requires_nc_time_axis,
3133
requires_seaborn,
3234
)
@@ -35,6 +37,7 @@
3537
try:
3638
import matplotlib as mpl
3739
import matplotlib.pyplot as plt
40+
import mpl_toolkits # type: ignore
3841
except ImportError:
3942
pass
4043

@@ -131,8 +134,8 @@ def setup(self):
131134
# Remove all matplotlib figures
132135
plt.close("all")
133136

134-
def pass_in_axis(self, plotmethod):
135-
fig, axes = plt.subplots(ncols=2)
137+
def pass_in_axis(self, plotmethod, subplot_kw=None):
138+
fig, axes = plt.subplots(ncols=2, subplot_kw=subplot_kw)
136139
plotmethod(ax=axes[0])
137140
assert axes[0].has_data()
138141

@@ -1106,6 +1109,9 @@ class Common2dMixin:
11061109
Should have the same name as the method.
11071110
"""
11081111

1112+
# Needs to be overridden in TestSurface for facet grid plots
1113+
subplot_kws: Union[Dict[Any, Any], None] = None
1114+
11091115
@pytest.fixture(autouse=True)
11101116
def setUp(self):
11111117
da = DataArray(
@@ -1421,7 +1427,7 @@ def test_colorbar_kwargs(self):
14211427
def test_verbose_facetgrid(self):
14221428
a = easy_array((10, 15, 3))
14231429
d = DataArray(a, dims=["y", "x", "z"])
1424-
g = xplt.FacetGrid(d, col="z")
1430+
g = xplt.FacetGrid(d, col="z", subplot_kws=self.subplot_kws)
14251431
g.map_dataarray(self.plotfunc, "x", "y")
14261432
for ax in g.axes.flat:
14271433
assert ax.has_data()
@@ -1821,6 +1827,95 @@ def test_origin_overrides_xyincrease(self):
18211827
assert plt.ylim()[0] < 0
18221828

18231829

1830+
class TestSurface(Common2dMixin, PlotTestCase):
1831+
1832+
plotfunc = staticmethod(xplt.surface)
1833+
subplot_kws = {"projection": "3d"}
1834+
1835+
def test_primitive_artist_returned(self):
1836+
artist = self.plotmethod()
1837+
assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection)
1838+
1839+
@pytest.mark.slow
1840+
def test_2d_coord_names(self):
1841+
self.plotmethod(x="x2d", y="y2d")
1842+
# make sure labels came out ok
1843+
ax = plt.gca()
1844+
assert "x2d" == ax.get_xlabel()
1845+
assert "y2d" == ax.get_ylabel()
1846+
assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel()
1847+
1848+
def test_xyincrease_false_changes_axes(self):
1849+
# Does not make sense for surface plots
1850+
pytest.skip("does not make sense for surface plots")
1851+
1852+
def test_xyincrease_true_changes_axes(self):
1853+
# Does not make sense for surface plots
1854+
pytest.skip("does not make sense for surface plots")
1855+
1856+
def test_can_pass_in_axis(self):
1857+
self.pass_in_axis(self.plotmethod, subplot_kw={"projection": "3d"})
1858+
1859+
def test_default_cmap(self):
1860+
# Does not make sense for surface plots with default arguments
1861+
pytest.skip("does not make sense for surface plots")
1862+
1863+
def test_diverging_color_limits(self):
1864+
# Does not make sense for surface plots with default arguments
1865+
pytest.skip("does not make sense for surface plots")
1866+
1867+
def test_colorbar_kwargs(self):
1868+
# Does not make sense for surface plots with default arguments
1869+
pytest.skip("does not make sense for surface plots")
1870+
1871+
def test_cmap_and_color_both(self):
1872+
# Does not make sense for surface plots with default arguments
1873+
pytest.skip("does not make sense for surface plots")
1874+
1875+
def test_seaborn_palette_as_cmap(self):
1876+
# seaborn does not work with mpl_toolkits.mplot3d
1877+
with pytest.raises(ValueError):
1878+
super().test_seaborn_palette_as_cmap()
1879+
1880+
# Need to modify this test for surface(), because all subplots should have labels,
1881+
# not just left and bottom
1882+
@pytest.mark.filterwarnings("ignore:tight_layout cannot")
1883+
def test_convenient_facetgrid(self):
1884+
a = easy_array((10, 15, 4))
1885+
d = DataArray(a, dims=["y", "x", "z"])
1886+
g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2)
1887+
1888+
assert_array_equal(g.axes.shape, [2, 2])
1889+
for (y, x), ax in np.ndenumerate(g.axes):
1890+
assert ax.has_data()
1891+
assert "y" == ax.get_ylabel()
1892+
assert "x" == ax.get_xlabel()
1893+
1894+
# Infering labels
1895+
g = self.plotfunc(d, col="z", col_wrap=2)
1896+
assert_array_equal(g.axes.shape, [2, 2])
1897+
for (y, x), ax in np.ndenumerate(g.axes):
1898+
assert ax.has_data()
1899+
assert "y" == ax.get_ylabel()
1900+
assert "x" == ax.get_xlabel()
1901+
1902+
@requires_matplotlib_3_3_0
1903+
def test_viridis_cmap(self):
1904+
return super().test_viridis_cmap()
1905+
1906+
@requires_matplotlib_3_3_0
1907+
def test_can_change_default_cmap(self):
1908+
return super().test_can_change_default_cmap()
1909+
1910+
@requires_matplotlib_3_3_0
1911+
def test_colorbar_default_label(self):
1912+
return super().test_colorbar_default_label()
1913+
1914+
@requires_matplotlib_3_3_0
1915+
def test_facetgrid_map_only_appends_mappables(self):
1916+
return super().test_facetgrid_map_only_appends_mappables()
1917+
1918+
18241919
class TestFacetGrid(PlotTestCase):
18251920
@pytest.fixture(autouse=True)
18261921
def setUp(self):

0 commit comments

Comments
 (0)