Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 98 additions & 41 deletions deeptrack/optics.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,11 +713,11 @@ def _pupil(
NA: float,
wavelength: float,
refractive_index_medium: float,
include_aberration: bool = True,
include_aberration: bool = True,
defocus: float | ArrayLike[float] = 0,
**kwargs: Any,
):
"""Calculates the pupil function at different focal points.
) -> NDArray | torch.Tensor:
"""Calculate the complex pupil function at one or more focal points.

Parameters
----------
Expand All @@ -729,85 +729,142 @@ def _pupil(
The wavelength of the scattered light in meters.
refractive_index_medium: float
The refractive index of the medium.
voxel_size: array_like[float (, float, float)]
The distance between pixels in the camera. A third value can be
included to define the resolution in the z-direction.
include_aberration: bool
If True, the aberration is included in the pupil function.
If True, the aberration is included in the pupil function. Default
is `True`.
defocus: float or list[float]
The defocus of the system. If a list is given, the pupil is
calculated for each focal point. Defocus is given in meters.
Default is `0`.

Returns
-------
pupil: array_like[complex]
The pupil function. Shape is (z, y, x).
The complex pupil function(s) at the specified defocus positions.
Shape is (z, y, x).

Notes
-----
The backend (NumPy or PyTorch) is determined by `self.get_backend()`
and can be switched using the global backend configuration.

Examples
--------
Calculating the pupil function:

>>> import deeptrack as dt

Calculating the pupil function:
>>> optics = dt.Optics()
>>> pupil = optics._pupil(
... shape=(128, 128),
... NA=0.8,
... wavelength=0.55e-6,
... refractive_index_medium=1.33,
... )
>>> print(pupil.shape)
(1, 128, 128)

>>> print(pupil.shape, pupil.dtype)
(1, 128, 128) complex128

Calculating the pupil function with a PyTorch backend:
>>> from deeptrack.backend import config
>>> config.set_backend("torch")
>>>
>>> optics = dt.Optics()
>>> pupil = optics._pupil(
... shape=(128, 128),
... NA=0.8,
... wavelength=0.55e-6,
... refractive_index_medium=1.33,
... )
>>> print(pupil.shape, pupil.dtype)
torch.Size([1, 128, 128]) torch.complex128

"""

# Calculates the pupil at each z-position in defocus.
# Calculate the pupil at each z-position in defocus.
voxel_size = get_active_voxel_size()
shape = np.array(shape)

# Pupil radius
R = NA / wavelength * np.array(voxel_size)[:2]
if self.get_backend() == "numpy":
shape = np.array(shape)

# Pupil radius
R = NA / wavelength * np.array(voxel_size)[:2]

elif self.get_backend() == "torch":
shape = torch.tensor(shape, dtype=torch.float64)

# Pupil radius
R = NA / wavelength * torch.tensor(voxel_size)[:2]

else:
raise ValueError(f"Unsupported backend: {self.get_backend()}")

x_radius = R[0] * shape[0]
y_radius = R[1] * shape[1]

x = (np.linspace(-(shape[0] / 2), shape[0] / 2 - 1, shape[0])) / x_radius + 1e-8
y = (np.linspace(-(shape[1] / 2), shape[1] / 2 - 1, shape[1])) / y_radius + 1e-8

W, H = np.meshgrid(y, x)
RHO = (W ** 2 + H ** 2).astype(complex)
pupil_function = Image((RHO < 1) + 0.0j, copy=False)
# Defocus
z_shift = Image(
2
* np.pi
* refractive_index_medium
/ wavelength
* voxel_size[2]
* np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO),
copy=False,
)
x = (
xp.linspace(-(shape[0] / 2), shape[0] / 2 - 1, int(shape[0]))
) / x_radius + 1e-8
y = (
xp.linspace(-(shape[1] / 2), shape[1] / 2 - 1, int(shape[1]))
) / y_radius + 1e-8

W, H = xp.meshgrid(y, x, indexing="xy")

z_shift._value[z_shift._value.imag != 0] = 0
if self.get_backend() == "numpy":
RHO = (W**2 + H**2).astype(complex)
pupil_function = Image((RHO < 1) + 0.0j, copy=False)

# Defocus
z_shift = Image(
2
* np.pi
* refractive_index_medium
/ wavelength
* voxel_size[2]
* np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO),
copy=False,
)

z_shift._value[z_shift._value.imag != 0] = 0

else:
RHO = W**2 + H**2
pupil_function = (RHO < 1).to(dtype=torch.complex128) + 0.0j
RHO = RHO.to(dtype=torch.complex128)

# Defocus
z_shift = (
2
* xp.pi
* refractive_index_medium
/ wavelength
* voxel_size[2]
* xp.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO)
)

z_shift[z_shift.imag != 0] = 0

try:
z_shift = np.nan_to_num(z_shift, False, 0, 0, 0)
z_shift = xp.nan_to_num(z_shift, nan=0.0, posinf=None, neginf=None)
except TypeError:
np.nan_to_num(z_shift, z_shift)
xp.nan_to_num(z_shift, z_shift)

if self.get_backend() == "numpy":
defocus = np.reshape(defocus, (-1, 1, 1))
z_shift = defocus * np.expand_dims(z_shift, axis=0)
else:
defocus = torch.reshape(torch.as_tensor(defocus), (-1, 1, 1))
z_shift = defocus * torch.unsqueeze(z_shift, dim=0)

defocus = np.reshape(defocus, (-1, 1, 1))
z_shift = defocus * np.expand_dims(z_shift, axis=0)

if include_aberration:
pupil = self.pupil
if isinstance(pupil, Feature):

pupil_function = pupil(pupil_function)

elif isinstance(pupil, np.ndarray):
elif isinstance(pupil, np.ndarray) or torch.is_tensor(pupil):
pupil_function *= pupil

pupil_functions = pupil_function * np.exp(1j * z_shift)
pupil_functions = pupil_function * xp.exp(1j * z_shift)

return pupil_functions

Expand Down
Loading