diff --git a/dash_slicer/slicer.py b/dash_slicer/slicer.py index cabaf5a..f892635 100644 --- a/dash_slicer/slicer.py +++ b/dash_slicer/slicer.py @@ -4,7 +4,7 @@ from dash.dependencies import Input, Output, State, ALL from dash_core_components import Graph, Slider, Store -from .utils import img_array_to_uri, get_thumbnail_size_from_shape, shape3d_to_size2d +from .utils import img_array_to_uri, get_thumbnail_size, shape3d_to_size2d class VolumeSlicer: @@ -15,34 +15,29 @@ class VolumeSlicer: volume (ndarray): the 3D numpy array to slice through. The dimensions are assumed to be in zyx order. If this is not the case, you can use ``np.swapaxes`` to make it so. - spacing (tuple of floats): The distance between voxels for each dimension (zyx). - The spacing and origin are applied to make the slice drawn in - "scene space" rather than "voxel space". + spacing (tuple of floats): The distance between voxels for each + dimension (zyx).The spacing and origin are applied to make the slice + drawn in "scene space" rather than "voxel space". origin (tuple of floats): The offset for each dimension (zyx). axis (int): the dimension to slice in. Default 0. reverse_y (bool): Whether to reverse the y-axis, so that the origin of the slice is in the top-left, rather than bottom-left. Default True. - (This sets the figure's yaxes ``autorange`` to either "reversed" or True.) + (This sets the figure's yaxes ``autorange`` to "reversed" or True.) scene_id (str): the scene that this slicer is part of. Slicers that have the same scene-id show each-other's positions with - line indicators. By default this is a hash of ``id(volume)``. + line indicators. By default this is derived from ``id(volume)``. This is a placeholder object, not a Dash component. The components - that make up the slicer can be accessed as attributes: + that make up the slicer can be accessed as attributes. These must all + be present in the app layout: - * ``graph``: the Graph object. - * ``slider``: the Slider object. - * ``stores``: a list of Store objects. Some are "public" values, others - used internally. Make sure to put them somewhere in the layout. + * ``graph``: the dcc.Graph object. Use ``graph.figure`` to access the + Plotly figure object. + * ``slider``: the dcc.Slider object, its value represents the slice + index. If you don't want to use the slider, wrap it in a div with + style ``display: none``. + * ``stores``: a list of dcc.Store objects. - Each component is given a dict-id with the following keys: - - * "context": a unique string id for this slicer instance. - * "scene": the scene_id. - * "axis": the int axis. - * "name": the name of the (sub) component. - - TODO: iron out these details, list the stores that are public """ _global_slicer_counter = 0 @@ -58,10 +53,11 @@ def __init__( reverse_y=True, scene_id=None ): - # todo: also implement xyz dim order? + if not isinstance(app, Dash): raise TypeError("Expect first arg to be a Dash app.") self._app = app + # Check and store volume if not (isinstance(volume, np.ndarray) and volume.ndim == 3): raise TypeError("Expected volume to be a 3D numpy array") @@ -70,22 +66,26 @@ def __init__( spacing = float(spacing[0]), float(spacing[1]), float(spacing[2]) origin = (0, 0, 0) if origin is None else origin origin = float(origin[0]), float(origin[1]), float(origin[2]) + # Check and store axis if not (isinstance(axis, int) and 0 <= axis <= 2): raise ValueError("The given axis must be 0, 1, or 2.") self._axis = int(axis) - # Check and store id + self._reverse_y = bool(reverse_y) + + # Check and store scene id, and generate if scene_id is None: scene_id = "volume_" + hex(id(volume))[2:] elif not isinstance(scene_id, str): raise TypeError("scene_id must be a string") - self.scene_id = scene_id + self._scene_id = scene_id + # Get unique id scoped to this slicer object VolumeSlicer._global_slicer_counter += 1 - self.context_id = "slicer_" + str(VolumeSlicer._global_slicer_counter) + self._context_id = "slicer" + str(VolumeSlicer._global_slicer_counter) - # Prepare slice info - info = { + # Prepare slice info that we use at the client side + self._slice_info = { "shape": tuple(volume.shape), "axis": self._axis, "size": shape3d_to_size2d(volume.shape, axis), @@ -93,10 +93,70 @@ def __init__( "spacing": shape3d_to_size2d(spacing, axis), } + # Build the slicer + self._create_dash_components() + self._create_server_callbacks() + self._create_client_callbacks() + + # Note(AK): we could make some stores public, but let's do this only when actual use-cases arise? + + @property + def scene_id(self): + """The id of the "virtual scene" for this slicer. Slicers that have + the same scene_id show each-other's positions. + """ + return self._scene_id + + @property + def axis(self): + """The axis at which the slicer is slicing.""" + return self._axis + + @property + def graph(self): + """The dcc.Graph for this slicer.""" + return self._graph + + @property + def slider(self): + """The dcc.Slider to change the index for this slicer.""" + return self._slider + + @property + def stores(self): + """A list of dcc.Stores that the slicer needs to work. These must + be added to the app layout. + """ + return self._stores + + def _subid(self, name, use_dict=False): + """Given a name, get the full id including the context id prefix.""" + if use_dict: + # A dict-id is nice to query objects with pattern matching callbacks, + # and we use that to show the position of other sliders. But it makes + # the id's very long, which is annoying e.g. in the callback graph. + return { + "context": self._context_id, + "scene": self._scene_id, + "axis": self._axis, + "name": name, + } + else: + return self._context_id + "-" + name + + def _slice(self, index): + """Sample a slice from the volume.""" + indices = [slice(None), slice(None), slice(None)] + indices[self._axis] = index + im = self._volume[tuple(indices)] + return (im.astype(np.float32) * (255 / im.max())).astype(np.uint8) + + def _create_dash_components(self): + """Create the graph, slider, figure, etc.""" + info = self._slice_info + # Prep low-res slices - thumbnail_size = get_thumbnail_size_from_shape( - (info["size"][1], info["size"][0]), 32 - ) + thumbnail_size = get_thumbnail_size(info["size"][:2], (32, 32)) thumbnails = [ img_array_to_uri(self._slice(i), thumbnail_size) for i in range(info["size"][2]) @@ -109,35 +169,35 @@ def __init__( source="", dx=1, dy=1, hovertemplate="(%{x}, %{y})" ) scatter_trace = Scatter(x=[], y=[]) # placeholder - # Create the figure object + + # Create the figure object - can be accessed by user via slicer.graph.figure self._fig = fig = Figure(data=[image_trace, scatter_trace]) fig.update_layout( template=None, margin=dict(l=0, r=0, b=0, t=0, pad=4), ) fig.update_xaxes( - # range=(0, slice_size[0]), showgrid=False, showticklabels=False, zeroline=False, ) fig.update_yaxes( - # range=(slice_size[1], 0), # todo: allow flipping x or y showgrid=False, scaleanchor="x", showticklabels=False, zeroline=False, - autorange="reversed" if reverse_y else True, + autorange="reversed" if self._reverse_y else True, ) - # Wrap the figure in a graph - # todo: or should the user provide this? - self.graph = Graph( + + # Create the graph (graph is a Dash component wrapping a Plotly figure) + self._graph = Graph( id=self._subid("graph"), figure=fig, config={"scrollZoom": True}, ) + # Create a slider object that the user can put in the layout (or not) - self.slider = Slider( + self._slider = Slider( id=self._subid("slider"), min=0, max=info["size"][2] - 1, @@ -146,45 +206,30 @@ def __init__( tooltip={"always_visible": False, "placement": "left"}, updatemode="drag", ) + # Create the stores that we need (these must be present in the layout) - self.stores = [ - Store(id=self._subid("info"), data=info), - Store(id=self._subid("index"), data=volume.shape[self._axis] // 2), - Store(id=self._subid("position"), data=0), - Store(id=self._subid("_requested-slice-index"), data=0), - Store(id=self._subid("_slice-data"), data=""), - Store(id=self._subid("_slice-data-lowres"), data=thumbnails), - Store(id=self._subid("_indicators"), data=[]), + self._info = Store(id=self._subid("info"), data=info) + self._position = Store(id=self._subid("position", True), data=0) + self._requested_index = Store(id=self._subid("req-index"), data=0) + self._request_data = Store(id=self._subid("req-data"), data="") + self._lowres_data = Store(id=self._subid("lowres-data"), data=thumbnails) + self._indicators = Store(id=self._subid("indicators"), data=[]) + self._stores = [ + self._info, + self._position, + self._requested_index, + self._request_data, + self._lowres_data, + self._indicators, ] - self._create_server_callbacks() - self._create_client_callbacks() - - def _subid(self, name): - """Given a subid, get the full id including the slicer's prefix.""" - # return self.context_id + "-" + name - # todo: is there a penalty for using a dict-id vs a string-id? - return { - "context": self.context_id, - "scene": self.scene_id, - "axis": self._axis, - "name": name, - } - - def _slice(self, index): - """Sample a slice from the volume.""" - indices = [slice(None), slice(None), slice(None)] - indices[self._axis] = index - im = self._volume[tuple(indices)] - return (im.astype(np.float32) * (255 / im.max())).astype(np.uint8) - def _create_server_callbacks(self): """Create the callbacks that run server-side.""" app = self._app @app.callback( - Output(self._subid("_slice-data"), "data"), - [Input(self._subid("_requested-slice-index"), "data")], + Output(self._request_data.id, "data"), + [Input(self._requested_index.id, "data")], ) def upload_requested_slice(slice_index): slice = self._slice(slice_index) @@ -194,25 +239,15 @@ def _create_client_callbacks(self): """Create the callbacks that run client-side.""" app = self._app - app.clientside_callback( - """ - function handle_slider_move(index) { - return index; - } - """, - Output(self._subid("index"), "data"), - [Input(self._subid("slider"), "value")], - ) - app.clientside_callback( """ function update_position(index, info) { return info.origin[2] + index * info.spacing[2]; } """, - Output(self._subid("position"), "data"), - [Input(self._subid("index"), "data")], - [State(self._subid("info"), "data")], + Output(self._position.id, "data"), + [Input(self.slider.id, "value")], + [State(self._info.id, "data")], ) app.clientside_callback( @@ -228,21 +263,12 @@ def _create_client_callbacks(self): } } """.replace( - "{{ID}}", self.context_id + "{{ID}}", self._context_id ), - Output(self._subid("_requested-slice-index"), "data"), - [Input(self._subid("index"), "data")], + Output(self._requested_index.id, "data"), + [Input(self.slider.id, "value")], ) - # app.clientside_callback(""" - # function update_slider_pos(index) { - # return index; - # } - # """, - # [Output("index", "data")], - # [State("slider", "value")], - # ) - app.clientside_callback( """ function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, info) { @@ -282,18 +308,18 @@ def _create_client_callbacks(self): return figure; } """.replace( - "{{ID}}", self.context_id + "{{ID}}", self._context_id ), - Output(self._subid("graph"), "figure"), + Output(self.graph.id, "figure"), [ - Input(self._subid("index"), "data"), - Input(self._subid("_slice-data"), "data"), - Input(self._subid("_indicators"), "data"), + Input(self.slider.id, "value"), + Input(self._request_data.id, "data"), + Input(self._indicators.id, "data"), ], [ - State(self._subid("graph"), "figure"), - State(self._subid("_slice-data-lowres"), "data"), - State(self._subid("info"), "data"), + State(self.graph.id, "figure"), + State(self._lowres_data.id, "data"), + State(self._info.id, "data"), ], ) @@ -334,11 +360,11 @@ def _create_client_callbacks(self): }; } """, - Output(self._subid("_indicators"), "data"), + Output(self._indicators.id, "data"), [ Input( { - "scene": self.scene_id, + "scene": self._scene_id, "context": ALL, "name": "position", "axis": axis, @@ -348,7 +374,7 @@ def _create_client_callbacks(self): for axis in axii ], [ - State(self._subid("info"), "data"), - State(self._subid("_indicators"), "data"), + State(self._info.id, "data"), + State(self._indicators.id, "data"), ], ) diff --git a/dash_slicer/utils.py b/dash_slicer/utils.py index 3bb57a1..435ec11 100644 --- a/dash_slicer/utils.py +++ b/dash_slicer/utils.py @@ -1,5 +1,4 @@ import io -import random import base64 import numpy as np @@ -7,11 +6,8 @@ import skimage -def gen_random_id(n=6): - return "".join(random.choice("abcdefghijklmnopqrtsuvwxyz") for i in range(n)) - - def img_array_to_uri(img_array, new_size=None): + """Convert the given image (numpy array) into a base64-encoded PNG.""" img_array = skimage.util.img_as_ubyte(img_array) # todo: leverage this Plotly util once it becomes part of the public API (also drops the Pillow dependency) # from plotly.express._imshow import _array_to_b64str @@ -26,11 +22,13 @@ def img_array_to_uri(img_array, new_size=None): return "data:image/png;base64," + base64_str -def get_thumbnail_size_from_shape(shape, base_size): - base_size = int(base_size) - img_array = np.zeros(shape, np.uint8) +def get_thumbnail_size(size, new_size): + """Given an image size (w, h), and a preferred smaller size, + get the actual size if we let Pillow downscale it. + """ + img_array = np.zeros(list(reversed(size)), np.uint8) img_pil = PIL.Image.fromarray(img_array) - img_pil.thumbnail((base_size, base_size)) + img_pil.thumbnail(new_size) return img_pil.size diff --git a/examples/bring_your_own_slider.py b/examples/bring_your_own_slider.py new file mode 100644 index 0000000..265ef98 --- /dev/null +++ b/examples/bring_your_own_slider.py @@ -0,0 +1,47 @@ +""" +Bring your own slider ... or dropdown. This example shows how to use a +different input element for the slice index. The slider's value is used +as an output, but the slider element itself is hidden. +""" + +import dash +import dash_html_components as html +import dash_core_components as dcc +from dash.dependencies import Input, Output +from dash_slicer import VolumeSlicer +import imageio + + +app = dash.Dash(__name__) + +vol = imageio.volread("imageio:stent.npz") +slicer = VolumeSlicer(app, vol) + +dropdown = dcc.Dropdown( + id="dropdown", + options=[{"label": f"slice {i}", "value": i} for i in range(0, vol.shape[0], 10)], + value=50, +) + + +# Define the layout +app.layout = html.Div( + [ + slicer.graph, + dropdown, + html.Div(slicer.slider, style={"display": "none"}), + *slicer.stores, + ] +) + + +@app.callback( + Output(slicer.slider.id, "value"), + [Input(dropdown.id, "value")], +) +def handle_dropdown_input(index): + return index + + +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/examples/use_components.py b/examples/slicer_customized.py similarity index 70% rename from examples/use_components.py rename to examples/slicer_customized.py index 42d309e..4e4347b 100644 --- a/examples/use_components.py +++ b/examples/slicer_customized.py @@ -1,6 +1,6 @@ """ -A small example showing how to write callbacks involving the slicer's -components. The slicer's components are used as both inputs and outputs. +An example showing how to customize the slicer and write callbacks +involving the slicer's components. """ import dash @@ -15,10 +15,20 @@ vol = imageio.volread("imageio:stent.npz") slicer = VolumeSlicer(app, vol) + # We can access the components, and modify them slicer.slider.value = 0 -# Define the layour, including extra buttons +# The graph can be configured +slicer.graph.config.update({"modeBarButtonsToAdd": ["drawclosedpath", "eraseshape"]}) + +# The plotly figure can be accessed too +slicer.graph.figure.update_layout(margin=dict(l=0, r=0, b=30, t=0, pad=4)) +slicer.graph.figure.update_xaxes(showgrid=True, showticklabels=True) +slicer.graph.figure.update_yaxes(showgrid=True, showticklabels=True) + + +# Define the layout, including extra buttons app.layout = html.Div( [ slicer.graph, diff --git a/setup.py b/setup.py index 343f16e..cabdf35 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ SUMMARY = "A volume slicer for Dash" -with open(f"{NAME}/__init__.py") as fh: +with open(f"{NAME.replace('-', '_')}/__init__.py") as fh: VERSION = re.search(r"__version__ = \"(.*?)\"", fh.read()).group(1) diff --git a/tests/test_slicer.py b/tests/test_slicer.py new file mode 100644 index 0000000..e90ec4f --- /dev/null +++ b/tests/test_slicer.py @@ -0,0 +1,47 @@ +from dash_slicer import VolumeSlicer + +import numpy as np +from pytest import raises +import dash +import dash_core_components as dcc + + +def test_slicer_init(): + app = dash.Dash() + + vol = np.random.uniform(0, 255, (100, 100, 100)).astype(np.uint8) + + # Need a valid volume + with raises(TypeError): + VolumeSlicer(app, [3, 4, 5]) + with raises(TypeError): + VolumeSlicer(app, vol[0]) + + # Need a valid axis + with raises(ValueError): + VolumeSlicer(app, vol, axis=4) + + # This works + s = VolumeSlicer(app, vol) + + # Check properties + assert isinstance(s.graph, dcc.Graph) + assert isinstance(s.slider, dcc.Slider) + assert isinstance(s.stores, list) + assert all(isinstance(store, dcc.Store) for store in s.stores) + + +def test_scene_id_and_context_id(): + app = dash.Dash() + + vol = np.random.uniform(0, 255, (100, 100, 100)).astype(np.uint8) + + s1 = VolumeSlicer(app, vol, axis=0) + s2 = VolumeSlicer(app, vol, axis=0) + s3 = VolumeSlicer(app, vol, axis=1) + + # The scene id's are equal, so indicators will match up + assert s1.scene_id == s2.scene_id and s1.scene_id == s3.scene_id + + # Context id's must be unique + assert s1._context_id != s2._context_id and s1._context_id != s3._context_id diff --git a/tests/test_utils.py b/tests/test_utils.py index e1e20ab..d8ee47e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,31 @@ -from dash_slicer.utils import shape3d_to_size2d +from dash_slicer.utils import img_array_to_uri, get_thumbnail_size, shape3d_to_size2d +import numpy as np from pytest import raises +def test_img_array_to_uri(): + + im = np.random.uniform(0, 255, (100, 100)).astype(np.uint8) + + r1 = img_array_to_uri(im) + r2 = img_array_to_uri(im, (32, 32)) + r3 = img_array_to_uri(im, (8, 8)) + + for r in (r1, r2, r3): + assert isinstance(r, str) + assert r.startswith("data:image/png;base64,") + + assert len(r1) > len(r2) > len(r3) + + +def test_get_thumbnail_size(): + + assert get_thumbnail_size((100, 100), (16, 16)) == (16, 16) + assert get_thumbnail_size((50, 100), (16, 16)) == (8, 16) + assert get_thumbnail_size((100, 100), (8, 16)) == (8, 8) + + def test_shape3d_to_size2d(): # shape -> z, y, x # size -> x, y, out-of-plane