diff --git a/src/imagej/_java.py b/src/imagej/_java.py index 1de17a29..6b4e8bf6 100644 --- a/src/imagej/_java.py +++ b/src/imagej/_java.py @@ -30,6 +30,10 @@ class MyJavaClasses(JavaClasses): significantly easier and more readable. """ + @JavaClasses.java_import + def Double(self): + return "java.lang.Double" + @JavaClasses.java_import def Throwable(self): return "java.lang.Throwable" @@ -50,6 +54,14 @@ def MetadataWrapper(self): def LabelingIOService(self): return "io.scif.labeling.LabelingIOService" + @JavaClasses.java_import + def DefaultLinearAxis(self): + return "net.imagej.axis.DefaultLinearAxis" + + @JavaClasses.java_import + def EnumeratedAxis(self): + return "net.imagej.axis.EnumeratedAxis" + @JavaClasses.java_import def Dataset(self): return "net.imagej.Dataset" diff --git a/src/imagej/dims.py b/src/imagej/dims.py index f009c3a0..03cf29e4 100644 --- a/src/imagej/dims.py +++ b/src/imagej/dims.py @@ -2,7 +2,7 @@ Utility functions for querying and manipulating dimensional axis metadata. """ import logging -from typing import List, Tuple +from typing import List, Tuple, Union import numpy as np import scyjava as sj @@ -177,49 +177,53 @@ def prioritize_rai_axes_order( return permute_order -def _assign_axes(xarr: xr.DataArray): +def _assign_axes( + xarr: xr.DataArray, +) -> List[Union["jc.DefaultLinearAxis", "jc.EnumeratedAxis"]]: """ - Obtain xarray axes names, origin, and scale and convert into ImageJ Axis; - currently supports EnumeratedAxis - :param xarr: xarray that holds the units - :return: A list of ImageJ Axis with the specified origin and scale + Obtain xarray axes names, origin, scale and convert into ImageJ Axis. Supports both + DefaultLinearAxis and the newer EnumeratedAxis. + + Note that, in many cases, there are small discrepancies between the coordinates. + This can either be actually within the data, or it can be from floating point math + errors. In this case, we delegate to numpy.isclose to tell us whether our + coordinates are linear or not. If our coordinates are nonlinear, and the + EnumeratedAxis type is available, we will use it. Otherwise, this function + returns a DefaultLinearAxis. + + :param xarr: xarray that holds the data. + :return: A list of ImageJ Axis with the specified origin and scale. """ - Double = sj.jimport("java.lang.Double") - - axes = [""] * len(xarr.dims) - - # try to get EnumeratedAxis, if not then default to LinearAxis in the loop - try: - EnumeratedAxis = _get_enumerated_axis() - except (JException, TypeError): - EnumeratedAxis = None - + axes = [""] * xarr.ndim for dim in xarr.dims: - axis_str = _convert_dim(dim, direction="java") + axis_str = _convert_dim(dim, "java") ax_type = jc.Axes.get(axis_str) ax_num = _get_axis_num(xarr, dim) - scale = _get_scale(xarr.coords[dim]) + coords_arr = xarr.coords[dim] - if scale is None: + # coerce numeric scale + if not _is_numeric_scale(coords_arr): _logger.warning( - f"The {ax_type.label} axis is non-numeric and is translated " + f"The {ax_type.getLabel()} axis is non-numeric and is translated " "to a linear index." ) - doub_coords = [ - Double(np.double(x)) for x in np.arange(len(xarr.coords[dim])) - ] + coords_arr = [np.double(x) for x in np.arange(len(xarr.coords[dim]))] else: - doub_coords = [Double(np.double(x)) for x in xarr.coords[dim]] - - # EnumeratedAxis is a new axis made for xarray, so is only present in - # ImageJ versions that are released later than March 2020. - # This actually returns a LinearAxis if using an earlier version. - if EnumeratedAxis is not None: - java_axis = EnumeratedAxis(ax_type, sj.to_java(doub_coords)) + coords_arr = coords_arr.to_numpy().astype(np.double) + + # check scale linearity + diffs = np.diff(coords_arr) + linear: bool = diffs.size and np.all(np.isclose(diffs, diffs[0])) + + if not linear: + try: + j_coords = [jc.Double(x) for x in coords_arr] + axes[ax_num] = jc.EnumeratedAxis(ax_type, sj.to_java(j_coords)) + except (JException, TypeError): + # if EnumeratedAxis not available - use DefaultLinearAxis + axes[ax_num] = _get_default_linear_axis(coords_arr, ax_type) else: - java_axis = _get_linear_axis(ax_type, sj.to_java(doub_coords)) - - axes[ax_num] = java_axis + axes[ax_num] = _get_default_linear_axis(coords_arr, ax_type) return axes @@ -274,48 +278,26 @@ def _get_axes_coords( return coords -def _get_scale(axis): +def _get_default_linear_axis(coords_arr: np.ndarray, ax_type: "jc.AxisType"): """ - Get the scale of an axis, assuming it is linear and so the scale is simply - second - first coordinate. + Create a new DefaultLinearAxis with the given coordinate array and axis type. - :param axis: A 1D list like entry accessible with indexing, which contains the - axis coordinates - :return: The scale for this axis or None if it is a non-numeric scale. + :param coords_arr: A 1D NumPy array. + :return: An instance of net.imagej.axis.DefaultLinearAxis. """ - try: - # HACK: This axis length check is a work around for singleton dimensions. - # You can't calculate the slope of a singleton dimension. - # This section will be removed when axis-scale-logic is merged. - if len(axis) <= 1: - return 1 - else: - return axis.values[1] - axis.values[0] - except TypeError: - return None - + scale = coords_arr[1] - coords_arr[0] if len(coords_arr) > 1 else 1 + origin = coords_arr[0] if len(coords_arr) > 0 else 0 + return jc.DefaultLinearAxis(ax_type, jc.Double(scale), jc.Double(origin)) -def _get_enumerated_axis(): - """Get EnumeratedAxis. - EnumeratedAxis is only in releases later than March 2020. If using - an older version of ImageJ without EnumeratedAxis, use - _get_linear_axis() instead. +def _is_numeric_scale(coords_array: np.ndarray) -> bool: """ - return sj.jimport("net.imagej.axis.EnumeratedAxis") - - -def _get_linear_axis(axis_type: "jc.AxisType", values): - """Get linear axis. + Checks if the coordinates array of the given axis is numeric. - This is used if no EnumeratedAxis is found. If EnumeratedAxis - is available, use _get_enumerated_axis() instead. + :param coords_array: A 1D NumPy array. + :return: bool """ - DefaultLinearAxis = sj.jimport("net.imagej.axis.DefaultLinearAxis") - origin = values[0] - scale = values[1] - values[0] - axis = DefaultLinearAxis(axis_type, scale, origin) - return axis + return np.issubdtype(coords_array.dtype, np.number) def _dataset_to_imgplus(rai: "jc.RandomAccessibleInterval") -> "jc.ImgPlus": diff --git a/tests/test_image_conversion.py b/tests/test_image_conversion.py index 977ce47f..34461308 100644 --- a/tests/test_image_conversion.py +++ b/tests/test_image_conversion.py @@ -1,4 +1,5 @@ import random +import string import numpy as np import pytest @@ -7,6 +8,7 @@ import imagej.dims as dims import imagej.images as images +from imagej._java import jc # -- Image helpers -- @@ -94,6 +96,75 @@ def get_xarr(option="C"): return xarr +def get_non_linear_coord_xarr(option="C"): + name: str = "non_linear_coord_data_array" + linear_coord_arr = np.arange(5) + # generate a 1D log scale array + non_linear_coord_arr = np.logspace(0, np.log10(100), num=30) + if option == "C": + xarr = xr.DataArray( + np.random.rand(30, 30, 5), + dims=["row", "col", "ch"], + coords={ + "row": non_linear_coord_arr, + "col": non_linear_coord_arr, + "ch": linear_coord_arr, + }, + attrs={"Hello": "World"}, + name=name, + ) + elif option == "F": + xarr = xr.DataArray( + np.ndarray([30, 30, 5], order="F"), + dims=["row", "col", "ch"], + coords={ + "row": non_linear_coord_arr, + "col": non_linear_coord_arr, + "ch": linear_coord_arr, + }, + attrs={"Hello": "World"}, + name=name, + ) + else: + xarr = xr.DataArray(np.random.rand(30, 30, 5), name=name) + + return xarr + + +def get_non_numeric_coord_xarr(option="C"): + name: str = "non_numeric_coord_data_array" + non_numeric_coord_list = [random.choice(string.ascii_letters) for _ in range(30)] + linear_coord_arr = np.arange(5) + if option == "C": + xarr = xr.DataArray( + np.random.rand(30, 30, 5), + dims=["row", "col", "ch"], + coords={ + "row": non_numeric_coord_list, + "col": non_numeric_coord_list, + "ch": linear_coord_arr, + }, + attrs={"Hello": "World"}, + name=name, + ) + elif option == "F": + xarr = xr.DataArray( + np.ndarray([30, 30, 5], order="F"), + dims=["row", "col", "ch"], + coords={ + "row": non_numeric_coord_list, + "col": non_numeric_coord_list, + "ch": linear_coord_arr, + }, + attrs={"Hello": "World"}, + name=name, + ) + else: + xarr = xr.DataArray(np.random.rand(30, 30, 5), name=name) + + return xarr + + # -- Helpers -- @@ -359,6 +430,34 @@ def test_no_coords_or_dims_in_xarr(ij_fixture): assert_inverted_xarr_equal_to_xarr(dataset, ij_fixture, xarr) +def test_linear_coord_on_xarr_conversion(ij_fixture): + xarr = get_xarr() + dataset = ij_fixture.py.to_java(xarr) + axes = dataset.dim_axes + # all axes should be DefaultLinearAxis + for ax in axes: + assert isinstance(ax, jc.DefaultLinearAxis) + + +def test_non_linear_coord_on_xarr_conversion(ij_fixture): + xarr = get_non_linear_coord_xarr() + dataset = ij_fixture.py.to_java(xarr) + axes = dataset.dim_axes + # axes [0, 1] should be EnumeratedAxis with axis 2 as DefaultLinearAxis + for i in range(2): + assert isinstance(axes[i], jc.EnumeratedAxis) + assert isinstance(axes[-1], jc.DefaultLinearAxis) + + +def test_non_numeric_coord_on_xarr_conversion(ij_fixture): + xarr = get_non_numeric_coord_xarr() + dataset = ij_fixture.py.to_java(xarr) + axes = dataset.dim_axes + # all axes should be DefaultLinearAxis + for ax in axes: + assert isinstance(ax, jc.DefaultLinearAxis) + + dataset_conversion_parameters = [ ( get_img,