Skip to content

Tweaking the API and cleaning up #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 130 additions & 104 deletions dash_slicer/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dash.dependencies import Input, Output, State, ALL
from dash_core_components import Graph, Slider, Store

from .utils import img_array_to_uri, get_thumbnail_size_from_shape, shape3d_to_size2d
from .utils import img_array_to_uri, get_thumbnail_size, shape3d_to_size2d


class VolumeSlicer:
Expand All @@ -15,34 +15,29 @@ class VolumeSlicer:
volume (ndarray): the 3D numpy array to slice through. The dimensions
are assumed to be in zyx order. If this is not the case, you can
use ``np.swapaxes`` to make it so.
spacing (tuple of floats): The distance between voxels for each dimension (zyx).
The spacing and origin are applied to make the slice drawn in
"scene space" rather than "voxel space".
spacing (tuple of floats): The distance between voxels for each
dimension (zyx).The spacing and origin are applied to make the slice
drawn in "scene space" rather than "voxel space".
origin (tuple of floats): The offset for each dimension (zyx).
axis (int): the dimension to slice in. Default 0.
reverse_y (bool): Whether to reverse the y-axis, so that the origin of
the slice is in the top-left, rather than bottom-left. Default True.
(This sets the figure's yaxes ``autorange`` to either "reversed" or True.)
(This sets the figure's yaxes ``autorange`` to "reversed" or True.)
scene_id (str): the scene that this slicer is part of. Slicers
that have the same scene-id show each-other's positions with
line indicators. By default this is a hash of ``id(volume)``.
line indicators. By default this is derived from ``id(volume)``.

This is a placeholder object, not a Dash component. The components
that make up the slicer can be accessed as attributes:
that make up the slicer can be accessed as attributes. These must all
be present in the app layout:

* ``graph``: the Graph object.
* ``slider``: the Slider object.
* ``stores``: a list of Store objects. Some are "public" values, others
used internally. Make sure to put them somewhere in the layout.
* ``graph``: the dcc.Graph object. Use ``graph.figure`` to access the
Plotly figure object.
* ``slider``: the dcc.Slider object, its value represents the slice
index. If you don't want to use the slider, wrap it in a div with
style ``display: none``.
* ``stores``: a list of dcc.Store objects.

Each component is given a dict-id with the following keys:

* "context": a unique string id for this slicer instance.
* "scene": the scene_id.
* "axis": the int axis.
* "name": the name of the (sub) component.

TODO: iron out these details, list the stores that are public
"""

_global_slicer_counter = 0
Expand All @@ -58,10 +53,11 @@ def __init__(
reverse_y=True,
scene_id=None
):
# todo: also implement xyz dim order?

if not isinstance(app, Dash):
raise TypeError("Expect first arg to be a Dash app.")
self._app = app

# Check and store volume
if not (isinstance(volume, np.ndarray) and volume.ndim == 3):
raise TypeError("Expected volume to be a 3D numpy array")
Expand All @@ -70,33 +66,97 @@ def __init__(
spacing = float(spacing[0]), float(spacing[1]), float(spacing[2])
origin = (0, 0, 0) if origin is None else origin
origin = float(origin[0]), float(origin[1]), float(origin[2])

# Check and store axis
if not (isinstance(axis, int) and 0 <= axis <= 2):
raise ValueError("The given axis must be 0, 1, or 2.")
self._axis = int(axis)
# Check and store id
self._reverse_y = bool(reverse_y)

# Check and store scene id, and generate
if scene_id is None:
scene_id = "volume_" + hex(id(volume))[2:]
elif not isinstance(scene_id, str):
raise TypeError("scene_id must be a string")
self.scene_id = scene_id
self._scene_id = scene_id

# Get unique id scoped to this slicer object
VolumeSlicer._global_slicer_counter += 1
self.context_id = "slicer_" + str(VolumeSlicer._global_slicer_counter)
self._context_id = "slicer" + str(VolumeSlicer._global_slicer_counter)

# Prepare slice info
info = {
# Prepare slice info that we use at the client side
self._slice_info = {
"shape": tuple(volume.shape),
"axis": self._axis,
"size": shape3d_to_size2d(volume.shape, axis),
"origin": shape3d_to_size2d(origin, axis),
"spacing": shape3d_to_size2d(spacing, axis),
}

# Build the slicer
self._create_dash_components()
self._create_server_callbacks()
self._create_client_callbacks()

# Note(AK): we could make some stores public, but let's do this only when actual use-cases arise?

@property
def scene_id(self):
"""The id of the "virtual scene" for this slicer. Slicers that have
the same scene_id show each-other's positions.
"""
return self._scene_id

@property
def axis(self):
"""The axis at which the slicer is slicing."""
return self._axis

@property
def graph(self):
"""The dcc.Graph for this slicer."""
return self._graph

@property
def slider(self):
"""The dcc.Slider to change the index for this slicer."""
return self._slider

@property
def stores(self):
"""A list of dcc.Stores that the slicer needs to work. These must
be added to the app layout.
"""
return self._stores

def _subid(self, name, use_dict=False):
"""Given a name, get the full id including the context id prefix."""
if use_dict:
# A dict-id is nice to query objects with pattern matching callbacks,
# and we use that to show the position of other sliders. But it makes
# the id's very long, which is annoying e.g. in the callback graph.
return {
"context": self._context_id,
"scene": self._scene_id,
"axis": self._axis,
"name": name,
}
else:
return self._context_id + "-" + name

def _slice(self, index):
"""Sample a slice from the volume."""
indices = [slice(None), slice(None), slice(None)]
indices[self._axis] = index
im = self._volume[tuple(indices)]
return (im.astype(np.float32) * (255 / im.max())).astype(np.uint8)

def _create_dash_components(self):
"""Create the graph, slider, figure, etc."""
info = self._slice_info

# Prep low-res slices
thumbnail_size = get_thumbnail_size_from_shape(
(info["size"][1], info["size"][0]), 32
)
thumbnail_size = get_thumbnail_size(info["size"][:2], (32, 32))
thumbnails = [
img_array_to_uri(self._slice(i), thumbnail_size)
for i in range(info["size"][2])
Expand All @@ -109,35 +169,35 @@ def __init__(
source="", dx=1, dy=1, hovertemplate="(%{x}, %{y})<extra></extra>"
)
scatter_trace = Scatter(x=[], y=[]) # placeholder
# Create the figure object

# Create the figure object - can be accessed by user via slicer.graph.figure
self._fig = fig = Figure(data=[image_trace, scatter_trace])
fig.update_layout(
template=None,
margin=dict(l=0, r=0, b=0, t=0, pad=4),
)
fig.update_xaxes(
# range=(0, slice_size[0]),
showgrid=False,
showticklabels=False,
zeroline=False,
)
fig.update_yaxes(
# range=(slice_size[1], 0), # todo: allow flipping x or y
showgrid=False,
scaleanchor="x",
showticklabels=False,
zeroline=False,
autorange="reversed" if reverse_y else True,
autorange="reversed" if self._reverse_y else True,
)
# Wrap the figure in a graph
# todo: or should the user provide this?
self.graph = Graph(

# Create the graph (graph is a Dash component wrapping a Plotly figure)
self._graph = Graph(
id=self._subid("graph"),
figure=fig,
config={"scrollZoom": True},
)

# Create a slider object that the user can put in the layout (or not)
self.slider = Slider(
self._slider = Slider(
id=self._subid("slider"),
min=0,
max=info["size"][2] - 1,
Expand All @@ -146,45 +206,30 @@ def __init__(
tooltip={"always_visible": False, "placement": "left"},
updatemode="drag",
)

# Create the stores that we need (these must be present in the layout)
self.stores = [
Store(id=self._subid("info"), data=info),
Store(id=self._subid("index"), data=volume.shape[self._axis] // 2),
Store(id=self._subid("position"), data=0),
Store(id=self._subid("_requested-slice-index"), data=0),
Store(id=self._subid("_slice-data"), data=""),
Store(id=self._subid("_slice-data-lowres"), data=thumbnails),
Store(id=self._subid("_indicators"), data=[]),
self._info = Store(id=self._subid("info"), data=info)
self._position = Store(id=self._subid("position", True), data=0)
self._requested_index = Store(id=self._subid("req-index"), data=0)
self._request_data = Store(id=self._subid("req-data"), data="")
self._lowres_data = Store(id=self._subid("lowres-data"), data=thumbnails)
self._indicators = Store(id=self._subid("indicators"), data=[])
self._stores = [
self._info,
self._position,
self._requested_index,
self._request_data,
self._lowres_data,
self._indicators,
]

self._create_server_callbacks()
self._create_client_callbacks()

def _subid(self, name):
"""Given a subid, get the full id including the slicer's prefix."""
# return self.context_id + "-" + name
# todo: is there a penalty for using a dict-id vs a string-id?
return {
"context": self.context_id,
"scene": self.scene_id,
"axis": self._axis,
"name": name,
}

def _slice(self, index):
"""Sample a slice from the volume."""
indices = [slice(None), slice(None), slice(None)]
indices[self._axis] = index
im = self._volume[tuple(indices)]
return (im.astype(np.float32) * (255 / im.max())).astype(np.uint8)

def _create_server_callbacks(self):
"""Create the callbacks that run server-side."""
app = self._app

@app.callback(
Output(self._subid("_slice-data"), "data"),
[Input(self._subid("_requested-slice-index"), "data")],
Output(self._request_data.id, "data"),
[Input(self._requested_index.id, "data")],
)
def upload_requested_slice(slice_index):
slice = self._slice(slice_index)
Expand All @@ -194,25 +239,15 @@ def _create_client_callbacks(self):
"""Create the callbacks that run client-side."""
app = self._app

app.clientside_callback(
"""
function handle_slider_move(index) {
return index;
}
""",
Output(self._subid("index"), "data"),
[Input(self._subid("slider"), "value")],
)

app.clientside_callback(
"""
function update_position(index, info) {
return info.origin[2] + index * info.spacing[2];
}
""",
Output(self._subid("position"), "data"),
[Input(self._subid("index"), "data")],
[State(self._subid("info"), "data")],
Output(self._position.id, "data"),
[Input(self.slider.id, "value")],
[State(self._info.id, "data")],
)

app.clientside_callback(
Expand All @@ -228,21 +263,12 @@ def _create_client_callbacks(self):
}
}
""".replace(
"{{ID}}", self.context_id
"{{ID}}", self._context_id
),
Output(self._subid("_requested-slice-index"), "data"),
[Input(self._subid("index"), "data")],
Output(self._requested_index.id, "data"),
[Input(self.slider.id, "value")],
)

# app.clientside_callback("""
# function update_slider_pos(index) {
# return index;
# }
# """,
# [Output("index", "data")],
# [State("slider", "value")],
# )

app.clientside_callback(
"""
function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, info) {
Expand Down Expand Up @@ -282,18 +308,18 @@ def _create_client_callbacks(self):
return figure;
}
""".replace(
"{{ID}}", self.context_id
"{{ID}}", self._context_id
),
Output(self._subid("graph"), "figure"),
Output(self.graph.id, "figure"),
[
Input(self._subid("index"), "data"),
Input(self._subid("_slice-data"), "data"),
Input(self._subid("_indicators"), "data"),
Input(self.slider.id, "value"),
Input(self._request_data.id, "data"),
Input(self._indicators.id, "data"),
],
[
State(self._subid("graph"), "figure"),
State(self._subid("_slice-data-lowres"), "data"),
State(self._subid("info"), "data"),
State(self.graph.id, "figure"),
State(self._lowres_data.id, "data"),
State(self._info.id, "data"),
],
)

Expand Down Expand Up @@ -334,11 +360,11 @@ def _create_client_callbacks(self):
};
}
""",
Output(self._subid("_indicators"), "data"),
Output(self._indicators.id, "data"),
[
Input(
{
"scene": self.scene_id,
"scene": self._scene_id,
"context": ALL,
"name": "position",
"axis": axis,
Expand All @@ -348,7 +374,7 @@ def _create_client_callbacks(self):
for axis in axii
],
[
State(self._subid("info"), "data"),
State(self._subid("_indicators"), "data"),
State(self._info.id, "data"),
State(self._indicators.id, "data"),
],
)
Loading