Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 98 additions & 55 deletions deeptrack/optics.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,11 @@ def _pad_volume(
from __future__ import annotations

from pint import Quantity
from typing import Any
from typing import Any, TYPE_CHECKING
import warnings

import numpy as np
from numpy.typing import NDArray
from scipy.ndimage import convolve

from deeptrack.backend.units import (
Expand All @@ -158,6 +159,9 @@ def _pad_volume(
from deeptrack import image
from deeptrack import units_registry as u

if TYPE_CHECKING:
import torch


#TODO ***??*** revise Microscope - torch, typing, docstring, unit test
class Microscope(StructuralFeature):
Expand Down Expand Up @@ -207,59 +211,60 @@ class Microscope(StructuralFeature):

__distributed__ = False

_sample: Feature
_objective: Feature

def __init__(
self: Microscope,
sample: Feature,
objective: Feature,
**kwargs: Any,
):
"""Initialize the `Microscope` instance.
"""Initialize a microscope feature combining sample and optics.

This constructor attaches a sample feature (typically a combination of
scatterers) and an objective feature (optical system) to the
microscope.

Parameters
----------
sample: Feature
A feature-set resolving a list of images describing the sample to be
imaged.
Feature that resolves one or more scatterer volumes or fields
representing the sample to be imaged.
objective: Feature
A feature-set defining the optical device that images the sample.
Feature describing the optical system used to image the sample
(e.g., brightfield, fluorescence).
**kwargs: Any
Additional parameters passed to the base `StructuralFeature` class.

Attributes
----------
_sample: Feature
The feature-set defining the sample to be imaged.
_objective: Feature
The feature-set defining the optical system imaging the sample.
Additional keyword arguments passed to the base `StructuralFeature`
class.

"""

super().__init__(**kwargs)

self._sample = self.add_feature(sample)
self._objective = self.add_feature(objective)

#TODO: erase following line when rid of Image
self._sample.store_properties()

def get(
self: Microscope,
image: Image | None,
input: Any = None, # Ignored, kept for API compatibility
**kwargs: Any,
) -> Image:
) -> NDArray[Any] | torch.Tensor:
"""Generate an image of the sample using the defined optical system.

This method processes the sample through the optical system to
produce a simulated image.

Parameters
----------
image: Image | None
The input image to be processed. If None, a new image is created.
image: Any, optional
Ignored. Kept for API compatibility. Defaults to None.
**kwargs: Any
Additional parameters for the imaging process.

Returns
-------
Image: Image
array or tensor
The processed image after applying the optical system.

Examples
Expand All @@ -277,94 +282,131 @@ def get(

"""

# Grab properties from the objective to pass to the sample
additional_sample_kwargs = self._objective.properties()
# Grab objective properties to pass to sample
objective_properties = self._objective.properties()

# Calculate required output image for the given upscale
# This way of providing the upscale will be deprecated in the future
# in favor of dt.Upscale().
_upscale_given_by_optics = additional_sample_kwargs["upscale"]
#TODO: TBE
"""
# Calculate required output image for the given upscale.
# This upscale way will be deprecated in favor of dt.Upscale().
_upscale_given_by_optics = objective_properties["upscale"]
if np.array(_upscale_given_by_optics).size == 1:
_upscale_given_by_optics = (_upscale_given_by_optics,) * 3

with u.context(
create_context(
*additional_sample_kwargs["voxel_size"], *_upscale_given_by_optics
*objective_properties["voxel_size"], *_upscale_given_by_optics
)
):
"""

upscale = np.round(get_active_scale())

output_region = additional_sample_kwargs.pop("output_region")
additional_sample_kwargs["output_region"] = [
int(o * upsc)
for o, upsc in zip(
output_region, (upscale[0], upscale[1], upscale[0], upscale[1])
)
]
with u.context(create_context(*objective_properties["voxel_size"])):

padding = additional_sample_kwargs.pop("padding")
additional_sample_kwargs["padding"] = [
int(p * upsc)
for p, upsc in zip(
padding, (upscale[0], upscale[1], upscale[0], upscale[1])
)
]
upscale = np.round(get_active_scale())

def _scale_region_2d(
region: list[int],
upscale: tuple[float, float, float],
) -> list[int]:
"""Scale a 4-tuple region (x_min, y_min, x_max, y_max) or
padding using the lateral upscale factors (ux, uy)."""
ux, uy, _ = upscale
return [int(v * f) for v, f in zip(region, (ux, uy, ux, uy))]

# Scale output region from optics into sample voxel units.
output_region = objective_properties.pop("output_region")
objective_properties["output_region"] = _scale_region_2d(
output_region,
upscale,
)
self._objective.output_region.set_value(
additional_sample_kwargs["output_region"]
objective_properties["output_region"]
)
self._objective.padding.set_value(additional_sample_kwargs["padding"])

# Scale padding region in the same way (left, top, right, bottom).
padding = objective_properties.pop("padding")
objective_properties["padding"] = _scale_region_2d(
padding,
upscale,
)
self._objective.padding.set_value(objective_properties["padding"])

# Propagate all relevant properties from the objective to the
# sample graph. This ensures scatterers are evaluated in the
# same voxel size, output region, and padding as the optics.
# The extra flag `return_fft=True` is forced here because most
# objectives (e.g., Brightfield, Holography) operate in Fourier
# space, and they require scatterers to provide Fourier-domain
# data in addition to real-space volumes.
propagate_data_to_dependencies(
self._sample, **{"return_fft": True, **additional_sample_kwargs}
self._sample,
**{"return_fft": True, **objective_properties},
)

# Evaluate the sample feature to obtain scatterers.
# The result may be a single scatterer or a list of them.
list_of_scatterers = self._sample()

if not isinstance(list_of_scatterers, list):
list_of_scatterers = [list_of_scatterers]

# All scatterers that are defined as volumes.
volume_samples = [
# Volume scatterers occupy voxels in 3D (e.g. PointParticle).
volume_scatterers = [
scatterer
for scatterer in list_of_scatterers
if not scatterer.get_property("is_field", default=False)
]

# All scatterers that are defined as fields.
field_samples = [
# Field scatterers provide a complex field directly,
# bypassing volume merge.
field_scatterers = [
scatterer
for scatterer in list_of_scatterers
if scatterer.get_property("is_field", default=False)
]








# Merge all volumes into a single volume.
sample_volume, limits = _create_volume(
volume_samples,
**additional_sample_kwargs,
volume_scatterers,
**objective_properties,
)
sample_volume = Image(sample_volume)

# Merge all properties into the volume.
for scatterer in volume_samples + field_samples:
for scatterer in volume_scatterers + field_scatterers:
sample_volume.merge_properties_from(scatterer)

# Let the objective know about the limits of the volume and all the fields.
propagate_data_to_dependencies(
self._objective,
limits=limits,
fields=field_samples,
fields=field_scatterers,
)

imaged_sample = self._objective.resolve(sample_volume)

# Upscale given by the optics needs to be handled separately.
#TODO: TBE
"""
# Handling separately upscale given by optics.
# This upscale way will be deprecated in favor of dt.Upscale().
if _upscale_given_by_optics != (1, 1, 1):
imaged_sample = AveragePooling((*_upscale_given_by_optics[:2], 1))(
imaged_sample
)
"""

return imaged_sample

"""
#TODO: erase rest of the method
# Merge with input
if not image:
if not self._wrap_array_with_image and isinstance(imaged_sample, Image):
Expand All @@ -377,6 +419,7 @@ def get(
for i in range(len(image)):
image[i].merge_properties_from(imaged_sample)
return image
"""

# def _no_wrap_format_input(self, *args, **kwargs) -> list:
# return self._image_wrapped_format_input(*args, **kwargs)
Expand Down
16 changes: 15 additions & 1 deletion deeptrack/tests/test_scatterers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def test_Ellipse(self):
self.assertEqual(output_image.shape, (64, 64, 1))

def test_EllipseUpscale(self):
pass #TODO: adapt test with dt.Upscale()
"""
optics = Fluorescence(
NA=0.7,
wavelength=680e-9,
Expand Down Expand Up @@ -105,8 +107,11 @@ def test_EllipseUpscale(self):
imaged_scatterer.resolve()
scatterer_volume = scatterer()
self.assertEqual(scatterer_volume.shape, (39, 79, 1))
"""

def test_EllipseUpscaleAsymmetric(self):
pass #TODO: adapt test with dt.Upscale()
"""
optics = Fluorescence(
NA=0.7,
wavelength=680e-9,
Expand Down Expand Up @@ -144,6 +149,7 @@ def test_EllipseUpscaleAsymmetric(self):
imaged_scatterer.resolve()
scatterer_volume = scatterer()
self.assertEqual(scatterer_volume.shape, (19, 39, 1))
"""

def test_Sphere(self):
optics = Fluorescence(
Expand All @@ -166,7 +172,8 @@ def test_Sphere(self):
self.assertEqual(output_image.shape, (64, 64, 1))

def test_SphereUpscale(self):

pass #TODO: adapt test with dt.Upscale()
"""
optics = Fluorescence(
NA=0.7,
wavelength=680e-9,
Expand All @@ -185,6 +192,7 @@ def test_SphereUpscale(self):
imaged_scatterer.resolve()
scatterer_volume = scatterer()
self.assertEqual(scatterer_volume.shape, (40, 40, 40))
"""

def test_Ellipsoid(self):
optics = Fluorescence(
Expand All @@ -208,6 +216,8 @@ def test_Ellipsoid(self):
self.assertEqual(output_image.shape, (64, 64, 1))

def test_EllipsoidUpscale(self):
pass #TODO: adapt test with dt.Upscale()
"""
optics = Fluorescence(
NA=0.7,
wavelength=680e-9,
Expand All @@ -227,8 +237,11 @@ def test_EllipsoidUpscale(self):
imaged_scatterer.resolve()
scatterer_volume = scatterer()
self.assertEqual(scatterer_volume.shape, (19, 39, 9))
"""

def test_EllipsoidUpscaleAsymmetric(self):
pass #TODO: adapt test with dt.Upscale()
"""
optics = Fluorescence(
NA=0.7,
wavelength=680e-9,
Expand Down Expand Up @@ -288,6 +301,7 @@ def test_EllipsoidUpscaleAsymmetric(self):
imaged_scatterer.resolve()
scatterer_volume = scatterer()
self.assertEqual(scatterer_volume.shape, (19, 39, 19))
"""

def test_MieSphere(self):
optics_1 = Brightfield(
Expand Down
Loading