From c8e40cf626417b3e32f187df0b30a099bcf7ce40 Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Mon, 11 Aug 2025 16:56:04 +0200 Subject: [PATCH 1/7] starting to update pupil --- deeptrack/optics.py | 71 ++++++++++++++++++++++++++++++++------------- 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 5149bdae2..a1de442b1 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -140,9 +140,12 @@ def _pad_volume( from typing import Any import warnings +import array_api_compat as apc import numpy as np +from numpy.typing import NDArray from scipy.ndimage import convolve +from deeptrack.backend import config, TORCH_AVAILABLE, xp from deeptrack.backend.units import ( ConversionTable, create_context, @@ -158,6 +161,9 @@ def _pad_volume( from deeptrack import image from deeptrack import units_registry as u +if TORCH_AVAILABLE: + import torch + #TODO ***??*** revise Microscope - torch, typing, docstring, unit test class Microscope(StructuralFeature): @@ -694,11 +700,11 @@ def _pupil( NA: float, wavelength: float, refractive_index_medium: float, - include_aberration: bool = True, + include_aberration: bool = True, defocus: float | ArrayLike[float] = 0, **kwargs: Any, ): - """Calculates the pupil function at different focal points. + """Calculate the pupil function at different focal points. Parameters ---------- @@ -739,56 +745,81 @@ def _pupil( ... ) >>> print(pupil.shape) (1, 128, 128) - + """ # Calculates the pupil at each z-position in defocus. voxel_size = get_active_voxel_size() - shape = np.array(shape) - # Pupil radius - R = NA / wavelength * np.array(voxel_size)[:2] + if config.get_backend() == "numpy": + shape = np.array(shape) + + # Pupil radius + R = NA / wavelength * np.array(voxel_size)[:2] + + elif config.get_backend() == "torch": + shape = torch.tensor(shape) + + # Pupil radius + R = NA / wavelength * torch.tensor(voxel_size)[:2] + + else: + raise ValueError(f"Unsupported backend: {config.get_backend()}") + x_radius = R[0] * shape[0] y_radius = R[1] * shape[1] - x = (np.linspace(-(shape[0] / 2), shape[0] / 2 - 1, shape[0])) / x_radius + 1e-8 - y = (np.linspace(-(shape[1] / 2), shape[1] / 2 - 1, shape[1])) / y_radius + 1e-8 + x = ( + xp.linspace(-(shape[0] / 2), shape[0] / 2 - 1, shape[0]) + ) / x_radius + 1e-8 + y = ( + xp.linspace(-(shape[1] / 2), shape[1] / 2 - 1, shape[1]) + ) / y_radius + 1e-8 - W, H = np.meshgrid(y, x) - RHO = (W ** 2 + H ** 2).astype(complex) - pupil_function = Image((RHO < 1) + 0.0j, copy=False) + W, H = xp.meshgrid(y, x, indexing='xy') + + if config.get_backend() == "numpy": + RHO = (W ** 2 + H ** 2).astype(complex) + else: + RHO = (W ** 2 + H ** 2).to(dtype=torch.complex64) + RHO = RHO.numpy() # not to be kept, only to make it compatible with Image at the moment + + pupil_function = Image((RHO < 1) + 0.0j, copy=False) # what should we do about this? # Defocus z_shift = Image( 2 - * np.pi + * np.pi # should be xp.pi * refractive_index_medium / wavelength * voxel_size[2] - * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO), + * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO), # should be xp.sqrt copy=False, ) z_shift._value[z_shift._value.imag != 0] = 0 try: - z_shift = np.nan_to_num(z_shift, False, 0, 0, 0) + z_shift = np.nan_to_num(z_shift, False, 0, 0, 0) # should be xp.nan_to_num except TypeError: - np.nan_to_num(z_shift, z_shift) - - defocus = np.reshape(defocus, (-1, 1, 1)) - z_shift = defocus * np.expand_dims(z_shift, axis=0) + np.nan_to_num(z_shift, z_shift) # should be xp.nan_to_num + defocus = np.reshape(defocus, (-1, 1, 1)) # should be xp.reshape + # if config.get_backend() == "numpy": + z_shift = defocus * np.expand_dims(z_shift, axis=0) + # else: + # z_shift = defocus * torch.unsqueeze(z_shift, dim=0) + if include_aberration: pupil = self.pupil if isinstance(pupil, Feature): pupil_function = pupil(pupil_function) - elif isinstance(pupil, np.ndarray): + elif isinstance(pupil, np.ndarray) or torch.is_tensor(pupil): pupil_function *= pupil - pupil_functions = pupil_function * np.exp(1j * z_shift) + pupil_functions = pupil_function * np.exp(1j * z_shift) # should be xp.exp return pupil_functions From b5d9407061012f5eccfcb1903c6f388b84032676 Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Wed, 13 Aug 2025 11:48:23 +0200 Subject: [PATCH 2/7] update pupil --- deeptrack/optics.py | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index a1de442b1..f80668fad 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -751,20 +751,20 @@ def _pupil( # Calculates the pupil at each z-position in defocus. voxel_size = get_active_voxel_size() - if config.get_backend() == "numpy": + if self.get_backend() == "numpy": shape = np.array(shape) # Pupil radius R = NA / wavelength * np.array(voxel_size)[:2] - elif config.get_backend() == "torch": + elif self.get_backend() == "torch": shape = torch.tensor(shape) # Pupil radius R = NA / wavelength * torch.tensor(voxel_size)[:2] else: - raise ValueError(f"Unsupported backend: {config.get_backend()}") + raise ValueError(f"Unsupported backend: {self.get_backend()}") x_radius = R[0] * shape[0] @@ -778,37 +778,38 @@ def _pupil( ) / y_radius + 1e-8 W, H = xp.meshgrid(y, x, indexing='xy') - - if config.get_backend() == "numpy": + + if self.get_backend() == "numpy": RHO = (W ** 2 + H ** 2).astype(complex) + pupil_function = (RHO < 1) + 0.0j else: - RHO = (W ** 2 + H ** 2).to(dtype=torch.complex64) - RHO = RHO.numpy() # not to be kept, only to make it compatible with Image at the moment + RHO = (W ** 2 + H ** 2) + pupil_function = (RHO < 1).to(dtype=torch.complex64) + 0.0j + RHO = RHO.to(dtype=torch.complex64) - pupil_function = Image((RHO < 1) + 0.0j, copy=False) # what should we do about this? # Defocus - z_shift = Image( + z_shift = ( 2 - * np.pi # should be xp.pi + * xp.pi * refractive_index_medium / wavelength * voxel_size[2] - * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO), # should be xp.sqrt - copy=False, + * xp.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO) ) - z_shift._value[z_shift._value.imag != 0] = 0 + z_shift[z_shift.imag != 0] = 0 try: - z_shift = np.nan_to_num(z_shift, False, 0, 0, 0) # should be xp.nan_to_num + z_shift = xp.nan_to_num(z_shift, nan=0.0, posinf=None, neginf=None) except TypeError: - np.nan_to_num(z_shift, z_shift) # should be xp.nan_to_num - - defocus = np.reshape(defocus, (-1, 1, 1)) # should be xp.reshape - # if config.get_backend() == "numpy": - z_shift = defocus * np.expand_dims(z_shift, axis=0) - # else: - # z_shift = defocus * torch.unsqueeze(z_shift, dim=0) + xp.nan_to_num(z_shift, z_shift) + + if self.get_backend() == "numpy": + defocus = np.reshape(defocus, (-1, 1, 1)) + z_shift = defocus * np.expand_dims(z_shift, axis=0) + else: + defocus = torch.reshape(torch.as_tensor(defocus), (-1, 1, 1)) + z_shift = defocus * torch.unsqueeze(z_shift, dim=0) if include_aberration: pupil = self.pupil @@ -819,7 +820,7 @@ def _pupil( elif isinstance(pupil, np.ndarray) or torch.is_tensor(pupil): pupil_function *= pupil - pupil_functions = pupil_function * np.exp(1j * z_shift) # should be xp.exp + pupil_functions = pupil_function * xp.exp(1j * z_shift) return pupil_functions From 5a8fbe98a8642fae8c64f1ad60147edd31a85f1e Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Wed, 13 Aug 2025 15:24:29 +0200 Subject: [PATCH 3/7] update pupil --- deeptrack/optics.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index f80668fad..51b012563 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -758,7 +758,7 @@ def _pupil( R = NA / wavelength * np.array(voxel_size)[:2] elif self.get_backend() == "torch": - shape = torch.tensor(shape) + shape = torch.tensor(shape, dtype=torch.float64) # Pupil radius R = NA / wavelength * torch.tensor(voxel_size)[:2] @@ -771,10 +771,10 @@ def _pupil( y_radius = R[1] * shape[1] x = ( - xp.linspace(-(shape[0] / 2), shape[0] / 2 - 1, shape[0]) + xp.linspace(-(shape[0] / 2), shape[0] / 2 - 1, int(shape[0])) ) / x_radius + 1e-8 y = ( - xp.linspace(-(shape[1] / 2), shape[1] / 2 - 1, shape[1]) + xp.linspace(-(shape[1] / 2), shape[1] / 2 - 1, int(shape[1])) ) / y_radius + 1e-8 W, H = xp.meshgrid(y, x, indexing='xy') @@ -784,8 +784,8 @@ def _pupil( pupil_function = (RHO < 1) + 0.0j else: RHO = (W ** 2 + H ** 2) - pupil_function = (RHO < 1).to(dtype=torch.complex64) + 0.0j - RHO = RHO.to(dtype=torch.complex64) + pupil_function = (RHO < 1).to(dtype=torch.complex128) + 0.0j + RHO = RHO.to(dtype=torch.complex128) # Defocus z_shift = ( From 40fb515789ccca12b7bb65ecfd5ff1f3b241fb5c Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Wed, 13 Aug 2025 15:32:14 +0200 Subject: [PATCH 4/7] update pupil --- deeptrack/optics.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 51b012563..ffaf0eaf4 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -766,7 +766,6 @@ def _pupil( else: raise ValueError(f"Unsupported backend: {self.get_backend()}") - x_radius = R[0] * shape[0] y_radius = R[1] * shape[1] @@ -777,13 +776,13 @@ def _pupil( xp.linspace(-(shape[1] / 2), shape[1] / 2 - 1, int(shape[1])) ) / y_radius + 1e-8 - W, H = xp.meshgrid(y, x, indexing='xy') + W, H = xp.meshgrid(y, x, indexing="xy") if self.get_backend() == "numpy": - RHO = (W ** 2 + H ** 2).astype(complex) + RHO = (W**2 + H**2).astype(complex) pupil_function = (RHO < 1) + 0.0j else: - RHO = (W ** 2 + H ** 2) + RHO = W**2 + H**2 pupil_function = (RHO < 1).to(dtype=torch.complex128) + 0.0j RHO = RHO.to(dtype=torch.complex128) From 8c6ee961e95f7553d52edb3ca029fb9d7f4124f6 Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Wed, 13 Aug 2025 16:08:34 +0200 Subject: [PATCH 5/7] update pupil --- deeptrack/optics.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index ffaf0eaf4..94395e2a3 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -702,8 +702,7 @@ def _pupil( refractive_index_medium: float, include_aberration: bool = True, defocus: float | ArrayLike[float] = 0, - **kwargs: Any, - ): + ) -> NDArray | torch.Tensor: """Calculate the pupil function at different focal points. Parameters @@ -716,26 +715,37 @@ def _pupil( The wavelength of the scattered light in meters. refractive_index_medium: float The refractive index of the medium. - voxel_size: array_like[float (, float, float)] - The distance between pixels in the camera. A third value can be - included to define the resolution in the z-direction. include_aberration: bool - If True, the aberration is included in the pupil function. + If True, the aberration is included in the pupil function. Default is `True`. defocus: float or list[float] The defocus of the system. If a list is given, the pupil is - calculated for each focal point. Defocus is given in meters. + calculated for each focal point. Defocus is given in meters. Default is `0`. Returns ------- pupil: array_like[complex] - The pupil function. Shape is (z, y, x). + The complex pupil function(s) at the specified defocus positions. + Shape is (z, y, x). Examples -------- - Calculating the pupil function: - >>> import deeptrack as dt + Calculating the pupil function: + >>> optics = dt.Optics() + >>> pupil = optics._pupil( + ... shape=(128, 128), + ... NA=0.8, + ... wavelength=0.55e-6, + ... refractive_index_medium=1.33, + ... ) + >>> print(pupil.shape, pupil.dtype) + (1, 128, 128) complex128 + + Calculating the pupil function with a PyTorch backend: + >>> from deeptrack.backend import config + >>> config.set_backend('torch') + >>> >>> optics = dt.Optics() >>> pupil = optics._pupil( ... shape=(128, 128), @@ -743,8 +753,8 @@ def _pupil( ... wavelength=0.55e-6, ... refractive_index_medium=1.33, ... ) - >>> print(pupil.shape) - (1, 128, 128) + >>> print(pupil.shape, pupil.dtype) + torch.Size([1, 128, 128]) torch.complex128 """ From f7b8712023f3c28e7cb84c3d26c71076f8864521 Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Wed, 13 Aug 2025 16:23:02 +0200 Subject: [PATCH 6/7] update pupil --- deeptrack/optics.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 94395e2a3..f48e5905d 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -702,8 +702,9 @@ def _pupil( refractive_index_medium: float, include_aberration: bool = True, defocus: float | ArrayLike[float] = 0, + **kwargs: Any, ) -> NDArray | torch.Tensor: - """Calculate the pupil function at different focal points. + """Calculate the complex pupil function at one or more focal points. Parameters ---------- @@ -716,10 +717,12 @@ def _pupil( refractive_index_medium: float The refractive index of the medium. include_aberration: bool - If True, the aberration is included in the pupil function. Default is `True`. + If True, the aberration is included in the pupil function. Default + is `True`. defocus: float or list[float] The defocus of the system. If a list is given, the pupil is - calculated for each focal point. Defocus is given in meters. Default is `0`. + calculated for each focal point. Defocus is given in meters. + Default is `0`. Returns ------- @@ -727,6 +730,11 @@ def _pupil( The complex pupil function(s) at the specified defocus positions. Shape is (z, y, x). + Notes + ----- + The backend (NumPy or PyTorch) is determined by `self.get_backend()` + and can be switched using the global backend configuration. + Examples -------- >>> import deeptrack as dt @@ -744,7 +752,7 @@ def _pupil( Calculating the pupil function with a PyTorch backend: >>> from deeptrack.backend import config - >>> config.set_backend('torch') + >>> config.set_backend("torch") >>> >>> optics = dt.Optics() >>> pupil = optics._pupil( @@ -758,7 +766,7 @@ def _pupil( """ - # Calculates the pupil at each z-position in defocus. + # Calculate the pupil at each z-position in defocus. voxel_size = get_active_voxel_size() if self.get_backend() == "numpy": From 421dbdbf481c308f6e0d420b22346c78e707724d Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Wed, 13 Aug 2025 16:32:17 +0200 Subject: [PATCH 7/7] update pupil --- deeptrack/optics.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index f48e5905d..aab96c75a 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -798,23 +798,37 @@ def _pupil( if self.get_backend() == "numpy": RHO = (W**2 + H**2).astype(complex) - pupil_function = (RHO < 1) + 0.0j + pupil_function = Image((RHO < 1) + 0.0j, copy=False) + + # Defocus + z_shift = Image( + 2 + * np.pi + * refractive_index_medium + / wavelength + * voxel_size[2] + * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO), + copy=False, + ) + + z_shift._value[z_shift._value.imag != 0] = 0 + else: RHO = W**2 + H**2 pupil_function = (RHO < 1).to(dtype=torch.complex128) + 0.0j RHO = RHO.to(dtype=torch.complex128) - # Defocus - z_shift = ( - 2 - * xp.pi - * refractive_index_medium - / wavelength - * voxel_size[2] - * xp.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO) - ) + # Defocus + z_shift = ( + 2 + * xp.pi + * refractive_index_medium + / wavelength + * voxel_size[2] + * xp.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO) + ) - z_shift[z_shift.imag != 0] = 0 + z_shift[z_shift.imag != 0] = 0 try: z_shift = xp.nan_to_num(z_shift, nan=0.0, posinf=None, neginf=None)