diff --git a/plotly_resampler/figure_resampler/figure_resampler.py b/plotly_resampler/figure_resampler/figure_resampler.py index ed109d9e..1439f324 100644 --- a/plotly_resampler/figure_resampler/figure_resampler.py +++ b/plotly_resampler/figure_resampler/figure_resampler.py @@ -55,8 +55,18 @@ def __init__( if isinstance(figure, BaseFigure): # go.FigureWidget or AbstractFigureAggregator # A base figure object, we first copy the layout and grid ref f.layout = figure.layout + f._grid_str = figure._grid_str f._grid_ref = figure._grid_ref f.add_traces(figure.data) + elif isinstance(figure, dict) and ( + "data" in figure or "layout" in figure # or "frames" in figure # TODO + ): + # A dict with data, layout or frames + f.layout = figure.get("layout") + f._grid_str = figure.get("_grid_str") + f._grid_ref = figure.get("_grid_ref") + f.add_traces(figure.get("data")) + # f.add_frames(figure.get("frames")) TODO elif isinstance(figure, (dict, list)): # A single trace dict or a list of traces f.add_traces(figure) diff --git a/plotly_resampler/figure_resampler/figure_resampler_interface.py b/plotly_resampler/figure_resampler/figure_resampler_interface.py index fff2f222..13b25b2d 100644 --- a/plotly_resampler/figure_resampler/figure_resampler_interface.py +++ b/plotly_resampler/figure_resampler/figure_resampler_interface.py @@ -113,6 +113,7 @@ def __init__( # call __init__ with the correct layout and set the `_grid_ref` of the # to-be-converted figure f_ = self._figure_class(layout=figure.layout) + f_._grid_str = figure._grid_str f_._grid_ref = figure._grid_ref super().__init__(f_) diff --git a/plotly_resampler/figure_resampler/figurewidget_resampler.py b/plotly_resampler/figure_resampler/figurewidget_resampler.py index c50f6118..b351d303 100644 --- a/plotly_resampler/figure_resampler/figurewidget_resampler.py +++ b/plotly_resampler/figure_resampler/figurewidget_resampler.py @@ -59,8 +59,17 @@ def __init__( if isinstance(figure, BaseFigure): # go.Figure or go.FigureWidget or AbstractFigureAggregator # A base figure object, we first copy the layout and grid ref f.layout = figure.layout + f._grid_str = figure._grid_str f._grid_ref = figure._grid_ref f.add_traces(figure.data) + elif isinstance(figure, dict) and ( + "data" in figure or "layout" in figure # or "frames" in figure # TODO + ): + f.layout = figure.get("layout") + f._grid_str = figure.get("_grid_str") + f._grid_ref = figure.get("_grid_ref") + f.add_traces(figure.get("data")) + # f.add_frames(figure.get("frames")) TODO elif isinstance(figure, (dict, list)): # A single trace dict or a list of traces f.add_traces(figure) diff --git a/tests/test_figure_resampler.py b/tests/test_figure_resampler.py index 677b2c54..8af12d15 100644 --- a/tests/test_figure_resampler.py +++ b/tests/test_figure_resampler.py @@ -648,7 +648,7 @@ def test_fr_add_empty_trace(): assert len(fig.hf_data[0]["y"]) == 0 -def test_fr_from_dict(): +def test_fr_from_trace_dict(): y = np.array([1] * 10_000) base_fig = { "type": "scatter", @@ -668,6 +668,24 @@ def test_fr_from_dict(): assert fr_fig.data[0].uid in fr_fig._hf_data +def test_fr_from_figure_dict(): + y = np.array([1] * 10_000) + base_fig = go.Figure() + base_fig.add_trace(go.Scatter(y=y)) + + fr_fig = FigureResampler(base_fig.to_dict(), default_n_shown_samples=1000) + assert len(fr_fig.hf_data) == 1 + assert (fr_fig.hf_data[0]["y"] == y).all() + assert len(fr_fig.data) == 1 + assert len(fr_fig.data[0]["x"]) == 1_000 + assert (fr_fig.data[0]["x"][0] >= 0) & (fr_fig.data[0]["x"][-1] < 10_000) + assert (fr_fig.data[0]["y"] == [1] * 1_000).all() + + # assert that all the uuids of data and hf_data match + # this is a proxy for assuring that the dynamic aggregation should work + assert fr_fig.data[0].uid in fr_fig._hf_data + + def test_fr_empty_list(): # and empty list -> so no concrete traces were added fr_fig = FigureResampler([], default_n_shown_samples=1000) @@ -927,3 +945,57 @@ def test_fr_object_binary_data(): assert fig.hf_data[0]["y"].dtype == "int64" assert fig.data[0]["y"].dtype == "int64" assert np.all(fig.data[0]["y"] == binary_series) + + +def test_fr_copy_grid(): + f = make_subplots(rows=2, cols=1) + f.add_scatter(y=np.arange(2_000), row=1, col=1) + f.add_scatter(y=np.arange(2_000), row=2, col=1) + + ## go.Figure + assert isinstance(f, go.Figure) + assert f._grid_ref is not None + fr = FigureResampler(f) + assert fr._grid_ref is not None + assert fr._grid_ref == f._grid_ref + + ## go.FigureWidget + fw = go.FigureWidget(f) + assert fw._grid_ref is not None + assert isinstance(fw, go.FigureWidget) + fr = FigureResampler(fw) + assert fr._grid_ref is not None + assert fr._grid_ref == fw._grid_ref + + ## FigureResampler + fr_ = FigureResampler(f) + assert fr_._grid_ref is not None + assert isinstance(fr_, FigureResampler) + fr = FigureResampler(fr_) + assert fr._grid_ref is not None + assert fr._grid_ref == fr_._grid_ref + + ## FigureWidgetResampler + from plotly_resampler import FigureWidgetResampler + fwr = FigureWidgetResampler(f) + assert fwr._grid_ref is not None + assert isinstance(fwr, FigureWidgetResampler) + fr = FigureResampler(fwr) + assert fr._grid_ref is not None + assert fr._grid_ref == fwr._grid_ref + + ## dict (with no _grid_ref) + f_dict = f.to_dict() + assert isinstance(f_dict, dict) + assert f_dict.get("_grid_ref") is None + fr = FigureResampler(f_dict) + assert fr._grid_ref is f_dict.get("_grid_ref") # both are None + + ## dict (with _grid_ref) + f_dict = f.to_dict() + f_dict["_grid_ref"] = f._grid_ref + assert isinstance(f_dict, dict) + assert f_dict.get("_grid_ref") is not None + fr = FigureResampler(f_dict) + assert fr._grid_ref is not None + assert fr._grid_ref == f_dict.get("_grid_ref") diff --git a/tests/test_figurewidget_resampler.py b/tests/test_figurewidget_resampler.py index 49fcbe90..dc26047f 100644 --- a/tests/test_figurewidget_resampler.py +++ b/tests/test_figurewidget_resampler.py @@ -1534,24 +1534,42 @@ def test_fwr_time_based_data_s(): assert (text == -hovertext).sum() == 1000 -def test_fwr_from_dict(): +def test_fwr_from_trace_dict(): y = np.array([1] * 10_000) base_fig = { "type": "scatter", "y": y, } - fr_fig = FigureWidgetResampler(base_fig, default_n_shown_samples=1000) - assert len(fr_fig.hf_data) == 1 - assert (fr_fig.hf_data[0]["y"] == y).all() - assert len(fr_fig.data) == 1 - assert len(fr_fig.data[0]["x"]) == 1_000 - assert (fr_fig.data[0]["x"][0] >= 0) & (fr_fig.data[0]["x"][-1] < 10_000) - assert (fr_fig.data[0]["y"] == [1] * 1_000).all() + fwr_fig = FigureWidgetResampler(base_fig, default_n_shown_samples=1000) + assert len(fwr_fig.hf_data) == 1 + assert (fwr_fig.hf_data[0]["y"] == y).all() + assert len(fwr_fig.data) == 1 + assert len(fwr_fig.data[0]["x"]) == 1_000 + assert (fwr_fig.data[0]["x"][0] >= 0) & (fwr_fig.data[0]["x"][-1] < 10_000) + assert (fwr_fig.data[0]["y"] == [1] * 1_000).all() # assert that all the uuids of data and hf_data match # this is a proxy for assuring that the dynamic aggregation should work - assert fr_fig.data[0].uid in fr_fig._hf_data + assert fwr_fig.data[0].uid in fwr_fig._hf_data + + +def test_fwr_from_figure_dict(): + y = np.array([1] * 10_000) + base_fig = go.Figure() + base_fig.add_trace(go.Scatter(y=y)) + + fwr_fig = FigureWidgetResampler(base_fig.to_dict(), default_n_shown_samples=1000) + assert len(fwr_fig.hf_data) == 1 + assert (fwr_fig.hf_data[0]["y"] == y).all() + assert len(fwr_fig.data) == 1 + assert len(fwr_fig.data[0]["x"]) == 1_000 + assert (fwr_fig.data[0]["x"][0] >= 0) & (fwr_fig.data[0]["x"][-1] < 10_000) + assert (fwr_fig.data[0]["y"] == [1] * 1_000).all() + + # assert that all the uuids of data and hf_data match + # this is a proxy for assuring that the dynamic aggregation should work + assert fwr_fig.data[0].uid in fwr_fig._hf_data def test_fwr_empty_list(): @@ -1796,3 +1814,57 @@ def test_fwr_object_binary_data(): assert fig.hf_data[0]["y"].dtype == "int64" assert fig.data[0]["y"].dtype == "int64" assert np.all(fig.data[0]["y"] == binary_series) + + +def test_fwr_copy_grid(): + f = make_subplots(rows=2, cols=1) + f.add_scatter(y=np.arange(2_000), row=1, col=1) + f.add_scatter(y=np.arange(2_000), row=2, col=1) + + ## go.Figure + assert isinstance(f, go.Figure) + assert f._grid_ref is not None + fwr = FigureWidgetResampler(f) + assert fwr._grid_ref is not None + assert fwr._grid_ref == f._grid_ref + + ## go.FigureWidget + fw = go.FigureWidget(f) + assert fw._grid_ref is not None + assert isinstance(fw, go.FigureWidget) + fwr = FigureWidgetResampler(fw) + assert fwr._grid_ref is not None + assert fwr._grid_ref == fw._grid_ref + + ## FigureWidgetResampler + fwr_ = FigureWidgetResampler(f) + assert fwr_._grid_ref is not None + assert isinstance(fwr_, FigureWidgetResampler) + fwr = FigureWidgetResampler(fwr_) + assert fwr._grid_ref is not None + assert fwr._grid_ref == fwr_._grid_ref + + ## FigureResampler + from plotly_resampler import FigureResampler + fr = FigureResampler(f) + assert fr._grid_ref is not None + assert isinstance(fr, FigureResampler) + fwr = FigureWidgetResampler(fr) + assert fwr._grid_ref is not None + assert fwr._grid_ref == fr._grid_ref + + ## dict (with no _grid_ref) + f_dict = f.to_dict() + assert isinstance(f_dict, dict) + assert f_dict.get("_grid_ref") is None + fwr = FigureWidgetResampler(f_dict) + assert fwr._grid_ref is f_dict.get("_grid_ref") # both are None + + ## dict (with _grid_ref) + f_dict = f.to_dict() + f_dict["_grid_ref"] = f._grid_ref + assert isinstance(f_dict, dict) + assert f_dict.get("_grid_ref") is not None + fwr = FigureWidgetResampler(f_dict) + assert fwr._grid_ref is not None + assert fwr._grid_ref == f_dict.get("_grid_ref")