From 925301cc78c15caa50dfbe681a3ddae50badc680 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Wed, 20 Aug 2014 13:22:41 -0400 Subject: [PATCH 01/20] NF: add version of Paul Ivanov's slice viewer Thanks Paul... --- nibabel/viewers.py | 215 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 nibabel/viewers.py diff --git a/nibabel/viewers.py b/nibabel/viewers.py new file mode 100644 index 0000000000..7ae4627d9a --- /dev/null +++ b/nibabel/viewers.py @@ -0,0 +1,215 @@ +""" Utilities for viewing images + +Includes version of OrthoSlicer3D code by our own Paul Ivanov +""" +from __future__ import division, print_function + +import numpy as np + +from .optpkg import optional_package + +plt, _, _ = optional_package('matplotlib.pyplot') +mpl_img, _, _ = optional_package('matplotlib.image') + +# Assumes the following layout +# +# ^ +---------+ ^ +---------+ +# | | | | | | +# | | | | +# z | 2 | z | 3 | +# | | | | +# | | | | | | +# v +---------+ v +---------+ +# <-- x --> <-- y --> +# ^ +---------+ +# | | | +# | | +# y | 1 | +# | | +# | | | +# v +---------+ +# <-- x --> + +class OrthoSlicer3D(object): + """Orthogonal-plane slicer. + + OrthoSlicer3d expects 3-dimensional data, and by default it creates a + figure with 3 axes, one for each slice orientation. + + There are two modes, "following on" and "following off". In "following on" + mode, moving the mouse in any one axis will select out the corresponding + slices in the other two. The mode is "following off" when the figure is + first created. Clicking the left mouse button toggles mouse following and + triggers a full redraw (to update the ticks, for example). Scrolling up and + down moves the slice up and down in the current axis. + + Example + ------- + import numpy as np + a = np.sin(np.linspace(0,np.pi,20)) + b = np.sin(np.linspace(0,np.pi*5,20)) + data = np.outer(a,b)[..., np.newaxis]*a + OrthoSlicer3D(data).show() + """ + def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', + pcnt_range=None): + """ + Parameters + ---------- + data : 3 dimensional ndarray + The data that will be displayed by the slicer + axes : None or length 3 sequence of mpl.Axes, optional + 3 axes instances for the X, Y, and Z slices, or None (default) + aspect_ratio : float or length 3 sequence, optional + stretch factors for X, Y, Z directions + cmap : colormap identifier, optional + String or cmap instance specifying colormap. Will be passed as + ``cmap`` argument to ``plt.imshow``. + pcnt_range : length 2 sequence, optional + Percentile range over which to scale image for display. If None, + scale between image mean and max. If sequence, min and max + percentile over which to scale image. + """ + data_shape = np.array(data.shape[:3]) # allow trailing RGB dimension + aspect_ratio = np.array(aspect_ratio) + if axes is None: # make the axes + # ^ +---------+ ^ +---------+ + # | | | | | | + # | | | | + # z | 2 | z | 3 | + # | | | | + # | | | | | | + # v +---------+ v +---------+ + # <-- x --> <-- y --> + # ^ +---------+ + # | | | + # | | + # y | 1 | + # | | + # | | | + # v +---------+ + # <-- x --> + fig = plt.figure() + x, y, z = data_shape * aspect_ratio + maxw = float(x + y) + maxh = float(y + z) + yh = y / maxh + xw = x / maxw + yw = y / maxw + zh = z / maxh + # z slice (if usual transverse acquisition => axial slice) + ax1 = fig.add_axes((0., 0., xw, yh)) + # y slice (usually coronal) + ax2 = fig.add_axes((0, yh, xw, zh)) + # x slice (usually sagittal) + ax3 = fig.add_axes((xw, yh, yw, zh)) + axes = (ax1, ax2, ax3) + else: + if not np.all(aspect_ratio == 1): + raise ValueError('Aspect ratio must be 1 for external axes') + ax1, ax2, ax3 = axes + + self.data = data + + if pcnt_range is None: + vmin, vmax = data.min(), data.max() + else: + vmin, vmax = np.percentile(data, pcnt_range) + + kw = dict(vmin=vmin, + vmax=vmax, + aspect='auto', + interpolation='nearest', + cmap=cmap, + origin='lower') + # Start midway through each axis + st_x, st_y, st_z = (data_shape - 1) / 2. + n_x, n_y, n_z = data_shape + z_get_slice = lambda i: self.data[:, :, min(i, n_z-1)].T + y_get_slice = lambda i: self.data[:, min(i, n_y-1), :].T + x_get_slice = lambda i: self.data[min(i, n_x-1), :, :].T + im1 = ax1.imshow(z_get_slice(st_z), **kw) + im2 = ax2.imshow(y_get_slice(st_y), **kw) + im3 = ax3.imshow(x_get_slice(st_x), **kw) + im1.get_slice, im2.get_slice, im3.get_slice = ( + z_get_slice, y_get_slice, x_get_slice) + # idx is the current slice number for each panel + im1.idx, im2.idx, im3.idx = st_z, st_y, st_x + # set the maximum dimensions for indexing + im1.size, im2.size, im3.size = n_z, n_y, n_x + # setup pairwise connections between the slice dimensions + im1.imx = im3 # x move in panel 1 (usually axial) + im1.imy = im2 # y move in panel 1 + im2.imx = im3 # x move in panel 2 (usually coronal) + im2.imy = im1 + im3.imx = im2 # x move in panel 3 (usually sagittal) + im3.imy = im1 + + self.follow = False + self.figs = set([ax.figure for ax in axes]) + for fig in self.figs: + fig.canvas.mpl_connect('button_press_event', self.on_click) + fig.canvas.mpl_connect('scroll_event', self.on_scroll) + fig.canvas.mpl_connect('motion_notify_event', self.on_mousemove) + + def show(self): + """ Show the slicer; convenience for ``plt.show()`` + """ + plt.show() + + def _axis_artist(self, event): + """ Return artist if within axes, and is an image, else None + """ + if not getattr(event, 'inaxes'): + return None + artist = event.inaxes.images[0] + return artist if isinstance(artist, mpl_img.AxesImage) else None + + def on_click(self, event): + if event.button == 1: + self.follow = not self.follow + plt.draw() + + def on_scroll(self, event): + assert event.button in ('up', 'down') + im = self._axis_artist(event) + if im is None: + return + im.idx += 1 if event.button == 'up' else -1 + im.idx %= im.size + im.set_data(im.get_slice(im.idx)) + ax = im.axes + ax.draw_artist(im) + ax.figure.canvas.blit(ax.bbox) + + def on_mousemove(self, event): + if not self.follow: + return + im = self._axis_artist(event) + if im is None: + return + ax = im.axes + imx, imy = im.imx, im.imy + x, y = np.round((event.xdata, event.ydata)).astype(int) + imx.set_data(imx.get_slice(x)) + imy.set_data(imy.get_slice(y)) + imx.idx = x + imy.idx = y + for i in imx, imy: + ax = i.axes + ax.draw_artist(i) + ax.figure.canvas.blit(ax.bbox) + + +if __name__ == '__main__': + a = np.sin(np.linspace(0,np.pi,20)) + b = np.sin(np.linspace(0,np.pi*5,20)) + data = np.outer(a,b)[..., np.newaxis]*a + # all slices + OrthoSlicer3D(data).show() + + # broken out into three separate figures + f, ax1 = plt.subplots() + f, ax2 = plt.subplots() + f, ax3 = plt.subplots() + OrthoSlicer3D(data, axes=(ax1, ax2, ax3)).show() From 778466f5f6fc87c75dad8a7928dd5bcad3e4f876 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Sun, 19 Oct 2014 11:59:03 -0700 Subject: [PATCH 02/20] ENH: Add crosshairs, modify mode --- .gitignore | 1 + .travis.yml | 1 + nibabel/__init__.py | 1 + nibabel/tests/test_viewers.py | 51 ++++++++++++ nibabel/viewers.py | 144 ++++++++++++++++++++-------------- 5 files changed, 139 insertions(+), 59 deletions(-) create mode 100644 nibabel/tests/test_viewers.py diff --git a/.gitignore b/.gitignore index 99c6f607c6..ca544aea99 100644 --- a/.gitignore +++ b/.gitignore @@ -58,6 +58,7 @@ dist/ .shelf .tox/ .coverage +cover/ # Logs and databases # ###################### diff --git a/.travis.yml b/.travis.yml index 6e4f97bf30..a37d7736d7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,6 +3,7 @@ # munges each line before executing it to print out the exit status. It's okay # for it to be on multiple physical lines, so long as you remember: - There # can't be any leading "-"s - All newlines will be removed, so use ";"s + language: python env: global: diff --git a/nibabel/__init__.py b/nibabel/__init__.py index a87469c56b..95fdf4af64 100644 --- a/nibabel/__init__.py +++ b/nibabel/__init__.py @@ -64,6 +64,7 @@ from .imageclasses import class_map, ext_map from . import trackvis from . import mriutils +from . import viewers # be friendly on systems with ancient numpy -- no tests, but at least # importable diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py new file mode 100644 index 0000000000..ce0713b4dc --- /dev/null +++ b/nibabel/tests/test_viewers.py @@ -0,0 +1,51 @@ +# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## +# +# See COPYING file distributed along with the NiBabel package for the +# copyright and license terms. +# +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## + +import numpy as np +from collections import namedtuple as nt + +from ..optpkg import optional_package +from ..viewers import OrthoSlicer3D + +from numpy.testing.decorators import skipif + +from nose.tools import assert_raises + +plt, has_mpl = optional_package('matplotlib.pyplot')[:2] +needs_mpl = skipif(not has_mpl, 'These tests need matplotlib') + + +@needs_mpl +def test_viewer(): + # Test viewer + a = np.sin(np.linspace(0, np.pi, 20)) + b = np.sin(np.linspace(0, np.pi*5, 30)) + data = np.outer(a, b)[..., np.newaxis] * a + viewer = OrthoSlicer3D(data) + plt.draw() + + # fake some events + viewer.on_scroll(nt('event', 'button inaxes')('up', None)) # outside axes + viewer.on_scroll(nt('event', 'button inaxes')('up', plt.gca())) # in axes + # tracking on + viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + None, 1)) + viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + plt.gca(), 1)) + # tracking off + viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + None, None)) + viewer.close() + + # other cases + fig, axes = plt.subplots(1, 3) + plt.close(fig) + OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes) + assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2, 3], + axes=axes) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 7ae4627d9a..bfdea2a489 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -1,10 +1,12 @@ """ Utilities for viewing images -Includes version of OrthoSlicer3D code by our own Paul Ivanov +Includes version of OrthoSlicer3D code originally written by our own +Paul Ivanov. """ from __future__ import division, print_function import numpy as np +from functools import partial from .optpkg import optional_package @@ -30,26 +32,32 @@ # v +---------+ # <-- x --> + +def _set_viewer_slice(idx, im): + """Helper to set a viewer slice number""" + im.idx = idx + im.set_data(im.get_slice(im.idx)) + for fun in im.cross_setters: + fun([idx] * 2) + + class OrthoSlicer3D(object): """Orthogonal-plane slicer. OrthoSlicer3d expects 3-dimensional data, and by default it creates a figure with 3 axes, one for each slice orientation. - There are two modes, "following on" and "following off". In "following on" - mode, moving the mouse in any one axis will select out the corresponding - slices in the other two. The mode is "following off" when the figure is - first created. Clicking the left mouse button toggles mouse following and - triggers a full redraw (to update the ticks, for example). Scrolling up and + Clicking and dragging the mouse in any one axis will select out the + corresponding slices in the other two. Scrolling up and down moves the slice up and down in the current axis. Example ------- - import numpy as np - a = np.sin(np.linspace(0,np.pi,20)) - b = np.sin(np.linspace(0,np.pi*5,20)) - data = np.outer(a,b)[..., np.newaxis]*a - OrthoSlicer3D(data).show() + >>> import numpy as np + >>> a = np.sin(np.linspace(0,np.pi,20)) + >>> b = np.sin(np.linspace(0,np.pi*5,20)) + >>> data = np.outer(a,b)[..., np.newaxis]*a + >>> OrthoSlicer3D(data).show() """ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', pcnt_range=None): @@ -70,9 +78,9 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', scale between image mean and max. If sequence, min and max percentile over which to scale image. """ - data_shape = np.array(data.shape[:3]) # allow trailing RGB dimension - aspect_ratio = np.array(aspect_ratio) - if axes is None: # make the axes + data_shape = np.array(data.shape[:3]) # allow trailing RGB dimension + aspect_ratio = np.array(aspect_ratio, float) + if axes is None: # make the axes # ^ +---------+ ^ +---------+ # | | | | | | # | | | | @@ -122,8 +130,10 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', interpolation='nearest', cmap=cmap, origin='lower') + # Start midway through each axis st_x, st_y, st_z = (data_shape - 1) / 2. + sts = (st_x, st_y, st_z) n_x, n_y, n_z = data_shape z_get_slice = lambda i: self.data[:, :, min(i, n_z-1)].T y_get_slice = lambda i: self.data[:, min(i, n_y-1), :].T @@ -133,22 +143,51 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', im3 = ax3.imshow(x_get_slice(st_x), **kw) im1.get_slice, im2.get_slice, im3.get_slice = ( z_get_slice, y_get_slice, x_get_slice) + self._ims = (im1, im2, im3) + # idx is the current slice number for each panel im1.idx, im2.idx, im3.idx = st_z, st_y, st_x + # set the maximum dimensions for indexing im1.size, im2.size, im3.size = n_z, n_y, n_x + + # set up axis crosshairs + colors = ['r', 'g', 'b'] + for ax, im, idx_1, idx_2 in zip(axes, self._ims, [0, 0, 1], [1, 2, 2]): + im.x_line = ax.plot([sts[idx_1]] * 2, + [-0.5, data.shape[idx_2] - 0.5], + color=colors[idx_1], linestyle='-', + alpha=0.25)[0] + im.y_line = ax.plot([-0.5, data.shape[idx_1] - 0.5], + [sts[idx_2]] * 2, + color=colors[idx_2], linestyle='-', + alpha=0.25)[0] + ax.axis('tight') + ax.patch.set_visible(False) + ax.set_frame_on(False) + ax.axes.get_yaxis().set_visible(False) + ax.axes.get_xaxis().set_visible(False) + + # monkey-patch some functions + im1.set_viewer_slice = partial(_set_viewer_slice, im=im1) + im2.set_viewer_slice = partial(_set_viewer_slice, im=im2) + im3.set_viewer_slice = partial(_set_viewer_slice, im=im3) + # setup pairwise connections between the slice dimensions - im1.imx = im3 # x move in panel 1 (usually axial) - im1.imy = im2 # y move in panel 1 - im2.imx = im3 # x move in panel 2 (usually coronal) - im2.imy = im1 - im3.imx = im2 # x move in panel 3 (usually sagittal) - im3.imy = im1 - - self.follow = False + im1.x_im = im3 # x move in panel 1 (usually axial) + im1.y_im = im2 # y move in panel 1 + im2.x_im = im3 # x move in panel 2 (usually coronal) + im2.y_im = im1 # y move in panel 2 + im3.x_im = im2 # x move in panel 3 (usually sagittal) + im3.y_im = im1 # y move in panel 3 + + # when an index changes, which crosshairs need to be updated + im1.cross_setters = [im2.y_line.set_ydata, im3.y_line.set_ydata] + im2.cross_setters = [im1.y_line.set_ydata, im3.x_line.set_xdata] + im3.cross_setters = [im1.x_line.set_xdata, im2.x_line.set_xdata] + self.figs = set([ax.figure for ax in axes]) for fig in self.figs: - fig.canvas.mpl_connect('button_press_event', self.on_click) fig.canvas.mpl_connect('scroll_event', self.on_scroll) fig.canvas.mpl_connect('motion_notify_event', self.on_mousemove) @@ -157,59 +196,46 @@ def show(self): """ plt.show() + def close(self): + """Close the viewer figures + """ + for f in self.figs: + plt.close(f) + def _axis_artist(self, event): - """ Return artist if within axes, and is an image, else None + """Return artist if within axes, and is an image, else None """ if not getattr(event, 'inaxes'): return None artist = event.inaxes.images[0] return artist if isinstance(artist, mpl_img.AxesImage) else None - def on_click(self, event): - if event.button == 1: - self.follow = not self.follow - plt.draw() - def on_scroll(self, event): assert event.button in ('up', 'down') im = self._axis_artist(event) if im is None: return - im.idx += 1 if event.button == 'up' else -1 - im.idx %= im.size - im.set_data(im.get_slice(im.idx)) - ax = im.axes - ax.draw_artist(im) - ax.figure.canvas.blit(ax.bbox) + idx = (im.idx + (1 if event.button == 'up' else -1)) + idx = max(min(idx, im.size - 1), 0) + im.set_viewer_slice(idx) + self._draw_ims() def on_mousemove(self, event): - if not self.follow: + if event.button != 1: # only enabled while dragging return im = self._axis_artist(event) if im is None: return - ax = im.axes - imx, imy = im.imx, im.imy + x_im, y_im = im.x_im, im.y_im x, y = np.round((event.xdata, event.ydata)).astype(int) - imx.set_data(imx.get_slice(x)) - imy.set_data(imy.get_slice(y)) - imx.idx = x - imy.idx = y - for i in imx, imy: - ax = i.axes - ax.draw_artist(i) + for i, idx in zip((x_im, y_im), (x, y)): + i.set_viewer_slice(idx) + self._draw_ims() + + def _draw_ims(self): + for im in self._ims: + ax = im.axes + ax.draw_artist(im) + ax.draw_artist(im.x_line) + ax.draw_artist(im.y_line) ax.figure.canvas.blit(ax.bbox) - - -if __name__ == '__main__': - a = np.sin(np.linspace(0,np.pi,20)) - b = np.sin(np.linspace(0,np.pi*5,20)) - data = np.outer(a,b)[..., np.newaxis]*a - # all slices - OrthoSlicer3D(data).show() - - # broken out into three separate figures - f, ax1 = plt.subplots() - f, ax2 = plt.subplots() - f, ax3 = plt.subplots() - OrthoSlicer3D(data, axes=(ax1, ax2, ax3)).show() From 0d51d848fc88c4ff0910dff8c272f886e85fd684 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Sun, 19 Oct 2014 12:14:43 -0700 Subject: [PATCH 03/20] FIX: Minor fixes --- nibabel/viewers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index bfdea2a489..53ecfd9f79 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -57,8 +57,9 @@ class OrthoSlicer3D(object): >>> a = np.sin(np.linspace(0,np.pi,20)) >>> b = np.sin(np.linspace(0,np.pi*5,20)) >>> data = np.outer(a,b)[..., np.newaxis]*a - >>> OrthoSlicer3D(data).show() + >>> OrthoSlicer3D(data).show() # doctest: +SKIP """ + # Skip doctest above b/c not all systems have mpl installed def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', pcnt_range=None): """ From f2b93f692b90bdb480d66bc73ad5a5f947085e28 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Sun, 19 Oct 2014 13:31:35 -0700 Subject: [PATCH 04/20] FIX: Xvfb --- .travis.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index a37d7736d7..df2984af2e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -45,7 +45,12 @@ before_install: - if [[ $DEPENDS == *h5py* ]]; then sudo apt-get install libhdf5-serial-dev; fi -# command to install dependencies + # Create a (fake) display on Travis so that mpl tests work + # TODO: Could also use the `agg` backend instead + - if [[ $DEPENDS == "*matplotlib*" ]]; then + export DISPLAY=:99.0; + /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset; + fi; install: - python setup.py install # Point to nibabel data directory From 956f3797a31b787c427a189b0e46f3583506de70 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Sun, 19 Oct 2014 19:22:17 -0700 Subject: [PATCH 05/20] ENH: Add set_indices method --- nibabel/tests/test_viewers.py | 10 +++--- nibabel/viewers.py | 57 ++++++++++++++++++++++------------- 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index ce0713b4dc..ec61c8dbca 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -33,14 +33,16 @@ def test_viewer(): # fake some events viewer.on_scroll(nt('event', 'button inaxes')('up', None)) # outside axes viewer.on_scroll(nt('event', 'button inaxes')('up', plt.gca())) # in axes - # tracking on + # "click" outside axes, then once in each axis, then move without click viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, 1)) - viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - plt.gca(), 1)) - # tracking off + for im in viewer._ims: + viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + im.axes, + 1)) viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None)) + viewer.set_indices(0, 1, 2) viewer.close() # other cases diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 53ecfd9f79..d71a76976e 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -35,10 +35,10 @@ def _set_viewer_slice(idx, im): """Helper to set a viewer slice number""" - im.idx = idx + im.idx = max(min(int(round(idx)), im.size - 1), 0) im.set_data(im.get_slice(im.idx)) for fun in im.cross_setters: - fun([idx] * 2) + fun([im.idx] * 2) class OrthoSlicer3D(object): @@ -133,24 +133,21 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', origin='lower') # Start midway through each axis - st_x, st_y, st_z = (data_shape - 1) / 2. - sts = (st_x, st_y, st_z) - n_x, n_y, n_z = data_shape - z_get_slice = lambda i: self.data[:, :, min(i, n_z-1)].T - y_get_slice = lambda i: self.data[:, min(i, n_y-1), :].T - x_get_slice = lambda i: self.data[min(i, n_x-1), :, :].T - im1 = ax1.imshow(z_get_slice(st_z), **kw) - im2 = ax2.imshow(y_get_slice(st_y), **kw) - im3 = ax3.imshow(x_get_slice(st_x), **kw) + z_get_slice = lambda i: self.data[:, :, i].T + y_get_slice = lambda i: self.data[:, i, :].T + x_get_slice = lambda i: self.data[i, :, :].T + sts = (data_shape - 1) // 2 + im1 = ax1.imshow(z_get_slice(sts[2]), **kw) + im2 = ax2.imshow(y_get_slice(sts[1]), **kw) + im3 = ax3.imshow(x_get_slice(sts[0]), **kw) + # idx is the current slice number for each panel + im1.idx, im2.idx, im3.idx = sts + self._ims = (im1, im2, im3) im1.get_slice, im2.get_slice, im3.get_slice = ( z_get_slice, y_get_slice, x_get_slice) - self._ims = (im1, im2, im3) - - # idx is the current slice number for each panel - im1.idx, im2.idx, im3.idx = st_z, st_y, st_x # set the maximum dimensions for indexing - im1.size, im2.size, im3.size = n_z, n_y, n_x + im1.size, im2.size, im3.size = data_shape # set up axis crosshairs colors = ['r', 'g', 'b'] @@ -191,6 +188,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', for fig in self.figs: fig.canvas.mpl_connect('scroll_event', self.on_scroll) fig.canvas.mpl_connect('motion_notify_event', self.on_mousemove) + fig.canvas.mpl_connect('button_press_event', self.on_mousemove) def show(self): """ Show the slicer; convenience for ``plt.show()`` @@ -203,6 +201,26 @@ def close(self): for f in self.figs: plt.close(f) + def set_indices(self, x=None, y=None, z=None): + """Set current displayed slice indices + + Parameters + ---------- + x : int | None + Index to use. If None, do not change. + y : int | None + Index to use. If None, do not change. + z : int | None + Index to use. If None, do not change. + """ + draw = False + for im, val in zip(self._ims, (z, y, x)): + if val is not None: + im.set_viewer_slice(val) + draw = True + if draw: + self._draw_ims() + def _axis_artist(self, event): """Return artist if within axes, and is an image, else None """ @@ -216,8 +234,7 @@ def on_scroll(self, event): im = self._axis_artist(event) if im is None: return - idx = (im.idx + (1 if event.button == 'up' else -1)) - idx = max(min(idx, im.size - 1), 0) + idx = im.idx + (1 if event.button == 'up' else -1) im.set_viewer_slice(idx) self._draw_ims() @@ -227,9 +244,7 @@ def on_mousemove(self, event): im = self._axis_artist(event) if im is None: return - x_im, y_im = im.x_im, im.y_im - x, y = np.round((event.xdata, event.ydata)).astype(int) - for i, idx in zip((x_im, y_im), (x, y)): + for i, idx in zip((im.x_im, im.y_im), (event.xdata, event.ydata)): i.set_viewer_slice(idx) self._draw_ims() From 1f3a890b66665012c9e07c0b836c7b1f5c8d0731 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 22 Oct 2014 17:08:03 -0700 Subject: [PATCH 06/20] ENH: Allow time dimension --- nibabel/spatialimages.py | 19 +++ nibabel/viewers.py | 358 +++++++++++++++++++++++---------------- 2 files changed, 228 insertions(+), 149 deletions(-) diff --git a/nibabel/spatialimages.py b/nibabel/spatialimages.py index 73104e4ca5..7380e65ee0 100644 --- a/nibabel/spatialimages.py +++ b/nibabel/spatialimages.py @@ -143,6 +143,7 @@ from .filename_parser import types_filenames, TypesFilenamesError from .fileholders import FileHolder +from .viewers import OrthoSlicer3D from .volumeutils import shape_zoom_affine @@ -744,3 +745,21 @@ def __getitem__(self): raise TypeError("Cannot slice image objects; consider slicing image " "array data with `img.dataobj[slice]` or " "`img.get_data()[slice]`") + + def plot(self, show=True): + """Plot the image using OrthoSlicer3D + + Parameters + ---------- + show : bool + If True, the viewer will be shown. + + Returns + ------- + viewer : instance of OrthoSlicer3D + The viewer. + """ + out = OrthoSlicer3D(self.get_data()) + if show: + out.show() + return out diff --git a/nibabel/viewers.py b/nibabel/viewers.py index d71a76976e..00fea018c2 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -6,40 +6,12 @@ from __future__ import division, print_function import numpy as np -from functools import partial from .optpkg import optional_package plt, _, _ = optional_package('matplotlib.pyplot') mpl_img, _, _ = optional_package('matplotlib.image') -# Assumes the following layout -# -# ^ +---------+ ^ +---------+ -# | | | | | | -# | | | | -# z | 2 | z | 3 | -# | | | | -# | | | | | | -# v +---------+ v +---------+ -# <-- x --> <-- y --> -# ^ +---------+ -# | | | -# | | -# y | 1 | -# | | -# | | | -# v +---------+ -# <-- x --> - - -def _set_viewer_slice(idx, im): - """Helper to set a viewer slice number""" - im.idx = max(min(int(round(idx)), im.size - 1), 0) - im.set_data(im.get_slice(im.idx)) - for fun in im.cross_setters: - fun([im.idx] * 2) - class OrthoSlicer3D(object): """Orthogonal-plane slicer. @@ -61,26 +33,42 @@ class OrthoSlicer3D(object): """ # Skip doctest above b/c not all systems have mpl installed def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', - pcnt_range=None): + pcnt_range=(1., 99.), figsize=(8, 8)): """ Parameters ---------- - data : 3 dimensional ndarray - The data that will be displayed by the slicer - axes : None or length 3 sequence of mpl.Axes, optional - 3 axes instances for the X, Y, and Z slices, or None (default) - aspect_ratio : float or length 3 sequence, optional - stretch factors for X, Y, Z directions - cmap : colormap identifier, optional + data : ndarray + The data that will be displayed by the slicer. Should have 3+ + dimensions. + axes : tuple of mpl.Axes | None, optional + 3 or 4 axes instances for the X, Y, Z slices plus volumes, + or None (default). + aspect_ratio : array-like, optional + Stretch factors for X, Y, Z directions. + cmap : str | instance of cmap, optional String or cmap instance specifying colormap. Will be passed as ``cmap`` argument to ``plt.imshow``. - pcnt_range : length 2 sequence, optional + pcnt_range : array-like, optional Percentile range over which to scale image for display. If None, scale between image mean and max. If sequence, min and max percentile over which to scale image. + figsize : tuple + Figure size (in inches) to use if axes are None. """ - data_shape = np.array(data.shape[:3]) # allow trailing RGB dimension - aspect_ratio = np.array(aspect_ratio, float) + ar = np.array(aspect_ratio, float) + if ar.shape != (3,) or np.any(ar <= 0): + raise ValueError('aspect ratio must have exactly 3 elements >= 0') + aspect_ratio = dict(x=ar[0], y=ar[1], z=ar[2]) + data = np.asanyarray(data) + if data.ndim < 3: + raise RuntimeError('data must have at least 3 dimensions') + self._volume_dims = data.shape[3:] + self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data + self._data = data + pcnt_range = (0, 100) if pcnt_range is None else pcnt_range + vmin, vmax = np.percentile(data, pcnt_range) + del data + if axes is None: # make the axes # ^ +---------+ ^ +---------+ # | | | | | | @@ -90,105 +78,110 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', # | | | | | | # v +---------+ v +---------+ # <-- x --> <-- y --> - # ^ +---------+ - # | | | - # | | - # y | 1 | - # | | - # | | | - # v +---------+ - # <-- x --> - fig = plt.figure() - x, y, z = data_shape * aspect_ratio - maxw = float(x + y) - maxh = float(y + z) - yh = y / maxh - xw = x / maxw - yw = y / maxw - zh = z / maxh - # z slice (if usual transverse acquisition => axial slice) - ax1 = fig.add_axes((0., 0., xw, yh)) - # y slice (usually coronal) - ax2 = fig.add_axes((0, yh, xw, zh)) - # x slice (usually sagittal) - ax3 = fig.add_axes((xw, yh, yw, zh)) - axes = (ax1, ax2, ax3) - else: - if not np.all(aspect_ratio == 1): - raise ValueError('Aspect ratio must be 1 for external axes') - ax1, ax2, ax3 = axes - - self.data = data + # ^ +---------+ ^ +---------+ + # | | | | | | + # | | | | + # y | 1 | A | 4 | + # | | | | + # | | | | | | + # v +---------+ v +---------+ + # <-- x --> <-- t --> - if pcnt_range is None: - vmin, vmax = data.min(), data.max() + fig, axes = plt.subplots(2, 2) + fig.set_size_inches(figsize, forward=True) + self._axes = dict(x=axes[0, 1], y=axes[0, 0], z=axes[1, 0], + v=axes[1, 1]) + plt.tight_layout(pad=0.1) + if not self.multi_volume: + fig.delaxes(self._axes['v']) + del self._axes['v'] else: - vmin, vmax = np.percentile(data, pcnt_range) + self._axes = dict(z=axes[0], y=axes[1], x=axes[2]) + if len(axes) > 3: + self._axes['v'] = axes[3] - kw = dict(vmin=vmin, - vmax=vmax, - aspect='auto', - interpolation='nearest', - cmap=cmap, - origin='lower') + kw = dict(vmin=vmin, vmax=vmax, aspect=1, interpolation='nearest', + cmap=cmap, origin='lower') - # Start midway through each axis - z_get_slice = lambda i: self.data[:, :, i].T - y_get_slice = lambda i: self.data[:, i, :].T - x_get_slice = lambda i: self.data[i, :, :].T - sts = (data_shape - 1) // 2 - im1 = ax1.imshow(z_get_slice(sts[2]), **kw) - im2 = ax2.imshow(y_get_slice(sts[1]), **kw) - im3 = ax3.imshow(x_get_slice(sts[0]), **kw) - # idx is the current slice number for each panel - im1.idx, im2.idx, im3.idx = sts - self._ims = (im1, im2, im3) - im1.get_slice, im2.get_slice, im3.get_slice = ( - z_get_slice, y_get_slice, x_get_slice) - - # set the maximum dimensions for indexing - im1.size, im2.size, im3.size = data_shape + # Start midway through each axis, idx is current slice number + self._ims, self._sizes, self._idx = dict(), dict(), dict() + colors = dict() + for k, size in zip('xyz', self._data.shape[:3]): + self._idx[k] = size // 2 + self._ims[k] = self._axes[k].imshow(self._get_slice(k), **kw) + self._sizes[k] = size + colors[k] = (0, 1, 0) + self._idx['v'] = 0 + labels = dict(z='ILSR', y='ALPR', x='AIPS') # set up axis crosshairs - colors = ['r', 'g', 'b'] - for ax, im, idx_1, idx_2 in zip(axes, self._ims, [0, 0, 1], [1, 2, 2]): - im.x_line = ax.plot([sts[idx_1]] * 2, - [-0.5, data.shape[idx_2] - 0.5], - color=colors[idx_1], linestyle='-', - alpha=0.25)[0] - im.y_line = ax.plot([-0.5, data.shape[idx_1] - 0.5], - [sts[idx_2]] * 2, - color=colors[idx_2], linestyle='-', - alpha=0.25)[0] - ax.axis('tight') + for type_, i_1, i_2 in zip('zyx', 'xxy', 'yzz'): + ax = self._axes[type_] + im = self._ims[type_] + label = labels[type_] + # add slice lines + im.vert_line = ax.plot([self._idx[i_1]] * 2, + [-0.5, self._sizes[i_2] - 0.5], + color=colors[i_1], linestyle='-')[0] + im.horiz_line = ax.plot([-0.5, self._sizes[i_1] - 0.5], + [self._idx[i_2]] * 2, + color=colors[i_2], linestyle='-')[0] + # add text labels (top, right, bottom, left) + lims = [0, self._sizes[i_1], 0, self._sizes[i_2]] + bump = 0.01 + poss = [[lims[1] / 2., lims[3]], + [(1 + bump) * lims[1], lims[3] / 2.], + [lims[1] / 2., 0], + [lims[0] - bump * lims[1], lims[3] / 2.]] + anchors = [['center', 'bottom'], ['left', 'center'], + ['center', 'top'], ['right', 'center']] + im.texts = [ax.text(pos[0], pos[1], lab, + horizontalalignment=anchor[0], + verticalalignment=anchor[1]) + for pos, anchor, lab in zip(poss, anchors, label)] + ax.axis(lims) + ax.set_aspect(aspect_ratio[type_]) ax.patch.set_visible(False) ax.set_frame_on(False) ax.axes.get_yaxis().set_visible(False) ax.axes.get_xaxis().set_visible(False) - # monkey-patch some functions - im1.set_viewer_slice = partial(_set_viewer_slice, im=im1) - im2.set_viewer_slice = partial(_set_viewer_slice, im=im2) - im3.set_viewer_slice = partial(_set_viewer_slice, im=im3) + # Set up volumes axis + if self.multi_volume: + ax = self._axes['v'] + ax.set_axis_bgcolor('k') + ax.set_title('Volumes') + n_vols = np.prod(self._volume_dims) + print(n_vols) + y = np.mean(np.mean(np.mean(self._data, 0), 0), 0).ravel() + y = np.concatenate((y, [y[-1]])) + x = np.arange(n_vols + 1) - 0.5 + step = ax.step(x, y, where='post', color='y')[0] + ax.set_xticks(np.unique(np.linspace(0, n_vols - 1, 5).astype(int))) + ax.set_xlim(x[0], x[-1]) + line = ax.plot([0, 0], ax.get_ylim(), color=(0, 1, 0))[0] + self._time_lines = [line, step] # setup pairwise connections between the slice dimensions - im1.x_im = im3 # x move in panel 1 (usually axial) - im1.y_im = im2 # y move in panel 1 - im2.x_im = im3 # x move in panel 2 (usually coronal) - im2.y_im = im1 # y move in panel 2 - im3.x_im = im2 # x move in panel 3 (usually sagittal) - im3.y_im = im1 # y move in panel 3 + self._click_update_keys = dict(x='yz', y='xz', z='xy') # when an index changes, which crosshairs need to be updated - im1.cross_setters = [im2.y_line.set_ydata, im3.y_line.set_ydata] - im2.cross_setters = [im1.y_line.set_ydata, im3.x_line.set_xdata] - im3.cross_setters = [im1.x_line.set_xdata, im2.x_line.set_xdata] + self._cross_setters = dict( + x=[self._ims['z'].vert_line.set_xdata, + self._ims['y'].vert_line.set_xdata], + y=[self._ims['z'].horiz_line.set_ydata, + self._ims['x'].vert_line.set_xdata], + z=[self._ims['y'].horiz_line.set_ydata, + self._ims['x'].horiz_line.set_ydata]) - self.figs = set([ax.figure for ax in axes]) - for fig in self.figs: - fig.canvas.mpl_connect('scroll_event', self.on_scroll) - fig.canvas.mpl_connect('motion_notify_event', self.on_mousemove) - fig.canvas.mpl_connect('button_press_event', self.on_mousemove) + self._figs = set([a.figure for a in self._axes.values()]) + for fig in self._figs: + fig.canvas.mpl_connect('scroll_event', self._on_scroll) + fig.canvas.mpl_connect('motion_notify_event', self._on_mousemove) + fig.canvas.mpl_connect('button_press_event', self._on_mousemove) + fig.canvas.mpl_connect('key_press_event', self._on_keypress) + plt.draw() + self._draw() def show(self): """ Show the slicer; convenience for ``plt.show()`` @@ -198,10 +191,15 @@ def show(self): def close(self): """Close the viewer figures """ - for f in self.figs: + for f in self._figs: plt.close(f) - def set_indices(self, x=None, y=None, z=None): + @property + def multi_volume(self): + """Whether or not the displayed data is multi-volume""" + return len(self._volume_dims) > 0 + + def set_indices(self, x=None, y=None, z=None, v=None): """Set current displayed slice indices Parameters @@ -212,46 +210,108 @@ def set_indices(self, x=None, y=None, z=None): Index to use. If None, do not change. z : int | None Index to use. If None, do not change. + v : int | None + Volume index to use. If None, do not change. """ + x = int(x) if x is not None else None + y = int(y) if y is not None else None + z = int(z) if z is not None else None + v = int(v) if v is not None else None draw = False - for im, val in zip(self._ims, (z, y, x)): + if v is not None: + if not self.multi_volume: + raise RuntimeError('cannot change volume index of ' + 'single-volume image') + self._set_vol_idx(v, draw=False) # delay draw + draw = True + for key, val in zip('zyx', (z, y, x)): if val is not None: - im.set_viewer_slice(val) + self._set_viewer_slice(key, val) draw = True if draw: - self._draw_ims() + self._draw() - def _axis_artist(self, event): - """Return artist if within axes, and is an image, else None - """ - if not getattr(event, 'inaxes'): + def _set_vol_idx(self, idx, draw=True): + """Helper to change which volume is shown""" + max_ = np.prod(self._volume_dims) + self._idx['v'] = max(min(int(round(idx)), max_ - 1), 0) + # Must reset what is shown + self._current_vol_data = self._data[:, :, :, self._idx['v']] + for key in 'xyz': + self._ims[key].set_data(self._get_slice(key)) + self._time_lines[0].set_xdata([self._idx['v']] * 2) + if draw: + self._draw() + + def _get_slice(self, key): + """Helper to get the current slice image""" + ii = dict(x=0, y=1, z=2)[key] + return np.take(self._current_vol_data, self._idx[key], axis=ii).T + + def _set_viewer_slice(self, key, idx): + """Helper to set a viewer slice number""" + self._idx[key] = max(min(int(round(idx)), self._sizes[key] - 1), 0) + self._ims[key].set_data(self._get_slice(key)) + for fun in self._cross_setters[key]: + fun([self._idx[key]] * 2) + + def _in_axis(self, event): + """Return axis key if within one of our axes, else None""" + if getattr(event, 'inaxes') is None: return None - artist = event.inaxes.images[0] - return artist if isinstance(artist, mpl_img.AxesImage) else None + for key, ax in self._axes.items(): + if event.inaxes is ax: + return key + return None - def on_scroll(self, event): + def _on_scroll(self, event): assert event.button in ('up', 'down') - im = self._axis_artist(event) - if im is None: + key = self._in_axis(event) + if key is None: return - idx = im.idx + (1 if event.button == 'up' else -1) - im.set_viewer_slice(idx) - self._draw_ims() + delta = 10 if event.key is not None and 'control' in event.key else 1 + if event.key is not None and 'shift' in event.key: + if not self.multi_volume: + return + key = 'v' # shift: change volume in any axis + idx = self._idx[key] + (delta if event.button == 'up' else -delta) + if key == 'v': + self._set_vol_idx(idx) + else: + self._set_viewer_slice(key, idx) + self._draw() - def on_mousemove(self, event): + def _on_mousemove(self, event): if event.button != 1: # only enabled while dragging return - im = self._axis_artist(event) - if im is None: + key = self._in_axis(event) + if key is None: return - for i, idx in zip((im.x_im, im.y_im), (event.xdata, event.ydata)): - i.set_viewer_slice(idx) - self._draw_ims() + if key == 'v': + self._set_vol_idx(event.xdata) + else: + for sub_key, idx in zip(self._click_update_keys[key], + (event.xdata, event.ydata)): + self._set_viewer_slice(sub_key, idx) + self._draw() - def _draw_ims(self): - for im in self._ims: + def _on_keypress(self, event): + if event.key is not None and 'escape' in event.key: + self.close() + + def _draw(self): + for im in self._ims.values(): ax = im.axes + ax.draw_artist(ax.patch) ax.draw_artist(im) - ax.draw_artist(im.x_line) - ax.draw_artist(im.y_line) + ax.draw_artist(im.vert_line) + ax.draw_artist(im.horiz_line) + ax.figure.canvas.blit(ax.bbox) + for t in im.texts: + ax.draw_artist(t) + if self.multi_volume: + ax = self._axes['v'] + ax.draw_artist(ax.patch) + for artist in self._time_lines: + ax.draw_artist(artist) ax.figure.canvas.blit(ax.bbox) From 40d2b3e9ed6061afdecca9c8ddd209cea5a36907 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Thu, 23 Oct 2014 00:30:01 -0700 Subject: [PATCH 07/20] ENH: Better tests --- nibabel/spatialimages.py | 18 +++++++-------- nibabel/tests/test_viewers.py | 43 ++++++++++++++++++++++------------- nibabel/viewers.py | 27 +++++++++++----------- 3 files changed, 49 insertions(+), 39 deletions(-) diff --git a/nibabel/spatialimages.py b/nibabel/spatialimages.py index 7380e65ee0..9bd1cfde3e 100644 --- a/nibabel/spatialimages.py +++ b/nibabel/spatialimages.py @@ -746,20 +746,18 @@ def __getitem__(self): "array data with `img.dataobj[slice]` or " "`img.get_data()[slice]`") - def plot(self, show=True): + def plot(self): """Plot the image using OrthoSlicer3D - Parameters - ---------- - show : bool - If True, the viewer will be shown. - Returns ------- viewer : instance of OrthoSlicer3D The viewer. + + Notes + ----- + This requires matplotlib. If a non-interactive backend is used, + consider using viewer.show() (equivalently plt.show()) to show + the figure. """ - out = OrthoSlicer3D(self.get_data()) - if show: - out.show() - return out + return OrthoSlicer3D(self.get_data()) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index ec61c8dbca..4cae4211d2 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -26,28 +26,39 @@ def test_viewer(): # Test viewer a = np.sin(np.linspace(0, np.pi, 20)) b = np.sin(np.linspace(0, np.pi*5, 30)) - data = np.outer(a, b)[..., np.newaxis] * a + data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis] viewer = OrthoSlicer3D(data) plt.draw() - # fake some events - viewer.on_scroll(nt('event', 'button inaxes')('up', None)) # outside axes - viewer.on_scroll(nt('event', 'button inaxes')('up', plt.gca())) # in axes + # fake some events, inside and outside axes + viewer._on_scroll(nt('event', 'button inaxes key')('up', None, None)) + for ax in (viewer._axes['x'], viewer._axes['v']): + viewer._on_scroll(nt('event', 'button inaxes key')('up', ax, None)) + viewer._on_scroll(nt('event', 'button inaxes key')('up', ax, 'shift')) # "click" outside axes, then once in each axis, then move without click - viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - None, 1)) - for im in viewer._ims: - viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - im.axes, - 1)) - viewer.on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - None, None)) + viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + None, 1)) + for ax in viewer._axes.values(): + viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + ax, 1)) + viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, + None, None)) viewer.set_indices(0, 1, 2) + viewer.set_indices(v=10) viewer.close() + # non-multi-volume + viewer = OrthoSlicer3D(data[:, :, :, 0]) + assert_raises(ValueError, viewer.set_indices, v=10) # not multi-volume + viewer._on_scroll(nt('event', 'button inaxes key')('up', viewer._axes['x'], + 'shift')) + viewer._on_keypress(nt('event', 'key')('escape')) + # other cases - fig, axes = plt.subplots(1, 3) + fig, axes = plt.subplots(1, 4) plt.close(fig) - OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes) - assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2, 3], - axes=axes) + OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes, + aspect_ratio=[1, 2, 3]) + OrthoSlicer3D(data, axes=axes[:3]) + assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2]) + assert_raises(ValueError, OrthoSlicer3D, data[:, :, 0, 0]) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 00fea018c2..834b23f4e9 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -11,6 +11,7 @@ plt, _, _ = optional_package('matplotlib.pyplot') mpl_img, _, _ = optional_package('matplotlib.image') +mpl_patch, _, _ = optional_package('matplotlib.patches') class OrthoSlicer3D(object): @@ -61,7 +62,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', aspect_ratio = dict(x=ar[0], y=ar[1], z=ar[2]) data = np.asanyarray(data) if data.ndim < 3: - raise RuntimeError('data must have at least 3 dimensions') + raise ValueError('data must have at least 3 dimensions') self._volume_dims = data.shape[3:] self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data self._data = data @@ -147,20 +148,23 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', ax.axes.get_xaxis().set_visible(False) # Set up volumes axis - if self.multi_volume: + if self.multi_volume and 'v' in self._axes: ax = self._axes['v'] ax.set_axis_bgcolor('k') ax.set_title('Volumes') n_vols = np.prod(self._volume_dims) - print(n_vols) y = np.mean(np.mean(np.mean(self._data, 0), 0), 0).ravel() y = np.concatenate((y, [y[-1]])) x = np.arange(n_vols + 1) - 0.5 step = ax.step(x, y, where='post', color='y')[0] ax.set_xticks(np.unique(np.linspace(0, n_vols - 1, 5).astype(int))) ax.set_xlim(x[0], x[-1]) - line = ax.plot([0, 0], ax.get_ylim(), color=(0, 1, 0))[0] - self._time_lines = [line, step] + lims = ax.get_ylim() + patch = mpl_patch.Rectangle([-0.5, lims[0]], 1., np.diff(lims)[0], + fill=True, facecolor=(0, 1, 0), + edgecolor=(0, 1, 0), alpha=0.25) + ax.add_patch(patch) + self._time_lines = [patch, step] # setup pairwise connections between the slice dimensions self._click_update_keys = dict(x='yz', y='xz', z='xy') @@ -180,11 +184,9 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', fig.canvas.mpl_connect('motion_notify_event', self._on_mousemove) fig.canvas.mpl_connect('button_press_event', self._on_mousemove) fig.canvas.mpl_connect('key_press_event', self._on_keypress) - plt.draw() - self._draw() def show(self): - """ Show the slicer; convenience for ``plt.show()`` + """ Show the slicer in blocking mode; convenience for ``plt.show()`` """ plt.show() @@ -220,8 +222,8 @@ def set_indices(self, x=None, y=None, z=None, v=None): draw = False if v is not None: if not self.multi_volume: - raise RuntimeError('cannot change volume index of ' - 'single-volume image') + raise ValueError('cannot change volume index of single-volume ' + 'image') self._set_vol_idx(v, draw=False) # delay draw draw = True for key, val in zip('zyx', (z, y, x)): @@ -239,7 +241,7 @@ def _set_vol_idx(self, idx, draw=True): self._current_vol_data = self._data[:, :, :, self._idx['v']] for key in 'xyz': self._ims[key].set_data(self._get_slice(key)) - self._time_lines[0].set_xdata([self._idx['v']] * 2) + self._time_lines[0].set_x(self._idx['v'] - 0.5) if draw: self._draw() @@ -262,7 +264,6 @@ def _in_axis(self, event): for key, ax in self._axes.items(): if event.inaxes is ax: return key - return None def _on_scroll(self, event): assert event.button in ('up', 'down') @@ -309,7 +310,7 @@ def _draw(self): ax.figure.canvas.blit(ax.bbox) for t in im.texts: ax.draw_artist(t) - if self.multi_volume: + if self.multi_volume and 'v' in self._axes: # user might only pass 3 ax = self._axes['v'] ax.draw_artist(ax.patch) for artist in self._time_lines: From f371fb37a04a1abaccbc177064271bb2850c0ecf Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Sat, 25 Oct 2014 22:46:52 -0700 Subject: [PATCH 08/20] FIX: Update volume plot --- nibabel/tests/test_viewers.py | 1 + nibabel/viewers.py | 78 +++++++++++++++++++++-------------- 2 files changed, 47 insertions(+), 32 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index 4cae4211d2..64ea1df514 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -27,6 +27,7 @@ def test_viewer(): a = np.sin(np.linspace(0, np.pi, 20)) b = np.sin(np.linspace(0, np.pi*5, 30)) data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis] + data = data * np.array([1., 2.]) # give it a # of volumes > 1 viewer = OrthoSlicer3D(data) plt.draw() diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 834b23f4e9..540a9b6e67 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -93,7 +93,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', self._axes = dict(x=axes[0, 1], y=axes[0, 0], z=axes[1, 0], v=axes[1, 1]) plt.tight_layout(pad=0.1) - if not self.multi_volume: + if self.n_volumes <= 1: fig.delaxes(self._axes['v']) del self._axes['v'] else: @@ -109,7 +109,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', colors = dict() for k, size in zip('xyz', self._data.shape[:3]): self._idx[k] = size // 2 - self._ims[k] = self._axes[k].imshow(self._get_slice(k), **kw) + self._ims[k] = self._axes[k].imshow(self._get_slice_data(k), **kw) self._sizes[k] = size colors[k] = (0, 1, 0) self._idx['v'] = 0 @@ -148,23 +148,24 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', ax.axes.get_xaxis().set_visible(False) # Set up volumes axis - if self.multi_volume and 'v' in self._axes: + if self.n_volumes > 1 and 'v' in self._axes: ax = self._axes['v'] ax.set_axis_bgcolor('k') ax.set_title('Volumes') - n_vols = np.prod(self._volume_dims) - y = np.mean(np.mean(np.mean(self._data, 0), 0), 0).ravel() - y = np.concatenate((y, [y[-1]])) - x = np.arange(n_vols + 1) - 0.5 + y = self._get_voxel_levels() + x = np.arange(self.n_volumes + 1) - 0.5 step = ax.step(x, y, where='post', color='y')[0] - ax.set_xticks(np.unique(np.linspace(0, n_vols - 1, 5).astype(int))) + ax.set_xticks(np.unique(np.linspace(0, self.n_volumes - 1, + 5).astype(int))) ax.set_xlim(x[0], x[-1]) - lims = ax.get_ylim() - patch = mpl_patch.Rectangle([-0.5, lims[0]], 1., np.diff(lims)[0], - fill=True, facecolor=(0, 1, 0), - edgecolor=(0, 1, 0), alpha=0.25) + yl = [self._data.min(), self._data.max()] + yl = [l + s * np.diff(lims)[0] for l, s in zip(yl, [-1.01, 1.01])] + patch = mpl_patch.Rectangle([-0.5, yl[0]], 1., np.diff(yl)[0], + fill=True, facecolor=(0, 1, 0), + edgecolor=(0, 1, 0), alpha=0.25) ax.add_patch(patch) - self._time_lines = [patch, step] + ax.set_ylim(yl) + self._volume_ax_objs = dict(step=step, patch=patch) # setup pairwise connections between the slice dimensions self._click_update_keys = dict(x='yz', y='xz', z='xy') @@ -197,9 +198,9 @@ def close(self): plt.close(f) @property - def multi_volume(self): - """Whether or not the displayed data is multi-volume""" - return len(self._volume_dims) > 0 + def n_volumes(self): + """Number of volumes in the data""" + return int(np.prod(self._volume_dims)) def set_indices(self, x=None, y=None, z=None, v=None): """Set current displayed slice indices @@ -221,31 +222,43 @@ def set_indices(self, x=None, y=None, z=None, v=None): v = int(v) if v is not None else None draw = False if v is not None: - if not self.multi_volume: + if self.n_volumes <= 1: raise ValueError('cannot change volume index of single-volume ' 'image') - self._set_vol_idx(v, draw=False) # delay draw + self._set_vol_idx(v) draw = True for key, val in zip('zyx', (z, y, x)): if val is not None: self._set_viewer_slice(key, val) draw = True if draw: + self._update_voxel_levels() self._draw() - def _set_vol_idx(self, idx, draw=True): - """Helper to change which volume is shown""" + def _get_voxel_levels(self): + """Get levels of the current voxel as a function of volume""" + y = self._data[self._idx['x'], + self._idx['y'], + self._idx['z'], :].ravel() + y = np.concatenate((y, [y[-1]])) + return y + + def _update_voxel_levels(self): + """Update voxel levels in time plot""" + if self.n_volumes > 1: + self._volume_ax_objs['step'].set_ydata(self._get_voxel_levels()) + + def _set_vol_idx(self, idx): + """Change which volume is shown""" max_ = np.prod(self._volume_dims) self._idx['v'] = max(min(int(round(idx)), max_ - 1), 0) # Must reset what is shown self._current_vol_data = self._data[:, :, :, self._idx['v']] for key in 'xyz': - self._ims[key].set_data(self._get_slice(key)) - self._time_lines[0].set_x(self._idx['v'] - 0.5) - if draw: - self._draw() + self._ims[key].set_data(self._get_slice_data(key)) + self._volume_ax_objs['patch'].set_x(self._idx['v'] - 0.5) - def _get_slice(self, key): + def _get_slice_data(self, key): """Helper to get the current slice image""" ii = dict(x=0, y=1, z=2)[key] return np.take(self._current_vol_data, self._idx[key], axis=ii).T @@ -253,7 +266,7 @@ def _get_slice(self, key): def _set_viewer_slice(self, key, idx): """Helper to set a viewer slice number""" self._idx[key] = max(min(int(round(idx)), self._sizes[key] - 1), 0) - self._ims[key].set_data(self._get_slice(key)) + self._ims[key].set_data(self._get_slice_data(key)) for fun in self._cross_setters[key]: fun([self._idx[key]] * 2) @@ -272,7 +285,7 @@ def _on_scroll(self, event): return delta = 10 if event.key is not None and 'control' in event.key else 1 if event.key is not None and 'shift' in event.key: - if not self.multi_volume: + if self.n_volumes <= 1: return key = 'v' # shift: change volume in any axis idx = self._idx[key] + (delta if event.button == 'up' else -delta) @@ -280,6 +293,7 @@ def _on_scroll(self, event): self._set_vol_idx(idx) else: self._set_viewer_slice(key, idx) + self._update_voxel_levels() self._draw() def _on_mousemove(self, event): @@ -294,6 +308,7 @@ def _on_mousemove(self, event): for sub_key, idx in zip(self._click_update_keys[key], (event.xdata, event.ydata)): self._set_viewer_slice(sub_key, idx) + self._update_voxel_levels() self._draw() def _on_keypress(self, event): @@ -303,16 +318,15 @@ def _on_keypress(self, event): def _draw(self): for im in self._ims.values(): ax = im.axes - ax.draw_artist(ax.patch) ax.draw_artist(im) ax.draw_artist(im.vert_line) ax.draw_artist(im.horiz_line) ax.figure.canvas.blit(ax.bbox) for t in im.texts: ax.draw_artist(t) - if self.multi_volume and 'v' in self._axes: # user might only pass 3 + if self.n_volumes > 1 and 'v' in self._axes: # user might only pass 3 ax = self._axes['v'] - ax.draw_artist(ax.patch) - for artist in self._time_lines: - ax.draw_artist(artist) + ax.draw_artist(ax.patch) # axis bgcolor to erase old lines + for key in ('step', 'patch'): + ax.draw_artist(self._volume_ax_objs[key]) ax.figure.canvas.blit(ax.bbox) From f9958c50cfba18f95baf7dea6a7774a3ee01f3f1 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Sat, 25 Oct 2014 23:13:25 -0700 Subject: [PATCH 09/20] FIX: Remove monkey patching --- nibabel/tests/test_viewers.py | 37 ++++++++--------- nibabel/viewers.py | 78 +++++++++++++++++++---------------- 2 files changed, 59 insertions(+), 56 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index 64ea1df514..8be12f104d 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -28,32 +28,28 @@ def test_viewer(): b = np.sin(np.linspace(0, np.pi*5, 30)) data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis] data = data * np.array([1., 2.]) # give it a # of volumes > 1 - viewer = OrthoSlicer3D(data) + v = OrthoSlicer3D(data) plt.draw() # fake some events, inside and outside axes - viewer._on_scroll(nt('event', 'button inaxes key')('up', None, None)) - for ax in (viewer._axes['x'], viewer._axes['v']): - viewer._on_scroll(nt('event', 'button inaxes key')('up', ax, None)) - viewer._on_scroll(nt('event', 'button inaxes key')('up', ax, 'shift')) + v._on_scroll(nt('event', 'button inaxes key')('up', None, None)) + for ax in (v._axes['x'], v._axes['v']): + v._on_scroll(nt('event', 'button inaxes key')('up', ax, None)) + v._on_scroll(nt('event', 'button inaxes key')('up', ax, 'shift')) # "click" outside axes, then once in each axis, then move without click - viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - None, 1)) - for ax in viewer._axes.values(): - viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - ax, 1)) - viewer._on_mousemove(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, - None, None)) - viewer.set_indices(0, 1, 2) - viewer.set_indices(v=10) - viewer.close() + v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, 1)) + for ax in v._axes.values(): + v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, ax, 1)) + v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None)) + v.set_indices(0, 1, 2) + v.set_indices(v=10) + v.close() # non-multi-volume - viewer = OrthoSlicer3D(data[:, :, :, 0]) - assert_raises(ValueError, viewer.set_indices, v=10) # not multi-volume - viewer._on_scroll(nt('event', 'button inaxes key')('up', viewer._axes['x'], - 'shift')) - viewer._on_keypress(nt('event', 'key')('escape')) + v = OrthoSlicer3D(data[:, :, :, 0]) + assert_raises(ValueError, v.set_indices, v=10) # not multi-volume + v._on_scroll(nt('event', 'button inaxes key')('up', v._axes['x'], 'shift')) + v._on_keypress(nt('event', 'key')('escape')) # other cases fig, axes = plt.subplots(1, 4) @@ -63,3 +59,4 @@ def test_viewer(): OrthoSlicer3D(data, axes=axes[:3]) assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2]) assert_raises(ValueError, OrthoSlicer3D, data[:, :, 0, 0]) + assert_raises(ValueError, OrthoSlicer3D, data, affine=np.eye(3)) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 540a9b6e67..d73a0ec547 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -33,8 +33,8 @@ class OrthoSlicer3D(object): >>> OrthoSlicer3D(data).show() # doctest: +SKIP """ # Skip doctest above b/c not all systems have mpl installed - def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', - pcnt_range=(1., 99.), figsize=(8, 8)): + def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), affine=None, + cmap='gray', pcnt_range=(1., 99.), figsize=(8, 8)): """ Parameters ---------- @@ -46,13 +46,14 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', or None (default). aspect_ratio : array-like, optional Stretch factors for X, Y, Z directions. + affine : array-like | None + Affine transform for the data. This is used to determine + how the data should be sliced for plotting into the X, Y, + and Z view axes. If None, identity is assumed. cmap : str | instance of cmap, optional - String or cmap instance specifying colormap. Will be passed as - ``cmap`` argument to ``plt.imshow``. + String or cmap instance specifying colormap. pcnt_range : array-like, optional - Percentile range over which to scale image for display. If None, - scale between image mean and max. If sequence, min and max - percentile over which to scale image. + Percentile range over which to scale image for display. figsize : tuple Figure size (in inches) to use if axes are None. """ @@ -63,6 +64,10 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', data = np.asanyarray(data) if data.ndim < 3: raise ValueError('data must have at least 3 dimensions') + affine = np.array(affine, float) if affine is not None else np.eye(4) + if affine.ndim != 2 or affine.shape != (4, 4): + raise ValueError('affine must be a 4x4 matrix') + self._affine = affine self._volume_dims = data.shape[3:] self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data self._data = data @@ -116,17 +121,16 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', labels = dict(z='ILSR', y='ALPR', x='AIPS') # set up axis crosshairs + self._crosshairs = dict() for type_, i_1, i_2 in zip('zyx', 'xxy', 'yzz'): - ax = self._axes[type_] - im = self._ims[type_] - label = labels[type_] - # add slice lines - im.vert_line = ax.plot([self._idx[i_1]] * 2, - [-0.5, self._sizes[i_2] - 0.5], - color=colors[i_1], linestyle='-')[0] - im.horiz_line = ax.plot([-0.5, self._sizes[i_1] - 0.5], - [self._idx[i_2]] * 2, - color=colors[i_2], linestyle='-')[0] + ax, label = self._axes[type_], labels[type_] + vert = ax.plot([self._idx[i_1]] * 2, + [-0.5, self._sizes[i_2] - 0.5], + color=colors[i_1], linestyle='-')[0] + horiz = ax.plot([-0.5, self._sizes[i_1] - 0.5], + [self._idx[i_2]] * 2, + color=colors[i_2], linestyle='-')[0] + self._crosshairs[type_] = dict(vert=vert, horiz=horiz) # add text labels (top, right, bottom, left) lims = [0, self._sizes[i_1], 0, self._sizes[i_2]] bump = 0.01 @@ -136,10 +140,10 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', [lims[0] - bump * lims[1], lims[3] / 2.]] anchors = [['center', 'bottom'], ['left', 'center'], ['center', 'top'], ['right', 'center']] - im.texts = [ax.text(pos[0], pos[1], lab, - horizontalalignment=anchor[0], - verticalalignment=anchor[1]) - for pos, anchor, lab in zip(poss, anchors, label)] + for pos, anchor, lab in zip(poss, anchors, label): + ax.text(pos[0], pos[1], lab, + horizontalalignment=anchor[0], + verticalalignment=anchor[1]) ax.axis(lims) ax.set_aspect(aspect_ratio[type_]) ax.patch.set_visible(False) @@ -172,18 +176,18 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), cmap='gray', # when an index changes, which crosshairs need to be updated self._cross_setters = dict( - x=[self._ims['z'].vert_line.set_xdata, - self._ims['y'].vert_line.set_xdata], - y=[self._ims['z'].horiz_line.set_ydata, - self._ims['x'].vert_line.set_xdata], - z=[self._ims['y'].horiz_line.set_ydata, - self._ims['x'].horiz_line.set_ydata]) + x=[self._crosshairs['z']['vert'].set_xdata, + self._crosshairs['y']['vert'].set_xdata], + y=[self._crosshairs['z']['horiz'].set_ydata, + self._crosshairs['x']['vert'].set_xdata], + z=[self._crosshairs['y']['horiz'].set_ydata, + self._crosshairs['x']['horiz'].set_ydata]) self._figs = set([a.figure for a in self._axes.values()]) for fig in self._figs: fig.canvas.mpl_connect('scroll_event', self._on_scroll) - fig.canvas.mpl_connect('motion_notify_event', self._on_mousemove) - fig.canvas.mpl_connect('button_press_event', self._on_mousemove) + fig.canvas.mpl_connect('motion_notify_event', self._on_mouse) + fig.canvas.mpl_connect('button_press_event', self._on_mouse) fig.canvas.mpl_connect('key_press_event', self._on_keypress) def show(self): @@ -279,6 +283,7 @@ def _in_axis(self, event): return key def _on_scroll(self, event): + """Handle mpl scroll wheel event""" assert event.button in ('up', 'down') key = self._in_axis(event) if key is None: @@ -296,7 +301,8 @@ def _on_scroll(self, event): self._update_voxel_levels() self._draw() - def _on_mousemove(self, event): + def _on_mouse(self, event): + """Handle mpl mouse move and button press events""" if event.button != 1: # only enabled while dragging return key = self._in_axis(event) @@ -312,18 +318,18 @@ def _on_mousemove(self, event): self._draw() def _on_keypress(self, event): + """Handle mpl keypress events""" if event.key is not None and 'escape' in event.key: self.close() def _draw(self): - for im in self._ims.values(): - ax = im.axes + """Update all four (or three) plots""" + for key in 'xyz': + ax, im = self._axes[key], self._ims[key] ax.draw_artist(im) - ax.draw_artist(im.vert_line) - ax.draw_artist(im.horiz_line) + for line in self._crosshairs[key].values(): + ax.draw_artist(line) ax.figure.canvas.blit(ax.bbox) - for t in im.texts: - ax.draw_artist(t) if self.n_volumes > 1 and 'v' in self._axes: # user might only pass 3 ax = self._axes['v'] ax.draw_artist(ax.patch) # axis bgcolor to erase old lines From e2c34f2c9081236c1300d4bc75ec54b712fc1117 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 27 Oct 2014 16:58:05 -0700 Subject: [PATCH 10/20] WIP --- nibabel/spatialimages.py | 2 +- nibabel/viewers.py | 32 ++++++++++++++++---------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/nibabel/spatialimages.py b/nibabel/spatialimages.py index 9bd1cfde3e..d132635993 100644 --- a/nibabel/spatialimages.py +++ b/nibabel/spatialimages.py @@ -760,4 +760,4 @@ def plot(self): consider using viewer.show() (equivalently plt.show()) to show the figure. """ - return OrthoSlicer3D(self.get_data()) + return OrthoSlicer3D(self.get_data(), self.get_affine()) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index d73a0ec547..10edd14d09 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -8,6 +8,7 @@ import numpy as np from .optpkg import optional_package +from .orientations import aff2axcodes, axcodes2ornt plt, _, _ = optional_package('matplotlib.pyplot') mpl_img, _, _ = optional_package('matplotlib.image') @@ -33,23 +34,22 @@ class OrthoSlicer3D(object): >>> OrthoSlicer3D(data).show() # doctest: +SKIP """ # Skip doctest above b/c not all systems have mpl installed - def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), affine=None, - cmap='gray', pcnt_range=(1., 99.), figsize=(8, 8)): + def __init__(self, data, affine=None, axes=None, cmap='gray', + pcnt_range=(1., 99.), figsize=(8, 8)): """ Parameters ---------- data : ndarray The data that will be displayed by the slicer. Should have 3+ dimensions. - axes : tuple of mpl.Axes | None, optional - 3 or 4 axes instances for the X, Y, Z slices plus volumes, - or None (default). - aspect_ratio : array-like, optional - Stretch factors for X, Y, Z directions. affine : array-like | None Affine transform for the data. This is used to determine how the data should be sliced for plotting into the X, Y, - and Z view axes. If None, identity is assumed. + and Z view axes. If None, identity is assumed. The aspect + ratio of the data are inferred from the affine transform. + axes : tuple of mpl.Axes | None, optional + 3 or 4 axes instances for the X, Y, Z slices plus volumes, + or None (default). cmap : str | instance of cmap, optional String or cmap instance specifying colormap. pcnt_range : array-like, optional @@ -57,17 +57,17 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), affine=None, figsize : tuple Figure size (in inches) to use if axes are None. """ - ar = np.array(aspect_ratio, float) - if ar.shape != (3,) or np.any(ar <= 0): - raise ValueError('aspect ratio must have exactly 3 elements >= 0') - aspect_ratio = dict(x=ar[0], y=ar[1], z=ar[2]) data = np.asanyarray(data) if data.ndim < 3: raise ValueError('data must have at least 3 dimensions') affine = np.array(affine, float) if affine is not None else np.eye(4) if affine.ndim != 2 or affine.shape != (4, 4): raise ValueError('affine must be a 4x4 matrix') - self._affine = affine + self._affine = affine.copy() + self._codes = axcodes2ornt(aff2axcodes(self._affine)) # XXX USE FOR ORDERING + print(self._codes) + self._scalers = np.abs(self._affine).max(axis=0)[:3] + self._inv_affine = np.linalg.inv(affine) self._volume_dims = data.shape[3:] self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data self._data = data @@ -122,7 +122,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), affine=None, # set up axis crosshairs self._crosshairs = dict() - for type_, i_1, i_2 in zip('zyx', 'xxy', 'yzz'): + for type_, i_1, i_2 in zip('xyz', 'yxx', 'zzy'): ax, label = self._axes[type_], labels[type_] vert = ax.plot([self._idx[i_1]] * 2, [-0.5, self._sizes[i_2] - 0.5], @@ -145,7 +145,7 @@ def __init__(self, data, axes=None, aspect_ratio=(1, 1, 1), affine=None, horizontalalignment=anchor[0], verticalalignment=anchor[1]) ax.axis(lims) - ax.set_aspect(aspect_ratio[type_]) + # ax.set_aspect(aspect_ratio[type_]) # XXX FIX ax.patch.set_visible(False) ax.set_frame_on(False) ax.axes.get_yaxis().set_visible(False) @@ -206,7 +206,7 @@ def n_volumes(self): """Number of volumes in the data""" return int(np.prod(self._volume_dims)) - def set_indices(self, x=None, y=None, z=None, v=None): + def set_position(self, x=None, y=None, z=None, v=None): """Set current displayed slice indices Parameters From 3c85ccadf5baf8df643f7f8dbe00854db144c6d6 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 28 Oct 2014 12:50:18 -0700 Subject: [PATCH 11/20] WIP: Closer to correcting orientation --- nibabel/tests/test_viewers.py | 7 +-- nibabel/viewers.py | 88 ++++++++++++++++++++++------------- 2 files changed, 57 insertions(+), 38 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index 8be12f104d..6ec5c1b1ef 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -41,22 +41,17 @@ def test_viewer(): for ax in v._axes.values(): v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, ax, 1)) v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None)) - v.set_indices(0, 1, 2) - v.set_indices(v=10) v.close() # non-multi-volume v = OrthoSlicer3D(data[:, :, :, 0]) - assert_raises(ValueError, v.set_indices, v=10) # not multi-volume v._on_scroll(nt('event', 'button inaxes key')('up', v._axes['x'], 'shift')) v._on_keypress(nt('event', 'key')('escape')) # other cases fig, axes = plt.subplots(1, 4) plt.close(fig) - OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes, - aspect_ratio=[1, 2, 3]) + OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes) OrthoSlicer3D(data, axes=axes[:3]) - assert_raises(ValueError, OrthoSlicer3D, data, aspect_ratio=[1, 2]) assert_raises(ValueError, OrthoSlicer3D, data[:, :, 0, 0]) assert_raises(ValueError, OrthoSlicer3D, data, affine=np.eye(3)) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 10edd14d09..8893d67467 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -29,7 +29,7 @@ class OrthoSlicer3D(object): ------- >>> import numpy as np >>> a = np.sin(np.linspace(0,np.pi,20)) - >>> b = np.sin(np.linspace(0,np.pi*5,20)) + >>> b = np.sin(np.linspace(0,np.pi*5,20))asa >>> data = np.outer(a,b)[..., np.newaxis]*a >>> OrthoSlicer3D(data).show() # doctest: +SKIP """ @@ -44,11 +44,12 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', dimensions. affine : array-like | None Affine transform for the data. This is used to determine - how the data should be sliced for plotting into the X, Y, - and Z view axes. If None, identity is assumed. The aspect - ratio of the data are inferred from the affine transform. + how the data should be sliced for plotting into the saggital, + coronal, and axial view axes. If None, identity is assumed. + The aspect ratio of the data are inferred from the affine + transform. axes : tuple of mpl.Axes | None, optional - 3 or 4 axes instances for the X, Y, Z slices plus volumes, + 3 or 4 axes instances for the 3 slices plus volumes, or None (default). cmap : str | instance of cmap, optional String or cmap instance specifying colormap. @@ -63,39 +64,43 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', affine = np.array(affine, float) if affine is not None else np.eye(4) if affine.ndim != 2 or affine.shape != (4, 4): raise ValueError('affine must be a 4x4 matrix') + # determine our orientation self._affine = affine.copy() - self._codes = axcodes2ornt(aff2axcodes(self._affine)) # XXX USE FOR ORDERING - print(self._codes) + codes = axcodes2ornt(aff2axcodes(self._affine)) + order = np.argsort([c[0] for c in codes]) + flips = np.array([c[1] for c in codes])[order] + self._order = dict(x=int(order[0]), y=int(order[1]), z=int(order[2])) + self._flips = dict(x=flips[0], y=flips[1], z=flips[2]) self._scalers = np.abs(self._affine).max(axis=0)[:3] self._inv_affine = np.linalg.inv(affine) + # current volume info self._volume_dims = data.shape[3:] self._current_vol_data = data[:, :, :, 0] if data.ndim > 3 else data self._data = data - pcnt_range = (0, 100) if pcnt_range is None else pcnt_range vmin, vmax = np.percentile(data, pcnt_range) del data if axes is None: # make the axes # ^ +---------+ ^ +---------+ # | | | | | | + # | Sag | | Cor | + # S | 1 | S | 2 | # | | | | - # z | 2 | z | 3 | # | | | | - # | | | | | | - # v +---------+ v +---------+ - # <-- x --> <-- y --> - # ^ +---------+ ^ +---------+ - # | | | | | | + # +---------+ +---------+ + # A --> <-- R + # ^ +---------+ +---------+ + # | | | | | + # | Axial | | | + # A | 3 | | 4 | # | | | | - # y | 1 | A | 4 | # | | | | - # | | | | | | - # v +---------+ v +---------+ - # <-- x --> <-- t --> + # +---------+ +---------+ + # <-- R <-- t --> fig, axes = plt.subplots(2, 2) fig.set_size_inches(figsize, forward=True) - self._axes = dict(x=axes[0, 1], y=axes[0, 0], z=axes[1, 0], + self._axes = dict(x=axes[0, 0], y=axes[0, 1], z=axes[1, 0], v=axes[1, 1]) plt.tight_layout(pad=0.1) if self.n_volumes <= 1: @@ -111,14 +116,15 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # Start midway through each axis, idx is current slice number self._ims, self._sizes, self._idx = dict(), dict(), dict() + self._vol = 0 colors = dict() - for k, size in zip('xyz', self._data.shape[:3]): + for k in 'xyz': + size = self._data.shape[self._order[k]] self._idx[k] = size // 2 self._ims[k] = self._axes[k].imshow(self._get_slice_data(k), **kw) self._sizes[k] = size colors[k] = (0, 1, 0) - self._idx['v'] = 0 - labels = dict(z='ILSR', y='ALPR', x='AIPS') + labels = dict(x='SAIP', y='SLIR', z='ALPR') # set up axis crosshairs self._crosshairs = dict() @@ -231,7 +237,7 @@ def set_position(self, x=None, y=None, z=None, v=None): 'image') self._set_vol_idx(v) draw = True - for key, val in zip('zyx', (z, y, x)): + for key, val in zip('xyz', (x, y, z)): if val is not None: self._set_viewer_slice(key, val) draw = True @@ -241,9 +247,11 @@ def set_position(self, x=None, y=None, z=None, v=None): def _get_voxel_levels(self): """Get levels of the current voxel as a function of volume""" - y = self._data[self._idx['x'], - self._idx['y'], - self._idx['z'], :].ravel() + # XXX THIS IS WRONG + #y = self._data[self._idx['x'], + # self._idx['y'], + # self._idx['z'], :].ravel() + y = self._data[0, 0, 0, :].ravel() y = np.concatenate((y, [y[-1]])) return y @@ -255,20 +263,34 @@ def _update_voxel_levels(self): def _set_vol_idx(self, idx): """Change which volume is shown""" max_ = np.prod(self._volume_dims) - self._idx['v'] = max(min(int(round(idx)), max_ - 1), 0) + self._vol = max(min(int(round(idx)), max_ - 1), 0) # Must reset what is shown - self._current_vol_data = self._data[:, :, :, self._idx['v']] + self._current_vol_data = self._data[:, :, :, self._vol] for key in 'xyz': self._ims[key].set_data(self._get_slice_data(key)) - self._volume_ax_objs['patch'].set_x(self._idx['v'] - 0.5) + self._volume_ax_objs['patch'].set_x(self._vol - 0.5) def _get_slice_data(self, key): """Helper to get the current slice image""" - ii = dict(x=0, y=1, z=2)[key] - return np.take(self._current_vol_data, self._idx[key], axis=ii).T + assert key in ['x', 'y', 'z'] + data = np.take(self._current_vol_data, self._idx[key], + axis=self._order[key]) + # saggital: get to S/A + # coronal: get to S/L + # axial: get to A/L + xaxes = dict(x='y', y='x', z='x') + yaxes = dict(x='z', y='z', z='y') + if self._order[xaxes[key]] < self._order[yaxes[key]]: + data = data.T + if self._flips[xaxes[key]]: + data = data[:, ::-1] + if self._flips[yaxes[key]]: + data = data[::-1] + return data def _set_viewer_slice(self, key, idx): """Helper to set a viewer slice number""" + assert key in ['x', 'y', 'z'] self._idx[key] = max(min(int(round(idx)), self._sizes[key] - 1), 0) self._ims[key].set_data(self._get_slice_data(key)) for fun in self._cross_setters[key]: @@ -293,7 +315,9 @@ def _on_scroll(self, event): if self.n_volumes <= 1: return key = 'v' # shift: change volume in any axis - idx = self._idx[key] + (delta if event.button == 'up' else -delta) + assert key in ['x', 'y', 'z', 'v'] + idx = self._idx[key] if key != 'v' else self._vol + idx += delta if event.button == 'up' else -delta if key == 'v': self._set_vol_idx(idx) else: From 7cb4fb54e43e31ca7664a8bcbc48c4e0bbf6694f Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 28 Oct 2014 12:58:34 -0700 Subject: [PATCH 12/20] WIP: Fixed ratio --- nibabel/viewers.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 8893d67467..f5b611aafb 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -68,7 +68,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', self._affine = affine.copy() codes = axcodes2ornt(aff2axcodes(self._affine)) order = np.argsort([c[0] for c in codes]) - flips = np.array([c[1] for c in codes])[order] + flips = np.array([c[1] < 0 for c in codes])[order] self._order = dict(x=int(order[0]), y=int(order[1]), z=int(order[2])) self._flips = dict(x=flips[0], y=flips[1], z=flips[2]) self._scalers = np.abs(self._affine).max(axis=0)[:3] @@ -128,7 +128,10 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # set up axis crosshairs self._crosshairs = dict() - for type_, i_1, i_2 in zip('xyz', 'yxx', 'zzy'): + r = [self._scalers[self._order['z']] / self._scalers[self._order['y']], + self._scalers[self._order['z']] / self._scalers[self._order['x']], + self._scalers[self._order['y']] / self._scalers[self._order['x']]] + for type_, i_1, i_2, ratio in zip('xyz', 'yxx', 'zzy', r): ax, label = self._axes[type_], labels[type_] vert = ax.plot([self._idx[i_1]] * 2, [-0.5, self._sizes[i_2] - 0.5], @@ -151,7 +154,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', horizontalalignment=anchor[0], verticalalignment=anchor[1]) ax.axis(lims) - # ax.set_aspect(aspect_ratio[type_]) # XXX FIX + ax.set_aspect(ratio) ax.patch.set_visible(False) ax.set_frame_on(False) ax.axes.get_yaxis().set_visible(False) From 96966d80d39df9e7d4e4e24354f5cad501c9a435 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Tue, 28 Oct 2014 14:48:18 -0700 Subject: [PATCH 13/20] FIX: FIX time plot --- nibabel/viewers.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index f5b611aafb..8be3345172 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -250,11 +250,10 @@ def set_position(self, x=None, y=None, z=None, v=None): def _get_voxel_levels(self): """Get levels of the current voxel as a function of volume""" - # XXX THIS IS WRONG - #y = self._data[self._idx['x'], - # self._idx['y'], - # self._idx['z'], :].ravel() - y = self._data[0, 0, 0, :].ravel() + idx = [0] * 3 + for key in 'xyz': + idx[self._order[key]] = self._idx[key] + y = self._data[idx[0], idx[1], idx[2], :].ravel() y = np.concatenate((y, [y[-1]])) return y From f5b40feab215ef3131557e3b631447257db9b6c5 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Wed, 29 Oct 2014 16:18:48 -0700 Subject: [PATCH 14/20] FIX: Fix orientations and interactions --- nibabel/tests/test_viewers.py | 16 +- nibabel/viewers.py | 320 +++++++++++++++++++++------------- 2 files changed, 208 insertions(+), 128 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index 6ec5c1b1ef..fa4a336a30 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -14,6 +14,7 @@ from ..viewers import OrthoSlicer3D from numpy.testing.decorators import skipif +from numpy.testing import assert_array_equal from nose.tools import assert_raises @@ -29,7 +30,7 @@ def test_viewer(): data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis] data = data * np.array([1., 2.]) # give it a # of volumes > 1 v = OrthoSlicer3D(data) - plt.draw() + assert_array_equal(v.position, (0, 0, 0)) # fake some events, inside and outside axes v._on_scroll(nt('event', 'button inaxes key')('up', None, None)) @@ -41,6 +42,8 @@ def test_viewer(): for ax in v._axes.values(): v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, ax, 1)) v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None)) + v.set_volume_idx(1) + v.set_volume_idx(1) # should just pass v.close() # non-multi-volume @@ -51,7 +54,14 @@ def test_viewer(): # other cases fig, axes = plt.subplots(1, 4) plt.close(fig) - OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes) - OrthoSlicer3D(data, axes=axes[:3]) + v1 = OrthoSlicer3D(data, pcnt_range=[0.1, 0.9], axes=axes) + aff = np.array([[0, 1, 0, 3], [-1, 0, 0, 2], [0, 0, 2, 1], [0, 0, 0, 1]], + float) + v2 = OrthoSlicer3D(data, affine=aff, axes=axes[:3]) assert_raises(ValueError, OrthoSlicer3D, data[:, :, 0, 0]) assert_raises(ValueError, OrthoSlicer3D, data, affine=np.eye(3)) + assert_raises(TypeError, v2.link_to, 1) + v2.link_to(v1) + v2.link_to(v1) # shouldn't do anything + v1.close() + v2.close() diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 8be3345172..f2d43231a8 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -6,6 +6,7 @@ from __future__ import division, print_function import numpy as np +import weakref from .optpkg import optional_package from .orientations import aff2axcodes, axcodes2ornt @@ -91,7 +92,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # A --> <-- R # ^ +---------+ +---------+ # | | | | | - # | Axial | | | + # | Axial | | Vol | # A | 3 | | 4 | # | | | | # | | | | @@ -111,37 +112,31 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', if len(axes) > 3: self._axes['v'] = axes[3] - kw = dict(vmin=vmin, vmax=vmax, aspect=1, interpolation='nearest', - cmap=cmap, origin='lower') - # Start midway through each axis, idx is current slice number - self._ims, self._sizes, self._idx = dict(), dict(), dict() - self._vol = 0 - colors = dict() - for k in 'xyz': - size = self._data.shape[self._order[k]] - self._idx[k] = size // 2 - self._ims[k] = self._axes[k].imshow(self._get_slice_data(k), **kw) - self._sizes[k] = size - colors[k] = (0, 1, 0) - labels = dict(x='SAIP', y='SLIR', z='ALPR') + self._ims, self._sizes, self._data_idx = dict(), dict(), dict() # set up axis crosshairs self._crosshairs = dict() r = [self._scalers[self._order['z']] / self._scalers[self._order['y']], self._scalers[self._order['z']] / self._scalers[self._order['x']], self._scalers[self._order['y']] / self._scalers[self._order['x']]] - for type_, i_1, i_2, ratio in zip('xyz', 'yxx', 'zzy', r): - ax, label = self._axes[type_], labels[type_] - vert = ax.plot([self._idx[i_1]] * 2, - [-0.5, self._sizes[i_2] - 0.5], - color=colors[i_1], linestyle='-')[0] - horiz = ax.plot([-0.5, self._sizes[i_1] - 0.5], - [self._idx[i_2]] * 2, - color=colors[i_2], linestyle='-')[0] - self._crosshairs[type_] = dict(vert=vert, horiz=horiz) + for k in 'xyz': + self._sizes[k] = self._data.shape[self._order[k]] + for k, xax, yax, ratio, label in zip('xyz', 'yxx', 'zzy', r, + ('SAIP', 'SLIR', 'ALPR')): + ax = self._axes[k] + d = np.zeros((self._sizes[yax], self._sizes[xax])) + self._ims[k] = self._axes[k].imshow(d, vmin=vmin, vmax=vmax, + aspect=1, cmap=cmap, + interpolation='nearest', + origin='lower') + vert = ax.plot([0] * 2, [-0.5, self._sizes[yax] - 0.5], + color=(0, 1, 0), linestyle='-')[0] + horiz = ax.plot([-0.5, self._sizes[xax] - 0.5], [0] * 2, + color=(0, 1, 0), linestyle='-')[0] + self._crosshairs[k] = dict(vert=vert, horiz=horiz) # add text labels (top, right, bottom, left) - lims = [0, self._sizes[i_1], 0, self._sizes[i_2]] + lims = [0, self._sizes[xax], 0, self._sizes[yax]] bump = 0.01 poss = [[lims[1] / 2., lims[3]], [(1 + bump) * lims[1], lims[3] / 2.], @@ -159,13 +154,15 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', ax.set_frame_on(False) ax.axes.get_yaxis().set_visible(False) ax.axes.get_xaxis().set_visible(False) + self._data_idx[k] = 0 + self._data_idx['v'] = -1 # Set up volumes axis if self.n_volumes > 1 and 'v' in self._axes: ax = self._axes['v'] ax.set_axis_bgcolor('k') ax.set_title('Volumes') - y = self._get_voxel_levels() + y = np.zeros(self.n_volumes + 1) x = np.arange(self.n_volumes + 1) - 0.5 step = ax.step(x, y, where='post', color='y')[0] ax.set_xticks(np.unique(np.linspace(0, self.n_volumes - 1, @@ -180,18 +177,6 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', ax.set_ylim(yl) self._volume_ax_objs = dict(step=step, patch=patch) - # setup pairwise connections between the slice dimensions - self._click_update_keys = dict(x='yz', y='xz', z='xy') - - # when an index changes, which crosshairs need to be updated - self._cross_setters = dict( - x=[self._crosshairs['z']['vert'].set_xdata, - self._crosshairs['y']['vert'].set_xdata], - y=[self._crosshairs['z']['horiz'].set_ydata, - self._crosshairs['x']['vert'].set_xdata], - z=[self._crosshairs['y']['horiz'].set_ydata, - self._crosshairs['x']['horiz'].set_ydata]) - self._figs = set([a.figure for a in self._axes.values()]) for fig in self._figs: fig.canvas.mpl_connect('scroll_event', self._on_scroll) @@ -199,8 +184,20 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', fig.canvas.mpl_connect('button_press_event', self._on_mouse) fig.canvas.mpl_connect('key_press_event', self._on_keypress) + # actually set data meaningfully + self._position = np.zeros(4) + self._position[3] = 1. # convenience for affine multn + self._changing = False # keep track of status to avoid loops + self._links = [] # other viewers this one is linked to + for fig in self._figs: + fig.canvas.draw() + self._set_volume_index(0, update_slices=False) + self._set_position(0., 0., 0.) + self._draw() + + # User-level functions ################################################### def show(self): - """ Show the slicer in blocking mode; convenience for ``plt.show()`` + """Show the slicer in blocking mode; convenience for ``plt.show()`` """ plt.show() @@ -209,95 +206,156 @@ def close(self): """ for f in self._figs: plt.close(f) + for link in self._links: + link()._unlink(self) @property def n_volumes(self): """Number of volumes in the data""" return int(np.prod(self._volume_dims)) - def set_position(self, x=None, y=None, z=None, v=None): + @property + def position(self): + """The current coordinates""" + return self._position[:3].copy() + + def link_to(self, other): + """Link positional changes between two canvases + + Parameters + ---------- + other : instance of OrthoSlicer3D + Other viewer to use to link movements. + """ + if not isinstance(other, self.__class__): + raise TypeError('other must be an instance of %s, not %s' + % (self.__class__.__name__, type(other))) + self._link(other, is_primary=True) + + def _link(self, other, is_primary): + """Link a viewer""" + ref = weakref.ref(other) + if ref in self._links: + return + self._links.append(ref) + if is_primary: + other._link(self, is_primary=False) + other.set_position(*self.position) + + def _unlink(self, other): + """Unlink a viewer""" + ref = weakref.ref(other) + if ref in self._links: + self._links.pop(self._links.index(ref)) + ref()._unlink(self) + + def _notify_links(self): + """Notify linked canvases of a position change""" + for link in self._links: + link().set_position(*self.position[:3]) + + def set_position(self, x=None, y=None, z=None): """Set current displayed slice indices Parameters ---------- - x : int | None - Index to use. If None, do not change. - y : int | None - Index to use. If None, do not change. - z : int | None - Index to use. If None, do not change. - v : int | None - Volume index to use. If None, do not change. + x : float | None + X coordinate to use. If None, do not change. + y : float | None + Y coordinate to use. If None, do not change. + z : float | None + Z coordinate to use. If None, do not change. """ - x = int(x) if x is not None else None - y = int(y) if y is not None else None - z = int(z) if z is not None else None - v = int(v) if v is not None else None - draw = False - if v is not None: - if self.n_volumes <= 1: - raise ValueError('cannot change volume index of single-volume ' - 'image') - self._set_vol_idx(v) - draw = True - for key, val in zip('xyz', (x, y, z)): - if val is not None: - self._set_viewer_slice(key, val) - draw = True - if draw: - self._update_voxel_levels() - self._draw() - - def _get_voxel_levels(self): - """Get levels of the current voxel as a function of volume""" - idx = [0] * 3 - for key in 'xyz': - idx[self._order[key]] = self._idx[key] - y = self._data[idx[0], idx[1], idx[2], :].ravel() - y = np.concatenate((y, [y[-1]])) - return y - - def _update_voxel_levels(self): - """Update voxel levels in time plot""" - if self.n_volumes > 1: - self._volume_ax_objs['step'].set_ydata(self._get_voxel_levels()) - - def _set_vol_idx(self, idx): - """Change which volume is shown""" + self._set_position(x, y, z) + self._draw() + + def set_volume_idx(self, v): + """Set current displayed volume index + + Parameters + ---------- + v : int + Volume index. + """ + self._set_volume_index(v) + self._draw() + + def _set_volume_index(self, v, update_slices=True): + """Set the plot data using a volume index""" + v = self._data_idx['v'] if v is None else int(round(v)) + if v == self._data_idx['v']: + return max_ = np.prod(self._volume_dims) - self._vol = max(min(int(round(idx)), max_ - 1), 0) - # Must reset what is shown - self._current_vol_data = self._data[:, :, :, self._vol] + self._data_idx['v'] = max(min(int(round(v)), max_ - 1), 0) + idx = (slice(None), slice(None), slice(None)) + if self._data.ndim > 3: + idx = idx + tuple(np.unravel_index(self._data_idx['v'], + self._volume_dims)) + self._current_vol_data = self._data[idx] + # update all of our slice plots + if update_slices: + self._set_position(None, None, None, notify=False) + + def _set_position(self, x, y, z, notify=True): + """Set the plot data using a physical position""" + # deal with volume first + if self._changing: + return + self._changing = True + x = self._position[0] if x is None else float(x) + y = self._position[1] if y is None else float(y) + z = self._position[2] if z is None else float(z) + + # deal with slicing appropriately + self._position[:3] = [x, y, z] + idxs = np.dot(self._inv_affine, self._position)[:3] + for key, idx in zip('xyz', idxs): + self._data_idx[key] = max(min(int(round(idx)), + self._sizes[key] - 1), 0) for key in 'xyz': - self._ims[key].set_data(self._get_slice_data(key)) - self._volume_ax_objs['patch'].set_x(self._vol - 0.5) - - def _get_slice_data(self, key): - """Helper to get the current slice image""" - assert key in ['x', 'y', 'z'] - data = np.take(self._current_vol_data, self._idx[key], - axis=self._order[key]) - # saggital: get to S/A - # coronal: get to S/L - # axial: get to A/L - xaxes = dict(x='y', y='x', z='x') - yaxes = dict(x='z', y='z', z='y') - if self._order[xaxes[key]] < self._order[yaxes[key]]: - data = data.T - if self._flips[xaxes[key]]: - data = data[:, ::-1] - if self._flips[yaxes[key]]: - data = data[::-1] - return data - - def _set_viewer_slice(self, key, idx): - """Helper to set a viewer slice number""" - assert key in ['x', 'y', 'z'] - self._idx[key] = max(min(int(round(idx)), self._sizes[key] - 1), 0) - self._ims[key].set_data(self._get_slice_data(key)) - for fun in self._cross_setters[key]: - fun([self._idx[key]] * 2) - + # saggital: get to S/A + # coronal: get to S/L + # axial: get to A/L + data = np.take(self._current_vol_data, self._data_idx[key], + axis=self._order[key]) + xax = dict(x='y', y='x', z='x')[key] + yax = dict(x='z', y='z', z='y')[key] + if self._order[xax] < self._order[yax]: + data = data.T + if self._flips[xax]: + data = data[:, ::-1] + if self._flips[yax]: + data = data[::-1] + self._ims[key].set_data(data) + # deal with crosshairs + loc = self._data_idx[key] + if self._flips[key]: + loc = self._sizes[key] - loc + loc = [loc] * 2 + if key == 'x': + self._crosshairs['z']['vert'].set_xdata(loc) + self._crosshairs['y']['vert'].set_xdata(loc) + elif key == 'y': + self._crosshairs['z']['horiz'].set_ydata(loc) + self._crosshairs['x']['vert'].set_xdata(loc) + else: # key == 'z' + self._crosshairs['y']['horiz'].set_ydata(loc) + self._crosshairs['x']['horiz'].set_ydata(loc) + + # Update volume trace + if self.n_volumes > 1 and 'v' in self._axes: + idx = [0] * 3 + for key in 'xyz': + idx[self._order[key]] = self._data_idx[key] + vdata = self._data[idx[0], idx[1], idx[2], :].ravel() + vdata = np.concatenate((vdata, [vdata[-1]])) + self._volume_ax_objs['patch'].set_x(self._data_idx['v'] - 0.5) + self._volume_ax_objs['step'].set_ydata(vdata) + if notify: + self._notify_links() + self._changing = False + + # Matplotlib handlers #################################################### def _in_axis(self, event): """Return axis key if within one of our axes, else None""" if getattr(event, 'inaxes') is None: @@ -312,19 +370,25 @@ def _on_scroll(self, event): key = self._in_axis(event) if key is None: return - delta = 10 if event.key is not None and 'control' in event.key else 1 if event.key is not None and 'shift' in event.key: if self.n_volumes <= 1: return key = 'v' # shift: change volume in any axis assert key in ['x', 'y', 'z', 'v'] - idx = self._idx[key] if key != 'v' else self._vol - idx += delta if event.button == 'up' else -delta + dv = 10. if event.key is not None and 'control' in event.key else 1. + dv *= 1. if event.button == 'up' else -1. + dv *= -1 if self._flips.get(key, False) else 1 + val = self._data_idx[key] + dv if key == 'v': - self._set_vol_idx(idx) + self._set_volume_index(val) else: - self._set_viewer_slice(key, idx) - self._update_voxel_levels() + coords = {key: val} + for k in 'xyz': + if k not in coords: + coords[k] = self._data_idx[k] + coords = np.array([coords['x'], coords['y'], coords['z'], 1.]) + coords = np.dot(self._affine, coords)[:3] + self._set_position(coords[0], coords[1], coords[2]) self._draw() def _on_mouse(self, event): @@ -335,12 +399,18 @@ def _on_mouse(self, event): if key is None: return if key == 'v': - self._set_vol_idx(event.xdata) + # volume plot directly translates + self._set_volume_index(event.xdata) else: - for sub_key, idx in zip(self._click_update_keys[key], - (event.xdata, event.ydata)): - self._set_viewer_slice(sub_key, idx) - self._update_voxel_levels() + # translate click xdata/ydata to physical position + xax, yax = dict(x='yz', y='xz', z='xy')[key] + x, y = event.xdata, event.ydata + x = self._sizes[xax] - x if self._flips[xax] else x + y = self._sizes[yax] - y if self._flips[yax] else y + idxs = {xax: x, yax: y, key: self._data_idx[key]} + idxs = np.array([idxs['x'], idxs['y'], idxs['z'], 1.]) + pos = np.dot(self._affine, idxs)[:3] + self._set_position(*pos) self._draw() def _on_keypress(self, event): From fe4e3c15ecd3fefdb25daf5b8bcf5265aa61f4c2 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Wed, 29 Oct 2014 16:26:15 -0700 Subject: [PATCH 15/20] FIX: Minor fixes --- nibabel/viewers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index f2d43231a8..3870b6f9b2 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -30,7 +30,7 @@ class OrthoSlicer3D(object): ------- >>> import numpy as np >>> a = np.sin(np.linspace(0,np.pi,20)) - >>> b = np.sin(np.linspace(0,np.pi*5,20))asa + >>> b = np.sin(np.linspace(0,np.pi*5,20)) >>> data = np.outer(a,b)[..., np.newaxis]*a >>> OrthoSlicer3D(data).show() # doctest: +SKIP """ @@ -189,6 +189,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', self._position[3] = 1. # convenience for affine multn self._changing = False # keep track of status to avoid loops self._links = [] # other viewers this one is linked to + plt.draw() for fig in self._figs: fig.canvas.draw() self._set_volume_index(0, update_slices=False) From 2d9133b175fb15a5197dc19fc34570e1e305d7cf Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Wed, 29 Oct 2014 16:34:22 -0700 Subject: [PATCH 16/20] FIX: Fix test --- nibabel/tests/test_viewers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index fa4a336a30..a5d0ca4709 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -10,6 +10,12 @@ import numpy as np from collections import namedtuple as nt +try: + import matplotlib + matplotlib.use('agg') +except Exception: + pass + from ..optpkg import optional_package from ..viewers import OrthoSlicer3D From 049eddc492c7f2852c4de2266ba907a8e4629ca8 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Wed, 29 Oct 2014 16:42:40 -0700 Subject: [PATCH 17/20] FIX: Better testing --- nibabel/tests/test_viewers.py | 10 ++++------ nibabel/viewers.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index a5d0ca4709..e639c5e38c 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -10,11 +10,6 @@ import numpy as np from collections import namedtuple as nt -try: - import matplotlib - matplotlib.use('agg') -except Exception: - pass from ..optpkg import optional_package from ..viewers import OrthoSlicer3D @@ -24,13 +19,16 @@ from nose.tools import assert_raises -plt, has_mpl = optional_package('matplotlib.pyplot')[:2] +matplotlib, has_mpl = optional_package('matplotlib')[:2] needs_mpl = skipif(not has_mpl, 'These tests need matplotlib') +if has_mpl: + matplotlib.use('Agg') @needs_mpl def test_viewer(): # Test viewer + plt = optional_package('matplotlib.pyplot')[0] a = np.sin(np.linspace(0, np.pi, 20)) b = np.sin(np.linspace(0, np.pi*5, 30)) data = (np.outer(a, b)[..., np.newaxis] * a)[:, :, :, np.newaxis] diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 3870b6f9b2..12d4f581e7 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -11,10 +11,6 @@ from .optpkg import optional_package from .orientations import aff2axcodes, axcodes2ornt -plt, _, _ = optional_package('matplotlib.pyplot') -mpl_img, _, _ = optional_package('matplotlib.image') -mpl_patch, _, _ = optional_package('matplotlib.patches') - class OrthoSlicer3D(object): """Orthogonal-plane slicer. @@ -59,6 +55,12 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', figsize : tuple Figure size (in inches) to use if axes are None. """ + # Nest imports so that matplotlib.use() has the appropriate + # effect in testing + plt, _, _ = optional_package('matplotlib.pyplot') + mpl_img, _, _ = optional_package('matplotlib.image') + mpl_patch, _, _ = optional_package('matplotlib.patches') + data = np.asanyarray(data) if data.ndim < 3: raise ValueError('data must have at least 3 dimensions') @@ -200,11 +202,13 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', def show(self): """Show the slicer in blocking mode; convenience for ``plt.show()`` """ + plt, _, _ = optional_package('matplotlib.pyplot') plt.show() def close(self): """Close the viewer figures """ + plt, _, _ = optional_package('matplotlib.pyplot') for f in self._figs: plt.close(f) for link in self._links: From 6909049e8cd4650f1d972a957ec770395db2e270 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Wed, 29 Oct 2014 17:17:04 -0700 Subject: [PATCH 18/20] STY: Remove dicts in favor of lists --- nibabel/tests/test_viewers.py | 6 +- nibabel/viewers.py | 181 ++++++++++++++++------------------ 2 files changed, 90 insertions(+), 97 deletions(-) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index e639c5e38c..e0bdfae814 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -38,12 +38,12 @@ def test_viewer(): # fake some events, inside and outside axes v._on_scroll(nt('event', 'button inaxes key')('up', None, None)) - for ax in (v._axes['x'], v._axes['v']): + for ax in (v._axes[0], v._axes[3]): v._on_scroll(nt('event', 'button inaxes key')('up', ax, None)) v._on_scroll(nt('event', 'button inaxes key')('up', ax, 'shift')) # "click" outside axes, then once in each axis, then move without click v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, 1)) - for ax in v._axes.values(): + for ax in v._axes: v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, ax, 1)) v._on_mouse(nt('event', 'xdata ydata inaxes button')(0.5, 0.5, None, None)) v.set_volume_idx(1) @@ -52,7 +52,7 @@ def test_viewer(): # non-multi-volume v = OrthoSlicer3D(data[:, :, :, 0]) - v._on_scroll(nt('event', 'button inaxes key')('up', v._axes['x'], 'shift')) + v._on_scroll(nt('event', 'button inaxes key')('up', v._axes[0], 'shift')) v._on_keypress(nt('event', 'key')('escape')) # other cases diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 12d4f581e7..bfde6bfbfc 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -70,10 +70,9 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # determine our orientation self._affine = affine.copy() codes = axcodes2ornt(aff2axcodes(self._affine)) - order = np.argsort([c[0] for c in codes]) - flips = np.array([c[1] < 0 for c in codes])[order] - self._order = dict(x=int(order[0]), y=int(order[1]), z=int(order[2])) - self._flips = dict(x=flips[0], y=flips[1], z=flips[2]) + self._order = np.argsort([c[0] for c in codes]) + self._flips = np.array([c[1] < 0 for c in codes])[self._order] + self._flips = list(self._flips) + [False] # add volume dim self._scalers = np.abs(self._affine).max(axis=0)[:3] self._inv_affine = np.linalg.inv(affine) # current volume info @@ -87,7 +86,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # ^ +---------+ ^ +---------+ # | | | | | | # | Sag | | Cor | - # S | 1 | S | 2 | + # S | 0 | S | 1 | # | | | | # | | | | # +---------+ +---------+ @@ -95,7 +94,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', # ^ +---------+ +---------+ # | | | | | # | Axial | | Vol | - # A | 3 | | 4 | + # A | 2 | | 3 | # | | | | # | | | | # +---------+ +---------+ @@ -103,40 +102,38 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', fig, axes = plt.subplots(2, 2) fig.set_size_inches(figsize, forward=True) - self._axes = dict(x=axes[0, 0], y=axes[0, 1], z=axes[1, 0], - v=axes[1, 1]) + self._axes = [axes[0, 0], axes[0, 1], axes[1, 0], axes[1, 1]] plt.tight_layout(pad=0.1) if self.n_volumes <= 1: - fig.delaxes(self._axes['v']) - del self._axes['v'] + fig.delaxes(self._axes[3]) + self._axes.pop(-1) else: - self._axes = dict(z=axes[0], y=axes[1], x=axes[2]) + self._axes = [axes[0], axes[1], axes[2]] if len(axes) > 3: - self._axes['v'] = axes[3] + self._axes.append(axes[3]) # Start midway through each axis, idx is current slice number - self._ims, self._sizes, self._data_idx = dict(), dict(), dict() + self._ims, self._data_idx = list(), list() # set up axis crosshairs - self._crosshairs = dict() - r = [self._scalers[self._order['z']] / self._scalers[self._order['y']], - self._scalers[self._order['z']] / self._scalers[self._order['x']], - self._scalers[self._order['y']] / self._scalers[self._order['x']]] - for k in 'xyz': - self._sizes[k] = self._data.shape[self._order[k]] - for k, xax, yax, ratio, label in zip('xyz', 'yxx', 'zzy', r, - ('SAIP', 'SLIR', 'ALPR')): - ax = self._axes[k] + self._crosshairs = [None] * 3 + r = [self._scalers[self._order[2]] / self._scalers[self._order[1]], + self._scalers[self._order[2]] / self._scalers[self._order[0]], + self._scalers[self._order[1]] / self._scalers[self._order[0]]] + self._sizes = [self._data.shape[o] for o in self._order] + for ii, xax, yax, ratio, label in zip([0, 1, 2], [1, 0, 0], [2, 2, 1], + r, ('SAIP', 'SLIR', 'ALPR')): + ax = self._axes[ii] d = np.zeros((self._sizes[yax], self._sizes[xax])) - self._ims[k] = self._axes[k].imshow(d, vmin=vmin, vmax=vmax, - aspect=1, cmap=cmap, - interpolation='nearest', - origin='lower') + im = self._axes[ii].imshow(d, vmin=vmin, vmax=vmax, aspect=1, + cmap=cmap, interpolation='nearest', + origin='lower') + self._ims.append(im) vert = ax.plot([0] * 2, [-0.5, self._sizes[yax] - 0.5], color=(0, 1, 0), linestyle='-')[0] horiz = ax.plot([-0.5, self._sizes[xax] - 0.5], [0] * 2, color=(0, 1, 0), linestyle='-')[0] - self._crosshairs[k] = dict(vert=vert, horiz=horiz) + self._crosshairs[ii] = dict(vert=vert, horiz=horiz) # add text labels (top, right, bottom, left) lims = [0, self._sizes[xax], 0, self._sizes[yax]] bump = 0.01 @@ -156,12 +153,12 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', ax.set_frame_on(False) ax.axes.get_yaxis().set_visible(False) ax.axes.get_xaxis().set_visible(False) - self._data_idx[k] = 0 - self._data_idx['v'] = -1 + self._data_idx.append(0) + self._data_idx.append(-1) # volume # Set up volumes axis - if self.n_volumes > 1 and 'v' in self._axes: - ax = self._axes['v'] + if self.n_volumes > 1 and len(self._axes) > 3: + ax = self._axes[3] ax.set_axis_bgcolor('k') ax.set_title('Volumes') y = np.zeros(self.n_volumes + 1) @@ -179,7 +176,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', ax.set_ylim(yl) self._volume_ax_objs = dict(step=step, patch=patch) - self._figs = set([a.figure for a in self._axes.values()]) + self._figs = set([a.figure for a in self._axes]) for fig in self._figs: fig.canvas.mpl_connect('scroll_event', self._on_scroll) fig.canvas.mpl_connect('motion_notify_event', self._on_mouse) @@ -287,14 +284,14 @@ def set_volume_idx(self, v): def _set_volume_index(self, v, update_slices=True): """Set the plot data using a volume index""" - v = self._data_idx['v'] if v is None else int(round(v)) - if v == self._data_idx['v']: + v = self._data_idx[3] if v is None else int(round(v)) + if v == self._data_idx[3]: return max_ = np.prod(self._volume_dims) - self._data_idx['v'] = max(min(int(round(v)), max_ - 1), 0) + self._data_idx[3] = max(min(int(round(v)), max_ - 1), 0) idx = (slice(None), slice(None), slice(None)) if self._data.ndim > 3: - idx = idx + tuple(np.unravel_index(self._data_idx['v'], + idx = idx + tuple(np.unravel_index(self._data_idx[3], self._volume_dims)) self._current_vol_data = self._data[idx] # update all of our slice plots @@ -314,47 +311,46 @@ def _set_position(self, x, y, z, notify=True): # deal with slicing appropriately self._position[:3] = [x, y, z] idxs = np.dot(self._inv_affine, self._position)[:3] - for key, idx in zip('xyz', idxs): - self._data_idx[key] = max(min(int(round(idx)), - self._sizes[key] - 1), 0) - for key in 'xyz': + for ii, (size, idx) in enumerate(zip(self._sizes, idxs)): + self._data_idx[ii] = max(min(int(round(idx)), size - 1), 0) + for ii in range(3): # saggital: get to S/A # coronal: get to S/L # axial: get to A/L - data = np.take(self._current_vol_data, self._data_idx[key], - axis=self._order[key]) - xax = dict(x='y', y='x', z='x')[key] - yax = dict(x='z', y='z', z='y')[key] + data = np.take(self._current_vol_data, self._data_idx[ii], + axis=self._order[ii]) + xax = [1, 0, 0][ii] + yax = [2, 2, 1][ii] if self._order[xax] < self._order[yax]: data = data.T if self._flips[xax]: data = data[:, ::-1] if self._flips[yax]: data = data[::-1] - self._ims[key].set_data(data) + self._ims[ii].set_data(data) # deal with crosshairs - loc = self._data_idx[key] - if self._flips[key]: - loc = self._sizes[key] - loc + loc = self._data_idx[ii] + if self._flips[ii]: + loc = self._sizes[ii] - loc loc = [loc] * 2 - if key == 'x': - self._crosshairs['z']['vert'].set_xdata(loc) - self._crosshairs['y']['vert'].set_xdata(loc) - elif key == 'y': - self._crosshairs['z']['horiz'].set_ydata(loc) - self._crosshairs['x']['vert'].set_xdata(loc) - else: # key == 'z' - self._crosshairs['y']['horiz'].set_ydata(loc) - self._crosshairs['x']['horiz'].set_ydata(loc) + if ii == 0: + self._crosshairs[2]['vert'].set_xdata(loc) + self._crosshairs[1]['vert'].set_xdata(loc) + elif ii == 1: + self._crosshairs[2]['horiz'].set_ydata(loc) + self._crosshairs[0]['vert'].set_xdata(loc) + else: # ii == 2 + self._crosshairs[1]['horiz'].set_ydata(loc) + self._crosshairs[0]['horiz'].set_ydata(loc) # Update volume trace - if self.n_volumes > 1 and 'v' in self._axes: - idx = [0] * 3 - for key in 'xyz': - idx[self._order[key]] = self._data_idx[key] - vdata = self._data[idx[0], idx[1], idx[2], :].ravel() + if self.n_volumes > 1 and len(self._axes) > 3: + idx = [None, Ellipsis] * 3 + for ii in range(3): + idx[self._order[ii]] = self._data_idx[ii] + vdata = self._data[idx].ravel() vdata = np.concatenate((vdata, [vdata[-1]])) - self._volume_ax_objs['patch'].set_x(self._data_idx['v'] - 0.5) + self._volume_ax_objs['patch'].set_x(self._data_idx[3] - 0.5) self._volume_ax_objs['step'].set_ydata(vdata) if notify: self._notify_links() @@ -362,60 +358,57 @@ def _set_position(self, x, y, z, notify=True): # Matplotlib handlers #################################################### def _in_axis(self, event): - """Return axis key if within one of our axes, else None""" + """Return axis index if within one of our axes, else None""" if getattr(event, 'inaxes') is None: return None - for key, ax in self._axes.items(): + for ii, ax in enumerate(self._axes): if event.inaxes is ax: - return key + return ii def _on_scroll(self, event): """Handle mpl scroll wheel event""" assert event.button in ('up', 'down') - key = self._in_axis(event) - if key is None: + ii = self._in_axis(event) + if ii is None: return if event.key is not None and 'shift' in event.key: if self.n_volumes <= 1: return - key = 'v' # shift: change volume in any axis - assert key in ['x', 'y', 'z', 'v'] + ii = 3 # shift: change volume in any axis + assert ii in range(4) dv = 10. if event.key is not None and 'control' in event.key else 1. dv *= 1. if event.button == 'up' else -1. - dv *= -1 if self._flips.get(key, False) else 1 - val = self._data_idx[key] + dv - if key == 'v': + dv *= -1 if self._flips[ii] else 1 + val = self._data_idx[ii] + dv + if ii == 3: self._set_volume_index(val) else: - coords = {key: val} - for k in 'xyz': - if k not in coords: - coords[k] = self._data_idx[k] - coords = np.array([coords['x'], coords['y'], coords['z'], 1.]) - coords = np.dot(self._affine, coords)[:3] - self._set_position(coords[0], coords[1], coords[2]) + coords = [self._data_idx[k] for k in range(3)] + [1.] + coords[ii] = val + self._set_position(*np.dot(self._affine, coords)[:3]) self._draw() def _on_mouse(self, event): """Handle mpl mouse move and button press events""" if event.button != 1: # only enabled while dragging return - key = self._in_axis(event) - if key is None: + ii = self._in_axis(event) + if ii is None: return - if key == 'v': + if ii == 3: # volume plot directly translates self._set_volume_index(event.xdata) else: # translate click xdata/ydata to physical position - xax, yax = dict(x='yz', y='xz', z='xy')[key] + xax, yax = [[1, 2], [0, 2], [0, 1]][ii] x, y = event.xdata, event.ydata x = self._sizes[xax] - x if self._flips[xax] else x y = self._sizes[yax] - y if self._flips[yax] else y - idxs = {xax: x, yax: y, key: self._data_idx[key]} - idxs = np.array([idxs['x'], idxs['y'], idxs['z'], 1.]) - pos = np.dot(self._affine, idxs)[:3] - self._set_position(*pos) + idxs = [None, None, None, 1.] + idxs[xax] = x + idxs[yax] = y + idxs[ii] = self._data_idx[ii] + self._set_position(*np.dot(self._affine, idxs)[:3]) self._draw() def _on_keypress(self, event): @@ -425,14 +418,14 @@ def _on_keypress(self, event): def _draw(self): """Update all four (or three) plots""" - for key in 'xyz': - ax, im = self._axes[key], self._ims[key] - ax.draw_artist(im) - for line in self._crosshairs[key].values(): + for ii in range(3): + ax = self._axes[ii] + ax.draw_artist(self._ims[ii]) + for line in self._crosshairs[ii].values(): ax.draw_artist(line) ax.figure.canvas.blit(ax.bbox) - if self.n_volumes > 1 and 'v' in self._axes: # user might only pass 3 - ax = self._axes['v'] + if self.n_volumes > 1 and len(self._axes) > 3: + ax = self._axes[3] ax.draw_artist(ax.patch) # axis bgcolor to erase old lines for key in ('step', 'patch'): ax.draw_artist(self._volume_ax_objs[key]) From 8333d0aaba00fb84d3947b4092a93d2a436d3dcf Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 31 Oct 2014 17:02:37 -0700 Subject: [PATCH 19/20] FIX: Cleanup on close --- nibabel/viewers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index bfde6bfbfc..055227c87b 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -182,6 +182,7 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', fig.canvas.mpl_connect('motion_notify_event', self._on_mouse) fig.canvas.mpl_connect('button_press_event', self._on_mouse) fig.canvas.mpl_connect('key_press_event', self._on_keypress) + fig.canvas.mpl_connect('close_event', self._cleanup) # actually set data meaningfully self._position = np.zeros(4) @@ -205,9 +206,13 @@ def show(self): def close(self): """Close the viewer figures """ + self._cleanup() plt, _, _ = optional_package('matplotlib.pyplot') for f in self._figs: plt.close(f) + + def _cleanup(self): + """Clean up before closing""" for link in self._links: link()._unlink(self) From 85248cde851f133a8a5ad766145ae9ca839b7d17 Mon Sep 17 00:00:00 2001 From: Eric89GXL Date: Fri, 31 Oct 2014 22:34:54 -0700 Subject: [PATCH 20/20] FIX: Better unlinking --- nibabel/spatialimages.py | 3 ++- nibabel/tests/test_viewers.py | 4 +++- nibabel/viewers.py | 21 ++++++++++++++++++--- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/nibabel/spatialimages.py b/nibabel/spatialimages.py index d132635993..25272da482 100644 --- a/nibabel/spatialimages.py +++ b/nibabel/spatialimages.py @@ -760,4 +760,5 @@ def plot(self): consider using viewer.show() (equivalently plt.show()) to show the figure. """ - return OrthoSlicer3D(self.get_data(), self.get_affine()) + return OrthoSlicer3D(self.get_data(), self.get_affine(), + title=self.get_filename()) diff --git a/nibabel/tests/test_viewers.py b/nibabel/tests/test_viewers.py index e0bdfae814..cfaf925647 100644 --- a/nibabel/tests/test_viewers.py +++ b/nibabel/tests/test_viewers.py @@ -17,7 +17,7 @@ from numpy.testing.decorators import skipif from numpy.testing import assert_array_equal -from nose.tools import assert_raises +from nose.tools import assert_raises, assert_true matplotlib, has_mpl = optional_package('matplotlib')[:2] needs_mpl = skipif(not has_mpl, 'These tests need matplotlib') @@ -35,6 +35,7 @@ def test_viewer(): data = data * np.array([1., 2.]) # give it a # of volumes > 1 v = OrthoSlicer3D(data) assert_array_equal(v.position, (0, 0, 0)) + assert_true('OrthoSlicer3D' in repr(v)) # fake some events, inside and outside axes v._on_scroll(nt('event', 'button inaxes key')('up', None, None)) @@ -49,6 +50,7 @@ def test_viewer(): v.set_volume_idx(1) v.set_volume_idx(1) # should just pass v.close() + v._draw() # should be safe # non-multi-volume v = OrthoSlicer3D(data[:, :, :, 0]) diff --git a/nibabel/viewers.py b/nibabel/viewers.py index 055227c87b..735d1e8eef 100644 --- a/nibabel/viewers.py +++ b/nibabel/viewers.py @@ -32,7 +32,7 @@ class OrthoSlicer3D(object): """ # Skip doctest above b/c not all systems have mpl installed def __init__(self, data, affine=None, axes=None, cmap='gray', - pcnt_range=(1., 99.), figsize=(8, 8)): + pcnt_range=(1., 99.), figsize=(8, 8), title=None): """ Parameters ---------- @@ -60,6 +60,8 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', plt, _, _ = optional_package('matplotlib.pyplot') mpl_img, _, _ = optional_package('matplotlib.image') mpl_patch, _, _ = optional_package('matplotlib.patches') + self._title = title + self._closed = False data = np.asanyarray(data) if data.ndim < 3: @@ -107,6 +109,8 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', if self.n_volumes <= 1: fig.delaxes(self._axes[3]) self._axes.pop(-1) + if self._title is not None: + fig.canvas.set_window_title(str(title)) else: self._axes = [axes[0], axes[1], axes[2]] if len(axes) > 3: @@ -196,6 +200,14 @@ def __init__(self, data, affine=None, axes=None, cmap='gray', self._set_position(0., 0., 0.) self._draw() + def __repr__(self): + title = '' if self._title is None else ('%s ' % self._title) + vol = '' if self.n_volumes <= 1 else (', %s' % self.n_volumes) + r = ('<%s: %s(%s, %s, %s%s)>' + % (self.__class__.__name__, title, self._sizes[0], self._sizes[1], + self._sizes[2], vol)) + return r + # User-level functions ################################################### def show(self): """Show the slicer in blocking mode; convenience for ``plt.show()`` @@ -213,8 +225,9 @@ def close(self): def _cleanup(self): """Clean up before closing""" - for link in self._links: - link()._unlink(self) + self._closed = True + for link in list(self._links): # make a copy before iterating + self._unlink(link()) @property def n_volumes(self): @@ -423,6 +436,8 @@ def _on_keypress(self, event): def _draw(self): """Update all four (or three) plots""" + if self._closed: # make sure we don't draw when we shouldn't + return for ii in range(3): ax = self._axes[ii] ax.draw_artist(self._ims[ii])