diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 5149bdae2..ab2ff4203 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -137,10 +137,11 @@ def _pad_volume( from __future__ import annotations from pint import Quantity -from typing import Any +from typing import Any, TYPE_CHECKING import warnings import numpy as np +from numpy.typing import NDArray from scipy.ndimage import convolve from deeptrack.backend.units import ( @@ -158,6 +159,9 @@ def _pad_volume( from deeptrack import image from deeptrack import units_registry as u +if TYPE_CHECKING: + import torch + #TODO ***??*** revise Microscope - torch, typing, docstring, unit test class Microscope(StructuralFeature): @@ -207,30 +211,32 @@ class Microscope(StructuralFeature): __distributed__ = False + _sample: Feature + _objective: Feature + def __init__( self: Microscope, sample: Feature, objective: Feature, **kwargs: Any, ): - """Initialize the `Microscope` instance. + """Initialize a microscope feature combining sample and optics. + + This constructor attaches a sample feature (typically a combination of + scatterers) and an objective feature (optical system) to the + microscope. Parameters ---------- sample: Feature - A feature-set resolving a list of images describing the sample to be - imaged. + Feature that resolves one or more scatterer volumes or fields + representing the sample to be imaged. objective: Feature - A feature-set defining the optical device that images the sample. + Feature describing the optical system used to image the sample + (e.g., brightfield, fluorescence). **kwargs: Any - Additional parameters passed to the base `StructuralFeature` class. - - Attributes - ---------- - _sample: Feature - The feature-set defining the sample to be imaged. - _objective: Feature - The feature-set defining the optical system imaging the sample. + Additional keyword arguments passed to the base `StructuralFeature` + class. """ @@ -238,28 +244,27 @@ def __init__( self._sample = self.add_feature(sample) self._objective = self.add_feature(objective) + + #TODO: erase following line when rid of Image self._sample.store_properties() def get( self: Microscope, - image: Image | None, + input: Any = None, # Ignored, kept for API compatibility **kwargs: Any, - ) -> Image: + ) -> NDArray[Any] | torch.Tensor: """Generate an image of the sample using the defined optical system. - This method processes the sample through the optical system to - produce a simulated image. - Parameters ---------- - image: Image | None - The input image to be processed. If None, a new image is created. + image: Any, optional + Ignored. Kept for API compatibility. Defaults to None. **kwargs: Any Additional parameters for the imaging process. Returns ------- - Image: Image + array or tensor The processed image after applying the optical system. Examples @@ -277,94 +282,131 @@ def get( """ - # Grab properties from the objective to pass to the sample - additional_sample_kwargs = self._objective.properties() + # Grab objective properties to pass to sample + objective_properties = self._objective.properties() - # Calculate required output image for the given upscale - # This way of providing the upscale will be deprecated in the future - # in favor of dt.Upscale(). - _upscale_given_by_optics = additional_sample_kwargs["upscale"] + #TODO: TBE + """ + # Calculate required output image for the given upscale. + # This upscale way will be deprecated in favor of dt.Upscale(). + _upscale_given_by_optics = objective_properties["upscale"] if np.array(_upscale_given_by_optics).size == 1: _upscale_given_by_optics = (_upscale_given_by_optics,) * 3 with u.context( create_context( - *additional_sample_kwargs["voxel_size"], *_upscale_given_by_optics + *objective_properties["voxel_size"], *_upscale_given_by_optics ) ): + """ - upscale = np.round(get_active_scale()) - - output_region = additional_sample_kwargs.pop("output_region") - additional_sample_kwargs["output_region"] = [ - int(o * upsc) - for o, upsc in zip( - output_region, (upscale[0], upscale[1], upscale[0], upscale[1]) - ) - ] + with u.context(create_context(*objective_properties["voxel_size"])): - padding = additional_sample_kwargs.pop("padding") - additional_sample_kwargs["padding"] = [ - int(p * upsc) - for p, upsc in zip( - padding, (upscale[0], upscale[1], upscale[0], upscale[1]) - ) - ] + upscale = np.round(get_active_scale()) + def _scale_region_2d( + region: list[int], + upscale: tuple[float, float, float], + ) -> list[int]: + """Scale a 4-tuple region (x_min, y_min, x_max, y_max) or + padding using the lateral upscale factors (ux, uy).""" + ux, uy, _ = upscale + return [int(v * f) for v, f in zip(region, (ux, uy, ux, uy))] + + # Scale output region from optics into sample voxel units. + output_region = objective_properties.pop("output_region") + objective_properties["output_region"] = _scale_region_2d( + output_region, + upscale, + ) self._objective.output_region.set_value( - additional_sample_kwargs["output_region"] + objective_properties["output_region"] ) - self._objective.padding.set_value(additional_sample_kwargs["padding"]) + # Scale padding region in the same way (left, top, right, bottom). + padding = objective_properties.pop("padding") + objective_properties["padding"] = _scale_region_2d( + padding, + upscale, + ) + self._objective.padding.set_value(objective_properties["padding"]) + + # Propagate all relevant properties from the objective to the + # sample graph. This ensures scatterers are evaluated in the + # same voxel size, output region, and padding as the optics. + # The extra flag `return_fft=True` is forced here because most + # objectives (e.g., Brightfield, Holography) operate in Fourier + # space, and they require scatterers to provide Fourier-domain + # data in addition to real-space volumes. propagate_data_to_dependencies( - self._sample, **{"return_fft": True, **additional_sample_kwargs} + self._sample, + **{"return_fft": True, **objective_properties}, ) + # Evaluate the sample feature to obtain scatterers. + # The result may be a single scatterer or a list of them. list_of_scatterers = self._sample() - if not isinstance(list_of_scatterers, list): list_of_scatterers = [list_of_scatterers] # All scatterers that are defined as volumes. - volume_samples = [ + # Volume scatterers occupy voxels in 3D (e.g. PointParticle). + volume_scatterers = [ scatterer for scatterer in list_of_scatterers if not scatterer.get_property("is_field", default=False) ] # All scatterers that are defined as fields. - field_samples = [ + # Field scatterers provide a complex field directly, + # bypassing volume merge. + field_scatterers = [ scatterer for scatterer in list_of_scatterers if scatterer.get_property("is_field", default=False) ] + + + + + + + # Merge all volumes into a single volume. sample_volume, limits = _create_volume( - volume_samples, - **additional_sample_kwargs, + volume_scatterers, + **objective_properties, ) sample_volume = Image(sample_volume) # Merge all properties into the volume. - for scatterer in volume_samples + field_samples: + for scatterer in volume_scatterers + field_scatterers: sample_volume.merge_properties_from(scatterer) # Let the objective know about the limits of the volume and all the fields. propagate_data_to_dependencies( self._objective, limits=limits, - fields=field_samples, + fields=field_scatterers, ) imaged_sample = self._objective.resolve(sample_volume) - # Upscale given by the optics needs to be handled separately. + #TODO: TBE + """ + # Handling separately upscale given by optics. + # This upscale way will be deprecated in favor of dt.Upscale(). if _upscale_given_by_optics != (1, 1, 1): imaged_sample = AveragePooling((*_upscale_given_by_optics[:2], 1))( imaged_sample ) + """ + + return imaged_sample + """ + #TODO: erase rest of the method # Merge with input if not image: if not self._wrap_array_with_image and isinstance(imaged_sample, Image): @@ -377,6 +419,7 @@ def get( for i in range(len(image)): image[i].merge_properties_from(imaged_sample) return image + """ # def _no_wrap_format_input(self, *args, **kwargs) -> list: # return self._image_wrapped_format_input(*args, **kwargs) diff --git a/deeptrack/tests/test_scatterers.py b/deeptrack/tests/test_scatterers.py index ca926d855..ae40b96a0 100644 --- a/deeptrack/tests/test_scatterers.py +++ b/deeptrack/tests/test_scatterers.py @@ -68,6 +68,8 @@ def test_Ellipse(self): self.assertEqual(output_image.shape, (64, 64, 1)) def test_EllipseUpscale(self): + pass #TODO: adapt test with dt.Upscale() + """ optics = Fluorescence( NA=0.7, wavelength=680e-9, @@ -105,8 +107,11 @@ def test_EllipseUpscale(self): imaged_scatterer.resolve() scatterer_volume = scatterer() self.assertEqual(scatterer_volume.shape, (39, 79, 1)) + """ def test_EllipseUpscaleAsymmetric(self): + pass #TODO: adapt test with dt.Upscale() + """ optics = Fluorescence( NA=0.7, wavelength=680e-9, @@ -144,6 +149,7 @@ def test_EllipseUpscaleAsymmetric(self): imaged_scatterer.resolve() scatterer_volume = scatterer() self.assertEqual(scatterer_volume.shape, (19, 39, 1)) + """ def test_Sphere(self): optics = Fluorescence( @@ -166,7 +172,8 @@ def test_Sphere(self): self.assertEqual(output_image.shape, (64, 64, 1)) def test_SphereUpscale(self): - + pass #TODO: adapt test with dt.Upscale() + """ optics = Fluorescence( NA=0.7, wavelength=680e-9, @@ -185,6 +192,7 @@ def test_SphereUpscale(self): imaged_scatterer.resolve() scatterer_volume = scatterer() self.assertEqual(scatterer_volume.shape, (40, 40, 40)) + """ def test_Ellipsoid(self): optics = Fluorescence( @@ -208,6 +216,8 @@ def test_Ellipsoid(self): self.assertEqual(output_image.shape, (64, 64, 1)) def test_EllipsoidUpscale(self): + pass #TODO: adapt test with dt.Upscale() + """ optics = Fluorescence( NA=0.7, wavelength=680e-9, @@ -227,8 +237,11 @@ def test_EllipsoidUpscale(self): imaged_scatterer.resolve() scatterer_volume = scatterer() self.assertEqual(scatterer_volume.shape, (19, 39, 9)) + """ def test_EllipsoidUpscaleAsymmetric(self): + pass #TODO: adapt test with dt.Upscale() + """ optics = Fluorescence( NA=0.7, wavelength=680e-9, @@ -288,6 +301,7 @@ def test_EllipsoidUpscaleAsymmetric(self): imaged_scatterer.resolve() scatterer_volume = scatterer() self.assertEqual(scatterer_volume.shape, (19, 39, 19)) + """ def test_MieSphere(self): optics_1 = Brightfield(