Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions plotly_resampler/figure_resampler/figure_resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,18 @@ def __init__(
if isinstance(figure, BaseFigure): # go.FigureWidget or AbstractFigureAggregator
# A base figure object, we first copy the layout and grid ref
f.layout = figure.layout
f._grid_str = figure._grid_str
f._grid_ref = figure._grid_ref
f.add_traces(figure.data)
elif isinstance(figure, dict) and (
"data" in figure or "layout" in figure # or "frames" in figure # TODO
):
# A dict with data, layout or frames
f.layout = figure.get("layout")
f._grid_str = figure.get("_grid_str")
f._grid_ref = figure.get("_grid_ref")
f.add_traces(figure.get("data"))
# f.add_frames(figure.get("frames")) TODO
elif isinstance(figure, (dict, list)):
# A single trace dict or a list of traces
f.add_traces(figure)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
# call __init__ with the correct layout and set the `_grid_ref` of the
# to-be-converted figure
f_ = self._figure_class(layout=figure.layout)
f_._grid_str = figure._grid_str
f_._grid_ref = figure._grid_ref
super().__init__(f_)

Expand Down
9 changes: 9 additions & 0 deletions plotly_resampler/figure_resampler/figurewidget_resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,17 @@ def __init__(
if isinstance(figure, BaseFigure): # go.Figure or go.FigureWidget or AbstractFigureAggregator
# A base figure object, we first copy the layout and grid ref
f.layout = figure.layout
f._grid_str = figure._grid_str
f._grid_ref = figure._grid_ref
f.add_traces(figure.data)
elif isinstance(figure, dict) and (
"data" in figure or "layout" in figure # or "frames" in figure # TODO
):
f.layout = figure.get("layout")
f._grid_str = figure.get("_grid_str")
f._grid_ref = figure.get("_grid_ref")
f.add_traces(figure.get("data"))
# f.add_frames(figure.get("frames")) TODO
elif isinstance(figure, (dict, list)):
# A single trace dict or a list of traces
f.add_traces(figure)
Expand Down
74 changes: 73 additions & 1 deletion tests/test_figure_resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def test_fr_add_empty_trace():
assert len(fig.hf_data[0]["y"]) == 0


def test_fr_from_dict():
def test_fr_from_trace_dict():
y = np.array([1] * 10_000)
base_fig = {
"type": "scatter",
Expand All @@ -668,6 +668,24 @@ def test_fr_from_dict():
assert fr_fig.data[0].uid in fr_fig._hf_data


def test_fr_from_figure_dict():
y = np.array([1] * 10_000)
base_fig = go.Figure()
base_fig.add_trace(go.Scatter(y=y))

fr_fig = FigureResampler(base_fig.to_dict(), default_n_shown_samples=1000)
assert len(fr_fig.hf_data) == 1
assert (fr_fig.hf_data[0]["y"] == y).all()
assert len(fr_fig.data) == 1
assert len(fr_fig.data[0]["x"]) == 1_000
assert (fr_fig.data[0]["x"][0] >= 0) & (fr_fig.data[0]["x"][-1] < 10_000)
assert (fr_fig.data[0]["y"] == [1] * 1_000).all()

# assert that all the uuids of data and hf_data match
# this is a proxy for assuring that the dynamic aggregation should work
assert fr_fig.data[0].uid in fr_fig._hf_data


def test_fr_empty_list():
# and empty list -> so no concrete traces were added
fr_fig = FigureResampler([], default_n_shown_samples=1000)
Expand Down Expand Up @@ -927,3 +945,57 @@ def test_fr_object_binary_data():
assert fig.hf_data[0]["y"].dtype == "int64"
assert fig.data[0]["y"].dtype == "int64"
assert np.all(fig.data[0]["y"] == binary_series)


def test_fr_copy_grid():
f = make_subplots(rows=2, cols=1)
f.add_scatter(y=np.arange(2_000), row=1, col=1)
f.add_scatter(y=np.arange(2_000), row=2, col=1)

## go.Figure
assert isinstance(f, go.Figure)
assert f._grid_ref is not None
fr = FigureResampler(f)
assert fr._grid_ref is not None
assert fr._grid_ref == f._grid_ref

## go.FigureWidget
fw = go.FigureWidget(f)
assert fw._grid_ref is not None
assert isinstance(fw, go.FigureWidget)
fr = FigureResampler(fw)
assert fr._grid_ref is not None
assert fr._grid_ref == fw._grid_ref

## FigureResampler
fr_ = FigureResampler(f)
assert fr_._grid_ref is not None
assert isinstance(fr_, FigureResampler)
fr = FigureResampler(fr_)
assert fr._grid_ref is not None
assert fr._grid_ref == fr_._grid_ref

## FigureWidgetResampler
from plotly_resampler import FigureWidgetResampler
fwr = FigureWidgetResampler(f)
assert fwr._grid_ref is not None
assert isinstance(fwr, FigureWidgetResampler)
fr = FigureResampler(fwr)
assert fr._grid_ref is not None
assert fr._grid_ref == fwr._grid_ref

## dict (with no _grid_ref)
f_dict = f.to_dict()
assert isinstance(f_dict, dict)
assert f_dict.get("_grid_ref") is None
fr = FigureResampler(f_dict)
assert fr._grid_ref is f_dict.get("_grid_ref") # both are None

## dict (with _grid_ref)
f_dict = f.to_dict()
f_dict["_grid_ref"] = f._grid_ref
assert isinstance(f_dict, dict)
assert f_dict.get("_grid_ref") is not None
fr = FigureResampler(f_dict)
assert fr._grid_ref is not None
assert fr._grid_ref == f_dict.get("_grid_ref")
90 changes: 81 additions & 9 deletions tests/test_figurewidget_resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,24 +1534,42 @@ def test_fwr_time_based_data_s():
assert (text == -hovertext).sum() == 1000


def test_fwr_from_dict():
def test_fwr_from_trace_dict():
y = np.array([1] * 10_000)
base_fig = {
"type": "scatter",
"y": y,
}

fr_fig = FigureWidgetResampler(base_fig, default_n_shown_samples=1000)
assert len(fr_fig.hf_data) == 1
assert (fr_fig.hf_data[0]["y"] == y).all()
assert len(fr_fig.data) == 1
assert len(fr_fig.data[0]["x"]) == 1_000
assert (fr_fig.data[0]["x"][0] >= 0) & (fr_fig.data[0]["x"][-1] < 10_000)
assert (fr_fig.data[0]["y"] == [1] * 1_000).all()
fwr_fig = FigureWidgetResampler(base_fig, default_n_shown_samples=1000)
assert len(fwr_fig.hf_data) == 1
assert (fwr_fig.hf_data[0]["y"] == y).all()
assert len(fwr_fig.data) == 1
assert len(fwr_fig.data[0]["x"]) == 1_000
assert (fwr_fig.data[0]["x"][0] >= 0) & (fwr_fig.data[0]["x"][-1] < 10_000)
assert (fwr_fig.data[0]["y"] == [1] * 1_000).all()

# assert that all the uuids of data and hf_data match
# this is a proxy for assuring that the dynamic aggregation should work
assert fr_fig.data[0].uid in fr_fig._hf_data
assert fwr_fig.data[0].uid in fwr_fig._hf_data


def test_fwr_from_figure_dict():
y = np.array([1] * 10_000)
base_fig = go.Figure()
base_fig.add_trace(go.Scatter(y=y))

fwr_fig = FigureWidgetResampler(base_fig.to_dict(), default_n_shown_samples=1000)
assert len(fwr_fig.hf_data) == 1
assert (fwr_fig.hf_data[0]["y"] == y).all()
assert len(fwr_fig.data) == 1
assert len(fwr_fig.data[0]["x"]) == 1_000
assert (fwr_fig.data[0]["x"][0] >= 0) & (fwr_fig.data[0]["x"][-1] < 10_000)
assert (fwr_fig.data[0]["y"] == [1] * 1_000).all()

# assert that all the uuids of data and hf_data match
# this is a proxy for assuring that the dynamic aggregation should work
assert fwr_fig.data[0].uid in fwr_fig._hf_data


def test_fwr_empty_list():
Expand Down Expand Up @@ -1796,3 +1814,57 @@ def test_fwr_object_binary_data():
assert fig.hf_data[0]["y"].dtype == "int64"
assert fig.data[0]["y"].dtype == "int64"
assert np.all(fig.data[0]["y"] == binary_series)


def test_fwr_copy_grid():
f = make_subplots(rows=2, cols=1)
f.add_scatter(y=np.arange(2_000), row=1, col=1)
f.add_scatter(y=np.arange(2_000), row=2, col=1)

## go.Figure
assert isinstance(f, go.Figure)
assert f._grid_ref is not None
fwr = FigureWidgetResampler(f)
assert fwr._grid_ref is not None
assert fwr._grid_ref == f._grid_ref

## go.FigureWidget
fw = go.FigureWidget(f)
assert fw._grid_ref is not None
assert isinstance(fw, go.FigureWidget)
fwr = FigureWidgetResampler(fw)
assert fwr._grid_ref is not None
assert fwr._grid_ref == fw._grid_ref

## FigureWidgetResampler
fwr_ = FigureWidgetResampler(f)
assert fwr_._grid_ref is not None
assert isinstance(fwr_, FigureWidgetResampler)
fwr = FigureWidgetResampler(fwr_)
assert fwr._grid_ref is not None
assert fwr._grid_ref == fwr_._grid_ref

## FigureResampler
from plotly_resampler import FigureResampler
fr = FigureResampler(f)
assert fr._grid_ref is not None
assert isinstance(fr, FigureResampler)
fwr = FigureWidgetResampler(fr)
assert fwr._grid_ref is not None
assert fwr._grid_ref == fr._grid_ref

## dict (with no _grid_ref)
f_dict = f.to_dict()
assert isinstance(f_dict, dict)
assert f_dict.get("_grid_ref") is None
fwr = FigureWidgetResampler(f_dict)
assert fwr._grid_ref is f_dict.get("_grid_ref") # both are None

## dict (with _grid_ref)
f_dict = f.to_dict()
f_dict["_grid_ref"] = f._grid_ref
assert isinstance(f_dict, dict)
assert f_dict.get("_grid_ref") is not None
fwr = FigureWidgetResampler(f_dict)
assert fwr._grid_ref is not None
assert fwr._grid_ref == f_dict.get("_grid_ref")