diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 8bb78a7f9..41847ba56 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -713,11 +713,11 @@ def _pupil( NA: float, wavelength: float, refractive_index_medium: float, - include_aberration: bool = True, + include_aberration: bool = True, defocus: float | ArrayLike[float] = 0, **kwargs: Any, - ): - """Calculates the pupil function at different focal points. + ) -> NDArray | torch.Tensor: + """Calculate the complex pupil function at one or more focal points. Parameters ---------- @@ -729,26 +729,30 @@ def _pupil( The wavelength of the scattered light in meters. refractive_index_medium: float The refractive index of the medium. - voxel_size: array_like[float (, float, float)] - The distance between pixels in the camera. A third value can be - included to define the resolution in the z-direction. include_aberration: bool - If True, the aberration is included in the pupil function. + If True, the aberration is included in the pupil function. Default + is `True`. defocus: float or list[float] The defocus of the system. If a list is given, the pupil is calculated for each focal point. Defocus is given in meters. + Default is `0`. Returns ------- pupil: array_like[complex] - The pupil function. Shape is (z, y, x). + The complex pupil function(s) at the specified defocus positions. + Shape is (z, y, x). + + Notes + ----- + The backend (NumPy or PyTorch) is determined by `self.get_backend()` + and can be switched using the global backend configuration. Examples -------- - Calculating the pupil function: - >>> import deeptrack as dt + Calculating the pupil function: >>> optics = dt.Optics() >>> pupil = optics._pupil( ... shape=(128, 128), @@ -756,58 +760,111 @@ def _pupil( ... wavelength=0.55e-6, ... refractive_index_medium=1.33, ... ) - >>> print(pupil.shape) - (1, 128, 128) - + >>> print(pupil.shape, pupil.dtype) + (1, 128, 128) complex128 + + Calculating the pupil function with a PyTorch backend: + >>> from deeptrack.backend import config + >>> config.set_backend("torch") + >>> + >>> optics = dt.Optics() + >>> pupil = optics._pupil( + ... shape=(128, 128), + ... NA=0.8, + ... wavelength=0.55e-6, + ... refractive_index_medium=1.33, + ... ) + >>> print(pupil.shape, pupil.dtype) + torch.Size([1, 128, 128]) torch.complex128 + """ - # Calculates the pupil at each z-position in defocus. + # Calculate the pupil at each z-position in defocus. voxel_size = get_active_voxel_size() - shape = np.array(shape) - # Pupil radius - R = NA / wavelength * np.array(voxel_size)[:2] + if self.get_backend() == "numpy": + shape = np.array(shape) + + # Pupil radius + R = NA / wavelength * np.array(voxel_size)[:2] + + elif self.get_backend() == "torch": + shape = torch.tensor(shape, dtype=torch.float64) + + # Pupil radius + R = NA / wavelength * torch.tensor(voxel_size)[:2] + + else: + raise ValueError(f"Unsupported backend: {self.get_backend()}") x_radius = R[0] * shape[0] y_radius = R[1] * shape[1] - x = (np.linspace(-(shape[0] / 2), shape[0] / 2 - 1, shape[0])) / x_radius + 1e-8 - y = (np.linspace(-(shape[1] / 2), shape[1] / 2 - 1, shape[1])) / y_radius + 1e-8 - - W, H = np.meshgrid(y, x) - RHO = (W ** 2 + H ** 2).astype(complex) - pupil_function = Image((RHO < 1) + 0.0j, copy=False) - # Defocus - z_shift = Image( - 2 - * np.pi - * refractive_index_medium - / wavelength - * voxel_size[2] - * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO), - copy=False, - ) + x = ( + xp.linspace(-(shape[0] / 2), shape[0] / 2 - 1, int(shape[0])) + ) / x_radius + 1e-8 + y = ( + xp.linspace(-(shape[1] / 2), shape[1] / 2 - 1, int(shape[1])) + ) / y_radius + 1e-8 + + W, H = xp.meshgrid(y, x, indexing="xy") - z_shift._value[z_shift._value.imag != 0] = 0 + if self.get_backend() == "numpy": + RHO = (W**2 + H**2).astype(complex) + pupil_function = Image((RHO < 1) + 0.0j, copy=False) + + # Defocus + z_shift = Image( + 2 + * np.pi + * refractive_index_medium + / wavelength + * voxel_size[2] + * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO), + copy=False, + ) + + z_shift._value[z_shift._value.imag != 0] = 0 + + else: + RHO = W**2 + H**2 + pupil_function = (RHO < 1).to(dtype=torch.complex128) + 0.0j + RHO = RHO.to(dtype=torch.complex128) + + # Defocus + z_shift = ( + 2 + * xp.pi + * refractive_index_medium + / wavelength + * voxel_size[2] + * xp.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO) + ) + + z_shift[z_shift.imag != 0] = 0 try: - z_shift = np.nan_to_num(z_shift, False, 0, 0, 0) + z_shift = xp.nan_to_num(z_shift, nan=0.0, posinf=None, neginf=None) except TypeError: - np.nan_to_num(z_shift, z_shift) + xp.nan_to_num(z_shift, z_shift) + + if self.get_backend() == "numpy": + defocus = np.reshape(defocus, (-1, 1, 1)) + z_shift = defocus * np.expand_dims(z_shift, axis=0) + else: + defocus = torch.reshape(torch.as_tensor(defocus), (-1, 1, 1)) + z_shift = defocus * torch.unsqueeze(z_shift, dim=0) - defocus = np.reshape(defocus, (-1, 1, 1)) - z_shift = defocus * np.expand_dims(z_shift, axis=0) - if include_aberration: pupil = self.pupil if isinstance(pupil, Feature): pupil_function = pupil(pupil_function) - elif isinstance(pupil, np.ndarray): + elif isinstance(pupil, np.ndarray) or torch.is_tensor(pupil): pupil_function *= pupil - pupil_functions = pupil_function * np.exp(1j * z_shift) + pupil_functions = pupil_function * xp.exp(1j * z_shift) return pupil_functions