diff --git a/plotly_resampler/figure_resampler/figure_resampler.py b/plotly_resampler/figure_resampler/figure_resampler.py index b44abf38..8b240e7f 100644 --- a/plotly_resampler/figure_resampler/figure_resampler.py +++ b/plotly_resampler/figure_resampler/figure_resampler.py @@ -11,7 +11,7 @@ __author__ = "Jonas Van Der Donckt, Jeroen Van Der Donckt, Emiel Deprost" import warnings -from typing import Tuple +from typing import Tuple, List import dash import plotly.graph_objects as go @@ -41,6 +41,7 @@ def __init__( show_mean_aggregation_size: bool = True, convert_traces_kwargs: dict | None = None, verbose: bool = False, + show_dash_kwargs: dict | None = None, ): # `pr_props`` is a variable to store properties of a plotly-resampler figure # This variable will only be set when loading a pickled plotly-resampler figure @@ -85,6 +86,8 @@ def __init__( # A single trace dict or a list of traces f.add_traces(figure) + self._show_dash_kwargs = show_dash_kwargs if show_dash_kwargs is not None else {} + super().__init__( f, convert_existing_traces, @@ -155,6 +158,8 @@ def show_dash( See more https://dash.plotly.com/dash-core-components/graph **kwargs: dict Additional app.run_server() kwargs. e.g.: port + Note that these kwargs take precedence over the ones passed to the + constructor via the ``show_dash_kwargs`` argument. """ graph_properties = {} if graph_properties is None else graph_properties @@ -175,14 +180,19 @@ def show_dash( # 2. Run the app if ( - self.layout.height is not None - and mode == "inline" + mode == "inline" and "height" not in kwargs ): - # If figure height is specified -> re-use is for inline dash app height - kwargs["height"] = self.layout.height + 18 + # If app height is not specified -> re-use figure height for inline dash app + # Note: default layout height is 450 (whereas default app height is 650) + # See: https://plotly.com/python/reference/layout/#layout-height + fig_height = self.layout.height if self.layout.height is not None else 450 + kwargs["height"] = fig_height + 18 + + # kwargs take precedence over the show_dash_kwargs + kwargs = {**self._show_dash_kwargs, **kwargs} - # store the app information, so it can be killed + # Store the app information, so it can be killed self._app = app self._host = kwargs.get("host", "127.0.0.1") self._port = kwargs.get("port", "8050") @@ -238,3 +248,11 @@ def register_update_graph_callback( dash.dependencies.Input(graph_id, "relayoutData"), prevent_initial_call=True, )(self.construct_update_data) + + def _get_pr_props_keys(self) -> List[str]: + # Add the additional plotly-resampler properties of this class + return super()._get_pr_props_keys() + ["_show_dash_kwargs"] + + def _ipython_display_(self): + # To display the figure inline as a dash app + self.show_dash(mode="inline") diff --git a/plotly_resampler/figure_resampler/figure_resampler_interface.py b/plotly_resampler/figure_resampler/figure_resampler_interface.py index 4038de90..0c4c1a07 100644 --- a/plotly_resampler/figure_resampler/figure_resampler_interface.py +++ b/plotly_resampler/figure_resampler/figure_resampler_interface.py @@ -1273,6 +1273,25 @@ def _re_matches(regex: re.Pattern, strings: Iterable[str]) -> List[str]: return sorted(matches) ## Magic methods (to use plotly.py words :grin:) + + def _get_pr_props_keys(self) -> List[str]: + """Returns the keys (i.e., the names) of the plotly-resampler properties. + + Note + ---- + This method is used to serialize the object in the `__reduce__` method. + + """ + return [ + "_hf_data", + "_global_n_shown_samples", + "_print_verbose", + "_show_mean_aggregation_size", + "_prefix", + "_suffix", + "_global_downsampler", + ] + def __reduce__(self): """Overwrite the reduce method (which is used to support deep copying and pickling). @@ -1288,15 +1307,6 @@ def __reduce__(self): # Add the plotly-resampler properties props["pr_props"] = {} - pr_keys = [ - "_hf_data", - "_global_n_shown_samples", - "_print_verbose", - "_show_mean_aggregation_size", - "_prefix", - "_suffix", - "_global_downsampler", - ] - for k in pr_keys: + for k in self._get_pr_props_keys(): props["pr_props"][k] = getattr(self, k) return (self.__class__, (props,)) # (props,) to comply with plotly magic diff --git a/tests/test_registering.py b/tests/test_registering.py index c4a1db4d..421872d8 100644 --- a/tests/test_registering.py +++ b/tests/test_registering.py @@ -127,9 +127,12 @@ def test_registering_plotly_express_and_kwargs(registering_cleanup): assert len(fig.data) == 1 assert len(fig.data[0].y) == 500 - register_plotly_resampler(default_n_shown_samples=50) + register_plotly_resampler( + default_n_shown_samples=50, show_dash_kwargs=dict(mode="inline", port=8051) + ) fig = px.scatter(y=np.arange(500)) assert isinstance(fig, FigureResampler) + assert fig._show_dash_kwargs == dict(mode="inline", port=8051) assert len(fig.data) == 1 assert len(fig.data[0].y) == 50 assert len(fig.hf_data) == 1 @@ -138,6 +141,7 @@ def test_registering_plotly_express_and_kwargs(registering_cleanup): register_plotly_resampler() fig = px.scatter(y=np.arange(5000)) assert isinstance(fig, FigureResampler) + assert fig._show_dash_kwargs == dict() assert len(fig.data) == 1 assert len(fig.data[0].y) == 1000 assert len(fig.hf_data) == 1 @@ -201,4 +205,3 @@ def test_compasibility_when_registered(registering_cleanup): assert len(f.data[0].y) == 1000 assert len(f.hf_data) == 1 assert len(f.hf_data[0]["y"]) == 1005 - \ No newline at end of file diff --git a/tests/test_serialization.py b/tests/test_serialization.py index cee4d0f3..182ae788 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,9 +1,11 @@ +from hashlib import sha1 import plotly.graph_objects as go import plotly.express as px import numpy as np import pickle import copy +from plotly.subplots import make_subplots from plotly_resampler import FigureResampler, FigureWidgetResampler from plotly_resampler.registering import ( register_plotly_resampler, @@ -23,14 +25,52 @@ def test_pickle_figure_resampler(pickle_figure): nb_traces = 3 nb_samples = 5_007 - fig = FigureResampler(default_n_shown_samples=50) + fig = FigureResampler(default_n_shown_samples=50, show_dash_kwargs=dict(port=8051)) for i in range(nb_traces): fig.add_trace(go.Scattergl(name=f"trace--{i}"), hf_y=np.arange(nb_samples)) + assert fig._show_dash_kwargs["port"] == 8051 pickle.dump(fig, open(pickle_figure, "wb")) fig_pickle = pickle.load(open(pickle_figure, "rb")) assert isinstance(fig_pickle, FigureResampler) + assert fig_pickle._show_dash_kwargs["port"] == 8051 + assert len(fig_pickle.data) == nb_traces + assert len(fig_pickle.hf_data) == nb_traces + for i in range(nb_traces): + trace = fig_pickle.data[i] + assert isinstance(trace, go.Scattergl) + assert len(trace.y) == 50 + assert f"trace--{i}" in trace.name + hf_trace = fig_pickle.hf_data[i] + assert len(hf_trace["y"]) == nb_samples + assert np.all(hf_trace["y"] == np.arange(nb_samples)) + + # Test for figure with subplots (check non-pickled private properties) + fig = FigureResampler( + make_subplots(rows=2, cols=1, shared_xaxes=True), + default_n_shown_samples=50, show_dash_kwargs=dict(port=8051), + ) + for i in range(nb_traces): + fig.add_trace( + go.Scattergl(name=f"trace--{i}"), hf_y=np.arange(nb_samples), + row=(i % 2) + 1, col=1, + ) + assert fig._global_n_shown_samples == 50 + assert fig._show_dash_kwargs["port"] == 8051 + assert fig._figure_class == go.Figure + assert fig._xaxis_list == ['xaxis', 'xaxis2'] + assert fig._yaxis_list == ['yaxis', 'yaxis2'] + + pickle.dump(fig, open(pickle_figure, "wb")) + fig_pickle = pickle.load(open(pickle_figure, "rb")) + + assert isinstance(fig_pickle, FigureResampler) + assert fig_pickle._global_n_shown_samples == 50 + assert fig_pickle._show_dash_kwargs["port"] == 8051 + assert fig_pickle._figure_class == go.Figure + assert fig_pickle._xaxis_list == ['xaxis', 'xaxis2'] + assert fig_pickle._yaxis_list == ['yaxis', 'yaxis2'] assert len(fig_pickle.data) == nb_traces assert len(fig_pickle.hf_data) == nb_traces for i in range(nb_traces): @@ -66,6 +106,40 @@ def test_pickle_figurewidget_resampler(pickle_figure): assert len(hf_trace["y"]) == nb_samples assert np.all(hf_trace["y"] == np.arange(nb_samples)) + # Test for figure with subplots (check non-pickled private properties) + fig = FigureWidgetResampler( + make_subplots(rows=2, cols=1, shared_xaxes=True), + default_n_shown_samples=50, + ) + for i in range(nb_traces): + fig.add_trace( + go.Scattergl(name=f"trace--{i}"), hf_y=np.arange(nb_samples), + row=(i % 2) + 1, col=1, + ) + assert fig._global_n_shown_samples == 50 + assert fig._figure_class == go.FigureWidget + assert fig._xaxis_list == ['xaxis', 'xaxis2'] + assert fig._yaxis_list == ['yaxis', 'yaxis2'] + + pickle.dump(fig, open(pickle_figure, "wb")) + fig_pickle = pickle.load(open(pickle_figure, "rb")) + + assert isinstance(fig_pickle, FigureWidgetResampler) + assert fig_pickle._global_n_shown_samples == 50 + assert fig_pickle._figure_class == go.FigureWidget + assert fig_pickle._xaxis_list == ['xaxis', 'xaxis2'] + assert fig_pickle._yaxis_list == ['yaxis', 'yaxis2'] + assert len(fig_pickle.data) == nb_traces + assert len(fig_pickle.hf_data) == nb_traces + for i in range(nb_traces): + trace = fig_pickle.data[i] + assert isinstance(trace, go.Scattergl) + assert len(trace.y) == 50 + assert f"trace--{i}" in trace.name + hf_trace = fig_pickle.hf_data[i] + assert len(hf_trace["y"]) == nb_samples + assert np.all(hf_trace["y"] == np.arange(nb_samples)) + ## Test pickling when registered @@ -73,13 +147,16 @@ def test_pickle_figure_resampler_registered(registering_cleanup, pickle_figure): nb_traces = 4 nb_samples = 5_043 - register_plotly_resampler(mode="figure", default_n_shown_samples=50) - + register_plotly_resampler( + mode="figure", default_n_shown_samples=50, show_dash_kwargs=dict(port=8051) + ) + fig = go.Figure() for i in range(nb_traces): fig.add_trace(go.Scattergl(name=f"trace--{i}"), hf_y=np.arange(nb_samples)) assert isinstance(fig, FigureResampler) assert not isinstance(fig, FigureWidgetResampler) + assert fig._show_dash_kwargs["port"] == 8051 pickle.dump(fig, open(pickle_figure, "wb")) @@ -87,6 +164,7 @@ def test_pickle_figure_resampler_registered(registering_cleanup, pickle_figure): assert isinstance(go.Figure(), FigureResampler) fig_pickle = pickle.load(open(pickle_figure, "rb")) assert isinstance(fig_pickle, FigureResampler) + assert fig_pickle._show_dash_kwargs["port"] == 8051 assert len(fig_pickle.data) == nb_traces assert len(fig_pickle.hf_data) == nb_traces for i in range(nb_traces): @@ -104,6 +182,7 @@ def test_pickle_figure_resampler_registered(registering_cleanup, pickle_figure): assert not isinstance(go.Figure(), FigureResampler) fig_pickle = pickle.load(open(pickle_figure, "rb")) assert isinstance(fig_pickle, FigureResampler) + assert fig_pickle._show_dash_kwargs["port"] == 8051 assert len(fig_pickle.data) == nb_traces assert len(fig_pickle.hf_data) == nb_traces for i in range(nb_traces): @@ -121,6 +200,7 @@ def test_pickle_figure_resampler_registered(registering_cleanup, pickle_figure): assert not isinstance(go.Figure(), FigureResampler) fig_pickle = pickle.load(open(pickle_figure, "rb")) assert isinstance(fig_pickle, FigureResampler) + assert fig_pickle._show_dash_kwargs["port"] == 8051 assert len(fig_pickle.data) == nb_traces assert len(fig_pickle.hf_data) == nb_traces for i in range(nb_traces): @@ -138,6 +218,7 @@ def test_pickle_figure_resampler_registered(registering_cleanup, pickle_figure): pickle.dump(fig, open(pickle_figure, "wb")) fig_pickle = pickle.load(open(pickle_figure, "rb")) assert isinstance(fig_pickle, FigureResampler) + assert fig_pickle._show_dash_kwargs["port"] == 8051 assert len(fig_pickle.data) == nb_traces assert len(fig_pickle.hf_data) == nb_traces for i in range(nb_traces): @@ -155,7 +236,7 @@ def test_pickle_figurewidget_resampler_registered(registering_cleanup, pickle_fi nb_samples = 3_643 register_plotly_resampler(mode="widget", default_n_shown_samples=50) - + fig = go.Figure() for i in range(nb_traces): fig.add_trace(go.Scattergl(name=f"trace--{i}"), hf_y=np.arange(nb_samples)) @@ -239,13 +320,15 @@ def test_copy_and_deepcopy_figure_resampler(): nb_traces = 3 nb_samples = 3_243 - fig = FigureResampler(default_n_shown_samples=50) + fig = FigureResampler(default_n_shown_samples=50, show_dash_kwargs=dict(port=8051)) for i in range(nb_traces): fig.add_trace(go.Scattergl(name=f"trace--{i}"), hf_y=np.arange(nb_samples)) + assert fig._show_dash_kwargs["port"] == 8051 fig_copy = copy.copy(fig) assert isinstance(fig_copy, FigureResampler) + assert fig_copy._show_dash_kwargs["port"] == 8051 assert len(fig_copy.data) == nb_traces assert len(fig_copy.hf_data) == nb_traces for i in range(nb_traces): @@ -260,6 +343,7 @@ def test_copy_and_deepcopy_figure_resampler(): fig_copy = copy.deepcopy(fig) assert isinstance(fig_copy, FigureResampler) + assert fig_copy._show_dash_kwargs["port"] == 8051 assert len(fig_copy.data) == nb_traces assert len(fig_copy.hf_data) == nb_traces for i in range(nb_traces): @@ -270,7 +354,7 @@ def test_copy_and_deepcopy_figure_resampler(): hf_trace = fig_copy.hf_data[i] assert len(hf_trace["y"]) == nb_samples assert np.all(hf_trace["y"] == np.arange(nb_samples)) - + def test_copy_and_deepcopy_figurewidget_resampler(): nb_traces = 3 @@ -307,25 +391,30 @@ def test_copy_and_deepcopy_figurewidget_resampler(): hf_trace = fig_copy.hf_data[i] assert len(hf_trace["y"]) == nb_samples assert np.all(hf_trace["y"] == np.arange(nb_samples)) - - ## Test basic (deep)copy with PR registered + + +## Test basic (deep)copy with PR registered def test_copy_figure_resampler_registered(): nb_traces = 3 nb_samples = 4_069 - register_plotly_resampler(mode="figure", default_n_shown_samples=50) - + register_plotly_resampler( + mode="figure", default_n_shown_samples=50, show_dash_kwargs=dict(port=8051) + ) + fig = go.Figure() for i in range(nb_traces): fig.add_trace(go.Scattergl(name=f"trace--{i}"), hf_y=np.arange(nb_samples)) assert isinstance(fig, FigureResampler) assert not isinstance(fig, FigureWidgetResampler) + assert fig._show_dash_kwargs["port"] == 8051 # Copy with PR registered fig_copy = copy.copy(fig) assert isinstance(go.Figure(), FigureResampler) assert isinstance(fig_copy, FigureResampler) + assert fig_copy._show_dash_kwargs["port"] == 8051 assert len(fig_copy.data) == nb_traces assert len(fig_copy.hf_data) == nb_traces for i in range(nb_traces): @@ -343,6 +432,7 @@ def test_copy_figure_resampler_registered(): assert not isinstance(go.Figure(), FigureResampler) fig_copy = copy.copy(fig) assert isinstance(fig_copy, FigureResampler) + assert fig_copy._show_dash_kwargs["port"] == 8051 assert len(fig_copy.data) == nb_traces assert len(fig_copy.hf_data) == nb_traces for i in range(nb_traces): @@ -360,6 +450,7 @@ def test_copy_figure_resampler_registered(): assert not isinstance(go.Figure(), FigureResampler) fig_copy = copy.copy(fig) assert isinstance(fig_copy, FigureResampler) + assert fig_copy._show_dash_kwargs["port"] == 8051 assert len(fig_copy.data) == nb_traces assert len(fig_copy.hf_data) == nb_traces for i in range(nb_traces): @@ -376,18 +467,22 @@ def test_deepcopy_figure_resampler_registered(): nb_traces = 4 nb_samples = 3_169 - register_plotly_resampler(mode="figure", default_n_shown_samples=50) - + register_plotly_resampler( + mode="figure", default_n_shown_samples=50, show_dash_kwargs=dict(port=8051) + ) + fig = go.Figure() for i in range(nb_traces): fig.add_trace(go.Scattergl(name=f"trace--{i}"), hf_y=np.arange(nb_samples)) assert isinstance(fig, FigureResampler) assert not isinstance(fig, FigureWidgetResampler) + assert fig._show_dash_kwargs["port"] == 8051 # Copy with PR registered fig_copy = copy.deepcopy(fig) assert isinstance(go.Figure(), FigureResampler) assert isinstance(fig_copy, FigureResampler) + assert fig_copy._show_dash_kwargs["port"] == 8051 assert len(fig_copy.data) == nb_traces assert len(fig_copy.hf_data) == nb_traces for i in range(nb_traces): @@ -405,6 +500,7 @@ def test_deepcopy_figure_resampler_registered(): assert not isinstance(go.Figure(), FigureResampler) fig_copy = copy.deepcopy(fig) assert isinstance(fig_copy, FigureResampler) + assert fig_copy._show_dash_kwargs["port"] == 8051 assert len(fig_copy.data) == nb_traces assert len(fig_copy.hf_data) == nb_traces for i in range(nb_traces): @@ -422,6 +518,7 @@ def test_deepcopy_figure_resampler_registered(): assert not isinstance(go.Figure(), FigureResampler) fig_copy = copy.deepcopy(fig) assert isinstance(fig_copy, FigureResampler) + assert fig_copy._show_dash_kwargs["port"] == 8051 assert len(fig_copy.data) == nb_traces assert len(fig_copy.hf_data) == nb_traces for i in range(nb_traces): @@ -439,7 +536,7 @@ def test_copy_figurewidget_resampler_registered(): nb_samples = 3_012 register_plotly_resampler(mode="widget", default_n_shown_samples=50) - + fig = go.Figure() for i in range(nb_traces): fig.add_trace(go.Scattergl(name=f"trace--{i}"), hf_y=np.arange(nb_samples)) @@ -501,7 +598,7 @@ def test_deepcopy_figurewidget_resampler_registered(): nb_samples = 3_012 register_plotly_resampler(mode="widget", default_n_shown_samples=50) - + fig = go.Figure() for i in range(nb_traces): fig.add_trace(go.Scattergl(name=f"trace--{i}"), hf_y=np.arange(nb_samples))